@@ -635,6 +635,204 @@ def _test_linear(
635635 torch .testing .assert_close (db_test , db_ref , ** tols )
636636
637637
638+ def _test_mlp (
639+ * ,
640+ bias : bool = True ,
641+ hidden_size : int = 32 ,
642+ local_batch_size : int = 32 ,
643+ dtype : torch .dtype = torch .float32 ,
644+ device : torch .device = "cuda" ,
645+ quantization : Optional [str ] = None ,
646+ quantized_weight : bool = False ,
647+ sequence_parallel : bool = False ,
648+ ) -> None :
649+ """2-layer MLP
650+
651+ MLP includes GELU activation in order to test op fusions. Model
652+ performs warmup steps in order to test inter-step logic.
653+
654+ """
655+
656+ # Skip invalid configurations
657+ quantized_compute = quantization is not None
658+ if not quantized_compute and quantized_weight :
659+ return
660+
661+ # Distributed process group
662+ process_group = world_group ()
663+ rank = torch .distributed .get_rank (process_group )
664+ world_size = torch .distributed .get_world_size (process_group )
665+
666+ # Tensor dimensions
667+ mlp_size = hidden_size * world_size
668+ batch_size = local_batch_size
669+ if sequence_parallel :
670+ batch_size *= world_size
671+ in_shape = (batch_size , hidden_size )
672+
673+ # Random data
674+ reset_rng ()
675+ x_ref , x_test = make_reference_and_test_tensors (
676+ in_shape ,
677+ quantization = quantization ,
678+ test_dtype = dtype ,
679+ test_device = device ,
680+ )
681+ w1_ref , w1_test = make_reference_and_test_tensors (
682+ (mlp_size , hidden_size ),
683+ quantization = quantization ,
684+ test_dtype = dtype ,
685+ test_device = device ,
686+ )
687+ b1_ref , b1_test = None , None
688+ w2_ref , w2_test = make_reference_and_test_tensors (
689+ (hidden_size , mlp_size ),
690+ quantization = quantization ,
691+ test_dtype = dtype ,
692+ test_device = device ,
693+ )
694+ b2_ref , b2_test = None , None
695+ if bias :
696+ b1_ref , b1_test = make_reference_and_test_tensors (
697+ (mlp_size ,),
698+ test_dtype = dtype ,
699+ test_device = device ,
700+ )
701+ b2_ref , b2_test = make_reference_and_test_tensors (
702+ (world_size , hidden_size ),
703+ test_dtype = dtype ,
704+ test_device = device ,
705+ )
706+ dy_ref , dy_test = make_reference_and_test_tensors (
707+ in_shape ,
708+ quantization = quantization ,
709+ test_dtype = dtype ,
710+ test_device = device ,
711+ requires_grad = False ,
712+ )
713+
714+ # Plain PyTorch implementation
715+ y_ref = torch .nn .functional .gelu (x_ref , approximate = "tanh" )
716+ y_ref = torch .nn .functional .linear (y_ref , w1_ref )
717+ if bias :
718+ y_ref += b1_ref
719+ y_ref = torch .nn .functional .gelu (y_ref , approximate = "tanh" )
720+ y_ref = torch .nn .functional .linear (y_ref , w2_ref )
721+ if bias :
722+ y_ref += b2_ref .sum (dim = 0 )
723+ y_ref = torch .nn .functional .gelu (y_ref , approximate = "tanh" )
724+ y_ref .backward (dy_ref )
725+
726+ # Convert to distributed tensors
727+ with torch .no_grad ():
728+ local_mlp_size = mlp_size // world_size
729+ local_mlp_slice = slice (rank * local_mlp_size , (rank + 1 ) * local_mlp_size )
730+ dx_ref = x_ref .grad
731+ dw1_ref = w1_ref .grad [local_mlp_slice , :]
732+ w1_ref = w1_ref [local_mlp_slice , :]
733+ w1_test = w1_test [local_mlp_slice , :]
734+ dw2_ref = w2_ref .grad [:, local_mlp_slice ]
735+ w2_ref = w2_ref [:, local_mlp_slice ]
736+ w2_test = w2_test [:, local_mlp_slice ]
737+ if bias :
738+ db1_ref = b1_ref .grad [local_mlp_slice ]
739+ b1_ref = b1_ref [local_mlp_slice ]
740+ b1_test = b1_test [local_mlp_slice ]
741+ db2_ref = b2_ref .grad [rank , :]
742+ b2_ref = b2_ref [rank , :]
743+ b2_test = b2_test [rank , :]
744+ else :
745+ db1_ref = None
746+ db2_ref = None
747+ if sequence_parallel :
748+ local_batch_slice = slice (
749+ rank * local_batch_size ,
750+ (rank + 1 ) * local_batch_size ,
751+ )
752+ x_ref = x_ref [local_batch_slice , ...]
753+ dx_ref = dx_ref [local_batch_slice , ...]
754+ x_test = x_test [local_batch_slice , ...].clone ()
755+ y_ref = y_ref [local_batch_slice , ...]
756+ dy_ref = dy_ref [local_batch_slice , ...]
757+ dy_test = dy_test [local_batch_slice , ...].clone ()
758+ x_test .requires_grad_ ()
759+
760+ # Implementation with fusible operation
761+ recipe = make_recipe (quantization )
762+ with te .fp8_model_init (enabled = quantized_weight , recipe = recipe ):
763+ model = te_ops .Sequential (
764+ te_ops .GELU (),
765+ te_ops .Linear (
766+ hidden_size ,
767+ mlp_size ,
768+ bias = bias ,
769+ device = device ,
770+ dtype = dtype ,
771+ tensor_parallel_mode = "column" ,
772+ tensor_parallel_group = process_group ,
773+ sequence_parallel = sequence_parallel ,
774+ ),
775+ te_ops .GELU (),
776+ te_ops .Linear (
777+ mlp_size ,
778+ hidden_size ,
779+ bias = bias ,
780+ device = device ,
781+ dtype = dtype ,
782+ tensor_parallel_mode = "row" ,
783+ tensor_parallel_group = process_group ,
784+ sequence_parallel = sequence_parallel ,
785+ ),
786+ te_ops .GELU (),
787+ )
788+ with torch .no_grad ():
789+ model [1 ].weight .copy_ (w1_test )
790+ model [3 ].weight .copy_ (w2_test )
791+ if bias :
792+ model [1 ].bias .copy_ (b1_test )
793+ model [3 ].bias .copy_ (b2_test )
794+ del w1_test , w2_test , b1_test , b2_test
795+
796+ # Warmup steps
797+ for _ in range (3 ):
798+ with te .fp8_autocast (enabled = quantized_compute , fp8_recipe = recipe ):
799+ y_test = model (x_test )
800+ y_test .backward (dy_test )
801+ x_test .grad = None
802+ model [1 ].weight .grad = None
803+ model [3 ].weight .grad = None
804+ if bias :
805+ model [1 ].bias .grad = None
806+ model [3 ].bias .grad = None
807+
808+ # Forward and backward step
809+ with te .fp8_autocast (enabled = quantized_compute , fp8_recipe = recipe ):
810+ y_test = model (x_test )
811+ y_test .backward (dy_test )
812+
813+ # Expected numerical error
814+ tols = dtype_tols (dtype )
815+ if dtype == torch .float32 :
816+ tols = dtype_tols (torch .float16 ) # TF32 GEMM
817+ if quantized_compute :
818+ tols = quantization_tols (quantization )
819+
820+ # Check results
821+ y_test = y_test .to (dtype = torch .float64 , device = "cpu" )
822+ dx_test = x_test .grad .to (dtype = torch .float64 , device = "cpu" )
823+ dw1_test = model [1 ].weight .grad .to (dtype = torch .float64 , device = "cpu" )
824+ dw2_test = model [3 ].weight .grad .to (dtype = torch .float64 , device = "cpu" )
825+ torch .testing .assert_close (y_test , y_ref , ** tols )
826+ torch .testing .assert_close (dx_test , dx_ref , ** tols )
827+ torch .testing .assert_close (dw1_test , dw1_ref , ** tols )
828+ torch .testing .assert_close (dw2_test , dw2_ref , ** tols )
829+ if bias :
830+ db1_test = model [1 ].bias .grad .to (dtype = torch .float64 , device = "cpu" )
831+ db2_test = model [3 ].bias .grad .to (dtype = torch .float64 , device = "cpu" )
832+ torch .testing .assert_close (db1_test , db1_ref , ** tols )
833+ torch .testing .assert_close (db2_test , db2_ref , ** tols )
834+
835+
638836def _test_fp8_scale_update (
639837 * ,
640838 amax_history_len : int = 31 ,
@@ -801,16 +999,31 @@ def run_parallel_tests() -> None:
801999 for config in itertools .product (
8021000 quantization_list ,
8031001 ("column" , "row" ),
1002+ (False , True ),
8041003 ):
8051004 if rank == 0 :
8061005 print (f"Running _test_linear with { config = } " )
807- quantization , tensor_parallel_mode = config
1006+ quantization , tensor_parallel_mode , sequence_parallel = config
8081007 dtype = torch .bfloat16 if is_bf16_compatible () else torch .float32
8091008 _test_linear (
8101009 bias = True , # bias=False is tested in _test_basic_linear
8111010 dtype = dtype ,
8121011 quantization = quantization ,
8131012 tensor_parallel_mode = tensor_parallel_mode ,
1013+ sequence_parallel = sequence_parallel ,
1014+ )
1015+
1016+ # MLP
1017+ for config in itertools .product (quantization_list , (False , True )):
1018+ if rank == 0 :
1019+ print (f"Running _test_mlp with { config = } " )
1020+ quantization , sequence_parallel = config
1021+ dtype = torch .bfloat16 if is_bf16_compatible () else torch .float32
1022+ _test_mlp (
1023+ bias = True , # bias=False is tested in _test_basic_linear
1024+ dtype = dtype ,
1025+ quantization = quantization ,
1026+ sequence_parallel = sequence_parallel ,
8141027 )
8151028
8161029 # FP8 scale update
0 commit comments