Skip to content

Commit

Permalink
hyper connect the audio attention models
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 12, 2025
1 parent 2154a74 commit d65fd15
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 9 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -599,3 +599,14 @@ $ accelerate launch train.py
url = {https://api.semanticscholar.org/CorpusID:275405495}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
32 changes: 24 additions & 8 deletions audiolm_pytorch/audiolm_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME

from hyper_connections import get_init_and_expand_reduce_stream_functions

from torchaudio.functional import resample

from audiolm_pytorch.soundstream import SoundStream
Expand Down Expand Up @@ -421,6 +423,7 @@ def __init__(
rel_pos_bias = True,
flash_attn = False,
add_value_residual = True,
num_residual_streams = 4,
**kwargs
):
super().__init__()
Expand All @@ -438,11 +441,17 @@ def __init__(

self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads) if rel_pos_bias else None

# hyper connections

init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

# layers

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs),
Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None,
FeedForward(dim = dim, dropout = ff_dropout)
init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs)),
init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs)) if cross_attend else None,
init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, dropout = ff_dropout))
]))

self.norm = LayerNorm(dim)
Expand Down Expand Up @@ -510,6 +519,10 @@ def forward(
self_attn_value_residual = None
cross_attn_value_residual = None

# expand residual streams

x = self.expand_streams(x)

# transformer layers

for attn, cross_attn, ff in self.layers:
Expand All @@ -523,18 +536,21 @@ def forward(

new_kv_cache.append(layer_kv_cache)

x = x + residual

if exists(cross_attn):
assert exists(context)

cross_attend_out, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual)
x = cross_attend_out + x
x, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual)

if self.add_value_residual:
cross_attn_value_residual = default(cross_attn_value_residual, values)

x = ff(x) + x
x = ff(x)

# reduce residual streams

x = self.reduce_streams(x)

# final norm

x = self.norm(x)

Expand Down
2 changes: 1 addition & 1 deletion audiolm_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.3.1'
__version__ = '2.4.0'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
'fairseq',
'wandb',
'gateloop-transformer>=0.2.3',
'hyper-connections>=0.1.8',
'joblib',
'local-attention>=1.9.0',
'pytorch-warmup',
Expand Down

0 comments on commit d65fd15

Please sign in to comment.