From 9fd905a5881b37043bad9a6af61c26524eb1a0bf Mon Sep 17 00:00:00 2001 From: Erik Date: Fri, 9 Feb 2024 22:18:20 -0800 Subject: [PATCH 1/2] Remove maq dependency --- causalml/metrics/__init__.py | 1 - pyproject.toml | 1 - tests/test_metrics.py | 17 ----------------- 3 files changed, 19 deletions(-) diff --git a/causalml/metrics/__init__.py b/causalml/metrics/__init__.py index 2986680b..0bc9d187 100644 --- a/causalml/metrics/__init__.py +++ b/causalml/metrics/__init__.py @@ -32,4 +32,3 @@ SensitivitySubsetData, SensitivitySelectionBias, ) # noqa -from maq import MAQ, get_ipw_scores # noqa diff --git a/pyproject.toml b/pyproject.toml index 48e0bc69..de517a31 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,7 +45,6 @@ dependencies = [ "torch", "pyro-ppl", "graphviz", - "maq@git+https://github.com/grf-labs/maq.git@py0.2.2#egg=maq&subdirectory=python-package", ] [project.optional-dependencies] diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 8eb5c3d6..7448c491 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,8 +1,6 @@ import pandas as pd -import numpy as np from numpy import isclose from causalml.metrics.visualize import qini_score -from causalml.metrics import MAQ, get_ipw_scores def test_qini_score(): @@ -28,18 +26,3 @@ def test_qini_score(): # for each learner, its qini score should stay same no matter calling with another model or calling separately assert isclose(full_result["learner_1"], learner_1_result["learner_1"]) assert isclose(full_result["learner_2"], learner_2_result["learner_2"]) - - -def test_MAQ(): - np.random.seed(42) - n = 1000 - K = 5 - tau_hat = np.random.randn(n, K) - cost = np.random.rand(n, K) - DR_scores = np.random.randn(n, K) - - mq = MAQ(n_bootstrap=200) - mq.fit(tau_hat, cost, DR_scores) - - # (0.005729002695991717, 0.019814651108894354) - assert isclose(mq.average_gain(spend=0.1)[0], 0.005729) From 892bcbac7796c272cdcd59e450876b05ec4827c8 Mon Sep 17 00:00:00 2001 From: Erik Date: Fri, 9 Feb 2024 22:45:15 -0800 Subject: [PATCH 2/2] Update MAQ notebook --- ...ini_curves_for_costly_treatment_arms.ipynb | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/docs/examples/qini_curves_for_costly_treatment_arms.ipynb b/docs/examples/qini_curves_for_costly_treatment_arms.ipynb index 89e348c2..3ccb4eeb 100644 --- a/docs/examples/qini_curves_for_costly_treatment_arms.ipynb +++ b/docs/examples/qini_curves_for_costly_treatment_arms.ipynb @@ -7,7 +7,25 @@ "source": [ "# Qini curves with multiple costly treatment arms\n", "\n", - "This notebook gives a brief overview of Qini curves for multi-armed treatment rules and a simple simulated example." + "This notebook shows approaches to evaluating multi-armed CATE estimators from `causalML` with the Multi-Armed Qini metric available in the `maq` package (available at https://github.com/grf-labs/maq).\n", + "\n", + "\n", + "This metric is a generalization of the familiar *Qini curve* to settings where we have multiple treatment arms available, and the cost of assigning treatment can vary by both unit and treatment arm according to some known cost structure. At a high level, this metric essentially allows you to quantify the value of targeting with more treatment arms by undertaking a cost-benefit exercise that uses your CATE estimates to assign the arm to the unit that is most cost-beneficial at various budget constraints.\n", + "\n", + "This notebook gives a brief overview of the statistical setup and a walkthrough with a simple simulated example. \n", + "\n", + "\n", + "To use this functionality, you first have to install the `maq` Python package from GitHub. The latest source release can be installed with:" + ] + }, + { + "cell_type": "markdown", + "id": "0a633fa7", + "metadata": {}, + "source": [ + "```\n", + "pip install \"git+https://github.com/grf-labs/maq.git#egg=maq&subdirectory=python-package\"\n", + "```" ] }, { @@ -22,7 +40,7 @@ "from xgboost import XGBRFRegressor\n", "\n", "# Generalized Qini curves\n", - "from causalml.metrics import MAQ, get_ipw_scores\n", + "from maq import MAQ, get_ipw_scores\n", "\n", "import numpy as np\n", "np.random.seed(42)"