diff --git a/notes/releases/r_0_8.md b/notes/releases/r_0_8.md index b958ccab..37c7c04d 100644 --- a/notes/releases/r_0_8.md +++ b/notes/releases/r_0_8.md @@ -1,6 +1,8 @@ + - Composite language programs with Module classes (inspired by DSPy) + - Consolidated message building logic to support formats required by different APIs + - Script and Sections for better control over complex chat message sequences - Support for new Anthropic tool calls mode - Support for Cohere API - - Composite language programs with Module classes (inspired by DSPy) - `FunctionCall` helper class for extracting arguments for callable objects - - Consolidated message building logic to support formats required by different APIs + - Refactored and simplified API client classes - Additions to docs and examples diff --git a/src/ApiClient/ApiClient.php b/src/ApiClient/ApiClient.php index 2717cdca..ecbe2b80 100644 --- a/src/ApiClient/ApiClient.php +++ b/src/ApiClient/ApiClient.php @@ -1,11 +1,13 @@ defaultMaxTokens; } - $mode = match(true) { - !empty($tools) => Mode::Tools, - !empty($responseFormat) => Mode::Json, - default => Mode::MdJson, - }; - $this->apiRequest = $this->apiRequestFactory->makeRequest( - requestClass: $this->getModeRequestClass($mode), - messages: $messages, - tools: $tools, - toolChoice: $toolChoice, - responseFormat: $responseFormat, - model: $this->getModel($model), - options: $options - ); - return $this; - } - - public function chatCompletion(array $messages, string $model = '', array $options = []): static { - if (!isset($options['max_tokens'])) { - $options['max_tokens'] = $this->defaultMaxTokens; + if (!isset($body['max_tokens'])) { + $body['max_tokens'] = $this->defaultMaxTokens; } - $this->apiRequest = $this->apiRequestFactory->makeChatCompletionRequest( - requestClass: $this->getModeRequestClass(Mode::MdJson), - messages: $messages, - model: $this->getModel($model), - options: $options - ); - return $this; - } - - public function jsonCompletion(array $messages, array $responseFormat, string $model = '', array $options = []): static { - if (!isset($options['max_tokens'])) { - $options['max_tokens'] = $this->defaultMaxTokens; - } - $this->apiRequest = $this->apiRequestFactory->makeJsonCompletionRequest( - requestClass: $this->getModeRequestClass(Mode::Json), - messages: $messages, - responseFormat: $responseFormat, - model: $this->getModel($model), - options: $options - ); - return $this; - } - - public function toolsCall(array $messages, array $tools, array $toolChoice, string $model = '', array $options = []): static { - if (!isset($options['max_tokens'])) { - $options['max_tokens'] = $this->defaultMaxTokens; + if (!isset($body['model'])) { + $body['model'] = $this->defaultModel(); } - $this->apiRequest = $this->apiRequestFactory->makeToolsCallRequest( - requestClass: $this->getModeRequestClass(Mode::Tools), - messages: $messages, - tools: $tools, - toolChoice: $toolChoice, - model: $this->getModel($model), - options: $options - ); + $this->apiRequest = $this->apiRequestFactory->makeRequest($this->getModeRequestClass(), $body, $endpoint, $method, $options, $data); return $this; } diff --git a/src/ApiClient/Contracts/CanCallApi.php b/src/ApiClient/Contracts/CanCallApi.php index df8094a3..27ba7aba 100644 --- a/src/ApiClient/Contracts/CanCallApi.php +++ b/src/ApiClient/Contracts/CanCallApi.php @@ -7,16 +7,17 @@ use Cognesy\Instructor\Enums\Mode; use Generator; use GuzzleHttp\Promise\PromiseInterface; +use Saloon\Enums\Method; interface CanCallApi { - public function request(array $messages, array $tools = [], array $toolChoice = [], array $responseFormat = [], string $model = '', array $options = []): static; + public function request(array $body, string $endpoint, Method $method): static; - public function chatCompletion(array $messages, string $model = '', array $options = []): static; - - public function jsonCompletion(array $messages, array $responseFormat, string $model = '', array $options = []): static; - - public function toolsCall(array $messages, array $tools, array $toolChoice, string $model = '', array $options = []): static; +// public function chatCompletion(array $messages, string $model = '', array $options = []): static; +// +// public function jsonCompletion(array $messages, array $responseFormat, string $model = '', array $options = []): static; +// +// public function toolsCall(array $messages, array $tools, array $toolChoice, string $model = '', array $options = []): static; public function get() : ApiResponse; diff --git a/src/ApiClient/Factories/ApiRequestFactory.php b/src/ApiClient/Factories/ApiRequestFactory.php index a1cac658..79350c52 100644 --- a/src/ApiClient/Factories/ApiRequestFactory.php +++ b/src/ApiClient/Factories/ApiRequestFactory.php @@ -4,6 +4,7 @@ use Cognesy\Instructor\ApiClient\Context\ApiRequestContext; use Cognesy\Instructor\ApiClient\Requests\ApiRequest; +use Saloon\Enums\Method; class ApiRequestFactory { @@ -13,95 +14,15 @@ public function __construct( /** * @param class-string $requestClass - */ + */ public function makeRequest( string $requestClass, - array $messages, - array $tools, - array $toolChoice, - array $responseFormat, - string $model = '', - array $options = [] + array $body, + string $endpoint = '', + Method $method = Method::POST, + array $options = [], + array $data = [], ): ApiRequest { - /** @var ApiRequest $apiRequest */ - $apiRequest = new $requestClass(...[ - 'messages' => $messages, - 'tools' => $tools, - 'toolChoice' => $toolChoice, - 'responseFormat' => $responseFormat, - 'model' => $model, - 'options' => $options, - ]); - $apiRequest->withContext($this->context); - return $apiRequest; - } - - public function makeChatCompletionRequest( - string $requestClass, - array $messages, - string $model = '', - array $options = [] - ): ApiRequest { - return $this->fromClass( - requestClass: $requestClass, - args: [ - 'messages' => $messages, - 'tools' => [], - 'toolChoice' => [], - 'responseFormat' => [], - 'model' => $model, - 'options' => $options, - ] - ); - } - - public function makeJsonCompletionRequest( - string $requestClass, - array $messages, - array $responseFormat, - string $model = '', - array $options = [] - ): ApiRequest { - return $this->fromClass( - requestClass: $requestClass, - args: [ - 'messages' => $messages, - 'tools' => [], - 'toolChoice' => [], - 'responseFormat' => $responseFormat, - 'model' => $model, - 'options' => $options, - ] - ); - } - - public function makeToolsCallRequest( - string $requestClass, - array $messages, - array $tools, - array $toolChoice, - string $model = '', - array $options = [] - ): ApiRequest { - return $this->fromClass( - requestClass: $requestClass, - args: [ - 'messages' => $messages, - 'tools' => $tools, - 'toolChoice' => $toolChoice, - 'responseFormat' => [], - 'model' => $model, - 'options' => $options, - ] - ); - } - - /// INTERNAL //////////////////////////////////////////////////////////////////////////////////////////// - - protected function fromClass(string $requestClass, array $args) : ApiRequest { - /** @var ApiRequest $apiRequest */ - $apiRequest = new $requestClass(...$args); - $apiRequest->withContext($this->context); - return $apiRequest; + return new $requestClass($body, $endpoint, $method, $this->context, $options, $data); } } diff --git a/src/ApiClient/Requests/ApiRequest.php b/src/ApiClient/Requests/ApiRequest.php index a452f65b..111496aa 100644 --- a/src/ApiClient/Requests/ApiRequest.php +++ b/src/ApiClient/Requests/ApiRequest.php @@ -2,9 +2,9 @@ namespace Cognesy\Instructor\ApiClient\Requests; +use Cognesy\Instructor\ApiClient\Context\ApiRequestContext; use Cognesy\Instructor\ApiClient\Responses\ApiResponse; use Cognesy\Instructor\ApiClient\Responses\PartialApiResponse; -use Override; use Saloon\CachePlugin\Contracts\Cacheable; use Saloon\Contracts\Body\HasBody; use Saloon\Enums\Method; @@ -22,21 +22,29 @@ abstract class ApiRequest extends Request implements HasBody, Cacheable use Traits\HandlesDebug; protected Method $method = Method::POST; + protected array $options = []; + protected array $requestBody = []; + protected array $data = []; + + // TO BE DEPRECATED? + public array $messages = []; + public array $tools = []; + public string|array $toolChoice = []; + public string|array $responseFormat = []; + public string $model = ''; public function __construct( - public array $messages = [], - public array $tools = [], - public string|array $toolChoice = [], - public string|array $responseFormat = [], - public string $model = '', - public array $options = [], - public string $endpoint = '', + array $body = [], + string $endpoint = '', + Method $method = Method::POST, + // + ApiRequestContext $context = null, + array $options = [], // to consolidate into $context? + array $data = [], // to consolidate into $context? ) { + $this->context = $context; $this->debug = $this->options['debug'] ?? false; - unset($this->options['debug']); - $this->cachingEnabled = $this->options['cache'] ?? false; - unset($this->options['cache']); if ($this->cachingEnabled) { if ($this->isStreamed()) { @@ -44,22 +52,38 @@ public function __construct( } } + $this->options = $options; + $this->endpoint = $endpoint; + $this->method = $method; + $this->requestBody = $body; + $this->data = $data; + + // maybe replace them with $requestBody + $this->messages = $body['messages'] ?? []; + $this->tools = $body['tools'] ?? []; + $this->toolChoice = $body['tool_choice'] ?? []; + $this->responseFormat = $body['response_format'] ?? []; + $this->model = $body['model'] ?? ''; + $this->body()->setJsonFlags(JSON_UNESCAPED_SLASHES); } public function isStreamed(): bool { - return $this->options['stream'] ?? false; + return $this->requestBody['stream'] ?? false; } protected function defaultBody(): array { return array_filter( - array_merge([ - 'messages' => $this->messages(), - 'model' => $this->model, - 'tools' => $this->tools(), - 'tool_choice' => $this->getToolChoice(), - 'response_format' => $this->getResponseFormat(), - ], $this->options) + array_merge( + $this->requestBody, + [ + 'messages' => $this->messages(), + 'model' => $this->model, + 'tools' => $this->tools(), + 'tool_choice' => $this->getToolChoice(), + 'response_format' => $this->getResponseFormat(), + ] + ) ); } diff --git a/src/ApiClient/Requests/Traits/HandlesEndpoint.php b/src/ApiClient/Requests/Traits/HandlesEndpoint.php index a348f1a1..dbe54d3f 100644 --- a/src/ApiClient/Requests/Traits/HandlesEndpoint.php +++ b/src/ApiClient/Requests/Traits/HandlesEndpoint.php @@ -4,6 +4,7 @@ trait HandlesEndpoint { + protected ?string $endpoint = null; protected string $defaultEndpoint = '/chat/completions'; public function resolveEndpoint() : string { diff --git a/src/ApiClient/Traits/HandlesDefaultModel.php b/src/ApiClient/Traits/HandlesDefaultModel.php index 85e9e1f2..d99c6ac0 100644 --- a/src/ApiClient/Traits/HandlesDefaultModel.php +++ b/src/ApiClient/Traits/HandlesDefaultModel.php @@ -9,8 +9,4 @@ trait HandlesDefaultModel public function defaultModel() : string { return $this->defaultModel; } - - protected function getModel(string $model) : string { - return $model ?: $this->defaultModel(); - } } \ No newline at end of file diff --git a/src/Clients/Anthropic/AnthropicApiRequest.php b/src/Clients/Anthropic/AnthropicApiRequest.php index 2ee559b8..b396da48 100644 --- a/src/Clients/Anthropic/AnthropicApiRequest.php +++ b/src/Clients/Anthropic/AnthropicApiRequest.php @@ -1,5 +1,4 @@ requestBody, [ 'messages' => $this->messages(), 'model' => $this->model, 'tools' => $this->tools(), 'tool_choice' => $this->getToolChoice(), - ], - $this->options + ] ) ); } diff --git a/src/Clients/Anthropic/AnthropicClient.php b/src/Clients/Anthropic/AnthropicClient.php index a98a765f..27472469 100644 --- a/src/Clients/Anthropic/AnthropicClient.php +++ b/src/Clients/Anthropic/AnthropicClient.php @@ -35,7 +35,7 @@ public function __construct( } #[Override] - public function getModeRequestClass(Mode $mode) : string { + public function getModeRequestClass(Mode $mode = null) : string { return AnthropicApiRequest::class; } } \ No newline at end of file diff --git a/src/Clients/Anyscale/AnyscaleApiRequest.php b/src/Clients/Anyscale/AnyscaleApiRequest.php index 5e26f545..f60482a7 100644 --- a/src/Clients/Anyscale/AnyscaleApiRequest.php +++ b/src/Clients/Anyscale/AnyscaleApiRequest.php @@ -1,5 +1,4 @@ isStreamed()) { - $options['stream_options']['include_usage'] = true; + $body['stream_options']['include_usage'] = true; } + parent::__construct( - messages: $messages, - tools: $tools, - toolChoice: $toolChoice, - responseFormat: $responseFormat, - model: $model, + body: $body, + endpoint: $endpoint, + method: $method, + context: $context, options: $options, - endpoint: $endpoint + data: $data, ); } } \ No newline at end of file diff --git a/src/Clients/Azure/AzureClient.php b/src/Clients/Azure/AzureClient.php index bfc5fa4d..a1c20ddf 100644 --- a/src/Clients/Azure/AzureClient.php +++ b/src/Clients/Azure/AzureClient.php @@ -42,7 +42,7 @@ public function __construct( } #[Override] - public function getModeRequestClass(Mode $mode) : string { + public function getModeRequestClass(Mode $mode = null) : string { return AzureApiRequest::class; } } diff --git a/src/Clients/Cohere/CohereApiRequest.php b/src/Clients/Cohere/CohereApiRequest.php index 7f705ed9..5e7ae90b 100644 --- a/src/Clients/Cohere/CohereApiRequest.php +++ b/src/Clients/Cohere/CohereApiRequest.php @@ -1,5 +1,4 @@ isStreamed()) { - $options['stream_options']['include_usage'] = true; + $body['stream_options']['include_usage'] = true; } parent::__construct( - messages: $messages, - tools: $tools, - toolChoice: $toolChoice, - responseFormat: $responseFormat, - model: $model, + body: $body, + endpoint: $endpoint, + method: $method, + context: $context, options: $options, - endpoint: $endpoint + data: $data, ); } } \ No newline at end of file diff --git a/src/Clients/OpenAI/OpenAIClient.php b/src/Clients/OpenAI/OpenAIClient.php index 6affe4e1..652fd684 100644 --- a/src/Clients/OpenAI/OpenAIClient.php +++ b/src/Clients/OpenAI/OpenAIClient.php @@ -37,7 +37,7 @@ public function __construct( } #[Override] - public function getModeRequestClass(Mode $mode) : string { + public function getModeRequestClass(Mode $mode = null) : string { return OpenAIApiRequest::class; } } diff --git a/src/Clients/OpenRouter/OpenRouterApiRequest.php b/src/Clients/OpenRouter/OpenRouterApiRequest.php index 3a0dee55..5d460ea8 100644 --- a/src/Clients/OpenRouter/OpenRouterApiRequest.php +++ b/src/Clients/OpenRouter/OpenRouterApiRequest.php @@ -1,5 +1,4 @@ option('max_tokens'))) { $this->setOption('max_tokens', $this->client->defaultMaxTokens); } - return $this->fromRequest(); - } - - // INTERNAL //////////////////////////////////////////////////////////////////////////////////////////// - protected function fromRequest() : ApiRequest { $requestClass = $this->client->getModeRequestClass($this->mode()); - return match ($this->mode()) { - Mode::MdJson => $this->makeChatCompletionRequest($requestClass), - Mode::Json => $this->makeJsonCompletionRequest($requestClass), - Mode::Tools => $this->makeToolsCallRequest($requestClass), - default => $this->makeApiRequest($requestClass), - }; - } - - protected function makeApiRequest(string $requestClass) : ApiRequest { return $this->apiRequestFactory->makeRequest( requestClass: $requestClass, - messages: $this->messages(), - tools: $this->toolCallSchema(), - toolChoice: $this->toolChoice(), - responseFormat: $this->responseFormat(), - model: $this->modelName(), - options: $this->makeOptions(), - ); - } - - protected function makeChatCompletionRequest(string $requestClass) : ApiRequest { - return $this->apiRequestFactory->makeChatCompletionRequest( - requestClass: $requestClass, - messages: $this->messages(), - model: $this->modelName(), - options: $this->makeOptions(), - ); - } - - protected function makeJsonCompletionRequest(string $requestClass) : ApiRequest { - return $this->apiRequestFactory->makeJsonCompletionRequest( - requestClass: $requestClass, - messages: $this->messages(), - responseFormat: $this->responseFormat(), - model: $this->modelName(), - options: $this->makeOptions(), - ); - } - - protected function makeToolsCallRequest(string $requestClass) : ApiRequest { - return $this->apiRequestFactory->makeToolsCallRequest( - requestClass: $requestClass, - messages: $this->messages(), - tools: $this->toolCallSchema(), - toolChoice: $this->toolChoice(), - model: $this->modelName(), - options: $this->makeOptions(), + body: $this->toApiRequestBody(), + endpoint: $this->endpoint(), + method: $this->method(), + options: $this->options, + data: $this->data(), ); } } \ No newline at end of file diff --git a/src/Data/Traits/Request/HandlesData.php b/src/Data/Traits/Request/HandlesData.php new file mode 100644 index 00000000..1992ec4f --- /dev/null +++ b/src/Data/Traits/Request/HandlesData.php @@ -0,0 +1,12 @@ +data; + } +} \ No newline at end of file diff --git a/src/Data/Traits/Request/HandlesEndpoint.php b/src/Data/Traits/Request/HandlesEndpoint.php new file mode 100644 index 00000000..66ce35b3 --- /dev/null +++ b/src/Data/Traits/Request/HandlesEndpoint.php @@ -0,0 +1,12 @@ +endpoint; + } +} \ No newline at end of file diff --git a/src/Data/Traits/Request/HandlesMessages.php b/src/Data/Traits/Request/HandlesMessages.php index a9ea9cfb..7914a62c 100644 --- a/src/Data/Traits/Request/HandlesMessages.php +++ b/src/Data/Traits/Request/HandlesMessages.php @@ -2,9 +2,12 @@ namespace Cognesy\Instructor\Data\Traits\Request; +use Cognesy\Instructor\Core\Messages\Script; + trait HandlesMessages { private string|array $messages; + private Script $script; public function messages() : array { if (is_string($this->messages)) { diff --git a/src/Data/Traits/Request/HandlesMethod.php b/src/Data/Traits/Request/HandlesMethod.php new file mode 100644 index 00000000..253cce12 --- /dev/null +++ b/src/Data/Traits/Request/HandlesMethod.php @@ -0,0 +1,14 @@ +method; + } +} \ No newline at end of file diff --git a/src/Data/Traits/Request/HandlesPrompts.php b/src/Data/Traits/Request/HandlesPrompts.php index 637161ad..79318647 100644 --- a/src/Data/Traits/Request/HandlesPrompts.php +++ b/src/Data/Traits/Request/HandlesPrompts.php @@ -2,7 +2,6 @@ namespace Cognesy\Instructor\Data\Traits\Request; -use Cognesy\Instructor\Core\Messages\Utils\MessageBuilder; use Cognesy\Instructor\Enums\Mode; use Exception; @@ -35,19 +34,5 @@ protected function makeOptions() : array { if (empty($this->client())) { throw new Exception('Client is required to render request body.'); } - - $body = MessageBuilder::requestBody( - clientClass: get_class($this->client()), - mode: $this->mode(), - messages: $this->messages(), - responseModel: $this->responseModel(), - dataAcknowledgedPrompt: $this->dataAcknowledgedPrompt, - prompt: $this->prompt(), - examples: $this->examples(), - ); - return array_merge( - $this->options, - $body, - ); } } \ No newline at end of file diff --git a/src/Data/Traits/Request/HandlesRequestBody.php b/src/Data/Traits/Request/HandlesRequestBody.php new file mode 100644 index 00000000..080e9759 --- /dev/null +++ b/src/Data/Traits/Request/HandlesRequestBody.php @@ -0,0 +1,39 @@ +client()) => throw new Exception('Client not set'), + default => get_class($this->client()), + }, + mode: $this->mode(), + messages: $this->messages(), + responseModel: $this->responseModel(), + dataAcknowledgedPrompt: $this->dataAcknowledgedPrompt, + prompt: $this->prompt(), + examples: $this->examples(), + ); + + $body['model'] = $this->model(); + if (Mode::Tools == $this->mode()) { + $body['tools'] = $this->toolCallSchema(); + $body['tool_choice'] = $this->toolChoice(); + } elseif (Mode::Json == $this->mode()) { + $body['response_format'] = $this->responseFormat(); + } + $body = array_merge($body, $this->options()); + + return array_merge( + $this->options, + $body, + ); + } +} \ No newline at end of file diff --git a/src/Instructor.php b/src/Instructor.php index f1b0019f..f109cb32 100644 --- a/src/Instructor.php +++ b/src/Instructor.php @@ -4,26 +4,18 @@ use Cognesy\Instructor\ApiClient\Contracts\CanCallApi; use Cognesy\Instructor\ApiClient\Factories\ApiClientFactory; use Cognesy\Instructor\Configuration\Configuration; -use Cognesy\Instructor\Contracts\CanHandleRequest; -use Cognesy\Instructor\Contracts\CanHandleStreamRequest; use Cognesy\Instructor\Core\Factories\RequestFactory; use Cognesy\Instructor\Core\Factories\ResponseModelFactory; -use Cognesy\Instructor\Core\RequestHandler; -use Cognesy\Instructor\Core\StreamRequestHandler; -use Cognesy\Instructor\Data\Request; use Cognesy\Instructor\Enums\Mode; -use Cognesy\Instructor\Events\Event; use Cognesy\Instructor\Events\EventDispatcher; use Cognesy\Instructor\Events\Instructor\InstructorDone; use Cognesy\Instructor\Events\Instructor\InstructorReady; use Cognesy\Instructor\Events\Instructor\InstructorStarted; use Cognesy\Instructor\Events\Instructor\RequestReceived; -use Cognesy\Instructor\Events\Instructor\ResponseGenerated; use Cognesy\Instructor\Logging\EventLogger; use Cognesy\Instructor\Utils\Env; use Exception; use Psr\Log\LoggerInterface; -use Throwable; /** * Main access point to Instructor. @@ -33,19 +25,19 @@ class Instructor { use Events\Traits\HandlesEvents; use Events\Traits\HandlesEventListeners; + + use Traits\HandlesApiClient; use Traits\HandlesConfig; use Traits\HandlesDebug; + use Traits\HandlesEnv; use Traits\HandlesErrors; use Traits\HandlesPartialUpdates; use Traits\HandlesQueuedEvents; + use Traits\HandlesRequest; use Traits\HandlesSchema; use Traits\HandlesSequenceUpdates; use Traits\HandlesTimer; - protected ?Request $request = null; - protected RequestFactory $requestFactory; - protected ApiClientFactory $clientFactory; - protected ResponseModelFactory $responseModelFactory; private LoggerInterface $logger; private EventLogger $eventLogger; @@ -65,29 +57,6 @@ public function __construct(array $config = []) { $this->queueEvent(new InstructorReady($this->config)); } - /// INITIALIZATION ENDPOINTS ////////////////////////////////////////////// - - /** - * Sets the environment variables configuration file paths and names - * - * @param string|array $paths - * @param string|array $names - * @return $this - */ - public function withEnv(string|array $paths, string|array $names = '') : self { - Env::set($paths, $names); - return $this; - } - - public function client() : CanCallApi { - return $this->clientFactory->getDefault(); - } - - public function withClient(CanCallApi $client) : self { - $this->clientFactory->setDefault($client); - return $this; - } - /// EXTRACTION EXECUTION ENDPOINTS //////////////////////////////////////// /** @@ -189,39 +158,4 @@ public function stream() : Stream { } return new Stream($this->handleStreamRequest(), $this->events()); } - - /// INTERNAL ////////////////////////////////////////////////////////////// - - protected function handleRequest() : mixed { - try { - /** @var RequestHandler $requestHandler */ - $requestHandler = $this->config()->get(CanHandleRequest::class); - $this->startTimer(); - $response = $requestHandler->respondTo($this->getRequest()); - $this->stopTimer(); - $this->events->dispatch(new ResponseGenerated($response)); - return $response; - } catch (Throwable $error) { - return $this->handleError($error); - } - } - - protected function handleStreamRequest() : Iterable { - try { - /** @var StreamRequestHandler $streamHandler */ - $streamHandler = $this->config()->get(CanHandleStreamRequest::class); - $this->startTimer(); - yield from $streamHandler->respondTo($this->getRequest()); - $this->stopTimer(); - } catch (Throwable $error) { - return $this->handleError($error); - } - } - - protected function getRequest() : Request { - if ($this->debug()) { - $this->request->setOption('debug', true); - } - return $this->requestFactory->fromRequest($this->request); - } } diff --git a/src/Traits/HandlesApiClient.php b/src/Traits/HandlesApiClient.php index 1f7250ae..4f4026a7 100644 --- a/src/Traits/HandlesApiClient.php +++ b/src/Traits/HandlesApiClient.php @@ -3,17 +3,18 @@ namespace Cognesy\Instructor\Traits; use Cognesy\Instructor\ApiClient\Contracts\CanCallApi; +use Cognesy\Instructor\ApiClient\Factories\ApiClientFactory; trait HandlesApiClient { - protected CanCallApi $client; + protected ApiClientFactory $clientFactory; public function client() : CanCallApi { - return $this->client; + return $this->clientFactory->getDefault(); } public function withClient(CanCallApi $client) : self { - $this->client = $client->withEventDispatcher($this->events); + $this->clientFactory->setDefault($client); return $this; } } diff --git a/src/Traits/HandlesEnv.php b/src/Traits/HandlesEnv.php new file mode 100644 index 00000000..ae33e124 --- /dev/null +++ b/src/Traits/HandlesEnv.php @@ -0,0 +1,20 @@ +config()->get(CanHandleRequest::class); + $this->startTimer(); + $response = $requestHandler->respondTo($this->getRequest()); + $this->stopTimer(); + $this->events->dispatch(new ResponseGenerated($response)); + return $response; + } catch (Throwable $error) { + return $this->handleError($error); + } + } + + protected function handleStreamRequest() : Iterable { + try { + /** @var StreamRequestHandler $streamHandler */ + $streamHandler = $this->config()->get(CanHandleStreamRequest::class); + $this->startTimer(); + yield from $streamHandler->respondTo($this->getRequest()); + $this->stopTimer(); + } catch (Throwable $error) { + return $this->handleError($error); + } + } + + protected function getRequest() : Request { + if ($this->debug()) { + $this->request->setOption('debug', true); + } + return $this->requestFactory->fromRequest($this->request); + } +} \ No newline at end of file diff --git a/src/Traits/HandlesSchema.php b/src/Traits/HandlesSchema.php index a65293bb..e1db4540 100644 --- a/src/Traits/HandlesSchema.php +++ b/src/Traits/HandlesSchema.php @@ -2,10 +2,13 @@ namespace Cognesy\Instructor\Traits; +use Cognesy\Instructor\Core\Factories\ResponseModelFactory; use Cognesy\Instructor\Data\ResponseModel; trait HandlesSchema { + protected ResponseModelFactory $responseModelFactory; + public function createResponseModel(string|array|object $responseModel) : ResponseModel { return $this->responseModelFactory->fromAny($responseModel); } diff --git a/tests/Feature/Script/ScriptTest.php b/tests/Feature/Script/ScriptTest.php index ce7f590e..a096bd58 100644 --- a/tests/Feature/Script/ScriptTest.php +++ b/tests/Feature/Script/ScriptTest.php @@ -17,11 +17,11 @@ $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); - $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 <|key-1|>']); $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); - $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 <|key-2|>']); $messages = $script->toArray(); @@ -53,7 +53,7 @@ $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); - $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 <|key-1|>']); $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); @@ -61,7 +61,7 @@ $script->section('section-3')->add(['role' => 'user', 'content' => 'content-7']); $script->section('section-3')->add(['role' => 'assistant', 'content' => 'content-8']); - $script->section('section-3')->add(['role' => 'user', 'content' => 'content-9 {key-2}']); + $script->section('section-3')->add(['role' => 'user', 'content' => 'content-9 <|key-2|>']); $messages = $script->select(['section-3', 'section-1'])->toArray(); @@ -92,11 +92,11 @@ ]); $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); - $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 <|key-1|>']); $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); - $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 <|key-2|>']); $messages = $script->select(['section-2', 'section-1'])->toNativeArray(ClientType::Cohere, ['section-2', 'section-1']); @@ -126,11 +126,11 @@ ]); $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); - $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 <|key-1|>']); $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); - $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 <|key-2|>']); $messages = $script->select(['section-2', 'section-1'])->toNativeArray(ClientType::Anthropic); @@ -160,11 +160,11 @@ ]); $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); - $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 <|key-1|>']); $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); - $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 <|key-2|>']); $text = $script->select(['section-2', 'section-1'])->toString(); expect($text)->toBe("content-4\ncontent-5\ncontent-6 value-2\ncontent-1\ncontent-2\ncontent-3 value-1\n"); diff --git a/tests/MockLLM.php b/tests/MockLLM.php index 0cf89418..cdc60be9 100644 --- a/tests/MockLLM.php +++ b/tests/MockLLM.php @@ -20,6 +20,7 @@ static public function get(array $args) : CanCallApi { $mockLLM->shouldReceive('defaultModel')->andReturn('openai:gpt-4o'); $mockLLM->shouldReceive('getModeRequestClass')->andReturn(OpenAIApiRequest::class); $mockLLM->shouldReceive('get')->andReturnUsing(...$list); + $mockLLM->shouldReceive('request')->andReturn($mockLLM); $mockLLM->shouldReceive('toolsCall')->andReturn($mockLLM); $mockLLM->shouldReceive('withApiRequest')->andReturn($mockLLM); $mockLLM->shouldReceive('withApiRequestFactory')->andReturn($mockLLM);