Skip to content

Commit

Permalink
chore: add cohere engine and model support
Browse files Browse the repository at this point in the history
  • Loading branch information
louis-jan committed Jan 14, 2025
1 parent 2ca94d3 commit c660a39
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions core/src/browser/extensions/engines/helpers/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ export function requestInference(
model.parameters?.stream === false
) {
const data = await response.json()
if (data.error) {
subscriber.error(data.error)
if (data.error || data.message) {
subscriber.error(data.error ?? data)
subscriber.complete()
return

Check warning on line 60 in core/src/browser/extensions/engines/helpers/sse.ts

View workflow job for this annotation

GitHub Actions / coverage-check

58-60 lines are not covered with tests
}
Expand Down
18 changes: 15 additions & 3 deletions extensions/engine-management-extension/engines.mjs
Original file line number Diff line number Diff line change
@@ -1,23 +1,35 @@
import anthropic from './resources/anthropic.json' with { type: 'json' }
import cohere from './resources/cohere.json' with { type: 'json' }
import openai from './resources/openai.json' with { type: 'json' }
import openrouter from './resources/openrouter.json' with { type: 'json' }
import groq from './resources/groq.json' with { type: 'json' }
import martian from './resources/martian.json' with { type: 'json' }
import mistral from './resources/mistral.json' with { type: 'json' }
import nvidia from './resources/nvidia.json' with { type: 'json' }

import openaiModels from './models/openai.json' with { type: 'json' }
import anthropicModels from './models/anthropic.json' with { type: 'json' }
import cohereModels from './models/cohere.json' with { type: 'json' }
import openaiModels from './models/openai.json' with { type: 'json' }
import openrouterModels from './models/openrouter.json' with { type: 'json' }
import groqModels from './models/groq.json' with { type: 'json' }
import martianModels from './models/martian.json' with { type: 'json' }
import mistralModels from './models/mistral.json' with { type: 'json' }
import nvidiaModels from './models/nvidia.json' with { type: 'json' }

const engines = [anthropic, openai, openrouter, groq, mistral, martian, nvidia]
const engines = [
anthropic,
openai,
cohere,
openrouter,
groq,
mistral,
martian,
nvidia,
]
const models = [
...openaiModels,
...anthropicModels,
...openaiModels,
...cohereModels,
...openrouterModels,
...groqModels,
...mistralModels,
Expand Down
4 changes: 2 additions & 2 deletions extensions/engine-management-extension/models/cohere.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"version": "1.0",
"description": "Command R+ is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It is best suited for complex RAG workflows and multi-step tool use.",
"inference_params": {
"max_tokens": 128000,
"max_tokens": 4096,
"temperature": 0.7,
"stream": false
},
Expand All @@ -19,7 +19,7 @@
"version": "1.0",
"description": "Command R is an instruction-following conversational model that performs language tasks at a higher quality, more reliably, and with a longer context than previous models. It can be used for complex workflows like code generation, retrieval augmented generation (RAG), tool use, and agents.",
"inference_params": {
"max_tokens": 128000,
"max_tokens": 4096,
"temperature": 0.7,
"stream": false
},
Expand Down
4 changes: 2 additions & 2 deletions extensions/engine-management-extension/resources/cohere.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
"transform_req": {
"chat_completions": {
"url": "https://api.cohere.ai/v1/chat",
"template": "{ {% set first = true %} {% for key, value in input_request %} {% if key == \"messages\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} }"
"template": "{ {% for key, value in input_request %} {% if key == \"messages\" %} {% if input_request.messages.0.role == \"system\" %} \"preamble\": \"{{ input_request.messages.0.content }}\", {% if length(input_request.messages) > 2 %} \"chatHistory\": [{% for message in input_request.messages %} {% if not loop.is_first and not loop.is_last %} {\"role\": {% if message.role == \"user\" %} \"USER\" {% else %} \"CHATBOT\" {% endif %}, \"content\": \"{{ message.content }}\" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} {% endif %} {% endfor %}], {% endif %} \"message\": \"{{ last(input_request.messages).content }}\" {% else %} {% if length(input_request.messages) > 2 %} \"chatHistory\": [{% for message in input_request.messages %} {% if not loop.is_last %} { \"role\": {% if message.role == \"user\" %} \"USER\" {% else %} \"CHATBOT\" {% endif %}, \"content\": \"{{ message.content }}\" } {% if loop.index < length(input_request.messages) - 2 %},{% endif %} {% endif %} {% endfor %}],{% endif %}\"message\": \"{{ last(input_request.messages).content }}\" {% endif %}{% if not loop.is_last %},{% endif %} {% else if key == \"system\" or key == \"model\" or key == \"temperature\" or key == \"store\" or key == \"max_tokens\" or key == \"stream\" or key == \"presence_penalty\" or key == \"metadata\" or key == \"frequency_penalty\" or key == \"tools\" or key == \"tool_choice\" or key == \"logprobs\" or key == \"top_logprobs\" or key == \"logit_bias\" or key == \"n\" or key == \"modalities\" or key == \"prediction\" or key == \"response_format\" or key == \"service_tier\" or key == \"seed\" or key == \"stop\" or key == \"stream_options\" or key == \"top_p\" or key == \"parallel_tool_calls\" or key == \"user\" %} \"{{ key }}\": {{ tojson(value) }} {% if not loop.is_last %},{% endif %} {% endif %} {% endfor %} }"
}
},
"transform_resp": {
"chat_completions": {
"template": "{ {% set first = true %} {% for key, value in input_request %} {% if key == \"choices\" or key == \"created\" or key == \"model\" or key == \"service_tier\" or key == \"system_fingerprint\" or key == \"stream\" or key == \"object\" or key == \"usage\" %} {% if not first %},{% endif %} \"{{ key }}\": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} }"
"template": "{% if input_request.stream %} {\"object\": \"chat.completion.chunk\", \"model\": \"{{ input_request.model }}\", \"choices\": [{\"index\": 0, \"delta\": { {% if input_request.event_type == \"text-generation\" %} \"role\": \"assistant\", \"content\": \"{{ input_request.text }}\" {% else %} \"role\": \"assistant\", \"content\": null {% endif %} }, {% if input_request.event_type == \"stream-end\" %} \"finish_reason\": \"{{ input_request.finish_reason }}\" {% else %} \"finish_reason\": null {% endif %} }]} {% else %} {\"id\": \"{{ input_request.generation_id }}\", \"created\": null, \"object\": \"chat.completion\", \"model\": {% if input_request.model %} \"{{ input_request.model }}\" {% else %} \"command-r-plus-08-2024\" {% endif %}, \"choices\": [{ \"index\": 0, \"message\": { \"role\": \"assistant\", \"content\": {% if not input_request.text %} null {% else %} \"{{ input_request.text }}\" {% endif %}, \"refusal\": null }, \"logprobs\": null, \"finish_reason\": \"{{ input_request.finish_reason }}\" } ], \"usage\": { \"prompt_tokens\": {{ input_request.meta.tokens.input_tokens }}, \"completion_tokens\": {{ input_request.meta.tokens.output_tokens }},\"total_tokens\": {{ input_request.meta.tokens.input_tokens + input_request.meta.tokens.output_tokens }}, \"prompt_tokens_details\": { \"cached_tokens\": 0 },\"completion_tokens_details\": { \"reasoning_tokens\": 0, \"accepted_prediction_tokens\": 0, \"rejected_prediction_tokens\": 0 } }, \"system_fingerprint\": \"fp_6b68a8204b\"} {% endif %}"
}
}
}
Expand Down

0 comments on commit c660a39

Please sign in to comment.