|
4 | 4 | "cell_type": "markdown", |
5 | 5 | "metadata": {}, |
6 | 6 | "source": [ |
7 | | - "# Example notebook for Image Classification (with imagenet-mini)" |
| 7 | + "# Example notebook for usage of Deletion/Insertion" |
8 | 8 | ] |
9 | 9 | }, |
10 | 10 | { |
11 | 11 | "cell_type": "markdown", |
12 | 12 | "metadata": {}, |
13 | 13 | "source": [ |
14 | | - "In this tutorial, You will see how You can use explainable algorithms to study pre-trained model decision. We will take a pre-trained model, sample images and run several explainable methods." |
| 14 | + "In this tutorial, You will see how You will understand how to setup to use insertion and deletion metrics for XAI. " |
15 | 15 | ] |
16 | 16 | }, |
17 | 17 | { |
|
273 | 273 | "%matplotlib inline" |
274 | 274 | ] |
275 | 275 | }, |
276 | | - { |
277 | | - "cell_type": "markdown", |
278 | | - "metadata": {}, |
279 | | - "source": [ |
280 | | - "Let's see how images from `ImageNet-Mini` looks like. We will display first few samples of dataset. In the following steps we will use them to explain model predictions using different explainable algorithms." |
281 | | - ] |
282 | | - }, |
283 | 276 | { |
284 | 277 | "cell_type": "markdown", |
285 | 278 | "metadata": { |
286 | 279 | "tags": [] |
287 | 280 | }, |
288 | 281 | "source": [ |
289 | | - "## Demo for general algorithms " |
| 282 | + "## Demo for usage of Insertion and Deletion metrics for GradCAM explanations" |
290 | 283 | ] |
291 | 284 | }, |
292 | 285 | { |
|
295 | 288 | "metadata": {}, |
296 | 289 | "outputs": [], |
297 | 290 | "source": [ |
298 | | - "# define list of explainers we want to use\n", |
299 | | - "# full list of supported explainers is present in `Explainers` enum class.\n", |
300 | | - "explainer_list = [\n", |
301 | | - " ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_GRADIENT_SHAP_EXPLAINER),\n", |
302 | | - " ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INPUT_X_GRADIENT_EXPLAINER),\n", |
303 | | - " ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_INTEGRATED_GRADIENTS_EXPLAINER),\n", |
304 | | - "]" |
305 | | - ] |
306 | | - }, |
307 | | - { |
308 | | - "cell_type": "markdown", |
309 | | - "metadata": { |
310 | | - "tags": [] |
311 | | - }, |
312 | | - "source": [ |
313 | | - "## Demo for layer specific algorithms" |
314 | | - ] |
315 | | - }, |
316 | | - { |
317 | | - "cell_type": "code", |
318 | | - "execution_count": null, |
319 | | - "metadata": {}, |
320 | | - "outputs": [], |
321 | | - "source": [ |
322 | | - "layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]" |
| 291 | + "from foxai.metrics import insertion, deletion\n", |
| 292 | + "from foxai.visualizer import visualize_metric" |
323 | 293 | ] |
324 | 294 | }, |
325 | 295 | { |
|
328 | 298 | "metadata": {}, |
329 | 299 | "outputs": [], |
330 | 300 | "source": [ |
331 | | - "from foxai.metrics import insertion, deletion\n", |
332 | | - "from foxai.visualizer import visualize_metric" |
| 301 | + "chosen_explainer = CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER" |
333 | 302 | ] |
334 | 303 | }, |
335 | 304 | { |
|
338 | 307 | "metadata": {}, |
339 | 308 | "outputs": [], |
340 | 309 | "source": [ |
341 | | - "type(model)" |
| 310 | + "layer = [module for module in model.modules() if isinstance(module, torch.nn.Conv2d)][-1]" |
342 | 311 | ] |
343 | 312 | }, |
344 | 313 | { |
|
347 | 316 | "metadata": {}, |
348 | 317 | "outputs": [], |
349 | 318 | "source": [ |
350 | | - "\n", |
351 | 319 | "# iterate over dataloader\n", |
352 | 320 | "sample_batch = next(iter(val_dataloader))\n", |
353 | 321 | "# iterate over all samples in batch\n", |
354 | 322 | "sample, label = sample_batch[0][0], sample_batch[1][0]\n", |
| 323 | + "sample = sample.to(device)\n", |
355 | 324 | "# add batch size dimension to the data sample\n", |
356 | | - "input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2]).to(device)\n", |
| 325 | + "input_data = sample.reshape(1, sample.shape[0], sample.shape[1], sample.shape[2])\n", |
357 | 326 | "category_name = categories[label.item()]\n", |
358 | 327 | "with FoXaiExplainer(\n", |
359 | 328 | " model=model,\n", |
360 | | - " explainers=[ExplainerWithParams(explainer_name=CVClassificationExplainers.CV_LAYER_GRADCAM_EXPLAINER, layer=layer)],\n", |
| 329 | + " explainers=[ExplainerWithParams(explainer_name=chosen_explainer, layer=layer)],\n", |
361 | 330 | " target=label,\n", |
362 | 331 | ") as xai_model:\n", |
| 332 | + " chosen_explainer_key = f\"{chosen_explainer.name}_0\"\n", |
363 | 333 | " # calculate attributes for every explainer\n", |
364 | 334 | " first_output, attributes_dict = xai_model(input_data)\n", |
365 | | - " value = attributes_dict[\"CV_LAYER_GRADCAM_EXPLAINER\"]\n", |
| 335 | + " value = attributes_dict[chosen_explainer_key]\n", |
366 | 336 | " figure = mean_channels_visualization(attributions=value[0], transformed_img=sample, title= f\"Mean of channels)\")\n", |
367 | 337 | " # save figure to artifact directory\n", |
368 | 338 | " show_figure(figure) \n", |
369 | 339 | " \n", |
370 | | - " gradcam_maps = attributes_dict[\"CV_LAYER_GRADCAM_EXPLAINER\"]\n", |
| 340 | + " gradcam_maps = attributes_dict[chosen_explainer_key]\n", |
371 | 341 | " value = gradcam_maps[0]\n", |
372 | 342 | " chosen_class = first_output.argmax()\n", |
373 | 343 | " insertion_result, importance_lst = insertion(value, sample, model, chosen_class)\n", |
|
376 | 346 | " visualize_metric(importance_lst, deletion_result, metric_type=\"Deletion\")\n", |
377 | 347 | " \n" |
378 | 348 | ] |
379 | | - }, |
380 | | - { |
381 | | - "cell_type": "code", |
382 | | - "execution_count": null, |
383 | | - "metadata": {}, |
384 | | - "outputs": [], |
385 | | - "source": [] |
386 | 349 | } |
387 | 350 | ], |
388 | 351 | "metadata": { |
|
401 | 364 | "name": "python", |
402 | 365 | "nbconvert_exporter": "python", |
403 | 366 | "pygments_lexer": "ipython3", |
404 | | - "version": "3.7.13" |
| 367 | + "version": "3.8.10" |
405 | 368 | }, |
406 | 369 | "vscode": { |
407 | 370 | "interpreter": { |
|
0 commit comments