diff --git a/fastapi_amis_admin/globals/__init__.py b/fastapi_amis_admin/globals/__init__.py new file mode 100644 index 0000000..c8c38e2 --- /dev/null +++ b/fastapi_amis_admin/globals/__init__.py @@ -0,0 +1,2 @@ +from .db import async_db, sync_db +from .sites import site diff --git a/fastapi_amis_admin/globals/db.py b/fastapi_amis_admin/globals/db.py new file mode 100644 index 0000000..2992310 --- /dev/null +++ b/fastapi_amis_admin/globals/db.py @@ -0,0 +1,77 @@ +from collections import namedtuple +from typing import Dict, Union + +from lazy_object_proxy import Proxy +from sqlalchemy_database import AsyncDatabase, Database +from typing_extensions import overload + +from fastapi_amis_admin.globals.sites import exists_site, get_site + +DBs = namedtuple("DBs", ["sync", "async_"]) + +__dbs__: Dict[str, DBs] = {} + + +@overload +def get_db(alias: str = "default") -> Database: + ... + + +@overload +def get_db(alias: str = "default", is_async: bool = True) -> AsyncDatabase: + ... + + +def get_async_db(alias: str = "default") -> AsyncDatabase: + """获取异步数据库""" + return get_db(alias, is_async=True) + + +def get_db(alias: str = "default", is_async: bool = False) -> Union[Database, AsyncDatabase]: + """获取数据库""" + if alias in __dbs__: + dbs = __dbs__[alias] + if is_async and dbs.async_: + return dbs.async_ + elif not is_async and dbs.sync: + return dbs.sync + if exists_site(alias): + db = get_site(alias).db + if is_async and isinstance(db, AsyncDatabase): + return db + elif not is_async and isinstance(db, Database): + return db + raise ValueError(f"db[{alias}] not found, please call `set_db` first") + + +def set_db(db: Union[Database, AsyncDatabase], alias: str = "default") -> None: + """设置数据库""" + if alias not in __dbs__: + __dbs__[alias] = DBs(sync=None, async_=None) + if isinstance(db, AsyncDatabase): + if __dbs__[alias].async_ is not None: + raise ValueError(f"async db[{alias}] already exists") + __dbs__[alias] = __dbs__[alias]._replace(async_=db) + elif isinstance(db, Database): + if __dbs__[alias].sync is not None: + raise ValueError(f"sync db[{alias}] already exists") + __dbs__[alias] = __dbs__[alias]._replace(sync=db) + else: + raise ValueError(f"db[{alias}] must be Database or AsyncDatabase") + + +def exists_db(alias: str = "default", is_async: bool = False) -> bool: + """判断数据库是否存在""" + if alias in __dbs__: + dbs = __dbs__[alias] + if is_async and dbs.async_: + return True + elif not is_async and dbs.sync: + return True + return False + + +# 默认同步数据库.需要在项目启动时调用`set_db`设置 +sync_db: Database = Proxy(get_db) +# 默认异步数据库.需要在项目启动时调用`set_db`设置 +async_db: AsyncDatabase = Proxy(get_async_db) diff --git a/fastapi_amis_admin/globals/deps.py b/fastapi_amis_admin/globals/deps.py new file mode 100644 index 0000000..8064879 --- /dev/null +++ b/fastapi_amis_admin/globals/deps.py @@ -0,0 +1,9 @@ +from fastapi import Depends +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session +from typing_extensions import Annotated + +from fastapi_amis_admin.globals.db import async_db, sync_db + +SyncSess = Annotated[Session, Depends(sync_db.session_generator)] +AsyncSess = Annotated[AsyncSession, Depends(async_db.session_generator)] diff --git a/fastapi_amis_admin/globals/sites.py b/fastapi_amis_admin/globals/sites.py new file mode 100644 index 0000000..1706b84 --- /dev/null +++ b/fastapi_amis_admin/globals/sites.py @@ -0,0 +1,30 @@ +from typing import Dict + +from lazy_object_proxy import Proxy + +from fastapi_amis_admin.admin import AdminSite + +__sites__: Dict[str, AdminSite] = {} + + +def get_site(alias: str = "default") -> AdminSite: + """获取站点""" + if alias not in __sites__: + raise ValueError(f"site[{alias}] not found, please call `set_site` first") + return __sites__[alias] + + +def set_site(site: AdminSite, alias: str = "default") -> None: + """设置站点""" + if alias in __sites__: + raise ValueError(f"site[{alias}] already exists") + __sites__[alias] = site + + +def exists_site(alias: str = "default") -> bool: + """判断站点是否存在""" + return alias in __sites__ + + +# 默认站点.需要在项目启动时调用`set_site`设置 +site: AdminSite = Proxy(get_site) diff --git a/pyproject.toml b/pyproject.toml index 346e8c7..d39feaf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dependencies = [ "python-multipart>=0.0.5", "sqlalchemy-database>=0.1.0,<0.2.0", "aiofiles>=0.17.0", + "lazy-object-proxy>=1.9.0", ] diff --git a/tests/test_globals/__init__.py b/tests/test_globals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_globals/test_db.py b/tests/test_globals/test_db.py new file mode 100644 index 0000000..2c8b885 --- /dev/null +++ b/tests/test_globals/test_db.py @@ -0,0 +1,38 @@ +import pytest + +from fastapi_amis_admin.globals import db, sites +from tests.conftest import sync_db + + +def test_sites(site): + # 清空sites,db,确保测试环境 + sites.__sites__.clear() + db.__dbs__.clear() + # 没有设置db,返回False + assert db.exists_db() is False + assert db.exists_db(is_async=True) is False + with pytest.raises(ValueError): + assert db.sync_db + with pytest.raises(ValueError): + assert db.async_db + # 设置site,site.db为异步数据库 + sites.set_site(site) + # 判断db是否存在 + assert db.exists_db() is False + assert db.exists_db(is_async=True) is False + # 获取db + with pytest.raises(ValueError): + assert db.sync_db + assert db.async_db == site.db # 读取site.db + assert isinstance(db.async_db, db.AsyncDatabase) + # 设置同步db + db.set_db(sync_db) + # 判断db是否存在 + assert db.exists_db() is True + assert db.exists_db(is_async=True) is False + # 获取db + assert db.sync_db == sync_db + assert isinstance(db.sync_db, db.Database) + # 重复设置db + with pytest.raises(ValueError): + db.set_db(sync_db) diff --git a/tests/test_globals/test_sites.py b/tests/test_globals/test_sites.py new file mode 100644 index 0000000..3cb996f --- /dev/null +++ b/tests/test_globals/test_sites.py @@ -0,0 +1,26 @@ +import pytest + +from fastapi_amis_admin.globals import sites + + +def test_sites(site): + # 清空sites,确保测试环境 + sites.__sites__.clear() + # 没有设置站点,返回False + assert sites.exists_site() is False + # 没有设置站点,抛出异常 + with pytest.raises(ValueError): + assert sites.get_site() + with pytest.raises(ValueError): + assert sites.site + # 设置站点 + sites.set_site(site) + # 判断站点是否存在 + assert sites.exists_site() is True + # 获取站点 + assert sites.get_site() == site + assert sites.site == site + assert isinstance(sites.site, sites.AdminSite) + # 重复设置站点 + with pytest.raises(ValueError): + sites.set_site(site)