Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
InferenceClient,
InferenceServer,
InferenceTransport,
ThreadingTransport,
)


Expand Down Expand Up @@ -214,3 +215,104 @@ def test_submit_returns_future(self):
assert isinstance(fut, concurrent.futures.Future)
result = fut.result(timeout=5.0)
assert "action" in result.keys()


# =============================================================================
# Tests: ThreadingTransport (Commit 2)
# =============================================================================


class TestThreadingTransport:
def test_single_request(self):
transport = ThreadingTransport()
policy = _make_policy()
with InferenceServer(policy, transport, max_batch_size=4):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()
assert result["action"].shape == (2,)

def test_concurrent_actors(self):
"""Multiple threads submit concurrently; all get correct results."""
transport = ThreadingTransport()
policy = _make_policy()
n_actors = 8
n_requests = 50

results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)]

def actor_fn(actor_id, client):
for _ in range(n_requests):
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
results_per_actor[actor_id].append(result)

with InferenceServer(policy, transport, max_batch_size=16):
client = transport.client()
with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool:
futs = [pool.submit(actor_fn, i, client) for i in range(n_actors)]
concurrent.futures.wait(futs)
# re-raise any exceptions
for f in futs:
f.result()

for actor_results in results_per_actor:
assert len(actor_results) == n_requests
for r in actor_results:
assert "action" in r.keys()
assert r["action"].shape == (2,)

def test_timeout_fires_partial_batch(self):
"""A single request should be processed even below max_batch_size."""
transport = ThreadingTransport()
policy = _make_policy()
# max_batch_size is large, but timeout should still fire
with InferenceServer(policy, transport, max_batch_size=1024, timeout=0.05):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()

def test_max_batch_size_threading(self):
"""Verify max_batch_size is respected with real threading transport."""
max_bs = 4
seen_sizes = []

def tracking_collate(items):
seen_sizes.append(len(items))
return lazy_stack(items)

transport = ThreadingTransport()
policy = _make_policy()
n = 20

# Submit many before starting so they queue up
futures = [
transport.submit(TensorDict({"observation": torch.randn(4)}))
for _ in range(n)
]
with InferenceServer(
policy,
transport,
max_batch_size=max_bs,
collate_fn=tracking_collate,
):
for f in futures:
f.result(timeout=5.0)

for s in seen_sizes:
assert s <= max_bs

def test_model_exception_propagates(self):
"""If the model raises, the exception propagates to the caller."""

def bad_model(td):
raise ValueError("model error")

transport = ThreadingTransport()
with InferenceServer(bad_model, transport, max_batch_size=4):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
with pytest.raises(ValueError, match="model error"):
client(td)
2 changes: 2 additions & 0 deletions torchrl/modules/inference_server/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
# LICENSE file in the root directory of this source tree.

from torchrl.modules.inference_server._server import InferenceClient, InferenceServer
from torchrl.modules.inference_server._threading import ThreadingTransport
from torchrl.modules.inference_server._transport import InferenceTransport

__all__ = [
"InferenceClient",
"InferenceServer",
"InferenceTransport",
"ThreadingTransport",
]
61 changes: 61 additions & 0 deletions torchrl/modules/inference_server/_threading.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import threading
from concurrent.futures import Future

from tensordict.base import TensorDictBase

from torchrl.modules.inference_server._transport import InferenceTransport


class ThreadingTransport(InferenceTransport):
"""In-process transport for actors that are threads.

Uses a shared list protected by a :class:`threading.Lock` as the request
queue and :class:`~concurrent.futures.Future` objects for response routing.

This is the simplest backend and is appropriate when all actors live in the
same process (e.g. running in a :class:`~concurrent.futures.ThreadPoolExecutor`).
"""

def __init__(self):
self._queue: list[TensorDictBase] = []
self._futures: list[Future] = []
self._lock = threading.Lock()
self._event = threading.Event()

def submit(self, td: TensorDictBase) -> Future[TensorDictBase]:
"""Enqueue a request and return a Future for the result."""
fut: Future[TensorDictBase] = Future()
with self._lock:
self._queue.append(td)
self._futures.append(fut)
self._event.set()
return fut

def drain(self, max_items: int) -> tuple[list[TensorDictBase], list[Future]]:
"""Dequeue up to *max_items* pending requests."""
with self._lock:
n = min(len(self._queue), max_items)
items = self._queue[:n]
futs = self._futures[:n]
del self._queue[:n]
del self._futures[:n]
return items, futs

def wait_for_work(self, timeout: float) -> None:
"""Block until at least one request is enqueued or *timeout* elapses."""
self._event.wait(timeout=timeout)
self._event.clear()

def resolve(self, callback: Future, result: TensorDictBase) -> None:
"""Set the result on the actor's Future."""
callback.set_result(result)

def resolve_exception(self, callback: Future, exc: BaseException) -> None:
"""Set an exception on the actor's Future."""
callback.set_exception(exc)
Loading