Skip to content

Commit 6833ba5

Browse files
committed
fix import
1 parent 2539e4a commit 6833ba5

File tree

1 file changed

+27
-6
lines changed

1 file changed

+27
-6
lines changed

swanlab/sync/wandb.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,34 @@
11
"""
2-
swanlab.init(sync_wandb=True)
2+
import wandb
3+
import random
4+
import swanlab
5+
6+
swanlab.sync_wandb()
7+
# swanlab.init(project="sync_wandb")
8+
9+
wandb.init(
10+
project="test",
11+
config={"a": 1, "b": 2},
12+
name="test",
13+
)
14+
15+
epochs = 10
16+
offset = random.random() / 5
17+
for epoch in range(2, epochs):
18+
acc = 1 - 2 ** -epoch - random.random() / epoch - offset
19+
loss = 2 ** -epoch + random.random() / epoch + offset
20+
21+
wandb.log({"acc": acc, "loss": loss})
322
"""
423
import swanlab
5-
try:
6-
import wandb
7-
from wandb import sdk as wandb_sdk
8-
except ImportError:
9-
raise ImportError("please install wandb first, command: `pip install wandb`")
1024

1125
def sync_wandb():
26+
try:
27+
import wandb
28+
from wandb import sdk as wandb_sdk
29+
except ImportError:
30+
raise ImportError("please install wandb first, command: `pip install wandb`")
31+
1232
original_init = wandb.init
1333
original_log = wandb_sdk.wandb_run.Run.log
1434
original_finish = wandb_sdk.finish
@@ -53,6 +73,7 @@ def patched_finish(*args, **kwargs):
5373

5474
if __name__ == "__main__":
5575
import random
76+
import wandb
5677

5778
# 在使用前调用sync_wandb
5879
sync_wandb()

0 commit comments

Comments
 (0)