-
Notifications
You must be signed in to change notification settings - Fork 163
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add the globals module for easy sharing of global databases and…
… sites
- Loading branch information
Showing
8 changed files
with
183 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .db import async_db, sync_db | ||
from .sites import site |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from collections import namedtuple | ||
from typing import Dict, Union | ||
|
||
from lazy_object_proxy import Proxy | ||
from sqlalchemy_database import AsyncDatabase, Database | ||
from typing_extensions import overload | ||
|
||
from fastapi_amis_admin.globals.sites import exists_site, get_site | ||
|
||
DBs = namedtuple("DBs", ["sync", "async_"]) | ||
|
||
__dbs__: Dict[str, DBs] = {} | ||
|
||
|
||
@overload | ||
def get_db(alias: str = "default") -> Database: | ||
... | ||
|
||
|
||
@overload | ||
def get_db(alias: str = "default", is_async: bool = True) -> AsyncDatabase: | ||
... | ||
|
||
|
||
def get_async_db(alias: str = "default") -> AsyncDatabase: | ||
"""获取异步数据库""" | ||
return get_db(alias, is_async=True) | ||
|
||
|
||
def get_db(alias: str = "default", is_async: bool = False) -> Union[Database, AsyncDatabase]: | ||
"""获取数据库""" | ||
if alias in __dbs__: | ||
dbs = __dbs__[alias] | ||
if is_async and dbs.async_: | ||
return dbs.async_ | ||
elif not is_async and dbs.sync: | ||
return dbs.sync | ||
if exists_site(alias): | ||
db = get_site(alias).db | ||
if is_async and isinstance(db, AsyncDatabase): | ||
return db | ||
elif not is_async and isinstance(db, Database): | ||
return db | ||
raise ValueError(f"db[{alias}] not found, please call `set_db` first") | ||
|
||
|
||
def set_db(db: Union[Database, AsyncDatabase], alias: str = "default") -> None: | ||
"""设置数据库""" | ||
if alias not in __dbs__: | ||
__dbs__[alias] = DBs(sync=None, async_=None) | ||
if isinstance(db, AsyncDatabase): | ||
if __dbs__[alias].async_ is not None: | ||
raise ValueError(f"async db[{alias}] already exists") | ||
__dbs__[alias] = __dbs__[alias]._replace(async_=db) | ||
elif isinstance(db, Database): | ||
if __dbs__[alias].sync is not None: | ||
raise ValueError(f"sync db[{alias}] already exists") | ||
__dbs__[alias] = __dbs__[alias]._replace(sync=db) | ||
else: | ||
raise ValueError(f"db[{alias}] must be Database or AsyncDatabase") | ||
|
||
|
||
def exists_db(alias: str = "default", is_async: bool = False) -> bool: | ||
"""判断数据库是否存在""" | ||
if alias in __dbs__: | ||
dbs = __dbs__[alias] | ||
if is_async and dbs.async_: | ||
return True | ||
elif not is_async and dbs.sync: | ||
return True | ||
return False | ||
|
||
|
||
# 默认同步数据库.需要在项目启动时调用`set_db`设置 | ||
sync_db: Database = Proxy(get_db) | ||
# 默认异步数据库.需要在项目启动时调用`set_db`设置 | ||
async_db: AsyncDatabase = Proxy(get_async_db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from fastapi import Depends | ||
from sqlalchemy.ext.asyncio import AsyncSession | ||
from sqlalchemy.orm import Session | ||
from typing_extensions import Annotated | ||
|
||
from fastapi_amis_admin.globals.db import async_db, sync_db | ||
|
||
SyncSess = Annotated[Session, Depends(sync_db.session_generator)] | ||
AsyncSess = Annotated[AsyncSession, Depends(async_db.session_generator)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
from typing import Dict | ||
|
||
from lazy_object_proxy import Proxy | ||
|
||
from fastapi_amis_admin.admin import AdminSite | ||
|
||
__sites__: Dict[str, AdminSite] = {} | ||
|
||
|
||
def get_site(alias: str = "default") -> AdminSite: | ||
"""获取站点""" | ||
if alias not in __sites__: | ||
raise ValueError(f"site[{alias}] not found, please call `set_site` first") | ||
return __sites__[alias] | ||
|
||
|
||
def set_site(site: AdminSite, alias: str = "default") -> None: | ||
"""设置站点""" | ||
if alias in __sites__: | ||
raise ValueError(f"site[{alias}] already exists") | ||
__sites__[alias] = site | ||
|
||
|
||
def exists_site(alias: str = "default") -> bool: | ||
"""判断站点是否存在""" | ||
return alias in __sites__ | ||
|
||
|
||
# 默认站点.需要在项目启动时调用`set_site`设置 | ||
site: AdminSite = Proxy(get_site) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import pytest | ||
|
||
from fastapi_amis_admin.globals import db, sites | ||
from tests.conftest import sync_db | ||
|
||
|
||
def test_sites(site): | ||
# 清空sites,db,确保测试环境 | ||
sites.__sites__.clear() | ||
db.__dbs__.clear() | ||
# 没有设置db,返回False | ||
assert db.exists_db() is False | ||
assert db.exists_db(is_async=True) is False | ||
with pytest.raises(ValueError): | ||
assert db.sync_db | ||
with pytest.raises(ValueError): | ||
assert db.async_db | ||
# 设置site,site.db为异步数据库 | ||
sites.set_site(site) | ||
# 判断db是否存在 | ||
assert db.exists_db() is False | ||
assert db.exists_db(is_async=True) is False | ||
# 获取db | ||
with pytest.raises(ValueError): | ||
assert db.sync_db | ||
assert db.async_db == site.db # 读取site.db | ||
assert isinstance(db.async_db, db.AsyncDatabase) | ||
# 设置同步db | ||
db.set_db(sync_db) | ||
# 判断db是否存在 | ||
assert db.exists_db() is True | ||
assert db.exists_db(is_async=True) is False | ||
# 获取db | ||
assert db.sync_db == sync_db | ||
assert isinstance(db.sync_db, db.Database) | ||
# 重复设置db | ||
with pytest.raises(ValueError): | ||
db.set_db(sync_db) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import pytest | ||
|
||
from fastapi_amis_admin.globals import sites | ||
|
||
|
||
def test_sites(site): | ||
# 清空sites,确保测试环境 | ||
sites.__sites__.clear() | ||
# 没有设置站点,返回False | ||
assert sites.exists_site() is False | ||
# 没有设置站点,抛出异常 | ||
with pytest.raises(ValueError): | ||
assert sites.get_site() | ||
with pytest.raises(ValueError): | ||
assert sites.site | ||
# 设置站点 | ||
sites.set_site(site) | ||
# 判断站点是否存在 | ||
assert sites.exists_site() is True | ||
# 获取站点 | ||
assert sites.get_site() == site | ||
assert sites.site == site | ||
assert isinstance(sites.site, sites.AdminSite) | ||
# 重复设置站点 | ||
with pytest.raises(ValueError): | ||
sites.set_site(site) |