Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unpickleable fields due to openai client #79

Merged
merged 5 commits into from
Apr 2, 2024
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,13 @@

from typing import Optional, Dict, Any, List
from openai import OpenAI, NotFoundError
from pydantic import model_validator, field_validator, ValidationInfo, Field
from pydantic import (
model_validator,
field_validator,
ValidationInfo,
Field,
computed_field,
)
from .base import Runtime, AsyncRuntime
from adala.utils.logs import print_error
from adala.utils.internal_data import InternalDataFrame, InternalSeries
@@ -156,19 +162,21 @@ class OpenAIChatRuntime(Runtime):
max_tokens: Maximum number of tokens to generate. Defaults to 1000.
"""

class Config:
arbitrary_types_allowed = True # for @computed_field

openai_model: str = Field(alias="model")
openai_api_key: Optional[str] = Field(
default=os.getenv("OPENAI_API_KEY"), alias="api_key"
)
max_tokens: Optional[int] = 1000
splitter: Optional[str] = None

_client: OpenAI = None
@computed_field
def _client(self) -> OpenAI:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need computed_field here but create a client on the fly in the Async version? let's use the unique approach if possible

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the sync runtime actually uses its _client, but the async runtime does not, it directly sends post requests in async_create_completion. Even if it did use the python client lib instead of requests, it wouldn't make sense for the class to carry a single client instance because it would create them on demand to use them concurrently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@niklub if you feel strongly about this I'll change it, but I think it's a meaningful difference that the async openai runtime doesn't need a self._client

return OpenAI(api_key=self.openai_api_key)

def init_runtime(self) -> "Runtime":
if self._client is None:
self._client = OpenAI(api_key=self.openai_api_key)

# check model availability
try:
self._client.models.retrieve(self.openai_model)
@@ -282,15 +290,12 @@ class AsyncOpenAIChatRuntime(AsyncRuntime):
concurrent_clients: Optional[int] = 10
timeout: Optional[int] = 10

_client: OpenAI = None

def init_runtime(self) -> "Runtime":
if self._client is None:
self._client = OpenAI(api_key=self.openai_api_key)

# check model availability
try:
self._client.models.retrieve(self.openai_model)
_client = OpenAI(api_key=self.openai_api_key)
_client.models.retrieve(self.openai_model)
except NotFoundError:
raise ValueError(
f'Requested model "{self.openai_model}" is not available in your OpenAI account.'
Loading