Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions etna/core/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import hashlib
import inspect
import json
import os
import pathlib
import warnings
import zipfile
from copy import deepcopy
from functools import wraps
from typing import Any
from typing import Callable
from typing import Dict
from typing import Optional
from urllib import request

from hydra_slayer import get_factory

Expand Down Expand Up @@ -72,6 +78,116 @@ def wrapper(*args, **kwargs):
return wrapper


# Known model hashes for integrity verification
# To add a hash for a model URL, download the file and compute its MD5 hash
KNOWN_MODEL_HASHES: Dict[str, str] = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not okay to add this here. Now utils in some sense depends on models.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed! Moved all file integrity utilities from etna.core.utils to etna.models.utils to maintain proper separation of concerns. Core utilities now remain general-purpose and domain-agnostic, while model-specific integrity checking stays with the models where it belongs. Commit: 073898c

# Add known model URL -> hash mappings here
# Example: "http://example.com/model.ckpt": "abcd1234...",
}


def verify_file_hash(file_path: str, expected_hash: Optional[str] = None) -> bool:
"""
Verify file integrity using MD5 hash.

Parameters
----------
file_path:
Path to the file to verify
expected_hash:
Expected MD5 hash. If None, verification is skipped.

Returns
-------
:
True if hash matches or no expected hash provided, False otherwise
"""
if expected_hash is None:
return True

if not os.path.exists(file_path):
return False

try:
hash_md5 = hashlib.md5()
with open(file_path, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
file_hash = hash_md5.hexdigest()
return file_hash == expected_hash
except Exception:
return False


def download_with_integrity_check(
url: str,
destination_path: str,
expected_hash: Optional[str] = None,
force_redownload: bool = False
) -> None:
"""
Download a file with integrity verification.

Parameters
----------
url:
URL to download from
destination_path:
Local path to save the file
expected_hash:
Expected MD5 hash for verification. If None, no verification is performed.
force_redownload:
If True, download even if file exists and passes verification

Raises
------
RuntimeError:
If download fails integrity check
"""
# Check if file exists and verify integrity
if os.path.exists(destination_path) and not force_redownload:
if verify_file_hash(destination_path, expected_hash):
return # File exists and is valid
else:
# File exists but hash doesn't match, re-download
if expected_hash is not None:
warnings.warn(
f"Local file hash does not match expected hash. "
f"This may indicate a corrupted download. Re-downloading from {url}"
)
os.remove(destination_path)

# Download the file
os.makedirs(os.path.dirname(destination_path), exist_ok=True)
request.urlretrieve(url=url, filename=destination_path)

# Verify the downloaded file
if not verify_file_hash(destination_path, expected_hash):
if expected_hash is not None:
os.remove(destination_path)
raise RuntimeError(
f"Downloaded file from {url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)


def get_known_hash(url: str) -> Optional[str]:
"""
Get known hash for a URL from the registry.

Parameters
----------
url:
URL to look up

Returns
-------
:
Known hash for the URL, or None if not found
"""
return KNOWN_MODEL_HASHES.get(url)


def create_type_with_init_collector(type_: type) -> type:
"""Create type with init decorated with init_collector."""
previous_frame = inspect.stack()[1]
Expand Down
33 changes: 32 additions & 1 deletion etna/experimental/classification/predictability.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import os
import warnings
from typing import Dict
from typing import List
from typing import Optional
from urllib import request

import numpy as np
from sklearn.base import ClassifierMixin

from etna.core.utils import get_known_hash, verify_file_hash
from etna.datasets import TSDataset
from etna.experimental.classification.classification import TimeSeriesBinaryClassifier
from etna.experimental.classification.feature_extraction.base import BaseTimeSeriesFeatureExtractor
Expand Down Expand Up @@ -98,7 +102,34 @@ def download_model(model_name: str, dataset_freq: str, path: str):
If the model does not exist in s3.
"""
url = f"http://etna-github-prod.cdn-tinkoff.ru/series_classification/22_11_2022/{dataset_freq}/{model_name}.pickle"
expected_hash = get_known_hash(url)

# Check if file exists and verify integrity
if os.path.exists(path):
if verify_file_hash(path, expected_hash):
return # File exists and is valid (or no hash to check)
else:
# File exists but hash doesn't match, re-download
if expected_hash is not None:
warnings.warn(
f"Local model file hash does not match expected hash. "
f"This may indicate a corrupted download. Re-downloading {model_name} from {url}"
)
os.remove(path)

# Download the file
try:
request.urlretrieve(url=url, filename=path)
except Exception:

# Verify the downloaded file
if not verify_file_hash(path, expected_hash):
if expected_hash is not None:
os.remove(path)
raise RuntimeError(
f"Downloaded model file {model_name} from {url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)
except Exception as e:
if expected_hash is not None and "integrity check" in str(e):
Comment on lines +128 to +133
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String matching on exception messages is fragile and error-prone. Consider creating a custom exception class for integrity check failures or checking the exception type instead.

Suggested change
raise RuntimeError(
f"Downloaded model file {model_name} from {url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)
except Exception as e:
if expected_hash is not None and "integrity check" in str(e):
raise IntegrityCheckError(
f"Downloaded model file {model_name} from {url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)
except Exception as e:
if expected_hash is not None and isinstance(e, IntegrityCheckError):

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

String matching on exception messages is fragile and error-prone. Consider catching specific exception types or using a custom exception class for integrity check failures instead of parsing the exception message.

Copilot uses AI. Check for mistakes.
raise # Re-raise integrity check errors
raise ValueError("Model not found! Check the list of available models!")
57 changes: 48 additions & 9 deletions etna/models/nn/chronos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from pathlib import Path
from typing import Dict
from typing import List
from typing import Optional
from typing import Sequence
from typing import Union
from urllib import request

import pandas as pd

from etna import SETTINGS
from etna.core.utils import get_known_hash, verify_file_hash
from etna.datasets import TSDataset
from etna.distributions import BaseDistribution
from etna.models.base import PredictionIntervalContextRequiredAbstractModel
Expand Down Expand Up @@ -84,18 +86,55 @@ def _is_url(self):
return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://")

def _download_model_from_url(self) -> str:
"""Download model from url to local cache_dir."""
"""Download model from url to local cache_dir with integrity verification."""
model_file = self.path_or_url.split("/")[-1]
model_dir = model_file.split(".zip")[0]
full_model_path = f"{self.cache_dir}/{model_dir}"
if not os.path.exists(full_model_path):
try:
request.urlretrieve(url=self.path_or_url, filename=model_file)

with zipfile.ZipFile(model_file, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
finally:
os.remove(model_file)
zip_file_path = f"{self.cache_dir}/{model_file}"
expected_hash = get_known_hash(self.path_or_url)

# Check if extracted model directory exists and verify ZIP file integrity if it still exists
if os.path.exists(full_model_path):
if os.path.exists(zip_file_path):
if verify_file_hash(zip_file_path, expected_hash):
return full_model_path
else:
# ZIP file exists but hash doesn't match, re-download
if expected_hash is not None:
warnings.warn(
f"Local model ZIP file hash does not match expected hash. "
f"This may indicate a corrupted download. Re-downloading from {self.path_or_url}"
)
# Remove both ZIP and extracted directory for clean re-download
os.remove(zip_file_path)
import shutil
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be placed at the top of the file, not inside functions. Move import shutil to the top of the file with other imports.

Suggested change
import shutil

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Jul 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import statements should be placed at the top of the file rather than within function bodies. Move import shutil to the import section at the beginning of the file.

Suggested change
import shutil

Copilot uses AI. Check for mistakes.
shutil.rmtree(full_model_path)
else:
# Extracted directory exists but no ZIP file - assume it's valid
# (ZIP was cleaned up after successful extraction)
return full_model_path

# Download and extract the file
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
try:
request.urlretrieve(url=self.path_or_url, filename=zip_file_path)

# Verify the downloaded file
if not verify_file_hash(zip_file_path, expected_hash):
if expected_hash is not None:
os.remove(zip_file_path)
raise RuntimeError(
f"Downloaded model file from {self.path_or_url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)

with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(self.cache_dir)
finally:
# Clean up ZIP file after successful extraction
if os.path.exists(zip_file_path):
os.remove(zip_file_path)

return full_model_path

@property
Expand Down
34 changes: 30 additions & 4 deletions etna/models/nn/timesfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd

from etna import SETTINGS
from etna.core.utils import get_known_hash, verify_file_hash
from etna.datasets import TSDataset
from etna.distributions import BaseDistribution
from etna.models.base import NonPredictionIntervalContextRequiredAbstractModel
Expand Down Expand Up @@ -152,12 +153,37 @@ def _is_url(self):
return self.path_or_url.startswith("https://") or self.path_or_url.startswith("http://")

def _download_model_from_url(self) -> str:
"""Download model from url to local cache_dir."""
"""Download model from url to local cache_dir with integrity verification."""
model_file = self.path_or_url.split("/")[-1]
full_model_path = f"{self.cache_dir}/{model_file}"
if not os.path.exists(full_model_path):
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
request.urlretrieve(url=self.path_or_url, filename=full_model_path)
expected_hash = get_known_hash(self.path_or_url)

# Check if file exists and verify integrity
if os.path.exists(full_model_path):
if verify_file_hash(full_model_path, expected_hash):
return full_model_path
else:
# File exists but hash doesn't match, re-download
if expected_hash is not None:
warnings.warn(
f"Local model file hash does not match expected hash. "
f"This may indicate a corrupted download. Re-downloading from {self.path_or_url}"
)
os.remove(full_model_path)

# Download the file
Path(self.cache_dir).mkdir(parents=True, exist_ok=True)
request.urlretrieve(url=self.path_or_url, filename=full_model_path)

# Verify the downloaded file
if not verify_file_hash(full_model_path, expected_hash):
if expected_hash is not None:
os.remove(full_model_path)
raise RuntimeError(
f"Downloaded model file from {self.path_or_url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)

return full_model_path

@property
Expand Down
3 changes: 3 additions & 0 deletions etna/models/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import os
import warnings
from typing import Optional
from typing import Union
from urllib import request

import pandas as pd

Expand Down
34 changes: 29 additions & 5 deletions etna/transforms/embeddings/models/ts2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np

from etna import SETTINGS
from etna.core.utils import get_known_hash, verify_file_hash
from etna.transforms.embeddings.models import BaseEmbeddingModel

if SETTINGS.torch_required:
Expand Down Expand Up @@ -307,16 +308,39 @@ def load(cls, path: Optional[pathlib.Path] = None, model_name: Optional[str] = N
if model_name is not None:
if path is None:
path = _DOWNLOAD_PATH / f"{model_name}.zip"

url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip"
expected_hash = get_known_hash(url)

# Check if file exists and verify integrity
if os.path.exists(path):
warnings.warn(
f"Path {path} already exists. Model {model_name} will not be downloaded. Loading existing local model."
)
else:
if verify_file_hash(str(path), expected_hash):
# File exists and is valid (or no hash to check)
pass
else:
# File exists but hash doesn't match, re-download
if expected_hash is not None:
warnings.warn(
f"Local model file hash does not match expected hash. "
f"This may indicate a corrupted download. Re-downloading {model_name} from {url}"
)
os.remove(path)

# Download if file doesn't exist (or was removed due to hash mismatch)
if not os.path.exists(path):
Path(path).parent.mkdir(exist_ok=True, parents=True)

if model_name in cls.list_models():
url = f"http://etna-github-prod.cdn-tinkoff.ru/embeddings/ts2vec/{model_name}.zip"
request.urlretrieve(url=url, filename=path)

# Verify the downloaded file
if not verify_file_hash(str(path), expected_hash):
if expected_hash is not None:
os.remove(path)
raise RuntimeError(
f"Downloaded model file {model_name} from {url} failed integrity check. "
f"This may indicate a network issue or corrupted download."
)
else:
raise NotImplementedError(
f"Model {model_name} is not available. To get list of available models use `list_models` method."
Expand Down
Loading
Loading