Skip to content

Commit

Permalink
Support noop concat without providing full tensor
Browse files Browse the repository at this point in the history
Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <[email protected]>
  • Loading branch information
timmoon10 committed Mar 13, 2024
1 parent e3d2efd commit e1e0fa2
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 164 deletions.
143 changes: 78 additions & 65 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,97 +98,110 @@ def _apply_normalization(inputmat:torch.Tensor,


class _NoopCatFunc(torch.autograd.Function):
"""No-op concatenate tensors along dim 0
"""Concatenate tensors, doing a no-op if possible
`full_tensor` is assumed to already be the concatenation of
`tensors`, i.e. they occupy the same memory with the correct
offsets.
See _noop_cat.
"""

@staticmethod
def forward(
ctx,
split_ranges: List[Tuple[int, int]],
full_tensor: torch.Tensor,
ctx: Any,
dim: int,
*tensors: Tuple[torch.Tensor, ...],
) -> torch.Tensor:
# pylint: disable=unused-argument

# Check first tensor
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
num_dims = tensors[0].dim()
if not (-num_dims <= dim < num_dims):
raise ValueError(
"Attempted to concatenate tensor "
f"with shape {list(tensors[0].size())} along dim {dim}"
)
dim %= num_dims

# Check remaining tensors
out_shape = list(tensors[0].size())
split_ranges = [(0, tensors[0].size(dim))]
for tensor in tensors[1:]:
in_shape = list(tensor.size())
if (
len(in_shape) != num_dims
or in_shape[:dim] != out_shape[:dim]
or in_shape[dim+1:] != out_shape[dim+1:]
):
raise ValueError(
"Attempted to concatenate tensors with shapes "
f"{[list(tensor.size()) for tensor in tensors]} "
f"along dim {dim}"
)
split_start = out_shape[dim]
split_end = split_start + in_shape[dim]
out_shape[dim] = split_end
split_ranges.append((split_start, split_end))

# Save state for backward
ctx.dim = dim
ctx.split_ranges = split_ranges
assert not full_tensor.requires_grad, "Concatenated tensor should not require gradient"
out = full_tensor.new()

# Out-of-place concatenation if needed
dtype = tensors[0].dtype
device = tensors[0].device
strides = tensors[0].strides()
data_ptr_stride = strides[dim] * tensors[0].element_size()
data_ptr = tensors[0].data_ptr() + tensors[0].size(dim) * data_ptr_stride
for tensor in tensors[1:]:
if (
tensor.dtype != dtype
or tensor.device != device
or tensor.strides() != strides
or tensor.data_ptr() != data_ptr
):
return torch.cat(tensors, dim=dim)
data_ptr += tensor.size(dim) * data_ptr_stride

# No-op concatenation
out = tensors[0].new()
out.set_(
full_tensor.untyped_storage(),
full_tensor.storage_offset(),
full_tensor.size(),
full_tensor.stride(),
tensors[0].untyped_storage(),
tensors[0].storage_offset(),
out_shape,
strides,
)
out.requires_grad = True
out.requires_grad = any(tensor.requires_grad for tensor in tensors)
return out

@staticmethod
def backward(
ctx,
grad_output: torch.Tensor,
) -> Tuple[Optional[torch.Tensor], ...]:
grads = [
grad_output[split_start:split_end]
for split_start, split_end in ctx.split_ranges
]
return None, None, *grads
grad_inputs = []
for split_start, split_end in ctx.split_ranges:
idxs = [None] * grad_output.dim()
idxs[ctx.dim] = slice(split_start, split_end)
grad_inputs.append(grad_output[idxs])
return None, *grad_inputs


def _noop_cat(
tensors: List[torch.Tensor],
full_tensor: torch.Tensor,
dim: int = 0,
) -> torch.Tensor:
"""Concatenate tensors along dim 0, doing a no-op if possible
If `full_tensor` is already the concatenation of `tensors`, i.e.
they occupy the same memory region with the correct offsets, then
no copies are performed. Otherwise the buffers in all the tensors
are reallocated so that another call would result in a no-op.
"""Concatenate tensors, doing a no-op if possible
In the backward pass, gradients to `partial_tensors` will just be
tensor views.
If tensors are already concatenated in memory, a tensor view of
that memory region will be returned. Otherwise the tensors will be
concatenated out-of-place, as usual.
"""

# Determine split points
split_ranges = []
full_tensor_shape = full_tensor.size()
offset = 0
for tensor in tensors:
tensor_shape = tensor.size()
if tensor_shape[1:] != full_tensor_shape[1:]:
raise ValueError(
f"Attempting to concatenate tensor with shape={list(tensor_shape)} "
f"into a tensor with shape={list(full_tensor_shape)}"
)
split_start = offset
offset += tensor_shape[0]
split_end = offset
split_ranges.append((split_start, split_end))
if offset != full_tensor_shape[0]:
raise ValueError(
f"Attempting to concatenate tensors with total shape[0]={offset} "
f"into a tensor with shape[0]={full_tensor_shape[0]}"
)

# Reallocate buffers if no-op concat isn't possible
need_to_reallocate = False
for tensor, (split_start, _) in zip(tensors, split_ranges):
if tensor.data_ptr() != full_tensor[split_start].data_ptr():
need_to_reallocate = True
break
if need_to_reallocate:
with torch.no_grad():
full_tensor.data = torch.cat(tensors)
for tensor, (split_start, split_end) in zip(tensors, split_ranges):
tensor.data = full_tensor[split_start:split_end]

# Perform no-op concat
return _NoopCatFunc.apply(split_ranges, full_tensor, *tensors)
if not tensors:
raise ValueError("Attempted to concatenate 0 tensors")
if len(tensors) == 1:
return tensors[0]
return _NoopCatFunc.apply(dim, *tensors)


@dataclass
Expand Down
96 changes: 46 additions & 50 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,17 +814,20 @@ def __init__(
else:
self.layer_norm_bias = None

self.weight_tensor = torch.empty(
self.out_features, self.in_features,
device=device, dtype=params_dtype)

# Contiguous buffers for params
weight_tensor = torch.empty(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
)
bias_tensor = None
if self.use_bias:
self.bias_tensor = torch.empty(
bias_tensor = torch.empty(
self.out_features,
device=device,
dtype=params_dtype)
else:
self.bias_tensor = torch.Tensor().to(dtype=params_dtype, device=device)
dtype=params_dtype,
)

# Configure parameter splits
self.weight_names = []
Expand Down Expand Up @@ -870,7 +873,11 @@ def __init__(
)
self.parameter_split_sizes[i] = size // self.tp_size

# Construct parameters from weight and bias buffers
# Construct weight parameters
# Note: Register weights together so that they are adjacent to
# each other in LayerNormLinear.parameters(). This makes it
# more likely that they will stay contiguous if the weights
# are manipulated externally, e.g. by FSDP.
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
Expand All @@ -886,32 +893,30 @@ def __init__(
)

# Construct weight parameter
weight = self.weight_tensor
if is_subview:
weight = weight[split_start:split_end]
weight = torch.nn.Parameter(weight)
self.register_parameter(self.weight_names[i], weight,
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT)

# Construct bias parameter if needed
if self.use_bias:
bias = self.bias_tensor
if is_subview:
bias = bias[split_start:split_end]
bias = torch.nn.Parameter(bias)
self.register_parameter(self.bias_names[i], bias,
init_fn=init_method_constant(0.0))
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, self.bias_names[i], bias)
self.register_parameter(
self.weight_names[i],
torch.nn.Parameter(weight_tensor[split_start:split_end]),
init_fn=init_method,
get_rng_state_tracker=get_rng_state_tracker,
fp8_meta_index=tex.FP8FwdTensors.GEMM1_WEIGHT,
)

# Concatenated tensors are not needed if not splitting
# into multiple parameters
if not is_subview:
del self.weight_tensor
del self.bias_tensor
# Construct bias parameters if needed
if self.use_bias:
offset = 0
for i, split_size in enumerate(self.parameter_split_sizes):
split_start = offset
offset += split_size
split_end = offset
self.register_parameter(
self.bias_names[i],
torch.nn.Parameter(bias_tensor[split_start:split_end]),
init_fn=init_method_constant(0.0),
)
else:
for name in self.bias_names:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
setattr(self, name, bias)

if self.primary_weights_in_fp8:
self.init_fp8_metadata()
Expand Down Expand Up @@ -1034,24 +1039,15 @@ def forward(
"Need to run inside fp8_autocast region when weights are stored in FP8."

# Get concatenated weight and bias tensors
if len(self.parameter_split_sizes) == 1:
weight_tensor = getattr(self, self.weight_names[0])
bias_tensor = getattr(self, self.bias_names[0])
elif torch.is_grad_enabled():
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
self.weight_tensor,
weight_tensor = _noop_cat(
[getattr(self, name) for name in self.weight_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
)
if self.use_bias:
bias_tensor = _noop_cat(
[getattr(self, name) for name in self.bias_names],
self.bias_tensor,
)
else:
bias_tensor = getattr(self, self.bias_names[0]) # Unused
else:
weight_tensor = self.weight_tensor
bias_tensor = self.bias_tensor
bias_tensor = getattr(self, self.bias_names[0]) # Unused

# Fetch the fp8 weights placeholders (for linear/gemm)
weight1_fp8, weight1_t_fp8 = self.get_fp8_weights_scratchpad(
Expand Down
Loading

0 comments on commit e1e0fa2

Please sign in to comment.