From 7ea47ed00e3d03d576a8d1cb89dd13aed113264d Mon Sep 17 00:00:00 2001 From: ddebowczyk Date: Wed, 2 Oct 2024 18:46:12 +0200 Subject: [PATCH] Added unified context caching to direct LLM inference --- .../examples/advanced/context_cache.mdx | 10 ++- .../examples/advanced/context_cache_llm.mdx | 76 +++++++++++++++++++ docs/mint.json | 1 + examples/A02_Advanced/ContextCacheLLM/run.php | 76 +++++++++++++++++++ examples/A02_Advanced/ContextCaching/run.php | 10 ++- src/Extras/LLM/Data/CachedContext.php | 36 +++++++++ src/Extras/LLM/Drivers/AnthropicDriver.php | 36 +++++++++ src/Extras/LLM/Drivers/CohereV1Driver.php | 13 ++++ src/Extras/LLM/Drivers/CohereV2Driver.php | 2 - src/Extras/LLM/Drivers/GeminiDriver.php | 13 ++++ src/Extras/LLM/Drivers/OpenAIDriver.php | 18 ++++- src/Extras/LLM/Inference.php | 16 +++- src/Extras/LLM/InferenceRequest.php | 8 +- 13 files changed, 301 insertions(+), 14 deletions(-) create mode 100644 docs/cookbook/examples/advanced/context_cache_llm.mdx create mode 100644 examples/A02_Advanced/ContextCacheLLM/run.php create mode 100644 src/Extras/LLM/Data/CachedContext.php diff --git a/docs/cookbook/examples/advanced/context_cache.mdx b/docs/cookbook/examples/advanced/context_cache.mdx index 6b24c48c..daf63567 100644 --- a/docs/cookbook/examples/advanced/context_cache.mdx +++ b/docs/cookbook/examples/advanced/context_cache.mdx @@ -9,7 +9,11 @@ Instructor offers a simplified way to work with LLM providers' APIs supporting c (currently only Anthropic API), so you can focus on your business logic while still being able to take advantage of lower latency and costs. -> **Note:** Context caching is only available for Anthropic API. +> **Note 1:** Instructor supports context caching for Anthropic API and OpenAI API. + +> **Note 2:** Context caching is automatic for all OpenAI API calls. Read more +> in the [OpenAI API documentation](https://platform.openai.com/docs/guides/prompt-caching). + ## Example @@ -46,7 +50,7 @@ class Project { public array $applications; #[Description('Explain the purpose of the project and the domain specific problems it solves')] public string $description; - #[Description('Example code as Markdown fragment, demonstrating domain specific application of the library')] + #[Description('Example code in Markdown demonstrating domain specific application of the library')] public string $code; } ?> @@ -93,7 +97,7 @@ which results in faster processing and lower costs. ```php respond( - messages: "Describe the project in a way compelling to my audience: lead gen software vendor.", + messages: "Describe the project in a way compelling to my audience: boutique CMS consulting company owner.", responseModel: Project::class, options: ['max_tokens' => 4096], mode: Mode::Json, diff --git a/docs/cookbook/examples/advanced/context_cache_llm.mdx b/docs/cookbook/examples/advanced/context_cache_llm.mdx new file mode 100644 index 00000000..6f66d8ad --- /dev/null +++ b/docs/cookbook/examples/advanced/context_cache_llm.mdx @@ -0,0 +1,76 @@ +--- +title: 'Context caching' +docname: 'context_cache_llm' +--- + +## Overview + +Instructor offers a simplified way to work with LLM providers' APIs supporting caching +(currently only Anthropic API), so you can focus on your business logic while still being +able to take advantage of lower latency and costs. + +> **Note 1:** Instructor supports context caching for Anthropic API and OpenAI API. + +> **Note 2:** Context caching is automatic for all OpenAI API calls. Read more +> in the [OpenAI API documentation](https://platform.openai.com/docs/guides/prompt-caching). + +## Example + +When you need to process multiple requests with the same context, you can use context +caching to improve performance and reduce costs. + +In our example we will be analyzing the README.md file of this Github project and +generating its summary for 2 target audiences. + + +```php +add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); + +use Cognesy\Instructor\Extras\LLM\Inference; +use Cognesy\Instructor\Utils\Str; + +$content = file_get_contents(__DIR__ . '/../../../README.md'); + +$inference = (new Inference)->withConnection('anthropic')->withCachedContext( + messages: [ + ['role' => 'user', 'content' => 'Here is content of README.md file'], + ['role' => 'user', 'content' => $content], + ['role' => 'user', 'content' => 'Generate short, very domain specific pitch of the project described in README.md'], + ['role' => 'assistant', 'content' => 'For whom do you want to generate the pitch?'], + ], +); + +$response = $inference->create( + messages: [['role' => 'user', 'content' => 'CTO of lead gen software vendor']], + options: ['max_tokens' => 256], +)->toApiResponse(); + +print("----------------------------------------\n"); +print("\n# Summary for CTO of lead gen vendor\n"); +print(" ($response->cacheReadTokens tokens read from cache)\n\n"); +print("----------------------------------------\n"); +print($response->content . "\n"); + +assert(!empty($response->content)); +assert(Str::contains($response->content, 'Instructor')); +assert(Str::contains($response->content, 'lead', false)); + +$response2 = $inference->create( + messages: [['role' => 'user', 'content' => 'CIO of insurance company']], + options: ['max_tokens' => 256], +)->toApiResponse(); + +print("----------------------------------------\n"); +print("\n# Summary for CIO of insurance company\n"); +print(" ($response2->cacheReadTokens tokens read from cache)\n\n"); +print("----------------------------------------\n"); +print($response2->content . "\n"); + +assert(!empty($response2->content)); +assert(Str::contains($response2->content, 'Instructor')); +assert(Str::contains($response2->content, 'insurance', false)); +//assert($response2->cacheReadTokens > 0); +?> +``` diff --git a/docs/mint.json b/docs/mint.json index fba011ec..8604ab42 100644 --- a/docs/mint.json +++ b/docs/mint.json @@ -145,6 +145,7 @@ { "group": "Advanced", "pages": [ + "cookbook/examples/advanced/context_cache_llm", "cookbook/examples/advanced/context_cache", "cookbook/examples/advanced/custom_client", "cookbook/examples/advanced/custom_prompts", diff --git a/examples/A02_Advanced/ContextCacheLLM/run.php b/examples/A02_Advanced/ContextCacheLLM/run.php new file mode 100644 index 00000000..6f66d8ad --- /dev/null +++ b/examples/A02_Advanced/ContextCacheLLM/run.php @@ -0,0 +1,76 @@ +--- +title: 'Context caching' +docname: 'context_cache_llm' +--- + +## Overview + +Instructor offers a simplified way to work with LLM providers' APIs supporting caching +(currently only Anthropic API), so you can focus on your business logic while still being +able to take advantage of lower latency and costs. + +> **Note 1:** Instructor supports context caching for Anthropic API and OpenAI API. + +> **Note 2:** Context caching is automatic for all OpenAI API calls. Read more +> in the [OpenAI API documentation](https://platform.openai.com/docs/guides/prompt-caching). + +## Example + +When you need to process multiple requests with the same context, you can use context +caching to improve performance and reduce costs. + +In our example we will be analyzing the README.md file of this Github project and +generating its summary for 2 target audiences. + + +```php +add('Cognesy\\Instructor\\', __DIR__ . '../../src/'); + +use Cognesy\Instructor\Extras\LLM\Inference; +use Cognesy\Instructor\Utils\Str; + +$content = file_get_contents(__DIR__ . '/../../../README.md'); + +$inference = (new Inference)->withConnection('anthropic')->withCachedContext( + messages: [ + ['role' => 'user', 'content' => 'Here is content of README.md file'], + ['role' => 'user', 'content' => $content], + ['role' => 'user', 'content' => 'Generate short, very domain specific pitch of the project described in README.md'], + ['role' => 'assistant', 'content' => 'For whom do you want to generate the pitch?'], + ], +); + +$response = $inference->create( + messages: [['role' => 'user', 'content' => 'CTO of lead gen software vendor']], + options: ['max_tokens' => 256], +)->toApiResponse(); + +print("----------------------------------------\n"); +print("\n# Summary for CTO of lead gen vendor\n"); +print(" ($response->cacheReadTokens tokens read from cache)\n\n"); +print("----------------------------------------\n"); +print($response->content . "\n"); + +assert(!empty($response->content)); +assert(Str::contains($response->content, 'Instructor')); +assert(Str::contains($response->content, 'lead', false)); + +$response2 = $inference->create( + messages: [['role' => 'user', 'content' => 'CIO of insurance company']], + options: ['max_tokens' => 256], +)->toApiResponse(); + +print("----------------------------------------\n"); +print("\n# Summary for CIO of insurance company\n"); +print(" ($response2->cacheReadTokens tokens read from cache)\n\n"); +print("----------------------------------------\n"); +print($response2->content . "\n"); + +assert(!empty($response2->content)); +assert(Str::contains($response2->content, 'Instructor')); +assert(Str::contains($response2->content, 'insurance', false)); +//assert($response2->cacheReadTokens > 0); +?> +``` diff --git a/examples/A02_Advanced/ContextCaching/run.php b/examples/A02_Advanced/ContextCaching/run.php index 6b24c48c..daf63567 100644 --- a/examples/A02_Advanced/ContextCaching/run.php +++ b/examples/A02_Advanced/ContextCaching/run.php @@ -9,7 +9,11 @@ (currently only Anthropic API), so you can focus on your business logic while still being able to take advantage of lower latency and costs. -> **Note:** Context caching is only available for Anthropic API. +> **Note 1:** Instructor supports context caching for Anthropic API and OpenAI API. + +> **Note 2:** Context caching is automatic for all OpenAI API calls. Read more +> in the [OpenAI API documentation](https://platform.openai.com/docs/guides/prompt-caching). + ## Example @@ -46,7 +50,7 @@ class Project { public array $applications; #[Description('Explain the purpose of the project and the domain specific problems it solves')] public string $description; - #[Description('Example code as Markdown fragment, demonstrating domain specific application of the library')] + #[Description('Example code in Markdown demonstrating domain specific application of the library')] public string $code; } ?> @@ -93,7 +97,7 @@ class Project { ```php respond( - messages: "Describe the project in a way compelling to my audience: lead gen software vendor.", + messages: "Describe the project in a way compelling to my audience: boutique CMS consulting company owner.", responseModel: Project::class, options: ['max_tokens' => 4096], mode: Mode::Json, diff --git a/src/Extras/LLM/Data/CachedContext.php b/src/Extras/LLM/Data/CachedContext.php new file mode 100644 index 00000000..ff999e74 --- /dev/null +++ b/src/Extras/LLM/Data/CachedContext.php @@ -0,0 +1,36 @@ +messages = ['role' => 'user', 'content' => $messages]; + } + } + + public function merged( + string|array $messages = [], + array $tools = [], + string|array $toolChoice = [], + array $responseFormat = [], + ) { + if (is_string($messages) && !empty($messages)) { + $messages = ['role' => 'user', 'content' => $messages]; + } + return new CachedContext( + array_merge($this->messages, $messages), + empty($tools) ? $this->tools : $tools, + empty($toolChoice) ? $this->toolChoice : $toolChoice, + empty($responseFormat) ? $this->responseFormat : $responseFormat, + ); + } +} diff --git a/src/Extras/LLM/Drivers/AnthropicDriver.php b/src/Extras/LLM/Drivers/AnthropicDriver.php index 9bfed543..851174c5 100644 --- a/src/Extras/LLM/Drivers/AnthropicDriver.php +++ b/src/Extras/LLM/Drivers/AnthropicDriver.php @@ -28,6 +28,7 @@ public function __construct( // REQUEST ////////////////////////////////////////////// public function handle(InferenceRequest $request) : ResponseInterface { + $request = $this->withCachedContext($request); return $this->httpClient->handle( url: $this->getEndpointUrl($request), headers: $this->getRequestHeaders(), @@ -245,4 +246,39 @@ private function makeContent(array $data) : string { private function makeDelta(array $data) : string { return $data['delta']['text'] ?? $data['delta']['partial_json'] ?? ''; } + + private function withCachedContext(InferenceRequest $request): InferenceRequest { + if (!isset($request->cachedContext)) { + return $request; + } + + $cloned = clone $request; + + $cloned->messages = empty($request->cachedContext->messages) + ? $request->messages + : array_merge($this->setCacheMarker($request->cachedContext->messages), $request->messages); + $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; + $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; + $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; + return $cloned; + } + + private function setCacheMarker(array $messages): array { + $lastIndex = count($messages) - 1; + $lastMessage = $messages[$lastIndex]; + + if (is_array($lastMessage['content'])) { + $subIndex = count($lastMessage['content']) - 1; + $lastMessage['content'][$subIndex]['cache_control'] = ["type" => "ephemeral"]; + } else { + $lastMessage['content'] = [[ + 'type' => $lastMessage['type'] ?? 'text', + 'text' => $lastMessage['content'] ?? '', + 'cache_control' => ["type" => "ephemeral"], + ]]; + } + + $messages[$lastIndex] = $lastMessage; + return $messages; + } } diff --git a/src/Extras/LLM/Drivers/CohereV1Driver.php b/src/Extras/LLM/Drivers/CohereV1Driver.php index f3585da0..195b5507 100644 --- a/src/Extras/LLM/Drivers/CohereV1Driver.php +++ b/src/Extras/LLM/Drivers/CohereV1Driver.php @@ -27,6 +27,7 @@ public function __construct( // REQUEST ////////////////////////////////////////////// public function handle(InferenceRequest $request) : ResponseInterface { + $request = $this->withCachedContext($request); return $this->httpClient->handle( url: $this->getEndpointUrl($request), headers: $this->getRequestHeaders(), @@ -227,4 +228,16 @@ private function makeToolNameDelta(array $data) : string { private function isStreamEnd(array $data) : bool { return 'stream_end' === ($data['event_type'] ?? ''); } + + private function withCachedContext(InferenceRequest $request): InferenceRequest { + if (!isset($request->cachedContext)) { + return $request; + } + $cloned = clone $request; + $cloned->messages = array_merge($request->cachedContext->messages, $request->messages); + $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; + $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; + $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; + return $cloned; + } } diff --git a/src/Extras/LLM/Drivers/CohereV2Driver.php b/src/Extras/LLM/Drivers/CohereV2Driver.php index d738bcfb..84c710c1 100644 --- a/src/Extras/LLM/Drivers/CohereV2Driver.php +++ b/src/Extras/LLM/Drivers/CohereV2Driver.php @@ -47,8 +47,6 @@ public function toApiResponse(array $data): ApiResponse { return new ApiResponse( content: $this->makeContent($data), responseData: $data, -// toolName: $data['message']['tool_calls'][0]['function']['name'] ?? '', -// toolArgs: $data['message']['tool_calls'][0]['function']['arguments'] ?? '', toolsData: $this->makeToolsData($data), finishReason: $data['finish_reason'] ?? '', toolCalls: $this->makeToolCalls($data), diff --git a/src/Extras/LLM/Drivers/GeminiDriver.php b/src/Extras/LLM/Drivers/GeminiDriver.php index 05e0e994..b550673d 100644 --- a/src/Extras/LLM/Drivers/GeminiDriver.php +++ b/src/Extras/LLM/Drivers/GeminiDriver.php @@ -29,6 +29,7 @@ public function __construct( // REQUEST ////////////////////////////////////////////// public function handle(InferenceRequest $request) : ResponseInterface { + $request = $this->withCachedContext($request); return $this->httpClient->handle( url: $this->getEndpointUrl($request), headers: $this->getRequestHeaders(), @@ -284,4 +285,16 @@ private function makeDelta(array $data): string { ?? Json::encode($data['candidates'][0]['content']['parts'][0]['functionCall']['args'] ?? []) ?? ''; } + + private function withCachedContext(InferenceRequest $request): InferenceRequest { + if (!isset($request->cachedContext)) { + return $request; + } + $cloned = clone $request; + $cloned->messages = array_merge($request->cachedContext->messages, $request->messages); + $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; + $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; + $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; + return $cloned; + } } diff --git a/src/Extras/LLM/Drivers/OpenAIDriver.php b/src/Extras/LLM/Drivers/OpenAIDriver.php index 612e0291..5343236b 100644 --- a/src/Extras/LLM/Drivers/OpenAIDriver.php +++ b/src/Extras/LLM/Drivers/OpenAIDriver.php @@ -26,6 +26,7 @@ public function __construct( // REQUEST ////////////////////////////////////////////// public function handle(InferenceRequest $request) : ResponseInterface { + $request = $this->withCachedContext($request); return $this->httpClient->handle( url: $this->getEndpointUrl($request), headers: $this->getRequestHeaders(), @@ -91,7 +92,7 @@ public function toApiResponse(array $data): ?ApiResponse { inputTokens: $this->makeInputTokens($data), outputTokens: $this->makeOutputTokens($data), cacheCreationTokens: 0, - cacheReadTokens: 0, + cacheReadTokens: $data['usage']['prompt_tokens_details']['cached_tokens'] ?? 0, ); } @@ -108,7 +109,7 @@ public function toPartialApiResponse(array $data) : ?PartialApiResponse { inputTokens: $this->makeInputTokens($data), outputTokens: $this->makeOutputTokens($data), cacheCreationTokens: 0, - cacheReadTokens: 0, + cacheReadTokens: $data['usage']['prompt_tokens_details']['cached_tokens'] ?? 0, ); } @@ -207,4 +208,17 @@ private function makeToolNameDelta(array $data) : string { private function makeToolArgsDelta(array $data) : string { return $data['choices'][0]['delta']['tool_calls'][0]['function']['arguments'] ?? ''; } + + private function withCachedContext(InferenceRequest $request): InferenceRequest { + if (!isset($request->cachedContext)) { + return $request; + } + + $cloned = clone $request; + $cloned->messages = array_merge($request->cachedContext->messages, $request->messages); + $cloned->tools = empty($request->tools) ? $request->cachedContext->tools : $request->tools; + $cloned->toolChoice = empty($request->toolChoice) ? $request->cachedContext->toolChoice : $request->toolChoice; + $cloned->responseFormat = empty($request->responseFormat) ? $request->cachedContext->responseFormat : $request->responseFormat; + return $cloned; + } } diff --git a/src/Extras/LLM/Inference.php b/src/Extras/LLM/Inference.php index 986e5f0f..8458ec42 100644 --- a/src/Extras/LLM/Inference.php +++ b/src/Extras/LLM/Inference.php @@ -8,6 +8,7 @@ use Cognesy\Instructor\Extras\Http\Contracts\CanHandleHttp; use Cognesy\Instructor\Extras\Http\HttpClient; use Cognesy\Instructor\Extras\LLM\Contracts\CanHandleInference; +use Cognesy\Instructor\Extras\LLM\Data\CachedContext; use Cognesy\Instructor\Extras\LLM\Data\LLMConfig; use Cognesy\Instructor\Extras\LLM\Drivers\AnthropicDriver; use Cognesy\Instructor\Extras\LLM\Drivers\AzureOpenAIDriver; @@ -28,6 +29,7 @@ class Inference protected CanHandleInference $driver; protected CanHandleHttp $httpClient; protected EventDispatcher $events; + protected CachedContext $cachedContext; public function __construct( string $connection = '', @@ -97,6 +99,16 @@ public function withDebug(bool $debug = true) : self { return $this; } + public function withCachedContext( + string|array $messages = [], + array $tools = [], + string|array $toolChoice = [], + array $responseFormat = [], + ): self { + $this->cachedContext = new CachedContext($messages, $tools, $toolChoice, $responseFormat); + return $this; + } + public function create( string|array $messages = [], string $model = '', @@ -106,7 +118,9 @@ public function create( array $options = [], Mode $mode = Mode::Text ): InferenceResponse { - $request = new InferenceRequest($messages, $model, $tools, $toolChoice, $responseFormat, $options, $mode); + $request = new InferenceRequest( + $messages, $model, $tools, $toolChoice, $responseFormat, $options, $mode, $this->cachedContext ?? null + ); $this->events->dispatch(new InferenceRequested($request)); return new InferenceResponse( response: $this->driver->handle($request), diff --git a/src/Extras/LLM/InferenceRequest.php b/src/Extras/LLM/InferenceRequest.php index 63cc7824..aa2c2a17 100644 --- a/src/Extras/LLM/InferenceRequest.php +++ b/src/Extras/LLM/InferenceRequest.php @@ -2,6 +2,8 @@ namespace Cognesy\Instructor\Extras\LLM; use Cognesy\Instructor\Enums\Mode; +use Cognesy\Instructor\Extras\LLM\Data\CachedContext; + //use Cognesy\Instructor\Utils\Uuid; class InferenceRequest @@ -16,6 +18,7 @@ class InferenceRequest public array $responseFormat = []; public array $options = []; public Mode $mode = Mode::Text; + public ?CachedContext $cachedContext; public function __construct( string|array $messages = [], @@ -25,10 +28,9 @@ public function __construct( array $responseFormat = [], array $options = [], Mode $mode = Mode::Text, -// array $metadata = [], + ?CachedContext $cachedContext = null, ) { -// $this->uuid = Uuid::uuid4(); -// $this->metadata = $metadata; + $this->cachedContext = $cachedContext; $this->model = $model; $this->options = $options;