Skip to content

Commit bd1a58e

Browse files
committed
Fix broken paths from standalone model reorganization
1 parent dc9e72e commit bd1a58e

File tree

5 files changed

+9
-22
lines changed

5 files changed

+9
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,9 @@ To address differences between models trained on earlier versions and the curren
5454
#### Models and Features
5555
- Updated version of S4 module, including new measures and theory from [[How to Train Your HiPPO](https://arxiv.org/abs/2206.12037)] (https://github.com/HazyResearch/state-spaces/issues/21, https://github.com/HazyResearch/state-spaces/issues/54)
5656
- Complete version of S4D module from [[On the Parameterization and Initialization of Diagonal State Space Models](https://arxiv.org/abs/2206.11893)]
57-
- [State forwarding](src/models/s4/README.md#state-forwarding) (https://github.com/HazyResearch/state-spaces/issues/49, https://github.com/HazyResearch/state-spaces/issues/56)
58-
- Support for S4 variants including DSS and GSS ([documentation](src/models/s4/README.md#other-variants))
57+
- [State forwarding](models/s4/README.md#state-forwarding) (https://github.com/HazyResearch/state-spaces/issues/49, https://github.com/HazyResearch/state-spaces/issues/56)
58+
- Support for S4 variants including DSS and GSS ([documentation](models/s4/README.md#other-variants))
5959

60-
<!--
61-
#### Compilation of additional resources
62-
- Recommended resources for understanding S4-style models, including the [Simplifying S4 blog](https://hazyresearch.stanford.edu/blog/2022-06-11-simplifying-s4) ([code](https://github.com/HazyResearch/state-spaces/tree/simple/src/models/sequence/ss/s4_simple)) and a minimal pedagogical version of S4D ([code](src/models/s4/s4d.py))
63-
- Tips & Tricks page for getting started with tuning S4
64-
-->
6560

6661
#### Bug fixes and library compatibility issues
6762
- PyTorch 1.11 had a [Dropout bug](https://github.com/pytorch/pytorch/issues/77081) which is now avoided with a custom Dropout implementation (https://github.com/HazyResearch/state-spaces/issues/42, https://github.com/HazyResearch/state-spaces/issues/22)

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ Installation usually works out of the box with `pip install pykeops cmake` which
6868

6969
### S4 Module
7070

71-
Self-contained files for the S4 layer and variants can be found in [src/models/s4/](./src/models/s4/),
71+
Self-contained files for the S4 layer and variants can be found in [models/s4/](./models/s4/),
7272
which includes instructions for calling the module.
7373

7474
See [notebooks/](notebooks/) for visualizations explaining some concepts behind HiPPO and S4.
@@ -97,7 +97,7 @@ One important feature of this codebase is supporting parameters that require dif
9797
In particular, the SSM kernel is particularly sensitive to the $(A, B)$ (and sometimes $\Delta$ parameters),
9898
so the learning rate on these parameters is sometimes lowered and the weight decay is always set to $0$.
9999

100-
See the method `register` in the model (e.g. [s4d.py](src/models/s4/s4d.py)) and the function `setup_optimizer` in the training script (e.g. [example.py](example.py)) for an examples of how to implement this in external repos.
100+
See the method `register` in the model (e.g. [s4d.py](py)) and the function `setup_optimizer` in the training script (e.g. [example.py](example.py)) for an examples of how to implement this in external repos.
101101

102102
<!--
103103
Our logic for setting these parameters can be found in the `OptimModule` class under `src/models/sequence/ss/kernel.py` and the corresponding optimizer hook in `SequenceLightningModule.configure_optimizers` under `train.py`

example.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
This code borrows heavily from https://github.com/kuangliu/pytorch-cifar.
44
55
This file only depends on the standalone S4 layer
6-
available in src/models/s4/
6+
available in /models/s4/
77
88
* Train standard sequential CIFAR:
99
python -m example

models/sashimi/sashimi.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
SaShiMi backbone.
33
44
Use this backbone in your own models. You'll also need to copy over the
5-
standalone S4 layer, which can be found at `state-spaces/src/models/s4/`
5+
standalone S4 layer, which can be found at `state-spaces/models/s4/`
66
77
It's Raw! Audio Generation with State-Space Models
88
Karan Goel, Albert Gu, Chris Donahue, Christopher Re.
@@ -16,7 +16,7 @@
1616

1717
from einops import rearrange
1818

19-
from src.models.s4.s4 import LinearActivation, S4
19+
from models.s4.s4 import LinearActivation, S4Block as S4
2020

2121
class DownPool(nn.Module):
2222
def __init__(self, d_input, expand, pool):

src/utils/registry.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"model": "src.models.sequence.backbones.model.SequenceModel",
3737
"unet": "src.models.sequence.backbones.unet.SequenceUNet",
3838
"sashimi": "src.models.sequence.backbones.sashimi.Sashimi",
39-
"sashimi_standalone": "sashimi.sashimi.Sashimi",
39+
"sashimi_standalone": "models.sashimi.sashimi.Sashimi",
4040
# Baseline RNNs
4141
"lstm": "src.models.baselines.lstm.TorchLSTM",
4242
"gru": "src.models.baselines.gru.TorchGRU",
@@ -46,12 +46,8 @@
4646
"stackedrnn": "src.models.baselines.samplernn.StackedRNN",
4747
"stackedrnn_baseline": "src.models.baselines.samplernn.StackedRNNBaseline",
4848
"samplernn": "src.models.baselines.samplernn.SampleRNN",
49-
"dcgru": "src.models.baselines.dcgru.DCRNNModel_classification",
50-
"dcgru_ss": "src.models.baselines.dcgru.DCRNNModel_nextTimePred",
5149
# Baseline CNNs
5250
"ckconv": "src.models.baselines.ckconv.ClassificationCKCNN",
53-
"wavegan": "src.models.baselines.wavegan.WaveGANDiscriminator", # DEPRECATED
54-
"denseinception": "src.models.baselines.dense_inception.DenseInception",
5551
"wavenet": "src.models.baselines.wavenet.WaveNetModel",
5652
"torch/resnet2d": "src.models.baselines.resnet.TorchVisionResnet", # 2D ResNet
5753
# Nonaka 1D CNN baselines
@@ -69,15 +65,13 @@
6965
"timm/convnext_micro": "src.models.baselines.convnext_timm.convnext_micro",
7066
"timm/resnet50": "src.models.baselines.resnet_timm.resnet50", # Can also register many other variants in resnet_timm
7167
"timm/convnext_tiny_3d": "src.models.baselines.convnext_timm.convnext3d_tiny",
72-
# Segmentation models
73-
"convnext_unet_tiny": "src.models.segmentation.convnext_unet.convnext_tiny_unet",
7468
}
7569

7670
layer = {
7771
"id": "src.models.sequence.base.SequenceIdentity",
7872
"lstm": "src.models.baselines.lstm.TorchLSTM",
7973
"standalone": "models.s4.s4.S4Block",
80-
"s4d": "src.models.s4.s4d.S4D",
74+
"s4d": "models.s4.s4d.S4D",
8175
"ffn": "src.models.sequence.modules.ffn.FFN",
8276
"sru": "src.models.sequence.rnns.sru.SRURNN",
8377
"rnn": "src.models.sequence.rnns.rnn.RNN", # General RNN wrapper
@@ -90,8 +84,6 @@
9084
"s4": "src.models.sequence.modules.s4block.S4Block",
9185
"s4nd": "src.models.sequence.modules.s4nd.S4ND",
9286
"mega": "src.models.sequence.modules.mega.MegaBlock",
93-
"h3": "src.models.sequence.experimental.h3.H3",
94-
"h4": "src.models.sequence.experimental.h4.H4",
9587
# 'packedrnn': 'models.sequence.rnns.packedrnn.PackedRNN',
9688
}
9789

0 commit comments

Comments
 (0)