Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up model usage #8

Merged
merged 3 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions src/gentrade_server/config.py

This file was deleted.

14 changes: 8 additions & 6 deletions src/gentrade_server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware

from .routers import secure, public, agent
from .routers import public, agent, admin
from .auth import get_user
from .util import check_server_time
from .config import settings
from .model import settings
from .datahub import DataHub

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')
Expand Down Expand Up @@ -51,14 +51,16 @@ def receive_signal(number, _):
public.router,
prefix="/api/v1/public"
)

app.include_router(
secure.router,
prefix="/api/v1/secure",
agent.router,
prefix="/api/v1/agent",
dependencies=[Depends(get_user)]
)

app.include_router(
agent.router,
prefix="/api/v1/agent",
admin.router,
prefix="/api/v1/admin",
dependencies=[Depends(get_user)]
)

Expand Down
72 changes: 72 additions & 0 deletions src/gentrade_server/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
Model
"""
from typing import List

from pydantic import BaseModel, Field, field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

class HealthCheck(BaseModel):
"""
Response model to validate and return when performing a health check.
"""
status: str = Field("OK")

class Settings(BaseSettings):
"""
Settings
"""
model_config = SettingsConfigDict(enable_decoding=False)

openai_api_key: str = ""
openai_api_url: str = ""
openai_api_model: str = "gpt-3.5-turbo"

ntp_servers : List[str] = Field(
"ntp.ntsc.ac.cn,ntp.sjtu.edu.cn,cn.ntp.org.cn,cn.pool.ntp.org,ntp.aliyun.com",
description="The string list of NTP server splitted via comma")

@field_validator('ntp_servers', mode='before')
@classmethod
def decode_ntp_servers(cls, v: str) -> List[str]:
"""decode function override

Args:
v (str): input string

Returns:
List[str]: splitted list for all NTP servers
"""
return v.split(',')

settings = Settings()

class Market(BaseModel):
"""
Response model to validate and return when performing a health check.
"""
name: str = Field(...)
type: str = Field(...)

class Asset(BaseModel):
"""
Asset Model_
"""
name: str = Field(...)
type: str = Field(...)
market: str = Field(...)
quote: str = Field(...)
cik: int = Field(None, description="only for US stock")
symbol: str = Field(None, description="only for crypto")
base: str = Field(None, description="only for crypto")

class OHLCV(BaseModel):
"""
OHLCV model
"""
time: int = Field(..., description="UTC timestamp in seconds")
open: float = Field(...)
high: float = Field(...)
low: float = Field(...)
close: float = Field(...)
vol: float = Field(...)
20 changes: 20 additions & 0 deletions src/gentrade_server/routers/admin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
'''
Admin portal
'''
import logging

from fastapi import APIRouter, Depends
from ..model import settings, Settings
from ..auth import get_user

LOG = logging.getLogger(__name__)

router = APIRouter()

@router.get("/settings")
async def get_settings(user: dict = Depends(get_user)) -> Settings:
"""
Get server settings
"""
LOG.info(user)
return settings
8 changes: 4 additions & 4 deletions src/gentrade_server/routers/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from openai import OpenAI
from fastapi import APIRouter, Depends
from ..auth import get_user
from ..config import settings
from ..model import settings

LOG = logging.getLogger(__name__)

router = APIRouter()

client = OpenAI(
api_key=settings.OPENAI_API_KEY,
base_url=settings.OPENAI_API_URL
api_key=settings.openai_api_key,
base_url=settings.openai_api_url
)

@router.get("/")
Expand All @@ -22,7 +22,7 @@ async def get_answer(prompt: str, user: dict = Depends(get_user)):
Prompt to OpenAI and get answer
"""
completion = client.chat.completions.create(
model=settings.OPENAI_API_MODEL,
model=settings.openai_api_model,
messages=[
{"role": "system", "content":
"You are a Lu Ken's assistant for cryptocurrency market."},
Expand Down
87 changes: 30 additions & 57 deletions src/gentrade_server/routers/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,46 +8,22 @@
import datetime
from dateutil.tz import tzlocal

from fastapi import APIRouter
from pydantic import BaseModel
from fastapi import APIRouter, HTTPException

from ..config import settings
from ..datahub import DataHub
from ..model import HealthCheck, Market, Asset, OHLCV

LOG = logging.getLogger(__name__)

router = APIRouter()

@router.get("/")
async def get_testroute():
"""
Test public interface
"""
return "OK"

class HealthCheck(BaseModel):
"""
Response model to validate and return when performing a health check.
"""

status: str = "OK"

@router.get("/health")
async def get_health() -> HealthCheck:
"""
Check health
"""
return HealthCheck(status="OK")

@router.get("/settings")
async def get_settings():
"""
Get server settings
"""
return {
'ntp_server': settings.ntp_server
}

@router.get("/server_time")
async def get_server_time():
"""
Expand All @@ -63,8 +39,8 @@ async def get_server_time():
'timestamp_server': int(curr_ts)
}

@router.get("/markets/")
async def get_markets():
@router.get("/markets")
async def get_markets() -> dict[str, Market]:
"""
Get markets
"""
Expand All @@ -76,36 +52,33 @@ async def get_markets():
}
return retval

@router.get("/assets/")
async def get_assets(market_id:str=""):
"""
Get assets
@router.get("/markets/{market_id}/assets")
async def get_assets(market_id:str="b13a4902-ad9d-11ef-a239-00155d3ba217",
start:int=0, limit:int=1000) -> list[Asset]:
"""Get assets array, The maximus lenth is 1000

Args:
market_id (str, optional): Market ID string. Defaults to
"b13a4902-ad9d-11ef-a239-00155d3ba217".
start (int, optional): Start index. Defaults to 0.

Returns:
dict[str, Asset]: _description_
"""
ret = {}
markets = []
if len(market_id) != 0 and market_id not in DataHub.inst().markets:
LOG.error("could not find the market %s", market_id)
return ret
if len(market_id) == 0:
for id_ in DataHub.inst().markets:
markets.append(id_)
else:
markets.append(market_id)

for id_ in markets:
market_inst = DataHub.inst().markets[id_]
for asset in market_inst.assets.values():
if market_inst.market_id == "b13a4902-ad9d-11ef-a239-00155d3ba217" and \
asset.asset_type == "spot":
ret[asset.name] = asset.to_dict()
elif market_inst.market_id == "5784f1f5-d8f6-401d-8d24-f685a3812f2d" and \
asset.asset_type == "stock":
ret[asset.name] = asset.to_dict()
return ret

@router.get("/asset/fetch_ohlcv/")
markets = DataHub.inst().markets

if market_id not in markets:
raise HTTPException(status_code=404, detail="Item not found")

assets = list(markets[market_id].assets.values())

if start > len(assets) - 1:
raise HTTPException(status_code=404, detail="Item not found")
return [item.to_dict() for item in assets[start:min(start + limit, len(assets))]]

@router.get("/asset/fetch_ohlcv")
async def fetch_ohlcv(assetname:str='btc_usdt', interval="1d",
since:int=-1, to:int=-1,limit:int=300):
since:int=-1, to:int=-1,limit:int=300) -> list[OHLCV]:
"""fetch ohlcv

Args:
Expand All @@ -119,7 +92,7 @@ async def fetch_ohlcv(assetname:str='btc_usdt', interval="1d",
_type_: _description_
"""
retval = {}
LOG.info("fetch_ohlcv: %s", assetname)
LOG.info("fetch_ohlcv: %s, interval: %s", assetname, interval)
asset = DataHub.inst().get_asset(assetname)
if asset is not None:
ret = asset.fetch_ohlcv(interval, since, to, limit)
Expand Down
14 changes: 0 additions & 14 deletions src/gentrade_server/routers/secure.py

This file was deleted.

Loading