Skip to content

Commit

Permalink
Add some 1b configs.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Dec 5, 2024
1 parent 67f594d commit 85b3c10
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 1 deletion.
2 changes: 1 addition & 1 deletion moshi_mlx/moshi_mlx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
Models for EnCodec, AudioGen, MusicGen, as well as the generic LMModel.
"""

from .lm import Lm, LmConfig, config_v0_1
from .lm import Lm, LmConfig, config_v0_1, config1b_202412
from .generate import LmGen
59 changes: 59 additions & 0 deletions moshi_mlx/moshi_mlx/models/lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,65 @@ def warmup(self):
for c in self.transformer_cache:
c.reset()

def config1b_202412() -> LmConfig:
transformer = TransformerConfig(
d_model=2048,
num_heads=16,
num_layers=16,
dim_feedforward=2048 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=750,
max_period=100000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="rope",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
)
depformer = DepFormerConfig(
transformer=TransformerConfig(
d_model=1024,
num_heads=16,
num_layers=6,
dim_feedforward=1024 * 4, # dim * hidden_scale
causal=True,
norm_first=True,
bias_ff=False,
bias_attn=False,
layer_scale=None,
context=8,
max_period=10000,
use_conv_block=False,
use_conv_bias=True,
cross_attention=False,
gating=True,
norm="rms_norm",
positional_embedding="none",
conv_layout=False,
conv_kernel_size=3,
kv_repeat=1,
max_seq_len=4096,
),
num_slices=8,
)
return LmConfig(
transformer=transformer,
depformer=depformer,
audio_vocab_size=2049,
text_in_vocab_size=48001,
text_out_vocab_size=48000,
audio_codebooks=16,
audio_delays=([0] + [1] * 7) * 2,
)

def config_v0_1() -> LmConfig:
transformer = TransformerConfig(
Expand Down

0 comments on commit 85b3c10

Please sign in to comment.