-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #38 from arnauqb/examples_losses
Added example common losses
- Loading branch information
Showing
4 changed files
with
288 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
from blackbirds.models.model import Model | ||
from blackbirds.simulate import simulate_and_observe_model | ||
|
||
import torch | ||
|
||
class SingleOutput_SimulateAndMSELoss: | ||
|
||
""" | ||
Computes MSE between observed data y and simulated data at theta (to be passed during __call__). | ||
**Arguments** | ||
- `model`: An instance of a Model. The model that you'd like to "fit". | ||
- `gradient_horizon`: Specifies the gradient horizon to use. None implies infinite horizon. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
model: Model, | ||
gradient_horizon: int | None = None | ||
): | ||
|
||
self.loss = torch.nn.MSELoss() | ||
self.model = model | ||
self.gradient_horizon = gradient_horizon | ||
|
||
def __call__( | ||
self, | ||
theta: torch.Tensor, | ||
y: torch.Tensor, | ||
): | ||
|
||
x = simulate_and_observe_model( | ||
self.model, | ||
theta, | ||
self.gradient_horizon | ||
)[0] | ||
return self.loss(x, y) | ||
|
||
class UnivariateMMDLoss: | ||
def __init__( | ||
self, | ||
y: torch.Tensor | ||
): | ||
|
||
""" | ||
Computes MMD between data y and simulated output x (to be passed during call). | ||
Assumes y is a torch.Tensor consisting of a single univariate time series. | ||
""" | ||
|
||
assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here" | ||
try: | ||
assert len(y.shape) == 1, "This class assumes y is a single univariate time series" | ||
except AssertionError: | ||
assert len(y.shape) == 2, "If not a 1D Tensor, y must be at most 2D of shape (1, T)" | ||
assert y.shape[1] == 1, "This class assumes y is a single univariate time series. This appears to be a batch of data." | ||
y = y.reshape(-1) | ||
|
||
self.y = y | ||
self.y_matrix = self.y.reshape(1,-1,1) | ||
yy = torch.cdist(self.y_matrix, self.y_matrix) | ||
yy_sqrd = torch.pow(yy, 2) | ||
self.y_sigma = torch.median(yy_sqrd) | ||
ny = self.y.shape[0] | ||
self.kyy = ( | ||
torch.exp( | ||
-yy_sqrd / self.y_sigma | ||
) | ||
- torch.eye(ny) | ||
).sum() / (ny * (ny - 1)) | ||
|
||
def __call__( | ||
self, | ||
x: torch.Tensor, | ||
): | ||
|
||
assert isinstance(x, torch.Tensor), "x is assumed to be a torch.Tensor here" | ||
try: | ||
assert len(x.shape) == 1, "This class assumes x is a single univariate time series" | ||
except AssertionError: | ||
assert len(x.shape) == 2, "If not a 1D Tensor, x must be at most 2D of shape (1, T)" | ||
assert x.shape[1] == 1, "This class assumes x is a single univariate time series. This appears to be a batch of data." | ||
x = x.reshape(-1) | ||
|
||
nx = x.shape[0] | ||
x_matrix = x.reshape(1,-1,1) | ||
kxx = torch.exp( | ||
-torch.pow( | ||
torch.cdist(x_matrix, x_matrix), | ||
2) | ||
/ self.y_sigma | ||
) | ||
kxx = (kxx - torch.eye(nx)).sum() / (nx * (nx - 1)) | ||
kxy = torch.exp( - torch.pow(torch.cdist(x_matrix, self.y_matrix), 2) / self.y_sigma ) | ||
kxy = kxy.mean() | ||
return kxx + self.kyy - 2 * kxy | ||
|
||
class SingleOutput_SimulateAndMMD: | ||
|
||
""" | ||
Example implementation of a loss that simulates from the model and computes the MMD | ||
between the model output and observed data y. (This treats the entries in y and in | ||
the simulator output as exchangeable.) | ||
**Arguments** | ||
- `y`: torch.Tensor containing a single univariate time series. | ||
- `model`: An instance of a Model. | ||
- `gradient_horizon`: An integer or None. Sets horizon over which gradients are retained. If None, infinite horizon used. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
y: torch.Tensor, | ||
model: Model, | ||
gradient_horizon: int | None = None | ||
): | ||
|
||
self.mmd_loss = UnivariateMMDLoss(y) | ||
self.model = model | ||
self.gradient_horizon = gradient_horizon | ||
|
||
def __call__( | ||
self, | ||
theta: torch.Tensor, | ||
y: torch.Tensor | ||
): | ||
|
||
x = simulate_and_observe_model(self.model, theta, self.gradient_horizon)[0] | ||
return self.mmd_loss(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
|
||
from blackbirds.models.model import Model | ||
|
||
class Normal(Model): | ||
|
||
def __init__( | ||
self, | ||
n_timesteps: int, | ||
): | ||
|
||
super().__init__() | ||
self.n_timesteps = n_timesteps | ||
|
||
def initialize(self, params): | ||
return torch.zeros(1).reshape(1, 1) | ||
|
||
def trim_time_series(self, x): | ||
return x[-1:] | ||
|
||
def step(self, params, x): | ||
|
||
mu, sigma = params | ||
assert sigma > 0, "Argument sigma must be a float greater than 0." | ||
return mu + sigma * torch.randn((1,1)) | ||
|
||
def observe(self, x): | ||
return [x] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
import numpy as np | ||
from blackbirds.models.normal import Normal | ||
|
||
|
||
class TestRandomWalk: | ||
def test__result(self): | ||
torch.manual_seed(0) | ||
n_timesteps = 5000 | ||
normal = Normal(n_timesteps) | ||
assert normal.n_timesteps == n_timesteps | ||
mu, sigma = 1., 3. | ||
params = torch.Tensor([mu, sigma]) | ||
x = normal.run(params) | ||
assert x.shape == (n_timesteps + 1, 1) | ||
trajectory = normal.observe(x)[0] | ||
mean = trajectory.mean() | ||
std = trajectory.std() | ||
# Sample mean should approximately match mu | ||
assert np.isclose( | ||
mean, mu, atol=1e-1 | ||
) | ||
# Sample std should approximately match sigma | ||
assert np.isclose( | ||
std, sigma, rtol=1e-1 | ||
) | ||
# Should also have approximately 0. skewness | ||
sample_skewness = torch.pow((trajectory - mean)/std, 3.).mean() | ||
assert np.isclose( | ||
sample_skewness, 0., atol=1e-1 | ||
) | ||
|
||
def test__gradient(self): | ||
mu, sigma = 0.4, 1.5 | ||
params = torch.Tensor([mu, sigma]) | ||
params.requires_grad = True | ||
n_timesteps = 100 | ||
normal = Normal(n_timesteps) | ||
assert normal.n_timesteps == n_timesteps | ||
x = normal.observe(normal.run(params))[0] | ||
assert x.shape == (n_timesteps + 1, 1) | ||
x.sum().backward() | ||
assert params.grad is not None | ||
assert np.isclose(params.grad[0], n_timesteps) | ||
grad_sigmas = [] | ||
|
||
def mean_x(params): | ||
x = normal.observe(normal.run(params))[0] | ||
return x.mean() | ||
|
||
for t in range(500): | ||
params = torch.Tensor([mu, sigma]) | ||
params.requires_grad = True | ||
sigma_grad = autograd.grad(mean_x(params), params)[-1] | ||
grad_sigmas.append(sigma_grad[-1]) | ||
|
||
assert np.isclose(torch.stack(grad_sigmas).mean(), 0., atol=1e-2) | ||
expected_std = 1. / np.sqrt(n_timesteps) | ||
assert np.isclose(torch.stack(grad_sigmas).std().item(), expected_std, atol=1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from blackbirds.losses import SingleOutput_SimulateAndMSELoss, SingleOutput_SimulateAndMMD | ||
from blackbirds.models.normal import Normal | ||
from blackbirds.models.random_walk import RandomWalk | ||
from blackbirds.simulate import simulate_and_observe_model | ||
|
||
class TestMSELoss: | ||
|
||
def test_normal_same(self): | ||
|
||
T = 5000 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mse_loss = SingleOutput_SimulateAndMSELoss(normal) | ||
loss_value = mse_loss(params, y) | ||
# MMD between distributions of hould be close to... | ||
value = 2 * sigma**2 | ||
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) | ||
|
||
def test_normal_different(self): | ||
|
||
T = 5000 | ||
normal = Normal(n_timesteps=T) | ||
mu_y, sigma_y = 1., 2. | ||
params = torch.Tensor([mu_y, sigma_y]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mse_loss = SingleOutput_SimulateAndMSELoss(normal) | ||
mu_x, sigma_x = 0., 1. | ||
params = torch.Tensor([mu_x, sigma_x]) | ||
loss_value = mse_loss(params, y) | ||
# MMD between distributions of hould be close to... | ||
value = sigma_x**2 + sigma_y**2 + (mu_x - mu_y)**2 | ||
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) | ||
|
||
|
||
class TestMMDLoss: | ||
|
||
def test_normal_same(self): | ||
|
||
T = 500 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mmd_loss = SingleOutput_SimulateAndMMD(y, normal) | ||
loss_value = mmd_loss(params, y) | ||
# MMD between distributions of hould be close to 0 | ||
assert np.isclose(loss_value, 0., atol=1e-3) | ||
|
||
def test_normal_different(self): | ||
|
||
T = 500 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mmd_loss = SingleOutput_SimulateAndMMD(y, normal) | ||
mu, sigma = 0., 1. | ||
params = torch.Tensor([mu, sigma]) | ||
loss_value = mmd_loss(params, y) | ||
# MMD between distributions of hould be close to 0 | ||
assert not np.isclose(loss_value, 0.) | ||
assert loss_value > 0. | ||
|
||
|