Skip to content

Commit 34b3239

Browse files
committed
Draft
1 parent 778cd10 commit 34b3239

File tree

6 files changed

+221
-42
lines changed

6 files changed

+221
-42
lines changed

outlines/generator.py

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
get_regex_logits_processor,
2222
)
2323
from outlines.backends.base import LogitsProcessorType
24+
from outlines.outputs import Output
25+
from outlines.tools import get_formatted_tools, ToolsInput
2426
from outlines.types import CFG, JsonSchema
2527
from outlines.types.dsl import python_types_to_terms, to_regex
2628

@@ -35,7 +37,13 @@ class BlackBoxGenerator:
3537
"""
3638
output_type: Optional[Any]
3739

38-
def __init__(self, model: BlackBoxModel, output_type: Optional[Any]):
40+
def __init__(
41+
self,
42+
model: BlackBoxModel,
43+
output_type: Optional[Any],
44+
*,
45+
tools: Optional[ToolsInput] = None,
46+
):
3947
"""
4048
Parameters
4149
----------
@@ -47,8 +55,9 @@ def __init__(self, model: BlackBoxModel, output_type: Optional[Any]):
4755
"""
4856
self.model = model
4957
self.output_type = output_type
58+
self.tools = get_formatted_tools(tools)
5059

51-
def __call__(self, prompt: Any, **inference_kwargs) -> Any:
60+
def __call__(self, prompt: Any, **inference_kwargs) -> Output:
5261
"""Generate a response from the model.
5362
5463
Parameters
@@ -65,10 +74,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
6574
6675
"""
6776
return self.model.generate(
68-
prompt, self.output_type, **inference_kwargs
77+
prompt, self.output_type, tools=self.tools, **inference_kwargs
6978
)
7079

71-
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
80+
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Output]:
7281
"""Generate a batch of responses from the model.
7382
7483
Parameters
@@ -85,7 +94,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
8594
8695
"""
8796
return self.model.generate_batch(
88-
prompts, self.output_type, **inference_kwargs
97+
prompts, self.output_type, tools=self.tools, **inference_kwargs
8998
)
9099

91100
def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
@@ -105,7 +114,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
105114
106115
"""
107116
return self.model.generate_stream(
108-
prompt, self.output_type, **inference_kwargs
117+
prompt, self.output_type, tools=self.tools, **inference_kwargs
109118
)
110119

111120

@@ -119,7 +128,13 @@ class AsyncBlackBoxGenerator:
119128
"""
120129
output_type: Optional[Any]
121130

122-
def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]):
131+
def __init__(
132+
self,
133+
model: AsyncBlackBoxModel,
134+
output_type: Optional[Any],
135+
*,
136+
tools: Optional[ToolsInput] = None,
137+
):
123138
"""
124139
Parameters
125140
----------
@@ -131,8 +146,9 @@ def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]):
131146
"""
132147
self.model = model
133148
self.output_type = output_type
149+
self.tools = get_formatted_tools(tools)
134150

135-
async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
151+
async def __call__(self, prompt: Any, **inference_kwargs) -> Output:
136152
"""Generate a response from the model.
137153
138154
Parameters
@@ -149,10 +165,10 @@ async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
149165
150166
"""
151167
return await self.model.generate(
152-
prompt, self.output_type, **inference_kwargs
168+
prompt, self.output_type, tools=self.tools, **inference_kwargs
153169
)
154170

155-
async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
171+
async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Output]:
156172
"""Generate a batch of responses from the model.
157173
158174
Parameters
@@ -169,7 +185,7 @@ async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
169185
170186
"""
171187
return await self.model.generate_batch(
172-
prompts, self.output_type, **inference_kwargs
188+
prompts, self.output_type, tools=self.tools, **inference_kwargs
173189
)
174190

175191
async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
@@ -189,7 +205,7 @@ async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
189205
190206
"""
191207
async for chunk in self.model.generate_stream( # pragma: no cover
192-
prompt, self.output_type, **inference_kwargs
208+
prompt, self.output_type, tools=self.tools, **inference_kwargs
193209
):
194210
yield chunk
195211

@@ -218,6 +234,8 @@ def __init__(
218234
model: SteerableModel,
219235
output_type: Optional[Any],
220236
backend_name: Optional[str] = None,
237+
*,
238+
tools: Optional[ToolsInput] = None,
221239
):
222240
"""
223241
Parameters
@@ -231,6 +249,7 @@ def __init__(
231249
232250
"""
233251
self.model = model
252+
self.tools = get_formatted_tools(tools)
234253
if output_type is None:
235254
self.logits_processor = None
236255
else:
@@ -258,7 +277,11 @@ def __init__(
258277

259278
@classmethod
260279
def from_processor(
261-
cls, model: SteerableModel, processor: LogitsProcessorType
280+
cls,
281+
model: SteerableModel,
282+
processor: LogitsProcessorType,
283+
*,
284+
tools: Optional[ToolsInput] = None,
262285
):
263286
"""Create a generator from a logits processor.
264287
@@ -270,13 +293,12 @@ def from_processor(
270293
An instance of a logits processor.
271294
272295
"""
273-
instance = cls.__new__(cls)
274-
instance.model = model
296+
instance = cls(model, None, tools=tools)
275297
instance.logits_processor = processor
276298

277299
return instance
278300

279-
def __call__(self, prompt: Any, **inference_kwargs) -> Any:
301+
def __call__(self, prompt: Any, **inference_kwargs) -> Output:
280302
"""Generate a response from the model.
281303
282304
Parameters
@@ -295,10 +317,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
295317
if self.logits_processor is not None:
296318
self.logits_processor.reset()
297319
return self.model.generate(
298-
prompt, self.logits_processor, **inference_kwargs
320+
prompt, self.logits_processor, tools=self.tools, **inference_kwargs
299321
)
300322

301-
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
323+
def batch(self, prompts: List[Any], **inference_kwargs) -> List[Output]:
302324
"""Generate a batch of responses from the model.
303325
304326
Parameters
@@ -317,7 +339,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
317339
if self.logits_processor is not None:
318340
self.logits_processor.reset()
319341
return self.model.generate_batch(
320-
prompts, self.logits_processor, **inference_kwargs
342+
prompts, self.logits_processor, tools=self.tools, **inference_kwargs
321343
)
322344

323345
def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
@@ -339,7 +361,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
339361
if self.logits_processor is not None:
340362
self.logits_processor.reset()
341363
return self.model.generate_stream(
342-
prompt, self.logits_processor, **inference_kwargs
364+
prompt, self.logits_processor, tools=self.tools, **inference_kwargs
343365
)
344366

345367

@@ -348,6 +370,7 @@ def Generator(
348370
output_type: Optional[Any] = None,
349371
backend: Optional[str] = None,
350372
*,
373+
tools: Optional[ToolsInput] = None,
351374
processor: Optional[LogitsProcessorType] = None,
352375
) -> Union[SteerableGenerator, BlackBoxGenerator, AsyncBlackBoxGenerator]:
353376
"""Create a generator for the given model and output parameters.
@@ -387,18 +410,18 @@ def Generator(
387410

388411
if isinstance(model, SteerableModel): # type: ignore
389412
if processor is not None:
390-
return SteerableGenerator.from_processor(model, processor) # type: ignore
413+
return SteerableGenerator.from_processor(model, processor, tools=tools) # type: ignore
391414
else:
392-
return SteerableGenerator(model, output_type, backend) # type: ignore
415+
return SteerableGenerator(model, output_type, backend, tools=tools) # type: ignore
393416
else:
394417
if processor is not None:
395418
raise NotImplementedError(
396419
"This model does not support logits processors"
397420
)
398421
if isinstance(model, AsyncBlackBoxModel): # type: ignore
399-
return AsyncBlackBoxGenerator(model, output_type) # type: ignore
422+
return AsyncBlackBoxGenerator(model, output_type, tools=tools) # type: ignore
400423
elif isinstance(model, BlackBoxModel): # type: ignore
401-
return BlackBoxGenerator(model, output_type) # type: ignore
424+
return BlackBoxGenerator(model, output_type, tools=tools) # type: ignore
402425
else:
403426
raise ValueError(
404427
"The model argument must be an instance of "

0 commit comments

Comments
 (0)