Skip to content

Commit

Permalink
Merge pull request #38 from arnauqb/examples_losses
Browse files Browse the repository at this point in the history
Added example common losses
  • Loading branch information
arnauqb authored Jul 13, 2023
2 parents c63dc10 + 06f7188 commit 51a94b4
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 0 deletions.
131 changes: 131 additions & 0 deletions blackbirds/losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
from blackbirds.models.model import Model
from blackbirds.simulate import simulate_and_observe_model

import torch

class SingleOutput_SimulateAndMSELoss:

"""
Computes MSE between observed data y and simulated data at theta (to be passed during __call__).
**Arguments**
- `model`: An instance of a Model. The model that you'd like to "fit".
- `gradient_horizon`: Specifies the gradient horizon to use. None implies infinite horizon.
"""

def __init__(
self,
model: Model,
gradient_horizon: int | None = None
):

self.loss = torch.nn.MSELoss()
self.model = model
self.gradient_horizon = gradient_horizon

def __call__(
self,
theta: torch.Tensor,
y: torch.Tensor,
):

x = simulate_and_observe_model(
self.model,
theta,
self.gradient_horizon
)[0]
return self.loss(x, y)

class UnivariateMMDLoss:
def __init__(
self,
y: torch.Tensor
):

"""
Computes MMD between data y and simulated output x (to be passed during call).
Assumes y is a torch.Tensor consisting of a single univariate time series.
"""

assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here"
try:
assert len(y.shape) == 1, "This class assumes y is a single univariate time series"
except AssertionError:
assert len(y.shape) == 2, "If not a 1D Tensor, y must be at most 2D of shape (1, T)"
assert y.shape[1] == 1, "This class assumes y is a single univariate time series. This appears to be a batch of data."
y = y.reshape(-1)

self.y = y
self.y_matrix = self.y.reshape(1,-1,1)
yy = torch.cdist(self.y_matrix, self.y_matrix)
yy_sqrd = torch.pow(yy, 2)
self.y_sigma = torch.median(yy_sqrd)
ny = self.y.shape[0]
self.kyy = (
torch.exp(
-yy_sqrd / self.y_sigma
)
- torch.eye(ny)
).sum() / (ny * (ny - 1))

def __call__(
self,
x: torch.Tensor,
):

assert isinstance(x, torch.Tensor), "x is assumed to be a torch.Tensor here"
try:
assert len(x.shape) == 1, "This class assumes x is a single univariate time series"
except AssertionError:
assert len(x.shape) == 2, "If not a 1D Tensor, x must be at most 2D of shape (1, T)"
assert x.shape[1] == 1, "This class assumes x is a single univariate time series. This appears to be a batch of data."
x = x.reshape(-1)

nx = x.shape[0]
x_matrix = x.reshape(1,-1,1)
kxx = torch.exp(
-torch.pow(
torch.cdist(x_matrix, x_matrix),
2)
/ self.y_sigma
)
kxx = (kxx - torch.eye(nx)).sum() / (nx * (nx - 1))
kxy = torch.exp( - torch.pow(torch.cdist(x_matrix, self.y_matrix), 2) / self.y_sigma )
kxy = kxy.mean()
return kxx + self.kyy - 2 * kxy

class SingleOutput_SimulateAndMMD:

"""
Example implementation of a loss that simulates from the model and computes the MMD
between the model output and observed data y. (This treats the entries in y and in
the simulator output as exchangeable.)
**Arguments**
- `y`: torch.Tensor containing a single univariate time series.
- `model`: An instance of a Model.
- `gradient_horizon`: An integer or None. Sets horizon over which gradients are retained. If None, infinite horizon used.
"""

def __init__(
self,
y: torch.Tensor,
model: Model,
gradient_horizon: int | None = None
):

self.mmd_loss = UnivariateMMDLoss(y)
self.model = model
self.gradient_horizon = gradient_horizon

def __call__(
self,
theta: torch.Tensor,
y: torch.Tensor
):

x = simulate_and_observe_model(self.model, theta, self.gradient_horizon)[0]
return self.mmd_loss(x)
28 changes: 28 additions & 0 deletions blackbirds/models/normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import torch

from blackbirds.models.model import Model

class Normal(Model):

def __init__(
self,
n_timesteps: int,
):

super().__init__()
self.n_timesteps = n_timesteps

def initialize(self, params):
return torch.zeros(1).reshape(1, 1)

def trim_time_series(self, x):
return x[-1:]

def step(self, params, x):

mu, sigma = params
assert sigma > 0, "Argument sigma must be a float greater than 0."
return mu + sigma * torch.randn((1,1))

def observe(self, x):
return [x]
60 changes: 60 additions & 0 deletions test/models/test_normal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import torch
import torch.autograd as autograd
import numpy as np
from blackbirds.models.normal import Normal


class TestRandomWalk:
def test__result(self):
torch.manual_seed(0)
n_timesteps = 5000
normal = Normal(n_timesteps)
assert normal.n_timesteps == n_timesteps
mu, sigma = 1., 3.
params = torch.Tensor([mu, sigma])
x = normal.run(params)
assert x.shape == (n_timesteps + 1, 1)
trajectory = normal.observe(x)[0]
mean = trajectory.mean()
std = trajectory.std()
# Sample mean should approximately match mu
assert np.isclose(
mean, mu, atol=1e-1
)
# Sample std should approximately match sigma
assert np.isclose(
std, sigma, rtol=1e-1
)
# Should also have approximately 0. skewness
sample_skewness = torch.pow((trajectory - mean)/std, 3.).mean()
assert np.isclose(
sample_skewness, 0., atol=1e-1
)

def test__gradient(self):
mu, sigma = 0.4, 1.5
params = torch.Tensor([mu, sigma])
params.requires_grad = True
n_timesteps = 100
normal = Normal(n_timesteps)
assert normal.n_timesteps == n_timesteps
x = normal.observe(normal.run(params))[0]
assert x.shape == (n_timesteps + 1, 1)
x.sum().backward()
assert params.grad is not None
assert np.isclose(params.grad[0], n_timesteps)
grad_sigmas = []

def mean_x(params):
x = normal.observe(normal.run(params))[0]
return x.mean()

for t in range(500):
params = torch.Tensor([mu, sigma])
params.requires_grad = True
sigma_grad = autograd.grad(mean_x(params), params)[-1]
grad_sigmas.append(sigma_grad[-1])

assert np.isclose(torch.stack(grad_sigmas).mean(), 0., atol=1e-2)
expected_std = 1. / np.sqrt(n_timesteps)
assert np.isclose(torch.stack(grad_sigmas).std().item(), expected_std, atol=1e-2)
69 changes: 69 additions & 0 deletions test/test_losses.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import numpy as np
import torch

from blackbirds.losses import SingleOutput_SimulateAndMSELoss, SingleOutput_SimulateAndMMD
from blackbirds.models.normal import Normal
from blackbirds.models.random_walk import RandomWalk
from blackbirds.simulate import simulate_and_observe_model

class TestMSELoss:

def test_normal_same(self):

T = 5000
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mse_loss = SingleOutput_SimulateAndMSELoss(normal)
loss_value = mse_loss(params, y)
# MMD between distributions of hould be close to...
value = 2 * sigma**2
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2)

def test_normal_different(self):

T = 5000
normal = Normal(n_timesteps=T)
mu_y, sigma_y = 1., 2.
params = torch.Tensor([mu_y, sigma_y])
y = simulate_and_observe_model(normal, params)[0]
mse_loss = SingleOutput_SimulateAndMSELoss(normal)
mu_x, sigma_x = 0., 1.
params = torch.Tensor([mu_x, sigma_x])
loss_value = mse_loss(params, y)
# MMD between distributions of hould be close to...
value = sigma_x**2 + sigma_y**2 + (mu_x - mu_y)**2
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2)


class TestMMDLoss:

def test_normal_same(self):

T = 500
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mmd_loss = SingleOutput_SimulateAndMMD(y, normal)
loss_value = mmd_loss(params, y)
# MMD between distributions of hould be close to 0
assert np.isclose(loss_value, 0., atol=1e-3)

def test_normal_different(self):

T = 500
normal = Normal(n_timesteps=T)
mu, sigma = 1., 2.
params = torch.Tensor([mu, sigma])
y = simulate_and_observe_model(normal, params)[0]
mmd_loss = SingleOutput_SimulateAndMMD(y, normal)
mu, sigma = 0., 1.
params = torch.Tensor([mu, sigma])
loss_value = mmd_loss(params, y)
# MMD between distributions of hould be close to 0
assert not np.isclose(loss_value, 0.)
assert loss_value > 0.


0 comments on commit 51a94b4

Please sign in to comment.