Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow extending api with kwargs #65

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions constraints
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ certifi==2024.2.2
# requests
charset-normalizer==3.3.2
# via requests
coverage==7.5.0
coverage==7.5.1
# via pytest-cov
distro==1.9.0
# via openai
Expand All @@ -38,7 +38,7 @@ mdurl==0.1.2
nodeenv==1.8.0
# via pyright
numpy==1.26.4
openai==1.23.6
openai==1.25.2
packaging==24.0
# via pytest
pbr==6.0.0
Expand All @@ -49,9 +49,9 @@ pydantic==2.7.1
# via openai
pydantic-core==2.18.2
# via pydantic
pygments==2.17.2
pygments==2.18.0
# via rich
pyright==1.1.360
pyright==1.1.361
pytest==7.4.4
# via
# pytest-asyncio
Expand All @@ -60,13 +60,13 @@ pytest-asyncio==0.23.6
pytest-cov==4.1.0
pyyaml==6.0.1
# via bandit
regex==2024.4.16
regex==2024.4.28
# via tiktoken
requests==2.31.0
# via tiktoken
rich==13.7.1
# via bandit
ruff==0.4.2
ruff==0.4.3
setuptools==69.5.1
# via nodeenv
sniffio==1.3.1
Expand All @@ -77,7 +77,7 @@ sniffio==1.3.1
stevedore==5.2.0
# via bandit
tiktoken==0.6.0
tqdm==4.66.2
tqdm==4.66.4
# via openai
typing-extensions==4.11.0
# via
Expand Down
7 changes: 6 additions & 1 deletion src/draive/embedding/call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from typing import Any

from draive.embedding.embedded import Embedded
from draive.embedding.state import Embedding
Expand All @@ -11,5 +12,9 @@

async def embed_text(
values: Iterable[str],
**extra: Any,
) -> list[Embedded[str]]:
return await ctx.state(Embedding).embed_text(values=values)
return await ctx.state(Embedding).embed_text(
values=values,
**extra,
)
3 changes: 2 additions & 1 deletion src/draive/embedding/embedder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable

from draive.embedding.embedded import Embedded

Expand All @@ -13,4 +13,5 @@ class Embedder[Value](Protocol):
async def __call__(
self,
values: Iterable[Value],
**extra: Any,
) -> list[Embedded[Value]]: ...
4 changes: 4 additions & 0 deletions src/draive/generation/image/call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any

from draive.generation.image.state import ImageGeneration
from draive.scope import ctx
from draive.types import ImageContent
Expand All @@ -10,7 +12,9 @@
async def generate_image(
*,
instruction: str,
**extra: Any,
) -> ImageContent:
return await ctx.state(ImageGeneration).generate(
instruction=instruction,
**extra,
)
6 changes: 3 additions & 3 deletions src/draive/generation/image/generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable

from draive.types import ImageContent

Expand All @@ -13,5 +13,5 @@ async def __call__(
self,
*,
instruction: str,
) -> ImageContent:
...
**extra: Any,
) -> ImageContent: ...
8 changes: 6 additions & 2 deletions src/draive/generation/model/call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from typing import Any

from draive.generation.model.state import ModelGeneration
from draive.scope import ctx
Expand All @@ -11,18 +12,21 @@


async def generate_model[Generated: Model](
model: type[Generated],
generated: type[Generated],
/,
*,
instruction: str,
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, Generated]] | None = None,
**extra: Any,
) -> Generated:
model_generation: ModelGeneration = ctx.state(ModelGeneration)
return await model_generation.generate(
model,
generated,
instruction=instruction,
input=input,
tools=tools or model_generation.tools,
examples=examples,
**extra,
)
6 changes: 4 additions & 2 deletions src/draive/generation/model/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable

from draive.tools import Toolbox
from draive.types import Model, MultimodalContent
Expand All @@ -13,10 +13,12 @@
class ModelGenerator(Protocol):
async def __call__[Generated: Model]( # noqa: PLR0913
self,
model: type[Generated],
generated: type[Generated],
/,
*,
instruction: str,
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, Generated]] | None = None,
**extra: Any,
) -> Generated: ...
12 changes: 8 additions & 4 deletions src/draive/generation/model/lmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from typing import Any

from draive.lmm import LMMCompletionMessage, lmm_completion
from draive.tools import Toolbox
Expand All @@ -10,18 +11,20 @@


async def lmm_generate_model[Generated: Model](
model: type[Generated],
generated: type[Generated],
/,
*,
instruction: str,
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, Generated]] | None = None,
**extra: Any,
) -> Generated:
system_message: LMMCompletionMessage = LMMCompletionMessage(
role="system",
content=INSTRUCTION.format(
instruction=instruction,
format=model.specification(),
format=generated.specification(),
),
)
input_message: LMMCompletionMessage = LMMCompletionMessage(
Expand Down Expand Up @@ -61,10 +64,11 @@ async def lmm_generate_model[Generated: Model](
context=context,
tools=tools,
output="json",
stream=False,
**extra,
)
generated: Generated = model.from_json(completion.content_string)

return generated
return generated.from_json(completion.content_string)


INSTRUCTION: str = """\
Expand Down
3 changes: 3 additions & 0 deletions src/draive/generation/text/call.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from typing import Any

from draive.generation.text.state import TextGeneration
from draive.scope import ctx
Expand All @@ -16,11 +17,13 @@ async def generate_text(
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, str]] | None = None,
**extra: Any,
) -> str:
text_generation: TextGeneration = ctx.state(TextGeneration)
return await text_generation.generate(
instruction=instruction,
input=input,
tools=tools or text_generation.tools,
examples=examples,
**extra,
)
6 changes: 3 additions & 3 deletions src/draive/generation/text/generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Iterable
from typing import Protocol, runtime_checkable
from typing import Any, Protocol, runtime_checkable

from draive.tools import Toolbox
from draive.types import MultimodalContent
Expand All @@ -18,5 +18,5 @@ async def __call__(
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, str]] | None = None,
) -> str:
...
**extra: Any,
) -> str: ...
4 changes: 4 additions & 0 deletions src/draive/generation/text/lmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections.abc import Iterable
from typing import Any

from draive.lmm import LMMCompletionMessage, lmm_completion
from draive.tools import Toolbox
Expand All @@ -15,6 +16,7 @@ async def lmm_generate_text(
input: MultimodalContent, # noqa: A002
tools: Toolbox | None = None,
examples: Iterable[tuple[MultimodalContent, str]] | None = None,
**extra: Any,
) -> str:
system_message: LMMCompletionMessage = LMMCompletionMessage(
role="system",
Expand Down Expand Up @@ -57,6 +59,8 @@ async def lmm_generate_text(
context=context,
tools=tools,
output="text",
stream=False,
**extra,
)
generated: str = completion.content_string

Expand Down
13 changes: 12 additions & 1 deletion src/draive/lmm/call.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Literal, overload
from typing import Any, Literal, overload

from draive.lmm.completion import LMMCompletionStream
from draive.lmm.message import (
Expand All @@ -21,6 +21,7 @@ async def lmm_completion(
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
stream: Literal[True],
**extra: Any,
) -> LMMCompletionStream: ...


Expand All @@ -30,6 +31,7 @@ async def lmm_completion(
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
stream: Callable[[LMMCompletionStreamingUpdate], None],
**extra: Any,
) -> LMMCompletionMessage: ...


Expand All @@ -39,6 +41,8 @@ async def lmm_completion(
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
output: Literal["text", "json"] = "text",
stream: Literal[False] = False,
**extra: Any,
) -> LMMCompletionMessage: ...


Expand All @@ -48,23 +52,30 @@ async def lmm_completion(
tools: Toolbox | None = None,
output: Literal["text", "json"] = "text",
stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False,
**extra: Any,
) -> LMMCompletionStream | LMMCompletionMessage:
match stream:
case False:
return await ctx.state(LMM).completion(
context=context,
tools=tools,
output=output,
stream=False,
**extra,
)
case True:
return await ctx.state(LMM).completion(
context=context,
tools=tools,
output=output,
stream=True,
**extra,
)
case progress:
return await ctx.state(LMM).completion(
context=context,
tools=tools,
output=output,
stream=progress,
**extra,
)
17 changes: 12 additions & 5 deletions src/draive/lmm/completion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Literal, Protocol, Self, overload, runtime_checkable
from typing import Any, Literal, Protocol, Self, overload, runtime_checkable

from draive.lmm.message import (
LMMCompletionMessage,
Expand All @@ -26,26 +26,32 @@ async def __call__(
self,
*,
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
tools: Toolbox | None,
output: Literal["text", "json"],
stream: Literal[True],
**extra: Any,
) -> LMMCompletionStream: ...

@overload
async def __call__(
self,
*,
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
tools: Toolbox | None,
output: Literal["text", "json"],
stream: Callable[[LMMCompletionStreamingUpdate], None],
**extra: Any,
) -> LMMCompletionMessage: ...

@overload
async def __call__(
self,
*,
context: list[LMMCompletionMessage],
tools: Toolbox | None = None,
output: Literal["text", "json"] = "text",
tools: Toolbox | None,
output: Literal["text", "json"],
stream: Literal[False],
**extra: Any,
) -> LMMCompletionMessage: ...

async def __call__(
Expand All @@ -55,4 +61,5 @@ async def __call__(
tools: Toolbox | None = None,
output: Literal["text", "json"] = "text",
stream: Callable[[LMMCompletionStreamingUpdate], None] | bool = False,
**extra: Any,
) -> LMMCompletionStream | LMMCompletionMessage: ...
Loading
Loading