Skip to content

Commit 6ac467c

Browse files
authored
Merge pull request #19 from salesforce/multimodal
Multimodal
2 parents 7f76d23 + c82c056 commit 6ac467c

File tree

21 files changed

+1283
-138
lines changed

21 files changed

+1283
-138
lines changed

README.md

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -77,35 +77,9 @@ explanation methods for vision, NLP and time-series tasks.
7777
for text data. *Counterfactual* accepts black box models for tabular, text and time-series data, and PyTorch/Tensorflow models for
7878
image data.
7979

80-
The following table shows the comparison between our toolkit/library and other existing XAI toolkits/libraries
81-
in literature:
82-
83-
| Data Type | Method | OmniXAI | InterpretML | AIX360 | Eli5 | Captum | Alibi | explainX
84-
:---: |:--------------------:| :---: | :---: | :---: | :---: | :---: | :---: | :---:
85-
| Tabular | LIME |||| || | |
86-
| | SHAP |||| ||||
87-
| | PDP ||| | | | | |
88-
| | ALE || | | | || |
89-
| | Sensitivity ||| | | | | |
90-
| | Integrated gradient || | | ||| |
91-
| | Counterfactual || | | | || |
92-
| | Linear models ||||| |||
93-
| | Tree models ||||| |||
94-
| | L2X || | | | | | |
95-
| Image | LIME || | | || | |
96-
| | SHAP || | | || | |
97-
| | Integrated gradient || | | ||| |
98-
| | Grad-CAM, Grad-CAM++ || | ||| | |
99-
| | CEM || || | || |
100-
| | Counterfactual || | | | || |
101-
| | L2X || | | | | | |
102-
| Text | LIME || | ||| | |
103-
| | SHAP || | | || | |
104-
| | Integrated gradient || | | ||| |
105-
| | L2X || | | | | | |
106-
| | Counterfactual || | | | | | |
107-
| Timeseries | SHAP || | | | | | |
108-
| | Counterfactual || | | | | | |
80+
This [table](https://opensource.salesforce.com/OmniXAI/latest/index.html#comparison-with-competitors)
81+
shows the comparison between our toolkit/library and other existing XAI toolkits/libraries
82+
in literature
10983

11084
## Installation
11185

@@ -134,6 +108,7 @@ Some examples:
134108
3. [Image classification](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision.ipynb)
135109
4. [Text classification](https://github.com/salesforce/OmniXAI/blob/main/tutorials/nlp_imdb.ipynb)
136110
5. [Time-series anomaly detection](https://github.com/salesforce/OmniXAI/blob/main/tutorials/timeseries.ipynb)
111+
6. [Vision-language tasks](https://github.com/salesforce/OmniXAI/blob/main/tutorials/vision/gradcam_vlm.ipynb)
137112

138113
To get started, we recommend the linked tutorials in [tutorials](https://opensource.salesforce.com/OmniXAI/latest/tutorials.html).
139114
In general, we recommend using `TabularExplainer`, `VisionExplainer`,

omnixai/data/image.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Image(Data):
2424
data_type = "image"
2525

2626
def __init__(
27-
self, data: Union[np.ndarray, PilImage.Image] = None, batched: bool = False, channel_last: bool = True
27+
self, data: Union[np.ndarray, PilImage.Image] = None, batched: bool = False, channel_last: bool = True
2828
):
2929
"""
3030
:param data: The image data, which is either np.ndarray or PIL.Image. If ``data``
@@ -111,7 +111,7 @@ def __getitem__(self, i: Union[int, slice, list]):
111111
:rtype: Image
112112
"""
113113
if isinstance(i, int):
114-
return Image(self.data[i : i + 1], batched=True, channel_last=True)
114+
return Image(self.data[i: i + 1], batched=True, channel_last=True)
115115
else:
116116
return Image(self.data[i], batched=True, channel_last=True)
117117

omnixai/data/multi_inputs.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
"""
8+
The class for multiple inputs.
9+
"""
10+
from typing import Dict, Union
11+
from .base import Data
12+
from .tabular import Tabular
13+
from .image import Image
14+
from .text import Text
15+
16+
17+
class MultiInputs(Data):
18+
"""
19+
This data class is used for a model with multiple inputs, e.g., a visual-language model with
20+
images and texts as its inputs, or a ranking model with queries and items as its inputs.
21+
The data is stored in a dict, e.g., `{"image": Image(), "text": Text()}`.
22+
"""
23+
data_type = "timeseries"
24+
25+
def __init__(self, **inputs):
26+
"""
27+
:param inputs: Multiple input parameters, e.g., ``inputs = {"image": Image(), "text": Text()}``.
28+
"""
29+
super().__init__()
30+
num_samples = []
31+
for key, value in inputs.items():
32+
assert isinstance(value, (Tabular, Image, Text)), \
33+
f"The type of input {key} must be `Tabular`, `Image` or `Text` " \
34+
f"instead of {type(value)}."
35+
num_samples.append(value.num_samples())
36+
assert min(num_samples) == max(num_samples), \
37+
f"The numbers of samples in the inputs are different: {num_samples}."
38+
39+
for key, value in inputs.items():
40+
setattr(self, key, value)
41+
self.inputs = inputs
42+
self.nsamples = num_samples[0]
43+
44+
@property
45+
def values(self) -> Dict:
46+
"""
47+
Returns the raw values of each input.
48+
49+
:return: A dict containing the raw values for each input.
50+
"""
51+
return {key: value.values for key, value in self.inputs.items()}
52+
53+
def num_samples(self) -> int:
54+
"""
55+
Returns the number of samples in the inputs.
56+
57+
:return: The number samples in the inputs.
58+
"""
59+
return self.nsamples
60+
61+
def __contains__(self, item):
62+
return item in self.inputs
63+
64+
def __getitem__(self, i: Union[int, slice, list]):
65+
"""
66+
Get a subset of the input samples given an index or a set of indices.
67+
68+
:param i: An integer index, slice or list.
69+
:return: A MultiInputs object with the selected samples.
70+
:rtype: MultiInputs
71+
"""
72+
inputs = {key: value[i] for key, value in self.inputs.items()}
73+
return MultiInputs(**inputs)

omnixai/explainers/nlp/specific/ig.py

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(self):
3030
self.embedding_layer_inputs = None
3131

3232
def compute_integrated_gradients(
33-
self, model, embedding_layer, inputs, output_index, additional_inputs=None, steps=50
33+
self, model, embedding_layer, inputs, output_index, additional_inputs=None, steps=50, batch_size=8
3434
):
3535
import torch
3636

@@ -48,26 +48,32 @@ def compute_integrated_gradients(
4848
hooks.append(embedding_layer.register_forward_hook(self._embedding_hook))
4949
model(*all_inputs)
5050
baselines = np.zeros(self.embeddings.shape)
51+
hooks.append(embedding_layer.register_forward_hook(self._embedding_layer_hook))
5152

5253
# Build the inputs for computing integrated gradient
5354
alphas = np.linspace(start=0.0, stop=1.0, num=steps, endpoint=True)
54-
self.embedding_layer_inputs = torch.tensor(
55-
np.stack([baselines[0] + a * (self.embeddings[0] - baselines[0]) for a in alphas]),
56-
dtype=torch.get_default_dtype(),
57-
device=device,
58-
requires_grad=True,
59-
)
60-
all_inputs = self._repeat(all_inputs, num_reps=self.embedding_layer_inputs.shape[0])
55+
gradients = []
56+
for k in range(0, len(alphas), batch_size):
57+
self.embedding_layer_inputs = torch.tensor(
58+
np.stack([baselines[0] + a * (self.embeddings[0] - baselines[0])
59+
for a in alphas[k:k + batch_size]]),
60+
dtype=torch.get_default_dtype(),
61+
device=device,
62+
requires_grad=True,
63+
)
64+
repeated_inputs = self._repeat(all_inputs, num_reps=self.embedding_layer_inputs.shape[0])
6165

62-
# Compute gradients
63-
hooks.append(embedding_layer.register_forward_hook(self._embedding_layer_hook))
64-
predictions = model(*all_inputs)
65-
if len(predictions.shape) > 1:
66-
assert output_index is not None, "The model has multiple outputs, the output index cannot be None"
67-
predictions = predictions[:, output_index]
68-
gradients = (
69-
torch.autograd.grad(torch.unbind(predictions), self.embedding_layer_inputs)[0].detach().cpu().numpy()
70-
)
66+
# Compute gradients
67+
predictions = model(*repeated_inputs)
68+
if len(predictions.shape) > 1:
69+
assert output_index is not None, "The model has multiple outputs, the output index cannot be None"
70+
predictions = predictions[:, output_index]
71+
grad = (
72+
torch.autograd.grad(
73+
torch.unbind(predictions), self.embedding_layer_inputs)[0].detach().cpu().numpy()
74+
)
75+
gradients.append(grad)
76+
gradients = np.concatenate(gradients, axis=0)
7177
finally:
7278
for hook in hooks:
7379
hook.remove()
@@ -90,7 +96,7 @@ def __init__(self):
9096
self.embedding_layer_inputs = None
9197

9298
def compute_integrated_gradients(
93-
self, model, embedding_layer, inputs, output_index, additional_inputs=None, steps=50
99+
self, model, embedding_layer, inputs, output_index, additional_inputs=None, steps=50, batch_size=8
94100
):
95101
import tensorflow as tf
96102

@@ -107,22 +113,28 @@ def compute_integrated_gradients(
107113

108114
# Build the inputs for computing integrated gradient
109115
alphas = np.linspace(start=0.0, stop=1.0, num=steps, endpoint=True)
110-
self.embedding_layer_inputs = tf.convert_to_tensor(
111-
np.stack([baselines[0] + a * (self.embeddings[0] - baselines[0]) for a in alphas]),
112-
dtype=tf.keras.backend.floatx(),
113-
)
114-
all_inputs = [
115-
tf.tile(x, (self.embedding_layer_inputs.shape[0],) + (1,) * (len(x.shape) - 1)) for x in all_inputs
116-
]
117-
118116
# Compute gradients
119-
with tf.GradientTape() as tape:
120-
self._embedding_layer_hook(embedding_layer, tape)
121-
predictions = model(*all_inputs)
122-
if len(predictions.shape) > 1:
123-
assert output_index is not None, "The model has multiple outputs, the output index cannot be None"
124-
predictions = predictions[:, output_index]
125-
gradients = tape.gradient(predictions, embedding_layer.res).numpy()
117+
gradients = []
118+
for k in range(0, len(alphas), batch_size):
119+
with tf.GradientTape() as tape:
120+
self._embedding_layer_hook(embedding_layer, tape)
121+
self.embedding_layer_inputs = tf.convert_to_tensor(
122+
np.stack([baselines[0] + a * (self.embeddings[0] - baselines[0])
123+
for a in alphas[k:k + batch_size]]),
124+
dtype=tf.keras.backend.floatx(),
125+
)
126+
repeated_inputs = [
127+
tf.tile(x, (self.embedding_layer_inputs.shape[0],) + (1,) * (len(x.shape) - 1))
128+
for x in all_inputs
129+
]
130+
predictions = model(*repeated_inputs)
131+
if len(predictions.shape) > 1:
132+
assert output_index is not None, \
133+
"The model has multiple outputs, the output index cannot be None"
134+
predictions = predictions[:, output_index]
135+
grad = tape.gradient(predictions, embedding_layer.res).numpy()
136+
gradients.append(grad)
137+
gradients = np.concatenate(gradients, axis=0)
126138
finally:
127139
self._remove_hook(embedding_layer, original_call)
128140
return _calculate_integral(self.embeddings[0], baselines[0], gradients)
@@ -164,13 +176,13 @@ class IntegratedGradientText(ExplainerBase):
164176
alias = ["ig", "integrated_gradient"]
165177

166178
def __init__(
167-
self,
168-
model,
169-
embedding_layer,
170-
preprocess_function: Callable,
171-
mode: str = "classification",
172-
id2token: Dict = None,
173-
**kwargs,
179+
self,
180+
model,
181+
embedding_layer,
182+
preprocess_function: Callable,
183+
mode: str = "classification",
184+
id2token: Dict = None,
185+
**kwargs,
174186
):
175187
"""
176188
:param model: The model to explain, whose type can be `tf.keras.Model` or `torch.nn.Module`.
@@ -245,6 +257,7 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
245257
:return: The explanations for all the instances, e.g., word/token importance scores.
246258
"""
247259
steps = kwargs.get("steps", 50)
260+
batch_size = kwargs.get("batch_size", 16)
248261
explanations = WordImportance(mode=self.mode)
249262

250263
inputs = self._preprocess(X)
@@ -275,6 +288,7 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
275288
output_index=output_index,
276289
additional_inputs=None if len(inputs) == 1 else inputs[1:],
277290
steps=steps,
291+
batch_size=batch_size
278292
)
279293
tokens = inputs[0].detach().cpu().numpy() if self.model_type == "torch" else inputs[0].numpy()
280294
explanations.add(
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
from .specific.gradcam import GradCAM
8+
9+
10+
__all__ = [
11+
"GradCAM",
12+
]

omnixai/explainers/vision_language/specific/__init__.py

Whitespace-only changes.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#
2+
# Copyright (c) 2022 salesforce.com, inc.
3+
# All rights reserved.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6+
#
7+
from .gradcam import GradCAM
8+
9+
__all__ = ["GradCAM"]

0 commit comments

Comments
 (0)