Skip to content

Commit

Permalink
Merge pull request #41 from salesforce/fft
Browse files Browse the repository at this point in the history
Fft
  • Loading branch information
yangwenz authored Sep 16, 2022
2 parents 6b39ea7 + d56f057 commit 2adc636
Show file tree
Hide file tree
Showing 15 changed files with 366 additions and 54 deletions.
17 changes: 14 additions & 3 deletions omnixai/explainers/nlp/specific/ig.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ def __init__(
preprocess_function: Callable,
mode: str = "classification",
id2token: Dict = None,
tokenizer: Callable = None,
**kwargs,
):
"""
Expand All @@ -196,7 +197,8 @@ def __init__(
into the inputs of ``model``. The first output of ``preprocess_function`` must
be the token ids.
:param mode: The task type, e.g., `classification` or `regression`.
:param id2token: The mapping from token ids to tokens.
:param id2token: The mapping from token ids to tokens. If `tokenizer` is set, `id2token` will be ignored.
:param tokenizer: The tokenizer for processing text inputs, i.e., tokenizers in HuggingFace.
"""
super().__init__()
assert preprocess_function is not None, (
Expand All @@ -207,6 +209,7 @@ def __init__(
self.embedding_layer = embedding_layer
self.preprocess_function = preprocess_function
self.id2token = id2token
self.tokenizer = tokenizer

ig_class = None
if is_torch_available():
Expand Down Expand Up @@ -293,11 +296,19 @@ def explain(self, X: Text, y=None, **kwargs) -> WordImportance:
steps=steps,
batch_size=batch_size
)
tokens = inputs[0].detach().cpu().numpy() if self.model_type == "torch" else inputs[0].numpy()
tokens = inputs[0].detach().cpu().numpy() if self.model_type == "torch" \
else inputs[0].numpy()

if self.tokenizer is not None:
input_tokens = [self.tokenizer.decode([t]) for t in tokens[0]]
elif self.id2token is not None:
input_tokens = [self.id2token[t] for t in tokens[0]]
else:
input_tokens = tokens[0]
explanations.add(
instance=instance.to_str(),
target_label=y[i] if y is not None else None,
tokens=tokens[0] if self.id2token is None else [self.id2token[t] for t in tokens[0]],
tokens=input_tokens,
importance_scores=scores,
)
return explanations
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from typing import Union, List
from collections import defaultdict
from ..utils import Objective, FeatureOptimizerMixin
from ..utils import fft_inputs, fft_scale
from .preprocess import fft_images


class FeatureOptimizer(FeatureOptimizerMixin):
Expand Down Expand Up @@ -171,6 +173,8 @@ def optimize(
value_normalizer="sigmoid",
value_range=(0.05, 0.95),
init_std=0.01,
use_fft=False,
fft_decay=1.0,
normal_color=False,
save_all_images=False,
verbose=True,
Expand All @@ -186,16 +190,36 @@ def optimize(
if not isinstance(regularizers, list):
regularizers = [regularizers]
regularizers = [self._regularize(reg, w) for reg, w in regularizers]
if use_fft:
# Using "normal color" for FFT preconditioning
normal_color = True

device = next(self.model.parameters()).device
inputs = torch.tensor(
np.random.randn(*(self.num_combinations, 3, *image_shape)) * init_std,
dtype=torch.float32,
requires_grad=True,
device=device
)
shape = (self.num_combinations, 3, *image_shape)
if not use_fft:
inputs = torch.tensor(
np.random.randn(*shape) * init_std,
dtype=torch.float32,
requires_grad=True,
device=device
)
normalize = lambda x: self._normalize(
x, value_normalizer, value_range, normal_color)
else:
inputs = torch.tensor(
fft_inputs(*shape, mode="torch", std=init_std),
dtype=torch.float32,
requires_grad=True,
device=device
)
scales = fft_scale(
image_shape[0], image_shape[1], mode="torch", decay_power=fft_decay)
scales = torch.tensor(scales, dtype=torch.complex64, device=device)
normalize = lambda x: self._normalize(
fft_images(image_shape[0], image_shape[1], inputs, scales),
value_normalizer, value_range, normal_color
)
optimizer = torch.optim.Adam([inputs], lr=learning_rate)
normalize = lambda x: self._normalize(x, value_normalizer, value_range, normal_color)

results = []
for i in range(num_iterations):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#
import torch
import torchvision
from packaging import version
from omnixai.preprocessing.base import TransformBase


Expand Down Expand Up @@ -121,3 +122,14 @@ def transform(self, x):

def invert(self, x):
raise RuntimeError("`Padding` doesn't support the `invert` function.")


def fft_images(width, height, inputs, scale):
spectrum = torch.complex(inputs[0], inputs[1]) * scale[None, None, :, :]
# Torch 1.7
if version.parse(torch.__version__) < version.parse("1.8"):
x = torch.cat([spectrum.real.unsqueeze(dim=-1), spectrum.imag.unsqueeze(dim=-1)], dim=-1)
image = torch.irfft(x, signal_ndim=2, normalized=False, onesided=False)
else:
image = torch.fft.ifft2(spectrum)
return image / 4.0
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import tensorflow as tf
from typing import Union, List
from ..utils import Objective, FeatureOptimizerMixin
from ..utils import fft_inputs, fft_scale
from .preprocess import fft_images


class FeatureOptimizer(FeatureOptimizerMixin):
Expand Down Expand Up @@ -195,6 +197,8 @@ def optimize(
value_normalizer="sigmoid",
value_range=(0.05, 0.95),
init_std=0.01,
use_fft=False,
fft_decay=1.0,
normal_color=False,
save_all_images=False,
verbose=True,
Expand All @@ -212,10 +216,25 @@ def optimize(
if not isinstance(regularizers, list):
regularizers = [regularizers]
regularizers = [self._regularize(reg, w) for reg, w in regularizers]
if use_fft:
# Using "normal color" for FFT preconditioning
normal_color = True

inputs = tf.Variable(
tf.random.normal(shape, stddev=init_std, dtype=tf.float32), trainable=True)
normalize = lambda x: self._normalize(x, value_normalizer, value_range, normal_color)
if not use_fft:
inputs = tf.Variable(
tf.random.normal(shape, stddev=init_std, dtype=tf.float32), trainable=True)
normalize = lambda x: self._normalize(
x, value_normalizer, value_range, normal_color)
else:
inputs = tf.Variable(
fft_inputs(shape[0], shape[3], shape[1], shape[2], mode="tf", std=init_std),
trainable=True)
scales = fft_scale(shape[1], shape[2], mode="tf", decay_power=fft_decay)
scales = tf.convert_to_tensor(scales, dtype=tf.complex64)
normalize = lambda x: self._normalize(
fft_images(shape[1], shape[2], inputs, scales),
value_normalizer, value_range, normal_color
)
optimizer = tf.keras.optimizers.Adam(learning_rate)

@tf.function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,11 @@ def transform(self, x):

def invert(self, x):
raise RuntimeError("`Padding` doesn't support the `invert` function.")


def fft_images(width, height, inputs, scale):
spectrum = tf.complex(inputs[0], inputs[1]) * scale
image = tf.signal.irfft2d(spectrum)
image = tf.transpose(image, (0, 2, 3, 1))
image = image[:, :width, :height, :]
return image / 4.0
24 changes: 24 additions & 0 deletions omnixai/explainers/vision/specific/feature_visualization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,27 @@ def _process_objectives(objectives):
labels.append({"type": r["type"], "layer_name": layer_name, "index": indices[i, j]})
names.append(labels)
return results, indices.shape[0], names


def fft_freq(width, height, mode):
freq_x = np.fft.fftfreq(width)[:, None]
if mode == "tf":
cut_off = int(height % 2 == 1)
freq_y = np.fft.fftfreq(height)[:height // 2 + 1 + cut_off]
return np.sqrt(freq_y ** 2 + freq_x ** 2)
else:
freq_y = np.fft.fftfreq(height)
return np.sqrt(freq_y ** 2 + freq_x ** 2)


def fft_scale(width, height, mode, decay_power=1.0):
frequencies = fft_freq(width, height, mode)
scale = 1.0 / np.maximum(frequencies, 1.0 / max(width, height)) ** decay_power
scale = scale * np.sqrt(width * height)
return scale


def fft_inputs(batch_size, channel, width, height, mode, std=0.01):
freq = fft_freq(width, height, mode)
inputs = np.random.randn(*((2, batch_size, channel) + freq.shape)) * std
return inputs.astype(np.float32)
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def explain(
transformers: Pipeline = None,
regularizers: List = None,
image_shape: Tuple = None,
use_fft=False,
fft_decay=1.0,
normal_color: bool = False,
verbose: bool = True,
**kwargs
Expand All @@ -113,6 +115,8 @@ def explain(
:param regularizers: A list of regularizers applied on images. Each regularizer is a tupe
`(regularizer_type, weight)` where `regularizer_type` is "l1", "l2" or "tv".
:param image_shape: The customized image shape. If None, the default shape is (224, 224).
:param use_fft: Whether to use fourier preconditioning.
:param fft_decay: The value controlling the allowed energy of the high frequency.
:param normal_color: Whether to map uncorrelated colors to normal colors.
:param verbose: Whether to print the optimization progress.
:return: The optimized images for the objectives.
Expand Down Expand Up @@ -155,6 +159,8 @@ def explain(
value_normalizer=value_normalizer,
value_range=value_range,
init_std=init_std,
use_fft=use_fft,
fft_decay=fft_decay,
normal_color=normal_color,
save_all_images=False,
verbose=verbose
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _plot(x):
def test_layer(self):
objectives = [
Objective(
layer=self.model.features[20],
layer=self.model.features[-6],
channel_indices=list(range(5))
)
]
Expand All @@ -39,7 +39,8 @@ def test_layer(self):
)
results, names = optimizer.optimize(
num_iterations=300,
image_shape=(224, 224)
image_shape=(224, 224),
use_fft=True
)
for res, name in zip(results[-1], names):
print(name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ class TestExplainer(unittest.TestCase):

def setUp(self) -> None:
device = "cuda" if torch.cuda.is_available() else "cpu"
# self.model = models.vgg16(pretrained=True).to(device)
# self.target_layer = self.model.features[20]
self.model = vgg16.VGG16()
self.target_layer = self.model.layers[15]
self.model = models.vgg16(pretrained=True).to(device)
self.target_layer = self.model.features[-6]
# self.model = vgg16.VGG16()
# self.target_layer = self.model.layers[15]

def test(self):
optimizer = FeatureVisualizer(
Expand All @@ -27,7 +27,8 @@ def test(self):
)
explanations = optimizer.explain(
num_iterations=300,
image_shape=(224, 224)
image_shape=(224, 224),
use_fft=True
)
explanations.ipython_plot()

Expand Down
80 changes: 80 additions & 0 deletions omnixai/tests/explainers/feature_visualization/test_fft.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#
# Copyright (c) 2022 salesforce.com, inc.
# All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
#
import unittest
import torch
import tensorflow as tf
from omnixai.explainers.vision.specific.feature_visualization.utils import \
fft_freq, fft_scale, fft_inputs
from omnixai.explainers.vision.specific.feature_visualization.tf.preprocess import \
fft_images as fft_images_tf
from omnixai.explainers.vision.specific.feature_visualization.pytorch.preprocess import \
fft_images as fft_images_torch


class TestFFT(unittest.TestCase):

def test_1(self):
batch_size = 5
channel = 3
width = 10
height = 7
mode = "torch"

freq = fft_freq(width, height, mode)
scale = fft_scale(width, height, mode)
inputs = fft_inputs(batch_size, channel, width, height, mode)
self.assertEqual(freq.shape, (10, 7))
self.assertEqual(scale.shape, (10, 7))
self.assertEqual(inputs.shape, (2, 5, 3, 10, 7))

def test_2(self):
batch_size = 5
channel = 3
width = 10
height = 7
mode = "tf"

freq = fft_freq(width, height, mode)
scale = fft_scale(width, height, mode)
inputs = fft_inputs(batch_size, channel, width, height, mode)
self.assertEqual(freq.shape, (10, 5))
self.assertEqual(scale.shape, (10, 5))
self.assertEqual(inputs.shape, (2, 5, 3, 10, 5))

def test_3(self):
batch_size = 5
channel = 3
width = 10
height = 7
mode = "tf"

scale = fft_scale(width, height, mode)
scale = tf.convert_to_tensor(scale, dtype=tf.complex64)
inputs = fft_inputs(batch_size, channel, width, height, mode)
inputs = tf.convert_to_tensor(inputs)

images = fft_images_tf(width, height, inputs, scale)
self.assertEqual(images.shape, (5, 10, 7, 3))

def test_4(self):
batch_size = 5
channel = 3
width = 10
height = 7
mode = "torch"

scale = fft_scale(width, height, mode)
scale = torch.tensor(scale, dtype=torch.complex64)
inputs = fft_inputs(batch_size, channel, width, height, mode)
inputs = torch.tensor(inputs, dtype=torch.float32)

images = fft_images_torch(width, height, inputs, scale)
self.assertEqual(images.shape, (5, 3, 10, 7))


if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions omnixai/tests/explainers/ig/nlp_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ def _preprocess(X: Text):
self.evaluate()

def train(self):
Trainer(optimizer_class=torch.optim.AdamW, learning_rate=1e-3, batch_size=128, num_epochs=10).train(
Trainer(optimizer_class=torch.optim.AdamW, learning_rate=1e-3,
batch_size=128, num_epochs=10).train(
model=self.model,
loss_func=nn.CrossEntropyLoss(),
train_x=self.transform.transform(self.x_train),
Expand Down Expand Up @@ -137,7 +138,8 @@ def evaluate(self):
outputs.append(y.detach().cpu().numpy())
outputs = np.concatenate(outputs, axis=0)
predictions = np.argmax(outputs, axis=1)
print("Test accuracy: {}".format(sklearn.metrics.f1_score(self.y_test, predictions, average="binary")))
print("Test accuracy: {}".format(
sklearn.metrics.f1_score(self.y_test, predictions, average="binary")))

def test_explain(self):
idx = 83
Expand Down
Loading

0 comments on commit 2adc636

Please sign in to comment.