Skip to content

Commit 8c00424

Browse files
authored
[PyTorch] Store module extra state in tensor (#1335)
Store module extra state in tensor Signed-off-by: Tim Moon <[email protected]>
1 parent 71ada55 commit 8c00424

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

transformer_engine/pytorch/module/base.py

+67-24
Original file line numberDiff line numberDiff line change
@@ -588,20 +588,50 @@ def reset(key):
588588

589589
def get_extra_state(self) -> torch.Tensor:
590590
"""Save before checkpointing."""
591-
state = None
592591

592+
# This implementation is working around a few issues:
593+
#
594+
# (1) PyTorch's "extra state" infrastructure might be able to
595+
# support any picklable type, but they make no guarantees.
596+
# We have experienced problems (e.g. in ONNX export) with
597+
# non-tensor extra state.
598+
# (2) PyTorch's checkpointing infrastructure does not remap
599+
# devices for "extra state" like it does for "state dict".
600+
# Thus, we want to avoid putting extra state on the GPU
601+
# since it may be loaded on the wrong device.
602+
# (3) The extra state consists of many small tensors. If we
603+
# want to copy them all to CPU, then we need to avoid the
604+
# overhead of many GPU-CPU memory transfers.
605+
#
606+
# See: https://github.com/NVIDIA/TransformerEngine/pull/351
607+
# See: https://github.com/NVIDIA/TransformerEngine/pull/363
608+
609+
def to_cpu(src: torch.Tensor) -> torch.Tensor:
610+
"""Helper function to make CPU copy of tensor
611+
612+
Memory transfer is asynchronous w.r.t. host, so GPU should
613+
be synchronized before using result.
614+
615+
"""
616+
dst = torch.empty_like(src, device="cpu")
617+
dst.copy_(src, non_blocking=True)
618+
return dst
619+
620+
# Store FP8 state if needed
621+
state = None
593622
fp8_checkpoint = self.fp8_meta["fp8_checkpoint"] or self.fp8 or self.fp8_calibration
594-
595623
if fp8_checkpoint:
624+
625+
# Copy tensors to CPU and store
596626
state = {}
597-
state["scale_fwd"] = self.fp8_meta["scaling_fwd"].scale
598-
state["scale_inv_fwd"] = self.fp8_meta["scaling_fwd"].scale_inv
599-
state["amax_history_fwd"] = self.fp8_meta["scaling_fwd"].amax_history
600-
state["scale_bwd"] = self.fp8_meta["scaling_bwd"].scale
601-
state["scale_inv_bwd"] = self.fp8_meta["scaling_bwd"].scale_inv
602-
state["amax_history_bwd"] = self.fp8_meta["scaling_bwd"].amax_history
603-
604-
# Store other pickelable values.
627+
state["scale_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale)
628+
state["amax_history_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].amax_history)
629+
state["scale_inv_fwd"] = to_cpu(self.fp8_meta["scaling_fwd"].scale_inv)
630+
state["scale_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale)
631+
state["amax_history_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].amax_history)
632+
state["scale_inv_bwd"] = to_cpu(self.fp8_meta["scaling_bwd"].scale_inv)
633+
634+
# Store other pickelable values
605635
extra = {}
606636
for k, v in self.fp8_meta.items():
607637
if k != "buffer_index_and_autocast_key" and isinstance(
@@ -610,22 +640,23 @@ def get_extra_state(self) -> torch.Tensor:
610640
extra[k] = v
611641
state["extra_fp8_variables"] = extra
612642

613-
if is_in_onnx_export_mode():
614-
state_serialized = torch.frombuffer(pickle.dumps(state), dtype=torch.uint8)
615-
else:
616-
state_serialized = io.BytesIO()
617-
torch.save(state, state_serialized)
618-
643+
# Serialize state into byte tensor
644+
torch.cuda.synchronize()
645+
state_serialized = bytearray(pickle.dumps(state))
646+
state_serialized = torch.frombuffer(state_serialized, dtype=torch.uint8)
619647
return state_serialized
620648

621649
def set_extra_state(self, state: torch.Tensor) -> None:
622650
"""Load previous state."""
623651
if state is None:
624652
return
625653

654+
# Load state
626655
if isinstance(state, torch.Tensor):
656+
# Default format: byte tensor with pickled data
627657
state = pickle.loads(state.detach().cpu().numpy().tobytes())
628658
elif isinstance(state, io.BytesIO):
659+
# Deprecated format with io.BytesIO
629660
state.seek(0)
630661
state = torch.load(state, map_location="cuda")
631662
else:
@@ -634,20 +665,32 @@ def set_extra_state(self, state: torch.Tensor) -> None:
634665
if state is None:
635666
return
636667

637-
# Load extra items.
668+
# Load extra items
638669
self.fp8_meta.update(state["extra_fp8_variables"])
639670
self.fp8_meta["recipe"].amax_history_len = state["amax_history_fwd"].shape[0]
640671
if "global_fp8_buffer_pos_fwd_recompute" in self.fp8_meta:
641672
del self.fp8_meta["global_fp8_buffer_pos_fwd_recompute"]
642673

643-
# Initialize before loading.
674+
# Initialize before loading
644675
self.init_fp8_meta_tensors()
645-
self.fp8_meta["scaling_fwd"].scale.copy_(state["scale_fwd"])
646-
self.fp8_meta["scaling_fwd"].amax_history.copy_(state["amax_history_fwd"])
647-
self.fp8_meta["scaling_bwd"].scale.copy_(state["scale_bwd"])
648-
self.fp8_meta["scaling_bwd"].amax_history.copy_(state["amax_history_bwd"])
649-
self.fp8_meta["scaling_fwd"].scale_inv.copy_(state["scale_inv_fwd"])
650-
self.fp8_meta["scaling_bwd"].scale_inv.copy_(state["scale_inv_bwd"])
676+
677+
def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None:
678+
"""Helper function to copy tensor from CPU
679+
680+
Memory transfer is asynchronous w.r.t. host, so GPU should
681+
be synchronized before using result.
682+
683+
"""
684+
dst.copy_(src, non_blocking=True)
685+
686+
# Load tensors
687+
copy_tensor(state["scale_fwd"], self.fp8_meta["scaling_fwd"].scale)
688+
copy_tensor(state["amax_history_fwd"], self.fp8_meta["scaling_fwd"].amax_history)
689+
copy_tensor(state["scale_inv_fwd"], self.fp8_meta["scaling_fwd"].scale_inv)
690+
copy_tensor(state["scale_bwd"], self.fp8_meta["scaling_bwd"].scale)
691+
copy_tensor(state["amax_history_bwd"], self.fp8_meta["scaling_bwd"].amax_history)
692+
copy_tensor(state["scale_inv_bwd"], self.fp8_meta["scaling_bwd"].scale_inv)
693+
torch.cuda.synchronize()
651694

652695
def set_activation_dtype(self, inp: torch.Tensor) -> None:
653696
"""Get activation data type for AMP."""

transformer_engine/pytorch/ops/op.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ def get_extra_state(self) -> torch.Tensor:
514514
#
515515
# (1) PyTorch's "extra state" infrastructure might be able to
516516
# support any picklable type, but they make no guarantees.
517-
# It seems that ONNX export experiences issues with
517+
# We have experienced problems (e.g. in ONNX export) with
518518
# non-tensor extra state.
519519
# (2) PyTorch's checkpointing infrastructure does not remap
520520
# devices for "extra state" like it does for "state dict".

0 commit comments

Comments
 (0)