diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 9fe8dac757d..d7e3aeaf736 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -19,6 +19,7 @@ InferenceClient, InferenceServer, InferenceTransport, + ThreadingTransport, ) @@ -214,3 +215,104 @@ def test_submit_returns_future(self): assert isinstance(fut, concurrent.futures.Future) result = fut.result(timeout=5.0) assert "action" in result.keys() + + +# ============================================================================= +# Tests: ThreadingTransport (Commit 2) +# ============================================================================= + + +class TestThreadingTransport: + def test_single_request(self): + transport = ThreadingTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + assert result["action"].shape == (2,) + + def test_concurrent_actors(self): + """Multiple threads submit concurrently; all get correct results.""" + transport = ThreadingTransport() + policy = _make_policy() + n_actors = 8 + n_requests = 50 + + results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)] + + def actor_fn(actor_id, client): + for _ in range(n_requests): + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + results_per_actor[actor_id].append(result) + + with InferenceServer(policy, transport, max_batch_size=16): + client = transport.client() + with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool: + futs = [pool.submit(actor_fn, i, client) for i in range(n_actors)] + concurrent.futures.wait(futs) + # re-raise any exceptions + for f in futs: + f.result() + + for actor_results in results_per_actor: + assert len(actor_results) == n_requests + for r in actor_results: + assert "action" in r.keys() + assert r["action"].shape == (2,) + + def test_timeout_fires_partial_batch(self): + """A single request should be processed even below max_batch_size.""" + transport = ThreadingTransport() + policy = _make_policy() + # max_batch_size is large, but timeout should still fire + with InferenceServer(policy, transport, max_batch_size=1024, timeout=0.05): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + + def test_max_batch_size_threading(self): + """Verify max_batch_size is respected with real threading transport.""" + max_bs = 4 + seen_sizes = [] + + def tracking_collate(items): + seen_sizes.append(len(items)) + return lazy_stack(items) + + transport = ThreadingTransport() + policy = _make_policy() + n = 20 + + # Submit many before starting so they queue up + futures = [ + transport.submit(TensorDict({"observation": torch.randn(4)})) + for _ in range(n) + ] + with InferenceServer( + policy, + transport, + max_batch_size=max_bs, + collate_fn=tracking_collate, + ): + for f in futures: + f.result(timeout=5.0) + + for s in seen_sizes: + assert s <= max_bs + + def test_model_exception_propagates(self): + """If the model raises, the exception propagates to the caller.""" + + def bad_model(td): + raise ValueError("model error") + + transport = ThreadingTransport() + with InferenceServer(bad_model, transport, max_batch_size=4): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + with pytest.raises(ValueError, match="model error"): + client(td) diff --git a/torchrl/modules/inference_server/__init__.py b/torchrl/modules/inference_server/__init__.py index 352246737b7..b6dd5287f48 100644 --- a/torchrl/modules/inference_server/__init__.py +++ b/torchrl/modules/inference_server/__init__.py @@ -4,10 +4,12 @@ # LICENSE file in the root directory of this source tree. from torchrl.modules.inference_server._server import InferenceClient, InferenceServer +from torchrl.modules.inference_server._threading import ThreadingTransport from torchrl.modules.inference_server._transport import InferenceTransport __all__ = [ "InferenceClient", "InferenceServer", "InferenceTransport", + "ThreadingTransport", ] diff --git a/torchrl/modules/inference_server/_threading.py b/torchrl/modules/inference_server/_threading.py new file mode 100644 index 00000000000..1cf552b9b6e --- /dev/null +++ b/torchrl/modules/inference_server/_threading.py @@ -0,0 +1,61 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import threading +from concurrent.futures import Future + +from tensordict.base import TensorDictBase + +from torchrl.modules.inference_server._transport import InferenceTransport + + +class ThreadingTransport(InferenceTransport): + """In-process transport for actors that are threads. + + Uses a shared list protected by a :class:`threading.Lock` as the request + queue and :class:`~concurrent.futures.Future` objects for response routing. + + This is the simplest backend and is appropriate when all actors live in the + same process (e.g. running in a :class:`~concurrent.futures.ThreadPoolExecutor`). + """ + + def __init__(self): + self._queue: list[TensorDictBase] = [] + self._futures: list[Future] = [] + self._lock = threading.Lock() + self._event = threading.Event() + + def submit(self, td: TensorDictBase) -> Future[TensorDictBase]: + """Enqueue a request and return a Future for the result.""" + fut: Future[TensorDictBase] = Future() + with self._lock: + self._queue.append(td) + self._futures.append(fut) + self._event.set() + return fut + + def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]: + """Dequeue up to *max_items* pending requests.""" + with self._lock: + n = min(len(self._queue), max_items) + items = self._queue[:n] + futs = self._futures[:n] + del self._queue[:n] + del self._futures[:n] + return items, futs + + def wait_for_work(self, timeout: float) -> None: + """Block until at least one request is enqueued or *timeout* elapses.""" + self._event.wait(timeout=timeout) + self._event.clear() + + def resolve(self, callback: Future, result: TensorDictBase) -> None: + """Set the result on the actor's Future.""" + callback.set_result(result) + + def resolve_exception(self, callback: Future, exc: BaseException) -> None: + """Set an exception on the actor's Future.""" + callback.set_exception(exc)