-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added tests for MMD loss and MSE losses. Introduced new simple Normal…
… model for this reason, along with corresponding tests
- Loading branch information
1 parent
8dc8b86
commit 06f7188
Showing
4 changed files
with
173 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import torch | ||
|
||
from blackbirds.models.model import Model | ||
|
||
class Normal(Model): | ||
|
||
def __init__( | ||
self, | ||
n_timesteps: int, | ||
): | ||
|
||
super().__init__() | ||
self.n_timesteps = n_timesteps | ||
|
||
def initialize(self, params): | ||
return torch.zeros(1).reshape(1, 1) | ||
|
||
def trim_time_series(self, x): | ||
return x[-1:] | ||
|
||
def step(self, params, x): | ||
|
||
mu, sigma = params | ||
assert sigma > 0, "Argument sigma must be a float greater than 0." | ||
return mu + sigma * torch.randn((1,1)) | ||
|
||
def observe(self, x): | ||
return [x] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import torch | ||
import torch.autograd as autograd | ||
import numpy as np | ||
from blackbirds.models.normal import Normal | ||
|
||
|
||
class TestRandomWalk: | ||
def test__result(self): | ||
torch.manual_seed(0) | ||
n_timesteps = 5000 | ||
normal = Normal(n_timesteps) | ||
assert normal.n_timesteps == n_timesteps | ||
mu, sigma = 1., 3. | ||
params = torch.Tensor([mu, sigma]) | ||
x = normal.run(params) | ||
assert x.shape == (n_timesteps + 1, 1) | ||
trajectory = normal.observe(x)[0] | ||
mean = trajectory.mean() | ||
std = trajectory.std() | ||
# Sample mean should approximately match mu | ||
assert np.isclose( | ||
mean, mu, atol=1e-1 | ||
) | ||
# Sample std should approximately match sigma | ||
assert np.isclose( | ||
std, sigma, rtol=1e-1 | ||
) | ||
# Should also have approximately 0. skewness | ||
sample_skewness = torch.pow((trajectory - mean)/std, 3.).mean() | ||
assert np.isclose( | ||
sample_skewness, 0., atol=1e-1 | ||
) | ||
|
||
def test__gradient(self): | ||
mu, sigma = 0.4, 1.5 | ||
params = torch.Tensor([mu, sigma]) | ||
params.requires_grad = True | ||
n_timesteps = 100 | ||
normal = Normal(n_timesteps) | ||
assert normal.n_timesteps == n_timesteps | ||
x = normal.observe(normal.run(params))[0] | ||
assert x.shape == (n_timesteps + 1, 1) | ||
x.sum().backward() | ||
assert params.grad is not None | ||
assert np.isclose(params.grad[0], n_timesteps) | ||
grad_sigmas = [] | ||
|
||
def mean_x(params): | ||
x = normal.observe(normal.run(params))[0] | ||
return x.mean() | ||
|
||
for t in range(500): | ||
params = torch.Tensor([mu, sigma]) | ||
params.requires_grad = True | ||
sigma_grad = autograd.grad(mean_x(params), params)[-1] | ||
grad_sigmas.append(sigma_grad[-1]) | ||
|
||
assert np.isclose(torch.stack(grad_sigmas).mean(), 0., atol=1e-2) | ||
expected_std = 1. / np.sqrt(n_timesteps) | ||
assert np.isclose(torch.stack(grad_sigmas).std().item(), expected_std, atol=1e-2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import numpy as np | ||
import torch | ||
|
||
from blackbirds.losses import SingleOutput_SimulateAndMSELoss, SingleOutput_SimulateAndMMD | ||
from blackbirds.models.normal import Normal | ||
from blackbirds.models.random_walk import RandomWalk | ||
from blackbirds.simulate import simulate_and_observe_model | ||
|
||
class TestMSELoss: | ||
|
||
def test_normal_same(self): | ||
|
||
T = 5000 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mse_loss = SingleOutput_SimulateAndMSELoss(normal) | ||
loss_value = mse_loss(params, y) | ||
# MMD between distributions of hould be close to... | ||
value = 2 * sigma**2 | ||
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) | ||
|
||
def test_normal_different(self): | ||
|
||
T = 5000 | ||
normal = Normal(n_timesteps=T) | ||
mu_y, sigma_y = 1., 2. | ||
params = torch.Tensor([mu_y, sigma_y]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mse_loss = SingleOutput_SimulateAndMSELoss(normal) | ||
mu_x, sigma_x = 0., 1. | ||
params = torch.Tensor([mu_x, sigma_x]) | ||
loss_value = mse_loss(params, y) | ||
# MMD between distributions of hould be close to... | ||
value = sigma_x**2 + sigma_y**2 + (mu_x - mu_y)**2 | ||
assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) | ||
|
||
|
||
class TestMMDLoss: | ||
|
||
def test_normal_same(self): | ||
|
||
T = 500 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mmd_loss = SingleOutput_SimulateAndMMD(y, normal) | ||
loss_value = mmd_loss(params, y) | ||
# MMD between distributions of hould be close to 0 | ||
assert np.isclose(loss_value, 0., atol=1e-3) | ||
|
||
def test_normal_different(self): | ||
|
||
T = 500 | ||
normal = Normal(n_timesteps=T) | ||
mu, sigma = 1., 2. | ||
params = torch.Tensor([mu, sigma]) | ||
y = simulate_and_observe_model(normal, params)[0] | ||
mmd_loss = SingleOutput_SimulateAndMMD(y, normal) | ||
mu, sigma = 0., 1. | ||
params = torch.Tensor([mu, sigma]) | ||
loss_value = mmd_loss(params, y) | ||
# MMD between distributions of hould be close to 0 | ||
assert not np.isclose(loss_value, 0.) | ||
assert loss_value > 0. | ||
|
||
|