A Python package based on JAX for linear and nonlinear system identification of state-space models, recurrent neural network (RNN) training, and nonlinear regression.
-
- Linear state-space models
- Nonlinear system identification and RNNs
- [Static models and nonlinear regression] (#static)
jax-sysid is a Python package based on JAX for linear and nonlinear system identification of state-space models, recurrent neural network (RNN) training, and nonlinear regression. The algorithm can handle L1-regularization and group-Lasso regularization and relies on L-BFGS optimization for accurate modeling, fast convergence, and good sparsification of model coefficients.
The package implements the approach described in the following paper:
[1] A. Bemporad, "Linear and nonlinear system identification under $\ell_1$- and group-Lasso regularization via L-BFGS-B," submitted for publication. Available on arXiv at http://arxiv.org/abs/2403.03827, 2024. [bib entry]
pip install jax-sysid
Given input/output training data
where
The training problem to solve is
where
The regularization term
with
Let's start training a discrete-time linear model
from jax_sysid.models import LinearModel
model = LinearModel(nx, ny, nu)
model.loss(rho_x0=1.e-3, rho_th=1.e-2)
model.optimization(lbfgs_epochs=1000)
model.fit(Y,U)
Yhat, Xhat = model.predict(model.x0, U)
After identifying the model, to retrieve the resulting state-space realization you can use the following:
A,B,C,D = model.ssdata()
Given a new test sequence of inputs and outputs, an initial state that is compatible with the identified model can be reconstructed by running an extended Kalman filter and Rauch–Tung–Striebel smoothing (cf. [1]) and used to simulate the model:
x0_test = model.learn_x0(U_test, Y_test)
Yhat_test, Xhat_test = model.predict(x0_test, U_test)
R2-scores on training and test data can be computed as follows:
from jax_sysid.utils import compute_scores
R2_train, R2_test, msg = compute_scores(Y, Yhat, Y_test, Yhat_test, fit='R2')
print(msg)
It is good practice to scale the input and output signals. To identify a model on scaled signals, you can use the following:
from jax_sysid.utils import standard_scale, unscale
Ys, ymean, ygain = standard_scale(Y)
Us, umean, ugain = standard_scale(U)
model.fit(Ys, Us)
Yshat, Xhat = model.predict(model.x0, Us)
Yhat = unscale(Yshat, ymean, ygain)
Let us now retrain the model using L1-regularization and check the sparsity of the resulting model:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_th=0.03)
model.fit(Ys, Us)
print(model.sparsity_analysis())
To reduce the number of states in the model, you can use group-Lasso regularization as follows:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_g=0.1)
model.group_lasso_x()
model.fit(Ys, Us)
Groups in this case are entries in A,B,C,x0 related to the same state.
Group-Lasso can be also used to try reducing the number of inputs that are relevant in the model. You can do this as follows:
model.loss(rho_x0=1.e-3, rho_th=1.e-2, tau_g=0.15)
model.group_lasso_u()
model.fit(Ys, Us)
Groups in this case are entries in B,D related to the same input.
jax-sysid also supports multiple training experiments. In this case, the sequences of training inputs and outputs are passed as a list of arrays. For example, if three experiments are available for training, use the following command:
model.fit([Ys1, Ys2, Ys3], [Us1, Us2, Us3])
In case the initial state train_x0=False
when calling model.loss
.
Given input/output training data
where
As for the linear case, the training problem to solve is
where
For example, let us consider the following residual RNN model without input/output feedthrough:
where
from jax_sysid.models import Model
Ys, ymean, ygain = standard_scale(Y)
Us, umean, ugain = standard_scale(U)
def sigmoid(x):
return 1. / (1. + jnp.exp(-x))
@jax.jit
def state_fcn(x,u,params):
A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4=params
return A@x+B@u+W3@sigmoid(W1@x+W2@u+b1)+b2
@jax.jit
def output_fcn(x,u,params):
A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4=params
return C@x+W5@sigmoid(W4@x+b3)+b4
model = Model(nx, ny, nu, state_fcn=state_fcn, output_fcn=output_fcn)
nnx = 5 # number of hidden neurons in state-update function
nny = 5 # number of hidden neurons in output function
# Parameter initialization:
A = 0.5*np.eye(nx)
B = 0.1*np.random.randn(nx,nu)
C = 0.1*np.random.randn(ny,nx)
W1 = 0.1*np.random.randn(nnx,nx)
W2 = 0.5*np.random.randn(nnx,nu)
W3 = 0.5*np.random.randn(nx,nnx)
b1 = np.zeros(nnx)
b2 = np.zeros(nx)
W4 = 0.5*np.random.randn(nny,nx)
W5 = 0.5*np.random.randn(ny,nny)
b3 = np.zeros(nny)
b4 = np.zeros(ny)
model.init(params=[A,B,C,W1,W2,W3,b1,b2,W4,W5,b3,b4])
model.loss(rho_x0=1.e-4, rho_th=1.e-4)
model.optimization(adam_epochs=1000, lbfgs_epochs=1000)
model.fit(Ys, Us)
Yshat, Xshat = model.predict(model.x0, Us)
Yhat = unscale(Yshat, ymean, ygain)
jax-sysid also supports recurrent neural networks defined via the flax.linen library:
from jax_sysid.models import RNN
# state-update function
class FX(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=5)(x)
x = nn.swish(x)
x = nn.Dense(features=5)(x)
x = nn.swish(x)
x = nn.Dense(features=nx)(x)
return x
# output function
class FY(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=5)(x)
x = nn.tanh(x)
x = nn.Dense(features=ny)(x)
return x
model = RNN(nx, ny, nu, FX=FX, FY=FY, x_scaling=0.1)
model.loss(rho_x0=1.e-4, rho_th=1.e-4, tau_th=0.0001)
model.optimization(adam_epochs=0, lbfgs_epochs=2000)
model.fit(Ys, Us)
where the extra parameter x_scaling
is used to scale down (when x_scaling
jax-sysid also supports custom loss functions penalizing the deviations of
where
epsil=1.e-4
@jax.jit
def cross_entropy_loss(Yhat,Y):
loss=jnp.sum(-Y*jnp.log(epsil+Yhat)-(1.-Y)*jnp.log(epsil+1.-Yhat))/Y.shape[0]
return loss
model.loss(rho_x0=0.01, rho_th=0.001, output_loss=cross_entropy_loss)
By default, jax-sysid minimizes the classical mean squared error
The same optimization algorithms used to train dynamical models can be used to train static models, i.e., to solve the nonlinear regression problem:
where
For example, if the model is a shallow neural network you can use the following code:
from jax_sysid.models import StaticModel
from jax_sysid.utils import standard_scale, unscale
@jax.jit
def output_fcn(u, params):
W1,b1,W2,b2=params
y = W1@u.T+b1
y = W2@jnp.arctan(y)+b2
return y.T
model = StaticModel(ny, nu, output_fcn)
nn=10 # number of neurons
model.init(params=[np.random.randn(nn,nu), np.random.randn(nn,1), np.random.randn(1,nn), np.random.randn(1,1)])
model.loss(rho_th=1.e-4, tau_th=tau_th)
model.optimization(lbfgs_epochs=500)
model.fit(Ys, Us)
jax-sysid also supports feedforward neural networks defined via the flax.linen library:
from jax_sysid.models import FNN
from flax import linen as nn
# output function
class FY(nn.Module):
@nn.compact
def __call__(self, x):
x = nn.Dense(features=20)(x)
x = nn.tanh(x)
x = nn.Dense(features=20)(x)
x = nn.tanh(x)
x = nn.Dense(features=ny)(x)
return x
model = FNN(ny, nu, FY)
model.loss(rho_th=1.e-4, tau_th=tau_th)
model.optimization(lbfgs_epochs=500)
model.fit(Ys, Us)
This package was coded by Alberto Bemporad.
This software is distributed without any warranty. Please cite the paper below if you use this software.
@article{Bem24,
author={A. Bemporad},
title={Linear and nonlinear system identification under $\ell_1$- and group-{Lasso} regularization via {L-BFGS-B}},
note = {submitted for publication. Also available on arXiv
at \url{http://arxiv.org/abs/2403.03827}},
year=2024
}
Apache 2.0
(C) 2024 A. Bemporad