diff --git a/benchmarks/bench_collectors.py b/benchmarks/bench_collectors.py index 26ca0384192..181a14d2fbf 100644 --- a/benchmarks/bench_collectors.py +++ b/benchmarks/bench_collectors.py @@ -12,7 +12,8 @@ 2. Collector (ParallelEnv x N) -- single-process, N envs in sub-procs 3. MultiCollector (sync, x N) -- N sub-processes, sync delivery 4. MultiCollector (async, x N) -- N sub-processes, async delivery - 5. AsyncBatchedCollector (threading) -- AsyncEnvPool + InferenceServer + 5. AsyncBatched (env=thread, pol=thread) -- threading pool + threading transport + 6. AsyncBatched (env=mp, pol=thread) -- multiprocessing pool + threading transport """ from __future__ import annotations @@ -368,33 +369,33 @@ def policy_factory(): ) ) - # 5. AsyncBatchedCollector (threading backend) + # 5. AsyncBatchedCollector (env=threading, policy=threading) results.append( bench( - f"AsyncBatchedCollector threading (x{num_envs})", + f"AsyncBatched env=thread pol=thread (x{num_envs})", lambda: AsyncBatchedCollector( create_env_fn=[make_env_fn] * num_envs, policy=policy_factory(), frames_per_batch=frames_per_batch, total_frames=-1, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ), target_frames=total_frames, ) ) - # 6. AsyncBatchedCollector (multiprocessing backend) + # 6. AsyncBatchedCollector (env=multiprocessing, policy=threading) results.append( bench( - f"AsyncBatchedCollector mp (x{num_envs})", + f"AsyncBatched env=mp pol=thread (x{num_envs})", lambda: AsyncBatchedCollector( create_env_fn=[make_env_fn] * num_envs, policy=policy_factory(), frames_per_batch=frames_per_batch, total_frames=-1, max_batch_size=num_envs, - backend="multiprocessing", + env_backend="multiprocessing", ), target_frames=total_frames, ) diff --git a/examples/collectors/async_batched_collector.py b/examples/collectors/async_batched_collector.py index 981ce9cc05f..0d1dd254c31 100644 --- a/examples/collectors/async_batched_collector.py +++ b/examples/collectors/async_batched_collector.py @@ -6,14 +6,21 @@ Architecture: - An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel - using the chosen backend (``"threading"`` or ``"multiprocessing"``). + using the chosen ``env_backend`` (``"threading"`` or ``"multiprocessing"``). - One lightweight coordinator thread per environment owns a slot in the pool and an inference client. - An :class:`~torchrl.modules.InferenceServer` batches incoming observations - and runs a single forward pass. + and runs a single forward pass. The communication layer (transport) is + controlled by ``policy_backend`` (``"threading"``, ``"multiprocessing"``, + ``"ray"``, or ``"monarch"``). - There is no global synchronisation barrier -- fast envs keep stepping while slow ones wait for inference. +Backend parameters: + - ``backend`` -- global default for both env pool and policy transport. + - ``env_backend`` -- override for the env pool (falls back to ``backend``). + - ``policy_backend`` -- override for the transport (falls back to ``backend``). + The user only supplies: - A list of environment factories - A policy (or policy factory) diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 39d620b0c41..cfefa92cc84 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -22,6 +22,7 @@ InferenceTransport, MPTransport, RayTransport, + SlotTransport, ThreadingTransport, ) from torchrl.modules.inference_server._monarch import MonarchTransport @@ -728,7 +729,7 @@ def test_basic_collection(self): frames_per_batch=frames_per_batch, total_frames=total_frames, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) total_collected = 0 for batch in collector: @@ -746,7 +747,7 @@ def test_policy_factory(self): frames_per_batch=10, total_frames=20, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) total_collected = 0 for batch in collector: @@ -786,7 +787,7 @@ def test_yield_completed_trajectories(self): total_frames=30, yield_completed_trajectories=True, max_batch_size=num_envs, - backend="threading", + env_backend="threading", ) count = 0 for batch in collector: @@ -804,7 +805,7 @@ def test_shutdown_idempotent(self): policy=policy, frames_per_batch=10, total_frames=10, - backend="threading", + env_backend="threading", ) # Consume one batch to start for _batch in collector: @@ -820,7 +821,7 @@ def test_endless_collector(self): policy=policy, frames_per_batch=10, total_frames=-1, - backend="threading", + env_backend="threading", ) collected = 0 for batch in collector: @@ -857,9 +858,126 @@ def postproc(td): frames_per_batch=10, total_frames=20, postproc=postproc, - backend="threading", + env_backend="threading", ) for _ in collector: pass collector.shutdown() assert called["count"] >= 1 + + +# ============================================================================= +# Tests: SlotTransport +# ============================================================================= + + +class TestSlotTransport: + def test_single_request(self): + transport = SlotTransport(num_slots=4) + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + assert result["action"].shape == (2,) + + def test_concurrent_actors(self): + """Multiple threads submit concurrently via slot clients.""" + n_actors = 4 + n_requests = 30 + transport = SlotTransport(num_slots=n_actors) + policy = _make_policy() + + results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)] + clients = [transport.client() for _ in range(n_actors)] + + def actor_fn(actor_id): + for _ in range(n_requests): + td = TensorDict({"observation": torch.randn(4)}) + result = clients[actor_id](td) + results_per_actor[actor_id].append(result) + + with InferenceServer(policy, transport, max_batch_size=n_actors): + with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool: + futs = [pool.submit(actor_fn, i) for i in range(n_actors)] + concurrent.futures.wait(futs) + for f in futs: + f.result() + + for actor_results in results_per_actor: + assert len(actor_results) == n_requests + for r in actor_results: + assert "action" in r.keys() + assert r["action"].shape == (2,) + + def test_too_many_clients_raises(self): + """Creating more clients than slots raises RuntimeError.""" + transport = SlotTransport(num_slots=2) + transport.client() + transport.client() + with pytest.raises(RuntimeError, match="slots"): + transport.client() + + def test_submit_raises(self): + """Direct submit() on SlotTransport is not supported.""" + transport = SlotTransport(num_slots=1) + td = TensorDict({"observation": torch.randn(4)}) + with pytest.raises(NotImplementedError): + transport.submit(td) + + def test_exception_propagates(self): + """Model exceptions propagate through SlotTransport.""" + + def bad_model(td): + raise ValueError("slot model error") + + transport = SlotTransport(num_slots=1) + with InferenceServer(bad_model, transport, max_batch_size=4): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + with pytest.raises(ValueError, match="slot model error"): + client(td) + + +# ============================================================================= +# Tests: min_batch_size +# ============================================================================= + + +class TestMinBatchSize: + def test_min_batch_size_accumulates(self): + """With min_batch_size > 1, the server waits for enough items.""" + min_bs = 4 + seen_sizes = [] + + def tracking_collate(items): + seen_sizes.append(len(items)) + return lazy_stack(items) + + transport = ThreadingTransport() + policy = _make_policy() + n = 8 + + with InferenceServer( + policy, + transport, + max_batch_size=16, + min_batch_size=min_bs, + collate_fn=tracking_collate, + timeout=1.0, + ): + client = transport.client() + # Submit items from threads to give the server time to accumulate + with concurrent.futures.ThreadPoolExecutor(max_workers=n) as pool: + futs = [ + pool.submit( + lambda: client(TensorDict({"observation": torch.randn(4)})) + ) + for _ in range(n) + ] + for f in futs: + f.result(timeout=10.0) + + # At least one batch should have >= min_batch_size items + assert any(s >= min_bs for s in seen_sizes) diff --git a/torchrl/collectors/_async_batched.py b/torchrl/collectors/_async_batched.py index 0e0791ae17c..5c3bca481b5 100644 --- a/torchrl/collectors/_async_batched.py +++ b/torchrl/collectors/_async_batched.py @@ -21,6 +21,45 @@ _ENV_IDX_KEY = "env_index" +_POLICY_BACKENDS = ("threading", "multiprocessing", "ray", "monarch") +_ENV_BACKENDS = ("threading", "multiprocessing") + + +def _make_transport( + policy_backend: str, num_slots: int | None = None +) -> InferenceTransport: + """Create an :class:`InferenceTransport` from a backend name. + + Args: + policy_backend: one of ``"threading"``, ``"multiprocessing"``, + ``"ray"``, or ``"monarch"``. + num_slots: when set and ``policy_backend="threading"``, a + :class:`~torchrl.modules.SlotTransport` is created instead of + the generic :class:`~torchrl.modules.ThreadingTransport`. + """ + if policy_backend == "threading": + if num_slots is not None: + from torchrl.modules.inference_server._slot import SlotTransport + + return SlotTransport(num_slots) + return ThreadingTransport() + if policy_backend == "multiprocessing": + from torchrl.modules.inference_server._mp import MPTransport + + return MPTransport() + if policy_backend == "ray": + from torchrl.modules.inference_server._ray import RayTransport + + return RayTransport() + if policy_backend == "monarch": + from torchrl.modules.inference_server._monarch import MonarchTransport + + return MonarchTransport() + raise ValueError( + f"Unknown policy_backend {policy_backend!r}. " + f"Expected one of {_POLICY_BACKENDS}." + ) + def _env_loop( pool: AsyncEnvPool, @@ -47,9 +86,7 @@ def _env_loop( while not shutdown_event.is_set(): pool.async_step_and_maybe_reset_send(action_td, env_index=env_id) - cur_td, next_obs = pool.async_step_and_maybe_reset_recv( - env_index=env_id - ) + cur_td, next_obs = pool.async_step_and_maybe_reset_recv(env_index=env_id) cur_td.set(_ENV_IDX_KEY, env_id) result_queue.put(cur_td) if shutdown_event.is_set(): @@ -104,22 +141,35 @@ class AsyncBatchedCollector(BaseCollector): max_batch_size (int, optional): upper bound on the number of requests the inference server processes in a single forward pass. Defaults to ``64``. + min_batch_size (int, optional): minimum number of requests the + inference server accumulates before dispatching a batch. After + the first request arrives the server keeps draining for up to + ``server_timeout`` seconds until this many items are collected. + ``1`` (default) dispatches immediately. server_timeout (float, optional): seconds the server waits for work before dispatching a partial batch. Defaults to ``0.01``. transport (InferenceTransport, optional): a pre-built transport - backend. When ``None`` (default) a - :class:`~torchrl.modules.ThreadingTransport` is created - automatically (since worker threads always live in the main - process). Pass a :class:`~torchrl.modules.RayTransport` or - :class:`~torchrl.modules.MonarchTransport` for distributed - setups where the inference server is remote. + object. When provided, it takes precedence over + ``policy_backend``. When ``None`` (default) a transport is + created automatically from the resolved ``policy_backend``. device (torch.device or str, optional): device for policy inference. Passed to the inference server. Defaults to ``None``. - backend (str, optional): backend for the + backend (str, optional): global default backend for both + environments and policy inference. Specific overrides + ``env_backend`` and ``policy_backend`` take precedence when set. + One of ``"threading"``, ``"multiprocessing"``, ``"ray"``, or + ``"monarch"``. Defaults to ``"threading"``. + env_backend (str, optional): backend for the :class:`~torchrl.envs.AsyncEnvPool` that runs environments. One - of ``"threading"`` or ``"multiprocessing"``. The coordinator - threads are always Python threads regardless of this setting. - Defaults to ``"threading"``. + of ``"threading"`` or ``"multiprocessing"``. Falls back to + ``backend`` when ``None``. The coordinator threads are always + Python threads regardless of this setting. Defaults to ``None``. + policy_backend (str, optional): backend for the inference transport + used to communicate with the + :class:`~torchrl.modules.InferenceServer`. One of + ``"threading"``, ``"multiprocessing"``, ``"ray"``, or + ``"monarch"``. Falls back to ``backend`` when ``None``. + Defaults to ``None``. reset_at_each_iter (bool, optional): whether to reset all envs at the start of every collection batch. Defaults to ``False``. postproc (Callable, optional): post-processing transform applied to @@ -169,10 +219,16 @@ def __init__( frames_per_batch: int, total_frames: int = -1, max_batch_size: int = 64, + min_batch_size: int = 1, server_timeout: float = 0.01, transport: InferenceTransport | None = None, device: torch.device | str | None = None, - backend: Literal["threading", "multiprocessing"] = "threading", + backend: Literal[ + "threading", "multiprocessing", "ray", "monarch" + ] = "threading", + env_backend: Literal["threading", "multiprocessing"] | None = None, + policy_backend: Literal["threading", "multiprocessing", "ray", "monarch"] + | None = None, reset_at_each_iter: bool = False, postproc: Callable[[TensorDictBase], TensorDictBase] | None = None, yield_completed_trajectories: bool = False, @@ -196,12 +252,26 @@ def __init__( raise TypeError("create_env_fn must be a list of env factories.") self._create_env_fn = list(create_env_fn) self._num_envs = len(create_env_fn) - self._backend = backend self._create_env_kwargs = create_env_kwargs + # ---- resolve backends ------------------------------------------------- + effective_env_backend = env_backend if env_backend is not None else backend + effective_policy_backend = ( + policy_backend if policy_backend is not None else backend + ) + if effective_env_backend not in _ENV_BACKENDS: + raise ValueError( + f"env_backend={effective_env_backend!r} is not supported. " + f"Expected one of {_ENV_BACKENDS}." + ) + self._env_backend = effective_env_backend + self._policy_backend = effective_policy_backend + # ---- build transport -------------------------------------------------- if transport is None: - transport = ThreadingTransport() + transport = _make_transport( + effective_policy_backend, num_slots=self._num_envs + ) self._transport = transport # ---- build inference server ------------------------------------------- @@ -209,6 +279,7 @@ def __init__( model=policy, transport=transport, max_batch_size=max_batch_size, + min_batch_size=min_batch_size, timeout=server_timeout, device=device, weight_sync=weight_sync, @@ -252,7 +323,7 @@ def _ensure_started(self) -> None: kwargs["create_env_kwargs"] = self._create_env_kwargs self._env_pool = AsyncEnvPool( self._create_env_fn, - backend=self._backend, + backend=self._env_backend, **kwargs, ) @@ -303,9 +374,18 @@ def _rollout_frames(self) -> TensorDictBase: transitions: list[TensorDictBase] = [] while collected < self.frames_per_batch: + # Block for at least one transition td = rq.get() transitions.append(td) collected += td.numel() + # Batch-drain any additional items already in the queue + while collected < self.frames_per_batch: + try: + td = rq.get_nowait() + except queue.Empty: + break + transitions.append(td) + collected += td.numel() if self.verbose: torchrl_logger.debug( f"AsyncBatchedCollector: {collected}/{self.frames_per_batch} frames" diff --git a/torchrl/envs/async_envs.py b/torchrl/envs/async_envs.py index 19cd53dbb48..8103cc53fa8 100644 --- a/torchrl/envs/async_envs.py +++ b/torchrl/envs/async_envs.py @@ -673,9 +673,7 @@ def async_step_and_maybe_reset_send( for _env_idx, local_td in _zip_strict(env_idx, local_tds): if not _per_env: self._current_step_reset = self._current_step_reset + 1 - self.input_queue[_env_idx].put( - ("step_and_maybe_reset", local_td, _per_env) - ) + self.input_queue[_env_idx].put(("step_and_maybe_reset", local_td, _per_env)) def async_step_and_maybe_reset_recv( self, min_get: int = 1, env_index: int | None = None @@ -807,29 +805,29 @@ def _env_exec( elif msg == "batch_size": output_queue.put(env.batch_size) elif msg == "reset": - data = env.reset(data.copy()) + # No .copy() needed: data was deserialized from the queue + # and is not referenced after this call. + data = env.reset(data) data.set(cls._env_idx_key, NonTensorData(i)) target = per_env_reset_queue if per_env else reset_queue target.put(data) elif msg == "_reset": - data = env._reset(data.copy()) + data = env._reset(data) data.set(cls._env_idx_key, NonTensorData(i)) reset_queue.put(data) elif msg == "step_and_maybe_reset": - data, data_ = env.step_and_maybe_reset(data.copy()) + data, data_ = env.step_and_maybe_reset(data) data.set(cls._env_idx_key, NonTensorData(i)) data_.set(cls._env_idx_key, NonTensorData(i)) - target = ( - per_env_step_reset_queue if per_env else step_reset_queue - ) + target = per_env_step_reset_queue if per_env else step_reset_queue target.put((data, data_)) elif msg == "step": - data = env.step(data.copy()) + data = env.step(data) data.set(cls._env_idx_key, NonTensorData(i)) target = per_env_step_queue if per_env else step_queue target.put(data) elif msg == "_step": - data = env._step(data.copy()) + data = env._step(data) data.set(cls._env_idx_key, NonTensorData(i)) step_queue.put(data) elif msg == "shutdown": diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py index e68f98626fd..388ba33028b 100644 --- a/torchrl/modules/inference_server/__init__.py +++ b/torchrl/modules/inference_server/__init__.py @@ -7,6 +7,7 @@ from torchrl.modules.inference_server._mp import MPTransport from torchrl.modules.inference_server._ray import RayTransport from torchrl.modules.inference_server._server import InferenceClient, InferenceServer +from torchrl.modules.inference_server._slot import SlotTransport from torchrl.modules.inference_server._threading import ThreadingTransport from torchrl.modules.inference_server._transport import InferenceTransport @@ -17,5 +18,6 @@ "MonarchTransport", "MPTransport", "RayTransport", + "SlotTransport", "ThreadingTransport", ] diff --git a/torchrl/modules/inference_server/_server.py b/torchrl/modules/inference_server/_server.py index 42ed04fcbdc..b4cbf09848b 100644 --- a/torchrl/modules/inference_server/_server.py +++ b/torchrl/modules/inference_server/_server.py @@ -5,6 +5,7 @@ from __future__ import annotations import threading +import time from collections.abc import Callable from concurrent.futures import Future @@ -32,6 +33,11 @@ class InferenceServer: Keyword Args: max_batch_size (int, optional): upper bound on the number of requests processed in a single forward pass. Default: ``64``. + min_batch_size (int, optional): minimum number of requests to + accumulate before dispatching a batch. After the first request + arrives the server keeps draining for up to ``timeout`` seconds + until at least this many items are collected. ``1`` (default) + dispatches immediately. timeout (float, optional): seconds to wait for new work before dispatching a partial batch. Default: ``0.01``. collate_fn (Callable, optional): function used to stack a list of @@ -70,6 +76,7 @@ def __init__( transport: InferenceTransport, *, max_batch_size: int = 64, + min_batch_size: int = 1, timeout: float = 0.01, collate_fn: Callable | None = None, device: torch.device | str | None = None, @@ -79,6 +86,7 @@ def __init__( self.model = model self.transport = transport self.max_batch_size = max_batch_size + self.min_batch_size = min_batch_size self.timeout = timeout self.collate_fn = collate_fn if collate_fn is not None else lazy_stack self.device = torch.device(device) if device is not None else None @@ -162,6 +170,20 @@ def _run(self) -> None: if not items: continue + # Accumulate up to min_batch_size (or until timeout expires) + if len(items) < self.min_batch_size: + deadline = time.monotonic() + self.timeout + while len(items) < self.min_batch_size: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + self.transport.wait_for_work(timeout=remaining) + more_items, more_cbs = self.transport.drain( + self.max_batch_size - len(items) + ) + items.extend(more_items) + callbacks.extend(more_cbs) + batch = self.collate_fn(items) if self.device is not None: batch = batch.to(self.device) diff --git a/torchrl/modules/inference_server/_slot.py b/torchrl/modules/inference_server/_slot.py new file mode 100644 index 00000000000..82781c7d3b1 --- /dev/null +++ b/torchrl/modules/inference_server/_slot.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import itertools +import threading + +from tensordict.base import TensorDictBase + +from torchrl.modules.inference_server._transport import InferenceTransport + + +class _SlotClient: + """Actor-side handle for a :class:`SlotTransport` slot. + + Each client owns a single slot (identified by ``slot_id``). Calling the + client writes the observation into the slot and blocks until the server + writes the action back. + + Args: + transport: the parent :class:`SlotTransport`. + slot_id: the slot this client owns. + """ + + def __init__(self, transport: SlotTransport, slot_id: int): + self._transport = transport + self._slot_id = slot_id + + def __call__(self, td: TensorDictBase) -> TensorDictBase: + """Submit an observation and block until the action is ready.""" + self._transport._slot_submit(self._slot_id, td) + return self._transport._slot_recv(self._slot_id) + + +class SlotTransport(InferenceTransport): + """Lock-free, in-process transport using per-env slots. + + Each actor thread owns a dedicated *slot*. Submitting an observation + writes to the slot without any lock (each slot is accessed by exactly + one writer thread). The server sweeps slots to find ready ones, collects + observations, runs the model, and writes actions back via per-slot events. + + This eliminates: + + * The shared ``threading.Lock`` that ``ThreadingTransport`` uses for + every ``submit()`` and ``drain()``. + * ``concurrent.futures.Future`` allocations (one per inference request). + + The trade-off is that the number of slots is fixed at construction time + (equal to the number of environments). + + Args: + num_slots (int): number of slots (one per environment / actor thread). + + Keyword Args: + preallocate (bool, optional): if ``True``, a contiguous observation + buffer of shape ``[num_slots, ...]`` is allocated on the first + submit. Subsequent submits copy into the buffer in-place + (``update_``). Defaults to ``False`` because the extra copy + into the buffer is not currently compensated by the batching + path (``lazy_stack`` still calls ``torch.stack``). + + .. note:: + This transport is only suitable for in-process threading scenarios + (the default for :class:`~torchrl.collectors.AsyncBatchedCollector` + with ``policy_backend="threading"``). + """ + + def __init__(self, num_slots: int, *, preallocate: bool = False): + self._num_slots = num_slots + self._preallocate = preallocate + self._slot_counter = itertools.count() + + # Per-slot observation storage (written by env thread, read by server) + self._obs: list[TensorDictBase | None] = [None] * num_slots + + # Per-slot readiness flag (True = observation ready for server) + # Under CPython's GIL, bool assignment is atomic. + self._obs_ready: list[bool] = [False] * num_slots + + # Per-slot action storage (written by server, read by env thread) + self._actions: list[TensorDictBase | BaseException | None] = [None] * num_slots + + # Per-slot events: server sets after writing the action + self._action_events: list[threading.Event] = [ + threading.Event() for _ in range(num_slots) + ] + + # Condition variable: env threads notify, server waits. + # Using a Condition instead of a bare Event avoids the race where + # clear() in wait_for_work drops a signal set between wait() and + # clear(). + self._work_cond = threading.Condition(threading.Lock()) + + # Pre-allocated observation buffer (lazily initialised) + self._obs_buffer: TensorDictBase | None = None + + # -- actor (env-thread) API ----------------------------------------------- + + def _slot_submit(self, slot_id: int, td: TensorDictBase) -> None: + """Write observation into the slot (no lock required).""" + if self._obs_buffer is not None: + # Copy into pre-allocated buffer (no new allocation) + self._obs_buffer[slot_id].update_(td) + else: + self._obs[slot_id] = td + self._obs_ready[slot_id] = True + with self._work_cond: + self._work_cond.notify() + + def _slot_recv(self, slot_id: int) -> TensorDictBase: + """Block until the server writes an action into the slot.""" + self._action_events[slot_id].wait() + self._action_events[slot_id].clear() + result = self._actions[slot_id] + self._actions[slot_id] = None + if isinstance(result, BaseException): + raise result + return result + + # -- InferenceTransport interface ----------------------------------------- + + def client(self) -> _SlotClient: + """Create a slot-bound client for one actor thread.""" + slot_id = next(self._slot_counter) + if slot_id >= self._num_slots: + raise RuntimeError( + f"SlotTransport has {self._num_slots} slots but " + f"client() was called {slot_id + 1} times. " + "Create a SlotTransport with more slots." + ) + return _SlotClient(self, slot_id) + + def submit(self, td: TensorDictBase): + """Not supported -- use :meth:`client` to get a slot-bound callable.""" + raise NotImplementedError( + "SlotTransport does not support submit(). " + "Use client() to obtain a slot-bound callable." + ) + + def wait_for_work(self, timeout: float) -> None: + """Block until at least one slot has a ready observation.""" + with self._work_cond: + # Check if any slot is already ready before waiting + if any(self._obs_ready): + return + self._work_cond.wait(timeout=timeout) + + def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[int]]: + """Sweep slots and return (observations, slot_ids) for ready ones.""" + # Lazily initialise the pre-allocated buffer on the first drain + # that finds ready observations. + if self._preallocate and self._obs_buffer is None: + for i in range(self._num_slots): + if self._obs_ready[i] and self._obs[i] is not None: + self._obs_buffer = ( + self._obs[i] + .unsqueeze(0) + .expand(self._num_slots) + .clone() + .contiguous() + ) + break + + items: list[TensorDictBase] = [] + slot_ids: list[int] = [] + for i in range(self._num_slots): + if self._obs_ready[i]: + self._obs_ready[i] = False + if self._obs_buffer is not None: + # Flush first-time observations that arrived before the + # buffer existed into the buffer. + if self._obs[i] is not None: + self._obs_buffer[i].update_(self._obs[i]) + self._obs[i] = None + items.append(self._obs_buffer[i]) + else: + items.append(self._obs[i]) + self._obs[i] = None + slot_ids.append(i) + if len(slot_ids) >= max_items: + break + return items, slot_ids + + def resolve(self, callback: int, result: TensorDictBase) -> None: + """Write the action into the slot and wake the waiting env thread.""" + self._actions[callback] = result + self._action_events[callback].set() + + def resolve_exception(self, callback: int, exc: BaseException) -> None: + """Propagate an exception to the waiting env thread.""" + self._actions[callback] = exc + self._action_events[callback].set()