Skip to content

Commit

Permalink
KEP-2170: Add unit and Integration tests for model and dataset initia…
Browse files Browse the repository at this point in the history
…lizers (#2323)

* KEP-2170: Add unit and integration tests for model and dataset initializers

Signed-off-by: wei-chenglai <[email protected]>

* refactor tests

Signed-off-by: wei-chenglai <[email protected]>

---------

Signed-off-by: wei-chenglai <[email protected]>
  • Loading branch information
seanlaii authored Jan 18, 2025
1 parent 6d58ea9 commit e47d8f7
Show file tree
Hide file tree
Showing 20 changed files with 628 additions and 4 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@ jobs:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}
JAX_JOB_IMAGE: kubeflow/jaxjob-dist-spmd-mnist:test

- name: Run initializer_v2 integration tests for Python 3.11+
if: ${{ matrix.python-version == '3.11' }}
run: |
pip install -r ./cmd/initializer_v2/dataset/requirements.txt
pip install -U './sdk_v2'
pytest ./test/integration/initializer_v2
- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
run: |
Expand Down
11 changes: 10 additions & 1 deletion .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,13 @@ jobs:
pip install -U './sdk/python[huggingface]'
- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
run: |
pytest ./sdk/python/kubeflow/training/api/training_client_test.py
- name: Run Python unit tests for v2
run: |
pip install -U './sdk_v2'
export PYTHONPATH="${{ github.workspace }}:$PYTHONPATH"
pytest ./pkg/initializer_v2/model
pytest ./pkg/initializer_v2/dataset
pytest ./pkg/initializer_v2/utils
2 changes: 1 addition & 1 deletion cmd/initializer_v2/model/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
huggingface_hub==0.23.4
huggingface-hub>=0.27.0,<0.28
23 changes: 23 additions & 0 deletions pkg/initializer_v2/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os

import pytest


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting dataset initialization")

try:
Expand All @@ -29,3 +30,7 @@
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception


if __name__ == "__main__":
main()
95 changes: 95 additions & 0 deletions pkg/initializer_v2/dataset/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.dataset.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://dataset/path"},
{"storage_uri": "hf://dataset/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()

with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_dataset_instance.load_config()
assert (
huggingface_dataset_instance.config.storage_uri == expected["storage_uri"]
)
assert (
huggingface_dataset_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/dataset-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/dataset-v1",
},
),
],
)
def test_download_dataset(test_name, test_case):
"""Test dataset download with different configurations"""

print(f"Running test: {test_name}")

huggingface_dataset_instance = HuggingFace()
huggingface_dataset_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_dataset_instance.download_dataset()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=DATASET_PATH,
repo_type="dataset",
)
print("Test execution completed")
71 changes: 71 additions & 0 deletions pkg/initializer_v2/dataset/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()

with patch(
"pkg.initializer_v2.dataset.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_dataset.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
Empty file.
7 changes: 6 additions & 1 deletion pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting pre-trained model initialization")

try:
Expand All @@ -31,3 +32,7 @@
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception


if __name__ == "__main__":
main()
93 changes: 93 additions & 0 deletions pkg/initializer_v2/model/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import MODEL_PATH

import pkg.initializer_v2.utils.utils as utils
from pkg.initializer_v2.model.huggingface import HuggingFace


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://model/path", "access_token": "test_token"},
{"storage_uri": "hf://model/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://model/path"},
{"storage_uri": "hf://model/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_model_instance.load_config()
assert huggingface_model_instance.config.storage_uri == expected["storage_uri"]
assert (
huggingface_model_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/model-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/model-name",
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/model-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/model-v1",
},
),
],
)
def test_download_model(test_name, test_case):
"""Test model download with different configurations"""

print(f"Running test: {test_name}")

huggingface_model_instance = HuggingFace()
huggingface_model_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Execute download
huggingface_model_instance.download_model()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=MODEL_PATH,
allow_patterns=["*.json", "*.safetensors", "*.model"],
ignore_patterns=["*.msgpack", "*.h5", "*.bin", ".pt", ".pth"],
)
print("Test execution completed")
Loading

0 comments on commit e47d8f7

Please sign in to comment.