Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
1f6f327
Fixed MultiSyncCollector set_seed and split_trajs issue
ParamThakkar123 Jan 19, 2026
e2aaf6b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 20, 2026
40642d5
Revert "Fixed MultiSyncCollector set_seed and split_trajs issue"
ParamThakkar123 Jan 20, 2026
efdc89c
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 21, 2026
628f44b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 23, 2026
a476a77
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 24, 2026
0f565c5
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 25, 2026
7fb086b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 26, 2026
ff72793
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 28, 2026
69001ed
Added Support for index_select in TensorSpec
ParamThakkar123 Jan 28, 2026
4ab13be
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 29, 2026
2e8face
rebase
ParamThakkar123 Jan 29, 2026
56e1529
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 31, 2026
ba6a19f
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 4, 2026
8be545b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 5, 2026
54abe29
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 8, 2026
6b099f5
Add OpenEnv environments
ParamThakkar123 Feb 9, 2026
8f41393
Updates
ParamThakkar123 Feb 9, 2026
4a9a6e7
Merge branch 'main' of https://github.com/pytorch/rl into add/openenv
ParamThakkar123 Feb 10, 2026
9b8b119
Using ChatEnv as base class
ParamThakkar123 Feb 10, 2026
f790843
Merge branch 'main' of https://github.com/pytorch/rl into add/openenv
ParamThakkar123 Feb 10, 2026
cf6940c
Added History
ParamThakkar123 Feb 10, 2026
16ee592
Merge branch 'main' of https://github.com/pytorch/rl into add/openenv
ParamThakkar123 Feb 11, 2026
a6c0fda
Fixes
ParamThakkar123 Feb 11, 2026
9af50f2
Fixed circular import error
ParamThakkar123 Feb 12, 2026
78dd00a
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 12, 2026
4918107
Merge branch 'main' of https://github.com/pytorch/rl into add/openenv
ParamThakkar123 Feb 12, 2026
93c67a0
edits
vmoens Feb 13, 2026
f3eb268
Merge remote-tracking branch 'origin/main' into add/openenv
vmoens Feb 13, 2026
94fe080
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 13, 2026
e1212f2
Merge branch 'main' of https://github.com/ParamThakkar123/rl into add…
ParamThakkar123 Feb 17, 2026
8902b83
Merge branch 'main' of https://github.com/pytorch/rl into add/openenv
ParamThakkar123 Feb 17, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/unittest/linux_libs/scripts_openenv/environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
channels:
- pytorch
- defaults
dependencies:
- pip
- pip:
- hypothesis
- future
- cloudpickle
- pytest
- pytest-cov
- pytest-mock
- pytest-instafail
- pytest-rerunfailures
- pytest-error-for-skips
- expecttest
- pybind11[global]
- pyyaml
- scipy
- hydra-core
- openenv-core
60 changes: 60 additions & 0 deletions .github/unittest/linux_libs/scripts_openenv/install.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash

unset PYTORCH_VERSION
# For unittest, nightly PyTorch is used as the following section,
# so no need to set PYTORCH_VERSION.
# In fact, keeping PYTORCH_VERSION forces us to hardcode PyTorch version in config.

set -e

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

if [ "${CU_VERSION:-}" == cpu ] ; then
version="cpu"
else
if [[ ${#CU_VERSION} -eq 4 ]]; then
CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
elif [[ ${#CU_VERSION} -eq 5 ]]; then
CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
fi
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
fi

# submodules
git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with cu128"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128 -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu
else
pip3 install torch --index-url https://download.pytorch.org/whl/cu128
fi
else
printf "Failed to install pytorch"
exit 1
fi

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
pip3 install git+https://github.com/pytorch/tensordict.git
else
pip3 install tensordict
fi

# smoke test
python -c "import functorch;import tensordict"

printf "* Installing torchrl\n"
python -m pip install -e . --no-build-isolation

# smoke test
python -c "import torchrl"
6 changes: 6 additions & 0 deletions .github/unittest/linux_libs/scripts_openenv/post_process.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#!/usr/bin/env bash

set -e

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env
28 changes: 28 additions & 0 deletions .github/unittest/linux_libs/scripts_openenv/run_test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#!/usr/bin/env bash

set -e

eval "$(./conda/bin/conda shell.bash hook)"
conda activate ./env

apt-get update && apt-get install -y git wget cmake

export PYTORCH_TEST_WITH_SLOW='1'
export LAZY_LEGACY_OP=False
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'

root_dir="$(git rev-parse --show-toplevel)"
env_dir="${root_dir}/env"
lib_dir="${env_dir}/lib"

conda deactivate && conda activate ./env

# this workflow only tests the libs
python -c "import openenv"

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test/test_libs.py --instafail -v --durations 200 --capture no -k TestOpenEnv --error-for-skips

coverage combine -q
coverage xml -i
69 changes: 69 additions & 0 deletions .github/unittest/linux_libs/scripts_openenv/setup_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#!/usr/bin/env bash

# This script is for setting up environment in which unit test is ran.
# To speed up the CI time, the resulting environment is cached.
#
# Do not install PyTorch and torchvision here, otherwise they also get cached.

set -e
set -v

apt-get update && apt-get upgrade -y && apt-get install -y git cmake
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
apt-get install -y wget \
gcc \
g++ \
unzip \
curl \
patchelf \
libosmesa6-dev \
libgl1-mesa-glx \
libglfw3 \
swig3.0 \
libglew-dev \
libglvnd0 \
libgl1 \
libglx0 \
libegl1 \
libgles2

# Upgrade specific package
apt-get upgrade -y libstdc++6

this_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"

cd "${root_dir}"

case "$(uname -s)" in
Darwin*) os=MacOSX;;
*) os=Linux
esac

# 1. Install conda at ./conda
if [ ! -d "${conda_dir}" ]; then
printf "* Installing conda\n"
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
bash ./miniconda.sh -b -f -p "${conda_dir}"
fi
eval "$(${conda_dir}/bin/conda shell.bash hook)"

# 2. Create test environment at ./env
printf "python: ${PYTHON_VERSION}\n"
if [ ! -d "${env_dir}" ]; then
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
conda activate "${env_dir}"

# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"

pip install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune
37 changes: 37 additions & 0 deletions .github/workflows/test-linux-libs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,43 @@ jobs:
bash .github/unittest/linux_libs/scripts_meltingpot/run_test.sh
bash .github/unittest/linux_libs/scripts_meltingpot/post_process.sh

unittests-openenv:
strategy:
matrix:
python_version: ["3.10"]
cuda_arch_version: ["12.8"]
if: ${{ github.event_name == 'push' || contains(github.event.pull_request.labels.*.name, 'Environments') || contains(github.event.pull_request.labels.*.name, 'Environments/openenv') }}
uses: pytorch/test-infra/.github/workflows/linux_job_v2.yml@main
with:
repository: pytorch/rl
runner: "linux.g5.4xlarge.nvidia.gpu"
gpu-arch-type: cuda
gpu-arch-version: "12.8"
docker-image: "nvidia/cuda:12.4.0-devel-ubuntu22.04"
timeout: 120
script: |
if [[ "${{ github.ref }}" =~ release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
else
export RELEASE=0
export TORCH_VERSION=nightly
fi

set -euo pipefail
export PYTHON_VERSION="3.10"
export CU_VERSION="12.8"
export TAR_OPTIONS="--no-same-owner"
export UPLOAD_CHANNEL="nightly"
export TF_CPP_MIN_LOG_LEVEL=0
export BATCHED_PIPE_TIMEOUT=60
export TD_GET_DEFAULTS_TO_NONE=1

bash .github/unittest/linux_libs/scripts_openenv/setup_env.sh
bash .github/unittest/linux_libs/scripts_openenv/install.sh
bash .github/unittest/linux_libs/scripts_openenv/run_test.sh
bash .github/unittest/linux_libs/scripts_openenv/post_process.sh

unittests-open_spiel:
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions docs/source/reference/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,7 @@ Environment Library Configurations
MeltingpotEnvConfig
MOGymEnvConfig
MultiThreadedEnvConfig
OpenEnvEnvConfig
OpenMLEnvConfig
OpenSpielEnvConfig
PettingZooEnvConfig
Expand Down
2 changes: 2 additions & 0 deletions docs/source/reference/envs_libraries.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ Available wrappers
MOGymWrapper
MultiThreadedEnv
MultiThreadedEnvWrapper
OpenEnvEnv
OpenEnvWrapper
OpenMLEnv
OpenSpielWrapper
OpenSpielEnv
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ offline-data = [
]
marl = ["vmas>=1.2.10", "pettingzoo>=1.24.1", "dm-meltingpot; python_version>='3.11'"]
open_spiel = ["open_spiel>=1.5"]
openenv = ["openenv-core"]
brax = ["jax>=0.7.0; python_version>='3.11'", "brax; python_version>='3.11'"]
procgen = ["procgen"]
# Base LLM dependencies (no inference backend - use llm-vllm or llm-sglang)
Expand Down
43 changes: 43 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@
assert_allclose_td,
is_tensor_collection,
LazyStackedTensorDict,
NonTensorData,
TensorDict,
TensorDictBase,
)
from tensordict.nn import (
ProbabilisticTensorDictModule,
Expand Down Expand Up @@ -110,6 +112,7 @@
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper
from torchrl.envs.libs.openenv import _has_openenv, OpenEnvEnv
from torchrl.envs.libs.openml import OpenMLEnv
from torchrl.envs.libs.openspiel import _has_pyspiel, OpenSpielEnv, OpenSpielWrapper
from torchrl.envs.libs.pettingzoo import _has_pettingzoo, PettingZooEnv
Expand Down Expand Up @@ -4703,6 +4706,46 @@ def fn(data: TensorDict):
assert sample["pixels"].shape == torch.Size([32, 3, 64, 64])


@pytest.mark.skipif(not _has_openenv, reason="openenv not found")
class TestOpenEnv:
@staticmethod
def _unwrap_observation(value):
if isinstance(value, TensorDictBase):
return value.to_dict()
if isinstance(value, NonTensorData):
return value.data
return value

def test_wrapper_basic(self):
env = OpenEnvEnv("openenv/echo-env", return_observation_dict=True)
td = env.reset()
assert "observation" in td.keys()
action_td = TensorDict({"action": {"message": "ping"}}, batch_size=(1,))
td_next = env.step(action_td)
assert td_next["next", "reward"].numel() == 1
assert td_next["next", "done"].dtype == torch.bool
env.close()

def test_wrapper_action_cls(self):
env = OpenEnvEnv("openenv/echo-env", return_observation_dict=True)
env.reset()
action_td = TensorDict({"action": {"message": "hello"}}, batch_size=(1,))
td_next = env.step(action_td)
obs = self._unwrap_observation(td_next["next", "observation"])
assert obs is not None
if isinstance(obs, dict) and "echoed_message" in obs:
assert obs["echoed_message"] == "hello"
env.close()

def test_wrapper_observation_dict(self):
env = OpenEnvEnv("openenv/echo-env", return_observation_dict=True)
td = env.reset()
obs = self._unwrap_observation(td["observation"])
assert obs is not None
if isinstance(obs, dict):
assert "echoed_message" in obs
env.close()

@pytest.mark.skipif(not _has_sklearn, reason="Scikit-learn not found")
@pytest.mark.parametrize(
"dataset",
Expand Down
2 changes: 1 addition & 1 deletion torchrl/data/llm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from .history import add_chat_template, ContentBase, History
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this reorder?

from .dataset import (
create_infinite_iterator,
get_dataloader,
TensorDictTokenizer,
TokenizedDatasetLoader,
)
from .history import add_chat_template, ContentBase, History
from .prompt import PromptData, PromptTensorDictTokenizer
from .reward import PairwiseDataset, RewardData
from .topk import TopKRewardSelector
Expand Down
13 changes: 13 additions & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,8 @@
"ObservationNorm",
"ObservationTransform",
"OpenMLEnv",
"OpenEnvEnv",
"OpenEnvWrapper",
"OpenSpielEnv",
"OpenSpielWrapper",
"ParallelEnv",
Expand Down Expand Up @@ -270,3 +272,14 @@
"step_mdp",
"terminated_or_truncated",
]


def __getattr__(name):
if name in ("OpenEnvEnv", "OpenEnvWrapper"):
from torchrl.envs.libs.openenv import OpenEnvEnv, OpenEnvWrapper

_globals = globals()
_globals["OpenEnvEnv"] = OpenEnvEnv
_globals["OpenEnvWrapper"] = OpenEnvWrapper
return _globals[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
16 changes: 16 additions & 0 deletions torchrl/envs/libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
"MOGymWrapper",
"MeltingpotEnv",
"MeltingpotWrapper",
"OpenEnvEnv",
"OpenEnvWrapper",
"MultiThreadedEnv",
"MultiThreadedEnvWrapper",
"OpenMLEnv",
Expand All @@ -66,3 +68,17 @@
"register_gym_spec_conversion",
"set_gym_backend",
]


def __getattr__(name):
# Lazy import for OpenEnv to avoid circular imports: openenv imports from
# torchrl.envs.llm which eventually imports from torchrl.data, which may
# not be fully initialised yet.
if name in ("OpenEnvEnv", "OpenEnvWrapper"):
from .openenv import OpenEnvEnv, OpenEnvWrapper

_globals = globals()
_globals["OpenEnvEnv"] = OpenEnvEnv
_globals["OpenEnvWrapper"] = OpenEnvWrapper
return _globals[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
Loading