-
Notifications
You must be signed in to change notification settings - Fork 35
/
utils.py
45 lines (34 loc) · 1.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Misc. utility and helper functions
"""
import copy
from dataclasses import field
from functools import partial
def _update_from_signal(
state_variable,
signal_key,
params,
substep,
state_history,
previous_state,
policy_input,
):
return state_variable, policy_input[signal_key]
def update_from_signal(state_variable, signal_key=None):
"""A generic State Update Function to update a State Variable directly from a Policy Signal
Args:
state_variable (str): State Variable key
signal_key (str, optional): Policy Signal key. Defaults to None.
Returns:
Callable: A generic State Update Function
"""
if not signal_key:
signal_key = state_variable
return partial(_update_from_signal, state_variable, signal_key)
def local_variables(_locals):
return {
key: _locals[key]
for key in [_key for _key in _locals.keys() if "__" not in _key]
}
def default(obj):
return field(default_factory=lambda: copy.copy(obj))