@@ -72,7 +72,7 @@ def iterations_to_complete_first_microbatch(self):
72
72
+ self .iterations_to_complete_first_microbatch_one_repeat ()
73
73
)
74
74
75
- def init_states (self , inputs , sharding_info ):
75
+ def init_states (self , inputs , sharding_info = None ):
76
76
"""Initialize components of state: state_io, shift, circular_storage and circular_storage_mover
77
77
Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]
78
78
@@ -139,23 +139,29 @@ def init_states(self, inputs, sharding_info):
139
139
else :
140
140
circ_storage_mover = None
141
141
142
- # bsw holds to repeats worth of weights for every stage. It us used to implement ideal behavior with FSDP:
143
- # which is that we can all-gather the weights only once per repeat (as opposed to every microbatch), since the same weights apply to all microbatches.
144
- # Additionally we can reduce the gradients across the data and fsdp axis only once per repeat using this buffer. This allows us to avoid
145
- # additional fsdp/data parallelism comms, without having to incur a huge memory cost of storing all of the weights and gradients.
146
- # If the weights and gradients fit into memory we can instead replace any FSDP with DP - probably some combination of TP and EP as well
147
- # If there is no FSDP this feature is not needed. If there is DP the XLA compiler will automatically reduce the gradients only after all microbatches.
148
- # This synergy with FSDP is only available to circular pipelines - otherwise we generally have to store the FSDP-gathered weights and gradients
149
- # to avoid extra comms associated with FSDP (e.g. FSDP is only sharding the optimizer state, not the live weights and grads)
142
+
150
143
def grab_two_rows_of_pytree (pytree ):
151
144
def _grab_two_rows_of_array (leaf ):
152
145
all_repeats = jnp .zeros_like (leaf , dtype = inputs .dtype ) # TODO: Should set to activation_dtype probably
153
146
return all_repeats [0 :2 ] # Buffer is of length 2 since at most 2 repeats are active across stages on most iterations on each iteration
154
147
return jax .tree .map (_grab_two_rows_of_array , pytree )
155
- bsw = grab_two_rows_of_pytree (self .layers .variables )
156
- physical_constraint_no_fsdp = self .get_physical_spec_no_fsdp (sharding_info )
157
- bsw = jax .lax .with_sharding_constraint (bsw , physical_constraint_no_fsdp )
158
- bsw = self .ag_new_bsw (bsw , sharding_info , 0 )
148
+
149
+ # bsw holds two repeats worth of weights for every stage. It us used to implement ideal behavior with FSDP.
150
+ # Ideal behavior is to all-gather the weights only once per repeat (as opposed to every microbatch), since the same weights apply to all microbatches.
151
+ # Additionally we can reduce the gradients across the data and fsdp axis only once per repeat by making use of this buffer. This allows us to avoid
152
+ # additional fsdp/data parallelism comms, without having to incur a huge memory cost of storing all of the weights and gradients.
153
+ # If the weights and gradients fit into memory we can instead replace any FSDP with DP - probably some combination of TP and EP as well.
154
+ # If there is no FSDP this feature is not needed. If there is DP the XLA compiler will automatically reduce the gradients only after all microbatches.
155
+ # This synergy with FSDP is only available to circular pipelines - otherwise we generally have to store the FSDP-gathered weights and gradients
156
+ # to avoid extra comms associated with FSDP (e.g. except for ciruclar pipelines, FSDP is only sharding the optimizer state, not the live weights and grads,
157
+ # without paying for extra communication costs)
158
+ if self .is_initializing ():
159
+ bsw = None
160
+ else :
161
+ bsw = grab_two_rows_of_pytree (self .layers .variables )
162
+ bsw = jax .lax .with_sharding_constraint (bsw , self .get_physical_spec_no_fsdp (sharding_info ))
163
+ # bsw = self.ag_new_bsw(bsw, sharding_info, 0) TODO(This was needed in the old implementation since we all gather for first time on later stages), this has to be initialized
164
+ # to real weights. I think for this double loop implementation it will be initialized right before the first microbatch of the first repeat.
159
165
160
166
init_loop_state = {
161
167
"state_io" : state_io ,
@@ -362,6 +368,7 @@ def _update_state_io(state_in, stream_slice, output):
362
368
"circ_storage_mover" : new_circ_storage_mover ,
363
369
"loop_iteration" : loop_iteration + 1 ,
364
370
"prev_outputs" : new_prev_outputs ,
371
+ 'bsw' : loop_state ['bsw' ], #bsw is updated outside of this inner loop, only once per outer loop iteration.
365
372
}
366
373
return new_loop_state
367
374
@@ -381,10 +388,11 @@ def get_current_stage_weights(self, weights, bsw, loop_iteration):
381
388
else :
382
389
return weights
383
390
384
- def get_current_weights_from_bsw (self , bsw , loop_iteration ):
391
+ def get_current_weights_from_bsw (self , bsw , loop_iteration ):
385
392
def get_bsw_idx (loop_iteration ):
386
393
_ , repeat_ids = self .get_microbatch_and_repeat_ids (loop_iteration )
387
394
bsw_ids = repeat_ids == repeat_ids [0 ] # For early repeats this might return true when it should be false( e.g. 0==0 instead of 0!=-1)
395
+ # TODO(confirm/clarify it doesn't matter because dummy bubble for this iterations anyway)
388
396
389
397
bsw_ids = bsw_ids .astype (jnp .int32 )
390
398
return bsw_ids
@@ -573,6 +581,32 @@ def get_pipeline_remat_policy(self):
573
581
remat_policy = save_input_policy
574
582
return remat_policy
575
583
584
+ def get_physical_spec_no_fsdp (self , full_logical ):
585
+ def remove_fsdp_sharding (sharding_tree ):
586
+ def _remove_fsdp_from_partition_spec (named_sharding ):
587
+ if isinstance (named_sharding , jax .sharding .NamedSharding ):
588
+ new_spec = []
589
+ for axis in named_sharding .spec :
590
+ if axis is None :
591
+ new_spec .append (None )
592
+ elif isinstance (axis , str ):
593
+ if axis != 'fsdp' :
594
+ new_spec .append (axis )
595
+ else :
596
+ new_spec .append (None )
597
+ elif isinstance (axis , (list , tuple )): # Handle list/tuple of axes
598
+ new_axis = [a for a in axis if a != 'fsdp' ]
599
+ new_spec .append (tuple (new_axis ))
600
+ #new_spec.append(tuple(new_axis) if new_axis else None)
601
+ else :
602
+ raise ValueError (f"Unsupported axis type: { type (axis )} " )
603
+ return jax .sharding .NamedSharding (named_sharding .mesh , jax .sharding .PartitionSpec (* new_spec ))
604
+ return named_sharding
605
+ return jax .tree .map (_remove_fsdp_from_partition_spec , sharding_tree )
606
+ physical = nn .logical_to_mesh_sharding (full_logical , mesh = self .mesh , rules = self .config .logical_axis_rules )
607
+ physical_no_fsdp = remove_fsdp_sharding (physical )
608
+ return physical_no_fsdp
609
+
576
610
@nn .compact
577
611
def __call__ (
578
612
self ,
@@ -581,6 +615,7 @@ def __call__(
581
615
positions : jnp .ndarray ,
582
616
deterministic : bool ,
583
617
model_mode = common_types .MODEL_MODE_TRAIN ,
618
+ sharding_info = None # Pytree of sharding specifications of the weights (aka self.layers.variables)
584
619
) -> jnp .ndarray :
585
620
"""The main method that maps the series of decoder layer inputs to final layer outputs.
586
621
Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally
@@ -615,7 +650,7 @@ def __call__(
615
650
example_segmentation = None
616
651
segment_idx = None
617
652
618
- loop_state = self .init_states (inputs )
653
+ loop_state = self .init_states (inputs , sharding_info = sharding_info )
619
654
620
655
# Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) compute to perform
621
656
# Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages * repeats / num_stages = num_micro * repeats
@@ -693,7 +728,7 @@ def run_iteration_scannable(model, loop_state, xs):
693
728
prevent_cse = not self .config .scan_pipeline_iterations , # prevent_cse not used with scan
694
729
policy = self .get_pipeline_remat_policy (),
695
730
)
696
-
731
+
697
732
# The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized.
698
733
if self .config .scan_pipeline_iterations :
699
734
variable_carry = []
@@ -719,9 +754,12 @@ def run_iteration_scannable(model, loop_state, xs):
719
754
length = self .config .num_pipeline_microbatches
720
755
)
721
756
# AG weights
722
- cur_repeat_weights_buffer = self .ag_new_bsw (bsw , sharding_info , loop_iter )
723
- for repeat in range (self .config .num_pipeline_repeats ):
757
+ # TODO(Consider wrapping this double loop in its own method - loop_over_repeats_gather_fsdp_first), since
758
+ # it probably won't be called all the time - only with FSDP and feature turned on
759
+ loop_state ['bsw' ] = self .ag_new_bsw (loop_state ['bsw' ], sharding_info , loop_state ['loop_iteration' ])
760
+ for repeat_index in range (self .config .num_pipeline_repeats ):
724
761
loop_state , _ = run_one_repeat (self , loop_state , None )
762
+ # TODO: Identical scan is used for repeat and flushing - should refactor to shared, only length differs
725
763
flush_pipeline = nn .scan (
726
764
run_iteration_scannable ,
727
765
variable_axes = {
0 commit comments