Skip to content

Commit

Permalink
new way to do rewards is now working - more cleanup to follow
Browse files Browse the repository at this point in the history
  • Loading branch information
mginoya committed Oct 5, 2024
1 parent 10de7c1 commit c76e896
Show file tree
Hide file tree
Showing 11 changed files with 146 additions and 250 deletions.
1 change: 0 additions & 1 deletion alfredo/agents/A1/alfredo_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from alfredo.tools import compose_scene
from alfredo.rewards import rConstant
from alfredo.rewards import rHealthy_simple_z
from alfredo.rewards import rSpeed_X
from alfredo.rewards import rControl_act_ss
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
Expand Down
217 changes: 86 additions & 131 deletions alfredo/agents/aant/aant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from jax import numpy as jp

from alfredo.tools import compose_scene
from alfredo.rewards import Reward
from alfredo.rewards import rConstant
from alfredo.rewards import rHealthy_simple_z
from alfredo.rewards import rSpeed_X
from alfredo.rewards import rControl_act_ss
from alfredo.rewards import rTorques
from alfredo.rewards import rTracking_lin_vel
Expand All @@ -22,6 +22,9 @@ class AAnt(PipelineEnv):
""" """

def __init__(self,
rewards = {},
env_xml_path = "",
agent_xml_path = "",
ctrl_cost_weight=0.5,
use_contact_forces=False,
contact_cost_weight=5e-4,
Expand All @@ -34,20 +37,24 @@ def __init__(self,
backend='generalized',
**kwargs,):

# forcing this model to need an input scene_xml_path or
# the combination of env_xml_path and agent_xml_path
# if none of these options are present, an error will be thrown
path=""

if "env_xml_path" and "agent_xml_path" in kwargs:
env_xp = kwargs["env_xml_path"]
agent_xp = kwargs["agent_xml_path"]
xml_scene = compose_scene(env_xp, agent_xp)
del kwargs["env_xml_path"]
del kwargs["agent_xml_path"]

# env_xml_path and agent_xml_path must be provided
if env_xml_path and agent_xml_path:
self._env_xml_path = env_xml_path
self._agent_xml_path = agent_xml_path

xml_scene = compose_scene(self._env_xml_path, self._agent_xml_path)
sys = mjcf.loads(xml_scene)
else:
raise Exception("env_xml_path & agent_xml_path both must be provided")

# reward dictionary must be provided
if rewards:
self._rewards = rewards
else:
raise Exception("reward_Structure must be in kwargs")

# TODO: clean this up in the future &
# make n_frames a function of input dt
n_frames = 5

if backend in ['spring', 'positional']:
Expand All @@ -64,8 +71,10 @@ def __init__(self,

kwargs['n_frames'] = kwargs.get('n_frames', n_frames)

# Initialize the superclass "PipelineEnv"
super().__init__(sys=sys, backend=backend, **kwargs)

# Setting other object parameters based on input params
self._ctrl_cost_weight = ctrl_cost_weight
self._use_contact_forces = use_contact_forces
self._contact_cost_weight = contact_cost_weight
Expand All @@ -83,151 +92,90 @@ def __init__(self,


def reset(self, rng: jax.Array) -> State:

rng, rng1, rng2, rng3 = jax.random.split(rng, 4)

low, hi = -self._reset_noise_scale, self._reset_noise_scale

jcmd = self._sample_command(rng3)
#wcmd = self._sample_waypoint(rng3)

#print(f"init_q: {self.sys.init_q}")
wcmd = jp.array([0.0, 0.0])

#q = self.sys.init_q
#qd = 0 * jax.random.normal(rng2, (self.sys.qd_size(),))

# initialize position vector with minor randomization in pose
q = self.sys.init_q + jax.random.uniform(
rng1, (self.sys.q_size(),), minval=low, maxval=hi
)


# initialize velocity vector with minor randomization
qd = hi * jax.random.normal(rng2, (self.sys.qd_size(),))

# generate sample commands
jcmd = self._sample_command(rng3)
wcmd = jp.array([0.0, 0.0])

# initialize pipeline_state (the physics state)
pipeline_state = self.pipeline_init(q, qd)

# reset values and metrics
reward, done, zero = jp.zeros(3)

state_info = {
'jcmd':jcmd,
'wcmd':wcmd,
'rewards': {k: 0.0 for k in self._rewards.keys()},
'step': 0,
}

pipeline_state = self.pipeline_init(q, qd)
obs = self._get_obs(pipeline_state, state_info)

reward, done, zero = jp.zeros(3)
metrics = {
'reward_ctrl': zero,
'reward_alive': zero,
'reward_torque': zero,
'reward_lin_vel': zero,
'reward_yaw_vel': zero,
'reward_upright': zero,
'reward_waypoint': zero,
'pos_x_world_abs': zero,
'pos_y_world_abs': zero,
'pos_z_world_abs': zero,
#'dist_goal_x': zero,
#'dist_goal_y': zero,
#'dist_goal_z': zero,
}
metrics = {'pos_x_world_abs': zero,
'pos_y_world_abs': zero,
'pos_z_world_abs': zero,}

for rn, r in self._rewards.items():
metrics[rn] = state_info['rewards'][rn]

# get initial observation vector
obs = self._get_obs(pipeline_state, state_info)

return State(pipeline_state, obs, reward, done, metrics, state_info)

def step(self, state: State, action: jax.Array) -> State:
"""Run one timestep of the environment's dynamics."""


# Save the previous physics state and step physics forward
pipeline_state0 = state.pipeline_state
pipeline_state = self.pipeline_step(pipeline_state0, action)

#print(f"wcmd: {state.info['wcmd']}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
waypoint_cost = rTracking_Waypoint(self.sys,
pipeline_state,
state.info['wcmd'],
weight=0.0,
focus_idx_range=0)

lin_vel_reward = rTracking_lin_vel(self.sys,
pipeline_state,
jp.array([0, 0, 0]), #dummy values for previous CoM
jp.array([0, 0, 0]), #dummy values for current CoM
self.dt,
state.info['jcmd'],
weight=15.0,
focus_idx_range=(0,0))

yaw_vel_reward = rTracking_yaw_vel(self.sys,
pipeline_state,
state.info['jcmd'],
weight=1.0,
focus_idx_range=(0,0))

ctrl_cost = rControl_act_ss(self.sys,
pipeline_state,
action,
weight=0.0)

torque_cost = rTorques(self.sys,
pipeline_state,
action,
weight=0.0)

upright_reward = rUpright(self.sys,
pipeline_state,
weight=0.0)

healthy_reward = rHealthy_simple_z(self.sys,
pipeline_state,
self._healthy_z_range,
early_terminate=self._terminate_when_unhealthy,
weight=0.0,
focus_idx_range=(0, 2))
#reward = 0.0
reward = healthy_reward[0]
reward += ctrl_cost
reward += torque_cost
reward += upright_reward
reward += waypoint_cost
reward += lin_vel_reward
reward += yaw_vel_reward

#print(f"lin_tracking_vel: {lin_vel_reward}")
#print(f"yaw_tracking_vel: {yaw_vel_reward}\n")
# Add all additional parameters to compute rewards
self._rewards['r_lin_vel'].add_param('jcmd', state.info['jcmd'])
self._rewards['r_yaw_vel'].add_param('jcmd', state.info['jcmd'])

# Compute all rewards and accumulate total reward
total_reward = 0.0
for rn, r in self._rewards.items():
r.add_param('sys', self.sys)
r.add_param('pipeline_state', pipeline_state)

reward_value = r.compute()
state.info['rewards'][rn] = reward_value
total_reward += reward_value[0]
# print(f'{rn} reward_val = {reward_value}\n')

# Computing additional metrics as necessary
pos_world = pipeline_state.x.pos[0]
abs_pos_world = jp.abs(pos_world)

#print(f"wcmd: {state.info['wcmd']}")
#print(f"x.pos[0]: {pipeline_state.x.pos[0]}")
#wcmd = state.info['wcmd']
#dist_goal = pos_world[0:2] - wcmd
#print(dist_goal)

#print(f'true position in world: {pos_world}')
#print(f'absolute position in world: {abs_pos_world}')
#print(f"dist_goal: {dist_goal}\n")

# Compute observations
obs = self._get_obs(pipeline_state, state.info)
# print(f"\n")
# print(f"healthy_reward? {healthy_reward}")
# print(f"\n")
#done = 1.0 - healthy_reward[1] if self._terminate_when_unhealthy else 0.0
done = 0.0

# State management
state.info['step'] += 1

state.metrics.update(state.info['rewards'])

state.metrics.update(
reward_ctrl = ctrl_cost,
reward_alive = healthy_reward[0],
reward_torque = torque_cost,
reward_upright = upright_reward,
reward_lin_vel = lin_vel_reward,
reward_yaw_vel = yaw_vel_reward,
reward_waypoint = waypoint_cost,
pos_x_world_abs = abs_pos_world[0],
pos_y_world_abs = abs_pos_world[1],
pos_z_world_abs = abs_pos_world[2],
#dist_goal_x = dist_goal[0],
#dist_goal_y = dist_goal[1],
#dist_goal_z = dist_goal[2],
)

return state.replace(
pipeline_state=pipeline_state, obs=obs, reward=reward, done=done
pipeline_state=pipeline_state, obs=obs, reward=total_reward, done=done
)

def _get_obs(self, pipeline_state, state_info) -> jax.Array:
Expand All @@ -240,16 +188,23 @@ def _get_obs(self, pipeline_state, state_info) -> jax.Array:
torso_pos = pipeline_state.x.pos[0]

jcmd = state_info['jcmd']
wcmd = state_info['wcmd']
#wcmd = state_info['wcmd']

if self._exclude_current_positions_from_observation:
qpos = pipeline_state.q[2:]

return jp.concatenate([qpos] + [qvel] + [local_rpyrate] + [jcmd]) #[jcmd])

obs = jp.concatenate([
jp.array(qpos),
jp.array(qvel),
jp.array(local_rpyrate),
jp.array(jcmd),
])

return obs

def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
x_range = [-25, 25]
y_range = [-25, 25]
x_range = [-25, 25]
y_range = [-25, 25]
z_range = [0, 2]

_, key1, key2, key3 = jax.random.split(rng, 4)
Expand All @@ -271,8 +226,8 @@ def _sample_waypoint(self, rng: jax.Array) -> jax.Array:
return wcmd

def _sample_command(self, rng: jax.Array) -> jax.Array:
lin_vel_x_range = [0.0, 0.0] #[m/s]
lin_vel_y_range = [0.0, 0.0] #[m/s]
lin_vel_x_range = [-3.0, 3.0] #[m/s]
lin_vel_y_range = [-3.0, 3.0] #[m/s]
yaw_vel_range = [-1.0, 1.0] #[rad/s]

_, key1, key2, key3 = jax.random.split(rng, 4)
Expand Down
3 changes: 2 additions & 1 deletion alfredo/rewards/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .reward import Reward

from .rConstant import *
from .rSpeed import *
from .rHealthy import *
from .rControl import *
from .rEnergy import *
Expand Down
7 changes: 3 additions & 4 deletions alfredo/rewards/rConstant.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
from jax import numpy as jp

def rConstant(sys: base.System,
pipeline_state: base.State,
weight=1.0,
focus_idx_range=(0, -1)) -> jp.ndarray:
pipeline_state: base.State,
focus_idx_range=(0, -1)) -> jax.Array:

return jp.array([weight])
return jp.array([1.0])
Loading

0 comments on commit c76e896

Please sign in to comment.