diff --git a/docker-compose.yml b/docker-compose.yml index f2c679e9..07ada6d0 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -31,8 +31,8 @@ services: redis: condition: service_healthy environment: - - REDIS_URL=redis://redis:6379/0 - KAFKA_BOOTSTRAP_SERVERS=kafka:9093 + - REDIS_URL=redis://redis:6379/0 command: ["poetry", "run", "uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8000"] worker: diff --git a/server/app.py b/server/app.py index a8171e0e..3320ca90 100644 --- a/server/app.py +++ b/server/app.py @@ -19,8 +19,13 @@ from log_middleware import LogMiddleware from tasks.process_file import app as celery_app -from tasks.process_file import process_file, process_file_streaming -from utils import get_input_topic +from tasks.process_file import ( + process_file, + process_file_streaming, + process_streaming_output, +) +from utils import get_input_topic, ResultHandler, Settings + logger = logging.getLogger(__name__) @@ -147,6 +152,7 @@ class SubmitStreamingRequest(BaseModel): """ agent: Agent + result_handler: str task_name: str = "process_file_streaming" @@ -204,10 +210,19 @@ async def submit_streaming(request: SubmitStreamingRequest): serialized_agent = pickle.dumps(request.agent) logger.info(f"Submitting task {task.name} with agent {serialized_agent}") - result = task.delay(serialized_agent=serialized_agent) - print(f"Task {task.name} submitted with job_id {result.id}") + input_result = task.delay(serialized_agent=serialized_agent) + input_job_id = input_result.id + logger.info(f"Task {task.name} submitted with job_id {input_job_id}") + + task = process_streaming_output + logger.info(f"Submitting task {task.name}") + output_result = task.delay( + job_id=input_job_id, result_handler=request.result_handler + ) + output_job_id = output_result.id + logger.info(f"Task {task.name} submitted with job_id {output_job_id}") - return Response[JobCreated](data=JobCreated(job_id=result.id)) + return Response[JobCreated](data=JobCreated(job_id=input_job_id)) @app.post("/jobs/submit-batch", response_model=Response) diff --git a/server/tasks/process_file.py b/server/tasks/process_file.py index b0d110ce..0aa9b6c6 100644 --- a/server/tasks/process_file.py +++ b/server/tasks/process_file.py @@ -1,9 +1,13 @@ import asyncio +import json import pickle import os import logging -from celery import Celery -from server.utils import get_input_topic, get_output_topic + +from aiokafka import AIOKafkaConsumer +from celery import Celery, states +from celery.exceptions import Ignore +from server.utils import get_input_topic, get_output_topic, ResultHandler, Settings logger = logging.getLogger(__name__) @@ -38,3 +42,62 @@ def process_file_streaming(self, serialized_agent: bytes): # Run the agent asyncio.run(agent.arun()) + + +async def async_process_streaming_output( + input_job_id: str, result_handler: str, batch_size: int +): + logger.info(f"Polling for results {input_job_id=}") + + try: + result_handler = ResultHandler.__dict__[result_handler] + except KeyError as e: + logger.error(f"{result_handler} is not a valid ResultHandler") + raise e + + topic = get_output_topic(input_job_id) + settings = Settings() + + consumer = AIOKafkaConsumer( + topic, + bootstrap_servers=settings.kafka_bootstrap_servers, + value_deserializer=lambda v: json.loads(v.decode("utf-8")), + auto_offset_reset="earliest", + ) + await consumer.start() + logger.info(f"consumer started {input_job_id=}") + + input_job_running = True + + while input_job_running: + try: + data = await consumer.getmany(timeout_ms=3000, max_records=batch_size) + for tp, messages in data.items(): + if messages: + result_handler(messages) + else: + logger.info(f"No messages in topic {tp.topic}") + finally: + job = process_file_streaming.AsyncResult(input_job_id) + if job.status in ["SUCCESS", "FAILURE", "REVOKED"]: + input_job_running = False + logger.info(f"Input job done, stopping output job") + else: + logger.info(f"Input job still running, keeping output job running") + + await consumer.stop() + + +@app.task(name="process_streaming_output", track_started=True, bind=True) +def process_streaming_output( + self, job_id: str, result_handler: str, batch_size: int = 2 +): + try: + asyncio.run(async_process_streaming_output(job_id, result_handler, batch_size)) + except KeyError: + # Set own status to failure + self.update_state(state=states.FAILURE) + + # Ignore the task so no other state is recorded + # TODO check if this updates state to Ignored, or keeps Failed + raise Ignore()