diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/README.md b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md new file mode 100644 index 0000000000..3bb276b1d4 --- /dev/null +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/README.md @@ -0,0 +1,37 @@ +# Example for Reinforcement Learning (RL) With Gazebo + +This demo world shows you an example of how you can use SDFormat, Ray-RLLIB and Gazebo to perform RL with python. +We start with a very simple cart-pole world. This world is defined in our sdf file `cart_pole.sdf`. It is analogous to +the + +## Create a VENV + +First create a virtual environment using python, +``` +python3 -m venv venv +``` +Lets activate it and install rayrllib and pytorch. +``` +. venv/bin/activate +``` + +Lets install our dependencies +``` +pip install "ray[rllib]" torch +``` + +In the same terminal you should add your gazebo python install directory to the `PYTHONPATH` +If you built gazebo from source in the current working directory this would be: +``` +export PYTHONPATH=$PYTHONPATH:install/lib/python +``` + +You will also need to set PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION to python due to version +mis-matches. +``` +export PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +``` + +## Exploring the environment + +You can see the environment by using `gz sim cart_pole.sdf`. \ No newline at end of file diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole.sdf b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole.sdf new file mode 100644 index 0000000000..6c1a1cd0de --- /dev/null +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole.sdf @@ -0,0 +1,368 @@ + + + + + + + 0.001 + 1.0 + + + + + + + + + + true + 0 0 10 0 0 0 + 1 1 1 1 + 0.5 0.5 0.5 1 + + 1000 + 0.9 + 0.01 + 0.001 + + -0.5 0.1 -0.9 + + + + true + + + + + 0 0 1 + 100 100 + + + + + + + 0 0 1 + 100 100 + + + + 0.8 0.8 0.8 1 + 0.8 0.8 0.8 1 + 0.8 0.8 0.8 1 + + + + + + + 0 0 0.325 0 -0 0 + + + -0.151427 -0 1.5 0 -0 0 + + 0.1 + + 0.126164 + 0 + 0 + 0.416519 + 0 + 0.481014 + + + + + + 0.2 0.2 1.5 + + + + 0.5 1.0 0.5 1 + 0.5 1.0 0.5 1 + 0.0 1.0 0.0 1 + + + + + + 0.2 0.2 1.5 + + + + + + + -0.151427 0 2.2 0 -0 0 + + 10.0 + + 1.26164 + 0 + 0 + 4.16519 + 0 + 4.81014 + + + + + + 0.3 0.3 0.3 + + + + 0.5 1.0 0.5 1 + 0.5 1.0 0.5 1 + 0.0 1.0 0.0 1 + + + + + + 0.3 0.3 0.3 + + + + + + + -0.151427 -0 0.175 0 -0 0 + + 1.14395 + + 0.126164 + 0 + 0 + 0.416519 + 0 + 0.481014 + + + + + + 2.01142 1 0.568726 + + + + 0.5 1.0 0.5 1 + 0.5 1.0 0.5 1 + 0.0 1.0 0.0 1 + + + + + + 2.01142 1 0.568726 + + + + + + + 0.554283 0.625029 -0.025 -1.5707 0 0 + + 2 + + 0.145833 + 0 + 0 + 0.145833 + 0 + 0.125 + + + + + + 0.3 + + + + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + + + + + + 0.3 + + + + + + 1 + 1 + 0.035 + 0 + 0 0 1 + + + 1 + 1 + 0.1 + + + + + + + + 0.554282 -0.625029 -0.025 -1.5707 0 0 + + 2 + + 0.145833 + 0 + 0 + 0.145833 + 0 + 0.125 + + + + + + 0.3 + + + + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + + + + + + 0.3 + + + + + + 1 + 1 + 0.035 + 0 + 0 0 1 + + + 1 + 1 + 0.1 + + + + + + + + -0.957138 -0 -0.125 0 -0 0 + + 1 + + 0.1 + 0 + 0 + 0.1 + 0 + 0.1 + + + + + + 0.2 + + + + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + 0.2 0.2 0.2 1 + + + + + + 0.2 + + + + + + + chassis + left_wheel + + 0 0 1 + + -1.79769e+308 + 1.79769e+308 + + + + + + chassis + right_wheel + + 0 0 1 + + -1.79769e+308 + 1.79769e+308 + + + + + + chassis + caster + + + + 0 0 -0.75 0 0 0 + chassis + pole + + 0 1 0 + + -1.79769e+308 + 1.79769e+308 + + + + + + pole + pole_mass + + + + + + diff --git a/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py new file mode 100644 index 0000000000..b45fa9599c --- /dev/null +++ b/examples/scripts/reinforcement_learning/simple_cart_pole/cart_pole_env.py @@ -0,0 +1,120 @@ + +import os +import gymnasium as gym +import numpy as np + +from gz.common6 import set_verbosity +from gz.sim9 import TestFixture, World, world_entity, Model, Link +from gz.math8 import Vector3d +from gz.transport14 import Node +from gz.msgs11.world_control_pb2 import WorldControl +from gz.msgs11.world_reset_pb2 import WorldReset +from gz.msgs11.boolean_pb2 import Boolean + +from stable_baselines3 import A2C + +file_path = os.path.dirname(os.path.realpath(__file__)) + +class GzRewardScorer: + def __init__(self): + self.fixture = TestFixture(os.path.join(file_path, 'cart_pole.sdf')) + self.fixture.on_pre_update(self.on_pre_update) + self.fixture.on_post_update(self.on_post_update) + #self.fixture.on_configure(self.on_configure) + self.command = None + self.first_time = True # Hack cause configure does not work well + self.fixture.finalize() + self.server = self.fixture.server() + self.terminated = False + + def on_pre_update(self, info, ecm): + if self.first_time: + print("Enabling checks") + world = World(world_entity(ecm)) + self.model = Model(world.model_by_name(ecm, "vehicle_green")) + self.pole_entity = self.model.link_by_name(ecm, "pole") + self.chassis_entity = self.model.link_by_name(ecm, "chassis") + self.pole = Link(self.pole_entity) + self.pole.enable_velocity_checks(ecm) + self.chassis = Link(self.chassis_entity) + self.chassis.enable_velocity_checks(ecm) + self.first_time = False + if self.command == 1: + self.chassis.add_world_force(Vector3d(0, 100, 0)) + elif self.command == 0: + self.chassis.add_world_force(Vector3d(0, -100, 0)) + + def on_post_update(self, info, ecm): + pole_pose = self.pole.world_pose(ecm).rot().euler().y() + if self.pole.world_angular_velocity(ecm) is not None: + pole_angular_vel = self.pole.world_angular_velocity(ecm).y() + else: + pole_angular_vel = 0 + print("Warning failed to get angular velocity") + cart_pose = self.chassis.world_pose(ecm).pos().x() + cart_vel = self.chassis.world_linear_velocity(ecm) + + if cart_vel is not None: + cart_vel = cart_vel.x() + else: + cart_vel = 0 + print("Warning failed to get cart velocity") + + #print("pole", pole_pose) + #print("cart", cart_pose) + #print("Pole angvel ", pole_angular_vel) + self.state = np.array([cart_pose, cart_vel, pole_pose, pole_angular_vel], dtype=np.float32) + if not self.terminated: + self.terminated = pole_pose > 0.24 or pole_pose < -0.24 or cart_pose > 4.8 or cart_pose < -4.8 + + if self.terminated: + self.reward = 0.0 + else: + self.reward = 1.0 + + def step(self, action, paused=False): + self.action = action + self.server.run(True, 1, paused) + obs = self.state + reward = self.reward + return obs, reward, self.terminated, False, {} + + def reset(self): + print("Resetting") + self.server.reset_all() + self.first_time = True + self.command = None + self.terminated = False + obs, reward_, term_, tunc_, other_= self.step(None, paused=False) + return obs, {} + + + +class CustomCartPole(gym.Env): + def __init__(self, env_config): + self.env = GzRewardScorer() + #self.server = + self.action_space = gym.spaces.Discrete(2)#self.env.action_space + self.observation_space = gym.spaces.Box( + np.array([-10, float("-inf"), -0.418, -3.4028235e+38]), + np.array([10, float("inf"), 0.418, 3.4028235e+38]), + (4,), np.float32) + + def reset(self, seed=123): + return self.env.reset() + + def step(self, action): + obs, reward, done, truncated, info = self.env.step(action) + return obs, reward, done, truncated, info + + +env = CustomCartPole({}) +model = A2C("MlpPolicy", env, verbose=1) +model.learn(total_timesteps=10_000) + +vec_env = model.get_env() +obs = vec_env.reset() +for i in range(5000): + action, _state = model.predict(obs, deterministic=True) + obs, reward, done, info = vec_env.step(action) + # Nice to have spawn a gz sim client \ No newline at end of file