Skip to content

Commit 3164258

Browse files
authored
Merge pull request #37 from salesforce/revise_docs
Revise docs
2 parents 4551fe3 + d9f924d commit 3164258

12 files changed

+833
-58
lines changed

README.md

Lines changed: 79 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -47,31 +47,32 @@ OmniXAI includes a rich family of explanation methods integrated in a unified in
4747
supports multiple data types (tabular data, images, texts, time-series), multiple types of ML models
4848
(traditional ML in Scikit-learn and deep learning models in PyTorch/TensorFlow), and a range of diverse explaination
4949
methods including "model-specific" and "model-agnostic" methods (such as feature-attribution explanation,
50-
counterfactual explanation, gradient-based explanation, etc). For practitioners, OmniXAI provides an easy-to-use
50+
counterfactual explanation, gradient-based explanation, feature visualization, etc). For practitioners, OmniXAI provides an easy-to-use
5151
unified interface to generate the explanations for their applications by only writing a few lines of
5252
codes, and also a GUI dashboard for visualization for obtaining more insights about decisions.
5353

5454
The following table shows the supported explanation methods and features in our library.
55-
We will continue improving this library to make it more comprehensive in the future, e.g., supporting more
56-
explanation methods for vision, NLP and time-series tasks.
57-
58-
| Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
59-
:-----------------------:| :---: | :---: |:---:| :---: | :---: | :---: | :---:
60-
| Feature analysis | NA | Global || | | | |
61-
| Feature selection | NA | Global || | | | |
62-
| Prediction metrics | Black box | Global | |||||
63-
| Partial dependence plots | Black box | Global | || | | |
64-
| Accumulated local effects | Black box | Global | || | | |
65-
| Sensitivity analysis | Black box | Global | || | | |
66-
| LIME | Black box | Local | |||| |
67-
| SHAP | Black box* | Local | |||||
68-
| Integrated gradient | Torch or TF | Local | |||| |
69-
| Counterfactual | Black box* | Local | |||||
70-
| Contrastive explanation | Torch or TF | Local | | || | |
71-
| Grad-CAM, Grad-CAM++ | Torch or TF | Local | | || | |
72-
| Learning to explain | Black box | Local | |||| |
73-
| Linear models | Linear models | Global and Local | || | | |
74-
| Tree models | Tree models | Global and Local | || | | |
55+
We will continue improving this library to make it more comprehensive in the future.
56+
57+
| Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
58+
:-------------------------:|:-------------:|:----------------:|:---:|:-------:|:-----:| :---: | :---:
59+
| Feature analysis | NA | Global || | | | |
60+
| Feature selection | NA | Global || | | | |
61+
| Prediction metrics | Black box | Global | |||||
62+
| Partial dependence plots | Black box | Global | || | | |
63+
| Accumulated local effects | Black box | Global | || | | |
64+
| Sensitivity analysis | Black box | Global | || | | |
65+
| Feature visualization | Torch or TF | Global | | || | |
66+
| LIME | Black box | Local | |||| |
67+
| SHAP | Black box* | Local | |||||
68+
| Integrated gradient | Torch or TF | Local | |||| |
69+
| Counterfactual | Black box* | Local | |||||
70+
| Contrastive explanation | Torch or TF | Local | | || | |
71+
| Grad-CAM, Grad-CAM++ | Torch or TF | Local | | || | |
72+
| Learning to explain | Black box | Local | |||| |
73+
| Linear models | Linear models | Global and Local | || | | |
74+
| Tree models | Tree models | Global and Local | || | | |
75+
| Feature maps | Torch or TF | Local | | || | |
7576

7677
*SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
7778
for text data. *Counterfactual* accepts black box models for tabular, text and time-series data, and PyTorch/Tensorflow models for
@@ -109,22 +110,29 @@ Some examples:
109110
4. [Text classification](https://github.com/salesforce/OmniXAI/blob/main/tutorials/nlp_imdb.ipynb)
110111
5. [Time-series anomaly detection](https://github.com/salesforce/OmniXAI/blob/main/tutorials/timeseries.ipynb)
111112
6. [Vision-language tasks](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/gradcam_vlm.ipynb)
113+
7. [Ranking tasks](https://github.com/salesforce/OmniXAI/blob/main/tutorials/tabular/ranking.ipynb)
114+
8. [Feature visualization](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_visualization_torch.ipynb)
115+
9. [Check feature maps](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_map_torch.ipynb)
112116

113117
To get started, we recommend the linked tutorials in [tutorials](https://opensource.salesforce.com/OmniXAI/latest/tutorials.html).
114118
In general, we recommend using `TabularExplainer`, `VisionExplainer`,
115119
`NLPExplainer` and `TimeseriesExplainer` for tabular, vision, NLP and time-series tasks, respectively, and using
116120
`DataAnalyzer` and `PredictionAnalyzer` for feature analysis and prediction result analysis.
117-
To generate explanations, one only needs to specify
121+
These classes act as the factories of the individual explainers supported in OmniXAI, providing a simpler
122+
interface to generate multiple explanations. To generate explanations, you only need to specify
118123

119124
- **The ML model to explain**: e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.
120125
- **The pre-processing function**: i.e., converting raw input features into the model inputs.
121126
- **The post-processing function (optional)**: e.g., converting the model outputs into class probabilities.
122127
- **The explainers to apply**: e.g., SHAP, MACE, Grad-CAM.
123128

129+
Besides using these classes, you can also create a single explainer defined in the `omnixai.explainers` package, e.g.,
130+
`ShapTabular`, `GradCAM`, `IntegratedGradient` or `FeatureVisualizer`.
131+
124132
Let's take the income prediction task as an example.
125133
The [dataset](https://archive.ics.uci.edu/ml/datasets/adult) used in this example is for income prediction.
126134
We recommend using data class `Tabular` to represent a tabular dataset. To create a `Tabular` instance given a pandas
127-
dataframe, one needs to specify the dataframe, the categorical feature names (if exists) and the target/label
135+
dataframe, you need to specify the dataframe, the categorical feature names (if exists) and the target/label
128136
column name (if exists).
129137

130138
```python
@@ -152,8 +160,8 @@ for a `Tabular` instance. `TabularTransform` is a special transform designed for
152160
By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features.
153161
The method ``transform`` of `TabularTransform` transforms a `Tabular` instance to a numpy array.
154162
If the `Tabular` instance has a target/label column, the last column of the numpy array
155-
will be the target/label. One can also apply any customized preprocessing functions instead of using `TabularTransform`.
156-
After data preprocessing, we train a XGBoost classifier for this task.
163+
will be the target/label. You can apply any customized preprocessing functions instead of using `TabularTransform`.
164+
After data preprocessing, let's train a XGBoost classifier for this task.
157165

158166
```python
159167
from omnixai.preprocessing.tabular import TabularTransform
@@ -172,7 +180,7 @@ train_data = transformer.invert(train)
172180
test_data = transformer.invert(test)
173181
```
174182

175-
To initialize `TabularExplainer`, we need to set the following parameters:
183+
To initialize `TabularExplainer`, the following parameters need to be set:
176184

177185
- ``explainers``: The names of the explainers to apply, e.g., ["lime", "shap", "mace", "pdp"].
178186
- ``data``: The data used to initialize explainers. ``data`` is the training dataset for training the
@@ -185,8 +193,8 @@ To initialize `TabularExplainer`, we need to set the following parameters:
185193
- ``mode``: The task type, e.g., "classification" or "regression".
186194

187195
The preprocessing function takes a `Tabular` instance as its input and outputs the processed features that
188-
the ML model consumes. In this example, we simply call ``transformer.transform``. If one uses some customized transforms
189-
on pandas dataframes, the preprocess function has format: `lambda z: some_transform(z.to_pd())`. If the output of ``model``
196+
the ML model consumes. In this example, we simply call ``transformer.transform``. If you use some customized transforms
197+
on pandas dataframes, the preprocess function has this format: `lambda z: some_transform(z.to_pd())`. If the output of ``model``
190198
is not a numpy array, ``postprocess`` needs to be set to convert it into a numpy array.
191199

192200
```python
@@ -222,7 +230,7 @@ global_explanations = explainers.explain_global(
222230
```
223231

224232
Similarly, we create a `PredictionAnalyzer` for computing performance metrics for this classification task.
225-
To initialize `PredictionAnalyzer`, we set the following parameters:
233+
To initialize `PredictionAnalyzer`, the following parameters need to be set:
226234

227235
- `mode`: The task type, e.g., "classification" or "regression".
228236
- `test_data`: The test dataset, which should be a `Tabular` instance.
@@ -265,6 +273,48 @@ dashboard.show() # Launch the dashboard
265273
After opening the Dash app in the browser, we will see a dashboard showing the explanations:
266274
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif)
267275

276+
For vision tasks, the same interface is used to create explainers and generate explanations.
277+
Let's take an image classification model as an example.
278+
279+
```python
280+
from omnixai.explainers.vision import VisionExplainer
281+
from omnixai.visualization.dashboard import Dashboard
282+
283+
explainer = VisionExplainer(
284+
explainers=["gradcam", "lime", "ig", "ce", "feature_visualization"],
285+
mode="classification",
286+
model=model, # An image classification model, e.g., ResNet50
287+
preprocess=preprocess, # The preprocessing function
288+
postprocess=postprocess, # The postprocessing function
289+
params={
290+
# Set the target layer for GradCAM
291+
"gradcam": {"target_layer": model.layer4[-1]},
292+
# Set the objective for feature visualization
293+
"feature_visualization":
294+
{"objectives": [{"layer": model.layer4[-3], "type": "channel", "index": list(range(6))}]}
295+
},
296+
)
297+
# Generate explanations of GradCAM, LIME, IG and CE
298+
local_explanations = explainer.explain(test_img)
299+
# Generate explanations of feature visualization
300+
global_explanations = explainer.explain_global()
301+
# Launch the dashboard
302+
dashboard = Dashboard(
303+
instances=test_img,
304+
local_explanations=local_explanations,
305+
global_explanations=global_explanations
306+
)
307+
dashboard.show()
308+
```
309+
310+
The following figure shows the dashboard of these explanations:
311+
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_vision.gif)
312+
313+
For NLP tasks and time-series forecasting/anomaly detection, OmniXAI also provides the same interface
314+
to generate and visualize explanations. This figure shows a dashboard example of text classification
315+
and time-series anomaly detection:
316+
![alt text](https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_nlp_ts.gif)
317+
268318
## How to Contribute
269319

270320
We welcome the contribution from the open-source community to improve the library!

docs/_static/demo.gif

1.04 MB
Loading

docs/_static/demo_nlp_ts.gif

783 KB
Loading

docs/_static/demo_vision.gif

1.04 MB
Loading

docs/index.rst

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ Prediction metrics Black box Global
5252
PDP Black box Global ✓
5353
ALE Black box Global ✓
5454
Sensitivity analysis Black box Global ✓
55+
Feature visualization Torch or TF Global ✓
5556
LIME Black box Local ✓ ✓ ✓
5657
SHAP Black box* Local ✓ ✓ ✓ ✓
5758
Integrated gradient Torch or TF Local ✓ ✓ ✓
@@ -61,6 +62,7 @@ Grad-CAM, Grad-CAM++ Torch or TF Local
6162
Learning to explain Black box Local ✓ ✓ ✓
6263
Linear models Linear models Global and Local ✓
6364
Tree models Tree models Global and Local ✓
65+
Feature maps Torch or TF Local ✓
6466
======================= ==================== ================ ============= ======= ======= ======= ==========
6567

6668
*SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
@@ -73,34 +75,35 @@ Comparison with Competitors
7375
The following table shows the comparison between our toolkit/library and other existing XAI toolkits/libraries
7476
in literature:
7577

76-
========== ==================== ======= =========== ====== ==== ====== ===== ========
77-
Data Type Method OmniXAI InterpretML AIX360 Eli5 Captum Alibi explainX
78-
========== ==================== ======= =========== ====== ==== ====== ===== ========
79-
Tabular LIME ✓ ✓ ✓ ✘ ✓ ✘ ✘
80-
\ SHAP ✓ ✓ ✓ ✘ ✓ ✓ ✓
81-
\ PDP ✓ ✓ ✘ ✘ ✘ ✘ ✘
82-
\ ALE ✓ ✘ ✘ ✘ ✘ ✓ ✘
83-
\ Sensitivity ✓ ✓ ✘ ✘ ✘ ✘ ✘
84-
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
85-
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
86-
\ Linear models ✓ ✓ ✓ ✓ ✘ ✓ ✓
87-
\ Tree models ✓ ✓ ✓ ✓ ✘ ✓ ✓
88-
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
89-
Image LIME ✓ ✘ ✘ ✘ ✓ ✘ ✘
90-
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
91-
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
92-
\ Grad-CAM, Grad-CAM++ ✓ ✘ ✘ ✓ ✓ ✘ ✘
93-
\ Contrastive ✓ ✘ ✓ ✘ ✘ ✓ ✘
94-
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
95-
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
96-
Text LIME ✓ ✘ ✘ ✓ ✓ ✘ ✘
97-
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
98-
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
99-
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
100-
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
101-
Timeseries SHAP ✓ ✘ ✘ ✘ ✘ ✘ ✘
102-
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
103-
========== ==================== ======= =========== ====== ==== ====== ===== ========
78+
========== ===================== ======= =========== ====== ==== ====== ===== ========
79+
Data Type Method OmniXAI InterpretML AIX360 Eli5 Captum Alibi explainX
80+
========== ===================== ======= =========== ====== ==== ====== ===== ========
81+
Tabular LIME ✓ ✓ ✓ ✘ ✓ ✘ ✘
82+
\ SHAP ✓ ✓ ✓ ✘ ✓ ✓ ✓
83+
\ PDP ✓ ✓ ✘ ✘ ✘ ✘ ✘
84+
\ ALE ✓ ✘ ✘ ✘ ✘ ✓ ✘
85+
\ Sensitivity ✓ ✓ ✘ ✘ ✘ ✘ ✘
86+
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
87+
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
88+
\ Linear models ✓ ✓ ✓ ✓ ✘ ✓ ✓
89+
\ Tree models ✓ ✓ ✓ ✓ ✘ ✓ ✓
90+
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
91+
Image LIME ✓ ✘ ✘ ✘ ✓ ✘ ✘
92+
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
93+
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
94+
\ Grad-CAM, Grad-CAM++ ✓ ✘ ✘ ✓ ✓ ✘ ✘
95+
\ Contrastive ✓ ✘ ✓ ✘ ✘ ✓ ✘
96+
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✓ ✘
97+
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
98+
\ Feature visualization ✓ ✘ ✘ ✘ ✘ ✘ ✘
99+
Text LIME ✓ ✘ ✘ ✓ ✓ ✘ ✘
100+
\ SHAP ✓ ✘ ✘ ✘ ✓ ✘ ✘
101+
\ Integrated gradient ✓ ✘ ✘ ✘ ✓ ✓ ✘
102+
\ L2X ✓ ✘ ✘ ✘ ✘ ✘ ✘
103+
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
104+
Timeseries SHAP ✓ ✘ ✘ ✘ ✘ ✘ ✘
105+
\ Counterfactual ✓ ✘ ✘ ✘ ✘ ✘ ✘
106+
========== ===================== ======= =========== ====== ==== ====== ===== ========
104107

105108
Installation
106109
############

0 commit comments

Comments
 (0)