diff --git a/apps/admin/depends.py b/apps/admin/dependencies.py similarity index 62% rename from apps/admin/depends.py rename to apps/admin/dependencies.py index 8e9e590b7..8efcced9b 100644 --- a/apps/admin/depends.py +++ b/apps/admin/dependencies.py @@ -7,9 +7,10 @@ from fastapi import Header, HTTPException from fastapi.requests import Request from core.settings import settings +from apps.admin.services import FileService, ConfigService, LocalFileService -async def admin_required(authorization: Union[str, None] = Header(default=None), request: Request = None): +async def admin_required(authorization: str = Header(default=None), request: Request = None): is_admin = authorization == str(settings.admin_token) if request.url.path.startswith('/share/'): if not settings.openUpload and not is_admin: @@ -17,3 +18,16 @@ async def admin_required(authorization: Union[str, None] = Header(default=None), else: if not is_admin: raise HTTPException(status_code=401, detail='未授权或授权校验失败') + return is_admin + + +async def get_file_service(): + return FileService() + + +async def get_config_service(): + return ConfigService() + + +async def get_local_file_service(): + return LocalFileService() diff --git a/apps/admin/pydantics.py b/apps/admin/pydantics.py deleted file mode 100644 index d46f1b7c5..000000000 --- a/apps/admin/pydantics.py +++ /dev/null @@ -1,5 +0,0 @@ -from pydantic import BaseModel - - -class IDData(BaseModel): - id: int diff --git a/apps/admin/schemas.py b/apps/admin/schemas.py new file mode 100644 index 000000000..c0028348e --- /dev/null +++ b/apps/admin/schemas.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class IDData(BaseModel): + id: int + + +class ConfigUpdateData(BaseModel): + admin_token: str + + +class ShareItem(BaseModel): + expire_value: int + expire_style: str = 'day' + filename: str + + +class DeleteItem(BaseModel): + filename: str diff --git a/apps/admin/services.py b/apps/admin/services.py new file mode 100644 index 000000000..f4d753de2 --- /dev/null +++ b/apps/admin/services.py @@ -0,0 +1,123 @@ +import os +import time + +from core.response import APIResponse +from core.storage import FileStorageInterface, storages +from core.settings import settings +from apps.base.models import FileCodes, KeyValue +from apps.base.utils import get_expire_info, get_file_path_name +from fastapi import HTTPException +from core.settings import data_root + + +class FileService: + def __init__(self): + self.file_storage: FileStorageInterface = storages[settings.file_storage]() + + async def delete_file(self, file_id: int): + file_code = await FileCodes.get(id=file_id) + await self.file_storage.delete_file(file_code) + await file_code.delete() + + async def list_files(self, page: int, size: int): + offset = (page - 1) * size + files = await FileCodes.all().limit(size).offset(offset) + total = await FileCodes.all().count() + return files, total + + async def download_file(self, file_id: int): + file_code = await FileCodes.filter(id=file_id).first() + if not file_code: + raise HTTPException(status_code=404, detail='文件不存在') + if file_code.text: + return APIResponse(detail=file_code.text) + else: + return await self.file_storage.get_file_response(file_code) + + async def share_local_file(self, item): + local_file = LocalFileClass(item.filename) + if not await local_file.exists(): + raise HTTPException(status_code=404, detail='文件不存在') + + text = await local_file.read() + expired_at, expired_count, used_count, code = await get_expire_info(item.expire_value, item.expire_style) + path, suffix, prefix, uuid_file_name, save_path = await get_file_path_name(item) + + await self.file_storage.save_file(text, save_path) + + await FileCodes.create( + code=code, + prefix=prefix, + suffix=suffix, + uuid_file_name=uuid_file_name, + file_path=path, + size=local_file.size, + expired_at=expired_at, + expired_count=expired_count, + used_count=used_count, + ) + + return { + 'code': code, + 'name': local_file.file, + } + + +class ConfigService: + def get_config(self): + return settings.items() + + async def update_config(self, data: dict): + admin_token = data.get('admin_token') + if admin_token is None or admin_token == '': + raise HTTPException(status_code=400, detail='管理员密码不能为空') + + for key, value in data.items(): + if key not in settings.default_config: + continue + if key in ['errorCount', 'errorMinute', 'max_save_seconds', 'onedrive_proxy', 'openUpload', 'port', 's3_proxy', 'uploadCount', 'uploadMinute', 'uploadSize']: + data[key] = int(value) + elif key in ['opacity']: + data[key] = float(value) + else: + data[key] = value + + await KeyValue.filter(key='settings').update(value=data) + for k, v in data.items(): + settings.__setattr__(k, v) + + +class LocalFileService: + async def list_files(self): + files = [] + for file in os.listdir(data_root / 'local'): + files.append(LocalFileClass(file)) + return files + + async def delete_file(self, filename: str): + file = LocalFileClass(filename) + if await file.exists(): + await file.delete() + return '删除成功' + raise HTTPException(status_code=404, detail='文件不存在') + + +class LocalFileClass: + def __init__(self, file): + self.file = file + self.path = data_root / 'local' / file + self.ctime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(os.path.getctime(self.path))) + self.size = os.path.getsize(self.path) + + async def read(self): + return open(self.path, 'rb') + + async def write(self, data): + with open(self.path, 'w') as f: + f.write(data) + + async def delete(self): + os.remove(self.path) + + async def exists(self): + return os.path.exists(self.path) diff --git a/apps/admin/views.py b/apps/admin/views.py index 1ce4a6170..9063bd140 100644 --- a/apps/admin/views.py +++ b/apps/admin/views.py @@ -2,180 +2,99 @@ # @Author : Lan # @File : views.py # @Software: PyCharm -import math -import os -import time -from fastapi import APIRouter, Depends, Form -from pydantic import BaseModel - -from apps.admin.depends import admin_required -from apps.admin.pydantics import IDData -from apps.base.models import FileCodes, KeyValue -from apps.base.utils import get_expire_info, get_file_path_name +from fastapi import APIRouter, Depends +from apps.admin.services import FileService, ConfigService, LocalFileService +from apps.admin.dependencies import admin_required, get_file_service, get_config_service, get_local_file_service +from apps.admin.schemas import IDData, ConfigUpdateData, ShareItem, DeleteItem from core.response import APIResponse -from core.settings import settings, data_root -from core.storage import FileStorageInterface, storages -admin_api = APIRouter( - prefix='/admin', - tags=['管理'], -) +admin_api = APIRouter(prefix='/admin', tags=['管理']) -@admin_api.post('/login', dependencies=[Depends(admin_required)]) -async def login(): +@admin_api.post('/login') +async def login(admin: bool = Depends(admin_required)): return APIResponse() -@admin_api.delete('/file/delete', dependencies=[Depends(admin_required)]) -async def file_delete(data: IDData): - file_storage: FileStorageInterface = storages[settings.file_storage]() - file_code = await FileCodes.get(id=data.id) - await file_storage.delete_file(file_code) - await file_code.delete() +@admin_api.delete('/file/delete') +async def file_delete( + data: IDData, + file_service: FileService = Depends(get_file_service), + admin: bool = Depends(admin_required) +): + await file_service.delete_file(data.id) return APIResponse() -@admin_api.get('/file/list', dependencies=[Depends(admin_required)]) -async def file_list(page: float = 1, size: int = 10): +@admin_api.get('/file/list') +async def file_list( + page: int = 1, + size: int = 10, + file_service: FileService = Depends(get_file_service), + admin: bool = Depends(admin_required) +): + files, total = await file_service.list_files(page, size) return APIResponse(detail={ 'page': page, 'size': size, - 'data': await FileCodes.all().limit(size).offset((math.ceil(page) - 1) * size), - 'total': await FileCodes.all().count(), + 'data': files, + 'total': total, }) -@admin_api.get('/config/get', dependencies=[Depends(admin_required)]) -async def get_config(): - return APIResponse(detail=settings.items()) - - -@admin_api.patch('/config/update', dependencies=[Depends(admin_required)]) -async def update_config(data: dict): - admin_token = data.get('admin_token') - for key, value in data.items(): - if key not in settings.default_config: - continue - if key in ['errorCount', 'errorMinute', 'max_save_seconds', 'onedrive_proxy', 'openUpload', 'port', 's3_proxy', 'uploadCount', 'uploadMinute', 'uploadSize']: - data[key] = int(value) - elif key in ['opacity']: - data[key] = float(value) - else: - data[key] = value - if admin_token is None or admin_token == '': - return APIResponse(code=400, detail='管理员密码不能为空') - await KeyValue.filter(key='settings').update(value=data) - for k, v in data.items(): - settings.__setattr__(k, v) +@admin_api.get('/config/get') +async def get_config( + config_service: ConfigService = Depends(get_config_service), + admin: bool = Depends(admin_required) +): + return APIResponse(detail=config_service.get_config()) + + +@admin_api.patch('/config/update') +async def update_config( + data: ConfigUpdateData, + config_service: ConfigService = Depends(get_config_service), + admin: bool = Depends(admin_required) +): + await config_service.update_config(data) return APIResponse() -# 根据code获取文件 -async def get_file_by_id(id): - # 查询文件 - file_code = await FileCodes.filter(id=id).first() - # 检查文件是否存在 - if not file_code: - return False, '文件不存在' - return True, file_code - - -@admin_api.get('/file/download', dependencies=[Depends(admin_required)]) -async def file_download(id: int): - file_storage: FileStorageInterface = storages[settings.file_storage]() - has, file_code = await get_file_by_id(id) - # 检查文件是否存在 - if not has: - # 返回API响应 - return APIResponse(code=404, detail='文件不存在') - # 如果文件是文本,返回文本内容,否则返回文件响应 - if file_code.text: - return APIResponse(detail=file_code.text) - else: - return await file_storage.get_file_response(file_code) - - -class LocalFileClass: - def __init__(self, file): - self.file = file - self.path = data_root / 'local' / file - self.ctime = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(os.path.getctime(self.path))) - self.size = os.path.getsize(self.path) - - async def read(self): - return open(self.path, 'rb') - - async def write(self, data): - with open(self.path, 'w') as f: - f.write(data) - - async def delete(self): - os.remove(self.path) - - async def exists(self): - return os.path.exists(self.path) - - -@admin_api.get('/local/lists', dependencies=[Depends(admin_required)]) -async def get_local_lists(): - files = [] - for file in os.listdir(data_root / 'local'): - files.append(LocalFileClass(file)) +@admin_api.get('/file/download') +async def file_download( + id: int, + file_service: FileService = Depends(get_file_service), + admin: bool = Depends(admin_required) +): + file_content = await file_service.download_file(id) + return file_content + + +@admin_api.get('/local/lists') +async def get_local_lists( + local_file_service: LocalFileService = Depends(get_local_file_service), + admin: bool = Depends(admin_required) +): + files = await local_file_service.list_files() return APIResponse(detail=files) -class DeleteItem(BaseModel): - filename: str - - -@admin_api.delete('/local/delete', dependencies=[Depends(admin_required)]) -async def delete_local_file(item: DeleteItem): - file = LocalFileClass(item.filename) - if await file.exists(): - await file.delete() - return APIResponse(detail='删除成功') - return APIResponse(code=404, detail='文件不存在') - - -class ShareItem(BaseModel): - expire_value: int - expire_style: str = 'day' - filename: str - - -class File: - def __init__(self, file): - self.file = file - - -@admin_api.post('/local/share', dependencies=[Depends(admin_required)]) -async def share_local_file(item: ShareItem): - file = LocalFileClass(item.filename) - if not await file.exists(): - return APIResponse(code=404, detail='文件不存在') - text = File(await file.read()) - expired_at, expired_count, used_count, code = await get_expire_info(item.expire_value, item.expire_style) - # 获取文件路径和名称 - path, suffix, prefix, uuid_file_name, save_path = await get_file_path_name(item) - # 保存文件 - file_storage: FileStorageInterface = storages[settings.file_storage]() - await file_storage.save_file(text, save_path) - # 创建一个新的FileCodes实例 - await FileCodes.create( - code=code, - prefix=prefix, - suffix=suffix, - uuid_file_name=uuid_file_name, - file_path=path, - size=file.size, - expired_at=expired_at, - expired_count=expired_count, - used_count=used_count, - ) - # 返回API响应 - return APIResponse(detail={ - 'code': code, - 'name': file.file, - }) +@admin_api.delete('/local/delete') +async def delete_local_file( + item: DeleteItem, + local_file_service: LocalFileService = Depends(get_local_file_service), + admin: bool = Depends(admin_required) +): + result = await local_file_service.delete_file(item.filename) + return APIResponse(detail=result) + + +@admin_api.post('/local/share') +async def share_local_file( + item: ShareItem, + file_service: FileService = Depends(get_file_service), + admin: bool = Depends(admin_required) +): + share_info = await file_service.share_local_file(item) + return APIResponse(detail=share_info) diff --git a/apps/base/dependencies.py b/apps/base/dependencies.py new file mode 100644 index 000000000..614eb5e2a --- /dev/null +++ b/apps/base/dependencies.py @@ -0,0 +1,37 @@ +from typing import Dict, Union +from datetime import datetime, timedelta +from fastapi import HTTPException, Request + + +class IPRateLimit: + def __init__(self, count: int, minutes: int): + self.ips: Dict[str, Dict[str, Union[int, datetime]]] = {} + self.count = count + self.minutes = minutes + + def check_ip(self, ip: str) -> bool: + if ip in self.ips: + ip_info = self.ips[ip] + if ip_info['count'] >= self.count: + if ip_info['time'] + timedelta(minutes=self.minutes) > datetime.now(): + return False + self.ips.pop(ip) + return True + + def add_ip(self, ip: str) -> int: + ip_info = self.ips.get(ip, {'count': 0, 'time': datetime.now()}) + ip_info['count'] += 1 + ip_info['time'] = datetime.now() + self.ips[ip] = ip_info + return ip_info['count'] + + async def remove_expired_ip(self) -> None: + now = datetime.now() + expiration = timedelta(minutes=self.minutes) + self.ips = {ip: info for ip, info in self.ips.items() if info['time'] + expiration >= now} + + def __call__(self, request: Request) -> str: + ip = request.headers.get('X-Real-IP') or request.headers.get('X-Forwarded-For') or request.client.host + if not self.check_ip(ip): + raise HTTPException(status_code=423, detail="请求次数过多,请稍后再试") + return ip diff --git a/apps/base/depends.py b/apps/base/depends.py deleted file mode 100644 index d76305166..000000000 --- a/apps/base/depends.py +++ /dev/null @@ -1,45 +0,0 @@ -# @Time : 2023/8/14 12:20 -# @Author : Lan -# @File : depends.py -# @Software: PyCharm -from typing import Union -from datetime import datetime, timedelta - -from fastapi import Header, HTTPException, Request - -from core.response import APIResponse - - -class IPRateLimit: - def __init__(self, count, minutes): - self.ips = {} - self.count = count - self.minutes = minutes - - def check_ip(self, ip): - # 检查ip是否被禁止 - if ip in self.ips: - if int(self.ips[ip]['count']) >= int(self.count): - if self.ips[ip]['time'] + timedelta(minutes=self.minutes) > datetime.now(): - return False - else: - self.ips.pop(ip) - return True - - def add_ip(self, ip): - ip_info = self.ips.get(ip, {'count': 0, 'time': datetime.now()}) - ip_info['count'] += 1 - ip_info['time'] = datetime.now() - self.ips[ip] = ip_info - return ip_info['count'] - - async def remove_expired_ip(self): - for ip in list(self.ips.keys()): - if self.ips[ip]['time'] + timedelta(minutes=self.minutes) < datetime.now(): - self.ips.pop(ip) - - def __call__(self, request: Request): - ip = request.headers.get('X-Real-IP', request.headers.get('X-Forwarded-For', request.client.host)) - if not self.check_ip(ip): - raise HTTPException(status_code=423, detail=f"请求次数过多,请稍后再试") - return ip diff --git a/apps/base/pydantics.py b/apps/base/schemas.py similarity index 100% rename from apps/base/pydantics.py rename to apps/base/schemas.py diff --git a/apps/base/utils.py b/apps/base/utils.py index 5ba36a845..3224935c1 100644 --- a/apps/base/utils.py +++ b/apps/base/utils.py @@ -1,85 +1,67 @@ -# @Time : 2023/8/14 01:10 -# @Author : Lan -# @File : utils.py -# @Software: PyCharm import datetime import uuid import os from fastapi import UploadFile, HTTPException +from typing import Tuple, Optional -from apps.base.depends import IPRateLimit +from apps.base.dependencies import IPRateLimit from apps.base.models import FileCodes from core.settings import settings from core.utils import get_random_num, get_random_string, max_save_times_desc -async def get_file_path_name(file: UploadFile): - """ - 获取文件路径和文件名 - :param file: - :return: { - 'path': 'share/data/2021/08/13', - 'suffix': '.jpg', - 'prefix': 'test', - 'file_uuid': '44a83bbd70e04c8aa7fd93bfd8c88249', - 'uuid_file_name': '44a83bbd70e04c8aa7fd93bfd8c88249.jpg', - 'save_path': 'share/data/2021/08/13/44a83bbd70e04c8aa7fd93bfd8c88249.jpg' - } - """ +async def get_file_path_name(file: UploadFile) -> Tuple[str, str, str, str, str]: + """获取文件路径和文件名""" today = datetime.datetime.now() path = f"share/data/{today.strftime('%Y/%m/%d')}" prefix, suffix = os.path.splitext(file.filename) - file_uuid = f"{uuid.uuid4().hex}" + file_uuid = uuid.uuid4().hex uuid_file_name = f"{file_uuid}{suffix}" save_path = f"{path}/{uuid_file_name}" return path, suffix, prefix, uuid_file_name, save_path -async def get_expire_info(expire_value: int, expire_style: str): - """ - 获取过期信息 - :param expire_value: - :param expire_style: - :return: expired_at 过期时间, expired_count 可用次数, used_count 已用次数, code 随机码 - """ - expired_count, used_count, now, code = -1, 0, datetime.datetime.now(), None - if int(settings.max_save_seconds) > 0: - max_timedelta = datetime.timedelta(seconds=settings.max_save_seconds) - detail = await max_save_times_desc(settings.max_save_seconds) - detail = f'限制最长时间为 {detail[0]},可换用其他方式' - else: - max_timedelta = datetime.timedelta(days=7) - detail = '限制最长时间为 7天,可换用其他方式' - if expire_style == 'day': - if datetime.timedelta(days=expire_value) > max_timedelta: - raise HTTPException(status_code=403, detail=detail) - expired_at = now + datetime.timedelta(days=expire_value) - elif expire_style == 'hour': - if datetime.timedelta(hours=expire_value) > max_timedelta: - raise HTTPException(status_code=403, detail=detail) - expired_at = now + datetime.timedelta(hours=expire_value) - elif expire_style == 'minute': - if datetime.timedelta(minutes=expire_value) > max_timedelta: +async def get_expire_info(expire_value: int, expire_style: str) -> Tuple[Optional[datetime.datetime], int, int, str]: + """获取过期信息""" + expired_count, used_count = -1, 0 + now = datetime.datetime.now() + code = None + + max_timedelta = datetime.timedelta(seconds=settings.max_save_seconds) if settings.max_save_seconds > 0 else datetime.timedelta(days=7) + detail = await max_save_times_desc(settings.max_save_seconds) if settings.max_save_seconds > 0 else '7天' + detail = f'限制最长时间为 {detail[0]},可换用其他方式' + + expire_styles = { + 'day': lambda: now + datetime.timedelta(days=expire_value), + 'hour': lambda: now + datetime.timedelta(hours=expire_value), + 'minute': lambda: now + datetime.timedelta(minutes=expire_value), + 'count': lambda: (now + datetime.timedelta(days=1), expire_value), + 'forever': lambda: (None, None), # 修改这里 + } + + if expire_style in expire_styles: + result = expire_styles[expire_style]() + if isinstance(result, tuple): + expired_at, extra = result + if expire_style == 'count': + expired_count = extra + elif expire_style == 'forever': + code = await get_random_code(style='string') # 移动到这里 + else: + expired_at = result + if expired_at and expired_at - now > max_timedelta: raise HTTPException(status_code=403, detail=detail) - expired_at = now + datetime.timedelta(minutes=expire_value) - elif expire_style == 'count': - expired_at = now + datetime.timedelta(days=1) - expired_count = expire_value - elif expire_style == 'forever': - expired_at = None - code = await get_random_code(style='string') else: expired_at = now + datetime.timedelta(days=1) + if not code: code = await get_random_code() + return expired_at, expired_count, used_count, code -async def get_random_code(style='num'): - """ - 获取随机字符串 - :return: - """ +async def get_random_code(style='num') -> str: + """获取随机字符串""" while True: code = await get_random_num() if style == 'num' else await get_random_string() if not await FileCodes.filter(code=code).exists(): diff --git a/apps/base/views.py b/apps/base/views.py index 8e553ebf2..bd22b6636 100644 --- a/apps/base/views.py +++ b/apps/base/views.py @@ -1,38 +1,40 @@ -# @Time : 2023/8/14 03:59 -# @Author : Lan -# @File : views.py -# @Software: PyCharm -# 导入所需的库和模块 from fastapi import APIRouter, Form, UploadFile, File, Depends, HTTPException -from apps.admin.depends import admin_required +from apps.admin.dependencies import admin_required from apps.base.models import FileCodes -from apps.base.pydantics import SelectFileModel +from apps.base.schemas import SelectFileModel from apps.base.utils import get_expire_info, get_file_path_name, ip_limit from core.response import APIResponse from core.settings import settings from core.storage import storages, FileStorageInterface from core.utils import get_select_token -# 创建一个API路由 -share_api = APIRouter( - prefix='/share', # 路由前缀 - tags=['分享'], # 标签 -) +share_api = APIRouter(prefix='/share', tags=['分享']) + + +async def validate_file_size(file: UploadFile, max_size: int): + if file.size > max_size: + max_size_mb = max_size / (1024 * 1024) + raise HTTPException(status_code=403, detail=f'大小超过限制,最大为{max_size_mb:.2f} MB') + + +async def create_file_code(code, **kwargs): + return await FileCodes.create(code=code, **kwargs) -# 分享文本的API @share_api.post('/text/', dependencies=[Depends(admin_required)]) -async def share_text(text: str = Form(...), expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), ip: str = Depends(ip_limit['upload'])): - # 获取大小 +async def share_text( + text: str = Form(...), + expire_value: int = Form(default=1, gt=0), + expire_style: str = Form(default='day'), + ip: str = Depends(ip_limit['upload']) +): text_size = len(text.encode('utf-8')) - # 限制 222KB - max_txt_size = 222 * 1024 # 转换为字节 + max_txt_size = 222 * 1024 if text_size > max_txt_size: - raise HTTPException(status_code=403, detail=f'内容过多,建议采用文件形式') - # 获取过期信息 + raise HTTPException(status_code=403, detail='内容过多,建议采用文件形式') + expired_at, expired_count, used_count, code = await get_expire_info(expire_value, expire_style) - # 创建一个新的FileCodes实例 - await FileCodes.create( + await create_file_code( code=code, text=text, expired_at=expired_at, @@ -41,33 +43,29 @@ async def share_text(text: str = Form(...), expire_value: int = Form(default=1, size=len(text), prefix='文本分享' ) - # 添加IP到限制列表 ip_limit['upload'].add_ip(ip) - # 返回API响应 - return APIResponse(detail={ - 'code': code, - }) + return APIResponse(detail={'code': code}) -# 分享文件的API @share_api.post('/file/', dependencies=[Depends(admin_required)]) -async def share_file(expire_value: int = Form(default=1, gt=0), expire_style: str = Form(default='day'), file: UploadFile = File(...), - ip: str = Depends(ip_limit['upload'])): - if file.size > settings.uploadSize: - # 转换为 MB 并格式化输出 - max_size_mb = settings.uploadSize / (1024 * 1024) - raise HTTPException(status_code=403, detail=f'大小超过限制,最大为{max_size_mb:.2f} MB') - # 获取过期信息 +async def share_file( + expire_value: int = Form(default=1, gt=0), + expire_style: str = Form(default='day'), + file: UploadFile = File(...), + ip: str = Depends(ip_limit['upload']) +): + await validate_file_size(file, settings.uploadSize) + if expire_style not in settings.expireStyle: raise HTTPException(status_code=400, detail='过期时间类型错误') + expired_at, expired_count, used_count, code = await get_expire_info(expire_value, expire_style) - # 获取文件路径和名称 path, suffix, prefix, uuid_file_name, save_path = await get_file_path_name(file) - # 保存文件 + file_storage: FileStorageInterface = storages[settings.file_storage]() await file_storage.save_file(file, save_path) - # 创建一个新的FileCodes实例 - await FileCodes.create( + + await create_file_code( code=code, prefix=prefix, suffix=suffix, @@ -78,69 +76,47 @@ async def share_file(expire_value: int = Form(default=1, gt=0), expire_style: st expired_count=expired_count, used_count=used_count, ) - # 添加IP到限制列表 ip_limit['upload'].add_ip(ip) - # 返回API响应 - return APIResponse(detail={ - 'code': code, - 'name': file.filename, - }) + return APIResponse(detail={'code': code, 'name': file.filename}) -# 根据code获取文件 async def get_code_file_by_code(code, check=True): - # 查询文件 file_code = await FileCodes.filter(code=code).first() - # 检查文件是否存在 if not file_code: return False, '文件不存在' - # 检查文件是否过期 if await file_code.is_expired() and check: - return False, '文件已过期', + return False, '文件已过期' return True, file_code -# 获取文件的API +async def update_file_usage(file_code): + file_code.used_count += 1 + if file_code.expired_count > 0: + file_code.expired_count -= 1 + await file_code.save() + + @share_api.get('/select/') async def get_code_file(code: str, ip: str = Depends(ip_limit['error'])): file_storage: FileStorageInterface = storages[settings.file_storage]() - # 获取文件 has, file_code = await get_code_file_by_code(code) - # 检查文件是否存在 if not has: - # 添加IP到限制列表 ip_limit['error'].add_ip(ip) - # 返回API响应 return APIResponse(code=404, detail=file_code) - # 更新文件的使用次数和过期次数 - file_code.used_count += 1 - if file_code.expired_count > 0: - file_code.expired_count -= 1 - # 保存文件 - await file_code.save() - # 返回文件响应 + + await update_file_usage(file_code) return await file_storage.get_file_response(file_code) -# 选择文件的API @share_api.post('/select/') async def select_file(data: SelectFileModel, ip: str = Depends(ip_limit['error'])): file_storage: FileStorageInterface = storages[settings.file_storage]() - # 获取文件 has, file_code = await get_code_file_by_code(data.code) - # 检查文件是否存在 if not has: - # 添加IP到限制列表 ip_limit['error'].add_ip(ip) - # 返回API响应 return APIResponse(code=404, detail=file_code) - # 更新文件的使用次数和过期次数 - file_code.used_count += 1 - if file_code.expired_count > 0: - file_code.expired_count -= 1 - # 保存文件 - await file_code.save() - # 返回API响应 + + await update_file_usage(file_code) return APIResponse(detail={ 'code': file_code.code, 'name': file_code.prefix + file_code.suffix, @@ -149,23 +125,14 @@ async def select_file(data: SelectFileModel, ip: str = Depends(ip_limit['error'] }) -# 下载文件的API @share_api.get('/download') async def download_file(key: str, code: str, ip: str = Depends(ip_limit['error'])): file_storage: FileStorageInterface = storages[settings.file_storage]() - # 检查token是否有效 - is_valid = await get_select_token(code) == key - if not is_valid: - # 添加IP到限制列表 + if await get_select_token(code) != key: ip_limit['error'].add_ip(ip) - # 获取文件 + has, file_code = await get_code_file_by_code(code, False) - # 检查文件是否存在 if not has: - # 返回API响应 return APIResponse(code=404, detail='文件不存在') - # 如果文件是文本,返回文本内容,否则返回文件响应 - if file_code.text: - return APIResponse(detail=file_code.text) - else: - return await file_storage.get_file_response(file_code) + + return APIResponse(detail=file_code.text) if file_code.text else await file_storage.get_file_response(file_code) diff --git a/core/utils.py b/core/utils.py index 5fded2acb..e5812e0ff 100644 --- a/core/utils.py +++ b/core/utils.py @@ -8,7 +8,7 @@ import string import time -from apps.base.depends import IPRateLimit +from apps.base.dependencies import IPRateLimit async def get_random_num(): diff --git a/main.py b/main.py index 0217abf9b..1d0a1e050 100644 --- a/main.py +++ b/main.py @@ -38,28 +38,30 @@ async def lifespan(app: FastAPI): # 初始化数据库 await init_db() - # 启动后台任务,不定时删除过期文件 + # 启动后台任务 task = asyncio.create_task(delete_expire_files()) - # 读取用户配置 - user_config, created = await KeyValue.get_or_create(key='settings', defaults={'value': DEFAULT_CONFIG}) + + # 加载配置 + await load_config() + + try: + yield + finally: + # 清理操作 + task.cancel() + await asyncio.gather(task, return_exceptions=True) + await Tortoise.close_connections() + + +async def load_config(): + user_config, _ = await KeyValue.get_or_create(key='settings', defaults={'value': DEFAULT_CONFIG}) settings.user_config = user_config.value + # 更新 ip_limit 配置 ip_limit['error'].minutes = settings.errorMinute ip_limit['error'].count = settings.errorCount ip_limit['upload'].minutes = settings.uploadMinute ip_limit['upload'].count = settings.uploadCount - yield - - # 清理操作 - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - # 关闭数据库连接 - await Tortoise.close_connections() - app = FastAPI(lifespan=lifespan)