Skip to content

Commit

Permalink
Merge pull request #34 from marzukr/stateless-api
Browse files Browse the repository at this point in the history
Stateless API
  • Loading branch information
marzukr authored Dec 16, 2024
2 parents 86637ed + 7478ef4 commit b35bbff
Show file tree
Hide file tree
Showing 19 changed files with 267 additions and 187 deletions.
5 changes: 2 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ DD_API_KEY=PLACEHOLDER
DD_ENV=prod
MAX_THREADS=1
STRATEGY_LOCATION=PLACEHOLDER
MAX_SESSIONS=20
SESSION_ID_BYTES=8
SESSION_TIMEOUT=1800 # in seconds
SESSION_TIMEOUT=86400 # in seconds
PIGEON_EXECUTION_TIMEOUT=5 # in seconds
REDIS_MAXMEMORY=256mb
File renamed without changes.
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@ poker. Inspired by the Pluribus poker bot developed by CMU and Facebook.
```sh
python -m venv api/venv
api/venv/bin/pip install -r api/requirements.txt
cp .githooks/pre-commit .git/hooks/pre-commit
cp .githooks/pre-push .git/hooks/pre-push
git submodule update --init --recursive
cp ai/src/mccfr/hyperparameters.h.dev ai/src/mccfr/hyperparameters.h
docker compose -f docker-compose.dev.yml build ai
docker run --rm -v ./ai/out:/build/out fishbait-ai /ai/dev_blueprint.sh
docker compose -f docker-compose.dev.yml build
docker compose -f docker-compose.dev.yml up
docker compose -f docker-compose.dev.yml up --build
```

## Deployment
Expand All @@ -22,7 +21,7 @@ docker compose -f docker-compose.dev.yml up
3. `cp nginx.conf.example nginx.conf`
* Set the `server_name` property to be the deployment url of the interface
4. `cp .env.example .env` and configure
5. `docker compose -f docker-compose.prod.yml up`
5. `docker compose -f docker-compose.prod.yml up --build -d`
6. Configure HTTPS with AWS Application Load Balancer

## Testing
Expand Down
9 changes: 2 additions & 7 deletions api/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,5 @@ RUN python3 -m venv venv
COPY requirements.txt .
RUN venv/bin/pip install --no-cache-dir -r requirements.txt

COPY entrypoint.dev.sh .
RUN chmod +x entrypoint.dev.sh

COPY entrypoint.prod.sh .
RUN chmod +x entrypoint.prod.sh

RUN mkdir /api/out
COPY ./entrypoints ./entrypoints
RUN chmod +x ./entrypoints/*
7 changes: 1 addition & 6 deletions api/entrypoint.dev.sh → api/entrypoints/api.dev.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#!/bin/sh
set -e

while [ ! -f /libvol/done ]; do
echo "Waiting for AI shared library..."
sleep 2
done

rm /libvol/done
export LD_LIBRARY_PATH=/libvol/lib:$LD_LIBRARY_PATH

cd /api/src
Expand Down
7 changes: 1 addition & 6 deletions api/entrypoint.prod.sh → api/entrypoints/api.prod.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
#!/bin/sh
set -e

while [ ! -f /libvol/done ]; do
echo "Waiting for AI shared library..."
sleep 2
done

rm /libvol/done
export LD_LIBRARY_PATH=/libvol/lib:$LD_LIBRARY_PATH

cd /api/src
Expand Down
7 changes: 7 additions & 0 deletions api/entrypoints/worker.dev.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh
set -e

export LD_LIBRARY_PATH=/libvol/lib:$LD_LIBRARY_PATH

cd /api/src
/api/venv/bin/celery -A tasks worker --loglevel=INFO
7 changes: 7 additions & 0 deletions api/entrypoints/worker.prod.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh
set -e

export LD_LIBRARY_PATH=/libvol/lib:$LD_LIBRARY_PATH

cd /api/src
/api/venv/bin/ddtrace-run /api/venv/bin/celery -A tasks worker --loglevel=INFO
13 changes: 13 additions & 0 deletions api/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,16 @@
amqp==5.3.1
astroid==3.3.6
attrs==23.2.0
billiard==4.2.1
bytecode==0.15.1
cattrs==23.2.3
celery==5.4.0
certifi==2024.2.2
charset-normalizer==3.3.2
click==8.1.3
click-didyoumean==0.3.1
click-plugins==1.1.1
click-repl==0.3.0
cpplint==1.6.1
datadog==0.44.0
ddsketch==2.0.4
Expand All @@ -20,22 +26,29 @@ importlib-metadata==6.11.0
isort==5.10.1
itsdangerous==2.1.2
Jinja2==3.1.2
kombu==5.4.2
lazy-object-proxy==1.7.1
MarkupSafe==2.1.1
mccabe==0.7.0
mypy==0.981
mypy-extensions==0.4.3
opentelemetry-api==1.22.0
platformdirs==2.5.2
prompt_toolkit==3.0.48
protobuf==4.25.2
pylint==3.3.2
python-dateutil==2.9.0.post0
redis==5.2.1
requests==2.31.0
setuptools==75.6.0
six==1.16.0
tomli==2.0.1
tomlkit==0.11.4
typing_extensions==4.3.0
tzdata==2024.2
urllib3==2.2.0
vine==5.1.0
wcwidth==0.2.13
Werkzeug==2.2.2
wrapt==1.14.1
xmltodict==0.13.0
Expand Down
84 changes: 18 additions & 66 deletions api/src/app.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,35 @@
"""Webserver for the FISHBAIT web interface."""

from datetime import datetime, timezone
from multiprocessing.managers import RemoteError
import re
from typing import Any
import secrets
import os

from flask import Flask, request, make_response
from werkzeug.exceptions import BadRequest
import datadog
from datadog import statsd
import redis

from pigeon import PigeonProxy, PigeonInterface
from pigeon import PigeonInterface
import settings
import error
from error import (
ApiError, ServerOverloadedError, MissingSessionIdError, UnknownSessionIdError,
ValidationError
)
from props import (
SetHandProps, ApplyProps, SetBoardProps, ResetProps, JoinEmailListProps
SetHandProps, ApplyProps, SetBoardProps, ResetProps
)
from utils import get_logger
import tasks
import sessions

app = Flask(__name__)
log = get_logger(__name__)
if os.getenv('DD_SERVICE'):
datadog.initialize()

def handle_api_error(e: ApiError):
log.exception(e)
return e.flask_tuple()

def handle_remote_error(e: RemoteError):
match = re.search(r'error\.([A-z]*?): ((.|\n)*)\n-{75}', f'{e}')
if match is None:
log.error('Could not parse PigeonProxy error')
raise ApiError() from e
error_name, error_msg = match.group(1, 2)
error_type = getattr(error, error_name)
parsed_error = error_type(error_msg)
raise parsed_error from e

def handle_bad_request(e: BadRequest):
raise ValidationError() from e

Expand All @@ -51,8 +38,6 @@ def handle_exceptions(e: Exception):
error_handler: Any = None
if isinstance(e, ApiError):
error_handler = handle_api_error
elif isinstance(e, RemoteError):
error_handler = handle_remote_error
elif isinstance(e, BadRequest):
error_handler = handle_bad_request
else:
Expand All @@ -64,11 +49,8 @@ def handle_exceptions(e: Exception):
log.exception(e)
return result
except Exception as new_exc: # pylint: disable=broad-except
if type(e) is type(new_exc):
log.exception(new_exc)
rec_er = RecursionError('Could not reduce the encountered error')
raise rec_er from new_exc
return handle_exceptions(new_exc)
log.exception('Exception while handling exception: %s', new_exc)
return ApiError().flask_tuple()

@app.before_request
def record_api_metric():
Expand All @@ -78,62 +60,31 @@ def record_api_metric():
# Session Management -----------------------------------------------------------
# ------------------------------------------------------------------------------

class Session:
def __init__(self):
self.updated = datetime.now(timezone.utc).timestamp()
self.revere = PigeonProxy()

sessions: dict[str, Session] = {}

def session_guard(route):
def guarded_route():
session_id = request.cookies.get(settings.SESSION_ID_KEY)
if session_id is None:
raise MissingSessionIdError()

session = sessions.get(session_id)
if session is None:
if not sessions.does_session_exist(session_id):
raise UnknownSessionIdError()

session.updated = datetime.now(timezone.utc).timestamp()
revere = tasks.PigeonProxy(session_id)
try:
return route(session.revere)
return route(revere)
except ApiError:
# These errors are anticipated and should be passed through
raise
except Exception:
# We should not get any other type of error. If we do, something may have
# gone horribly wrong and we need to delete this session to preserve the
# integrity of the server:
sessions.pop(session_id)
sessions.delete_session(session_id)
raise

guarded_route.__name__ = route.__name__
return guarded_route

def create_new_session():
token_candidate = secrets.token_hex(settings.SESSION_ID_BYTES)
while token_candidate in sessions:
token_candidate = secrets.token_hex(settings.SESSION_ID_BYTES)
sessions[token_candidate] = Session()
return token_candidate

def try_create_new_session():
if len(sessions) >= settings.MAX_SESSIONS:
current_time = datetime.now(timezone.utc).timestamp()
to_remove = None
for session_id, session in sessions.items():
if current_time - session.updated >= settings.SESSION_TIMEOUT:
to_remove = session_id
break
if to_remove is not None:
sessions.pop(to_remove)
return create_new_session()
else:
return None
else:
return create_new_session()

# ------------------------------------------------------------------------------
# Routes -----------------------------------------------------------------------
# ------------------------------------------------------------------------------
Expand All @@ -144,9 +95,13 @@ def api_status():

@app.route('/api/new-session', methods=['GET'])
def new_session():
new_session_id = try_create_new_session()
if new_session_id is None:
raise ServerOverloadedError()
try:
new_session_id = tasks.create_new_session.delay().get(
timeout=settings.PIGEON_EXECUTION_TIMEOUT
)
except redis.exceptions.OutOfMemoryError as e:
raise ServerOverloadedError() from e

resp = make_response()
resp.set_cookie(settings.SESSION_ID_KEY, new_session_id)
return resp
Expand Down Expand Up @@ -196,7 +151,4 @@ def reset(revere: PigeonInterface):

@app.route('/api/join-email-list', methods=['POST'])
def join_email_list():
props = JoinEmailListProps(request.get_json())
with open(settings.EMAIL_LIST_LOCATION, 'a', encoding='utf-8') as email_list:
email_list.write(f'{props.email}\n')
return make_response()
69 changes: 14 additions & 55 deletions api/src/pigeon.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
Structure,
CFUNCTYPE,
)
from multiprocessing.managers import BaseManager
import concurrent.futures
from functools import cached_property

import settings
import props
Expand Down Expand Up @@ -573,16 +572,26 @@ class Pigeon(PigeonInterface):
'''Sends messages between python clients and the fishbait C++ AI.'''

def __init__(self):
def commander_callback(data, size):
self._commander = bytes(data[:size])
self._commander_callback = CallbackFunc(commander_callback)
self._commander = b''

strategy_loc = settings.STRATEGY_LOCATION
commander_create(bytes(strategy_loc, 'utf-8'), self._commander_callback)

self._state: PigeonState = PigeonState()
self._update_state()

@cached_property
def _commander_callback(self):
def callback(data, size):
self._commander = bytes(data[:size])
return CallbackFunc(callback)

def __getstate__(self) -> object:
return (self._commander, self._state)

def __setstate__(self, state: tuple[bytes, PigeonState]):
self._commander, self._state = state

@staticmethod
def _auto_advance[**P, T](fn: Callable[Concatenate['Pigeon', P], T]):
# pylint: disable=protected-access
Expand Down Expand Up @@ -722,53 +731,3 @@ def _award_pot(self):
self._commander, self._commander_callback
)
self._update_state()

class PigeonManager(BaseManager):
pass
PigeonManager.register('Pigeon', Pigeon)

class PigeonProxy(PigeonInterface):
'''
An object that behaves like a local Pigeon to an outside observer but spawns
a Pigeon on a new process and sends messages to it.
'''

def __init__(self) -> None:
super().__init__()
self.manager = PigeonManager()
self.manager.start() # pylint: disable=consider-using-with
self.revere = self.manager.Pigeon()

def __del__(self):
log.info('Shutting down manager for %s', self)
self.manager.shutdown()
log.info('Completed manager shutdown for %s', self)

class PigeonMessage():
'''
A descriptor that applies the given function on the managed Pigeon
'''

def __set_name__(self, owner, name):
self.name = name

def __get__(self, obj, objtype):
def wrapped_fn(*args, **kwargs):
log.info(
'Calling %s for %s with args %s and kwargs %s',
self.name, obj, args, kwargs
)
with concurrent.futures.ThreadPoolExecutor() as executor:
fn = getattr(obj.revere, self.name)
future = executor.submit(fn, *args, **kwargs)
result = future.result(timeout=settings.PIGEON_EXECUTION_TIMEOUT)
log.info('Completed %s for %s', self.name, obj)
return result
return wrapped_fn

reset = PigeonMessage()
set_hand = PigeonMessage()
apply = PigeonMessage()
set_board = PigeonMessage()
state_dict = PigeonMessage()
new_hand = PigeonMessage()
Loading

0 comments on commit b35bbff

Please sign in to comment.