generated from fofr/cog-comfyui
-
Notifications
You must be signed in to change notification settings - Fork 65
/
weights_downloader.py
63 lines (54 loc) · 2.42 KB
/
weights_downloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import subprocess
import time
import os
from weights_manifest import WeightsManifest
BASE_URL = "https://weights.replicate.delivery/default/comfy-ui"
class WeightsDownloader:
def __init__(self):
self.weights_manifest = WeightsManifest()
self.weights_map = self.weights_manifest.weights_map
def download_weights(self, weight_str):
if weight_str in self.weights_map:
if self.weights_manifest.is_non_commercial_only(weight_str):
print(
f"⚠️ {weight_str} is for non-commercial use only. Unless you have obtained a commercial license.\nDetails: https://github.com/fofr/cog-comfyui/blob/main/weights_licenses.md"
)
self.download_if_not_exists(
weight_str,
self.weights_map[weight_str]["url"],
self.weights_map[weight_str]["dest"],
)
else:
raise ValueError(
f"{weight_str} unavailable. View the list of available weights: https://github.com/fofr/cog-comfyui/blob/main/supported_weights.md"
)
def download_torch_checkpoints(self):
self.download_if_not_exists(
"mobilenet_v2-b0353104.pth",
f"{BASE_URL}/custom_nodes/comfyui_controlnet_aux/mobilenet_v2-b0353104.pth.tar",
"/root/.cache/torch/hub/checkpoints/",
)
def download_if_not_exists(self, weight_str, url, dest):
if not os.path.exists(f"{dest}/{weight_str}"):
self.download(weight_str, url, dest)
def download(self, weight_str, url, dest):
if "/" in weight_str:
subfolder = weight_str.rsplit("/", 1)[0]
dest = os.path.join(dest, subfolder)
os.makedirs(dest, exist_ok=True)
print(f"⏳ Downloading {weight_str} to {dest}")
start = time.time()
subprocess.check_call(
["pget", "--log-level", "warn", "-xf", url, dest], close_fds=False
)
elapsed_time = time.time() - start
try:
file_size_bytes = os.path.getsize(
os.path.join(dest, os.path.basename(weight_str))
)
file_size_megabytes = file_size_bytes / (1024 * 1024)
print(
f"⌛️ Downloaded {weight_str} in {elapsed_time:.2f}s, size: {file_size_megabytes:.2f}MB"
)
except FileNotFoundError:
print(f"⌛️ Downloaded {weight_str} in {elapsed_time:.2f}s")