Skip to content

JAX and PZ RL envs + algorithms for swarms of CrazyFlies

License

Notifications You must be signed in to change notification settings

ffelten/CrazyRL

Repository files navigation

The CPU-based environments have been moved into MOMAland for long-term maintenance. I suggest using these if you do not need the Jax-based implementations.

Swarm

pre-commit Code style: black Imports: isort Test: pytest YouTube

CrazyRL

A hardware-accelerated library for doing Multi-Agent Reinforcement Learning with Crazyflie drones. A video showing the results with real drones in our lab is available on YouTube.

It has:

  • ⚡️ A lightweight and fast simulator that is good enough to control Crazyflies in practice;
  • 🤝 A set of environments implemented in Python and Numpy, under the PettingZoo parallel API;
  • 🚀 The same environments implemented in Jax, that can be run fully on GPU;
  • 🧠 MARL algorithms implemented in Jax, both for PettingZoo and for full Jax environments;
  • 🚁 A set of utilities based on the cflib to control actual Crazyflies;
  • ✅ Good quality, tested and documented Python code;

The real-life example shown in the video is the result of executing the policies in real-life after learning in the lightweight simulator. Once the environment trained it can be displayed on simulation environment or in reality with the Crazyflies.

Environments

The red balls represent the position of the controlled drones.

Circle

The drones learn to perform a coordinated circle.

Circle

The yellow balls represent the target position of the drones.

Available in Numpy and JAX version.

Surround

The drones learn to surround a fixed target point.

Surround

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Escort

The drones learn to escort a target moving straight to one point to another.

Escort

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Catch

The drones learn to catch a target trying to escape.

Catch

The yellow ball represents the target the drones have to surround.

Available in Numpy and JAX version.

Learning

We provide implementations of MAPPO [1] both compatible with a CPU env (PettingZoo parallel API), and a GPU env (our JAX API). These implementations should be very close to each others in terms of sample efficiency but the GPU version is immensely faster in terms of time. We also have a multi-agent version of SAC, MASAC, which is compatible with the CPU envs.

In the above image, we can see that sample efficiency of both MAPPO versions are very close, but the JAX version is much faster in terms of time. Notice that the Jax version can be improved further by relying on vectorized envs.

Multi-Objective Multi-Agent RL

When vmapping over a set of weight vectors to perform MOMARL learning, we achieve sublinear scaling w.r.t. the number of Pareto optimal policies we aim at learning:

Learning time

API

There are examples of usage in the test files and main methods of the environments. Moreover, the learning folder contains examples of MARL algorithms.

Python/Numpy version

Basic version which can be used for training, simulation and the real drones. It follows the PettingZoo parallel API.

Execution :

from crazy_rl.multi_agent.numpy.circle.circle import Circle

env: ParallelEnv = Circle(
    drone_ids=np.array([0, 1]),
    render_mode="human",    # or real, or None
    init_flying_pos=np.array([[0, 0, 1], [2, 2, 1]]),
)

obs, info = env.reset()

done = False
while not done:
    # Execute policy for each agent
    actions: Dict[str, np.ndarray] = {}
    for agent_id in env.possible_agents:
        actions[agent_id] = actor.get_action(obs[agent_id], agent_id)

    obs, _, terminated, truncated, info = env.step(actions)
    done = terminated or truncated

You can have a look at the learning/ folder to see how we execute pre-trained policies.

JAX version

This version is specifically optimized for GPU usage and intended for agent training purposes. However, simulation and real-world functionalities are not available in this version.

Moreover, it is not compliant with the PettingZoo API as it heavily relies on functional programming. We sacrificed the API compatibility for huge performance gains.

Some functionalities are automatically done by wrappers, such as vmap, enabling parallelized training, allowing to leverage all the cores on the GPU. While it offers faster performance on GPUs, it may exhibit slower execution on CPUs.

You can find other wrappers you may need defined in jax_wrappers.

Execution:

from jax import random
from crazy_rl.multi_agent.jax.circle.circle import Circle

parallel_env = Circle(
        num_drones=5,
        init_flying_pos=jnp.array([[0.0, 0.0, 1.0], [2.0, 1.0, 1.0], [0.0, 1.0, 1.0], [2.0, 2.0, 1.0], [1.0, 0.0, 1.0]]),
        num_intermediate_points=100,
    )

num_envs = 3  # number of envs in parallel
seed = 5  # PRNG seed
key = random.PRNGKey(seed)
key, subkeys = random.split(key)
subkeys = random.split(subkeys, num_envs)

# Wrappers
env = AutoReset(env)  # Auto reset the env when done, stores additional info in the dict
env = VecEnv(env)  # Vectorizes the env public methods

obs, info, state = env.reset(subkeys)

# Example of stepping through the 5 parallel environments
for i in range(301):
    actions = jnp.zeros((num_envs, parallel_env.num_drones, parallel_env.action_space(0).shape[0]))
    for env_id, obs in enumerate(obs):
        for agent_id in range(parallel_env.num_drones):
            key, subkey = random.split(key)
            actions[env_id, agent_id] = actor.get_action(obs, agent_id, subkey) # YOUR POLICY HERE

    key, *subkeys = random.split(key, num_envs + 1)
    obs, rewards, term, trunc, info, state = env.step(state, actions, jnp.stack(subkeys))

    # where you would learn or add to buffer

Install & run

Numpy version

poetry install
poetry run python crazy_rl/multi_agent/numpy/circle/circle.py

JAX on CPU

poetry install
poetry run python crazy_rl/multi_agent/jax/circle/circle.py

JAX on GPU

JAX GPU support is not included in the pyproject.toml file, as JAX CPU is the default option. Therefore, you need to manually install JAX GPU and disregard the poetry requirements for this purpose.

poetry install
poetry shell
pip install --upgrade pip

# Using CUDA 12
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Or using CUDA 11
pip install --upgrade "jax[cuda11_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

python crazy_rl/learning/mappo.py

Please refer to the JAX installation GitHub page for the specific CUDA version requirements.

After installation, the JAX version automatically utilizes the GPU as the default device. However, if you prefer to switch to the CPU without reinstalling, you can manually set the device using the following command:

jax.config.update("jax_platform_name", "cpu")

Modes

Simulation

render_mode = "human"

The simulation is a simple particle representation on a 3D cartesian reference based on Crazyflie lighthouse reference frame. It is sufficient since the control of the CrazyFlies is high-level and precise enough.

Available in the Numpy version.

Real

render_mode = "real"

In our experiments, positioning was managed by Lighthouse positioning. It can probably be deployed with other positioning systems too.

Available in the Numpy version.

Guidelines

Firstly configuration of the positioning system has to be saved in a config file using the cfclient app. We have a script which does that in geometry.py. You have to run it for each drone id, e.g. python geometry.py geometry.yaml 1,2,4 0.

Secondly place the turned on drones on your environment, on the ground below the positions given to init_flying_pos in your code. Be careful to put your drones at their right place depending on their id to avoid any crash at start up.

Tips

Verify also that the LEDs on drones aren't red: it means the drone have not enough battery to pursue the mission.

The LED on lighthouse deck have to be green to ensure a good reception of lighthouse positioning.

Dev infos

Structure

The project consists of two versions, each with corresponding files located in the JAX directory and the Numpy directory, respectively.

In the Numpy version, the switch between real environment and simulation is specified through the render_mode option, can be "real", "human" or None.

BaseParallelEnv is the base class for the environment in both versions. It contains the basic methods to interact with the environment. From there, child classes allow to specify specific tasks such as Circle or Hover. utils/ contains the basic functions to interact with the drones, OpenGL stuff for rendering and wrappers which add automatic behaviours to JAX version.

You can explore the test files to gain examples of usage and make comparisons between the Numpy and JAX versions.

Env design

The envs often try to minimize the distance towards the target of each drone. While we initially modelled this as the negative distance, it seems that PPO doesn't like having only negative reward signals. Thus, we opted for potential based rewards [2] instead.

In some cases, an additional conflicting reward is also needed: maximizing the distance towards the other drones. Both rewards are then linearly combined using weights which pre-defined. To find the weights, we used a multi-objective technique consisting in exposing the rewards as vectors and let the learning algorithm try multiple weights (in the Jax version, it is trivially performed by vmapping the learning loop under a few weights). While this seems very simple, it is blazing fast because there is no coordination needed between threads.

Related projects

  • MORL-Baselines and MO-Gymnasium: Algorithms and environments for multi-objective RL, but not multi-agent :-);
  • JaxMARL: Multi-agent RL environments in Jax, published at the same time as this work, but not multi-objective :-);
  • gymnax: RL environments in Jax, but not multi-agent;
  • PureJaxRL: End-to-end RL in Jax, but not multi-agent;
  • PettingZoo: MARL API and environments;
  • MOMAland: MOMARL API and environments, including CrazyRLs, under construction;
  • cflib: Crazyflie Python library;
  • CrazyFlyt: Simulation and real life control of Crazyflies, the main difference with this project is that the simulator is an actual, heavyweight simulator (Pybullet). Hence, it does not have a full jax version. It is in practice more fit for learning controllers, while our project focuses on learning swarm formation.

Citation

If you use this code for your research, please cite this using:

@phdthesis{felten_multi-objective_2024,
	type = {{PhD} {Thesis}},
	title = {Multi-{Objective} {Reinforcement} {Learning}},
	url = {https://hdl.handle.net/10993/61488},
	language = {English},
	school = {Unilu - Université du Luxembourg [FSTM], Luxembourg},
	author = {Felten, Florian},
	month = jun,
	year = {2024},
}

Bibliography

[1] C. Yu et al., “The Surprising Effectiveness of PPO in Cooperative Multi-Agent Games,” presented at the Thirty-sixth Conference on Neural Information Processing Systems Datasets and Benchmarks Track, Jun. 2022. Accessed: Sep. 05, 2023. [Online]. Available: https://openreview.net/forum?id=YVXaxB6L2Pl

[2] A. Ng, D. Harada, and S. J. Russell, “Policy Invariance Under Reward Transformations: Theory and Application to Reward Shaping,” presented at the International Conference on Machine Learning, Jun. 1999. Accessed: Aug. 10, 2023. [Online]. Available: https://www.semanticscholar.org/paper/Policy-Invariance-Under-Reward-Transformations%3A-and-Ng-Harada/94066dc12fe31e96af7557838159bde598cb4f10

About

JAX and PZ RL envs + algorithms for swarms of CrazyFlies

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages