From d20cf911615dae31e74fb34ba94ad4bf6f9d6dd2 Mon Sep 17 00:00:00 2001 From: ddebowczyk Date: Sat, 5 Oct 2024 09:12:07 +0200 Subject: [PATCH] Code cleanup - RequestHandler --- examples/A01_Basics/Validation/run.php | 5 +- examples/A01_Basics/ValidationCustom/run.php | 2 +- src/Core/PartialsGenerator.php | 6 +- src/Core/RequestHandler.php | 63 ++++++------------- .../Traits/ChatTemplate/HandlesRetries.php | 25 ++++---- src/Extras/LLM/Inference.php | 1 + 6 files changed, 38 insertions(+), 64 deletions(-) diff --git a/examples/A01_Basics/Validation/run.php b/examples/A01_Basics/Validation/run.php index 4bbc62fc..a99729f7 100644 --- a/examples/A01_Basics/Validation/run.php +++ b/examples/A01_Basics/Validation/run.php @@ -41,13 +41,12 @@ class UserDetails messages: [['role' => 'user', 'content' => "you can reply to me via mail -- Jason"]], responseModel: UserDetails::class, )->get(); + dump($user); } catch(Exception $e) { $caughtException = true; } -dump($user); - -assert($user === null); +assert(!isset($user)); assert($caughtException === true); ?> ``` diff --git a/examples/A01_Basics/ValidationCustom/run.php b/examples/A01_Basics/ValidationCustom/run.php index 9c416c69..43788dd4 100644 --- a/examples/A01_Basics/ValidationCustom/run.php +++ b/examples/A01_Basics/ValidationCustom/run.php @@ -28,7 +28,7 @@ class UserDetails #[Assert\Callback] public function validateName(ExecutionContextInterface $context, mixed $payload) { if ($this->name !== strtoupper($this->name)) { - $context->buildViolation("Name must be in all uppercase letters.") + $context->buildViolation("Name must be all uppercase.") ->atPath('name') ->setInvalidValue($this->name) ->addViolation(); diff --git a/src/Core/PartialsGenerator.php b/src/Core/PartialsGenerator.php index c5819e84..6357760a 100644 --- a/src/Core/PartialsGenerator.php +++ b/src/Core/PartialsGenerator.php @@ -70,7 +70,7 @@ public function getPartialResponses(Generator $stream, ResponseModel $responseMo $this->resetPartialResponse(); // receive data - /** @var \Cognesy\Instructor\Extras\LLM\Data\PartialLLMResponse $partialResponse */ + /** @var PartialLLMResponse $partialResponse */ foreach($stream as $partialResponse) { $this->events->dispatch(new StreamedResponseReceived($partialResponse)); // store partial response @@ -108,10 +108,6 @@ public function getPartialResponses(Generator $stream, ResponseModel $responseMo } $this->events->dispatch(new PartialJsonReceived($this->responseJson)); -// yield new PartialProcessedResponse( -// result: $result, -// partialLLMResponse: $partialResponse, -// ); yield $result->unwrap(); } $this->events->dispatch(new StreamedResponseFinished($this->lastPartialResponse())); diff --git a/src/Core/RequestHandler.php b/src/Core/RequestHandler.php index 65e17066..0b353baf 100644 --- a/src/Core/RequestHandler.php +++ b/src/Core/RequestHandler.php @@ -4,7 +4,6 @@ use Cognesy\Instructor\Contracts\CanGeneratePartials; use Cognesy\Instructor\Contracts\CanGenerateResponse; use Cognesy\Instructor\Data\Request; -use Cognesy\Instructor\Data\ResponseModel; use Cognesy\Instructor\Enums\Mode; use Cognesy\Instructor\Events\EventDispatcher; use Cognesy\Instructor\Events\Instructor\InstructorDone; @@ -30,9 +29,7 @@ class RequestHandler protected EventDispatcher $events; protected int $retries = 0; - protected array $messages = []; protected array $errors = []; - protected ?ResponseModel $responseModel; public function __construct( protected Request $request, @@ -76,16 +73,21 @@ public function stream() : Stream { * Generates response value */ protected function responseFor(Request $request) : mixed { - $this->init($request); + $this->init(); $processingResult = Result::failure("No response generated"); while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - $llmResponse = $this->getLLMResponse($request); + $llmResponse = $this->getInference($request)->toLLMResponse(); + + $llmResponse->content = match($request->mode()) { + Mode::Text => $llmResponse->content, + default => Json::from($llmResponse->content)->toString(), + }; $partialResponses = []; $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); } - $value = $this->processResult($processingResult, $request, $llmResponse, $partialResponses); + $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); return $value; } @@ -96,61 +98,36 @@ protected function responseFor(Request $request) : mixed { * @return Generator */ protected function streamResponseFor(Request $request) : Generator { - $this->init($request); + $this->init(); $processingResult = Result::failure("No response generated"); while ($processingResult->isFailure() && !$this->maxRetriesReached($request)) { - yield from $this->getStreamedLLMResponses($request); + $stream = $this->getInference($request)->toPartialLLMResponses(); + yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel()); $llmResponse = $this->partialsGenerator->getCompleteResponse(); $partialResponses = $this->partialsGenerator->partialResponses(); $processingResult = $this->processResponse($request, $llmResponse, $partialResponses); } - $value = $this->processResult($processingResult, $request, $llmResponse, $partialResponses); + $value = $this->finalizeResult($processingResult, $request, $llmResponse, $partialResponses); yield $value; } - protected function init(Request $request) : void { - $this->responseModel = $request->responseModel(); - if ($this->responseModel === null) { - throw new Exception("Request does not have a response model"); - } - + protected function init() : void { $this->retries = 0; - $this->messages = $request->messages(); // TODO: tx messages to Scripts $this->errors = []; } - protected function getLLMResponse(Request $request) : LLMResponse { + protected function getInference(Request $request) : InferenceResponse { + $this->events->dispatch(new RequestSentToLLM($request)); try { - $this->events->dispatch(new RequestSentToLLM($request)); - $llmResponse = $this->makeInference($request)->toLLMResponse(); - $llmResponse->content = match($request->mode()) { - Mode::Text => $llmResponse->content, - default => Json::from($llmResponse->content)->toString(), - }; + return $this->makeInference($request); } catch (Exception $e) { $this->events->dispatch(new RequestToLLMFailed($request, $e->getMessage())); throw $e; } - return $llmResponse; - } - - /** - * @param Request $request - * @return Generator - */ - protected function getStreamedLLMResponses(Request $request) : Generator { - try { - $this->events->dispatch(new RequestSentToLLM($request)); - $stream = $this->makeInference($request)->toPartialLLMResponses(); - yield from $this->partialsGenerator->getPartialResponses($stream, $request->responseModel()); - } catch(Exception $e) { - $this->events->dispatch(new RequestToLLMFailed($request, $e->getMessage())); - throw $e; - } } protected function makeInference(Request $request) : InferenceResponse { @@ -176,7 +153,7 @@ protected function processResponse(Request $request, LLMResponse $llmResponse, a $this->events->dispatch(new ResponseReceivedFromLLM($llmResponse)); // we have LLMResponse here - let's process it: deserialize, validate, transform - $processingResult = $this->responseGenerator->makeResponse($llmResponse, $this->responseModel); + $processingResult = $this->responseGenerator->makeResponse($llmResponse, $request->responseModel()); if ($processingResult->isFailure()) { // retry - we have not managed to deserialize, validate or transform the response @@ -186,7 +163,7 @@ protected function processResponse(Request $request, LLMResponse $llmResponse, a return $processingResult; } - protected function processResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed { + protected function finalizeResult(Result $processingResult, Request $request, LLMResponse $llmResponse, array $partialResponses) : mixed { if ($processingResult->isFailure()) { $this->events->dispatch(new ValidationRecoveryLimitReached($this->retries, $this->errors)); throw new Exception("Validation recovery attempts limit reached after {$this->retries} attempts due to: ".implode(", ", $this->errors)); @@ -195,7 +172,7 @@ protected function processResult(Result $processingResult, Request $request, LLM // get final value $value = $processingResult->unwrap(); // store response - $request->setResponse($this->messages, $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts + $request->setResponse($request->messages(), $llmResponse, $partialResponses, $value); // TODO: tx messages to Scripts // notify on response generation $this->events->dispatch(new ResponseGenerated($value)); @@ -207,7 +184,7 @@ protected function handleError(Result $processingResult, Request $request, LLMRe $this->errors = is_array($error) ? $error : [$error]; // store failed response - $request->addFailedResponse($this->messages, $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts + $request->addFailedResponse($request->messages(), $llmResponse, $partialResponses, $this->errors); // TODO: tx messages to Scripts $this->retries++; if (!$this->maxRetriesReached($request)) { $this->events->dispatch(new NewValidationRecoveryAttempt($this->retries, $this->errors)); diff --git a/src/Data/Traits/ChatTemplate/HandlesRetries.php b/src/Data/Traits/ChatTemplate/HandlesRetries.php index 51d6f31e..2e8d721d 100644 --- a/src/Data/Traits/ChatTemplate/HandlesRetries.php +++ b/src/Data/Traits/ChatTemplate/HandlesRetries.php @@ -5,6 +5,19 @@ trait HandlesRetries { + protected function addRetryMessages() : void { + $failedResponse = $this->request->lastFailedResponse(); + if (!$failedResponse || !$this->request->hasLastResponseFailed()) { + return; + } + foreach($this->request->attempts() as $attempt) { + $messages = $this->makeRetryMessages( + [], $attempt->llmResponse()->content, $attempt->errors() + ); + $this->script->section('retries')->appendMessages($messages); + } + } + protected function makeRetryMessages( array $messages, string $jsonData, @@ -19,16 +32,4 @@ protected function makeRetryMessages( protected function makeRetryPrompt() : string { return $this->request->retryPrompt() ?: $this->defaultRetryPrompt; } - - protected function addRetryMessages() { - $failedResponse = $this->request->lastFailedResponse(); - if (!$failedResponse || !$this->request->hasLastResponseFailed()) { - return; - } - $this->script->section('retries')->appendMessages( - $this->makeRetryMessages( - [], $failedResponse->llmResponse()->content, $failedResponse->errors() - ) - ); - } } \ No newline at end of file diff --git a/src/Extras/LLM/Inference.php b/src/Extras/LLM/Inference.php index 8458ec42..4404d35f 100644 --- a/src/Extras/LLM/Inference.php +++ b/src/Extras/LLM/Inference.php @@ -121,6 +121,7 @@ public function create( $request = new InferenceRequest( $messages, $model, $tools, $toolChoice, $responseFormat, $options, $mode, $this->cachedContext ?? null ); +dump($request); $this->events->dispatch(new InferenceRequested($request)); return new InferenceResponse( response: $this->driver->handle($request),