21
21
get_regex_logits_processor ,
22
22
)
23
23
from outlines .backends .base import LogitsProcessorType
24
+ from outlines .outputs import Output
25
+ from outlines .tools import get_formatted_tools , ToolsInput
24
26
from outlines .types import CFG , JsonSchema
25
27
from outlines .types .dsl import python_types_to_terms , to_regex
26
28
@@ -35,7 +37,13 @@ class BlackBoxGenerator:
35
37
"""
36
38
output_type : Optional [Any ]
37
39
38
- def __init__ (self , model : BlackBoxModel , output_type : Optional [Any ]):
40
+ def __init__ (
41
+ self ,
42
+ model : BlackBoxModel ,
43
+ output_type : Optional [Any ],
44
+ * ,
45
+ tools : Optional [ToolsInput ] = None ,
46
+ ):
39
47
"""
40
48
Parameters
41
49
----------
@@ -47,8 +55,9 @@ def __init__(self, model: BlackBoxModel, output_type: Optional[Any]):
47
55
"""
48
56
self .model = model
49
57
self .output_type = output_type
58
+ self .tools = get_formatted_tools (tools )
50
59
51
- def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
60
+ def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
52
61
"""Generate a response from the model.
53
62
54
63
Parameters
@@ -65,10 +74,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
65
74
66
75
"""
67
76
return self .model .generate (
68
- prompt , self .output_type , ** inference_kwargs
77
+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
69
78
)
70
79
71
- def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
80
+ def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
72
81
"""Generate a batch of responses from the model.
73
82
74
83
Parameters
@@ -85,7 +94,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
85
94
86
95
"""
87
96
return self .model .generate_batch (
88
- prompts , self .output_type , ** inference_kwargs
97
+ prompts , self .output_type , tools = self . tools , ** inference_kwargs
89
98
)
90
99
91
100
def stream (self , prompt : Any , ** inference_kwargs ) -> Iterator [Any ]:
@@ -105,7 +114,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
105
114
106
115
"""
107
116
return self .model .generate_stream (
108
- prompt , self .output_type , ** inference_kwargs
117
+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
109
118
)
110
119
111
120
@@ -119,7 +128,13 @@ class AsyncBlackBoxGenerator:
119
128
"""
120
129
output_type : Optional [Any ]
121
130
122
- def __init__ (self , model : AsyncBlackBoxModel , output_type : Optional [Any ]):
131
+ def __init__ (
132
+ self ,
133
+ model : AsyncBlackBoxModel ,
134
+ output_type : Optional [Any ],
135
+ * ,
136
+ tools : Optional [ToolsInput ] = None ,
137
+ ):
123
138
"""
124
139
Parameters
125
140
----------
@@ -131,8 +146,9 @@ def __init__(self, model: AsyncBlackBoxModel, output_type: Optional[Any]):
131
146
"""
132
147
self .model = model
133
148
self .output_type = output_type
149
+ self .tools = get_formatted_tools (tools )
134
150
135
- async def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
151
+ async def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
136
152
"""Generate a response from the model.
137
153
138
154
Parameters
@@ -149,10 +165,10 @@ async def __call__(self, prompt: Any, **inference_kwargs) -> Any:
149
165
150
166
"""
151
167
return await self .model .generate (
152
- prompt , self .output_type , ** inference_kwargs
168
+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
153
169
)
154
170
155
- async def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
171
+ async def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
156
172
"""Generate a batch of responses from the model.
157
173
158
174
Parameters
@@ -169,7 +185,7 @@ async def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
169
185
170
186
"""
171
187
return await self .model .generate_batch (
172
- prompts , self .output_type , ** inference_kwargs
188
+ prompts , self .output_type , tools = self . tools , ** inference_kwargs
173
189
)
174
190
175
191
async def stream (self , prompt : Any , ** inference_kwargs ) -> AsyncIterator [Any ]:
@@ -189,7 +205,7 @@ async def stream(self, prompt: Any, **inference_kwargs) -> AsyncIterator[Any]:
189
205
190
206
"""
191
207
async for chunk in self .model .generate_stream ( # pragma: no cover
192
- prompt , self .output_type , ** inference_kwargs
208
+ prompt , self .output_type , tools = self . tools , ** inference_kwargs
193
209
):
194
210
yield chunk
195
211
@@ -218,6 +234,8 @@ def __init__(
218
234
model : SteerableModel ,
219
235
output_type : Optional [Any ],
220
236
backend_name : Optional [str ] = None ,
237
+ * ,
238
+ tools : Optional [ToolsInput ] = None ,
221
239
):
222
240
"""
223
241
Parameters
@@ -231,6 +249,7 @@ def __init__(
231
249
232
250
"""
233
251
self .model = model
252
+ self .tools = get_formatted_tools (tools )
234
253
if output_type is None :
235
254
self .logits_processor = None
236
255
else :
@@ -258,7 +277,11 @@ def __init__(
258
277
259
278
@classmethod
260
279
def from_processor (
261
- cls , model : SteerableModel , processor : LogitsProcessorType
280
+ cls ,
281
+ model : SteerableModel ,
282
+ processor : LogitsProcessorType ,
283
+ * ,
284
+ tools : Optional [ToolsInput ] = None ,
262
285
):
263
286
"""Create a generator from a logits processor.
264
287
@@ -270,13 +293,12 @@ def from_processor(
270
293
An instance of a logits processor.
271
294
272
295
"""
273
- instance = cls .__new__ (cls )
274
- instance .model = model
296
+ instance = cls (model , None , tools = tools )
275
297
instance .logits_processor = processor
276
298
277
299
return instance
278
300
279
- def __call__ (self , prompt : Any , ** inference_kwargs ) -> Any :
301
+ def __call__ (self , prompt : Any , ** inference_kwargs ) -> Output :
280
302
"""Generate a response from the model.
281
303
282
304
Parameters
@@ -295,10 +317,10 @@ def __call__(self, prompt: Any, **inference_kwargs) -> Any:
295
317
if self .logits_processor is not None :
296
318
self .logits_processor .reset ()
297
319
return self .model .generate (
298
- prompt , self .logits_processor , ** inference_kwargs
320
+ prompt , self .logits_processor , tools = self . tools , ** inference_kwargs
299
321
)
300
322
301
- def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Any ]:
323
+ def batch (self , prompts : List [Any ], ** inference_kwargs ) -> List [Output ]:
302
324
"""Generate a batch of responses from the model.
303
325
304
326
Parameters
@@ -317,7 +339,7 @@ def batch(self, prompts: List[Any], **inference_kwargs) -> List[Any]:
317
339
if self .logits_processor is not None :
318
340
self .logits_processor .reset ()
319
341
return self .model .generate_batch (
320
- prompts , self .logits_processor , ** inference_kwargs
342
+ prompts , self .logits_processor , tools = self . tools , ** inference_kwargs
321
343
)
322
344
323
345
def stream (self , prompt : Any , ** inference_kwargs ) -> Iterator [Any ]:
@@ -339,7 +361,7 @@ def stream(self, prompt: Any, **inference_kwargs) -> Iterator[Any]:
339
361
if self .logits_processor is not None :
340
362
self .logits_processor .reset ()
341
363
return self .model .generate_stream (
342
- prompt , self .logits_processor , ** inference_kwargs
364
+ prompt , self .logits_processor , tools = self . tools , ** inference_kwargs
343
365
)
344
366
345
367
@@ -348,6 +370,7 @@ def Generator(
348
370
output_type : Optional [Any ] = None ,
349
371
backend : Optional [str ] = None ,
350
372
* ,
373
+ tools : Optional [ToolsInput ] = None ,
351
374
processor : Optional [LogitsProcessorType ] = None ,
352
375
) -> Union [SteerableGenerator , BlackBoxGenerator , AsyncBlackBoxGenerator ]:
353
376
"""Create a generator for the given model and output parameters.
@@ -387,18 +410,18 @@ def Generator(
387
410
388
411
if isinstance (model , SteerableModel ): # type: ignore
389
412
if processor is not None :
390
- return SteerableGenerator .from_processor (model , processor ) # type: ignore
413
+ return SteerableGenerator .from_processor (model , processor , tools = tools ) # type: ignore
391
414
else :
392
- return SteerableGenerator (model , output_type , backend ) # type: ignore
415
+ return SteerableGenerator (model , output_type , backend , tools = tools ) # type: ignore
393
416
else :
394
417
if processor is not None :
395
418
raise NotImplementedError (
396
419
"This model does not support logits processors"
397
420
)
398
421
if isinstance (model , AsyncBlackBoxModel ): # type: ignore
399
- return AsyncBlackBoxGenerator (model , output_type ) # type: ignore
422
+ return AsyncBlackBoxGenerator (model , output_type , tools = tools ) # type: ignore
400
423
elif isinstance (model , BlackBoxModel ): # type: ignore
401
- return BlackBoxGenerator (model , output_type ) # type: ignore
424
+ return BlackBoxGenerator (model , output_type , tools = tools ) # type: ignore
402
425
else :
403
426
raise ValueError (
404
427
"The model argument must be an instance of "
0 commit comments