Skip to content

Commit 2cf9a05

Browse files
committed
Add custom cost function option.
1 parent 2268560 commit 2cf9a05

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

flybody/agents/losses_mpo.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@ class PenalizationCostRealActions():
2424
actions, and then to calculate penalization cost based on real actions.
2525
"""
2626

27-
def __init__(self, action_spec, dtype=tf.float32):
27+
def __init__(self, action_spec, dtype=tf.float32, custom_cost_func=None):
2828
"""
2929
Args:
3030
action_spec: Action specs of the original, not canonically-wrapped,
3131
environment.
3232
dtype: Action datatype.
33+
custom_cost_func: Optional callable for custom cost calculation.
3334
"""
3435
self._scale = tf.constant(action_spec.maximum - action_spec.minimum,
3536
dtype=dtype)
@@ -39,6 +40,13 @@ def __init__(self, action_spec, dtype=tf.float32):
3940
' Perhaps this action_spec is from an already wrapped'
4041
'canonical environment?'
4142
)
43+
if custom_cost_func is None:
44+
def cost_func(actions):
45+
cost = - tf.norm(actions, axis=-1)
46+
return cost
47+
self._cost_func = cost_func
48+
else:
49+
self._cost_func = custom_cost_func
4250

4351
def __call__(self, actions: 'tf.tensor'):
4452
"""Calculate penalization cost.
@@ -53,9 +61,7 @@ def __call__(self, actions: 'tf.tensor'):
5361
# Transform canonical actions to real actions.
5462
actions = 0.5 * (actions + 1) # In [0, 1].
5563
actions = actions * self._scale + self._offset
56-
# Get cost.
57-
cost = -tf.norm(actions, axis=-1)
58-
return cost
64+
return self._cost_func(actions)
5965

6066

6167
class MPO(snt.Module):

0 commit comments

Comments
 (0)