Skip to content

Commit

Permalink
Add option to use max norm in LD calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
camall3n committed Aug 4, 2024
1 parent 8221a71 commit 783b661
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions grl/utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,17 @@ def weight_and_sum_discrep_loss(diff: jnp.ndarray,
unweighted_err = (diff**2)
elif error_type == 'abs':
unweighted_err = jnp.abs(diff)
elif error_type == 'max':
unweighted_err = jnp.max(jnp.abs(diff), keepdims=True)
else:
raise NotImplementedError(f"Error {error_type} not implemented yet in mem_loss fn.")

weighted_err = weight * unweighted_err
if value_type == 'q':
weighted_err = weighted_err.sum(axis=0)
if error_type == 'max':
weighted_err = unweighted_err
else:
weighted_err = weight * unweighted_err
if value_type == 'q':
weighted_err = weighted_err.sum(axis=0)

loss = weighted_err.sum()
return loss
Expand Down

0 comments on commit 783b661

Please sign in to comment.