Skip to content

Commit 808ce1e

Browse files
authored
update configs (#28)
1 parent a3cf14b commit 808ce1e

13 files changed

+36
-410
lines changed

configs/cosmos-reason1/cosmos-reason1-7b-tp4-sft.toml renamed to configs/cosmos-reason1/cosmos-reason1-7b-fsdp2-sft.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ redis = "12800"
33
[train]
44
resume = false
55
epoch = 1
6-
output_dir = "./outputs/cosmos-reason1-7b-tp4-sft"
6+
output_dir = "./outputs/cosmos-reason1-7b-fsdp2-sft"
77
epsilon = 1e-6
88
optm_name = "AdamW"
9-
optm_lr = 1e-6
9+
optm_lr = 2e-6
1010
optm_impl = "fused"
1111
optm_weight_decay = 0.01
1212
optm_betas = [ 0.9, 0.999,]
@@ -18,7 +18,7 @@ param_dtype = "bfloat16"
1818
fsdp_reduce_dtype = "float32"
1919
fsdp_offload = false
2020
fsdp_reshard_after_forward = "default"
21-
train_batch_per_replica = 4
21+
train_batch_per_replica = 32
2222
sync_weight_interval = 1
2323
enable_validation = true
2424
validation_step = 30
@@ -44,6 +44,7 @@ enable_dataset_cache = false
4444
dataloader_num_workers = 4
4545
dataloader_prefetch_factor = 4
4646
conversation_column_name = "conversations"
47+
mini_batch = 4
4748

4849
[train.ckpt]
4950
enable_checkpoint = true
@@ -52,9 +53,9 @@ save_mode = "async"
5253

5354
[policy.parallelism]
5455
n_init_replicas = 1
55-
tp_size = 4
56+
tp_size = 1
5657
cp_size = 1
57-
dp_shard_size = -1
58+
dp_shard_size = 2
5859
pp_size = 1
5960
dp_replicate_size = 1
6061
cp_rotate_method = "allgather"

configs/cosmos-reason1/cosmos-reason1-7b-tp2-fsdp-sft.toml renamed to configs/cosmos-reason1/cosmos-reason1-7b-fsdp4-sft.toml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ redis = "12800"
33
[train]
44
resume = false
55
epoch = 1
6-
output_dir = "./outputs/cosmos-reason1-7b-tp2-dpx-sft"
6+
output_dir = "./outputs/cosmos-reason1-7b-fsdp4-sft"
77
epsilon = 1e-6
88
optm_name = "AdamW"
9-
optm_lr = 1e-6
9+
optm_lr = 2e-6
1010
optm_impl = "fused"
1111
optm_weight_decay = 0.01
1212
optm_betas = [ 0.9, 0.999,]
@@ -18,7 +18,7 @@ param_dtype = "bfloat16"
1818
fsdp_reduce_dtype = "float32"
1919
fsdp_offload = false
2020
fsdp_reshard_after_forward = "default"
21-
train_batch_per_replica = 4
21+
train_batch_per_replica = 32
2222
sync_weight_interval = 1
2323
enable_validation = true
2424
validation_step = 30
@@ -44,6 +44,7 @@ enable_dataset_cache = false
4444
dataloader_num_workers = 4
4545
dataloader_prefetch_factor = 4
4646
conversation_column_name = "conversations"
47+
mini_batch = 4
4748

4849
[train.ckpt]
4950
enable_checkpoint = true
@@ -52,9 +53,9 @@ save_mode = "async"
5253

5354
[policy.parallelism]
5455
n_init_replicas = 1
55-
tp_size = 2
56+
tp_size = 1
5657
cp_size = 1
57-
dp_shard_size = -1
58+
dp_shard_size = 4
5859
pp_size = 1
5960
dp_replicate_size = 1
6061
cp_rotate_method = "allgather"

configs/cosmos-reason1/cosmos-reason1-7b-tp2-sft-profile.toml renamed to configs/cosmos-reason1/cosmos-reason1-7b-fsdp8-sft.toml

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ redis = "12800"
33
[train]
44
resume = false
55
epoch = 1
6-
output_dir = "./outputs/cosmos-reason1-7b-tp2-sft-profile"
6+
output_dir = "./outputs/cosmos-reason1-7b-fsdp8-sft"
77
epsilon = 1e-6
88
optm_name = "AdamW"
9-
optm_lr = 1e-6
9+
optm_lr = 2e-6
1010
optm_impl = "fused"
1111
optm_weight_decay = 0.01
1212
optm_betas = [ 0.9, 0.999,]
@@ -18,10 +18,10 @@ param_dtype = "bfloat16"
1818
fsdp_reduce_dtype = "float32"
1919
fsdp_offload = false
2020
fsdp_reshard_after_forward = "default"
21-
train_batch_per_replica = 4
21+
train_batch_per_replica = 32
2222
sync_weight_interval = 1
2323
enable_validation = true
24-
validation_step = 100
24+
validation_step = 30
2525
validation_batch_per_replica = 2
2626

2727
[policy]
@@ -34,17 +34,6 @@ logger = ['console', 'wandb']
3434
project_name = "cosmos_reason1"
3535
experiment_name = "cosmos-reason1-sft"
3636

37-
[profiler]
38-
enable_profiler = true
39-
40-
[profiler.sub_profiler_config]
41-
active_steps = 2
42-
rank_filter = [0]
43-
record_shape = true
44-
profile_memory = false
45-
with_stack = false
46-
with_modules = false
47-
4837
[train.train_policy]
4938
type = "sft"
5039
dataset.name = "nvidia/Cosmos-Reason1-SFT-Dataset"
@@ -55,18 +44,18 @@ enable_dataset_cache = false
5544
dataloader_num_workers = 4
5645
dataloader_prefetch_factor = 4
5746
conversation_column_name = "conversations"
47+
mini_batch = 4
5848

5949
[train.ckpt]
6050
enable_checkpoint = true
61-
save_freq = 100
62-
max_keep = 2
51+
save_freq = 30
6352
save_mode = "async"
6453

6554
[policy.parallelism]
6655
n_init_replicas = 1
67-
tp_size = 2
56+
tp_size = 1
6857
cp_size = 1
69-
dp_shard_size = 1
58+
dp_shard_size = 8
7059
pp_size = 1
7160
dp_replicate_size = 1
7261
cp_rotate_method = "allgather"

configs/cosmos-reason1/cosmos-reason1-7b-p-fsdp1-tp2-r-tp1-pp2-grpo.toml

Lines changed: 0 additions & 84 deletions
This file was deleted.

configs/cosmos-reason1/cosmos-reason1-7b-p-fsdp1-tp2-r-tp2-pp1-grpo-fp8.toml

Lines changed: 0 additions & 90 deletions
This file was deleted.

configs/cosmos-reason1/cosmos-reason1-7b-p-fsdp1-tp2-r-tp2-pp1-grpo.toml

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ redis = "12800"
22

33
[train]
44
resume = false
5-
epoch = 1
5+
epoch = 80
66
output_dir = "./outputs/cosmos-reason1-7b-p-fsdp1-tp2-r-tp2-pp1-grpo"
77
epsilon = 1e-6
88
optm_name = "AdamW"
@@ -18,7 +18,7 @@ param_dtype = "bfloat16"
1818
fsdp_reduce_dtype = "float32"
1919
fsdp_offload = false
2020
fsdp_reshard_after_forward = "default"
21-
train_batch_per_replica = 8
21+
train_batch_per_replica = 128
2222
sync_weight_interval = 1
2323

2424
[rollout]
@@ -29,13 +29,6 @@ n_generation = 8
2929
batch_size = 4
3030
quantization = "none"
3131

32-
33-
[rollout.sampling_config]
34-
temperature = 0.6
35-
top_p = 0.95
36-
top_k = 50
37-
repetition_penalty = 1.05
38-
3932
[policy]
4033
model_name_or_path = "nvidia/Cosmos-Reason1-7B"
4134
model_max_length = 10240
@@ -62,7 +55,7 @@ epsilon_high = 0.2
6255
kl_beta = 0.0
6356
mu_iterations = 1
6457
min_filter_prefix_tokens = 1
65-
mini_batch = 1
58+
mini_batch = 4
6659

6760
[train.ckpt]
6861
enable_checkpoint = true

0 commit comments

Comments
 (0)