Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add async support to Module class #1988

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 56 additions & 7 deletions dspy/clients/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,23 +85,31 @@ def __init__(
max_tokens >= 5000 and temperature == 1.0
), "OpenAI's o1-* models require passing temperature=1.0 and max_tokens >= 5000 to `dspy.LM(...)`"

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
# Build the request.
def _build_request(self, prompt=None, messages=None, **kwargs):
"""Build the request dictionary for LM calls"""
cache = kwargs.pop("cache", self.cache)
messages = messages or [{"role": "user", "content": prompt}]
kwargs = {**self.kwargs, **kwargs}

# Make the request and handle LRU & disk caching.

if self.model_type == "chat":
completion = cached_litellm_completion if cache else litellm_completion
else:
completion = cached_litellm_text_completion if cache else litellm_text_completion

response = completion(
return dict(
request=dict(model=self.model, messages=messages, **kwargs),
completion=completion, # <-- ADD THIS LINE
num_retries=self.num_retries,
)

@with_callbacks
def __call__(self, prompt=None, messages=None, **kwargs):
request = self._build_request(prompt, messages, **kwargs)
# Pass required arguments explicitly instead of **request
response = request["completion"](
request=request["request"],
num_retries=request["num_retries"]
)
if kwargs.get("logprobs"):
outputs = [
{
Expand Down Expand Up @@ -216,6 +224,47 @@ def infer_adapter(self) -> Adapter:
model_type = self.model_type
return model_type_to_adapter[model_type]

async def _async_request(self, request: dict) -> list:
"""Base async request handler"""
# Pass required arguments explicitly
response = await litellm.acompletion(**request["request"])
if request["request"].get("logprobs"):
outputs = [
{
"text": c.message.content if hasattr(c, "message") else c["text"],
"logprobs": c.logprobs if hasattr(c, "logprobs") else c["logprobs"],
}
for c in response["choices"]
]
else:
outputs = [c.message.content if hasattr(c, "message") else c["text"] for c in response["choices"]]

# Logging
kwargs = {k: v for k, v in request["request"].items() if not k.startswith("api_")}
entry = dict(
prompt=request["request"].get("prompt"),
messages=request["request"].get("messages"),
kwargs=kwargs,
response=response,
outputs=outputs,
usage=dict(response["usage"]),
cost=response.get("_hidden_params", {}).get("response_cost"),
timestamp=datetime.now().isoformat(),
uuid=str(uuid.uuid4()),
model=self.model,
response_model=response["model"],
model_type=self.model_type,
)
self.history.append(entry)
self.update_global_history(entry)

return outputs

async def __acall__(self, prompt=None, messages=None, **kwargs):
"""Async call interface"""
request = self._build_request(prompt, messages, **kwargs)
return await self._async_request(request)

def copy(self, **kwargs):
"""Returns a copy of the language model with possibly updated parameters."""

Expand Down
134 changes: 133 additions & 1 deletion dspy/primitives/program.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,24 @@
import magicattr
import inspect
import asyncio
from typing import Any, Union, Awaitable, TypeVar, Optional, List, Callable

# Marker for async arguments
ASYNC_MARKER = object()

def is_async_arg(arg):
"""Check if an argument requires async resolution."""
return (arg is ASYNC_MARKER or
inspect.iscoroutine(arg) or
inspect.isawaitable(arg) or
isinstance(arg, asyncio.Future))

from dspy.predict.parallel import Parallel
from dspy.primitives.module import BaseModule
from dspy.utils.callback import with_callbacks

T = TypeVar('T')


class ProgramMeta(type):
pass
Expand All @@ -18,9 +33,126 @@ def __init__(self, callbacks=None):
self._compiled = False

@with_callbacks
def __call__(self, *args, **kwargs):
def __call__(self, *args: Any, **kwargs: Any) -> Union[T, Awaitable[T]]:
"""Call the module with given arguments.

Automatically determines whether to use sync or async execution based on arguments.
If any argument is a coroutine, awaitable, or future, uses async execution.
Also uses async execution if the module has a custom aforward implementation.

Args:
*args: Positional arguments
**kwargs: Keyword arguments

Returns:
Either the direct result (sync) or an awaitable of the result (async)
"""
# Check if we should use async execution
use_async = (
# If any argument is async
any(is_async_arg(arg) for arg in args) or
any(is_async_arg(v) for v in kwargs.values()) or
# Or if we have a custom aforward implementation
(hasattr(self, 'aforward') and
self.aforward.__func__ is not Module.aforward)
)

if use_async:
async def _async_call():
try:
# Collect ALL async values first
all_async = [
arg for arg in args if is_async_arg(arg)
] + [
v for v in kwargs.values() if is_async_arg(v)
]

# Resolve ALL concurrently
if all_async:
resolved = await asyncio.gather(*all_async)
else:
resolved = []

# Rebuild args/kwargs with resolved values
resolved_iter = iter(resolved)
new_args = [next(resolved_iter) if is_async_arg(arg) else arg for arg in args]
new_kwargs = {k: next(resolved_iter) if is_async_arg(v) else v for k, v in kwargs.items()}

# Validate all async values were resolved
for arg in new_args:
if arg is ASYNC_MARKER:
raise ValueError("Unresolved async argument in args")
for v in new_kwargs.values():
if v is ASYNC_MARKER:
raise ValueError("Unresolved async argument in kwargs")

return await self.aforward(*new_args, **new_kwargs)
except Exception as e:
raise e
return _async_call()

# Use sync execution
return self.forward(*args, **kwargs)

async def aforward(self, *args: Any, **kwargs: Any) -> T:
"""Async version of forward.

This method should be implemented by subclasses to provide async execution.
By default, raises NotImplementedError to encourage proper async implementation.

When implementing this method:
1. Use 'async def' and 'await' for async operations
2. Avoid blocking operations - they should be properly awaited
3. Consider using asyncio.create_task for concurrent operations
4. Be mindful of async context managers (use 'async with')

Example:
```python
class MyAsyncModule(Module):
async def aforward(self, x):
# Good: proper async operation
result = await async_operation(x)
return result

# Bad: blocking operation
# time.sleep(1) # Don't do this!

# Bad: sync operation without proper async
# return self.forward(x) # Don't do this!
```

Args:
*args: Positional arguments
**kwargs: Keyword arguments

Returns:
The result of the async computation

Raises:
NotImplementedError: Subclasses must implement this method for async operations
"""
raise NotImplementedError(
"Subclasses must implement aforward for async operations. "
"Do not use sync operations or blocking calls in this method."
)

def forward(self, *args: Any, **kwargs: Any) -> T:
"""Synchronous forward pass.

Must be implemented by subclasses to define the module's computation.

Args:
*args: Positional arguments
**kwargs: Keyword arguments

Returns:
The result of the computation

Raises:
NotImplementedError: If not implemented by subclass
"""
raise NotImplementedError("Subclasses must implement forward method")

def named_predictors(self):
from dspy.predict.predict import Predict

Expand Down
39 changes: 31 additions & 8 deletions dspy/utils/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,48 @@ def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
>>> print(value) # Print each streamed value incrementally
"""
import dspy
import inspect

if not iscoroutinefunction(program):
program = asyncify(program)

async def generator(args, kwargs, stream: MemoryObjectSendStream):
with dspy.settings.context(send_stream=stream):
prediction = await program(*args, **kwargs)
try:
with dspy.settings.context(send_stream=stream):
# Get the raw output from the program
output = program(*args, **kwargs)

# Handle both async and sync outputs
if inspect.isawaitable(output):
output = await output

await stream.send(prediction)
# If output is a generator/async generator, stream its items
if inspect.isgenerator(output) or inspect.isasyncgen(output):
async for chunk in output:
await stream.send(chunk)
else:
# For single predictions, send as a single chunk
await stream.send(output)

# Send completion marker
await stream.send(None)
finally:
await stream.aclose()

async def streamer(*args, **kwargs):
send_stream, receive_stream = create_memory_object_stream(16)
async with create_task_group() as tg, send_stream, receive_stream:
async with create_task_group() as tg:
tg.start_soon(generator, args, kwargs, send_stream)

try:
async for value in receive_stream:
if value is None: # Completion marker
break
yield value
finally:
await receive_stream.aclose()

async for value in receive_stream:
yield value
if isinstance(value, Prediction):
return
return streamer

return streamer

Expand Down
Loading