-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added support for BEYOND [ICML-2024] #2489
base: dev_1.19.0
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,163 @@ | ||||||||||||||||||||||||||||||||||||
# MIT License | ||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023 | ||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||||||||||||||||||||||||||||||||||||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||||||||||||||||||||||||||||||||||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||||||||||||||||||||||||||||||||||||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||||||||||||||||||||||||||||||||||||
# Software. | ||||||||||||||||||||||||||||||||||||
# | ||||||||||||||||||||||||||||||||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||||||||||||||||||||||||||||||||||||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||||||||||||||||||||||||||||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||||||||||||||||||||||||||||||||||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||||||||||||||||||||||||||||||||
# SOFTWARE. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
This module implements the abstract base class for all evasion detectors. | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
from __future__ import absolute_import, division, print_function, unicode_literals, annotations | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import abc | ||||||||||||||||||||||||||||||||||||
Check notice Code scanning / CodeQL Unused import Note
Import of 'abc' is not used.
|
||||||||||||||||||||||||||||||||||||
from typing import Any | ||||||||||||||||||||||||||||||||||||
Check notice Code scanning / CodeQL Unused import Note
Import of 'Any' is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
import numpy as np | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
from art.defences.detector.evasion.evasion_detector import EvasionDetector | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
class BeyondDetector(EvasionDetector): | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
BEYOND detector for adversarial samples detection. | ||||||||||||||||||||||||||||||||||||
This detector uses a combination of SSL and target model predictions to detect adversarial samples. | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Comment on lines
+31
to
+34
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
defence_params = ["target_model", "ssl_model", "augmentations", "aug_num", "alpha", "K", "percentile"] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def __init__(self, | ||||||||||||||||||||||||||||||||||||
target_model, | ||||||||||||||||||||||||||||||||||||
ssl_model, | ||||||||||||||||||||||||||||||||||||
augmentations=None, | ||||||||||||||||||||||||||||||||||||
aug_num=50, | ||||||||||||||||||||||||||||||||||||
alpha=0.8, | ||||||||||||||||||||||||||||||||||||
K=20, | ||||||||||||||||||||||||||||||||||||
percentile=5) -> None: | ||||||||||||||||||||||||||||||||||||
Comment on lines
+39
to
+45
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add typing to all arguments. |
||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Initialize the BEYOND detector. | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
:param target_model: The target model to be protected | ||||||||||||||||||||||||||||||||||||
:param ssl_model: The self-supervised learning model used for feature extraction | ||||||||||||||||||||||||||||||||||||
:param augmentation: data augmentations for generating neighborhoods | ||||||||||||||||||||||||||||||||||||
:param aug_num: Number of augmentations to apply to each sample (default: 50) | ||||||||||||||||||||||||||||||||||||
:param alpha: Weight factor for combining label and representation similarities (default: 0.8) | ||||||||||||||||||||||||||||||||||||
:param K: Number of top similarities to consider (default: 20) | ||||||||||||||||||||||||||||||||||||
:param percentile: using to calculate the threshold | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||||||||||||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self.target_model = target_model.to(self.device) | ||||||||||||||||||||||||||||||||||||
self.ssl_model = ssl_model.to(self.device) | ||||||||||||||||||||||||||||||||||||
self.aug_num = aug_num | ||||||||||||||||||||||||||||||||||||
self.alpha = alpha | ||||||||||||||||||||||||||||||||||||
self.K = K | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self.backbone = ssl_model.backbone | ||||||||||||||||||||||||||||||||||||
self.classifier = ssl_model.classifier | ||||||||||||||||||||||||||||||||||||
self.projector = ssl_model.projector | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self.img_augmentations = augmentations | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self.percentile = percentile # determinate the threshold | ||||||||||||||||||||||||||||||||||||
self.threshold = None | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _multi_transform(self, img): | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pleas add typing to all arguments. |
||||||||||||||||||||||||||||||||||||
return torch.stack([self.img_augmentations(img) for _ in range(self.aug_num)], dim=1) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def _get_metrics(self, x: np.ndarray, batch_size: int = 128) -> tuple[dict, np.ndarray]: | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Calculate similarities that combining label consistency and representation similarity for given samples | ||||||||||||||||||||||||||||||||||||
:param x: Input samples | ||||||||||||||||||||||||||||||||||||
:param batch_size: Batch size for processing | ||||||||||||||||||||||||||||||||||||
:return: A report similarities | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Comment on lines
+80
to
+86
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
samples = torch.from_numpy(x).to(self.device) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
self.target_model.eval() | ||||||||||||||||||||||||||||||||||||
self.backbone.eval() | ||||||||||||||||||||||||||||||||||||
self.classifier.eval() | ||||||||||||||||||||||||||||||||||||
self.projector.eval() | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
number_batch = int(math.ceil(len(samples) / batch_size)) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
similarities = [] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
with torch.no_grad(): | ||||||||||||||||||||||||||||||||||||
for index in range(number_batch): | ||||||||||||||||||||||||||||||||||||
start = index * batch_size | ||||||||||||||||||||||||||||||||||||
end = min((index + 1) * batch_size, len(samples)) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
batch_samples = samples[start:end] | ||||||||||||||||||||||||||||||||||||
b, c, h, w = batch_samples.shape | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
trans_images = self._multi_transform(batch_samples).to(self.device) | ||||||||||||||||||||||||||||||||||||
ssl_backbone_out = self.backbone(batch_samples) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
ssl_repre = self.projector(ssl_backbone_out) | ||||||||||||||||||||||||||||||||||||
ssl_pred = self.classifier(ssl_backbone_out) | ||||||||||||||||||||||||||||||||||||
ssl_label = torch.max(ssl_pred, -1)[1] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
aug_backbone_out = self.backbone(trans_images.reshape(-1, c, h, w)) | ||||||||||||||||||||||||||||||||||||
aug_repre = self.projector(aug_backbone_out) | ||||||||||||||||||||||||||||||||||||
aug_pred = self.classifier(aug_backbone_out) | ||||||||||||||||||||||||||||||||||||
aug_pred = aug_pred.reshape(b, self.aug_num, -1) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
sim_repre = F.cosine_similarity(ssl_repre.unsqueeze(dim=1), aug_repre.reshape(b, self.aug_num, -1), dim=2) | ||||||||||||||||||||||||||||||||||||
sim_preds = F.cosine_similarity(F.one_hot(torch.argmax(ssl_label, dim=1), num_classes=ssl_pred.shape[-1]).unsqueeze(dim=1), aug_pred, dim=2) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
similarities.append((self.alpha * sim_preds + (1-self.alpha)*sim_repre).sort(descending=True)[0].cpu().numpy()) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
similarities = np.concatenate(similarities, axis=0) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return similarities | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def fit(self, x: np.ndarray, y: np.ndarray, batch_size: int = 128, nb_epochs: int = 20, **kwargs) -> None: | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Determine a threshold that covers 95% of clean samples. | ||||||||||||||||||||||||||||||||||||
:param x: Clean sample data | ||||||||||||||||||||||||||||||||||||
:param y: Clean sample labels (not used in this method) | ||||||||||||||||||||||||||||||||||||
:param batch_size: Batch size for processing | ||||||||||||||||||||||||||||||||||||
:param nb_epochs: Number of training epochs (not used in this method) | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
clean_similarities = self._get_metrics(x, batch_size) | ||||||||||||||||||||||||||||||||||||
Check notice Code scanning / CodeQL Unused local variable Note
Variable clean_similarities is not used.
|
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# 使用第K-1列的值来确定阈值 | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a translation in English. |
||||||||||||||||||||||||||||||||||||
k_minus_one_metrics = clean_metrics[:, self.K-1] | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
# 计算95%分位数作为阈值 | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add a translation in English. |
||||||||||||||||||||||||||||||||||||
self.threshold = np.percentile(k_minus_one_metrics, self.threshold) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
print(f"Threshold set to: {self.threshold}") | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace |
||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
def detect(self, x: np.ndarray, batch_size: int = 128, **kwargs) -> tuple[dict, np.ndarray]: | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Detect whether given samples are adversarial | ||||||||||||||||||||||||||||||||||||
:param x: Input samples | ||||||||||||||||||||||||||||||||||||
:param batch_size: Batch size for processing | ||||||||||||||||||||||||||||||||||||
:return: (report, is_adversarial): | ||||||||||||||||||||||||||||||||||||
where report containing detection results | ||||||||||||||||||||||||||||||||||||
where is_adversarial is a boolean list indicating whether samples are adversarial or not | ||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||
Comment on lines
+147
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||
if self.threshold is None: | ||||||||||||||||||||||||||||||||||||
raise ValueError("Detector has not been fitted. Call fit() before detect().") | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
similarities = self._get_metrics(x, batch_size) | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
report = similarities[:, self.K-1] | ||||||||||||||||||||||||||||||||||||
is_adversarial = report < self.threshold | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
return report, is_adversarial | ||||||||||||||||||||||||||||||||||||
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,180 @@ | ||||||
# MIT License | ||||||
# | ||||||
# Copyright (C) The Adversarial Robustness Toolbox (ART) Authors 2023 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
# | ||||||
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated | ||||||
# documentation files (the "Software"), to deal in the Software without restriction, including without limitation the | ||||||
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit | ||||||
# persons to whom the Software is furnished to do so, subject to the following conditions: | ||||||
# | ||||||
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of the | ||||||
# Software. | ||||||
# | ||||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE | ||||||
# WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, | ||||||
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||||||
# SOFTWARE. | ||||||
from __future__ import absolute_import, division, print_function, unicode_literals | ||||||
|
||||||
import logging | ||||||
import pytest | ||||||
import numpy as np | ||||||
|
||||||
import sys | ||||||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'sys' is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
import os | ||||||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'os' is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
from art.attacks.evasion.fast_gradient import FastGradientMethod | ||||||
from art.estimators.classification import PyTorchClassifier | ||||||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'PyTorchClassifier' is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
from art.defences.detector.evasion import BeyondDetector | ||||||
from art.utils import load_dataset, get_file | ||||||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'load_dataset' is not used.
Import of 'get_file' is not used. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
|
||||||
from tests.utils import ARTTestException | ||||||
|
||||||
logger = logging.getLogger(__name__) | ||||||
|
||||||
import torch.nn as nn | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please move third-part imports before the imports from |
||||||
from torchvision import models, transforms | ||||||
|
||||||
class SimSiamWithCls(nn.Module): | ||||||
''' | ||||||
SimSiam with Classifier | ||||||
''' | ||||||
def __init__(self, arch='resnet18', feat_dim=2048, num_proj_layers=2): | ||||||
|
||||||
super(SimSiamWithCls, self).__init__() | ||||||
self.backbone = models.resnet18() | ||||||
out_dim = self.backbone.fc.weight.shape[1] | ||||||
self.backbone.conv1 = nn.Conv2d( | ||||||
3, 64, kernel_size=3, stride=1, padding=2, bias=False | ||||||
) | ||||||
self.backbone.maxpool = nn.Identity() | ||||||
self.backbone.fc = nn.Identity() | ||||||
self.classifier = nn.Linear(out_dim, 10) | ||||||
|
||||||
pred_hidden_dim = int(feat_dim / 4) | ||||||
|
||||||
self.projector = nn.Sequential( | ||||||
nn.Linear(out_dim, feat_dim, bias=False), | ||||||
nn.BatchNorm1d(feat_dim), | ||||||
nn.ReLU(), | ||||||
nn.Linear(feat_dim, feat_dim, bias=False), | ||||||
nn.BatchNorm1d(feat_dim), | ||||||
nn.ReLU(), | ||||||
nn.Linear(feat_dim, feat_dim), | ||||||
nn.BatchNorm1d(feat_dim, affine=False), | ||||||
) | ||||||
self.projector[6].bias.requires_grad = False | ||||||
|
||||||
self.predictor = nn.Sequential( | ||||||
nn.Linear(feat_dim, pred_hidden_dim, bias=False), | ||||||
nn.BatchNorm1d(pred_hidden_dim), | ||||||
nn.ReLU(), | ||||||
nn.Linear(pred_hidden_dim, feat_dim), | ||||||
) | ||||||
|
||||||
def forward(self, img, im_aug1=None, im_aug2=None): | ||||||
|
||||||
r_ori = self.backbone(img) | ||||||
if im_aug1 is None and im_aug2 is None: | ||||||
cls = self.classifier(r_ori) | ||||||
rep = self.projector(r_ori) | ||||||
return {'cls': cls, 'rep':rep} | ||||||
else: | ||||||
|
||||||
r1 = self.backbone(im_aug1) | ||||||
r2 = self.backbone(im_aug2) | ||||||
|
||||||
z1 = self.projector(r1) | ||||||
z2 = self.projector(r2) | ||||||
# print("shape of z:", z1.shape) | ||||||
|
||||||
p1 = self.predictor(z1) | ||||||
p2 = self.predictor(z2) | ||||||
# print("shape of p:", p1.shape) | ||||||
|
||||||
return {'z1': z1, 'z2': z2, 'p1': p1, 'p2': p2} | ||||||
|
||||||
@pytest.fixture | ||||||
def get_cifar10(): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can use |
||||||
""" | ||||||
Loads CIFAR10 dataset. | ||||||
""" | ||||||
(x_train, y_train), (x_test, y_test), min_, max_ = load_cifar10() | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add import for fixture |
||||||
return (x_train, y_train), (x_test, y_test), min_, max_ | ||||||
|
||||||
|
||||||
@pytest.fixture | ||||||
def get_ssl_model(weights_path): | ||||||
""" | ||||||
Loads the SSL model (SimSiamWithCls). | ||||||
""" | ||||||
model = SimSiamWithCls() | ||||||
model.load_state_dict(torch.load(weights_path)) | ||||||
return model | ||||||
|
||||||
@pytest.mark.only_with_platform("pytorch") | ||||||
def test_beyond_detector(art_warning, get_cifar10, get_ssl_model): | ||||||
try: | ||||||
# Load CIFAR10 data | ||||||
(x_train, y_train), (x_test, y_test), min_, max_ = get_cifar10 | ||||||
Check notice Code scanning / CodeQL Unused local variable Note test
Variable min_ is not used.
Check notice Code scanning / CodeQL Unused local variable Note test
Variable max_ is not used.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
# Load models | ||||||
# Download pretrained weights from https://drive.google.com/drive/folders/1ieEdd7hOj2CIl1FQfu4-3RGZmEj-mesi?usp=sharing | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How large are the downloaded files? Can we store them in the ART repo? |
||||||
target_model = models.resnet18() | ||||||
target_model.load_state_dict(torch.load("./resnet_c10.pth")) | ||||||
ssl_model = get_ssl_model() | ||||||
ssl_model.load_state_dict(torch.load("./simsiam_c10.pth")) | ||||||
|
||||||
|
||||||
# Generate adversarial samples | ||||||
attack = FastGradientMethod(estimator=target_model, eps=0.05) | ||||||
x_test_adv = attack.generate(x_test) | ||||||
|
||||||
img_augmentations = transforms.Compose([ | ||||||
transforms.RandomResizedCrop(32, scale=(0.2, 1.)), | ||||||
transforms.RandomHorizontalFlip(), | ||||||
transforms.RandomApply([ | ||||||
transforms.ColorJitter(0.4, 0.4, 0.4, 0.1) # not strengthened | ||||||
], p=0.8), | ||||||
transforms.RandomGrayscale(p=0.2), | ||||||
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) | ||||||
]) | ||||||
|
||||||
# Initialize BeyondDetector | ||||||
detector = BeyondDetector( | ||||||
target_model=target_model, | ||||||
ssl_model=ssl_model, | ||||||
img_augmentation=img_augmentations, | ||||||
aug_num=50, | ||||||
alpha=0.8, | ||||||
K=20, | ||||||
percentile=5 | ||||||
) | ||||||
Comment on lines
+146
to
+154
Check failure Code scanning / CodeQL Wrong name for an argument in a class instantiation Error test
Keyword argument 'img_augmentation' is not a supported parameter name of
BeyondDetector.__init__ Error loading related location Loading |
||||||
|
||||||
# Fit the detector | ||||||
detector.fit(x_train, y_train, batch_size=128) | ||||||
|
||||||
# Apply detector on clean and adversarial test data | ||||||
_, test_detection = detector.detect(x_test) | ||||||
_, test_adv_detection = detector.detect(x_test_adv) | ||||||
|
||||||
# Assert there is at least one true positive and negative | ||||||
nb_true_positives = np.sum(test_adv_detection) | ||||||
nb_true_negatives = len(test_detection) - np.sum(test_detection) | ||||||
assert nb_true_positives > 0 | ||||||
assert nb_true_negatives > 0 | ||||||
Comment on lines
+166
to
+167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be possible to make these assertions more accurate? |
||||||
|
||||||
# Calculate and print detection accuracy | ||||||
clean_accuracy = 1 - np.mean(test_detection) | ||||||
adv_accuracy = np.mean(test_adv_detection) | ||||||
print(f"Clean Detection Accuracy: {clean_accuracy:.4f}") | ||||||
print(f"Adversarial Detection Accuracy: {adv_accuracy:.4f}") | ||||||
Comment on lines
+172
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please replace |
||||||
|
||||||
except ARTTestException as e: | ||||||
art_warning(e) | ||||||
|
||||||
if __name__ == "__main__": | ||||||
|
||||||
test_beyond_detector() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.