From 6866fea44aa7d0d4897ed32f6dd04540396d48c8 Mon Sep 17 00:00:00 2001
From: Prakhar Kulshreshtha
Date: Thu, 4 Jan 2024 09:43:42 -0800
Subject: [PATCH 1/4] make the latent vector flexible so that it can be both
tuple and tensor
---
saicinpainting/evaluation/refinement.py | 21 ++++++++++++++++-----
1 file changed, 16 insertions(+), 5 deletions(-)
diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py
index d9d3cbac..fe5c42e4 100644
--- a/saicinpainting/evaluation/refinement.py
+++ b/saicinpainting/evaluation/refinement.py
@@ -125,21 +125,32 @@ def _infer(
if ref_lower_res is not None:
ref_lower_res = ref_lower_res.detach()
with torch.no_grad():
- z1,z2 = forward_front(masked_image)
+ z_feats = forward_front(masked_image)
+ z_feat_type = type(z_feats)
+ if z_feat_type == tuple:
+ z_feats = list(z_feats)
+ elif z_feat_type == torch.Tensor:
+ z_feats = [z_feats]
+ else:
+ raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
# Inference
mask = mask.to(devices[-1])
ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
ekernel = ekernel.to(devices[-1])
image = image.to(devices[-1])
- z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0])
- z1.requires_grad, z2.requires_grad = True, True
+ z_feats = [z_feat.detach().to(devices[0]) for z_feat in z_feats]
+ for z_feat in z_feats:
+ z_feat.requires_grad = True
- optimizer = Adam([z1,z2], lr=lr)
+ optimizer = Adam(z_feats, lr=lr)
pbar = tqdm(range(n_iters), leave=False)
for idi in pbar:
optimizer.zero_grad()
- input_feat = (z1,z2)
+ if z_feat_type == tuple:
+ input_feat = tuple(z_feats)
+ elif z_feat_type == torch.Tensor:
+ input_feat = z_feats[0]
for idd, forward_rear in enumerate(forward_rears):
output_feat = forward_rear(input_feat)
if idd < len(devices) - 1:
From 34a1c54f62019e65931a3c9b23ccbf3c3d61cb96 Mon Sep 17 00:00:00 2001
From: Prakhar Kulshreshtha
Date: Fri, 5 Jan 2024 09:57:43 -0800
Subject: [PATCH 2/4] type conversion is repeating, so make it a function
---
saicinpainting/evaluation/refinement.py | 41 ++++++++++++++++---------
1 file changed, 27 insertions(+), 14 deletions(-)
diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py
index fe5c42e4..f4d085f8 100644
--- a/saicinpainting/evaluation/refinement.py
+++ b/saicinpainting/evaluation/refinement.py
@@ -83,6 +83,27 @@ def _l1_loss(
loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8]))
return loss
+def feats_type_to_list(feats, feats_type):
+ """unpacks the tuple of features into a list"""
+ if feats_type == tuple:
+ feats = list(feats)
+ elif feats_type == torch.Tensor:
+ feats = [feats]
+ else:
+ raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
+ return feats
+
+def list_to_feats_type(feats, feats_type):
+ """packs the list of features into the original feature type"""
+ if feats_type == tuple:
+ feats = tuple(feats)
+ elif feats_type == torch.Tensor:
+ feats = feats[0]
+ else:
+ raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
+ return feats
+
+
def _infer(
image : torch.Tensor, mask : torch.Tensor,
forward_front : nn.Module, forward_rears : nn.Module,
@@ -126,13 +147,8 @@ def _infer(
ref_lower_res = ref_lower_res.detach()
with torch.no_grad():
z_feats = forward_front(masked_image)
- z_feat_type = type(z_feats)
- if z_feat_type == tuple:
- z_feats = list(z_feats)
- elif z_feat_type == torch.Tensor:
- z_feats = [z_feats]
- else:
- raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
+ z_feats_type = type(z_feats)
+ z_feats = feats_type_to_list(z_feats, z_feats_type)
# Inference
mask = mask.to(devices[-1])
ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float()
@@ -147,16 +163,13 @@ def _infer(
pbar = tqdm(range(n_iters), leave=False)
for idi in pbar:
optimizer.zero_grad()
- if z_feat_type == tuple:
- input_feat = tuple(z_feats)
- elif z_feat_type == torch.Tensor:
- input_feat = z_feats[0]
+ input_feat = list_to_feats_type(z_feats, z_feats_type)
for idd, forward_rear in enumerate(forward_rears):
output_feat = forward_rear(input_feat)
if idd < len(devices) - 1:
- midz1, midz2 = output_feat
- midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1])
- input_feat = (midz1, midz2)
+ mid_z_feats = feats_type_to_list(output_feat, z_feats_type)
+ mid_z_feats = [mid_z_feat.to(devices[idd+1]) for mid_z_feat in mid_z_feats]
+ input_feat = list_to_feats_type(mid_z_feats, z_feats_type)
else:
pred = output_feat
From 385eb143c29c394290c5e4a880794e79f75291b1 Mon Sep 17 00:00:00 2001
From: Prakhar Kulshreshtha
Date: Fri, 5 Jan 2024 10:11:36 -0800
Subject: [PATCH 3/4] blank space
---
saicinpainting/evaluation/refinement.py | 1 -
1 file changed, 1 deletion(-)
diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py
index f4d085f8..fa298bfd 100644
--- a/saicinpainting/evaluation/refinement.py
+++ b/saicinpainting/evaluation/refinement.py
@@ -103,7 +103,6 @@ def list_to_feats_type(feats, feats_type):
raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!")
return feats
-
def _infer(
image : torch.Tensor, mask : torch.Tensor,
forward_front : nn.Module, forward_rears : nn.Module,
From 1f081fcf347b2da231ccc5098f36bb29202b9d69 Mon Sep 17 00:00:00 2001
From: Prakhar Kulshreshtha
Date: Fri, 5 Jan 2024 10:59:12 -0800
Subject: [PATCH 4/4] correct the link to Geomagical Labs in README
---
README.md | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/README.md b/README.md
index 446026e3..307b0505 100644
--- a/README.md
+++ b/README.md
@@ -36,7 +36,7 @@ LaMa generalizes surprisingly well to much higher resolutions (~2k❗️) than i
-- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](geomagical.com))
+- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](https://www.geomagical.com))