Skip to content
This repository was archived by the owner on Oct 24, 2025. It is now read-only.

Commit 7b62266

Browse files
committed
Improving readability of charts for insertion and deletion
1 parent 21d6220 commit 7b62266

File tree

4 files changed

+76
-66
lines changed

4 files changed

+76
-66
lines changed

docker-compose.yml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
version: "3.8"
2+
services:
3+
trainer:
4+
env_file: .env
5+
image: "rfl/foxai-trainer:latest"
6+
build: "./"
7+
container_name: "foxai-trainer"
8+
shm_size: 32gb
9+
volumes:
10+
- ".:/FoXAI"
11+
- "${IMG_DIR}/:/home/user/Downloads"
12+
ports:
13+
- 8888:8888
14+
- 6006:6006
15+
deploy:
16+
resources:
17+
reservations:
18+
devices:
19+
- driver: nvidia
20+
count: 1
21+
capabilities: [gpu]
22+
privileged: true
23+
command: ["bash", "-c", "source /etc/bash.bashrc && jupyter lab --notebook-dir=/FoXAI --ip 0.0.0.0 --no-browser --allow-root"]

example/notebooks/deletion_insertion_metric_example.ipynb

Lines changed: 14 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
"cell_type": "markdown",
55
"metadata": {},
66
"source": [
7-
"# Example notebook for Image Classification (with imagenet-mini)"
7+
"# Example notebook for usage of Deletion/Insertion"
88
]
99
},
1010
{
1111
"cell_type": "markdown",
1212
"metadata": {},
1313
"source": [
14-
"In this tutorial, You will see how You can use explainable algorithms to study pre-trained model decision. We will take a pre-trained model, sample images and run several explainable methods."
14+
"In this tutorial, You will see how You will understand how to setup to use insertion and deletion metrics for XAI. "
1515
]
1616
},
1717
{
@@ -273,20 +273,13 @@
273273
"%matplotlib inline"
274274
]
275275
},
276-
{
277-
"cell_type": "markdown",
278-
"metadata": {},
279-
"source": [
280-
"Let's see how images from `ImageNet-Mini` looks like. We will display first few samples of dataset. In the following steps we will use them to explain model predictions using different explainable algorithms."
281-
]
282-
},
283276
{
284277
"cell_type": "markdown",
285278
"metadata": {
286279
"tags": []
287280
},
288281
"source": [
289-
"## Demo for general algorithms "
282+
"## Demo for usage of Insertion and Deletion metrics for GradCAM explanations"
290283
]
291284
},
292285
{
@@ -295,31 +288,8 @@
295288
"metadata": {},
296289
"outputs": [],
297290
"source": [
298-
"# define list of explainers we want to use\n",
299-
"# full list of supported explainers is present in `Explainers` enum class.\n",
300-
"explainer_list = [\n",
301-
" ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_GRADIENT_SHAP_EXPLAINER),\n",
302-
" ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INPUT_X_GRADIENT_EXPLAINER),\n",
303-
" ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INTEGRATED_GRADIENTS_EXPLAINER),\n",
304-
"]"
305-
]
306-
},
307-
{
308-
"cell_type": "markdown",
309-
"metadata": {
310-
"tags": []
311-
},
312-
"source": [
313-
"## Demo for layer specific algorithms"
314-
]
315-
},
316-
{
317-
"cell_type": "code",
318-
"execution_count": null,
319-
"metadata": {},
320-
"outputs": [],
321-
"source": [
322-
"layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]"
291+
"from foxai.metrics import insertion, deletion\n",
292+
"from foxai.visualizer import visualize_metric"
323293
]
324294
},
325295
{
@@ -328,8 +298,7 @@
328298
"metadata": {},
329299
"outputs": [],
330300
"source": [
331-
"from foxai.metrics import insertion, deletion\n",
332-
"from foxai.visualizer import visualize_metric"
301+
"chosen_explainer = CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER"
333302
]
334303
},
335304
{
@@ -338,7 +307,7 @@
338307
"metadata": {},
339308
"outputs": [],
340309
"source": [
341-
"type(model)"
310+
"layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]"
342311
]
343312
},
344313
{
@@ -347,27 +316,28 @@
347316
"metadata": {},
348317
"outputs": [],
349318
"source": [
350-
"\n",
351319
"# iterate over dataloader\n",
352320
"sample_batch = next(iter(val_dataloader))\n",
353321
"# iterate over all samples in batch\n",
354322
"sample, label = sample_batch[0][0], sample_batch[1][0]\n",
323+
"sample = sample.to(device)\n",
355324
"# add batch size dimension to the data sample\n",
356-
"input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)\n",
325+
"input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2])\n",
357326
"category_name = categories[label.item()]\n",
358327
"with FoXaiExplainer(\n",
359328
" model=model,\n",
360-
" explainers=[ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER, layer=layer)],\n",
329+
" explainers=[ExplainerWithParams(explainer_name=chosen_explainer, layer=layer)],\n",
361330
" target=label,\n",
362331
") as xai_model:\n",
332+
" chosen_explainer_key = f\"{chosen_explainer.name}_0\"\n",
363333
" # calculate attributes for every explainer\n",
364334
" first_output, attributes_dict = xai_model(input_data)\n",
365-
" value = attributes_dict[\"CV_LAYER_GRADCAM_EXPLAINER\"]\n",
335+
" value = attributes_dict[chosen_explainer_key]\n",
366336
" figure = mean_channels_visualization(attributions=value[0], transformed_img=sample, title= f\"Mean of channels)\")\n",
367337
" # save figure to artifact directory\n",
368338
" show_figure(figure) \n",
369339
" \n",
370-
" gradcam_maps = attributes_dict[\"CV_LAYER_GRADCAM_EXPLAINER\"]\n",
340+
" gradcam_maps = attributes_dict[chosen_explainer_key]\n",
371341
" value = gradcam_maps[0]\n",
372342
" chosen_class = first_output.argmax()\n",
373343
" insertion_result, importance_lst = insertion(value, sample, model, chosen_class)\n",
@@ -376,13 +346,6 @@
376346
" visualize_metric(importance_lst, deletion_result, metric_type=\"Deletion\")\n",
377347
" \n"
378348
]
379-
},
380-
{
381-
"cell_type": "code",
382-
"execution_count": null,
383-
"metadata": {},
384-
"outputs": [],
385-
"source": []
386349
}
387350
],
388351
"metadata": {
@@ -401,7 +364,7 @@
401364
"name": "python",
402365
"nbconvert_exporter": "python",
403366
"pygments_lexer": "ipython3",
404-
"version": "3.7.13"
367+
"version": "3.8.10"
405368
},
406369
"vscode": {
407370
"interpreter": {

foxai/metrics/insertion_deletion_metrics.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -39,21 +39,25 @@ def _metric_calculation(
3939
transformed_img: torch.Tensor,
4040
model: ModelType,
4141
chosen_class: int,
42-
steps_num: int = 30,
43-
metric_type: Metrics = Metrics.INSERTION,
42+
steps_num=30,
43+
metric_type=Metrics.INSERTION,
44+
kernel=(101, 101),
4445
) -> Tuple[np.ndarray, List]:
45-
"""Calculate metric (insertion or deletion) given importance map, image, model and chosen class.
46+
"""Calculate metric (insertion or deletion) given importance map, image, model and chosen class. Implementation of both metrics
47+
(insertion and deletion) are inspired by the paper "RISE: Randomized Input Sampling for Explanation of Black-box Models"
48+
: https://arxiv.org/abs/1806.07421
4649
4750
Args:
4851
attributions: Torch Tensor corresponding to importance map.
4952
transformed_img: Torch Tensor corresponding to image.
5053
model: model which we are explaining.
5154
chosen_class: index of the class we are creating metric for.
5255
metric_type: type of metric presented using enum, supported ones are: Insertion and Deletion.
56+
kernel: define a tuple regarding the used blurring kernel. Default value is 101 to produce very blurred value.
5357
5458
Returns:
5559
metric: numerical value of chosen metric for given picture and explanation.
56-
importance_lst: list of numpy elements corresponding to confidence value at each step.
60+
importance_list: list of numpy elements corresponding to confidence value at each step.
5761
5862
Raises:
5963
AttributeError: if metric type is not enum of Metrics.INSERTION or Metrics.DELETION
@@ -74,7 +78,7 @@ def _metric_calculation(
7478
sorted_attrs: np.ndarray = np.flip(np.sort(np.unique(preprocessed_attrs)))
7579
stepped_attrs: np.ndarray = _get_stepped_attrs(sorted_attrs, steps_num)
7680

77-
importance_lst: List[np.ndarray] = []
81+
importance_list: List[Tuple[float, float]] = []
7882

7983
cuda = next(model.parameters()).is_cuda
8084
device = torch.device("cuda" if cuda else "cpu")
@@ -83,7 +87,7 @@ def _metric_calculation(
8387
removed_img_part[:] = transformed_img.mean()
8488

8589
if metric_type == Metrics.INSERTION:
86-
removed_img_part = gaussian_blur(transformed_img, (101, 101))
90+
removed_img_part = gaussian_blur(transformed_img, kernel)
8791

8892
for val in stepped_attrs:
8993
attributes_map_np: np.ndarray = np.expand_dims(
@@ -115,11 +119,17 @@ def _metric_calculation(
115119

116120
output = model(perturbed_img.unsqueeze(dim=0))
117121
softmax_output: torch.Tensor = torch.nn.functional.softmax(output)[0]
118-
importance_lst.append(softmax_output[chosen_class].detach().numpy())
122+
importance_val: float = float(
123+
softmax_output[chosen_class].detach().cpu().numpy()
124+
)
125+
importance_list.append((val, importance_val))
119126

120-
metric: np.ndarray = np.round(np.trapz(importance_lst) / len(importance_lst), 4)
127+
importance_values: List = [elem[0] for elem in importance_list]
128+
metric: np.ndarray = np.round(
129+
np.trapz(importance_values) / len(importance_values), 4
130+
)
121131

122-
return metric, importance_lst
132+
return metric, importance_list
123133

124134

125135
def deletion(
@@ -138,7 +148,7 @@ def deletion(
138148
139149
Returns:
140150
metric: numerical value of chosen metric for given picture and explanation.
141-
importance_lst: list of numpy elements corresponding to confidence value at each step.
151+
importance_list: list of numpy elements corresponding to confidence value at each step.
142152
143153
Raises:
144154
AttributeError: if metric type is not enum of Metrics.INSERTION or Metrics.DELETION
@@ -153,6 +163,7 @@ def insertion(
153163
transformed_img: torch.Tensor,
154164
model: ModelType,
155165
chosen_class: int,
166+
kernel=(101, 101),
156167
) -> Tuple[np.ndarray, List]:
157168
"""Calculate insertion metric given importance map, image, model and chosen class.
158169
@@ -164,7 +175,7 @@ def insertion(
164175
165176
Returns:
166177
metric: numerical value of chosen metric for given picture and explanation.
167-
importance_lst: list of numpy elements corresponding to confidence value at each step.
178+
importance_list: list of numpy elements corresponding to confidence value at each step.
168179
169180
Raises:
170181
AttributeError: if metric type is not enum of Metrics.INSERTION or Metrics.DELETION
@@ -175,4 +186,5 @@ def insertion(
175186
model,
176187
chosen_class,
177188
metric_type=Metrics.INSERTION,
189+
kernel=kernel,
178190
)

foxai/visualizer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -244,17 +244,29 @@ def single_channel_visualization(
244244

245245

246246
def visualize_metric(
247-
importance_lst: List[np.ndarray],
247+
importance_list: List[np.ndarray],
248248
metric_result: float,
249249
metric_type: str = "Deletion",
250250
):
251251
"""
252-
Visualize graph for Insertion or deletion metric based on which area under the curve is caluclated.
252+
Visualize chart for Insertion or deletion metric based on which area under the curve is caluclated.
253+
254+
Args:
255+
importance_list: List of temporary results used for calculating metrics like deletion or insertion.
256+
metric_result: Overall result of metric.
257+
metric_type: String name of visualized metric (currently supported are Deletion and Insertion).
253258
"""
259+
x_vals = [elem[0] for elem in importance_list]
260+
y_vals = [elem[1] for elem in importance_list]
254261
plt.ylim((0, 1))
255-
plt.xlim((0, len(importance_lst)))
256-
plt.plot(np.arange(len(importance_lst)), importance_lst)
262+
plt.xlim((0, 1))
263+
plt.plot(x_vals, y_vals)
257264
plt.title(f"{metric_type}: {metric_result}")
265+
plt.ylabel("Probablity of predicting chosen class")
266+
if metric_type == "Deletion":
267+
plt.xlabel("Percentage of image removed")
268+
else:
269+
plt.xlabel("Percentage of image revealed")
258270
plt.show()
259271

260272

0 commit comments

Comments
 (0)