Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KEP-2170: Add unit and Integration tests for model and dataset initializers #2323

Merged
merged 2 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ jobs:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Run initializer_v2 integration tests for Python 3.11+
if: ${{ matrix.python-version == '3.11' }}
run: |
pip install -r ./cmd/initializer_v2/dataset/requirements.txt
pip install -U './sdk_v2'
pytest ./test/integration/initializer_v2

- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
run: |
Expand Down
11 changes: 10 additions & 1 deletion .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,13 @@ jobs:
pip install -U './sdk/python[huggingface]'
- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
run: |
pytest ./sdk/python/kubeflow/training/api/training_client_test.py
- name: Run Python unit tests for v2
run: |
pip install -U './sdk_v2'
export PYTHONPATH="${{ github.workspace }}:$PYTHONPATH"
seanlaii marked this conversation as resolved.
Show resolved Hide resolved
pytest ./pkg/initializer_v2/model
pytest ./pkg/initializer_v2/dataset
pytest ./pkg/initializer_v2/utils
2 changes: 1 addition & 1 deletion cmd/initializer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
huggingface_hub==0.23.4
huggingface-hub>=0.27.0,<0.28
23 changes: 23 additions & 0 deletions pkg/initializer_v2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

import pytest


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting dataset initialization")

try:
Expand All @@ -29,3 +30,7 @@
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions pkg/initializer_v2/dataset/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.dataset.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://dataset/path"},
{"storage_uri": "hf://dataset/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()

with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_dataset_instance.load_config()
assert (
huggingface_dataset_instance.config.storage_uri == expected["storage_uri"]
)
assert (
huggingface_dataset_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/dataset-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/dataset-v1",
},
),
],
)
def test_download_dataset(test_name, test_case):
"""Test dataset download with different configurations"""

print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()
huggingface_dataset_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_dataset_instance.download_dataset()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=DATASET_PATH,
repo_type="dataset",
)
print("Test execution completed")
71 changes: 71 additions & 0 deletions pkg/initializer_v2/dataset/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()

with patch(
"pkg.initializer_v2.dataset.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_dataset.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting pre-trained model initialization")

try:
Expand All @@ -31,3 +32,7 @@
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception


if __name__ == "__main__":
main()
93 changes: 93 additions & 0 deletions pkg/initializer_v2/model/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import MODEL_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.model.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://model/path", "access_token": "test_token"},
{"storage_uri": "hf://model/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://model/path"},
{"storage_uri": "hf://model/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_model_instance.load_config()
assert huggingface_model_instance.config.storage_uri == expected["storage_uri"]
assert (
huggingface_model_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/model-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/model-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/model-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/model-v1",
},
),
],
)
def test_download_model(test_name, test_case):
"""Test model download with different configurations"""

print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
huggingface_model_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_model_instance.download_model()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=MODEL_PATH,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
)
print("Test execution completed")
Loading
Loading