Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions docs/source/reference/modules_inference_server.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,86 @@ The inference server provides auto-batching model serving for RL actors.
Multiple actors submit individual TensorDicts; the server transparently
batches them, runs a single model forward pass, and routes results back.

Core API
--------

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

InferenceServer
InferenceClient
InferenceTransport

Transport Backends
------------------

.. autosummary::
:toctree: generated/
:template: rl_template_noinherit.rst

ThreadingTransport
MPTransport
RayTransport
MonarchTransport

Usage
-----

The simplest setup uses :class:`ThreadingTransport` for actors that are
threads in the same process:

.. code-block:: python

from tensordict.nn import TensorDictModule
from torchrl.modules.inference_server import (
InferenceServer,
ThreadingTransport,
)
import torch.nn as nn
import concurrent.futures

policy = TensorDictModule(
nn.Sequential(nn.Linear(8, 64), nn.ReLU(), nn.Linear(64, 4)),
in_keys=["observation"],
out_keys=["action"],
)

transport = ThreadingTransport()
server = InferenceServer(policy, transport, max_batch_size=32)
server.start()
client = transport.client()

# actor threads call client(td) -- batched automatically
with concurrent.futures.ThreadPoolExecutor(16) as pool:
...

server.shutdown()

Weight Synchronisation
^^^^^^^^^^^^^^^^^^^^^^

The server integrates with :class:`~torchrl.weight_update.WeightSyncScheme`
to receive updated model weights from a trainer between inference batches:

.. code-block:: python

from torchrl.weight_update import SharedMemWeightSyncScheme

weight_sync = SharedMemWeightSyncScheme()
# Initialise on the trainer (sender) side first
weight_sync.init_on_sender(model=training_model, ...)

server = InferenceServer(
model=inference_model,
transport=ThreadingTransport(),
weight_sync=weight_sync,
)
server.start()

# Training loop
for batch in dataloader:
loss = loss_fn(training_model(batch))
loss.backward()
optimizer.step()
weight_sync.send(model=training_model) # pushed to server
120 changes: 120 additions & 0 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import concurrent.futures
import threading
import time

import pytest
import torch
Expand Down Expand Up @@ -557,3 +558,122 @@ def test_import_without_monarch(self):
def test_instantiation_without_monarch_raises(self):
with pytest.raises(ImportError, match="Monarch is required"):
MonarchTransport()


# =============================================================================
# Tests: WeightSyncScheme integration (Commit 6)
# =============================================================================


class _SimpleWeightSync:
"""Minimal mock that mimics the WeightSyncScheme receiver interface.

Stores a queue of weight TensorDicts. ``receive(timeout=...)`` pops
the next one and applies it to the model via
``TensorDict.from_module / to_module``.
"""

def __init__(self):
self._queue: list[TensorDictBase] = []
self._model = None
self.initialized_on_receiver = False
self.synchronized_on_receiver = False

def init_on_receiver(self, *, model_id, model=None, worker_idx=0, **kwargs):
self._model = model
self.initialized_on_receiver = True

def connect(self, *, worker_idx=0):
self.synchronized_on_receiver = True

def receive(self, timeout=None):
if self._queue:
weights = self._queue.pop(0)
weights.to_module(self._model)
return weights
return None

def push(self, weights: TensorDictBase):
"""Test helper: enqueue weights for the server to pick up."""
self._queue.append(weights)


class TestWeightSyncIntegration:
def test_weight_sync_init_called(self):
"""Server calls init_on_receiver and connect at startup."""
transport = ThreadingTransport()
policy = _make_policy()
ws = _SimpleWeightSync()

with InferenceServer(policy, transport, weight_sync=ws):
# Give the worker thread a moment to start
time.sleep(0.1)
assert ws.initialized_on_receiver
assert ws.synchronized_on_receiver

def test_weight_update_applied(self):
"""Weights pushed via weight_sync are applied to the model."""
transport = ThreadingTransport()
policy = _make_policy()
ws = _SimpleWeightSync()

with InferenceServer(
policy, transport, max_batch_size=4, weight_sync=ws
) as server:
client = transport.client()

# Get initial prediction
td = TensorDict({"observation": torch.ones(4)})
client(td)

# Mutate the model weights externally and push via weight_sync
new_weights = TensorDict.from_module(policy)
for key in new_weights.keys(True, True):
new_weights[key] = torch.zeros_like(new_weights[key])
ws.push(new_weights)

# Give the server loop a chance to apply the update
time.sleep(0.2)

# Now inference should reflect zero weights
result_after = client(td)
# With zero weights the linear output should be zero (bias=0 too)
assert torch.allclose(result_after["action"], torch.zeros(2), atol=1e-6)

def test_inference_continues_after_weight_update(self):
"""The server keeps serving after a weight update."""
transport = ThreadingTransport()
policy = _make_policy()
ws = _SimpleWeightSync()

with InferenceServer(policy, transport, max_batch_size=4, weight_sync=ws):
client = transport.client()

# Initial requests
for _ in range(5):
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()

# Push weight update
new_weights = TensorDict.from_module(policy)
ws.push(new_weights)

time.sleep(0.1)

# Continue making requests
for _ in range(5):
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()
assert result["action"].shape == (2,)

def test_no_weight_sync(self):
"""Server works fine when weight_sync is None."""
transport = ThreadingTransport()
policy = _make_policy()
with InferenceServer(policy, transport, max_batch_size=4):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()
37 changes: 36 additions & 1 deletion torchrl/modules/inference_server/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ class InferenceServer:
:class:`~torchrl.weight_update.WeightSyncScheme` used to receive
updated model weights from a trainer. When set, the server polls
for new weights between inference batches.
weight_sync_model_id (str, optional): the model identifier used when
initialising the weight sync scheme on the receiver side.
Default: ``"policy"``.

Example:
>>> from tensordict.nn import TensorDictModule
Expand Down Expand Up @@ -71,6 +74,7 @@ def __init__(
collate_fn: Callable | None = None,
device: torch.device | str | None = None,
weight_sync=None,
weight_sync_model_id: str = "policy",
):
self.model = model
self.transport = transport
Expand All @@ -79,9 +83,12 @@ def __init__(
self.collate_fn = collate_fn if collate_fn is not None else lazy_stack
self.device = torch.device(device) if device is not None else None
self.weight_sync = weight_sync
self._weight_sync_model_id = weight_sync_model_id

self._shutdown_event = threading.Event()
self._worker: threading.Thread | None = None
# Protects model access during weight updates
self._model_lock = threading.Lock()

# -- lifecycle ------------------------------------------------------------

Expand Down Expand Up @@ -119,9 +126,36 @@ def is_alive(self) -> bool:

# -- background loop ------------------------------------------------------

def _init_weight_sync(self) -> None:
"""Initialise the weight sync scheme on the receiver (server) side."""
ws = self.weight_sync
if ws is None:
return
if not ws.initialized_on_receiver:
ws.init_on_receiver(
model_id=self._weight_sync_model_id,
model=self.model,
worker_idx=0,
)
if not ws.synchronized_on_receiver:
ws.connect(worker_idx=0)

def _poll_weight_update(self) -> None:
"""Non-blocking check for fresh weights from the trainer."""
ws = self.weight_sync
if ws is None:
return
with self._model_lock:
ws.receive(timeout=0.0)

@torch.no_grad()
def _run(self) -> None:
self._init_weight_sync()

while not self._shutdown_event.is_set():
# Poll for weight updates between batches (non-blocking)
self._poll_weight_update()

self.transport.wait_for_work(timeout=self.timeout)

items, callbacks = self.transport.drain(self.max_batch_size)
Expand All @@ -133,7 +167,8 @@ def _run(self) -> None:
batch = batch.to(self.device)

try:
results = self.model(batch).unbind(0)
with self._model_lock:
results = self.model(batch).unbind(0)
if len(results) != len(callbacks):
raise RuntimeError(
f"Model returned {len(results)} results for a "
Expand Down
Loading