From 3291a4371540c7687bf6f09a8fc89ef3cc432a10 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Feb 2026 20:30:26 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- .../reference/modules_inference_server.rst | 76 +++++++++++ test/test_inference_server.py | 125 ++++++++++++++++++ torchrl/modules/inference_server/_server.py | 37 +++++- 3 files changed, 237 insertions(+), 1 deletion(-) diff --git a/docs/source/reference/modules_inference_server.rst b/docs/source/reference/modules_inference_server.rst index 980eb4d0849..1ef89426470 100644 --- a/docs/source/reference/modules_inference_server.rst +++ b/docs/source/reference/modules_inference_server.rst @@ -9,6 +9,9 @@ The inference server provides auto-batching model serving for RL actors. Multiple actors submit individual TensorDicts; the server transparently batches them, runs a single model forward pass, and routes results back. +Core API +-------- + .. autosummary:: :toctree: generated/ :template: rl_template_noinherit.rst @@ -16,3 +19,76 @@ batches them, runs a single model forward pass, and routes results back. InferenceServer InferenceClient InferenceTransport + +Transport Backends +------------------ + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + ThreadingTransport + MPTransport + RayTransport + MonarchTransport + +Usage +----- + +The simplest setup uses :class:`ThreadingTransport` for actors that are +threads in the same process: + +.. code-block:: python + + from tensordict.nn import TensorDictModule + from torchrl.modules.inference_server import ( + InferenceServer, + ThreadingTransport, + ) + import torch.nn as nn + import concurrent.futures + + policy = TensorDictModule( + nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)), + in_keys=["observation"], + out_keys=["action"], + ) + + transport = ThreadingTransport() + server = InferenceServer(policy, transport, max_batch_size=32) + server.start() + client = transport.client() + + # actor threads call client(td) -- batched automatically + with concurrent.futures.ThreadPoolExecutor(16) as pool: + ... + + server.shutdown() + +Weight Synchronisation +^^^^^^^^^^^^^^^^^^^^^^ + +The server integrates with :class:`~torchrl.weight_update.WeightSyncScheme` +to receive updated model weights from a trainer between inference batches: + +.. code-block:: python + + from torchrl.weight_update import SharedMemWeightSyncScheme + + weight_sync = SharedMemWeightSyncScheme() + # Initialise on the trainer (sender) side first + weight_sync.init_on_sender(model=training_model, ...) + + server = InferenceServer( + model=inference_model, + transport=ThreadingTransport(), + weight_sync=weight_sync, + ) + server.start() + + # Training loop + for batch in dataloader: + loss = loss_fn(training_model(batch)) + loss.backward() + optimizer.step() + weight_sync.send(model=training_model) # pushed to server diff --git a/test/test_inference_server.py b/test/test_inference_server.py index 185aa4335d4..b710223cae0 100644 --- a/test/test_inference_server.py +++ b/test/test_inference_server.py @@ -557,3 +557,128 @@ def test_import_without_monarch(self): def test_instantiation_without_monarch_raises(self): with pytest.raises(ImportError, match="Monarch is required"): MonarchTransport() + + +# ============================================================================= +# Tests: WeightSyncScheme integration (Commit 6) +# ============================================================================= + + +class _SimpleWeightSync: + """Minimal mock that mimics the WeightSyncScheme receiver interface. + + Stores a queue of weight TensorDicts. ``receive(timeout=...)`` pops + the next one and applies it to the model via + ``TensorDict.from_module / to_module``. + """ + + def __init__(self): + self._queue: list[TensorDictBase] = [] + self._model = None + self.initialized_on_receiver = False + self.synchronized_on_receiver = False + + def init_on_receiver(self, *, model_id, model=None, worker_idx=0, **kwargs): + self._model = model + self.initialized_on_receiver = True + + def connect(self, *, worker_idx=0): + self.synchronized_on_receiver = True + + def receive(self, timeout=None): + if self._queue: + weights = self._queue.pop(0) + weights.to_module(self._model) + return weights + return None + + def push(self, weights: TensorDictBase): + """Test helper: enqueue weights for the server to pick up.""" + self._queue.append(weights) + + +class TestWeightSyncIntegration: + def test_weight_sync_init_called(self): + """Server calls init_on_receiver and connect at startup.""" + transport = ThreadingTransport() + policy = _make_policy() + ws = _SimpleWeightSync() + + with InferenceServer(policy, transport, weight_sync=ws): + # Give the worker thread a moment to start + import time + + time.sleep(0.1) + assert ws.initialized_on_receiver + assert ws.synchronized_on_receiver + + def test_weight_update_applied(self): + """Weights pushed via weight_sync are applied to the model.""" + transport = ThreadingTransport() + policy = _make_policy() + ws = _SimpleWeightSync() + + with InferenceServer( + policy, transport, max_batch_size=4, weight_sync=ws + ) as server: + client = transport.client() + + # Get initial prediction + td = TensorDict({"observation": torch.ones(4)}) + client(td) + + # Mutate the model weights externally and push via weight_sync + new_weights = TensorDict.from_module(policy) + for key in new_weights.keys(True, True): + new_weights[key] = torch.zeros_like(new_weights[key]) + ws.push(new_weights) + + # Give the server loop a chance to apply the update + import time + + time.sleep(0.2) + + # Now inference should reflect zero weights + result_after = client(td) + # With zero weights the linear output should be zero (bias=0 too) + assert torch.allclose(result_after["action"], torch.zeros(2), atol=1e-6) + + def test_inference_continues_after_weight_update(self): + """The server keeps serving after a weight update.""" + transport = ThreadingTransport() + policy = _make_policy() + ws = _SimpleWeightSync() + + with InferenceServer(policy, transport, max_batch_size=4, weight_sync=ws): + client = transport.client() + + # Initial requests + for _ in range(5): + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + + # Push weight update + new_weights = TensorDict.from_module(policy) + ws.push(new_weights) + + import time + + time.sleep(0.1) + + # Continue making requests + for _ in range(5): + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() + assert result["action"].shape == (2,) + + def test_no_weight_sync(self): + """Server works fine when weight_sync is None.""" + transport = ThreadingTransport() + policy = _make_policy() + with InferenceServer(policy, transport, max_batch_size=4): + client = transport.client() + td = TensorDict({"observation": torch.randn(4)}) + result = client(td) + assert "action" in result.keys() diff --git a/torchrl/modules/inference_server/_server.py b/torchrl/modules/inference_server/_server.py index 4f2e27a86d0..42ed04fcbdc 100644 --- a/torchrl/modules/inference_server/_server.py +++ b/torchrl/modules/inference_server/_server.py @@ -42,6 +42,9 @@ class InferenceServer: :class:`~torchrl.weight_update.WeightSyncScheme` used to receive updated model weights from a trainer. When set, the server polls for new weights between inference batches. + weight_sync_model_id (str, optional): the model identifier used when + initialising the weight sync scheme on the receiver side. + Default: ``"policy"``. Example: >>> from tensordict.nn import TensorDictModule @@ -71,6 +74,7 @@ def __init__( collate_fn: Callable | None = None, device: torch.device | str | None = None, weight_sync=None, + weight_sync_model_id: str = "policy", ): self.model = model self.transport = transport @@ -79,9 +83,12 @@ def __init__( self.collate_fn = collate_fn if collate_fn is not None else lazy_stack self.device = torch.device(device) if device is not None else None self.weight_sync = weight_sync + self._weight_sync_model_id = weight_sync_model_id self._shutdown_event = threading.Event() self._worker: threading.Thread | None = None + # Protects model access during weight updates + self._model_lock = threading.Lock() # -- lifecycle ------------------------------------------------------------ @@ -119,9 +126,36 @@ def is_alive(self) -> bool: # -- background loop ------------------------------------------------------ + def _init_weight_sync(self) -> None: + """Initialise the weight sync scheme on the receiver (server) side.""" + ws = self.weight_sync + if ws is None: + return + if not ws.initialized_on_receiver: + ws.init_on_receiver( + model_id=self._weight_sync_model_id, + model=self.model, + worker_idx=0, + ) + if not ws.synchronized_on_receiver: + ws.connect(worker_idx=0) + + def _poll_weight_update(self) -> None: + """Non-blocking check for fresh weights from the trainer.""" + ws = self.weight_sync + if ws is None: + return + with self._model_lock: + ws.receive(timeout=0.0) + @torch.no_grad() def _run(self) -> None: + self._init_weight_sync() + while not self._shutdown_event.is_set(): + # Poll for weight updates between batches (non-blocking) + self._poll_weight_update() + self.transport.wait_for_work(timeout=self.timeout) items, callbacks = self.transport.drain(self.max_batch_size) @@ -133,7 +167,8 @@ def _run(self) -> None: batch = batch.to(self.device) try: - results = self.model(batch).unbind(0) + with self._model_lock: + results = self.model(batch).unbind(0) if len(results) != len(callbacks): raise RuntimeError( f"Model returned {len(results)} results for a "