Skip to content

Commit

Permalink
Added tests for MMD loss and MSE losses. Introduced new simple Normal…
Browse files Browse the repository at this point in the history
… model for this reason, along with corresponding tests
  • Loading branch information
joelnmdyer committed Jul 13, 2023
1 parent 8dc8b86 commit 06f7188
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 6 deletions.
22 changes: 16 additions & 6 deletions blackbirds/losses.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from blackbirds.model import Model
from blackbirds.simulate import simulate_and_observe
from blackbirds.models.model import Model
from blackbirds.simulate import simulate_and_observe_model

import torch

class SingleOutput_SmulateAndMSELoss:
class SingleOutput_SimulateAndMSELoss:

"""
Computes MSE between observed data y and simulated data at theta (to be passed during __call__).
Expand Down Expand Up @@ -50,7 +50,12 @@ def __init__(
"""

assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here"
assert len(y.shape) == 1, "This class assumes y is a single univariate time series"
try:
assert len(y.shape) == 1, "This class assumes y is a single univariate time series"
except AssertionError:
assert len(y.shape) == 2, "If not a 1D Tensor, y must be at most 2D of shape (1, T)"
assert y.shape[1] == 1, "This class assumes y is a single univariate time series. This appears to be a batch of data."
y = y.reshape(-1)

self.y = y
self.y_matrix = self.y.reshape(1,-1,1)
Expand All @@ -71,7 +76,12 @@ def __call__(
):

assert isinstance(x, torch.Tensor), "x is assumed to be a torch.Tensor here"
assert len(x.shape) == 1, "This class assumes x is a single univariate time series"
try:
assert len(x.shape) == 1, "This class assumes x is a single univariate time series"
except AssertionError:
assert len(x.shape) == 2, "If not a 1D Tensor, x must be at most 2D of shape (1, T)"
assert x.shape[1] == 1, "This class assumes x is a single univariate time series. This appears to be a batch of data."
x = x.reshape(-1)

nx = x.shape[0]
x_matrix = x.reshape(1,-1,1)
Expand Down Expand Up @@ -107,7 +117,7 @@ def __init__(
gradient_horizon: int | None = None
):

self.mmd_loss = MMDLoss(y)
self.mmd_loss = UnivariateMMDLoss(y)
self.model = model
self.gradient_horizon = gradient_horizon

Expand Down
28 changes: 28 additions & 0 deletions blackbirds/models/normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from blackbirds.models.model import Model

class Normal(Model):

def __init__(
self,
n_timesteps: int,
):

super().__init__()
self.n_timesteps = n_timesteps

def initialize(self, params):
return torch.zeros(1).reshape(1, 1)

def trim_time_series(self, x):
return x[-1:]

def step(self, params, x):

mu, sigma = params
assert sigma > 0, "Argument sigma must be a float greater than 0."
return mu + sigma * torch.randn((1,1))

def observe(self, x):
return [x]
60 changes: 60 additions & 0 deletions test/models/test_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.autograd as autograd
import numpy as np
from blackbirds.models.normal import Normal


class TestRandomWalk:
def test__result(self):
torch.manual_seed(0)
n_timesteps = 5000
normal = Normal(n_timesteps)
assert normal.n_timesteps == n_timesteps
mu, sigma = 1., 3.
params = torch.Tensor([mu, sigma])
x = normal.run(params)
assert x.shape == (n_timesteps + 1, 1)
trajectory = normal.observe(x)[0]
mean = trajectory.mean()
std = trajectory.std()
# Sample mean should approximately match mu
assert np.isclose(
mean, mu, atol=1e-1
)
# Sample std should approximately match sigma
assert np.isclose(
std, sigma, rtol=1e-1
)
# Should also have approximately 0. skewness
sample_skewness = torch.pow((trajectory - mean)/std, 3.).mean()
assert np.isclose(
sample_skewness, 0., atol=1e-1
)

def test__gradient(self):
mu, sigma = 0.4, 1.5
params = torch.Tensor([mu, sigma])
params.requires_grad = True
n_timesteps = 100
normal = Normal(n_timesteps)
assert normal.n_timesteps == n_timesteps
x = normal.observe(normal.run(params))[0]
assert x.shape == (n_timesteps + 1, 1)
x.sum().backward()
assert params.grad is not None
assert np.isclose(params.grad[0], n_timesteps)
grad_sigmas = []

def mean_x(params):
x = normal.observe(normal.run(params))[0]
return x.mean()

for t in range(500):
params = torch.Tensor([mu, sigma])
params.requires_grad = True
sigma_grad = autograd.grad(mean_x(params), params)[-1]
grad_sigmas.append(sigma_grad[-1])

assert np.isclose(torch.stack(grad_sigmas).mean(), 0., atol=1e-2)
expected_std = 1. / np.sqrt(n_timesteps)
assert np.isclose(torch.stack(grad_sigmas).std().item(), expected_std, atol=1e-2)
69 changes: 69 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import torch

from blackbirds.losses import SingleOutput_SimulateAndMSELoss, SingleOutput_SimulateAndMMD
from blackbirds.models.normal import Normal
from blackbirds.models.random_walk import RandomWalk
from blackbirds.simulate import simulate_and_observe_model

class TestMSELoss:

def test_normal_same(self):

T = 5000
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mse_loss = SingleOutput_SimulateAndMSELoss(normal)
loss_value = mse_loss(params, y)
# MMD between distributions of hould be close to...
value = 2 * sigma**2
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2)

def test_normal_different(self):

T = 5000
normal = Normal(n_timesteps=T)
mu_y, sigma_y = 1., 2.
params = torch.Tensor([mu_y, sigma_y])
y = simulate_and_observe_model(normal, params)[0]
mse_loss = SingleOutput_SimulateAndMSELoss(normal)
mu_x, sigma_x = 0., 1.
params = torch.Tensor([mu_x, sigma_x])
loss_value = mse_loss(params, y)
# MMD between distributions of hould be close to...
value = sigma_x**2 + sigma_y**2 + (mu_x - mu_y)**2
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2)


class TestMMDLoss:

def test_normal_same(self):

T = 500
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mmd_loss = SingleOutput_SimulateAndMMD(y, normal)
loss_value = mmd_loss(params, y)
# MMD between distributions of hould be close to 0
assert np.isclose(loss_value, 0., atol=1e-3)

def test_normal_different(self):

T = 500
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mmd_loss = SingleOutput_SimulateAndMMD(y, normal)
mu, sigma = 0., 1.
params = torch.Tensor([mu, sigma])
loss_value = mmd_loss(params, y)
# MMD between distributions of hould be close to 0
assert not np.isclose(loss_value, 0.)
assert loss_value > 0.


0 comments on commit 06f7188

Please sign in to comment.