diff --git a/README.md b/README.md index 446026e3..307b0505 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,7 @@ LaMa generalizes surprisingly well to much higher resolutions (~2k❗️) than i

-- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](geomagical.com)) +- [Feature Refinement to Improve High Resolution Image Inpainting](https://arxiv.org/abs/2206.13644) / [video](https://www.youtube.com/watch?v=gEukhOheWgE) / code https://github.com/advimman/lama/pull/112 / by Geomagical Labs ([geomagical.com](https://www.geomagical.com))

diff --git a/saicinpainting/evaluation/refinement.py b/saicinpainting/evaluation/refinement.py index d9d3cbac..fa298bfd 100644 --- a/saicinpainting/evaluation/refinement.py +++ b/saicinpainting/evaluation/refinement.py @@ -83,6 +83,26 @@ def _l1_loss( loss += torch.mean(torch.abs(pred_downscaled[mask_downscaled>=1e-8] - ref[mask_downscaled>=1e-8])) return loss +def feats_type_to_list(feats, feats_type): + """unpacks the tuple of features into a list""" + if feats_type == tuple: + feats = list(feats) + elif feats_type == torch.Tensor: + feats = [feats] + else: + raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") + return feats + +def list_to_feats_type(feats, feats_type): + """packs the list of features into the original feature type""" + if feats_type == tuple: + feats = tuple(feats) + elif feats_type == torch.Tensor: + feats = feats[0] + else: + raise NotImplementedError("Expected the output of forward_front to be a tuple or a tensor!") + return feats + def _infer( image : torch.Tensor, mask : torch.Tensor, forward_front : nn.Module, forward_rears : nn.Module, @@ -125,27 +145,30 @@ def _infer( if ref_lower_res is not None: ref_lower_res = ref_lower_res.detach() with torch.no_grad(): - z1,z2 = forward_front(masked_image) + z_feats = forward_front(masked_image) + z_feats_type = type(z_feats) + z_feats = feats_type_to_list(z_feats, z_feats_type) # Inference mask = mask.to(devices[-1]) ekernel = torch.from_numpy(cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(15,15)).astype(bool)).float() ekernel = ekernel.to(devices[-1]) image = image.to(devices[-1]) - z1, z2 = z1.detach().to(devices[0]), z2.detach().to(devices[0]) - z1.requires_grad, z2.requires_grad = True, True + z_feats = [z_feat.detach().to(devices[0]) for z_feat in z_feats] + for z_feat in z_feats: + z_feat.requires_grad = True - optimizer = Adam([z1,z2], lr=lr) + optimizer = Adam(z_feats, lr=lr) pbar = tqdm(range(n_iters), leave=False) for idi in pbar: optimizer.zero_grad() - input_feat = (z1,z2) + input_feat = list_to_feats_type(z_feats, z_feats_type) for idd, forward_rear in enumerate(forward_rears): output_feat = forward_rear(input_feat) if idd < len(devices) - 1: - midz1, midz2 = output_feat - midz1, midz2 = midz1.to(devices[idd+1]), midz2.to(devices[idd+1]) - input_feat = (midz1, midz2) + mid_z_feats = feats_type_to_list(output_feat, z_feats_type) + mid_z_feats = [mid_z_feat.to(devices[idd+1]) for mid_z_feat in mid_z_feats] + input_feat = list_to_feats_type(mid_z_feats, z_feats_type) else: pred = output_feat