@@ -588,20 +588,50 @@ def reset(key):
588
588
589
589
def get_extra_state (self ) -> torch .Tensor :
590
590
"""Save before checkpointing."""
591
- state = None
592
591
592
+ # This implementation is working around a few issues:
593
+ #
594
+ # (1) PyTorch's "extra state" infrastructure might be able to
595
+ # support any picklable type, but they make no guarantees.
596
+ # We have experienced problems (e.g. in ONNX export) with
597
+ # non-tensor extra state.
598
+ # (2) PyTorch's checkpointing infrastructure does not remap
599
+ # devices for "extra state" like it does for "state dict".
600
+ # Thus, we want to avoid putting extra state on the GPU
601
+ # since it may be loaded on the wrong device.
602
+ # (3) The extra state consists of many small tensors. If we
603
+ # want to copy them all to CPU, then we need to avoid the
604
+ # overhead of many GPU-CPU memory transfers.
605
+ #
606
+ # See: https://github.com/NVIDIA/TransformerEngine/pull/351
607
+ # See: https://github.com/NVIDIA/TransformerEngine/pull/363
608
+
609
+ def to_cpu (src : torch .Tensor ) -> torch .Tensor :
610
+ """Helper function to make CPU copy of tensor
611
+
612
+ Memory transfer is asynchronous w.r.t. host, so GPU should
613
+ be synchronized before using result.
614
+
615
+ """
616
+ dst = torch .empty_like (src , device = "cpu" )
617
+ dst .copy_ (src , non_blocking = True )
618
+ return dst
619
+
620
+ # Store FP8 state if needed
621
+ state = None
593
622
fp8_checkpoint = self .fp8_meta ["fp8_checkpoint" ] or self .fp8 or self .fp8_calibration
594
-
595
623
if fp8_checkpoint :
624
+
625
+ # Copy tensors to CPU and store
596
626
state = {}
597
- state ["scale_fwd" ] = self .fp8_meta ["scaling_fwd" ].scale
598
- state ["scale_inv_fwd " ] = self .fp8_meta ["scaling_fwd" ].scale_inv
599
- state ["amax_history_fwd " ] = self .fp8_meta ["scaling_fwd" ].amax_history
600
- state ["scale_bwd" ] = self .fp8_meta ["scaling_bwd" ].scale
601
- state ["scale_inv_bwd " ] = self .fp8_meta ["scaling_bwd" ].scale_inv
602
- state ["amax_history_bwd " ] = self .fp8_meta ["scaling_bwd" ].amax_history
603
-
604
- # Store other pickelable values.
627
+ state ["scale_fwd" ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].scale )
628
+ state ["amax_history_fwd " ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].amax_history )
629
+ state ["scale_inv_fwd " ] = to_cpu ( self .fp8_meta ["scaling_fwd" ].scale_inv )
630
+ state ["scale_bwd" ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].scale )
631
+ state ["amax_history_bwd " ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].amax_history )
632
+ state ["scale_inv_bwd " ] = to_cpu ( self .fp8_meta ["scaling_bwd" ].scale_inv )
633
+
634
+ # Store other pickelable values
605
635
extra = {}
606
636
for k , v in self .fp8_meta .items ():
607
637
if k != "buffer_index_and_autocast_key" and isinstance (
@@ -610,22 +640,23 @@ def get_extra_state(self) -> torch.Tensor:
610
640
extra [k ] = v
611
641
state ["extra_fp8_variables" ] = extra
612
642
613
- if is_in_onnx_export_mode ():
614
- state_serialized = torch .frombuffer (pickle .dumps (state ), dtype = torch .uint8 )
615
- else :
616
- state_serialized = io .BytesIO ()
617
- torch .save (state , state_serialized )
618
-
643
+ # Serialize state into byte tensor
644
+ torch .cuda .synchronize ()
645
+ state_serialized = bytearray (pickle .dumps (state ))
646
+ state_serialized = torch .frombuffer (state_serialized , dtype = torch .uint8 )
619
647
return state_serialized
620
648
621
649
def set_extra_state (self , state : torch .Tensor ) -> None :
622
650
"""Load previous state."""
623
651
if state is None :
624
652
return
625
653
654
+ # Load state
626
655
if isinstance (state , torch .Tensor ):
656
+ # Default format: byte tensor with pickled data
627
657
state = pickle .loads (state .detach ().cpu ().numpy ().tobytes ())
628
658
elif isinstance (state , io .BytesIO ):
659
+ # Deprecated format with io.BytesIO
629
660
state .seek (0 )
630
661
state = torch .load (state , map_location = "cuda" )
631
662
else :
@@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None:
634
665
if state is None :
635
666
return
636
667
637
- # Load extra items.
668
+ # Load extra items
638
669
self .fp8_meta .update (state ["extra_fp8_variables" ])
639
670
self .fp8_meta ["recipe" ].amax_history_len = state ["amax_history_fwd" ].shape [0 ]
640
671
if "global_fp8_buffer_pos_fwd_recompute" in self .fp8_meta :
641
672
del self .fp8_meta ["global_fp8_buffer_pos_fwd_recompute" ]
642
673
643
- # Initialize before loading.
674
+ # Initialize before loading
644
675
self .init_fp8_meta_tensors ()
645
- self .fp8_meta ["scaling_fwd" ].scale .copy_ (state ["scale_fwd" ])
646
- self .fp8_meta ["scaling_fwd" ].amax_history .copy_ (state ["amax_history_fwd" ])
647
- self .fp8_meta ["scaling_bwd" ].scale .copy_ (state ["scale_bwd" ])
648
- self .fp8_meta ["scaling_bwd" ].amax_history .copy_ (state ["amax_history_bwd" ])
649
- self .fp8_meta ["scaling_fwd" ].scale_inv .copy_ (state ["scale_inv_fwd" ])
650
- self .fp8_meta ["scaling_bwd" ].scale_inv .copy_ (state ["scale_inv_bwd" ])
676
+
677
+ def copy_tensor (src : torch .Tensor , dst : torch .Tensor ) -> None :
678
+ """Helper function to copy tensor from CPU
679
+
680
+ Memory transfer is asynchronous w.r.t. host, so GPU should
681
+ be synchronized before using result.
682
+
683
+ """
684
+ dst .copy_ (src , non_blocking = True )
685
+
686
+ # Load tensors
687
+ copy_tensor (state ["scale_fwd" ], self .fp8_meta ["scaling_fwd" ].scale )
688
+ copy_tensor (state ["amax_history_fwd" ], self .fp8_meta ["scaling_fwd" ].amax_history )
689
+ copy_tensor (state ["scale_inv_fwd" ], self .fp8_meta ["scaling_fwd" ].scale_inv )
690
+ copy_tensor (state ["scale_bwd" ], self .fp8_meta ["scaling_bwd" ].scale )
691
+ copy_tensor (state ["amax_history_bwd" ], self .fp8_meta ["scaling_bwd" ].amax_history )
692
+ copy_tensor (state ["scale_inv_bwd" ], self .fp8_meta ["scaling_bwd" ].scale_inv )
693
+ torch .cuda .synchronize ()
651
694
652
695
def set_activation_dtype (self , inp : torch .Tensor ) -> None :
653
696
"""Get activation data type for AMP."""
0 commit comments