Skip to content

Commit d955895

Browse files
niklubnik
and
nik
authored
fix: DIA-986: Upgrade OpenAI client version & pytest coverage (#78)
Co-authored-by: nik <[email protected]>
1 parent 48ee656 commit d955895

27 files changed

+523
-1275
lines changed

adala/agents/base.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
import logging
2-
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator, SerializeAsAny
2+
from pydantic import (
3+
BaseModel,
4+
Field,
5+
SkipValidation,
6+
field_validator,
7+
model_validator,
8+
SerializeAsAny,
9+
)
310
from abc import ABC, abstractmethod
411
from typing import Any, Optional, List, Dict, Union, Tuple
512
from rich import print
@@ -9,7 +16,6 @@
916
from adala.environments.static_env import StaticEnvironment
1017
from adala.runtimes.base import Runtime, AsyncRuntime
1118
from adala.runtimes._openai import OpenAIChatRuntime
12-
from adala.runtimes import GuidanceRuntime
1319
from adala.skills._base import Skill
1420
from adala.memories.base import Memory
1521
from adala.skills.skillset import SkillSet, LinearSkillSet
@@ -53,24 +59,12 @@ class Agent(BaseModel, ABC):
5359

5460
memory: Memory = Field(default=None)
5561
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
56-
default_factory=lambda: {
57-
"default": GuidanceRuntime()
58-
# 'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
59-
# 'llama2': LLMRuntime(
60-
# llm_runtime_type=LLMRuntimeModelType.Transformers,
61-
# llm_params={
62-
# 'model': 'meta-llama/Llama-2-7b',
63-
# 'device': 'cuda:0',
64-
# }
65-
# )
66-
}
62+
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
6763
)
64+
default_runtime: str = "default"
6865
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
69-
default_factory=lambda: {
70-
"default": None
71-
}
66+
default_factory=lambda: {"default": None}
7267
)
73-
default_runtime: str = "default"
7468
default_teacher_runtime: str = "default"
7569

7670
class Config:
@@ -121,9 +115,11 @@ def skills_validator(cls, v) -> SkillSet:
121115
elif isinstance(v, list):
122116
return LinearSkillSet(skills=v)
123117
else:
124-
raise ValueError(f"skills must be of type SkillSet or Skill, but received type {type(v)}")
118+
raise ValueError(
119+
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
120+
)
125121

126-
@field_validator('runtimes', mode='before')
122+
@field_validator("runtimes", mode="before")
127123
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
128124
"""
129125
Validates and creates runtimes
@@ -136,7 +132,9 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
136132
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
137133
)
138134
type_name = runtime_value.pop("type")
139-
runtime_value = Runtime.create_from_registry(type=type_name, **runtime_value)
135+
runtime_value = Runtime.create_from_registry(
136+
type=type_name, **runtime_value
137+
)
140138
out[runtime_name] = runtime_value
141139
return out
142140

@@ -209,9 +207,11 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
209207
raise ValueError(f'Teacher Runtime "{runtime}" not found.')
210208
runtime = self.teacher_runtimes[runtime]
211209
if not runtime:
212-
raise ValueError(f"Teacher Runtime is requested, but it was not set."
213-
f"Please provide a teacher runtime in the agent's constructor explicitly:"
214-
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})")
210+
raise ValueError(
211+
f"Teacher Runtime is requested, but it was not set."
212+
f"Please provide a teacher runtime in the agent's constructor explicitly:"
213+
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})"
214+
)
215215
return runtime
216216

217217
def run(
@@ -269,7 +269,9 @@ async def arun(
269269
# run on the environment until it is exhausted
270270
while True:
271271
try:
272-
data_batch = await self.environment.get_data_batch(batch_size=runtime.batch_size)
272+
data_batch = await self.environment.get_data_batch(
273+
batch_size=runtime.batch_size
274+
)
273275
if data_batch.empty:
274276
print_text("No more data in the environment. Exiting.")
275277
break

adala/environments/base.py

-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ class Config:
133133

134134

135135
class AsyncEnvironment(Environment, ABC):
136-
137136
@abstractmethod
138137
async def initialize(self):
139138
"""

adala/environments/kafka.py

+63-32
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,16 @@ async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
4141
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
4242
yield msg.value
4343
except asyncio.TimeoutError:
44-
print_text(f"No message received within the timeout {timeout} seconds")
44+
print_text(
45+
f"No message received within the timeout {timeout} seconds"
46+
)
4547
break
4648
finally:
4749
await consumer.stop()
4850

49-
async def message_sender(self, producer: AIOKafkaProducer, data: Iterable, topic: str):
51+
async def message_sender(
52+
self, producer: AIOKafkaProducer, data: Iterable, topic: str
53+
):
5054
await producer.start()
5155
try:
5256
for record in data:
@@ -70,20 +74,22 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
7074
consumer = AIOKafkaConsumer(
7175
self.kafka_input_topic,
7276
bootstrap_servers=self.kafka_bootstrap_servers,
73-
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
74-
auto_offset_reset='earliest',
75-
group_id='adala-consumer-group' # TODO: make it configurable based on the environment
77+
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
78+
auto_offset_reset="earliest",
79+
group_id="adala-consumer-group", # TODO: make it configurable based on the environment
7680
)
7781

7882
data_stream = self.message_receiver(consumer)
7983
batch = await self.get_next_batch(data_stream, batch_size)
80-
logger.info(f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}")
84+
logger.info(
85+
f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}"
86+
)
8187
return InternalDataFrame(batch)
8288

8389
async def set_predictions(self, predictions: InternalDataFrame):
8490
producer = AIOKafkaProducer(
8591
bootstrap_servers=self.kafka_bootstrap_servers,
86-
value_serializer=lambda v: json.dumps(v).encode('utf-8')
92+
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
8793
)
8894
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
8995
await self.message_sender(producer, predictions_iter, self.kafka_output_topic)
@@ -109,7 +115,7 @@ def _iter_csv_local(self, csv_file_path):
109115
Read data from the CSV file and push it to the kafka topic.
110116
"""
111117

112-
with open(csv_file_path, 'r') as csv_file:
118+
with open(csv_file_path, "r") as csv_file:
113119
csv_reader = DictReader(csv_file)
114120
for row in csv_reader:
115121
yield row
@@ -120,9 +126,9 @@ def _iter_csv_s3(self, s3_uri):
120126
"""
121127
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
122128
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
123-
s3 = boto3.client('s3')
129+
s3 = boto3.client("s3")
124130
obj = s3.get_object(Bucket=bucket_name, Key=key)
125-
data = obj['Body'].read().decode('utf-8')
131+
data = obj["Body"].read().decode("utf-8")
126132
csv_reader = DictReader(StringIO(data))
127133
for row in csv_reader:
128134
yield row
@@ -140,7 +146,7 @@ async def initialize(self):
140146

141147
producer = AIOKafkaProducer(
142148
bootstrap_servers=self.kafka_bootstrap_servers,
143-
value_serializer=lambda v: json.dumps(v).encode('utf-8')
149+
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
144150
)
145151

146152
await self.message_sender(producer, csv_reader, self.kafka_input_topic)
@@ -153,53 +159,79 @@ async def finalize(self):
153159
consumer = AIOKafkaConsumer(
154160
self.kafka_output_topic,
155161
bootstrap_servers=self.kafka_bootstrap_servers,
156-
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
157-
auto_offset_reset='earliest',
158-
group_id='consumer-group-output-topic' # TODO: make it configurable based on the environment
162+
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
163+
auto_offset_reset="earliest",
164+
group_id="consumer-group-output-topic", # TODO: make it configurable based on the environment
159165
)
160166

161167
data_stream = self.message_receiver(consumer)
162168

163169
if self.output_file.startswith("s3://"):
164-
await self._write_to_s3(self.output_file, self.error_file, data_stream, self.pass_through_columns)
170+
await self._write_to_s3(
171+
self.output_file,
172+
self.error_file,
173+
data_stream,
174+
self.pass_through_columns,
175+
)
165176
else:
166-
await self._write_to_local(self.output_file, self.error_file, data_stream, self.pass_through_columns)
167-
168-
async def _write_to_csv_fileobj(self, fileobj, error_fileobj, data_stream, column_names):
177+
await self._write_to_local(
178+
self.output_file,
179+
self.error_file,
180+
data_stream,
181+
self.pass_through_columns,
182+
)
183+
184+
async def _write_to_csv_fileobj(
185+
self, fileobj, error_fileobj, data_stream, column_names
186+
):
169187
csv_writer, error_csv_writer = None, None
170-
error_columns = ['index', 'message', 'details']
188+
error_columns = ["index", "message", "details"]
171189
while True:
172190
try:
173191
record = await anext(data_stream)
174-
if record.get('error') == True:
192+
if record.get("error") == True:
175193
logger.error(f"Error occurred while processing record: {record}")
176194
if error_csv_writer is None:
177-
error_csv_writer = DictWriter(error_fileobj, fieldnames=error_columns)
195+
error_csv_writer = DictWriter(
196+
error_fileobj, fieldnames=error_columns
197+
)
178198
error_csv_writer.writeheader()
179-
error_csv_writer.writerow({k: record.get(k, '') for k in error_columns})
199+
error_csv_writer.writerow(
200+
{k: record.get(k, "") for k in error_columns}
201+
)
180202
else:
181203
if csv_writer is None:
182204
if column_names is None:
183205
column_names = list(record.keys())
184206
csv_writer = DictWriter(fileobj, fieldnames=column_names)
185207
csv_writer.writeheader()
186-
csv_writer.writerow({k: record.get(k, '') for k in column_names})
208+
csv_writer.writerow({k: record.get(k, "") for k in column_names})
187209
except StopAsyncIteration:
188210
break
189211

190-
async def _write_to_local(self, file_path: str, error_file_path: str, data_stream, column_names):
191-
with open(file_path, 'w') as csv_file, open(error_file_path, 'w') as error_file:
192-
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)
193-
194-
async def _write_to_s3(self, s3_uri: str, s3_uri_errors: str, data_stream, column_names):
212+
async def _write_to_local(
213+
self, file_path: str, error_file_path: str, data_stream, column_names
214+
):
215+
with open(file_path, "w") as csv_file, open(error_file_path, "w") as error_file:
216+
await self._write_to_csv_fileobj(
217+
csv_file, error_file, data_stream, column_names
218+
)
219+
220+
async def _write_to_s3(
221+
self, s3_uri: str, s3_uri_errors: str, data_stream, column_names
222+
):
195223
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
196224
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
197225
error_bucket_name, error_key = s3_uri_errors.replace("s3://", "").split("/", 1)
198-
s3 = boto3.client('s3')
226+
s3 = boto3.client("s3")
199227
with StringIO() as csv_file, StringIO() as error_file:
200-
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)
228+
await self._write_to_csv_fileobj(
229+
csv_file, error_file, data_stream, column_names
230+
)
201231
s3.put_object(Bucket=bucket_name, Key=key, Body=csv_file.getvalue())
202-
s3.put_object(Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue())
232+
s3.put_object(
233+
Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue()
234+
)
203235

204236
async def get_feedback(
205237
self,
@@ -214,4 +246,3 @@ async def restore(self):
214246

215247
async def save(self):
216248
raise NotImplementedError("Save is not supported in Kafka environment")
217-

adala/memories/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .file_memory import FileMemory
2+
from .vectordb import VectorDBMemory
23
from .base import Memory

adala/memories/vectordb.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ class VectorDBMemory(Memory):
1313
"""
1414

1515
db_name: str = ""
16+
openai_api_key: str
17+
openai_embedding_model: str = "text-embedding-ada-002"
1618
_client = None
1719
_collection = None
1820
_embedding_function = None
@@ -21,7 +23,7 @@ class VectorDBMemory(Memory):
2123
def init_database(self):
2224
self._client = chromadb.Client()
2325
self._embedding_function = embedding_functions.OpenAIEmbeddingFunction(
24-
model_name="text-embedding-ada-002"
26+
model_name=self.openai_embedding_model, api_key=self.openai_api_key
2527
)
2628
self._collection = self._client.get_or_create_collection(
2729
name=self.db_name, embedding_function=self._embedding_function

adala/runtimes/__init__.py

-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,2 @@
11
from .base import Runtime, AsyncRuntime
22
from ._openai import OpenAIChatRuntime, OpenAIVisionRuntime, AsyncOpenAIChatRuntime
3-
from ._guidance import GuidanceRuntime, GuidanceModelType
4-
from ._batch import BatchRuntime
5-
6-
try:
7-
# check if langchain is installed
8-
from ._langchain import LangChainRuntime
9-
except ImportError:
10-
pass

0 commit comments

Comments
 (0)