diff --git a/config/llm.php b/config/llm.php
index 55b77ba2..85a1465c 100644
--- a/config/llm.php
+++ b/config/llm.php
@@ -1,137 +1,145 @@
- false,
- 'defaultConnection' => 'openai',
-
- 'defaultToolName' => 'extracted_data',
- 'defaultToolDescription' => 'Function call based on user instructions.',
- 'defaultRetryPrompt' => "JSON generated incorrectly, fix following errors:\n",
- 'defaultMdJsonPrompt' => "Response must validate against this JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object within a ```json {} ``` codeblock.\n",
- 'defaultJsonPrompt' => "Response must follow JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object.\n",
- 'defaultToolsPrompt' => "Extract correct and accurate data from the input using provided tools.\n",
-
- 'connections' => [
- 'anthropic' => [
- 'providerType' => LLMProviderType::Anthropic->value,
- 'apiUrl' => 'https://api.anthropic.com/v1',
- 'apiKey' => Env::get('ANTHROPIC_API_KEY', ''),
- 'endpoint' => '/messages',
- 'metadata' => [
- 'apiVersion' => '2023-06-01',
- 'beta' => 'prompt-caching-2024-07-31',
- ],
- 'defaultModel' => 'claude-3-haiku-20240307',
- 'defaultMaxTokens' => 1024,
- ],
- 'azure' => [
- 'providerType' => LLMProviderType::Azure->value,
- 'apiUrl' => 'https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}',
- 'apiKey' => Env::get('AZURE_OPENAI_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'metadata' => [
- 'apiVersion' => '2024-08-01-preview',
- 'resourceName' => 'instructor-dev',
- 'deploymentId' => 'gpt-4o-mini',
- ],
- 'defaultModel' => 'gpt-4o-mini',
- 'defaultMaxTokens' => 1024,
- ],
- 'cohere1' => [
- 'providerType' => LLMProviderType::CohereV1->value,
- 'apiUrl' => 'https://api.cohere.ai/v1',
- 'apiKey' => Env::get('COHERE_API_KEY', ''),
- 'endpoint' => '/chat',
- 'defaultModel' => 'command-r-plus-08-2024',
- 'defaultMaxTokens' => 1024,
- ],
- 'cohere2' => [
- 'providerType' => LLMProviderType::CohereV2->value,
- 'apiUrl' => 'https://api.cohere.ai/v2',
- 'apiKey' => Env::get('COHERE_API_KEY', ''),
- 'endpoint' => '/chat',
- 'defaultModel' => 'command-r-plus-08-2024',
- 'defaultMaxTokens' => 1024,
- ],
- 'fireworks' => [
- 'providerType' => LLMProviderType::Fireworks->value,
- 'apiUrl' => 'https://api.fireworks.ai/inference/v1',
- 'apiKey' => Env::get('FIREWORKS_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'accounts/fireworks/models/mixtral-8x7b-instruct',
- 'defaultMaxTokens' => 1024,
- ],
- 'gemini' => [
- 'providerType' => LLMProviderType::Gemini->value,
- 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta',
- 'apiKey' => Env::get('GEMINI_API_KEY', ''),
- 'endpoint' => '/models/{model}:generateContent',
- 'defaultModel' => 'gemini-1.5-flash-latest',
- 'defaultMaxTokens' => 1024,
- ],
- 'grok' => [
- 'providerType' => LLMProviderType::Grok->value,
- 'apiUrl' => 'https://api.x.ai/v1',
- 'apiKey' => Env::get('GROK_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'grok-beta',
- 'defaultMaxTokens' => 1024,
- ],
- 'groq' => [
- 'providerType' => LLMProviderType::Groq->value,
- 'apiUrl' => 'https://api.groq.com/openai/v1',
- 'apiKey' => Env::get('GROQ_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'llama3-groq-8b-8192-tool-use-preview', // 'gemma2-9b-it',
- 'defaultMaxTokens' => 1024,
- ],
- 'mistral' => [
- 'providerType' => LLMProviderType::Mistral->value,
- 'apiUrl' => 'https://api.mistral.ai/v1',
- 'apiKey' => Env::get('MISTRAL_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'mistral-small-latest',
- 'defaultMaxTokens' => 1024,
- ],
- 'ollama' => [
- 'providerType' => LLMProviderType::Ollama->value,
- 'apiUrl' => 'http://localhost:11434/v1',
- 'apiKey' => Env::get('OLLAMA_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'qwen2.5:0.5b', //'gemma2:2b',
- 'defaultMaxTokens' => 1024,
- 'httpClient' => 'http-ollama',
- ],
- 'openai' => [
- 'providerType' => LLMProviderType::OpenAI->value,
- 'apiUrl' => 'https://api.openai.com/v1',
- 'apiKey' => Env::get('OPENAI_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'metadata' => [
- 'organization' => '',
- 'project' => '',
- ],
- 'defaultModel' => 'gpt-4o-mini',
- 'defaultMaxTokens' => 1024,
- ],
- 'openrouter' => [
- 'providerType' => LLMProviderType::OpenRouter->value,
- 'apiUrl' => 'https://openrouter.ai/api/v1',
- 'apiKey' => Env::get('OPENROUTER_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'qwen/qwen-2.5-72b-instruct', //'microsoft/phi-3.5-mini-128k-instruct',
- 'defaultMaxTokens' => 1024,
- ],
- 'together' => [
- 'providerType' => LLMProviderType::Together->value,
- 'apiUrl' => 'https://api.together.xyz/v1',
- 'apiKey' => Env::get('TOGETHER_API_KEY', ''),
- 'endpoint' => '/chat/completions',
- 'defaultModel' => 'mistralai/Mixtral-8x7B-Instruct-v0.1',
- 'defaultMaxTokens' => 1024,
- ],
- ],
-];
+ false,
+ 'defaultConnection' => 'openai',
+
+ 'defaultToolName' => 'extracted_data',
+ 'defaultToolDescription' => 'Function call based on user instructions.',
+ 'defaultRetryPrompt' => "JSON generated incorrectly, fix following errors:\n",
+ 'defaultMdJsonPrompt' => "Response must validate against this JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object within a ```json {} ``` codeblock.\n",
+ 'defaultJsonPrompt' => "Response must follow JSON Schema:\n<|json_schema|>\n. Respond correctly with strict JSON object.\n",
+ 'defaultToolsPrompt' => "Extract correct and accurate data from the input using provided tools.\n",
+
+ 'connections' => [
+ 'anthropic' => [
+ 'providerType' => LLMProviderType::Anthropic->value,
+ 'apiUrl' => 'https://api.anthropic.com/v1',
+ 'apiKey' => Env::get('ANTHROPIC_API_KEY', ''),
+ 'endpoint' => '/messages',
+ 'metadata' => [
+ 'apiVersion' => '2023-06-01',
+ 'beta' => 'prompt-caching-2024-07-31',
+ ],
+ 'defaultModel' => 'claude-3-haiku-20240307',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'azure' => [
+ 'providerType' => LLMProviderType::Azure->value,
+ 'apiUrl' => 'https://{resourceName}.openai.azure.com/openai/deployments/{deploymentId}',
+ 'apiKey' => Env::get('AZURE_OPENAI_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'metadata' => [
+ 'apiVersion' => '2024-08-01-preview',
+ 'resourceName' => 'instructor-dev',
+ 'deploymentId' => 'gpt-4o-mini',
+ ],
+ 'defaultModel' => 'gpt-4o-mini',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'cohere1' => [
+ 'providerType' => LLMProviderType::CohereV1->value,
+ 'apiUrl' => 'https://api.cohere.ai/v1',
+ 'apiKey' => Env::get('COHERE_API_KEY', ''),
+ 'endpoint' => '/chat',
+ 'defaultModel' => 'command-r-plus-08-2024',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'cohere2' => [
+ 'providerType' => LLMProviderType::CohereV2->value,
+ 'apiUrl' => 'https://api.cohere.ai/v2',
+ 'apiKey' => Env::get('COHERE_API_KEY', ''),
+ 'endpoint' => '/chat',
+ 'defaultModel' => 'command-r-plus-08-2024',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'fireworks' => [
+ 'providerType' => LLMProviderType::Fireworks->value,
+ 'apiUrl' => 'https://api.fireworks.ai/inference/v1',
+ 'apiKey' => Env::get('FIREWORKS_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'accounts/fireworks/models/mixtral-8x7b-instruct',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'gemini' => [
+ 'providerType' => LLMProviderType::Gemini->value,
+ 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta',
+ 'apiKey' => Env::get('GEMINI_API_KEY', ''),
+ 'endpoint' => '/models/{model}:generateContent',
+ 'defaultModel' => 'gemini-1.5-flash-latest',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'gemini-oai' => [
+ 'providerType' => LLMProviderType::GeminiOAI->value,
+ 'apiUrl' => 'https://generativelanguage.googleapis.com/v1beta/openai',
+ 'apiKey' => Env::get('GEMINI_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'gemini-1.5-flash',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'grok' => [
+ 'providerType' => LLMProviderType::Grok->value,
+ 'apiUrl' => 'https://api.x.ai/v1',
+ 'apiKey' => Env::get('GROK_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'grok-beta',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'groq' => [
+ 'providerType' => LLMProviderType::Groq->value,
+ 'apiUrl' => 'https://api.groq.com/openai/v1',
+ 'apiKey' => Env::get('GROQ_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'llama3-groq-8b-8192-tool-use-preview', // 'gemma2-9b-it',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'mistral' => [
+ 'providerType' => LLMProviderType::Mistral->value,
+ 'apiUrl' => 'https://api.mistral.ai/v1',
+ 'apiKey' => Env::get('MISTRAL_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'mistral-small-latest',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'ollama' => [
+ 'providerType' => LLMProviderType::Ollama->value,
+ 'apiUrl' => 'http://localhost:11434/v1',
+ 'apiKey' => Env::get('OLLAMA_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'qwen2.5-coder:3b', //'gemma2:2b',
+ 'defaultMaxTokens' => 1024,
+ 'httpClient' => 'http-ollama',
+ ],
+ 'openai' => [
+ 'providerType' => LLMProviderType::OpenAI->value,
+ 'apiUrl' => 'https://api.openai.com/v1',
+ 'apiKey' => Env::get('OPENAI_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'metadata' => [
+ 'organization' => '',
+ 'project' => '',
+ ],
+ 'defaultModel' => 'gpt-4o-mini',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'openrouter' => [
+ 'providerType' => LLMProviderType::OpenRouter->value,
+ 'apiUrl' => 'https://openrouter.ai/api/v1',
+ 'apiKey' => Env::get('OPENROUTER_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'qwen/qwen-2.5-72b-instruct', //'microsoft/phi-3.5-mini-128k-instruct',
+ 'defaultMaxTokens' => 1024,
+ ],
+ 'together' => [
+ 'providerType' => LLMProviderType::Together->value,
+ 'apiUrl' => 'https://api.together.xyz/v1',
+ 'apiKey' => Env::get('TOGETHER_API_KEY', ''),
+ 'endpoint' => '/chat/completions',
+ 'defaultModel' => 'mistralai/Mixtral-8x7B-Instruct-v0.1',
+ 'defaultMaxTokens' => 1024,
+ ],
+ ],
+];
diff --git a/evals/LLMModes/CompanyEval.php b/evals/LLMModes/CompanyEval.php
index b70ef744..29687592 100644
--- a/evals/LLMModes/CompanyEval.php
+++ b/evals/LLMModes/CompanyEval.php
@@ -6,6 +6,7 @@
use Cognesy\Instructor\Extras\Evals\Contracts\CanGenerateObservations;
use Cognesy\Instructor\Extras\Evals\Execution;
use Cognesy\Instructor\Extras\Evals\Observation;
+use Cognesy\Instructor\Features\LLM\Data\ToolCall;
use Cognesy\Instructor\Utils\Str;
class CompanyEval implements CanGenerateObservations
@@ -50,10 +51,14 @@ public function correctness(Execution $execution): Observation {
}
private function validateToolsData(Execution $execution) : bool {
- $data = $execution->get('response')->toolsData[0] ?? [];
- return 'store_company' === ($data['name'] ?? '')
- && 'ACME' === ($data['arguments']['name'] ?? '')
- && 2020 === (int) ($data['arguments']['year'] ?? 0);
+ /** @var ToolCall $toolCall */
+ $toolCall = $execution->get('response')->toolCalls?->first();
+ if (null === $toolCall) {
+ return false;
+ }
+ return 'store_company' === $toolCall->name()
+ && 'ACME' === $toolCall->stringValue('name')
+ && 2020 === $toolCall->intValue('year');
}
private function validateDefault(Execution $execution) : bool {
diff --git a/examples/A02_Advanced/CustomPrompts/run.php b/examples/A02_Advanced/CustomPrompts/run.php
index cc0ffec8..b15f3312 100644
--- a/examples/A02_Advanced/CustomPrompts/run.php
+++ b/examples/A02_Advanced/CustomPrompts/run.php
@@ -1,68 +1,69 @@
----
-title: 'Custom prompts'
-docname: 'custom_prompts'
----
-
-## Overview
-
-In case you want to take control over the prompts sent by Instructor
-to LLM for different modes, you can use the `prompt` parameter in the
-`request()` or `respond()` methods.
-
-It will override the default Instructor prompts, allowing you to fully
-customize how LLM is instructed to process the input.
-
-## Example
-
-```php
-add('Cognesy\\Instructor\\', __DIR__ . '../../src/');
-
-use Cognesy\Instructor\Enums\Mode;
-use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM;
-use Cognesy\Instructor\Instructor;
-
-class User {
- public int $age;
- public string $name;
-}
-
-$instructor = (new Instructor)
- ->onEvent(RequestSentToLLM::class, fn(RequestSentToLLM $event)=>dump($event->request->toMessages()));
-
-print("\n# Request for Mode::Tools:\n\n");
-$user = $instructor
- ->respond(
- messages: "Our user Jason is 25 years old.",
- responseModel: User::class,
- prompt: "\nYour task is to extract correct and accurate data from the messages using provided tools.\n",
- mode: Mode::Tools
- );
-echo "\nRESPONSE:\n";
-dump($user);
-
-print("\n# Request for Mode::Json:\n\n");
-$user = $instructor
- ->respond(
- messages: "Our user Jason is 25 years old.",
- responseModel: User::class,
- prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n<|json_schema|>\n",
- mode: Mode::Json
- );
-echo "\nRESPONSE:\n";
-dump($user);
-
-print("\n# Request for Mode::MdJson:\n\n");
-$user = $instructor
- ->respond(
- messages: "Our user Jason is 25 years old.",
- responseModel: User::class,
- prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n",
- mode: Mode::MdJson
- );
-echo "\nRESPONSE:\n";
-dump($user);
-
-?>
-```
+---
+title: 'Custom prompts'
+docname: 'custom_prompts'
+---
+
+## Overview
+
+In case you want to take control over the prompts sent by Instructor
+to LLM for different modes, you can use the `prompt` parameter in the
+`request()` or `respond()` methods.
+
+It will override the default Instructor prompts, allowing you to fully
+customize how LLM is instructed to process the input.
+
+## Example
+
+```php
+add('Cognesy\\Instructor\\', __DIR__ . '../../src/');
+
+use Cognesy\Instructor\Enums\Mode;
+use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM;
+use Cognesy\Instructor\Instructor;
+
+class User {
+ public int $age;
+ public string $name;
+}
+
+$instructor = (new Instructor)
+ // let's dump the request data to see how customized prompts look like in requests
+ ->onEvent(RequestSentToLLM::class, fn(RequestSentToLLM $event) => dump($event));
+
+print("\n# Request for Mode::Tools:\n\n");
+$user = $instructor
+ ->respond(
+ messages: "Our user Jason is 25 years old.",
+ responseModel: User::class,
+ prompt: "\nYour task is to extract correct and accurate data from the messages using provided tools.\n",
+ mode: Mode::Tools
+ );
+echo "\nRESPONSE:\n";
+dump($user);
+
+print("\n# Request for Mode::Json:\n\n");
+$user = $instructor
+ ->respond(
+ messages: "Our user Jason is 25 years old.",
+ responseModel: User::class,
+ prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n<|json_schema|>\n",
+ mode: Mode::Json
+ );
+echo "\nRESPONSE:\n";
+dump($user);
+
+print("\n# Request for Mode::MdJson:\n\n");
+$user = $instructor
+ ->respond(
+ messages: "Our user Jason is 25 years old.",
+ responseModel: User::class,
+ prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n",
+ mode: Mode::MdJson
+ );
+echo "\nRESPONSE:\n";
+dump($user);
+
+?>
+```
diff --git a/examples/A02_Advanced/ProvidingExamples/run.php b/examples/A02_Advanced/ProvidingExamples/run.php
index 2a9e4128..017be1dc 100644
--- a/examples/A02_Advanced/ProvidingExamples/run.php
+++ b/examples/A02_Advanced/ProvidingExamples/run.php
@@ -1,57 +1,58 @@
----
-title: 'Providing example inputs and outputs'
-docname: 'demonstrations'
----
-
-## Overview
-
-To improve the results of LLM inference you can provide examples of the expected output.
-This will help LLM to understand the context and the expected structure of the output.
-
-It is typically useful in the `Mode::Json` and `Mode::MdJson` modes, where the output
-is expected to be a JSON object.
-
-
-## Example
-
-```php
-add('Cognesy\\Instructor\\', __DIR__ . '../../src/');
-
-use Cognesy\Instructor\Enums\Mode;
-use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM;
-use Cognesy\Instructor\Features\Core\Data\Example;
-use Cognesy\Instructor\Instructor;
-
-class User {
- public int $age;
- public string $name;
-}
-
-echo "\nREQUEST:\n";
-$user = (new Instructor)
- ->onEvent(RequestSentToLLM::class, fn($event)=>dump($event->request->toMessages()))
- ->request(
- messages: "Our user Jason is 25 years old.",
- responseModel: User::class,
- examples: [
- new Example(
- input: "John is 50 and works as a teacher.",
- output: ['name' => 'John', 'age' => 50]
- ),
- new Example(
- input: "We have recently hired Ian, who is 27 years old.",
- output: ['name' => 'Ian', 'age' => 27],
- template: "example input:\n<|input|>\noutput:\n```json\n<|output|>\n```\n",
- ),
- ],
- mode: Mode::Json)
- ->get();
-
-echo "\nOUTPUT:\n";
-dump($user);
-assert($user->name === 'Jason');
-assert($user->age === 25);
-?>
-```
+---
+title: 'Providing example inputs and outputs'
+docname: 'demonstrations'
+---
+
+## Overview
+
+To improve the results of LLM inference you can provide examples of the expected output.
+This will help LLM to understand the context and the expected structure of the output.
+
+It is typically useful in the `Mode::Json` and `Mode::MdJson` modes, where the output
+is expected to be a JSON object.
+
+
+## Example
+
+```php
+add('Cognesy\\Instructor\\', __DIR__ . '../../src/');
+
+use Cognesy\Instructor\Enums\Mode;
+use Cognesy\Instructor\Events\HttpClient\RequestSentToLLM;
+use Cognesy\Instructor\Features\Core\Data\Example;
+use Cognesy\Instructor\Instructor;
+
+class User {
+ public int $age;
+ public string $name;
+}
+
+echo "\nREQUEST:\n";
+$user = (new Instructor)
+ // let's dump the request data to see how examples are used in requests
+ ->onEvent(RequestSentToLLM::class, fn($event) => dump($event))
+ ->request(
+ messages: "Our user Jason is 25 years old.",
+ responseModel: User::class,
+ examples: [
+ new Example(
+ input: "John is 50 and works as a teacher.",
+ output: ['name' => 'John', 'age' => 50]
+ ),
+ new Example(
+ input: "We have recently hired Ian, who is 27 years old.",
+ output: ['name' => 'Ian', 'age' => 27],
+ template: "example input:\n<|input|>\noutput:\n```json\n<|output|>\n```\n",
+ ),
+ ],
+ mode: Mode::Json)
+ ->get();
+
+echo "\nOUTPUT:\n";
+dump($user);
+assert($user->name === 'Jason');
+assert($user->age === 25);
+?>
+```
diff --git a/examples/A05_Extras/LLMToolUse/run.php b/examples/A05_Extras/LLMToolUse/run.php
new file mode 100644
index 00000000..17e8e213
--- /dev/null
+++ b/examples/A05_Extras/LLMToolUse/run.php
@@ -0,0 +1,209 @@
+---
+title: 'Inference and tool use'
+docname: 'llm_tool_use'
+---
+
+## Overview
+
+
+## Example
+
+
+```php
+add('Cognesy\\Instructor\\', __DIR__ . '../../src/');
+
+function add_numbers($a, $b) : int {
+ return $a + $b;
+}
+
+function subtract_numbers($a, $b) : int {
+ return $a - $b;
+}
+
+$tools = [
+ [
+ 'type' => 'function',
+ 'function' => [
+ 'name' => 'add_numbers',
+ 'description' => 'Add two numbers',
+ 'parameters' => [
+ 'type' => 'object',
+ 'description' => 'Numbers to add',
+ 'properties' => [
+ 'a' => [
+ 'type' => 'integer',
+ 'description' => 'First number',
+ ],
+ 'b' => [
+ 'type' => 'integer',
+ 'description' => 'Second number',
+ ],
+ ],
+ 'required' => ['a', 'b'],
+ 'additionalProperties' => false,
+ ],
+ ],
+ ],
+ [
+ 'type' => 'function',
+ 'function' => [
+ 'name' => 'subtract_numbers',
+ 'description' => 'Subtract two numbers',
+ 'parameters' => [
+ 'type' => 'object',
+ 'description' => 'Numbers to subtract',
+ 'properties' => [
+ 'a' => [
+ 'type' => 'integer',
+ 'description' => 'First number',
+ ],
+ 'b' => [
+ 'type' => 'integer',
+ 'description' => 'Second number',
+ ],
+ ],
+ 'required' => ['a', 'b'],
+ 'additionalProperties' => false,
+ ],
+ ],
+ ]
+];
+
+$chat = (new Inference)
+ ->withConnection('openrouter');
+
+$prompt = 'Add 2455 and 3558 then subtract 4344 from the result.';
+
+$messages = [
+ ['role' => 'user', 'content' => $prompt]
+];
+
+Debug::enable();
+
+$toolUse = new ToolUse($chat, $tools);
+$response = $toolUse->response($messages);
+
+class ToolUse {
+ private Inference $inference;
+ private array $tools;
+ private int $maxDepth;
+ private array $options;
+ private bool $parallelToolCalls;
+ private string|array $toolChoice;
+ /** @var LLMResponse[] */
+ private array $responses = [];
+
+ public function __construct(
+ Inference $chat,
+ array $tools,
+ int $maxDepth = 3,
+ array $options = [],
+ string|array $toolChoice = 'auto',
+ bool $parallelToolCalls = false
+ ) {
+ $this->inference = $chat;
+ $this->tools = $tools;
+ $this->maxDepth = $maxDepth;
+ $this->parallelToolCalls = $parallelToolCalls;
+ $this->toolChoice = $toolChoice;
+ $this->options = $options;
+ }
+
+ public function response(string|array $messages) : LLMResponse {
+ $messages = match(true) {
+ is_string($messages) => [['role' => 'user', 'content' => $messages]],
+ is_array($messages) => $messages,
+ default => []
+ };
+ $chat = (new Messages)->appendMessages($messages);
+ return $this->tryGetResponse($chat);
+ }
+
+ /**
+ * @return LLMResponse[]
+ */
+ public function responses() : array {
+ return $this->responses;
+ }
+
+ // INTERNAL //////////////////////////////////////////////
+
+ private function tryGetResponse(Messages $chat) : LLMResponse {
+ $response = $this->getResponse($chat->toArray());
+ $this->responses[] = $response;
+ $depth = 0;
+ while ($this->tryContinue($response)) {
+ if ($depth++ >= $this->maxDepth) {
+ break;
+ }
+ $responseMessages = $this->makeToolsResponseMessages($response);
+ $chat->appendMessages($responseMessages);
+ $response = $this->getResponse($chat->toArray());
+ $this->responses[] = $response;
+ }
+ return $response;
+ }
+
+ private function getResponse(string|array $messages) : LLMResponse {
+ return $this->inference
+ ->create(
+ messages: $messages,
+ tools: $this->tools,
+ toolChoice: $this->toolChoice,
+ options: array_merge($this->options, ['parallel_tool_calls' => $this->parallelToolCalls]),
+ mode: Mode::Tools,
+ )->response();
+ }
+
+ private function makeToolsResponseMessages(LLMResponse $response) : array {
+ $messages = [];
+ $toolCalls = $response->toolCalls();
+ $count = 0;
+ foreach ($toolCalls->all() as $toolCall) {
+ $function = $toolCall->name();
+ $args = $toolCall->args();
+ $result = $function(...$args);
+ $resultString = match(true) {
+ is_string($result) => $result,
+ is_array($result) => Json::encode($result),
+ is_object($result) => Json::encode($result),
+ default => (string) $result,
+ };
+ $messages[] = [
+ 'role' => 'assistant',
+ '_metadata' => [
+ 'tool_calls' => [$toolCall->toToolCallArray()]
+ ]
+ ];
+ $messages[] = [
+ 'role' => 'tool',
+ 'content' => $resultString,
+ '_metadata' => [
+ 'tool_call_id' => $toolCall->id(),
+ 'tool_name' => $toolCall->name(),
+ 'result' => $result
+ ]
+ ];
+ $count++;
+ if ($this->parallelToolCalls && $count >= 0) {
+ break;
+ }
+ }
+ return $messages;
+ }
+
+ private function tryContinue(LLMResponse $response) : bool {
+ return $response->hasToolCalls() || LLMFinishReason::ToolCalls->equals($response->finishReason());
+ }
+}
\ No newline at end of file
diff --git a/prompts/system/mode_tools.twig b/prompts/system/mode_tools.twig
index 8b137891..3936d7e1 100644
--- a/prompts/system/mode_tools.twig
+++ b/prompts/system/mode_tools.twig
@@ -1 +1,45 @@
+
+
+ {{ request.system }}
+
+ {% if request.context_cache.prompt is not empty %}
+
+ TASK:
+ {{ request.context_cache.prompt }}
+
+ {% endif %}
+
+ {% if request.context_cache.examples is not empty %}
+
+ EXAMPLES:
+ {{ request.context_cache.examples }}
+
+ {% endif %}
+
+ {% if request.context_cache.input is not empty %}
+
+ INPUT:
+ {{ request.context_cache.input }}
+
+ {% endif %}
+
+ {% if request.context_cache is not empty %}
+
+
+ INSTRUCTIONS:
+
+
+ {% endif %}
+
+
+ {% for attempt in request.attempts %}
+
+ {{ attempt.llmResponse.content }}
+
+
+ Settings::get('llm', 'defaultRetryPrompt')
+ {{ attempt.errors }}
+
+ {% endfor %}
+
\ No newline at end of file
diff --git a/src/Extras/Prompt/Prompt.php b/src/Extras/Prompt/Prompt.php
index 7e904578..6579c3e4 100644
--- a/src/Extras/Prompt/Prompt.php
+++ b/src/Extras/Prompt/Prompt.php
@@ -244,18 +244,44 @@ private function getMessageContent(XmlElement $element) : array {
if ($child->tag() !== 'content') {
continue;
}
- // check if content type is text, image or audio
$type = $child->attribute('type', 'text');
$content[] = match($type) {
- 'image' => ['type' => 'image_url', 'image_url' => ['url' => $child->content()]],
- 'audio' => ['type' => 'input_audio', 'input_audio' => ['data' => $child->content(), 'format' => $child->attribute('format', 'mp3')]],
- 'text' => ['type' => 'text', 'text' => $child->content()],
- default => throw new InvalidArgumentException("Invalid content type: $type"),
+ 'image' => $this->makeImageContent($child),
+ 'audio' => $this->makeAudioContent($child),
+ default => $this->makeTextContent($child),
};
}
return $content;
}
+ private function makeTextContent(XmlElement $child) : array {
+ $hasCacheControl = $child->attribute('cache', false);
+ return array_filter([
+ 'type' => 'text',
+ 'text' => $child->content(),
+ 'cache_control' => $hasCacheControl ? ['type' => 'ephemeral'] : []
+ ]);
+ }
+
+ private function makeImageContent(XmlElement $child) : array {
+ return [
+ 'type' => 'image_url',
+ 'image_url' => [
+ 'url' => $child->content()
+ ]
+ ];
+ }
+
+ private function makeAudioContent(XmlElement $child) : array {
+ return [
+ 'type' => 'input_audio',
+ 'input_audio' => [
+ 'data' => $child->content(),
+ 'format' => $child->attribute('format', 'mp3')
+ ]
+ ];
+ }
+
private function validateVariables(array $infoVars, array $templateVars, array $valueKeys) : array {
$messages = [];
foreach($infoVars as $var) {
diff --git a/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php b/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php
index afdb5b6d..25464fa5 100644
--- a/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php
+++ b/src/Features/Core/Data/Traits/ChatTemplate/HandlesSections.php
@@ -1,84 +1,84 @@
-appendMessage(['role' => 'system', 'content' => $system]);
- }
-
- // EXTRACT SYSTEM ROLE FROM MESSAGES - until first non-system message
- foreach ($messages as $message) {
- if ($message['role'] !== 'system') {
- break;
- }
- $output->appendMessage(['role' => 'system', 'content' => $message['content']]);
- }
-
- return $output;
- }
-
- protected function makeMessages(string|array $messages) : Messages {
- $output = new Messages();
- if (empty($messages)) {
- return $output;
- }
-
- // skip system messages
- $index = 0;
- foreach ($messages as $message) {
- if ($message['role'] !== 'system') {
- break;
- }
- $index++;
- }
- $output->appendMessages(array_slice($messages, $index));
- return $output;
- }
-
- protected function makeExamples(array $examples) : Messages {
- $messages = new Messages();
- if (empty($examples)) {
- return $messages;
- }
- foreach ($examples as $item) {
- $example = match(true) {
- is_array($item) => Example::fromArray($item),
- is_string($item) => Example::fromJson($item),
- $item instanceof Example => $item,
- default => throw new Exception('Invalid example type'),
- };
- $messages->appendMessages($example->toMessages());
- }
- return $messages;
- }
-
- protected function makePrompt(string $prompt) : Message {
- return new Message(
- role: 'user',
- content: $prompt
- ?: $this->defaultPrompts[$this->request->mode()->value]
- ?? ''
- );
- }
-
- protected function makeInput(array|object|string $input) : Messages {
- if (empty($input)) {
- return new Messages();
- }
- return Messages::fromInput($input);
- }
-
- protected function makeJsonSchema() : array {
- return $this->request->responseModel()?->toJsonSchema();
- }
+appendMessage(['role' => 'system', 'content' => $system]);
+ }
+
+ // EXTRACT SYSTEM ROLE FROM MESSAGES - until first non-system message
+ foreach ($messages as $message) {
+ if ($message['role'] !== 'system') {
+ break;
+ }
+ $output->appendMessage($message);
+ }
+
+ return $output;
+ }
+
+ protected function makeMessages(string|array $messages) : Messages {
+ $output = new Messages();
+ if (empty($messages)) {
+ return $output;
+ }
+
+ // skip system messages
+ $index = 0;
+ foreach ($messages as $message) {
+ if ($message['role'] !== 'system') {
+ break;
+ }
+ $index++;
+ }
+ $output->appendMessages(array_slice($messages, $index));
+ return $output;
+ }
+
+ protected function makeExamples(array $examples) : Messages {
+ $messages = new Messages();
+ if (empty($examples)) {
+ return $messages;
+ }
+ foreach ($examples as $item) {
+ $example = match(true) {
+ is_array($item) => Example::fromArray($item),
+ is_string($item) => Example::fromJson($item),
+ $item instanceof Example => $item,
+ default => throw new Exception('Invalid example type'),
+ };
+ $messages->appendMessages($example->toMessages());
+ }
+ return $messages;
+ }
+
+ protected function makePrompt(string $prompt) : Message {
+ return new Message(
+ role: 'user',
+ content: $prompt
+ ?: $this->defaultPrompts[$this->request->mode()->value]
+ ?? ''
+ );
+ }
+
+ protected function makeInput(array|object|string $input) : Messages {
+ if (empty($input)) {
+ return new Messages();
+ }
+ return Messages::fromInput($input);
+ }
+
+ protected function makeJsonSchema() : array {
+ return $this->request->responseModel()?->toJsonSchema();
+ }
}
\ No newline at end of file
diff --git a/src/Features/Core/PartialsGenerator.php b/src/Features/Core/PartialsGenerator.php
index 5ecc4c31..de179afa 100644
--- a/src/Features/Core/PartialsGenerator.php
+++ b/src/Features/Core/PartialsGenerator.php
@@ -1,206 +1,206 @@
-toolCalls = new ToolCalls();
- $this->sequenceableHandler = new SequenceableHandler($events);
- }
-
- public function resetPartialResponse() : void {
- $this->previousHash = '';
- $this->responseText = '';
- $this->responseJson = '';
- $this->sequenceableHandler->reset();
- $this->toolCalls->reset();
- }
-
- /**
- * @param Generator $stream
- * @param ResponseModel $responseModel
- * @return Generator
- */
- public function getPartialResponses(Generator $stream, ResponseModel $responseModel) : Generator {
- // reset state
- $this->resetPartialResponse();
-
- // receive data
- /** @var PartialLLMResponse $partialResponse */
- foreach($stream as $partialResponse) {
- $this->events->dispatch(new StreamedResponseReceived($partialResponse));
- // store partial response
- $this->partialResponses[] = $partialResponse;
-
- // situation 1: new function call
- $maybeToolName = $partialResponse->toolName;
- // create next FC only if JSON buffer is not empty (which is the case for 1st iteration)
- if ($maybeToolName) {
- if (empty($this->responseJson)) {
- $this->newToolCall($response->toolName ?? $responseModel->toolName());
- } else {
- $this->finalizeToolCall($this->responseJson, $responseModel->toolName());
- $this->responseJson = ''; // reset json buffer
- }
- }
-
- // situation 2: new delta
- $maybeArgumentChunk = $partialResponse->contentDelta;
- if (empty($maybeArgumentChunk)) {
- continue;
- }
- $this->events->dispatch(new ChunkReceived($maybeArgumentChunk));
- $this->responseText .= $maybeArgumentChunk;
- $this->responseJson = Json::fromPartial($this->responseText)->toString();
- if (empty($this->responseJson)) {
- continue;
- }
- if ($this->toolCalls->empty()) {
- $this->newToolCall($responseModel->toolName());
- }
- $result = $this->handleDelta($this->responseJson, $responseModel);
- if ($result->isFailure()) {
- continue;
- }
- $this->events->dispatch(new PartialJsonReceived($this->responseJson));
-
- yield $partialResponse
- ->withValue($result->unwrap())
- ->withContent($this->responseText);
- }
- $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse()));
-
- // finalize last function call
- // check if there are any toolCalls
- if ($this->toolCalls->count() === 0) {
- throw new Exception('No tool calls found in the response');
- }
- // finalize last function call
- if ($this->toolCalls->count() > 0) {
- $this->finalizeToolCall(Json::from($this->responseText)->toString(), $responseModel->toolName());
- }
- // finalize sequenceable
- $this->sequenceableHandler->finalize();
- }
-
- // INTERNAL ////////////////////////////////////////////////////////
-
- protected function handleDelta(
- string $partialJson,
- ResponseModel $responseModel
- ) : Result {
- return Chain::make()
- ->through(fn() => $this->validatePartialResponse($partialJson, $responseModel, $this->preventJsonSchema, $this->matchToExpectedFields))
- ->tap(fn() => $this->events->dispatch(new PartialJsonReceived($partialJson)))
- ->tap(fn() => $this->updateToolCall($partialJson, $responseModel->toolName()))
- ->through(fn() => $this->tryGetPartialObject($partialJson, $responseModel))
- ->onFailure(fn($result) => $this->events->dispatch(
- new PartialResponseGenerationFailed(Arrays::asArray($result->error()))
- ))
- ->then(fn($result) => $this->getChangedOnly($result))
- ->result();
- }
-
- protected function tryGetPartialObject(
- string $partialJsonData,
- ResponseModel $responseModel,
- ) : Result {
- return Chain::from(fn() => Json::fromPartial($partialJsonData)->toString())
- ->through(fn($json) => $this->responseDeserializer->deserialize($json, $responseModel, $this?->toolCalls->last()->name))
- ->through(fn($object) => $this->responseTransformer->transform($object))
- ->result();
- }
-
- protected function getChangedOnly(Result $result) : ?Result {
- if ($result->isFailure()) {
- return $result;
- }
- $partialObject = $result->unwrap();
- // we only want to send partial response if it's different from the previous one
- $currentHash = hash('xxh3', Json::encode($partialObject));
- if ($this->previousHash == $currentHash) {
- return Result::failure('No changes detected');
- }
- $this->events->dispatch(new PartialResponseGenerated($partialObject));
- if (($partialObject instanceof Sequenceable)) {
- $this->sequenceableHandler->update($partialObject);
- }
- $this->previousHash = $currentHash;
- return $result;
- }
-
- public function getCompleteResponse() : LLMResponse {
- return LLMResponse::fromPartialResponses($this->partialResponses);
- }
-
- public function lastPartialResponse() : PartialLLMResponse {
- $index = count($this->partialResponses) - 1;
- return $this->partialResponses[$index];
- }
-
- public function partialResponses() : array {
- return $this->partialResponses;
- }
-
- protected function newToolCall(string $name) : ToolCall {
- $newToolCall = $this->toolCalls->create($name);
- $this->events->dispatch(new StreamedToolCallStarted($newToolCall));
- return $newToolCall;
- }
-
- protected function updateToolCall(string $responseJson, string $defaultName) : ToolCall {
- $updatedToolCall = $this->toolCalls->updateLast($responseJson, $defaultName);
- $this->events->dispatch(new StreamedToolCallUpdated($updatedToolCall));
- return $updatedToolCall;
- }
-
- protected function finalizeToolCall(string $responseJson, string $defaultName) : ToolCall {
- $finalizedToolCall = $this->toolCalls->finalizeLast($responseJson, $defaultName);
- $this->events->dispatch(new StreamedToolCallCompleted($finalizedToolCall));
- return $finalizedToolCall;
- }
+toolCalls = new ToolCalls();
+ $this->sequenceableHandler = new SequenceableHandler($events);
+ }
+
+ public function resetPartialResponse() : void {
+ $this->previousHash = '';
+ $this->responseText = '';
+ $this->responseJson = '';
+ $this->sequenceableHandler->reset();
+ $this->toolCalls->reset();
+ }
+
+ /**
+ * @param Generator $stream
+ * @param ResponseModel $responseModel
+ * @return Generator
+ */
+ public function getPartialResponses(Generator $stream, ResponseModel $responseModel) : Generator {
+ // reset state
+ $this->resetPartialResponse();
+
+ // receive data
+ /** @var PartialLLMResponse $partialResponse */
+ foreach($stream as $partialResponse) {
+ $this->events->dispatch(new StreamedResponseReceived($partialResponse));
+ // store partial response
+ $this->partialResponses[] = $partialResponse;
+
+ // situation 1: new function call
+ $maybeToolName = $partialResponse->toolName;
+ // create next FC only if JSON buffer is not empty (which is the case for 1st iteration)
+ if ($maybeToolName) {
+ if (empty($this->responseJson)) {
+ $this->newToolCall($response->toolName ?? $responseModel->toolName());
+ } else {
+ $this->finalizeToolCall($this->responseJson, $responseModel->toolName());
+ $this->responseJson = ''; // reset json buffer
+ }
+ }
+
+ // situation 2: new delta
+ $maybeArgumentChunk = $partialResponse->contentDelta;
+ if (empty($maybeArgumentChunk)) {
+ continue;
+ }
+ $this->events->dispatch(new ChunkReceived($maybeArgumentChunk));
+ $this->responseText .= $maybeArgumentChunk;
+ $this->responseJson = Json::fromPartial($this->responseText)->toString();
+ if (empty($this->responseJson)) {
+ continue;
+ }
+ if ($this->toolCalls->empty()) {
+ $this->newToolCall($responseModel->toolName());
+ }
+ $result = $this->handleDelta($this->responseJson, $responseModel);
+ if ($result->isFailure()) {
+ continue;
+ }
+ $this->events->dispatch(new PartialJsonReceived($this->responseJson));
+
+ yield $partialResponse
+ ->withValue($result->unwrap())
+ ->withContent($this->responseText);
+ }
+ $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse()));
+
+ // finalize last function call
+ // check if there are any toolCalls
+ if ($this->toolCalls->count() === 0) {
+ throw new Exception('No tool calls found in the response');
+ }
+ // finalize last function call
+ if ($this->toolCalls->count() > 0) {
+ $this->finalizeToolCall(Json::from($this->responseText)->toString(), $responseModel->toolName());
+ }
+ // finalize sequenceable
+ $this->sequenceableHandler->finalize();
+ }
+
+ // INTERNAL ////////////////////////////////////////////////////////
+
+ protected function handleDelta(
+ string $partialJson,
+ ResponseModel $responseModel
+ ) : Result {
+ return Chain::make()
+ ->through(fn() => $this->validatePartialResponse($partialJson, $responseModel, $this->preventJsonSchema, $this->matchToExpectedFields))
+ ->tap(fn() => $this->events->dispatch(new PartialJsonReceived($partialJson)))
+ ->tap(fn() => $this->updateToolCall($partialJson, $responseModel->toolName()))
+ ->through(fn() => $this->tryGetPartialObject($partialJson, $responseModel))
+ ->onFailure(fn($result) => $this->events->dispatch(
+ new PartialResponseGenerationFailed(Arrays::asArray($result->error()))
+ ))
+ ->then(fn($result) => $this->getChangedOnly($result))
+ ->result();
+ }
+
+ protected function tryGetPartialObject(
+ string $partialJsonData,
+ ResponseModel $responseModel,
+ ) : Result {
+ return Chain::from(fn() => Json::fromPartial($partialJsonData)->toString())
+ ->through(fn($json) => $this->responseDeserializer->deserialize($json, $responseModel, $this->toolCalls->last()?->name()))
+ ->through(fn($object) => $this->responseTransformer->transform($object))
+ ->result();
+ }
+
+ protected function getChangedOnly(Result $result) : ?Result {
+ if ($result->isFailure()) {
+ return $result;
+ }
+ $partialObject = $result->unwrap();
+ // we only want to send partial response if it's different from the previous one
+ $currentHash = hash('xxh3', Json::encode($partialObject));
+ if ($this->previousHash == $currentHash) {
+ return Result::failure('No changes detected');
+ }
+ $this->events->dispatch(new PartialResponseGenerated($partialObject));
+ if (($partialObject instanceof Sequenceable)) {
+ $this->sequenceableHandler->update($partialObject);
+ }
+ $this->previousHash = $currentHash;
+ return $result;
+ }
+
+ public function getCompleteResponse() : LLMResponse {
+ return LLMResponse::fromPartialResponses($this->partialResponses);
+ }
+
+ public function lastPartialResponse() : PartialLLMResponse {
+ $index = count($this->partialResponses) - 1;
+ return $this->partialResponses[$index];
+ }
+
+ public function partialResponses() : array {
+ return $this->partialResponses;
+ }
+
+ protected function newToolCall(string $name) : ToolCall {
+ $newToolCall = $this->toolCalls->add($name);
+ $this->events->dispatch(new StreamedToolCallStarted($newToolCall));
+ return $newToolCall;
+ }
+
+ protected function updateToolCall(string $responseJson, string $defaultName) : ToolCall {
+ $updatedToolCall = $this->toolCalls->updateLast($responseJson, $defaultName);
+ $this->events->dispatch(new StreamedToolCallUpdated($updatedToolCall));
+ return $updatedToolCall;
+ }
+
+ protected function finalizeToolCall(string $responseJson, string $defaultName) : ToolCall {
+ $finalizedToolCall = $this->toolCalls->finalizeLast($responseJson, $defaultName);
+ $this->events->dispatch(new StreamedToolCallCompleted($finalizedToolCall));
+ return $finalizedToolCall;
+ }
}
\ No newline at end of file
diff --git a/src/Features/Core/RequestHandler.php b/src/Features/Core/RequestHandler.php
index ba91ea79..5d4356a2 100644
--- a/src/Features/Core/RequestHandler.php
+++ b/src/Features/Core/RequestHandler.php
@@ -1,154 +1,154 @@
-events = $events;
- $this->retries = 0;
- $this->errors = [];
- }
-
- // PUBLIC //////////////////////////////////////////////////////////////
-
- /**
- * Generates response value
- */
- public function responseFor(Request $request) : LLMResponse {
- $processingResult = Result::failure("No response generated");
- while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
- $llmResponse = $this->getInference($request)->response();
- $llmResponse->content = match($request->mode()) {
- Mode::Text => $llmResponse->content,
- Mode::Tools => $llmResponse->toolCalls->first()?->args
- ?? $llmResponse->content // fallback if no tool calls - some LLMs return just a string
- ?? '',
- default => Json::from($llmResponse->content)->toString(),
- };
- $partialResponses = [];
- $processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
- }
-
- $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);
-
- return $llmResponse->withValue($value);
- }
-
- /**
- * Yields response value versions based on streamed responses
- * @param Request $request
- * @return Generator
- */
- public function streamResponseFor(Request $request) : Generator {
- $processingResult = Result::failure("No response generated");
- while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
- $stream = $this->getInference($request)->stream()->responses();
- yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel());
-
- $llmResponse = $this->partialsGenerator->getCompleteResponse();
- $partialResponses = $this->partialsGenerator->partialResponses();
- $processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
- }
-
- $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);
-
- yield $llmResponse->withValue($value);
- }
-
- // INTERNAL ///////////////////////////////////////////////////////////
-
- protected function getInference(Request $request) : InferenceResponse {
- $inference = new Inference(
- connection: $this->connection,
- httpClient: $this->httpClient,
- driver: $this->driver,
- events: $this->events,
- );
- return $inference
- ->create(
- $request->toMessages(),
- $request->model(),
- $request->toolCallSchema(),
- $request->toolChoice(),
- $request->responseFormat(),
- $request->options(),
- $request->mode()
- );
- }
-
- protected function processResponse(Request $request, LLMResponse $llmResponse, array $partialResponses) : Result {
- // we have LLMResponse here - let's process it: deserialize, validate, transform
- $processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel());
-
- if ($processingResult->isFailure()) {
- // retry - we have not managed to deserialize, validate or transform the response
- $this->handleError($processingResult, $request, $llmResponse, $partialResponses);
- }
-
- return $processingResult;
- }
-
- protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed {
- if ($processingResult->isFailure()) {
- $this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors));
- throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors));
- }
-
- // get final value
- $value = $processingResult->unwrap();
- // store response
- $request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts
- // notify on response generation
- $this->events->dispatch(new ResponseGenerated($value));
-
- return $value;
- }
-
- protected function handleError(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : void {
- $error = $processingResult->error();
- $this->errors = is_array($error) ? $error : [$error];
-
- // store failed response
- $request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts
- $this->retries++;
- if (!$this->maxRetriesReached($request)) {
- $this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors));
- }
- }
-
- protected function maxRetriesReached(Request $request) : bool {
- return $this->retries > $request->maxRetries();
- }
-}
+events = $events;
+ $this->retries = 0;
+ $this->errors = [];
+ }
+
+ // PUBLIC //////////////////////////////////////////////////////////////
+
+ /**
+ * Generates response value
+ */
+ public function responseFor(Request $request) : LLMResponse {
+ $processingResult = Result::failure("No response generated");
+ while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
+ $llmResponse = $this->getInference($request)->response();
+ $llmResponse->content = match($request->mode()) {
+ Mode::Text => $llmResponse->content,
+ Mode::Tools => $llmResponse->toolCalls()->first()?->argsAsJson()
+ ?? $llmResponse->content // fallback if no tool calls - some LLMs return just a string
+ ?? '',
+ default => Json::from($llmResponse->content)->toString(),
+ };
+ $partialResponses = [];
+ $processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
+ }
+
+ $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);
+
+ return $llmResponse->withValue($value);
+ }
+
+ /**
+ * Yields response value versions based on streamed responses
+ * @param Request $request
+ * @return Generator
+ */
+ public function streamResponseFor(Request $request) : Generator {
+ $processingResult = Result::failure("No response generated");
+ while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) {
+ $stream = $this->getInference($request)->stream()->responses();
+ yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel());
+
+ $llmResponse = $this->partialsGenerator->getCompleteResponse();
+ $partialResponses = $this->partialsGenerator->partialResponses();
+ $processingResult = $this->processResponse($request, $llmResponse, $partialResponses);
+ }
+
+ $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses);
+
+ yield $llmResponse->withValue($value);
+ }
+
+ // INTERNAL ///////////////////////////////////////////////////////////
+
+ protected function getInference(Request $request) : InferenceResponse {
+ $inference = new Inference(
+ connection: $this->connection,
+ httpClient: $this->httpClient,
+ driver: $this->driver,
+ events: $this->events,
+ );
+ return $inference
+ ->create(
+ $request->toMessages(),
+ $request->model(),
+ $request->toolCallSchema(),
+ $request->toolChoice(),
+ $request->responseFormat(),
+ $request->options(),
+ $request->mode()
+ );
+ }
+
+ protected function processResponse(Request $request, LLMResponse $llmResponse, array $partialResponses) : Result {
+ // we have LLMResponse here - let's process it: deserialize, validate, transform
+ $processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel());
+
+ if ($processingResult->isFailure()) {
+ // retry - we have not managed to deserialize, validate or transform the response
+ $this->handleError($processingResult, $request, $llmResponse, $partialResponses);
+ }
+
+ return $processingResult;
+ }
+
+ protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed {
+ if ($processingResult->isFailure()) {
+ $this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors));
+ throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors));
+ }
+
+ // get final value
+ $value = $processingResult->unwrap();
+ // store response
+ $request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts
+ // notify on response generation
+ $this->events->dispatch(new ResponseGenerated($value));
+
+ return $value;
+ }
+
+ protected function handleError(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : void {
+ $error = $processingResult->error();
+ $this->errors = is_array($error) ? $error : [$error];
+
+ // store failed response
+ $request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts
+ $this->retries++;
+ if (!$this->maxRetriesReached($request)) {
+ $this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors));
+ }
+ }
+
+ protected function maxRetriesReached(Request $request) : bool {
+ return $this->retries > $request->maxRetries();
+ }
+}
diff --git a/src/Features/LLM/Data/LLMResponse.php b/src/Features/LLM/Data/LLMResponse.php
index e60ed294..00aeb91f 100644
--- a/src/Features/LLM/Data/LLMResponse.php
+++ b/src/Features/LLM/Data/LLMResponse.php
@@ -1,129 +1,125 @@
-usage = $usage ?? new Usage();
- }
-
- // STATIC ////////////////////////////////////////////////
-
- public static function fromPartialResponses(array $partialResponses) : LLMResponse {
- return (new self)->makeFromPartialResponses($partialResponses);
- }
-
- // PUBLIC ////////////////////////////////////////////////
-
- public function hasValue() : bool {
- return $this->value !== null;
- }
-
- public function withValue(mixed $value) : self {
- $this->value = $value;
- return $this;
- }
-
- public function value() : mixed {
- return $this->value;
- }
-
- public function hasContent() : bool {
- return $this->content !== '';
- }
-
- public function content() : string {
- return $this->content;
- }
-
- public function json(): Json {
- return match(true) {
- // TODO: what about tool calls?
- $this->hasContent() => Json::from($this->content),
- default => Json::none(),
- };
- }
-
- public function hasToolCalls() : bool {
- return !empty($this->toolCalls);
- }
-
- public function usage() : Usage {
- return $this->usage ?? new Usage();
- }
-
- // INTERNAL //////////////////////////////////////////////
-
- /**
- * @param PartialLLMResponse[] $partialResponses
- * @return LLMResponse
- */
- private function makeFromPartialResponses(array $partialResponses = []) : self {
- if (empty($partialResponses)) {
- return $this;
- }
-
- $content = '';
- foreach($partialResponses as $partialResponse) {
- if ($partialResponse === null) {
- continue;
- }
- $content .= $partialResponse->contentDelta;
- $this->responseData[] = $partialResponse->responseData;
- $this->usage()->accumulate($partialResponse->usage);
- $this->finishReason = $partialResponse->finishReason;
- }
- $this->content = $content;
-
- $tools = $this->makeTools($partialResponses);
- if (!empty($tools)) {
- $this->toolsData = $this->makeToolsData($tools);
- $this->toolCalls = ToolCalls::fromArray($this->toolsData);
- }
- return $this;
- }
-
- private function makeTools(array $partialResponses): array {
- $tools = [];
- $currentTool = '';
- foreach ($partialResponses as $partialResponse) {
- if ($partialResponse === null) {
- continue;
- }
- if (('' !== ($partialResponse->toolName ?? ''))
- && ($currentTool !== ($partialResponse->toolName ?? ''))) {
- $currentTool = $partialResponse->toolName ?? '';
- $tools[$currentTool] = '';
- }
- if ('' !== $currentTool) {
- if (('' !== ($partialResponse->toolArgs ?? ''))) {
- $tools[$currentTool] .= $partialResponse->toolArgs ?? '';
- }
- }
- }
- return $tools;
- }
-
- private function makeToolsData(array $tools) : array {
- $data = [];
- foreach($tools as $tool => $args) {
- $data[] = [
- 'name' => $tool,
- 'arguments' => '' !== $args ? Json::decode($args) : [],
- ];
- }
- return $data;
- }
-}
+usage = $usage ?? new Usage();
+ }
+
+ // STATIC ////////////////////////////////////////////////
+
+ public static function fromPartialResponses(array $partialResponses) : LLMResponse {
+ return (new self)->makeFromPartialResponses($partialResponses);
+ }
+
+ // PUBLIC ////////////////////////////////////////////////
+
+ public function hasValue() : bool {
+ return $this->value !== null;
+ }
+
+ public function withValue(mixed $value) : self {
+ $this->value = $value;
+ return $this;
+ }
+
+ public function value() : mixed {
+ return $this->value;
+ }
+
+ public function hasContent() : bool {
+ return $this->content !== '';
+ }
+
+ public function content() : string {
+ return $this->content;
+ }
+
+ public function json(): Json {
+ return match(true) {
+ // TODO: what about tool calls?
+ $this->hasContent() => Json::from($this->content),
+ default => Json::none(),
+ };
+ }
+
+ public function hasToolCalls() : bool {
+ return !$this->toolCalls?->empty();
+ }
+
+ public function usage() : Usage {
+ return $this->usage ?? new Usage();
+ }
+
+ public function toolCalls() : ToolCalls {
+ return $this->toolCalls ?? new ToolCalls();
+ }
+
+ public function finishReason() : LLMFinishReason {
+ return LLMFinishReason::fromText($this->finishReason);
+ }
+
+ // INTERNAL //////////////////////////////////////////////
+
+ /**
+ * @param PartialLLMResponse[] $partialResponses
+ * @return LLMResponse
+ */
+ private function makeFromPartialResponses(array $partialResponses = []) : self {
+ if (empty($partialResponses)) {
+ return $this;
+ }
+
+ $content = '';
+ foreach($partialResponses as $partialResponse) {
+ if ($partialResponse === null) {
+ continue;
+ }
+ $content .= $partialResponse->contentDelta;
+ $this->responseData[] = $partialResponse->responseData;
+ $this->usage()->accumulate($partialResponse->usage);
+ $this->finishReason = $partialResponse->finishReason;
+ }
+ $this->content = $content;
+
+ $tools = $this->makeTools($partialResponses);
+ if (!empty($tools)) {
+ $this->toolCalls = ToolCalls::fromArray($tools);
+ }
+ return $this;
+ }
+
+ private function makeTools(array $partialResponses): array {
+ $tools = [];
+ $currentTool = '';
+ foreach ($partialResponses as $partialResponse) {
+ if ($partialResponse === null) {
+ continue;
+ }
+ if (('' !== ($partialResponse->toolName ?? ''))
+ && ($currentTool !== ($partialResponse->toolName ?? ''))) {
+ $currentTool = $partialResponse->toolName ?? '';
+ $tools[$currentTool] = '';
+ }
+ if ('' !== $currentTool) {
+ if (('' !== ($partialResponse->toolArgs ?? ''))) {
+ $tools[$currentTool] .= $partialResponse->toolArgs ?? '';
+ }
+ }
+ }
+ return $tools;
+ }
+}
diff --git a/src/Features/LLM/Data/PartialLLMResponse.php b/src/Features/LLM/Data/PartialLLMResponse.php
index 9904b357..99683f58 100644
--- a/src/Features/LLM/Data/PartialLLMResponse.php
+++ b/src/Features/LLM/Data/PartialLLMResponse.php
@@ -1,64 +1,65 @@
-value !== null;
- }
-
- public function withValue(mixed $value) : self {
- $this->value = $value;
- return $this;
- }
-
- public function value() : mixed {
- return $this->value;
- }
-
- public function hasContent() : bool {
- return $this->content !== '';
- }
-
- public function withContent(string $content) : self {
- $this->content = $content;
- return $this;
- }
-
- public function content() : string {
- return $this->content;
- }
-
- public function json(): string {
- if (!$this->hasContent()) {
- return '';
- }
- return Json::fromPartial($this->content)->toString();
- }
-
- public function withFinishReason(string $finishReason) : self {
- $this->finishReason = $finishReason;
- return $this;
- }
-
- public function usage() : Usage {
- return $this->usage ?? new Usage();
- }
-}
+value !== null;
+ }
+
+ public function withValue(mixed $value) : self {
+ $this->value = $value;
+ return $this;
+ }
+
+ public function value() : mixed {
+ return $this->value;
+ }
+
+ public function hasContent() : bool {
+ return $this->content !== '';
+ }
+
+ public function withContent(string $content) : self {
+ $this->content = $content;
+ return $this;
+ }
+
+ public function content() : string {
+ return $this->content;
+ }
+
+ public function json(): string {
+ if (!$this->hasContent()) {
+ return '';
+ }
+ return Json::fromPartial($this->content)->toString();
+ }
+
+ public function withFinishReason(string $finishReason) : self {
+ $this->finishReason = $finishReason;
+ return $this;
+ }
+
+ public function usage() : Usage {
+ return $this->usage ?? new Usage();
+ }
+}
diff --git a/src/Features/LLM/Data/ToolCall.php b/src/Features/LLM/Data/ToolCall.php
index 10d1523a..69358d6e 100644
--- a/src/Features/LLM/Data/ToolCall.php
+++ b/src/Features/LLM/Data/ToolCall.php
@@ -1,25 +1,131 @@
- empty($toolCall['arguments']) ? '' : Json::encode($toolCall['arguments']),
- is_string($toolCall['arguments'] ?? false) => $toolCall['arguments'] ?? '',
- is_null($toolCall['arguments']) => '',
- default => throw new \InvalidArgumentException('ToolCall args must be a string or an array')
- }
- );
- }
-}
+id = $id;
+ $this->name = $name;
+ $this->arguments = match(true) {
+ is_array($args) => $args,
+ is_string($args) => Json::from($args)->toArray(),
+ default => []
+ };
+ }
+
+ public static function fromArray(array $toolCall) : ?ToolCall {
+ if (empty($toolCall['name'])) {
+ return null;
+ }
+ return new ToolCall(
+ name: $toolCall['name'] ?? '',
+ args: match(true) {
+ is_array($toolCall['arguments'] ?? false) => $toolCall['arguments'],
+ is_string($toolCall['arguments'] ?? false) => $toolCall['arguments'],
+ is_null($toolCall['arguments'] ?? null) => [],
+ default => throw new \InvalidArgumentException('ToolCall args must be a string or an array')
+ },
+ id: $toolCall['id'] ?? ''
+ );
+ }
+
+ public function withId(string $id) : self {
+ $this->id = $id;
+ return $this;
+ }
+
+ public function withName(string $name) : self {
+ $this->name = $name;
+ return $this;
+ }
+
+ public function withArgs(string|array $args) : self {
+ $this->arguments = match(true) {
+ is_array($args) => $args,
+ is_string($args) => Json::from($args)->toArray(),
+ default => []
+ };
+ return $this;
+ }
+
+ public function hasArgs() : bool {
+ return !empty($this->arguments);
+ }
+
+ public function id() : string {
+ return $this->id;
+ }
+
+ public function name() : string {
+ return $this->name;
+ }
+
+ public function args() : array {
+ return $this->arguments;
+ }
+
+ public function argsAsJson() : string {
+ return Json::encode($this->arguments);
+ }
+
+ public function hasValue(string $key) : bool {
+ return isset($this->arguments[$key]);
+ }
+
+ public function value(string $key, mixed $default = null) : mixed {
+ return $this->arguments[$key] ?? $default;
+ }
+
+ public function intValue(string $key, int $default = 0) : int {
+ return (int) ($this->arguments[$key] ?? $default);
+ }
+
+ public function boolValue(string $key, bool $default = false) : bool {
+ return (bool) ($this->arguments[$key] ?? $default);
+ }
+
+ public function stringValue(string $key, string $default = '') : string {
+ return (string) ($this->arguments[$key] ?? $default);
+ }
+
+ public function arrayValue(string $key, array $default = []) : array {
+ return (array) ($this->arguments[$key] ?? $default);
+ }
+
+ public function objectValue(string $key, object $default = null) : object {
+ return (object) ($this->arguments[$key] ?? $default);
+ }
+
+ public function floatValue(string $key, float $default = 0.0) : float {
+ return (float) ($this->arguments[$key] ?? $default);
+ }
+
+ public function toArray() : array {
+ return [
+ 'name' => $this->name,
+ 'arguments' => $this->arguments,
+ ];
+ }
+
+ public function toToolCallArray() : array {
+ return [
+ 'id' => $this->id,
+ 'type' => 'function',
+ 'function' => [
+ 'name' => $this->name,
+ 'arguments' => Json::encode($this->arguments),
+ ]
+ ];
+ }
+}
diff --git a/src/Features/LLM/Data/ToolCalls.php b/src/Features/LLM/Data/ToolCalls.php
index b658f2d1..bae8c556 100644
--- a/src/Features/LLM/Data/ToolCalls.php
+++ b/src/Features/LLM/Data/ToolCalls.php
@@ -1,86 +1,110 @@
-toolCalls = $toolCalls;
- }
-
- public static function fromArray(array $toolCalls) : ToolCalls {
- $list = [];
- foreach ($toolCalls as $toolCall) {
- $list[] = ToolCall::fromArray($toolCall);
- }
- return new ToolCalls($list);
- }
-
- public static function fromMapper(array $toolCalls, callable $mapper) : ToolCalls {
- $list = [];
- foreach ($toolCalls as $toolCall) {
- $list[] = $mapper($toolCall);
- }
- return new ToolCalls($list);
- }
-
- public function count() : int {
- return count($this->toolCalls);
- }
-
- public function first() : ?ToolCall {
- return $this->toolCalls[0] ?? null;
- }
-
- public function last() : ?ToolCall {
- if (empty($this->toolCalls)) {
- return null;
- }
- return $this->toolCalls[count($this->toolCalls) - 1];
- }
-
- public function empty() : bool {
- return empty($this->toolCalls);
- }
-
- public function all() : array {
- return $this->toolCalls;
- }
-
- public function reset() : void {
- $this->toolCalls = [];
- }
-
- public function create(string $toolName, string $args = '') : ToolCall {
- $newToolCall = new ToolCall(
- name: $toolName,
- args: $args
- );
- $this->toolCalls[] = $newToolCall;
- return $newToolCall;
- }
-
- public function updateLast(string $responseJson, string $defaultName) : ToolCall {
- $last = $this->last();
- if (empty($last)) {
- return $this->create($defaultName, $responseJson);
- }
- $last->name = $last->name ?? $defaultName;
- $last->args = $responseJson;
- return $this->last();
- }
-
- public function finalizeLast(string $responseJson, string $defaultName) : ToolCall {
- $last = $this->last();
- if (empty($last)) {
- return $this->create($defaultName, $responseJson);
- }
- $last->name = $last->name ?? $defaultName;
- $last->args = $responseJson;
- return $this->last();
- }
+toolCalls = $toolCalls;
+ }
+
+ public static function fromArray(array $toolCalls) : ToolCalls {
+ $list = [];
+ foreach ($toolCalls as $key => $toolCall) {
+ $list[] = match(true) {
+ is_array($toolCall) => ToolCall::fromArray($toolCall),
+ is_object($toolCall) && $toolCall instanceof ToolCall => $toolCall,
+ is_string($toolCall) => new ToolCall($key, $toolCall),
+ default => throw new InvalidArgumentException('Cannot create ToolCall from provided data: ' . print_r($toolCall, true))
+ };
+ }
+ return new ToolCalls($list);
+ }
+
+ public static function fromMapper(array $toolCalls, callable $mapper) : ToolCalls {
+ $list = [];
+ foreach ($toolCalls as $item) {
+ $toolCall = $mapper($item);
+ if ($toolCall instanceof ToolCall) {
+ $list[] = $toolCall;
+ }
+ }
+ return new ToolCalls($list);
+ }
+
+ public function count() : int {
+ return count($this->toolCalls);
+ }
+
+ public function first() : ?ToolCall {
+ return $this->toolCalls[0] ?? null;
+ }
+
+ public function last() : ?ToolCall {
+ if (empty($this->toolCalls)) {
+ return null;
+ }
+ return $this->toolCalls[count($this->toolCalls) - 1];
+ }
+
+ public function empty() : bool {
+ return empty($this->toolCalls);
+ }
+
+ /**
+ * @return ToolCall[]
+ */
+ public function all() : array {
+ return $this->toolCalls;
+ }
+
+ public function reset() : void {
+ $this->toolCalls = [];
+ }
+
+ public function add(string $toolName, string $args = '') : ToolCall {
+ $newToolCall = new ToolCall(
+ name: $toolName,
+ args: $args
+ );
+ $this->toolCalls[] = $newToolCall;
+ return $newToolCall;
+ }
+
+ public function updateLast(string $responseJson, string $defaultName) : ToolCall {
+ $last = $this->last();
+ if (empty($last)) {
+ return $this->add($defaultName, $responseJson);
+ }
+ $last->withName($last->name() ?? $defaultName);
+ $last->withArgs($responseJson);
+ return $this->last();
+ }
+
+ public function finalizeLast(string $responseJson, string $defaultName) : ToolCall {
+ $last = $this->last();
+ if (empty($last)) {
+ return $this->add($defaultName, $responseJson);
+ }
+ $last->withName($last->name() ?? $defaultName);
+ $last->withArgs($responseJson);
+ return $this->last();
+ }
+
+ public function toArray() : array {
+ $list = [];
+ foreach ($this->toolCalls as $toolCall) {
+ $list[] = $toolCall->toArray();
+ }
+ return $list;
+ }
}
\ No newline at end of file
diff --git a/src/Features/LLM/Drivers/AnthropicDriver.php b/src/Features/LLM/Drivers/AnthropicDriver.php
index b7b728de..38ab25f8 100644
--- a/src/Features/LLM/Drivers/AnthropicDriver.php
+++ b/src/Features/LLM/Drivers/AnthropicDriver.php
@@ -20,6 +20,8 @@
class AnthropicDriver implements CanHandleInference
{
+ private bool $parallelToolCalls = false;
+
public function __construct(
protected LLMConfig $config,
protected ?CanHandleHttp $httpClient = null,
@@ -72,18 +74,22 @@ public function getRequestBody(
array $options = [],
Mode $mode = Mode::Text,
) : array {
- $request = array_filter(array_merge([
+ $this->parallelToolCalls = $options['parallel_tool_calls'] ?? false;
+ unset($options['parallel_tool_calls']);
+
+ $request = array_merge(array_filter([
'model' => $model ?: $this->config->model,
'max_tokens' => $options['max_tokens'] ?? $this->config->maxTokens,
'system' => Messages::fromArray($messages)
->forRoles(['system'])
->toString(),
- 'messages' => $this->toNativeMessages(Messages::fromArray($messages)
- ->exceptRoles(['system'])
- ->toMergedPerRole()
- ->toArray()
+ 'messages' => $this->toNativeMessages(
+ Messages::fromArray($messages)
+ ->exceptRoles(['system'])
+ //->toMergedPerRole()
+ ->toArray()
),
- ], $options));
+ ]), $options);
return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
}
@@ -94,9 +100,6 @@ public function toLLMResponse(array $data): ?LLMResponse {
return new LLMResponse(
content: $this->makeContent($data),
responseData: $data,
-// toolName: $data['content'][0]['name'] ?? '',
-// toolArgs: Json::encode($data['content'][0]['input'] ?? ''),
- toolsData: $this->mapToolsData($data),
finishReason: $data['stop_reason'] ?? '',
toolCalls: $this->makeToolCalls($data),
usage: $this->makeUsage($data),
@@ -110,6 +113,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse {
return new PartialLLMResponse(
contentDelta: $this->makeContentDelta($data),
responseData: $data,
+ toolId: $data['content_block']['id'] ?? '',
toolName: $data['content_block']['name'] ?? '',
toolArgs: $data['delta']['partial_json'] ?? '',
finishReason: $data['delta']['stop_reason'] ?? $data['message']['stop_reason'] ?? '',
@@ -159,27 +163,55 @@ private function toTools(array $tools) : array {
private function toToolChoice(string|array $toolChoice, array $tools) : array|string {
return match(true) {
empty($tools) => '',
+ empty($toolChoice) => [
+ 'type' => 'auto',
+ 'disable_parallel_tool_use' => !$this->parallelToolCalls,
+ ],
is_array($toolChoice) => [
'type' => 'tool',
'name' => $toolChoice['function']['name'],
- ],
- empty($toolChoice) => [
- 'type' => 'auto',
+ 'disable_parallel_tool_use' => !$this->parallelToolCalls,
],
default => [
- 'type' => $toolChoice,
+ 'type' => $this->mapToolChoice($toolChoice),
+ 'disable_parallel_tool_use' => !$this->parallelToolCalls,
],
};
}
+ protected function mapToolChoice(string $choice) : string {
+ return match($choice) {
+ 'auto' => 'auto',
+ 'required' => 'any',
+ default => 'auto',
+ };
+ }
+
private function toNativeMessages(array $messages) : array {
- return array_map(
- fn($message) => [
- 'role' => $this->mapRole($message['role'] ?? 'user'),
- 'content' => $this->toNativeContent($message['content']),
- ],
- $messages
- );
+ $list = [];
+ foreach ($messages as $message) {
+ $nativeMessage = $this->mapMessage($message);
+ if (empty($nativeMessage)) {
+ continue;
+ }
+ $list[] = $nativeMessage;
+ }
+ return $list;
+ }
+
+ private function mapMessage(array $message) : array {
+ return match(true) {
+ ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message),
+ ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message),
+ default => $this->toNativeTextMessage($message),
+ };
+ }
+
+ private function toNativeTextMessage(array $message) : array {
+ return [
+ 'role' => $this->mapRole($message['role'] ?? 'user'),
+ 'content' => $this->toNativeContent($message['content']),
+ ];
}
private function mapRole(string $role) : string {
@@ -201,10 +233,18 @@ private function toNativeContent(string|array $content) : string|array {
private function contentPartToNative(array $contentPart) : array {
$type = $contentPart['type'] ?? 'text';
- if ($type === 'image_url') {
- $contentPart = $this->toNativeImage($contentPart);
- }
- return $contentPart;
+ return match($type) {
+ 'text' => $this->toNativeTextContent($contentPart),
+ 'image_url' => $this->toNativeImage($contentPart),
+ default => $contentPart,
+ };
+ }
+
+ private function toNativeTextContent(array $contentPart) : array {
+ return [
+ 'type' => 'text',
+ 'text' => $contentPart['text'] ?? '',
+ ];
}
private function toNativeImage(array $contentPart) : array {
@@ -220,21 +260,42 @@ private function toNativeImage(array $contentPart) : array {
];
return $contentPart;
}
+
+ private function toNativeToolCall(array $message) : array {
+ return [
+ 'role' => 'assistant',
+ 'content' => [[
+ 'type' => 'tool_use',
+ 'id' => $message['_metadata']['tool_calls'][0]['id'] ?? '',
+ 'name' => $message['_metadata']['tool_calls'][0]['function']['name'] ?? '',
+ 'input' => Json::from($message['_metadata']['tool_calls'][0]['function']['arguments'] ?? '')->toArray(),
+ ]]
+ ];
+ }
+
+ private function toNativeToolResult(array $message) : array {
+ return [
+ 'role' => 'user',
+ 'content' => [[
+ 'type' => 'tool_result',
+ 'tool_use_id' => $message['_metadata']['tool_call_id'] ?? '',
+ 'content' => $message['content'] ?? '',
+ //'is_error' => false,
+ ]]
+ ];
+ }
+
private function makeToolCalls(array $data) : ToolCalls {
return ToolCalls::fromMapper(array_map(
callback: fn(array $call) => $call,
- array: $data['content'] ?? []
- ), fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['input'] ?? '']));
- }
-
- private function mapToolsData(array $data) : array {
- return array_map(
- fn($tool) => [
- 'name' => $tool['name'] ?? '',
- 'arguments' => $tool['input'] ?? '',
- ],
- array_filter($data['content'] ?? [], fn($part) => 'tool_use' === ($part['type'] ?? ''))
- );
+ array: array_filter(
+ array: $data['content'] ?? [],
+ callback: fn($part) => 'tool_use' === ($part['type'] ?? ''))
+ ), fn($call) => ToolCall::fromArray([
+ 'id' => $call['id'] ?? '',
+ 'name' => $call['name'] ?? '',
+ 'arguments' => $call['input'] ?? ''
+ ]));
}
private function makeContent(array $data) : string {
@@ -297,4 +358,12 @@ private function makeUsage(array $data) : Usage {
reasoningTokens: 0,
);
}
+
+ protected function excludeUnderscoredKeys(array $messages) : array {
+ $list = [];
+ foreach ($messages as $message) {
+ $list[] = array_filter($message, fn($value, $key) => !Str::startsWith($key, '_'), ARRAY_FILTER_USE_BOTH);
+ }
+ return $list;
+ }
}
diff --git a/src/Features/LLM/Drivers/AzureOpenAIDriver.php b/src/Features/LLM/Drivers/AzureOpenAIDriver.php
index e56172f4..50064a90 100644
--- a/src/Features/LLM/Drivers/AzureOpenAIDriver.php
+++ b/src/Features/LLM/Drivers/AzureOpenAIDriver.php
@@ -1,36 +1,34 @@
- "{".$key."}", array_keys($this->config->metadata)),
- replace: array_values($this->config->metadata),
- subject: "{$this->config->apiUrl}{$this->config->endpoint}"
- ) . $this->getUrlParams();
- }
-
- protected function getUrlParams(): string {
- $params = array_filter([
- 'api-version' => $this->config->metadata['apiVersion'] ?? '',
- ]);
- if (!empty($params)) {
- return '?' . http_build_query($params);
- }
- return '';
- }
-
- public function getRequestHeaders(): array {
- return [
- 'Api-Key' => $this->config->apiKey,
- 'Content-Type' => 'application/json',
- ];
- }
+ "{".$key."}", array_keys($this->config->metadata)),
+ replace: array_values($this->config->metadata),
+ subject: "{$this->config->apiUrl}{$this->config->endpoint}"
+ ) . $this->getUrlParams();
+ }
+
+ protected function getUrlParams(): string {
+ $params = array_filter([
+ 'api-version' => $this->config->metadata['apiVersion'] ?? '',
+ ]);
+ if (!empty($params)) {
+ return '?' . http_build_query($params);
+ }
+ return '';
+ }
+
+ public function getRequestHeaders(): array {
+ return [
+ 'Api-Key' => $this->config->apiKey,
+ 'Content-Type' => 'application/json',
+ ];
+ }
}
\ No newline at end of file
diff --git a/src/Features/LLM/Drivers/CohereV1Driver.php b/src/Features/LLM/Drivers/CohereV1Driver.php
index bf783d77..13693fb8 100644
--- a/src/Features/LLM/Drivers/CohereV1Driver.php
+++ b/src/Features/LLM/Drivers/CohereV1Driver.php
@@ -16,6 +16,7 @@
use Cognesy\Instructor\Features\LLM\InferenceRequest;
use Cognesy\Instructor\Utils\Json\Json;
use Cognesy\Instructor\Utils\Messages\Messages;
+use Cognesy\Instructor\Utils\Str;
class CohereV1Driver implements CanHandleInference
{
@@ -68,15 +69,18 @@ public function getRequestBody(
array $options = [],
Mode $mode = Mode::Text,
) : array {
+ unset($options['parallel_tool_calls']);
+
$system = '';
$chatHistory = [];
+ $messages = $this->excludeUnderscoredKeys($messages);
- $request = array_filter(array_merge([
+ $request = array_merge(array_filter([
'model' => $model ?: $this->config->model,
'preamble' => $system,
'chat_history' => $chatHistory,
'message' => Messages::asString($messages),
- ], $options));
+ ]), $options);
return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
}
@@ -87,7 +91,7 @@ public function toLLMResponse(array $data): LLMResponse {
return new LLMResponse(
content: $this->makeContent($data),
responseData: $data,
- toolsData: $this->mapToolsData($data),
+ //: $this->map($data),
finishReason: $data['finish_reason'] ?? '',
toolCalls: $this->makeToolCalls($data),
usage: $this->makeUsage($data),
@@ -98,6 +102,7 @@ public function toPartialLLMResponse(array $data) : PartialLLMResponse {
return new PartialLLMResponse(
contentDelta: $this->makeContentDelta($data),
responseData: $data,
+ toolId: $this->makeToolId($data),
toolName: $this->makeToolNameDelta($data),
toolArgs: $this->makeToolArgsDelta($data),
finishReason: $data['response']['finish_reason'] ?? $data['delta']['finish_reason'] ?? '',
@@ -184,16 +189,6 @@ private function makeToolCalls(array $data) : ToolCalls {
);
}
- private function mapToolsData(array $data) : array {
- return array_map(
- fn($tool) => [
- 'name' => $tool['name'] ?? '',
- 'arguments' => $tool['parameters'] ?? '',
- ],
- $data['tool_calls'] ?? []
- );
- }
-
private function makeContent(array $data) : string {
return ($data['text'] ?? '') . (!empty($data['tool_calls'])
? ("\n" . Json::encode($data['tool_calls']))
@@ -208,12 +203,11 @@ private function makeContentDelta(array $data) : string {
return $data['tool_call_delta']['parameters'] ?? $data['text'] ?? '';
}
- private function makeToolArgsDelta(array $data) : string {
+ private function makeToolId(array $data) {
if (!$this->isStreamChunk($data)) {
return '';
}
- $toolArgs = $data['tool_calls'][0]['parameters'] ?? '';
- return ('' === $toolArgs) ? '' : Json::encode($toolArgs);
+ return $data['tool_calls'][0]['id'] ?? '';
}
private function makeToolNameDelta(array $data) : string {
@@ -223,6 +217,14 @@ private function makeToolNameDelta(array $data) : string {
return $data['tool_calls'][0]['name'] ?? '';
}
+ private function makeToolArgsDelta(array $data) : string {
+ if (!$this->isStreamChunk($data)) {
+ return '';
+ }
+ $toolArgs = $data['tool_calls'][0]['parameters'] ?? '';
+ return ('' === $toolArgs) ? '' : Json::encode($toolArgs);
+ }
+
private function isStreamChunk(array $data) : bool {
return in_array(($data['event_type'] ?? ''), ['text-generation', 'tool-calls-chunk']);
}
@@ -254,4 +256,12 @@ private function makeUsage(array $data) : Usage {
reasoningTokens: 0,
);
}
+
+ protected function excludeUnderscoredKeys(array $messages) : array {
+ $list = [];
+ foreach ($messages as $message) {
+ $list[] = array_filter($message, fn($value, $key) => !Str::startsWith($key, '_'), ARRAY_FILTER_USE_BOTH);
+ }
+ return $list;
+ }
}
diff --git a/src/Features/LLM/Drivers/CohereV2Driver.php b/src/Features/LLM/Drivers/CohereV2Driver.php
index 5ebad31e..95b08598 100644
--- a/src/Features/LLM/Drivers/CohereV2Driver.php
+++ b/src/Features/LLM/Drivers/CohereV2Driver.php
@@ -1,174 +1,179 @@
- $model ?: $this->config->model,
- 'max_tokens' => $this->config->maxTokens,
- 'messages' => $messages,
- ], $options));
-
- return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
- }
-
- public function getRequestHeaders(): array {
- $optional = [
- 'X-Client-Name' => $this->config->metadata['client_name'] ?? '',
- ];
- return array_merge([
- 'Authorization' => "Bearer {$this->config->apiKey}",
- 'Content-Type' => 'application/json',
- ], $optional);
- }
-
- // RESPONSE //////////////////////////////////////////////
-
- public function toLLMResponse(array $data): LLMResponse {
- return new LLMResponse(
- content: $this->makeContent($data),
- responseData: $data,
- toolsData: $this->makeToolsData($data),
- finishReason: $data['finish_reason'] ?? '',
- toolCalls: $this->makeToolCalls($data),
- usage: $this->makeUsage($data),
- );
- }
-
- public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse {
- if (empty($data)) {
- return null;
- }
- return new PartialLLMResponse(
- contentDelta: $this->makeContentDelta($data),
- responseData: $data,
- toolName: $data['delta']['message']['tool_calls']['function']['name'] ?? '',
- toolArgs: $data['delta']['message']['tool_calls']['function']['arguments'] ?? '',
- finishReason: $data['delta']['finish_reason'] ?? '',
- usage: $this->makeUsage($data),
- );
- }
-
- public function getData(string $data): string|bool {
- if (!str_starts_with($data, 'data:')) {
- return '';
- }
- $data = trim(substr($data, 5));
- return match(true) {
- $data === '[DONE]' => false,
- default => $data,
- };
- }
-
- // OVERRIDES - HELPERS ///////////////////////////////////
-
- protected function applyMode(
- array $request,
- Mode $mode,
- array $tools,
- string|array $toolChoice,
- array $responseFormat
- ) : array {
- switch($mode) {
- case Mode::Tools:
- $request['tools'] = $this->removeDisallowedEntries($tools);
- break;
- case Mode::Json:
- $request['response_format'] = $responseFormat;
- break;
- case Mode::JsonSchema:
- $request['response_format'] = [
- 'type' => 'json_object',
- 'schema' => $responseFormat['json_schema']['schema'],
- ];
- break;
- }
- return $request;
- }
-
- protected function removeDisallowedEntries(array $jsonSchema) : array {
- return Arrays::removeRecursively($jsonSchema, [
- 'x-title',
- 'x-php-class',
- 'additionalProperties',
- ]);
- }
-
- private function makeContent(array $data): string {
- $contentMsg = $data['message']['content'][0]['text'] ?? '';
- $contentFnArgs = $data['message']['tool_calls'][0]['function']['arguments'] ?? '';
- return match(true) {
- !empty($contentMsg) => $contentMsg,
- !empty($contentFnArgs) => $contentFnArgs,
- default => ''
- };
- }
-
- private function makeToolCalls(array $data) : ToolCalls {
- return ToolCalls::fromArray(array_map(
- callback: fn(array $call) => $call['function'] ?? [],
- array: $data['message']['tool_calls'] ?? []
- ));
- }
-
- private function makeToolsData(array $data) : array {
- return array_map(
- fn($tool) => [
- 'name' => $tool['function']['name'] ?? '',
- 'arguments' => Json::decode($tool['function']['arguments']) ?? '',
- ],
- $data['message']['tool_calls'] ?? []
- );
- }
-
- private function makeContentDelta(array $data): string {
- $deltaContent = match(true) {
- ([] !== ($data['delta']['message']['content'] ?? [])) => $this->normalizeContent($data['delta']['message']['content']),
- default => '',
- };
- $deltaFnArgs = $data['delta']['message']['tool_calls']['function']['arguments'] ?? '';
- return match(true) {
- '' !== $deltaContent => $deltaContent,
- '' !== $deltaFnArgs => $deltaFnArgs,
- default => ''
- };
- }
-
- private function normalizeContent(array|string $content) : string {
- return is_array($content) ? $content['text'] : $content;
- }
-
- private function makeUsage(array $data) : Usage {
- return new Usage(
- inputTokens: $data['usage']['billed_units']['input_tokens']
- ?? $data['delta']['usage']['billed_units']['input_tokens']
- ?? 0,
- outputTokens: $data['usage']['billed_units']['output_tokens']
- ?? $data['delta']['usage']['billed_units']['output_tokens']
- ?? 0,
- cacheWriteTokens: 0,
- cacheReadTokens: 0,
- reasoningTokens: 0,
- );
- }
-}
+ $model ?: $this->config->model,
+ 'max_tokens' => $this->config->maxTokens,
+ 'messages' => $this->toNativeMessages($messages),
+ ]), $options);
+
+ return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
+ }
+
+ public function getRequestHeaders(): array {
+ $optional = [
+ 'X-Client-Name' => $this->config->metadata['client_name'] ?? '',
+ ];
+ return array_merge([
+ 'Authorization' => "Bearer {$this->config->apiKey}",
+ 'Content-Type' => 'application/json',
+ ], $optional);
+ }
+
+ // RESPONSE //////////////////////////////////////////////
+
+ public function toLLMResponse(array $data): LLMResponse {
+ return new LLMResponse(
+ content: $this->makeContent($data),
+ responseData: $data,
+ finishReason: $data['finish_reason'] ?? '',
+ toolCalls: $this->makeToolCalls($data),
+ usage: $this->makeUsage($data),
+ );
+ }
+
+ public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse {
+ if (empty($data)) {
+ return null;
+ }
+ return new PartialLLMResponse(
+ contentDelta: $this->makeContentDelta($data),
+ responseData: $data,
+ toolId: $data['delta']['message']['tool_calls']['function']['id'] ?? '',
+ toolName: $data['delta']['message']['tool_calls']['function']['name'] ?? '',
+ toolArgs: $data['delta']['message']['tool_calls']['function']['arguments'] ?? '',
+ finishReason: $data['delta']['finish_reason'] ?? '',
+ usage: $this->makeUsage($data),
+ );
+ }
+
+ public function getData(string $data): string|bool {
+ if (!str_starts_with($data, 'data:')) {
+ return '';
+ }
+ $data = trim(substr($data, 5));
+ return match(true) {
+ $data === '[DONE]' => false,
+ default => $data,
+ };
+ }
+
+ // OVERRIDES - HELPERS ///////////////////////////////////
+
+ protected function applyMode(
+ array $request,
+ Mode $mode,
+ array $tools,
+ string|array $toolChoice,
+ array $responseFormat
+ ) : array {
+ switch($mode) {
+ case Mode::Tools:
+ $request['tools'] = $this->removeDisallowedEntries($tools);
+ break;
+ case Mode::Json:
+ $request['response_format'] = $responseFormat;
+ break;
+ case Mode::JsonSchema:
+ $request['response_format'] = [
+ 'type' => 'json_object',
+ 'schema' => $responseFormat['json_schema']['schema'],
+ ];
+ break;
+ }
+ return $request;
+ }
+
+ protected function removeDisallowedEntries(array $jsonSchema) : array {
+ return Arrays::removeRecursively($jsonSchema, [
+ 'x-title',
+ 'x-php-class',
+ 'additionalProperties',
+ ]);
+ }
+
+ private function makeContent(array $data): string {
+ $contentMsg = $data['message']['content'][0]['text'] ?? '';
+ $contentFnArgs = $data['message']['tool_calls'][0]['function']['arguments'] ?? '';
+ return match(true) {
+ !empty($contentMsg) => $contentMsg,
+ !empty($contentFnArgs) => $contentFnArgs,
+ default => ''
+ };
+ }
+
+ private function makeToolCalls(array $data) : ToolCalls {
+ return ToolCalls::fromArray(array_map(
+ callback: fn(array $call) => $this->makeToolCall($call),
+ array: $data['message']['tool_calls'] ?? [],
+ ));
+ }
+
+ private function makeToolCall(array $data) : ?ToolCall {
+ if (empty($data)) {
+ return null;
+ }
+ if (!isset($data['function'])) {
+ return null;
+ }
+ if (!isset($data['id'])) {
+ return null;
+ }
+ return ToolCall::fromArray($data['function'] ?? [])->withId($data['id'] ?? '');
+ }
+
+ private function makeContentDelta(array $data): string {
+ $deltaContent = match(true) {
+ ([] !== ($data['delta']['message']['content'] ?? [])) => $this->normalizeContent($data['delta']['message']['content']),
+ default => '',
+ };
+ $deltaFnArgs = $data['delta']['message']['tool_calls']['function']['arguments'] ?? '';
+ return match(true) {
+ '' !== $deltaContent => $deltaContent,
+ '' !== $deltaFnArgs => $deltaFnArgs,
+ default => ''
+ };
+ }
+
+ private function normalizeContent(array|string $content) : string {
+ return is_array($content) ? $content['text'] : $content;
+ }
+
+ private function makeUsage(array $data) : Usage {
+ return new Usage(
+ inputTokens: $data['usage']['billed_units']['input_tokens']
+ ?? $data['delta']['usage']['billed_units']['input_tokens']
+ ?? 0,
+ outputTokens: $data['usage']['billed_units']['output_tokens']
+ ?? $data['delta']['usage']['billed_units']['output_tokens']
+ ?? 0,
+ cacheWriteTokens: 0,
+ cacheReadTokens: 0,
+ reasoningTokens: 0,
+ );
+ }
+}
diff --git a/src/Features/LLM/Drivers/GeminiDriver.php b/src/Features/LLM/Drivers/GeminiDriver.php
index d218bb56..0b4fb0ec 100644
--- a/src/Features/LLM/Drivers/GeminiDriver.php
+++ b/src/Features/LLM/Drivers/GeminiDriver.php
@@ -100,9 +100,6 @@ public function toLLMResponse(array $data): ?LLMResponse {
return new LLMResponse(
content: $this->makeContent($data),
responseData: $data,
-// toolName: $data['candidates'][0]['content']['parts'][0]['functionCall']['name'] ?? '',
-// toolArgs: Json::encode($data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? []),
- toolsData: $this->mapToolsData($data),
finishReason: $data['candidates'][0]['finishReason'] ?? '',
toolCalls: $this->makeToolCalls($data),
usage: $this->makeUsage($data),
@@ -116,6 +113,7 @@ public function toPartialLLMResponse(array $data) : ?PartialLLMResponse {
return new PartialLLMResponse(
contentDelta: $this->makeContentDelta($data),
responseData: $data,
+ toolId: $data['candidates'][0]['id'] ?? '',
toolName: $this->makeToolName($data),
toolArgs: $this->makeToolArgs($data),
finishReason: $data['candidates'][0]['finishReason'] ?? '',
@@ -145,10 +143,12 @@ private function toSystem(array $messages) : array {
}
private function toMessages(array $messages) : array {
- return $this->toNativeMessages(Messages::fromArray($messages)
- ->exceptRoles(['system'])
- //->toMergedPerRole()
- ->toArray());
+ return $this->toNativeMessages(
+ Messages::fromArray($messages)
+ ->exceptRoles(['system'])
+ //->toMergedPerRole()
+ ->toArray()
+ );
}
protected function toOptions(
@@ -172,12 +172,13 @@ protected function toTools(array $tools) : array {
)];
}
- protected function toToolChoice(array $toolChoice): string|array {
+ protected function toToolChoice(string|array $toolChoice): string|array {
return match(true) {
empty($toolChoice) => ["function_calling_config" => ["mode" => "ANY"]],
+ is_string($toolChoice) => ["function_calling_config" => ["mode" => $this->mapToolChoice($toolChoice)]],
is_array($toolChoice) => [
"function_calling_config" => array_filter([
- "mode" => "ANY",
+ "mode" => $this->mapToolChoice($toolChoice['mode'] ?? "ANY"),
"allowed_function_names" => $toolChoice['function']['name'] ?? [],
]),
],
@@ -185,6 +186,15 @@ protected function toToolChoice(array $toolChoice): string|array {
};
}
+ protected function mapToolChoice(string $choice) : string {
+ return match($choice) {
+ 'auto' => 'AUTO',
+ 'required' => 'ANY',
+ 'none' => 'NONE',
+ default => 'ANY',
+ };
+ }
+
protected function toResponseMimeType(Mode $mode): string {
return match($mode) {
Mode::Text => "text/plain",
@@ -211,18 +221,69 @@ protected function removeDisallowedEntries(array $jsonSchema) : array {
]);
}
- protected function toNativeMessages(string|array $messages) : array {
- if (is_string($messages)) {
- return [["text" => $messages]];
- }
- $transformed = [];
+ private function toNativeMessages(array $messages) : array {
+ $list = [];
foreach ($messages as $message) {
- $transformed[] = [
- 'role' => $this->mapRole($message['role']),
- 'parts' => $this->contentPartsToNative($message['content']),
- ];
+ $nativeMessage = $this->mapMessage($message);
+ if (empty($nativeMessage)) {
+ continue;
+ }
+ $list[] = $nativeMessage;
}
- return $transformed;
+ return $list;
+ }
+
+ private function mapMessage(array $message) : array {
+ return match(true) {
+ ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message),
+ ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message),
+ default => $this->toNativeTextMessage($message),
+ };
+ }
+
+ private function toNativeTextMessage(array $message) : array {
+ return [
+ 'role' => $this->mapRole($message['role'] ?? 'user'),
+ 'parts' => $this->toNativeContentParts($message['content']),
+ ];
+ }
+
+ private function toNativeToolCall(array $message) : array {
+ return [
+ 'role' => 'model',
+ 'parts' => array_map(
+ callback: fn($call) => $this->toNativeToolCallPart($call),
+ array: $message['_metadata']['tool_calls'] ?? []
+ ),
+ ];
+ }
+
+ private function toNativeToolCallPart(array $call) : array {
+ return [
+ 'functionCall' => [
+ 'name' => $call['function']['name'] ?? '',
+ 'args' => Json::from($call['function']['arguments'])->toArray() ?? [],
+ ]
+ ];
+ }
+
+ private function toNativeToolResult(array $message) : array {
+ $content = match(true) {
+ is_array($message['_metadata']['result'] ?? '') => Json::from($message['_metadata']['result'] ?? '')->toArray(),
+ default => $message['content'],
+ };
+ return [
+ 'role' => 'user',
+ 'parts' => [[
+ 'functionResponse' => [
+ 'name' => $message['_metadata']['tool_name'] ?? '',
+ 'response' => [
+ 'name' => $message['_metadata']['tool_name'] ?? '',
+ 'content' => $content,
+ ],
+ ],
+ ]],
+ ];
}
protected function mapRole(string $role) : string {
@@ -230,7 +291,7 @@ protected function mapRole(string $role) : string {
return $roles[$role] ?? $role;
}
- protected function contentPartsToNative(string|array $contentParts) : array {
+ protected function toNativeContentParts(string|array $contentParts) : array {
if (is_string($contentParts)) {
return [["text" => $contentParts]];
}
@@ -243,18 +304,27 @@ protected function contentPartsToNative(string|array $contentParts) : array {
protected function contentPartToNative(array $contentPart) : array {
$type = $contentPart['type'] ?? 'text';
- return match($type) {
- 'text' => ['text' => $contentPart['text'] ?? ''],
- 'image_url' => [
- 'inlineData' => [
- 'mimeType' => Str::between($contentPart['image_url']['url'], 'data:', ';base64,'),
- 'data' => Str::after($contentPart['image_url']['url'], ';base64,'),
- ],
- ],
+ return match(true) {
+ ($type === 'text') => $this->makeTextContentPart($contentPart),
+ ($type === 'image_url') => $this->makeImageUrlContentPart($contentPart),
+
default => $contentPart,
};
}
+ private function makeTextContentPart(array $contentPart) : array {
+ return ['text' => $contentPart['text'] ?? ''];
+ }
+
+ private function makeImageUrlContentPart(array $contentPart) : array {
+ return [
+ 'inlineData' => [
+ 'mimeType' => Str::between($contentPart['image_url']['url'], 'data:', ';base64,'),
+ 'data' => Str::after($contentPart['image_url']['url'], ';base64,'),
+ ],
+ ];
+ }
+
private function makeToolCalls(array $data) : ToolCalls {
return ToolCalls::fromMapper(array_map(
callback: fn(array $call) => $call['functionCall'] ?? [],
@@ -262,16 +332,6 @@ private function makeToolCalls(array $data) : ToolCalls {
), fn($call) => ToolCall::fromArray(['name' => $call['name'] ?? '', 'arguments' => $call['args'] ?? '']));
}
- private function mapToolsData(array $data) : array {
- return array_map(
- fn($tool) => [
- 'name' => $tool['functionCall']['name'] ?? '',
- 'arguments' => $tool['functionCall']['args'] ?? '',
- ],
- $data['candidates'][0]['content']['parts'] ?? []
- );
- }
-
private function makeContent(array $data) : string {
return $data['candidates'][0]['content']['parts'][0]['text']
?? Json::encode($data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? '')
diff --git a/src/Features/LLM/Drivers/GrokDriver.php b/src/Features/LLM/Drivers/GrokDriver.php
new file mode 100644
index 00000000..21d2755c
--- /dev/null
+++ b/src/Features/LLM/Drivers/GrokDriver.php
@@ -0,0 +1,14 @@
+ 'assistant',
+ 'content' => $message['content'] ?? 'I\'m calling tool: ' . $message['_metadata']['tool_calls'][0]['function']['name'],
+ 'tool_calls' => $message['_metadata']['tool_calls'] ?? [],
+ ];
+ }
+}
\ No newline at end of file
diff --git a/src/Features/LLM/Drivers/MistralDriver.php b/src/Features/LLM/Drivers/MistralDriver.php
index f8099ccc..3a51af92 100644
--- a/src/Features/LLM/Drivers/MistralDriver.php
+++ b/src/Features/LLM/Drivers/MistralDriver.php
@@ -1,59 +1,73 @@
- $model ?: $this->config->model,
- 'max_tokens' => $this->config->maxTokens,
- 'messages' => $messages,
- ], $options));
-
- return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
- }
-
- // PRIVATE //////////////////////////////////////////////
-
- private function applyMode(
- array $request,
- Mode $mode,
- array $tools,
- string|array $toolChoice,
- array $responseFormat
- ) : array {
- switch($mode) {
- case Mode::Tools:
- $request['tools'] = $this->removeDisallowedEntries($tools);
- $request['tool_choice'] = 'any';
- break;
- case Mode::Json:
- case Mode::JsonSchema:
- $request['response_format'] = ['type' => 'json_object'];
- break;
- }
- return $request;
- }
-
- private function removeDisallowedEntries(array $jsonSchema) : array {
- return Arrays::removeRecursively($jsonSchema, [
- 'x-title',
- //'description',
- 'x-php-class',
- 'additionalProperties',
- ]);
- }
+ $model ?: $this->config->model,
+ 'max_tokens' => $this->config->maxTokens,
+ 'messages' => $this->toNativeMessages($messages),
+ ]), $options);
+
+ return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
+ }
+
+ // PRIVATE //////////////////////////////////////////////
+
+ private function applyMode(
+ array $request,
+ Mode $mode,
+ array $tools,
+ string|array $toolChoice,
+ array $responseFormat
+ ) : array {
+ switch($mode) {
+ case Mode::Tools:
+ $request['tools'] = $this->removeDisallowedEntries($tools);
+ $request['tool_choice'] = $this->toToolChoice($tools, $toolChoice);
+ break;
+ case Mode::Json:
+ case Mode::JsonSchema:
+ $request['response_format'] = ['type' => 'json_object'];
+ break;
+ }
+ return $request;
+ }
+
+ private function toToolChoice(array $tools, array|string $toolChoice) : array|string {
+ return match(true) {
+ empty($tools) => '',
+ empty($toolChoice) => 'auto',
+ is_array($toolChoice) => [
+ 'type' => 'function',
+ 'name' => $toolChoice['function']['name'],
+ ],
+ default => $toolChoice,
+ };
+ }
+
+ private function removeDisallowedEntries(array $jsonSchema) : array {
+ return Arrays::removeRecursively($jsonSchema, [
+ 'x-title',
+ //'description',
+ 'x-php-class',
+ 'additionalProperties',
+ ]);
+ }
}
\ No newline at end of file
diff --git a/src/Features/LLM/Drivers/OpenAICompatibleDriver.php b/src/Features/LLM/Drivers/OpenAICompatibleDriver.php
index 7a367299..adb64e2a 100644
--- a/src/Features/LLM/Drivers/OpenAICompatibleDriver.php
+++ b/src/Features/LLM/Drivers/OpenAICompatibleDriver.php
@@ -1,63 +1,75 @@
- $model ?: $this->config->model,
- 'max_tokens' => $this->config->maxTokens,
- 'messages' => $messages,
- ], $options));
-
- return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
- }
-
- // OVERRIDES - HELPERS ///////////////////////////////////
-
- protected function applyMode(
- array $request,
- Mode $mode,
- array $tools,
- string|array $toolChoice,
- array $responseFormat
- ) : array {
- switch($mode) {
- case Mode::Tools:
- $request['tools'] = $this->removeDisallowedEntries($tools);
- $request['tool_choice'] = 'auto';
- break;
- case Mode::Json:
- $request['response_format'] = $responseFormat;
- break;
- case Mode::JsonSchema:
- $request['response_format'] = [
- 'type' => 'json_object',
- 'schema' => $responseFormat['json_schema']['schema'],
- ];
- break;
- }
- return $request;
- }
-
- protected function removeDisallowedEntries(array $jsonSchema) : array {
- return Arrays::removeRecursively($jsonSchema, [
- 'title',
- 'x-php-class',
- 'additionalProperties',
- ]);
- }
-}
+ $model ?: $this->config->model,
+ 'max_tokens' => $this->config->maxTokens,
+ 'messages' => $this->toNativeMessages($messages),
+ ]), $options);
+
+ return $this->applyMode($request, $mode, $tools, $toolChoice, $responseFormat);
+ }
+
+ // OVERRIDES - HELPERS ///////////////////////////////////
+
+ protected function applyMode(
+ array $request,
+ Mode $mode,
+ array $tools,
+ string|array $toolChoice,
+ array $responseFormat
+ ) : array {
+ switch($mode) {
+ case Mode::Tools:
+ $request['tools'] = $this->removeDisallowedEntries($tools);
+ $request['tool_choice'] = $this->toToolChoice($tools, $toolChoice);
+ break;
+ case Mode::Json:
+ $request['response_format'] = $responseFormat;
+ break;
+ case Mode::JsonSchema:
+ $request['response_format'] = [
+ 'type' => 'json_object',
+ 'schema' => $responseFormat['json_schema']['schema'],
+ ];
+ break;
+ }
+ return $request;
+ }
+
+ protected function removeDisallowedEntries(array $jsonSchema) : array {
+ return Arrays::removeRecursively($jsonSchema, [
+ 'title',
+ 'x-php-class',
+ 'additionalProperties',
+ ]);
+ }
+
+ private function toToolChoice(array $tools, array|string $toolChoice) : array|string {
+ return match(true) {
+ empty($tools) => '',
+ empty($toolChoice) => 'auto',
+ is_array($toolChoice) => [
+ 'type' => 'function',
+ 'name' => $toolChoice['function']['name'],
+ ],
+ default => $toolChoice,
+ };
+ }
+}
diff --git a/src/Features/LLM/Drivers/OpenAIDriver.php b/src/Features/LLM/Drivers/OpenAIDriver.php
index cdb0afd1..9bb4ee45 100644
--- a/src/Features/LLM/Drivers/OpenAIDriver.php
+++ b/src/Features/LLM/Drivers/OpenAIDriver.php
@@ -11,10 +11,11 @@
use Cognesy\Instructor\Features\LLM\Data\LLMConfig;
use Cognesy\Instructor\Features\LLM\Data\LLMResponse;
use Cognesy\Instructor\Features\LLM\Data\PartialLLMResponse;
+use Cognesy\Instructor\Features\LLM\Data\ToolCall;
use Cognesy\Instructor\Features\LLM\Data\ToolCalls;
use Cognesy\Instructor\Features\LLM\Data\Usage;
use Cognesy\Instructor\Features\LLM\InferenceRequest;
-use Cognesy\Instructor\Utils\Json\Json;
+use Cognesy\Instructor\Utils\Str;
class OpenAIDriver implements CanHandleInference
{
@@ -71,11 +72,11 @@ public function getRequestBody(
array $options = [],
Mode $mode = Mode::Text,
) : array {
- $request = array_filter(array_merge([
+ $request = array_merge(array_filter([
'model' => $model ?: $this->config->model,
'max_tokens' => $this->config->maxTokens,
- 'messages' => $messages,
- ], $options));
+ 'messages' => $this->toNativeMessages($messages),
+ ]), $options);
if ($options['stream'] ?? false) {
$request['stream_options']['include_usage'] = true;
@@ -90,7 +91,6 @@ public function toLLMResponse(array $data): ?LLMResponse {
return new LLMResponse(
content: $this->makeContent($data),
responseData: $data,
- toolsData: $this->makeToolsData($data),
finishReason: $data['choices'][0]['finish_reason'] ?? '',
toolCalls: $this->makeToolCalls($data),
usage: $this->makeUsage($data),
@@ -104,6 +104,7 @@ public function toPartialLLMResponse(array|null $data) : ?PartialLLMResponse {
return new PartialLLMResponse(
contentDelta: $this->makeContentDelta($data),
responseData: $data,
+ toolId: $this->makeToolId($data),
toolName: $this->makeToolNameDelta($data),
toolArgs: $this->makeToolArgsDelta($data),
finishReason: $data['choices'][0]['finish_reason'] ?? '',
@@ -165,19 +166,22 @@ private function withCachedContext(InferenceRequest $request): InferenceRequest
private function makeToolCalls(array $data) : ToolCalls {
return ToolCalls::fromArray(array_map(
- callback: fn(array $call) => $call['function'] ?? [],
+ callback: fn(array $call) => $this->makeToolCall($call),
array: $data['choices'][0]['message']['tool_calls'] ?? []
));
}
- private function makeToolsData(array $data) : array {
- return array_map(
- fn($tool) => [
- 'name' => $tool['function']['name'] ?? '',
- 'arguments' => Json::decode($tool['function']['arguments']) ?? '',
- ],
- $data['choices'][0]['message']['tool_calls'] ?? []
- );
+ private function makeToolCall(array $data) : ?ToolCall {
+ if (empty($data)) {
+ return null;
+ }
+ if (!isset($data['function'])) {
+ return null;
+ }
+ if (!isset($data['id'])) {
+ return null;
+ }
+ return ToolCall::fromArray($data['function'])?->withId($data['id']);
}
private function makeContent(array $data): string {
@@ -200,6 +204,10 @@ private function makeContentDelta(array $data): string {
};
}
+ private function makeToolId(array $data) : string {
+ return $data['choices'][0]['delta']['tool_calls'][0]['id'] ?? '';
+ }
+
private function makeToolNameDelta(array $data) : string {
return $data['choices'][0]['delta']['tool_calls'][0]['function']['name'] ?? '';
}
@@ -221,4 +229,39 @@ private function makeUsage(array $data): Usage {
reasoningTokens: $data['usage']['prompt_tokens_details']['reasoning_tokens'] ?? 0,
);
}
+
+ protected function toNativeMessages(array $messages) : array {
+ $list = [];
+ foreach ($messages as $message) {
+ $nativeMessage = $this->mapMessage($message);
+ if (empty($nativeMessage)) {
+ continue;
+ }
+ $list[] = $nativeMessage;
+ }
+ return $list;
+ }
+
+ protected function mapMessage(array $message) : array {
+ return match(true) {
+ ($message['role'] ?? '') === 'assistant' && !empty($message['_metadata']['tool_calls'] ?? []) => $this->toNativeToolCall($message),
+ ($message['role'] ?? '') === 'tool' => $this->toNativeToolResult($message),
+ default => $message,
+ };
+ }
+
+ protected function toNativeToolCall(array $message) : array {
+ return [
+ 'role' => 'assistant',
+ 'tool_calls' => $message['_metadata']['tool_calls'] ?? [],
+ ];
+ }
+
+ protected function toNativeToolResult(array $message) : array {
+ return [
+ 'role' => 'tool',
+ 'tool_call_id' => $message['_metadata']['tool_call_id'] ?? '',
+ 'content' => $message['content'] ?? '',
+ ];
+ }
}
diff --git a/src/Features/LLM/Enums/LLMFinishReason.php b/src/Features/LLM/Enums/LLMFinishReason.php
index 40f2801f..8988c231 100644
--- a/src/Features/LLM/Enums/LLMFinishReason.php
+++ b/src/Features/LLM/Enums/LLMFinishReason.php
@@ -1,38 +1,48 @@
- self::Stop,
- 'COMPLETE' => self::Stop,
- 'stop' => self::Stop,
- 'STOP' => self::Stop,
- 'max_tokens' => self::Length,
- 'MAX_TOKENS' => self::Length,
- 'length' => self::Length,
- 'model_length' => self::Length,
- 'SAFETY' => self::ContentFilter,
- 'RECITATION' => self::ContentFilter,
- 'LANGUAGE' => self::ContentFilter,
- 'BLOCKLIST' => self::ContentFilter,
- 'PROHIBITED_CONTENT' => self::ContentFilter,
- 'SPII' => self::ContentFilter,
- 'error' => self::Error,
- 'MALFORMED_FUNCTION_CALL' => self::Error,
- 'tool_calls' => self::ToolCalls,
- 'FINISH_REASON_UNSPECIFIED' => self::Other,
- 'OTHER' => self::Other,
- default => self::Other,
- };
- }
-}
+ ($this->value === $reason->value),
+ is_string($reason) => ($this->value === $reason),
+ default => false,
+ };
+ }
+
+ public static function fromText(string $text) : LLMFinishReason {
+ return match ($text) {
+ 'BLOCKLIST' => self::ContentFilter,
+ 'COMPLETE' => self::Stop,
+ 'error' => self::Error,
+ 'FINISH_REASON_UNSPECIFIED' => self::Other,
+ 'LANGUAGE' => self::ContentFilter,
+ 'length' => self::Length,
+ 'MALFORMED_FUNCTION_CALL' => self::Error,
+ 'max_tokens' => self::Length,
+ 'MAX_TOKENS' => self::Length,
+ 'model_length' => self::Length,
+ 'OTHER' => self::Other,
+ 'PROHIBITED_CONTENT' => self::ContentFilter,
+ 'RECITATION' => self::ContentFilter,
+ 'SAFETY' => self::ContentFilter,
+ 'SPII' => self::ContentFilter,
+ 'stop' => self::Stop,
+ 'STOP' => self::Stop,
+ 'stop_sequence' => self::Stop,
+ 'TOOL_CALL' => self::ToolCalls,
+ 'tool_calls' => self::ToolCalls,
+ 'tool_use' => self::ToolCalls,
+ default => self::Other,
+ };
+ }
+}
diff --git a/src/Features/LLM/Enums/LLMProviderType.php b/src/Features/LLM/Enums/LLMProviderType.php
index bf9ff8b0..62e18e94 100644
--- a/src/Features/LLM/Enums/LLMProviderType.php
+++ b/src/Features/LLM/Enums/LLMProviderType.php
@@ -1,28 +1,29 @@
-value === $clientType->value;
- }
-}
+value === $clientType->value;
+ }
+}
diff --git a/src/Features/LLM/Inference.php b/src/Features/LLM/Inference.php
index 7c6d5e7b..7103ff7b 100644
--- a/src/Features/LLM/Inference.php
+++ b/src/Features/LLM/Inference.php
@@ -14,6 +14,7 @@
use Cognesy\Instructor\Features\LLM\Drivers\CohereV1Driver;
use Cognesy\Instructor\Features\LLM\Drivers\CohereV2Driver;
use Cognesy\Instructor\Features\LLM\Drivers\GeminiDriver;
+use Cognesy\Instructor\Features\LLM\Drivers\GrokDriver;
use Cognesy\Instructor\Features\LLM\Drivers\MistralDriver;
use Cognesy\Instructor\Features\LLM\Drivers\OpenAICompatibleDriver;
use Cognesy\Instructor\Features\LLM\Drivers\OpenAIDriver;
@@ -232,10 +233,11 @@ protected function makeDriver(LLMConfig $config, CanHandleHttp $httpClient): Can
LLMProviderType::CohereV1 => new CohereV1Driver($config, $httpClient, $this->events),
LLMProviderType::CohereV2 => new CohereV2Driver($config, $httpClient, $this->events),
LLMProviderType::Gemini => new GeminiDriver($config, $httpClient, $this->events),
+ LLMProviderType::Grok => new GrokDriver($config, $httpClient, $this->events),
LLMProviderType::Mistral => new MistralDriver($config, $httpClient, $this->events),
LLMProviderType::OpenAI => new OpenAIDriver($config, $httpClient, $this->events),
LLMProviderType::Fireworks,
- LLMProviderType::Grok,
+ LLMProviderType::GeminiOAI,
LLMProviderType::Groq,
LLMProviderType::Ollama,
LLMProviderType::OpenAICompatible,
diff --git a/src/Features/LLM/InferenceRequest.php b/src/Features/LLM/InferenceRequest.php
index e0aa2695..a72fac52 100644
--- a/src/Features/LLM/InferenceRequest.php
+++ b/src/Features/LLM/InferenceRequest.php
@@ -1,81 +1,79 @@
-cachedContext = $cachedContext;
-
- $this->model = $model;
- $this->options = $options;
- $this->mode = $mode;
-
- $this->messages = match(true) {
- is_string($messages) => [['role' => 'user', 'content' => $messages]],
- default => $messages,
- };
-
- if ($mode->is(Mode::Tools)) {
- $this->tools = $tools;
- $this->toolChoice = $toolChoice;
- } elseif ($mode->is(Mode::Json)) {
- $this->responseFormat = [
- 'type' => 'json_object',
- 'schema' => $responseFormat['schema'] ?? [],
- ];
- } elseif ($mode->is(Mode::JsonSchema)) {
- $this->responseFormat = [
- 'type' => 'json_schema',
- 'json_schema' => [
- 'name' => $responseFormat['json_schema']['name'] ?? 'schema',
- 'schema' => $responseFormat['json_schema']['schema'] ?? [],
- 'strict' => $responseFormat['json_schema']['strict'] ?? true,
- ],
- ];
- } elseif ($mode->is([Mode::Text, Mode::MdJson])) {
- $this->tools = [];
- $this->toolChoice = [];
- $this->responseFormat = [];
- }
- }
-
- public function toArray() : array {
- return [
-// 'uuid' => $this->uuid,
-// 'metadata' => $this->metadata,
- 'messages' => $this->messages,
- 'model' => $this->model,
- 'tools' => $this->tools,
- 'tool_choice' => $this->toolChoice,
- 'response_format' => $this->responseFormat,
- 'options' => $this->options,
- 'mode' => $this->mode->value,
- ];
- }
-}
+cachedContext = $cachedContext;
+
+ $this->model = $model;
+ $this->options = $options;
+ $this->mode = $mode;
+
+ $this->messages = match(true) {
+ is_string($messages) => [['role' => 'user', 'content' => $messages]],
+ default => $messages,
+ };
+
+ if ($mode->is(Mode::Tools)) {
+ $this->tools = $tools;
+ $this->toolChoice = $toolChoice;
+ } elseif ($mode->is(Mode::Json)) {
+ $this->responseFormat = [
+ 'type' => 'json_object',
+ 'schema' => $responseFormat['schema'] ?? [],
+ ];
+ } elseif ($mode->is(Mode::JsonSchema)) {
+ $this->responseFormat = [
+ 'type' => 'json_schema',
+ 'json_schema' => [
+ 'name' => $responseFormat['json_schema']['name'] ?? 'schema',
+ 'schema' => $responseFormat['json_schema']['schema'] ?? [],
+ 'strict' => $responseFormat['json_schema']['strict'] ?? true,
+ ],
+ ];
+ } elseif ($mode->is([Mode::Text, Mode::MdJson])) {
+ $this->tools = [];
+ $this->toolChoice = [];
+ $this->responseFormat = [];
+ }
+ }
+
+ public function toArray() : array {
+ return [
+// 'uuid' => $this->uuid,
+// 'metadata' => $this->metadata,
+ 'messages' => $this->messages,
+ 'model' => $this->model,
+ 'tools' => $this->tools,
+ 'tool_choice' => $this->toolChoice,
+ 'response_format' => $this->responseFormat,
+ 'options' => $this->options,
+ 'mode' => $this->mode->value,
+ ];
+ }
+}
diff --git a/src/Utils/Json/PartialJsonParser.php b/src/Utils/Json/PartialJsonParser.php
index fcc70546..fa95bbe8 100644
--- a/src/Utils/Json/PartialJsonParser.php
+++ b/src/Utils/Json/PartialJsonParser.php
@@ -1,204 +1,205 @@
-parsers = array_fill_keys([' ', "\r", "\n", "\t"], $this->parseSpace(...));
- $this->parsers['['] = $this->parseArray(...);
- $this->parsers['{'] = $this->parseObject(...);
- $this->parsers['"'] = $this->parseString(...);
- $this->parsers['t'] = $this->parseTrue(...);
- $this->parsers['f'] = $this->parseFalse(...);
- $this->parsers['n'] = $this->parseNull(...);
-
- foreach (str_split('0123456789.-') as $char) {
- $this->parsers[$char] = $this->parseNumber(...);
- }
-
- $this->onExtraToken = function ($text, $data, $reminding) {
- // we're just skipping extra tokens
- //$message = 'Parsed JSON with extra tokens: ' . Json::encode(['text' => $text, 'data' => $data, 'reminding' => $reminding]);
- //throw new JSONParsingException($message, $text);
- };
- }
-
- public function fix(string $partialJson) : string {
- return Json::encode($this->parse($partialJson));
- }
-
- public function parse(string $json, bool $associative = true) : array|object
- {
- if (strlen($json) >= 1) {
- try {
- return json_decode($json, $associative, 512, JSON_THROW_ON_ERROR);
- } catch (JsonException $e) {
- list($data, $reminding) = $this->parseAny($json, $e);
- $this->lastParseReminding = $reminding;
- if ($this->onExtraToken && $reminding && !$this->skipExtraTokens) {
- ($this->onExtraToken)($json, $data, $reminding);
- }
- return match(true) {
- is_string($data) => [],
- default => $data,
- };
- } catch (Exception $e) {
- throw new JsonParsingException('Unexpected error: ' . $e->getMessage());
- }
- } else {
- return json_decode('{}', $associative);
- }
- }
-
- public function reminder() : string {
- return $this->lastParseReminding;
- }
-
- private function parseAny(string $json, Throwable $e) : array|object {
- if (!$json) {
- throw $e;
- }
- $parser = $this->parsers[$json[0]] ?? null;
- if (!$parser) {
- throw $e;
- }
- return $parser($json, $e);
- }
-
- private function parseSpace(string $json, Throwable $e) : array|object {
- return $this->parseAny(trim($json), $e);
- }
-
- private function parseArray(string $json, Throwable $e) : array {
- $json = substr($json, 1); // skip starting '['
- $acc = [];
- $json = trim($json);
- while ($json) {
- if ($json[0] === ']') {
- $json = substr($json, 1); // skip ending ']'
- break;
- }
- list($res, $json) = $this->parseAny($json, $e);
- $acc[] = $res;
- $json = trim($json);
- if (strpos($json, ',') === 0) {
- $json = substr($json, 1);
- $json = trim($json);
- }
- }
- return [$acc, $json];
- }
-
- private function parseObject(string $json, Throwable $e) : array {
- $json = substr($json, 1); // skip starting '{'
- $acc = [];
- $json = trim($json);
- while ($json) {
- if ($json[0] === '}') {
- $json = substr($json, 1); // skip ending '}'
- break;
- }
- list($key, $json) = $this->parseAny($json, $e);
- $json = trim($json);
-
- if (!$json || $json[0] === '}') {
- $acc[$key] = null;
- break;
- }
-
- if ($json[0] !== ':') {
- throw $e;
- }
-
- $json = substr($json, 1); // skip ':'
- $json = trim($json);
-
- if (!$json || in_array($json[0], [',', '}'])) {
- $acc[$key] = null;
- if (strpos($json, ',') === 0) {
- $json = substr($json, 1);
- }
- break;
- }
-
- list($value, $json) = $this->parseAny($json, $e);
- $acc[$key] = $value;
- $json = trim($json);
- if (strpos($json, ',') === 0) {
- $json = substr($json, 1);
- $json = trim($json);
- }
- }
- return [$acc, $json];
- }
-
- private function parseString(string $json, Throwable $e) : array {
- $end = strpos($json, '"', 1);
- while ($end !== false && $json[$end - 1] === '\\') { // Handle escaped quotes
- $end = strpos($json, '"', $end + 1);
- }
- if ($end === false) {
- // Return the incomplete string without the opening quote
- return [substr($json, 1), ""];
- }
- $strVal = substr($json, 0, $end + 1);
- $json = substr($json, $end + 1);
- return [json_decode($strVal), $json];
- }
-
- private function parseNumber(string $json, Throwable $e) : array {
- $i = 0;
- while ($i < strlen($json) && strpos('0123456789.-', $json[$i]) !== false) {
- $i++;
- }
- $numStr = substr($json, 0, $i);
- $json = substr($json, $i);
- if ($numStr == '' || substr($numStr, -1) === '.' || substr($numStr, -1) === '-') {
- // Return the incomplete number as is
- return [$numStr, ""];
- }
- if (strpos($numStr, '.') !== false || strpos($numStr, 'e') !== false || strpos($numStr, 'E') !== false) {
- $num = (float) $numStr;
- } else {
- $num = (int) $numStr;
- }
- return [$num, $json];
- }
-
- private function parseTrue(string $json, Throwable $e) : array {
- if (substr($json, 0, 4) === 'true') {
- return [true, substr($json, 4)];
- }
- throw $e;
- }
-
- private function parseFalse(string $json, Throwable $e) : array {
- if (substr($json, 0, 5) === 'false') {
- return [false, substr($json, 5)];
- }
- throw $e;
- }
-
- private function parseNull(string $json, Throwable $e) : array {
- if (substr($json, 0, 4) === 'null') {
- return [null, substr($json, 4)];
- }
- throw $e;
- }
+parsers = array_fill_keys([' ', "\r", "\n", "\t"], $this->parseSpace(...));
+ $this->parsers['['] = $this->parseArray(...);
+ $this->parsers['{'] = $this->parseObject(...);
+ $this->parsers['"'] = $this->parseString(...);
+ $this->parsers['t'] = $this->parseTrue(...);
+ $this->parsers['f'] = $this->parseFalse(...);
+ $this->parsers['n'] = $this->parseNull(...);
+
+ foreach (str_split('0123456789.-') as $char) {
+ $this->parsers[$char] = $this->parseNumber(...);
+ }
+
+ $this->onExtraToken = function ($text, $data, $reminding) {
+ // we're just skipping extra tokens
+
+ //$message = 'Parsed JSON with extra tokens: ' . Json::encode(['text' => $text, 'data' => $data, 'reminding' => $reminding]);
+ //throw new JSONParsingException($message, $text);
+ };
+ }
+
+ public function fix(string $partialJson) : string {
+ return Json::encode($this->parse($partialJson));
+ }
+
+ public function parse(string $json, bool $associative = true) : array|object
+ {
+ if (strlen($json) >= 1) {
+ try {
+ return json_decode($json, $associative, 512, JSON_THROW_ON_ERROR);
+ } catch (JsonException $e) {
+ list($data, $reminding) = $this->parseAny($json, $e);
+ $this->lastParseReminding = $reminding;
+ if ($this->onExtraToken && $reminding && !$this->skipExtraTokens) {
+ ($this->onExtraToken)($json, $data, $reminding);
+ }
+ return match(true) {
+ is_string($data) => [],
+ default => $data,
+ };
+ } catch (Exception $e) {
+ throw new JsonParsingException('Unexpected error: ' . $e->getMessage());
+ }
+ } else {
+ return json_decode('{}', $associative);
+ }
+ }
+
+ public function reminder() : string {
+ return $this->lastParseReminding;
+ }
+
+ private function parseAny(string $json, Throwable $e) : array|object {
+ if (!$json) {
+ throw $e;
+ }
+ $parser = $this->parsers[$json[0]] ?? null;
+ if (!$parser) {
+ throw $e;
+ }
+ return $parser($json, $e);
+ }
+
+ private function parseSpace(string $json, Throwable $e) : array|object {
+ return $this->parseAny(trim($json), $e);
+ }
+
+ private function parseArray(string $json, Throwable $e) : array {
+ $json = substr($json, 1); // skip starting '['
+ $acc = [];
+ $json = trim($json);
+ while ($json) {
+ if ($json[0] === ']') {
+ $json = substr($json, 1); // skip ending ']'
+ break;
+ }
+ list($res, $json) = $this->parseAny($json, $e);
+ $acc[] = $res;
+ $json = trim($json);
+ if (strpos($json, ',') === 0) {
+ $json = substr($json, 1);
+ $json = trim($json);
+ }
+ }
+ return [$acc, $json];
+ }
+
+ private function parseObject(string $json, Throwable $e) : array {
+ $json = substr($json, 1); // skip starting '{'
+ $acc = [];
+ $json = trim($json);
+ while ($json) {
+ if ($json[0] === '}') {
+ $json = substr($json, 1); // skip ending '}'
+ break;
+ }
+ list($key, $json) = $this->parseAny($json, $e);
+ $json = trim($json);
+
+ if (!$json || $json[0] === '}') {
+ $acc[$key] = null;
+ break;
+ }
+
+ if ($json[0] !== ':') {
+ throw $e;
+ }
+
+ $json = substr($json, 1); // skip ':'
+ $json = trim($json);
+
+ if (!$json || in_array($json[0], [',', '}'])) {
+ $acc[$key] = null;
+ if (strpos($json, ',') === 0) {
+ $json = substr($json, 1);
+ }
+ break;
+ }
+
+ list($value, $json) = $this->parseAny($json, $e);
+ $acc[$key] = $value;
+ $json = trim($json);
+ if (strpos($json, ',') === 0) {
+ $json = substr($json, 1);
+ $json = trim($json);
+ }
+ }
+ return [$acc, $json];
+ }
+
+ private function parseString(string $json, Throwable $e) : array {
+ $end = strpos($json, '"', 1);
+ while ($end !== false && $json[$end - 1] === '\\') { // Handle escaped quotes
+ $end = strpos($json, '"', $end + 1);
+ }
+ if ($end === false) {
+ // Return the incomplete string without the opening quote
+ return [substr($json, 1), ""];
+ }
+ $strVal = substr($json, 0, $end + 1);
+ $json = substr($json, $end + 1);
+ return [json_decode($strVal), $json];
+ }
+
+ private function parseNumber(string $json, Throwable $e) : array {
+ $i = 0;
+ while ($i < strlen($json) && strpos('0123456789.-', $json[$i]) !== false) {
+ $i++;
+ }
+ $numStr = substr($json, 0, $i);
+ $json = substr($json, $i);
+ if ($numStr == '' || substr($numStr, -1) === '.' || substr($numStr, -1) === '-') {
+ // Return the incomplete number as is
+ return [$numStr, ""];
+ }
+ if (strpos($numStr, '.') !== false || strpos($numStr, 'e') !== false || strpos($numStr, 'E') !== false) {
+ $num = (float) $numStr;
+ } else {
+ $num = (int) $numStr;
+ }
+ return [$num, $json];
+ }
+
+ private function parseTrue(string $json, Throwable $e) : array {
+ if (substr($json, 0, 4) === 'true') {
+ return [true, substr($json, 4)];
+ }
+ throw $e;
+ }
+
+ private function parseFalse(string $json, Throwable $e) : array {
+ if (substr($json, 0, 5) === 'false') {
+ return [false, substr($json, 5)];
+ }
+ throw $e;
+ }
+
+ private function parseNull(string $json, Throwable $e) : array {
+ if (substr($json, 0, 4) === 'null') {
+ return [null, substr($json, 4)];
+ }
+ throw $e;
+ }
}
\ No newline at end of file
diff --git a/src/Utils/Messages/Message.php b/src/Utils/Messages/Message.php
index 1d94ecbc..cd79aaf4 100644
--- a/src/Utils/Messages/Message.php
+++ b/src/Utils/Messages/Message.php
@@ -1,20 +1,35 @@
- $content
- */
- public function __construct(
- public string $role = '',
- public string|array $content = '',
- ) {}
-}
+ $content
+ */
+ public function __construct(
+ string|MessageRole $role = '',
+ string|array|null $content = '',
+ array $metadata = [],
+ ) {
+ $this->role = match(true) {
+ $role instanceof MessageRole => $role->value,
+ ($role === '') => self::DEFAULT_ROLE,
+ default => $role,
+ };
+ $this->content = $content ?? '';
+ $this->metadata = $metadata;
+ }
+}
diff --git a/src/Utils/Messages/Traits/Message/HandlesAccess.php b/src/Utils/Messages/Traits/Message/HandlesAccess.php
index eca87b1e..c17c0035 100644
--- a/src/Utils/Messages/Traits/Message/HandlesAccess.php
+++ b/src/Utils/Messages/Traits/Message/HandlesAccess.php
@@ -1,36 +1,66 @@
-role);
- }
-
- public function content() : string|array {
- return $this->content;
- }
-
- public function isEmpty() : bool {
- return empty($this->content);
- }
-
- public function isNull() : bool {
- return ($this->role === '' && $this->content === '');
- }
-
- public function isComposite() : bool {
- return is_array($this->content);
- }
+role);
+ }
+
+ public function content() : string|array {
+ return $this->content;
+ }
+
+ public function isEmpty() : bool {
+ return empty($this->content) && !$this->hasMeta();
+ }
+
+ public function isNull() : bool {
+ return ($this->role === '' && $this->content === '');
+ }
+
+ public function isComposite() : bool {
+ return is_array($this->content);
+ }
+
+ public function hasMeta(string $key = null) : bool {
+ return match(true) {
+ $key === null => !empty($this->metadata),
+ default => isset($this->metadata[$key]),
+ };
+ }
+
+ public function meta(string $key = null) : mixed {
+ return match(true) {
+ $key === null => $this->metadata,
+ default => $this->metadata[$key] ?? null,
+ };
+ }
+
+ public function metaKeys() : array {
+ return array_keys($this->metadata);
+ }
+
+ public function withMeta(array $metadata) : self {
+ $this->metadata = $metadata;
+ return $this;
+ }
+
+ public function withMetaValue(string $key, mixed $value) : self {
+ $this->metadata[$key] = $value;
+ return $this;
+ }
}
\ No newline at end of file
diff --git a/src/Utils/Messages/Traits/Message/HandlesCreation.php b/src/Utils/Messages/Traits/Message/HandlesCreation.php
index c1ea12a6..6502faaa 100644
--- a/src/Utils/Messages/Traits/Message/HandlesCreation.php
+++ b/src/Utils/Messages/Traits/Message/HandlesCreation.php
@@ -6,7 +6,6 @@
use Cognesy\Instructor\Utils\Messages\Message;
use Cognesy\Instructor\Utils\Messages\Utils\Text;
use Exception;
-use InvalidArgumentException;
trait HandlesCreation
{
@@ -19,10 +18,11 @@ public static function fromString(string $content) : static {
}
public static function fromArray(array $message) : static {
- if (!self::hasRoleAndContent($message)) {
- throw new InvalidArgumentException('Message array must contain "role" and "content" keys');
- }
- return new static($message['role'], $message['content']);
+ return new static(
+ role: $message['role'] ?? 'user',
+ content: $message['content'] ?? '',
+ metadata: $message['_metadata'] ?? [],
+ );
}
public static function fromContent(string $role, string|array $content) : static {
diff --git a/src/Utils/Messages/Traits/Message/HandlesTransformation.php b/src/Utils/Messages/Traits/Message/HandlesTransformation.php
index a053765e..82ec8d30 100644
--- a/src/Utils/Messages/Traits/Message/HandlesTransformation.php
+++ b/src/Utils/Messages/Traits/Message/HandlesTransformation.php
@@ -1,51 +1,46 @@
- $this->role, 'content' => $this->content];
- }
-
- public function toString() : string {
- if (!$this->isComposite()) {
- return $this->content;
- }
- // flatten composite message to text
- $text = '';
- foreach($this->content as $part) {
- if ($part['type'] !== 'text') {
- throw new RuntimeException('Message contains non-text parts and cannot be flattened to text');
- }
- $text .= $part['text'];
- }
- return $text;
- }
-
- public function toRoleString() : string {
- return $this->role . ': ' . $this->toString();
- }
-
- public function toCompositeMessage() : Message {
- return Message::fromArray($this->toCompositeArray());
- }
-
- public function toCompositeArray() : array {
- return match($this->isComposite()) {
- true => [
- 'role' => $this->role,
- 'content' => $this->content,
- ],
- default => [
- 'role' => $this->role,
- 'content' => [[
- 'type' => 'text',
- 'text' => $this->content,
- ]]
- ]
- };
- }
-}
+ $this->role,
+ 'content' => $this->content,
+ '_metadata' => $this->metadata,
+ ]);
+ }
+
+ public function toString() : string {
+ if (!$this->isComposite()) {
+ return $this->content;
+ }
+ // flatten composite message to text
+ $text = '';
+ foreach($this->content as $part) {
+ if ($part['type'] !== 'text') {
+ throw new RuntimeException('Message contains non-text parts and cannot be flattened to text');
+ }
+ $text .= $part['text'];
+ }
+ return $text;
+ }
+
+ public function toRoleString() : string {
+ return $this->role . ': ' . $this->toString();
+ }
+
+ public function toCompositeMessage() : Message {
+ return Message::fromArray([
+ 'role' => $this->role,
+ 'content' => match(true) {
+ $this->isComposite() => $this->content,
+ default => [['type' => 'text', 'text' => $this->content]]
+ },
+ '_metadata' => $this->metadata,
+ ]);
+ }
+}
diff --git a/src/Utils/Messages/Traits/Messages/HandlesConversion.php b/src/Utils/Messages/Traits/Messages/HandlesConversion.php
index fa64f8ad..b4bf3205 100644
--- a/src/Utils/Messages/Traits/Messages/HandlesConversion.php
+++ b/src/Utils/Messages/Traits/Messages/HandlesConversion.php
@@ -1,91 +1,92 @@
-> $messages
- * @return array>
- */
- public static function asPerRoleArray(array $messages) : array {
- if (empty($messages)) {
- return ['role' => 'system', 'content' => ''];
- }
-
- $role = 'user';
- $merged = new Messages();
- $content = [];
- foreach ($messages as $message) {
- if ($role !== $message['role'] || Message::becomesComposite($message)) {
- $merged->appendMessage(new Message(
- role: $role,
- content: implode("\n\n", array_filter($content)),
- ));
- $role = $message['role'];
- $content = [];
-
- if (Message::becomesComposite($message)) {
- $merged->appendMessage($message);
- continue;
- }
- }
- $content[] = $message['content'];
- }
- // append remaining content
- if (!empty($content)) {
- $merged->appendMessage(new Message(
- role: $role,
- content: implode("\n", array_filter($content)), // TODO: see above
- ));
- }
- return $merged->toArray();
- }
-
- public static function asString(
- array $messages,
- string $separator = "\n",
- callable $renderer = null
- ) : string {
- $result = '';
- foreach ($messages as $message) {
- if (empty($message) || !is_array($message) || empty($message['content'])) {
- continue;
- }
- $rendered = match(true) {
- !is_null($renderer) => $renderer($message),
- default => match(true) {
- Message::becomesComposite($message) => throw new RuntimeException('Array contains composite messages, cannot be converted to string.'),
- default => $message['content'] . $separator,
- }
- };
- $result .= $rendered;
- }
- return $result;
- }
-
- /**
- * @return array
- */
- public function toArray() : array {
- $result = [];
- foreach ($this->messages as $message) {
- if ($message->isEmpty()) {
- continue;
- }
- $result[] = $message->toArray();
- }
- return $result;
- }
-
- public function toString(string $separator = "\n") : string {
- if ($this->hasComposites()) {
- throw new RuntimeException('Collection contains composite messages and cannot be converted to string.');
- }
- return self::asString($this->toArray(), $separator);
- }
+> $messages
+ * @return array>
+ */
+ public static function asPerRoleArray(array $messages) : array {
+ if (empty($messages)) {
+ return ['role' => 'user', 'content' => ''];
+ }
+
+ $role = 'user';
+ $merged = new Messages();
+ $content = [];
+ foreach ($messages as $message) {
+ if ($role !== $message['role'] || Message::becomesComposite($message)) {
+ $merged->appendMessage(new Message(
+ role: $role,
+ content: implode("\n\n", array_filter($content)),
+ ));
+ $role = $message['role'];
+ $content = [];
+
+ if (Message::becomesComposite($message)) {
+ $merged->appendMessage($message);
+ continue;
+ }
+ }
+ $content[] = $message['content'];
+ }
+ // append remaining content
+ if (!empty($content)) {
+ $merged->appendMessage(new Message(
+ role: $role,
+ content: implode("\n", array_filter($content)), // TODO: see above
+ metadata: $message['_metadata'],
+ ));
+ }
+ return $merged->toArray();
+ }
+
+ public static function asString(
+ array $messages,
+ string $separator = "\n",
+ callable $renderer = null
+ ) : string {
+ $result = '';
+ foreach ($messages as $message) {
+ if (empty($message) || !is_array($message) || empty($message['content'])) {
+ continue;
+ }
+ $rendered = match(true) {
+ !is_null($renderer) => $renderer($message),
+ default => match(true) {
+ Message::becomesComposite($message) => throw new RuntimeException('Array contains composite messages, cannot be converted to string.'),
+ default => $message['content'] . $separator,
+ }
+ };
+ $result .= $rendered;
+ }
+ return $result;
+ }
+
+ /**
+ * @return array
+ */
+ public function toArray() : array {
+ $result = [];
+ foreach ($this->messages as $message) {
+ if ($message->isEmpty()) {
+ continue;
+ }
+ $result[] = $message->toArray();
+ }
+ return $result;
+ }
+
+ public function toString(string $separator = "\n") : string {
+ if ($this->hasComposites()) {
+ throw new RuntimeException('Collection contains composite messages and cannot be converted to string.');
+ }
+ return self::asString($this->toArray(), $separator);
+ }
}
\ No newline at end of file
diff --git a/src/Utils/Messages/Traits/Messages/HandlesCreation.php b/src/Utils/Messages/Traits/Messages/HandlesCreation.php
index 11e29b21..863534c9 100644
--- a/src/Utils/Messages/Traits/Messages/HandlesCreation.php
+++ b/src/Utils/Messages/Traits/Messages/HandlesCreation.php
@@ -1,91 +1,91 @@
-appendMessage(Message::fromString($content));
- }
-
- /**
- * @param array $messages
- */
- static public function fromArray(array $messages) : Messages {
- $instance = new self();
- foreach ($messages as $message) {
- $instance->messages[] = match(true) {
- is_string($message) => Message::fromString($message),
- Message::hasRoleAndContent($message) => new Message($message['role'], $message['content']),
- default => throw new Exception('Invalid message array - missing role or content keys'),
- };
- }
- return $instance;
- }
-
- /**
- * @param Messages[] $messages
- */
- static public function fromMessages(array|Message|Messages ...$arrayOfMessages) : Messages {
- $instance = new self();
- foreach ($arrayOfMessages as $message) {
- if ($message instanceof Messages) {
- $instance->appendMessages($message);
- } elseif ($message instanceof Message) {
- $instance->appendMessage($message);
- } elseif (is_array($message)) {
- $instance->appendMessage(Message::fromArray($message));
- } else {
- throw new InvalidArgumentException('Invalid type for message');
- }
- }
- return $instance;
- }
-
- public static function fromAnyArray(array $messages) : Messages {
- if (Message::hasRoleAndContent($messages)) {
- return self::fromArray([$messages]);
- }
- $normalized = new self();
- foreach ($messages as $message) {
- $normalized->appendMessage(match(true) {
- is_array($message) => match(true) {
- Message::hasRoleAndContent($message) => new Message($message['role'], $message['content']),
- default => throw new Exception('Invalid message array - missing role or content keys'),
- },
- is_string($message) => new Message('user', $message),
- $message instanceof Message => $message,
- default => throw new Exception('Invalid message type'),
- });
- }
- return $normalized;
- }
-
- public static function fromAny(string|array|Message|Messages $messages) : Messages {
- return match(true) {
- is_string($messages) => self::fromString($messages),
- is_array($messages) => self::fromAnyArray($messages),
- $messages instanceof Message => (new Messages)->appendMessage($messages),
- $messages instanceof Messages => $messages,
- default => throw new Exception('Invalid message type'),
- };
- }
-
- public static function fromInput(string|array|object $input) : static {
- return match(true) {
- $input instanceof Messages => $input,
- $input instanceof CanProvideMessages => $input->toMessages(),
- $input instanceof Message => (new Messages)->appendMessage($input),
- $input instanceof CanProvideMessage => (new Messages)->appendMessage($input->toMessage()),
- default => (new Messages)->appendMessage(new Message('user', Text::fromAny($input))),
- };
- }
-}
+appendMessage(Message::fromString($content));
+ }
+
+ /**
+ * @param array $messages
+ */
+ static public function fromArray(array $messages) : Messages {
+ $instance = new self();
+ foreach ($messages as $message) {
+ $instance->messages[] = match(true) {
+ is_string($message) => Message::fromString($message),
+ Message::hasRoleAndContent($message) => Message::fromArray($message),
+ default => throw new Exception('Invalid message array - missing role or content keys'),
+ };
+ }
+ return $instance;
+ }
+
+ /**
+ * @param Messages[] $messages
+ */
+ static public function fromMessages(array|Message|Messages ...$arrayOfMessages) : Messages {
+ $instance = new self();
+ foreach ($arrayOfMessages as $message) {
+ if ($message instanceof Messages) {
+ $instance->appendMessages($message);
+ } elseif ($message instanceof Message) {
+ $instance->appendMessage($message);
+ } elseif (is_array($message)) {
+ $instance->appendMessage(Message::fromArray($message));
+ } else {
+ throw new InvalidArgumentException('Invalid type for message');
+ }
+ }
+ return $instance;
+ }
+
+ public static function fromAnyArray(array $messages) : Messages {
+ if (Message::hasRoleAndContent($messages)) {
+ return self::fromArray([$messages]);
+ }
+ $normalized = new self();
+ foreach ($messages as $message) {
+ $normalized->appendMessage(match(true) {
+ is_array($message) => match(true) {
+ Message::hasRoleAndContent($message) => Message::fromArray($message),
+ default => throw new Exception('Invalid message array - missing role or content keys'),
+ },
+ is_string($message) => new Message('user', $message),
+ $message instanceof Message => $message,
+ default => throw new Exception('Invalid message type'),
+ });
+ }
+ return $normalized;
+ }
+
+ public static function fromAny(string|array|Message|Messages $messages) : Messages {
+ return match(true) {
+ is_string($messages) => self::fromString($messages),
+ is_array($messages) => self::fromAnyArray($messages),
+ $messages instanceof Message => (new Messages)->appendMessage($messages),
+ $messages instanceof Messages => $messages,
+ default => throw new Exception('Invalid message type'),
+ };
+ }
+
+ public static function fromInput(string|array|object $input) : static {
+ return match(true) {
+ $input instanceof Messages => $input,
+ $input instanceof CanProvideMessages => $input->toMessages(),
+ $input instanceof Message => (new Messages)->appendMessage($input),
+ $input instanceof CanProvideMessage => (new Messages)->appendMessage($input->toMessage()),
+ default => (new Messages)->appendMessage(new Message('user', Text::fromAny($input))),
+ };
+ }
+}
diff --git a/src/Utils/Str.php b/src/Utils/Str.php
index a7fc5e8f..7e8375aa 100644
--- a/src/Utils/Str.php
+++ b/src/Utils/Str.php
@@ -94,34 +94,34 @@ static public function containsAny(string $haystack, string|array $needles, bool
return false;
}
- public static function startsWith(string $url, string $string) : bool {
- return substr($url, 0, strlen($string)) === $string;
+ public static function startsWith(string $text, string $prefix) : bool {
+ return substr($text, 0, strlen($prefix)) === $prefix;
}
- public static function endsWith(string $url, string $string) : bool {
- return substr($url, -strlen($string)) === $string;
+ public static function endsWith(string $text, string $suffix) : bool {
+ return substr($text, -strlen($suffix)) === $suffix;
}
- public static function between(mixed $url, string $string, string $string1) : string {
- $start = strpos($url, $string);
+ public static function between(mixed $text, string $firstNeedle, string $nextNeedle) : string {
+ $start = strpos($text, $firstNeedle);
if ($start === false) {
return '';
}
- $start += strlen($string);
- $end = strpos($url, $string1, $start);
+ $start += strlen($firstNeedle);
+ $end = strpos($text, $nextNeedle, $start);
if ($end === false) {
return '';
}
- return substr($url, $start, $end - $start);
+ return substr($text, $start, $end - $start);
}
- public static function after(mixed $url, string $string) : string {
- $start = strpos($url, $string);
+ public static function after(mixed $text, string $needle) : string {
+ $start = strpos($text, $needle);
if ($start === false) {
return '';
}
- $start += strlen($string);
- return substr($url, $start);
+ $start += strlen($needle);
+ return substr($text, $start);
}
public static function when(bool $condition, string $onTrue, string $onFalse) : string {
diff --git a/src/Utils/TemplateUtil.php b/src/Utils/TemplateUtil.php
index 48675382..60d620a9 100644
--- a/src/Utils/TemplateUtil.php
+++ b/src/Utils/TemplateUtil.php
@@ -1,139 +1,139 @@
-clearUnknownParams = $clearUnknownParams;
- $this->parameters = $parameters;
- if (empty($parameters)) {
- return;
- }
- // remove keys starting with @ - these are used for section templates
- $filteredParameters = array_filter(
- $parameters,
- fn($key) => substr($key, 0, 1) !== '@',
- ARRAY_FILTER_USE_KEY
- );
- $materializedParameters = $this->materializeParameters($filteredParameters);
- $this->parameterValues = array_values($materializedParameters);
- $this->parameterKeys = array_map(
- fn($key) => $this->varPattern($key),
- array_keys($materializedParameters)
- );
- }
-
- public function getParameters() : array {
- return $this->parameters;
- }
-
- public static function cleanVarMarkers(string $template) : string {
- return str_replace(['<|', '|>'], '', $template);
- }
-
- public static function render(
- string $template,
- array $parameters,
- bool $clearUnknownParams = true,
- ) : string {
- return (new TemplateUtil(
- $parameters,
- $clearUnknownParams
- ))->renderString($template);
- }
-
- public function renderString(string $template): string {
- // find all keys in the template
- $keys = $this->findVars($template);
- if ($this->clearUnknownParams) {
- // find keys missing from $this->keys
- $missingKeys = array_diff($keys, $this->parameterKeys);
- // remove missing key strings from the template
- $template = str_replace($missingKeys, '', $template);
- }
- // render values
- return str_replace($this->parameterKeys, $this->parameterValues, $template);
- }
-
- public function renderArray(
- array $rows,
- string $field = 'content'
- ): array {
- return array_map(
- fn($item) => $this->renderString($item[$field] ?? ''),
- $rows
- );
- }
-
- public function renderMessage(array|Message $message) : array {
- $normalized = match(true) {
- is_array($message) => Message::fromArray($message),
- $message instanceof Message => $message,
- default => throw new InvalidArgumentException('Invalid message type'),
- };
-
- // skip rendering if content is an array - it may contain non-text data
- if (is_array($normalized->content)) {
- return ['role' => $normalized->role, 'content' => $normalized->content];
- }
-
- return ['role' => $normalized->role, 'content' => $this->renderString($normalized->content)];
- }
-
- public function renderMessages(array|Messages $messages) : array {
- return array_map(
- fn($message) => $this->renderMessage($message),
- is_array($messages) ? $messages : $messages->toArray()
- );
- }
-
- // OVERRIDEABLE //////////////////////////////////////////////////////////////
-
- protected function varPattern(string $key) : string {
- return '<|' . $key . '|>';
- }
-
- protected function findVars(string $template) : array {
- $matches = [];
- // replace {xxx} pattern with <|xxx|> pattern match
- preg_match_all('/<\|([^|]+)\|>/', $template, $matches);
- return $matches[0];
- }
-
- // INTERNAL //////////////////////////////////////////////////////////////////
-
- private function materializeParameters(array $parameters) : array {
- // TODO: is there a way to consolidate value rendering?
- $parameterValues = [];
- foreach ($parameters as $key => $value) {
- $value = match (true) {
- is_scalar($value) => $value,
- is_array($value) => Json::encode($value),
- is_callable($value) => $value($key, $parameters),
- is_object($value) && method_exists($value, 'toString') => $value->toString(),
- is_object($value) && method_exists($value, 'toJson') => $value->toJson(),
- is_object($value) && method_exists($value, 'toArray') => Json::encode($value->toArray()),
- is_object($value) && method_exists($value, 'toSchema') => Json::encode($value->toSchema()),
- is_object($value) && method_exists($value, 'toOutputSchema') => Json::encode($value->toOutputSchema()),
- is_object($value) && property_exists($value, 'value') => $value->value(),
- is_object($value) => Json::encode($value),
- default => $value,
- };
- $parameterValues[$key] = $value;
- }
- return $parameterValues;
- }
-}
+clearUnknownParams = $clearUnknownParams;
+ $this->parameters = $parameters;
+ if (empty($parameters)) {
+ return;
+ }
+ // remove keys starting with @ - these are used for section templates
+ $filteredParameters = array_filter(
+ $parameters,
+ fn($key) => substr($key, 0, 1) !== '@',
+ ARRAY_FILTER_USE_KEY
+ );
+ $materializedParameters = $this->materializeParameters($filteredParameters);
+ $this->parameterValues = array_values($materializedParameters);
+ $this->parameterKeys = array_map(
+ fn($key) => $this->varPattern($key),
+ array_keys($materializedParameters)
+ );
+ }
+
+ public function getParameters() : array {
+ return $this->parameters;
+ }
+
+ public static function cleanVarMarkers(string $template) : string {
+ return str_replace(['<|', '|>'], '', $template);
+ }
+
+ public static function render(
+ string $template,
+ array $parameters,
+ bool $clearUnknownParams = true,
+ ) : string {
+ return (new TemplateUtil(
+ $parameters,
+ $clearUnknownParams
+ ))->renderString($template);
+ }
+
+ public function renderString(string $template): string {
+ // find all keys in the template
+ $keys = $this->findVars($template);
+ if ($this->clearUnknownParams) {
+ // find keys missing from $this->keys
+ $missingKeys = array_diff($keys, $this->parameterKeys);
+ // remove missing key strings from the template
+ $template = str_replace($missingKeys, '', $template);
+ }
+ // render values
+ return str_replace($this->parameterKeys, $this->parameterValues, $template);
+ }
+
+ public function renderArray(
+ array $rows,
+ string $field = 'content'
+ ): array {
+ return array_map(
+ fn($item) => $this->renderString($item[$field] ?? ''),
+ $rows
+ );
+ }
+
+ public function renderMessage(array|Message $message) : array {
+ $normalized = match(true) {
+ is_array($message) => Message::fromArray($message),
+ $message instanceof Message => $message,
+ default => throw new InvalidArgumentException('Invalid message type'),
+ };
+
+ // skip rendering if content is an array - it may contain non-text data
+ if (is_array($normalized->content())) {
+ return ['role' => $normalized->role(), 'content' => $normalized->content()];
+ }
+
+ return ['role' => $normalized->role(), 'content' => $this->renderString($normalized->content())];
+ }
+
+ public function renderMessages(array|Messages $messages) : array {
+ return array_map(
+ fn($message) => $this->renderMessage($message),
+ is_array($messages) ? $messages : $messages->toArray()
+ );
+ }
+
+ // OVERRIDEABLE //////////////////////////////////////////////////////////////
+
+ protected function varPattern(string $key) : string {
+ return '<|' . $key . '|>';
+ }
+
+ protected function findVars(string $template) : array {
+ $matches = [];
+ // replace {xxx} pattern with <|xxx|> pattern match
+ preg_match_all('/<\|([^|]+)\|>/', $template, $matches);
+ return $matches[0];
+ }
+
+ // INTERNAL //////////////////////////////////////////////////////////////////
+
+ private function materializeParameters(array $parameters) : array {
+ // TODO: is there a way to consolidate value rendering?
+ $parameterValues = [];
+ foreach ($parameters as $key => $value) {
+ $value = match (true) {
+ is_scalar($value) => $value,
+ is_array($value) => Json::encode($value),
+ is_callable($value) => $value($key, $parameters),
+ is_object($value) && method_exists($value, 'toString') => $value->toString(),
+ is_object($value) && method_exists($value, 'toJson') => $value->toJson(),
+ is_object($value) && method_exists($value, 'toArray') => Json::encode($value->toArray()),
+ is_object($value) && method_exists($value, 'toSchema') => Json::encode($value->toSchema()),
+ is_object($value) && method_exists($value, 'toOutputSchema') => Json::encode($value->toOutputSchema()),
+ is_object($value) && property_exists($value, 'value') => $value->value(),
+ is_object($value) => Json::encode($value),
+ default => $value,
+ };
+ $parameterValues[$key] = $value;
+ }
+ return $parameterValues;
+ }
+}
diff --git a/src/Utils/Web/Webpage.php b/src/Utils/Web/Webpage.php
index 7d22a35e..777dd2ed 100644
--- a/src/Utils/Web/Webpage.php
+++ b/src/Utils/Web/Webpage.php
@@ -1,34 +1,34 @@
-scraper = $scraper ?? new BasicReader();
- $this->htmlProcessor = new HtmlProcessor();
- }
-
- public function toMessage(): Message {
- return new Message(content: $this->asMarkdown());
- }
-}
+scraper = $scraper ?? new BasicReader();
+ $this->htmlProcessor = new HtmlProcessor();
+ }
+
+ public function toMessage(): Message {
+ return new Message(content: $this->asMarkdown());
+ }
+}