Skip to content

Commit a4785c4

Browse files
authored
feat: add choices for swanlab.init (#783)
1 parent baa92b3 commit a4785c4

File tree

8 files changed

+201
-56
lines changed

8 files changed

+201
-56
lines changed

swanlab/api/auth/login.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,19 @@
88
用户登录接口,输入用户的apikey,保存用户token到本地
99
进行一些交互定义和数据请求
1010
"""
11-
from swanlab.error import ValidationError, APIKeyFormatError
12-
from swankit.log import FONT
11+
import getpass
12+
import os
13+
import sys
14+
15+
import requests
1316
from swankit.env import is_windows
14-
from swanlab.package import get_user_setting_path, get_host_api
17+
from swankit.log import FONT
18+
1519
from swanlab.api.info import LoginInfo
16-
from swanlab.log import swanlog
1720
from swanlab.env import in_jupyter, SwanLabEnv
18-
import getpass
19-
import requests
20-
import sys
21-
import os
21+
from swanlab.error import ValidationError, APIKeyFormatError
22+
from swanlab.log import swanlog
23+
from swanlab.package import get_user_setting_path, get_host_api
2224

2325

2426
def login_request(api_key: str, timeout: int = 20) -> requests.Response:
@@ -64,7 +66,6 @@ def input_api_key(
6466
_t = sys.excepthook
6567
sys.excepthook = _abort_tip
6668
if not again:
67-
swanlog.info("Logging into swanlab cloud.")
6869
swanlog.info("You can find your API key at: " + FONT.yellow(get_user_setting_path()))
6970
# windows 额外打印提示信息
7071
if is_windows():

swanlab/data/callback_cloud.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
@Description:
88
云端回调
99
"""
10-
import io
1110
import json
1211
import os
1312
import sys
@@ -23,7 +22,7 @@
2322
from swanlab.api.upload.model import ColumnModel, ScalarModel, MediaModel, FileModel
2423
from swanlab.data.cloud import ThreadPool
2524
from swanlab.data.cloud import UploadType
26-
from swanlab.env import in_jupyter, SwanLabEnv
25+
from swanlab.env import in_jupyter, SwanLabEnv, is_interactive
2726
from swanlab.error import KeyFileError
2827
from swanlab.log import swanlog
2928
from swanlab.package import (
@@ -144,23 +143,19 @@ def __init__(self, public: bool):
144143
self.public = public
145144

146145
@classmethod
147-
def get_login_info(cls):
146+
def create_login_info(cls):
148147
"""
149148
发起登录,获取登录信息,执行此方法会覆盖原有的login_info
150149
"""
151150
key = None
152151
try:
153152
key = get_key()
154153
except KeyFileError:
155-
try:
156-
fd = sys.stdin.fileno()
157-
# 不是标准终端,且非jupyter环境,无法控制其回显
158-
if not os.isatty(fd) and not in_jupyter():
159-
raise KeyFileError("The key file is not found, call `swanlab.login()` or use `swanlab login` ")
160-
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
161-
# 这种情况下为用户自定义情况
162-
except io.UnsupportedOperation:
163-
pass
154+
pass
155+
if key is None and not is_interactive():
156+
raise KeyFileError(
157+
"api key not configured (no-tty), call `swanlab.login(api_key=[your_api_key])` or set `swanlab.init(mode=\"local\")`."
158+
)
164159
return terminal_login(key)
165160

166161
@staticmethod
@@ -208,12 +203,11 @@ def __str__(self):
208203

209204
def on_init(self, project: str, workspace: str, logdir: str = None, **kwargs) -> int:
210205
super(CloudRunCallback, self).on_init(project, workspace, logdir)
211-
# 检测是否有最新的版本
212-
self._get_package_latest_version()
213206
if self.login_info is None:
214207
swanlog.debug("Login info is None, get login info.")
215-
self.login_info = self.get_login_info()
216-
208+
self.login_info = self.create_login_info()
209+
# 检测是否有最新的版本
210+
self._get_package_latest_version()
217211
http = create_http(self.login_info)
218212
return http.mount_project(project, workspace, self.public).history_exp_count
219213

swanlab/data/sdk.py

Lines changed: 55 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212

1313
from swanboard import SwanBoardCallback
1414
from swankit.env import SwanLabMode
15+
from swankit.log import FONT
1516

16-
from swanlab.api import code_login
17-
from swanlab.env import SwanLabEnv
17+
from swanlab.api import code_login, terminal_login
18+
from swanlab.env import SwanLabEnv, is_interactive
1819
from swanlab.log import swanlog
1920
from .callback_cloud import CloudRunCallback
2021
from .callback_local import LocalRunCallback
@@ -27,6 +28,8 @@
2728
get_run,
2829
)
2930
from .run.helper import SwanLabRunOperator
31+
from ..error import KeyFileError
32+
from ..package import get_key, get_host_web
3033

3134

3235
def _check_proj_name(name: str) -> str:
@@ -68,7 +71,10 @@ def login(api_key: str = None):
6871
"""
6972
if SwanLabRun.is_started():
7073
raise RuntimeError("You must call swanlab.login() before using init()")
71-
CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.get_login_info()
74+
CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.create_login_info()
75+
76+
77+
MODES = Literal["disabled", "cloud", "local"]
7278

7379

7480
def init(
@@ -78,7 +84,7 @@ def init(
7884
description: str = None,
7985
config: Union[dict, str] = None,
8086
logdir: str = None,
81-
mode: Literal["disabled", "cloud", "local"] = None,
87+
mode: MODES = None,
8288
load: str = None,
8389
public: bool = None,
8490
**kwargs,
@@ -152,13 +158,9 @@ def init(
152158
project = _load_data(load_data, "project", project)
153159
workspace = _load_data(load_data, "workspace", workspace)
154160
public = _load_data(load_data, "private", public)
155-
156-
# ---------------------------------- 模式选择 ----------------------------------
157-
# for
158-
159-
# ---------------------------------- helper初始化 ----------------------------------
160-
operator, c = _create_operator(mode, public)
161161
project = _check_proj_name(project if project else os.path.basename(os.getcwd())) # 默认实验名称为当前目录名
162+
# ---------------------------------- 启动操作员 ----------------------------------
163+
operator, c = _create_operator(mode, public)
162164
exp_num = SwanLabRunOperator.parse_return(
163165
operator.on_init(project, workspace, logdir=logdir),
164166
key=c.__str__() if c else None,
@@ -238,6 +240,8 @@ def _init_mode(mode: str = None):
238240
传入的mode必须为SwanLabMode枚举中的一个值,否则报错ValueError
239241
如果环境变量和传入的mode参数都为None,则默认为cloud
240242
243+
从环境变量中提取mode参数以后,还有一步让用户选择运行模式的交互,详见issue: https://github.com/SwanHubX/SwanLab/issues/632
244+
241245
:param mode: str, optional
242246
传入的mode参数
243247
:return: str mode
@@ -252,8 +256,45 @@ def _init_mode(mode: str = None):
252256
if mode is not None and mode not in allowed:
253257
raise ValueError(f"`mode` must be one of {allowed}, but got {mode}")
254258
mode = "cloud" if mode is None else mode
259+
# 如果mode为cloud,且没找到 api key或者未登录,则提示用户输入
260+
try:
261+
get_key()
262+
no_api_key = False
263+
except KeyFileError:
264+
no_api_key = True
265+
login_info = None
266+
if mode == "cloud" and no_api_key:
267+
# 判断当前进程是否在交互模式下
268+
if is_interactive():
269+
swanlog.info(
270+
f"Using SwanLab to track your experiments. Please refer to {FONT.yellow('https://docs.swanlab.cn')} for more information."
271+
)
272+
swanlog.info("(1) Create a SwanLab account.")
273+
swanlog.info("(2) Use an existing SwanLab account.")
274+
swanlog.info("(3) Don't visualize my results.")
275+
276+
# 交互选择
277+
tip = FONT.swanlab("Enter your choice: ")
278+
code = input(tip)
279+
while code not in ["1", "2", "3"]:
280+
swanlog.warning("Invalid choice, please enter again.")
281+
code = input(tip)
282+
if code == "3":
283+
mode = "local"
284+
elif code == "2":
285+
swanlog.info("You chose 'Use an existing swanlab account'")
286+
swanlog.info("Logging into " + FONT.yellow(get_host_web()))
287+
login_info = terminal_login()
288+
elif code == "1":
289+
swanlog.info("You chose 'Create a swanlab account'")
290+
swanlog.info("Create a SwanLab account here: " + FONT.yellow(get_host_web() + "/login"))
291+
login_info = terminal_login()
292+
else:
293+
raise ValueError("Invalid choice")
294+
295+
# 如果不在就不管
255296
os.environ[mode_key] = mode
256-
return mode
297+
return mode, login_info
257298

258299

259300
def _init_config(config: Union[dict, str]):
@@ -284,7 +325,9 @@ def _create_operator(mode: str, public: bool) -> Tuple[SwanLabRunOperator, Optio
284325
:param public: 是否公开
285326
:return: SwanLabRunOperator, CloudRunCallback
286327
"""
287-
mode = _init_mode(mode)
328+
mode, login_info = _init_mode(mode)
329+
CloudRunCallback.login_info = login_info
330+
288331
if mode == SwanLabMode.DISABLED.value:
289332
swanlog.warning("SwanLab run disabled, the data will not be saved or uploaded.")
290333
return SwanLabRunOperator(), None

swanlab/env.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
除了utils和error模块,其他模块都可以使用这个模块
99
"""
1010
import enum
11+
import io
1112
import os
13+
import sys
1214
from typing import List
1315

1416
import swankit.env as E
@@ -127,3 +129,17 @@ def in_jupyter() -> bool:
127129
return True
128130
except NameError:
129131
return False
132+
133+
134+
def is_interactive():
135+
"""
136+
是否为可交互式环境(输入连接tty设备)
137+
特殊的环境:jupyter notebook
138+
"""
139+
try:
140+
fd = sys.stdin.fileno()
141+
return os.isatty(fd) or in_jupyter()
142+
# 当使用capsys、capfd或monkeypatch等fixture来捕获或修改标准输入输出时,会抛出io.UnsupportedOperation
143+
# 多为测试情况,可交互
144+
except io.UnsupportedOperation:
145+
return True

swanlab/package.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,22 @@
77
@Description:
88
用于管理swanlab的包管理器的模块,做一些封装
99
"""
10-
from .env import get_save_dir, SwanLabEnv
11-
from .error import KeyFileError
12-
from typing import Optional
13-
import requests
14-
import netrc
1510
import json
11+
import netrc
1612
import os
13+
from typing import Optional
14+
15+
import requests
16+
17+
from .env import get_save_dir, SwanLabEnv
18+
from .error import KeyFileError
1719

1820
package_path = os.path.join(os.path.dirname(__file__), "package.json")
1921

2022

2123
# ---------------------------------- 版本号相关 ----------------------------------
2224

25+
2326
def get_package_version() -> str:
2427
"""获取swanlab的版本号
2528
:return: swanlab的版本号
@@ -69,7 +72,7 @@ def get_user_setting_path() -> str:
6972
"""获取用户设置的url
7073
:return: 用户设置的url
7174
"""
72-
return get_host_web() + "/settings"
75+
return get_host_web() + "/space/~/settings"
7376

7477

7578
def get_project_url(username: str, projname: str) -> str:

test/unit/data/test_sdk.py

Lines changed: 79 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,13 @@ def test_init_error_mode(self):
4646
S._init_mode("123456") # noqa
4747

4848
@pytest.mark.parametrize("mode", ["disabled", "local", "cloud"])
49-
def test_init_mode(self, mode):
49+
def test_init_mode(self, mode, monkeypatch):
5050
"""
5151
初始化时mode参数正确
5252
"""
53+
if mode == 'cloud':
54+
mode = 'local'
55+
monkeypatch.setattr("builtins.input", lambda _: "3")
5356
S._init_mode(mode)
5457
assert os.environ[MODE] == mode
5558
del os.environ[MODE]
@@ -58,14 +61,88 @@ def test_init_mode(self, mode):
5861
# assert os.environ[MODE] == mode
5962

6063
@pytest.mark.parametrize("mode", ["disabled", "local", "cloud"])
61-
def test_overwrite_mode(self, mode):
64+
def test_overwrite_mode(self, mode, monkeypatch):
6265
"""
6366
初始化时mode参数正确,覆盖环境变量
6467
"""
68+
if mode == 'cloud':
69+
mode = 'local'
70+
monkeypatch.setattr("builtins.input", lambda _: "3")
6571
os.environ[MODE] = "123456"
6672
S._init_mode(mode)
6773
assert os.environ[MODE] == mode
6874

75+
def test_no_api_key_to_cloud(self, monkeypatch):
76+
"""
77+
初始化时mode为cloud,但是没有设置apikey
78+
"""
79+
if SwanLabEnv.API_KEY.value in os.environ:
80+
del os.environ[SwanLabEnv.API_KEY.value]
81+
monkeypatch.setattr("builtins.input", lambda _: "3")
82+
mode, login_info = S._init_mode("cloud")
83+
assert mode == "local"
84+
assert login_info is None
85+
86+
@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
87+
def test_init_cloud_with_no_api_key(self, monkeypatch):
88+
"""
89+
初始化时mode为cloud,但是没有设置apikey
90+
"""
91+
api_key = os.environ[SwanLabEnv.API_KEY.value]
92+
del os.environ[SwanLabEnv.API_KEY.value]
93+
# 在测试时默认会在交互模式下
94+
# 接下来需要模拟终端连接,使用monkeypatch
95+
# 三种选择方式:
96+
# 1. 输入api key
97+
# 2. 创建账号
98+
# 3. 使用本地版
99+
100+
# 选择第三种
101+
monkeypatch.setattr("builtins.input", lambda _: "3")
102+
mode, login_info = S._init_mode("cloud")
103+
assert mode == "local"
104+
assert login_info is None
105+
106+
# 选择第二种
107+
monkeypatch.setattr("builtins.input", lambda _: "2")
108+
monkeypatch.setattr("getpass.getpass", lambda _: api_key)
109+
mode, login_info = S._init_mode("cloud")
110+
assert mode == "cloud"
111+
assert login_info is not None
112+
113+
# 登录后会保存key,测试时需要删除
114+
os.remove(os.path.join(get_save_dir(), ".netrc"))
115+
116+
# 选择第一种
117+
monkeypatch.setattr("builtins.input", lambda _: "1")
118+
monkeypatch.setattr("getpass.getpass", lambda _: api_key)
119+
mode, login_info = S._init_mode("cloud")
120+
assert mode == "cloud"
121+
assert login_info is not None
122+
123+
# 登录后会保存key,测试时需要删除
124+
os.remove(os.path.join(get_save_dir(), ".netrc"))
125+
126+
# 选择其他的
127+
def create_other_input():
128+
first = True
129+
130+
def oi():
131+
nonlocal first
132+
if first:
133+
first = False
134+
return "123456"
135+
else:
136+
return "3"
137+
138+
return oi
139+
140+
other_input = create_other_input()
141+
monkeypatch.setattr("builtins.input", lambda _: other_input())
142+
mode, login_info = S._init_mode("cloud")
143+
assert mode == "local"
144+
assert login_info is None
145+
69146

70147
class TestInitMode:
71148
"""

0 commit comments

Comments
 (0)