1212
1313from swanboard import SwanBoardCallback
1414from swankit .env import SwanLabMode
15+ from swankit .log import FONT
1516
16- from swanlab .api import code_login
17- from swanlab .env import SwanLabEnv
17+ from swanlab .api import code_login , terminal_login
18+ from swanlab .env import SwanLabEnv , is_interactive
1819from swanlab .log import swanlog
1920from .callback_cloud import CloudRunCallback
2021from .callback_local import LocalRunCallback
2728 get_run ,
2829)
2930from .run .helper import SwanLabRunOperator
31+ from ..error import KeyFileError
32+ from ..package import get_key , get_host_web
3033
3134
3235def _check_proj_name (name : str ) -> str :
@@ -68,7 +71,10 @@ def login(api_key: str = None):
6871 """
6972 if SwanLabRun .is_started ():
7073 raise RuntimeError ("You must call swanlab.login() before using init()" )
71- CloudRunCallback .login_info = code_login (api_key ) if api_key else CloudRunCallback .get_login_info ()
74+ CloudRunCallback .login_info = code_login (api_key ) if api_key else CloudRunCallback .create_login_info ()
75+
76+
77+ MODES = Literal ["disabled" , "cloud" , "local" ]
7278
7379
7480def init (
@@ -78,7 +84,7 @@ def init(
7884 description : str = None ,
7985 config : Union [dict , str ] = None ,
8086 logdir : str = None ,
81- mode : Literal [ "disabled" , "cloud" , "local" ] = None ,
87+ mode : MODES = None ,
8288 load : str = None ,
8389 public : bool = None ,
8490 ** kwargs ,
@@ -152,13 +158,9 @@ def init(
152158 project = _load_data (load_data , "project" , project )
153159 workspace = _load_data (load_data , "workspace" , workspace )
154160 public = _load_data (load_data , "private" , public )
155-
156- # ---------------------------------- 模式选择 ----------------------------------
157- # for
158-
159- # ---------------------------------- helper初始化 ----------------------------------
160- operator , c = _create_operator (mode , public )
161161 project = _check_proj_name (project if project else os .path .basename (os .getcwd ())) # 默认实验名称为当前目录名
162+ # ---------------------------------- 启动操作员 ----------------------------------
163+ operator , c = _create_operator (mode , public )
162164 exp_num = SwanLabRunOperator .parse_return (
163165 operator .on_init (project , workspace , logdir = logdir ),
164166 key = c .__str__ () if c else None ,
@@ -238,6 +240,8 @@ def _init_mode(mode: str = None):
238240 传入的mode必须为SwanLabMode枚举中的一个值,否则报错ValueError
239241 如果环境变量和传入的mode参数都为None,则默认为cloud
240242
243+ 从环境变量中提取mode参数以后,还有一步让用户选择运行模式的交互,详见issue: https://github.com/SwanHubX/SwanLab/issues/632
244+
241245 :param mode: str, optional
242246 传入的mode参数
243247 :return: str mode
@@ -252,8 +256,45 @@ def _init_mode(mode: str = None):
252256 if mode is not None and mode not in allowed :
253257 raise ValueError (f"`mode` must be one of { allowed } , but got { mode } " )
254258 mode = "cloud" if mode is None else mode
259+ # 如果mode为cloud,且没找到 api key或者未登录,则提示用户输入
260+ try :
261+ get_key ()
262+ no_api_key = False
263+ except KeyFileError :
264+ no_api_key = True
265+ login_info = None
266+ if mode == "cloud" and no_api_key :
267+ # 判断当前进程是否在交互模式下
268+ if is_interactive ():
269+ swanlog .info (
270+ f"Using SwanLab to track your experiments. Please refer to { FONT .yellow ('https://docs.swanlab.cn' )} for more information."
271+ )
272+ swanlog .info ("(1) Create a SwanLab account." )
273+ swanlog .info ("(2) Use an existing SwanLab account." )
274+ swanlog .info ("(3) Don't visualize my results." )
275+
276+ # 交互选择
277+ tip = FONT .swanlab ("Enter your choice: " )
278+ code = input (tip )
279+ while code not in ["1" , "2" , "3" ]:
280+ swanlog .warning ("Invalid choice, please enter again." )
281+ code = input (tip )
282+ if code == "3" :
283+ mode = "local"
284+ elif code == "2" :
285+ swanlog .info ("You chose 'Use an existing swanlab account'" )
286+ swanlog .info ("Logging into " + FONT .yellow (get_host_web ()))
287+ login_info = terminal_login ()
288+ elif code == "1" :
289+ swanlog .info ("You chose 'Create a swanlab account'" )
290+ swanlog .info ("Create a SwanLab account here: " + FONT .yellow (get_host_web () + "/login" ))
291+ login_info = terminal_login ()
292+ else :
293+ raise ValueError ("Invalid choice" )
294+
295+ # 如果不在就不管
255296 os .environ [mode_key ] = mode
256- return mode
297+ return mode , login_info
257298
258299
259300def _init_config (config : Union [dict , str ]):
@@ -284,7 +325,9 @@ def _create_operator(mode: str, public: bool) -> Tuple[SwanLabRunOperator, Optio
284325 :param public: 是否公开
285326 :return: SwanLabRunOperator, CloudRunCallback
286327 """
287- mode = _init_mode (mode )
328+ mode , login_info = _init_mode (mode )
329+ CloudRunCallback .login_info = login_info
330+
288331 if mode == SwanLabMode .DISABLED .value :
289332 swanlog .warning ("SwanLab run disabled, the data will not be saved or uploaded." )
290333 return SwanLabRunOperator (), None
0 commit comments