Skip to content

Commit a43c13f

Browse files
authored
Refactor of Models API, Support for calling downloads (#985)
1 parent b27b806 commit a43c13f

File tree

2 files changed

+189
-60
lines changed

2 files changed

+189
-60
lines changed

api/src/serge/routers/model.py

+130-57
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import os
33
import shutil
44

5+
import aiohttp
56

67
from fastapi import APIRouter, HTTPException
7-
import huggingface_hub
8+
from huggingface_hub import hf_hub_url
89
from serge.models.models import Families
910

1011
from pathlib import Path
@@ -14,6 +15,8 @@
1415
tags=["model"],
1516
)
1617

18+
active_downloads = {}
19+
1720
WEIGHTS = "/usr/src/app/weights/"
1821

1922
models_file_path = Path(__file__).parent.parent / "data" / "models.json"
@@ -34,13 +37,59 @@
3437
# Helper functions
3538
async def is_model_installed(model_name: str) -> bool:
3639
installed_models = await list_of_installed_models()
37-
return f"{model_name}.bin" in installed_models
40+
return any(file_name == f"{model_name}.bin" and not file_name.startswith(".") for file_name in installed_models)
3841

3942

4043
async def get_file_size(file_path: str) -> int:
4144
return os.stat(file_path).st_size
4245

4346

47+
async def cleanup_model_resources(model_name: str):
48+
model_repo, _, _ = models_info.get(model_name, (None, None, None))
49+
if not model_repo:
50+
print(f"No model repo found for {model_name}, cleanup may be incomplete.")
51+
return
52+
53+
temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
54+
lock_dir = os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}")
55+
cache_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}")
56+
57+
# Try to cleanup temporary file if it exists
58+
if os.path.exists(temp_model_path):
59+
try:
60+
os.remove(temp_model_path)
61+
except OSError as e:
62+
print(f"Error removing temporary file for {model_name}: {e}")
63+
64+
# Remove lock file if it exists
65+
if os.path.exists(lock_dir):
66+
try:
67+
shutil.rmtree(lock_dir)
68+
except OSError as e:
69+
print(f"Error removing lock directory for {model_name}: {e}")
70+
71+
# Remove cache directory if it exists
72+
if os.path.exists(cache_dir):
73+
try:
74+
shutil.rmtree(cache_dir)
75+
except OSError as e:
76+
print(f"Error removing cache directory for {model_name}: {e}")
77+
78+
79+
async def download_file(session: aiohttp.ClientSession, url: str, path: str) -> None:
80+
async with session.get(url) as response:
81+
if response.status != 200:
82+
raise HTTPException(status_code=500, detail="Error downloading model")
83+
84+
# Write response content to file asynchronously
85+
with open(path, "wb") as f:
86+
while True:
87+
chunk = await response.content.read(1024)
88+
if not chunk:
89+
break
90+
f.write(chunk)
91+
92+
4493
# Handlers
4594
@model_router.get("/all")
4695
async def list_of_all_models():
@@ -73,26 +122,15 @@ async def list_of_all_models():
73122
return resp
74123

75124

76-
@model_router.get("/downloadable")
77-
async def list_of_downloadable_models():
78-
files = os.listdir(WEIGHTS)
79-
files = list(filter(lambda x: x.endswith(".bin"), files))
80-
81-
installed_models = [i.rstrip(".bin") for i in files]
82-
83-
return list(filter(lambda x: x not in installed_models, models_info.keys()))
84-
85-
86125
@model_router.get("/installed")
87126
async def list_of_installed_models():
88-
# after iterating through the WEIGHTS directory, return location and filename
127+
# Iterate through the WEIGHTS directory and return filenames that end with .bin and do not start with a dot
89128
files = [
90-
f"{model_location.replace(WEIGHTS, '')}/{bin_file}"
91-
for model_location, directory, filenames in os.walk(WEIGHTS)
129+
os.path.join(model_location.replace(WEIGHTS, "").lstrip("/"), bin_file)
130+
for model_location, _, filenames in os.walk(WEIGHTS)
92131
for bin_file in filenames
93-
if os.path.splitext(bin_file)[1] == ".bin"
132+
if bin_file.endswith(".bin") and not bin_file.startswith(".")
94133
]
95-
files = [i.lstrip("/") for i in files]
96134
return files
97135

98136

@@ -102,18 +140,63 @@ async def download_model(model_name: str):
102140
raise HTTPException(status_code=404, detail="Model not found")
103141

104142
try:
105-
# Download file, and resume broken downloads
106143
model_repo, filename, _ = models_info[model_name]
107-
model_path = f"{WEIGHTS}{model_name}.bin"
108-
await asyncio.to_thread(
109-
huggingface_hub.hf_hub_download, repo_id=model_repo, filename=filename, local_dir=WEIGHTS, cache_dir=WEIGHTS, resume_download=True
110-
)
111-
# Rename file
112-
os.rename(os.path.join(WEIGHTS, filename), os.path.join(WEIGHTS, model_path))
144+
model_url = hf_hub_url(repo_id=model_repo, filename=filename)
145+
temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
146+
model_path = os.path.join(WEIGHTS, f"{model_name}.bin")
147+
148+
# Create an aiohttp session with timeout settings
149+
timeout = aiohttp.ClientTimeout(total=300)
150+
async with aiohttp.ClientSession(timeout=timeout) as session:
151+
# Start the download and add to active_downloads
152+
download_task = asyncio.create_task(download_file(session, model_url, temp_model_path))
153+
active_downloads[model_name] = download_task
154+
await download_task
155+
156+
# Rename the dotfile to its final name
157+
os.rename(temp_model_path, model_path)
158+
159+
# Remove the entry from active_downloads after successful download
160+
active_downloads.pop(model_name, None)
161+
113162
return {"message": f"Model {model_name} downloaded"}
163+
except asyncio.CancelledError:
164+
await cleanup_model_resources(model_name)
165+
raise HTTPException(status_code=200, detail="Download cancelled")
166+
except Exception as exc:
167+
await cleanup_model_resources(model_name)
168+
raise HTTPException(status_code=500, detail=f"Error downloading model: {exc}")
169+
170+
171+
@model_router.post("/{model_name}/download/cancel")
172+
async def cancel_download(model_name: str):
173+
try:
174+
task = active_downloads.get(model_name)
175+
if not task:
176+
raise HTTPException(status_code=404, detail="No active download for this model")
177+
178+
# Remove the entry from active downloads after cancellation
179+
task.cancel()
180+
181+
# Remove entry from active downloads
182+
active_downloads.pop(model_name, None)
183+
184+
# Wait for the task to be cancelled
185+
try:
186+
# Wait for the task to respond to cancellation
187+
print(f"Waiting for download for {model_name} to be cancelled")
188+
await task
189+
except asyncio.CancelledError:
190+
# Handle the expected cancellation exception
191+
pass
192+
193+
# Cleanup resources
194+
await cleanup_model_resources(model_name)
195+
196+
print(f"Download for {model_name} cancelled")
197+
return {"message": f"Download for {model_name} cancelled"}
114198
except Exception as e:
115-
# Handle exceptions, possibly log them
116-
raise HTTPException(status_code=500, detail=f"Error downloading model: {str(e)}")
199+
raise HTTPException(status_code=500, detail=f"Error cancelling model download: {str(e)}")
117200

118201

119202
@model_router.get("/{model_name}/download/status")
@@ -125,48 +208,38 @@ async def download_status(model_name: str):
125208
model_repo, _, _ = models_info[model_name]
126209

127210
# Construct the path to the blobs directory
128-
blobs_dir = os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}", "blobs")
129-
130-
# Check for the .incomplete file in the blobs directory
131-
if os.path.exists(os.path.join(WEIGHTS, f"{model_name}.bin")):
132-
currentsize = os.path.getsize(os.path.join(WEIGHTS, f"{model_name}.bin"))
133-
return min(round(currentsize / filesize * 100, 1), 100)
134-
elif os.path.exists(blobs_dir):
135-
for file in os.listdir(blobs_dir):
136-
if file.endswith(".incomplete"):
137-
incomplete_file_path = os.path.join(blobs_dir, file)
138-
# Check if the .incomplete file exists and calculate the download status
139-
if os.path.exists(incomplete_file_path):
140-
currentsize = os.path.getsize(incomplete_file_path)
141-
return min(round(currentsize / filesize * 100, 1), 100)
142-
return 0
211+
temp_model_path = os.path.join(WEIGHTS, f".{model_name}.bin")
212+
model_path = os.path.join(WEIGHTS, f"{model_name}.bin")
213+
214+
# Check if the model is currently being downloaded
215+
task = active_downloads.get(model_name)
216+
217+
if os.path.exists(model_path):
218+
currentsize = os.path.getsize(model_path)
219+
progress = min(round(currentsize / filesize * 100, 1), 100)
220+
return progress
221+
elif task and not task.done():
222+
# If the task is still running, check for incomplete files
223+
if os.path.exists(temp_model_path):
224+
currentsize = os.path.getsize(temp_model_path)
225+
return min(round(currentsize / filesize * 100, 1), 100)
226+
# If temp_model_path doesn't exist, the download is likely just starting, progress is 0
227+
return 0
228+
else:
229+
# No active download and the file does not exist
230+
return None
143231

144232

145233
@model_router.delete("/{model_name}")
146234
async def delete_model(model_name: str):
147235
if f"{model_name}.bin" not in await list_of_installed_models():
148236
raise HTTPException(status_code=404, detail="Model not found")
149237

150-
model_repo, _, _ = models_info.get(model_name, (None, None, None))
151-
if not model_repo:
152-
raise HTTPException(status_code=404, detail="Model info not found")
153-
154-
# Remove link to model file
155238
try:
156239
os.remove(os.path.join(WEIGHTS, f"{model_name}.bin"))
157240
except OSError as e:
158241
print(f"Error removing model file: {e}")
159242

160-
# Remove lock file
161-
try:
162-
shutil.rmtree(os.path.join(WEIGHTS, ".locks", f"models--{model_repo.replace('/', '--')}"))
163-
except OSError as e:
164-
print(f"Error removing lock directory: {e}")
165-
166-
# Remove cache directory
167-
try:
168-
shutil.rmtree(os.path.join(WEIGHTS, f"models--{model_repo.replace('/', '--')}"))
169-
except OSError as e:
170-
print(f"Error removing cache directory: {e}")
243+
await cleanup_model_resources(model_name)
171244

172245
return {"message": f"Model {model_name} deleted"}

web/src/routes/models/+page.svelte

+59-3
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,15 @@
122122
* @param model - The model name.
123123
* @param isAvailable - Boolean indicating if the model is available.
124124
*/
125-
async function handleModelAction(model: string, isAvailable: boolean) {
125+
async function handleModelAction(
126+
model: string,
127+
isAvailable: boolean,
128+
isDownloading: boolean = false,
129+
) {
130+
if (isDownloading) {
131+
await cancelDownload(model);
132+
return;
133+
}
126134
const url = `/api/model/${model}${isAvailable ? "" : "/download"}`;
127135
const method = isAvailable ? "DELETE" : "POST";
128136
@@ -218,6 +226,42 @@
218226
$: downloadedOrDownloadingModels = data.models
219227
.filter((model) => model.progress > 0 || model.available)
220228
.sort((a, b) => a.name.localeCompare(b.name));
229+
230+
async function cancelDownload(modelName: string) {
231+
try {
232+
const response = await fetch(`/api/model/${modelName}/download/cancel`, {
233+
method: "POST",
234+
});
235+
236+
if (response.ok) {
237+
console.log(`Download for ${modelName} cancelled successfully.`);
238+
// Update UI based on successful cancellation
239+
const modelIndex = data.models.findIndex((m) => m.name === modelName);
240+
if (modelIndex !== -1) {
241+
data.models[modelIndex].progress = 0;
242+
data.models[modelIndex].available = false;
243+
data.models = [...data.models]; // trigger reactivity
244+
}
245+
246+
// Remove model from tracking and local storage
247+
downloadingModels.delete(modelName);
248+
const currentDownloads = JSON.parse(
249+
localStorage.getItem("downloadingModels") || "[]",
250+
);
251+
const updatedDownloads = currentDownloads.filter(
252+
(model: string) => model !== modelName,
253+
);
254+
localStorage.setItem(
255+
"downloadingModels",
256+
JSON.stringify(updatedDownloads),
257+
);
258+
} else {
259+
console.error(`Failed to cancel download for ${modelName}`);
260+
}
261+
} catch (error) {
262+
console.error(`Error cancelling download for ${modelName}:`, error);
263+
}
264+
}
221265
</script>
222266

223267
<div class="top-section">
@@ -249,13 +293,25 @@
249293
</div>
250294
{/if}
251295
{#if model.progress >= 100}
252-
<p>Size: {model.size / 1e9}GB</p>
296+
<p>Size: {(model.size / 1e9).toFixed(2)} GB</p>
253297
<button
254298
on:click={() => handleModelAction(model.name, model.available)}
255299
class="btn btn-error mt-2"
256300
>
257301
<Icon icon="mdi:trash" width="32" height="32" />
258302
</button>
303+
{:else}
304+
<button
305+
on:click={() =>
306+
handleModelAction(
307+
model.name,
308+
model.available,
309+
model.progress > 0 && model.progress < 100,
310+
)}
311+
class="btn btn-error mt-2"
312+
>
313+
<Icon icon="mdi:cancel" width="32" height="32" />
314+
</button>
259315
{/if}
260316
</div>
261317
</div>
@@ -287,7 +343,7 @@
287343
{#if models.length === 1}
288344
<h3>{truncateString(model.name, 24)}</h3>
289345
{/if}
290-
<p>Size: {model.size / 1e9}GB</p>
346+
<p>Size: {(model.size / 1e9).toFixed(2)} GB</p>
291347
<button
292348
on:click={() => handleModelAction(model.name, model.available)}
293349
class="btn btn-primary mt-2"

0 commit comments

Comments
 (0)