From a1c2d76f48758ed313efc94b6a578aa097c6511a Mon Sep 17 00:00:00 2001 From: ddebowczyk Date: Thu, 14 Nov 2024 13:12:49 +0100 Subject: [PATCH] Tool use addon --- config/llm.php | 282 ++++++------ evals/LLMModes/CompanyEval.php | 13 +- examples/A02_Advanced/CustomPrompts/run.php | 137 +++--- .../A02_Advanced/ProvidingExamples/run.php | 115 ++--- examples/A05_Extras/LLMToolUse/run.php | 209 +++++++++ prompts/system/mode_tools.twig | 44 ++ src/Extras/Prompt/Prompt.php | 36 +- .../Traits/ChatTemplate/HandlesSections.php | 166 +++---- src/Features/Core/PartialsGenerator.php | 410 +++++++++--------- src/Features/Core/RequestHandler.php | 308 ++++++------- src/Features/LLM/Data/LLMResponse.php | 254 ++++++----- src/Features/LLM/Data/PartialLLMResponse.php | 129 +++--- src/Features/LLM/Data/ToolCall.php | 156 +++++-- src/Features/LLM/Data/ToolCalls.php | 194 +++++---- src/Features/LLM/Drivers/AnthropicDriver.php | 141 ++++-- .../LLM/Drivers/AzureOpenAIDriver.php | 68 ++- src/Features/LLM/Drivers/CohereV1Driver.php | 42 +- src/Features/LLM/Drivers/CohereV2Driver.php | 353 +++++++-------- src/Features/LLM/Drivers/GeminiDriver.php | 136 ++++-- src/Features/LLM/Drivers/GrokDriver.php | 14 + src/Features/LLM/Drivers/MistralDriver.php | 130 +++--- .../LLM/Drivers/OpenAICompatibleDriver.php | 138 +++--- src/Features/LLM/Drivers/OpenAIDriver.php | 71 ++- src/Features/LLM/Enums/LLMFinishReason.php | 86 ++-- src/Features/LLM/Enums/LLMProviderType.php | 57 +-- src/Features/LLM/Inference.php | 4 +- src/Features/LLM/InferenceRequest.php | 160 ++++--- src/Utils/Json/PartialJsonParser.php | 407 ++++++++--------- src/Utils/Messages/Message.php | 55 ++- .../Messages/Traits/Message/HandlesAccess.php | 100 +++-- .../Traits/Message/HandlesCreation.php | 10 +- .../Traits/Message/HandlesTransformation.php | 97 ++--- .../Traits/Messages/HandlesConversion.php | 181 ++++---- .../Traits/Messages/HandlesCreation.php | 182 ++++---- src/Utils/Str.php | 26 +- src/Utils/TemplateUtil.php | 278 ++++++------ src/Utils/Web/Webpage.php | 68 +-- 37 files changed, 2978 insertions(+), 2279 deletions(-) create mode 100644 examples/A05_Extras/LLMToolUse/run.php create mode 100644 src/Features/LLM/Drivers/GrokDriver.php diff --git a/config/llm.php b/config/llm.php index 55b77ba2..85a1465c 100644 --- a/config/llm.php +++ b/config/llm.php @@ -1,137 +1,145 @@ - false, - 'defaultConnection' => 'openai', - - 'defaultToolName' => 'extracted_data', - 'defaultToolDescription' => 'Function call based on user instructions.', - 'defaultRetryPrompt' => "JSON generated incorrectly, fix following errors:\n", - 'defaultMdJsonPrompt' => "Response must validate against this JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object within a ```json {} ``` codeblock.\n", - 'defaultJsonPrompt' => "Response must follow JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object.\n", - 'defaultToolsPrompt' => "Extract correct and accurate data from the input using provided tools.\n", - - 'connections' => [ - 'anthropic' => [ - 'providerType' => LLMProviderType::Anthropic->value, - 'apiUrl' => 'https://api.anthropic.com/v1', - 'apiKey' => Env::get('ANTHROPIC_API_KEY', ''), - 'endpoint' => '/messages', - 'metadata' => [ - 'apiVersion' => '2023-06-01', - 'beta' => 'prompt-caching-2024-07-31', - ], - 'defaultModel' => 'claude-3-haiku-20240307', - 'defaultMaxTokens' => 1024, - ], - 'azure' => [ - 'providerType' => LLMProviderType::Azure->value, - 'apiUrl' => 'https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}', - 'apiKey' => Env::get('AZURE_OPENAI_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'metadata' => [ - 'apiVersion' => '2024-08-01-preview', - 'resourceName' => 'instructor-dev', - 'deploymentId' => 'gpt-4o-mini', - ], - 'defaultModel' => 'gpt-4o-mini', - 'defaultMaxTokens' => 1024, - ], - 'cohere1' => [ - 'providerType' => LLMProviderType::CohereV1->value, - 'apiUrl' => 'https://api.cohere.ai/v1', - 'apiKey' => Env::get('COHERE_API_KEY', ''), - 'endpoint' => '/chat', - 'defaultModel' => 'command-r-plus-08-2024', - 'defaultMaxTokens' => 1024, - ], - 'cohere2' => [ - 'providerType' => LLMProviderType::CohereV2->value, - 'apiUrl' => 'https://api.cohere.ai/v2', - 'apiKey' => Env::get('COHERE_API_KEY', ''), - 'endpoint' => '/chat', - 'defaultModel' => 'command-r-plus-08-2024', - 'defaultMaxTokens' => 1024, - ], - 'fireworks' => [ - 'providerType' => LLMProviderType::Fireworks->value, - 'apiUrl' => 'https://api.fireworks.ai/inference/v1', - 'apiKey' => Env::get('FIREWORKS_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'accounts/fireworks/models/mixtral-8x7b-instruct', - 'defaultMaxTokens' => 1024, - ], - 'gemini' => [ - 'providerType' => LLMProviderType::Gemini->value, - 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta', - 'apiKey' => Env::get('GEMINI_API_KEY', ''), - 'endpoint' => '/models/{model}:generateContent', - 'defaultModel' => 'gemini-1.5-flash-latest', - 'defaultMaxTokens' => 1024, - ], - 'grok' => [ - 'providerType' => LLMProviderType::Grok->value, - 'apiUrl' => 'https://api.x.ai/v1', - 'apiKey' => Env::get('GROK_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'grok-beta', - 'defaultMaxTokens' => 1024, - ], - 'groq' => [ - 'providerType' => LLMProviderType::Groq->value, - 'apiUrl' => 'https://api.groq.com/openai/v1', - 'apiKey' => Env::get('GROQ_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'llama3-groq-8b-8192-tool-use-preview', // 'gemma2-9b-it', - 'defaultMaxTokens' => 1024, - ], - 'mistral' => [ - 'providerType' => LLMProviderType::Mistral->value, - 'apiUrl' => 'https://api.mistral.ai/v1', - 'apiKey' => Env::get('MISTRAL_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'mistral-small-latest', - 'defaultMaxTokens' => 1024, - ], - 'ollama' => [ - 'providerType' => LLMProviderType::Ollama->value, - 'apiUrl' => 'http://localhost:11434/v1', - 'apiKey' => Env::get('OLLAMA_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'qwen2.5:0.5b', //'gemma2:2b', - 'defaultMaxTokens' => 1024, - 'httpClient' => 'http-ollama', - ], - 'openai' => [ - 'providerType' => LLMProviderType::OpenAI->value, - 'apiUrl' => 'https://api.openai.com/v1', - 'apiKey' => Env::get('OPENAI_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'metadata' => [ - 'organization' => '', - 'project' => '', - ], - 'defaultModel' => 'gpt-4o-mini', - 'defaultMaxTokens' => 1024, - ], - 'openrouter' => [ - 'providerType' => LLMProviderType::OpenRouter->value, - 'apiUrl' => 'https://openrouter.ai/api/v1', - 'apiKey' => Env::get('OPENROUTER_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'qwen/qwen-2.5-72b-instruct', //'microsoft/phi-3.5-mini-128k-instruct', - 'defaultMaxTokens' => 1024, - ], - 'together' => [ - 'providerType' => LLMProviderType::Together->value, - 'apiUrl' => 'https://api.together.xyz/v1', - 'apiKey' => Env::get('TOGETHER_API_KEY', ''), - 'endpoint' => '/chat/completions', - 'defaultModel' => 'mistralai/Mixtral-8x7B-Instruct-v0.1', - 'defaultMaxTokens' => 1024, - ], - ], -]; + false, + 'defaultConnection' => 'openai', + + 'defaultToolName' => 'extracted_data', + 'defaultToolDescription' => 'Function call based on user instructions.', + 'defaultRetryPrompt' => "JSON generated incorrectly, fix following errors:\n", + 'defaultMdJsonPrompt' => "Response must validate against this JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object within a ```json {} ``` codeblock.\n", + 'defaultJsonPrompt' => "Response must follow JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object.\n", + 'defaultToolsPrompt' => "Extract correct and accurate data from the input using provided tools.\n", + + 'connections' => [ + 'anthropic' => [ + 'providerType' => LLMProviderType::Anthropic->value, + 'apiUrl' => 'https://api.anthropic.com/v1', + 'apiKey' => Env::get('ANTHROPIC_API_KEY', ''), + 'endpoint' => '/messages', + 'metadata' => [ + 'apiVersion' => '2023-06-01', + 'beta' => 'prompt-caching-2024-07-31', + ], + 'defaultModel' => 'claude-3-haiku-20240307', + 'defaultMaxTokens' => 1024, + ], + 'azure' => [ + 'providerType' => LLMProviderType::Azure->value, + 'apiUrl' => 'https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}', + 'apiKey' => Env::get('AZURE_OPENAI_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'metadata' => [ + 'apiVersion' => '2024-08-01-preview', + 'resourceName' => 'instructor-dev', + 'deploymentId' => 'gpt-4o-mini', + ], + 'defaultModel' => 'gpt-4o-mini', + 'defaultMaxTokens' => 1024, + ], + 'cohere1' => [ + 'providerType' => LLMProviderType::CohereV1->value, + 'apiUrl' => 'https://api.cohere.ai/v1', + 'apiKey' => Env::get('COHERE_API_KEY', ''), + 'endpoint' => '/chat', + 'defaultModel' => 'command-r-plus-08-2024', + 'defaultMaxTokens' => 1024, + ], + 'cohere2' => [ + 'providerType' => LLMProviderType::CohereV2->value, + 'apiUrl' => 'https://api.cohere.ai/v2', + 'apiKey' => Env::get('COHERE_API_KEY', ''), + 'endpoint' => '/chat', + 'defaultModel' => 'command-r-plus-08-2024', + 'defaultMaxTokens' => 1024, + ], + 'fireworks' => [ + 'providerType' => LLMProviderType::Fireworks->value, + 'apiUrl' => 'https://api.fireworks.ai/inference/v1', + 'apiKey' => Env::get('FIREWORKS_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'accounts/fireworks/models/mixtral-8x7b-instruct', + 'defaultMaxTokens' => 1024, + ], + 'gemini' => [ + 'providerType' => LLMProviderType::Gemini->value, + 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta', + 'apiKey' => Env::get('GEMINI_API_KEY', ''), + 'endpoint' => '/models/{model}:generateContent', + 'defaultModel' => 'gemini-1.5-flash-latest', + 'defaultMaxTokens' => 1024, + ], + 'gemini-oai' => [ + 'providerType' => LLMProviderType::GeminiOAI->value, + 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta/openai', + 'apiKey' => Env::get('GEMINI_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'gemini-1.5-flash', + 'defaultMaxTokens' => 1024, + ], + 'grok' => [ + 'providerType' => LLMProviderType::Grok->value, + 'apiUrl' => 'https://api.x.ai/v1', + 'apiKey' => Env::get('GROK_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'grok-beta', + 'defaultMaxTokens' => 1024, + ], + 'groq' => [ + 'providerType' => LLMProviderType::Groq->value, + 'apiUrl' => 'https://api.groq.com/openai/v1', + 'apiKey' => Env::get('GROQ_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'llama3-groq-8b-8192-tool-use-preview', // 'gemma2-9b-it', + 'defaultMaxTokens' => 1024, + ], + 'mistral' => [ + 'providerType' => LLMProviderType::Mistral->value, + 'apiUrl' => 'https://api.mistral.ai/v1', + 'apiKey' => Env::get('MISTRAL_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'mistral-small-latest', + 'defaultMaxTokens' => 1024, + ], + 'ollama' => [ + 'providerType' => LLMProviderType::Ollama->value, + 'apiUrl' => 'http://localhost:11434/v1', + 'apiKey' => Env::get('OLLAMA_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'qwen2.5-coder:3b', //'gemma2:2b', + 'defaultMaxTokens' => 1024, + 'httpClient' => 'http-ollama', + ], + 'openai' => [ + 'providerType' => LLMProviderType::OpenAI->value, + 'apiUrl' => 'https://api.openai.com/v1', + 'apiKey' => Env::get('OPENAI_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'metadata' => [ + 'organization' => '', + 'project' => '', + ], + 'defaultModel' => 'gpt-4o-mini', + 'defaultMaxTokens' => 1024, + ], + 'openrouter' => [ + 'providerType' => LLMProviderType::OpenRouter->value, + 'apiUrl' => 'https://openrouter.ai/api/v1', + 'apiKey' => Env::get('OPENROUTER_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'qwen/qwen-2.5-72b-instruct', //'microsoft/phi-3.5-mini-128k-instruct', + 'defaultMaxTokens' => 1024, + ], + 'together' => [ + 'providerType' => LLMProviderType::Together->value, + 'apiUrl' => 'https://api.together.xyz/v1', + 'apiKey' => Env::get('TOGETHER_API_KEY', ''), + 'endpoint' => '/chat/completions', + 'defaultModel' => 'mistralai/Mixtral-8x7B-Instruct-v0.1', + 'defaultMaxTokens' => 1024, + ], + ], +]; diff --git a/evals/LLMModes/CompanyEval.php b/evals/LLMModes/CompanyEval.php index b70ef744..29687592 100644 --- a/evals/LLMModes/CompanyEval.php +++ b/evals/LLMModes/CompanyEval.php @@ -6,6 +6,7 @@ use Cognesy\Instructor\Extras\Evals\Contracts\CanGenerateObservations; use Cognesy\Instructor\Extras\Evals\Execution; use Cognesy\Instructor\Extras\Evals\Observation; +use Cognesy\Instructor\Features\LLM\Data\ToolCall; use Cognesy\Instructor\Utils\Str; class CompanyEval implements CanGenerateObservations @@ -50,10 +51,14 @@ public function correctness(Execution $execution): Observation { } private function validateToolsData(Execution $execution) : bool { - $data = $execution->get('response')->toolsData[0] ?? []; - return 'store_company' === ($data['name'] ?? '') - && 'ACME' === ($data['arguments']['name'] ?? '') - && 2020 === (int) ($data['arguments']['year'] ?? 0); + /** @var ToolCall $toolCall */ + $toolCall = $execution->get('response')->toolCalls?->first(); + if (null === $toolCall) { + return false; + } + return 'store_company' === $toolCall->name() + && 'ACME' === $toolCall->stringValue('name') + && 2020 === $toolCall->intValue('year'); } private function validateDefault(Execution $execution) : bool { diff --git a/examples/A02_Advanced/CustomPrompts/run.php b/examples/A02_Advanced/CustomPrompts/run.php index cc0ffec8..b15f3312 100644 --- a/examples/A02_Advanced/CustomPrompts/run.php +++ b/examples/A02_Advanced/CustomPrompts/run.php @@ -1,68 +1,69 @@ ---- -title: 'Custom prompts' -docname: 'custom_prompts' ---- - -## Overview - -In case you want to take control over the prompts sent by Instructor -to LLM for different modes, you can use the `prompt` parameter in the -`request()` or `respond()` methods. - -It will override the default Instructor prompts, allowing you to fully -customize how LLM is instructed to process the input. - -## Example - -```php -add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); - -use Cognesy\Instructor\Enums\Mode; -use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM; -use Cognesy\Instructor\Instructor; - -class User { - public int $age; - public string $name; -} - -$instructor = (new Instructor) - ->onEvent(RequestSentToLLM::class, fn(RequestSentToLLM $event)=>dump($event->request->toMessages())); - -print("\n# Request for Mode::Tools:\n\n"); -$user = $instructor - ->respond( - messages: "Our user Jason is 25 years old.", - responseModel: User::class, - prompt: "\nYour task is to extract correct and accurate data from the messages using provided tools.\n", - mode: Mode::Tools - ); -echo "\nRESPONSE:\n"; -dump($user); - -print("\n# Request for Mode::Json:\n\n"); -$user = $instructor - ->respond( - messages: "Our user Jason is 25 years old.", - responseModel: User::class, - prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n<|json_schema|>\n", - mode: Mode::Json - ); -echo "\nRESPONSE:\n"; -dump($user); - -print("\n# Request for Mode::MdJson:\n\n"); -$user = $instructor - ->respond( - messages: "Our user Jason is 25 years old.", - responseModel: User::class, - prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n", - mode: Mode::MdJson - ); -echo "\nRESPONSE:\n"; -dump($user); - -?> -``` +--- +title: 'Custom prompts' +docname: 'custom_prompts' +--- + +## Overview + +In case you want to take control over the prompts sent by Instructor +to LLM for different modes, you can use the `prompt` parameter in the +`request()` or `respond()` methods. + +It will override the default Instructor prompts, allowing you to fully +customize how LLM is instructed to process the input. + +## Example + +```php +add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); + +use Cognesy\Instructor\Enums\Mode; +use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM; +use Cognesy\Instructor\Instructor; + +class User { + public int $age; + public string $name; +} + +$instructor = (new Instructor) + // let's dump the request data to see how customized prompts look like in requests + ->onEvent(RequestSentToLLM::class, fn(RequestSentToLLM $event) => dump($event)); + +print("\n# Request for Mode::Tools:\n\n"); +$user = $instructor + ->respond( + messages: "Our user Jason is 25 years old.", + responseModel: User::class, + prompt: "\nYour task is to extract correct and accurate data from the messages using provided tools.\n", + mode: Mode::Tools + ); +echo "\nRESPONSE:\n"; +dump($user); + +print("\n# Request for Mode::Json:\n\n"); +$user = $instructor + ->respond( + messages: "Our user Jason is 25 years old.", + responseModel: User::class, + prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n<|json_schema|>\n", + mode: Mode::Json + ); +echo "\nRESPONSE:\n"; +dump($user); + +print("\n# Request for Mode::MdJson:\n\n"); +$user = $instructor + ->respond( + messages: "Our user Jason is 25 years old.", + responseModel: User::class, + prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n", + mode: Mode::MdJson + ); +echo "\nRESPONSE:\n"; +dump($user); + +?> +``` diff --git a/examples/A02_Advanced/ProvidingExamples/run.php b/examples/A02_Advanced/ProvidingExamples/run.php index 2a9e4128..017be1dc 100644 --- a/examples/A02_Advanced/ProvidingExamples/run.php +++ b/examples/A02_Advanced/ProvidingExamples/run.php @@ -1,57 +1,58 @@ ---- -title: 'Providing example inputs and outputs' -docname: 'demonstrations' ---- - -## Overview - -To improve the results of LLM inference you can provide examples of the expected output. -This will help LLM to understand the context and the expected structure of the output. - -It is typically useful in the `Mode::Json` and `Mode::MdJson` modes, where the output -is expected to be a JSON object. - - -## Example - -```php -add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); - -use Cognesy\Instructor\Enums\Mode; -use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM; -use Cognesy\Instructor\Features\Core\Data\Example; -use Cognesy\Instructor\Instructor; - -class User { - public int $age; - public string $name; -} - -echo "\nREQUEST:\n"; -$user = (new Instructor) - ->onEvent(RequestSentToLLM::class, fn($event)=>dump($event->request->toMessages())) - ->request( - messages: "Our user Jason is 25 years old.", - responseModel: User::class, - examples: [ - new Example( - input: "John is 50 and works as a teacher.", - output: ['name' => 'John', 'age' => 50] - ), - new Example( - input: "We have recently hired Ian, who is 27 years old.", - output: ['name' => 'Ian', 'age' => 27], - template: "example input:\n<|input|>\noutput:\n```json\n<|output|>\n```\n", - ), - ], - mode: Mode::Json) - ->get(); - -echo "\nOUTPUT:\n"; -dump($user); -assert($user->name === 'Jason'); -assert($user->age === 25); -?> -``` +--- +title: 'Providing example inputs and outputs' +docname: 'demonstrations' +--- + +## Overview + +To improve the results of LLM inference you can provide examples of the expected output. +This will help LLM to understand the context and the expected structure of the output. + +It is typically useful in the `Mode::Json` and `Mode::MdJson` modes, where the output +is expected to be a JSON object. + + +## Example + +```php +add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); + +use Cognesy\Instructor\Enums\Mode; +use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM; +use Cognesy\Instructor\Features\Core\Data\Example; +use Cognesy\Instructor\Instructor; + +class User { + public int $age; + public string $name; +} + +echo "\nREQUEST:\n"; +$user = (new Instructor) + // let's dump the request data to see how examples are used in requests + ->onEvent(RequestSentToLLM::class, fn($event) => dump($event)) + ->request( + messages: "Our user Jason is 25 years old.", + responseModel: User::class, + examples: [ + new Example( + input: "John is 50 and works as a teacher.", + output: ['name' => 'John', 'age' => 50] + ), + new Example( + input: "We have recently hired Ian, who is 27 years old.", + output: ['name' => 'Ian', 'age' => 27], + template: "example input:\n<|input|>\noutput:\n```json\n<|output|>\n```\n", + ), + ], + mode: Mode::Json) + ->get(); + +echo "\nOUTPUT:\n"; +dump($user); +assert($user->name === 'Jason'); +assert($user->age === 25); +?> +``` diff --git a/examples/A05_Extras/LLMToolUse/run.php b/examples/A05_Extras/LLMToolUse/run.php new file mode 100644 index 00000000..17e8e213 --- /dev/null +++ b/examples/A05_Extras/LLMToolUse/run.php @@ -0,0 +1,209 @@ +--- +title: 'Inference and tool use' +docname: 'llm_tool_use' +--- + +## Overview + + +## Example + + +```php +add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); + +function add_numbers($a, $b) : int { + return $a + $b; +} + +function subtract_numbers($a, $b) : int { + return $a - $b; +} + +$tools = [ + [ + 'type' => 'function', + 'function' => [ + 'name' => 'add_numbers', + 'description' => 'Add two numbers', + 'parameters' => [ + 'type' => 'object', + 'description' => 'Numbers to add', + 'properties' => [ + 'a' => [ + 'type' => 'integer', + 'description' => 'First number', + ], + 'b' => [ + 'type' => 'integer', + 'description' => 'Second number', + ], + ], + 'required' => ['a', 'b'], + 'additionalProperties' => false, + ], + ], + ], + [ + 'type' => 'function', + 'function' => [ + 'name' => 'subtract_numbers', + 'description' => 'Subtract two numbers', + 'parameters' => [ + 'type' => 'object', + 'description' => 'Numbers to subtract', + 'properties' => [ + 'a' => [ + 'type' => 'integer', + 'description' => 'First number', + ], + 'b' => [ + 'type' => 'integer', + 'description' => 'Second number', + ], + ], + 'required' => ['a', 'b'], + 'additionalProperties' => false, + ], + ], + ] +]; + +$chat = (new Inference) + ->withConnection('openrouter'); + +$prompt = 'Add 2455 and 3558 then subtract 4344 from the result.'; + +$messages = [ + ['role' => 'user', 'content' => $prompt] +]; + +Debug::enable(); + +$toolUse = new ToolUse($chat, $tools); +$response = $toolUse->response($messages); + +class ToolUse { + private Inference $inference; + private array $tools; + private int $maxDepth; + private array $options; + private bool $parallelToolCalls; + private string|array $toolChoice; + /** @var LLMResponse[] */ + private array $responses = []; + + public function __construct( + Inference $chat, + array $tools, + int $maxDepth = 3, + array $options = [], + string|array $toolChoice = 'auto', + bool $parallelToolCalls = false + ) { + $this->inference = $chat; + $this->tools = $tools; + $this->maxDepth = $maxDepth; + $this->parallelToolCalls = $parallelToolCalls; + $this->toolChoice = $toolChoice; + $this->options = $options; + } + + public function response(string|array $messages) : LLMResponse { + $messages = match(true) { + is_string($messages) => [['role' => 'user', 'content' => $messages]], + is_array($messages) => $messages, + default => [] + }; + $chat = (new Messages)->appendMessages($messages); + return $this->tryGetResponse($chat); + } + + /** + * @return LLMResponse[] + */ + public function responses() : array { + return $this->responses; + } + + // INTERNAL ////////////////////////////////////////////// + + private function tryGetResponse(Messages $chat) : LLMResponse { + $response = $this->getResponse($chat->toArray()); + $this->responses[] = $response; + $depth = 0; + while ($this->tryContinue($response)) { + if ($depth++ >= $this->maxDepth) { + break; + } + $responseMessages = $this->makeToolsResponseMessages($response); + $chat->appendMessages($responseMessages); + $response = $this->getResponse($chat->toArray()); + $this->responses[] = $response; + } + return $response; + } + + private function getResponse(string|array $messages) : LLMResponse { + return $this->inference + ->create( + messages: $messages, + tools: $this->tools, + toolChoice: $this->toolChoice, + options: array_merge($this->options, ['parallel_tool_calls' => $this->parallelToolCalls]), + mode: Mode::Tools, + )->response(); + } + + private function makeToolsResponseMessages(LLMResponse $response) : array { + $messages = []; + $toolCalls = $response->toolCalls(); + $count = 0; + foreach ($toolCalls->all() as $toolCall) { + $function = $toolCall->name(); + $args = $toolCall->args(); + $result = $function(...$args); + $resultString = match(true) { + is_string($result) => $result, + is_array($result) => Json::encode($result), + is_object($result) => Json::encode($result), + default => (string) $result, + }; + $messages[] = [ + 'role' => 'assistant', + '_metadata' => [ + 'tool_calls' => [$toolCall->toToolCallArray()] + ] + ]; + $messages[] = [ + 'role' => 'tool', + 'content' => $resultString, + '_metadata' => [ + 'tool_call_id' => $toolCall->id(), + 'tool_name' => $toolCall->name(), + 'result' => $result + ] + ]; + $count++; + if ($this->parallelToolCalls && $count >= 0) { + break; + } + } + return $messages; + } + + private function tryContinue(LLMResponse $response) : bool { + return $response->hasToolCalls() || LLMFinishReason::ToolCalls->equals($response->finishReason()); + } +} \ No newline at end of file diff --git a/prompts/system/mode_tools.twig b/prompts/system/mode_tools.twig index 8b137891..3936d7e1 100644 --- a/prompts/system/mode_tools.twig +++ b/prompts/system/mode_tools.twig @@ -1 +1,45 @@ + + + {{ request.system }} + + {% if request.context_cache.prompt is not empty %} + + TASK: + {{ request.context_cache.prompt }} + + {% endif %} + + {% if request.context_cache.examples is not empty %} + + EXAMPLES: + {{ request.context_cache.examples }} + + {% endif %} + + {% if request.context_cache.input is not empty %} + + INPUT: + {{ request.context_cache.input }} + + {% endif %} + + {% if request.context_cache is not empty %} + + + INSTRUCTIONS: + + + {% endif %} + + + {% for attempt in request.attempts %} + + {{ attempt.llmResponse.content }} + + + Settings::get('llm', 'defaultRetryPrompt') + {{ attempt.errors }} + + {% endfor %} + \ No newline at end of file diff --git a/src/Extras/Prompt/Prompt.php b/src/Extras/Prompt/Prompt.php index 7e904578..6579c3e4 100644 --- a/src/Extras/Prompt/Prompt.php +++ b/src/Extras/Prompt/Prompt.php @@ -244,18 +244,44 @@ private function getMessageContent(XmlElement $element) : array { if ($child->tag() !== 'content') { continue; } - // check if content type is text, image or audio $type = $child->attribute('type', 'text'); $content[] = match($type) { - 'image' => ['type' => 'image_url', 'image_url' => ['url' => $child->content()]], - 'audio' => ['type' => 'input_audio', 'input_audio' => ['data' => $child->content(), 'format' => $child->attribute('format', 'mp3')]], - 'text' => ['type' => 'text', 'text' => $child->content()], - default => throw new InvalidArgumentException("Invalid content type: $type"), + 'image' => $this->makeImageContent($child), + 'audio' => $this->makeAudioContent($child), + default => $this->makeTextContent($child), }; } return $content; } + private function makeTextContent(XmlElement $child) : array { + $hasCacheControl = $child->attribute('cache', false); + return array_filter([ + 'type' => 'text', + 'text' => $child->content(), + 'cache_control' => $hasCacheControl ? ['type' => 'ephemeral'] : [] + ]); + } + + private function makeImageContent(XmlElement $child) : array { + return [ + 'type' => 'image_url', + 'image_url' => [ + 'url' => $child->content() + ] + ]; + } + + private function makeAudioContent(XmlElement $child) : array { + return [ + 'type' => 'input_audio', + 'input_audio' => [ + 'data' => $child->content(), + 'format' => $child->attribute('format', 'mp3') + ] + ]; + } + private function validateVariables(array $infoVars, array $templateVars, array $valueKeys) : array { $messages = []; foreach($infoVars as $var) { diff --git a/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php b/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php index afdb5b6d..25464fa5 100644 --- a/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php +++ b/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php @@ -1,84 +1,84 @@ -appendMessage(['role' => 'system', 'content' => $system]); - } - - // EXTRACT SYSTEM ROLE FROM MESSAGES - until first non-system message - foreach ($messages as $message) { - if ($message['role'] !== 'system') { - break; - } - $output->appendMessage(['role' => 'system', 'content' => $message['content']]); - } - - return $output; - } - - protected function makeMessages(string|array $messages) : Messages { - $output = new Messages(); - if (empty($messages)) { - return $output; - } - - // skip system messages - $index = 0; - foreach ($messages as $message) { - if ($message['role'] !== 'system') { - break; - } - $index++; - } - $output->appendMessages(array_slice($messages, $index)); - return $output; - } - - protected function makeExamples(array $examples) : Messages { - $messages = new Messages(); - if (empty($examples)) { - return $messages; - } - foreach ($examples as $item) { - $example = match(true) { - is_array($item) => Example::fromArray($item), - is_string($item) => Example::fromJson($item), - $item instanceof Example => $item, - default => throw new Exception('Invalid example type'), - }; - $messages->appendMessages($example->toMessages()); - } - return $messages; - } - - protected function makePrompt(string $prompt) : Message { - return new Message( - role: 'user', - content: $prompt - ?: $this->defaultPrompts[$this->request->mode()->value] - ?? '' - ); - } - - protected function makeInput(array|object|string $input) : Messages { - if (empty($input)) { - return new Messages(); - } - return Messages::fromInput($input); - } - - protected function makeJsonSchema() : array { - return $this->request->responseModel()?->toJsonSchema(); - } +appendMessage(['role' => 'system', 'content' => $system]); + } + + // EXTRACT SYSTEM ROLE FROM MESSAGES - until first non-system message + foreach ($messages as $message) { + if ($message['role'] !== 'system') { + break; + } + $output->appendMessage($message); + } + + return $output; + } + + protected function makeMessages(string|array $messages) : Messages { + $output = new Messages(); + if (empty($messages)) { + return $output; + } + + // skip system messages + $index = 0; + foreach ($messages as $message) { + if ($message['role'] !== 'system') { + break; + } + $index++; + } + $output->appendMessages(array_slice($messages, $index)); + return $output; + } + + protected function makeExamples(array $examples) : Messages { + $messages = new Messages(); + if (empty($examples)) { + return $messages; + } + foreach ($examples as $item) { + $example = match(true) { + is_array($item) => Example::fromArray($item), + is_string($item) => Example::fromJson($item), + $item instanceof Example => $item, + default => throw new Exception('Invalid example type'), + }; + $messages->appendMessages($example->toMessages()); + } + return $messages; + } + + protected function makePrompt(string $prompt) : Message { + return new Message( + role: 'user', + content: $prompt + ?: $this->defaultPrompts[$this->request->mode()->value] + ?? '' + ); + } + + protected function makeInput(array|object|string $input) : Messages { + if (empty($input)) { + return new Messages(); + } + return Messages::fromInput($input); + } + + protected function makeJsonSchema() : array { + return $this->request->responseModel()?->toJsonSchema(); + } } \ No newline at end of file diff --git a/src/Features/Core/PartialsGenerator.php b/src/Features/Core/PartialsGenerator.php index 5ecc4c31..de179afa 100644 --- a/src/Features/Core/PartialsGenerator.php +++ b/src/Features/Core/PartialsGenerator.php @@ -1,206 +1,206 @@ -toolCalls = new ToolCalls(); - $this->sequenceableHandler = new SequenceableHandler($events); - } - - public function resetPartialResponse() : void { - $this->previousHash = ''; - $this->responseText = ''; - $this->responseJson = ''; - $this->sequenceableHandler->reset(); - $this->toolCalls->reset(); - } - - /** - * @param Generator $stream - * @param ResponseModel $responseModel - * @return Generator - */ - public function getPartialResponses(Generator $stream, ResponseModel $responseModel) : Generator { - // reset state - $this->resetPartialResponse(); - - // receive data - /** @var PartialLLMResponse $partialResponse */ - foreach($stream as $partialResponse) { - $this->events->dispatch(new StreamedResponseReceived($partialResponse)); - // store partial response - $this->partialResponses[] = $partialResponse; - - // situation 1: new function call - $maybeToolName = $partialResponse->toolName; - // create next FC only if JSON buffer is not empty (which is the case for 1st iteration) - if ($maybeToolName) { - if (empty($this->responseJson)) { - $this->newToolCall($response->toolName ?? $responseModel->toolName()); - } else { - $this->finalizeToolCall($this->responseJson, $responseModel->toolName()); - $this->responseJson = ''; // reset json buffer - } - } - - // situation 2: new delta - $maybeArgumentChunk = $partialResponse->contentDelta; - if (empty($maybeArgumentChunk)) { - continue; - } - $this->events->dispatch(new ChunkReceived($maybeArgumentChunk)); - $this->responseText .= $maybeArgumentChunk; - $this->responseJson = Json::fromPartial($this->responseText)->toString(); - if (empty($this->responseJson)) { - continue; - } - if ($this->toolCalls->empty()) { - $this->newToolCall($responseModel->toolName()); - } - $result = $this->handleDelta($this->responseJson, $responseModel); - if ($result->isFailure()) { - continue; - } - $this->events->dispatch(new PartialJsonReceived($this->responseJson)); - - yield $partialResponse - ->withValue($result->unwrap()) - ->withContent($this->responseText); - } - $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse())); - - // finalize last function call - // check if there are any toolCalls - if ($this->toolCalls->count() === 0) { - throw new Exception('No tool calls found in the response'); - } - // finalize last function call - if ($this->toolCalls->count() > 0) { - $this->finalizeToolCall(Json::from($this->responseText)->toString(), $responseModel->toolName()); - } - // finalize sequenceable - $this->sequenceableHandler->finalize(); - } - - // INTERNAL //////////////////////////////////////////////////////// - - protected function handleDelta( - string $partialJson, - ResponseModel $responseModel - ) : Result { - return Chain::make() - ->through(fn() => $this->validatePartialResponse($partialJson, $responseModel, $this->preventJsonSchema, $this->matchToExpectedFields)) - ->tap(fn() => $this->events->dispatch(new PartialJsonReceived($partialJson))) - ->tap(fn() => $this->updateToolCall($partialJson, $responseModel->toolName())) - ->through(fn() => $this->tryGetPartialObject($partialJson, $responseModel)) - ->onFailure(fn($result) => $this->events->dispatch( - new PartialResponseGenerationFailed(Arrays::asArray($result->error())) - )) - ->then(fn($result) => $this->getChangedOnly($result)) - ->result(); - } - - protected function tryGetPartialObject( - string $partialJsonData, - ResponseModel $responseModel, - ) : Result { - return Chain::from(fn() => Json::fromPartial($partialJsonData)->toString()) - ->through(fn($json) => $this->responseDeserializer->deserialize($json, $responseModel, $this?->toolCalls->last()->name)) - ->through(fn($object) => $this->responseTransformer->transform($object)) - ->result(); - } - - protected function getChangedOnly(Result $result) : ?Result { - if ($result->isFailure()) { - return $result; - } - $partialObject = $result->unwrap(); - // we only want to send partial response if it's different from the previous one - $currentHash = hash('xxh3', Json::encode($partialObject)); - if ($this->previousHash == $currentHash) { - return Result::failure('No changes detected'); - } - $this->events->dispatch(new PartialResponseGenerated($partialObject)); - if (($partialObject instanceof Sequenceable)) { - $this->sequenceableHandler->update($partialObject); - } - $this->previousHash = $currentHash; - return $result; - } - - public function getCompleteResponse() : LLMResponse { - return LLMResponse::fromPartialResponses($this->partialResponses); - } - - public function lastPartialResponse() : PartialLLMResponse { - $index = count($this->partialResponses) - 1; - return $this->partialResponses[$index]; - } - - public function partialResponses() : array { - return $this->partialResponses; - } - - protected function newToolCall(string $name) : ToolCall { - $newToolCall = $this->toolCalls->create($name); - $this->events->dispatch(new StreamedToolCallStarted($newToolCall)); - return $newToolCall; - } - - protected function updateToolCall(string $responseJson, string $defaultName) : ToolCall { - $updatedToolCall = $this->toolCalls->updateLast($responseJson, $defaultName); - $this->events->dispatch(new StreamedToolCallUpdated($updatedToolCall)); - return $updatedToolCall; - } - - protected function finalizeToolCall(string $responseJson, string $defaultName) : ToolCall { - $finalizedToolCall = $this->toolCalls->finalizeLast($responseJson, $defaultName); - $this->events->dispatch(new StreamedToolCallCompleted($finalizedToolCall)); - return $finalizedToolCall; - } +toolCalls = new ToolCalls(); + $this->sequenceableHandler = new SequenceableHandler($events); + } + + public function resetPartialResponse() : void { + $this->previousHash = ''; + $this->responseText = ''; + $this->responseJson = ''; + $this->sequenceableHandler->reset(); + $this->toolCalls->reset(); + } + + /** + * @param Generator $stream + * @param ResponseModel $responseModel + * @return Generator + */ + public function getPartialResponses(Generator $stream, ResponseModel $responseModel) : Generator { + // reset state + $this->resetPartialResponse(); + + // receive data + /** @var PartialLLMResponse $partialResponse */ + foreach($stream as $partialResponse) { + $this->events->dispatch(new StreamedResponseReceived($partialResponse)); + // store partial response + $this->partialResponses[] = $partialResponse; + + // situation 1: new function call + $maybeToolName = $partialResponse->toolName; + // create next FC only if JSON buffer is not empty (which is the case for 1st iteration) + if ($maybeToolName) { + if (empty($this->responseJson)) { + $this->newToolCall($response->toolName ?? $responseModel->toolName()); + } else { + $this->finalizeToolCall($this->responseJson, $responseModel->toolName()); + $this->responseJson = ''; // reset json buffer + } + } + + // situation 2: new delta + $maybeArgumentChunk = $partialResponse->contentDelta; + if (empty($maybeArgumentChunk)) { + continue; + } + $this->events->dispatch(new ChunkReceived($maybeArgumentChunk)); + $this->responseText .= $maybeArgumentChunk; + $this->responseJson = Json::fromPartial($this->responseText)->toString(); + if (empty($this->responseJson)) { + continue; + } + if ($this->toolCalls->empty()) { + $this->newToolCall($responseModel->toolName()); + } + $result = $this->handleDelta($this->responseJson, $responseModel); + if ($result->isFailure()) { + continue; + } + $this->events->dispatch(new PartialJsonReceived($this->responseJson)); + + yield $partialResponse + ->withValue($result->unwrap()) + ->withContent($this->responseText); + } + $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse())); + + // finalize last function call + // check if there are any toolCalls + if ($this->toolCalls->count() === 0) { + throw new Exception('No tool calls found in the response'); + } + // finalize last function call + if ($this->toolCalls->count() > 0) { + $this->finalizeToolCall(Json::from($this->responseText)->toString(), $responseModel->toolName()); + } + // finalize sequenceable + $this->sequenceableHandler->finalize(); + } + + // INTERNAL //////////////////////////////////////////////////////// + + protected function handleDelta( + string $partialJson, + ResponseModel $responseModel + ) : Result { + return Chain::make() + ->through(fn() => $this->validatePartialResponse($partialJson, $responseModel, $this->preventJsonSchema, $this->matchToExpectedFields)) + ->tap(fn() => $this->events->dispatch(new PartialJsonReceived($partialJson))) + ->tap(fn() => $this->updateToolCall($partialJson, $responseModel->toolName())) + ->through(fn() => $this->tryGetPartialObject($partialJson, $responseModel)) + ->onFailure(fn($result) => $this->events->dispatch( + new PartialResponseGenerationFailed(Arrays::asArray($result->error())) + )) + ->then(fn($result) => $this->getChangedOnly($result)) + ->result(); + } + + protected function tryGetPartialObject( + string $partialJsonData, + ResponseModel $responseModel, + ) : Result { + return Chain::from(fn() => Json::fromPartial($partialJsonData)->toString()) + ->through(fn($json) => $this->responseDeserializer->deserialize($json, $responseModel, $this->toolCalls->last()?->name())) + ->through(fn($object) => $this->responseTransformer->transform($object)) + ->result(); + } + + protected function getChangedOnly(Result $result) : ?Result { + if ($result->isFailure()) { + return $result; + } + $partialObject = $result->unwrap(); + // we only want to send partial response if it's different from the previous one + $currentHash = hash('xxh3', Json::encode($partialObject)); + if ($this->previousHash == $currentHash) { + return Result::failure('No changes detected'); + } + $this->events->dispatch(new PartialResponseGenerated($partialObject)); + if (($partialObject instanceof Sequenceable)) { + $this->sequenceableHandler->update($partialObject); + } + $this->previousHash = $currentHash; + return $result; + } + + public function getCompleteResponse() : LLMResponse { + return LLMResponse::fromPartialResponses($this->partialResponses); + } + + public function lastPartialResponse() : PartialLLMResponse { + $index = count($this->partialResponses) - 1; + return $this->partialResponses[$index]; + } + + public function partialResponses() : array { + return $this->partialResponses; + } + + protected function newToolCall(string $name) : ToolCall { + $newToolCall = $this->toolCalls->add($name); + $this->events->dispatch(new StreamedToolCallStarted($newToolCall)); + return $newToolCall; + } + + protected function updateToolCall(string $responseJson, string $defaultName) : ToolCall { + $updatedToolCall = $this->toolCalls->updateLast($responseJson, $defaultName); + $this->events->dispatch(new StreamedToolCallUpdated($updatedToolCall)); + return $updatedToolCall; + } + + protected function finalizeToolCall(string $responseJson, string $defaultName) : ToolCall { + $finalizedToolCall = $this->toolCalls->finalizeLast($responseJson, $defaultName); + $this->events->dispatch(new StreamedToolCallCompleted($finalizedToolCall)); + return $finalizedToolCall; + } } \ No newline at end of file diff --git a/src/Features/Core/RequestHandler.php b/src/Features/Core/RequestHandler.php index ba91ea79..5d4356a2 100644 --- a/src/Features/Core/RequestHandler.php +++ b/src/Features/Core/RequestHandler.php @@ -1,154 +1,154 @@ -events = $events; - $this->retries = 0; - $this->errors = []; - } - - // PUBLIC ////////////////////////////////////////////////////////////// - - /** - * Generates response value - */ - public function responseFor(Request $request) : LLMResponse { - $processingResult = Result::failure("No response generated"); - while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - $llmResponse = $this->getInference($request)->response(); - $llmResponse->content = match($request->mode()) { - Mode::Text => $llmResponse->content, - Mode::Tools => $llmResponse->toolCalls->first()?->args - ?? $llmResponse->content // fallback if no tool calls - some LLMs return just a string - ?? '', - default => Json::from($llmResponse->content)->toString(), - }; - $partialResponses = []; - $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); - } - - $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); - - return $llmResponse->withValue($value); - } - - /** - * Yields response value versions based on streamed responses - * @param Request $request - * @return Generator - */ - public function streamResponseFor(Request $request) : Generator { - $processingResult = Result::failure("No response generated"); - while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - $stream = $this->getInference($request)->stream()->responses(); - yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel()); - - $llmResponse = $this->partialsGenerator->getCompleteResponse(); - $partialResponses = $this->partialsGenerator->partialResponses(); - $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); - } - - $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); - - yield $llmResponse->withValue($value); - } - - // INTERNAL /////////////////////////////////////////////////////////// - - protected function getInference(Request $request) : InferenceResponse { - $inference = new Inference( - connection: $this->connection, - httpClient: $this->httpClient, - driver: $this->driver, - events: $this->events, - ); - return $inference - ->create( - $request->toMessages(), - $request->model(), - $request->toolCallSchema(), - $request->toolChoice(), - $request->responseFormat(), - $request->options(), - $request->mode() - ); - } - - protected function processResponse(Request $request, LLMResponse $llmResponse, array $partialResponses) : Result { - // we have LLMResponse here - let's process it: deserialize, validate, transform - $processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel()); - - if ($processingResult->isFailure()) { - // retry - we have not managed to deserialize, validate or transform the response - $this->handleError($processingResult, $request, $llmResponse, $partialResponses); - } - - return $processingResult; - } - - protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed { - if ($processingResult->isFailure()) { - $this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors)); - throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors)); - } - - // get final value - $value = $processingResult->unwrap(); - // store response - $request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts - // notify on response generation - $this->events->dispatch(new ResponseGenerated($value)); - - return $value; - } - - protected function handleError(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : void { - $error = $processingResult->error(); - $this->errors = is_array($error) ? $error : [$error]; - - // store failed response - $request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts - $this->retries++; - if (!$this->maxRetriesReached($request)) { - $this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors)); - } - } - - protected function maxRetriesReached(Request $request) : bool { - return $this->retries > $request->maxRetries(); - } -} +events = $events; + $this->retries = 0; + $this->errors = []; + } + + // PUBLIC ////////////////////////////////////////////////////////////// + + /** + * Generates response value + */ + public function responseFor(Request $request) : LLMResponse { + $processingResult = Result::failure("No response generated"); + while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { + $llmResponse = $this->getInference($request)->response(); + $llmResponse->content = match($request->mode()) { + Mode::Text => $llmResponse->content, + Mode::Tools => $llmResponse->toolCalls()->first()?->argsAsJson() + ?? $llmResponse->content // fallback if no tool calls - some LLMs return just a string + ?? '', + default => Json::from($llmResponse->content)->toString(), + }; + $partialResponses = []; + $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); + } + + $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); + + return $llmResponse->withValue($value); + } + + /** + * Yields response value versions based on streamed responses + * @param Request $request + * @return Generator + */ + public function streamResponseFor(Request $request) : Generator { + $processingResult = Result::failure("No response generated"); + while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { + $stream = $this->getInference($request)->stream()->responses(); + yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel()); + + $llmResponse = $this->partialsGenerator->getCompleteResponse(); + $partialResponses = $this->partialsGenerator->partialResponses(); + $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); + } + + $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); + + yield $llmResponse->withValue($value); + } + + // INTERNAL /////////////////////////////////////////////////////////// + + protected function getInference(Request $request) : InferenceResponse { + $inference = new Inference( + connection: $this->connection, + httpClient: $this->httpClient, + driver: $this->driver, + events: $this->events, + ); + return $inference + ->create( + $request->toMessages(), + $request->model(), + $request->toolCallSchema(), + $request->toolChoice(), + $request->responseFormat(), + $request->options(), + $request->mode() + ); + } + + protected function processResponse(Request $request, LLMResponse $llmResponse, array $partialResponses) : Result { + // we have LLMResponse here - let's process it: deserialize, validate, transform + $processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel()); + + if ($processingResult->isFailure()) { + // retry - we have not managed to deserialize, validate or transform the response + $this->handleError($processingResult, $request, $llmResponse, $partialResponses); + } + + return $processingResult; + } + + protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed { + if ($processingResult->isFailure()) { + $this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors)); + throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors)); + } + + // get final value + $value = $processingResult->unwrap(); + // store response + $request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts + // notify on response generation + $this->events->dispatch(new ResponseGenerated($value)); + + return $value; + } + + protected function handleError(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : void { + $error = $processingResult->error(); + $this->errors = is_array($error) ? $error : [$error]; + + // store failed response + $request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts + $this->retries++; + if (!$this->maxRetriesReached($request)) { + $this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors)); + } + } + + protected function maxRetriesReached(Request $request) : bool { + return $this->retries > $request->maxRetries(); + } +} diff --git a/src/Features/LLM/Data/LLMResponse.php b/src/Features/LLM/Data/LLMResponse.php index e60ed294..00aeb91f 100644 --- a/src/Features/LLM/Data/LLMResponse.php +++ b/src/Features/LLM/Data/LLMResponse.php @@ -1,129 +1,125 @@ -usage = $usage ?? new Usage(); - } - - // STATIC //////////////////////////////////////////////// - - public static function fromPartialResponses(array $partialResponses) : LLMResponse { - return (new self)->makeFromPartialResponses($partialResponses); - } - - // PUBLIC //////////////////////////////////////////////// - - public function hasValue() : bool { - return $this->value !== null; - } - - public function withValue(mixed $value) : self { - $this->value = $value; - return $this; - } - - public function value() : mixed { - return $this->value; - } - - public function hasContent() : bool { - return $this->content !== ''; - } - - public function content() : string { - return $this->content; - } - - public function json(): Json { - return match(true) { - // TODO: what about tool calls? - $this->hasContent() => Json::from($this->content), - default => Json::none(), - }; - } - - public function hasToolCalls() : bool { - return !empty($this->toolCalls); - } - - public function usage() : Usage { - return $this->usage ?? new Usage(); - } - - // INTERNAL ////////////////////////////////////////////// - - /** - * @param PartialLLMResponse[] $partialResponses - * @return LLMResponse - */ - private function makeFromPartialResponses(array $partialResponses = []) : self { - if (empty($partialResponses)) { - return $this; - } - - $content = ''; - foreach($partialResponses as $partialResponse) { - if ($partialResponse === null) { - continue; - } - $content .= $partialResponse->contentDelta; - $this->responseData[] = $partialResponse->responseData; - $this->usage()->accumulate($partialResponse->usage); - $this->finishReason = $partialResponse->finishReason; - } - $this->content = $content; - - $tools = $this->makeTools($partialResponses); - if (!empty($tools)) { - $this->toolsData = $this->makeToolsData($tools); - $this->toolCalls = ToolCalls::fromArray($this->toolsData); - } - return $this; - } - - private function makeTools(array $partialResponses): array { - $tools = []; - $currentTool = ''; - foreach ($partialResponses as $partialResponse) { - if ($partialResponse === null) { - continue; - } - if (('' !== ($partialResponse->toolName ?? '')) - && ($currentTool !== ($partialResponse->toolName ?? ''))) { - $currentTool = $partialResponse->toolName ?? ''; - $tools[$currentTool] = ''; - } - if ('' !== $currentTool) { - if (('' !== ($partialResponse->toolArgs ?? ''))) { - $tools[$currentTool] .= $partialResponse->toolArgs ?? ''; - } - } - } - return $tools; - } - - private function makeToolsData(array $tools) : array { - $data = []; - foreach($tools as $tool => $args) { - $data[] = [ - 'name' => $tool, - 'arguments' => '' !== $args ? Json::decode($args) : [], - ]; - } - return $data; - } -} +usage = $usage ?? new Usage(); + } + + // STATIC //////////////////////////////////////////////// + + public static function fromPartialResponses(array $partialResponses) : LLMResponse { + return (new self)->makeFromPartialResponses($partialResponses); + } + + // PUBLIC //////////////////////////////////////////////// + + public function hasValue() : bool { + return $this->value !== null; + } + + public function withValue(mixed $value) : self { + $this->value = $value; + return $this; + } + + public function value() : mixed { + return $this->value; + } + + public function hasContent() : bool { + return $this->content !== ''; + } + + public function content() : string { + return $this->content; + } + + public function json(): Json { + return match(true) { + // TODO: what about tool calls? + $this->hasContent() => Json::from($this->content), + default => Json::none(), + }; + } + + public function hasToolCalls() : bool { + return !$this->toolCalls?->empty(); + } + + public function usage() : Usage { + return $this->usage ?? new Usage(); + } + + public function toolCalls() : ToolCalls { + return $this->toolCalls ?? new ToolCalls(); + } + + public function finishReason() : LLMFinishReason { + return LLMFinishReason::fromText($this->finishReason); + } + + // INTERNAL ////////////////////////////////////////////// + + /** + * @param PartialLLMResponse[] $partialResponses + * @return LLMResponse + */ + private function makeFromPartialResponses(array $partialResponses = []) : self { + if (empty($partialResponses)) { + return $this; + } + + $content = ''; + foreach($partialResponses as $partialResponse) { + if ($partialResponse === null) { + continue; + } + $content .= $partialResponse->contentDelta; + $this->responseData[] = $partialResponse->responseData; + $this->usage()->accumulate($partialResponse->usage); + $this->finishReason = $partialResponse->finishReason; + } + $this->content = $content; + + $tools = $this->makeTools($partialResponses); + if (!empty($tools)) { + $this->toolCalls = ToolCalls::fromArray($tools); + } + return $this; + } + + private function makeTools(array $partialResponses): array { + $tools = []; + $currentTool = ''; + foreach ($partialResponses as $partialResponse) { + if ($partialResponse === null) { + continue; + } + if (('' !== ($partialResponse->toolName ?? '')) + && ($currentTool !== ($partialResponse->toolName ?? ''))) { + $currentTool = $partialResponse->toolName ?? ''; + $tools[$currentTool] = ''; + } + if ('' !== $currentTool) { + if (('' !== ($partialResponse->toolArgs ?? ''))) { + $tools[$currentTool] .= $partialResponse->toolArgs ?? ''; + } + } + } + return $tools; + } +} diff --git a/src/Features/LLM/Data/PartialLLMResponse.php b/src/Features/LLM/Data/PartialLLMResponse.php index 9904b357..99683f58 100644 --- a/src/Features/LLM/Data/PartialLLMResponse.php +++ b/src/Features/LLM/Data/PartialLLMResponse.php @@ -1,64 +1,65 @@ -value !== null; - } - - public function withValue(mixed $value) : self { - $this->value = $value; - return $this; - } - - public function value() : mixed { - return $this->value; - } - - public function hasContent() : bool { - return $this->content !== ''; - } - - public function withContent(string $content) : self { - $this->content = $content; - return $this; - } - - public function content() : string { - return $this->content; - } - - public function json(): string { - if (!$this->hasContent()) { - return ''; - } - return Json::fromPartial($this->content)->toString(); - } - - public function withFinishReason(string $finishReason) : self { - $this->finishReason = $finishReason; - return $this; - } - - public function usage() : Usage { - return $this->usage ?? new Usage(); - } -} +value !== null; + } + + public function withValue(mixed $value) : self { + $this->value = $value; + return $this; + } + + public function value() : mixed { + return $this->value; + } + + public function hasContent() : bool { + return $this->content !== ''; + } + + public function withContent(string $content) : self { + $this->content = $content; + return $this; + } + + public function content() : string { + return $this->content; + } + + public function json(): string { + if (!$this->hasContent()) { + return ''; + } + return Json::fromPartial($this->content)->toString(); + } + + public function withFinishReason(string $finishReason) : self { + $this->finishReason = $finishReason; + return $this; + } + + public function usage() : Usage { + return $this->usage ?? new Usage(); + } +} diff --git a/src/Features/LLM/Data/ToolCall.php b/src/Features/LLM/Data/ToolCall.php index 10d1523a..69358d6e 100644 --- a/src/Features/LLM/Data/ToolCall.php +++ b/src/Features/LLM/Data/ToolCall.php @@ -1,25 +1,131 @@ - empty($toolCall['arguments']) ? '' : Json::encode($toolCall['arguments']), - is_string($toolCall['arguments'] ?? false) => $toolCall['arguments'] ?? '', - is_null($toolCall['arguments']) => '', - default => throw new \InvalidArgumentException('ToolCall args must be a string or an array') - } - ); - } -} +id = $id; + $this->name = $name; + $this->arguments = match(true) { + is_array($args) => $args, + is_string($args) => Json::from($args)->toArray(), + default => [] + }; + } + + public static function fromArray(array $toolCall) : ?ToolCall { + if (empty($toolCall['name'])) { + return null; + } + return new ToolCall( + name: $toolCall['name'] ?? '', + args: match(true) { + is_array($toolCall['arguments'] ?? false) => $toolCall['arguments'], + is_string($toolCall['arguments'] ?? false) => $toolCall['arguments'], + is_null($toolCall['arguments'] ?? null) => [], + default => throw new \InvalidArgumentException('ToolCall args must be a string or an array') + }, + id: $toolCall['id'] ?? '' + ); + } + + public function withId(string $id) : self { + $this->id = $id; + return $this; + } + + public function withName(string $name) : self { + $this->name = $name; + return $this; + } + + public function withArgs(string|array $args) : self { + $this->arguments = match(true) { + is_array($args) => $args, + is_string($args) => Json::from($args)->toArray(), + default => [] + }; + return $this; + } + + public function hasArgs() : bool { + return !empty($this->arguments); + } + + public function id() : string { + return $this->id; + } + + public function name() : string { + return $this->name; + } + + public function args() : array { + return $this->arguments; + } + + public function argsAsJson() : string { + return Json::encode($this->arguments); + } + + public function hasValue(string $key) : bool { + return isset($this->arguments[$key]); + } + + public function value(string $key, mixed $default = null) : mixed { + return $this->arguments[$key] ?? $default; + } + + public function intValue(string $key, int $default = 0) : int { + return (int) ($this->arguments[$key] ?? $default); + } + + public function boolValue(string $key, bool $default = false) : bool { + return (bool) ($this->arguments[$key] ?? $default); + } + + public function stringValue(string $key, string $default = '') : string { + return (string) ($this->arguments[$key] ?? $default); + } + + public function arrayValue(string $key, array $default = []) : array { + return (array) ($this->arguments[$key] ?? $default); + } + + public function objectValue(string $key, object $default = null) : object { + return (object) ($this->arguments[$key] ?? $default); + } + + public function floatValue(string $key, float $default = 0.0) : float { + return (float) ($this->arguments[$key] ?? $default); + } + + public function toArray() : array { + return [ + 'name' => $this->name, + 'arguments' => $this->arguments, + ]; + } + + public function toToolCallArray() : array { + return [ + 'id' => $this->id, + 'type' => 'function', + 'function' => [ + 'name' => $this->name, + 'arguments' => Json::encode($this->arguments), + ] + ]; + } +} diff --git a/src/Features/LLM/Data/ToolCalls.php b/src/Features/LLM/Data/ToolCalls.php index b658f2d1..bae8c556 100644 --- a/src/Features/LLM/Data/ToolCalls.php +++ b/src/Features/LLM/Data/ToolCalls.php @@ -1,86 +1,110 @@ -toolCalls = $toolCalls; - } - - public static function fromArray(array $toolCalls) : ToolCalls { - $list = []; - foreach ($toolCalls as $toolCall) { - $list[] = ToolCall::fromArray($toolCall); - } - return new ToolCalls($list); - } - - public static function fromMapper(array $toolCalls, callable $mapper) : ToolCalls { - $list = []; - foreach ($toolCalls as $toolCall) { - $list[] = $mapper($toolCall); - } - return new ToolCalls($list); - } - - public function count() : int { - return count($this->toolCalls); - } - - public function first() : ?ToolCall { - return $this->toolCalls[0] ?? null; - } - - public function last() : ?ToolCall { - if (empty($this->toolCalls)) { - return null; - } - return $this->toolCalls[count($this->toolCalls) - 1]; - } - - public function empty() : bool { - return empty($this->toolCalls); - } - - public function all() : array { - return $this->toolCalls; - } - - public function reset() : void { - $this->toolCalls = []; - } - - public function create(string $toolName, string $args = '') : ToolCall { - $newToolCall = new ToolCall( - name: $toolName, - args: $args - ); - $this->toolCalls[] = $newToolCall; - return $newToolCall; - } - - public function updateLast(string $responseJson, string $defaultName) : ToolCall { - $last = $this->last(); - if (empty($last)) { - return $this->create($defaultName, $responseJson); - } - $last->name = $last->name ?? $defaultName; - $last->args = $responseJson; - return $this->last(); - } - - public function finalizeLast(string $responseJson, string $defaultName) : ToolCall { - $last = $this->last(); - if (empty($last)) { - return $this->create($defaultName, $responseJson); - } - $last->name = $last->name ?? $defaultName; - $last->args = $responseJson; - return $this->last(); - } +toolCalls = $toolCalls; + } + + public static function fromArray(array $toolCalls) : ToolCalls { + $list = []; + foreach ($toolCalls as $key => $toolCall) { + $list[] = match(true) { + is_array($toolCall) => ToolCall::fromArray($toolCall), + is_object($toolCall) && $toolCall instanceof ToolCall => $toolCall, + is_string($toolCall) => new ToolCall($key, $toolCall), + default => throw new InvalidArgumentException('Cannot create ToolCall from provided data: ' . print_r($toolCall, true)) + }; + } + return new ToolCalls($list); + } + + public static function fromMapper(array $toolCalls, callable $mapper) : ToolCalls { + $list = []; + foreach ($toolCalls as $item) { + $toolCall = $mapper($item); + if ($toolCall instanceof ToolCall) { + $list[] = $toolCall; + } + } + return new ToolCalls($list); + } + + public function count() : int { + return count($this->toolCalls); + } + + public function first() : ?ToolCall { + return $this->toolCalls[0] ?? null; + } + + public function last() : ?ToolCall { + if (empty($this->toolCalls)) { + return null; + } + return $this->toolCalls[count($this->toolCalls) - 1]; + } + + public function empty() : bool { + return empty($this->toolCalls); + } + + /** + * @return ToolCall[] + */ + public function all() : array { + return $this->toolCalls; + } + + public function reset() : void { + $this->toolCalls = []; + } + + public function add(string $toolName, string $args = '') : ToolCall { + $newToolCall = new ToolCall( + name: $toolName, + args: $args + ); + $this->toolCalls[] = $newToolCall; + return $newToolCall; + } + + public function updateLast(string $responseJson, string $defaultName) : ToolCall { + $last = $this->last(); + if (empty($last)) { + return $this->add($defaultName, $responseJson); + } + $last->withName($last->name() ?? $defaultName); + $last->withArgs($responseJson); + return $this->last(); + } + + public function finalizeLast(string $responseJson, string $defaultName) : ToolCall { + $last = $this->last(); + if (empty($last)) { + return $this->add($defaultName, $responseJson); + } + $last->withName($last->name() ?? $defaultName); + $last->withArgs($responseJson); + return $this->last(); + } + + public function toArray() : array { + $list = []; + foreach ($this->toolCalls as $toolCall) { + $list[] = $toolCall->toArray(); + } + return $list; + } } \ No newline at end of file diff --git a/src/Features/LLM/Drivers/AnthropicDriver.php b/src/Features/LLM/Drivers/AnthropicDriver.php index b7b728de..38ab25f8 100644 --- a/src/Features/LLM/Drivers/AnthropicDriver.php +++ b/src/Features/LLM/Drivers/AnthropicDriver.php @@ -20,6 +20,8 @@ class AnthropicDriver implements CanHandleInference { + private bool $parallelToolCalls = false; + public function __construct( protected LLMConfig $config, protected ?CanHandleHttp $httpClient = null, @@ -72,18 +74,22 @@ public function getRequestBody( array $options = [], Mode $mode = Mode::Text, ) : array { - $request = array_filter(array_merge([ + $this->parallelToolCalls = $options['parallel_tool_calls'] ?? false; + unset($options['parallel_tool_calls']); + + $request = array_merge(array_filter([ 'model' => $model ?: $this->config->model, 'max_tokens' => $options['max_tokens'] ?? $this->config->maxTokens, 'system' => Messages::fromArray($messages) ->forRoles(['system']) ->toString(), - 'messages' => $this->toNativeMessages(Messages::fromArray($messages) - ->exceptRoles(['system']) - ->toMergedPerRole() - ->toArray() + 'messages' => $this->toNativeMessages( + Messages::fromArray($messages) + ->exceptRoles(['system']) + //->toMergedPerRole() + ->toArray() ), - ], $options)); + ]), $options); return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); } @@ -94,9 +100,6 @@ public function toLLMResponse(array $data): ?LLMResponse { return new LLMResponse( content: $this->makeContent($data), responseData: $data, -// toolName: $data['content'][0]['name'] ?? '', -// toolArgs: Json::encode($data['content'][0]['input'] ?? ''), - toolsData: $this->mapToolsData($data), finishReason: $data['stop_reason'] ?? '', toolCalls: $this->makeToolCalls($data), usage: $this->makeUsage($data), @@ -110,6 +113,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse { return new PartialLLMResponse( contentDelta: $this->makeContentDelta($data), responseData: $data, + toolId: $data['content_block']['id'] ?? '', toolName: $data['content_block']['name'] ?? '', toolArgs: $data['delta']['partial_json'] ?? '', finishReason: $data['delta']['stop_reason'] ?? $data['message']['stop_reason'] ?? '', @@ -159,27 +163,55 @@ private function toTools(array $tools) : array { private function toToolChoice(string|array $toolChoice, array $tools) : array|string { return match(true) { empty($tools) => '', + empty($toolChoice) => [ + 'type' => 'auto', + 'disable_parallel_tool_use' => !$this->parallelToolCalls, + ], is_array($toolChoice) => [ 'type' => 'tool', 'name' => $toolChoice['function']['name'], - ], - empty($toolChoice) => [ - 'type' => 'auto', + 'disable_parallel_tool_use' => !$this->parallelToolCalls, ], default => [ - 'type' => $toolChoice, + 'type' => $this->mapToolChoice($toolChoice), + 'disable_parallel_tool_use' => !$this->parallelToolCalls, ], }; } + protected function mapToolChoice(string $choice) : string { + return match($choice) { + 'auto' => 'auto', + 'required' => 'any', + default => 'auto', + }; + } + private function toNativeMessages(array $messages) : array { - return array_map( - fn($message) => [ - 'role' => $this->mapRole($message['role'] ?? 'user'), - 'content' => $this->toNativeContent($message['content']), - ], - $messages - ); + $list = []; + foreach ($messages as $message) { + $nativeMessage = $this->mapMessage($message); + if (empty($nativeMessage)) { + continue; + } + $list[] = $nativeMessage; + } + return $list; + } + + private function mapMessage(array $message) : array { + return match(true) { + ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message), + ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message), + default => $this->toNativeTextMessage($message), + }; + } + + private function toNativeTextMessage(array $message) : array { + return [ + 'role' => $this->mapRole($message['role'] ?? 'user'), + 'content' => $this->toNativeContent($message['content']), + ]; } private function mapRole(string $role) : string { @@ -201,10 +233,18 @@ private function toNativeContent(string|array $content) : string|array { private function contentPartToNative(array $contentPart) : array { $type = $contentPart['type'] ?? 'text'; - if ($type === 'image_url') { - $contentPart = $this->toNativeImage($contentPart); - } - return $contentPart; + return match($type) { + 'text' => $this->toNativeTextContent($contentPart), + 'image_url' => $this->toNativeImage($contentPart), + default => $contentPart, + }; + } + + private function toNativeTextContent(array $contentPart) : array { + return [ + 'type' => 'text', + 'text' => $contentPart['text'] ?? '', + ]; } private function toNativeImage(array $contentPart) : array { @@ -220,21 +260,42 @@ private function toNativeImage(array $contentPart) : array { ]; return $contentPart; } + + private function toNativeToolCall(array $message) : array { + return [ + 'role' => 'assistant', + 'content' => [[ + 'type' => 'tool_use', + 'id' => $message['_metadata']['tool_calls'][0]['id'] ?? '', + 'name' => $message['_metadata']['tool_calls'][0]['function']['name'] ?? '', + 'input' => Json::from($message['_metadata']['tool_calls'][0]['function']['arguments'] ?? '')->toArray(), + ]] + ]; + } + + private function toNativeToolResult(array $message) : array { + return [ + 'role' => 'user', + 'content' => [[ + 'type' => 'tool_result', + 'tool_use_id' => $message['_metadata']['tool_call_id'] ?? '', + 'content' => $message['content'] ?? '', + //'is_error' => false, + ]] + ]; + } + private function makeToolCalls(array $data) : ToolCalls { return ToolCalls::fromMapper(array_map( callback: fn(array $call) => $call, - array: $data['content'] ?? [] - ), fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['input'] ?? ''])); - } - - private function mapToolsData(array $data) : array { - return array_map( - fn($tool) => [ - 'name' => $tool['name'] ?? '', - 'arguments' => $tool['input'] ?? '', - ], - array_filter($data['content'] ?? [], fn($part) => 'tool_use' === ($part['type'] ?? '')) - ); + array: array_filter( + array: $data['content'] ?? [], + callback: fn($part) => 'tool_use' === ($part['type'] ?? '')) + ), fn($call) => ToolCall::fromArray([ + 'id' => $call['id'] ?? '', + 'name' => $call['name'] ?? '', + 'arguments' => $call['input'] ?? '' + ])); } private function makeContent(array $data) : string { @@ -297,4 +358,12 @@ private function makeUsage(array $data) : Usage { reasoningTokens: 0, ); } + + protected function excludeUnderscoredKeys(array $messages) : array { + $list = []; + foreach ($messages as $message) { + $list[] = array_filter($message, fn($value, $key) => !Str::startsWith($key, '_'), ARRAY_FILTER_USE_BOTH); + } + return $list; + } } diff --git a/src/Features/LLM/Drivers/AzureOpenAIDriver.php b/src/Features/LLM/Drivers/AzureOpenAIDriver.php index e56172f4..50064a90 100644 --- a/src/Features/LLM/Drivers/AzureOpenAIDriver.php +++ b/src/Features/LLM/Drivers/AzureOpenAIDriver.php @@ -1,36 +1,34 @@ - "{".$key."}", array_keys($this->config->metadata)), - replace: array_values($this->config->metadata), - subject: "{$this->config->apiUrl}{$this->config->endpoint}" - ) . $this->getUrlParams(); - } - - protected function getUrlParams(): string { - $params = array_filter([ - 'api-version' => $this->config->metadata['apiVersion'] ?? '', - ]); - if (!empty($params)) { - return '?' . http_build_query($params); - } - return ''; - } - - public function getRequestHeaders(): array { - return [ - 'Api-Key' => $this->config->apiKey, - 'Content-Type' => 'application/json', - ]; - } + "{".$key."}", array_keys($this->config->metadata)), + replace: array_values($this->config->metadata), + subject: "{$this->config->apiUrl}{$this->config->endpoint}" + ) . $this->getUrlParams(); + } + + protected function getUrlParams(): string { + $params = array_filter([ + 'api-version' => $this->config->metadata['apiVersion'] ?? '', + ]); + if (!empty($params)) { + return '?' . http_build_query($params); + } + return ''; + } + + public function getRequestHeaders(): array { + return [ + 'Api-Key' => $this->config->apiKey, + 'Content-Type' => 'application/json', + ]; + } } \ No newline at end of file diff --git a/src/Features/LLM/Drivers/CohereV1Driver.php b/src/Features/LLM/Drivers/CohereV1Driver.php index bf783d77..13693fb8 100644 --- a/src/Features/LLM/Drivers/CohereV1Driver.php +++ b/src/Features/LLM/Drivers/CohereV1Driver.php @@ -16,6 +16,7 @@ use Cognesy\Instructor\Features\LLM\InferenceRequest; use Cognesy\Instructor\Utils\Json\Json; use Cognesy\Instructor\Utils\Messages\Messages; +use Cognesy\Instructor\Utils\Str; class CohereV1Driver implements CanHandleInference { @@ -68,15 +69,18 @@ public function getRequestBody( array $options = [], Mode $mode = Mode::Text, ) : array { + unset($options['parallel_tool_calls']); + $system = ''; $chatHistory = []; + $messages = $this->excludeUnderscoredKeys($messages); - $request = array_filter(array_merge([ + $request = array_merge(array_filter([ 'model' => $model ?: $this->config->model, 'preamble' => $system, 'chat_history' => $chatHistory, 'message' => Messages::asString($messages), - ], $options)); + ]), $options); return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); } @@ -87,7 +91,7 @@ public function toLLMResponse(array $data): LLMResponse { return new LLMResponse( content: $this->makeContent($data), responseData: $data, - toolsData: $this->mapToolsData($data), + //: $this->map($data), finishReason: $data['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), usage: $this->makeUsage($data), @@ -98,6 +102,7 @@ public function toPartialLLMResponse(array $data) : PartialLLMResponse { return new PartialLLMResponse( contentDelta: $this->makeContentDelta($data), responseData: $data, + toolId: $this->makeToolId($data), toolName: $this->makeToolNameDelta($data), toolArgs: $this->makeToolArgsDelta($data), finishReason: $data['response']['finish_reason'] ?? $data['delta']['finish_reason'] ?? '', @@ -184,16 +189,6 @@ private function makeToolCalls(array $data) : ToolCalls { ); } - private function mapToolsData(array $data) : array { - return array_map( - fn($tool) => [ - 'name' => $tool['name'] ?? '', - 'arguments' => $tool['parameters'] ?? '', - ], - $data['tool_calls'] ?? [] - ); - } - private function makeContent(array $data) : string { return ($data['text'] ?? '') . (!empty($data['tool_calls']) ? ("\n" . Json::encode($data['tool_calls'])) @@ -208,12 +203,11 @@ private function makeContentDelta(array $data) : string { return $data['tool_call_delta']['parameters'] ?? $data['text'] ?? ''; } - private function makeToolArgsDelta(array $data) : string { + private function makeToolId(array $data) { if (!$this->isStreamChunk($data)) { return ''; } - $toolArgs = $data['tool_calls'][0]['parameters'] ?? ''; - return ('' === $toolArgs) ? '' : Json::encode($toolArgs); + return $data['tool_calls'][0]['id'] ?? ''; } private function makeToolNameDelta(array $data) : string { @@ -223,6 +217,14 @@ private function makeToolNameDelta(array $data) : string { return $data['tool_calls'][0]['name'] ?? ''; } + private function makeToolArgsDelta(array $data) : string { + if (!$this->isStreamChunk($data)) { + return ''; + } + $toolArgs = $data['tool_calls'][0]['parameters'] ?? ''; + return ('' === $toolArgs) ? '' : Json::encode($toolArgs); + } + private function isStreamChunk(array $data) : bool { return in_array(($data['event_type'] ?? ''), ['text-generation', 'tool-calls-chunk']); } @@ -254,4 +256,12 @@ private function makeUsage(array $data) : Usage { reasoningTokens: 0, ); } + + protected function excludeUnderscoredKeys(array $messages) : array { + $list = []; + foreach ($messages as $message) { + $list[] = array_filter($message, fn($value, $key) => !Str::startsWith($key, '_'), ARRAY_FILTER_USE_BOTH); + } + return $list; + } } diff --git a/src/Features/LLM/Drivers/CohereV2Driver.php b/src/Features/LLM/Drivers/CohereV2Driver.php index 5ebad31e..95b08598 100644 --- a/src/Features/LLM/Drivers/CohereV2Driver.php +++ b/src/Features/LLM/Drivers/CohereV2Driver.php @@ -1,174 +1,179 @@ - $model ?: $this->config->model, - 'max_tokens' => $this->config->maxTokens, - 'messages' => $messages, - ], $options)); - - return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); - } - - public function getRequestHeaders(): array { - $optional = [ - 'X-Client-Name' => $this->config->metadata['client_name'] ?? '', - ]; - return array_merge([ - 'Authorization' => "Bearer {$this->config->apiKey}", - 'Content-Type' => 'application/json', - ], $optional); - } - - // RESPONSE ////////////////////////////////////////////// - - public function toLLMResponse(array $data): LLMResponse { - return new LLMResponse( - content: $this->makeContent($data), - responseData: $data, - toolsData: $this->makeToolsData($data), - finishReason: $data['finish_reason'] ?? '', - toolCalls: $this->makeToolCalls($data), - usage: $this->makeUsage($data), - ); - } - - public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse { - if (empty($data)) { - return null; - } - return new PartialLLMResponse( - contentDelta: $this->makeContentDelta($data), - responseData: $data, - toolName: $data['delta']['message']['tool_calls']['function']['name'] ?? '', - toolArgs: $data['delta']['message']['tool_calls']['function']['arguments'] ?? '', - finishReason: $data['delta']['finish_reason'] ?? '', - usage: $this->makeUsage($data), - ); - } - - public function getData(string $data): string|bool { - if (!str_starts_with($data, 'data:')) { - return ''; - } - $data = trim(substr($data, 5)); - return match(true) { - $data === '[DONE]' => false, - default => $data, - }; - } - - // OVERRIDES - HELPERS /////////////////////////////////// - - protected function applyMode( - array $request, - Mode $mode, - array $tools, - string|array $toolChoice, - array $responseFormat - ) : array { - switch($mode) { - case Mode::Tools: - $request['tools'] = $this->removeDisallowedEntries($tools); - break; - case Mode::Json: - $request['response_format'] = $responseFormat; - break; - case Mode::JsonSchema: - $request['response_format'] = [ - 'type' => 'json_object', - 'schema' => $responseFormat['json_schema']['schema'], - ]; - break; - } - return $request; - } - - protected function removeDisallowedEntries(array $jsonSchema) : array { - return Arrays::removeRecursively($jsonSchema, [ - 'x-title', - 'x-php-class', - 'additionalProperties', - ]); - } - - private function makeContent(array $data): string { - $contentMsg = $data['message']['content'][0]['text'] ?? ''; - $contentFnArgs = $data['message']['tool_calls'][0]['function']['arguments'] ?? ''; - return match(true) { - !empty($contentMsg) => $contentMsg, - !empty($contentFnArgs) => $contentFnArgs, - default => '' - }; - } - - private function makeToolCalls(array $data) : ToolCalls { - return ToolCalls::fromArray(array_map( - callback: fn(array $call) => $call['function'] ?? [], - array: $data['message']['tool_calls'] ?? [] - )); - } - - private function makeToolsData(array $data) : array { - return array_map( - fn($tool) => [ - 'name' => $tool['function']['name'] ?? '', - 'arguments' => Json::decode($tool['function']['arguments']) ?? '', - ], - $data['message']['tool_calls'] ?? [] - ); - } - - private function makeContentDelta(array $data): string { - $deltaContent = match(true) { - ([] !== ($data['delta']['message']['content'] ?? [])) => $this->normalizeContent($data['delta']['message']['content']), - default => '', - }; - $deltaFnArgs = $data['delta']['message']['tool_calls']['function']['arguments'] ?? ''; - return match(true) { - '' !== $deltaContent => $deltaContent, - '' !== $deltaFnArgs => $deltaFnArgs, - default => '' - }; - } - - private function normalizeContent(array|string $content) : string { - return is_array($content) ? $content['text'] : $content; - } - - private function makeUsage(array $data) : Usage { - return new Usage( - inputTokens: $data['usage']['billed_units']['input_tokens'] - ?? $data['delta']['usage']['billed_units']['input_tokens'] - ?? 0, - outputTokens: $data['usage']['billed_units']['output_tokens'] - ?? $data['delta']['usage']['billed_units']['output_tokens'] - ?? 0, - cacheWriteTokens: 0, - cacheReadTokens: 0, - reasoningTokens: 0, - ); - } -} + $model ?: $this->config->model, + 'max_tokens' => $this->config->maxTokens, + 'messages' => $this->toNativeMessages($messages), + ]), $options); + + return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); + } + + public function getRequestHeaders(): array { + $optional = [ + 'X-Client-Name' => $this->config->metadata['client_name'] ?? '', + ]; + return array_merge([ + 'Authorization' => "Bearer {$this->config->apiKey}", + 'Content-Type' => 'application/json', + ], $optional); + } + + // RESPONSE ////////////////////////////////////////////// + + public function toLLMResponse(array $data): LLMResponse { + return new LLMResponse( + content: $this->makeContent($data), + responseData: $data, + finishReason: $data['finish_reason'] ?? '', + toolCalls: $this->makeToolCalls($data), + usage: $this->makeUsage($data), + ); + } + + public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse { + if (empty($data)) { + return null; + } + return new PartialLLMResponse( + contentDelta: $this->makeContentDelta($data), + responseData: $data, + toolId: $data['delta']['message']['tool_calls']['function']['id'] ?? '', + toolName: $data['delta']['message']['tool_calls']['function']['name'] ?? '', + toolArgs: $data['delta']['message']['tool_calls']['function']['arguments'] ?? '', + finishReason: $data['delta']['finish_reason'] ?? '', + usage: $this->makeUsage($data), + ); + } + + public function getData(string $data): string|bool { + if (!str_starts_with($data, 'data:')) { + return ''; + } + $data = trim(substr($data, 5)); + return match(true) { + $data === '[DONE]' => false, + default => $data, + }; + } + + // OVERRIDES - HELPERS /////////////////////////////////// + + protected function applyMode( + array $request, + Mode $mode, + array $tools, + string|array $toolChoice, + array $responseFormat + ) : array { + switch($mode) { + case Mode::Tools: + $request['tools'] = $this->removeDisallowedEntries($tools); + break; + case Mode::Json: + $request['response_format'] = $responseFormat; + break; + case Mode::JsonSchema: + $request['response_format'] = [ + 'type' => 'json_object', + 'schema' => $responseFormat['json_schema']['schema'], + ]; + break; + } + return $request; + } + + protected function removeDisallowedEntries(array $jsonSchema) : array { + return Arrays::removeRecursively($jsonSchema, [ + 'x-title', + 'x-php-class', + 'additionalProperties', + ]); + } + + private function makeContent(array $data): string { + $contentMsg = $data['message']['content'][0]['text'] ?? ''; + $contentFnArgs = $data['message']['tool_calls'][0]['function']['arguments'] ?? ''; + return match(true) { + !empty($contentMsg) => $contentMsg, + !empty($contentFnArgs) => $contentFnArgs, + default => '' + }; + } + + private function makeToolCalls(array $data) : ToolCalls { + return ToolCalls::fromArray(array_map( + callback: fn(array $call) => $this->makeToolCall($call), + array: $data['message']['tool_calls'] ?? [], + )); + } + + private function makeToolCall(array $data) : ?ToolCall { + if (empty($data)) { + return null; + } + if (!isset($data['function'])) { + return null; + } + if (!isset($data['id'])) { + return null; + } + return ToolCall::fromArray($data['function'] ?? [])->withId($data['id'] ?? ''); + } + + private function makeContentDelta(array $data): string { + $deltaContent = match(true) { + ([] !== ($data['delta']['message']['content'] ?? [])) => $this->normalizeContent($data['delta']['message']['content']), + default => '', + }; + $deltaFnArgs = $data['delta']['message']['tool_calls']['function']['arguments'] ?? ''; + return match(true) { + '' !== $deltaContent => $deltaContent, + '' !== $deltaFnArgs => $deltaFnArgs, + default => '' + }; + } + + private function normalizeContent(array|string $content) : string { + return is_array($content) ? $content['text'] : $content; + } + + private function makeUsage(array $data) : Usage { + return new Usage( + inputTokens: $data['usage']['billed_units']['input_tokens'] + ?? $data['delta']['usage']['billed_units']['input_tokens'] + ?? 0, + outputTokens: $data['usage']['billed_units']['output_tokens'] + ?? $data['delta']['usage']['billed_units']['output_tokens'] + ?? 0, + cacheWriteTokens: 0, + cacheReadTokens: 0, + reasoningTokens: 0, + ); + } +} diff --git a/src/Features/LLM/Drivers/GeminiDriver.php b/src/Features/LLM/Drivers/GeminiDriver.php index d218bb56..0b4fb0ec 100644 --- a/src/Features/LLM/Drivers/GeminiDriver.php +++ b/src/Features/LLM/Drivers/GeminiDriver.php @@ -100,9 +100,6 @@ public function toLLMResponse(array $data): ?LLMResponse { return new LLMResponse( content: $this->makeContent($data), responseData: $data, -// toolName: $data['candidates'][0]['content']['parts'][0]['functionCall']['name'] ?? '', -// toolArgs: Json::encode($data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? []), - toolsData: $this->mapToolsData($data), finishReason: $data['candidates'][0]['finishReason'] ?? '', toolCalls: $this->makeToolCalls($data), usage: $this->makeUsage($data), @@ -116,6 +113,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse { return new PartialLLMResponse( contentDelta: $this->makeContentDelta($data), responseData: $data, + toolId: $data['candidates'][0]['id'] ?? '', toolName: $this->makeToolName($data), toolArgs: $this->makeToolArgs($data), finishReason: $data['candidates'][0]['finishReason'] ?? '', @@ -145,10 +143,12 @@ private function toSystem(array $messages) : array { } private function toMessages(array $messages) : array { - return $this->toNativeMessages(Messages::fromArray($messages) - ->exceptRoles(['system']) - //->toMergedPerRole() - ->toArray()); + return $this->toNativeMessages( + Messages::fromArray($messages) + ->exceptRoles(['system']) + //->toMergedPerRole() + ->toArray() + ); } protected function toOptions( @@ -172,12 +172,13 @@ protected function toTools(array $tools) : array { )]; } - protected function toToolChoice(array $toolChoice): string|array { + protected function toToolChoice(string|array $toolChoice): string|array { return match(true) { empty($toolChoice) => ["function_calling_config" => ["mode" => "ANY"]], + is_string($toolChoice) => ["function_calling_config" => ["mode" => $this->mapToolChoice($toolChoice)]], is_array($toolChoice) => [ "function_calling_config" => array_filter([ - "mode" => "ANY", + "mode" => $this->mapToolChoice($toolChoice['mode'] ?? "ANY"), "allowed_function_names" => $toolChoice['function']['name'] ?? [], ]), ], @@ -185,6 +186,15 @@ protected function toToolChoice(array $toolChoice): string|array { }; } + protected function mapToolChoice(string $choice) : string { + return match($choice) { + 'auto' => 'AUTO', + 'required' => 'ANY', + 'none' => 'NONE', + default => 'ANY', + }; + } + protected function toResponseMimeType(Mode $mode): string { return match($mode) { Mode::Text => "text/plain", @@ -211,18 +221,69 @@ protected function removeDisallowedEntries(array $jsonSchema) : array { ]); } - protected function toNativeMessages(string|array $messages) : array { - if (is_string($messages)) { - return [["text" => $messages]]; - } - $transformed = []; + private function toNativeMessages(array $messages) : array { + $list = []; foreach ($messages as $message) { - $transformed[] = [ - 'role' => $this->mapRole($message['role']), - 'parts' => $this->contentPartsToNative($message['content']), - ]; + $nativeMessage = $this->mapMessage($message); + if (empty($nativeMessage)) { + continue; + } + $list[] = $nativeMessage; } - return $transformed; + return $list; + } + + private function mapMessage(array $message) : array { + return match(true) { + ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message), + ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message), + default => $this->toNativeTextMessage($message), + }; + } + + private function toNativeTextMessage(array $message) : array { + return [ + 'role' => $this->mapRole($message['role'] ?? 'user'), + 'parts' => $this->toNativeContentParts($message['content']), + ]; + } + + private function toNativeToolCall(array $message) : array { + return [ + 'role' => 'model', + 'parts' => array_map( + callback: fn($call) => $this->toNativeToolCallPart($call), + array: $message['_metadata']['tool_calls'] ?? [] + ), + ]; + } + + private function toNativeToolCallPart(array $call) : array { + return [ + 'functionCall' => [ + 'name' => $call['function']['name'] ?? '', + 'args' => Json::from($call['function']['arguments'])->toArray() ?? [], + ] + ]; + } + + private function toNativeToolResult(array $message) : array { + $content = match(true) { + is_array($message['_metadata']['result'] ?? '') => Json::from($message['_metadata']['result'] ?? '')->toArray(), + default => $message['content'], + }; + return [ + 'role' => 'user', + 'parts' => [[ + 'functionResponse' => [ + 'name' => $message['_metadata']['tool_name'] ?? '', + 'response' => [ + 'name' => $message['_metadata']['tool_name'] ?? '', + 'content' => $content, + ], + ], + ]], + ]; } protected function mapRole(string $role) : string { @@ -230,7 +291,7 @@ protected function mapRole(string $role) : string { return $roles[$role] ?? $role; } - protected function contentPartsToNative(string|array $contentParts) : array { + protected function toNativeContentParts(string|array $contentParts) : array { if (is_string($contentParts)) { return [["text" => $contentParts]]; } @@ -243,18 +304,27 @@ protected function contentPartsToNative(string|array $contentParts) : array { protected function contentPartToNative(array $contentPart) : array { $type = $contentPart['type'] ?? 'text'; - return match($type) { - 'text' => ['text' => $contentPart['text'] ?? ''], - 'image_url' => [ - 'inlineData' => [ - 'mimeType' => Str::between($contentPart['image_url']['url'], 'data:', ';base64,'), - 'data' => Str::after($contentPart['image_url']['url'], ';base64,'), - ], - ], + return match(true) { + ($type === 'text') => $this->makeTextContentPart($contentPart), + ($type === 'image_url') => $this->makeImageUrlContentPart($contentPart), + default => $contentPart, }; } + private function makeTextContentPart(array $contentPart) : array { + return ['text' => $contentPart['text'] ?? '']; + } + + private function makeImageUrlContentPart(array $contentPart) : array { + return [ + 'inlineData' => [ + 'mimeType' => Str::between($contentPart['image_url']['url'], 'data:', ';base64,'), + 'data' => Str::after($contentPart['image_url']['url'], ';base64,'), + ], + ]; + } + private function makeToolCalls(array $data) : ToolCalls { return ToolCalls::fromMapper(array_map( callback: fn(array $call) => $call['functionCall'] ?? [], @@ -262,16 +332,6 @@ private function makeToolCalls(array $data) : ToolCalls { ), fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['args'] ?? ''])); } - private function mapToolsData(array $data) : array { - return array_map( - fn($tool) => [ - 'name' => $tool['functionCall']['name'] ?? '', - 'arguments' => $tool['functionCall']['args'] ?? '', - ], - $data['candidates'][0]['content']['parts'] ?? [] - ); - } - private function makeContent(array $data) : string { return $data['candidates'][0]['content']['parts'][0]['text'] ?? Json::encode($data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? '') diff --git a/src/Features/LLM/Drivers/GrokDriver.php b/src/Features/LLM/Drivers/GrokDriver.php new file mode 100644 index 00000000..21d2755c --- /dev/null +++ b/src/Features/LLM/Drivers/GrokDriver.php @@ -0,0 +1,14 @@ + 'assistant', + 'content' => $message['content'] ?? 'I\'m calling tool: ' . $message['_metadata']['tool_calls'][0]['function']['name'], + 'tool_calls' => $message['_metadata']['tool_calls'] ?? [], + ]; + } +} \ No newline at end of file diff --git a/src/Features/LLM/Drivers/MistralDriver.php b/src/Features/LLM/Drivers/MistralDriver.php index f8099ccc..3a51af92 100644 --- a/src/Features/LLM/Drivers/MistralDriver.php +++ b/src/Features/LLM/Drivers/MistralDriver.php @@ -1,59 +1,73 @@ - $model ?: $this->config->model, - 'max_tokens' => $this->config->maxTokens, - 'messages' => $messages, - ], $options)); - - return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); - } - - // PRIVATE ////////////////////////////////////////////// - - private function applyMode( - array $request, - Mode $mode, - array $tools, - string|array $toolChoice, - array $responseFormat - ) : array { - switch($mode) { - case Mode::Tools: - $request['tools'] = $this->removeDisallowedEntries($tools); - $request['tool_choice'] = 'any'; - break; - case Mode::Json: - case Mode::JsonSchema: - $request['response_format'] = ['type' => 'json_object']; - break; - } - return $request; - } - - private function removeDisallowedEntries(array $jsonSchema) : array { - return Arrays::removeRecursively($jsonSchema, [ - 'x-title', - //'description', - 'x-php-class', - 'additionalProperties', - ]); - } + $model ?: $this->config->model, + 'max_tokens' => $this->config->maxTokens, + 'messages' => $this->toNativeMessages($messages), + ]), $options); + + return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); + } + + // PRIVATE ////////////////////////////////////////////// + + private function applyMode( + array $request, + Mode $mode, + array $tools, + string|array $toolChoice, + array $responseFormat + ) : array { + switch($mode) { + case Mode::Tools: + $request['tools'] = $this->removeDisallowedEntries($tools); + $request['tool_choice'] = $this->toToolChoice($tools, $toolChoice); + break; + case Mode::Json: + case Mode::JsonSchema: + $request['response_format'] = ['type' => 'json_object']; + break; + } + return $request; + } + + private function toToolChoice(array $tools, array|string $toolChoice) : array|string { + return match(true) { + empty($tools) => '', + empty($toolChoice) => 'auto', + is_array($toolChoice) => [ + 'type' => 'function', + 'name' => $toolChoice['function']['name'], + ], + default => $toolChoice, + }; + } + + private function removeDisallowedEntries(array $jsonSchema) : array { + return Arrays::removeRecursively($jsonSchema, [ + 'x-title', + //'description', + 'x-php-class', + 'additionalProperties', + ]); + } } \ No newline at end of file diff --git a/src/Features/LLM/Drivers/OpenAICompatibleDriver.php b/src/Features/LLM/Drivers/OpenAICompatibleDriver.php index 7a367299..adb64e2a 100644 --- a/src/Features/LLM/Drivers/OpenAICompatibleDriver.php +++ b/src/Features/LLM/Drivers/OpenAICompatibleDriver.php @@ -1,63 +1,75 @@ - $model ?: $this->config->model, - 'max_tokens' => $this->config->maxTokens, - 'messages' => $messages, - ], $options)); - - return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); - } - - // OVERRIDES - HELPERS /////////////////////////////////// - - protected function applyMode( - array $request, - Mode $mode, - array $tools, - string|array $toolChoice, - array $responseFormat - ) : array { - switch($mode) { - case Mode::Tools: - $request['tools'] = $this->removeDisallowedEntries($tools); - $request['tool_choice'] = 'auto'; - break; - case Mode::Json: - $request['response_format'] = $responseFormat; - break; - case Mode::JsonSchema: - $request['response_format'] = [ - 'type' => 'json_object', - 'schema' => $responseFormat['json_schema']['schema'], - ]; - break; - } - return $request; - } - - protected function removeDisallowedEntries(array $jsonSchema) : array { - return Arrays::removeRecursively($jsonSchema, [ - 'title', - 'x-php-class', - 'additionalProperties', - ]); - } -} + $model ?: $this->config->model, + 'max_tokens' => $this->config->maxTokens, + 'messages' => $this->toNativeMessages($messages), + ]), $options); + + return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat); + } + + // OVERRIDES - HELPERS /////////////////////////////////// + + protected function applyMode( + array $request, + Mode $mode, + array $tools, + string|array $toolChoice, + array $responseFormat + ) : array { + switch($mode) { + case Mode::Tools: + $request['tools'] = $this->removeDisallowedEntries($tools); + $request['tool_choice'] = $this->toToolChoice($tools, $toolChoice); + break; + case Mode::Json: + $request['response_format'] = $responseFormat; + break; + case Mode::JsonSchema: + $request['response_format'] = [ + 'type' => 'json_object', + 'schema' => $responseFormat['json_schema']['schema'], + ]; + break; + } + return $request; + } + + protected function removeDisallowedEntries(array $jsonSchema) : array { + return Arrays::removeRecursively($jsonSchema, [ + 'title', + 'x-php-class', + 'additionalProperties', + ]); + } + + private function toToolChoice(array $tools, array|string $toolChoice) : array|string { + return match(true) { + empty($tools) => '', + empty($toolChoice) => 'auto', + is_array($toolChoice) => [ + 'type' => 'function', + 'name' => $toolChoice['function']['name'], + ], + default => $toolChoice, + }; + } +} diff --git a/src/Features/LLM/Drivers/OpenAIDriver.php b/src/Features/LLM/Drivers/OpenAIDriver.php index cdb0afd1..9bb4ee45 100644 --- a/src/Features/LLM/Drivers/OpenAIDriver.php +++ b/src/Features/LLM/Drivers/OpenAIDriver.php @@ -11,10 +11,11 @@ use Cognesy\Instructor\Features\LLM\Data\LLMConfig; use Cognesy\Instructor\Features\LLM\Data\LLMResponse; use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse; +use Cognesy\Instructor\Features\LLM\Data\ToolCall; use Cognesy\Instructor\Features\LLM\Data\ToolCalls; use Cognesy\Instructor\Features\LLM\Data\Usage; use Cognesy\Instructor\Features\LLM\InferenceRequest; -use Cognesy\Instructor\Utils\Json\Json; +use Cognesy\Instructor\Utils\Str; class OpenAIDriver implements CanHandleInference { @@ -71,11 +72,11 @@ public function getRequestBody( array $options = [], Mode $mode = Mode::Text, ) : array { - $request = array_filter(array_merge([ + $request = array_merge(array_filter([ 'model' => $model ?: $this->config->model, 'max_tokens' => $this->config->maxTokens, - 'messages' => $messages, - ], $options)); + 'messages' => $this->toNativeMessages($messages), + ]), $options); if ($options['stream'] ?? false) { $request['stream_options']['include_usage'] = true; @@ -90,7 +91,6 @@ public function toLLMResponse(array $data): ?LLMResponse { return new LLMResponse( content: $this->makeContent($data), responseData: $data, - toolsData: $this->makeToolsData($data), finishReason: $data['choices'][0]['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), usage: $this->makeUsage($data), @@ -104,6 +104,7 @@ public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse { return new PartialLLMResponse( contentDelta: $this->makeContentDelta($data), responseData: $data, + toolId: $this->makeToolId($data), toolName: $this->makeToolNameDelta($data), toolArgs: $this->makeToolArgsDelta($data), finishReason: $data['choices'][0]['finish_reason'] ?? '', @@ -165,19 +166,22 @@ private function withCachedContext(InferenceRequest $request): InferenceRequest private function makeToolCalls(array $data) : ToolCalls { return ToolCalls::fromArray(array_map( - callback: fn(array $call) => $call['function'] ?? [], + callback: fn(array $call) => $this->makeToolCall($call), array: $data['choices'][0]['message']['tool_calls'] ?? [] )); } - private function makeToolsData(array $data) : array { - return array_map( - fn($tool) => [ - 'name' => $tool['function']['name'] ?? '', - 'arguments' => Json::decode($tool['function']['arguments']) ?? '', - ], - $data['choices'][0]['message']['tool_calls'] ?? [] - ); + private function makeToolCall(array $data) : ?ToolCall { + if (empty($data)) { + return null; + } + if (!isset($data['function'])) { + return null; + } + if (!isset($data['id'])) { + return null; + } + return ToolCall::fromArray($data['function'])?->withId($data['id']); } private function makeContent(array $data): string { @@ -200,6 +204,10 @@ private function makeContentDelta(array $data): string { }; } + private function makeToolId(array $data) : string { + return $data['choices'][0]['delta']['tool_calls'][0]['id'] ?? ''; + } + private function makeToolNameDelta(array $data) : string { return $data['choices'][0]['delta']['tool_calls'][0]['function']['name'] ?? ''; } @@ -221,4 +229,39 @@ private function makeUsage(array $data): Usage { reasoningTokens: $data['usage']['prompt_tokens_details']['reasoning_tokens'] ?? 0, ); } + + protected function toNativeMessages(array $messages) : array { + $list = []; + foreach ($messages as $message) { + $nativeMessage = $this->mapMessage($message); + if (empty($nativeMessage)) { + continue; + } + $list[] = $nativeMessage; + } + return $list; + } + + protected function mapMessage(array $message) : array { + return match(true) { + ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message), + ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message), + default => $message, + }; + } + + protected function toNativeToolCall(array $message) : array { + return [ + 'role' => 'assistant', + 'tool_calls' => $message['_metadata']['tool_calls'] ?? [], + ]; + } + + protected function toNativeToolResult(array $message) : array { + return [ + 'role' => 'tool', + 'tool_call_id' => $message['_metadata']['tool_call_id'] ?? '', + 'content' => $message['content'] ?? '', + ]; + } } diff --git a/src/Features/LLM/Enums/LLMFinishReason.php b/src/Features/LLM/Enums/LLMFinishReason.php index 40f2801f..8988c231 100644 --- a/src/Features/LLM/Enums/LLMFinishReason.php +++ b/src/Features/LLM/Enums/LLMFinishReason.php @@ -1,38 +1,48 @@ - self::Stop, - 'COMPLETE' => self::Stop, - 'stop' => self::Stop, - 'STOP' => self::Stop, - 'max_tokens' => self::Length, - 'MAX_TOKENS' => self::Length, - 'length' => self::Length, - 'model_length' => self::Length, - 'SAFETY' => self::ContentFilter, - 'RECITATION' => self::ContentFilter, - 'LANGUAGE' => self::ContentFilter, - 'BLOCKLIST' => self::ContentFilter, - 'PROHIBITED_CONTENT' => self::ContentFilter, - 'SPII' => self::ContentFilter, - 'error' => self::Error, - 'MALFORMED_FUNCTION_CALL' => self::Error, - 'tool_calls' => self::ToolCalls, - 'FINISH_REASON_UNSPECIFIED' => self::Other, - 'OTHER' => self::Other, - default => self::Other, - }; - } -} + ($this->value === $reason->value), + is_string($reason) => ($this->value === $reason), + default => false, + }; + } + + public static function fromText(string $text) : LLMFinishReason { + return match ($text) { + 'BLOCKLIST' => self::ContentFilter, + 'COMPLETE' => self::Stop, + 'error' => self::Error, + 'FINISH_REASON_UNSPECIFIED' => self::Other, + 'LANGUAGE' => self::ContentFilter, + 'length' => self::Length, + 'MALFORMED_FUNCTION_CALL' => self::Error, + 'max_tokens' => self::Length, + 'MAX_TOKENS' => self::Length, + 'model_length' => self::Length, + 'OTHER' => self::Other, + 'PROHIBITED_CONTENT' => self::ContentFilter, + 'RECITATION' => self::ContentFilter, + 'SAFETY' => self::ContentFilter, + 'SPII' => self::ContentFilter, + 'stop' => self::Stop, + 'STOP' => self::Stop, + 'stop_sequence' => self::Stop, + 'TOOL_CALL' => self::ToolCalls, + 'tool_calls' => self::ToolCalls, + 'tool_use' => self::ToolCalls, + default => self::Other, + }; + } +} diff --git a/src/Features/LLM/Enums/LLMProviderType.php b/src/Features/LLM/Enums/LLMProviderType.php index bf9ff8b0..62e18e94 100644 --- a/src/Features/LLM/Enums/LLMProviderType.php +++ b/src/Features/LLM/Enums/LLMProviderType.php @@ -1,28 +1,29 @@ -value === $clientType->value; - } -} +value === $clientType->value; + } +} diff --git a/src/Features/LLM/Inference.php b/src/Features/LLM/Inference.php index 7c6d5e7b..7103ff7b 100644 --- a/src/Features/LLM/Inference.php +++ b/src/Features/LLM/Inference.php @@ -14,6 +14,7 @@ use Cognesy\Instructor\Features\LLM\Drivers\CohereV1Driver; use Cognesy\Instructor\Features\LLM\Drivers\CohereV2Driver; use Cognesy\Instructor\Features\LLM\Drivers\GeminiDriver; +use Cognesy\Instructor\Features\LLM\Drivers\GrokDriver; use Cognesy\Instructor\Features\LLM\Drivers\MistralDriver; use Cognesy\Instructor\Features\LLM\Drivers\OpenAICompatibleDriver; use Cognesy\Instructor\Features\LLM\Drivers\OpenAIDriver; @@ -232,10 +233,11 @@ protected function makeDriver(LLMConfig $config, CanHandleHttp $httpClient): Can LLMProviderType::CohereV1 => new CohereV1Driver($config, $httpClient, $this->events), LLMProviderType::CohereV2 => new CohereV2Driver($config, $httpClient, $this->events), LLMProviderType::Gemini => new GeminiDriver($config, $httpClient, $this->events), + LLMProviderType::Grok => new GrokDriver($config, $httpClient, $this->events), LLMProviderType::Mistral => new MistralDriver($config, $httpClient, $this->events), LLMProviderType::OpenAI => new OpenAIDriver($config, $httpClient, $this->events), LLMProviderType::Fireworks, - LLMProviderType::Grok, + LLMProviderType::GeminiOAI, LLMProviderType::Groq, LLMProviderType::Ollama, LLMProviderType::OpenAICompatible, diff --git a/src/Features/LLM/InferenceRequest.php b/src/Features/LLM/InferenceRequest.php index e0aa2695..a72fac52 100644 --- a/src/Features/LLM/InferenceRequest.php +++ b/src/Features/LLM/InferenceRequest.php @@ -1,81 +1,79 @@ -cachedContext = $cachedContext; - - $this->model = $model; - $this->options = $options; - $this->mode = $mode; - - $this->messages = match(true) { - is_string($messages) => [['role' => 'user', 'content' => $messages]], - default => $messages, - }; - - if ($mode->is(Mode::Tools)) { - $this->tools = $tools; - $this->toolChoice = $toolChoice; - } elseif ($mode->is(Mode::Json)) { - $this->responseFormat = [ - 'type' => 'json_object', - 'schema' => $responseFormat['schema'] ?? [], - ]; - } elseif ($mode->is(Mode::JsonSchema)) { - $this->responseFormat = [ - 'type' => 'json_schema', - 'json_schema' => [ - 'name' => $responseFormat['json_schema']['name'] ?? 'schema', - 'schema' => $responseFormat['json_schema']['schema'] ?? [], - 'strict' => $responseFormat['json_schema']['strict'] ?? true, - ], - ]; - } elseif ($mode->is([Mode::Text, Mode::MdJson])) { - $this->tools = []; - $this->toolChoice = []; - $this->responseFormat = []; - } - } - - public function toArray() : array { - return [ -// 'uuid' => $this->uuid, -// 'metadata' => $this->metadata, - 'messages' => $this->messages, - 'model' => $this->model, - 'tools' => $this->tools, - 'tool_choice' => $this->toolChoice, - 'response_format' => $this->responseFormat, - 'options' => $this->options, - 'mode' => $this->mode->value, - ]; - } -} +cachedContext = $cachedContext; + + $this->model = $model; + $this->options = $options; + $this->mode = $mode; + + $this->messages = match(true) { + is_string($messages) => [['role' => 'user', 'content' => $messages]], + default => $messages, + }; + + if ($mode->is(Mode::Tools)) { + $this->tools = $tools; + $this->toolChoice = $toolChoice; + } elseif ($mode->is(Mode::Json)) { + $this->responseFormat = [ + 'type' => 'json_object', + 'schema' => $responseFormat['schema'] ?? [], + ]; + } elseif ($mode->is(Mode::JsonSchema)) { + $this->responseFormat = [ + 'type' => 'json_schema', + 'json_schema' => [ + 'name' => $responseFormat['json_schema']['name'] ?? 'schema', + 'schema' => $responseFormat['json_schema']['schema'] ?? [], + 'strict' => $responseFormat['json_schema']['strict'] ?? true, + ], + ]; + } elseif ($mode->is([Mode::Text, Mode::MdJson])) { + $this->tools = []; + $this->toolChoice = []; + $this->responseFormat = []; + } + } + + public function toArray() : array { + return [ +// 'uuid' => $this->uuid, +// 'metadata' => $this->metadata, + 'messages' => $this->messages, + 'model' => $this->model, + 'tools' => $this->tools, + 'tool_choice' => $this->toolChoice, + 'response_format' => $this->responseFormat, + 'options' => $this->options, + 'mode' => $this->mode->value, + ]; + } +} diff --git a/src/Utils/Json/PartialJsonParser.php b/src/Utils/Json/PartialJsonParser.php index fcc70546..fa95bbe8 100644 --- a/src/Utils/Json/PartialJsonParser.php +++ b/src/Utils/Json/PartialJsonParser.php @@ -1,204 +1,205 @@ -parsers = array_fill_keys([' ', "\r", "\n", "\t"], $this->parseSpace(...)); - $this->parsers['['] = $this->parseArray(...); - $this->parsers['{'] = $this->parseObject(...); - $this->parsers['"'] = $this->parseString(...); - $this->parsers['t'] = $this->parseTrue(...); - $this->parsers['f'] = $this->parseFalse(...); - $this->parsers['n'] = $this->parseNull(...); - - foreach (str_split('0123456789.-') as $char) { - $this->parsers[$char] = $this->parseNumber(...); - } - - $this->onExtraToken = function ($text, $data, $reminding) { - // we're just skipping extra tokens - //$message = 'Parsed JSON with extra tokens: ' . Json::encode(['text' => $text, 'data' => $data, 'reminding' => $reminding]); - //throw new JSONParsingException($message, $text); - }; - } - - public function fix(string $partialJson) : string { - return Json::encode($this->parse($partialJson)); - } - - public function parse(string $json, bool $associative = true) : array|object - { - if (strlen($json) >= 1) { - try { - return json_decode($json, $associative, 512, JSON_THROW_ON_ERROR); - } catch (JsonException $e) { - list($data, $reminding) = $this->parseAny($json, $e); - $this->lastParseReminding = $reminding; - if ($this->onExtraToken && $reminding && !$this->skipExtraTokens) { - ($this->onExtraToken)($json, $data, $reminding); - } - return match(true) { - is_string($data) => [], - default => $data, - }; - } catch (Exception $e) { - throw new JsonParsingException('Unexpected error: ' . $e->getMessage()); - } - } else { - return json_decode('{}', $associative); - } - } - - public function reminder() : string { - return $this->lastParseReminding; - } - - private function parseAny(string $json, Throwable $e) : array|object { - if (!$json) { - throw $e; - } - $parser = $this->parsers[$json[0]] ?? null; - if (!$parser) { - throw $e; - } - return $parser($json, $e); - } - - private function parseSpace(string $json, Throwable $e) : array|object { - return $this->parseAny(trim($json), $e); - } - - private function parseArray(string $json, Throwable $e) : array { - $json = substr($json, 1); // skip starting '[' - $acc = []; - $json = trim($json); - while ($json) { - if ($json[0] === ']') { - $json = substr($json, 1); // skip ending ']' - break; - } - list($res, $json) = $this->parseAny($json, $e); - $acc[] = $res; - $json = trim($json); - if (strpos($json, ',') === 0) { - $json = substr($json, 1); - $json = trim($json); - } - } - return [$acc, $json]; - } - - private function parseObject(string $json, Throwable $e) : array { - $json = substr($json, 1); // skip starting '{' - $acc = []; - $json = trim($json); - while ($json) { - if ($json[0] === '}') { - $json = substr($json, 1); // skip ending '}' - break; - } - list($key, $json) = $this->parseAny($json, $e); - $json = trim($json); - - if (!$json || $json[0] === '}') { - $acc[$key] = null; - break; - } - - if ($json[0] !== ':') { - throw $e; - } - - $json = substr($json, 1); // skip ':' - $json = trim($json); - - if (!$json || in_array($json[0], [',', '}'])) { - $acc[$key] = null; - if (strpos($json, ',') === 0) { - $json = substr($json, 1); - } - break; - } - - list($value, $json) = $this->parseAny($json, $e); - $acc[$key] = $value; - $json = trim($json); - if (strpos($json, ',') === 0) { - $json = substr($json, 1); - $json = trim($json); - } - } - return [$acc, $json]; - } - - private function parseString(string $json, Throwable $e) : array { - $end = strpos($json, '"', 1); - while ($end !== false && $json[$end - 1] === '\\') { // Handle escaped quotes - $end = strpos($json, '"', $end + 1); - } - if ($end === false) { - // Return the incomplete string without the opening quote - return [substr($json, 1), ""]; - } - $strVal = substr($json, 0, $end + 1); - $json = substr($json, $end + 1); - return [json_decode($strVal), $json]; - } - - private function parseNumber(string $json, Throwable $e) : array { - $i = 0; - while ($i < strlen($json) && strpos('0123456789.-', $json[$i]) !== false) { - $i++; - } - $numStr = substr($json, 0, $i); - $json = substr($json, $i); - if ($numStr == '' || substr($numStr, -1) === '.' || substr($numStr, -1) === '-') { - // Return the incomplete number as is - return [$numStr, ""]; - } - if (strpos($numStr, '.') !== false || strpos($numStr, 'e') !== false || strpos($numStr, 'E') !== false) { - $num = (float) $numStr; - } else { - $num = (int) $numStr; - } - return [$num, $json]; - } - - private function parseTrue(string $json, Throwable $e) : array { - if (substr($json, 0, 4) === 'true') { - return [true, substr($json, 4)]; - } - throw $e; - } - - private function parseFalse(string $json, Throwable $e) : array { - if (substr($json, 0, 5) === 'false') { - return [false, substr($json, 5)]; - } - throw $e; - } - - private function parseNull(string $json, Throwable $e) : array { - if (substr($json, 0, 4) === 'null') { - return [null, substr($json, 4)]; - } - throw $e; - } +parsers = array_fill_keys([' ', "\r", "\n", "\t"], $this->parseSpace(...)); + $this->parsers['['] = $this->parseArray(...); + $this->parsers['{'] = $this->parseObject(...); + $this->parsers['"'] = $this->parseString(...); + $this->parsers['t'] = $this->parseTrue(...); + $this->parsers['f'] = $this->parseFalse(...); + $this->parsers['n'] = $this->parseNull(...); + + foreach (str_split('0123456789.-') as $char) { + $this->parsers[$char] = $this->parseNumber(...); + } + + $this->onExtraToken = function ($text, $data, $reminding) { + // we're just skipping extra tokens + + //$message = 'Parsed JSON with extra tokens: ' . Json::encode(['text' => $text, 'data' => $data, 'reminding' => $reminding]); + //throw new JSONParsingException($message, $text); + }; + } + + public function fix(string $partialJson) : string { + return Json::encode($this->parse($partialJson)); + } + + public function parse(string $json, bool $associative = true) : array|object + { + if (strlen($json) >= 1) { + try { + return json_decode($json, $associative, 512, JSON_THROW_ON_ERROR); + } catch (JsonException $e) { + list($data, $reminding) = $this->parseAny($json, $e); + $this->lastParseReminding = $reminding; + if ($this->onExtraToken && $reminding && !$this->skipExtraTokens) { + ($this->onExtraToken)($json, $data, $reminding); + } + return match(true) { + is_string($data) => [], + default => $data, + }; + } catch (Exception $e) { + throw new JsonParsingException('Unexpected error: ' . $e->getMessage()); + } + } else { + return json_decode('{}', $associative); + } + } + + public function reminder() : string { + return $this->lastParseReminding; + } + + private function parseAny(string $json, Throwable $e) : array|object { + if (!$json) { + throw $e; + } + $parser = $this->parsers[$json[0]] ?? null; + if (!$parser) { + throw $e; + } + return $parser($json, $e); + } + + private function parseSpace(string $json, Throwable $e) : array|object { + return $this->parseAny(trim($json), $e); + } + + private function parseArray(string $json, Throwable $e) : array { + $json = substr($json, 1); // skip starting '[' + $acc = []; + $json = trim($json); + while ($json) { + if ($json[0] === ']') { + $json = substr($json, 1); // skip ending ']' + break; + } + list($res, $json) = $this->parseAny($json, $e); + $acc[] = $res; + $json = trim($json); + if (strpos($json, ',') === 0) { + $json = substr($json, 1); + $json = trim($json); + } + } + return [$acc, $json]; + } + + private function parseObject(string $json, Throwable $e) : array { + $json = substr($json, 1); // skip starting '{' + $acc = []; + $json = trim($json); + while ($json) { + if ($json[0] === '}') { + $json = substr($json, 1); // skip ending '}' + break; + } + list($key, $json) = $this->parseAny($json, $e); + $json = trim($json); + + if (!$json || $json[0] === '}') { + $acc[$key] = null; + break; + } + + if ($json[0] !== ':') { + throw $e; + } + + $json = substr($json, 1); // skip ':' + $json = trim($json); + + if (!$json || in_array($json[0], [',', '}'])) { + $acc[$key] = null; + if (strpos($json, ',') === 0) { + $json = substr($json, 1); + } + break; + } + + list($value, $json) = $this->parseAny($json, $e); + $acc[$key] = $value; + $json = trim($json); + if (strpos($json, ',') === 0) { + $json = substr($json, 1); + $json = trim($json); + } + } + return [$acc, $json]; + } + + private function parseString(string $json, Throwable $e) : array { + $end = strpos($json, '"', 1); + while ($end !== false && $json[$end - 1] === '\\') { // Handle escaped quotes + $end = strpos($json, '"', $end + 1); + } + if ($end === false) { + // Return the incomplete string without the opening quote + return [substr($json, 1), ""]; + } + $strVal = substr($json, 0, $end + 1); + $json = substr($json, $end + 1); + return [json_decode($strVal), $json]; + } + + private function parseNumber(string $json, Throwable $e) : array { + $i = 0; + while ($i < strlen($json) && strpos('0123456789.-', $json[$i]) !== false) { + $i++; + } + $numStr = substr($json, 0, $i); + $json = substr($json, $i); + if ($numStr == '' || substr($numStr, -1) === '.' || substr($numStr, -1) === '-') { + // Return the incomplete number as is + return [$numStr, ""]; + } + if (strpos($numStr, '.') !== false || strpos($numStr, 'e') !== false || strpos($numStr, 'E') !== false) { + $num = (float) $numStr; + } else { + $num = (int) $numStr; + } + return [$num, $json]; + } + + private function parseTrue(string $json, Throwable $e) : array { + if (substr($json, 0, 4) === 'true') { + return [true, substr($json, 4)]; + } + throw $e; + } + + private function parseFalse(string $json, Throwable $e) : array { + if (substr($json, 0, 5) === 'false') { + return [false, substr($json, 5)]; + } + throw $e; + } + + private function parseNull(string $json, Throwable $e) : array { + if (substr($json, 0, 4) === 'null') { + return [null, substr($json, 4)]; + } + throw $e; + } } \ No newline at end of file diff --git a/src/Utils/Messages/Message.php b/src/Utils/Messages/Message.php index 1d94ecbc..cd79aaf4 100644 --- a/src/Utils/Messages/Message.php +++ b/src/Utils/Messages/Message.php @@ -1,20 +1,35 @@ - $content - */ - public function __construct( - public string $role = '', - public string|array $content = '', - ) {} -} + $content + */ + public function __construct( + string|MessageRole $role = '', + string|array|null $content = '', + array $metadata = [], + ) { + $this->role = match(true) { + $role instanceof MessageRole => $role->value, + ($role === '') => self::DEFAULT_ROLE, + default => $role, + }; + $this->content = $content ?? ''; + $this->metadata = $metadata; + } +} diff --git a/src/Utils/Messages/Traits/Message/HandlesAccess.php b/src/Utils/Messages/Traits/Message/HandlesAccess.php index eca87b1e..c17c0035 100644 --- a/src/Utils/Messages/Traits/Message/HandlesAccess.php +++ b/src/Utils/Messages/Traits/Message/HandlesAccess.php @@ -1,36 +1,66 @@ -role); - } - - public function content() : string|array { - return $this->content; - } - - public function isEmpty() : bool { - return empty($this->content); - } - - public function isNull() : bool { - return ($this->role === '' && $this->content === ''); - } - - public function isComposite() : bool { - return is_array($this->content); - } +role); + } + + public function content() : string|array { + return $this->content; + } + + public function isEmpty() : bool { + return empty($this->content) && !$this->hasMeta(); + } + + public function isNull() : bool { + return ($this->role === '' && $this->content === ''); + } + + public function isComposite() : bool { + return is_array($this->content); + } + + public function hasMeta(string $key = null) : bool { + return match(true) { + $key === null => !empty($this->metadata), + default => isset($this->metadata[$key]), + }; + } + + public function meta(string $key = null) : mixed { + return match(true) { + $key === null => $this->metadata, + default => $this->metadata[$key] ?? null, + }; + } + + public function metaKeys() : array { + return array_keys($this->metadata); + } + + public function withMeta(array $metadata) : self { + $this->metadata = $metadata; + return $this; + } + + public function withMetaValue(string $key, mixed $value) : self { + $this->metadata[$key] = $value; + return $this; + } } \ No newline at end of file diff --git a/src/Utils/Messages/Traits/Message/HandlesCreation.php b/src/Utils/Messages/Traits/Message/HandlesCreation.php index c1ea12a6..6502faaa 100644 --- a/src/Utils/Messages/Traits/Message/HandlesCreation.php +++ b/src/Utils/Messages/Traits/Message/HandlesCreation.php @@ -6,7 +6,6 @@ use Cognesy\Instructor\Utils\Messages\Message; use Cognesy\Instructor\Utils\Messages\Utils\Text; use Exception; -use InvalidArgumentException; trait HandlesCreation { @@ -19,10 +18,11 @@ public static function fromString(string $content) : static { } public static function fromArray(array $message) : static { - if (!self::hasRoleAndContent($message)) { - throw new InvalidArgumentException('Message array must contain "role" and "content" keys'); - } - return new static($message['role'], $message['content']); + return new static( + role: $message['role'] ?? 'user', + content: $message['content'] ?? '', + metadata: $message['_metadata'] ?? [], + ); } public static function fromContent(string $role, string|array $content) : static { diff --git a/src/Utils/Messages/Traits/Message/HandlesTransformation.php b/src/Utils/Messages/Traits/Message/HandlesTransformation.php index a053765e..82ec8d30 100644 --- a/src/Utils/Messages/Traits/Message/HandlesTransformation.php +++ b/src/Utils/Messages/Traits/Message/HandlesTransformation.php @@ -1,51 +1,46 @@ - $this->role, 'content' => $this->content]; - } - - public function toString() : string { - if (!$this->isComposite()) { - return $this->content; - } - // flatten composite message to text - $text = ''; - foreach($this->content as $part) { - if ($part['type'] !== 'text') { - throw new RuntimeException('Message contains non-text parts and cannot be flattened to text'); - } - $text .= $part['text']; - } - return $text; - } - - public function toRoleString() : string { - return $this->role . ': ' . $this->toString(); - } - - public function toCompositeMessage() : Message { - return Message::fromArray($this->toCompositeArray()); - } - - public function toCompositeArray() : array { - return match($this->isComposite()) { - true => [ - 'role' => $this->role, - 'content' => $this->content, - ], - default => [ - 'role' => $this->role, - 'content' => [[ - 'type' => 'text', - 'text' => $this->content, - ]] - ] - }; - } -} + $this->role, + 'content' => $this->content, + '_metadata' => $this->metadata, + ]); + } + + public function toString() : string { + if (!$this->isComposite()) { + return $this->content; + } + // flatten composite message to text + $text = ''; + foreach($this->content as $part) { + if ($part['type'] !== 'text') { + throw new RuntimeException('Message contains non-text parts and cannot be flattened to text'); + } + $text .= $part['text']; + } + return $text; + } + + public function toRoleString() : string { + return $this->role . ': ' . $this->toString(); + } + + public function toCompositeMessage() : Message { + return Message::fromArray([ + 'role' => $this->role, + 'content' => match(true) { + $this->isComposite() => $this->content, + default => [['type' => 'text', 'text' => $this->content]] + }, + '_metadata' => $this->metadata, + ]); + } +} diff --git a/src/Utils/Messages/Traits/Messages/HandlesConversion.php b/src/Utils/Messages/Traits/Messages/HandlesConversion.php index fa64f8ad..b4bf3205 100644 --- a/src/Utils/Messages/Traits/Messages/HandlesConversion.php +++ b/src/Utils/Messages/Traits/Messages/HandlesConversion.php @@ -1,91 +1,92 @@ -> $messages - * @return array> - */ - public static function asPerRoleArray(array $messages) : array { - if (empty($messages)) { - return ['role' => 'system', 'content' => '']; - } - - $role = 'user'; - $merged = new Messages(); - $content = []; - foreach ($messages as $message) { - if ($role !== $message['role'] || Message::becomesComposite($message)) { - $merged->appendMessage(new Message( - role: $role, - content: implode("\n\n", array_filter($content)), - )); - $role = $message['role']; - $content = []; - - if (Message::becomesComposite($message)) { - $merged->appendMessage($message); - continue; - } - } - $content[] = $message['content']; - } - // append remaining content - if (!empty($content)) { - $merged->appendMessage(new Message( - role: $role, - content: implode("\n", array_filter($content)), // TODO: see above - )); - } - return $merged->toArray(); - } - - public static function asString( - array $messages, - string $separator = "\n", - callable $renderer = null - ) : string { - $result = ''; - foreach ($messages as $message) { - if (empty($message) || !is_array($message) || empty($message['content'])) { - continue; - } - $rendered = match(true) { - !is_null($renderer) => $renderer($message), - default => match(true) { - Message::becomesComposite($message) => throw new RuntimeException('Array contains composite messages, cannot be converted to string.'), - default => $message['content'] . $separator, - } - }; - $result .= $rendered; - } - return $result; - } - - /** - * @return array - */ - public function toArray() : array { - $result = []; - foreach ($this->messages as $message) { - if ($message->isEmpty()) { - continue; - } - $result[] = $message->toArray(); - } - return $result; - } - - public function toString(string $separator = "\n") : string { - if ($this->hasComposites()) { - throw new RuntimeException('Collection contains composite messages and cannot be converted to string.'); - } - return self::asString($this->toArray(), $separator); - } +> $messages + * @return array> + */ + public static function asPerRoleArray(array $messages) : array { + if (empty($messages)) { + return ['role' => 'user', 'content' => '']; + } + + $role = 'user'; + $merged = new Messages(); + $content = []; + foreach ($messages as $message) { + if ($role !== $message['role'] || Message::becomesComposite($message)) { + $merged->appendMessage(new Message( + role: $role, + content: implode("\n\n", array_filter($content)), + )); + $role = $message['role']; + $content = []; + + if (Message::becomesComposite($message)) { + $merged->appendMessage($message); + continue; + } + } + $content[] = $message['content']; + } + // append remaining content + if (!empty($content)) { + $merged->appendMessage(new Message( + role: $role, + content: implode("\n", array_filter($content)), // TODO: see above + metadata: $message['_metadata'], + )); + } + return $merged->toArray(); + } + + public static function asString( + array $messages, + string $separator = "\n", + callable $renderer = null + ) : string { + $result = ''; + foreach ($messages as $message) { + if (empty($message) || !is_array($message) || empty($message['content'])) { + continue; + } + $rendered = match(true) { + !is_null($renderer) => $renderer($message), + default => match(true) { + Message::becomesComposite($message) => throw new RuntimeException('Array contains composite messages, cannot be converted to string.'), + default => $message['content'] . $separator, + } + }; + $result .= $rendered; + } + return $result; + } + + /** + * @return array + */ + public function toArray() : array { + $result = []; + foreach ($this->messages as $message) { + if ($message->isEmpty()) { + continue; + } + $result[] = $message->toArray(); + } + return $result; + } + + public function toString(string $separator = "\n") : string { + if ($this->hasComposites()) { + throw new RuntimeException('Collection contains composite messages and cannot be converted to string.'); + } + return self::asString($this->toArray(), $separator); + } } \ No newline at end of file diff --git a/src/Utils/Messages/Traits/Messages/HandlesCreation.php b/src/Utils/Messages/Traits/Messages/HandlesCreation.php index 11e29b21..863534c9 100644 --- a/src/Utils/Messages/Traits/Messages/HandlesCreation.php +++ b/src/Utils/Messages/Traits/Messages/HandlesCreation.php @@ -1,91 +1,91 @@ -appendMessage(Message::fromString($content)); - } - - /** - * @param array $messages - */ - static public function fromArray(array $messages) : Messages { - $instance = new self(); - foreach ($messages as $message) { - $instance->messages[] = match(true) { - is_string($message) => Message::fromString($message), - Message::hasRoleAndContent($message) => new Message($message['role'], $message['content']), - default => throw new Exception('Invalid message array - missing role or content keys'), - }; - } - return $instance; - } - - /** - * @param Messages[] $messages - */ - static public function fromMessages(array|Message|Messages ...$arrayOfMessages) : Messages { - $instance = new self(); - foreach ($arrayOfMessages as $message) { - if ($message instanceof Messages) { - $instance->appendMessages($message); - } elseif ($message instanceof Message) { - $instance->appendMessage($message); - } elseif (is_array($message)) { - $instance->appendMessage(Message::fromArray($message)); - } else { - throw new InvalidArgumentException('Invalid type for message'); - } - } - return $instance; - } - - public static function fromAnyArray(array $messages) : Messages { - if (Message::hasRoleAndContent($messages)) { - return self::fromArray([$messages]); - } - $normalized = new self(); - foreach ($messages as $message) { - $normalized->appendMessage(match(true) { - is_array($message) => match(true) { - Message::hasRoleAndContent($message) => new Message($message['role'], $message['content']), - default => throw new Exception('Invalid message array - missing role or content keys'), - }, - is_string($message) => new Message('user', $message), - $message instanceof Message => $message, - default => throw new Exception('Invalid message type'), - }); - } - return $normalized; - } - - public static function fromAny(string|array|Message|Messages $messages) : Messages { - return match(true) { - is_string($messages) => self::fromString($messages), - is_array($messages) => self::fromAnyArray($messages), - $messages instanceof Message => (new Messages)->appendMessage($messages), - $messages instanceof Messages => $messages, - default => throw new Exception('Invalid message type'), - }; - } - - public static function fromInput(string|array|object $input) : static { - return match(true) { - $input instanceof Messages => $input, - $input instanceof CanProvideMessages => $input->toMessages(), - $input instanceof Message => (new Messages)->appendMessage($input), - $input instanceof CanProvideMessage => (new Messages)->appendMessage($input->toMessage()), - default => (new Messages)->appendMessage(new Message('user', Text::fromAny($input))), - }; - } -} +appendMessage(Message::fromString($content)); + } + + /** + * @param array $messages + */ + static public function fromArray(array $messages) : Messages { + $instance = new self(); + foreach ($messages as $message) { + $instance->messages[] = match(true) { + is_string($message) => Message::fromString($message), + Message::hasRoleAndContent($message) => Message::fromArray($message), + default => throw new Exception('Invalid message array - missing role or content keys'), + }; + } + return $instance; + } + + /** + * @param Messages[] $messages + */ + static public function fromMessages(array|Message|Messages ...$arrayOfMessages) : Messages { + $instance = new self(); + foreach ($arrayOfMessages as $message) { + if ($message instanceof Messages) { + $instance->appendMessages($message); + } elseif ($message instanceof Message) { + $instance->appendMessage($message); + } elseif (is_array($message)) { + $instance->appendMessage(Message::fromArray($message)); + } else { + throw new InvalidArgumentException('Invalid type for message'); + } + } + return $instance; + } + + public static function fromAnyArray(array $messages) : Messages { + if (Message::hasRoleAndContent($messages)) { + return self::fromArray([$messages]); + } + $normalized = new self(); + foreach ($messages as $message) { + $normalized->appendMessage(match(true) { + is_array($message) => match(true) { + Message::hasRoleAndContent($message) => Message::fromArray($message), + default => throw new Exception('Invalid message array - missing role or content keys'), + }, + is_string($message) => new Message('user', $message), + $message instanceof Message => $message, + default => throw new Exception('Invalid message type'), + }); + } + return $normalized; + } + + public static function fromAny(string|array|Message|Messages $messages) : Messages { + return match(true) { + is_string($messages) => self::fromString($messages), + is_array($messages) => self::fromAnyArray($messages), + $messages instanceof Message => (new Messages)->appendMessage($messages), + $messages instanceof Messages => $messages, + default => throw new Exception('Invalid message type'), + }; + } + + public static function fromInput(string|array|object $input) : static { + return match(true) { + $input instanceof Messages => $input, + $input instanceof CanProvideMessages => $input->toMessages(), + $input instanceof Message => (new Messages)->appendMessage($input), + $input instanceof CanProvideMessage => (new Messages)->appendMessage($input->toMessage()), + default => (new Messages)->appendMessage(new Message('user', Text::fromAny($input))), + }; + } +} diff --git a/src/Utils/Str.php b/src/Utils/Str.php index a7fc5e8f..7e8375aa 100644 --- a/src/Utils/Str.php +++ b/src/Utils/Str.php @@ -94,34 +94,34 @@ static public function containsAny(string $haystack, string|array $needles, bool return false; } - public static function startsWith(string $url, string $string) : bool { - return substr($url, 0, strlen($string)) === $string; + public static function startsWith(string $text, string $prefix) : bool { + return substr($text, 0, strlen($prefix)) === $prefix; } - public static function endsWith(string $url, string $string) : bool { - return substr($url, -strlen($string)) === $string; + public static function endsWith(string $text, string $suffix) : bool { + return substr($text, -strlen($suffix)) === $suffix; } - public static function between(mixed $url, string $string, string $string1) : string { - $start = strpos($url, $string); + public static function between(mixed $text, string $firstNeedle, string $nextNeedle) : string { + $start = strpos($text, $firstNeedle); if ($start === false) { return ''; } - $start += strlen($string); - $end = strpos($url, $string1, $start); + $start += strlen($firstNeedle); + $end = strpos($text, $nextNeedle, $start); if ($end === false) { return ''; } - return substr($url, $start, $end - $start); + return substr($text, $start, $end - $start); } - public static function after(mixed $url, string $string) : string { - $start = strpos($url, $string); + public static function after(mixed $text, string $needle) : string { + $start = strpos($text, $needle); if ($start === false) { return ''; } - $start += strlen($string); - return substr($url, $start); + $start += strlen($needle); + return substr($text, $start); } public static function when(bool $condition, string $onTrue, string $onFalse) : string { diff --git a/src/Utils/TemplateUtil.php b/src/Utils/TemplateUtil.php index 48675382..60d620a9 100644 --- a/src/Utils/TemplateUtil.php +++ b/src/Utils/TemplateUtil.php @@ -1,139 +1,139 @@ -clearUnknownParams = $clearUnknownParams; - $this->parameters = $parameters; - if (empty($parameters)) { - return; - } - // remove keys starting with @ - these are used for section templates - $filteredParameters = array_filter( - $parameters, - fn($key) => substr($key, 0, 1) !== '@', - ARRAY_FILTER_USE_KEY - ); - $materializedParameters = $this->materializeParameters($filteredParameters); - $this->parameterValues = array_values($materializedParameters); - $this->parameterKeys = array_map( - fn($key) => $this->varPattern($key), - array_keys($materializedParameters) - ); - } - - public function getParameters() : array { - return $this->parameters; - } - - public static function cleanVarMarkers(string $template) : string { - return str_replace(['<|', '|>'], '', $template); - } - - public static function render( - string $template, - array $parameters, - bool $clearUnknownParams = true, - ) : string { - return (new TemplateUtil( - $parameters, - $clearUnknownParams - ))->renderString($template); - } - - public function renderString(string $template): string { - // find all keys in the template - $keys = $this->findVars($template); - if ($this->clearUnknownParams) { - // find keys missing from $this->keys - $missingKeys = array_diff($keys, $this->parameterKeys); - // remove missing key strings from the template - $template = str_replace($missingKeys, '', $template); - } - // render values - return str_replace($this->parameterKeys, $this->parameterValues, $template); - } - - public function renderArray( - array $rows, - string $field = 'content' - ): array { - return array_map( - fn($item) => $this->renderString($item[$field] ?? ''), - $rows - ); - } - - public function renderMessage(array|Message $message) : array { - $normalized = match(true) { - is_array($message) => Message::fromArray($message), - $message instanceof Message => $message, - default => throw new InvalidArgumentException('Invalid message type'), - }; - - // skip rendering if content is an array - it may contain non-text data - if (is_array($normalized->content)) { - return ['role' => $normalized->role, 'content' => $normalized->content]; - } - - return ['role' => $normalized->role, 'content' => $this->renderString($normalized->content)]; - } - - public function renderMessages(array|Messages $messages) : array { - return array_map( - fn($message) => $this->renderMessage($message), - is_array($messages) ? $messages : $messages->toArray() - ); - } - - // OVERRIDEABLE ////////////////////////////////////////////////////////////// - - protected function varPattern(string $key) : string { - return '<|' . $key . '|>'; - } - - protected function findVars(string $template) : array { - $matches = []; - // replace {xxx} pattern with <|xxx|> pattern match - preg_match_all('/<\|([^|]+)\|>/', $template, $matches); - return $matches[0]; - } - - // INTERNAL ////////////////////////////////////////////////////////////////// - - private function materializeParameters(array $parameters) : array { - // TODO: is there a way to consolidate value rendering? - $parameterValues = []; - foreach ($parameters as $key => $value) { - $value = match (true) { - is_scalar($value) => $value, - is_array($value) => Json::encode($value), - is_callable($value) => $value($key, $parameters), - is_object($value) && method_exists($value, 'toString') => $value->toString(), - is_object($value) && method_exists($value, 'toJson') => $value->toJson(), - is_object($value) && method_exists($value, 'toArray') => Json::encode($value->toArray()), - is_object($value) && method_exists($value, 'toSchema') => Json::encode($value->toSchema()), - is_object($value) && method_exists($value, 'toOutputSchema') => Json::encode($value->toOutputSchema()), - is_object($value) && property_exists($value, 'value') => $value->value(), - is_object($value) => Json::encode($value), - default => $value, - }; - $parameterValues[$key] = $value; - } - return $parameterValues; - } -} +clearUnknownParams = $clearUnknownParams; + $this->parameters = $parameters; + if (empty($parameters)) { + return; + } + // remove keys starting with @ - these are used for section templates + $filteredParameters = array_filter( + $parameters, + fn($key) => substr($key, 0, 1) !== '@', + ARRAY_FILTER_USE_KEY + ); + $materializedParameters = $this->materializeParameters($filteredParameters); + $this->parameterValues = array_values($materializedParameters); + $this->parameterKeys = array_map( + fn($key) => $this->varPattern($key), + array_keys($materializedParameters) + ); + } + + public function getParameters() : array { + return $this->parameters; + } + + public static function cleanVarMarkers(string $template) : string { + return str_replace(['<|', '|>'], '', $template); + } + + public static function render( + string $template, + array $parameters, + bool $clearUnknownParams = true, + ) : string { + return (new TemplateUtil( + $parameters, + $clearUnknownParams + ))->renderString($template); + } + + public function renderString(string $template): string { + // find all keys in the template + $keys = $this->findVars($template); + if ($this->clearUnknownParams) { + // find keys missing from $this->keys + $missingKeys = array_diff($keys, $this->parameterKeys); + // remove missing key strings from the template + $template = str_replace($missingKeys, '', $template); + } + // render values + return str_replace($this->parameterKeys, $this->parameterValues, $template); + } + + public function renderArray( + array $rows, + string $field = 'content' + ): array { + return array_map( + fn($item) => $this->renderString($item[$field] ?? ''), + $rows + ); + } + + public function renderMessage(array|Message $message) : array { + $normalized = match(true) { + is_array($message) => Message::fromArray($message), + $message instanceof Message => $message, + default => throw new InvalidArgumentException('Invalid message type'), + }; + + // skip rendering if content is an array - it may contain non-text data + if (is_array($normalized->content())) { + return ['role' => $normalized->role(), 'content' => $normalized->content()]; + } + + return ['role' => $normalized->role(), 'content' => $this->renderString($normalized->content())]; + } + + public function renderMessages(array|Messages $messages) : array { + return array_map( + fn($message) => $this->renderMessage($message), + is_array($messages) ? $messages : $messages->toArray() + ); + } + + // OVERRIDEABLE ////////////////////////////////////////////////////////////// + + protected function varPattern(string $key) : string { + return '<|' . $key . '|>'; + } + + protected function findVars(string $template) : array { + $matches = []; + // replace {xxx} pattern with <|xxx|> pattern match + preg_match_all('/<\|([^|]+)\|>/', $template, $matches); + return $matches[0]; + } + + // INTERNAL ////////////////////////////////////////////////////////////////// + + private function materializeParameters(array $parameters) : array { + // TODO: is there a way to consolidate value rendering? + $parameterValues = []; + foreach ($parameters as $key => $value) { + $value = match (true) { + is_scalar($value) => $value, + is_array($value) => Json::encode($value), + is_callable($value) => $value($key, $parameters), + is_object($value) && method_exists($value, 'toString') => $value->toString(), + is_object($value) && method_exists($value, 'toJson') => $value->toJson(), + is_object($value) && method_exists($value, 'toArray') => Json::encode($value->toArray()), + is_object($value) && method_exists($value, 'toSchema') => Json::encode($value->toSchema()), + is_object($value) && method_exists($value, 'toOutputSchema') => Json::encode($value->toOutputSchema()), + is_object($value) && property_exists($value, 'value') => $value->value(), + is_object($value) => Json::encode($value), + default => $value, + }; + $parameterValues[$key] = $value; + } + return $parameterValues; + } +} diff --git a/src/Utils/Web/Webpage.php b/src/Utils/Web/Webpage.php index 7d22a35e..777dd2ed 100644 --- a/src/Utils/Web/Webpage.php +++ b/src/Utils/Web/Webpage.php @@ -1,34 +1,34 @@ -scraper = $scraper ?? new BasicReader(); - $this->htmlProcessor = new HtmlProcessor(); - } - - public function toMessage(): Message { - return new Message(content: $this->asMarkdown()); - } -} +scraper = $scraper ?? new BasicReader(); + $this->htmlProcessor = new HtmlProcessor(); + } + + public function toMessage(): Message { + return new Message(content: $this->asMarkdown()); + } +}