Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions benchmarks/bench_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
2. Collector (ParallelEnv x N) -- single-process, N envs in sub-procs
3. MultiCollector (sync, x N) -- N sub-processes, sync delivery
4. MultiCollector (async, x N) -- N sub-processes, async delivery
5. AsyncBatchedCollector (threading) -- AsyncEnvPool + InferenceServer
5. AsyncBatched (env=thread, pol=thread) -- threading pool + threading transport
6. AsyncBatched (env=mp, pol=thread) -- multiprocessing pool + threading transport
"""
from __future__ import annotations

Expand Down Expand Up @@ -368,33 +369,33 @@ def policy_factory():
)
)

# 5. AsyncBatchedCollector (threading backend)
# 5. AsyncBatchedCollector (env=threading, policy=threading)
results.append(
bench(
f"AsyncBatchedCollector threading (x{num_envs})",
f"AsyncBatched env=thread pol=thread (x{num_envs})",
lambda: AsyncBatchedCollector(
create_env_fn=[make_env_fn] * num_envs,
policy=policy_factory(),
frames_per_batch=frames_per_batch,
total_frames=-1,
max_batch_size=num_envs,
backend="threading",
env_backend="threading",
),
target_frames=total_frames,
)
)

# 6. AsyncBatchedCollector (multiprocessing backend)
# 6. AsyncBatchedCollector (env=multiprocessing, policy=threading)
results.append(
bench(
f"AsyncBatchedCollector mp (x{num_envs})",
f"AsyncBatched env=mp pol=thread (x{num_envs})",
lambda: AsyncBatchedCollector(
create_env_fn=[make_env_fn] * num_envs,
policy=policy_factory(),
frames_per_batch=frames_per_batch,
total_frames=-1,
max_batch_size=num_envs,
backend="multiprocessing",
env_backend="multiprocessing",
),
target_frames=total_frames,
)
Expand Down
11 changes: 9 additions & 2 deletions examples/collectors/async_batched_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@

Architecture:
- An :class:`~torchrl.envs.AsyncEnvPool` runs environments in parallel
using the chosen backend (``"threading"`` or ``"multiprocessing"``).
using the chosen ``env_backend`` (``"threading"`` or ``"multiprocessing"``).
- One lightweight coordinator thread per environment owns a slot in the pool
and an inference client.
- An :class:`~torchrl.modules.InferenceServer` batches incoming observations
and runs a single forward pass.
and runs a single forward pass. The communication layer (transport) is
controlled by ``policy_backend`` (``"threading"``, ``"multiprocessing"``,
``"ray"``, or ``"monarch"``).
- There is no global synchronisation barrier -- fast envs keep stepping
while slow ones wait for inference.

Backend parameters:
- ``backend`` -- global default for both env pool and policy transport.
- ``env_backend`` -- override for the env pool (falls back to ``backend``).
- ``policy_backend`` -- override for the transport (falls back to ``backend``).

The user only supplies:
- A list of environment factories
- A policy (or policy factory)
Expand Down
130 changes: 124 additions & 6 deletions test/test_inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
InferenceTransport,
MPTransport,
RayTransport,
SlotTransport,
ThreadingTransport,
)
from torchrl.modules.inference_server._monarch import MonarchTransport
Expand Down Expand Up @@ -728,7 +729,7 @@ def test_basic_collection(self):
frames_per_batch=frames_per_batch,
total_frames=total_frames,
max_batch_size=num_envs,
backend="threading",
env_backend="threading",
)
total_collected = 0
for batch in collector:
Expand All @@ -746,7 +747,7 @@ def test_policy_factory(self):
frames_per_batch=10,
total_frames=20,
max_batch_size=num_envs,
backend="threading",
env_backend="threading",
)
total_collected = 0
for batch in collector:
Expand Down Expand Up @@ -786,7 +787,7 @@ def test_yield_completed_trajectories(self):
total_frames=30,
yield_completed_trajectories=True,
max_batch_size=num_envs,
backend="threading",
env_backend="threading",
)
count = 0
for batch in collector:
Expand All @@ -804,7 +805,7 @@ def test_shutdown_idempotent(self):
policy=policy,
frames_per_batch=10,
total_frames=10,
backend="threading",
env_backend="threading",
)
# Consume one batch to start
for _batch in collector:
Expand All @@ -820,7 +821,7 @@ def test_endless_collector(self):
policy=policy,
frames_per_batch=10,
total_frames=-1,
backend="threading",
env_backend="threading",
)
collected = 0
for batch in collector:
Expand Down Expand Up @@ -857,9 +858,126 @@ def postproc(td):
frames_per_batch=10,
total_frames=20,
postproc=postproc,
backend="threading",
env_backend="threading",
)
for _ in collector:
pass
collector.shutdown()
assert called["count"] >= 1


# =============================================================================
# Tests: SlotTransport
# =============================================================================


class TestSlotTransport:
def test_single_request(self):
transport = SlotTransport(num_slots=4)
policy = _make_policy()
with InferenceServer(policy, transport, max_batch_size=4):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
result = client(td)
assert "action" in result.keys()
assert result["action"].shape == (2,)

def test_concurrent_actors(self):
"""Multiple threads submit concurrently via slot clients."""
n_actors = 4
n_requests = 30
transport = SlotTransport(num_slots=n_actors)
policy = _make_policy()

results_per_actor: list[list[TensorDictBase]] = [[] for _ in range(n_actors)]
clients = [transport.client() for _ in range(n_actors)]

def actor_fn(actor_id):
for _ in range(n_requests):
td = TensorDict({"observation": torch.randn(4)})
result = clients[actor_id](td)
results_per_actor[actor_id].append(result)

with InferenceServer(policy, transport, max_batch_size=n_actors):
with concurrent.futures.ThreadPoolExecutor(max_workers=n_actors) as pool:
futs = [pool.submit(actor_fn, i) for i in range(n_actors)]
concurrent.futures.wait(futs)
for f in futs:
f.result()

for actor_results in results_per_actor:
assert len(actor_results) == n_requests
for r in actor_results:
assert "action" in r.keys()
assert r["action"].shape == (2,)

def test_too_many_clients_raises(self):
"""Creating more clients than slots raises RuntimeError."""
transport = SlotTransport(num_slots=2)
transport.client()
transport.client()
with pytest.raises(RuntimeError, match="slots"):
transport.client()

def test_submit_raises(self):
"""Direct submit() on SlotTransport is not supported."""
transport = SlotTransport(num_slots=1)
td = TensorDict({"observation": torch.randn(4)})
with pytest.raises(NotImplementedError):
transport.submit(td)

def test_exception_propagates(self):
"""Model exceptions propagate through SlotTransport."""

def bad_model(td):
raise ValueError("slot model error")

transport = SlotTransport(num_slots=1)
with InferenceServer(bad_model, transport, max_batch_size=4):
client = transport.client()
td = TensorDict({"observation": torch.randn(4)})
with pytest.raises(ValueError, match="slot model error"):
client(td)


# =============================================================================
# Tests: min_batch_size
# =============================================================================


class TestMinBatchSize:
def test_min_batch_size_accumulates(self):
"""With min_batch_size > 1, the server waits for enough items."""
min_bs = 4
seen_sizes = []

def tracking_collate(items):
seen_sizes.append(len(items))
return lazy_stack(items)

transport = ThreadingTransport()
policy = _make_policy()
n = 8

with InferenceServer(
policy,
transport,
max_batch_size=16,
min_batch_size=min_bs,
collate_fn=tracking_collate,
timeout=1.0,
):
client = transport.client()
# Submit items from threads to give the server time to accumulate
with concurrent.futures.ThreadPoolExecutor(max_workers=n) as pool:
futs = [
pool.submit(
lambda: client(TensorDict({"observation": torch.randn(4)}))
)
for _ in range(n)
]
for f in futs:
f.result(timeout=10.0)

# At least one batch should have >= min_batch_size items
assert any(s >= min_bs for s in seen_sizes)
Loading
Loading