Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for BEYOND [ICML-2024] #2489

Open
wants to merge 1 commit into
base: dev_1.19.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions art/defences/detector/evasion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,5 @@
from art.defences.detector.evasion.binary_input_detector import BinaryInputDetector
from art.defences.detector.evasion.binary_activation_detector import BinaryActivationDetector
from art.defences.detector.evasion.subsetscanning.detector import SubsetScanningDetector
from art.defences.detector.evasion.beyond_detector import BeyondDetector

163 changes: 163 additions & 0 deletions art/defences/detector/evasion/beyond_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024

#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""
This module implements the abstract base class for all evasion detectors.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
This module implements the abstract base class for all evasion detectors.
This module implements the BEYOND detector for adversarial examples detection.
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3

"""
from __future__ import absolute_import, division, print_function, unicode_literals, annotations

import abc

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'abc' is not used.
from typing import Any

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'Any' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from typing import Any


import numpy as np

from art.defences.detector.evasion.evasion_detector import EvasionDetector

class BeyondDetector(EvasionDetector):
"""
BEYOND detector for adversarial samples detection.
This detector uses a combination of SSL and target model predictions to detect adversarial samples.
"""
Comment on lines +31 to +34
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
BEYOND detector for adversarial samples detection.
This detector uses a combination of SSL and target model predictions to detect adversarial samples.
"""
"""
BEYOND detector for adversarial samples detection.
This detector uses a combination of SSL and target model predictions to detect adversarial examples.
| Paper link: https://openreview.net/pdf?id=S4LqI6CcJ3
"""


defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"]

def __init__(self,
target_model,
ssl_model,
augmentations=None,
aug_num=50,
alpha=0.8,
K=20,
percentile=5) -> None:
Comment on lines +39 to +45
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add typing to all arguments.

"""
Initialize the BEYOND detector.

:param target_model: The target model to be protected
:param ssl_model: The self-supervised learning model used for feature extraction
:param augmentation: data augmentations for generating neighborhoods
:param aug_num: Number of augmentations to apply to each sample (default: 50)
:param alpha: Weight factor for combining label and representation similarities (default: 0.8)
:param K: Number of top similarities to consider (default: 20)
:param percentile: using to calculate the threshold
"""
super().__init__()
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Check warning on line 58 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L57-L58

Added lines #L57 - L58 were not covered by tests

self.target_model = target_model.to(self.device)
self.ssl_model = ssl_model.to(self.device)
self.aug_num = aug_num
self.alpha = alpha
self.K = K

Check warning on line 64 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L60-L64

Added lines #L60 - L64 were not covered by tests

self.backbone = ssl_model.backbone
self.classifier = ssl_model.classifier
self.projector = ssl_model.projector

Check warning on line 68 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L66-L68

Added lines #L66 - L68 were not covered by tests

self.img_augmentations = augmentations

Check warning on line 70 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L70

Added line #L70 was not covered by tests

self.percentile = percentile # determinate the threshold
self.threshold = None

Check warning on line 73 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L72-L73

Added lines #L72 - L73 were not covered by tests



def _multi_transform(self, img):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pleas add typing to all arguments.

return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1)

def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]:
"""
Calculate similarities that combining label consistency and representation similarity for given samples
:param x: Input samples
:param batch_size: Batch size for processing
:return: A report similarities
"""
Comment on lines +80 to +86
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]:
"""
Calculate similarities that combining label consistency and representation similarity for given samples
:param x: Input samples
:param batch_size: Batch size for processing
:return: A report similarities
"""
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]:
"""
Calculate similarities that combining label consistency and representation similarity for given samples
:param x: Input samples
:param batch_size: Batch size for processing
:return: A report similarities
"""

samples = torch.from_numpy(x).to(self.device)

Check warning on line 87 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L87

Added line #L87 was not covered by tests

self.target_model.eval()
self.backbone.eval()
self.classifier.eval()
self.projector.eval()

Check warning on line 92 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L89-L92

Added lines #L89 - L92 were not covered by tests

number_batch = int(math.ceil(len(samples) / batch_size))

Check warning on line 94 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L94

Added line #L94 was not covered by tests

similarities = []

Check warning on line 96 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L96

Added line #L96 was not covered by tests

with torch.no_grad():
for index in range(number_batch):
start = index * batch_size
end = min((index + 1) * batch_size, len(samples))

Check warning on line 101 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L100-L101

Added lines #L100 - L101 were not covered by tests

batch_samples = samples[start:end]
b, c, h, w = batch_samples.shape

Check warning on line 104 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L103-L104

Added lines #L103 - L104 were not covered by tests

trans_images = self._multi_transform(batch_samples).to(self.device)
ssl_backbone_out = self.backbone(batch_samples)

Check warning on line 107 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L106-L107

Added lines #L106 - L107 were not covered by tests

ssl_repre = self.projector(ssl_backbone_out)
ssl_pred = self.classifier(ssl_backbone_out)
ssl_label = torch.max(ssl_pred, -1)[1]

Check warning on line 111 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L109-L111

Added lines #L109 - L111 were not covered by tests

aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w))
aug_repre = self.projector(aug_backbone_out)
aug_pred = self.classifier(aug_backbone_out)
aug_pred = aug_pred.reshape(b, self.aug_num, -1)

Check warning on line 116 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L113-L116

Added lines #L113 - L116 were not covered by tests

sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2)
sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2)

Check warning on line 119 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L118-L119

Added lines #L118 - L119 were not covered by tests

similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy())

Check warning on line 121 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L121

Added line #L121 was not covered by tests

similarities = np.concatenate(similarities, axis=0)

Check warning on line 123 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L123

Added line #L123 was not covered by tests

return similarities

Check warning on line 125 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L125

Added line #L125 was not covered by tests


def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None:
"""
Determine a threshold that covers 95% of clean samples.
:param x: Clean sample data
:param y: Clean sample labels (not used in this method)
:param batch_size: Batch size for processing
:param nb_epochs: Number of training epochs (not used in this method)
"""
clean_similarities = self._get_metrics(x, batch_size)

Check warning on line 136 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L136

Added line #L136 was not covered by tests

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable clean_similarities is not used.

# 使用第K-1列的值来确定阈值
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a translation in English.

k_minus_one_metrics = clean_metrics[:, self.K-1]

Check warning on line 139 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L139

Added line #L139 was not covered by tests

# 计算95%分位数作为阈值
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a translation in English.

self.threshold = np.percentile(k_minus_one_metrics, self.threshold)

Check warning on line 142 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L142

Added line #L142 was not covered by tests

print(f"Threshold set to: {self.threshold}")

Check warning on line 144 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L144

Added line #L144 was not covered by tests
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace print with ART's logger.


def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]:
"""
Detect whether given samples are adversarial
:param x: Input samples
:param batch_size: Batch size for processing
:return: (report, is_adversarial):
where report containing detection results
where is_adversarial is a boolean list indicating whether samples are adversarial or not
"""
Comment on lines +147 to +154
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Detect whether given samples are adversarial
:param x: Input samples
:param batch_size: Batch size for processing
:return: (report, is_adversarial):
where report containing detection results
where is_adversarial is a boolean list indicating whether samples are adversarial or not
"""
"""
Detect whether given samples are adversarial
:param x: Input samples
:param batch_size: Batch size for processing
:return: (report, is_adversarial):
where report containing detection results
where is_adversarial is a boolean list indicating whether samples are adversarial or not
"""

if self.threshold is None:
raise ValueError("Detector has not been fitted. Call fit() before detect().")

Check warning on line 156 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L156

Added line #L156 was not covered by tests

similarities = self._get_metrics(x, batch_size)

Check warning on line 158 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L158

Added line #L158 was not covered by tests

report = similarities[:, self.K-1]
is_adversarial = report < self.threshold

Check warning on line 161 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L160-L161

Added lines #L160 - L161 were not covered by tests

return report, is_adversarial

Check warning on line 163 in art/defences/detector/evasion/beyond_detector.py

View check run for this annotation

Codecov / codecov/patch

art/defences/detector/evasion/beyond_detector.py#L163

Added line #L163 was not covered by tests
180 changes: 180 additions & 0 deletions tests/defences/detector/evasion/test_beyond_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# MIT License
#
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2024

#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the
# Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import absolute_import, division, print_function, unicode_literals

import logging
import pytest
import numpy as np

import sys

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'sys' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import sys

import os

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'os' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
import os


from art.attacks.evasion.fast_gradient import FastGradientMethod
from art.estimators.classification import PyTorchClassifier

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'PyTorchClassifier' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from art.estimators.classification import PyTorchClassifier

from art.defences.detector.evasion import BeyondDetector
from art.utils import load_dataset, get_file

Check notice

Code scanning / CodeQL

Unused import Note test

Import of 'load_dataset' is not used.
Import of 'get_file' is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from art.utils import load_dataset, get_file



from tests.utils import ARTTestException

logger = logging.getLogger(__name__)

import torch.nn as nn
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move third-part imports before the imports from tests and art.

from torchvision import models, transforms

class SimSiamWithCls(nn.Module):
'''
SimSiam with Classifier
'''
def __init__(self, arch='resnet18', feat_dim=2048, num_proj_layers=2):

super(SimSiamWithCls, self).__init__()
self.backbone = models.resnet18()
out_dim = self.backbone.fc.weight.shape[1]
self.backbone.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=2, bias=False
)
self.backbone.maxpool = nn.Identity()
self.backbone.fc = nn.Identity()
self.classifier = nn.Linear(out_dim, 10)

pred_hidden_dim = int(feat_dim / 4)

self.projector = nn.Sequential(
nn.Linear(out_dim, feat_dim, bias=False),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, feat_dim, bias=False),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
nn.Linear(feat_dim, feat_dim),
nn.BatchNorm1d(feat_dim, affine=False),
)
self.projector[6].bias.requires_grad = False

self.predictor = nn.Sequential(
nn.Linear(feat_dim, pred_hidden_dim, bias=False),
nn.BatchNorm1d(pred_hidden_dim),
nn.ReLU(),
nn.Linear(pred_hidden_dim, feat_dim),
)

def forward(self, img, im_aug1=None, im_aug2=None):

r_ori = self.backbone(img)
if im_aug1 is None and im_aug2 is None:
cls = self.classifier(r_ori)
rep = self.projector(r_ori)
return {'cls': cls, 'rep':rep}
else:

r1 = self.backbone(im_aug1)
r2 = self.backbone(im_aug2)

z1 = self.projector(r1)
z2 = self.projector(r2)
# print("shape of z:", z1.shape)

p1 = self.predictor(z1)
p2 = self.predictor(z2)
# print("shape of p:", p1.shape)

return {'z1': z1, 'z2': z2, 'p1': p1, 'p2': p2}

@pytest.fixture
def get_cifar10():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can use load_cifar10() directly and remove get_cifar10()(.

"""
Loads CIFAR10 dataset.
"""
(x_train, y_train), (x_test, y_test), min_, max_ = load_cifar10()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add import for fixture load_cifar10().

return (x_train, y_train), (x_test, y_test), min_, max_


@pytest.fixture
def get_ssl_model(weights_path):
"""
Loads the SSL model (SimSiamWithCls).
"""
model = SimSiamWithCls()
model.load_state_dict(torch.load(weights_path))
return model

@pytest.mark.only_with_platform("pytorch")
def test_beyond_detector(art_warning, get_cifar10, get_ssl_model):
try:
# Load CIFAR10 data
(x_train, y_train), (x_test, y_test), min_, max_ = get_cifar10

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable min_ is not used.

Check notice

Code scanning / CodeQL

Unused local variable Note test

Variable max_ is not used.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
(x_train, y_train), (x_test, y_test), min_, max_ = get_cifar10
(x_train, y_train), (x_test, y_test), _, _ = get_cifar10


# Load models
# Download pretrained weights from https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How large are the downloaded files? Can we store them in the ART repo?

target_model = models.resnet18()
target_model.load_state_dict(torch.load("./resnet_c10.pth"))
ssl_model = get_ssl_model()
ssl_model.load_state_dict(torch.load("./simsiam_c10.pth"))


# Generate adversarial samples
attack = FastGradientMethod(estimator=target_model, eps=0.05)
x_test_adv = attack.generate(x_test)

img_augmentations = transforms.Compose([
transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Initialize BeyondDetector
detector = BeyondDetector(
target_model=target_model,
ssl_model=ssl_model,
img_augmentation=img_augmentations,
aug_num=50,
alpha=0.8,
K=20,
percentile=5
)
Comment on lines +146 to +154

Check failure

Code scanning / CodeQL

Wrong name for an argument in a class instantiation Error test

Keyword argument 'img_augmentation' is not a supported parameter name of
BeyondDetector.__init__
.

# Fit the detector
detector.fit(x_train, y_train, batch_size=128)

# Apply detector on clean and adversarial test data
_, test_detection = detector.detect(x_test)
_, test_adv_detection = detector.detect(x_test_adv)

# Assert there is at least one true positive and negative
nb_true_positives = np.sum(test_adv_detection)
nb_true_negatives = len(test_detection) - np.sum(test_detection)
assert nb_true_positives > 0
assert nb_true_negatives > 0
Comment on lines +166 to +167
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make these assertions more accurate?


# Calculate and print detection accuracy
clean_accuracy = 1 - np.mean(test_detection)
adv_accuracy = np.mean(test_adv_detection)
print(f"Clean Detection Accuracy: {clean_accuracy:.4f}")
print(f"Adversarial Detection Accuracy: {adv_accuracy:.4f}")
Comment on lines +172 to +173
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please replace print with logger.


except ARTTestException as e:
art_warning(e)

if __name__ == "__main__":

test_beyond_detector()
Loading