Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-953: Stream results from adala inference server into LSE #75

Merged
merged 21 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,36 @@ class AsyncKafkaEnvironment(AsyncEnvironment):
kafka_input_topic: str
kafka_output_topic: str

async def initialize(self):
pakelley marked this conversation as resolved.
Show resolved Hide resolved
# claim kafka topic from shared pool here?
pass

async def finalize(self):
# release kafka topic to shared pool here?
pass

async def get_feedback(
self,
skills: SkillSet,
predictions: InternalDataFrame,
num_feedbacks: Optional[int] = None,
) -> EnvironmentFeedback:
raise NotImplementedError("Feedback is not supported in Kafka environment")

async def restore(self):
raise NotImplementedError("Restore is not supported in Kafka environment")

async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")

async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
await consumer.start()
try:
while True:
try:
# Wait for the next message with a timeout
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
print_text(f"Received message: {msg.value}")
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
yield msg.value
except asyncio.TimeoutError:
print_text(
Expand All @@ -55,8 +78,10 @@ async def message_sender(
try:
for record in data:
await producer.send_and_wait(topic, value=record)
print_text(f"Sent message: {record} to {topic=}")
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
finally:
await producer.stop()
print_text(f"No more messages for {topic=}")

async def get_next_batch(self, data_iterator, batch_size: int) -> List[Dict]:
batch = []
Expand Down
216 changes: 215 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ fastapi = "^0.104.1"
celery = {version = "^5.3.6", extras = ["redis"]}
uvicorn = "*"
pydantic-settings = "^2.2.1"
label-studio-sdk = "^0.0.32"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did we land on making this an optional dependency?
I know poetry doesn't directly support this, but imo it's probably better to document that this dep is necessary if you want to use the LSE result handler (and then we can include it in our deployment, but not force users to have it to use adala)

Copy link
Contributor Author

@matt-bernstein matt-bernstein Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from comment above:

# these are for the server
# they would be installed as `extras` if poetry supported version strings for extras, but it doesn't
# https://github.com/python-poetry/poetry/issues/834
# they also can't be installed as a `group`, because those are for dev dependencies only and could not be included if this package was pip-installed

Are we ok with everyone (including external users/contributors) needing to clone this repo in order to use the server in stead of pip installing adala? I could see it going either way.

Copy link
Contributor

@pakelley pakelley Apr 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed that the server dependencies are questionable here too, and our deps are in a questionable state no matte what, but this one at least seems to be stemming from a different issue:

  • the server deps should only be required by users of the server. These be in their own section (or repo/package), and the fix for that would eventually be breaking the server out into it's own repo/package
  • the LSE SDK dep should be an optional dep even for users of the server, and even if/when we break that into its own repo, we'll still want it to be separate. (Granted, we'd still want to use extras from poetry, which we won't be able to. But since it's currently only one additional dep to install for the LSE handler, I think it's preferable to just document that you need to have label-studio-sdk installed to use the handler than it is to force anyone who wants the server to have label-studio-sdk installed)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see what you mean - can do something like put the label_studio_sdk import inside this class, catch and surface the error from it not being available, and update our deployment setup to manually add the dep.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, not sure what needs to be done to install adala like this in deployment env:
poetry install --with label-studio

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to tell plate about it probably as well update our own dockerfile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added to dockerfile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@farioas we're adding the label studio SDK as an optional dep to Adala, is there anything that needs to be done manually here for our dev instance of adala to keep working? The dockerfile is already updated in this branch.


[tool.poetry.dev-dependencies]
pytest = "^7.4.3"
Expand Down
46 changes: 23 additions & 23 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiokafka import AIOKafkaProducer
from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated
Expand All @@ -24,25 +24,13 @@
process_file_streaming,
process_streaming_output,
)
from utils import get_input_topic, ResultHandler, Settings
from utils import get_input_topic, Settings
from server.handlers.result_handlers import ResultHandler


logger = logging.getLogger(__name__)


class Settings(BaseSettings):
"""
Can hardcode settings here, read from env file, or pass as env vars
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#field-value-priority
"""

kafka_bootstrap_servers: Union[str, List[str]]

model_config = SettingsConfigDict(
env_file=".env",
)


settings = Settings()

app = fastapi.FastAPI()
Expand Down Expand Up @@ -148,13 +136,25 @@ class SubmitRequest(BaseModel):
class SubmitStreamingRequest(BaseModel):
"""
Request model for submitting a streaming job.
Only difference from SubmitRequest is the task_name
"""

agent: Agent
result_handler: str
# SerializeAsAny is for allowing subclasses of ResultHandler
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
result_handler: SerializeAsAny[ResultHandler]
task_name: str = "process_file_streaming"

@field_validator("result_handler", mode="before")
def validate_result_handler(cls, value: Dict) -> ResultHandler:
"""
Allows polymorphism for ResultHandlers created from a dict; same implementation as the Skills, Environment, and Runtime within an Agent
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
"""
if "type" not in value:
raise HTTPException(
status_code=400, detail="Missing type in result_handler"
)
result_handler = ResultHandler.create_from_registry(value.pop("type"), **value)
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
return result_handler


class BatchData(BaseModel):
"""
Expand Down Expand Up @@ -184,10 +184,10 @@ async def submit(request: SubmitRequest):

# TODO: get task by name, e.g. request.task_name
task = process_file
serialized_agent = pickle.dumps(request.agent)
agent = request.agent

logger.debug(f"Submitting task {task.name} with agent {serialized_agent}")
result = task.delay(serialized_agent=serialized_agent)
logger.debug(f"Submitting task {task.name} with agent {agent}")
pakelley marked this conversation as resolved.
Show resolved Hide resolved
result = task.delay(agent=agent)
logger.debug(f"Task {task.name} submitted with job_id {result.id}")

return Response[JobCreated](data=JobCreated(job_id=result.id))
Expand All @@ -207,10 +207,10 @@ async def submit_streaming(request: SubmitStreamingRequest):

# TODO: get task by name, e.g. request.task_name
task = process_file_streaming
serialized_agent = pickle.dumps(request.agent)
agent = request.agent

logger.info(f"Submitting task {task.name} with agent {serialized_agent}")
input_result = task.delay(serialized_agent=serialized_agent)
logger.info(f"Submitting task {task.name} with agent {agent}")
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
input_result = task.delay(agent=agent)
input_job_id = input_result.id
logger.info(f"Task {task.name} submitted with job_id {input_job_id}")

Expand Down
Empty file added server/handlers/__init__.py
Empty file.
77 changes: 77 additions & 0 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Optional
import logging
import json
from abc import abstractmethod
from pydantic import computed_field, ConfigDict, model_validator

from adala.utils.registry import BaseModelInRegistry
from label_studio_sdk import Client


logger = logging.getLogger(__name__)


class ResultHandler(BaseModelInRegistry):
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
@abstractmethod
def __call__(self, batch):
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
"""
Callable to do something with a batch of results.
"""
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved


class DummyHandler(ResultHandler):
"""
Dummy handler to test streaming output flow
Can delete once we have a real handler
"""

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")


class LSEHandler(ResultHandler):
"""
Handler to use the Label Studio SDK to load a batch of results back into a Label Studio project
"""

model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

api_key: str
url: str
modelrun_id: int

@computed_field
def client(self) -> Client:
_client = Client(
api_key=self.api_key,
url=self.url,
)
# Need this to make POST requests using the SDK client
# TODO headers can only be set in this function, since client is a computed field. Need to rethink approach if we make non-POST requests, should probably just make a PR in label_studio_sdk to allow setting this in make_request()
_client.headers.update(
{
"accept": "application/json",
"Content-Type": "application/json",
}
)
return _client

@model_validator(mode="after")
def ready(self):
conn = self.client.check_connection()
assert conn["status"] == "UP", "Label Studio is not available"

return self

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")
self.client.make_request(
"POST",
"/api/model-run/batch-predictions",
data=json.dumps(
{
"modelrun_id": self.modelrun_id,
pakelley marked this conversation as resolved.
Show resolved Hide resolved
"results": batch,
}
),
)
50 changes: 28 additions & 22 deletions server/tasks/process_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,25 @@
import os
import logging

from adala.agents import Agent

from aiokafka import AIOKafkaConsumer
from celery import Celery, states
from celery.exceptions import Ignore
from server.utils import get_input_topic, get_output_topic, ResultHandler, Settings
from server.utils import get_input_topic, get_output_topic, Settings
from server.handlers.result_handlers import ResultHandler


logger = logging.getLogger(__name__)

REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379/0")
app = Celery("worker", broker=REDIS_URL, backend=REDIS_URL)
app = Celery(
"worker", broker=REDIS_URL, backend=REDIS_URL, accept_content=["json", "pickle"]
)


@app.task(name="process_file", track_started=True)
def process_file(serialized_agent: bytes):
# Load the agent
agent = pickle.loads(serialized_agent)
@app.task(name="process_file", track_started=True, serializer="pickle")
def process_file(agent: Agent):
# # Read data from a file and send it to the Kafka input topic
asyncio.run(agent.environment.initialize())

Expand All @@ -30,11 +33,10 @@ def process_file(serialized_agent: bytes):
asyncio.run(agent.environment.finalize())


@app.task(name="process_file_streaming", track_started=True, bind=True)
def process_file_streaming(self, serialized_agent: bytes):
# Load the agent
agent = pickle.loads(serialized_agent)

@app.task(
name="process_file_streaming", track_started=True, bind=True, serializer="pickle"
)
def process_file_streaming(self, agent: Agent):
# Get own job ID to set Consumer topic accordingly
job_id = self.request.id
agent.environment.kafka_input_topic = get_input_topic(job_id)
Expand All @@ -45,16 +47,10 @@ def process_file_streaming(self, serialized_agent: bytes):


async def async_process_streaming_output(
input_job_id: str, result_handler: str, batch_size: int
input_job_id: str, result_handler: ResultHandler, batch_size: int
):
logger.info(f"Polling for results {input_job_id=}")

try:
result_handler = ResultHandler.__dict__[result_handler]
except KeyError as e:
logger.error(f"{result_handler} is not a valid ResultHandler")
raise e

topic = get_output_topic(input_job_id)
settings = Settings()

Expand All @@ -74,11 +70,17 @@ async def async_process_streaming_output(
data = await consumer.getmany(timeout_ms=3000, max_records=batch_size)
for tp, messages in data.items():
if messages:
result_handler(messages)
logger.info(f"Handling {messages=} in topic {tp.topic}")
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
data = [msg.value for msg in messages]
result_handler(data)
logger.info(f"Handled {len(messages)} messages in topic {tp.topic}")
else:
logger.info(f"No messages in topic {tp.topic}")
if not data:
logger.info(f"No messages in any topic")
finally:
job = process_file_streaming.AsyncResult(input_job_id)
# TODO no way to recover here if connection to main app is lost, job will be stuck at "PENDING" so this will loop forever
pakelley marked this conversation as resolved.
Show resolved Hide resolved
if job.status in ["SUCCESS", "FAILURE", "REVOKED"]:
input_job_running = False
logger.info(f"Input job done, stopping output job")
Expand All @@ -88,16 +90,20 @@ async def async_process_streaming_output(
await consumer.stop()


@app.task(name="process_streaming_output", track_started=True, bind=True)
@app.task(
name="process_streaming_output", track_started=True, bind=True, serializer="pickle"
)
def process_streaming_output(
self, job_id: str, result_handler: str, batch_size: int = 2
self, job_id: str, result_handler: ResultHandler, batch_size: int = 2
):
try:
asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size))
except KeyError:
except Exception as e:
# Set own status to failure
self.update_state(state=states.FAILURE)

logger.log(level=logging.ERROR, msg=e)
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved

# Ignore the task so no other state is recorded
# TODO check if this updates state to Ignored, or keeps Failed
raise Ignore()
23 changes: 4 additions & 19 deletions server/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import logging

from enum import Enum
# from enum import Enum
from pydantic_settings import BaseSettings, SettingsConfigDict
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
from typing import List, Union

logger = logging.getLogger(__name__)
from pathlib import Path


class Settings(BaseSettings):
Expand All @@ -16,23 +13,11 @@ class Settings(BaseSettings):
kafka_bootstrap_servers: Union[str, List[str]]

model_config = SettingsConfigDict(
env_file=".env",
# have to use an absolute path here so celery workers can find it
env_file=(Path(__file__).parent / ".env"),
)


def dummy_handler(batch):
"""
Dummy handler to test streaming output flow
Can delete once we have a real handler
"""

logger.info(f"\n\nHandler received batch: {batch}\n\n")


class ResultHandler(Enum):
DUMMY = dummy_handler


def get_input_topic(job_id: str):
return f"adala-input-{job_id}"

Expand Down
Loading