diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bcae1a52..f68c8bc9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -85,12 +85,19 @@ Before submitting a PR, make sure the change to pass all tests and test coverage $ pytest -vs tests/ --cov causalml/ ``` +To run tests that require tensorflow (i.e. DragonNet), make sure tensorflow is installed and include the `--runtf` option with the `pytest` command. For example: + +```bash +$ pytest --runtf -vs tests/test_dragonnet.py +``` + You can also run tests via make: ```bash $ make test ``` + ## Submission :tada: In your PR, please include: diff --git a/causalml/inference/tf/dragonnet.py b/causalml/inference/tf/dragonnet.py index cf86caaa..f8cd3ed3 100644 --- a/causalml/inference/tf/dragonnet.py +++ b/causalml/inference/tf/dragonnet.py @@ -21,6 +21,7 @@ from tensorflow.keras.layers import Dense, Concatenate from tensorflow.keras.optimizers import SGD, Adam from tensorflow.keras.regularizers import l2 +from tensorflow.keras.models import load_model from causalml.inference.tf.utils import ( dragonnet_loss_binarycross, @@ -290,3 +291,36 @@ def fit_predict(self, X, treatment, y, p=None, return_components=False): """ self.fit(X, treatment, y) return self.predict_tau(X) + + def save(self, h5_filepath): + """ + Save the dragonnet model as a H5 file. + + Args: + h5_filepath (H5 file path): H5 file path + """ + self.dragonnet.save(h5_filepath) + + def load(self, h5_filepath, ratio=1.0, dragonnet_loss=dragonnet_loss_binarycross): + """ + Load the dragonnet model from a H5 file. + + Args: + h5_filepath (H5 file path): H5 file path + ratio (float): weight assigned to the targeted regularization loss component + dragonnet_loss (function): a loss function + """ + self.dragonnet = load_model( + h5_filepath, + custom_objects={ + "EpsilonLayer": EpsilonLayer, + "dragonnet_loss_binarycross": dragonnet_loss_binarycross, + "tarreg_ATE_unbounded_domain_loss": make_tarreg_loss( + ratio=ratio, dragonnet_loss=dragonnet_loss + ), + "regression_loss": regression_loss, + "binary_classification_loss": binary_classification_loss, + "treatment_accuracy": treatment_accuracy, + "track_epsilon": track_epsilon, + }, + ) diff --git a/docs/installation.rst b/docs/installation.rst index f15e459b..32be82c4 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -96,6 +96,7 @@ with ``tensorflow``: pip install .[tf] +======= Windows ------- @@ -106,11 +107,19 @@ See content in https://github.com/uber/causalml/issues/678 Running Tests ------------- +Make sure pytest is installed before attempting to run tests. + Run all tests with: .. code-block:: bash pytest -vs tests/ --cov causalml/ - Add ``--runtf`` to run optional tensorflow tests which will be skipped by default. + +You can also run tests via make: + +.. code-block:: bash + + make test + diff --git a/tests/conftest.py b/tests/conftest.py index d9532fac..02988b2f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,3 +66,21 @@ def _generate_data(): return data yield _generate_data + + +def pytest_addoption(parser): + parser.addoption("--runtf", action="store_true", default=False, help="run tf tests") + + +def pytest_configure(config): + config.addinivalue_line("markers", "tf: mark test as tf to run") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--runtf"): + # --runtf given in cli: do not skip tf tests + return + skip_tf = pytest.mark.skip(reason="need --runtf option to run") + for item in items: + if "tf" in item.keywords: + item.add_marker(skip_tf) diff --git a/tests/test_dragonnet.py b/tests/test_dragonnet.py new file mode 100644 index 00000000..0788d475 --- /dev/null +++ b/tests/test_dragonnet.py @@ -0,0 +1,23 @@ +try: + from causalml.inference.tf import DragonNet +except ImportError: + pass +from causalml.dataset.regression import simulate_nuisance_and_easy_treatment +import shutil +import pytest + + +@pytest.mark.tf +def test_save_load_dragonnet(): + y, X, w, tau, b, e = simulate_nuisance_and_easy_treatment(n=1000) + + dragon = DragonNet(neurons_per_layer=200, targeted_reg=True, verbose=False) + dragon_ite = dragon.fit_predict(X, w, y, return_components=False) + dragon_ate = dragon_ite.mean() + dragon.save("smaug") + + smaug = DragonNet() + smaug.load("smaug") + shutil.rmtree("smaug") + + assert smaug.predict_tau(X).mean() == dragon_ate