Skip to content

Commit

Permalink
Fixes in Cohere1 driver
Browse files Browse the repository at this point in the history
  • Loading branch information
ddebowczyk committed Sep 30, 2024
1 parent a661d9b commit 1c2837a
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 54 deletions.
36 changes: 18 additions & 18 deletions evals/LLMModes/run.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,29 +10,29 @@
use Cognesy\Instructor\Utils\Str;

$connections = [
'anthropic',
'azure',
'cohere2',
'fireworks',
'gemini',
'groq',
'mistral',
'ollama',
'openai',
'openrouter',
'together'
// 'anthropic',
// 'azure',
'cohere1',
// 'fireworks',
// 'gemini',
// 'groq',
// 'mistral',
// 'ollama',
// 'openai',
// 'openrouter',
// 'together'
];

$streamingModes = [
true,
false
// false
];

$modes = [
Mode::Text,
Mode::MdJson,
Mode::Json,
Mode::JsonSchema,
// Mode::Text,
// Mode::MdJson,
// Mode::Json,
// Mode::JsonSchema,
Mode::Tools,
];

Expand All @@ -57,8 +57,8 @@ function evalFn(EvalRequest $er) {

function validateToolsData(array $data) : bool {
return 'store_company' === ($data[0]['name'] ?? '')
&& 'ACME' === $data[0]['arguments']['name'] ?? ''
&& 2020 === $data[0]['arguments']['year'] ?? 0;
&& 'ACME' === ($data[0]['arguments']['name'] ?? '')
&& 2020 === ($data[0]['arguments']['year'] ?? 0);
}

(new CompareModes(
Expand Down
78 changes: 46 additions & 32 deletions src/Extras/LLM/Data/ApiResponse.php
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@ class ApiResponse
public function __construct(
public string $content = '',
public array $responseData = [],
// public string $toolName = '', // TODO: remove
// public string $toolArgs = '', // TODO: remove
public array $toolsData = [],
public string $finishReason = '',
public ?ToolCalls $toolCalls = null,
Expand All @@ -24,19 +22,49 @@ public function getJson(): string {
return Json::find($this->content);
}

public static function fromPartialResponses(
array $partialResponses,
) : ApiResponse {
$instance = new self();
public static function fromPartialResponses(array $partialResponses) : ApiResponse {
return (new self)->makeInstance($partialResponses);
}

public function hasToolCalls() : bool {
return !empty($this->toolCalls);
}

// INTERNAL //////////////////////////////////////////////

private function makeInstance(array $partialResponses) : self {
$content = '';
foreach($partialResponses as $partialResponse) {
if ($partialResponse === null) {
continue;
}
$content .= $partialResponse->delta;
$this->responseData[] = $partialResponse->responseData;
$this->inputTokens += $partialResponse->inputTokens;
$this->outputTokens += $partialResponse->outputTokens;
$this->cacheCreationTokens += $partialResponse->cacheCreationTokens;
$this->cacheReadTokens += $partialResponse->cacheReadTokens;
$this->finishReason = $partialResponse->finishReason;
}
$this->content = $content;

$tools = $this->makeTools($partialResponses);
if (!empty($tools)) {
$this->toolsData = $this->makeToolsData($tools);
$this->toolCalls = ToolCalls::fromArray($this->toolsData);
}
return $this;
}

private function makeTools(array $partialResponses) : array {
$tools = [];
$currentTool = '';
foreach($partialResponses as $partialResponse) {
if ($partialResponse === null) {
continue;
}

if (('' !== $partialResponse->toolName ?? '') && ($currentTool !== $partialResponse->toolName ?? '')) {
if (('' !== $partialResponse->toolName ?? '')
&& ($currentTool !== ($partialResponse->toolName ?? ''))) {
$currentTool = $partialResponse->toolName ?? '';
$tools[$currentTool] = '';
}
Expand All @@ -45,32 +73,18 @@ public static function fromPartialResponses(
$tools[$currentTool] .= $partialResponse->toolArgs ?? '';
}
}

$content .= $partialResponse->delta;
$instance->responseData[] = $partialResponse->responseData;
// $instance->toolName .= $partialResponse->toolName;
// $instance->toolArgs .= $partialResponse->toolArgs;
// $instance->toolsData = [];
$instance->inputTokens += $partialResponse->inputTokens;
$instance->outputTokens += $partialResponse->outputTokens;
$instance->cacheCreationTokens += $partialResponse->cacheCreationTokens;
$instance->cacheReadTokens += $partialResponse->cacheReadTokens;
$instance->finishReason = $partialResponse->finishReason;
}
return $tools;
}

if (!empty($tools)) {
$instance->toolsData = [];
foreach($tools as $tool => $args) {
$instance->toolsData[] = [
'name' => $tool,
'arguments' => Json::parse($args),
];
}
$instance->toolCalls = ToolCalls::fromArray($instance->toolsData);
private function makeToolsData(array $tools) : array {
$data = [];
foreach($tools as $tool => $args) {
$data[] = [
'name' => $tool,
'arguments' => '' !== $args ? Json::parse($args) : [],
];
}

$instance->content = $content;

return $instance;
return $data;
}
}
28 changes: 24 additions & 4 deletions src/Extras/LLM/Drivers/CohereV1Driver.php
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,6 @@ public function toApiResponse(array $data): ApiResponse {
return new ApiResponse(
content: $this->makeContent($data),
responseData: $data,
// toolName: $data['tool_calls'][0]['name'] ?? '',
// toolArgs: Json::encode($data['tool_calls'][0]['parameters'] ?? []),
toolsData: $this->mapToolsData($data),
finishReason: $data['finish_reason'] ?? '',
toolCalls: $this->makeToolCalls($data),
Expand All @@ -98,8 +96,8 @@ public function toPartialApiResponse(array $data) : PartialApiResponse {
return new PartialApiResponse(
delta: $this->makeDelta($data),
responseData: $data,
toolName: $data['tool_calls'][0]['name'] ?? '',
toolArgs: Json::encode($data['tool_calls'][0]['parameters'] ?? []),
toolName: $this->makeToolNameDelta($data),
toolArgs: $this->makeToolArgsDelta($data),
finishReason: $data['response']['finish_reason'] ?? $data['delta']['finish_reason'] ?? '',
inputTokens: $data['response']['meta']['tokens']['input_tokens'] ?? $data['delta']['tokens']['input_tokens'] ?? 0,
outputTokens: $data['response']['meta']['tokens']['output_tokens'] ?? $data['delta']['tokens']['input_tokens'] ?? 0,
Expand Down Expand Up @@ -205,6 +203,28 @@ private function makeContent(array $data) : string {
}

private function makeDelta(array $data) : string {
if ($this->isStreamEnd($data)) {
return '';
}
return $data['text'] ?? $data['tool_calls'][0]['parameters'] ?? '';
}

private function makeToolArgsDelta(array $data) : string {
if ($this->isStreamEnd($data)) {
return '';
}
$toolArgs = $data['tool_calls'][0]['parameters'] ?? '';
return ('' === $toolArgs) ? '' : Json::encode($toolArgs);
}

private function makeToolNameDelta(array $data) : string {
if ($this->isStreamEnd($data)) {
return '';
}
return $data['tool_calls'][0]['name'] ?? '';
}

private function isStreamEnd(array $data) : bool {
return 'stream_end' === ($data['event_type'] ?? '');
}
}

0 comments on commit 1c2837a

Please sign in to comment.