Skip to content

Commit

Permalink
Updates
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Jan 30, 2025
1 parent df17170 commit 6e8041d
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,11 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

# Distributed variables
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0
self.nnodes = dist.get_local_size()
self.enable_tensor_parallel = cfg.get("enable_tensor_parallel", False)

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
Expand Down Expand Up @@ -521,21 +524,22 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# Apply TP if specified
mesh_shape = (1, 8)
device_mesh = init_device_mesh(
"cuda", tp_mesh_shape, mesh_dim_names=("dp", "tp")
)

# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
training.prepare_mha_for_tp(model, device_mesh["tp"])
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=config.instantiate(cfg.parallelize_plan),
)
device_mesh = {}
if self.enable_tensor_parallel:
mesh_shape = (self.nnodes, self.world_size // self.nnodes)
device_mesh = init_device_mesh(
"cuda", mesh_shape, mesh_dim_names=("dp", "tp")
)
# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
training.prepare_mha_for_tp(model, device_mesh["tp"])
# Apply tensor parallelism to the model
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=config.instantiate(cfg.parallelize_plan),
)

# For FSDP sharding
# Shard the model
fsdp_shard_conditions = [
partial(
training.get_shard_conditions,
Expand All @@ -547,7 +551,7 @@ def _setup_model(
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
device_mesh=device_mesh["dp"],
dp_device_mesh=device_mesh.get("dp"),
)

with training.set_default_dtype(self._dtype), self._device:
Expand Down

0 comments on commit 6e8041d

Please sign in to comment.