From b49c9df6ef7a72e91cc66137a2166e3236f2cd7d Mon Sep 17 00:00:00 2001 From: ddebowczyk Date: Fri, 7 Jun 2024 13:30:41 +0200 Subject: [PATCH] Implemented Scripts for better control of controlling chat sequences generated by Instructor --- examples/01_Basics/ValidationMixin/run.php | 8 +- examples/02_Advanced/CustomPrompts/run.php | 4 +- examples/02_Advanced/LanguagePrograms/run.php | 2 +- .../02_Advanced/LanguagePrograms2/run.php | 19 +- examples/05_APISupport/LLMSupportGroq/run.php | 8 +- notes/NOTES.md | 31 +- src/Core/MessageBuilder.php | 281 ------------------ src/Core/Messages/Enums/MessageRole.php | 20 ++ src/Core/Messages/Message.php | 20 ++ src/Core/Messages/Messages.php | 12 + src/Core/Messages/Script.php | 18 ++ src/Core/Messages/Section.php | 33 ++ .../Messages/Traits/Message/HandlesAccess.php | 25 ++ .../Traits/Message/HandlesTransformation.php | 13 + .../Traits/Messages/HandlesAccess.php | 79 +++++ .../Traits/Messages/HandlesCreation.php | 41 +++ .../Traits/Messages/HandlesMutation.php | 50 ++++ .../Traits/Messages/HandlesTransformation.php | 30 ++ src/Core/Messages/Traits/RendersTemplates.php | 46 +++ .../Messages/Traits/Script/HandlesAccess.php | 52 ++++ .../Messages/Traits/Script/HandlesContext.php | 15 + .../Traits/Script/HandlesMutation.php | 51 ++++ .../Traits/Script/HandlesReordering.php | 54 ++++ .../Traits/Script/HandlesTransformation.php | 69 +++++ .../Messages/Traits/Section/HandlesAccess.php | 23 ++ .../Traits/Section/HandlesMutation.php | 21 ++ .../Traits/Section/HandlesTransformation.php | 56 ++++ src/Core/Messages/Utils/ChatFormat.php | 37 +++ src/Core/Messages/Utils/MessageBuilder.php | 58 ++++ .../MessageBuilder/HandlesApiProviders.php | 78 +++++ .../Utils/MessageBuilder/MakesMessages.php | 65 ++++ src/Data/Example.php | 18 +- src/Data/Request.php | 19 -- .../Request/HandlesApiRequestFactory.php | 3 +- src/Data/Traits/Request/HandlesPrompts.php | 12 +- src/Extras/Module/Addons/Predict/Predict.php | 25 +- src/Utils/Arrays.php | 9 + src/Utils/Template.php | 97 +++++- tests/Feature/Script/ScriptTest.php | 174 +++++++++++ 39 files changed, 1327 insertions(+), 349 deletions(-) delete mode 100644 src/Core/MessageBuilder.php create mode 100644 src/Core/Messages/Enums/MessageRole.php create mode 100644 src/Core/Messages/Message.php create mode 100644 src/Core/Messages/Messages.php create mode 100644 src/Core/Messages/Script.php create mode 100644 src/Core/Messages/Section.php create mode 100644 src/Core/Messages/Traits/Message/HandlesAccess.php create mode 100644 src/Core/Messages/Traits/Message/HandlesTransformation.php create mode 100644 src/Core/Messages/Traits/Messages/HandlesAccess.php create mode 100644 src/Core/Messages/Traits/Messages/HandlesCreation.php create mode 100644 src/Core/Messages/Traits/Messages/HandlesMutation.php create mode 100644 src/Core/Messages/Traits/Messages/HandlesTransformation.php create mode 100644 src/Core/Messages/Traits/RendersTemplates.php create mode 100644 src/Core/Messages/Traits/Script/HandlesAccess.php create mode 100644 src/Core/Messages/Traits/Script/HandlesContext.php create mode 100644 src/Core/Messages/Traits/Script/HandlesMutation.php create mode 100644 src/Core/Messages/Traits/Script/HandlesReordering.php create mode 100644 src/Core/Messages/Traits/Script/HandlesTransformation.php create mode 100644 src/Core/Messages/Traits/Section/HandlesAccess.php create mode 100644 src/Core/Messages/Traits/Section/HandlesMutation.php create mode 100644 src/Core/Messages/Traits/Section/HandlesTransformation.php create mode 100644 src/Core/Messages/Utils/ChatFormat.php create mode 100644 src/Core/Messages/Utils/MessageBuilder.php create mode 100644 src/Core/Messages/Utils/MessageBuilder/HandlesApiProviders.php create mode 100644 src/Core/Messages/Utils/MessageBuilder/MakesMessages.php create mode 100644 tests/Feature/Script/ScriptTest.php diff --git a/examples/01_Basics/ValidationMixin/run.php b/examples/01_Basics/ValidationMixin/run.php index db7fba46..e403874f 100644 --- a/examples/01_Basics/ValidationMixin/run.php +++ b/examples/01_Basics/ValidationMixin/run.php @@ -16,7 +16,6 @@ $loader = require 'vendor/autoload.php'; $loader->add('Cognesy\\Instructor\\', __DIR__.'../../src/'); -use Cognesy\Instructor\Events\Request\RequestSentToLLM; use Cognesy\Instructor\Instructor; use Cognesy\Instructor\Validation\Traits\ValidationMixin; use Cognesy\Instructor\Validation\ValidationResult; @@ -41,14 +40,15 @@ public function validate() : ValidationResult { } } -$user = (new Instructor)->respond( - messages: [['role' => 'user', 'content' => 'Jason was born in 1990 and graduated in 1010.']], +$user = (new Instructor)->wiretap(fn($e)=>$e->printDump())->respond( + messages: [['role' => 'user', 'content' => 'Jason was born in 2000 and graduated in 1923.']], responseModel: UserDetails::class, + model: 'gpt-3.5-turbo', maxRetries: 2 ); dump($user); -assert($user->graduationYear === 2010); +assert($user->graduationYear === 2023); ?> ``` \ No newline at end of file diff --git a/examples/02_Advanced/CustomPrompts/run.php b/examples/02_Advanced/CustomPrompts/run.php index 6f4bd880..4bddfc49 100644 --- a/examples/02_Advanced/CustomPrompts/run.php +++ b/examples/02_Advanced/CustomPrompts/run.php @@ -39,7 +39,7 @@ class User { ->respond( messages: "Our user Jason is 25 years old.", responseModel: User::class, - prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n{json_schema}\n", + prompt: "\nYour task is to respond correctly with JSON object. Response must follow JSONSchema:\n<|json_schema|>\n", mode: Mode::Json); print("\n# Request for Mode::MdJson:\n\n"); @@ -47,7 +47,7 @@ class User { ->respond( messages: "Our user Jason is 25 years old.", responseModel: User::class, - prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n{json_schema}\n", + prompt: "\nYour task is to respond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n", mode: Mode::MdJson); ?> diff --git a/examples/02_Advanced/LanguagePrograms/run.php b/examples/02_Advanced/LanguagePrograms/run.php index 728a8b5e..4175299f 100644 --- a/examples/02_Advanced/LanguagePrograms/run.php +++ b/examples/02_Advanced/LanguagePrograms/run.php @@ -1,4 +1,4 @@ -# Language programs +# Language programs with Modules Instructor provides an addon allowing to implement complex processing flows using LLM in a modular way. This addon to Instructor has been inspired by DSPy diff --git a/examples/02_Advanced/LanguagePrograms2/run.php b/examples/02_Advanced/LanguagePrograms2/run.php index 5739beaf..5cf9a43e 100644 --- a/examples/02_Advanced/LanguagePrograms2/run.php +++ b/examples/02_Advanced/LanguagePrograms2/run.php @@ -1,4 +1,4 @@ -# Language programs +# 'Structure to structure' LLM processing Instructor provides an addon allowing to implement complex processing flows using LLM in a modular way. This addon to Instructor has been inspired by DSPy @@ -25,6 +25,8 @@ ```php add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); @@ -42,9 +45,10 @@ //#[Description('extract email details from text')] class ParsedEmail extends SignatureData { + // INPUTS #[InputField('text containing email')] public string $text; - + // OUTPUTS #[OutputField('email address of sender')] public string $senderEmail; #[OutputField('subject of the email')] @@ -54,28 +58,29 @@ class ParsedEmail extends SignatureData { } class FixedEmail extends SignatureData { + // INPUTS #[InputField('subject of the email')] public string $subject; #[InputField('body of the email')] public string $body; - + // OUTPUTS #[OutputField('subject of the email with fixed spelling mistakes')] public string $fixedSubject; #[OutputField('body of the email with fixed spelling mistakes')] public string $fixedBody; } -// Alternative way to define the module signature data without extending a class +// Alternative way to define the class signature data without extending a class class EmailTranslation implements HasInputOutputData, CanProvideSchema { use AutoSignature; - + // INPUTS #[InputField('subject of email')] public string $subject; #[InputField('body of email')] public string $body; #[InputField('language to translate to')] public string $language; - + // OUTPUTS #[OutputField('translated subject of email')] public string $translatedSubject; #[OutputField('translated body of email')] @@ -106,7 +111,7 @@ class ProcessEmail extends Module { private Predict $translate; public function __construct() { - $instructor = new Instructor(); + $instructor = (new Instructor);//->withClient(new AnthropicClient(Env::get('ANTHROPIC_API_KEY')));//->wiretap(fn($e) => $e->printDump()); $this->parse = new Predict(signature: ParsedEmail::class, instructor: $instructor); $this->fix = new Predict(signature: FixedEmail::class, instructor: $instructor); diff --git a/examples/05_APISupport/LLMSupportGroq/run.php b/examples/05_APISupport/LLMSupportGroq/run.php index 1590f2d2..00314f14 100644 --- a/examples/05_APISupport/LLMSupportGroq/run.php +++ b/examples/05_APISupport/LLMSupportGroq/run.php @@ -21,7 +21,7 @@ enum UserType : string { } class User { - public int $age; + public ?int $age; public string $name; public string $username; public UserType $role; @@ -45,8 +45,12 @@ class User { messages: "Jason (@jxnlco) is 25 years old and is the admin of this project. He likes playing football and reading books.", responseModel: User::class, model: 'llama3-8b-8192', + options: ['stream' => false], + examples: [[ + 'input' => 'Ive got email Frank - their developer. He asked to come back to him frank@hk.ch. Btw, he plays on drums!', + 'output' => ['age' => null, 'name' => 'Frank', 'role' => 'developer', 'hobbies' => ['playing drums'],], + ]], mode: Mode::Json, - options: ['stream' => false] ); print("Completed response model:\n\n"); diff --git a/notes/NOTES.md b/notes/NOTES.md index 907f706a..ca1cafdf 100644 --- a/notes/NOTES.md +++ b/notes/NOTES.md @@ -14,8 +14,34 @@ Examples to demonstrate use cases. Catch up with the latest additions. -## Research - +## Brain dump + +- MessageSequence to better process multi-stage chat content +- Extract APIClient to a separate package? +- example of integration with Laravel/Livewire +- DSPy next steps: eval, optimize, compile +- Finish logging support +- Documentation for logging +- Add more modules: XoT, RAG, ReAct, etc. +- RAG - how to handle multiple VDB providers? +- Better error messages +- Finish module observability via events - currently no access to this info & only 2 events supported +- Test validation in modules - provide an example +- Parallel execution of modules (e.g. a la Laravel jobs?) +- How to track API rate limits across multiple requests / parallel executions +- Moderation endpoint support +- Make using DocBlocks optional - it may not always to be desired to pass this info to LLM +- Rework Events so they have toArray() method, make __toString() use it +- Document how to inject custom HTTP client +- Git/GitHub integration module to allow easy automation +- Data mapping module(s) for easier data transformations +- Add super detailed tests of Module core functionality - esp. around input/output mappings +- How to handle dynamic module graph definition + visualization + +## To research + +- Schema.org ld+json // Spatie https://github.com/spatie/schema-org // https://developers.google.com/search/docs/appearance/structured-data?hl=pl +- nette/schema https://github.com/nette/schema - Queue-based load leveling - Throttling - Circuit breaker @@ -29,4 +55,3 @@ Catch up with the latest additions. - Sparse task scheduler - Marker and sweeper - Actor model - diff --git a/src/Core/MessageBuilder.php b/src/Core/MessageBuilder.php deleted file mode 100644 index 764ce292..00000000 --- a/src/Core/MessageBuilder.php +++ /dev/null @@ -1,281 +0,0 @@ -clientClass = $clientClass; - $instance->mode = $mode; - $instance->messages = $messages; - $instance->responseModel = $responseModel; - $instance->dataAcknowledgedPrompt = $dataAcknowledgedPrompt; - $instance->prompt = $prompt; - $instance->examples = $examples; - return $instance->makeExtractionRequest(); - } - - // INTERNAL TOOLS /////////////////////////////////////////////////////////////////// - - private function makeExtractionRequest() : array { - // get body creation method based on client - $builder = $this->getBuilder($this->clientClass); - // get the parts of body specific to the client - $body = $builder(); - // filter out empty values - $body = array_filter($body); - return $body; - } - - private function getBuilder(string $clientClass) : callable { - return match($clientClass) { - AnthropicClient::class => fn() => $this->anthropic(), - CohereClient::class => fn() => $this->cohere(), - MistralClient::class => fn() => $this->mistral(), - OpenAIClient::class, - AzureClient::class => fn() => $this->openAI(), - AnyscaleClient::class, - FireworksAIClient::class, - GroqClient::class, - OllamaClient::class, - OpenRouterClient::class, - TogetherAIClient::class => fn() => $this->openAILike(), - default => fn() => [], - }; - } - - private function anthropic() : array { - $body = []; - $messages = $this->makeMessages(); - $body['system'] = $messages->system(); - $body['messages'] = $this->mapToTargetAPI( - type: ClientType::Anthropic->value, - messages: array_filter(array_merge( - $messages->history(), - $messages->command() - )), - ); - return $body; - } - - private function cohere() : array { - $body = []; - $messages = $this->makeMessages(); - $body['preamble'] = $messages->system(); - $body['chat_history'] = $this->mapToTargetAPI( - type: ClientType::Cohere->value, - messages: $messages->history() - ); - $body['message'] = $messages->command(true)[0]['content'] ?? ''; - return array_filter($body); - } - - private function mistral() : array { - $body = []; - $messages = $this->makeMessages(); - $body['messages'] = $this->mapToTargetAPI( - type: ClientType::Mistral->value, - messages: array_filter(array_merge( - empty($messages->system()) ? [] : ['role' => 'system', 'content' => $messages->system()], - $messages->history(), - $messages->command(withSchema: true) - )), - ); - return $body; - } - - private function openAI() : array { - $body = []; - $messages = $this->makeMessages(); - $body['messages'] = $this->mapToTargetAPI( - type: ClientType::OpenAI->value, - messages: array_filter(array_merge( - empty($messages->system()) ? [] : ['role' => 'system', 'content' => $messages->system()], - $messages->history(), - $messages->command(withSchema: true) - )), - ); - return $body; - } - - private function openAILike() : array { - $body = []; - $messages = $this->makeMessages(); - $body['messages'] = $this->mapToTargetAPI( - type: ClientType::OpenAICompatible->value, - messages: array_filter(array_merge( - empty($messages->system()) ? [] : ['role' => 'system', 'content' => $messages->system()], - $messages->history(), - $messages->command() - )), - ); - return $body; - } - - /** - * @return object - * @throws Exception - */ - private function makeMessages() : object { - if (empty($this->messages)) { - throw new Exception('Messages cannot be empty - you have to provide the content for processing.'); - } - - return new class( - $this->clientClass, - $this->mode, - $this->messages, - $this->prompt, - $this->examples, - $this->dataAcknowledgedPrompt, - $this->responseModel, - ) { - private array $system = []; - private array $history = []; - - public function __construct( - private string $clientClass, - private Mode $mode, - private array $messages, - private string $prompt, - private array $examples, - private string $dataAcknowledgedPrompt, - private ResponseModel $responseModel, - ) { - $index = 0; - // extract initial system messages - foreach ($this->messages as $message) { - $role = $message['role']; - $content = $message['content']; - if ($role === 'system') { - $this->system[] = $content; - $index++; - } else { - break; - } - } - // extract history - $this->history = array_slice($this->messages, $index); - } - - // BEGINNING MESSAGE - public function system() : string { - return implode("\n", $this->system); - } - - // ORIGINAL MESSAGES PASSED FOR PROCESSING - DATA - public function history() : array { - return array_merge( - $this->ensureProperSequence( - $this->normalize($this->history) - ), - ); - } - - // PROMPT AND EXAMPLES - public function command(bool $withSchema = false) : array { - $content = ''; - // PROMPT SECTION - if (!empty($this->prompt)) { - $content = match($withSchema) { - true => Template::render( - $this->prompt, - ['json_schema' => $this->responseModel->toJsonSchema()] - ), - default => str_replace('{json_schema}', '', $this->prompt), - }; - } - // EXAMPLES SECTION - $content .= "\n\n"; - if (!empty($this->examples)) { - foreach ($this->examples as $example) { - $content .= $example->toString() . "\n\n"; - } - } - // MERGE PROMPT AND EXAMPLES INTO SINGLE MESSAGE ENTRY - return [['role' => 'user', 'content' => $content]]; - } - - private function normalize(string|array $messages): array { - if (!is_array($messages)) { - return [['role' => 'user', 'content' => $messages]]; - } - return $messages; - } - - private function ensureProperSequence(string|array $messages): array { - // add user turn if assistant was the first to speak - if ($messages[0]['role'] === 'assistant') { - $messages = array_merge([['role' => 'user', 'content' => 'Continue']], $messages); - } - // add assistant turn if user was the last to speak - if ($messages[count($messages)-1]['role'] === 'user') { - $messages = array_merge($messages, [['role' => 'assistant', 'content' => $this->dataAcknowledgedPrompt]]); - } - return $messages; - } - }; - } - - private function mapToTargetAPI(string $type, array $messages) : array { - if (empty($messages)) { - return []; - } - $roleMap = [ - ClientType::Anthropic->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'assistant', 'tool' => 'user'], - ClientType::Cohere->value => ['user' => 'USER', 'assistant' => 'CHATBOT', 'system' => 'CHATBOT', 'tool' => 'USER'], - ClientType::Mistral->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], - ClientType::OpenAI->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], - ClientType::OpenAICompatible->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], - ]; - $keyMap = [ - ClientType::Anthropic->value => 'content', - ClientType::Cohere->value => 'message', - ClientType::Mistral->value => 'content', - ClientType::OpenAICompatible->value => 'content', - ClientType::OpenAI->value => 'content', - ]; - $roles = $roleMap[$type]; - $key = $keyMap[$type]; - $normalized = []; - foreach ($messages as $message) { - $normalized[] = ['role' => $roles[$message['role']], $key => $message['content']]; - } - return $normalized; - } -} diff --git a/src/Core/Messages/Enums/MessageRole.php b/src/Core/Messages/Enums/MessageRole.php new file mode 100644 index 00000000..ec20e0ee --- /dev/null +++ b/src/Core/Messages/Enums/MessageRole.php @@ -0,0 +1,20 @@ + self::System, + 'user' => self::User, + 'assistant' => self::Assistant, + 'tool' => self::Tool, + default => self::User, + }; + } +} diff --git a/src/Core/Messages/Message.php b/src/Core/Messages/Message.php new file mode 100644 index 00000000..a9925bae --- /dev/null +++ b/src/Core/Messages/Message.php @@ -0,0 +1,20 @@ + $content + */ + public function __construct( + public string $role = '', + public string|array $content = '', + ) {} + + public function clone() : Message { + return new Message($this->role, $this->content); + } +} diff --git a/src/Core/Messages/Messages.php b/src/Core/Messages/Messages.php new file mode 100644 index 00000000..0e99d2f8 --- /dev/null +++ b/src/Core/Messages/Messages.php @@ -0,0 +1,12 @@ +sections = $sections; + } +} diff --git a/src/Core/Messages/Section.php b/src/Core/Messages/Section.php new file mode 100644 index 00000000..bebd1522 --- /dev/null +++ b/src/Core/Messages/Section.php @@ -0,0 +1,33 @@ +messages = new Messages(); + } +} + +//enum StepType : string { +// case GoalStatement = 'goal'; +// case GoalAcknowledgement = 'goal_ack'; +// case ContentProvision = 'content'; +// case ContentAcknowledgement = 'content_ack'; +// case StopAndThink = 'think'; +// case ContinueCommand = 'continue'; +// case ToolsRequest = 'tools'; +// case ToolsResponse = 'tools'; +// case InferenceRequest = 'inference'; +// case AssistantResponse = 'response'; +// case RetryRequest = 'retry'; +// case CustomUserStep = 'custom_user'; +// case CustomAssistantStep = 'custom_assistant'; +//} diff --git a/src/Core/Messages/Traits/Message/HandlesAccess.php b/src/Core/Messages/Traits/Message/HandlesAccess.php new file mode 100644 index 00000000..0a7419d7 --- /dev/null +++ b/src/Core/Messages/Traits/Message/HandlesAccess.php @@ -0,0 +1,25 @@ +role); + } + + public function content() : string|array { + return $this->content; + } + + public function isEmpty() : bool { + return empty($this->content); + } + + public function isNull() : bool { + return ($this->role === '' && $this->content === ''); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Message/HandlesTransformation.php b/src/Core/Messages/Traits/Message/HandlesTransformation.php new file mode 100644 index 00000000..b6e89841 --- /dev/null +++ b/src/Core/Messages/Traits/Message/HandlesTransformation.php @@ -0,0 +1,13 @@ + $this->role, 'content' => $this->content]; + } + + public function toString() : string { + return $this->content; + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Messages/HandlesAccess.php b/src/Core/Messages/Traits/Messages/HandlesAccess.php new file mode 100644 index 00000000..5c69b161 --- /dev/null +++ b/src/Core/Messages/Traits/Messages/HandlesAccess.php @@ -0,0 +1,79 @@ +messages)) { + return new Message(); + } + return $this->messages[0]; + } + + public function last() : Message { + if (empty($this->messages)) { + return new Message(); + } + return $this->messages[count($this->messages)-1]; + } + + public function middle() : Messages { + if (count($this->messages) < 3) { + return new Messages(); + } + return Messages::fromMessages(array_slice($this->messages, 1, count($this->messages)-2)); + } + + public function head() : array { + if (empty($this->messages)) { + return []; + } + return array_slice($this->messages, 0, 1); + } + + public function tail() : array { + if (empty($this->messages)) { + return []; + } + return array_slice($this->messages, count($this->messages)-1); + } + + public function isEmpty() : bool { + return empty($this->messages); + } + + public function filter(callable $callback = null) : Messages { + $messages = new Messages(); + foreach ($this->messages as $message) { + if ($message->isEmpty()) { + continue; + } + if ($callback($message)) { + $messages->messages[] = $message->clone(); + } + } + return $messages; + } + + // CONVENIENCE METHODS /////////////////////////////////////////////////// + + public function firstRole() : MessageRole { + return $this->first()?->role(); + } + + public function lastRole() : MessageRole { + return $this->last()?->role(); + } + + public function firstContent() : string|array { + return $this->first()?->content(); + } + + public function lastContent() : string|array { + return $this->last()?->content(); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Messages/HandlesCreation.php b/src/Core/Messages/Traits/Messages/HandlesCreation.php new file mode 100644 index 00000000..fe42e4a7 --- /dev/null +++ b/src/Core/Messages/Traits/Messages/HandlesCreation.php @@ -0,0 +1,41 @@ + $messages + */ + static public function fromArray(array $messages) : Messages + { + $instance = new self(); + foreach ($messages as $message) { + $instance->messages[] = new Message($message['role'], $message['content']); + } + return $instance; + } + + /** + * @param Messages[] $messages + */ + static public function fromMessages(array|Message|Messages ...$arrayOfMessages) : Messages { + $instance = new self(); + foreach ($arrayOfMessages as $message) { + if ($message instanceof Messages) { + $instance->appendMessages($message); + } elseif ($message instanceof Message) { + $instance->add($message); + } elseif (is_array($message)) { + $instance->add(new Message($message['role'], $message['content'])); + } else { + throw new InvalidArgumentException('Invalid type for message'); + } + } + return $instance; + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Messages/HandlesMutation.php b/src/Core/Messages/Traits/Messages/HandlesMutation.php new file mode 100644 index 00000000..0c8ef65f --- /dev/null +++ b/src/Core/Messages/Traits/Messages/HandlesMutation.php @@ -0,0 +1,50 @@ +messages[] = match (true) { + is_array($message) => new Message($message['role'], $message['content']), + default => $message, + }; + return $this; + } + + public function appendMessages(array|Messages $messages) : static { + if ($messages instanceof Messages) { + $this->messages = array_merge($this->messages, $messages->messages); + } else { + foreach ($messages as $message) { + $this->messages[] = new Message($message['role'], $message['content']); + } + } + return $this; + } + + public function prependMessages(array|Messages $messages) : static { + if ($messages instanceof Messages) { + $this->messages = array_merge($messages->messages, $this->messages); + } else { + $prepended = []; + foreach ($messages as $message) { + $prepended = new Message($message['role'], $message['content']); + } + $this->messages = array_merge($prepended, $this->messages); + } + return $this; + } + + public function removeHead() : static { + array_shift($this->messages); + return $this; + } + + public function removeTail() : static { + array_pop($this->messages); + return $this; + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Messages/HandlesTransformation.php b/src/Core/Messages/Traits/Messages/HandlesTransformation.php new file mode 100644 index 00000000..46a9554c --- /dev/null +++ b/src/Core/Messages/Traits/Messages/HandlesTransformation.php @@ -0,0 +1,30 @@ + + */ + public function toArray() : array { + $result = []; + foreach ($this->messages as $message) { + if ($message->isEmpty()) { + continue; + } + $result[] = $message->toArray(); + } + return $result; + } + + public function toString(string $separator = "\n") : string { + $result = ''; + foreach ($this->messages as $message) { + if ($message->isEmpty()) { + continue; + } + $result .= $message->toString() . $separator; + } + return $result; + } +} diff --git a/src/Core/Messages/Traits/RendersTemplates.php b/src/Core/Messages/Traits/RendersTemplates.php new file mode 100644 index 00000000..a0dc3497 --- /dev/null +++ b/src/Core/Messages/Traits/RendersTemplates.php @@ -0,0 +1,46 @@ +|null $context + * @return string + */ + private function renderString(string $template, ?array $context) : string { + return match(true) { + is_null($context) => $template, + default => (new Template($context))->renderString($template), + }; + } + + /** + * @param array|Message $messages + * @param array|null $context + * @return string + */ + protected function renderMessage(array|Message $message, ?array $context) : array { + return match(true) { + is_null($context) => $message, + default => (new Template($context))->renderMessage($message), + }; + } + + /** + * @param array|Messages $messages + * @param array|null $context + * @return string + */ + protected function renderMessages(array|Messages $messages, ?array $context) : array { + return match(true) { + is_null($context) => $messages, + default => (new Template($context))->renderMessages($messages), + }; + } +} diff --git a/src/Core/Messages/Traits/Script/HandlesAccess.php b/src/Core/Messages/Traits/Script/HandlesAccess.php new file mode 100644 index 00000000..809af621 --- /dev/null +++ b/src/Core/Messages/Traits/Script/HandlesAccess.php @@ -0,0 +1,52 @@ +sectionIndex($name); + if ($index === -1) { + $this->addSection(new Section($name)); + $index = $this->sectionIndex($name); + } + return $this->sections[$index]; + } + + /** + * @param string|string[] $sections + */ + public function select(string|array $sections = []) : Script { + $names = match (true) { + empty($sections) => array_map(fn($section) => $section->name, $this->sections), + is_string($sections) => [$sections], + is_array($sections) => $sections, + }; + $script = new Script(); + $script->context = $this->context; + foreach ($names as $sectionName) { + $script->append($this->section($sectionName)); + } + return $script; + } + + public function hasSection(string $name) : bool { + return $this->sectionIndex($name) !== -1; + } + + // INTERNAL //////////////////////////////////////////////////// + + private function sectionIndex(string $name) : int { + $index = -1; + foreach ($this->sections as $i => $section) { + if ($section->name === $name) { + $index = $i; + break; + } + } + return $index; + } +} diff --git a/src/Core/Messages/Traits/Script/HandlesContext.php b/src/Core/Messages/Traits/Script/HandlesContext.php new file mode 100644 index 00000000..718fc3fd --- /dev/null +++ b/src/Core/Messages/Traits/Script/HandlesContext.php @@ -0,0 +1,15 @@ +context; + } + + public function setContext(array $context) : void { + $this->context = $context; + } +} diff --git a/src/Core/Messages/Traits/Script/HandlesMutation.php b/src/Core/Messages/Traits/Script/HandlesMutation.php new file mode 100644 index 00000000..e607d278 --- /dev/null +++ b/src/Core/Messages/Traits/Script/HandlesMutation.php @@ -0,0 +1,51 @@ +append($section); + return $this; + } + + public function append(Section $section) : static { + $this->sections = $this->appendSections([$section]); + return $this; + } + + public function prepend(Section $section) : static { + $this->sections = $this->prependSections([$section]); + return $this; + } + + public function addBefore(string $name, Section $section) : static { + $this->sections = $this->insert($this->sections, $this->sectionIndex($name), [$section]); + return $this; + } + + public function addAfter(string $name, Section $section) : static { + $this->sections = $this->insert($this->sections, $this->sectionIndex($name) + 1, [$section]); + return $this; + } + + // INTERNAL //////////////////////////////////////////////////// + + private function insert(array $array, int $index, array $new) : array { + return array_merge( + array_slice($array, 0, $index), + $new, + array_slice($array, $index) + ); + } + + private function appendSections(array $array) : array { + return array_merge($this->sections, $array); + } + + private function prependSections(array $array) { + return array_merge($array, $this->sections); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Script/HandlesReordering.php b/src/Core/Messages/Traits/Script/HandlesReordering.php new file mode 100644 index 00000000..a8f15aa6 --- /dev/null +++ b/src/Core/Messages/Traits/Script/HandlesReordering.php @@ -0,0 +1,54 @@ +listInOrder($order); + + $script = new Script(); + $script->context = $this->context; + foreach ($sections as $section) { + $script->append($section); + } + return $script; + } + + public function reverse() : Script { + $script = new Script(); + $script->context = $this->context; + foreach ($this->listReverse() as $section) { + $script->append($section); + } + return $script; + } + + // INTERNAL //////////////////////////////////////////////////// + + /** @return Section[] */ + private function listAsIs() : array { + return $this->sections; + } + + /** @return Section[] */ + private function listReverse() : array { + return array_reverse($this->sections); + } + + /** @return Section[] */ + private function listInOrder(array $order) : array { + $ordered = []; + foreach ($order as $name) { + if (!$this->hasSection($name)) { + continue; + } + $section = $this->section($name); + $ordered[] = $section; + } + return $ordered; + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Script/HandlesTransformation.php b/src/Core/Messages/Traits/Script/HandlesTransformation.php new file mode 100644 index 00000000..e8117874 --- /dev/null +++ b/src/Core/Messages/Traits/Script/HandlesTransformation.php @@ -0,0 +1,69 @@ + $order + * @return Messages + */ + public function toMessages() : Messages { + $messages = new Messages(); + foreach ($this->sections as $section) { + $messages->appendMessages($section->toMessages()); + } + return $messages; + } + + /** + * @param array $order + * @param array|null $context + * @return array + */ + public function toArray(array $context = null, bool $raw = false) : array { + $array = $this->toMessages()->toArray(); + return match($raw) { + false => $this->renderMessages($array, Arrays::mergeNull($this->context, $context)), + true => $array, + }; + } + + /** + * @param ClientType $type + * @param array $order + * @param array|null $context + * @return array + */ + public function toNativeArray(ClientType $type, array $context = null) : array { + $array = $this->renderMessages( + $this->toArray(raw: true), + Arrays::mergeNull($this->context, $context) + ); + return ChatFormat::mapToTargetAPI( + clientType: $type, + messages: $array, + ); + } + + /** + * @param array $order + * @param string $separator + * @param array|null $context + * @return string + */ + public function toString(string $separator = "\n", array $context = null) : string { + $text = array_reduce( + $this->toArray(raw: true), + fn($carry, $message) => $carry . $message['content'] . $separator, + ); + if (empty($text)) { + return ''; + } + return $this->renderString($text, Arrays::mergeNull($this->context, $context)); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Section/HandlesAccess.php b/src/Core/Messages/Traits/Section/HandlesAccess.php new file mode 100644 index 00000000..b1404a3d --- /dev/null +++ b/src/Core/Messages/Traits/Section/HandlesAccess.php @@ -0,0 +1,23 @@ +name; + } + + public function firstRole() : MessageRole { + return $this->messages->firstRole(); + } + + public function lastRole() : MessageRole { + return $this->messages->lastRole(); + } + + public function isEmpty() : bool { + return $this->messages->isEmpty(); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Section/HandlesMutation.php b/src/Core/Messages/Traits/Section/HandlesMutation.php new file mode 100644 index 00000000..061a65fd --- /dev/null +++ b/src/Core/Messages/Traits/Section/HandlesMutation.php @@ -0,0 +1,21 @@ +messages->add($message); + } + + public function appendMessages(array|Messages $messages) : void { + $this->messages->appendMessages($messages); + } + + public function prependMessages(array|Messages $messages) : void { + $this->messages->prependMessages($messages); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Traits/Section/HandlesTransformation.php b/src/Core/Messages/Traits/Section/HandlesTransformation.php new file mode 100644 index 00000000..a81dd916 --- /dev/null +++ b/src/Core/Messages/Traits/Section/HandlesTransformation.php @@ -0,0 +1,56 @@ +messages; + } + + /** + * @param array|null $context + * @return array + */ + public function toArray(array $context = null) : array { + return $this->renderMessages( + $this->messages->toArray(), + $context + ); + } + + /** + * @param ClientType $clientType + * @param array|null $context + * @return array + */ + public function toNativeArray(ClientType $clientType, array $context = null) : array { + $array = $this->renderMessages( + $this->toArray(), + $context, + ); + return ChatFormat::mapToTargetAPI( + clientType: $clientType, + messages: $array, + ); + } + + /** + * @param array|null $context + * @param string $separator + * @return array + */ + public function toString(array $context = [], string $separator = "\n") : string { + $text = array_reduce( + $this->messages->toArray(), + fn($carry, $message) => $carry . $message['content'] . $separator, + ); + return $this->renderString($text, $context); + } +} diff --git a/src/Core/Messages/Utils/ChatFormat.php b/src/Core/Messages/Utils/ChatFormat.php new file mode 100644 index 00000000..610255bd --- /dev/null +++ b/src/Core/Messages/Utils/ChatFormat.php @@ -0,0 +1,37 @@ +value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'assistant', 'tool' => 'user'], + ClientType::Cohere->value => ['user' => 'USER', 'assistant' => 'CHATBOT', 'system' => 'CHATBOT', 'tool' => 'USER'], + ClientType::Mistral->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], + ClientType::OpenAI->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], + ClientType::OpenAICompatible->value => ['user' => 'user', 'assistant' => 'assistant', 'system' => 'system', 'tool' => 'tool'], + ]; + + $keyMap = [ + ClientType::Anthropic->value => 'content', + ClientType::Cohere->value => 'message', + ClientType::Mistral->value => 'content', + ClientType::OpenAICompatible->value => 'content', + ClientType::OpenAI->value => 'content', + ]; + + $roles = $roleMap[$clientType->value]; + $key = $keyMap[$clientType->value]; + + return array_map(function($message) use ($roles, $key) { + return ['role' => $roles[$message['role']], $key => $message['content']]; + }, $messages); + } +} \ No newline at end of file diff --git a/src/Core/Messages/Utils/MessageBuilder.php b/src/Core/Messages/Utils/MessageBuilder.php new file mode 100644 index 00000000..99f02cfc --- /dev/null +++ b/src/Core/Messages/Utils/MessageBuilder.php @@ -0,0 +1,58 @@ +clientClass = $clientClass; + $instance->mode = $mode; + $instance->messages = $messages; + $instance->responseModel = $responseModel; + $instance->dataAcknowledgedPrompt = $dataAcknowledgedPrompt; + $instance->prompt = $prompt; + $instance->examples = $examples; + return $instance->makeExtractionRequest(); + } + + // INTERNAL TOOLS /////////////////////////////////////////////////////////////////// + + private function makeExtractionRequest() : array { + // get body creation method based on client + $builder = $this->getBuilder($this->clientClass); + // get the parts of body specific to the client + $script = $this->makeMessages(); + $script->setContext(['json_schema' => $this->responseModel->toJsonSchema()]); + $body = $builder($script); + + // filter out empty values + $body = array_filter($body); + return $body; + } +} diff --git a/src/Core/Messages/Utils/MessageBuilder/HandlesApiProviders.php b/src/Core/Messages/Utils/MessageBuilder/HandlesApiProviders.php new file mode 100644 index 00000000..90941713 --- /dev/null +++ b/src/Core/Messages/Utils/MessageBuilder/HandlesApiProviders.php @@ -0,0 +1,78 @@ + fn($script) => $this->anthropic($script), + CohereClient::class => fn($script) => $this->cohere($script), + MistralClient::class => fn($script) => $this->mistral($script), + OpenAIClient::class, + AzureClient::class => fn($script) => $this->openAI($script), + AnyscaleClient::class, + FireworksAIClient::class, + GroqClient::class, + OllamaClient::class, + OpenRouterClient::class, + TogetherAIClient::class => fn($script) => $this->openAILike($script), + default => fn() => [], + }; + } + + private function anthropic(Script $script) : array { + return [ + 'system' => $script->select('system')->toString(), + 'messages' => $script + ->select(['messages', 'data_ack', 'command', 'examples']) + ->toNativeArray(ClientType::Anthropic), + ]; + } + + private function cohere(Script $script) : array { + return array_filter([ + 'preamble' => $script->select('system')->toString(), + 'chat_history' => $script->select('messages')->toNativeArray(ClientType::Cohere), + 'message' => $script->select(['command', 'examples'])->toString(), + ]); + } + + private function mistral(Script $script) : array { + return [ + 'messages' => $script + ->select(['system', 'command', 'examples', 'messages']) + ->toNativeArray(ClientType::Mistral), + ]; + } + + private function openAI(Script $script) : array { + return [ + 'messages' => $script + ->select(['system', 'messages', 'data_ack', 'command', 'examples']) + ->toNativeArray(ClientType::OpenAI), + ]; + } + + private function openAILike(Script $script) : array { + return [ + 'messages' => $script + ->select(['system', 'messages', 'data_ack', 'command', 'examples']) + ->toNativeArray(ClientType::OpenAICompatible), + ]; + } +} diff --git a/src/Core/Messages/Utils/MessageBuilder/MakesMessages.php b/src/Core/Messages/Utils/MessageBuilder/MakesMessages.php new file mode 100644 index 00000000..678aeebc --- /dev/null +++ b/src/Core/Messages/Utils/MessageBuilder/MakesMessages.php @@ -0,0 +1,65 @@ +messages)) { + throw new Exception('Messages cannot be empty - you have to provide the content for processing.'); + } + + $script = new Script(); + $script->addSection(new Section('system', 'System messages')); + $script->addSection(new Section('messages', 'Chat messages')); + $script->addSection(new Section('command', 'Command prompt')); + $script->addSection(new Section('data_ack', 'Data acknowledged prompt')); + + // SYSTEM SECTION + $index = 0; + foreach ($this->messages as $message) { + if ($message['role'] !== 'system') { + break; + } + $script->section('system')->add(['role' => 'system', 'content' => $message['content']]); + $index++; + } + + // DATA ACK SECTION + $script->section('data_ack')->add([ + 'role' => 'assistant', + 'content' => $this->dataAcknowledgedPrompt + ]); + + // MESSAGES SECTION + $script->section('messages')->appendMessages(array_slice($this->messages, $index)); + + // PROMPT SECTION + if (!empty($this->prompt)) { + $script->section('command')->add([ + 'role' => 'user', + 'content' => $this->prompt + ]); + } + + // EXAMPLES SECTION + if (!empty($this->examples)) { + $script->addSection(new Section('examples', 'Inference examples')); + foreach ($this->examples as $item) { + $example = match(true) { + is_array($item) => Example::fromArray($item), + is_string($item) => Example::fromJson($item), + $item instanceof Example => $item, + }; + $script->section('examples')->appendMessages($example->toMessages()); + } + } + + return $script; + } +} diff --git a/src/Data/Example.php b/src/Data/Example.php index f661536d..a5302f07 100644 --- a/src/Data/Example.php +++ b/src/Data/Example.php @@ -55,14 +55,21 @@ static public function fromData(mixed $data, array|object $output) : self { static public function fromJson(string $json) : self { $data = Json::parse($json); + if (!isset($data['input']) || !isset($data['output'])) { + throw new Exception("Invalid JSON data for example - missing `input` or `output` fields"); + } + return self::fromArray($data); + } + + static public function fromArray(array $data) : self { if (!isset($data['input']) || !isset($data['output'])) { throw new Exception("Invalid JSON data for example - missing `input` or `output` fields"); } return new self( input: $data['input'], output: $data['output'], - uid: $data['id'] ?? null, - createdAt: new DateTimeImmutable($data['created_at']) ?? null + uid: $data['id'] ?? Uuid::uuid4(), + createdAt: isset($data['created_at']) ? new DateTimeImmutable($data['created_at']) : new DateTimeImmutable(), ); } @@ -85,6 +92,13 @@ public function toString() : string { ]); } + public function toMessages() : array { + return [ + ['role' => 'user', 'content' => $this->input()], + ['role' => 'assistant', 'content' => $this->outputString()], + ]; + } + public function toJson() : string { return Json::encode([ 'id' => $this->uid, diff --git a/src/Data/Request.php b/src/Data/Request.php index df9d7ed9..1da8e847 100644 --- a/src/Data/Request.php +++ b/src/Data/Request.php @@ -75,23 +75,4 @@ public function mode() : Mode { public function copy(array $messages) : self { return (clone $this)->withMessages($messages); } - - public function toCompilerInput() : CompilerInput { - return new Input( - mode: $this->mode, - model: $this->model, - messages: $this->messages, - responseModel: $this->responseModel, - options: $this->options, - examples: $this->examples, - feedback: $this->feedback, - tools: $this->tools, - toolChoice: $this->toolChoice, - prompt: $this->prompt, - retryPrompt: $this->retryPrompt, - signature: $this->signature, - inputSchema: $this->inputSchema, - inputData: $this->inputData, - ); - } } diff --git a/src/Data/Traits/Request/HandlesApiRequestFactory.php b/src/Data/Traits/Request/HandlesApiRequestFactory.php index b571359c..f4b52dc0 100644 --- a/src/Data/Traits/Request/HandlesApiRequestFactory.php +++ b/src/Data/Traits/Request/HandlesApiRequestFactory.php @@ -26,8 +26,7 @@ public function toApiRequest() : ApiRequest { if (empty($this->option('max_tokens'))) { $this->setOption('max_tokens', $this->client->defaultMaxTokens); } - $request = $this->fromRequest(); - return $request; + return $this->fromRequest(); } // INTERNAL //////////////////////////////////////////////////////////////////////////////////////////// diff --git a/src/Data/Traits/Request/HandlesPrompts.php b/src/Data/Traits/Request/HandlesPrompts.php index 26770657..637161ad 100644 --- a/src/Data/Traits/Request/HandlesPrompts.php +++ b/src/Data/Traits/Request/HandlesPrompts.php @@ -2,16 +2,15 @@ namespace Cognesy\Instructor\Data\Traits\Request; -use Cognesy\Instructor\Core\MessageBuilder; +use Cognesy\Instructor\Core\Messages\Utils\MessageBuilder; use Cognesy\Instructor\Enums\Mode; -use Cognesy\Instructor\Utils\Template; use Exception; trait HandlesPrompts { private array $defaultPrompts = [ - Mode::MdJson->value => "\nRespond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n{json_schema}\n", - Mode::Json->value => "\nRespond correctly with strict JSON object. Response must follow JSONSchema:\n{json_schema}\n", + Mode::MdJson->value => "\nRespond correctly with strict JSON object containing extracted data within a ```json {} ``` codeblock. Object must validate against this JSONSchema:\n<|json_schema|>\n", + Mode::Json->value => "\nRespond correctly with strict JSON object. Response must follow JSONSchema:\n<|json_schema|>\n", Mode::Tools->value => "\nExtract correct and accurate data from the input using provided tools. Response must be JSON object following provided tool schema.\n", ]; private string $dataAcknowledgedPrompt = "Input acknowledged."; @@ -29,7 +28,7 @@ public function withPrompt(string $prompt) : self { // INTERNAL //////////////////////////////////////////////////////////////////////////////////////////// protected function makeOptions() : array { - if (empty($this->messages)) { + if (empty($this->messages())) { throw new Exception('Messages cannot be empty - you have to provide the content for processing.'); } @@ -43,10 +42,9 @@ protected function makeOptions() : array { messages: $this->messages(), responseModel: $this->responseModel(), dataAcknowledgedPrompt: $this->dataAcknowledgedPrompt, - prompt: Template::render($this->prompt(), ['json_schema' => $this->jsonSchema()]), + prompt: $this->prompt(), examples: $this->examples(), ); - return array_merge( $this->options, $body, diff --git a/src/Extras/Module/Addons/Predict/Predict.php b/src/Extras/Module/Addons/Predict/Predict.php index df769178..b26a2040 100644 --- a/src/Extras/Module/Addons/Predict/Predict.php +++ b/src/Extras/Module/Addons/Predict/Predict.php @@ -19,14 +19,15 @@ class Predict extends DynamicModule { protected Instructor $instructor; - protected string $prompt; + protected string $predictionPrompt; + protected string $extractionPrompt; protected $options = []; protected $model = 'gpt-4o'; protected $mode = Mode::Tools; protected array $examples = []; protected int $maxRetries = 3; - protected string $defaultPrompt = 'Your job is to infer output argument values in input data based on specification: {signature} {description}'; + protected string $defaultPredictionPrompt = 'Your job is to infer output argument values in input data based on specification: {signature} {description}'; protected string|Signature|HasSignature $defaultSignature; @@ -35,12 +36,13 @@ class Predict extends DynamicModule public function __construct( string|Signature|HasSignature $signature, Instructor $instructor, - string $model = 'gpt-4o', + string $model = '', int $maxRetries = 3, array $options = [], array $examples = [], - string $prompt = '', - Mode $mode = Mode::Tools, + string $predictionPrompt = '', + string $extractionPrompt = '', + Mode $mode = Mode::Json, ) { if ($signature instanceof HasSignature) { $this->signatureCarrier = $signature; @@ -54,7 +56,8 @@ public function __construct( $this->maxRetries = $maxRetries; $this->options = $options; $this->examples = $examples; - $this->prompt = $prompt; + $this->predictionPrompt = $predictionPrompt; + $this->extractionPrompt = $extractionPrompt; $this->mode = $mode; } @@ -98,18 +101,18 @@ public function forward(array $args, object $targetObject): mixed { maxRetries: $this->maxRetries, options: $this->options, examples: $this->examples, - prompt: $this->prompt(), + prompt: $this->extractionPrompt, mode: $this->mode, ); return $response; } - public function prompt() : string { + protected function predictionPrompt() : string { if (empty($this->prompt)) { - $this->prompt = $this->renderPrompt($this->defaultPrompt); + $this->predictionPrompt = $this->renderPrompt($this->defaultPredictionPrompt); } - return $this->prompt; + return $this->predictionPrompt; } // INTERNAL //////////////////////////////////////////////////////////////////////////////////// @@ -123,7 +126,7 @@ private function toMessages(string|array|object $input) : array { default => json_encode($input), // wrap in json }; return [ - ['role' => 'user', 'content' => $this->prompt()], + ['role' => 'user', 'content' => $this->predictionPrompt()], ['role' => 'assistant', 'content' => 'Provide input data.'], ['role' => 'user', 'content' => $content] ]; diff --git a/src/Utils/Arrays.php b/src/Utils/Arrays.php index 958602ac..c2f711ac 100644 --- a/src/Utils/Arrays.php +++ b/src/Utils/Arrays.php @@ -6,6 +6,15 @@ class Arrays { + public static function mergeNull(?array $array1, ?array $array2): ?array { + return match(true) { + is_null($array1) && is_null($array2) => null, + is_null($array1) => $array2, + is_null($array2) => $array1, + default => array_merge($array1, $array2), + }; + } + public static function unset(array $array, array|string $fields) : array { if (!is_array($fields)) { $fields = [$fields]; diff --git a/src/Utils/Template.php b/src/Utils/Template.php index 7f977e85..3b9e0310 100644 --- a/src/Utils/Template.php +++ b/src/Utils/Template.php @@ -2,18 +2,99 @@ namespace Cognesy\Instructor\Utils; +use Cognesy\Instructor\Core\Messages\Message; +use Cognesy\Instructor\Core\Messages\Messages; +use InvalidArgumentException; + class Template { - public static function render(string $template, array $data): string { - $keys = array_map(fn($key) => '{'.$key.'}', array_keys($data)); - $normalized = []; - foreach ($data as $key => $value) { - $normalized[$key] = match (true) { - is_array($value) || is_object($value) => Json::encode($value), + private array $values = []; + private array $keys = []; + + public function __construct( + array $context = [] + ) { + if (empty($context)) { + return; + } + $materializedContext = $this->materializeContext($context); + $this->values = array_values($materializedContext); + $this->keys = array_map( + fn($key) => $this->varPattern($key), + array_keys($materializedContext) + ); + } + + public static function render(string $template, array $context) : string { + return (new Template($context))->renderString($template); + } + + public function renderString(string $template): string { + // find all keys in the template + $keys = $this->findVars($template); + // find keys missing from $this->keys + $missingKeys = array_diff($keys, $this->keys); + // remove missing key strings from the template + $template = str_replace($missingKeys, '', $template); + return str_replace($this->keys, $this->values, $template); + } + + public function renderArray(array $rows, string $field = 'content'): array { + return array_map( + fn($item) => $this->renderString($item[$field] ?? ''), + $rows + ); + } + + public function renderMessage(array|Message $message) : array { + $rendered = match(true) { + is_array($message) => ['role' => $message['role'], 'content' => $this->renderString($message['content'])], + $message instanceof Message => ['role' => $message->role, 'content' => $this->renderString($message->content)], + default => throw new InvalidArgumentException('Invalid message type'), + }; + return $rendered; + } + + public function renderMessages(array|Messages $messages) : array { + return array_map( + fn($message) => $this->renderMessage($message), + is_array($messages) ? $messages : $messages->toArray() + ); + } + + // OVERRIDEABLE ////////////////////////////////////////////////////////////// + + protected function varPattern(string $key) : string { + return '<|' . $key . '|>'; + } + + protected function findVars(string $template) : array { + $matches = []; + // replace {xxx} pattern with <|xxx|> pattern match + preg_match_all('/<\|([^|]+)\|>/', $template, $matches); + return $matches[0]; + } + + // INTERNAL ////////////////////////////////////////////////////////////////// + + private function materializeContext(array $context) : array { + $contextValues = []; + foreach ($context as $key => $value) { + $value = match (true) { + is_scalar($value) => $value, + is_array($value) => Json::encode($value), + is_callable($value) => $value($key, $context), + is_object($value) && method_exists($value, 'toString') => $value->toString(), + is_object($value) && method_exists($value, 'toJson') => $value->toJson(), + is_object($value) && method_exists($value, 'toArray') => Json::encode($value->toArray()), + is_object($value) && method_exists($value, 'toSchema') => Json::encode($value->toSchema()), + is_object($value) && method_exists($value, 'toOutputSchema') => Json::encode($value->toOutputSchema()), + is_object($value) && property_exists($value, 'value') => $value->value(), + is_object($value) => Json::encode($value), default => $value, }; + $contextValues[$key] = $value; } - $values = array_values($normalized); - return trim(str_replace($keys, $values, $template)); + return $contextValues; } } diff --git a/tests/Feature/Script/ScriptTest.php b/tests/Feature/Script/ScriptTest.php new file mode 100644 index 00000000..ce7f590e --- /dev/null +++ b/tests/Feature/Script/ScriptTest.php @@ -0,0 +1,174 @@ +setContext([ + 'key-1' => 'value-1', + 'key-2' => 'value-2', + ]); + + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); + $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); + $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + + $messages = $script->toArray(); + + expect(count($messages))->toBe(6); + expect($messages[0]['role'])->toBe('user'); + expect($messages[0]['content'])->toBe('content-1'); + expect($messages[1]['role'])->toBe('assistant'); + expect($messages[1]['content'])->toBe('content-2'); + expect($messages[2]['role'])->toBe('user'); + expect($messages[2]['content'])->toBe('content-3 value-1'); + expect($messages[3]['role'])->toBe('user'); + expect($messages[3]['content'])->toBe('content-4'); + expect($messages[4]['role'])->toBe('assistant'); + expect($messages[4]['content'])->toBe('content-5'); + expect($messages[5]['role'])->toBe('user'); + expect($messages[5]['content'])->toBe('content-6 value-2'); +}); + + +it('selects sections from script', function () { + $script = new Script( + new Section('section-1'), + new Section('section-2'), + ); + $script->setContext([ + 'key-1' => 'value-1', + 'key-2' => 'value-2', + ]); + + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); + $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); + $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6']); + + $script->section('section-3')->add(['role' => 'user', 'content' => 'content-7']); + $script->section('section-3')->add(['role' => 'assistant', 'content' => 'content-8']); + $script->section('section-3')->add(['role' => 'user', 'content' => 'content-9 {key-2}']); + + $messages = $script->select(['section-3', 'section-1'])->toArray(); + + expect(count($messages))->toBe(6); + expect($messages[0]['role'])->toBe('user'); + expect($messages[0]['content'])->toBe('content-7'); + expect($messages[1]['role'])->toBe('assistant'); + expect($messages[1]['content'])->toBe('content-8'); + expect($messages[2]['role'])->toBe('user'); + expect($messages[2]['content'])->toBe('content-9 value-2'); + expect($messages[3]['role'])->toBe('user'); + expect($messages[3]['content'])->toBe('content-1'); + expect($messages[4]['role'])->toBe('assistant'); + expect($messages[4]['content'])->toBe('content-2'); + expect($messages[5]['role'])->toBe('user'); + expect($messages[5]['content'])->toBe('content-3 value-1'); +}); + + +it('translates messages to native format - Cohere', function () { + $script = new Script( + new Section('section-1'), + new Section('section-2'), + ); + $script->setContext([ + 'key-1' => 'value-1', + 'key-2' => 'value-2', + ]); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); + $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); + $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + + $messages = $script->select(['section-2', 'section-1'])->toNativeArray(ClientType::Cohere, ['section-2', 'section-1']); + + expect(count($messages))->toBe(6); + expect($messages[0]['role'])->toBe('USER'); + expect($messages[0]['message'])->toBe('content-4'); + expect($messages[1]['role'])->toBe('CHATBOT'); + expect($messages[1]['message'])->toBe('content-5'); + expect($messages[2]['role'])->toBe('USER'); + expect($messages[2]['message'])->toBe('content-6 value-2'); + expect($messages[3]['role'])->toBe('USER'); + expect($messages[3]['message'])->toBe('content-1'); + expect($messages[4]['role'])->toBe('CHATBOT'); + expect($messages[4]['message'])->toBe('content-2'); + expect($messages[5]['role'])->toBe('USER'); + expect($messages[5]['message'])->toBe('content-3 value-1'); +}); + +it('translates messages to native format - Anthropic', function () { + $script = new Script( + new Section('section-1'), + new Section('section-2'), + ); + $script->setContext([ + 'key-1' => 'value-1', + 'key-2' => 'value-2', + ]); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); + $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); + $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + + $messages = $script->select(['section-2', 'section-1'])->toNativeArray(ClientType::Anthropic); + + expect(count($messages))->toBe(6); + expect($messages[0]['role'])->toBe('user'); + expect($messages[0]['content'])->toBe('content-4'); + expect($messages[1]['role'])->toBe('assistant'); + expect($messages[1]['content'])->toBe('content-5'); + expect($messages[2]['role'])->toBe('user'); + expect($messages[2]['content'])->toBe('content-6 value-2'); + expect($messages[3]['role'])->toBe('user'); + expect($messages[3]['content'])->toBe('content-1'); + expect($messages[4]['role'])->toBe('assistant'); + expect($messages[4]['content'])->toBe('content-2'); + expect($messages[5]['role'])->toBe('user'); + expect($messages[5]['content'])->toBe('content-3 value-1'); +}); + +it('translates messages to string', function () { + $script = new Script( + new Section('section-1'), + new Section('section-2'), + ); + $script->setContext([ + 'key-1' => 'value-1', + 'key-2' => 'value-2', + ]); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-1']); + $script->section('section-1')->add(['role' => 'assistant', 'content' => 'content-2']); + $script->section('section-1')->add(['role' => 'user', 'content' => 'content-3 {key-1}']); + + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-4']); + $script->section('section-2')->add(['role' => 'assistant', 'content' => 'content-5']); + $script->section('section-2')->add(['role' => 'user', 'content' => 'content-6 {key-2}']); + + $text = $script->select(['section-2', 'section-1'])->toString(); + expect($text)->toBe("content-4\ncontent-5\ncontent-6 value-2\ncontent-1\ncontent-2\ncontent-3 value-1\n"); + + $text = $script->select('section-1')->toString(); + expect($text)->toBe("content-1\ncontent-2\ncontent-3 value-1\n"); +}); \ No newline at end of file