Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: DIA-953: Stream results from adala inference server into LSE #75

Merged
merged 21 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Dockerfile.app
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@ COPY pyproject.toml poetry.lock ./

# Install dependencies
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi --no-root
&& poetry install --no-interaction --no-ansi --no-root --with label-studio

COPY . .

# Install adala and the app
RUN poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi
&& poetry install --no-interaction --no-ansi --with label-studio

# Set the working directory in the container to where the app will be run from
WORKDIR /usr/src/app/server
31 changes: 28 additions & 3 deletions adala/environments/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,41 @@ class AsyncKafkaEnvironment(AsyncEnvironment):
kafka_input_topic: str
kafka_output_topic: str

async def initialize(self):
pakelley marked this conversation as resolved.
Show resolved Hide resolved
# claim kafka topic from shared pool here?
pass

async def finalize(self):
# release kafka topic to shared pool here?
pass

async def get_feedback(
self,
skills: SkillSet,
predictions: InternalDataFrame,
num_feedbacks: Optional[int] = None,
) -> EnvironmentFeedback:
raise NotImplementedError("Feedback is not supported in Kafka environment")

async def restore(self):
raise NotImplementedError("Restore is not supported in Kafka environment")

async def save(self):
raise NotImplementedError("Save is not supported in Kafka environment")

async def message_receiver(self, consumer: AIOKafkaConsumer, timeout: int = 3):
await consumer.start()
try:
while True:
try:
# Wait for the next message with a timeout
msg = await asyncio.wait_for(consumer.getone(), timeout=timeout)
# print_text(f"Received message: {msg.value}")
yield msg.value
except asyncio.TimeoutError:
print_text(
f"No message received within the timeout {timeout} seconds"
)
# print_text(
# f"No message received within the timeout {timeout} seconds"
# )
break
finally:
await consumer.stop()
Expand All @@ -55,8 +78,10 @@ async def message_sender(
try:
for record in data:
await producer.send_and_wait(topic, value=record)
# print_text(f"Sent message: {record} to {topic=}")
finally:
await producer.stop()
# print_text(f"No more messages for {topic=}")

async def get_next_batch(self, data_iterator, batch_size: int) -> List[Dict]:
batch = []
Expand Down
1 change: 0 additions & 1 deletion docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# docker-compose.yml
version: "3.8"
services:
kafka:
restart: always
Expand Down
216 changes: 215 additions & 1 deletion poetry.lock

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ jupyterlab = "^4.0.10"
jupyter-client = "8.4.0"
matplotlib = "^3.7.4"

[tool.poetry.group.label-studio]
optional = true

[tool.poetry.group.label-studio.dependencies]
label-studio-sdk = "^0.0.32"

[tool.poetry.scripts]
adala = "adala.cli:main"

Expand Down
54 changes: 28 additions & 26 deletions server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from aiokafka import AIOKafkaProducer
from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from pydantic import BaseModel, SerializeAsAny, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
from pydantic.functional_validators import AfterValidator
from typing_extensions import Annotated
Expand All @@ -24,25 +24,13 @@
process_file_streaming,
process_streaming_output,
)
from utils import get_input_topic, ResultHandler, Settings
from utils import get_input_topic, Settings
from server.handlers.result_handlers import ResultHandler


logger = logging.getLogger(__name__)


class Settings(BaseSettings):
"""
Can hardcode settings here, read from env file, or pass as env vars
https://docs.pydantic.dev/latest/concepts/pydantic_settings/#field-value-priority
"""

kafka_bootstrap_servers: Union[str, List[str]]

model_config = SettingsConfigDict(
env_file=".env",
)


settings = Settings()

app = fastapi.FastAPI()
Expand Down Expand Up @@ -148,13 +136,27 @@ class SubmitRequest(BaseModel):
class SubmitStreamingRequest(BaseModel):
"""
Request model for submitting a streaming job.
Only difference from SubmitRequest is the task_name
"""

agent: Agent
result_handler: str
# SerializeAsAny allows for subclasses of ResultHandler
result_handler: SerializeAsAny[ResultHandler]
task_name: str = "process_file_streaming"

@field_validator("result_handler", mode="before")
def validate_result_handler(cls, value: Dict) -> ResultHandler:
"""
Allows polymorphism for ResultHandlers created from a dict; same implementation as the Skills, Environment, and Runtime within an Agent
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
"type" is the name of the subclass of ResultHandler being used. Currently available subclasses: LSEHandler, DummyHandler
Look in server/handlers/result_handlers.py for available subclasses
"""
if "type" not in value:
raise HTTPException(
status_code=400, detail="Missing type in result_handler"
)
result_handler = ResultHandler.create_from_registry(value.pop("type"), **value)
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
return result_handler


class BatchData(BaseModel):
"""
Expand Down Expand Up @@ -184,10 +186,10 @@ async def submit(request: SubmitRequest):

# TODO: get task by name, e.g. request.task_name
task = process_file
serialized_agent = pickle.dumps(request.agent)
agent = request.agent

logger.debug(f"Submitting task {task.name} with agent {serialized_agent}")
result = task.delay(serialized_agent=serialized_agent)
logger.info(f"Submitting task {task.name} with agent {agent}")
result = task.delay(agent=agent)
logger.debug(f"Task {task.name} submitted with job_id {result.id}")

return Response[JobCreated](data=JobCreated(job_id=result.id))
Expand All @@ -207,20 +209,20 @@ async def submit_streaming(request: SubmitStreamingRequest):

# TODO: get task by name, e.g. request.task_name
task = process_file_streaming
serialized_agent = pickle.dumps(request.agent)
agent = request.agent

logger.info(f"Submitting task {task.name} with agent {serialized_agent}")
input_result = task.delay(serialized_agent=serialized_agent)
logger.info(f"Submitting task {task.name} with agent {agent}")
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
input_result = task.delay(agent=agent)
input_job_id = input_result.id
logger.info(f"Task {task.name} submitted with job_id {input_job_id}")
logger.debug(f"Task {task.name} submitted with job_id {input_job_id}")

task = process_streaming_output
logger.info(f"Submitting task {task.name}")
logger.debug(f"Submitting task {task.name}")
output_result = task.delay(
job_id=input_job_id, result_handler=request.result_handler
)
output_job_id = output_result.id
logger.info(f"Task {task.name} submitted with job_id {output_job_id}")
logger.debug(f"Task {task.name} submitted with job_id {output_job_id}")

return Response[JobCreated](data=JobCreated(job_id=input_job_id))

Expand Down
Empty file added server/handlers/__init__.py
Empty file.
107 changes: 107 additions & 0 deletions server/handlers/result_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from typing import Optional
import logging
import json
from abc import abstractmethod
from pydantic import computed_field, ConfigDict, model_validator

from adala.utils.registry import BaseModelInRegistry


logger = logging.getLogger(__name__)

try:
from label_studio_sdk import Client as LSEClient
except ImportError:
logger.warning(
"Label Studio SDK not found. LSEHandler will not be available. Run `poetry install --with label-studio` to fix"
)

class LSEClient:
def __init__(self, *args, **kwargs):
logger.error(
"Label Studio SDK not found. LSEHandler is not available. Run `poetry install --with label-studio` to fix"
)


class ResultHandler(BaseModelInRegistry):
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
"""
Abstract base class for a result handler.
This is a callable that is instantiated in `/submit-streaming` with any arguments that are needed, and then is called on each batch of results when it is finished being processed by the Agent (it consumes from the Kafka topic that the Agent produces to).

It can be used as a connector to load results into a file or external service. If a ResultHandler is not used, the results will be discarded.

Subclasses must implement the `__call__` method.

The BaseModelInRegistry base class implements a factory pattern, allowing the "type" parameter to specify which subclass of ResultHandler to instantiate. For example:
```json
result_handler: {
"type": "DummyHandler",
"other_model_field": "other_model_value",
...
}
```
"""

@abstractmethod
def __call__(self, result_batch: list[dict]) -> None:
"""
Callable to do something with a batch of results.
"""
matt-bernstein marked this conversation as resolved.
Show resolved Hide resolved
pass


class DummyHandler(ResultHandler):
"""
Dummy handler to test streaming output flow
"""

def __call__(self, batch):
logger.info(f"\n\nHandler received batch: {batch}\n\n")


class LSEHandler(ResultHandler):
"""
Handler to use the Label Studio SDK to load a batch of results back into a Label Studio project
"""

model_config = ConfigDict(arbitrary_types_allowed=True) # for @computed_field

api_key: str
url: str
modelrun_id: int

@computed_field
def client(self) -> LSEClient:
_client = LSEClient(
api_key=self.api_key,
url=self.url,
)
# Need this to make POST requests using the SDK client
# TODO headers can only be set in this function, since client is a computed field. Need to rethink approach if we make non-POST requests, should probably just make a PR in label_studio_sdk to allow setting this in make_request()
_client.headers.update(
{
"accept": "application/json",
"Content-Type": "application/json",
}
)
return _client

@model_validator(mode="after")
def ready(self):
conn = self.client.check_connection()
assert conn["status"] == "UP", "Label Studio is not available"

return self

def __call__(self, result_batch):
logger.info(f"\n\nHandler received batch: {result_batch}\n\n")
self.client.make_request(
"POST",
"/api/model-run/batch-predictions",
data=json.dumps(
{
"modelrun_id": self.modelrun_id,
pakelley marked this conversation as resolved.
Show resolved Hide resolved
"results": result_batch,
}
),
)
Loading
Loading