Skip to content

Commit 40c69e7

Browse files
timmoon10ptrendx
authored andcommitted
[PyTorch] Set usages for linear op quantizers before forward (#2222)
* Make sure to set usages for linear op quantizers before forward Signed-off-by: Tim Moon <[email protected]> * Avoid unsupported case for fused dbias+quantize kernel Hopper does not support dbias + FP8 cast without FP8 transpose. Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]>
1 parent 264ab86 commit 40c69e7

File tree

8 files changed

+296
-39
lines changed

8 files changed

+296
-39
lines changed

tests/pytorch/distributed/test_fusible_ops.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,204 @@ def _test_linear(
635635
torch.testing.assert_close(db_test, db_ref, **tols)
636636

637637

638+
def _test_mlp(
639+
*,
640+
bias: bool = True,
641+
hidden_size: int = 32,
642+
local_batch_size: int = 32,
643+
dtype: torch.dtype = torch.float32,
644+
device: torch.device = "cuda",
645+
quantization: Optional[str] = None,
646+
quantized_weight: bool = False,
647+
sequence_parallel: bool = False,
648+
) -> None:
649+
"""2-layer MLP
650+
651+
MLP includes GELU activation in order to test op fusions. Model
652+
performs warmup steps in order to test inter-step logic.
653+
654+
"""
655+
656+
# Skip invalid configurations
657+
quantized_compute = quantization is not None
658+
if not quantized_compute and quantized_weight:
659+
return
660+
661+
# Distributed process group
662+
process_group = world_group()
663+
rank = torch.distributed.get_rank(process_group)
664+
world_size = torch.distributed.get_world_size(process_group)
665+
666+
# Tensor dimensions
667+
mlp_size = hidden_size * world_size
668+
batch_size = local_batch_size
669+
if sequence_parallel:
670+
batch_size *= world_size
671+
in_shape = (batch_size, hidden_size)
672+
673+
# Random data
674+
reset_rng()
675+
x_ref, x_test = make_reference_and_test_tensors(
676+
in_shape,
677+
quantization=quantization,
678+
test_dtype=dtype,
679+
test_device=device,
680+
)
681+
w1_ref, w1_test = make_reference_and_test_tensors(
682+
(mlp_size, hidden_size),
683+
quantization=quantization,
684+
test_dtype=dtype,
685+
test_device=device,
686+
)
687+
b1_ref, b1_test = None, None
688+
w2_ref, w2_test = make_reference_and_test_tensors(
689+
(hidden_size, mlp_size),
690+
quantization=quantization,
691+
test_dtype=dtype,
692+
test_device=device,
693+
)
694+
b2_ref, b2_test = None, None
695+
if bias:
696+
b1_ref, b1_test = make_reference_and_test_tensors(
697+
(mlp_size,),
698+
test_dtype=dtype,
699+
test_device=device,
700+
)
701+
b2_ref, b2_test = make_reference_and_test_tensors(
702+
(world_size, hidden_size),
703+
test_dtype=dtype,
704+
test_device=device,
705+
)
706+
dy_ref, dy_test = make_reference_and_test_tensors(
707+
in_shape,
708+
quantization=quantization,
709+
test_dtype=dtype,
710+
test_device=device,
711+
requires_grad=False,
712+
)
713+
714+
# Plain PyTorch implementation
715+
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
716+
y_ref = torch.nn.functional.linear(y_ref, w1_ref)
717+
if bias:
718+
y_ref += b1_ref
719+
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
720+
y_ref = torch.nn.functional.linear(y_ref, w2_ref)
721+
if bias:
722+
y_ref += b2_ref.sum(dim=0)
723+
y_ref = torch.nn.functional.gelu(y_ref, approximate="tanh")
724+
y_ref.backward(dy_ref)
725+
726+
# Convert to distributed tensors
727+
with torch.no_grad():
728+
local_mlp_size = mlp_size // world_size
729+
local_mlp_slice = slice(rank * local_mlp_size, (rank + 1) * local_mlp_size)
730+
dx_ref = x_ref.grad
731+
dw1_ref = w1_ref.grad[local_mlp_slice, :]
732+
w1_ref = w1_ref[local_mlp_slice, :]
733+
w1_test = w1_test[local_mlp_slice, :]
734+
dw2_ref = w2_ref.grad[:, local_mlp_slice]
735+
w2_ref = w2_ref[:, local_mlp_slice]
736+
w2_test = w2_test[:, local_mlp_slice]
737+
if bias:
738+
db1_ref = b1_ref.grad[local_mlp_slice]
739+
b1_ref = b1_ref[local_mlp_slice]
740+
b1_test = b1_test[local_mlp_slice]
741+
db2_ref = b2_ref.grad[rank, :]
742+
b2_ref = b2_ref[rank, :]
743+
b2_test = b2_test[rank, :]
744+
else:
745+
db1_ref = None
746+
db2_ref = None
747+
if sequence_parallel:
748+
local_batch_slice = slice(
749+
rank * local_batch_size,
750+
(rank + 1) * local_batch_size,
751+
)
752+
x_ref = x_ref[local_batch_slice, ...]
753+
dx_ref = dx_ref[local_batch_slice, ...]
754+
x_test = x_test[local_batch_slice, ...].clone()
755+
y_ref = y_ref[local_batch_slice, ...]
756+
dy_ref = dy_ref[local_batch_slice, ...]
757+
dy_test = dy_test[local_batch_slice, ...].clone()
758+
x_test.requires_grad_()
759+
760+
# Implementation with fusible operation
761+
recipe = make_recipe(quantization)
762+
with te.fp8_model_init(enabled=quantized_weight, recipe=recipe):
763+
model = te_ops.Sequential(
764+
te_ops.GELU(),
765+
te_ops.Linear(
766+
hidden_size,
767+
mlp_size,
768+
bias=bias,
769+
device=device,
770+
dtype=dtype,
771+
tensor_parallel_mode="column",
772+
tensor_parallel_group=process_group,
773+
sequence_parallel=sequence_parallel,
774+
),
775+
te_ops.GELU(),
776+
te_ops.Linear(
777+
mlp_size,
778+
hidden_size,
779+
bias=bias,
780+
device=device,
781+
dtype=dtype,
782+
tensor_parallel_mode="row",
783+
tensor_parallel_group=process_group,
784+
sequence_parallel=sequence_parallel,
785+
),
786+
te_ops.GELU(),
787+
)
788+
with torch.no_grad():
789+
model[1].weight.copy_(w1_test)
790+
model[3].weight.copy_(w2_test)
791+
if bias:
792+
model[1].bias.copy_(b1_test)
793+
model[3].bias.copy_(b2_test)
794+
del w1_test, w2_test, b1_test, b2_test
795+
796+
# Warmup steps
797+
for _ in range(3):
798+
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
799+
y_test = model(x_test)
800+
y_test.backward(dy_test)
801+
x_test.grad = None
802+
model[1].weight.grad = None
803+
model[3].weight.grad = None
804+
if bias:
805+
model[1].bias.grad = None
806+
model[3].bias.grad = None
807+
808+
# Forward and backward step
809+
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
810+
y_test = model(x_test)
811+
y_test.backward(dy_test)
812+
813+
# Expected numerical error
814+
tols = dtype_tols(dtype)
815+
if dtype == torch.float32:
816+
tols = dtype_tols(torch.float16) # TF32 GEMM
817+
if quantized_compute:
818+
tols = quantization_tols(quantization)
819+
820+
# Check results
821+
y_test = y_test.to(dtype=torch.float64, device="cpu")
822+
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
823+
dw1_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu")
824+
dw2_test = model[3].weight.grad.to(dtype=torch.float64, device="cpu")
825+
torch.testing.assert_close(y_test, y_ref, **tols)
826+
torch.testing.assert_close(dx_test, dx_ref, **tols)
827+
torch.testing.assert_close(dw1_test, dw1_ref, **tols)
828+
torch.testing.assert_close(dw2_test, dw2_ref, **tols)
829+
if bias:
830+
db1_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu")
831+
db2_test = model[3].bias.grad.to(dtype=torch.float64, device="cpu")
832+
torch.testing.assert_close(db1_test, db1_ref, **tols)
833+
torch.testing.assert_close(db2_test, db2_ref, **tols)
834+
835+
638836
def _test_fp8_scale_update(
639837
*,
640838
amax_history_len: int = 31,
@@ -801,16 +999,31 @@ def run_parallel_tests() -> None:
801999
for config in itertools.product(
8021000
quantization_list,
8031001
("column", "row"),
1002+
(False, True),
8041003
):
8051004
if rank == 0:
8061005
print(f"Running _test_linear with {config=}")
807-
quantization, tensor_parallel_mode = config
1006+
quantization, tensor_parallel_mode, sequence_parallel = config
8081007
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
8091008
_test_linear(
8101009
bias=True, # bias=False is tested in _test_basic_linear
8111010
dtype=dtype,
8121011
quantization=quantization,
8131012
tensor_parallel_mode=tensor_parallel_mode,
1013+
sequence_parallel=sequence_parallel,
1014+
)
1015+
1016+
# MLP
1017+
for config in itertools.product(quantization_list, (False, True)):
1018+
if rank == 0:
1019+
print(f"Running _test_mlp with {config=}")
1020+
quantization, sequence_parallel = config
1021+
dtype = torch.bfloat16 if is_bf16_compatible() else torch.float32
1022+
_test_mlp(
1023+
bias=True, # bias=False is tested in _test_basic_linear
1024+
dtype=dtype,
1025+
quantization=quantization,
1026+
sequence_parallel=sequence_parallel,
8141027
)
8151028

8161029
# FP8 scale update

transformer_engine/pytorch/csrc/extensions/bias.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,25 @@ std::vector<py::object> bgrad_quantize(const at::Tensor &grad_output, py::handle
5454
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};
5555
}
5656

57-
// Unfused impl if quantizer is not supported
58-
const bool with_fused_dbias_quantize_kernel =
59-
detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr());
60-
if (!with_fused_dbias_quantize_kernel) {
57+
// Check if fused kernel is supported
58+
bool with_fused_kernel = false;
59+
if (detail::IsFloat8Quantizers(quantizer.ptr())) {
60+
auto prop = at::cuda::getCurrentDeviceProperties();
61+
const size_t sm_arch = 10 * prop->major + prop->minor;
62+
if (sm_arch >= 100) {
63+
// Fused kernel for dbias + FP8 cast on SM arch 10.0+
64+
with_fused_kernel = true;
65+
} else if (quantizer_cpp->rowwise_usage && quantizer_cpp->columnwise_usage) {
66+
// Fused kernel for dbias + FP8 cast + FP8 transpose
67+
with_fused_kernel = true;
68+
}
69+
} else if (detail::IsMXFP8Quantizers(quantizer.ptr())) {
70+
// Fused kernel for dbias + MXFP8 quantize
71+
with_fused_kernel = true;
72+
}
73+
74+
// Apply unfused impl if fused kernel is not supported
75+
if (!with_fused_kernel) {
6176
at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0});
6277
quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte);
6378
return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)};

transformer_engine/pytorch/ops/basic/basic_linear.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,20 @@ def pre_first_fuser_forward(self) -> None:
322322
if self.weight.device.type == "meta":
323323
self.reset_parameters()
324324

325+
def pre_fuser_forward(self, *, requires_grad: bool) -> None:
326+
super().pre_fuser_forward(requires_grad=requires_grad)
327+
if FP8GlobalStateManager.is_fp8_enabled():
328+
# Configure quantizer usages
329+
# Note: We cache the quantized input for backward pass,
330+
# but discard the quantized weights.
331+
weight_requires_grad = requires_grad and self.weight.requires_grad
332+
input_quantizer = self.get_quantizer("forward", 0)
333+
weight_quantizer = self.get_quantizer("forward", 1)
334+
grad_output_quantizer = self.get_quantizer("backward", 0)
335+
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
336+
weight_quantizer.set_usage(rowwise=True, columnwise=False)
337+
grad_output_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
338+
325339
def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
326340
super().reset_recipe_state(recipe=recipe)
327341

@@ -352,6 +366,35 @@ def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None:
352366
and not getattr(self, "_with_quantized_weight", False)
353367
)
354368

369+
# Recipe-specific configuration
370+
# Note: This function may be called in base class constructor,
371+
# before any basic linear attrs have been set.
372+
if recipe is not None:
373+
if recipe.float8_current_scaling():
374+
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
375+
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
376+
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_weight.power_2_scale
377+
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_weight.amax_epsilon
378+
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_bwd_grad.power_2_scale
379+
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_bwd_grad.amax_epsilon
380+
if getattr(self, "sequence_parallel", False):
381+
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
382+
if tensor_parallel_mode == "column":
383+
input_quantizer.with_amax_reduction = True
384+
input_quantizer.amax_reduction_group = self.tensor_parallel_group
385+
elif tensor_parallel_mode == "row":
386+
grad_output_quantizer.with_amax_reduction = True
387+
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
388+
if recipe.nvfp4():
389+
if getattr(self, "sequence_parallel", False):
390+
tensor_parallel_mode = getattr(self, "tensor_parallel_mode", None)
391+
if tensor_parallel_mode == "column":
392+
input_quantizer.with_amax_reduction = True
393+
input_quantizer.amax_reduction_group = self.tensor_parallel_group
394+
elif tensor_parallel_mode == "row":
395+
grad_output_quantizer.with_amax_reduction = True
396+
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
397+
355398
@staticmethod
356399
def _functional_forward(
357400
input: torch.Tensor, # pylint: disable=redefined-builtin
@@ -731,7 +774,7 @@ def _functional_backward(
731774
if with_quantized_compute:
732775
if input_quantizer is None:
733776
raise ValueError("Missing quantizer for input tensor")
734-
input_quantizer.set_usage(columnwise=True)
777+
input_quantizer.set_usage(rowwise=False, columnwise=True)
735778
if with_x_all_gather:
736779
x, x_async = gather_along_first_dim(
737780
x_local,
@@ -912,42 +955,13 @@ def op_forward(
912955
input_requires_grad = ctx.requires_grad
913956
weight_requires_grad = ctx.requires_grad and self.weight.requires_grad
914957

915-
# FP8 metadata
958+
# Quantizers
916959
input_quantizer = self.get_quantizer("forward", 0)
917960
weight_quantizer = self.get_quantizer("forward", 1)
918961
output_quantizer = next_op_input_quantizer
919962
grad_output_quantizer = self.get_quantizer("backward", 0)
920963
grad_input_quantizer = prev_op_grad_output_quantizer
921964
with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled()
922-
if with_quantized_compute:
923-
# Configure quantizers
924-
# Note: We cache the quantized input for backward pass,
925-
# but discard the quantized weights.
926-
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
927-
weight_quantizer.set_usage(rowwise=True, columnwise=False)
928-
929-
# Recipe-specific configuration
930-
recipe = FP8GlobalStateManager.get_fp8_recipe()
931-
if recipe.float8_current_scaling():
932-
input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
933-
input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
934-
weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
935-
weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
936-
grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale
937-
grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon
938-
if self.sequence_parallel and self.tensor_parallel_mode == "column":
939-
input_quantizer.with_amax_reduction = True
940-
input_quantizer.amax_reduction_group = self.tensor_parallel_group
941-
if self.sequence_parallel and self.tensor_parallel_mode == "row":
942-
grad_output_quantizer.with_amax_reduction = True
943-
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
944-
if recipe.nvfp4():
945-
if self.sequence_parallel and self.tensor_parallel_mode == "column":
946-
input_quantizer.with_amax_reduction = True
947-
input_quantizer.amax_reduction_group = self.tensor_parallel_group
948-
if self.sequence_parallel and self.tensor_parallel_mode == "row":
949-
grad_output_quantizer.with_amax_reduction = True
950-
grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group
951965

952966
# Get autocast dtype if needed
953967
if torch.is_autocast_enabled():

transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def fuser_forward(
8585
input_requires_grad = linear_op_ctx.requires_grad
8686
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
8787

88-
# FP8 metadata
88+
# Quantizers
8989
input_quantizer = linear_op.get_quantizer("forward", 0)
9090
weight_quantizer = linear_op.get_quantizer("forward", 1)
9191
output_quantizer = next_op_input_quantizer

transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def fuser_forward(
7979
input_requires_grad = linear_op_ctx.requires_grad
8080
weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad
8181

82-
# FP8 metadata
82+
# Quantizers
8383
input_quantizer = linear_op.get_quantizer("forward", 0)
8484
weight_quantizer = linear_op.get_quantizer("forward", 1)
8585
output_quantizer = None

0 commit comments

Comments
 (0)