From d65fd151075d304bc7b28dfcb736e179b6020fca Mon Sep 17 00:00:00 2001 From: lucidrains Date: Sun, 12 Jan 2025 07:40:38 -0800 Subject: [PATCH] hyper connect the audio attention models --- README.md | 11 ++++++++++ audiolm_pytorch/audiolm_pytorch.py | 32 ++++++++++++++++++++++-------- audiolm_pytorch/version.py | 2 +- setup.py | 1 + 4 files changed, 37 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 3546faf..ded8219 100644 --- a/README.md +++ b/README.md @@ -599,3 +599,14 @@ $ accelerate launch train.py url = {https://api.semanticscholar.org/CorpusID:275405495} } ``` + +```bibtex +@article{Zhu2024HyperConnections, + title = {Hyper-Connections}, + author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2409.19606}, + url = {https://api.semanticscholar.org/CorpusID:272987528} +} +``` diff --git a/audiolm_pytorch/audiolm_pytorch.py b/audiolm_pytorch/audiolm_pytorch.py index 7f0f462..57210b2 100644 --- a/audiolm_pytorch/audiolm_pytorch.py +++ b/audiolm_pytorch/audiolm_pytorch.py @@ -21,6 +21,8 @@ from audiolm_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME +from hyper_connections import get_init_and_expand_reduce_stream_functions + from torchaudio.functional import resample from audiolm_pytorch.soundstream import SoundStream @@ -421,6 +423,7 @@ def __init__( rel_pos_bias = True, flash_attn = False, add_value_residual = True, + num_residual_streams = 4, **kwargs ): super().__init__() @@ -438,11 +441,17 @@ def __init__( self.rel_pos_bias = RelativePositionBias(dim = dim // 2, heads = heads) if rel_pos_bias else None + # hyper connections + + init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) + + # layers + for _ in range(depth): self.layers.append(nn.ModuleList([ - Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs), - Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs) if cross_attend else None, - FeedForward(dim = dim, dropout = ff_dropout) + init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, flash = flash_attn, causal = True, **kwargs)), + init_hyper_conn(dim = dim, branch = Attention(dim = dim, heads = heads, dropout = attn_dropout, dim_context = dim_context, flash = flash_attn, num_null_kv = 1, norm_context = True, **kwargs)) if cross_attend else None, + init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, dropout = ff_dropout)) ])) self.norm = LayerNorm(dim) @@ -510,6 +519,10 @@ def forward( self_attn_value_residual = None cross_attn_value_residual = None + # expand residual streams + + x = self.expand_streams(x) + # transformer layers for attn, cross_attn, ff in self.layers: @@ -523,18 +536,21 @@ def forward( new_kv_cache.append(layer_kv_cache) - x = x + residual - if exists(cross_attn): assert exists(context) - cross_attend_out, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual) - x = cross_attend_out + x + x, values = cross_attn(x, context = context, mask = context_mask, return_values = True, value_residual = cross_attn_value_residual) if self.add_value_residual: cross_attn_value_residual = default(cross_attn_value_residual, values) - x = ff(x) + x + x = ff(x) + + # reduce residual streams + + x = self.reduce_streams(x) + + # final norm x = self.norm(x) diff --git a/audiolm_pytorch/version.py b/audiolm_pytorch/version.py index 1c4ddd3..ba9b913 100644 --- a/audiolm_pytorch/version.py +++ b/audiolm_pytorch/version.py @@ -1 +1 @@ -__version__ = '2.3.1' +__version__ = '2.4.0' diff --git a/setup.py b/setup.py index ec174b8..121d565 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ 'fairseq', 'wandb', 'gateloop-transformer>=0.2.3', + 'hyper-connections>=0.1.8', 'joblib', 'local-attention>=1.9.0', 'pytorch-warmup',