From 9a45358a452401429f2bf73558cd4889b238c421 Mon Sep 17 00:00:00 2001 From: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:11:06 +0800 Subject: [PATCH 1/2] feat: delete launch feat: delete task runtime chore: delete rich from requirements --- requirements.txt | 1 - swanlab/cli/commands/__init__.py | 5 +- swanlab/cli/commands/launcher/__init__.py | 44 -- .../cli/commands/launcher/parser/__init__.py | 30 -- swanlab/cli/commands/launcher/parser/model.py | 134 ------ .../commands/launcher/parser/v1/__init__.py | 13 - .../cli/commands/launcher/parser/v1/folder.py | 231 ----------- swanlab/cli/commands/task/__init__.py | 32 -- swanlab/cli/commands/task/list.py | 207 ---------- swanlab/cli/commands/task/search.py | 75 ---- swanlab/cli/commands/task/stop.py | 29 -- swanlab/cli/commands/task/utils.py | 114 ----- swanlab/cli/commands/uploader/__init__.py | 153 ------- swanlab/cli/main.py | 14 +- swanlab/cli/utils.py | 182 -------- swanlab/data/callback_cloud.py | 11 +- swanlab/env.py | 2 +- test/unit/cli/test_cli_launch.py | 389 ------------------ test/unit/cli/test_cli_task.py | 39 -- test/unit/cli/test_cli_utils.py | 31 -- test/unit/data/test_sdk.py | 14 +- 21 files changed, 15 insertions(+), 1735 deletions(-) delete mode 100644 swanlab/cli/commands/launcher/__init__.py delete mode 100644 swanlab/cli/commands/launcher/parser/__init__.py delete mode 100644 swanlab/cli/commands/launcher/parser/model.py delete mode 100644 swanlab/cli/commands/launcher/parser/v1/__init__.py delete mode 100644 swanlab/cli/commands/launcher/parser/v1/folder.py delete mode 100644 swanlab/cli/commands/task/__init__.py delete mode 100644 swanlab/cli/commands/task/list.py delete mode 100644 swanlab/cli/commands/task/search.py delete mode 100644 swanlab/cli/commands/task/stop.py delete mode 100644 swanlab/cli/commands/task/utils.py delete mode 100644 swanlab/cli/commands/uploader/__init__.py delete mode 100644 swanlab/cli/utils.py delete mode 100644 test/unit/cli/test_cli_launch.py delete mode 100644 test/unit/cli/test_cli_task.py delete mode 100644 test/unit/cli/test_cli_utils.py diff --git a/requirements.txt b/requirements.txt index 2649b1658..405d08333 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,3 @@ click pyyaml psutil>=5.0.0 pynvml -rich diff --git a/swanlab/cli/commands/__init__.py b/swanlab/cli/commands/__init__.py index 4b4e59f65..ff146cb0b 100644 --- a/swanlab/cli/commands/__init__.py +++ b/swanlab/cli/commands/__init__.py @@ -8,8 +8,5 @@ 暴露子命令 """ from .auth import login, logout -from .dashboard import watch from .converter import convert -from .task import task -from .launcher import launch -from .uploader import upload +from .dashboard import watch diff --git a/swanlab/cli/commands/launcher/__init__.py b/swanlab/cli/commands/launcher/__init__.py deleted file mode 100644 index f648bc6ab..000000000 --- a/swanlab/cli/commands/launcher/__init__.py +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:27 -@File: __init__.py -@IDE: pycharm -@Description: - 专注于任务启动 -""" -import click -import yaml -from .parser import parse -import os - -__all__ = ['launch'] - - -@click.command() -@click.option( - '--file', - '-f', - default='swanlab.yml', - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True), - help='Designated file to launch', -) -@click.option( - '--dry-run', - is_flag=True, - default=False, - help='Execute commands without applying changes, only outputting the operations that would be performed.', -) -def launch(file: str, dry_run: bool): - """ - Launch a task - """ - file = os.path.abspath(file) - config = yaml.safe_load(open(file, 'r', encoding='utf-8')) - if not isinstance(config, dict): - raise click.FileError(file, hint='Invalid configuration file') - p = parse(config, file) - p.parse() - if dry_run: - return p.dry_run() - p.run() diff --git a/swanlab/cli/commands/launcher/parser/__init__.py b/swanlab/cli/commands/launcher/parser/__init__.py deleted file mode 100644 index 28620baf3..000000000 --- a/swanlab/cli/commands/launcher/parser/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:29 -@File: __init__.py.py -@IDE: pycharm -@Description: - 解析配置文件 -""" -from . import v1 -from .model import LaunchParser -from typing import Dict, List -import click - -__all__ = ['parse'] - -parsers: Dict[str, List[LaunchParser.__class__]] = { - 'swanlab/v1': v1.parsers -} - - -def parse(config: dict, path: str) -> LaunchParser: - version = config.get("apiVersion") - if not parsers.get(version): - raise click.UsageError(f"Unknown api version: {version}") - kind = config.get("kind") - for parser in parsers[version]: - if parser.__type__() == kind: - return parser(config, path) - raise click.UsageError(f"Unknown kind: {kind}") diff --git a/swanlab/cli/commands/launcher/parser/model.py b/swanlab/cli/commands/launcher/parser/model.py deleted file mode 100644 index c3c99f453..000000000 --- a/swanlab/cli/commands/launcher/parser/model.py +++ /dev/null @@ -1,134 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:52 -@File: model.py -@IDE: pycharm -@Description: - 基础模型 -""" -import click -from typing import Type, Any -from abc import ABC, abstractmethod -import os - - -class LaunchParser(ABC): - - def __init__(self, config: dict, path: str): - self.config = config - """ - 配置文件内容 - """ - self.dirpath = os.path.dirname(path) - """ - 配置文件所在目录 - """ - - @staticmethod - def should_be(n: str, v: Any, t: Type, none: bool = False) -> Any: - """ - 检查参数是否符合预期 - :param n: 参数名 - :param v: 参数值 - :param t: 预期类型 - :param none: 是否允许为None - """ - if none and v is None: - return None - if v is None and not none: - raise click.BadParameter(f'{n} should not be None') - if not isinstance(v, t): - raise click.BadParameter(f'{n} should be {t}, not {type(v)}') - return v - - def should_file_exist(self, n: str, p: str): - """ - 检查文件是否存在,必须是文件 - :param n: 参数名 - :param p: 文件路径 - """ - p = os.path.join(self.dirpath, p) - if not os.path.exists(p): - raise click.FileError(p, hint=f'{n} not found: {p}') - if not os.path.isfile(p): - raise click.FileError(p, hint=f'{n} should be a file: {p}') - - @staticmethod - def should_in_values(n: str, v: Any, ls: list) -> Any: - """ - 检查参数是否在指定范围内 - :param n: 参数名 - :param v: 参数值 - :param ls: 范围 - """ - if v not in ls: - raise click.BadParameter(f'{n} should be in {ls}, not {v}') - return v - - @staticmethod - def should_equal_keys(n: str, v: dict, keys: list) -> Any: - """ - 确保字典的key在指定范围内 - :param n: 参数名 - :param v: 参数值 - :param keys: 范围 - """ - for k in v.keys(): - if k not in keys: - raise click.BadParameter(f'Unknown key: {n}.{k}') - return v - - @classmethod - @abstractmethod - def __type__(cls): - """ - 返回当前解析器的名称,对应到kind参数 - """ - pass - - @abstractmethod - def __dict__(self) -> dict: - """ - 最终向后端发布任务的data数据 - """ - pass - - @abstractmethod - def parse_spec(self, spec: dict): - """ - 解析spec数据 - """ - pass - - @abstractmethod - def parse_metadata(self, metadata: dict): - """ - 解析metadata数据 - """ - pass - - def parse(self): - """ - 解析整个配置文件 - """ - metadata = self.should_be('metadata', self.config.get('metadata'), dict) - self.should_equal_keys('metadata', metadata, ['name', 'desc', 'combo']) - self.parse_metadata(metadata) - spec = self.should_be('spec', self.config.get('spec'), dict) - self.should_equal_keys('spec', spec, ['python', 'entry', 'volumes', 'exclude']) - self.parse_spec(spec) - - @abstractmethod - def run(self): - """ - 执行具体的操作 - """ - pass - - @abstractmethod - def dry_run(self): - """ - 单纯向用户展示即将执行的操作,而不实际应用 - """ - pass diff --git a/swanlab/cli/commands/launcher/parser/v1/__init__.py b/swanlab/cli/commands/launcher/parser/v1/__init__.py deleted file mode 100644 index e44357e8f..000000000 --- a/swanlab/cli/commands/launcher/parser/v1/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 15:30 -@File: __init__.py.py -@IDE: pycharm -@Description: - 解析:v1版本 -""" - -from .folder import FolderParser - -parsers = [FolderParser] diff --git a/swanlab/cli/commands/launcher/parser/v1/folder.py b/swanlab/cli/commands/launcher/parser/v1/folder.py deleted file mode 100644 index 47fbeb1b8..000000000 --- a/swanlab/cli/commands/launcher/parser/v1/folder.py +++ /dev/null @@ -1,231 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:52 -@File: folder.py -@IDE: pycharm -@Description: - 文件夹上传模型 -""" -from typing import List, Tuple -import click -from ..model import LaunchParser -from swanlab.error import ApiError -from swanlab.cli.utils import login_init_sid, UseTaskHttp, CosUploader, UploadBytesIO -import zipfile -from rich.progress import ( - BarColumn, - Progress, - TextColumn, - TimeRemainingColumn, -) -from swankit.log import FONT -from rich import print as rprint -from rich.filesize import decimal -from rich.text import Text -import time -import glob -import os -import io - - -class FolderParser(LaunchParser): - """ - 承担了上传文件夹的一系列任务 - """ - - def __init__(self, config: dict, path: str): - super().__init__(config, path) - self.metadata = {"name": '', "desc": '', "combo": ''} - self.spec = {"python": '', "entry": '', "volumes": [], "exclude": []} - self.key = None - """ - 上传到cos的路径 - """ - self.api_key = None - """ - 用户的api_key - """ - - @classmethod - def __type__(cls): - return 'Folder' - - def __dict__(self) -> dict: - data = { - "src": self.key, - "index": self.spec['entry'], - "python": self.spec['python'], - "conf": {"key": self.api_key}, - "name": self.metadata['name'], - } - if self.metadata.get("desc"): - data["desc"] = self.metadata['desc'] - if self.metadata.get("combo"): - data["combo"] = self.metadata['combo'] - if len(self.spec["volumes"]) > 0: - data['datasets'] = [v['id'] for v in self.spec['volumes']] - return data - - def parse_metadata(self, metadata: dict): - name = self.should_be('metadata.name', metadata.get('name'), str) - desc = self.should_be('metadata.desc', metadata.get('desc'), str, none=True) - combo = self.should_be('metadata.combo', metadata.get('combo'), str, none=True) - self.metadata['name'] = name - self.metadata['desc'] = desc - self.metadata['combo'] = combo - - def parse_spec(self, spec: dict): - python = self.should_be('spec.python', spec.get('python'), str, none=True) or '3.10' - python = self.should_in_values('spec.python', python, ['3.11', '3.10', '3.9', '3.8']) - entry = self.should_be('spec.entry', spec.get('entry'), str, none=True) or 'train.py' - self.should_file_exist('spec.entry', entry) - volumes = self.should_be('spec.volumes', spec.get('volumes'), list, none=True) or [] - # NOTE 当前后端只支持一个volume - import click - - if len(volumes) > 1: - raise click.BadParameter('Only one volume is supported') - - for volume in volumes: - self.should_equal_keys('volume', volume, ['name', 'id']) - self.should_be('volume.name', volume.get('name'), str, none=True) - self.should_be('volume.id', volume.get('id'), str) - exclude = self.should_be('spec.exclude', spec.get('exclude'), list, none=True) or [] - [self.should_be('exclude', e, str) for e in exclude] - - self.spec['python'] = 'python' + python - self.spec['entry'] = entry - self.spec['volumes'] = volumes - self.spec['exclude'] = exclude - - def walk(self, path: str = None) -> Tuple[List[str], List[str]]: - """ - 遍历path,生成文件列表,注意排除exclude中的文件 - 此函数为递归调用函数 - 返回所有命中的文件列表和排除的文件列表 - """ - path = path or self.dirpath - all_files = glob.glob(os.path.join(path, '**')) - exclude_files = [] - for g in self.spec['exclude']: - efs = glob.glob(os.path.join(path, g)) - exclude_files.extend(efs) - exclude_files = list(set(exclude_files)) - files = [] - for f in all_files: - if os.path.isdir(f): - if f in exclude_files: - continue - fs, efs = self.walk(f) - files.extend(fs) - exclude_files.extend(efs) - else: - if f in exclude_files: - continue - files.append(f) - return files, exclude_files - - def zip(self, files: List[str]) -> io.BytesIO: - """ - 将walk得到的文件列表压缩到memory_file中 - """ - memory_file = io.BytesIO() - progress = Progress( - TextColumn("{task.description}", justify="left"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - TimeRemainingColumn(), - ) - z = zipfile.ZipFile(memory_file, "w", zipfile.ZIP_DEFLATED) - with progress: - for i in progress.track(range(len(files)), description=FONT.swanlab("Packing... ")): - arcname = os.path.relpath(files[i], start=self.dirpath) - z.write(files[i], arcname) - memory_file.seek(0) - return memory_file - - def upload(self, memory_file: io.BytesIO): - """ - 上传压缩文件 - """ - val = memory_file.getvalue() - client, sts = CosUploader.create() - self.key = sts['prefix'] + "/tasks/" + f"{int(time.time() * 1000)}.zip" - with UploadBytesIO(FONT.swanlab("Uploading..."), val) as buffer: - client.upload_file_from_buffer( - Bucket=sts['bucket'], - Key=self.key, - Body=buffer, - MAXThread=5, - MaxBufferSize=5, - PartSize=1, - ) - - def run(self): - # 剔除、压缩、上传、发布任务 - files, _ = self.walk() - if len(files) == 0: - raise click.BadParameter(self.dirpath + " is empty") - login_info = login_init_sid() - print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!") - self.api_key = login_info.api_key - memory_file = self.zip(files) - self.upload(memory_file) - with UseTaskHttp() as http: - try: - http.post("/task", data=self.__dict__()) - except ApiError as e: - if e.resp.status_code not in [404, 401]: - raise e - elif e.resp.status_code == 404: - raise click.BadParameter("The dataset does not exist") - else: - raise click.BadParameter("The combo does not exist") - print( - FONT.swanlab( - f"Launch task successfully, use {FONT.bold(FONT.default('swanlab task list'))} to view the task" - ) - ) - - def dry_run(self): - # 剔除、显示即将发布的任务的相关信息 - # 1. 任务名称 - # 2. 任务描述 - # 3. 任务套餐 - # 4. 任务python版本 - # 5. 任务入口文件路径 - # 6. 上传的任务文件夹路径 - # 7. 上传的任务文件夹中忽略的文件列表 - # 8. 绑定的数据卷信息 - _, exclude_files = self.walk() - print(FONT.swanlab("This task will be launched:")) - rprint("[bold]Name: [/bold]" + self.metadata['name']) - rprint("[bold]Description: [/bold]" + self.metadata['desc']) - rprint("[bold]Combo: [/bold]" + self.metadata['combo'] or 'Default') - rprint("[bold]Python: [/bold]" + self.spec['python']) - rprint("[bold]Entry: [/bold]" + self.spec['entry']) - rprint("[bold]Folder: [/bold]" + self.dirpath) - rprint("[bold]Excluded files: [/bold]") - files = [] - for ef in exclude_files: - if os.path.isdir(ef): - suffix = ' 📁 ' - index = 1 # 文件夹在前 - size = 0 - for path, dirs, ef_files in os.walk(ef): - for _ in ef_files: - fp = os.path.join(path, _) - size += os.path.getsize(fp) - else: - suffix = " 🐍 " if ef.endswith('.py') else " 📄 " - index = 2 - size = os.path.getsize(ef) - ef = os.path.relpath(ef, start=self.dirpath) - files.append({'str': f'{ef} ({decimal(size)})', 'index': index, 'icon': Text(suffix)}) - files.sort(key=lambda x: x['str']) - files.sort(key=lambda x: x['index']) - for f in files: - rprint(f['icon'] + f['str']) - rprint("[bold]Volumes: [/bold]" + (str(self.spec['volumes']) if len(self.spec['volumes']) > 0 else 'None')) diff --git a/swanlab/cli/commands/task/__init__.py b/swanlab/cli/commands/task/__init__.py deleted file mode 100644 index 9a1c6a5ea..000000000 --- a/swanlab/cli/commands/task/__init__.py +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/7/17 17:16 -@File: __init__.py.py -@IDE: pycharm -@Description: - 启动! - beta版 -""" -from .list import list -from .search import search -from .stop import stop -import click - -__all__ = ["task"] - - -@click.group() -def task(): - """ - Beta Function: List, modify, query task information. - """ - pass - - -# noinspection PyTypeChecker -task.add_command(list) -# noinspection PyTypeChecker -task.add_command(search) -# noinspection PyTypeChecker -task.add_command(stop) diff --git a/swanlab/cli/commands/task/list.py b/swanlab/cli/commands/task/list.py deleted file mode 100644 index e924c8a5b..000000000 --- a/swanlab/cli/commands/task/list.py +++ /dev/null @@ -1,207 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/7/19 14:09 -@File: status.py -@IDE: pycharm -@Description: - 列出任务状态 -""" -import time -from datetime import datetime -from typing import List - -import click -from rich.layout import Layout -from rich.live import Live -from rich.markdown import Markdown -from rich.panel import Panel -from rich.table import Table - -from swanlab.cli.utils import login_init_sid, UseTaskHttp -from .utils import TaskModel - - -@click.command() -@click.option( - "--max-num", - "-m", - default=10, - nargs=1, - type=click.IntRange(1, 100), - help="The maximum number of tasks to display, default by 10, maximum by 100", -) -def list(max_num: int): # noqa - """ - List tasks - """ - # 获取访问凭证,生成http会话对象 - login_info = login_init_sid() - # 获取任务列表 - ltm = ListTasksModel(num=max_num, username=login_info.username) - aqm = AskQueueModel() - layout = ListTaskLayout(ltm, aqm) - layout.show() - - -class AskQueueModel: - def __init__(self): - self.num = None - - def ask(self): - with UseTaskHttp() as http: - queue_info = http.get("/task/queuing") - self.num = queue_info["sum"] - - def table(self): - qi = Table( - expand=True, - show_header=False, - header_style="bold", - title="[blue][b]Now Global Queue[/b]", - highlight=True, - border_style="blue", - ) - qi.add_column("Queue Info", "Queue Info") - self.ask() - qi.add_row(f"[b]Task Queuing count: {self.num}[/b]") - return qi - - -class ListTasksModel: - def __init__(self, num: int, username: str): - """ - :param num: 最大显示的任务数 - """ - self.num = num - self.username = username - - def __dict__(self): - return {"num": self.num} - - def list(self) -> List[TaskModel]: - with UseTaskHttp() as http: - tasks = http.get("/task", self.__dict__()) - return [TaskModel(self.username, task) for task in tasks] - - def table(self): - st = Table( - expand=True, - show_header=True, - title="[magenta][b]Now Task[/b]", - highlight=True, - border_style="magenta", - show_lines=True, - ) - st.add_column("Task ID", justify="right", ratio=1) - st.add_column("Task Name", justify="center", ratio=1) - st.add_column("Status", justify="center", ratio=1) - st.add_column("URL", justify="center", no_wrap=True, ratio=2) - st.add_column("Output URL", justify="center", ratio=1) - st.add_column("Output Size", justify="center", ratio=1) - st.add_column("Started Time", justify="center", ratio=2) - st.add_column("Finished Time", justify="center", ratio=2) - for tlm in self.list(): - status = tlm.status - if status == "COMPLETED": - status = f"[green]{status}[/green]" - elif status == "CRASHED": - status = f"[red]{status}[/red]" - output_url = tlm.output.output_url - if output_url is not None: - output_url = Markdown(f"[{tlm.output.path}]({output_url})", justify="center") - url = tlm.url - if url is not None: - url = Markdown(f"[{tlm.project_name}]({url})", justify="center") - st.add_row( - tlm.cuid, - tlm.name, - status, - url, - output_url, - tlm.output.size, - tlm.started_at, - tlm.finished_at, - ) - return st - - -class ListTaskHeader: - """ - Display header with clock. - """ - - @staticmethod - def __rich__() -> Panel: - grid = Table.grid(expand=True) - grid.add_column(justify="center", ratio=1) - grid.add_column(justify="right") - grid.add_row( - "[b]SwanLab[/b] task dashboard", - datetime.now().ctime().replace(":", "[blink]:[/]"), - ) - return Panel(grid, style="red on black") - - -class ListTaskLayout: - """ - 任务列表展示例如,如果有 3 列,总比率为 6,而比率=2,那么列的大小将是可用大小的三分之一。 - """ - - def __init__(self, ltm: ListTasksModel, aqm: AskQueueModel): - self.layout = Layout() - self.layout.split(Layout(name="header", size=3), Layout(name="main")) - self.layout["main"].split_row(Layout(name="task_table", ratio=16), Layout(name="info_side", ratio=7)) - self.layout["info_side"].split_column(Layout(name="queue_info", ratio=1), Layout(name="datasets_list", ratio=5)) - self.layout["header"].update(ListTaskHeader()) - self.layout["task_table"].update(Panel(ltm.table(), border_style="magenta")) - self.layout["queue_info"].update(Panel(aqm.table(), border_style="blue")) - self.ltm = ltm - self.aqm = aqm - self.redraw_datasets_list() - - @property - def datasets_list(self): - to = Table( - expand=True, - show_header=True, - header_style="bold", - title="[blue][b]Datasets[/b]", - highlight=True, - show_lines=True, - border_style="blue", - ) - to.add_column("Dataset ID", justify="right") - to.add_column("Dataset Name", justify="center") - to.add_column("Dataset Desc", justify="center") - to.add_column("Created Time", justify="center") - return to - - def redraw_datasets_list(self): - datasets_list = self.datasets_list - with UseTaskHttp() as http: - datasets = http.get("/task/datasets") - for dataset in datasets: - datasets_list.add_row( - dataset["cuid"], - dataset["name"], - dataset.get("desc", ""), - TaskModel.fmt_time(dataset["createdAt"]), - ) - self.layout["datasets_list"].update(Panel(datasets_list, border_style="blue")) - - def show(self): - with Live(self.layout, refresh_per_second=10, screen=True) as live: - search_now = time.time() - queue_now = time.time() - while True: - time.sleep(1) - self.layout["header"].update(ListTaskHeader()) - if time.time() - search_now > 5: - search_now = time.time() - self.layout["task_table"].update(Panel(self.ltm.table(), border_style="magenta")) - self.redraw_datasets_list() - if time.time() - queue_now > 3: - queue_now = time.time() - self.layout["queue_info"].update(Panel(self.aqm.table(), border_style="blue")) - live.refresh() diff --git a/swanlab/cli/commands/task/search.py b/swanlab/cli/commands/task/search.py deleted file mode 100644 index 87673372b..000000000 --- a/swanlab/cli/commands/task/search.py +++ /dev/null @@ -1,75 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/7/26 17:22 -@File: detail.py -@IDE: pycharm -@Description: - 根据cuid获取任务详情 -""" -import click -from rich.markdown import Markdown -from rich.syntax import Console, Syntax - -from swanlab.cli.utils import login_init_sid, UseTaskHttp -from swanlab.error import ApiError -from .utils import TaskModel - - -def validate_six_char_string(_, __, value): - if value is None: - raise click.BadParameter('Parameter is required') - if not isinstance(value, str): - raise click.BadParameter('Value must be a string') - if len(value) != 6: - raise click.BadParameter('String must be exactly 6 characters long') - return value - - -@click.command() -@click.argument("cuid", type=str, callback=validate_six_char_string) -def search(cuid): - """ - Get task detail by cuid - """ - login_info = login_init_sid() - with UseTaskHttp() as http: - try: - data = http.get(f"/task/{cuid}") - except ApiError as e: - if e.resp.status_code == 404: - raise click.BadParameter("Task not found") - tm = TaskModel(login_info.username, data) - """ - 任务名称,python版本,入口文件,任务状态,URL,创建时间,执行时间,结束时间,错误信息 - """ - console = Console() - print("") - console.print("[bold]Task Info[/bold]") - console.print(f"[bold]Task Name:[/bold] [yellow]{tm.name}[/yellow]") - console.print(f"[bold]Python Version:[/bold] [white]{tm.python}[white]") - console.print(f"[bold]Entry File:[/bold] [white]{tm.index}[white]") - icon = '✅' - if tm.status == 'CRASHED': - icon = '❌' - elif tm.status == 'STOPPED': - icon = '🛑' - elif tm.status != 'COMPLETED': - icon = '🏃' - console.print(f"[bold]Status:[/bold] {icon} {tm.status}") - console.print(f"[bold]Combo:[/bold] [white]{tm.combo}[/white]") - # dataset - for dataset in data.get("datasets", []): - console.print(f"[bold]Dataset ID:[/bold] [white]{dataset['cuid']}[/white]") - - tm.url is not None and console.print(Markdown(f"**SwanLab URL:** [{tm.project_name}]({tm.url})")) - if tm.output.path is not None: - console.print(Markdown(f"**Output URL**: [{tm.output.path}]({tm.output.output_url})")) - console.print(f"[bold]Output Size:[/bold] {tm.output.size}") - console.print(f"[bold]Created At:[/bold] {tm.created_at}") - tm.started_at is not None and console.print(f"[bold]Started At:[/bold] {tm.started_at}") - tm.finished_at is not None and console.print(f"[bold]Finished At:[/bold] {tm.finished_at}") - if tm.status == 'CRASHED': - console.print(f"[bold][red]Task Error[/red]:[/bold]\n") - console.print(Syntax(tm.msg, 'python')) - print("") # 加一行空行,与开头一致 diff --git a/swanlab/cli/commands/task/stop.py b/swanlab/cli/commands/task/stop.py deleted file mode 100644 index fe8fd0c24..000000000 --- a/swanlab/cli/commands/task/stop.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/7/30 16:13 -@File: stop.py -@IDE: pycharm -@Description: - 停止任务 -""" -import click -from swanlab.cli.utils import login_init_sid, UseTaskHttp -from .utils import validate_six_char_string -from swanlab.error import ApiError - - -@click.command() -@click.argument("cuid", type=str, callback=validate_six_char_string) -def stop(cuid): - """ - Stop a task by cuid - """ - login_init_sid() - with UseTaskHttp() as http: - try: - http.patch(f"/task/status", {"cuid": cuid, "status": "STOPPED", "msg": "User stopped by sdk"}) - except ApiError as e: - if e.resp.status_code == 404: - raise click.BadParameter("Task not found") - click.echo("Task stopped successfully, there may be a few minutes of delay online.") diff --git a/swanlab/cli/commands/task/utils.py b/swanlab/cli/commands/task/utils.py deleted file mode 100644 index 1c5439ce6..000000000 --- a/swanlab/cli/commands/task/utils.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/7/19 14:30 -@File: utils.py -@IDE: pycharm -@Description: - 任务相关工具函数 -""" -import time -from datetime import datetime, timedelta - -import click - -from swanlab.cli.utils import CosUploader -from swanlab.package import get_experiment_url - - -def validate_six_char_string(_, __, value): - if value is None: - raise click.BadParameter('Parameter is required') - if not isinstance(value, str): - raise click.BadParameter('Value must be a string') - if len(value) != 6: - raise click.BadParameter('String must be exactly 6 characters long') - return value - - -class TaskModel: - """ - 获取到的任务列表模型 - """ - - def __init__(self, username: str, task: dict): - self.cuid = task["cuid"] - self.username = username - self.name = task["name"] - """ - 任务名称 - """ - self.python = task["python"] - """ - 任务的python版本 - """ - self.index = task["index"] - """ - 任务的入口文件 - """ - self.project_name = task.get("pName", None) - """ - 项目名称 - """ - self.experiment_id = task.get("eId", None) - """ - 实验ID - """ - self.created_at = self.fmt_time(task["createdAt"]) - self.started_at = self.fmt_time(task.get("startedAt", None)) - self.finished_at = self.fmt_time(task.get("finishedAt", None)) - self.status = task["status"] - self.msg = task.get("msg", None) - self.combo = task["combo"] - self.output = OutputModel(self.cuid, task.get("output", {})) - - @property - def url(self): - if self.project_name is None or self.experiment_id is None: - return None - return get_experiment_url(self.username, self.project_name, self.experiment_id) - - @staticmethod - def fmt_time(date: str = None): - if date is None: - return None - date = date.replace("Z", "+00:00") - # 获取当前计算机时区的时差(以秒为单位) - local_time_offset = time.altzone if time.localtime().tm_isdst else time.timezone - local_time_offset = timedelta(seconds=-local_time_offset) - local_time = datetime.fromisoformat(date) + local_time_offset - # 将 UTC 时间转换为本地时间 - return local_time.strftime("%Y-%m-%d %H:%M:%S") - - -class OutputModel: - """ - 任务输出模型 - """ - - def __init__(self, cuid: str, output: dict): - self.cuid = cuid - self.path = output.get('path', None) - self.size = self.fmt_size(output.get('size', None)) - - @property - def output_url(self): - """获取预签名的输出下载 url (过期时间 1 小时)""" - if self.path is None: - return None - uploader = CosUploader() - key = f"{uploader.prefix}/outputs/{self.path}" - return uploader.client.get_presigned_download_url( - Bucket=uploader.bucket, Key=key, Params={'x-cos-security-token': uploader.token}, Expired=3600 - ) - - @staticmethod - def fmt_size(size: int = None): - if size is None: - return None - units = ['Byte', 'KB', 'MB', 'GB', 'TB'] - unit = 0 - while size >= 1024: - size /= 1024 - unit += 1 - return f"{size:.2f} {units[unit]}" diff --git a/swanlab/cli/commands/uploader/__init__.py b/swanlab/cli/commands/uploader/__init__.py deleted file mode 100644 index 1bc8cfe2a..000000000 --- a/swanlab/cli/commands/uploader/__init__.py +++ /dev/null @@ -1,153 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/26 15:02 -@File: __init__.py.py -@IDE: pycharm -@Description: - 上传模块 -""" -import os -import pathlib -import sys -import threading -import time - -import click -# noinspection PyPackageRequirements -from qcloud_cos.cos_exception import CosClientError, CosServiceError -# noinspection PyPackageRequirements -from qcloud_cos.cos_threadpool import SimpleThreadPool -from rich.progress import ( - BarColumn, - Progress, - TextColumn, - TimeRemainingColumn, -) -from swankit.log import FONT - -from swanlab.cli.utils import login_init_sid, UseTaskHttp, CosUploader - - -class FolderProgress: - def __init__(self, total_size: int): - self.progress = Progress( - TextColumn("{task.description}", justify="left"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - TimeRemainingColumn(), - ) - self.current = 0 - self.total_size = total_size - self.running = True - - def start(self, description: str): - with self.progress as progress: - for i in progress.track(range(self.total_size), description=description): - if not self.running: - break - if self.current > i: - continue - time.sleep(0.5) - while True: - if self.current > i or not self.running: - break - - def increase(self): - self.current += 1 - - def stop(self): - self.running = False - - -class UploadFolderHandler: - def __init__(self, uploader: CosUploader, progress: FolderProgress, retry=10): - self.uploader = uploader - self.progress = progress - self.retry = retry - self.error = None - - def __call__(self, path: str, key: str): - if self.error is not None: - return - if self.uploader.should_refresh: - self.uploader.refresh() - error = None - for i in range(self.retry): - try: - self.uploader.client.upload_file( - Bucket=self.uploader.bucket, - Key=key, - LocalFilePath=path, - ) - return self.progress.increase() - except CosClientError or CosServiceError as e: - error = e - continue - except Exception as e: - error = e - break - if error is not None: - self.error = error - raise error - - -@click.command() -@click.argument("path", type=click.Path(exists=True, file_okay=False, dir_okay=True, readable=True)) -@click.option("--name", "-n", type=str, help="Name of the dataset to be uploaded.") -@click.option("--desc", "-d", type=str, help="Description of the dataset to be uploaded.") -def upload(path, name: str, desc: str): - """ - Upload your 'dataset' to the cloud to - accelerate task initiation speed. - """ - path = os.path.abspath(path) - name = name or os.path.basename(path) - login_init_sid() - - # 创建数据集索引 - with UseTaskHttp() as http: - data = {"name": name, "desc": desc} if desc else {"name": name} - dataset = http.post("/task/dataset", data=data) - - # 创建上传对象 - uploader = CosUploader() - cuid = dataset['cuid'] - prefix = f"{uploader.prefix}/datasets/{cuid}/{name}" - - # 上传文件夹 - if os.path.isdir(path): - total_size = 0 - for root, dirs, files in os.walk(path): - total_size += len(files) - progress = FolderProgress(total_size) - t = threading.Thread(target=progress.start, args=("Uploading...",)) - t.start() - pool = SimpleThreadPool(5) - handler = UploadFolderHandler(uploader, progress) - # 遍历,添加任务 - for root, dirs, files in os.walk(path): - for file in files: - local_path = os.path.join(root, file).__str__() - # 生成key - key = prefix - tmp = os.path.relpath(local_path, start=path) - path_obj = pathlib.Path(tmp.__str__()) - folders = [parent for parent in path_obj.parents] - folders.reverse() - for folder in folders[1:]: - key += "/" + folder.name - key += "/" + path_obj.name - pool.add_task(handler, local_path, key) - pool.wait_completion() - result = pool.get_result() - progress.stop() - t.join() - if not result['success_all']: - print("Not all files upload success. you should retry, Error: {}".format(handler.error), file=sys.stderr) - status = "FAILURE" - else: - status = "SUCCESS" - print(FONT.swanlab("Upload success, dataset id: {}".format(FONT.bold(cuid)))) - return http.patch("/task/dataset/status", {"cuid": cuid, "status": status}) diff --git a/swanlab/cli/main.py b/swanlab/cli/main.py index c693b15cb..76a3cbe6e 100644 --- a/swanlab/cli/main.py +++ b/swanlab/cli/main.py @@ -7,10 +7,11 @@ @Description: swanlab脚本命令的主入口 """ -from swanlab.package import get_package_version -import swanlab.cli.commands as C import click +import swanlab.cli.commands as C +from swanlab.package import get_package_version + @click.group(invoke_without_command=True) @click.version_option(get_package_version(), "--version", "-v", message="SwanLab %(version)s") @@ -32,15 +33,6 @@ def cli(): # noinspection PyTypeChecker cli.add_command(C.convert) # 转换命令,用于转换其他实验跟踪工具 -# noinspection PyTypeChecker -cli.add_command(C.task) # 列出、停止、查询任务接口(beta,后续可能删除) - -# noinspection PyTypeChecker -cli.add_command(C.launch) # 启动任务 - -# noinspection PyTypeChecker -cli.add_command(C.upload) # 上传文件 - if __name__ == "__main__": cli() diff --git a/swanlab/cli/utils.py b/swanlab/cli/utils.py deleted file mode 100644 index d9eb12df0..000000000 --- a/swanlab/cli/utils.py +++ /dev/null @@ -1,182 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/25 14:33 -@File: utils.py -@IDE: pycharm -@Description: - 一些工具函数 -""" -import io -import sys -import threading -import time -from datetime import datetime, timedelta -from typing import Optional, Tuple - -# noinspection PyPackageRequirements -from qcloud_cos import CosConfig, CosS3Client -from rich.progress import ( - BarColumn, - DownloadColumn, - Progress, - TextColumn, - TimeRemainingColumn, - TransferSpeedColumn, -) - -from swanlab.api import terminal_login, create_http, LoginInfo, get_http -from swanlab.error import KeyFileError, ApiError -from swanlab.log import swanlog -from swanlab.package import get_key - - -class UseTaskHttp: - """ - 主要用于检测http响应是否为3xx字段,如果是则要求用户更新版本 - 使用此类之前需要先调用login_init_sid()函数完成全局http对象的初始化 - """ - - def __init__(self): - self.http = get_http() - - def __enter__(self): - return self.http - - def __exit__(self, exc_type, exc_val: Optional[ApiError], exc_tb): - if exc_type is ApiError: - # api已过期,需要更新swanlab版本 - if exc_val.resp.status_code // 100 == 3: - swanlog.info("SwanLab in your environment is outdated. Upgrade: `pip install -U swanlab`") - sys.exit(3) - return False - - -def login_init_sid() -> LoginInfo: - key = None - try: - key = get_key() - except KeyFileError: - pass - login_info = terminal_login(key) - create_http(login_info) - return login_info - - -class UploadBytesIO(io.BytesIO): - """ - 封装BytesIO,使其可以在上传文件时显示进度条 - """ - - class UploadProgressBar: - def __init__(self, total_size: int): - """ - :param total_size: 总大小(bytes) - """ - self.total_size = total_size - self.current = 0 - self.progress = Progress( - TextColumn("{task.description}", justify="left"), - BarColumn(), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - DownloadColumn(), - "•", - TransferSpeedColumn(), - "•", - TimeRemainingColumn(), - ) - - def update(self, *args): - self.current += args[0] - - def start(self, description: str): - with self.progress as progress: - for i in progress.track(range(self.total_size), description=description): - if self.current > i: - continue - time.sleep(0.5) - while True: - if self.current > i: - break - - def __init__(self, description: str, *args, **kwargs): - super().__init__(*args, **kwargs) - self.upload_progress = self.UploadProgressBar(len(self.getvalue())) - self.t = None - self.description = description - - def read(self, *args): - self.upload_progress.update(*args) - return super().read(*args) - - def __enter__(self): - self.t = threading.Thread(target=self.upload_progress.start, args=(self.description,)) - self.t.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.t.join() - return False - - -class CosUploader: - REFRESH_TIME = 60 * 60 * 1.5 # 1.5小时 - - def __init__(self): - """初始化 cos""" - client, sts = self.create() - self.__expired_time = datetime.fromtimestamp(sts["expiredTime"]) - self.prefix = sts["prefix"] - self.bucket = sts["bucket"] - self.client = client - self.__updating = False - """ - 标记是否正在更新sts - """ - self.__token = sts["credentials"]["sessionToken"] # 临时密钥使用的 token - - @property - def token(self): - if self.should_refresh: - self.refresh() - return self.__token - - @property - def should_refresh(self): - # cos传递的是北京时间,需要添加8小时 - now = datetime.utcnow() + timedelta(hours=8) - # 过期时间减去当前时间小于刷新时间,需要注意为负数的情况 - if self.__expired_time < now: - return True - return (self.__expired_time - now).seconds < self.REFRESH_TIME - - @staticmethod - def create() -> Tuple[CosS3Client, dict]: - with UseTaskHttp() as http: - sts = http.get("/user/codes/sts") - region = sts["region"] - token = sts["credentials"]["sessionToken"] - secret_id = sts["credentials"]["tmpSecretId"] - secret_key = sts["credentials"]["tmpSecretKey"] - config = CosConfig(Region=region, SecretId=secret_id, SecretKey=secret_key, Token=token, Scheme="https") - client = CosS3Client(config) - return client, sts - - def refresh(self): - """ - 更新sts - """ - # 防止多线程更新sts - if self.__updating: - while self.__updating: - time.sleep(1) - return - - self.__updating = True - client, sts = self.create() - self.client = client - self.__expired_time = datetime.fromtimestamp(sts["expiredTime"]) - self.prefix = sts["prefix"] - self.bucket = sts["bucket"] - self.__token = sts["credentials"]["sessionToken"] diff --git a/swanlab/data/callback_cloud.py b/swanlab/data/callback_cloud.py index 2235d9473..985e107e3 100644 --- a/swanlab/data/callback_cloud.py +++ b/swanlab/data/callback_cloud.py @@ -8,7 +8,6 @@ 云端回调 """ import json -import os import sys from swankit.callback.models import RuntimeInfo, MetricInfo, ColumnInfo @@ -22,7 +21,7 @@ from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel from swanlab.data.cloud import ThreadPool from swanlab.data.cloud import UploadType -from swanlab.env import in_jupyter, SwanLabEnv, is_interactive +from swanlab.env import in_jupyter, is_interactive from swanlab.error import KeyFileError from swanlab.log import swanlog from swanlab.package import ( @@ -33,6 +32,7 @@ get_key, ) from .callback_local import LocalRunCallback, get_run, SwanLabRunState +from ..api.http import reset_http def show_button_html(experiment_url): @@ -241,12 +241,6 @@ def _write_call_call(message): if in_jupyter(): show_button_html(experiment_url) - # task环境下,同步实验信息回调 - if os.environ.get(SwanLabEnv.RUNTIME.value) == "task": - cuid = os.environ["SWANLAB_TASK_ID"] - info = {"cuid": cuid, "pId": http.proj_id, "eId": http.exp_id, "pName": http.projname} - http.patch("/task/experiment", info) - def on_runtime_info_update(self, r: RuntimeInfo): # 执行local逻辑,保存文件到本地 super(CloudRunCallback, self).on_runtime_info_update(r) @@ -335,6 +329,7 @@ def _(): FONT.loading("Waiting for uploading complete", _) get_http().update_state(state == SwanLabRunState.SUCCESS) + reset_http() # 取消注册系统回调 self._unregister_sys_callback() self.exiting = False diff --git a/swanlab/env.py b/swanlab/env.py index a51c8f806..13072f845 100644 --- a/swanlab/env.py +++ b/swanlab/env.py @@ -63,7 +63,7 @@ class SwanLabEnv(enum.Enum): """ RUNTIME = "SWANLAB_RUNTIME" """ - swanlab的运行时环境,"user" "develop" "test" "test-no-cloud" "task" + swanlab的运行时环境,"user" "develop" "test" "test-no-cloud" """ WEBHOOK = "SWANLAB_WEBHOOK" """ diff --git a/test/unit/cli/test_cli_launch.py b/test/unit/cli/test_cli_launch.py deleted file mode 100644 index 65bc922c1..000000000 --- a/test/unit/cli/test_cli_launch.py +++ /dev/null @@ -1,389 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:56 -@File: test_cli_launch.py -@IDE: pycharm -@Description: - 测试启动 - 主要是配置文件的解析 -""" -from swanlab.cli.commands import launcher as L -import tutils as T -import pytest -import click -import os - - -class TestParse: - """ - 测试解析配置文件,对应parse函数 - """ - - def test_error_api_version(self): - """ - 测试错误的apiVersion - """ - config = { - 'apiVersion': '12345', - } - with pytest.raises(click.UsageError) as e: - L.parse(config, '') - assert str(e.value) == 'Unknown api version: 12345' - - def test_error_kind(self): - """ - 测试错误的kind - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'folder', # 大小写敏感 - } - with pytest.raises(click.UsageError) as e: - L.parse(config, '') - assert str(e.value) == 'Unknown kind: folder' - - def test_get_folder_v1(self): - """ - 测试成功 - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - } - parser = L.parse(config, '') - assert isinstance(parser, L.parser.v1.FolderParser) - - -class TestFolderParserMetadata: - """ - 测试解析器解析metadata部分 - """ - - def test_metadata_error_type(self): - """ - 测试错误的类型 - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - 'metadata': 12345, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, '') - parser.parse() - assert str(e.value) == 'metadata should be , not ' - - def test_metadata_error_key(self): - """ - 测试错误的key - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - 'metadata': { - 'name': 'test', - 'error': 'test', - }, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, '') - parser.parse() - assert str(e.value) == 'Unknown key: metadata.error' - - def test_no_name(self): - """ - 测试没有name - """ - config = {'apiVersion': 'swanlab/v1', 'kind': 'Folder', 'metadata': {}} - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, '') - parser.parse() - assert str(e.value) == 'metadata.name should not be None' - - def test_no_combo(self): - """ - 测试没有combo - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - 'metadata': { - 'name': 'test', - }, - } - parser = L.parse(config, '') - parser.parse_metadata(config['metadata']) - assert parser.metadata['combo'] == None # noqa - - def test_no_desc(self): - """ - 测试没有desc - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - 'metadata': { - 'name': 'test', - 'combo': 'test', - }, - } - parser = L.parse(config, '') - parser.parse_metadata(config['metadata']) - assert parser.metadata['desc'] is None # noqa - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - 'metadata': { - 'name': 'test', - 'combo': 'test', - 'desc': 'test', - }, - } - parser = L.parse(config, '') - parser.parse_metadata(config['metadata']) - assert parser.metadata['desc'] == 'test' # noqa - - -class TestFolderParserSpec: - """ - 测试解析器解析spec部分 - """ - - def test_spec_error_type(self): - """ - 测试错误的类型 - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': 12345, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, '') - parser.parse() - assert str(e.value) == 'spec should be , not ' - - def test_spec_error_key(self): - """ - 测试错误的key - """ - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'name': 'test', - }, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, '') - parser.parse() - assert str(e.value) == 'Unknown key: spec.name' - - @staticmethod - def mock_entry(name='train.py') -> str: - """ - 模拟一个文件 - """ - with open(os.path.join(T.TEMP_PATH, name), 'w') as f: - f.write('print("hello")') - return os.path.join(T.TEMP_PATH, 'swanlab.yaml') - - def test_python(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'python': '3.8', - }, - } - parser = L.parse(config, f) - parser.parse() - assert parser.spec['python'] == 'python3.8' # noqa - - def test_error_python(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'python': '3.7', - }, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'spec.python should be in [\'3.11\', \'3.10\', \'3.9\', \'3.8\'], not 3.7' - - def test_no_python(self): - """ - 测试没有python - """ - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - }, - } - parser = L.parse(config, f) - parser.parse() - assert parser.spec['python'] == 'python3.10' # noqa - - def test_no_entry(self): - """ - 测试没有entry - """ - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'python': '3.8', - }, - } - parser = L.parse(config, f) - parser.parse() - parser.spec['entry'] == 'train.py' # noqa - - def test_no_entry_file(self): - """ - 测试entry文件不存在 - """ - f = self.mock_entry('test.py') - err_entry = '1234' - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': err_entry, - }, - } - with pytest.raises(click.FileError) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == f'spec.entry not found: {os.path.join(T.TEMP_PATH, err_entry)}' - - def test_entry_not_file(self): - """ - 测试entry不是文件 - """ - f = self.mock_entry() - err_entry = 'test' - os.mkdir(os.path.join(T.TEMP_PATH, err_entry)) - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': err_entry, - }, - } - with pytest.raises(click.FileError) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == f'spec.entry should be a file: {os.path.join(T.TEMP_PATH, err_entry)}' - - def test_volumes(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'volumes': [ - {'name': 'test', 'id': 'test'}, - ], - }, - } - parser = L.parse(config, f) - parser.parse() - assert parser.spec['volumes'] == [{'name': 'test', 'id': 'test'}] # noqa - config['spec']['volumes'].append({'name': 'test', 'id': '2'}) - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'Only one volume is supported' - config['spec']['volumes'] = [{'id': 'test'}] # noqa - parser = L.parse(config, f) - parser.parse() - assert parser.spec['volumes'] == [{'id': 'test'}] # noqa - config['spec']['volumes'] = [{'name': 'test'}] # noqa - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'volume.id should not be None' - - def test_error_volumes_type(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'volumes': 'test', - }, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'spec.volumes should be , not ' - config['spec']['volumes'] = [{'name': 'test', 'id': 123}] # noqa - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'volume.id should be , not ' - config['spec']['volumes'] = [{'name': 'test', 'id': 'test', 'error': 1}] # noqa - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'Unknown key: volume.error' - - def test_exclude(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'exclude': ['test'], - }, - } - parser = L.parse(config, f) - parser.parse() - assert parser.spec['exclude'] == ['test'] # noqa - - def test_exclude_error_type(self): - f = self.mock_entry() - config = { - 'apiVersion': 'swanlab/v1', - 'kind': 'Folder', - "metadata": {"name": "test", "desc": "test", "combo": "test"}, - 'spec': { - 'entry': 'train.py', - 'exclude': 'test', - }, - } - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'spec.exclude should be , not ' - config['spec']['exclude'] = ['test', 123] # noqa - with pytest.raises(click.BadParameter) as e: - parser = L.parse(config, f) - parser.parse() - assert str(e.value) == 'exclude should be , not ' diff --git a/test/unit/cli/test_cli_task.py b/test/unit/cli/test_cli_task.py deleted file mode 100644 index a842918e8..000000000 --- a/test/unit/cli/test_cli_task.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/9/6 16:34 -@File: test_cli_task.py -@IDE: pycharm -@Description: - 测试 swanlab task 相关函数 -""" -import pytest - -import tutils -from swanlab.cli.commands.task.utils import OutputModel - - -@pytest.mark.skipif(tutils.is_skip_cloud_test, reason="skip cloud test") -def test_output_model_none(): - om = OutputModel("123456", {}) - assert om.cuid == "123456" - assert om.path is None - assert om.size is None - assert om.output_url is None - - -@pytest.mark.skipif(tutils.is_skip_cloud_test, reason="skip cloud test") -def test_output_model_ok(): - om = OutputModel("123456", {"path": "nothing.zip", "size": 123}) - assert om.cuid == "123456" - assert om.path == "nothing.zip" - assert om.size == OutputModel.fmt_size(123) - - -@pytest.mark.skipif(tutils.is_skip_cloud_test, reason="skip cloud test") -def test_output_model_fmt_size(): - assert OutputModel.fmt_size(1) == "1.00 Byte" - assert OutputModel.fmt_size(1024) == "1.00 KB" - assert OutputModel.fmt_size(1024 * 1024) == "1.00 MB" - assert OutputModel.fmt_size(1024 * 1024 * 1024) == "1.00 GB" - assert OutputModel.fmt_size(1024 * 1024 * 1024 * 1024) == "1.00 TB" diff --git a/test/unit/cli/test_cli_utils.py b/test/unit/cli/test_cli_utils.py deleted file mode 100644 index b4d5be80a..000000000 --- a/test/unit/cli/test_cli_utils.py +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -r""" -@DATE: 2024/8/22 14:30 -@File: test_cli_utils.py -@IDE: pycharm -@Description: - 测试cli相关工具函数 -""" -import pytest -from swanlab.cli.utils import UseTaskHttp -import tutils.setup as SU - - -def test_use_task_http_ok(): - with SU.UseMocker() as m: - m.post("/test", text="mock") - with SU.UseSetupHttp(): - with UseTaskHttp() as http: - text = http.post("/test") - assert text == "mock" - - -def test_use_task_http_abandon(): - with pytest.raises(SystemExit) as p: - with SU.UseMocker() as m: - m.post("/test", status_code=301, reason="Abandon") - with SU.UseSetupHttp(): - with UseTaskHttp() as http: - http.post("/test") - assert p.value.code == 3 diff --git a/test/unit/data/test_sdk.py b/test/unit/data/test_sdk.py index 9d555e329..39a75a732 100644 --- a/test/unit/data/test_sdk.py +++ b/test/unit/data/test_sdk.py @@ -149,6 +149,13 @@ class TestInitMode: 测试init时函数的mode参数设置行为 """ + def test_init_local(self): + run = S.init(mode="local") + assert os.environ[MODE] == "local" + run.log({"TestInitMode": 1}) # 不会报错 + assert get_run() is not None + assert run.public.cloud.project_name is None + def test_init_disabled(self): logdir = os.path.join(T.TEMP_PATH, generate()).__str__() run = S.init(mode="disabled", logdir=logdir) @@ -159,13 +166,6 @@ def test_init_disabled(self): assert not os.path.exists(a) assert get_run() is not None - def test_init_local(self): - run = S.init(mode="local") - assert os.environ[MODE] == "local" - run.log({"TestInitMode": 1}) # 不会报错 - assert get_run() is not None - assert run.public.cloud.project_name is None - @pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") def test_init_cloud(self): S.login(T.is_skip_cloud_test) From f5b46fe8a9008078108da52b885c8cdd98fc6fe0 Mon Sep 17 00:00:00 2001 From: KAAANG <79990647+SAKURA-CAT@users.noreply.github.com> Date: Fri, 17 Jan 2025 12:01:35 +0800 Subject: [PATCH 2/2] fix: test error in cloud mode --- swanlab/data/run/main.py | 3 ++- swanlab/data/run/public.py | 14 ++++++++------ test/unit/data/test_sdk.py | 16 ++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/swanlab/data/run/main.py b/swanlab/data/run/main.py index 84fddd579..8cbd0099e 100644 --- a/swanlab/data/run/main.py +++ b/swanlab/data/run/main.py @@ -86,6 +86,7 @@ def __init__( should_save=not self.__operator.disabled, version=get_package_version(), ) + self.__mode = get_mode() self.__public = SwanLabPublicConfig(self.__project_name, self.__settings) self.__operator.before_run(self.__settings) # ---------------------------------- 初始化日志记录器 ---------------------------------- @@ -216,7 +217,7 @@ def public(self): @property def mode(self) -> str: - return get_mode() + return self.__mode @property def state(self) -> SwanLabRunState: diff --git a/swanlab/data/run/public.py b/swanlab/data/run/public.py index 39997f5c2..522c577df 100644 --- a/swanlab/data/run/public.py +++ b/swanlab/data/run/public.py @@ -1,6 +1,7 @@ from swankit.core import SwanLabSharedSettings from swanlab.api import get_http +from swanlab.env import get_mode from swanlab.package import get_project_url, get_experiment_url @@ -11,6 +12,12 @@ class SwanlabCloudConfig: def __init__(self): self.__http = None + if get_mode() == "cloud": + try: + self.__http = get_http() + except ValueError: + pass + self.__available = self.__http is not None def __get_property_from_http(self, name: str): """ @@ -27,12 +34,7 @@ def available(self): """ Whether the SwanLab is running in cloud mode. """ - try: - if self.__http is None: - self.__http = get_http() - return True - except ValueError: - return False + return self.__available @property def project_name(self): diff --git a/test/unit/data/test_sdk.py b/test/unit/data/test_sdk.py index 39a75a732..04a0d4a45 100644 --- a/test/unit/data/test_sdk.py +++ b/test/unit/data/test_sdk.py @@ -26,6 +26,9 @@ def setup_function(): 在当前测试文件下的每个测试函数执行前后执行 """ swanlog.disable_log() + run = get_run() + if run is not None: + run.finish() yield run = get_run() if run is not None: @@ -153,6 +156,8 @@ def test_init_local(self): run = S.init(mode="local") assert os.environ[MODE] == "local" run.log({"TestInitMode": 1}) # 不会报错 + assert run.mode == "local" + assert run.public.cloud.available is False assert get_run() is not None assert run.public.cloud.project_name is None @@ -181,6 +186,17 @@ def test_init_error(self): S.init(mode="123456") # noqa assert get_run() is None + @pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test") + def test_init_multiple(self): + # 先初始化cloud + self.test_init_cloud() + get_run().finish() + # 再初始化local + self.test_init_local() + get_run().finish() + # 再初始化disabled + self.test_init_disabled() + # ---------------------------------- 测试环境变量输入 ---------------------------------- def test_init_disabled_env(self):