@@ -30,7 +30,7 @@ def __init__(self):
30
30
self .embedding_layer_inputs = None
31
31
32
32
def compute_integrated_gradients (
33
- self , model , embedding_layer , inputs , output_index , additional_inputs = None , steps = 50
33
+ self , model , embedding_layer , inputs , output_index , additional_inputs = None , steps = 50 , batch_size = 8
34
34
):
35
35
import torch
36
36
@@ -48,26 +48,32 @@ def compute_integrated_gradients(
48
48
hooks .append (embedding_layer .register_forward_hook (self ._embedding_hook ))
49
49
model (* all_inputs )
50
50
baselines = np .zeros (self .embeddings .shape )
51
+ hooks .append (embedding_layer .register_forward_hook (self ._embedding_layer_hook ))
51
52
52
53
# Build the inputs for computing integrated gradient
53
54
alphas = np .linspace (start = 0.0 , stop = 1.0 , num = steps , endpoint = True )
54
- self .embedding_layer_inputs = torch .tensor (
55
- np .stack ([baselines [0 ] + a * (self .embeddings [0 ] - baselines [0 ]) for a in alphas ]),
56
- dtype = torch .get_default_dtype (),
57
- device = device ,
58
- requires_grad = True ,
59
- )
60
- all_inputs = self ._repeat (all_inputs , num_reps = self .embedding_layer_inputs .shape [0 ])
55
+ gradients = []
56
+ for k in range (0 , len (alphas ), batch_size ):
57
+ self .embedding_layer_inputs = torch .tensor (
58
+ np .stack ([baselines [0 ] + a * (self .embeddings [0 ] - baselines [0 ])
59
+ for a in alphas [k :k + batch_size ]]),
60
+ dtype = torch .get_default_dtype (),
61
+ device = device ,
62
+ requires_grad = True ,
63
+ )
64
+ repeated_inputs = self ._repeat (all_inputs , num_reps = self .embedding_layer_inputs .shape [0 ])
61
65
62
- # Compute gradients
63
- hooks .append (embedding_layer .register_forward_hook (self ._embedding_layer_hook ))
64
- predictions = model (* all_inputs )
65
- if len (predictions .shape ) > 1 :
66
- assert output_index is not None , "The model has multiple outputs, the output index cannot be None"
67
- predictions = predictions [:, output_index ]
68
- gradients = (
69
- torch .autograd .grad (torch .unbind (predictions ), self .embedding_layer_inputs )[0 ].detach ().cpu ().numpy ()
70
- )
66
+ # Compute gradients
67
+ predictions = model (* repeated_inputs )
68
+ if len (predictions .shape ) > 1 :
69
+ assert output_index is not None , "The model has multiple outputs, the output index cannot be None"
70
+ predictions = predictions [:, output_index ]
71
+ grad = (
72
+ torch .autograd .grad (
73
+ torch .unbind (predictions ), self .embedding_layer_inputs )[0 ].detach ().cpu ().numpy ()
74
+ )
75
+ gradients .append (grad )
76
+ gradients = np .concatenate (gradients , axis = 0 )
71
77
finally :
72
78
for hook in hooks :
73
79
hook .remove ()
@@ -90,7 +96,7 @@ def __init__(self):
90
96
self .embedding_layer_inputs = None
91
97
92
98
def compute_integrated_gradients (
93
- self , model , embedding_layer , inputs , output_index , additional_inputs = None , steps = 50
99
+ self , model , embedding_layer , inputs , output_index , additional_inputs = None , steps = 50 , batch_size = 8
94
100
):
95
101
import tensorflow as tf
96
102
@@ -107,22 +113,28 @@ def compute_integrated_gradients(
107
113
108
114
# Build the inputs for computing integrated gradient
109
115
alphas = np .linspace (start = 0.0 , stop = 1.0 , num = steps , endpoint = True )
110
- self .embedding_layer_inputs = tf .convert_to_tensor (
111
- np .stack ([baselines [0 ] + a * (self .embeddings [0 ] - baselines [0 ]) for a in alphas ]),
112
- dtype = tf .keras .backend .floatx (),
113
- )
114
- all_inputs = [
115
- tf .tile (x , (self .embedding_layer_inputs .shape [0 ],) + (1 ,) * (len (x .shape ) - 1 )) for x in all_inputs
116
- ]
117
-
118
116
# Compute gradients
119
- with tf .GradientTape () as tape :
120
- self ._embedding_layer_hook (embedding_layer , tape )
121
- predictions = model (* all_inputs )
122
- if len (predictions .shape ) > 1 :
123
- assert output_index is not None , "The model has multiple outputs, the output index cannot be None"
124
- predictions = predictions [:, output_index ]
125
- gradients = tape .gradient (predictions , embedding_layer .res ).numpy ()
117
+ gradients = []
118
+ for k in range (0 , len (alphas ), batch_size ):
119
+ with tf .GradientTape () as tape :
120
+ self ._embedding_layer_hook (embedding_layer , tape )
121
+ self .embedding_layer_inputs = tf .convert_to_tensor (
122
+ np .stack ([baselines [0 ] + a * (self .embeddings [0 ] - baselines [0 ])
123
+ for a in alphas [k :k + batch_size ]]),
124
+ dtype = tf .keras .backend .floatx (),
125
+ )
126
+ repeated_inputs = [
127
+ tf .tile (x , (self .embedding_layer_inputs .shape [0 ],) + (1 ,) * (len (x .shape ) - 1 ))
128
+ for x in all_inputs
129
+ ]
130
+ predictions = model (* repeated_inputs )
131
+ if len (predictions .shape ) > 1 :
132
+ assert output_index is not None , \
133
+ "The model has multiple outputs, the output index cannot be None"
134
+ predictions = predictions [:, output_index ]
135
+ grad = tape .gradient (predictions , embedding_layer .res ).numpy ()
136
+ gradients .append (grad )
137
+ gradients = np .concatenate (gradients , axis = 0 )
126
138
finally :
127
139
self ._remove_hook (embedding_layer , original_call )
128
140
return _calculate_integral (self .embeddings [0 ], baselines [0 ], gradients )
@@ -164,13 +176,13 @@ class IntegratedGradientText(ExplainerBase):
164
176
alias = ["ig" , "integrated_gradient" ]
165
177
166
178
def __init__ (
167
- self ,
168
- model ,
169
- embedding_layer ,
170
- preprocess_function : Callable ,
171
- mode : str = "classification" ,
172
- id2token : Dict = None ,
173
- ** kwargs ,
179
+ self ,
180
+ model ,
181
+ embedding_layer ,
182
+ preprocess_function : Callable ,
183
+ mode : str = "classification" ,
184
+ id2token : Dict = None ,
185
+ ** kwargs ,
174
186
):
175
187
"""
176
188
:param model: The model to explain, whose type can be `tf.keras.Model` or `torch.nn.Module`.
@@ -245,6 +257,7 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
245
257
:return: The explanations for all the instances, e.g., word/token importance scores.
246
258
"""
247
259
steps = kwargs .get ("steps" , 50 )
260
+ batch_size = kwargs .get ("batch_size" , 16 )
248
261
explanations = WordImportance (mode = self .mode )
249
262
250
263
inputs = self ._preprocess (X )
@@ -275,6 +288,7 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
275
288
output_index = output_index ,
276
289
additional_inputs = None if len (inputs ) == 1 else inputs [1 :],
277
290
steps = steps ,
291
+ batch_size = batch_size
278
292
)
279
293
tokens = inputs [0 ].detach ().cpu ().numpy () if self .model_type == "torch" else inputs [0 ].numpy ()
280
294
explanations .add (
0 commit comments