From 50f9bf4143302cfdfa184692ddb58e15951b6f0f Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Mon, 16 Feb 2026 10:33:16 +0000 Subject: [PATCH] Update [ghstack-poisoned] --- benchmarks/bench_collectors.py | 15 +- .../collectors/async_batched_collector.py | 11 +- test/test_inference_server.py | 12 +- torchrl/collectors/_async_batched.py | 114 +++++++++-- torchrl/envs/async_envs.py | 20 +- torchrl/modules/inference_server/__init__.py | 2 + torchrl/modules/inference_server/_server.py | 22 +++ torchrl/modules/inference_server/_slot.py | 182 ++++++++++++++++++ 8 files changed, 335 insertions(+), 43 deletions(-) create mode 100644 torchrl/modules/inference_server/_slot.py diff --git a/benchmarks/bench_collectors.py b/benchmarks/bench_collectors.py index 26ca0384192..181a14d2fbf 100644 --- a/benchmarks/bench_collectors.py +++ b/benchmarks/bench_collectors.py @@ -12,7 +12,8 @@ 2. Collector (ParallelEnv x N) -- single-process, N envs in sub-procs 3. MultiCollector (sync, x N) -- N sub-processes, sync delivery 4. MultiCollector (async, x N) -- N sub-processes, async delivery - 5. AsyncBatchedCollector (threading) -- AsyncEnvPool + InferenceServer + 5. AsyncBatched (env=thread, pol=thread) -- threading pool + threading transport + 6. AsyncBatched (env=mp, pol=thread) -- multiprocessing pool + threading transport """ from __future__ import annotations @@ -368,33 +369,33 @@ def policy_factory(): ) ) - # 5. AsyncBatchedCollector (threading backend) + # 5. AsyncBatchedCollector (env=threading, policy=threading) results.append( bench( - f"AsyncBatchedCollector threading (x{num_envs})", + f"AsyncBatched env=thread pol=thread (x{num_envs})", lambda: AsyncBatchedCollector( create_env_fn=[make_env_fn] * num_envs, policy=policy_factory(), frames_per_batch=frames_per_batch, total_frames=-1, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ), target_frames=total_frames, ) ) - # 6. AsyncBatchedCollector (multiprocessing backend) + # 6. AsyncBatchedCollector (env=multiprocessing, policy=threading) results.append( bench( - f"AsyncBatchedCollector mp (x{num_envs})", + f"AsyncBatched env=mp pol=thread (x{num_envs})", lambda: AsyncBatchedCollector( create_env_fn=[make_env_fn] * num_envs, policy=policy_factory(), frames_per_batch=frames_per_batch, total_frames=-1, max_batch_size=num_envs, - backend="multiprocessing", + env_backend="multiprocessing", ), target_frames=total_frames, ) diff --git a/examples/collectors/async_batched_collector.py b/examples/collectors/async_batched_collector.py index 981ce9cc05f..0d1dd254c31 100644 --- a/examples/collectors/async_batched_collector.py +++ b/examples/collectors/async_batched_collector.py @@ -6,14 +6,21 @@ Architecture: - An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel - using the chosen backend (``"threading"`` or ``"multiprocessing"``). + using the chosen ``env_backend`` (``"threading"`` or ``"multiprocessing"``). - One lightweight coordinator thread per environment owns a slot in the pool and an inference client. - An :class:`~torchrl.modules.InferenceServer` batches incoming observations - and runs a single forward pass. + and runs a single forward pass. The communication layer (transport) is + controlled by ``policy_backend`` (``"threading"``, ``"multiprocessing"``, + ``"ray"``, or ``"monarch"``). - There is no global synchronisation barrier -- fast envs keep stepping while slow ones wait for inference. +Backend parameters: + - ``backend`` -- global default for both env pool and policy transport. + - ``env_backend`` -- override for the env pool (falls back to ``backend``). + - ``policy_backend`` -- override for the transport (falls back to ``backend``). + The user only supplies: - A list of environment factories - A policy (or policy factory) diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 9ec30f7a27e..4e11b5793bc 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -733,7 +733,7 @@ def test_basic_collection(self): frames_per_batch=frames_per_batch, total_frames=total_frames, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) total_collected = 0 for batch in collector: @@ -751,7 +751,7 @@ def test_policy_factory(self): frames_per_batch=10, total_frames=20, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) total_collected = 0 for batch in collector: @@ -791,7 +791,7 @@ def test_yield_completed_trajectories(self): total_frames=30, yield_completed_trajectories=True, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) count = 0 for batch in collector: @@ -809,7 +809,7 @@ def test_shutdown_idempotent(self): policy=policy, frames_per_batch=10, total_frames=10, - backend="threading", + env_backend="threading", ) # Consume one batch to start for _batch in collector: @@ -825,7 +825,7 @@ def test_endless_collector(self): policy=policy, frames_per_batch=10, total_frames=-1, - backend="threading", + env_backend="threading", ) collected = 0 for batch in collector: @@ -862,7 +862,7 @@ def postproc(td): frames_per_batch=10, total_frames=20, postproc=postproc, - backend="threading", + env_backend="threading", ) for _ in collector: pass diff --git a/torchrl/collectors/_async_batched.py b/torchrl/collectors/_async_batched.py index 0e0791ae17c..5c3bca481b5 100644 --- a/torchrl/collectors/_async_batched.py +++ b/torchrl/collectors/_async_batched.py @@ -21,6 +21,45 @@ _ENV_IDX_KEY = "env_index" +_POLICY_BACKENDS = ("threading", "multiprocessing", "ray", "monarch") +_ENV_BACKENDS = ("threading", "multiprocessing") + + +def _make_transport( + policy_backend: str, num_slots: int | None = None +) -> InferenceTransport: + """Create an :class:`InferenceTransport` from a backend name. + + Args: + policy_backend: one of ``"threading"``, ``"multiprocessing"``, + ``"ray"``, or ``"monarch"``. + num_slots: when set and ``policy_backend="threading"``, a + :class:`~torchrl.modules.SlotTransport` is created instead of + the generic :class:`~torchrl.modules.ThreadingTransport`. + """ + if policy_backend == "threading": + if num_slots is not None: + from torchrl.modules.inference_server._slot import SlotTransport + + return SlotTransport(num_slots) + return ThreadingTransport() + if policy_backend == "multiprocessing": + from torchrl.modules.inference_server._mp import MPTransport + + return MPTransport() + if policy_backend == "ray": + from torchrl.modules.inference_server._ray import RayTransport + + return RayTransport() + if policy_backend == "monarch": + from torchrl.modules.inference_server._monarch import MonarchTransport + + return MonarchTransport() + raise ValueError( + f"Unknown policy_backend {policy_backend!r}. " + f"Expected one of {_POLICY_BACKENDS}." + ) + def _env_loop( pool: AsyncEnvPool, @@ -47,9 +86,7 @@ def _env_loop( while not shutdown_event.is_set(): pool.async_step_and_maybe_reset_send(action_td, env_index=env_id) - cur_td, next_obs = pool.async_step_and_maybe_reset_recv( - env_index=env_id - ) + cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id) cur_td.set(_ENV_IDX_KEY, env_id) result_queue.put(cur_td) if shutdown_event.is_set(): @@ -104,22 +141,35 @@ class AsyncBatchedCollector(BaseCollector): max_batch_size (int, optional): upper bound on the number of requests the inference server processes in a single forward pass. Defaults to ``64``. + min_batch_size (int, optional): minimum number of requests the + inference server accumulates before dispatching a batch. After + the first request arrives the server keeps draining for up to + ``server_timeout`` seconds until this many items are collected. + ``1`` (default) dispatches immediately. server_timeout (float, optional): seconds the server waits for work before dispatching a partial batch. Defaults to ``0.01``. transport (InferenceTransport, optional): a pre-built transport - backend. When ``None`` (default) a - :class:`~torchrl.modules.ThreadingTransport` is created - automatically (since worker threads always live in the main - process). Pass a :class:`~torchrl.modules.RayTransport` or - :class:`~torchrl.modules.MonarchTransport` for distributed - setups where the inference server is remote. + object. When provided, it takes precedence over + ``policy_backend``. When ``None`` (default) a transport is + created automatically from the resolved ``policy_backend``. device (torch.device or str, optional): device for policy inference. Passed to the inference server. Defaults to ``None``. - backend (str, optional): backend for the + backend (str, optional): global default backend for both + environments and policy inference. Specific overrides + ``env_backend`` and ``policy_backend`` take precedence when set. + One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or + ``"monarch"``. Defaults to ``"threading"``. + env_backend (str, optional): backend for the :class:`~torchrl.envs.AsyncEnvPool` that runs environments. One - of ``"threading"`` or ``"multiprocessing"``. The coordinator - threads are always Python threads regardless of this setting. - Defaults to ``"threading"``. + of ``"threading"`` or ``"multiprocessing"``. Falls back to + ``backend`` when ``None``. The coordinator threads are always + Python threads regardless of this setting. Defaults to ``None``. + policy_backend (str, optional): backend for the inference transport + used to communicate with the + :class:`~torchrl.modules.InferenceServer`. One of + ``"threading"``, ``"multiprocessing"``, ``"ray"``, or + ``"monarch"``. Falls back to ``backend`` when ``None``. + Defaults to ``None``. reset_at_each_iter (bool, optional): whether to reset all envs at the start of every collection batch. Defaults to ``False``. postproc (Callable, optional): post-processing transform applied to @@ -169,10 +219,16 @@ def __init__( frames_per_batch: int, total_frames: int = -1, max_batch_size: int = 64, + min_batch_size: int = 1, server_timeout: float = 0.01, transport: InferenceTransport | None = None, device: torch.device | str | None = None, - backend: Literal["threading", "multiprocessing"] = "threading", + backend: Literal[ + "threading", "multiprocessing", "ray", "monarch" + ] = "threading", + env_backend: Literal["threading", "multiprocessing"] | None = None, + policy_backend: Literal["threading", "multiprocessing", "ray", "monarch"] + | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, yield_completed_trajectories: bool = False, @@ -196,12 +252,26 @@ def __init__( raise TypeError("create_env_fn must be a list of env factories.") self._create_env_fn = list(create_env_fn) self._num_envs = len(create_env_fn) - self._backend = backend self._create_env_kwargs = create_env_kwargs + # ---- resolve backends ------------------------------------------------- + effective_env_backend = env_backend if env_backend is not None else backend + effective_policy_backend = ( + policy_backend if policy_backend is not None else backend + ) + if effective_env_backend not in _ENV_BACKENDS: + raise ValueError( + f"env_backend={effective_env_backend!r} is not supported. " + f"Expected one of {_ENV_BACKENDS}." + ) + self._env_backend = effective_env_backend + self._policy_backend = effective_policy_backend + # ---- build transport -------------------------------------------------- if transport is None: - transport = ThreadingTransport() + transport = _make_transport( + effective_policy_backend, num_slots=self._num_envs + ) self._transport = transport # ---- build inference server ------------------------------------------- @@ -209,6 +279,7 @@ def __init__( model=policy, transport=transport, max_batch_size=max_batch_size, + min_batch_size=min_batch_size, timeout=server_timeout, device=device, weight_sync=weight_sync, @@ -252,7 +323,7 @@ def _ensure_started(self) -> None: kwargs["create_env_kwargs"] = self._create_env_kwargs self._env_pool = AsyncEnvPool( self._create_env_fn, - backend=self._backend, + backend=self._env_backend, **kwargs, ) @@ -303,9 +374,18 @@ def _rollout_frames(self) -> TensorDictBase: transitions: list[TensorDictBase] = [] while collected < self.frames_per_batch: + # Block for at least one transition td = rq.get() transitions.append(td) collected += td.numel() + # Batch-drain any additional items already in the queue + while collected < self.frames_per_batch: + try: + td = rq.get_nowait() + except queue.Empty: + break + transitions.append(td) + collected += td.numel() if self.verbose: torchrl_logger.debug( f"AsyncBatchedCollector: {collected}/{self.frames_per_batch} frames" diff --git a/torchrl/envs/async_envs.py b/torchrl/envs/async_envs.py index 19cd53dbb48..8103cc53fa8 100644 --- a/torchrl/envs/async_envs.py +++ b/torchrl/envs/async_envs.py @@ -673,9 +673,7 @@ def async_step_and_maybe_reset_send( for _env_idx, local_td in _zip_strict(env_idx, local_tds): if not _per_env: self._current_step_reset = self._current_step_reset + 1 - self.input_queue[_env_idx].put( - ("step_and_maybe_reset", local_td, _per_env) - ) + self.input_queue[_env_idx].put(("step_and_maybe_reset", local_td, _per_env)) def async_step_and_maybe_reset_recv( self, min_get: int = 1, env_index: int | None = None @@ -807,29 +805,29 @@ def _env_exec( elif msg == "batch_size": output_queue.put(env.batch_size) elif msg == "reset": - data = env.reset(data.copy()) + # No .copy() needed: data was deserialized from the queue + # and is not referenced after this call. + data = env.reset(data) data.set(cls._env_idx_key, NonTensorData(i)) target = per_env_reset_queue if per_env else reset_queue target.put(data) elif msg == "_reset": - data = env._reset(data.copy()) + data = env._reset(data) data.set(cls._env_idx_key, NonTensorData(i)) reset_queue.put(data) elif msg == "step_and_maybe_reset": - data, data_ = env.step_and_maybe_reset(data.copy()) + data, data_ = env.step_and_maybe_reset(data) data.set(cls._env_idx_key, NonTensorData(i)) data_.set(cls._env_idx_key, NonTensorData(i)) - target = ( - per_env_step_reset_queue if per_env else step_reset_queue - ) + target = per_env_step_reset_queue if per_env else step_reset_queue target.put((data, data_)) elif msg == "step": - data = env.step(data.copy()) + data = env.step(data) data.set(cls._env_idx_key, NonTensorData(i)) target = per_env_step_queue if per_env else step_queue target.put(data) elif msg == "_step": - data = env._step(data.copy()) + data = env._step(data) data.set(cls._env_idx_key, NonTensorData(i)) step_queue.put(data) elif msg == "shutdown": diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py index e68f98626fd..388ba33028b 100644 --- a/torchrl/modules/inference_server/__init__.py +++ b/torchrl/modules/inference_server/__init__.py @@ -7,6 +7,7 @@ from torchrl.modules.inference_server._mp import MPTransport from torchrl.modules.inference_server._ray import RayTransport from torchrl.modules.inference_server._server import InferenceClient, InferenceServer +from torchrl.modules.inference_server._slot import SlotTransport from torchrl.modules.inference_server._threading import ThreadingTransport from torchrl.modules.inference_server._transport import InferenceTransport @@ -17,5 +18,6 @@ "MonarchTransport", "MPTransport", "RayTransport", + "SlotTransport", "ThreadingTransport", ] diff --git a/torchrl/modules/inference_server/_server.py b/torchrl/modules/inference_server/_server.py index 42ed04fcbdc..b4cbf09848b 100644 --- a/torchrl/modules/inference_server/_server.py +++ b/torchrl/modules/inference_server/_server.py @@ -5,6 +5,7 @@ from __future__ import annotations import threading +import time from collections.abc import Callable from concurrent.futures import Future @@ -32,6 +33,11 @@ class InferenceServer: Keyword Args: max_batch_size (int, optional): upper bound on the number of requests processed in a single forward pass. Default: ``64``. + min_batch_size (int, optional): minimum number of requests to + accumulate before dispatching a batch. After the first request + arrives the server keeps draining for up to ``timeout`` seconds + until at least this many items are collected. ``1`` (default) + dispatches immediately. timeout (float, optional): seconds to wait for new work before dispatching a partial batch. Default: ``0.01``. collate_fn (Callable, optional): function used to stack a list of @@ -70,6 +76,7 @@ def __init__( transport: InferenceTransport, *, max_batch_size: int = 64, + min_batch_size: int = 1, timeout: float = 0.01, collate_fn: Callable | None = None, device: torch.device | str | None = None, @@ -79,6 +86,7 @@ def __init__( self.model = model self.transport = transport self.max_batch_size = max_batch_size + self.min_batch_size = min_batch_size self.timeout = timeout self.collate_fn = collate_fn if collate_fn is not None else lazy_stack self.device = torch.device(device) if device is not None else None @@ -162,6 +170,20 @@ def _run(self) -> None: if not items: continue + # Accumulate up to min_batch_size (or until timeout expires) + if len(items) < self.min_batch_size: + deadline = time.monotonic() + self.timeout + while len(items) < self.min_batch_size: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + self.transport.wait_for_work(timeout=remaining) + more_items, more_cbs = self.transport.drain( + self.max_batch_size - len(items) + ) + items.extend(more_items) + callbacks.extend(more_cbs) + batch = self.collate_fn(items) if self.device is not None: batch = batch.to(self.device) diff --git a/torchrl/modules/inference_server/_slot.py b/torchrl/modules/inference_server/_slot.py new file mode 100644 index 00000000000..eaaa7313b3f --- /dev/null +++ b/torchrl/modules/inference_server/_slot.py @@ -0,0 +1,182 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import threading + +from tensordict.base import TensorDictBase + +from torchrl.modules.inference_server._transport import InferenceTransport + + +class _SlotClient: + """Actor-side handle for a :class:`SlotTransport` slot. + + Each client owns a single slot (identified by ``slot_id``). Calling the + client writes the observation into the slot and blocks until the server + writes the action back. + + Args: + transport: the parent :class:`SlotTransport`. + slot_id: the slot this client owns. + """ + + def __init__(self, transport: SlotTransport, slot_id: int): + self._transport = transport + self._slot_id = slot_id + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + """Submit an observation and block until the action is ready.""" + self._transport._slot_submit(self._slot_id, td) + return self._transport._slot_recv(self._slot_id) + + +class SlotTransport(InferenceTransport): + """Lock-free, in-process transport using per-env slots. + + Each actor thread owns a dedicated *slot*. Submitting an observation + writes to the slot without any lock (each slot is accessed by exactly + one writer thread). The server sweeps slots to find ready ones, collects + observations, runs the model, and writes actions back via per-slot events. + + This eliminates: + + * The shared ``threading.Lock`` that ``ThreadingTransport`` uses for + every ``submit()`` and ``drain()``. + * ``concurrent.futures.Future`` allocations (one per inference request). + + The trade-off is that the number of slots is fixed at construction time + (equal to the number of environments). + + Args: + num_slots (int): number of slots (one per environment / actor thread). + + Keyword Args: + preallocate (bool, optional): if ``True``, a contiguous observation + buffer of shape ``[num_slots, ...]`` is allocated on the first + submit. Subsequent submits copy into the buffer in-place + (``update_``). Defaults to ``False`` because the extra copy + into the buffer is not currently compensated by the batching + path (``lazy_stack`` still calls ``torch.stack``). + + .. note:: + This transport is only suitable for in-process threading scenarios + (the default for :class:`~torchrl.collectors.AsyncBatchedCollector` + with ``policy_backend="threading"``). + """ + + def __init__(self, num_slots: int, *, preallocate: bool = False): + self._num_slots = num_slots + self._preallocate = preallocate + self._next_slot = 0 + + # Per-slot observation storage (written by env thread, read by server) + self._obs: list[TensorDictBase | None] = [None] * num_slots + + # Per-slot readiness flag (True = observation ready for server) + # Under CPython's GIL, bool assignment is atomic. + self._obs_ready: list[bool] = [False] * num_slots + + # Per-slot action storage (written by server, read by env thread) + self._actions: list[TensorDictBase | BaseException | None] = [None] * num_slots + + # Per-slot events: server sets after writing the action + self._action_events: list[threading.Event] = [ + threading.Event() for _ in range(num_slots) + ] + + # Global event: any env thread sets to wake the server + self._has_work = threading.Event() + + # Pre-allocated observation buffer (lazily initialised) + self._obs_buffer: TensorDictBase | None = None + + # -- actor (env-thread) API ----------------------------------------------- + + def _slot_submit(self, slot_id: int, td: TensorDictBase) -> None: + """Write observation into the slot (no lock required).""" + if self._obs_buffer is not None: + # Copy into pre-allocated buffer (no new allocation) + self._obs_buffer[slot_id].update_(td) + else: + self._obs[slot_id] = td + self._obs_ready[slot_id] = True + self._has_work.set() + + def _slot_recv(self, slot_id: int) -> TensorDictBase: + """Block until the server writes an action into the slot.""" + self._action_events[slot_id].wait() + self._action_events[slot_id].clear() + result = self._actions[slot_id] + self._actions[slot_id] = None + if isinstance(result, BaseException): + raise result + return result + + # -- InferenceTransport interface ----------------------------------------- + + def client(self) -> _SlotClient: + """Create a slot-bound client for one actor thread.""" + slot_id = self._next_slot + self._next_slot += 1 + return _SlotClient(self, slot_id) + + def submit(self, td: TensorDictBase): + """Not supported -- use :meth:`client` to get a slot-bound callable.""" + raise NotImplementedError( + "SlotTransport does not support submit(). " + "Use client() to obtain a slot-bound callable." + ) + + def wait_for_work(self, timeout: float) -> None: + """Block until at least one slot has a ready observation.""" + self._has_work.wait(timeout=timeout) + self._has_work.clear() + + def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[int]]: + """Sweep slots and return (observations, slot_ids) for ready ones.""" + # Lazily initialise the pre-allocated buffer on the first drain + # that finds ready observations. + if self._preallocate and self._obs_buffer is None: + for i in range(self._num_slots): + if self._obs_ready[i] and self._obs[i] is not None: + self._obs_buffer = ( + self._obs[i] + .unsqueeze(0) + .expand(self._num_slots) + .clone() + .contiguous() + ) + break + + items: list[TensorDictBase] = [] + slot_ids: list[int] = [] + for i in range(self._num_slots): + if self._obs_ready[i]: + self._obs_ready[i] = False + if self._obs_buffer is not None: + # Flush first-time observations that arrived before the + # buffer existed into the buffer. + if self._obs[i] is not None: + self._obs_buffer[i].update_(self._obs[i]) + self._obs[i] = None + items.append(self._obs_buffer[i]) + else: + items.append(self._obs[i]) + self._obs[i] = None + slot_ids.append(i) + if len(slot_ids) >= max_items: + break + return items, slot_ids + + def resolve(self, callback: int, result: TensorDictBase) -> None: + """Write the action into the slot and wake the waiting env thread.""" + self._actions[callback] = result + self._action_events[callback].set() + + def resolve_exception(self, callback: int, exc: BaseException) -> None: + """Propagate an exception to the waiting env thread.""" + self._actions[callback] = exc + self._action_events[callback].set()