2
2
import os
3
3
import shutil
4
4
5
+ import aiohttp
5
6
6
7
from fastapi import APIRouter , HTTPException
7
- import huggingface_hub
8
+ from huggingface_hub import hf_hub_url
8
9
from serge .models .models import Families
9
10
10
11
from pathlib import Path
14
15
tags = ["model" ],
15
16
)
16
17
18
+ active_downloads = {}
19
+
17
20
WEIGHTS = "/usr/src/app/weights/"
18
21
19
22
models_file_path = Path (__file__ ).parent .parent / "data" / "models.json"
34
37
# Helper functions
35
38
async def is_model_installed (model_name : str ) -> bool :
36
39
installed_models = await list_of_installed_models ()
37
- return f"{ model_name } .bin" in installed_models
40
+ return any ( file_name == f"{ model_name } .bin" and not file_name . startswith ( "." ) for file_name in installed_models )
38
41
39
42
40
43
async def get_file_size (file_path : str ) -> int :
41
44
return os .stat (file_path ).st_size
42
45
43
46
47
+ async def cleanup_model_resources (model_name : str ):
48
+ model_repo , _ , _ = models_info .get (model_name , (None , None , None ))
49
+ if not model_repo :
50
+ print (f"No model repo found for { model_name } , cleanup may be incomplete." )
51
+ return
52
+
53
+ temp_model_path = os .path .join (WEIGHTS , f".{ model_name } .bin" )
54
+ lock_dir = os .path .join (WEIGHTS , ".locks" , f"models--{ model_repo .replace ('/' , '--' )} " )
55
+ cache_dir = os .path .join (WEIGHTS , f"models--{ model_repo .replace ('/' , '--' )} " )
56
+
57
+ # Try to cleanup temporary file if it exists
58
+ if os .path .exists (temp_model_path ):
59
+ try :
60
+ os .remove (temp_model_path )
61
+ except OSError as e :
62
+ print (f"Error removing temporary file for { model_name } : { e } " )
63
+
64
+ # Remove lock file if it exists
65
+ if os .path .exists (lock_dir ):
66
+ try :
67
+ shutil .rmtree (lock_dir )
68
+ except OSError as e :
69
+ print (f"Error removing lock directory for { model_name } : { e } " )
70
+
71
+ # Remove cache directory if it exists
72
+ if os .path .exists (cache_dir ):
73
+ try :
74
+ shutil .rmtree (cache_dir )
75
+ except OSError as e :
76
+ print (f"Error removing cache directory for { model_name } : { e } " )
77
+
78
+
79
+ async def download_file (session : aiohttp .ClientSession , url : str , path : str ) -> None :
80
+ async with session .get (url ) as response :
81
+ if response .status != 200 :
82
+ raise HTTPException (status_code = 500 , detail = "Error downloading model" )
83
+
84
+ # Write response content to file asynchronously
85
+ with open (path , "wb" ) as f :
86
+ while True :
87
+ chunk = await response .content .read (1024 )
88
+ if not chunk :
89
+ break
90
+ f .write (chunk )
91
+
92
+
44
93
# Handlers
45
94
@model_router .get ("/all" )
46
95
async def list_of_all_models ():
@@ -73,26 +122,15 @@ async def list_of_all_models():
73
122
return resp
74
123
75
124
76
- @model_router .get ("/downloadable" )
77
- async def list_of_downloadable_models ():
78
- files = os .listdir (WEIGHTS )
79
- files = list (filter (lambda x : x .endswith (".bin" ), files ))
80
-
81
- installed_models = [i .rstrip (".bin" ) for i in files ]
82
-
83
- return list (filter (lambda x : x not in installed_models , models_info .keys ()))
84
-
85
-
86
125
@model_router .get ("/installed" )
87
126
async def list_of_installed_models ():
88
- # after iterating through the WEIGHTS directory, return location and filename
127
+ # Iterate through the WEIGHTS directory and return filenames that end with .bin and do not start with a dot
89
128
files = [
90
- f" { model_location .replace (WEIGHTS , '' ) } / { bin_file } "
91
- for model_location , directory , filenames in os .walk (WEIGHTS )
129
+ os . path . join ( model_location .replace (WEIGHTS , "" ). lstrip ( "/" ), bin_file )
130
+ for model_location , _ , filenames in os .walk (WEIGHTS )
92
131
for bin_file in filenames
93
- if os . path . splitext ( bin_file )[ 1 ] == ".bin"
132
+ if bin_file . endswith ( ".bin" ) and not bin_file . startswith ( "." )
94
133
]
95
- files = [i .lstrip ("/" ) for i in files ]
96
134
return files
97
135
98
136
@@ -102,18 +140,63 @@ async def download_model(model_name: str):
102
140
raise HTTPException (status_code = 404 , detail = "Model not found" )
103
141
104
142
try :
105
- # Download file, and resume broken downloads
106
143
model_repo , filename , _ = models_info [model_name ]
107
- model_path = f"{ WEIGHTS } { model_name } .bin"
108
- await asyncio .to_thread (
109
- huggingface_hub .hf_hub_download , repo_id = model_repo , filename = filename , local_dir = WEIGHTS , cache_dir = WEIGHTS , resume_download = True
110
- )
111
- # Rename file
112
- os .rename (os .path .join (WEIGHTS , filename ), os .path .join (WEIGHTS , model_path ))
144
+ model_url = hf_hub_url (repo_id = model_repo , filename = filename )
145
+ temp_model_path = os .path .join (WEIGHTS , f".{ model_name } .bin" )
146
+ model_path = os .path .join (WEIGHTS , f"{ model_name } .bin" )
147
+
148
+ # Create an aiohttp session with timeout settings
149
+ timeout = aiohttp .ClientTimeout (total = 300 )
150
+ async with aiohttp .ClientSession (timeout = timeout ) as session :
151
+ # Start the download and add to active_downloads
152
+ download_task = asyncio .create_task (download_file (session , model_url , temp_model_path ))
153
+ active_downloads [model_name ] = download_task
154
+ await download_task
155
+
156
+ # Rename the dotfile to its final name
157
+ os .rename (temp_model_path , model_path )
158
+
159
+ # Remove the entry from active_downloads after successful download
160
+ active_downloads .pop (model_name , None )
161
+
113
162
return {"message" : f"Model { model_name } downloaded" }
163
+ except asyncio .CancelledError :
164
+ await cleanup_model_resources (model_name )
165
+ raise HTTPException (status_code = 200 , detail = "Download cancelled" )
166
+ except Exception as exc :
167
+ await cleanup_model_resources (model_name )
168
+ raise HTTPException (status_code = 500 , detail = f"Error downloading model: { exc } " )
169
+
170
+
171
+ @model_router .post ("/{model_name}/download/cancel" )
172
+ async def cancel_download (model_name : str ):
173
+ try :
174
+ task = active_downloads .get (model_name )
175
+ if not task :
176
+ raise HTTPException (status_code = 404 , detail = "No active download for this model" )
177
+
178
+ # Remove the entry from active downloads after cancellation
179
+ task .cancel ()
180
+
181
+ # Remove entry from active downloads
182
+ active_downloads .pop (model_name , None )
183
+
184
+ # Wait for the task to be cancelled
185
+ try :
186
+ # Wait for the task to respond to cancellation
187
+ print (f"Waiting for download for { model_name } to be cancelled" )
188
+ await task
189
+ except asyncio .CancelledError :
190
+ # Handle the expected cancellation exception
191
+ pass
192
+
193
+ # Cleanup resources
194
+ await cleanup_model_resources (model_name )
195
+
196
+ print (f"Download for { model_name } cancelled" )
197
+ return {"message" : f"Download for { model_name } cancelled" }
114
198
except Exception as e :
115
- # Handle exceptions, possibly log them
116
- raise HTTPException (status_code = 500 , detail = f"Error downloading model: { str (e )} " )
199
+ raise HTTPException (status_code = 500 , detail = f"Error cancelling model download: { str (e )} " )
117
200
118
201
119
202
@model_router .get ("/{model_name}/download/status" )
@@ -125,48 +208,38 @@ async def download_status(model_name: str):
125
208
model_repo , _ , _ = models_info [model_name ]
126
209
127
210
# Construct the path to the blobs directory
128
- blobs_dir = os .path .join (WEIGHTS , f"models--{ model_repo .replace ('/' , '--' )} " , "blobs" )
129
-
130
- # Check for the .incomplete file in the blobs directory
131
- if os .path .exists (os .path .join (WEIGHTS , f"{ model_name } .bin" )):
132
- currentsize = os .path .getsize (os .path .join (WEIGHTS , f"{ model_name } .bin" ))
133
- return min (round (currentsize / filesize * 100 , 1 ), 100 )
134
- elif os .path .exists (blobs_dir ):
135
- for file in os .listdir (blobs_dir ):
136
- if file .endswith (".incomplete" ):
137
- incomplete_file_path = os .path .join (blobs_dir , file )
138
- # Check if the .incomplete file exists and calculate the download status
139
- if os .path .exists (incomplete_file_path ):
140
- currentsize = os .path .getsize (incomplete_file_path )
141
- return min (round (currentsize / filesize * 100 , 1 ), 100 )
142
- return 0
211
+ temp_model_path = os .path .join (WEIGHTS , f".{ model_name } .bin" )
212
+ model_path = os .path .join (WEIGHTS , f"{ model_name } .bin" )
213
+
214
+ # Check if the model is currently being downloaded
215
+ task = active_downloads .get (model_name )
216
+
217
+ if os .path .exists (model_path ):
218
+ currentsize = os .path .getsize (model_path )
219
+ progress = min (round (currentsize / filesize * 100 , 1 ), 100 )
220
+ return progress
221
+ elif task and not task .done ():
222
+ # If the task is still running, check for incomplete files
223
+ if os .path .exists (temp_model_path ):
224
+ currentsize = os .path .getsize (temp_model_path )
225
+ return min (round (currentsize / filesize * 100 , 1 ), 100 )
226
+ # If temp_model_path doesn't exist, the download is likely just starting, progress is 0
227
+ return 0
228
+ else :
229
+ # No active download and the file does not exist
230
+ return None
143
231
144
232
145
233
@model_router .delete ("/{model_name}" )
146
234
async def delete_model (model_name : str ):
147
235
if f"{ model_name } .bin" not in await list_of_installed_models ():
148
236
raise HTTPException (status_code = 404 , detail = "Model not found" )
149
237
150
- model_repo , _ , _ = models_info .get (model_name , (None , None , None ))
151
- if not model_repo :
152
- raise HTTPException (status_code = 404 , detail = "Model info not found" )
153
-
154
- # Remove link to model file
155
238
try :
156
239
os .remove (os .path .join (WEIGHTS , f"{ model_name } .bin" ))
157
240
except OSError as e :
158
241
print (f"Error removing model file: { e } " )
159
242
160
- # Remove lock file
161
- try :
162
- shutil .rmtree (os .path .join (WEIGHTS , ".locks" , f"models--{ model_repo .replace ('/' , '--' )} " ))
163
- except OSError as e :
164
- print (f"Error removing lock directory: { e } " )
165
-
166
- # Remove cache directory
167
- try :
168
- shutil .rmtree (os .path .join (WEIGHTS , f"models--{ model_repo .replace ('/' , '--' )} " ))
169
- except OSError as e :
170
- print (f"Error removing cache directory: { e } " )
243
+ await cleanup_model_resources (model_name )
171
244
172
245
return {"message" : f"Model { model_name } deleted" }
0 commit comments