diff --git a/src/draive/anthropic/client.py b/src/draive/anthropic/client.py index c3deb2f..a2eed99 100644 --- a/src/draive/anthropic/client.py +++ b/src/draive/anthropic/client.py @@ -98,6 +98,9 @@ async def completion( # noqa: PLR0913 max_tokens=config.max_tokens, top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN, timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN, + stop_sequences=config.stop_sequences + if not_missing(config.stop_sequences) + else NOT_GIVEN, stream=stream, ) diff --git a/src/draive/anthropic/config.py b/src/draive/anthropic/config.py index 35c64fb..08d3bf2 100644 --- a/src/draive/anthropic/config.py +++ b/src/draive/anthropic/config.py @@ -26,3 +26,4 @@ class AnthropicConfig(DataModel): top_p: float | Missing = MISSING max_tokens: int = 2048 timeout: float | Missing = MISSING + stop_sequences: list[str] | Missing = MISSING diff --git a/src/draive/gemini/client.py b/src/draive/gemini/client.py index 02f79c9..27bda5a 100644 --- a/src/draive/gemini/client.py +++ b/src/draive/gemini/client.py @@ -75,6 +75,9 @@ async def generate( # noqa: PLR0913 "maxOutputTokens": config.max_tokens, "responseSchema": response_schema if response_schema else None, "candidateCount": 1, + "stopSequences": config.stop_sequences + if not_missing(config.stop_sequences) + else None, }, "systemInstruction": { "parts": ({"text": instruction},), diff --git a/src/draive/gemini/config.py b/src/draive/gemini/config.py index 2dc2757..02195c5 100644 --- a/src/draive/gemini/config.py +++ b/src/draive/gemini/config.py @@ -17,6 +17,7 @@ class GeminiConfig(DataModel): top_k: int | Missing = MISSING max_tokens: int = 2048 timeout: float | Missing = MISSING + stop_sequences: list[str] | Missing = MISSING class GeminiEmbeddingConfig(DataModel): diff --git a/src/draive/mistral/client.py b/src/draive/mistral/client.py index 080b4a1..590f512 100644 --- a/src/draive/mistral/client.py +++ b/src/draive/mistral/client.py @@ -92,6 +92,7 @@ async def chat_completion( seed=config.seed if not_missing(config.seed) else None, tools=tools, tool_choice=tool_choice if tools else None, + stop=config.stop_sequences if not_missing(config.stop_sequences) else None, ) async def embedding( @@ -124,6 +125,7 @@ async def _create_chat_completion( # noqa: PLR0913 messages: list[ChatMessage], tools: list[dict[str, object]] | None, tool_choice: str | None, + stop: list[str] | None, ) -> ChatCompletionResponse: request_body: dict[str, Any] = { "model": model, @@ -143,6 +145,8 @@ async def _create_chat_completion( # noqa: PLR0913 request_body["top_p"] = top_p if seed is not None: request_body["random_seed"] = seed + if stop: + request_body["stop"] = stop if response_format is not None: request_body["response_format"] = response_format diff --git a/src/draive/mistral/config.py b/src/draive/mistral/config.py index 67e57b6..8211f1b 100644 --- a/src/draive/mistral/config.py +++ b/src/draive/mistral/config.py @@ -21,6 +21,7 @@ class MistralChatConfig(DataModel): max_tokens: int = 2048 response_format: ResponseFormat | Missing = MISSING timeout: float | Missing = MISSING + stop_sequences: list[str] | Missing = MISSING class MistralEmbeddingConfig(DataModel): diff --git a/src/draive/mrs/client.py b/src/draive/mrs/client.py index 0f8d03a..d642935 100644 --- a/src/draive/mrs/client.py +++ b/src/draive/mrs/client.py @@ -103,6 +103,9 @@ async def chat_completion( top_p=config.top_p if not_missing(config.top_p) else None, top_k=config.top_k if not_missing(config.top_k) else None, max_tokens=config.max_tokens if not_missing(config.max_tokens) else None, + stop_sequences=config.stop_sequences + if not_missing(config.stop_sequences) + else None, ) else: @@ -113,6 +116,9 @@ async def chat_completion( top_p=config.top_p if not_missing(config.top_p) else None, top_k=config.top_k if not_missing(config.top_k) else None, max_tokens=config.max_tokens if not_missing(config.max_tokens) else None, + stop_sequences=config.stop_sequences + if not_missing(config.stop_sequences) + else None, ) async def _create_chat_completion( # noqa: PLR0913 @@ -123,6 +129,7 @@ async def _create_chat_completion( # noqa: PLR0913 top_k: int | None, max_tokens: int | None, messages: list[dict[str, object]], + stop_sequences: list[str] | None, ) -> ChatCompletionResponse: return await self._send_chat_completion_request( runner=await self._get_runner(model), @@ -132,6 +139,7 @@ async def _create_chat_completion( # noqa: PLR0913 top_k=top_k, max_tokens=max_tokens, messages=messages, + stop_sequences=stop_sequences, ) async def _create_chat_stream( # noqa: PLR0913 @@ -142,6 +150,7 @@ async def _create_chat_stream( # noqa: PLR0913 top_k: int | None, max_tokens: int | None, messages: list[dict[str, object]], + stop_sequences: list[str] | None, ) -> AsyncIterable[ChatCompletionChunkResponse]: return ctx.stream_sync( self._send_chat_completion_stream_request( @@ -152,6 +161,7 @@ async def _create_chat_stream( # noqa: PLR0913 top_k=top_k, max_tokens=max_tokens, messages=messages, + stop_sequences=stop_sequences, ), executor=MRS_EXECUTOR, ) @@ -205,6 +215,7 @@ def _send_chat_completion_request( # noqa: PLR0913 top_k: int | None, max_tokens: int | None, messages: list[dict[Any, Any]], + stop_sequences: list[str] | None, ) -> ChatCompletionResponse: return cast( ChatCompletionResponse, @@ -219,7 +230,7 @@ def _send_chat_completion_request( # noqa: PLR0913 n_choices=1, presence_penalty=None, frequency_penalty=None, - stop_seqs=None, + stop_seqs=stop_sequences, temperature=temperature, top_p=top_p, stream=False, @@ -240,6 +251,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913 top_k: int | None, max_tokens: int | None, messages: list[dict[Any, Any]], + stop_sequences: list[str] | None, ) -> Generator[ChatCompletionChunkResponse]: yield from cast( Iterator[ChatCompletionChunkResponse], @@ -254,7 +266,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913 n_choices=1, presence_penalty=None, frequency_penalty=None, - stop_seqs=None, + stop_seqs=stop_sequences, temperature=temperature, top_p=top_p, stream=True, diff --git a/src/draive/mrs/config.py b/src/draive/mrs/config.py index 11cbe59..0ab57ba 100644 --- a/src/draive/mrs/config.py +++ b/src/draive/mrs/config.py @@ -12,3 +12,4 @@ class MRSChatConfig(DataModel): top_p: float | Missing = MISSING top_k: int | Missing = MISSING max_tokens: int = 2048 + stop_sequences: list[str] | Missing = MISSING diff --git a/src/draive/ollama/client.py b/src/draive/ollama/client.py index 74cd7b3..cb7c600 100644 --- a/src/draive/ollama/client.py +++ b/src/draive/ollama/client.py @@ -52,6 +52,7 @@ async def chat_completion( if not_missing(config.response_format) else "text", seed=config.seed if not_missing(config.seed) else None, + stop=config.stop_sequences if not_missing(config.stop_sequences) else None, ) async def _create_chat_completion( # noqa: PLR0913 @@ -64,6 +65,7 @@ async def _create_chat_completion( # noqa: PLR0913 max_tokens: int | None, response_format: Literal["text", "json"], messages: list[ChatMessage], + stop: list[str] | None, ) -> ChatCompletionResponse: request_body: dict[str, Any] = { "model": model, @@ -82,6 +84,8 @@ async def _create_chat_completion( # noqa: PLR0913 request_body["options"]["top_p"] = top_p if seed is not None: request_body["options"]["seed"] = seed + if stop: + request_body["options"]["stop"] = stop if response_format == "json": request_body["format"] = "json" diff --git a/src/draive/ollama/config.py b/src/draive/ollama/config.py index 20a53ef..5a6946e 100644 --- a/src/draive/ollama/config.py +++ b/src/draive/ollama/config.py @@ -17,3 +17,4 @@ class OllamaChatConfig(DataModel): max_tokens: int = 2048 response_format: Literal["text", "json"] | Missing = MISSING timeout: float | Missing = MISSING + stop_sequences: list[str] | Missing = MISSING diff --git a/src/draive/openai/client.py b/src/draive/openai/client.py index 2befcca..ebca0ab 100644 --- a/src/draive/openai/client.py +++ b/src/draive/openai/client.py @@ -126,6 +126,7 @@ async def chat_completion( top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN, timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN, stream_options={"include_usage": True} if stream else NOT_GIVEN, + stop=config.stop_sequences if not_missing(config.stop_sequences) else NOT_GIVEN, ) except OpenAIRateLimitError as exc: # retry on rate limit after delay diff --git a/src/draive/openai/config.py b/src/draive/openai/config.py index 4dd2fb7..1ebd226 100644 --- a/src/draive/openai/config.py +++ b/src/draive/openai/config.py @@ -26,6 +26,7 @@ class OpenAIChatConfig(DataModel): response_format: ResponseFormat | Missing = MISSING vision_details: Literal["auto", "low", "high"] | Missing = MISSING timeout: float | Missing = MISSING + stop_sequences: list[str] | Missing = MISSING class OpenAIEmbeddingConfig(DataModel):