Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1f6f327
Fixed MultiSyncCollector set_seed and split_trajs issue
ParamThakkar123 Jan 19, 2026
e2aaf6b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 20, 2026
40642d5
Revert "Fixed MultiSyncCollector set_seed and split_trajs issue"
ParamThakkar123 Jan 20, 2026
efdc89c
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 21, 2026
628f44b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 23, 2026
a476a77
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 24, 2026
0f565c5
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 25, 2026
7fb086b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 26, 2026
ff72793
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 28, 2026
69001ed
Added Support for index_select in TensorSpec
ParamThakkar123 Jan 28, 2026
4ab13be
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 29, 2026
2e8face
rebase
ParamThakkar123 Jan 29, 2026
56e1529
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Jan 31, 2026
ba6a19f
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 4, 2026
8be545b
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 5, 2026
54abe29
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 8, 2026
78dd00a
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 12, 2026
94fe080
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 13, 2026
1619008
Merge branch 'main' of https://github.com/pytorch/rl
ParamThakkar123 Feb 18, 2026
1747204
Add num_workers support to isaaclab environments
ParamThakkar123 Feb 18, 2026
bbe6a38
Fixes
ParamThakkar123 Feb 18, 2026
e03d16e
Fixes
ParamThakkar123 Feb 18, 2026
84aebd4
Fixes
ParamThakkar123 Feb 18, 2026
64cafbb
Fixes
ParamThakkar123 Feb 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/reference/envs_libraries.rst
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Available wrappers
HabitatEnv
IsaacGymEnv
IsaacGymWrapper
IsaacLabEnv
IsaacLabWrapper
JumanjiEnv
JumanjiWrapper
Expand Down
22 changes: 21 additions & 1 deletion docs/source/reference/isaaclab.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@ This guide covers how to use TorchRL components with
For general IsaacLab installation and cluster setup (not specific to TorchRL), see the
`knowledge_base/ISAACLAB.md <https://github.com/pytorch/rl/blob/main/knowledge_base/ISAACLAB.md>`_ file.

IsaacLabEnv
-----------

Use :class:`~torchrl.envs.libs.isaac_lab.IsaacLabEnv` to build IsaacLab
environments directly from their gymnasium ID:

.. code-block:: python

from torchrl.envs.libs.isaac_lab import IsaacLabEnv

env = IsaacLabEnv("Isaac-Ant-v0", cfg=env_cfg)

``IsaacLabEnv`` supports ``num_workers`` following the same lazy behavior as
other TorchRL env libraries:

.. code-block:: python

env = IsaacLabEnv("Isaac-Ant-v0", cfg=env_cfg, num_workers=2)
# env is a lazy ParallelEnv until first reset/step/spec query

IsaacLabWrapper
---------------

Expand Down Expand Up @@ -52,7 +72,7 @@ Collector
---------

Because IsaacLab environments are **pre-vectorized** (a single ``gym.make``
creates ~4096 parallel environments on the GPU), use a single
creates ~4096 parallel environments on the GPU), most workloads can use a single
:class:`~torchrl.collectors.Collector` — there is no need for
``ParallelEnv`` or ``MultiCollector``:

Expand Down
12 changes: 12 additions & 0 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
set_gym_backend,
)
from torchrl.envs.libs.habitat import _has_habitat, HabitatEnv
from torchrl.envs.libs.isaac_lab import IsaacLabEnv
from torchrl.envs.libs.jumanji import _has_jumanji, JumanjiEnv
from torchrl.envs.libs.meltingpot import MeltingpotEnv, MeltingpotWrapper
from torchrl.envs.libs.openml import OpenMLEnv
Expand Down Expand Up @@ -5568,6 +5569,17 @@ def test_render(self, rollout_steps):

@pytest.mark.skipif(not _has_isaaclab, reason="Isaaclab not found")
class TestIsaacLab:
def test_num_workers_returns_lazy_parallel_env(self):
env = IsaacLabEnv("Isaac-Ant-v0", num_workers=2)
try:
assert isinstance(env, ParallelEnv)
assert env.num_workers == 2
assert env.is_closed
env.configure_parallel(use_buffers=False)
assert env._use_buffers is False
finally:
env.close()

@pytest.fixture(scope="class")
def env(self):
env = torchrl.testing.env_helper.make_isaac_env()
Expand Down
1 change: 0 additions & 1 deletion torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6828,7 +6828,6 @@ def _stack_specs(list_of_spec, dim=0, out=None):
else:
raise NotImplementedError


@TensorSpec.implements_for_spec(torch.index_select)
@Composite.implements_for_spec(torch.index_select)
def _index_select_spec(input: TensorSpec, dim: int, index: torch.Tensor) -> TensorSpec:
Expand Down
2 changes: 2 additions & 0 deletions torchrl/envs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
HabitatEnv,
IsaacGymEnv,
IsaacGymWrapper,
IsaacLabEnv,
IsaacLabWrapper,
JumanjiEnv,
JumanjiWrapper,
Expand Down Expand Up @@ -139,6 +140,7 @@
"ActionDiscretizer",
"ActionMask",
"VecNormV2",
"IsaacLabEnv",
"IsaacLabWrapper",
"AutoResetEnv",
"AutoResetTransform",
Expand Down
3 changes: 2 additions & 1 deletion torchrl/envs/libs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
set_gym_backend,
)
from .habitat import HabitatEnv
from .isaac_lab import IsaacLabWrapper
from .isaac_lab import IsaacLabEnv, IsaacLabWrapper
from .isaacgym import IsaacGymEnv, IsaacGymWrapper
from .jumanji import JumanjiEnv, JumanjiWrapper
from .meltingpot import MeltingpotEnv, MeltingpotWrapper
Expand All @@ -39,6 +39,7 @@
"HabitatEnv",
"IsaacGymEnv",
"IsaacGymWrapper",
"IsaacLabEnv",
"IsaacLabWrapper",
"JumanjiEnv",
"JumanjiWrapper",
Expand Down
135 changes: 115 additions & 20 deletions torchrl/envs/libs/isaac_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,85 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import functools
import importlib.util

import torch
from torchrl.envs.libs.gym import GymWrapper
from torchrl.envs.libs.gym import _GymAsyncMeta, GymEnv, GymWrapper
from torchrl.envs.utils import _classproperty

_has_isaaclab = importlib.util.find_spec("isaaclab") is not None


def _raise_isaaclab_import_error():
raise ImportError(
"IsaacLab could not be loaded. Consider installing it and importing/launching "
"IsaacLab before creating an environment. Refer to TorchRL's knowledge base in "
"the documentation to debug IsaacLab installation."
)


def _wrap_import_error(fun):
@functools.wraps(fun)
def new_fun(*args, **kwargs):
if not _has_isaaclab:
_raise_isaaclab_import_error()
return fun(*args, **kwargs)

return new_fun


@_wrap_import_error
def _get_available_envs():
for env in GymEnv.available_envs:
if env.startswith("Isaac"):
yield env


class _IsaacLabMeta(_GymAsyncMeta):
"""Metaclass for IsaacLabEnv that returns a lazy ParallelEnv when num_workers > 1."""

def __call__(cls, *args, num_workers: int | None = None, **kwargs):
# Extract num_workers from explicit kwarg or kwargs dict
if num_workers is None:
num_workers = kwargs.pop("num_workers", 1)
else:
kwargs.pop("num_workers", None)

num_workers = int(num_workers) if num_workers is not None else 1
if getattr(cls, "__name__", None) == "IsaacLabEnv" and num_workers > 1:
from torchrl.envs import ParallelEnv

env_name = args[0] if len(args) >= 1 else kwargs.get("env_name")
env_kwargs = {k: v for k, v in kwargs.items() if k != "env_name"}
make_env = functools.partial(cls, env_name, num_workers=1, **env_kwargs)
return ParallelEnv(num_workers, make_env)

return super().__call__(*args, **kwargs)


class IsaacLabWrapper(GymWrapper):
class _IsaacLabMixin:
def seed(self, seed: int | None):
self._set_seed(seed)

def _output_transform(self, step_outputs_tuple): # noqa: F811
# IsaacLab will modify the `terminated` and `truncated` tensors in-place.
# We clone them here to make sure data doesn't inadvertently get modified.
# The variable naming follows torchrl's convention here.
observations, reward, terminated, truncated, info = step_outputs_tuple
done = terminated | truncated
reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
return (
observations,
reward,
terminated.clone(),
truncated.clone(),
done.clone(),
info,
)


class IsaacLabWrapper(_IsaacLabMixin, GymWrapper):
"""A wrapper for IsaacLab environments.

Args:
Expand All @@ -19,7 +93,7 @@ class IsaacLabWrapper(GymWrapper):
Defaults to ``False``.
allow_done_after_reset (bool, optional): if ``True``, it is tolerated
for envs to be ``done`` just after :meth:`reset` is called.
Defaults to ``False``.
Defaults to ``True``.

For other arguments, see the :class:`torchrl.envs.GymWrapper` documentation.

Expand Down Expand Up @@ -67,21 +141,42 @@ def __init__(
**kwargs,
)

def seed(self, seed: int | None):
self._set_seed(seed)

def _output_transform(self, step_outputs_tuple): # noqa: F811
# IsaacLab will modify the `terminated` and `truncated` tensors
# in-place. We clone them here to make sure data doesn't inadvertently get modified.
# The variable naming follows torchrl's convention here.
observations, reward, terminated, truncated, info = step_outputs_tuple
done = terminated | truncated
reward = reward.unsqueeze(-1) # to get to (num_envs, 1)
return (
observations,
reward,
terminated.clone(),
truncated.clone(),
done.clone(),
info,
)
class IsaacLabEnv(_IsaacLabMixin, GymEnv, metaclass=_IsaacLabMeta):
"""IsaacLab environment wrapper built from environment ID.

This class behaves like :class:`~torchrl.envs.GymEnv` but applies IsaacLab-specific
defaults and output processing.

Args:
env_name (str): environment ID registered in gymnasium.

Keyword Args:
num_workers (int, optional): if provided and greater than 1, a lazy
:class:`torchrl.envs.ParallelEnv` will be instantiated with
``num_workers`` copies of ``IsaacLabEnv``. Defaults to ``1``.
allow_done_after_reset (bool, optional): defaults to ``True`` for IsaacLab
compatibility.
convert_actions_to_numpy (bool, optional): defaults to ``False`` so actions
stay as tensors.
device (torch.device, optional): defaults to ``torch.device("cuda:0")``.

For other keyword arguments, see :class:`~torchrl.envs.GymEnv`.
"""

@_classproperty
def available_envs(cls):
if not _has_isaaclab:
return []
return list(_get_available_envs())

@_wrap_import_error
def __init__(self, env_name: str, **kwargs):
kwargs.setdefault("backend", "gymnasium")
kwargs.setdefault("allow_done_after_reset", True)
kwargs.setdefault("convert_actions_to_numpy", False)
device = kwargs.pop("device", None)
if device is None:
device = torch.device("cuda:0")
kwargs["device"] = device
super().__init__(env_name=env_name, **kwargs)
7 changes: 2 additions & 5 deletions torchrl/testing/env_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,10 @@ def make_isaac_env(env_name: str = "Isaac-Ant-v0"):
AppLauncher(args_cli)

# Imports and env
import gymnasium as gym
import isaaclab_tasks # noqa: F401
from isaaclab_tasks.manager_based.classic.ant.ant_env_cfg import AntEnvCfg
from torchrl.envs.libs.isaac_lab import IsaacLabWrapper
from torchrl.envs.libs.isaac_lab import IsaacLabEnv

torchrl_logger.info("Making IsaacLab env...")
env = gym.make(env_name, cfg=AntEnvCfg())
torchrl_logger.info("Wrapping IsaacLab env...")
env = IsaacLabWrapper(env)
env = IsaacLabEnv(env_name, cfg=AntEnvCfg())
return env