A collection of simple implementations of long-range sequence models, including LRU, S5, and S4. More implementations to come.
$ pip install long-range-models
This library offers detailed documentation for every module and layer implemented. Models are created by composing different pieces together. Check out the examples below.
Consider a language model built with an LRU sequence layer and the architecture proposed in the S4 paper:
from functools import partial
import jax.random as jrandom
from long_range_models import SequenceModel, S4Module, LRULayer
rng = jrandom.PRNGKey(0)
model = SequenceModel(
num_tokens=1000,
module=S4Module(
sequence_layer=partial(LRULayer, state_dim=256),
dim=128,
depth=6,
),
)
x = jrandom.randint(rng, (1, 1024), 0, 1000)
variables = model.init(rng, x)
model.apply(variables, x) # (1, 1024, 1000)
For sequences with continuous values, the setup looks as follows:
from functools import partial
import jax.random as jrandom
from long_range_models import ContinuousSequenceModel, S4Module, LRULayer
rng = jrandom.PRNGKey(0)
model = ContinuousSequenceModel(
out_dim=10,
module=S4Module(
sequence_layer=partial(LRULayer, state_dim=256),
dim=128,
depth=6,
),
)
x = jrandom.normal(rng, (1, 1024, 32))
variables = model.init(rng, x)
model.apply(variables, x) # (1, 1024, 10)
Note: both model types offer several customization options. Make sure to check out their documentation.
- More implementations: Extend the library with models like S4D, S4Liquid, BiGS, Hyena, RetNet, SGConv, H3, and others.
- Customization: Allow users to better customize currently implemented layers and architectures (e.g., activation functions, initialization, etc.).
- Sequential API: Allow recurrent models to run sequentially, allowing for efficient inference.