@@ -47,31 +47,32 @@ OmniXAI includes a rich family of explanation methods integrated in a unified in
47
47
supports multiple data types (tabular data, images, texts, time-series), multiple types of ML models
48
48
(traditional ML in Scikit-learn and deep learning models in PyTorch/TensorFlow), and a range of diverse explaination
49
49
methods including "model-specific" and "model-agnostic" methods (such as feature-attribution explanation,
50
- counterfactual explanation, gradient-based explanation, etc). For practitioners, OmniXAI provides an easy-to-use
50
+ counterfactual explanation, gradient-based explanation, feature visualization, etc). For practitioners, OmniXAI provides an easy-to-use
51
51
unified interface to generate the explanations for their applications by only writing a few lines of
52
52
codes, and also a GUI dashboard for visualization for obtaining more insights about decisions.
53
53
54
54
The following table shows the supported explanation methods and features in our library.
55
- We will continue improving this library to make it more comprehensive in the future, e.g., supporting more
56
- explanation methods for vision, NLP and time-series tasks.
57
-
58
- | Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
59
- :-----------------------:| :---: | :---: |:---:| :---: | :---: | :---: | :---:
60
- | Feature analysis | NA | Global | ✅ | | | | |
61
- | Feature selection | NA | Global | ✅ | | | | |
62
- | Prediction metrics | Black box | Global | | ✅ | ✅ | ✅ | ✅ |
63
- | Partial dependence plots | Black box | Global | | ✅ | | | |
64
- | Accumulated local effects | Black box | Global | | ✅ | | | |
65
- | Sensitivity analysis | Black box | Global | | ✅ | | | |
66
- | LIME | Black box | Local | | ✅ | ✅ | ✅ | |
67
- | SHAP | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
68
- | Integrated gradient | Torch or TF | Local | | ✅ | ✅ | ✅ | |
69
- | Counterfactual | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
70
- | Contrastive explanation | Torch or TF | Local | | | ✅ | | |
71
- | Grad-CAM, Grad-CAM++ | Torch or TF | Local | | | ✅ | | |
72
- | Learning to explain | Black box | Local | | ✅ | ✅ | ✅ | |
73
- | Linear models | Linear models | Global and Local | | ✅ | | | |
74
- | Tree models | Tree models | Global and Local | | ✅ | | | |
55
+ We will continue improving this library to make it more comprehensive in the future.
56
+
57
+ | Method | Model Type | Explanation Type | EDA | Tabular | Image | Text | Timeseries |
58
+ :-------------------------:|:-------------:|:----------------:|:---:|:-------:|:-----:| :---: | :---:
59
+ | Feature analysis | NA | Global | ✅ | | | | |
60
+ | Feature selection | NA | Global | ✅ | | | | |
61
+ | Prediction metrics | Black box | Global | | ✅ | ✅ | ✅ | ✅ |
62
+ | Partial dependence plots | Black box | Global | | ✅ | | | |
63
+ | Accumulated local effects | Black box | Global | | ✅ | | | |
64
+ | Sensitivity analysis | Black box | Global | | ✅ | | | |
65
+ | Feature visualization | Torch or TF | Global | | | ✅ | | |
66
+ | LIME | Black box | Local | | ✅ | ✅ | ✅ | |
67
+ | SHAP | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
68
+ | Integrated gradient | Torch or TF | Local | | ✅ | ✅ | ✅ | |
69
+ | Counterfactual | Black box* | Local | | ✅ | ✅ | ✅ | ✅ |
70
+ | Contrastive explanation | Torch or TF | Local | | | ✅ | | |
71
+ | Grad-CAM, Grad-CAM++ | Torch or TF | Local | | | ✅ | | |
72
+ | Learning to explain | Black box | Local | | ✅ | ✅ | ✅ | |
73
+ | Linear models | Linear models | Global and Local | | ✅ | | | |
74
+ | Tree models | Tree models | Global and Local | | ✅ | | | |
75
+ | Feature maps | Torch or TF | Local | | | ✅ | | |
75
76
76
77
* SHAP* accepts black box models for tabular data, PyTorch/Tensorflow models for image data, transformer models
77
78
for text data. * Counterfactual* accepts black box models for tabular, text and time-series data, and PyTorch/Tensorflow models for
@@ -109,22 +110,29 @@ Some examples:
109
110
4 . [ Text classification] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/nlp_imdb.ipynb )
110
111
5 . [ Time-series anomaly detection] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/timeseries.ipynb )
111
112
6 . [ Vision-language tasks] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/gradcam_vlm.ipynb )
113
+ 7 . [ Ranking tasks] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/tabular/ranking.ipynb )
114
+ 8 . [ Feature visualization] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_visualization_torch.ipynb )
115
+ 9 . [ Check feature maps] ( https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/feature_map_torch.ipynb )
112
116
113
117
To get started, we recommend the linked tutorials in [ tutorials] ( https://opensource.salesforce.com/OmniXAI/latest/tutorials.html ) .
114
118
In general, we recommend using ` TabularExplainer ` , ` VisionExplainer ` ,
115
119
` NLPExplainer ` and ` TimeseriesExplainer ` for tabular, vision, NLP and time-series tasks, respectively, and using
116
120
` DataAnalyzer ` and ` PredictionAnalyzer ` for feature analysis and prediction result analysis.
117
- To generate explanations, one only needs to specify
121
+ These classes act as the factories of the individual explainers supported in OmniXAI, providing a simpler
122
+ interface to generate multiple explanations. To generate explanations, you only need to specify
118
123
119
124
- ** The ML model to explain** : e.g., a scikit-learn model, a tensorflow model, a pytorch model or a black-box prediction function.
120
125
- ** The pre-processing function** : i.e., converting raw input features into the model inputs.
121
126
- ** The post-processing function (optional)** : e.g., converting the model outputs into class probabilities.
122
127
- ** The explainers to apply** : e.g., SHAP, MACE, Grad-CAM.
123
128
129
+ Besides using these classes, you can also create a single explainer defined in the ` omnixai.explainers ` package, e.g.,
130
+ ` ShapTabular ` , ` GradCAM ` , ` IntegratedGradient ` or ` FeatureVisualizer ` .
131
+
124
132
Let's take the income prediction task as an example.
125
133
The [ dataset] ( https://archive.ics.uci.edu/ml/datasets/adult ) used in this example is for income prediction.
126
134
We recommend using data class ` Tabular ` to represent a tabular dataset. To create a ` Tabular ` instance given a pandas
127
- dataframe, one needs to specify the dataframe, the categorical feature names (if exists) and the target/label
135
+ dataframe, you need to specify the dataframe, the categorical feature names (if exists) and the target/label
128
136
column name (if exists).
129
137
130
138
``` python
@@ -152,8 +160,8 @@ for a `Tabular` instance. `TabularTransform` is a special transform designed for
152
160
By default, it converts categorical features into one-hot encoding, and keeps continuous-valued features.
153
161
The method `` transform `` of ` TabularTransform ` transforms a ` Tabular ` instance to a numpy array.
154
162
If the ` Tabular ` instance has a target/label column, the last column of the numpy array
155
- will be the target/label. One can also apply any customized preprocessing functions instead of using ` TabularTransform ` .
156
- After data preprocessing, we train a XGBoost classifier for this task.
163
+ will be the target/label. You can apply any customized preprocessing functions instead of using ` TabularTransform ` .
164
+ After data preprocessing, let's train a XGBoost classifier for this task.
157
165
158
166
``` python
159
167
from omnixai.preprocessing.tabular import TabularTransform
@@ -172,7 +180,7 @@ train_data = transformer.invert(train)
172
180
test_data = transformer.invert(test)
173
181
```
174
182
175
- To initialize ` TabularExplainer ` , we need to set the following parameters :
183
+ To initialize ` TabularExplainer ` , the following parameters need to be set :
176
184
177
185
- `` explainers `` : The names of the explainers to apply, e.g., [ "lime", "shap", "mace", "pdp"] .
178
186
- `` data `` : The data used to initialize explainers. `` data `` is the training dataset for training the
@@ -185,8 +193,8 @@ To initialize `TabularExplainer`, we need to set the following parameters:
185
193
- `` mode `` : The task type, e.g., "classification" or "regression".
186
194
187
195
The preprocessing function takes a ` Tabular ` instance as its input and outputs the processed features that
188
- the ML model consumes. In this example, we simply call `` transformer.transform `` . If one uses some customized transforms
189
- on pandas dataframes, the preprocess function has format: ` lambda z: some_transform(z.to_pd()) ` . If the output of `` model ``
196
+ the ML model consumes. In this example, we simply call `` transformer.transform `` . If you use some customized transforms
197
+ on pandas dataframes, the preprocess function has this format: ` lambda z: some_transform(z.to_pd()) ` . If the output of `` model ``
190
198
is not a numpy array, `` postprocess `` needs to be set to convert it into a numpy array.
191
199
192
200
``` python
@@ -222,7 +230,7 @@ global_explanations = explainers.explain_global(
222
230
```
223
231
224
232
Similarly, we create a ` PredictionAnalyzer ` for computing performance metrics for this classification task.
225
- To initialize ` PredictionAnalyzer ` , we set the following parameters:
233
+ To initialize ` PredictionAnalyzer ` , the following parameters need to be set :
226
234
227
235
- ` mode ` : The task type, e.g., "classification" or "regression".
228
236
- ` test_data ` : The test dataset, which should be a ` Tabular ` instance.
@@ -265,6 +273,48 @@ dashboard.show() # Launch the dashboard
265
273
After opening the Dash app in the browser, we will see a dashboard showing the explanations:
266
274
![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo.gif )
267
275
276
+ For vision tasks, the same interface is used to create explainers and generate explanations.
277
+ Let's take an image classification model as an example.
278
+
279
+ ``` python
280
+ from omnixai.explainers.vision import VisionExplainer
281
+ from omnixai.visualization.dashboard import Dashboard
282
+
283
+ explainer = VisionExplainer(
284
+ explainers = [" gradcam" , " lime" , " ig" , " ce" , " feature_visualization" ],
285
+ mode = " classification" ,
286
+ model = model, # An image classification model, e.g., ResNet50
287
+ preprocess = preprocess, # The preprocessing function
288
+ postprocess = postprocess, # The postprocessing function
289
+ params = {
290
+ # Set the target layer for GradCAM
291
+ " gradcam" : {" target_layer" : model.layer4[- 1 ]},
292
+ # Set the objective for feature visualization
293
+ " feature_visualization" :
294
+ {" objectives" : [{" layer" : model.layer4[- 3 ], " type" : " channel" , " index" : list (range (6 ))}]}
295
+ },
296
+ )
297
+ # Generate explanations of GradCAM, LIME, IG and CE
298
+ local_explanations = explainer.explain(test_img)
299
+ # Generate explanations of feature visualization
300
+ global_explanations = explainer.explain_global()
301
+ # Launch the dashboard
302
+ dashboard = Dashboard(
303
+ instances = test_img,
304
+ local_explanations = local_explanations,
305
+ global_explanations = global_explanations
306
+ )
307
+ dashboard.show()
308
+ ```
309
+
310
+ The following figure shows the dashboard of these explanations:
311
+ ![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_vision.gif )
312
+
313
+ For NLP tasks and time-series forecasting/anomaly detection, OmniXAI also provides the same interface
314
+ to generate and visualize explanations. This figure shows a dashboard example of text classification
315
+ and time-series anomaly detection:
316
+ ![ alt text] ( https://github.com/salesforce/OmniXAI/raw/main/docs/_static/demo_nlp_ts.gif )
317
+
268
318
## How to Contribute
269
319
270
320
We welcome the contribution from the open-source community to improve the library!
0 commit comments