Skip to content

Commit af71ada

Browse files
committed
feat: host params for login cli
1 parent de87581 commit af71ada

File tree

6 files changed

+121
-64
lines changed

6 files changed

+121
-64
lines changed

swanlab/cli/commands/auth/login.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,15 @@
77
@Description:
88
登录模块
99
"""
10+
import os
11+
import sys
12+
1013
import click
1114
from swankit.log import FONT
1215

16+
from swanlab import SwanLabEnv
1317
from swanlab.api.auth import terminal_login
14-
from swanlab.package import has_api_key
18+
from swanlab.package import has_api_key, HostFormatter
1519

1620

1721
@click.command()
@@ -30,13 +34,36 @@
3034
help="If you prefer not to engage in commands-line interaction to input the api key, "
3135
"this will allow automatic login.",
3236
)
33-
def login(api_key: str, relogin: bool):
37+
@click.option(
38+
"--host",
39+
"-h",
40+
default=None,
41+
type=str,
42+
help="The host of the swanlab server.",
43+
)
44+
@click.option(
45+
"--web-host",
46+
"-w",
47+
default=None,
48+
type=str,
49+
help="The web host of the swanlab cloud front-end.",
50+
)
51+
def login(api_key: str, relogin: bool, host: str, web_host: str):
3452
"""Login to the swanlab cloud."""
3553
if not relogin and has_api_key():
3654
# 此时代表token已经获取,需要打印一条信息:已经登录
3755
command = FONT.bold("swanlab login --relogin")
3856
tip = FONT.swanlab("You are already logged in. Use `" + command + "` to force relogin.")
3957
return print(tip)
58+
# 清除环境变量
59+
if relogin:
60+
del os.environ[SwanLabEnv.API_HOST.value]
61+
del os.environ[SwanLabEnv.WEB_HOST.value]
62+
try:
63+
HostFormatter(host, web_host)()
64+
except ValueError as e:
65+
click.BadParameter(str(e))
66+
return sys.exit(1)
4067
# 进行登录,此时将直接覆盖本地token文件
4168
login_info = terminal_login(api_key)
4269
print(FONT.swanlab("Login successfully. Hi, " + FONT.bold(FONT.default(login_info.username))) + "!")

swanlab/data/sdk.py

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
在此处封装swanlab在日志记录模式下的各种接口
99
"""
1010
import os
11-
import re
1211
from typing import Optional, Union, Dict, Tuple, Literal
1312

1413
from swanboard import SwanBoardCallback
@@ -30,7 +29,7 @@
3029
)
3130
from .run.helper import SwanLabRunOperator
3231
from ..error import KeyFileError
33-
from ..package import get_key, get_host_web
32+
from ..package import get_key, get_host_web, HostFormatter
3433

3534

3635
def _check_proj_name(name: str) -> str:
@@ -58,31 +57,6 @@ def _check_proj_name(name: str) -> str:
5857
return _name
5958

6059

61-
class HostFormatter:
62-
def __init__(self):
63-
# 定义正则模式,匹配协议、主机、端口三部分
64-
self.pattern = re.compile(
65-
r'^(?:(https?)://)?' # 可选协议 http 或 https
66-
r'([a-zA-Z0-9.-]+\.[a-zA-Z]{2,63})' # 必填 主机名必须包含顶级域名(TLD)
67-
r'(?::(\d{1,5}))?$' # 可选端口号(1~5位数字)
68-
)
69-
70-
def __call__(self, input_str: str) -> str:
71-
match = self.pattern.match(input_str.rstrip("/"))
72-
if match:
73-
protocol = match.group(1) or "https" # 默认协议为 https
74-
host = match.group(2)
75-
port = match.group(3)
76-
77-
# 构建标准化的 URL 输出
78-
result = f"{protocol}://{host}"
79-
if port:
80-
result += f":{port}"
81-
return result
82-
else:
83-
raise ValueError("Invalid host format")
84-
85-
8660
def login(api_key: str = None, host: str = None, web_host: str = None):
8761
"""
8862
Login to SwanLab Cloud. If you already have logged in, you can use this function to relogin.
@@ -100,19 +74,9 @@ def login(api_key: str = None, host: str = None, web_host: str = None):
10074
"""
10175
if SwanLabRun.is_started():
10276
raise RuntimeError("You must call swanlab.login() before using init()")
103-
# ---------------------------------- 检查host是否合法,并格式化 ----------------------------------
104-
formater = HostFormatter()
105-
if host:
106-
try:
107-
os.environ[SwanLabEnv.API_HOST.value] = formater(host) + "/api"
108-
except ValueError:
109-
raise ValueError("Invalid host: {}".format(host))
110-
if web_host:
111-
try:
112-
os.environ[SwanLabEnv.WEB_HOST.value] = formater(web_host)
113-
except ValueError:
114-
raise ValueError("Invalid web_host: {}".format(web_host))
115-
# ---------------------------------- 登录,初始化http对象 ----------------------------------
77+
# 检查host是否合法,并格式化,注入到环境变量中
78+
HostFormatter(host, web_host)()
79+
# 登录,初始化http对象
11680
CloudRunCallback.login_info = code_login(api_key) if api_key else CloudRunCallback.create_login_info()
11781

11882

swanlab/package.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import json
1111
import netrc
1212
import os
13+
import re
1314
from typing import Optional
1415

1516
import requests
@@ -54,6 +55,50 @@ def get_package_latest_version(timeout=0.5) -> Optional[str]:
5455
# ---------------------------------- 云端相关 ----------------------------------
5556

5657

58+
class HostFormatter:
59+
def __init__(self, host: str = None, web_host: str = None):
60+
# 定义正则模式,匹配协议、主机、端口三部分
61+
self.pattern = re.compile(
62+
r'^(?:(https?)://)?' # 可选协议 http 或 https
63+
r'([a-zA-Z0-9.-]+\.[a-zA-Z]{2,63})' # 必填 主机名必须包含顶级域名(TLD)
64+
r'(?::(\d{1,5}))?$' # 可选端口号(1~5位数字)
65+
)
66+
self.host = host
67+
self.web_host = web_host
68+
69+
def fmt(self, input_str: str) -> str:
70+
match = self.pattern.match(input_str.rstrip("/"))
71+
if match:
72+
protocol = match.group(1) or "https" # 默认协议为 https
73+
host = match.group(2)
74+
port = match.group(3)
75+
76+
# 构建标准化的 URL 输出
77+
result = f"{protocol}://{host}"
78+
if port:
79+
result += f":{port}"
80+
return result
81+
else:
82+
raise ValueError("Invalid host format")
83+
84+
def __call__(self):
85+
"""
86+
如果host或web_host不为空,格式化并设置环境变量
87+
:raises ValueError: host或web_host格式不正确
88+
"""
89+
if self.host:
90+
try:
91+
os.environ[SwanLabEnv.API_HOST.value] = self.fmt(self.host) + "/api"
92+
except ValueError:
93+
raise ValueError("Invalid host: {}".format(self.host))
94+
self.web_host = self.host
95+
if self.web_host:
96+
try:
97+
os.environ[SwanLabEnv.WEB_HOST.value] = self.fmt(self.web_host)
98+
except ValueError:
99+
raise ValueError("Invalid web_host: {}".format(self.web_host))
100+
101+
57102
def get_host_web() -> str:
58103
"""获取swanlab网站网址
59104
:return: swanlab网站的网址

test/unit/cli/test_cli_login.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,17 @@
77
@Description:
88
测试命令登录
99
"""
10+
import os
11+
1012
import nanoid
11-
from swanlab.package import get_key
13+
import pytest
1214
from click.testing import CliRunner
15+
16+
import tutils as T
1317
from swanlab.cli.main import cli
18+
from swanlab.env import SwanLabEnv
1419
from swanlab.error import ValidationError, APIKeyFormatError
15-
import tutils as T
16-
import pytest
20+
from swanlab.package import get_key
1721

1822

1923
# noinspection PyTypeChecker
@@ -35,3 +39,20 @@ def test_login_fail():
3539
result = runner.invoke(cli, ["login", "--api-key", "wrong-key"])
3640
assert result.exit_code == 1
3741
assert isinstance(result.exception, APIKeyFormatError)
42+
43+
44+
# noinspection PyTypeChecker
45+
@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
46+
def test_login_host():
47+
"""
48+
测试登录时指定host
49+
"""
50+
runner = CliRunner()
51+
del os.environ[SwanLabEnv.API_HOST.value] # 删除环境变量
52+
del os.environ[SwanLabEnv.WEB_HOST.value] # 删除环境变量
53+
result = runner.invoke(cli, ["login", "--api-key", T.API_KEY, "--host", T.API_HOST.rstrip("/api")])
54+
assert result.exit_code == 0
55+
del os.environ[SwanLabEnv.API_HOST.value] # 删除环境变量
56+
del os.environ[SwanLabEnv.WEB_HOST.value] # 删除环境变量
57+
result = runner.invoke(cli, ["login", "--api-key", T.API_KEY, "--host", "http://wrong-host"])
58+
assert result.exit_code == 1

test/unit/data/test_sdk.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -306,25 +306,6 @@ def test_init_logdir_env(self):
306306
assert run.public.swanlog_dir == logdir
307307

308308

309-
class TestHostFormatter:
310-
def test_ok(self):
311-
formatter = S.HostFormatter()
312-
assert formatter("swanlab.cn") == "https://swanlab.cn"
313-
assert formatter("https://swanlab.cn") == "https://swanlab.cn"
314-
assert formatter("http://swanlab.cn") == "http://swanlab.cn" # noqa
315-
assert formatter("https://swanlab.cn:8443/") == "https://swanlab.cn:8443"
316-
assert formatter("abc.example.com") == "https://abc.example.com"
317-
318-
def test_value_err(self):
319-
formatter = S.HostFormatter()
320-
with pytest.raises(ValueError):
321-
formatter("test")
322-
with pytest.raises(ValueError):
323-
formatter("https://test")
324-
with pytest.raises(ValueError):
325-
formatter("http://test") # noqa
326-
327-
328309
@pytest.mark.skipif(T.is_skip_cloud_test, reason="skip cloud test")
329310
class TestLogin:
330311
"""

test/unit/test_package.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,22 @@ def test_wrong_host(self):
239239
self.save_api_key()
240240
os.environ[SwanLabEnv.API_HOST.value] = nanoid.generate()
241241
assert not P.has_api_key()
242+
243+
244+
class TestHostFormatter:
245+
def test_ok(self):
246+
formatter = P.HostFormatter()
247+
assert formatter.fmt("swanlab.cn") == "https://swanlab.cn"
248+
assert formatter.fmt("https://swanlab.cn") == "https://swanlab.cn"
249+
assert formatter.fmt("http://swanlab.cn") == "http://swanlab.cn" # noqa
250+
assert formatter.fmt("https://swanlab.cn:8443/") == "https://swanlab.cn:8443"
251+
assert formatter.fmt("abc.example.com") == "https://abc.example.com"
252+
253+
def test_value_err(self):
254+
formatter = P.HostFormatter()
255+
with pytest.raises(ValueError):
256+
formatter.fmt("test")
257+
with pytest.raises(ValueError):
258+
formatter.fmt("https://test")
259+
with pytest.raises(ValueError):
260+
formatter.fmt("http://test") # noqa

0 commit comments

Comments
 (0)