Skip to content

Commit

Permalink
Add support for stop sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ authored Aug 9, 2024
1 parent 02ca8b3 commit b400631
Show file tree
Hide file tree
Showing 12 changed files with 35 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/draive/anthropic/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ async def completion( # noqa: PLR0913
max_tokens=config.max_tokens,
top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN,
timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else NOT_GIVEN,
stream=stream,
)

Expand Down
1 change: 1 addition & 0 deletions src/draive/anthropic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ class AnthropicConfig(DataModel):
top_p: float | Missing = MISSING
max_tokens: int = 2048
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING
3 changes: 3 additions & 0 deletions src/draive/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ async def generate( # noqa: PLR0913
"maxOutputTokens": config.max_tokens,
"responseSchema": response_schema if response_schema else None,
"candidateCount": 1,
"stopSequences": config.stop_sequences
if not_missing(config.stop_sequences)
else None,
},
"systemInstruction": {
"parts": ({"text": instruction},),
Expand Down
1 change: 1 addition & 0 deletions src/draive/gemini/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class GeminiConfig(DataModel):
top_k: int | Missing = MISSING
max_tokens: int = 2048
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class GeminiEmbeddingConfig(DataModel):
Expand Down
4 changes: 4 additions & 0 deletions src/draive/mistral/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def chat_completion(
seed=config.seed if not_missing(config.seed) else None,
tools=tools,
tool_choice=tool_choice if tools else None,
stop=config.stop_sequences if not_missing(config.stop_sequences) else None,
)

async def embedding(
Expand Down Expand Up @@ -124,6 +125,7 @@ async def _create_chat_completion( # noqa: PLR0913
messages: list[ChatMessage],
tools: list[dict[str, object]] | None,
tool_choice: str | None,
stop: list[str] | None,
) -> ChatCompletionResponse:
request_body: dict[str, Any] = {
"model": model,
Expand All @@ -143,6 +145,8 @@ async def _create_chat_completion( # noqa: PLR0913
request_body["top_p"] = top_p
if seed is not None:
request_body["random_seed"] = seed
if stop:
request_body["stop"] = stop
if response_format is not None:
request_body["response_format"] = response_format

Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MistralChatConfig(DataModel):
max_tokens: int = 2048
response_format: ResponseFormat | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class MistralEmbeddingConfig(DataModel):
Expand Down
16 changes: 14 additions & 2 deletions src/draive/mrs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else None,
top_k=config.top_k if not_missing(config.top_k) else None,
max_tokens=config.max_tokens if not_missing(config.max_tokens) else None,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else None,
)

else:
Expand All @@ -113,6 +116,9 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else None,
top_k=config.top_k if not_missing(config.top_k) else None,
max_tokens=config.max_tokens if not_missing(config.max_tokens) else None,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else None,
)

async def _create_chat_completion( # noqa: PLR0913
Expand All @@ -123,6 +129,7 @@ async def _create_chat_completion( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[str, object]],
stop_sequences: list[str] | None,
) -> ChatCompletionResponse:
return await self._send_chat_completion_request(
runner=await self._get_runner(model),
Expand All @@ -132,6 +139,7 @@ async def _create_chat_completion( # noqa: PLR0913
top_k=top_k,
max_tokens=max_tokens,
messages=messages,
stop_sequences=stop_sequences,
)

async def _create_chat_stream( # noqa: PLR0913
Expand All @@ -142,6 +150,7 @@ async def _create_chat_stream( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[str, object]],
stop_sequences: list[str] | None,
) -> AsyncIterable[ChatCompletionChunkResponse]:
return ctx.stream_sync(
self._send_chat_completion_stream_request(
Expand All @@ -152,6 +161,7 @@ async def _create_chat_stream( # noqa: PLR0913
top_k=top_k,
max_tokens=max_tokens,
messages=messages,
stop_sequences=stop_sequences,
),
executor=MRS_EXECUTOR,
)
Expand Down Expand Up @@ -205,6 +215,7 @@ def _send_chat_completion_request( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[Any, Any]],
stop_sequences: list[str] | None,
) -> ChatCompletionResponse:
return cast(
ChatCompletionResponse,
Expand All @@ -219,7 +230,7 @@ def _send_chat_completion_request( # noqa: PLR0913
n_choices=1,
presence_penalty=None,
frequency_penalty=None,
stop_seqs=None,
stop_seqs=stop_sequences,
temperature=temperature,
top_p=top_p,
stream=False,
Expand All @@ -240,6 +251,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[Any, Any]],
stop_sequences: list[str] | None,
) -> Generator[ChatCompletionChunkResponse]:
yield from cast(
Iterator[ChatCompletionChunkResponse],
Expand All @@ -254,7 +266,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913
n_choices=1,
presence_penalty=None,
frequency_penalty=None,
stop_seqs=None,
stop_seqs=stop_sequences,
temperature=temperature,
top_p=top_p,
stream=True,
Expand Down
1 change: 1 addition & 0 deletions src/draive/mrs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class MRSChatConfig(DataModel):
top_p: float | Missing = MISSING
top_k: int | Missing = MISSING
max_tokens: int = 2048
stop_sequences: list[str] | Missing = MISSING
4 changes: 4 additions & 0 deletions src/draive/ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def chat_completion(
if not_missing(config.response_format)
else "text",
seed=config.seed if not_missing(config.seed) else None,
stop=config.stop_sequences if not_missing(config.stop_sequences) else None,
)

async def _create_chat_completion( # noqa: PLR0913
Expand All @@ -64,6 +65,7 @@ async def _create_chat_completion( # noqa: PLR0913
max_tokens: int | None,
response_format: Literal["text", "json"],
messages: list[ChatMessage],
stop: list[str] | None,
) -> ChatCompletionResponse:
request_body: dict[str, Any] = {
"model": model,
Expand All @@ -82,6 +84,8 @@ async def _create_chat_completion( # noqa: PLR0913
request_body["options"]["top_p"] = top_p
if seed is not None:
request_body["options"]["seed"] = seed
if stop:
request_body["options"]["stop"] = stop
if response_format == "json":
request_body["format"] = "json"

Expand Down
1 change: 1 addition & 0 deletions src/draive/ollama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class OllamaChatConfig(DataModel):
max_tokens: int = 2048
response_format: Literal["text", "json"] | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING
1 change: 1 addition & 0 deletions src/draive/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN,
timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN,
stream_options={"include_usage": True} if stream else NOT_GIVEN,
stop=config.stop_sequences if not_missing(config.stop_sequences) else NOT_GIVEN,
)

except OpenAIRateLimitError as exc: # retry on rate limit after delay
Expand Down
1 change: 1 addition & 0 deletions src/draive/openai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OpenAIChatConfig(DataModel):
response_format: ResponseFormat | Missing = MISSING
vision_details: Literal["auto", "low", "high"] | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class OpenAIEmbeddingConfig(DataModel):
Expand Down

0 comments on commit b400631

Please sign in to comment.