diff --git a/blackbirds/losses.py b/blackbirds/losses.py new file mode 100644 index 0000000..e1ab5f0 --- /dev/null +++ b/blackbirds/losses.py @@ -0,0 +1,131 @@ +from blackbirds.models.model import Model +from blackbirds.simulate import simulate_and_observe_model + +import torch + +class SingleOutput_SimulateAndMSELoss: + + """ + Computes MSE between observed data y and simulated data at theta (to be passed during __call__). + + **Arguments** + + - `model`: An instance of a Model. The model that you'd like to "fit". + - `gradient_horizon`: Specifies the gradient horizon to use. None implies infinite horizon. + """ + + def __init__( + self, + model: Model, + gradient_horizon: int | None = None + ): + + self.loss = torch.nn.MSELoss() + self.model = model + self.gradient_horizon = gradient_horizon + + def __call__( + self, + theta: torch.Tensor, + y: torch.Tensor, + ): + + x = simulate_and_observe_model( + self.model, + theta, + self.gradient_horizon + )[0] + return self.loss(x, y) + +class UnivariateMMDLoss: + def __init__( + self, + y: torch.Tensor + ): + + """ + Computes MMD between data y and simulated output x (to be passed during call). + + Assumes y is a torch.Tensor consisting of a single univariate time series. + """ + + assert isinstance(y, torch.Tensor), "y is assumed to be a torch.Tensor here" + try: + assert len(y.shape) == 1, "This class assumes y is a single univariate time series" + except AssertionError: + assert len(y.shape) == 2, "If not a 1D Tensor, y must be at most 2D of shape (1, T)" + assert y.shape[1] == 1, "This class assumes y is a single univariate time series. This appears to be a batch of data." + y = y.reshape(-1) + + self.y = y + self.y_matrix = self.y.reshape(1,-1,1) + yy = torch.cdist(self.y_matrix, self.y_matrix) + yy_sqrd = torch.pow(yy, 2) + self.y_sigma = torch.median(yy_sqrd) + ny = self.y.shape[0] + self.kyy = ( + torch.exp( + -yy_sqrd / self.y_sigma + ) + - torch.eye(ny) + ).sum() / (ny * (ny - 1)) + + def __call__( + self, + x: torch.Tensor, + ): + + assert isinstance(x, torch.Tensor), "x is assumed to be a torch.Tensor here" + try: + assert len(x.shape) == 1, "This class assumes x is a single univariate time series" + except AssertionError: + assert len(x.shape) == 2, "If not a 1D Tensor, x must be at most 2D of shape (1, T)" + assert x.shape[1] == 1, "This class assumes x is a single univariate time series. This appears to be a batch of data." + x = x.reshape(-1) + + nx = x.shape[0] + x_matrix = x.reshape(1,-1,1) + kxx = torch.exp( + -torch.pow( + torch.cdist(x_matrix, x_matrix), + 2) + / self.y_sigma + ) + kxx = (kxx - torch.eye(nx)).sum() / (nx * (nx - 1)) + kxy = torch.exp( - torch.pow(torch.cdist(x_matrix, self.y_matrix), 2) / self.y_sigma ) + kxy = kxy.mean() + return kxx + self.kyy - 2 * kxy + +class SingleOutput_SimulateAndMMD: + + """ + Example implementation of a loss that simulates from the model and computes the MMD + between the model output and observed data y. (This treats the entries in y and in + the simulator output as exchangeable.) + + **Arguments** + + - `y`: torch.Tensor containing a single univariate time series. + - `model`: An instance of a Model. + - `gradient_horizon`: An integer or None. Sets horizon over which gradients are retained. If None, infinite horizon used. + """ + + def __init__( + self, + y: torch.Tensor, + model: Model, + gradient_horizon: int | None = None + ): + + self.mmd_loss = UnivariateMMDLoss(y) + self.model = model + self.gradient_horizon = gradient_horizon + + def __call__( + self, + theta: torch.Tensor, + y: torch.Tensor + ): + + x = simulate_and_observe_model(self.model, theta, self.gradient_horizon)[0] + return self.mmd_loss(x) diff --git a/blackbirds/models/normal.py b/blackbirds/models/normal.py new file mode 100644 index 0000000..500e77e --- /dev/null +++ b/blackbirds/models/normal.py @@ -0,0 +1,28 @@ +import torch + +from blackbirds.models.model import Model + +class Normal(Model): + + def __init__( + self, + n_timesteps: int, + ): + + super().__init__() + self.n_timesteps = n_timesteps + + def initialize(self, params): + return torch.zeros(1).reshape(1, 1) + + def trim_time_series(self, x): + return x[-1:] + + def step(self, params, x): + + mu, sigma = params + assert sigma > 0, "Argument sigma must be a float greater than 0." + return mu + sigma * torch.randn((1,1)) + + def observe(self, x): + return [x] diff --git a/test/models/test_normal.py b/test/models/test_normal.py new file mode 100644 index 0000000..c8ed101 --- /dev/null +++ b/test/models/test_normal.py @@ -0,0 +1,60 @@ +import torch +import torch.autograd as autograd +import numpy as np +from blackbirds.models.normal import Normal + + +class TestRandomWalk: + def test__result(self): + torch.manual_seed(0) + n_timesteps = 5000 + normal = Normal(n_timesteps) + assert normal.n_timesteps == n_timesteps + mu, sigma = 1., 3. + params = torch.Tensor([mu, sigma]) + x = normal.run(params) + assert x.shape == (n_timesteps + 1, 1) + trajectory = normal.observe(x)[0] + mean = trajectory.mean() + std = trajectory.std() + # Sample mean should approximately match mu + assert np.isclose( + mean, mu, atol=1e-1 + ) + # Sample std should approximately match sigma + assert np.isclose( + std, sigma, rtol=1e-1 + ) + # Should also have approximately 0. skewness + sample_skewness = torch.pow((trajectory - mean)/std, 3.).mean() + assert np.isclose( + sample_skewness, 0., atol=1e-1 + ) + + def test__gradient(self): + mu, sigma = 0.4, 1.5 + params = torch.Tensor([mu, sigma]) + params.requires_grad = True + n_timesteps = 100 + normal = Normal(n_timesteps) + assert normal.n_timesteps == n_timesteps + x = normal.observe(normal.run(params))[0] + assert x.shape == (n_timesteps + 1, 1) + x.sum().backward() + assert params.grad is not None + assert np.isclose(params.grad[0], n_timesteps) + grad_sigmas = [] + + def mean_x(params): + x = normal.observe(normal.run(params))[0] + return x.mean() + + for t in range(500): + params = torch.Tensor([mu, sigma]) + params.requires_grad = True + sigma_grad = autograd.grad(mean_x(params), params)[-1] + grad_sigmas.append(sigma_grad[-1]) + + assert np.isclose(torch.stack(grad_sigmas).mean(), 0., atol=1e-2) + expected_std = 1. / np.sqrt(n_timesteps) + assert np.isclose(torch.stack(grad_sigmas).std().item(), expected_std, atol=1e-2) diff --git a/test/test_losses.py b/test/test_losses.py new file mode 100644 index 0000000..793e103 --- /dev/null +++ b/test/test_losses.py @@ -0,0 +1,69 @@ +import numpy as np +import torch + +from blackbirds.losses import SingleOutput_SimulateAndMSELoss, SingleOutput_SimulateAndMMD +from blackbirds.models.normal import Normal +from blackbirds.models.random_walk import RandomWalk +from blackbirds.simulate import simulate_and_observe_model + +class TestMSELoss: + + def test_normal_same(self): + + T = 5000 + normal = Normal(n_timesteps=T) + mu, sigma = 1., 2. + params = torch.Tensor([mu, sigma]) + y = simulate_and_observe_model(normal, params)[0] + mse_loss = SingleOutput_SimulateAndMSELoss(normal) + loss_value = mse_loss(params, y) + # MMD between distributions of hould be close to... + value = 2 * sigma**2 + assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) + + def test_normal_different(self): + + T = 5000 + normal = Normal(n_timesteps=T) + mu_y, sigma_y = 1., 2. + params = torch.Tensor([mu_y, sigma_y]) + y = simulate_and_observe_model(normal, params)[0] + mse_loss = SingleOutput_SimulateAndMSELoss(normal) + mu_x, sigma_x = 0., 1. + params = torch.Tensor([mu_x, sigma_x]) + loss_value = mse_loss(params, y) + # MMD between distributions of hould be close to... + value = sigma_x**2 + sigma_y**2 + (mu_x - mu_y)**2 + assert np.isclose(loss_value, value, atol=1e-2, rtol=5e-2) + + +class TestMMDLoss: + + def test_normal_same(self): + + T = 500 + normal = Normal(n_timesteps=T) + mu, sigma = 1., 2. + params = torch.Tensor([mu, sigma]) + y = simulate_and_observe_model(normal, params)[0] + mmd_loss = SingleOutput_SimulateAndMMD(y, normal) + loss_value = mmd_loss(params, y) + # MMD between distributions of hould be close to 0 + assert np.isclose(loss_value, 0., atol=1e-3) + + def test_normal_different(self): + + T = 500 + normal = Normal(n_timesteps=T) + mu, sigma = 1., 2. + params = torch.Tensor([mu, sigma]) + y = simulate_and_observe_model(normal, params)[0] + mmd_loss = SingleOutput_SimulateAndMMD(y, normal) + mu, sigma = 0., 1. + params = torch.Tensor([mu, sigma]) + loss_value = mmd_loss(params, y) + # MMD between distributions of hould be close to 0 + assert not np.isclose(loss_value, 0.) + assert loss_value > 0. + +