From 6866fea44aa7d0d4897ed32f6dd04540396d48c8 Mon Sep 17 00:00:00 2001 From: Prakhar Kulshreshtha Date: Thu, 4 Jan 2024 09:43:42 -0800 Subject: [PATCH 1/4] make the latent vector flexible so that it can be both tuple and tensor --- saicinpainting/evaluation/refinement.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py index d9d3cbac..fe5c42e4 100644 --- a/saicinpainting/evaluation/refinement.py +++ b/saicinpainting/evaluation/refinement.py @@ -125,21 +125,32 @@ def _infer( if ref_lower_res is not None: ref_lower_res = ref_lower_res.detach() with torch.no_grad(): - z1,z2 = forward_front(masked_image) + z_feats = forward_front(masked_image) + z_feat_type = type(z_feats) + if z_feat_type == tuple: + z_feats = list(z_feats) + elif z_feat_type == torch.Tensor: + z_feats = [z_feats] + else: + raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") # Inference mask = mask.to(devices[-1]) ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float() ekernel = ekernel.to(devices[-1]) image = image.to(devices[-1]) - z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) - z1.requires_grad, z2.requires_grad = True, True + z_feats = [z_feat.detach().to(devices[0]) for z_feat in z_feats] + for z_feat in z_feats: + z_feat.requires_grad = True - optimizer = Adam([z1,z2], lr=lr) + optimizer = Adam(z_feats, lr=lr) pbar = tqdm(range(n_iters), leave=False) for idi in pbar: optimizer.zero_grad() - input_feat = (z1,z2) + if z_feat_type == tuple: + input_feat = tuple(z_feats) + elif z_feat_type == torch.Tensor: + input_feat = z_feats[0] for idd, forward_rear in enumerate(forward_rears): output_feat = forward_rear(input_feat) if idd < len(devices) - 1: From 34a1c54f62019e65931a3c9b23ccbf3c3d61cb96 Mon Sep 17 00:00:00 2001 From: Prakhar Kulshreshtha Date: Fri, 5 Jan 2024 09:57:43 -0800 Subject: [PATCH 2/4] type conversion is repeating, so make it a function --- saicinpainting/evaluation/refinement.py | 41 ++++++++++++++++--------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py index fe5c42e4..f4d085f8 100644 --- a/saicinpainting/evaluation/refinement.py +++ b/saicinpainting/evaluation/refinement.py @@ -83,6 +83,27 @@ def _l1_loss( loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8])) return loss +def feats_type_to_list(feats, feats_type): + """unpacks the tuple of features into a list""" + if feats_type == tuple: + feats = list(feats) + elif feats_type == torch.Tensor: + feats = [feats] + else: + raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") + return feats + +def list_to_feats_type(feats, feats_type): + """packs the list of features into the original feature type""" + if feats_type == tuple: + feats = tuple(feats) + elif feats_type == torch.Tensor: + feats = feats[0] + else: + raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") + return feats + + def _infer( image : torch.Tensor, mask : torch.Tensor, forward_front : nn.Module, forward_rears : nn.Module, @@ -126,13 +147,8 @@ def _infer( ref_lower_res = ref_lower_res.detach() with torch.no_grad(): z_feats = forward_front(masked_image) - z_feat_type = type(z_feats) - if z_feat_type == tuple: - z_feats = list(z_feats) - elif z_feat_type == torch.Tensor: - z_feats = [z_feats] - else: - raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") + z_feats_type = type(z_feats) + z_feats = feats_type_to_list(z_feats, z_feats_type) # Inference mask = mask.to(devices[-1]) ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float() @@ -147,16 +163,13 @@ def _infer( pbar = tqdm(range(n_iters), leave=False) for idi in pbar: optimizer.zero_grad() - if z_feat_type == tuple: - input_feat = tuple(z_feats) - elif z_feat_type == torch.Tensor: - input_feat = z_feats[0] + input_feat = list_to_feats_type(z_feats, z_feats_type) for idd, forward_rear in enumerate(forward_rears): output_feat = forward_rear(input_feat) if idd < len(devices) - 1: - midz1, midz2 = output_feat - midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1]) - input_feat = (midz1, midz2) + mid_z_feats = feats_type_to_list(output_feat, z_feats_type) + mid_z_feats = [mid_z_feat.to(devices[idd+1]) for mid_z_feat in mid_z_feats] + input_feat = list_to_feats_type(mid_z_feats, z_feats_type) else: pred = output_feat From 385eb143c29c394290c5e4a880794e79f75291b1 Mon Sep 17 00:00:00 2001 From: Prakhar Kulshreshtha Date: Fri, 5 Jan 2024 10:11:36 -0800 Subject: [PATCH 3/4] blank space --- saicinpainting/evaluation/refinement.py | 1 - 1 file changed, 1 deletion(-) diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py index f4d085f8..fa298bfd 100644 --- a/saicinpainting/evaluation/refinement.py +++ b/saicinpainting/evaluation/refinement.py @@ -103,7 +103,6 @@ def list_to_feats_type(feats, feats_type): raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") return feats - def _infer( image : torch.Tensor, mask : torch.Tensor, forward_front : nn.Module, forward_rears : nn.Module, From 1f081fcf347b2da231ccc5098f36bb29202b9d69 Mon Sep 17 00:00:00 2001 From: Prakhar Kulshreshtha Date: Fri, 5 Jan 2024 10:59:12 -0800 Subject: [PATCH 4/4] correct the link to Geomagical Labs in README --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 446026e3..307b0505 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ LaMa generalizes surprisingly well to much higher resolutions (~2k❗️) than i

-- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](geomagical.com)) +- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](https://www.geomagical.com))