@@ -24,12 +24,13 @@ class PenalizationCostRealActions():
24
24
actions, and then to calculate penalization cost based on real actions.
25
25
"""
26
26
27
- def __init__ (self , action_spec , dtype = tf .float32 ):
27
+ def __init__ (self , action_spec , dtype = tf .float32 , custom_cost_func = None ):
28
28
"""
29
29
Args:
30
30
action_spec: Action specs of the original, not canonically-wrapped,
31
31
environment.
32
32
dtype: Action datatype.
33
+ custom_cost_func: Optional callable for custom cost calculation.
33
34
"""
34
35
self ._scale = tf .constant (action_spec .maximum - action_spec .minimum ,
35
36
dtype = dtype )
@@ -39,6 +40,13 @@ def __init__(self, action_spec, dtype=tf.float32):
39
40
' Perhaps this action_spec is from an already wrapped'
40
41
'canonical environment?'
41
42
)
43
+ if custom_cost_func is None :
44
+ def cost_func (actions ):
45
+ cost = - tf .norm (actions , axis = - 1 )
46
+ return cost
47
+ self ._cost_func = cost_func
48
+ else :
49
+ self ._cost_func = custom_cost_func
42
50
43
51
def __call__ (self , actions : 'tf.tensor' ):
44
52
"""Calculate penalization cost.
@@ -53,9 +61,7 @@ def __call__(self, actions: 'tf.tensor'):
53
61
# Transform canonical actions to real actions.
54
62
actions = 0.5 * (actions + 1 ) # In [0, 1].
55
63
actions = actions * self ._scale + self ._offset
56
- # Get cost.
57
- cost = - tf .norm (actions , axis = - 1 )
58
- return cost
64
+ return self ._cost_func (actions )
59
65
60
66
61
67
class MPO (snt .Module ):
0 commit comments