Skip to content

Commit

Permalink
Merge branch 'fix-nomask' into 'main'
Browse files Browse the repository at this point in the history
fix vit mask

See merge request ADLR/megatron-lm!1803
  • Loading branch information
ko3n1g committed Aug 15, 2024
2 parents 101c08e + 20abc85 commit e8f8e63
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
10 changes: 7 additions & 3 deletions megatron/core/models/vision/vit_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,16 @@

# Use this spec to use lower level Transformer Engine modules (required for fp8 training)
def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:
'''
Returns ViT layer spec with Transformer Engine layers
'''
mlp = _get_mlp_module_spec(use_te=True)
return ModuleSpec(
module=TransformerLayer,
submodules=TransformerLayerSubmodules(
self_attention=ModuleSpec(
module=SelfAttention,
params={
"attn_mask_type": AttnMaskType.causal
}, # TODO: This should be no_mask when CI is upgraded
params={"attn_mask_type": AttnMaskType.no_mask},
submodules=SelfAttentionSubmodules(
linear_qkv=TELayerNormColumnParallelLinear,
core_attention=TEDotProductAttention,
Expand All @@ -57,6 +58,9 @@ def get_vit_layer_with_transformer_engine_spec() -> ModuleSpec:


def get_vit_layer_with_local_spec() -> ModuleSpec:
'''
Returns ViT layer spec with Mcore local layers
'''
mlp = _get_mlp_module_spec(use_te=False)
return ModuleSpec(
module=TransformerLayer,
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.13455, 9.13251, 9.12855, 9.11268, 9.05516, 9.04352, 8.98424, 8.9352, 8.8928, 8.79364]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3478602.0, 3585025.0, 3475914.0, 3384266.0, 3700151.0, 3480265.0, 3398670.0, 3454930.0, 3426119.0, 3585909.0]}, "iteration_timing_avg": 0.2253964705882353}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.13442, 9.13256, 9.12852, 9.11273, 9.05533, 9.04358, 8.98427, 8.93519, 8.89295, 8.79396]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3478477.0, 3585145.0, 3475635.0, 3384010.0, 3700478.0, 3480110.0, 3398548.0, 3454436.0, 3425849.0, 3585758.0]},"iteration_timing_avg": 0.2253964705882353}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.16216, 9.16272, 9.15753, 9.14108, 9.09527, 9.07229, 9.01583, 8.96745, 8.92202, 8.83118]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3558559.0, 3664672.0, 3555664.0, 3463897.0, 3780688.0, 3560220.0, 3478422.0, 3535024.0, 3506032.0, 3666249.0]}, "iteration_timing_avg": 0.2253964705882353}
{"num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3558381.0, 3664861.0, 3555505.0, 3463866.0, 3780904.0, 3560200.0, 3478189.0, 3534510.0, 3506002.0, 3665772.0]},"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.16219, 9.16263, 9.15739, 9.1412, 9.09523, 9.07236, 9.01592, 8.96749, 8.92204, 8.8314]}}
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19795, 9.20023, 9.19544, 9.17244, 9.11854, 9.1031, 9.04185, 8.98723, 8.94423, 8.84517]}, "num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3718669.0, 3825107.0, 3715731.0, 3623999.0, 3940369.0, 3720312.0, 3638182.0, 3695283.0, 3666175.0, 3826111.0]}, "iteration_timing_avg": 0.5847132352941178}
{"lm loss": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [9.19789, 9.20022, 9.19547, 9.17248, 9.11862, 9.10315, 9.0418, 8.98727, 8.9443, 8.84512]},"num-zeros": {"start_step": 0, "end_step": 50, "step_interval": 5, "values": [3718539.0, 3825032.0, 3715374.0, 3623934.0, 3940675.0, 3720162.0, 3638165.0, 3695121.0, 3666164.0, 3825842.0]}, "iteration_timing_avg": 0.5847132352941178}

0 comments on commit e8f8e63

Please sign in to comment.