From 2e351ce2be9eb8041a6794e6194121e7a5e1bbfe Mon Sep 17 00:00:00 2001 From: Robin Huang Date: Thu, 8 Aug 2024 13:27:12 -0700 Subject: [PATCH] Revert "Use pydantic." This reverts commit 7461e8eb0073add315c65c6f5e361f0891bffc7d. --- model_filemanager/download_models.py | 73 ++++++++----------- requirements.txt | 1 - .../download_models_test.py | 17 +---- 3 files changed, 34 insertions(+), 57 deletions(-) diff --git a/model_filemanager/download_models.py b/model_filemanager/download_models.py index 406e3588344..fd3ec1dbced 100644 --- a/model_filemanager/download_models.py +++ b/model_filemanager/download_models.py @@ -8,25 +8,35 @@ from typing import Callable, Any, Optional, Awaitable, Dict from enum import Enum import time -from pydantic import BaseModel, Field +from dataclasses import dataclass -class DownloadStatusType(str, Enum): + +class DownloadStatusType(Enum): PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" ERROR = "error" -class DownloadModelStatus(BaseModel): - status: DownloadStatusType - progress_percentage: float = Field(ge=0, le=100) +@dataclass +class DownloadModelStatus(): + status: str + progress_percentage: float message: str already_existed: bool = False - class Config: - use_enum_values = True - + def __init__(self, status: DownloadStatusType, progress_percentage: float, message: str, already_existed: bool): + self.status = status.value # Store the string value of the Enum + self.progress_percentage = progress_percentage + self.message = message + self.already_existed = already_existed + def to_dict(self) -> Dict[str, Any]: - return self.model_dump() + return { + "status": self.status, + "progress_percentage": self.progress_percentage, + "message": self.message, + "already_existed": self.already_existed + } async def download_model(model_download_request: Callable[[str], Awaitable[aiohttp.ClientResponse]], model_name: str, @@ -55,10 +65,10 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht """ if not validate_model_subdirectory(model_sub_directory): return DownloadModelStatus( - status=DownloadStatusType.ERROR, - progress_percentage=0, - message="Invalid model subdirectory", - already_existed=False + DownloadStatusType.ERROR, + 0, + "Invalid model subdirectory", + False ) file_path, relative_path = create_model_path(model_name, model_sub_directory, models_dir) @@ -67,25 +77,16 @@ async def download_model(model_download_request: Callable[[str], Awaitable[aioht return existing_file try: - status = DownloadModelStatus(status=DownloadStatusType.PENDING, - progress_percentage=0, - message=f"Starting download of {model_name}", - already_existed=False) + status = DownloadModelStatus(DownloadStatusType.PENDING, 0, f"Starting download of {model_name}", False) await progress_callback(relative_path, status) response = await model_download_request(model_url) if response.status != 200: error_message = f"Failed to download {model_name}. Status code: {response.status}" logging.error(error_message) - status = DownloadModelStatus(status=DownloadStatusType.ERROR, - progress_percentage= 0, - message=error_message, - already_existed= False) + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) - return DownloadModelStatus(status=DownloadStatusType.ERROR, - progress_percentage=0, - message= error_message, - already_existed=False) + return DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) return await track_download_progress(response, file_path, model_name, progress_callback, relative_path, progress_interval) @@ -106,11 +107,7 @@ async def check_file_exists(file_path: str, progress_callback: Callable[[str, DownloadModelStatus], Awaitable[Any]], relative_path: str) -> Optional[DownloadModelStatus]: if os.path.exists(file_path): - status = DownloadModelStatus( - status=DownloadStatusType.COMPLETED, - progress_percentage=100, - message= f"{model_name} already exists", - already_existed=True) + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"{model_name} already exists", True) await progress_callback(relative_path, status) return status return None @@ -130,10 +127,7 @@ async def track_download_progress(response: aiohttp.ClientResponse, async def update_progress(): nonlocal last_update_time progress = (downloaded / total_size) * 100 if total_size > 0 else 0 - status = DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, - progress_percentage=progress, - message=f"Downloading {model_name}", - already_existed=False) + status = DownloadModelStatus(DownloadStatusType.IN_PROGRESS, progress, f"Downloading {model_name}", False) await progress_callback(relative_path, status) last_update_time = time.time() @@ -153,11 +147,7 @@ async def update_progress(): await update_progress() logging.info(f"Successfully downloaded {model_name}. Total downloaded: {downloaded}") - status = DownloadModelStatus( - status=DownloadStatusType.COMPLETED, - progress_percentage=100, - message=f"Successfully downloaded {model_name}", - already_existed=False) + status = DownloadModelStatus(DownloadStatusType.COMPLETED, 100, f"Successfully downloaded {model_name}", False) await progress_callback(relative_path, status) return status @@ -171,10 +161,7 @@ async def handle_download_error(e: Exception, progress_callback: Callable[[str, DownloadModelStatus], Any], relative_path: str) -> DownloadModelStatus: error_message = f"Error downloading {model_name}: {str(e)}" - status = DownloadModelStatus(status=DownloadStatusType.ERROR, - progress_percentage=0, - message=error_message, - already_existed=False) + status = DownloadModelStatus(DownloadStatusType.ERROR, 0, error_message, False) await progress_callback(relative_path, status) return status diff --git a/requirements.txt b/requirements.txt index ce9db8c195a..4c2c0b2b221 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ Pillow scipy tqdm psutil -pydantic~=2.8 #non essential dependencies: kornia>=0.7.1 diff --git a/tests-unit/prompt_server_test/download_models_test.py b/tests-unit/prompt_server_test/download_models_test.py index f90c09a140a..26dd94d4cce 100644 --- a/tests-unit/prompt_server_test/download_models_test.py +++ b/tests-unit/prompt_server_test/download_models_test.py @@ -84,19 +84,13 @@ async def test_download_model_success(): # Check initial call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadModelStatus(status=DownloadStatusType.PENDING, - progress_percentage= 0, - message="Starting download of model.bin", - already_existed= False) + DownloadModelStatus(DownloadStatusType.PENDING, 0, "Starting download of model.bin", False) ) # Check final call mock_progress_callback.assert_any_call( 'checkpoints/model.bin', - DownloadModelStatus(status=DownloadStatusType.COMPLETED, - progress_percentage=100, - message="Successfully downloaded model.bin", - already_existed= False) + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "Successfully downloaded model.bin", False) ) # Verify file writing @@ -210,10 +204,7 @@ async def test_check_file_exists_when_file_exists(tmp_path): mock_callback.assert_called_once_with( "test/existing_model.bin", - DownloadModelStatus(status=DownloadStatusType.COMPLETED, - progress_percentage=100, - message="existing_model.bin already exists", - already_existed=True) + DownloadModelStatus(DownloadStatusType.COMPLETED, 100, "existing_model.bin already exists", already_existed=True) ) @pytest.mark.asyncio @@ -246,7 +237,7 @@ async def test_track_download_progress_no_content_length(): # Check that progress was reported even without knowing the total size mock_callback.assert_any_call( 'models/model.bin', - DownloadModelStatus(status=DownloadStatusType.IN_PROGRESS, progress_percentage= 0, message="Downloading model.bin", already_existed=False) + DownloadModelStatus(DownloadStatusType.IN_PROGRESS, 0, "Downloading model.bin", already_existed=False) ) @pytest.mark.asyncio