Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: DIA-986: Upgrade OpenAI client version & pytest coverage #78

Merged
merged 6 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 26 additions & 24 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import logging
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator, SerializeAsAny
from pydantic import (
BaseModel,
Field,
SkipValidation,
field_validator,
model_validator,
SerializeAsAny,
)
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print
Expand All @@ -9,7 +16,6 @@
from adala.environments.static_env import StaticEnvironment
from adala.runtimes.base import Runtime, AsyncRuntime
from adala.runtimes._openai import OpenAIChatRuntime
from adala.runtimes import GuidanceRuntime
from adala.skills._base import Skill
from adala.memories.base import Memory
from adala.skills.skillset import SkillSet, LinearSkillSet
Expand Down Expand Up @@ -53,24 +59,12 @@ class Agent(BaseModel, ABC):

memory: Memory = Field(default=None)
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {
"default": GuidanceRuntime()
# 'openai': OpenAIChatRuntime(model='gpt-3.5-turbo'),
# 'llama2': LLMRuntime(
# llm_runtime_type=LLMRuntimeModelType.Transformers,
# llm_params={
# 'model': 'meta-llama/Llama-2-7b',
# 'device': 'cuda:0',
# }
# )
}
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
)
default_runtime: str = "default"
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
default_factory=lambda: {
"default": None
}
default_factory=lambda: {"default": None}
)
default_runtime: str = "default"
default_teacher_runtime: str = "default"

class Config:
Expand Down Expand Up @@ -121,9 +115,11 @@ def skills_validator(cls, v) -> SkillSet:
elif isinstance(v, list):
return LinearSkillSet(skills=v)
else:
raise ValueError(f"skills must be of type SkillSet or Skill, but received type {type(v)}")
raise ValueError(
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
)

@field_validator('runtimes', mode='before')
@field_validator("runtimes", mode="before")
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
Validates and creates runtimes
Expand All @@ -136,7 +132,9 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
)
type_name = runtime_value.pop("type")
runtime_value = Runtime.create_from_registry(type=type_name, **runtime_value)
runtime_value = Runtime.create_from_registry(
type=type_name, **runtime_value
)
out[runtime_name] = runtime_value
return out

Expand Down Expand Up @@ -209,9 +207,11 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
raise ValueError(f'Teacher Runtime "{runtime}" not found.')
runtime = self.teacher_runtimes[runtime]
if not runtime:
raise ValueError(f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})")
raise ValueError(
f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})"
)
return runtime

def run(
Expand Down Expand Up @@ -269,7 +269,9 @@ async def arun(
# run on the environment until it is exhausted
while True:
try:
data_batch = await self.environment.get_data_batch(batch_size=runtime.batch_size)
data_batch = await self.environment.get_data_batch(
batch_size=runtime.batch_size
)
if data_batch.empty:
print_text("No more data in the environment. Exiting.")
break
Expand Down
1 change: 0 additions & 1 deletion adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class Config:


class AsyncEnvironment(Environment, ABC):

@abstractmethod
async def initialize(self):
"""
Expand Down
95 changes: 63 additions & 32 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
yield msg.value
except asyncio.TimeoutError:
print_text(f"No message received within the timeout {timeout} seconds")
print_text(
f"No message received within the timeout {timeout} seconds"
)
break
finally:
await consumer.stop()

async def message_sender(self, producer: AIOKafkaProducer, data: Iterable, topic: str):
async def message_sender(
self, producer: AIOKafkaProducer, data: Iterable, topic: str
):
await producer.start()
try:
for record in data:
Expand All @@ -70,20 +74,22 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
consumer = AIOKafkaConsumer(
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
auto_offset_reset='earliest',
group_id='adala-consumer-group' # TODO: make it configurable based on the environment
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
group_id="adala-consumer-group", # TODO: make it configurable based on the environment
)

data_stream = self.message_receiver(consumer)
batch = await self.get_next_batch(data_stream, batch_size)
logger.info(f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}")
logger.info(
f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}"
)
return InternalDataFrame(batch)

async def set_predictions(self, predictions: InternalDataFrame):
producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
await self.message_sender(producer, predictions_iter, self.kafka_output_topic)
Expand All @@ -109,7 +115,7 @@ def _iter_csv_local(self, csv_file_path):
Read data from the CSV file and push it to the kafka topic.
"""

with open(csv_file_path, 'r') as csv_file:
with open(csv_file_path, "r") as csv_file:
csv_reader = DictReader(csv_file)
for row in csv_reader:
yield row
Expand All @@ -120,9 +126,9 @@ def _iter_csv_s3(self, s3_uri):
"""
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
s3 = boto3.client('s3')
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=bucket_name, Key=key)
data = obj['Body'].read().decode('utf-8')
data = obj["Body"].read().decode("utf-8")
csv_reader = DictReader(StringIO(data))
for row in csv_reader:
yield row
Expand All @@ -140,7 +146,7 @@ async def initialize(self):

producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)

await self.message_sender(producer, csv_reader, self.kafka_input_topic)
Expand All @@ -153,53 +159,79 @@ async def finalize(self):
consumer = AIOKafkaConsumer(
self.kafka_output_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
auto_offset_reset='earliest',
group_id='consumer-group-output-topic' # TODO: make it configurable based on the environment
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
group_id="consumer-group-output-topic", # TODO: make it configurable based on the environment
)

data_stream = self.message_receiver(consumer)

if self.output_file.startswith("s3://"):
await self._write_to_s3(self.output_file, self.error_file, data_stream, self.pass_through_columns)
await self._write_to_s3(
self.output_file,
self.error_file,
data_stream,
self.pass_through_columns,
)
else:
await self._write_to_local(self.output_file, self.error_file, data_stream, self.pass_through_columns)

async def _write_to_csv_fileobj(self, fileobj, error_fileobj, data_stream, column_names):
await self._write_to_local(
self.output_file,
self.error_file,
data_stream,
self.pass_through_columns,
)

async def _write_to_csv_fileobj(
self, fileobj, error_fileobj, data_stream, column_names
):
csv_writer, error_csv_writer = None, None
error_columns = ['index', 'message', 'details']
error_columns = ["index", "message", "details"]
while True:
try:
record = await anext(data_stream)
if record.get('error') == True:
if record.get("error") == True:
logger.error(f"Error occurred while processing record: {record}")
if error_csv_writer is None:
error_csv_writer = DictWriter(error_fileobj, fieldnames=error_columns)
error_csv_writer = DictWriter(
error_fileobj, fieldnames=error_columns
)
error_csv_writer.writeheader()
error_csv_writer.writerow({k: record.get(k, '') for k in error_columns})
error_csv_writer.writerow(
{k: record.get(k, "") for k in error_columns}
)
else:
if csv_writer is None:
if column_names is None:
column_names = list(record.keys())
csv_writer = DictWriter(fileobj, fieldnames=column_names)
csv_writer.writeheader()
csv_writer.writerow({k: record.get(k, '') for k in column_names})
csv_writer.writerow({k: record.get(k, "") for k in column_names})
except StopAsyncIteration:
break

async def _write_to_local(self, file_path: str, error_file_path: str, data_stream, column_names):
with open(file_path, 'w') as csv_file, open(error_file_path, 'w') as error_file:
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)

async def _write_to_s3(self, s3_uri: str, s3_uri_errors: str, data_stream, column_names):
async def _write_to_local(
self, file_path: str, error_file_path: str, data_stream, column_names
):
with open(file_path, "w") as csv_file, open(error_file_path, "w") as error_file:
await self._write_to_csv_fileobj(
csv_file, error_file, data_stream, column_names
)

async def _write_to_s3(
self, s3_uri: str, s3_uri_errors: str, data_stream, column_names
):
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
error_bucket_name, error_key = s3_uri_errors.replace("s3://", "").split("/", 1)
s3 = boto3.client('s3')
s3 = boto3.client("s3")
with StringIO() as csv_file, StringIO() as error_file:
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)
await self._write_to_csv_fileobj(
csv_file, error_file, data_stream, column_names
)
s3.put_object(Bucket=bucket_name, Key=key, Body=csv_file.getvalue())
s3.put_object(Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue())
s3.put_object(
Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue()
)

async def get_feedback(
self,
Expand All @@ -214,4 +246,3 @@ async def restore(self):

async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")

1 change: 1 addition & 0 deletions adala/memories/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .file_memory import FileMemory
from .vectordb import VectorDBMemory
from .base import Memory
4 changes: 3 additions & 1 deletion adala/memories/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ class VectorDBMemory(Memory):
"""

db_name: str = ""
openai_api_key: str
openai_embedding_model: str = "text-embedding-ada-002"
_client = None
_collection = None
_embedding_function = None
Expand All @@ -21,7 +23,7 @@ class VectorDBMemory(Memory):
def init_database(self):
self._client = chromadb.Client()
self._embedding_function = embedding_functions.OpenAIEmbeddingFunction(
model_name="text-embedding-ada-002"
model_name=self.openai_embedding_model, api_key=self.openai_api_key
)
self._collection = self._client.get_or_create_collection(
name=self.db_name, embedding_function=self._embedding_function
Expand Down
8 changes: 0 additions & 8 deletions adala/runtimes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,2 @@
from .base import Runtime, AsyncRuntime
from ._openai import OpenAIChatRuntime, OpenAIVisionRuntime, AsyncOpenAIChatRuntime
from ._guidance import GuidanceRuntime, GuidanceModelType
from ._batch import BatchRuntime

try:
# check if langchain is installed
from ._langchain import LangChainRuntime
except ImportError:
pass
Loading
Loading