diff --git a/grl/utils/loss.py b/grl/utils/loss.py index 8979663d..c40bcedc 100644 --- a/grl/utils/loss.py +++ b/grl/utils/loss.py @@ -77,12 +77,17 @@ def weight_and_sum_discrep_loss(diff: jnp.ndarray, unweighted_err = (diff**2) elif error_type == 'abs': unweighted_err = jnp.abs(diff) + elif error_type == 'max': + unweighted_err = jnp.max(jnp.abs(diff), keepdims=True) else: raise NotImplementedError(f"Error {error_type} not implemented yet in mem_loss fn.") - weighted_err = weight * unweighted_err - if value_type == 'q': - weighted_err = weighted_err.sum(axis=0) + if error_type == 'max': + weighted_err = unweighted_err + else: + weighted_err = weight * unweighted_err + if value_type == 'q': + weighted_err = weighted_err.sum(axis=0) loss = weighted_err.sum() return loss