diff --git a/etna/core/utils.py b/etna/core/utils.py index a76ce004c..7262caad6 100644 --- a/etna/core/utils.py +++ b/etna/core/utils.py @@ -1,11 +1,17 @@ +import hashlib import inspect import json +import os import pathlib +import warnings import zipfile from copy import deepcopy from functools import wraps from typing import Any from typing import Callable +from typing import Dict +from typing import Optional +from urllib import request from hydra_slayer import get_factory @@ -72,6 +78,116 @@ def wrapper(*args, **kwargs): return wrapper +# Known model hashes for integrity verification +# To add a hash for a model URL, download the file and compute its MD5 hash +KNOWN_MODEL_HASHES: Dict[str, str] = { + # Add known model URL -> hash mappings here + # Example: "http://example.com/model.ckpt": "abcd1234...", +} + + +def verify_file_hash(file_path: str, expected_hash: Optional[str] = None) -> bool: + """ + Verify file integrity using MD5 hash. + + Parameters + ---------- + file_path: + Path to the file to verify + expected_hash: + Expected MD5 hash. If None, verification is skipped. + + Returns + ------- + : + True if hash matches or no expected hash provided, False otherwise + """ + if expected_hash is None: + return True + + if not os.path.exists(file_path): + return False + + try: + hash_md5 = hashlib.md5() + with open(file_path, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + file_hash = hash_md5.hexdigest() + return file_hash == expected_hash + except Exception: + return False + + +def download_with_integrity_check( + url: str, + destination_path: str, + expected_hash: Optional[str] = None, + force_redownload: bool = False +) -> None: + """ + Download a file with integrity verification. + + Parameters + ---------- + url: + URL to download from + destination_path: + Local path to save the file + expected_hash: + Expected MD5 hash for verification. If None, no verification is performed. + force_redownload: + If True, download even if file exists and passes verification + + Raises + ------ + RuntimeError: + If download fails integrity check + """ + # Check if file exists and verify integrity + if os.path.exists(destination_path) and not force_redownload: + if verify_file_hash(destination_path, expected_hash): + return # File exists and is valid + else: + # File exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading from {url}" + ) + os.remove(destination_path) + + # Download the file + os.makedirs(os.path.dirname(destination_path), exist_ok=True) + request.urlretrieve(url=url, filename=destination_path) + + # Verify the downloaded file + if not verify_file_hash(destination_path, expected_hash): + if expected_hash is not None: + os.remove(destination_path) + raise RuntimeError( + f"Downloaded file from {url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) + + +def get_known_hash(url: str) -> Optional[str]: + """ + Get known hash for a URL from the registry. + + Parameters + ---------- + url: + URL to look up + + Returns + ------- + : + Known hash for the URL, or None if not found + """ + return KNOWN_MODEL_HASHES.get(url) + + def create_type_with_init_collector(type_: type) -> type: """Create type with init decorated with init_collector.""" previous_frame = inspect.stack()[1] diff --git a/etna/experimental/classification/predictability.py b/etna/experimental/classification/predictability.py index 381429de0..9d273d567 100644 --- a/etna/experimental/classification/predictability.py +++ b/etna/experimental/classification/predictability.py @@ -1,10 +1,14 @@ +import os +import warnings from typing import Dict from typing import List +from typing import Optional from urllib import request import numpy as np from sklearn.base import ClassifierMixin +from etna.core.utils import get_known_hash, verify_file_hash from etna.datasets import TSDataset from etna.experimental.classification.classification import TimeSeriesBinaryClassifier from etna.experimental.classification.feature_extraction.base import BaseTimeSeriesFeatureExtractor @@ -98,7 +102,34 @@ def download_model(model_name: str, dataset_freq: str, path: str): If the model does not exist in s3. """ url = f"http://etna-github-prod.cdn-tinkoff.ru/series_classification/22_11_2022/{dataset_freq}/{model_name}.pickle" + expected_hash = get_known_hash(url) + + # Check if file exists and verify integrity + if os.path.exists(path): + if verify_file_hash(path, expected_hash): + return # File exists and is valid (or no hash to check) + else: + # File exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local model file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading {model_name} from {url}" + ) + os.remove(path) + + # Download the file try: request.urlretrieve(url=url, filename=path) - except Exception: + + # Verify the downloaded file + if not verify_file_hash(path, expected_hash): + if expected_hash is not None: + os.remove(path) + raise RuntimeError( + f"Downloaded model file {model_name} from {url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) + except Exception as e: + if expected_hash is not None and "integrity check" in str(e): + raise # Re-raise integrity check errors raise ValueError("Model not found! Check the list of available models!") diff --git a/etna/models/nn/chronos/base.py b/etna/models/nn/chronos/base.py index 2b53ace03..39aafcc30 100644 --- a/etna/models/nn/chronos/base.py +++ b/etna/models/nn/chronos/base.py @@ -5,6 +5,7 @@ from pathlib import Path from typing import Dict from typing import List +from typing import Optional from typing import Sequence from typing import Union from urllib import request @@ -12,6 +13,7 @@ import pandas as pd from etna import SETTINGS +from etna.core.utils import get_known_hash, verify_file_hash from etna.datasets import TSDataset from etna.distributions import BaseDistribution from etna.models.base import PredictionIntervalContextRequiredAbstractModel @@ -84,18 +86,55 @@ def _is_url(self): return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://") def _download_model_from_url(self) -> str: - """Download model from url to local cache_dir.""" + """Download model from url to local cache_dir with integrity verification.""" model_file = self.path_or_url.split("/")[-1] model_dir = model_file.split(".zip")[0] full_model_path = f"{self.cache_dir}/{model_dir}" - if not os.path.exists(full_model_path): - try: - request.urlretrieve(url=self.path_or_url, filename=model_file) - - with zipfile.ZipFile(model_file, "r") as zip_ref: - zip_ref.extractall(self.cache_dir) - finally: - os.remove(model_file) + zip_file_path = f"{self.cache_dir}/{model_file}" + expected_hash = get_known_hash(self.path_or_url) + + # Check if extracted model directory exists and verify ZIP file integrity if it still exists + if os.path.exists(full_model_path): + if os.path.exists(zip_file_path): + if verify_file_hash(zip_file_path, expected_hash): + return full_model_path + else: + # ZIP file exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local model ZIP file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading from {self.path_or_url}" + ) + # Remove both ZIP and extracted directory for clean re-download + os.remove(zip_file_path) + import shutil + shutil.rmtree(full_model_path) + else: + # Extracted directory exists but no ZIP file - assume it's valid + # (ZIP was cleaned up after successful extraction) + return full_model_path + + # Download and extract the file + Path(self.cache_dir).mkdir(parents=True, exist_ok=True) + try: + request.urlretrieve(url=self.path_or_url, filename=zip_file_path) + + # Verify the downloaded file + if not verify_file_hash(zip_file_path, expected_hash): + if expected_hash is not None: + os.remove(zip_file_path) + raise RuntimeError( + f"Downloaded model file from {self.path_or_url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) + + with zipfile.ZipFile(zip_file_path, "r") as zip_ref: + zip_ref.extractall(self.cache_dir) + finally: + # Clean up ZIP file after successful extraction + if os.path.exists(zip_file_path): + os.remove(zip_file_path) + return full_model_path @property diff --git a/etna/models/nn/timesfm.py b/etna/models/nn/timesfm.py index 4ace63c0b..bd97a3129 100644 --- a/etna/models/nn/timesfm.py +++ b/etna/models/nn/timesfm.py @@ -12,6 +12,7 @@ import pandas as pd from etna import SETTINGS +from etna.core.utils import get_known_hash, verify_file_hash from etna.datasets import TSDataset from etna.distributions import BaseDistribution from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel @@ -152,12 +153,37 @@ def _is_url(self): return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://") def _download_model_from_url(self) -> str: - """Download model from url to local cache_dir.""" + """Download model from url to local cache_dir with integrity verification.""" model_file = self.path_or_url.split("/")[-1] full_model_path = f"{self.cache_dir}/{model_file}" - if not os.path.exists(full_model_path): - Path(self.cache_dir).mkdir(parents=True, exist_ok=True) - request.urlretrieve(url=self.path_or_url, filename=full_model_path) + expected_hash = get_known_hash(self.path_or_url) + + # Check if file exists and verify integrity + if os.path.exists(full_model_path): + if verify_file_hash(full_model_path, expected_hash): + return full_model_path + else: + # File exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local model file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading from {self.path_or_url}" + ) + os.remove(full_model_path) + + # Download the file + Path(self.cache_dir).mkdir(parents=True, exist_ok=True) + request.urlretrieve(url=self.path_or_url, filename=full_model_path) + + # Verify the downloaded file + if not verify_file_hash(full_model_path, expected_hash): + if expected_hash is not None: + os.remove(full_model_path) + raise RuntimeError( + f"Downloaded model file from {self.path_or_url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) + return full_model_path @property diff --git a/etna/models/utils.py b/etna/models/utils.py index 4af07e35e..42802595d 100644 --- a/etna/models/utils.py +++ b/etna/models/utils.py @@ -1,5 +1,8 @@ +import os +import warnings from typing import Optional from typing import Union +from urllib import request import pandas as pd diff --git a/etna/transforms/embeddings/models/ts2vec.py b/etna/transforms/embeddings/models/ts2vec.py index ea38c0301..adbe76352 100644 --- a/etna/transforms/embeddings/models/ts2vec.py +++ b/etna/transforms/embeddings/models/ts2vec.py @@ -12,6 +12,7 @@ import numpy as np from etna import SETTINGS +from etna.core.utils import get_known_hash, verify_file_hash from etna.transforms.embeddings.models import BaseEmbeddingModel if SETTINGS.torch_required: @@ -307,16 +308,39 @@ def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = N if model_name is not None: if path is None: path = _DOWNLOAD_PATH / f"{model_name}.zip" + + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip" + expected_hash = get_known_hash(url) + + # Check if file exists and verify integrity if os.path.exists(path): - warnings.warn( - f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." - ) - else: + if verify_file_hash(str(path), expected_hash): + # File exists and is valid (or no hash to check) + pass + else: + # File exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local model file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading {model_name} from {url}" + ) + os.remove(path) + + # Download if file doesn't exist (or was removed due to hash mismatch) + if not os.path.exists(path): Path(path).parent.mkdir(exist_ok=True, parents=True) if model_name in cls.list_models(): - url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip" request.urlretrieve(url=url, filename=path) + + # Verify the downloaded file + if not verify_file_hash(str(path), expected_hash): + if expected_hash is not None: + os.remove(path) + raise RuntimeError( + f"Downloaded model file {model_name} from {url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) else: raise NotImplementedError( f"Model {model_name} is not available. To get list of available models use `list_models` method." diff --git a/etna/transforms/embeddings/models/tstcc.py b/etna/transforms/embeddings/models/tstcc.py index f5b0e4665..f67c46063 100644 --- a/etna/transforms/embeddings/models/tstcc.py +++ b/etna/transforms/embeddings/models/tstcc.py @@ -12,6 +12,7 @@ import numpy as np from etna import SETTINGS +from etna.core.utils import get_known_hash, verify_file_hash from etna.transforms.embeddings.models import BaseEmbeddingModel if SETTINGS.torch_required: @@ -303,16 +304,39 @@ def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = N if model_name is not None: if path is None: path = _DOWNLOAD_PATH / f"{model_name}.zip" + + url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip" + expected_hash = get_known_hash(url) + + # Check if file exists and verify integrity if os.path.exists(path): - warnings.warn( - f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model." - ) - else: + if verify_file_hash(str(path), expected_hash): + # File exists and is valid (or no hash to check) + pass + else: + # File exists but hash doesn't match, re-download + if expected_hash is not None: + warnings.warn( + f"Local model file hash does not match expected hash. " + f"This may indicate a corrupted download. Re-downloading {model_name} from {url}" + ) + os.remove(path) + + # Download if file doesn't exist (or was removed due to hash mismatch) + if not os.path.exists(path): Path(path).parent.mkdir(exist_ok=True, parents=True) if model_name in cls.list_models(): - url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/tstcc/{model_name}.zip" request.urlretrieve(url=url, filename=path) + + # Verify the downloaded file + if not verify_file_hash(str(path), expected_hash): + if expected_hash is not None: + os.remove(path) + raise RuntimeError( + f"Downloaded model file {model_name} from {url} failed integrity check. " + f"This may indicate a network issue or corrupted download." + ) else: raise NotImplementedError( f"Model {model_name} is not available. To get list of available models use `list_models` method." diff --git a/tests/test_core/test_utils.py b/tests/test_core/test_utils.py index 583f272f2..f9b04e77d 100644 --- a/tests/test_core/test_utils.py +++ b/tests/test_core/test_utils.py @@ -1,10 +1,14 @@ +import hashlib +import os import pathlib import tempfile +from unittest.mock import mock_open, patch import pandas as pd import pytest from etna.core import load +from etna.core.utils import download_with_integrity_check, get_known_hash, verify_file_hash, KNOWN_MODEL_HASHES from etna.models import NaiveModel from etna.pipeline import Pipeline from etna.transforms import AddConstTransform @@ -43,3 +47,212 @@ def test_load_ok_with_params(example_tsds): assert new_pipeline.ts is not None assert type(new_pipeline) == type(pipeline) pd.testing.assert_frame_equal(new_pipeline.ts.to_pandas(), example_tsds.to_pandas()) + + +class TestVerifyFileHash: + def test_verify_file_hash_no_expected_hash(self): + """Test that verification returns True when no expected hash is provided.""" + with tempfile.NamedTemporaryFile() as temp_file: + result = verify_file_hash(temp_file.name, expected_hash=None) + assert result is True + + def test_verify_file_hash_file_not_exists(self): + """Test that verification returns False when file doesn't exist.""" + result = verify_file_hash("/nonexistent/file.txt", expected_hash="dummy_hash") + assert result is False + + def test_verify_file_hash_correct_hash(self): + """Test that verification returns True when hash matches.""" + test_content = b"test content" + expected_hash = hashlib.md5(test_content).hexdigest() + + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(test_content) + temp_file.flush() + + result = verify_file_hash(temp_file.name, expected_hash=expected_hash) + assert result is True + + def test_verify_file_hash_incorrect_hash(self): + """Test that verification returns False when hash doesn't match.""" + test_content = b"test content" + wrong_hash = "wrong_hash" + + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(test_content) + temp_file.flush() + + result = verify_file_hash(temp_file.name, expected_hash=wrong_hash) + assert result is False + + def test_verify_file_hash_chunked_reading(self): + """Test that chunked reading works correctly for large files.""" + # Create content larger than chunk size (4096 bytes) + test_content = b"x" * 10000 + expected_hash = hashlib.md5(test_content).hexdigest() + + with tempfile.NamedTemporaryFile() as temp_file: + temp_file.write(test_content) + temp_file.flush() + + result = verify_file_hash(temp_file.name, expected_hash=expected_hash) + assert result is True + + def test_verify_file_hash_exception_handling(self): + """Test that verification returns False when an exception occurs.""" + with patch("builtins.open", side_effect=IOError("File read error")): + result = verify_file_hash("dummy_path", expected_hash="dummy_hash") + assert result is False + + +class TestGetKnownHash: + def test_get_known_hash_existing_url(self): + """Test retrieving a known hash for an existing URL.""" + test_url = "http://example.com/model.ckpt" + test_hash = "abcd1234" + + # Temporarily add to known hashes + original_hashes = KNOWN_MODEL_HASHES.copy() + KNOWN_MODEL_HASHES[test_url] = test_hash + + try: + result = get_known_hash(test_url) + assert result == test_hash + finally: + # Restore original hashes + KNOWN_MODEL_HASHES.clear() + KNOWN_MODEL_HASHES.update(original_hashes) + + def test_get_known_hash_nonexistent_url(self): + """Test retrieving hash for a non-existent URL returns None.""" + result = get_known_hash("http://nonexistent.com/model.ckpt") + assert result is None + + +class TestDownloadWithIntegrityCheck: + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('etna.core.utils.verify_file_hash') + def test_download_file_exists_and_valid(self, mock_verify, mock_exists, mock_urlretrieve): + """Test that download is skipped when file exists and is valid.""" + mock_exists.return_value = True + mock_verify.return_value = True + + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234" + ) + + mock_urlretrieve.assert_not_called() + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('os.remove') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + def test_download_file_exists_but_invalid(self, mock_makedirs, mock_verify, mock_remove, mock_exists, mock_urlretrieve): + """Test that file is re-downloaded when existing file fails verification.""" + mock_exists.return_value = True + mock_verify.side_effect = [False, True] # First call (existing file) fails, second call (after download) succeeds + + with patch('warnings.warn') as mock_warn: + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234" + ) + + mock_remove.assert_called_once_with("/path/to/model.ckpt") + mock_urlretrieve.assert_called_once_with(url="http://example.com/model.ckpt", filename="/path/to/model.ckpt") + mock_warn.assert_called_once() + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + def test_download_file_not_exists(self, mock_makedirs, mock_verify, mock_exists, mock_urlretrieve): + """Test that file is downloaded when it doesn't exist.""" + mock_exists.return_value = False + mock_verify.return_value = True + + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234" + ) + + mock_urlretrieve.assert_called_once_with(url="http://example.com/model.ckpt", filename="/path/to/model.ckpt") + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('os.remove') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + def test_download_fails_integrity_check(self, mock_makedirs, mock_verify, mock_remove, mock_exists, mock_urlretrieve): + """Test that RuntimeError is raised when downloaded file fails integrity check.""" + mock_exists.return_value = False + mock_verify.return_value = False + + with pytest.raises(RuntimeError, match="Downloaded file from .* failed integrity check"): + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234" + ) + + mock_remove.assert_called_once_with("/path/to/model.ckpt") + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + def test_download_no_expected_hash(self, mock_makedirs, mock_verify, mock_exists, mock_urlretrieve): + """Test that download works without integrity checking when no hash is provided.""" + mock_exists.return_value = False + mock_verify.return_value = True # Should return True when no hash is provided + + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash=None + ) + + mock_urlretrieve.assert_called_once_with(url="http://example.com/model.ckpt", filename="/path/to/model.ckpt") + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + def test_force_redownload(self, mock_makedirs, mock_verify, mock_exists, mock_urlretrieve): + """Test that file is re-downloaded when force_redownload is True.""" + mock_exists.return_value = True + mock_verify.return_value = True + + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234", + force_redownload=True + ) + + mock_urlretrieve.assert_called_once_with(url="http://example.com/model.ckpt", filename="/path/to/model.ckpt") + + @patch('etna.core.utils.request.urlretrieve') + @patch('os.path.exists') + @patch('etna.core.utils.verify_file_hash') + @patch('os.makedirs') + @patch('os.path.dirname') + def test_creates_directory(self, mock_dirname, mock_makedirs, mock_verify, mock_exists, mock_urlretrieve): + """Test that destination directory is created if it doesn't exist.""" + mock_exists.return_value = False + mock_verify.return_value = True + mock_dirname.return_value = "/path/to" + + download_with_integrity_check( + url="http://example.com/model.ckpt", + destination_path="/path/to/model.ckpt", + expected_hash="abcd1234" + ) + + mock_makedirs.assert_called_once_with("/path/to", exist_ok=True)