Skip to content

Commit

Permalink
Add prefill to completions
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed Aug 22, 2024
1 parent 5717680 commit b55ff10
Show file tree
Hide file tree
Showing 17 changed files with 149 additions and 33 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "setuptools.build_meta"
[project]
name = "draive"
description = "Framework designed to simplify and accelerate the development of LLM-based applications."
version = "0.26.0"
version = "0.27.0"
readme = "README.md"
maintainers = [
{ name = "Kacper Kaliński", email = "[email protected]" },
Expand Down
22 changes: 20 additions & 2 deletions src/draive/anthropic/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,21 @@ async def _completion( # noqa: PLR0913, PLR0912, C901
),
)

message_parts: list[TextBlock] = []
message_parts: list[TextBlock]
match messages[-1]:
case {"role": "assistant", "content": str() as content_text}:
message_parts = [TextBlock(type="text", text=content_text)]

case {"role": "assistant", "content": content_parts}:
message_parts = [ # currently supporting only text prefills
TextBlock(type="text", text=part.text)
for part in content_parts
if isinstance(part, TextBlock)
]

case _:
message_parts = []

tool_calls: list[ToolUseBlock] = []
for part in completion.content:
match part:
Expand Down Expand Up @@ -376,8 +390,12 @@ async def _completion( # noqa: PLR0913, PLR0912, C901

else:
ctx.record(ResultTrace.of(message_parts))

return LMMCompletion.of(
MultimodalContent.of(*[TextContent(text=part.text) for part in message_parts])
MultimodalContent.of(
*[TextContent(text=part.text) for part in message_parts],
merge_text=True,
)
)

case other:
Expand Down
4 changes: 3 additions & 1 deletion src/draive/choice/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
]


async def choice_completion(
async def choice_completion( # noqa: PLR0913
*,
instruction: Instruction | str,
options: Iterable[ChoiceOption | Multimodal],
input: Multimodal, # noqa: A002
prefill: str | None = None,
tools: Toolbox | Sequence[AnyTool] | None = None,
examples: Iterable[tuple[Multimodal, ChoiceOption]] | None = None,
**extra: Any,
Expand All @@ -40,6 +41,7 @@ async def choice_completion(
for option in options
],
input=input,
prefill=prefill,
toolbox=toolbox,
examples=examples,
**extra,
Expand Down
3 changes: 2 additions & 1 deletion src/draive/choice/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@

@runtime_checkable
class ChoiceCompletion(Protocol):
async def __call__(
async def __call__( # noqa: PLR0913
self,
*,
instruction: Instruction | str,
options: Sequence[ChoiceOption],
input: Multimodal, # noqa: A002
prefill: str | None,
toolbox: Toolbox,
examples: Iterable[tuple[Multimodal, ChoiceOption]] | None,
**extra: Any,
Expand Down
19 changes: 12 additions & 7 deletions src/draive/choice/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,19 @@
]


async def lmm_choice_completion( # noqa: C901
async def lmm_choice_completion( # noqa: C901, PLR0912, PLR0913
*,
instruction: Instruction | str,
options: Sequence[ChoiceOption],
input: Multimodal, # noqa: A002
prefill: str | None,
toolbox: Toolbox,
examples: Iterable[tuple[Multimodal, ChoiceOption]] | None = None,
examples: Iterable[tuple[Multimodal, ChoiceOption]] | None,
**extra: Any,
) -> ChoiceOption:
with ctx.nested(
"lmm_choice_completion",
):
assert "select" in str(instruction).lower(), "Instruction have to contain a word 'select'" # nosec: B101
assert options, "Choice options cannot be empty" # nosec: B101
assert all( # nosec: B101
example[1] in options for example in examples or []
Expand Down Expand Up @@ -78,6 +78,9 @@ async def lmm_choice_completion( # noqa: C901
),
]

if prefill := prefill:
context.append(LMMCompletion.of(prefill))

recursion_level: int = 0
while recursion_level <= toolbox.recursion_limit:
match await lmm_invocation(
Expand All @@ -91,7 +94,7 @@ async def lmm_choice_completion( # noqa: C901
):
case LMMCompletion() as completion:
ctx.log_debug("Received choice results")
if selection := xml_tag("SELECTION", source=completion.content.as_string()):
if selection := xml_tag("CHOICE", source=completion.content.as_string()):
if option := options_map.get(selection):
return option

Expand All @@ -105,7 +108,7 @@ async def lmm_choice_completion( # noqa: C901
response.content for response in responses if response.direct
]:
if selection := xml_tag(
"SELECTION",
"CHOICE",
source=MultimodalContent.of(*direct_content).as_string(),
):
if option := options_map.get(selection):
Expand Down Expand Up @@ -161,6 +164,8 @@ def _format_example(


INSTRUCTION_EXTENSION: str = """\
Selection HAVE to contain an identifier of a chosen option inside a `SELECTION` \
xml tag within the result i.e. `<SELECTION>identifier</SELECTION>`.
<FORMAT>
Place identifier of the final choice inside a <CHOICE> XML tag within the result, \
like this: `<CHOICE>identifier</CHOICE>`.
</FORMAT>
"""
29 changes: 27 additions & 2 deletions src/draive/gemini/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
LMMToolRequest,
LMMToolRequests,
LMMToolResponse,
MultimodalContent,
MultimodalContentElement,
TextContent,
VideoBase64Content,
Expand Down Expand Up @@ -359,6 +360,24 @@ async def _generate( # noqa: PLR0913, C901, PLR0912, PLR0915

converted_tools.append(tool_function)

prefill: str = ""
match messages[-1]:
case {"role": "model", "parts": content_parts}:
if config.response_format == "application/json":
del messages[-1] # for json mode ignore prefill

else:
for part in content_parts:
match part: # currently supporting only text prefills
case {"text": str() as text}:
prefill += text

case _:
continue

case _:
pass

match tool_selection:
case "auto":
result = await client.generate(
Expand Down Expand Up @@ -445,7 +464,8 @@ async def _generate( # noqa: PLR0913, C901, PLR0912, PLR0915

message_parts: list[
GeminiTextMessageContent | GeminiDataReferenceMessageContent | GeminiDataMessageContent
] = []
] = [GeminiTextMessageContent(text=prefill)] if prefill else []

tool_calls: list[GeminiFunctionCallMessageContent] = []
for part in result_message.content:
match part:
Expand Down Expand Up @@ -483,7 +503,12 @@ async def _generate( # noqa: PLR0913, C901, PLR0912, PLR0915

elif message_parts:
ctx.record(ResultTrace.of(message_parts))
return LMMCompletion.of(*[_convert_content_part(part) for part in message_parts])
return LMMCompletion.of(
MultimodalContent.of(
*[_convert_content_part(part) for part in message_parts],
merge_text=True,
)
)

else:
raise GeminiException("Invalid Gemini completion", result)
Expand Down
9 changes: 5 additions & 4 deletions src/draive/generation/model/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913, C901, PLR0
]
],
LMMInput.of(input),
LMMCompletion.of("{"), # prefill with json opening
]

recursion_level: int = 0
Expand Down Expand Up @@ -146,10 +147,10 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913, C901, PLR0

DEFAULT_INSTRUCTION_EXTENSION: str = """\
<FORMAT>
The result have to be a JSON object conforming to the following schema:
```
Provide the result using a single raw valid JSON object that adheres strictly to the given \
SCHEMA without any comments, formatting, or additional elements.
<SCHEMA>
{schema}
```
Provide ONLY a single, raw, valid JSON without any comments, formatting or additional elements.
</SCHEMA>
</FORMAT>
"""
2 changes: 2 additions & 0 deletions src/draive/generation/text/call.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ async def generate_text(
*,
instruction: Instruction | str,
input: MultimodalContent | MultimodalContentConvertible, # noqa: A002
prefill: str | None = None,
tools: Toolbox | Sequence[AnyTool] | None = None,
examples: Iterable[tuple[MultimodalContent | MultimodalContentConvertible, str]] | None = None,
**extra: Any,
) -> str:
return await ctx.state(TextGeneration).generate(
instruction=instruction,
input=input,
prefill=prefill,
tools=tools,
examples=examples,
**extra,
Expand Down
1 change: 1 addition & 0 deletions src/draive/generation/text/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ async def __call__(
*,
instruction: Instruction | str,
input: MultimodalContent | MultimodalContentConvertible, # noqa: A002
prefill: str | None = None,
tools: Toolbox | Sequence[AnyTool] | None = None,
examples: Iterable[tuple[MultimodalContent | MultimodalContentConvertible, str]]
| None = None,
Expand Down
6 changes: 6 additions & 0 deletions src/draive/generation/text/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async def lmm_generate_text(
*,
instruction: Instruction | str,
input: MultimodalContent | MultimodalContentConvertible, # noqa: A002
prefill: str | None = None,
tools: Toolbox | Sequence[AnyTool] | None = None,
examples: Iterable[tuple[MultimodalContent | MultimodalContentConvertible, str]] | None = None,
**extra: Any,
Expand Down Expand Up @@ -51,6 +52,11 @@ async def lmm_generate_text(
LMMInput.of(input),
]

if prefill := prefill:
context.append(
LMMCompletion.of(prefill),
)

recursion_level: int = 0
while recursion_level <= toolbox.recursion_limit:
match await lmm_invocation(
Expand Down
7 changes: 7 additions & 0 deletions src/draive/mistral/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ async def chat_completion(
raise NotImplementedError("Mistral streaming is not supported yet")

else:
if messages[-1]["role"] == "assistant":
if config.response_format == {"type": "json_object"}:
del messages[-1] # for json mode ignore prefill

else:
messages[-1]["prefix"] = True # add prefill parameter indicator

return await self._create_chat_completion(
messages=messages,
model=config.model,
Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ async def _chat_completion(
),
tool_choice="auto",
)

case "none":
completion = await client.chat_completion(
config=config,
Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ChatMessage(TypedDict, total=False):
content: Required[str | list[str]]
name: NotRequired[str]
tool_calls: NotRequired[list[ChatToolCallRequest]]
prefix: NotRequired[bool]


class ChatFunctionCallResponse(DataModel):
Expand Down
18 changes: 11 additions & 7 deletions src/draive/ollama/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from draive.metrics.tokens import TokenUsage
from draive.ollama.client import OllamaClient
from draive.ollama.config import OllamaChatConfig
from draive.ollama.errors import OllamaException
from draive.ollama.models import ChatCompletionResponse, ChatMessage
from draive.scope import ctx
from draive.types import (
Expand Down Expand Up @@ -162,6 +161,14 @@ async def _chat_completion(
config: OllamaChatConfig,
messages: list[ChatMessage],
) -> LMMOutput:
prefill: str = ""
if messages[-1].role == "assistant":
if config.response_format == "json":
del messages[-1] # for json mode ignore prefill

else:
prefill = messages[-1].content

completion: ChatCompletionResponse = await client.chat_completion(
config=config,
messages=messages,
Expand All @@ -175,12 +182,9 @@ async def _chat_completion(
),
)

if message := completion.message.content:
ctx.record(ResultTrace.of(message))
return LMMCompletion.of(message)

else:
raise OllamaException("Invalid Ollama completion", completion)
completion_message: str = prefill + completion.message.content
ctx.record(ResultTrace.of(completion_message))
return LMMCompletion.of(completion_message)


async def _chat_completion_stream(
Expand Down
Loading

0 comments on commit b55ff10

Please sign in to comment.