Skip to content

Commit

Permalink
Fix multi-GPU setup
Browse files Browse the repository at this point in the history
  • Loading branch information
pablomlago committed Jan 2, 2025
1 parent 0a209ee commit adc5238
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ def offload_model(
else:
device_map = infer_auto_device_map(
model, memory_map, no_split_module_classes=model._no_split_modules)

model = dispatch_model(model, device_map)

# Fixes an asymetric behavior in Accelerate where hooks are not attached at all when a single device is used.
Expand Down
83 changes: 51 additions & 32 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import argparse
from copy import deepcopy
from functools import partial
from functools import wraps
import os
import sys
Expand All @@ -20,6 +21,7 @@
from brevitas.export import export_torch_qcdq
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.base import ModuleInstanceFuseRotationWeights
from brevitas.graph.base import Transform
from brevitas.graph.equalize import GraphRotationEqualization
from brevitas.graph.equalize import LayerwiseActivationRotation
from brevitas.graph.quantize import layerwise_quantize
Expand Down Expand Up @@ -56,27 +58,31 @@ def set_seed(seed):
torch.random.manual_seed(seed)


def on_process(process_index: int):
def is_main_process():
return int(os.environ.get('LOCAL_RANK', -1)) in [-1, 0]

def decorator(func: Callable):

@wraps(func)
def _wrapper(model, *args, **kwargs):
curr_process_index = int(os.environ.get('LOCAL_RANK', -1))
def on_process(func: Callable, process_index: int):

if curr_process_index == -1 or (process_index == curr_process_index):
print(f"Applying {func.__name__} on process index {curr_process_index}")
return func(model, *args, **kwargs)
else:
print(f"Skipping function {func.__name__} on process index {curr_process_index}")
return model
@wraps(func)
def _wrapper(model, *args, **kwargs):
curr_process_index = int(os.environ.get('LOCAL_RANK', -1))

if curr_process_index == -1 or (process_index == curr_process_index):
print(f"Applying {func.__name__} on process index {curr_process_index}")
return func(model, *args, **kwargs)
else:
print(f"Skipping function {func.__name__} on process index {curr_process_index}")
return model

return _wrapper
return _wrapper

return decorator

on_main_process = partial(on_process, process_index=0)

def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.Module:

@on_main_process
def apply_fused_rotations(model: torch.nn.Module, rewriters: List[Transform]) -> torch.nn.Module:
model = offload_model(model)
for r in rewriters:
if isinstance(r, ModuleInstanceFuseRotationWeights):
Expand All @@ -85,6 +91,15 @@ def apply_fused_rotations(model: torch.nn.Module, rewriters: List) -> torch.nn.M
return model


@on_main_process
def evaluate_model(model: torch.nn.Module, validation_loader, args, tokenizer):
model = offload_model(model)
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Perplexity ({args.dataset}): {quant_ppl:.3f}")
remove_hooks(model)


# TODO: Use no_grad? The result of fusing the rotations would yield tensor with requires_grad set to False,
# which might no be a problem, as that flag is set in the appropiate QAT/PTQ algorithms.
def fused_rotation_no_fx(
Expand Down Expand Up @@ -282,12 +297,9 @@ def main(args, unknown_args=None):

if args.eval:
assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously"
print("Float model eval...")
model = offload_model(model)
float_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
remove_hooks(model)
print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}")
print("Evaluating float model...")
evaluate_model(model, validation_loader, args, tokenizer)
print("Float evaluation done.")

if args.replace_rmsnorm:
model = replace_rmsnorm_with_torch(model, model.config)
Expand Down Expand Up @@ -437,12 +449,21 @@ def mock_save_pretrained_fn(*args, **kwargs):
if args.bias_corr:
model = add_zero_bias_to_linear(model)

model = offload_model(model)

with torch.no_grad():
model(**calibration_loader[0])

remove_hooks(model)
# We need to run a calibration forward pass to initialize quantization-related parameters,
# e.g. scales. In DDP, as parameters are synchronized across replicas before optimization,
# it is not needed to run this pass for every process, as the parameters of the main
# process will be broadcasted to each replica.
if is_main_process():
model = offload_model(model)
with torch.no_grad():
model(**calibration_loader[0])
remove_hooks(model)
else:
# TODO: Generalize this logic. Currently, only ParameterFromStatsFromParameterZeroPoint
# and ParameterFromStatsFromParameterScaling have the attribute init_done
for module in model.modules():
if hasattr(module, "init_done"):
module.init_done = True

if args.rotation in ['fused_no_fx_optimize', 'fused_no_fx_optimize_self_attn_region']:
apply_rotation_optimization(
Expand All @@ -453,6 +474,7 @@ def mock_save_pretrained_fn(*args, **kwargs):
)

remove_hooks(model)
torch.cuda.empty_cache()

if args.act_calibration:
print("Apply act calibration...")
Expand Down Expand Up @@ -489,12 +511,9 @@ def mock_save_pretrained_fn(*args, **kwargs):
print("Bias correction applied.")

if args.eval and not args.no_quantize:
print("Model eval...")
model = offload_model(model)
quant_ppl = compute_perplexity(
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
remove_hooks(model)
print("Evaluating quantized model...")
evaluate_model(model, validation_loader, args, tokenizer)
print("Quantized evaluation done.")

if args.checkpoint_name is not None:
print(f"Saving checkpoint to {args.checkpoint_name}")
Expand Down

0 comments on commit adc5238

Please sign in to comment.