Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for stop sequences #153

Merged
merged 1 commit into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/draive/anthropic/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ async def completion( # noqa: PLR0913
max_tokens=config.max_tokens,
top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN,
timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else NOT_GIVEN,
stream=stream,
)

Expand Down
1 change: 1 addition & 0 deletions src/draive/anthropic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,4 @@ class AnthropicConfig(DataModel):
top_p: float | Missing = MISSING
max_tokens: int = 2048
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING
3 changes: 3 additions & 0 deletions src/draive/gemini/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ async def generate( # noqa: PLR0913
"maxOutputTokens": config.max_tokens,
"responseSchema": response_schema if response_schema else None,
"candidateCount": 1,
"stopSequences": config.stop_sequences
if not_missing(config.stop_sequences)
else None,
},
"systemInstruction": {
"parts": ({"text": instruction},),
Expand Down
1 change: 1 addition & 0 deletions src/draive/gemini/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class GeminiConfig(DataModel):
top_k: int | Missing = MISSING
max_tokens: int = 2048
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class GeminiEmbeddingConfig(DataModel):
Expand Down
4 changes: 4 additions & 0 deletions src/draive/mistral/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ async def chat_completion(
seed=config.seed if not_missing(config.seed) else None,
tools=tools,
tool_choice=tool_choice if tools else None,
stop=config.stop_sequences if not_missing(config.stop_sequences) else None,
)

async def embedding(
Expand Down Expand Up @@ -124,6 +125,7 @@ async def _create_chat_completion( # noqa: PLR0913
messages: list[ChatMessage],
tools: list[dict[str, object]] | None,
tool_choice: str | None,
stop: list[str] | None,
) -> ChatCompletionResponse:
request_body: dict[str, Any] = {
"model": model,
Expand All @@ -143,6 +145,8 @@ async def _create_chat_completion( # noqa: PLR0913
request_body["top_p"] = top_p
if seed is not None:
request_body["random_seed"] = seed
if stop:
request_body["stop"] = stop
if response_format is not None:
request_body["response_format"] = response_format

Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class MistralChatConfig(DataModel):
max_tokens: int = 2048
response_format: ResponseFormat | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class MistralEmbeddingConfig(DataModel):
Expand Down
16 changes: 14 additions & 2 deletions src/draive/mrs/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else None,
top_k=config.top_k if not_missing(config.top_k) else None,
max_tokens=config.max_tokens if not_missing(config.max_tokens) else None,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else None,
)

else:
Expand All @@ -113,6 +116,9 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else None,
top_k=config.top_k if not_missing(config.top_k) else None,
max_tokens=config.max_tokens if not_missing(config.max_tokens) else None,
stop_sequences=config.stop_sequences
if not_missing(config.stop_sequences)
else None,
)

async def _create_chat_completion( # noqa: PLR0913
Expand All @@ -123,6 +129,7 @@ async def _create_chat_completion( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[str, object]],
stop_sequences: list[str] | None,
) -> ChatCompletionResponse:
return await self._send_chat_completion_request(
runner=await self._get_runner(model),
Expand All @@ -132,6 +139,7 @@ async def _create_chat_completion( # noqa: PLR0913
top_k=top_k,
max_tokens=max_tokens,
messages=messages,
stop_sequences=stop_sequences,
)

async def _create_chat_stream( # noqa: PLR0913
Expand All @@ -142,6 +150,7 @@ async def _create_chat_stream( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[str, object]],
stop_sequences: list[str] | None,
) -> AsyncIterable[ChatCompletionChunkResponse]:
return ctx.stream_sync(
self._send_chat_completion_stream_request(
Expand All @@ -152,6 +161,7 @@ async def _create_chat_stream( # noqa: PLR0913
top_k=top_k,
max_tokens=max_tokens,
messages=messages,
stop_sequences=stop_sequences,
),
executor=MRS_EXECUTOR,
)
Expand Down Expand Up @@ -205,6 +215,7 @@ def _send_chat_completion_request( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[Any, Any]],
stop_sequences: list[str] | None,
) -> ChatCompletionResponse:
return cast(
ChatCompletionResponse,
Expand All @@ -219,7 +230,7 @@ def _send_chat_completion_request( # noqa: PLR0913
n_choices=1,
presence_penalty=None,
frequency_penalty=None,
stop_seqs=None,
stop_seqs=stop_sequences,
temperature=temperature,
top_p=top_p,
stream=False,
Expand All @@ -240,6 +251,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913
top_k: int | None,
max_tokens: int | None,
messages: list[dict[Any, Any]],
stop_sequences: list[str] | None,
) -> Generator[ChatCompletionChunkResponse]:
yield from cast(
Iterator[ChatCompletionChunkResponse],
Expand All @@ -254,7 +266,7 @@ def _send_chat_completion_stream_request( # noqa: PLR0913
n_choices=1,
presence_penalty=None,
frequency_penalty=None,
stop_seqs=None,
stop_seqs=stop_sequences,
temperature=temperature,
top_p=top_p,
stream=True,
Expand Down
1 change: 1 addition & 0 deletions src/draive/mrs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ class MRSChatConfig(DataModel):
top_p: float | Missing = MISSING
top_k: int | Missing = MISSING
max_tokens: int = 2048
stop_sequences: list[str] | Missing = MISSING
4 changes: 4 additions & 0 deletions src/draive/ollama/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ async def chat_completion(
if not_missing(config.response_format)
else "text",
seed=config.seed if not_missing(config.seed) else None,
stop=config.stop_sequences if not_missing(config.stop_sequences) else None,
)

async def _create_chat_completion( # noqa: PLR0913
Expand All @@ -64,6 +65,7 @@ async def _create_chat_completion( # noqa: PLR0913
max_tokens: int | None,
response_format: Literal["text", "json"],
messages: list[ChatMessage],
stop: list[str] | None,
) -> ChatCompletionResponse:
request_body: dict[str, Any] = {
"model": model,
Expand All @@ -82,6 +84,8 @@ async def _create_chat_completion( # noqa: PLR0913
request_body["options"]["top_p"] = top_p
if seed is not None:
request_body["options"]["seed"] = seed
if stop:
request_body["options"]["stop"] = stop
if response_format == "json":
request_body["format"] = "json"

Expand Down
1 change: 1 addition & 0 deletions src/draive/ollama/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class OllamaChatConfig(DataModel):
max_tokens: int = 2048
response_format: Literal["text", "json"] | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING
1 change: 1 addition & 0 deletions src/draive/openai/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ async def chat_completion(
top_p=config.top_p if not_missing(config.top_p) else NOT_GIVEN,
timeout=config.timeout if not_missing(config.timeout) else NOT_GIVEN,
stream_options={"include_usage": True} if stream else NOT_GIVEN,
stop=config.stop_sequences if not_missing(config.stop_sequences) else NOT_GIVEN,
)

except OpenAIRateLimitError as exc: # retry on rate limit after delay
Expand Down
1 change: 1 addition & 0 deletions src/draive/openai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class OpenAIChatConfig(DataModel):
response_format: ResponseFormat | Missing = MISSING
vision_details: Literal["auto", "low", "high"] | Missing = MISSING
timeout: float | Missing = MISSING
stop_sequences: list[str] | Missing = MISSING


class OpenAIEmbeddingConfig(DataModel):
Expand Down
Loading