Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] MJX environment prototype #834

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
cf153f6
Add Hopper and Walker2D models for v5
Kallinteris-Andreas May 2, 2023
bc92449
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas May 9, 2023
0cbdd72
Delete hopper_v5.xml
Kallinteris-Andreas May 9, 2023
db3734e
Delete walker2d_v5.xml
Kallinteris-Andreas May 9, 2023
a2d2e64
General MuJoCo Env Documention Cleanup
Kallinteris-Andreas May 9, 2023
f58bb5e
typofix
Kallinteris-Andreas May 9, 2023
7a4bc32
typo fix
Kallinteris-Andreas May 9, 2023
2418631
update following @pseudo-rnd-thoughts reviews
Kallinteris-Andreas May 9, 2023
3b9080b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 5, 2023
77bcb8b
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 16, 2023
7639d18
refactor `tests/env/test_mojoco.py` ->
Kallinteris-Andreas Jun 16, 2023
8eb1b11
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Jun 27, 2023
61d0848
Update setup.py
Kallinteris-Andreas Oct 23, 2023
5831a19
do nothing
Kallinteris-Andreas Oct 23, 2023
803dc49
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 3, 2023
d99cc5d
[MuJoCo] add action space figures
Kallinteris-Andreas Nov 3, 2023
f788bb3
Merge branch 'Farama-Foundation:main' into main
Kallinteris-Andreas Nov 10, 2023
450b471
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Nov 30, 2023
14fb4d8
replace `flat.copy()` with `flatten()`
Kallinteris-Andreas Dec 5, 2023
1583839
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 5, 2023
47a7059
add `MuJoCo.test_model_sensors`
Kallinteris-Andreas Dec 6, 2023
9dc31e2
`test_model_sensors` remove check for standup `v3`
Kallinteris-Andreas Dec 6, 2023
bededa3
factorize `_get_rew()` out of `step`
Kallinteris-Andreas Dec 6, 2023
999d888
some cleanup
Kallinteris-Andreas Dec 6, 2023
0f59baa
support `python==3.8`
Kallinteris-Andreas Dec 6, 2023
76f5e17
fix for real this time
Kallinteris-Andreas Dec 6, 2023
724e47f
`black`
Kallinteris-Andreas Dec 6, 2023
32c1cb8
add prototype
Kallinteris-Andreas Dec 10, 2023
e1772bc
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Dec 10, 2023
30cc231
cleanup
Kallinteris-Andreas Dec 15, 2023
925fcdc
update mjx envs
Kallinteris-Andreas Feb 1, 2024
b7f8806
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 1, 2024
08299e7
huge update
Kallinteris-Andreas Feb 5, 2024
be527c4
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 5, 2024
04ed837
`pre-commit`
Kallinteris-Andreas Feb 5, 2024
72d87ba
Merge branch 'Farama-Foundation:main' into mjx
Kallinteris-Andreas Feb 15, 2024
696b0a0
update
Kallinteris-Andreas Feb 15, 2024
a7c614a
`pre-commit`
Kallinteris-Andreas Feb 15, 2024
3e56f40
fix reacher
Kallinteris-Andreas Feb 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions gymnasium/envs/mjx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"""Contains the base class and environments for MJX."""
from gymnasium.envs.mjx.mjx_env import MJXEnv
188 changes: 188 additions & 0 deletions gymnasium/envs/mjx/ant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Contains the class for the `Ant` environment."""
import gymnasium


try:
import jax
from jax import numpy as jnp
from mujoco import mjx
except ImportError as e:
MJX_IMPORT_ERROR = e
else:
MJX_IMPORT_ERROR = None

from typing import Dict, Tuple

import numpy as np

from gymnasium.envs.mjx.mjx_env import MJXEnv
from gymnasium.envs.mujoco.ant_v5 import DEFAULT_CAMERA_CONFIG


class Ant_MJXEnv(MJXEnv):
# NOTE: MJX does not yet support cfrc_ext and therefore this class can not be instantiated
"""Class for Ant."""

def __init__(
self,
params: Dict[str, any],
):
"""Sets the `obveration_space`."""
MJXEnv.__init__(self, params=params)

self.observation_structure = {
"skipped_qpos": 2 * params["exclude_current_positions_from_observation"],
"qpos": self.mjx_model.nq
- 2 * params["exclude_current_positions_from_observation"],
"qvel": self.mjx_model.nv,
"cfrc_ext": (self.mjx_model.nbody - 1)
* 6
* params["include_cfrc_ext_in_observation"],
}

obs_size = self.observation_structure["qpos"]
obs_size += self.observation_structure["qvel"]
obs_size += self.observation_space["cfrc_ext"]

self.observation_space = gymnasium.spaces.Box(
low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64
)

def _gen_init_physics_state(
self, rng, params: Dict[str, any]
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Sets `qpos` (positional elements) from a CUD and `qvel` (velocity elements) from a gaussian."""
noise_low = -params["reset_noise_scale"]
noise_high = params["reset_noise_scale"]

qpos = self.mjx_model.qpos0 + jax.random.uniform(
key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,)
)
qvel = params["reset_noise_scale"] * jax.random.normal(
key=rng, shape=(self.mjx_model.nv,)
)
act = jnp.empty(self.mjx_model.na)

return qpos, qvel, act

def observation(
self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any]
) -> jnp.ndarray:
"""Observes the `qpos` (posional elements) and `qvel` (velocity elements) and `cfrc_ext` (external contact forces) of the robot."""
mjx_data = state

position = mjx_data.qpos.flatten()
velocity = mjx_data.qvel.flatten()

if params["exclude_current_positions_from_observation"]:
position = position[2:]

if params["include_cfrc_ext_in_observation"] is True:
external_contact_forces = self._get_contact_forces(mjx_data, params)
else:
external_contact_forces = jnp.array([])

observation = jnp.concatenate((position, velocity, external_contact_forces))

return observation

def _get_contact_forces(self, mjx_data: mjx.Data, params: Dict[str, any]):
"""Get External Contact Forces (`cfrc_ext`) clipped by `contact_force_range`."""
raw_contact_forces = mjx_data.cfrc_ext
min_value, max_value = params["contact_force_range"]
contact_forces = jnp.clip(raw_contact_forces, min_value, max_value)
return contact_forces

def _get_reward(
self,
state: mjx.Data,
action: jnp.ndarray,
next_state: mjx.Data,
params: Dict[str, any],
) -> Tuple[jnp.ndarray, Dict]:
"""Reward = forward_reward + healthy_reward - ctrl_cost - contact cost."""
mjx_data_old = state
mjx_data_new = next_state

xy_position_before = mjx_data_old.xpos[params["main_body"], :2]
xy_position_after = mjx_data_new.xpos[params["main_body"], :2]

xy_velocity = (xy_position_after - xy_position_before) / self.dt(params)
x_velocity, y_velocity = xy_velocity

forward_reward = x_velocity * params["forward_reward_weight"]
healthy_reward = (
self._gen_is_healthy(mjx_data_new, params) * params["healthy_reward"]
)
rewards = forward_reward + healthy_reward

ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action))
contact_cost = params["contact_cost_weight"] * jnp.sum(
jnp.square(self._get_contact_forces(mjx_data_new, params))
)
costs = ctrl_cost + contact_cost

reward = rewards - costs

reward_info = {
"reward_forward": forward_reward,
"reward_ctrl": -ctrl_cost,
"reward_contact": -contact_cost,
"reward_survive": healthy_reward,
}

return reward, reward_info

def _gen_is_healty(self, state: mjx.Data, params: Dict[str, any]) -> jnp.ndarray:
"""Checks if the robot is in a healthy potision."""
mjx_data = state

z = mjx_data.qpos[2]
min_z, max_z = params["healthy_z_range"]
is_healthy = (
jnp.isfinite(
jnp.concatenate(mjx_data.qpos, mjx_data.qvel.mjx_data.act)
).all()
and min_z <= z <= max_z
)
return is_healthy

def state_info(self, state: mjx.Data, params: Dict[str, any]) -> Dict[str, float]:
"""Includes state information exclueded from `observation()`."""
mjx_data = state

info = {
"x_position": mjx_data.qpos[0],
"y_position": mjx_data.qpos[1],
"distance_from_origin": jnp.linalg.norm(mjx_data.qpos[0:2], ord=2),
}
return info

def terminal(
self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any]
) -> bool:
"""Terminates if unhealthy."""
return jnp.logical_and(
jnp.logical_not(self._gen_is_healty(state, params)),
params["terminate_when_unhealthy"],
)

def get_default_params(**kwargs) -> Dict[str, any]:
"""Get the default parameter for the Ant environment."""
default = {
"xml_file": "ant.xml",
"frame_skip": 5,
"default_camera_config": DEFAULT_CAMERA_CONFIG,
"forward_reward_weight": 1,
"ctrl_cost_weight": 0.5,
"contact_cost_weight": 5e-4,
"healthy_reward": 1.0,
"main_body": 1,
"terminate_when_unhealthy": True,
"healthy_z_range": (0.2, 1.0),
"contact_force_range": (-1.0, 1.0),
"reset_noise_scale": 0.1,
"exclude_current_positions_from_observation": True,
"include_cfrc_ext_in_observation": True,
}
return {**MJXEnv.get_default_params(), **default, **kwargs}
1 change: 1 addition & 0 deletions gymnasium/envs/mjx/assets
Loading
Loading