Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
nik committed Apr 1, 2024
1 parent cac3787 commit 3a89d63
Show file tree
Hide file tree
Showing 15 changed files with 380 additions and 236 deletions.
39 changes: 25 additions & 14 deletions adala/agents/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
import logging
from pydantic import BaseModel, Field, SkipValidation, field_validator, model_validator, SerializeAsAny
from pydantic import (
BaseModel,
Field,
SkipValidation,
field_validator,
model_validator,
SerializeAsAny,
)
from abc import ABC, abstractmethod
from typing import Any, Optional, List, Dict, Union, Tuple
from rich import print
Expand Down Expand Up @@ -52,15 +59,11 @@ class Agent(BaseModel, ABC):

memory: Memory = Field(default=None)
runtimes: Dict[str, SerializeAsAny[Union[Runtime, AsyncRuntime]]] = Field(
default_factory=lambda: {
"default": OpenAIChatRuntime(model='gpt-3.5-turbo')
}
default_factory=lambda: {"default": OpenAIChatRuntime(model="gpt-3.5-turbo")}
)
default_runtime: str = "default"
teacher_runtimes: Dict[str, SerializeAsAny[Runtime]] = Field(
default_factory=lambda: {
"default": None
}
default_factory=lambda: {"default": None}
)
default_teacher_runtime: str = "default"

Expand Down Expand Up @@ -112,9 +115,11 @@ def skills_validator(cls, v) -> SkillSet:
elif isinstance(v, list):
return LinearSkillSet(skills=v)
else:
raise ValueError(f"skills must be of type SkillSet or Skill, but received type {type(v)}")
raise ValueError(
f"skills must be of type SkillSet or Skill, but received type {type(v)}"
)

@field_validator('runtimes', mode='before')
@field_validator("runtimes", mode="before")
def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
"""
Validates and creates runtimes
Expand All @@ -127,7 +132,9 @@ def runtimes_validator(cls, v) -> Dict[str, Union[Runtime, AsyncRuntime]]:
f"Runtime {runtime_name} must have a 'type' field to specify the runtime type."
)
type_name = runtime_value.pop("type")
runtime_value = Runtime.create_from_registry(type=type_name, **runtime_value)
runtime_value = Runtime.create_from_registry(
type=type_name, **runtime_value
)
out[runtime_name] = runtime_value
return out

Expand Down Expand Up @@ -200,9 +207,11 @@ def get_teacher_runtime(self, runtime: Optional[str] = None) -> Runtime:
raise ValueError(f'Teacher Runtime "{runtime}" not found.')
runtime = self.teacher_runtimes[runtime]
if not runtime:
raise ValueError(f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})")
raise ValueError(
f"Teacher Runtime is requested, but it was not set."
f"Please provide a teacher runtime in the agent's constructor explicitly:"
f"agent = Agent(..., teacher_runtimes={{'default': OpenAIChatRuntime(model='gpt-4')}})"
)
return runtime

def run(
Expand Down Expand Up @@ -260,7 +269,9 @@ async def arun(
# run on the environment until it is exhausted
while True:
try:
data_batch = await self.environment.get_data_batch(batch_size=runtime.batch_size)
data_batch = await self.environment.get_data_batch(
batch_size=runtime.batch_size
)
if data_batch.empty:
print_text("No more data in the environment. Exiting.")
break
Expand Down
1 change: 0 additions & 1 deletion adala/environments/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,6 @@ class Config:


class AsyncEnvironment(Environment, ABC):

@abstractmethod
async def initialize(self):
"""
Expand Down
95 changes: 63 additions & 32 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,16 @@ async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
yield msg.value
except asyncio.TimeoutError:
print_text(f"No message received within the timeout {timeout} seconds")
print_text(
f"No message received within the timeout {timeout} seconds"
)
break
finally:
await consumer.stop()

async def message_sender(self, producer: AIOKafkaProducer, data: Iterable, topic: str):
async def message_sender(
self, producer: AIOKafkaProducer, data: Iterable, topic: str
):
await producer.start()
try:
for record in data:
Expand All @@ -70,20 +74,22 @@ async def get_data_batch(self, batch_size: Optional[int]) -> InternalDataFrame:
consumer = AIOKafkaConsumer(
self.kafka_input_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
auto_offset_reset='earliest',
group_id='adala-consumer-group' # TODO: make it configurable based on the environment
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
group_id="adala-consumer-group", # TODO: make it configurable based on the environment
)

data_stream = self.message_receiver(consumer)
batch = await self.get_next_batch(data_stream, batch_size)
logger.info(f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}")
logger.info(
f"Received a batch of {len(batch)} records from Kafka topic {self.kafka_input_topic}"
)
return InternalDataFrame(batch)

async def set_predictions(self, predictions: InternalDataFrame):
producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)
predictions_iter = (r.to_dict() for _, r in predictions.iterrows())
await self.message_sender(producer, predictions_iter, self.kafka_output_topic)
Expand All @@ -109,7 +115,7 @@ def _iter_csv_local(self, csv_file_path):
Read data from the CSV file and push it to the kafka topic.
"""

with open(csv_file_path, 'r') as csv_file:
with open(csv_file_path, "r") as csv_file:
csv_reader = DictReader(csv_file)
for row in csv_reader:
yield row
Expand All @@ -120,9 +126,9 @@ def _iter_csv_s3(self, s3_uri):
"""
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
s3 = boto3.client('s3')
s3 = boto3.client("s3")
obj = s3.get_object(Bucket=bucket_name, Key=key)
data = obj['Body'].read().decode('utf-8')
data = obj["Body"].read().decode("utf-8")
csv_reader = DictReader(StringIO(data))
for row in csv_reader:
yield row
Expand All @@ -140,7 +146,7 @@ async def initialize(self):

producer = AIOKafkaProducer(
bootstrap_servers=self.kafka_bootstrap_servers,
value_serializer=lambda v: json.dumps(v).encode('utf-8')
value_serializer=lambda v: json.dumps(v).encode("utf-8"),
)

await self.message_sender(producer, csv_reader, self.kafka_input_topic)
Expand All @@ -153,53 +159,79 @@ async def finalize(self):
consumer = AIOKafkaConsumer(
self.kafka_output_topic,
bootstrap_servers=self.kafka_bootstrap_servers,
value_deserializer=lambda v: json.loads(v.decode('utf-8')),
auto_offset_reset='earliest',
group_id='consumer-group-output-topic' # TODO: make it configurable based on the environment
value_deserializer=lambda v: json.loads(v.decode("utf-8")),
auto_offset_reset="earliest",
group_id="consumer-group-output-topic", # TODO: make it configurable based on the environment
)

data_stream = self.message_receiver(consumer)

if self.output_file.startswith("s3://"):
await self._write_to_s3(self.output_file, self.error_file, data_stream, self.pass_through_columns)
await self._write_to_s3(
self.output_file,
self.error_file,
data_stream,
self.pass_through_columns,
)
else:
await self._write_to_local(self.output_file, self.error_file, data_stream, self.pass_through_columns)

async def _write_to_csv_fileobj(self, fileobj, error_fileobj, data_stream, column_names):
await self._write_to_local(
self.output_file,
self.error_file,
data_stream,
self.pass_through_columns,
)

async def _write_to_csv_fileobj(
self, fileobj, error_fileobj, data_stream, column_names
):
csv_writer, error_csv_writer = None, None
error_columns = ['index', 'message', 'details']
error_columns = ["index", "message", "details"]
while True:
try:
record = await anext(data_stream)
if record.get('error') == True:
if record.get("error") == True:
logger.error(f"Error occurred while processing record: {record}")
if error_csv_writer is None:
error_csv_writer = DictWriter(error_fileobj, fieldnames=error_columns)
error_csv_writer = DictWriter(
error_fileobj, fieldnames=error_columns
)
error_csv_writer.writeheader()
error_csv_writer.writerow({k: record.get(k, '') for k in error_columns})
error_csv_writer.writerow(
{k: record.get(k, "") for k in error_columns}
)
else:
if csv_writer is None:
if column_names is None:
column_names = list(record.keys())
csv_writer = DictWriter(fileobj, fieldnames=column_names)
csv_writer.writeheader()
csv_writer.writerow({k: record.get(k, '') for k in column_names})
csv_writer.writerow({k: record.get(k, "") for k in column_names})
except StopAsyncIteration:
break

async def _write_to_local(self, file_path: str, error_file_path: str, data_stream, column_names):
with open(file_path, 'w') as csv_file, open(error_file_path, 'w') as error_file:
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)

async def _write_to_s3(self, s3_uri: str, s3_uri_errors: str, data_stream, column_names):
async def _write_to_local(
self, file_path: str, error_file_path: str, data_stream, column_names
):
with open(file_path, "w") as csv_file, open(error_file_path, "w") as error_file:
await self._write_to_csv_fileobj(
csv_file, error_file, data_stream, column_names
)

async def _write_to_s3(
self, s3_uri: str, s3_uri_errors: str, data_stream, column_names
):
# Assuming s3_uri format is "s3://bucket-name/path/to/file.csv"
bucket_name, key = s3_uri.replace("s3://", "").split("/", 1)
error_bucket_name, error_key = s3_uri_errors.replace("s3://", "").split("/", 1)
s3 = boto3.client('s3')
s3 = boto3.client("s3")
with StringIO() as csv_file, StringIO() as error_file:
await self._write_to_csv_fileobj(csv_file, error_file, data_stream, column_names)
await self._write_to_csv_fileobj(
csv_file, error_file, data_stream, column_names
)
s3.put_object(Bucket=bucket_name, Key=key, Body=csv_file.getvalue())
s3.put_object(Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue())
s3.put_object(
Bucket=error_bucket_name, Key=error_key, Body=error_file.getvalue()
)

async def get_feedback(
self,
Expand All @@ -214,4 +246,3 @@ async def restore(self):

async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")

3 changes: 1 addition & 2 deletions adala/memories/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class VectorDBMemory(Memory):
def init_database(self):
self._client = chromadb.Client()
self._embedding_function = embedding_functions.OpenAIEmbeddingFunction(
model_name=self.openai_embedding_model,
api_key=self.openai_api_key
model_name=self.openai_embedding_model, api_key=self.openai_api_key
)
self._collection = self._client.get_or_create_collection(
name=self.db_name, embedding_function=self._embedding_function
Expand Down
32 changes: 18 additions & 14 deletions adala/runtimes/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ async def async_create_completion(
if not semaphore:
semaphore = asyncio.Semaphore(1)
if not session:
session = aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=default_timeout))
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=default_timeout)
)
messages = [{"role": "user", "content": user_prompt}]
if system_prompt:
if instruction_first:
Expand All @@ -64,21 +66,20 @@ async def async_create_completion(
try:
async with semaphore, session.post(
DEFAULT_CREATE_COMPLETION_URL,
headers={"Authorization": f'Bearer {openai_api_key}'},
headers={"Authorization": f"Bearer {openai_api_key}"},
json={
"messages": messages,
"model": model,
"max_tokens": max_tokens,
"temperature": temperature,
}
},
) as response:

response.raise_for_status()
response_json = await response.json()
completion_text = response_json["choices"][0]["message"]["content"]
return {
"text": completion_text,
}
response.raise_for_status()
response_json = await response.json()
completion_text = response_json["choices"][0]["message"]["content"]
return {
"text": completion_text,
}
except aiohttp.ClientResponseError as e:
# Handle HTTP errors
return {
Expand Down Expand Up @@ -144,6 +145,7 @@ async def async_concurrent_create_completion(
responses = await asyncio.gather(*tasks)
return responses


class OpenAIChatRuntime(Runtime):
"""
Runtime that uses [OpenAI API](https://openai.com/) and chat completion models to perform the skill.
Expand Down Expand Up @@ -365,18 +367,21 @@ async def batch_to_batch(

# parse responses, optionally match it with options
for prompt, response in zip(prompts, responses):

# check for errors - if any, append to outputs and continue
if response.get("error"):
outputs.append(response)
if self.verbose:
print_error(f"Prompt: {prompt}\nOpenAI API error: {response}")
print_error(
f"Prompt: {prompt}\nOpenAI API error: {response}"
)
continue

# otherwise, append the response to outputs
completion_text = response["text"]
if self.verbose:
print(f"Prompt: {prompt}\nOpenAI API response: {completion_text}")
print(
f"Prompt: {prompt}\nOpenAI API response: {completion_text}"
)
if name in options:
completion_text = match_options(completion_text, options[name])
outputs.append({name: completion_text})
Expand All @@ -396,7 +401,6 @@ async def record_to_record(
field_schema: Optional[Dict] = None,
instructions_first: bool = True,
) -> Dict[str, str]:

raise NotImplementedError("record_to_record is not implemented")


Expand Down
12 changes: 8 additions & 4 deletions adala/skills/skillset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
TransformSkill,
SampleTransformSkill,
AnalysisSkill,
SynthesisSkill
SynthesisSkill,
)


Expand Down Expand Up @@ -57,13 +57,17 @@ def skills_validator(cls, v: Union[List, Dict]) -> Dict[str, Skill]:
elif isinstance(v[0], dict):
# convert list of skill dictionaries to dictionary
for skill in v:
if 'type' not in skill:
if "type" not in skill:
raise ValueError("Skill dictionary must contain a 'type' key")
skills[skill["name"]] = Skill.create_from_registry(skill.pop('type'), **skill)
skills[skill["name"]] = Skill.create_from_registry(
skill.pop("type"), **skill
)
elif isinstance(v, dict):
skills = v
else:
raise ValueError(f"skills must be a list or dictionary, but received type {type(v)}")
raise ValueError(
f"skills must be a list or dictionary, but received type {type(v)}"
)
return skills

@abstractmethod
Expand Down
Loading

0 comments on commit 3a89d63

Please sign in to comment.