|
13 | 13 | Generic, |
14 | 14 | Protocol, |
15 | 15 | TypeVar, |
| 16 | + cast, |
16 | 17 | runtime_checkable, |
17 | 18 | ) |
18 | 19 |
|
|
43 | 44 | 'StateDeps', |
44 | 45 | ] |
45 | 46 |
|
46 | | - |
47 | 47 | RunInputT = TypeVar('RunInputT') |
48 | 48 | """Type variable for protocol-specific run input types.""" |
49 | 49 |
|
|
53 | 53 | EventT = TypeVar('EventT') |
54 | 54 | """Type variable for protocol-specific event types.""" |
55 | 55 |
|
56 | | - |
57 | 56 | StateT = TypeVar('StateT', bound=BaseModel) |
58 | 57 | """Type variable for the state type, which must be a subclass of `BaseModel`.""" |
59 | 58 |
|
| 59 | +DispatchDepsT = TypeVar('DispatchDepsT') |
| 60 | +"""TypeVar for deps to avoid awkwardness with unbound classvar deps.""" |
| 61 | + |
60 | 62 |
|
61 | 63 | @runtime_checkable |
62 | 64 | class StateHandler(Protocol): |
@@ -328,18 +330,18 @@ async def dispatch_request( |
328 | 330 | cls, |
329 | 331 | request: Request, |
330 | 332 | *, |
331 | | - agent: AbstractAgent[AgentDepsT, OutputDataT], |
| 333 | + agent: AbstractAgent[DispatchDepsT, OutputDataT], |
332 | 334 | message_history: Sequence[ModelMessage] | None = None, |
333 | 335 | deferred_tool_results: DeferredToolResults | None = None, |
334 | 336 | model: Model | KnownModelName | str | None = None, |
335 | | - instructions: Instructions[AgentDepsT] = None, |
336 | | - deps: AgentDepsT = None, |
| 337 | + instructions: Instructions[DispatchDepsT] = None, |
| 338 | + deps: DispatchDepsT = None, |
337 | 339 | output_type: OutputSpec[Any] | None = None, |
338 | 340 | model_settings: ModelSettings | None = None, |
339 | 341 | usage_limits: UsageLimits | None = None, |
340 | 342 | usage: RunUsage | None = None, |
341 | 343 | infer_name: bool = True, |
342 | | - toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, |
| 344 | + toolsets: Sequence[AbstractToolset[DispatchDepsT]] | None = None, |
343 | 345 | builtin_tools: Sequence[AbstractBuiltinTool] | None = None, |
344 | 346 | on_complete: OnCompleteFunc[EventT] | None = None, |
345 | 347 | ) -> Response: |
@@ -376,7 +378,11 @@ async def dispatch_request( |
376 | 378 | ) from e |
377 | 379 |
|
378 | 380 | try: |
379 | | - adapter = await cls.from_request(request, agent=agent) |
| 381 | + # The DepsT comes from `agent`, not from `cls`; the cast is necessary to explain this to pyright |
| 382 | + adapter = cast( |
| 383 | + UIAdapter[RunInputT, MessageT, EventT, DispatchDepsT, OutputDataT], |
| 384 | + await cls.from_request(request, agent=agent), |
| 385 | + ) |
380 | 386 | except ValidationError as e: # pragma: no cover |
381 | 387 | return Response( |
382 | 388 | content=e.json(), |
|
0 commit comments