Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Jan 29, 2025
1 parent d4465c8 commit df17170
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
27 changes: 19 additions & 8 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,8 @@ def __init__(self, cfg: DictConfig) -> None:
)
self._log_peak_memory_stats = False

_, rank = utils.get_world_size_and_rank()
self._is_rank_zero = rank == 0
self.world_size, self.rank = utils.get_world_size_and_rank()
self._is_rank_zero = self.rank == 0

# Training cfg
self._resume_from_checkpoint = cfg.resume_from_checkpoint
Expand Down Expand Up @@ -521,6 +521,20 @@ def _setup_model(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)

# Apply TP if specified
mesh_shape = (1, 8)
device_mesh = init_device_mesh(
"cuda", tp_mesh_shape, mesh_dim_names=("dp", "tp")
)

# Use the local number (num_heads, num_kv_heads, embed_dim) to account for tensor paralell
training.prepare_mha_for_tp(model, device_mesh["tp"])
parallelize_module(
model,
device_mesh["tp"],
parallelize_plan=config.instantiate(cfg.parallelize_plan),
)

# For FSDP sharding
fsdp_shard_conditions = [
partial(
Expand All @@ -533,6 +547,7 @@ def _setup_model(
shard_conditions=fsdp_shard_conditions,
cpu_offload=fsdp_cpu_offload,
reshard_after_forward=reshard_after_forward,
device_mesh=device_mesh["dp"],
)

with training.set_default_dtype(self._dtype), self._device:
Expand Down Expand Up @@ -638,8 +653,6 @@ def _setup_data(
DistributedSamplers with Map-style Datasets which fit into memory. Other samplers,
iterable datasets and streaming datasets are not supported.
"""
world_size, rank = utils.get_world_size_and_rank()

if isinstance(cfg_dataset, ListConfig):
datasets = [
config.instantiate(single_cfg_dataset, self._tokenizer)
Expand All @@ -657,7 +670,7 @@ def _setup_data(
collate_fn = _get_component_from_path(collate_fn)

sampler = DistributedSampler(
ds, num_replicas=world_size, rank=rank, shuffle=shuffle, seed=0
ds, num_replicas=world_size, rank=self.rank, shuffle=shuffle, seed=0
)
dataloader = DataLoader(
dataset=ds,
Expand Down Expand Up @@ -687,8 +700,6 @@ def train(self) -> None:
# clean up before training begins
training.cleanup_before_training()

world_size, rank = utils.get_world_size_and_rank()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
Expand All @@ -708,7 +719,7 @@ def train(self) -> None:
# in case shuffle is True
self._sampler.set_epoch(curr_epoch)

pbar = tqdm(total=self._steps_per_epoch, disable=not (rank == 0))
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
for idx, batch in enumerate(self._dataloader):
if (
self.max_steps_per_epoch is not None
Expand Down
4 changes: 4 additions & 0 deletions torchtune/training/_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ def shard_model(
*,
cpu_offload: bool,
reshard_after_forward: bool = True,
device_mesh: Optional[DeviceMesh] = None,
) -> None:
"""
Utility to shard a model with FSDP using the PyTorch Distributed fully_shard API.
Expand All @@ -534,6 +535,9 @@ def shard_model(
if cpu_offload:
fsdp_kwargs["offload_policy"] = CPUOffloadPolicy()

if device_mesh is not None:
fsdp_kwargs["mesh"] = device_mesh

# Shard the model with FSDP, iterating in reverse to start with
# lowest-level modules first
num_layers_sharded = 0
Expand Down

0 comments on commit df17170

Please sign in to comment.