Skip to content

Commit

Permalink
Revert "Use pydantic."
Browse files Browse the repository at this point in the history
This reverts commit 7461e8e.
  • Loading branch information
robinjhuang committed Aug 8, 2024
1 parent 836db8c commit 2e351ce
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 57 deletions.
73 changes: 30 additions & 43 deletions model_filemanager/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,25 +8,35 @@
from typing import Callable, Any, Optional, Awaitable, Dict
from enum import Enum
import time
from pydantic import BaseModel, Field
from dataclasses import dataclass

class DownloadStatusType(str, Enum):

class DownloadStatusType(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
ERROR = "error"

class DownloadModelStatus(BaseModel):
status: DownloadStatusType
progress_percentage: float = Field(ge=0, le=100)
@dataclass
class DownloadModelStatus():
status: str
progress_percentage: float
message: str
already_existed: bool = False

class Config:
use_enum_values = True

def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool):
self.status = status.value # Store the string value of the Enum
self.progress_percentage = progress_percentage
self.message = message
self.already_existed = already_existed

def to_dict(self) -> Dict[str, Any]:
return self.model_dump()
return {
"status": self.status,
"progress_percentage": self.progress_percentage,
"message": self.message,
"already_existed": self.already_existed
}

async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]],
model_name: str,
Expand Down Expand Up @@ -55,10 +65,10 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
"""
if not validate_model_subdirectory(model_sub_directory):
return DownloadModelStatus(
status=DownloadStatusType.ERROR,
progress_percentage=0,
message="Invalid model subdirectory",
already_existed=False
DownloadStatusType.ERROR,
0,
"Invalid model subdirectory",
False
)

file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir)
Expand All @@ -67,25 +77,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht
return existing_file

try:
status = DownloadModelStatus(status=DownloadStatusType.PENDING,
progress_percentage=0,
message=f"Starting download of {model_name}",
already_existed=False)
status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False)
await progress_callback(relative_path, status)

response = await model_download_request(model_url)
if response.status != 200:
error_message = f"Failed to download {model_name}. Status code: {response.status}"
logging.error(error_message)
status = DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage= 0,
message=error_message,
already_existed= False)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage=0,
message= error_message,
already_existed=False)
return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)

return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval)

Expand All @@ -106,11 +107,7 @@ async def check_file_exists(file_path: str,
progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]],
relative_path: str) -> Optional[DownloadModelStatus]:
if os.path.exists(file_path):
status = DownloadModelStatus(
status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message= f"{model_name} already exists",
already_existed=True)
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True)
await progress_callback(relative_path, status)
return status
return None
Expand All @@ -130,10 +127,7 @@ async def track_download_progress(response: aiohttp.ClientResponse,
async def update_progress():
nonlocal last_update_time
progress = (downloaded / total_size) * 100 if total_size > 0 else 0
status = DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS,
progress_percentage=progress,
message=f"Downloading {model_name}",
already_existed=False)
status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False)
await progress_callback(relative_path, status)
last_update_time = time.time()

Expand All @@ -153,11 +147,7 @@ async def update_progress():
await update_progress()

logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}")
status = DownloadModelStatus(
status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message=f"Successfully downloaded {model_name}",
already_existed=False)
status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False)
await progress_callback(relative_path, status)

return status
Expand All @@ -171,10 +161,7 @@ async def handle_download_error(e: Exception,
progress_callback: Callable[[str, DownloadModelStatus], Any],
relative_path: str) -> DownloadModelStatus:
error_message = f"Error downloading {model_name}: {str(e)}"
status = DownloadModelStatus(status=DownloadStatusType.ERROR,
progress_percentage=0,
message=error_message,
already_existed=False)
status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False)
await progress_callback(relative_path, status)
return status

Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ Pillow
scipy
tqdm
psutil
pydantic~=2.8

#non essential dependencies:
kornia>=0.7.1
Expand Down
17 changes: 4 additions & 13 deletions tests-unit/prompt_server_test/download_models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,13 @@ async def test_download_model_success():
# Check initial call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadModelStatus(status=DownloadStatusType.PENDING,
progress_percentage= 0,
message="Starting download of model.bin",
already_existed= False)
DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False)
)

# Check final call
mock_progress_callback.assert_any_call(
'checkpoints/model.bin',
DownloadModelStatus(status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message="Successfully downloaded model.bin",
already_existed= False)
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False)
)

# Verify file writing
Expand Down Expand Up @@ -210,10 +204,7 @@ async def test_check_file_exists_when_file_exists(tmp_path):

mock_callback.assert_called_once_with(
"test/existing_model.bin",
DownloadModelStatus(status=DownloadStatusType.COMPLETED,
progress_percentage=100,
message="existing_model.bin already exists",
already_existed=True)
DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True)
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -246,7 +237,7 @@ async def test_track_download_progress_no_content_length():
# Check that progress was reported even without knowing the total size
mock_callback.assert_any_call(
'models/model.bin',
DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, progress_percentage= 0, message="Downloading model.bin", already_existed=False)
DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False)
)

@pytest.mark.asyncio
Expand Down

0 comments on commit 2e351ce

Please sign in to comment.