Skip to content

Commit 3e229d7

Browse files
committed
Initial single AG, but still RS every microbatch
1 parent 9ffc7a3 commit 3e229d7

File tree

2 files changed

+76
-32
lines changed

2 files changed

+76
-32
lines changed

MaxText/layers/models.py

+20-14
Original file line numberDiff line numberDiff line change
@@ -393,21 +393,27 @@ def __call__(
393393
RemattedBlockLayer = self.set_remat_policy(self.decoder_layer, policy)
394394

395395
if cfg.using_pipeline_parallelism:
396-
key = jax.random.PRNGKey(0)
397-
keys={"params": key, "dropout": key, "aqt": key}
398-
weights=self.pipeline_module.init(keys,y,decoder_segment_ids, decoder_positions, deterministic, model_mode)
399-
400-
def get_partition_spec(pytree):
401-
def _is_leaf(x):
402-
return isinstance(x, nn.spmd.LogicallyPartitioned)
403-
404-
def get_partition_spec_leaf(leaf):
405-
return leaf.get_partition_spec()
406-
partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf)
407-
return partition_spec_tree
408-
weights_partition_spec = get_partition_spec(weights)
409-
weights_partition_spec['params'] = weights_partition_spec['params']['layers']
410396

397+
# Generate sharding spec of the weights used by the pipeline module (e.g. weights of the decoder layers)
398+
# TODO(nit): Can this be plumbed into pipeline.py instead?
399+
def generate_pp_weights_sharding_spec():
400+
key = jax.random.PRNGKey(0)
401+
keys={"params": key, "dropout": key, "aqt": key}
402+
weights=self.pipeline_module.init(keys,y,decoder_segment_ids, decoder_positions, deterministic, model_mode)
403+
404+
def get_partition_spec(pytree):
405+
def _is_leaf(x):
406+
return isinstance(x, nn.spmd.LogicallyPartitioned)
407+
408+
def get_partition_spec_leaf(leaf):
409+
return leaf.get_partition_spec()
410+
partition_spec_tree = jax.tree.map(get_partition_spec_leaf, pytree, is_leaf=_is_leaf)
411+
return partition_spec_tree
412+
weights_partition_spec = get_partition_spec(weights)
413+
weights_partition_spec['params'] = weights_partition_spec['params']['layers']
414+
return weights_partition_spec
415+
416+
weights_partition_spec = generate_pp_weights_sharding_spec()
411417
y = self.pipeline_module(y,
412418
decoder_segment_ids,
413419
decoder_positions,

MaxText/layers/pipeline.py

+56-18
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def iterations_to_complete_first_microbatch(self):
7272
+ self.iterations_to_complete_first_microbatch_one_repeat()
7373
)
7474

75-
def init_states(self, inputs, sharding_info):
75+
def init_states(self, inputs, sharding_info=None):
7676
"""Initialize components of state: state_io, shift, circular_storage and circular_storage_mover
7777
Assumes input has already been reshaped into microbatches: [num_micro_batches, micro_batch_size, sequence, embed]
7878
@@ -139,23 +139,29 @@ def init_states(self, inputs, sharding_info):
139139
else:
140140
circ_storage_mover = None
141141

142-
# bsw holds to repeats worth of weights for every stage. It us used to implement ideal behavior with FSDP:
143-
# which is that we can all-gather the weights only once per repeat (as opposed to every microbatch), since the same weights apply to all microbatches.
144-
# Additionally we can reduce the gradients across the data and fsdp axis only once per repeat using this buffer. This allows us to avoid
145-
# additional fsdp/data parallelism comms, without having to incur a huge memory cost of storing all of the weights and gradients.
146-
# If the weights and gradients fit into memory we can instead replace any FSDP with DP - probably some combination of TP and EP as well
147-
# If there is no FSDP this feature is not needed. If there is DP the XLA compiler will automatically reduce the gradients only after all microbatches.
148-
# This synergy with FSDP is only available to circular pipelines - otherwise we generally have to store the FSDP-gathered weights and gradients
149-
# to avoid extra comms associated with FSDP (e.g. FSDP is only sharding the optimizer state, not the live weights and grads)
142+
150143
def grab_two_rows_of_pytree(pytree):
151144
def _grab_two_rows_of_array(leaf):
152145
all_repeats = jnp.zeros_like(leaf, dtype=inputs.dtype) # TODO: Should set to activation_dtype probably
153146
return all_repeats[0:2] # Buffer is of length 2 since at most 2 repeats are active across stages on most iterations on each iteration
154147
return jax.tree.map(_grab_two_rows_of_array, pytree)
155-
bsw = grab_two_rows_of_pytree(self.layers.variables)
156-
physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(sharding_info)
157-
bsw = jax.lax.with_sharding_constraint(bsw, physical_constraint_no_fsdp)
158-
bsw = self.ag_new_bsw(bsw, sharding_info, 0)
148+
149+
# bsw holds two repeats worth of weights for every stage. It us used to implement ideal behavior with FSDP.
150+
# Ideal behavior is to all-gather the weights only once per repeat (as opposed to every microbatch), since the same weights apply to all microbatches.
151+
# Additionally we can reduce the gradients across the data and fsdp axis only once per repeat by making use of this buffer. This allows us to avoid
152+
# additional fsdp/data parallelism comms, without having to incur a huge memory cost of storing all of the weights and gradients.
153+
# If the weights and gradients fit into memory we can instead replace any FSDP with DP - probably some combination of TP and EP as well.
154+
# If there is no FSDP this feature is not needed. If there is DP the XLA compiler will automatically reduce the gradients only after all microbatches.
155+
# This synergy with FSDP is only available to circular pipelines - otherwise we generally have to store the FSDP-gathered weights and gradients
156+
# to avoid extra comms associated with FSDP (e.g. except for ciruclar pipelines, FSDP is only sharding the optimizer state, not the live weights and grads,
157+
# without paying for extra communication costs)
158+
if self.is_initializing():
159+
bsw = None
160+
else:
161+
bsw = grab_two_rows_of_pytree(self.layers.variables)
162+
bsw = jax.lax.with_sharding_constraint(bsw, self.get_physical_spec_no_fsdp(sharding_info))
163+
# bsw = self.ag_new_bsw(bsw, sharding_info, 0) TODO(This was needed in the old implementation since we all gather for first time on later stages), this has to be initialized
164+
# to real weights. I think for this double loop implementation it will be initialized right before the first microbatch of the first repeat.
159165

160166
init_loop_state = {
161167
"state_io": state_io,
@@ -362,6 +368,7 @@ def _update_state_io(state_in, stream_slice, output):
362368
"circ_storage_mover": new_circ_storage_mover,
363369
"loop_iteration": loop_iteration + 1,
364370
"prev_outputs": new_prev_outputs,
371+
'bsw': loop_state['bsw'], #bsw is updated outside of this inner loop, only once per outer loop iteration.
365372
}
366373
return new_loop_state
367374

@@ -381,10 +388,11 @@ def get_current_stage_weights(self, weights, bsw, loop_iteration):
381388
else:
382389
return weights
383390

384-
def get_current_weights_from_bsw(self, bsw, loop_iteration):
391+
def get_current_weights_from_bsw(self, bsw, loop_iteration):
385392
def get_bsw_idx(loop_iteration):
386393
_, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iteration)
387394
bsw_ids = repeat_ids==repeat_ids[0] # For early repeats this might return true when it should be false( e.g. 0==0 instead of 0!=-1)
395+
# TODO(confirm/clarify it doesn't matter because dummy bubble for this iterations anyway)
388396

389397
bsw_ids = bsw_ids.astype(jnp.int32)
390398
return bsw_ids
@@ -573,6 +581,32 @@ def get_pipeline_remat_policy(self):
573581
remat_policy = save_input_policy
574582
return remat_policy
575583

584+
def get_physical_spec_no_fsdp(self, full_logical):
585+
def remove_fsdp_sharding(sharding_tree):
586+
def _remove_fsdp_from_partition_spec(named_sharding):
587+
if isinstance(named_sharding, jax.sharding.NamedSharding):
588+
new_spec = []
589+
for axis in named_sharding.spec:
590+
if axis is None:
591+
new_spec.append(None)
592+
elif isinstance(axis, str):
593+
if axis != 'fsdp':
594+
new_spec.append(axis)
595+
else:
596+
new_spec.append(None)
597+
elif isinstance(axis, (list, tuple)): # Handle list/tuple of axes
598+
new_axis = [a for a in axis if a != 'fsdp']
599+
new_spec.append(tuple(new_axis))
600+
#new_spec.append(tuple(new_axis) if new_axis else None)
601+
else:
602+
raise ValueError(f"Unsupported axis type: {type(axis)}")
603+
return jax.sharding.NamedSharding(named_sharding.mesh, jax.sharding.PartitionSpec(*new_spec))
604+
return named_sharding
605+
return jax.tree.map(_remove_fsdp_from_partition_spec, sharding_tree)
606+
physical = nn.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules)
607+
physical_no_fsdp = remove_fsdp_sharding(physical)
608+
return physical_no_fsdp
609+
576610
@nn.compact
577611
def __call__(
578612
self,
@@ -581,6 +615,7 @@ def __call__(
581615
positions: jnp.ndarray,
582616
deterministic: bool,
583617
model_mode=common_types.MODEL_MODE_TRAIN,
618+
sharding_info=None # Pytree of sharding specifications of the weights (aka self.layers.variables)
584619
) -> jnp.ndarray:
585620
"""The main method that maps the series of decoder layer inputs to final layer outputs.
586621
Has the same signature of a single decoder layer, and expects the same shapes, e.g. the inputs should have shape [global_batch], and internally
@@ -615,7 +650,7 @@ def __call__(
615650
example_segmentation = None
616651
segment_idx = None
617652

618-
loop_state = self.init_states(inputs)
653+
loop_state = self.init_states(inputs, sharding_info=sharding_info)
619654

620655
# Each microbatch should go through each stage (with repeats) - so there is num_micro * (num_stages * repeats) compute to perform
621656
# Each iteration is vmapped by num_stages, so the number of iterations should be num_micro * num_stages * repeats / num_stages = num_micro * repeats
@@ -693,7 +728,7 @@ def run_iteration_scannable(model, loop_state, xs):
693728
prevent_cse=not self.config.scan_pipeline_iterations, # prevent_cse not used with scan
694729
policy=self.get_pipeline_remat_policy(),
695730
)
696-
731+
697732
# The scan cannot be used on init since it broadcasts the weights, which aren't yet initialized.
698733
if self.config.scan_pipeline_iterations:
699734
variable_carry = []
@@ -719,9 +754,12 @@ def run_iteration_scannable(model, loop_state, xs):
719754
length=self.config.num_pipeline_microbatches
720755
)
721756
# AG weights
722-
cur_repeat_weights_buffer = self.ag_new_bsw(bsw, sharding_info, loop_iter)
723-
for repeat in range(self.config.num_pipeline_repeats):
757+
# TODO(Consider wrapping this double loop in its own method - loop_over_repeats_gather_fsdp_first), since
758+
# it probably won't be called all the time - only with FSDP and feature turned on
759+
loop_state['bsw'] = self.ag_new_bsw(loop_state['bsw'], sharding_info, loop_state['loop_iteration'])
760+
for repeat_index in range(self.config.num_pipeline_repeats):
724761
loop_state, _ = run_one_repeat(self, loop_state, None)
762+
# TODO: Identical scan is used for repeat and flushing - should refactor to shared, only length differs
725763
flush_pipeline = nn.scan(
726764
run_iteration_scannable,
727765
variable_axes={

0 commit comments

Comments
 (0)