Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize parameters processing #73

Merged
merged 1 commit into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 7 additions & 9 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ certifi==2024.2.2
# requests
charset-normalizer==3.3.2
# via requests
coverage==7.5.2
coverage==7.5.3
# via pytest-cov
distro==1.9.0
# via openai
Expand All @@ -38,27 +38,27 @@ markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
nodeenv==1.8.0
nodeenv==1.9.0
# via pyright
numpy==1.26.4
# via draive (pyproject.toml)
openai==1.30.3
openai==1.30.4
# via draive (pyproject.toml)
packaging==24.0
# via pytest
pbr==6.0.0
# via stevedore
pluggy==1.5.0
# via pytest
pydantic==2.7.1
pydantic==2.7.2
# via
# draive (pyproject.toml)
# openai
pydantic-core==2.18.2
pydantic-core==2.18.3
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.364
pyright==1.1.365
# via draive (pyproject.toml)
pytest==7.4.4
# via
Expand All @@ -77,10 +77,8 @@ requests==2.32.2
# via tiktoken
rich==13.7.1
# via bandit
ruff==0.4.5
ruff==0.4.6
# via draive (pyproject.toml)
setuptools==70.0.0
# via nodeenv
sniffio==1.3.1
# via
# anyio
Expand Down
22 changes: 14 additions & 8 deletions examples/BasicConversation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "OpenAIChatConfig.__init__() missing 8 required keyword-only arguments: 'temperature', 'top_p', 'frequency_penalty', 'max_tokens', 'seed', 'response_format', 'vision_details', and 'timeout'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 17\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdraive\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m LMM,\n\u001b[1;32m 3\u001b[0m ConversationMessage,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m openai_lmm_invocation,\n\u001b[1;32m 10\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# initialize dependencies and configuration\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m ctx\u001b[38;5;241m.\u001b[39mnew(\n\u001b[1;32m 14\u001b[0m dependencies\u001b[38;5;241m=\u001b[39m[OpenAIClient], \u001b[38;5;66;03m# use OpenAI client\u001b[39;00m\n\u001b[1;32m 15\u001b[0m state\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 16\u001b[0m LMM(invocation\u001b[38;5;241m=\u001b[39mopenai_lmm_invocation), \u001b[38;5;66;03m# define used LMM\u001b[39;00m\n\u001b[0;32m---> 17\u001b[0m \u001b[43mOpenAIChatConfig\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgpt-3.5-turbo-0125\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m, \u001b[38;5;66;03m# configure OpenAI model\u001b[39;00m\n\u001b[1;32m 18\u001b[0m ],\n\u001b[1;32m 19\u001b[0m ):\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# request conversation completion\u001b[39;00m\n\u001b[1;32m 21\u001b[0m response: ConversationMessage \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m conversation_completion( \u001b[38;5;66;03m# noqa: PLE1142\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# provide a prompt instruction\u001b[39;00m\n\u001b[1;32m 23\u001b[0m instruction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are a helpful assistant.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 29\u001b[0m ),\n\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28mprint\u001b[39m(response)\n",
"\u001b[0;31mTypeError\u001b[0m: OpenAIChatConfig.__init__() missing 8 required keyword-only arguments: 'temperature', 'top_p', 'frequency_penalty', 'max_tokens', 'seed', 'response_format', 'vision_details', and 'timeout'"
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"identifier\": \"d5b98674fd024c0b977a49bfd9d3caa7\",\n",
" \"role\": \"model\",\n",
" \"author\": null,\n",
" \"created\": \"2024-05-29T15:26:35.346493+00:00\",\n",
" \"content\": {\n",
" \"elements\": [\n",
" \"The current UTC time and date is Wednesday, 29 May 2024, 15:26:34.\"\n",
" ]\n",
" }\n",
"}\n"
]
}
],
Expand Down
21 changes: 18 additions & 3 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,18 @@
openai_lmm_invocation,
openai_tokenize_text,
)
from draive.parameters import Argument, BasicValue, DataModel, Field, ParameterPath, State
from draive.parameters import (
Argument,
BasicValue,
DataModel,
Field,
ParameterDefaultFactory,
ParameterPath,
ParameterValidationContext,
ParameterValidator,
ParameterVerifier,
State,
)
from draive.scope import (
ScopeDependencies,
ScopeDependency,
Expand Down Expand Up @@ -119,8 +130,8 @@
getenv_float,
getenv_int,
getenv_str,
is_missing,
load_env,
missing,
not_missing,
setup_logging,
split_sequence,
Expand Down Expand Up @@ -197,7 +208,7 @@
"MetricsTrace",
"MetricsTraceReport",
"MetricsTraceReporter",
"missing",
"is_missing",
"Missing",
"MISSING",
"mistral_embed_text",
Expand Down Expand Up @@ -252,4 +263,8 @@
"VideoContent",
"VideoDataContent",
"VideoURLContent",
"ParameterDefaultFactory",
"ParameterValidationContext",
"ParameterValidator",
"ParameterVerifier",
]
5 changes: 5 additions & 0 deletions src/draive/conversation/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def _lmm_conversation_completion(
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received conversation result")
response_message: ConversationMessage = ConversationMessage(
role="model",
created=datetime.now(UTC),
Expand All @@ -176,6 +177,7 @@ async def _lmm_conversation_completion(
return response_message

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received conversation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down Expand Up @@ -221,6 +223,7 @@ async def _lmm_conversation_completion_stream(
):
match part:
case LMMCompletionChunk() as chunk:
ctx.log_debug("Received conversation result chunk")
response_content = response_content.extending(chunk.content)

yield ConversationMessageChunk(
Expand All @@ -230,6 +233,7 @@ async def _lmm_conversation_completion_stream(
# keep yielding parts

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received conversation tool calls")
assert ( # nosec: B101
not response_content
), "Tools and completion message should not be used at the same time"
Expand Down Expand Up @@ -264,6 +268,7 @@ async def _lmm_conversation_completion_stream(
break # exit the loop with result

if response_content:
ctx.log_debug("Remembering conversation result")
# remember messages when finishing stream
await conversation_memory.remember(
request_message,
Expand Down
8 changes: 5 additions & 3 deletions src/draive/generation/model/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]


async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913, C901
generated: type[Generated],
/,
*,
Expand Down Expand Up @@ -57,7 +57,7 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
if variable := schema_variable:
instruction_message = LMMInstruction.of(
generation_instruction.updated(
**{variable: generated.json_schema()},
**{variable: generated.json_schema(indent=2)},
),
)

Expand All @@ -66,7 +66,7 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
generation_instruction.extended(
DEFAULT_INSTRUCTION_EXTENSION,
joiner="\n\n",
schema=generated.json_schema(),
schema=generated.json_schema(indent=2),
)
)

Expand All @@ -93,9 +93,11 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received model generation result")
return generated.from_json(completion.content.as_string())

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received model generation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down
2 changes: 2 additions & 0 deletions src/draive/generation/text/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ async def lmm_generate_text(
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received text generation result")
return completion.content.as_string()

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received text generation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down
56 changes: 39 additions & 17 deletions src/draive/lmm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
)
from uuid import uuid4

from draive.lmm.errors import ToolException
from draive.lmm.state import ToolCallContext, ToolStatusStream
from draive.metrics import ArgumentsTrace, ResultTrace
from draive.parameters import Function, ParametrizedFunction, ToolSpecification
from draive.parameters import (
Function,
ParameterSpecification,
ParametrizedFunction,
ToolSpecification,
)
from draive.scope import ctx
from draive.types import MultimodalContent, MultimodalContentElement
from draive.utils import freeze, not_missing
Expand All @@ -23,7 +29,7 @@


class ToolAvailabilityCheck(Protocol):
def __call__(self) -> None: ...
def __call__(self) -> bool: ...


@final
Expand All @@ -41,23 +47,38 @@ def __init__( # noqa: PLR0913
direct_result: bool = False,
) -> None:
super().__init__(function=function)
if not_missing(self._parameters.specification):
self.specification: ToolSpecification = {
"type": "function",
"function": {
"name": name,
"parameters": self._parameters.specification,
"description": description or "",
},
}
aliased_required: list[str] = []
parameters: dict[str, ParameterSpecification] = {}
for parameter in self._parameters.values():
if not_missing(parameter.specification):
parameters[parameter.alias or parameter.name] = parameter.specification

else:
raise TypeError(
f"{function.__qualname__} can't be represented as a tool"
f" - argument '{parameter.name}' is missing specification."
)

else:
raise TypeError(f"{function.__qualname__} can't be represented as a tool")
if not (parameter.has_default and parameter.allows_missing):
aliased_required.append(parameter.alias or parameter.name)

self.specification: ToolSpecification = {
"type": "function",
"function": {
"name": name,
"parameters": {
"type": "object",
"properties": parameters,
"required": aliased_required,
},
"description": description or "",
},
}

self.name: str = name
self._direct_result: bool = direct_result
self._check_availability: ToolAvailabilityCheck = availability_check or (
lambda: None # available by default
lambda: True # available by default
)
self.format_result: Callable[[Result], MultimodalContent | MultimodalContentElement] = (
format_result
Expand All @@ -71,8 +92,7 @@ def __init__( # noqa: PLR0913
@property
def available(self) -> bool:
try:
self._check_availability()
return True
return self._check_availability()

except Exception:
return False
Expand Down Expand Up @@ -101,7 +121,9 @@ async def _toolbox_call(
call_context.report("STARTED")

try:
self._check_availability()
if not self.available:
raise ToolException(f"{self.name} is not available!")

result: Result = await super().__call__(**arguments) # pyright: ignore[reportCallIssue]
ctx.record(ResultTrace.of(result))

Expand Down
4 changes: 2 additions & 2 deletions src/draive/metrics/log_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from draive.metrics.reporter import MetricsTraceReport, MetricsTraceReporter
from draive.parameters import ParametrizedData
from draive.utils import missing
from draive.utils import is_missing

__all__ = [
"metrics_log_reporter",
Expand Down Expand Up @@ -150,7 +150,7 @@ def _raw_value_report(
list_items_limit: int | None,
item_character_limit: int | None,
) -> str | None:
if missing(value):
if is_missing(value):
return None # skip missing

# workaround for pydantic models
Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def mistral_lmm_invocation(
),
],
):
ctx.log_debug("Requested Mistral lmm")
client: MistralClient = ctx.dependency(MistralClient)
config: MistralChatConfig = ctx.state(MistralChatConfig).updated(**extra)
match output:
Expand Down
3 changes: 2 additions & 1 deletion src/draive/openai/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def openai_lmm_invocation(
),
],
):
ctx.log_debug("Requested OpenAI lmm")
client: OpenAIClient = ctx.dependency(OpenAIClient)
config: OpenAIChatConfig = ctx.state(OpenAIChatConfig).updated(**extra)
match output:
Expand Down Expand Up @@ -138,7 +139,7 @@ async def openai_lmm_invocation(
)


def _convert_content_element(
def _convert_content_element( # noqa: C901
element: MultimodalContentElement,
config: OpenAIChatConfig,
) -> ChatCompletionContentPartParam:
Expand Down
Loading
Loading