diff --git a/gymnasium/envs/mjx/__init__.py b/gymnasium/envs/mjx/__init__.py new file mode 100644 index 000000000..ad6ec9f00 --- /dev/null +++ b/gymnasium/envs/mjx/__init__.py @@ -0,0 +1,3 @@ +"""Contains the base class and environments for MJX.""" + +from gymnasium.envs.mjx.mjx_env import MJXEnv diff --git a/gymnasium/envs/mjx/ant.py b/gymnasium/envs/mjx/ant.py new file mode 100644 index 000000000..80d0f12e4 --- /dev/null +++ b/gymnasium/envs/mjx/ant.py @@ -0,0 +1,189 @@ +"""Contains the class for the `Ant` environment.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv +from gymnasium.envs.mujoco.ant_v5 import DEFAULT_CAMERA_CONFIG + + +class Ant_MJXEnv(MJXEnv): + # NOTE: MJX does not yet support cfrc_ext and therefore this class can not be instantiated + """Class for Ant.""" + + def __init__( + self, + params: Dict[str, any], + ): + """Sets the `obveration_space`.""" + MJXEnv.__init__(self, params=params) + + self.observation_structure = { + "skipped_qpos": 2 * params["exclude_current_positions_from_observation"], + "qpos": self.mjx_model.nq + - 2 * params["exclude_current_positions_from_observation"], + "qvel": self.mjx_model.nv, + "cfrc_ext": (self.mjx_model.nbody - 1) + * 6 + * params["include_cfrc_ext_in_observation"], + } + + obs_size = self.observation_structure["qpos"] + obs_size += self.observation_structure["qvel"] + obs_size += self.observation_space["cfrc_ext"] + + self.observation_space = gymnasium.spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64 + ) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) from a CUD and `qvel` (velocity elements) from a gaussian.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = params["reset_noise_scale"] * jax.random.normal( + key=rng, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) and `cfrc_ext` (external contact forces) of the robot.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + + if params["exclude_current_positions_from_observation"]: + position = position[2:] + + if params["include_cfrc_ext_in_observation"] is True: + external_contact_forces = self._get_contact_forces(mjx_data, params) + else: + external_contact_forces = jnp.array([]) + + observation = jnp.concatenate((position, velocity, external_contact_forces)) + + return observation + + def _get_contact_forces(self, mjx_data: mjx.Data, params: Dict[str, any]): + """Get External Contact Forces (`cfrc_ext`) clipped by `contact_force_range`.""" + raw_contact_forces = mjx_data.cfrc_ext + min_value, max_value = params["contact_force_range"] + contact_forces = jnp.clip(raw_contact_forces, min_value, max_value) + return contact_forces + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = forward_reward + healthy_reward - ctrl_cost - contact cost.""" + mjx_data_old = state + mjx_data_new = next_state + + xy_position_before = mjx_data_old.xpos[params["main_body"], :2] + xy_position_after = mjx_data_new.xpos[params["main_body"], :2] + + xy_velocity = (xy_position_after - xy_position_before) / self.dt(params) + x_velocity, y_velocity = xy_velocity + + forward_reward = x_velocity * params["forward_reward_weight"] + healthy_reward = ( + self._gen_is_healthy(mjx_data_new, params) * params["healthy_reward"] + ) + rewards = forward_reward + healthy_reward + + ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action)) + contact_cost = params["contact_cost_weight"] * jnp.sum( + jnp.square(self._get_contact_forces(mjx_data_new, params)) + ) + costs = ctrl_cost + contact_cost + + reward = rewards - costs + + reward_info = { + "reward_forward": forward_reward, + "reward_ctrl": -ctrl_cost, + "reward_contact": -contact_cost, + "reward_survive": healthy_reward, + } + + return reward, reward_info + + def _gen_is_healty(self, state: mjx.Data, params: Dict[str, any]) -> jnp.ndarray: + """Checks if the robot is in a healthy potision.""" + mjx_data = state + + z = mjx_data.qpos[2] + min_z, max_z = params["healthy_z_range"] + is_healthy = ( + jnp.isfinite( + jnp.concatenate(mjx_data.qpos, mjx_data.qvel.mjx_data.act) + ).all() + and min_z <= z <= max_z + ) + return is_healthy + + def state_info(self, state: mjx.Data, params: Dict[str, any]) -> Dict[str, float]: + """Includes state information exclueded from `observation()`.""" + mjx_data = state + + info = { + "x_position": mjx_data.qpos[0], + "y_position": mjx_data.qpos[1], + "distance_from_origin": jnp.linalg.norm(mjx_data.qpos[0:2], ord=2), + } + return info + + def terminal( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> bool: + """Terminates if unhealthy.""" + return jnp.logical_and( + jnp.logical_not(self._gen_is_healty(state, params)), + params["terminate_when_unhealthy"], + ) + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Ant environment.""" + default = { + "xml_file": "ant.xml", + "frame_skip": 5, + "default_camera_config": DEFAULT_CAMERA_CONFIG, + "forward_reward_weight": 1, + "ctrl_cost_weight": 0.5, + "contact_cost_weight": 5e-4, + "healthy_reward": 1.0, + "main_body": 1, + "terminate_when_unhealthy": True, + "healthy_z_range": (0.2, 1.0), + "contact_force_range": (-1.0, 1.0), + "reset_noise_scale": 0.1, + "exclude_current_positions_from_observation": True, + "include_cfrc_ext_in_observation": True, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} diff --git a/gymnasium/envs/mjx/assets b/gymnasium/envs/mjx/assets new file mode 120000 index 000000000..ca8ce1e5b --- /dev/null +++ b/gymnasium/envs/mjx/assets @@ -0,0 +1 @@ +../mujoco/assets \ No newline at end of file diff --git a/gymnasium/envs/mjx/humanoid.py b/gymnasium/envs/mjx/humanoid.py new file mode 100644 index 000000000..1c46ec689 --- /dev/null +++ b/gymnasium/envs/mjx/humanoid.py @@ -0,0 +1,292 @@ +"""Contains the classes for the humaanoid environments environments, `Humanoid` and `HumanoidStandup`.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv +from gymnasium.envs.mujoco.humanoid_v5 import ( + DEFAULT_CAMERA_CONFIG as HUMANOID_DEFAULT_CAMERA_CONFIG, +) +from gymnasium.envs.mujoco.humanoidstandup_v5 import ( + DEFAULT_CAMERA_CONFIG as HUMANOIDSTANDUP_DEFAULT_CAMERA_CONFIG, +) + + +class BaseHumanoid_MJXEnv(MJXEnv): + # NOTE: MJX does not yet support many features therefore this class can not be instantiated + """Base environment class for humanoid environments such as Humanoid, & HumanoidStandup.""" + + def __init__( + self, + params: Dict[str, any], + ): + """Sets the `obveration_space`.""" + MJXEnv.__init__(self, params=params) + + self.observation_structure = { + "skipped_qpos": 2 * params["exclude_current_positions_from_observation"], + "qpos": self.mjx_model.nq + - 2 * params["exclude_current_positions_from_observation"], + "qvel": self.mjx_model.nv, + "cinert": (self.mjx_model.nbody - 1) + * 10 + * params["include_cinert_in_observation"], + "cvel": (self.mjx_model.nbody - 1) + * 6 + * params["include_cvel_in_observation"], + "qfrc_actuator": (self.mjx_model.nv - 6) + * params["include_qfrc_actuator_in_observation"], + "cfrc_ext": (self.mjx_model.nbody - 1) + * 6 + * params["include_cfrc_ext_in_observation"], + "ten_lenght": 0, + "ten_velocity": 0, + } + + obs_size = self.observation_structure["qpos"] + obs_size += self.observation_structure["qvel"] + obs_size += self.observation_structure["cinert"] + obs_size += self.observation_structure["cvel"] + obs_size += self.observation_structure["qfrc_actuator"] + obs_size += self.observation_space["cfrc_ext"] + + self.observation_space = gymnasium.spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64 + ) + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) and `cfrc_ext` (external contact forces) of the robot.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + + if params["exclude_current_positions_from_observation"]: + position = position[2:] + + if params["include_cinert_in_observation"] is True: + com_inertia = mjx_data.cinert[1:].flatten() + else: + com_inertia = jnp.array([]) + if params["include_cvel_in_observation"] is True: + com_velocity = mjx_data.cvel[1:].flatten() + else: + com_velocity = jnp.array([]) + + if params["include_qfrc_actuator_in_observation"] is True: + actuator_forces = mjx_data.qfrc_actuator[6:].flatten() + else: + actuator_forces = jnp.array([]) + if params["include_cfrc_ext_in_observation"] is True: + external_contact_forces = mjx_data.cfrc_ext[1:].flatten() + else: + external_contact_forces = jnp.array([]) + + observation = jnp.concatenate( + ( + position, + velocity, + com_inertia, + com_velocity, + actuator_forces, + external_contact_forces, + ) + ) + return observation + + def state_info(self, state: mjx.Data, params: Dict[str, any]) -> Dict[str, float]: + """Includes state information exclueded from `observation()`.""" + mjx_data = state + + info = { + "x_position": mjx_data.qpos[0], + "y_position": mjx_data.qpos[1], + "tendon_lenght": mjx_data.ten_length, + "tendon_velocity": mjx_data.ten_velocity, + "distance_from_origin": jnp.linalg.norm(mjx_data.qpos[0:2], ord=2), + } + return info + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) and `qvel` (velocity elements) form a CUD.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + +class HumanoidMJXEnv(BaseHumanoid_MJXEnv): + """Class for Humanoid.""" + + def mass_center(self, mjx_data): + """Calculates the xpos based center of mass.""" + mass = np.expand_dims(self.mjx_model.body_mass, axis=1) + xpos = mjx_data.xipos + return (jnp.sum(mass * xpos, axis=0) / jnp.sum(mass))[0:2] + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = forward_reward + healthy_reward - ctrl_cost - contact_cost.""" + mjx_data_old = state + mjx_data_new = next_state + + xy_position_before = self.mass_center(mjx_data_old) + xy_position_after = self.mass_center(mjx_data_new) + + xy_velocity = (xy_position_after - xy_position_before) / self.dt + x_velocity, y_velocity = xy_velocity + + forward_reward = params["forward_reward_weight"] * x_velocity + healthy_reward = params["healthy_reward"] * self._gen_is_healty( + mjx_data_new, params + ) + rewards = forward_reward + healthy_reward + + ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action)) + contact_cost = self._get_conctact_cost(mjx_data_new, params) + costs = ctrl_cost + contact_cost + + reward = rewards - costs + + reward_info = { + "reward_survive": healthy_reward, + "reward_forward": forward_reward, + "reward_ctrl": -ctrl_cost, + "reward_contact": -contact_cost, + } + + return reward, reward_info + + def _get_conctact_cost(self, mjx_data: mjx.Data, params: Dict[str, any]): + contact_forces = mjx_data.cfrc_ext + contact_cost = params["contact_cost_weight"] * jnp.sum( + jnp.square(contact_forces) + ) + min_cost, max_cost = params["contact_cost_range"] + contact_cost = jnp.clip(contact_cost, min_cost, max_cost) + return contact_cost + + def _gen_is_healty(self, state: mjx.Data, params: Dict[str, any]): + """Checks if the robot is in a healthy potision.""" + mjx_data = state + + min_z, max_z = params["healthy_z_range"] + is_healthy = min_z < mjx_data.qpos[2] < max_z + + return is_healthy + + def terminal( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> bool: + """Terminates if unhealthy.""" + return jnp.logical_and( + jnp.logical_not(self._gen_is_healty(state, params)), + params["terminate_when_unhealthy"], + ) + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Humanoid environment.""" + default = { + "xml_file": "humanoid.xml", + "frame_skip": 5, + "default_camera_config": HUMANOID_DEFAULT_CAMERA_CONFIG, + "forward_reward_weight": 1.25, + "ctrl_cost_weight": 0.1, + "contact_cost_weight": 5e-7, + "contact_cost_range": (-np.inf, 10.0), + "healthy_reward": 5.0, + "terminate_when_unhealthy": True, + "healthy_z_range": (1.0, 2.0), + "reset_noise_scale": 1e-2, + "exclude_current_positions_from_observation": True, + "include_cinert_in_observation": True, + "include_cvel_in_observation": True, + "include_qfrc_actuator_in_observation": True, + "include_cfrc_ext_in_observation": True, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} + + +class HumanoidStandupMJXEnv(BaseHumanoid_MJXEnv): + """Class for HumanoidStandup.""" + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1.""" + mjx_data_new = next_state + + pos_after = mjx_data_new.qpos[2] + + uph_cost = (pos_after - 0) / self.mjx_model.opt.timestep + + quad_ctrl_cost = params["ctrl_cost_weight"] * jnp.square(action).sum() + + quad_impact_cost = ( + params["impact_cost_weight"] * jnp.square(mjx_data_new.cfrc_ext).sum() + ) + min_impact_cost, max_impact_cost = params["impact_cost_range"] + quad_impact_cost = np.clip(quad_impact_cost, min_impact_cost, max_impact_cost) + + reward = uph_cost - quad_ctrl_cost - quad_impact_cost + 1 + + reward_info = { + "reward_linup": uph_cost, + "reward_quadctrl": -quad_ctrl_cost, + "reward_impact": -quad_impact_cost, + } + + return reward, reward_info + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Humanoid environment.""" + default = { + "xml_file": "humanoidstandup.xml", + "frame_skip": 5, + "default_camera_config": HUMANOIDSTANDUP_DEFAULT_CAMERA_CONFIG, + "uph_cost_weight": 1, + "ctrl_cost_weight": 0.1, + "impact_cost_weight": 0.5e-6, + "impact_cost_range": (-np.inf, 10.0), + "reset_noise_scale": 1e-2, + "exclude_current_positions_from_observation": True, + "include_cinert_in_observation": True, + "include_cvel_in_observation": True, + "include_qfrc_actuator_in_observation": True, + "include_cfrc_ext_in_observation": True, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} diff --git a/gymnasium/envs/mjx/locomotion_2d.py b/gymnasium/envs/mjx/locomotion_2d.py new file mode 100644 index 000000000..ad52f0695 --- /dev/null +++ b/gymnasium/envs/mjx/locomotion_2d.py @@ -0,0 +1,265 @@ +"""Contains the classes for the 2d locomotion environments, `HalfCheetah`, `Hopper` and `Walker2D`.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv +from gymnasium.envs.mujoco.half_cheetah_v5 import ( + DEFAULT_CAMERA_CONFIG as HALFCHEETAH_DEFAULT_CAMERA_CONFIG, +) +from gymnasium.envs.mujoco.hopper_v5 import ( + DEFAULT_CAMERA_CONFIG as HOPPER_DEFAULT_CAMERA_CONFIG, +) +from gymnasium.envs.mujoco.walker2d_v5 import ( + DEFAULT_CAMERA_CONFIG as WALKER2D_DEFAULT_CAMERA_CONFIG, +) + + +class Locomotion_2d_MJXEnv(MJXEnv): + """Base environment class for 2d locomotion environments such as HalfCheetah, Hopper & Walker2d.""" + + def __init__( + self, + params: Dict[str, any], # NOTE not API compliant (yet?) + ): + """Sets the `obveration.shape`.""" + MJXEnv.__init__(self, params=params) + + self.observation_structure = { + "skipped_qpos": 1 * params["exclude_current_positions_from_observation"], + "qpos": self.mjx_model.nq + - 1 * params["exclude_current_positions_from_observation"], + "qvel": self.mjx_model.nv, + } + + obs_size = self.observation_structure["qpos"] + obs_size += self.observation_structure["qvel"] + + self.observation_space = gymnasium.spaces.Box( # TODO use jnp when and if `Box` supports jax natively + low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float32 + ) + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) of the robot.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + + if params["exclude_current_positions_from_observation"]: + position = position[1:] + + observation = jnp.concatenate((position, velocity)) + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = foward_reward + healty_reward - control_cost.""" + mjx_data_old = state + mjx_data_new = next_state + + x_position_before = mjx_data_old.qpos[0] + x_position_after = mjx_data_new.qpos[0] + x_velocity = (x_position_after - x_position_before) / self.dt(params) + + forward_reward = params["forward_reward_weight"] * x_velocity + healthy_reward = params["healthy_reward"] * self._gen_is_healty( + mjx_data_new, params + ) + rewards = forward_reward + healthy_reward + + costs = ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action)) + + reward = rewards - costs + reward_info = { + "reward_survive": healthy_reward, # TODO? make optional + "reward_forward": forward_reward, + "reward_ctrl": -ctrl_cost, + "x_velocity": x_velocity, + } + + return reward, reward_info + + def terminal( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> bool: + """Terminates if unhealthy.""" + return jnp.logical_and( + jnp.logical_not(self._gen_is_healty(state, params)), + params["terminate_when_unhealthy"], + ) + + def state_info(self, state: mjx.Data, params: Dict[str, any]) -> Dict[str, float]: + """Includes state information exclueded from `observation()`.""" + mjx_data = state + + info = { + "x_position": mjx_data.qpos[0], + "z_distance_from_origin": mjx_data.qpos[1] - self.mjx_model.qpos0[1], + } + return info + + def _gen_is_healty(self, state: mjx.Data, params: Dict[str, any]): + """Checks if the robot is a healthy potision.""" + mjx_data = state + + z, angle = mjx_data.qpos[1:3] + physics_state = jnp.concatenate( + (mjx_data.qpos[2:], mjx_data.qvel, mjx_data.act) + ) + + min_state, max_state = params["healthy_state_range"] + min_z, max_z = params["healthy_z_range"] + min_angle, max_angle = params["healthy_angle_range"] + + healthy_state = jnp.all( + jnp.logical_and(min_state < physics_state, physics_state < max_state) + ) + healthy_z = jnp.logical_and(min_z < z, z < max_z) + healthy_angle = jnp.logical_and(min_angle < angle, angle < max_angle) + + # NOTE there is probably a clearer way to write this + is_healthy = jnp.logical_and( + jnp.logical_and(healthy_state, healthy_z), healthy_angle + ) + + return is_healthy + + +# The following could maybe be implemented as **kwargs in register() +class HalfCheetahMJXEnv(Locomotion_2d_MJXEnv): + """Class for HalfCheetah.""" + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) from a CUD and `qvel` (velocity elements) from a gaussian.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = params["reset_noise_scale"] * jax.random.normal( + key=rng, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the HalfCheetah environment.""" + default = { + "xml_file": "half_cheetah.xml", + "frame_skip": 5, + "default_camera_config": HALFCHEETAH_DEFAULT_CAMERA_CONFIG, + "forward_reward_weight": 1.0, + "ctrl_cost_weight": 0.1, + "healthy_reward": 0, + "terminate_when_unhealthy": True, + "healthy_state_range": (-jnp.inf, jnp.inf), + "healthy_z_range": (-jnp.inf, jnp.inf), + "healthy_angle_range": (-jnp.inf, jnp.inf), + "reset_noise_scale": 0.1, + "exclude_current_positions_from_observation": True, + } + return {**Locomotion_2d_MJXEnv.get_default_params(), **default, **kwargs} + + +class HopperMJXEnv(Locomotion_2d_MJXEnv): + # NOTE: MJX does not yet support condim=1 and therefore this class can not be instantiated + """Class for Hopper.""" + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) and `qvel` (velocity elements) form a CUD.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Hopper environment.""" + default = { + "xml_file": "hopper.xml", + "frame_skip": 4, + "default_camera_config": HOPPER_DEFAULT_CAMERA_CONFIG, + "forward_reward_weight": 1.0, + "ctrl_cost_weight": 1e-3, + "healthy_reward": 1.0, + "terminate_when_unhealthy": True, + "healthy_state_range": (-100.0, 100.0), + "healthy_z_range": (0.7, jnp.inf), + "healthy_angle_range": (-0.2, 0.2), + "reset_noise_scale": 5e-3, + "exclude_current_positions_from_observation": True, + } + return {**Locomotion_2d_MJXEnv.get_default_params(), **default, **kwargs} + + +class Walker2dMJXEnv(Locomotion_2d_MJXEnv): + """Class for Walker2d.""" + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) and `qvel` (velocity elements) form a CUD.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Walker2d environment.""" + default = { + "xml_file": "walker2d_v5.xml", + "frame_skip": 4, + "default_camera_config": WALKER2D_DEFAULT_CAMERA_CONFIG, + "forward_reward_weight": 1.0, + "ctrl_cost_weight": 1e-3, + "healthy_reward": 1.0, + "terminate_when_unhealthy": True, + "healthy_state_range": (-jnp.inf, jnp.inf), + "healthy_z_range": (0.8, 2.0), + "healthy_angle_range": (-1.0, 1.0), + "reset_noise_scale": 5e-3, + "exclude_current_positions_from_observation": True, + } + return {**Locomotion_2d_MJXEnv.get_default_params(), **default, **kwargs} diff --git a/gymnasium/envs/mjx/manipulation.py b/gymnasium/envs/mjx/manipulation.py new file mode 100644 index 000000000..fbf5f0b85 --- /dev/null +++ b/gymnasium/envs/mjx/manipulation.py @@ -0,0 +1,251 @@ +"""Contains the classes for the manipulation environments, `Pusher`, `Reacher`.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv +from gymnasium.envs.mujoco.pusher_v5 import ( + DEFAULT_CAMERA_CONFIG as PUSHER_DEFAULT_CAMERA_CONFIG, +) +from gymnasium.envs.mujoco.reacher_v5 import ( + DEFAULT_CAMERA_CONFIG as REACHER_HOPPER_DEFAULT_CAMERA_CONFIG, +) + + +class Reacher_MJXEnv(MJXEnv): + """Class for Reacher.""" + + def __init__( + self, + params: Dict[str, any], + ): + """Sets the `obveration_space`.""" + MJXEnv.__init__(self, params=params) + + self.observation_space = gymnasium.spaces.Box( # TODO use jnp when and if `Box` supports jax natively + low=-np.inf, high=np.inf, shape=(10,), dtype=np.float32 + ) + + def _body_goal(self, goal, rng): + goal = jax.random.uniform(key=rng, minval=-0.2, maxval=0.2, shape=(2,)) + return goal + + def _validate_goal(self, goal): + """Check if the `goal` is within a circle of radius 0.2 meters.""" + return jnp.less(jnp.linalg.norm(goal), jnp.array(0.2)) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `arm.qpos` (positional elements) and `arm.qvel` (velocity elements) from a CUD and the `goal.qpos` from a cicrular uniform distribution.""" + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=-0.1, maxval=0.1, shape=(self.mjx_model.nq,) + ) + + goal = jax.lax.while_loop( + self._validate_goal, + lambda goal: self._body_goal(goal, rng), + init_val=jnp.array((10.0, 0.0)), + ) + qpos.at[-2:].set(goal) + + qvel = jax.random.uniform( + key=rng, minval=-0.005, maxval=0.005, shape=(self.mjx_model.nv,) + ) + qvel.at[-2:].set(jnp.zeros(2)) + + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def _get_goal(self, mjx_data: mjx.Data) -> jnp.ndarray: + return mjx_data.qpos[-2:] + + def _set_goal(self, mjx_data: mjx.Data, goal: jnp.ndarray) -> mjx.Data: + """Add the coordinate of `goal` to `mjx_data`.""" + mjx_data = mjx_data.replace(qpos=mjx_data.qpos.at[-2:].set(goal)) + return mjx_data + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `sin(theta)` & `cos(theta)` & `qpos` & `qvel` & 'fingertip - target' distance.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + theta = position[:2] + + fingertip_position = mjx_data.xpos[3] # TODO make this dynamic + target_position = mjx_data.xpos[4] # TODO make this dynamic + observation = jnp.concatenate( + ( + jnp.cos(theta), + jnp.sin(theta), + position[2:], + velocity[:2], + (fingertip_position - target_position)[:2], + ) + ) + + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = reward_dist + reward_ctrl.""" + mjx_data = next_state + + fingertip_position = mjx_data.xpos[3] # TODO make this dynamic + target_position = mjx_data.xpos[4] # TODO make this dynamic + + vec = fingertip_position - target_position + reward_dist = -jnp.linalg.norm(vec) * params["reward_dist_weight"] + reward_ctrl = -jnp.square(action).sum() * params["reward_control_weight"] + + reward = reward_dist + reward_ctrl + + reward_info = { + "reward_dist": reward_dist, + "reward_ctrl": reward_ctrl, + } + + return reward, reward_info + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Reacher environment.""" + default = { + "xml_file": "reacher.xml", + "frame_skip": 2, + "default_camera_config": REACHER_HOPPER_DEFAULT_CAMERA_CONFIG, + "reward_dist_weight": 1, + "reward_control_weight": 1, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} + + +class Pusher_MJXEnv(MJXEnv): + # NOTE: MJX does not yet support condim=1 and therefore this class can not be instantiated + """Class for Pusher.""" + + def __init__( + self, + params: Dict[str, any], + ): + """Sets the `obveration_space`.""" + MJXEnv.__init__(self, params=params) + + self.observation_space = gymnasium.spaces.Box( # TODO use jnp when and if `Box` supports jax natively + low=-np.inf, high=np.inf, shape=(23,), dtype=np.float32 + ) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `arm.qpos` (positional elements) and `arm.qvel` (velocity elements) from a CUD and the `goal.qpos` from a cicrular uniform distribution.""" + qpos = self.mjx_model.qpos0 + + goal_pos = jnp.zeroes(2) + while True: + cylinder_pos = np.concatenate( + [ + jax.random.uniform(key=rng, minval=-0.3, maxval=0.3, shape=1), + jax.random.uniform(key=rng, minval=-0.2, maxval=0.2, shape=1), + ] + ) + if jnp.linalg.norm(cylinder_pos - goal_pos) > 0.17: + break + + qpos.at[-4:-2].set(cylinder_pos) + qpos.at[-2:].set(goal_pos) + qvel = jax.random.uniform( + key=rng, minval=-0.005, maxval=0.005, shape=(self.mjx_model.nv,) + ) + qvel.at[-4:].set(0) + + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the & `qpos` & `qvel` & `tips_arm` & `object` `goal`.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + tips_arm_position = mjx_data.xpos[10] # TODO make this dynamic + object_position = mjx_data.xpos[11] # TODO make this dynamic + goal_position = mjx_data.xpos[12] # TODO make this dynamic + + observation = jnp.concatenate( + ( + position[:7], + velocity[:7], + tips_arm_position, + object_position, + goal_position, + ) + ) + + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = reward_dist + reward_ctrl + reward_near.""" + mjx_data = next_state + tips_arm_position = mjx_data.xpos[10] # TODO make this dynamic + object_position = mjx_data.xpos[11] # TODO make this dynamic + goal_position = mjx_data.xpos[12] # TODO make this dynamic + + vec_1 = object_position - tips_arm_position + vec_2 = object_position - goal_position + + reward_near = -jnp.linalg.norm(vec_1) * params["reward_near_weight"] + reward_dist = -jnp.linalg.norm(vec_2) * params["reward_dist_weight"] + reward_ctrl = -jnp.square(action).sum() * params["reward_control_weight"] + + reward = reward_dist + reward_ctrl + reward_near + + reward_info = { + "reward_dist": reward_dist, + "reward_ctrl": reward_ctrl, + "reward_near": reward_near, + } + + return reward, reward_info + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Reacher environment.""" + default = { + "xml_file": "pusher.xml", + "frame_skip": 5, + "default_camera_config": PUSHER_DEFAULT_CAMERA_CONFIG, + "reward_near_weight": 0.5, + "reward_dist_weight": 1, + "reward_control_weight": 0.1, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} diff --git a/gymnasium/envs/mjx/mjx_env.py b/gymnasium/envs/mjx/mjx_env.py new file mode 100644 index 000000000..3098b55de --- /dev/null +++ b/gymnasium/envs/mjx/mjx_env.py @@ -0,0 +1,266 @@ +"""Contains the base class for MJX based robot environments. + +Note: This is expted to be used my `gymnasium`, `gymnasium-robotics`, `metaworld` and 3rd party libraries. +""" + +from typing import Dict, Tuple, Union + +import numpy as np + +import gymnasium +from gymnasium.envs.mujoco import MujocoRenderer +from gymnasium.envs.mujoco.mujoco_env import expand_model_path +from gymnasium.experimental.functional import FuncEnv + + +try: + import jax + import mujoco + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + + +""" +# TODO unit test these +def mjx_get_physics_state(mjx_data: mjx.Data) -> jnp.ndarray: + ""Get physics state of `mjx_data` similar to mujoco.get_state."" + return jnp.concatenate([mjx_data.qpos, mjx_data.qvel, mjx_data.act]) + + +def mjx_set_physics_state(mjx_data: mjx.Data, mjx_physics_state) -> mjx.Data: + ""Sets the physics state in `mjx_data`."" + qpos_end_index = mjx_data.qpos.size + qvel_end_index = qpos_end_index + mjx_data.qvel.size + + qpos = mjx_physics_state[:qpos_end_index] + qvel = mjx_physics_state[qpos_end_index: qvel_end_index] + act = mjx_physics_state[qvel_end_index:] + assert qpos.size == mjx_data.qpos.size + assert qvel.size == mjx_data.qvel.size + assert act.size == mjx_data.act.size + + return mjx_data.replace(qpos=qpos, qvel=qvel, act=act) +""" + + +# TODO add type hint to `params` +# TODO add render `metadata` +# TODO add init_qvel +# TODO create pip install gymnasium[mjx] +class MJXEnv( + FuncEnv[ + mjx.Data, + jnp.ndarray, + jnp.ndarray, + jnp.ndarray, + bool, + MujocoRenderer, + Dict[str, any], + ] +): + """The Base class for MJX Environments in Gymnasium. + + `observation`, `terminal`, and `state_info` should be defined in sub-classes. + """ + + def __init__(self, params: Dict[str, any]): + """Create the `mjx.Model` of the robot defined in `params["xml_file"]`. + + Keep `mujoco.MjModel` of model for rendering purposes. + The Sub-class environments are expected to define `self.observation_space` + """ + if MJX_IMPORT_ERROR is not None: + raise gymnasium.error.DependencyNotInstalled( + f"{MJX_IMPORT_ERROR}. " + "(HINT: you need to install mujoco-mjx, run `pip install gymnasium[mjx]`.)" + ) + + fullpath = expand_model_path(params["xml_file"]) + + self.model = mujoco.MjModel.from_xml_path(fullpath) + self.mjx_model = mjx.put_model(self.model) + + # observation_space: gymnasium.spaces.Box # set by subclass + self.action_space = gymnasium.spaces.Box( + low=self.model.actuator_ctrlrange.T[0], + high=self.model.actuator_ctrlrange.T[1], + dtype=np.float32, + ) + # TODO change bounds and types when and if `Box` supports JAX nativly + # self.action_space = gymnasium.spaces.Box(low=self.mjx_model.actuator_ctrlrange.T[0], high=self.mjx_model.actuator_ctrlrange.T[1], dtype=np.float32) + + def initial(self, rng: jax.random.PRNGKey, params: Dict[str, any]) -> mjx.Data: + """Initializes and returns the `mjx.Data`.""" + # TODO? find a more performant alternative that does not allocate? + mjx_data = mjx.make_data(self.model) + qpos, qvel, act = self._gen_init_physics_state(rng, params) + mjx_data = mjx_data.replace(qpos=qpos, qvel=qvel, act=act) + mjx_data = mjx.forward(self.mjx_model, mjx_data) + + return mjx_data + + def transition( + self, + state: mjx.Data, + action: jnp.ndarray, + rng: jax.random.PRNGKey, + params: Dict[str, any], + ) -> mjx.Data: + """Step through the simulator using `action` for `self.dt` (note: `rng` argument is ignored).""" + mjx_data = state + + mjx_data = mjx_data.replace(ctrl=action) + """ + mjx_data = jax.lax.fori_loop( + 0, params["frame_skip"], lambda _, x: mjx.step(self.mjx_model, x), mjx_data + ) + """ + # """ + d, _ = jax.lax.scan( + lambda x, _: (mjx.step(self.mjx_model, x), None), + init=mjx_data, + xs=None, + length=params["frame_skip"], + ) + # """ + + # TODO fix sensors with MJX>=3.2 + return mjx_data + + def reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + rng: jax.random.PRNGKey, + params: Dict[str, any], + ) -> jnp.ndarray: + """Returns the reward.""" + return self._get_reward(state, action, next_state, params)[0] + + def transition_info( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Dict: + """Includes just reward info.""" + return self._get_reward(state, action, next_state, params)[1] + + def render_init( + self, + params: Dict[str, any], + ) -> MujocoRenderer: + """Returns a `MujocoRenderer` object.""" + return MujocoRenderer( + self.model, + None, # no MuJoCo DATA + params["default_camera_config"], + params["width"], + params["height"], + params["max_geom"], + params["camera_id"], + params["camera_name"], + ) + + def render_image( + self, + state: mjx.Data, + render_state: MujocoRenderer, + params: Dict[str, any], + ) -> Tuple[MujocoRenderer, Union[np.ndarray, None]]: + """Renders the `mujoco` frame of the environment by converting `mjx.Data` to `mujoco.MjData`. + + NOTE: this function can not be jitted. + """ + mjx_data = state + mujoco_renderer = render_state + + data = mjx.get_data(self.model, mjx_data) + mujoco.mj_forward(self.model, data) + + mujoco_renderer.data = data + + frame = mujoco_renderer.render(params["render_mode"]) + + return mujoco_renderer, frame + + def render_close( + self, render_state: MujocoRenderer, params: Dict[str, any] + ) -> None: + """Closes the `MujocoRender` object.""" + mujoco_renderer = render_state + if mujoco_renderer is not None: + mujoco_renderer.close() + + def dt(self, params: Dict[str, any]) -> float: + """Returns the duration between timesteps (`dt`).""" + return self.mjx_model.opt.timestep * params["frame_skip"] + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Generates the initial physics state. + + `MJXEnv` Equivalent of `MujocoEnv.model.` + + Returns: `(qpos, qvel, act)` + """ + raise NotImplementedError + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict[str, float]]: + """Generates `reward` and `transition_info`, we rely on the JIT's SEE to optimize it. + + Returns: `(reward, reward_info)` + """ + raise NotImplementedError + + def terminal( + self, + state: mjx.Data, + rng: jax.random.PRNGKey, + params: Dict[str, any] | None = None, + ) -> jnp.ndarray: + """Should be overwritten if the sub-class environment terminates.""" + return jnp.array(False) + + def get_default_params(**kwargs) -> Dict[str, any]: + """Generate the default parameters for rendering.""" + default = { + "default_camera_config": {}, + "camera_id": None, + "camera_name": None, + "max_geom": 1000, + "width": 480, + "height": 480, + "render_mode": None, + } + return default + + """ + def mjx_get_physics_state_put_version(self, mjx_data: mjx.Data) -> np.ndarray: + ""version based on @btaba suggestion"" + # data = mujoco.MjData(self.model) + # mjx.device_get_into(data, mjx_data) + data = mjx.get_data(self.model, mjx_data) + state = np.empty(mujoco.mj_stateSize(self.model, mujoco.mjtState.mjSTATE_PHYSICS)) + mujoco.mj_getState(self.model, data, state, spec=mujoco.mjtState.mjSTATE_PHYSICS) + + return state + """ + + +# TODO add vector environment +# TODO consider requirement of `metaworld` & `gymansium_robotics.RobotEnv` & `mo-gymnasium` +# TODO unit testing diff --git a/gymnasium/envs/mjx/pendulum.py b/gymnasium/envs/mjx/pendulum.py new file mode 100644 index 000000000..b62eede13 --- /dev/null +++ b/gymnasium/envs/mjx/pendulum.py @@ -0,0 +1,218 @@ +"""Contains the classes for the Inverted Pendulum environments, `InvertedPendulum`, `InvertedDoublePendulum`.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv +from gymnasium.envs.mujoco.inverted_double_pendulum_v5 import ( + DEFAULT_CAMERA_CONFIG as INVERTED_DOUBLE_PENDULUM_DEFAULT_CAMERA_CONFIG, +) +from gymnasium.envs.mujoco.inverted_pendulum_v5 import ( + DEFAULT_CAMERA_CONFIG as INVERTED_PENDULUM_DEFAULT_CAMERA_CONFIG, +) + + +class InvertedDoublePendulumMJXEnv(MJXEnv): + """Class for InvertedDoublePendulum.""" + + def __init__( + self, + params: Dict[str, any], # NOTE not API compliant (yet?) + ): + """Sets the `obveration_space.shape`.""" + MJXEnv.__init__(self, params=params) + + # TODO use jnp when and if `Box` supports jax natively + self.observation_space = gymnasium.spaces.Box( + low=-np.inf, high=np.inf, shape=(9,), dtype=np.float32 + ) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) from a CUD and `qvel` (velocity elements) from a gaussian.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = params["reset_noise_scale"] * jax.random.normal( + key=rng, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) of the robot.""" + mjx_data = state + + velocity = mjx_data.qvel.flatten() + + observation = jnp.concatenate( + ( + mjx_data.qpos.flatten()[:1], # `cart` x-position + jnp.sin(mjx_data.qpos[1:]), + jnp.cos(mjx_data.qpos[1:]), + jnp.clip(velocity, -10, 10), + jnp.clip(mjx_data.qfrc_constraint, -10, 10)[:1], + ) + ) + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = alive_bonus - dist_penalty - vel_penalty.""" + mjx_data_new = next_state + + v = mjx_data_new.qvel[1:3] + x, _, y = mjx_data_new.site_xpos[0] + + dist_penalty = 0.01 * x**2 + (y - 2) ** 2 + vel_penalty = jnp.array([1e-3, 5e-3]).T * jnp.square(v) + alive_bonus = params["healthy_reward"] * self._gen_is_healty(mjx_data_new) + + reward = alive_bonus - dist_penalty - vel_penalty + + reward_info = { + "reward_survive": alive_bonus, + "distance_penalty": -dist_penalty, + "velocity_penalty": -vel_penalty, + } + + return reward, reward_info + + def _gen_is_healty(self, state: mjx.Data): + """Checks if the pendulum is upright.""" + mjx_data = state + + y = mjx_data.site_xpos[0][2] + + return jnp.array(y > 1) + + def terminal( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> bool: + """Terminates if unhealty.""" + return jnp.logical_not(self._gen_is_healty(state)) + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the parameters for the InvertedDoublePendulum environment.""" + default = { + "xml_file": "inverted_double_pendulum.xml", + "frame_skip": 5, + "default_camera_config": INVERTED_DOUBLE_PENDULUM_DEFAULT_CAMERA_CONFIG, + "healthy_reward": 10.0, + "reset_noise_scale": 0.1, + } + return {**MJXEnv.get_default_params(), **default, **kwargs} + + +class InvertedPendulumMJXEnv(MJXEnv): + """Class for InvertedPendulum.""" + + def __init__( + self, + params: Dict[str, any], # NOTE not API compliant (yet?) + ): + """Sets the `obveration_space.shape`.""" + MJXEnv.__init__(self, params=params) + + self.observation_structure = { + "qpos": self.mjx_model.nq, + "qvel": self.mjx_model.nv, + } + + obs_size = self.observation_structure["qpos"] + obs_size += self.observation_structure["qvel"] + + # TODO use jnp when and if `Box` supports jax natively + self.observation_space = gymnasium.spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float32 + ) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) and `qvel` (velocity elements) form a CUD.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) of the robot.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + + observation = jnp.concatenate((position, velocity)) + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + reward = jnp.array(self._gen_is_healty(next_state), dtype=jnp.float32) + reward_info = {"reward_survive": reward} + return reward, reward_info + + def _gen_is_healty(self, state: mjx.Data): + """Checks if the pendulum is upright.""" + mjx_data = state + + angle = mjx_data.qpos[1] + + return jnp.abs(angle) <= 0.2 + + def terminal( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> bool: + """Terminates if unhealty.""" + return jnp.logical_not(self._gen_is_healty(state)) + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the parameters for the InvertedPendulum environment.""" + default = { + "xml_file": "inverted_pendulum.xml", + "frame_skip": 2, + "default_camera_config": INVERTED_PENDULUM_DEFAULT_CAMERA_CONFIG, + "reset_noise_scale": 0.01, + } + + return {**MJXEnv.get_default_params(), **default, **kwargs} diff --git a/gymnasium/envs/mjx/swimmer.py b/gymnasium/envs/mjx/swimmer.py new file mode 100644 index 000000000..b3f321f44 --- /dev/null +++ b/gymnasium/envs/mjx/swimmer.py @@ -0,0 +1,128 @@ +"""Contains the class for the `Swimmer` environment.""" + +import gymnasium + + +try: + import jax + from jax import numpy as jnp + from mujoco import mjx +except ImportError as e: + MJX_IMPORT_ERROR = e +else: + MJX_IMPORT_ERROR = None + +from typing import Dict, Tuple + +import numpy as np + +from gymnasium.envs.mjx.mjx_env import MJXEnv + + +class Swimmer_MJXEnv(MJXEnv): + # NOTE: MJX does not yet support condim=1 and therefore this class can not be instantiated + """Class for Swimmer.""" + + def __init__( + self, + params: Dict[str, any], + ): + """Sets the `obveration_space`.""" + MJXEnv.__init__(self, params=params) + + self.observation_structure = { + "skipped_qpos": 2 * params["exclude_current_positions_from_observation"], + "qpos": self.mjx_model.nq + - 2 * params["exclude_current_positions_from_observation"], + "qvel": self.mjx_model.nv, + } + + obs_size = self.observation_structure["qpos"] + obs_size += self.observation_structure["qvel"] + + self.observation_space = gymnasium.spaces.Box( + low=-np.inf, high=np.inf, shape=(obs_size,), dtype=np.float64 + ) + + def _gen_init_physics_state( + self, rng, params: Dict[str, any] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """Sets `qpos` (positional elements) and `qvel` (velocity elements) form a CUD.""" + noise_low = -params["reset_noise_scale"] + noise_high = params["reset_noise_scale"] + + qpos = self.mjx_model.qpos0 + jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nq,) + ) + qvel = jax.random.uniform( + key=rng, minval=noise_low, maxval=noise_high, shape=(self.mjx_model.nv,) + ) + act = jnp.empty(self.mjx_model.na) + + return qpos, qvel, act + + def observation( + self, state: mjx.Data, rng: jax.random.PRNGKey, params: Dict[str, any] + ) -> jnp.ndarray: + """Observes the `qpos` (posional elements) and `qvel` (velocity elements) of the robot.""" + mjx_data = state + + position = mjx_data.qpos.flatten() + velocity = mjx_data.qvel.flatten() + + if params["exclude_current_positions_from_observation"]: + position = position[2:] + + observation = jnp.concatenate((position, velocity)) + return observation + + def _get_reward( + self, + state: mjx.Data, + action: jnp.ndarray, + next_state: mjx.Data, + params: Dict[str, any], + ) -> Tuple[jnp.ndarray, Dict]: + """Reward = reward_dist + reward_ctrl.""" + mjx_data_old = state + mjx_data_new = next_state + + x_position_before = mjx_data_old.qpos[0] + x_position_after = mjx_data_new.qpos[0] + x_velocity = (x_position_after - x_position_before) / self.dt(params) + + forward_reward = params["forward_reward_weight"] * x_velocity + ctrl_cost = params["ctrl_cost_weight"] * jnp.sum(jnp.square(action)) + + reward = forward_reward - ctrl_cost + + reward_info = { + "reward_forward": forward_reward, + "reward_ctrl": -ctrl_cost, + } + + return reward, reward_info + + def state_info(self, state: mjx.Data, params: Dict[str, any]) -> Dict[str, float]: + """Includes state information exclueded from `observation()`.""" + mjx_data = state + + info = { + "x_position": mjx_data.qpos[0], + "y_position": mjx_data.qpos[1], + "distance_from_origin": jnp.linalg.norm(mjx_data.qpos[0:2], ord=2), + } + return info + + def get_default_params(**kwargs) -> Dict[str, any]: + """Get the default parameter for the Swimmer environment.""" + default = { + "xml_file": "swimmer.xml", + "frame_skip": 4, + "default_camera_config": {}, + "forward_reward_weight": 1.0, + "ctrl_cost_weight": 1e-4, + "reset_noise_scale": 0.1, + "exclude_current_positions_from_observation": True, + } + return {**MJXEnv.get_default_params(), **default, **kwargs}