Skip to content

Commit

Permalink
Optimize parameters processing
Browse files Browse the repository at this point in the history
  • Loading branch information
KaQuMiQ committed May 29, 2024
1 parent 2816886 commit 384b3d5
Show file tree
Hide file tree
Showing 24 changed files with 677 additions and 499 deletions.
16 changes: 7 additions & 9 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ certifi==2024.2.2
# requests
charset-normalizer==3.3.2
# via requests
coverage==7.5.2
coverage==7.5.3
# via pytest-cov
distro==1.9.0
# via openai
Expand All @@ -38,27 +38,27 @@ markdown-it-py==3.0.0
# via rich
mdurl==0.1.2
# via markdown-it-py
nodeenv==1.8.0
nodeenv==1.9.0
# via pyright
numpy==1.26.4
# via draive (pyproject.toml)
openai==1.30.3
openai==1.30.4
# via draive (pyproject.toml)
packaging==24.0
# via pytest
pbr==6.0.0
# via stevedore
pluggy==1.5.0
# via pytest
pydantic==2.7.1
pydantic==2.7.2
# via
# draive (pyproject.toml)
# openai
pydantic-core==2.18.2
pydantic-core==2.18.3
# via pydantic
pygments==2.18.0
# via rich
pyright==1.1.364
pyright==1.1.365
# via draive (pyproject.toml)
pytest==7.4.4
# via
Expand All @@ -77,10 +77,8 @@ requests==2.32.2
# via tiktoken
rich==13.7.1
# via bandit
ruff==0.4.5
ruff==0.4.6
# via draive (pyproject.toml)
setuptools==70.0.0
# via nodeenv
sniffio==1.3.1
# via
# anyio
Expand Down
22 changes: 14 additions & 8 deletions examples/BasicConversation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,20 @@
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "OpenAIChatConfig.__init__() missing 8 required keyword-only arguments: 'temperature', 'top_p', 'frequency_penalty', 'max_tokens', 'seed', 'response_format', 'vision_details', and 'timeout'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 17\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mdraive\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m LMM,\n\u001b[1;32m 3\u001b[0m ConversationMessage,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 9\u001b[0m openai_lmm_invocation,\n\u001b[1;32m 10\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[38;5;66;03m# initialize dependencies and configuration\u001b[39;00m\n\u001b[1;32m 13\u001b[0m \u001b[38;5;28;01masync\u001b[39;00m \u001b[38;5;28;01mwith\u001b[39;00m ctx\u001b[38;5;241m.\u001b[39mnew(\n\u001b[1;32m 14\u001b[0m dependencies\u001b[38;5;241m=\u001b[39m[OpenAIClient], \u001b[38;5;66;03m# use OpenAI client\u001b[39;00m\n\u001b[1;32m 15\u001b[0m state\u001b[38;5;241m=\u001b[39m[\n\u001b[1;32m 16\u001b[0m LMM(invocation\u001b[38;5;241m=\u001b[39mopenai_lmm_invocation), \u001b[38;5;66;03m# define used LMM\u001b[39;00m\n\u001b[0;32m---> 17\u001b[0m \u001b[43mOpenAIChatConfig\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mgpt-3.5-turbo-0125\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m, \u001b[38;5;66;03m# configure OpenAI model\u001b[39;00m\n\u001b[1;32m 18\u001b[0m ],\n\u001b[1;32m 19\u001b[0m ):\n\u001b[1;32m 20\u001b[0m \u001b[38;5;66;03m# request conversation completion\u001b[39;00m\n\u001b[1;32m 21\u001b[0m response: ConversationMessage \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mawait\u001b[39;00m conversation_completion( \u001b[38;5;66;03m# noqa: PLE1142\u001b[39;00m\n\u001b[1;32m 22\u001b[0m \u001b[38;5;66;03m# provide a prompt instruction\u001b[39;00m\n\u001b[1;32m 23\u001b[0m instruction\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mYou are a helpful assistant.\u001b[39m\u001b[38;5;124m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 29\u001b[0m ),\n\u001b[1;32m 30\u001b[0m )\n\u001b[1;32m 31\u001b[0m \u001b[38;5;28mprint\u001b[39m(response)\n",
"\u001b[0;31mTypeError\u001b[0m: OpenAIChatConfig.__init__() missing 8 required keyword-only arguments: 'temperature', 'top_p', 'frequency_penalty', 'max_tokens', 'seed', 'response_format', 'vision_details', and 'timeout'"
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"identifier\": \"d5b98674fd024c0b977a49bfd9d3caa7\",\n",
" \"role\": \"model\",\n",
" \"author\": null,\n",
" \"created\": \"2024-05-29T15:26:35.346493+00:00\",\n",
" \"content\": {\n",
" \"elements\": [\n",
" \"The current UTC time and date is Wednesday, 29 May 2024, 15:26:34.\"\n",
" ]\n",
" }\n",
"}\n"
]
}
],
Expand Down
21 changes: 18 additions & 3 deletions src/draive/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,18 @@
openai_lmm_invocation,
openai_tokenize_text,
)
from draive.parameters import Argument, BasicValue, DataModel, Field, ParameterPath, State
from draive.parameters import (
Argument,
BasicValue,
DataModel,
Field,
ParameterDefaultFactory,
ParameterPath,
ParameterValidationContext,
ParameterValidator,
ParameterVerifier,
State,
)
from draive.scope import (
ScopeDependencies,
ScopeDependency,
Expand Down Expand Up @@ -119,8 +130,8 @@
getenv_float,
getenv_int,
getenv_str,
is_missing,
load_env,
missing,
not_missing,
setup_logging,
split_sequence,
Expand Down Expand Up @@ -197,7 +208,7 @@
"MetricsTrace",
"MetricsTraceReport",
"MetricsTraceReporter",
"missing",
"is_missing",
"Missing",
"MISSING",
"mistral_embed_text",
Expand Down Expand Up @@ -252,4 +263,8 @@
"VideoContent",
"VideoDataContent",
"VideoURLContent",
"ParameterDefaultFactory",
"ParameterValidationContext",
"ParameterValidator",
"ParameterVerifier",
]
5 changes: 5 additions & 0 deletions src/draive/conversation/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ async def _lmm_conversation_completion(
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received conversation result")
response_message: ConversationMessage = ConversationMessage(
role="model",
created=datetime.now(UTC),
Expand All @@ -176,6 +177,7 @@ async def _lmm_conversation_completion(
return response_message

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received conversation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down Expand Up @@ -221,6 +223,7 @@ async def _lmm_conversation_completion_stream(
):
match part:
case LMMCompletionChunk() as chunk:
ctx.log_debug("Received conversation result chunk")
response_content = response_content.extending(chunk.content)

yield ConversationMessageChunk(
Expand All @@ -230,6 +233,7 @@ async def _lmm_conversation_completion_stream(
# keep yielding parts

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received conversation tool calls")
assert ( # nosec: B101
not response_content
), "Tools and completion message should not be used at the same time"
Expand Down Expand Up @@ -264,6 +268,7 @@ async def _lmm_conversation_completion_stream(
break # exit the loop with result

if response_content:
ctx.log_debug("Remembering conversation result")
# remember messages when finishing stream
await conversation_memory.remember(
request_message,
Expand Down
8 changes: 5 additions & 3 deletions src/draive/generation/model/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]


async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913, C901
generated: type[Generated],
/,
*,
Expand Down Expand Up @@ -57,7 +57,7 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
if variable := schema_variable:
instruction_message = LMMInstruction.of(
generation_instruction.updated(
**{variable: generated.json_schema()},
**{variable: generated.json_schema(indent=2)},
),
)

Expand All @@ -66,7 +66,7 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
generation_instruction.extended(
DEFAULT_INSTRUCTION_EXTENSION,
joiner="\n\n",
schema=generated.json_schema(),
schema=generated.json_schema(indent=2),
)
)

Expand All @@ -93,9 +93,11 @@ async def lmm_generate_model[Generated: DataModel]( # noqa: PLR0913
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received model generation result")
return generated.from_json(completion.content.as_string())

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received model generation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down
2 changes: 2 additions & 0 deletions src/draive/generation/text/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ async def lmm_generate_text(
**extra,
):
case LMMCompletion() as completion:
ctx.log_debug("Received text generation result")
return completion.content.as_string()

case LMMToolRequests() as tool_requests:
ctx.log_debug("Received text generation tool calls")
context.append(tool_requests)
responses: list[LMMToolResponse] = await toolbox.respond(tool_requests)

Expand Down
56 changes: 39 additions & 17 deletions src/draive/lmm/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,15 @@
)
from uuid import uuid4

from draive.lmm.errors import ToolException
from draive.lmm.state import ToolCallContext, ToolStatusStream
from draive.metrics import ArgumentsTrace, ResultTrace
from draive.parameters import Function, ParametrizedFunction, ToolSpecification
from draive.parameters import (
Function,
ParameterSpecification,
ParametrizedFunction,
ToolSpecification,
)
from draive.scope import ctx
from draive.types import MultimodalContent, MultimodalContentElement
from draive.utils import freeze, not_missing
Expand All @@ -23,7 +29,7 @@


class ToolAvailabilityCheck(Protocol):
def __call__(self) -> None: ...
def __call__(self) -> bool: ...


@final
Expand All @@ -41,23 +47,38 @@ def __init__( # noqa: PLR0913
direct_result: bool = False,
) -> None:
super().__init__(function=function)
if not_missing(self._parameters.specification):
self.specification: ToolSpecification = {
"type": "function",
"function": {
"name": name,
"parameters": self._parameters.specification,
"description": description or "",
},
}
aliased_required: list[str] = []
parameters: dict[str, ParameterSpecification] = {}
for parameter in self._parameters.values():
if not_missing(parameter.specification):
parameters[parameter.alias or parameter.name] = parameter.specification

else:
raise TypeError(
f"{function.__qualname__} can't be represented as a tool"
f" - argument '{parameter.name}' is missing specification."
)

else:
raise TypeError(f"{function.__qualname__} can't be represented as a tool")
if not (parameter.has_default and parameter.allows_missing):
aliased_required.append(parameter.alias or parameter.name)

self.specification: ToolSpecification = {
"type": "function",
"function": {
"name": name,
"parameters": {
"type": "object",
"properties": parameters,
"required": aliased_required,
},
"description": description or "",
},
}

self.name: str = name
self._direct_result: bool = direct_result
self._check_availability: ToolAvailabilityCheck = availability_check or (
lambda: None # available by default
lambda: True # available by default
)
self.format_result: Callable[[Result], MultimodalContent | MultimodalContentElement] = (
format_result
Expand All @@ -71,8 +92,7 @@ def __init__( # noqa: PLR0913
@property
def available(self) -> bool:
try:
self._check_availability()
return True
return self._check_availability()

except Exception:
return False
Expand Down Expand Up @@ -101,7 +121,9 @@ async def _toolbox_call(
call_context.report("STARTED")

try:
self._check_availability()
if not self.available:
raise ToolException(f"{self.name} is not available!")

result: Result = await super().__call__(**arguments) # pyright: ignore[reportCallIssue]
ctx.record(ResultTrace.of(result))

Expand Down
4 changes: 2 additions & 2 deletions src/draive/metrics/log_reporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from draive.metrics.reporter import MetricsTraceReport, MetricsTraceReporter
from draive.parameters import ParametrizedData
from draive.utils import missing
from draive.utils import is_missing

__all__ = [
"metrics_log_reporter",
Expand Down Expand Up @@ -150,7 +150,7 @@ def _raw_value_report(
list_items_limit: int | None,
item_character_limit: int | None,
) -> str | None:
if missing(value):
if is_missing(value):
return None # skip missing

# workaround for pydantic models
Expand Down
1 change: 1 addition & 0 deletions src/draive/mistral/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ async def mistral_lmm_invocation(
),
],
):
ctx.log_debug("Requested Mistral lmm")
client: MistralClient = ctx.dependency(MistralClient)
config: MistralChatConfig = ctx.state(MistralChatConfig).updated(**extra)
match output:
Expand Down
3 changes: 2 additions & 1 deletion src/draive/openai/lmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ async def openai_lmm_invocation(
),
],
):
ctx.log_debug("Requested OpenAI lmm")
client: OpenAIClient = ctx.dependency(OpenAIClient)
config: OpenAIChatConfig = ctx.state(OpenAIChatConfig).updated(**extra)
match output:
Expand Down Expand Up @@ -138,7 +139,7 @@ async def openai_lmm_invocation(
)


def _convert_content_element(
def _convert_content_element( # noqa: C901
element: MultimodalContentElement,
config: OpenAIChatConfig,
) -> ChatCompletionContentPartParam:
Expand Down
Loading

0 comments on commit 384b3d5

Please sign in to comment.