Skip to content

Commit 12375d8

Browse files
committed
Fix types.
1 parent a82d7e6 commit 12375d8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

flybody/agents/actors.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -63,13 +63,13 @@ def _policy(self, observation: types.NestedArray) -> types.NestedTensor:
6363
policy = self._policy_network(batched_observation)
6464

6565
# Sample from the policy if it is stochastic.
66-
action = policy.sample() if isinstance(policy,
67-
tfd.Distribution) else policy
66+
action = policy.sample() if isinstance(
67+
policy, tfd.Distribution) else policy
6868

6969
return action
7070

7171
def select_action(self,
72-
observation: types.NestedArray) -> types.NestedArray:
72+
observation: types.NestedArray) -> np.ndarray:
7373
"""Samples from the policy and returns an action."""
7474
if self._observation_callback is not None:
7575
observation = self._observation_callback(observation)

0 commit comments

Comments
 (0)