Skip to content

Commit

Permalink
[ENH] Speed up build_mask
Browse files Browse the repository at this point in the history
Signed-off-by: chenhe <[email protected]>
  • Loading branch information
chAwater committed Apr 12, 2023
1 parent 6cc98fc commit 4d2d2d0
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions visualizations/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,20 @@


def build_mask(s: int, margin: float = 2., dtype=torch.float32):
mask = torch.zeros(1, 1, s, s, dtype=dtype)
mask = torch.ones(1, 1, s, s, dtype=dtype)
c = (s - 1) / 2
t = (c - margin / 100. * c) ** 2
sig = 2.
for x in range(s):
for y in range(s):
r = (x - c) ** 2 + (y - c) ** 2
if r > t:
mask[..., x, y] = np.exp((t - r) / sig ** 2)
else:
mask[..., x, y] = 1.
y, x = np.ogrid[:s, :s]
r = (x - c) ** 2 + (y - c) ** 2
# r > t
outer_mask = ((t - r) / sig ** 2)
outer_mask = outer_mask ** (r > t) # To prevent overflow
outer_mask = (r > t) * np.exp(outer_mask)
# r <= t
inner_mask = (r <= t)
mask = mask * outer_mask + mask * inner_mask

return mask


Expand Down

0 comments on commit 4d2d2d0

Please sign in to comment.