Skip to content

Commit

Permalink
refactor tests
Browse files Browse the repository at this point in the history
Signed-off-by: wei-chenglai <[email protected]>
  • Loading branch information
seanlaii committed Jan 16, 2025
1 parent 3f12b9a commit f9d9ef2
Show file tree
Hide file tree
Showing 10 changed files with 69 additions and 146 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,9 @@
import os
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)
from pkg.initializer_v2.utils.utils_test import mock_env_vars # noqa: F401


@pytest.mark.parametrize(
Expand All @@ -34,7 +14,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"mock_config_error": False,
"mock_download_error": False,
"expected_error": None,
},
Expand All @@ -44,7 +23,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": None,
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
Expand All @@ -54,17 +32,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Config loading failure",
{
"storage_uri": "hf://dataset/path",
"access_token": None,
"mock_config_error": True,
"mock_download_error": False,
"expected_error": Exception,
},
Expand All @@ -74,14 +41,13 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "hf://dataset/path/error",
"access_token": None,
"mock_config_error": False,
"mock_download_error": True,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
def test_dataset_main(test_name, test_case, mock_env_vars): # noqa: F811
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

Expand All @@ -94,8 +60,6 @@ def test_dataset_main(test_name, test_case, mock_env_vars):

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()
if test_case["mock_config_error"]:
mock_hf_instance.load_config.side_effect = Exception
if test_case["mock_download_error"]:
mock_hf_instance.download_dataset.side_effect = Exception

Expand Down
Original file line number Diff line number Diff line change
@@ -1,29 +1,9 @@
import os
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.model.__main__ import main


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)
from pkg.initializer_v2.utils.utils_test import mock_env_vars # noqa: F401


@pytest.mark.parametrize(
Expand All @@ -34,7 +14,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "hf://model/path",
"access_token": "test_token",
"mock_config_error": False,
"mock_download_error": False,
"expected_error": None,
},
Expand All @@ -44,7 +23,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": None,
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
Expand All @@ -54,17 +32,6 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "invalid://model/path",
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Config loading failure",
{
"storage_uri": "hf://model/path",
"access_token": None,
"mock_config_error": True,
"mock_download_error": False,
"expected_error": Exception,
},
Expand All @@ -74,14 +41,13 @@ def _set_env_vars(**kwargs):
{
"storage_uri": "hf://model/path/error",
"access_token": None,
"mock_config_error": False,
"mock_download_error": True,
"expected_error": Exception,
},
),
],
)
def test_model_main(test_name, test_case, mock_env_vars):
def test_model_main(test_name, test_case, mock_env_vars): # noqa: F811
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

Expand All @@ -94,8 +60,6 @@ def test_model_main(test_name, test_case, mock_env_vars):

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()
if test_case["mock_config_error"]:
mock_hf_instance.load_config.side_effect = Exception
if test_case["mock_download_error"]:
mock_hf_instance.download_model.side_effect = Exception

Expand Down
Empty file added test/__init__.py
Empty file.
Empty file added test/integration/__init__.py
Empty file.
6 changes: 0 additions & 6 deletions test/integration/initializer_v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
import os
import sys

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../"))
)
37 changes: 5 additions & 32 deletions test/integration/initializer_v2/dataset_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import runpy
import shutil
import tempfile
from test.integration.initializer_v2.utils import setup_temp_path # noqa: F401
from test.integration.initializer_v2.utils import verify_downloaded_files

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils

Expand All @@ -13,34 +12,8 @@ class TestDatasetIntegration:
"""Integration tests for dataset initialization"""

@pytest.fixture(autouse=True)
def setup_teardown(self, monkeypatch):
"""Setup and teardown for each test"""
# Create temporary directory for dataset downloads
current_dir = os.path.dirname(os.path.abspath(__file__))
self.temp_dir = tempfile.mkdtemp(dir=current_dir)
os.environ[DATASET_PATH] = self.temp_dir

# Store original environment
self.original_env = dict(os.environ)

# Monkeypatch the constant in the module
import kubeflow.training as training

monkeypatch.setattr(training, "DATASET_PATH", self.temp_dir)

yield

# Cleanup
shutil.rmtree(self.temp_dir, ignore_errors=True)
os.environ.clear()
os.environ.update(self.original_env)

def verify_dataset_files(self, expected_files):
"""Verify downloaded dataset files"""
if expected_files:
actual_files = set(os.listdir(self.temp_dir))
missing_files = set(expected_files) - actual_files
assert not missing_files, f"Missing expected files: {missing_files}"
def setup_teardown(self, setup_temp_path): # noqa: F811
self.temp_dir = setup_temp_path("DATASET_PATH")

@pytest.mark.parametrize(
"test_name, provider, test_case",
Expand Down Expand Up @@ -97,6 +70,6 @@ def test_dataset_download(self, test_name, provider, test_case):
)
else:
runpy.run_module("pkg.initializer_v2.dataset.__main__", run_name="__main__")
self.verify_dataset_files(expected_files)
verify_downloaded_files(self.temp_dir, expected_files)

print("Test execution completed")
37 changes: 5 additions & 32 deletions test/integration/initializer_v2/model_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import os
import runpy
import shutil
import tempfile
from test.integration.initializer_v2.utils import setup_temp_path # noqa: F401
from test.integration.initializer_v2.utils import verify_downloaded_files

import pytest
from kubeflow.training import MODEL_PATH

import pkg.initializer_v2.utils.utils as utils

Expand All @@ -13,34 +12,8 @@ class TestModelIntegration:
"""Integration tests for model initialization"""

@pytest.fixture(autouse=True)
def setup_teardown(self, monkeypatch):
"""Setup and teardown for each test"""
# Create temporary directory for model downloads
current_dir = os.path.dirname(os.path.abspath(__file__))
self.temp_dir = tempfile.mkdtemp(dir=current_dir)
os.environ[MODEL_PATH] = self.temp_dir

# Store original environment
self.original_env = dict(os.environ)

# Monkeypatch the constant in the module
import kubeflow.training as training

monkeypatch.setattr(training, "MODEL_PATH", self.temp_dir)

yield

# Cleanup
shutil.rmtree(self.temp_dir, ignore_errors=True)
os.environ.clear()
os.environ.update(self.original_env)

def verify_model_files(self, expected_files):
"""Verify downloaded model files"""
if expected_files:
actual_files = set(os.listdir(self.temp_dir))
missing_files = set(expected_files) - actual_files
assert not missing_files, f"Missing expected files: {missing_files}"
def setup_teardown(self, setup_temp_path): # noqa: F811
self.temp_dir = setup_temp_path("MODEL_PATH")

@pytest.mark.parametrize(
"test_name, provider, test_case",
Expand Down Expand Up @@ -103,6 +76,6 @@ def test_model_download(self, test_name, provider, test_case):
)
else:
runpy.run_module("pkg.initializer_v2.model.__main__", run_name="__main__")
self.verify_model_files(expected_files)
verify_downloaded_files(self.temp_dir, expected_files)

print("Test execution completed")
55 changes: 55 additions & 0 deletions test/integration/initializer_v2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import shutil
import tempfile

import pytest


@pytest.fixture
def setup_temp_path(monkeypatch):
"""Creates temporary directory and patches path constant.
This fixture:
1. Creates a temporary directory
2. Allows configuration of path constant
3. Handles automatic cleanup after tests
4. Restores original environment state
Args:
monkeypatch: pytest fixture for modifying objects
Returns:
function: A configurator that accepts path_var (str) and returns temp_dir path
Usage:
def test_something(setup_temp_path):
temp_dir = setup_temp_path("MODEL_PATH")
# temp_dir is created and MODEL_PATH is patched
# cleanup happens automatically after test
"""
# Setup
original_env = dict(os.environ)
current_dir = os.path.dirname(os.path.abspath(__file__))
temp_dir = tempfile.mkdtemp(dir=current_dir)

def configure_path(path_var: str):
"""Configure path variable in kubeflow.training"""
import kubeflow.training as training

monkeypatch.setattr(training, path_var, temp_dir)
return temp_dir

yield configure_path

# Cleanup temp directory after test
shutil.rmtree(temp_dir, ignore_errors=True)
os.environ.clear()
os.environ.update(original_env)


def verify_downloaded_files(dir_path, expected_files):
"""Verify downloaded files"""
if expected_files:
actual_files = set(os.listdir(dir_path))
missing_files = set(expected_files) - actual_files
assert not missing_files, f"Missing expected files: {missing_files}"

0 comments on commit f9d9ef2

Please sign in to comment.