Skip to content

Commit

Permalink
Add sharded interpolate (#83)
Browse files Browse the repository at this point in the history
Add fully replicated case.
Add split batch or channel dimensions.
  • Loading branch information
sogartar authored Jun 28, 2024
1 parent 8f3f93d commit f577d8b
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 1 deletion.
1 change: 1 addition & 0 deletions sharktank/sharktank/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"""

from . import _registry
from ._registry import unbox_tensor
from .signatures import *
from .shape import *

Expand Down
21 changes: 21 additions & 0 deletions sharktank/sharktank/ops/default_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,27 @@ def group_norm_affine_default(input, weight, bias, *, num_groups, eps):
return F.group_norm(input, num_groups=num_groups, weight=weight, bias=bias, eps=eps)


@interpolate.override(Tensor)
def interpolate_default(
input: Tensor,
size: Optional[int | List[int]],
scale_factor: Optional[float | List[float]],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool],
antialias: bool,
) -> Tensor:
return torch.nn.functional.interpolate(
input=unbox_tensor(input),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)


@layer_norm.override(Tensor, Tensor, Tensor)
def layer_norm_default(input, weight, bias, *, eps):
input = unbox_tensor(input)
Expand Down
53 changes: 52 additions & 1 deletion sharktank/sharktank/ops/sharded_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import torch
from torch import Tensor
from typing import List
from typing import List, Optional
import itertools

from ..types import (
Expand Down Expand Up @@ -279,6 +279,57 @@ def shareded_group_norm_affine(input, weight, bias, *, num_groups, eps):
return SplitPrimitiveTensor(shard_dim=1, ts=result_shards)


@interpolate.override(ReplicatedTensor)
def interpolate_replicated(
input: ReplicatedTensor,
size: Optional[int | List[int]],
scale_factor: Optional[float | List[float]],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool],
antialias: bool,
) -> ReplicatedTensor:
shards = [
torch.nn.functional.interpolate(
input=unbox_tensor(shard),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
for shard in input.shards
]
return ReplicatedTensor(ts=shards)


@interpolate.override(SplitPrimitiveTensor)
def interpolate_split_batch_or_channel(
input: SplitPrimitiveTensor,
size: Optional[int | List[int]],
scale_factor: Optional[float | List[float]],
mode: str,
align_corners: Optional[bool],
recompute_scale_factor: Optional[bool],
antialias: bool,
) -> SplitPrimitiveTensor:
assert input.shard_dim == 0 or input.shard_dim == 1
shards = [
torch.nn.functional.interpolate(
input=unbox_tensor(shard),
size=size,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
for shard in input.shards
]
return SplitPrimitiveTensor(ts=shards, shard_dim=input.shard_dim)


@layer_norm.override(SplitPrimitiveTensor, Tensor, Tensor)
def layer_norm_default(input, weight, bias, *, eps):
assert input.shard_dim >= 0 and input.shard_dim < len(input.shape) - len(
Expand Down
43 changes: 43 additions & 0 deletions sharktank/sharktank/ops/signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
"equal",
"group_norm_affine",
"layer_norm",
"interpolate",
"linear",
"matmul",
"permute",
Expand Down Expand Up @@ -216,6 +217,48 @@ def _group_norm_affine_trampoline(
d.fail(tensors)


@overridable
def interpolate(
input: AnyTensor,
size: Optional[int | List[int]] = None,
scale_factor: Optional[float | List[float]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
) -> AnyTensor:
"""Equivalent to torch.nn.functional.interpolate"""
raise NotImplementedError


@interpolate.trampoline
def _interpolate_trampoline(
d: SignatureDispatcher,
input: AnyTensor,
size: Optional[int | List[int]] = None,
scale_factor: Optional[float | List[float]] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
) -> AnyTensor:
tensors = [input]
for override in d.find_overrides(tensors):
result = override(
input,
size,
scale_factor,
mode,
align_corners,
recompute_scale_factor,
antialias,
)
if result is not NotImplemented:
return override, result
else:
d.fail(tensors)


@overridable
def layer_norm(
input: AnyTensor, weight: AnyTensor, bias: Optional[AnyTensor], *, eps: float
Expand Down
71 changes: 71 additions & 0 deletions sharktank/tests/ops/sharded_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,77 @@ def testNotEqualSharded(self):
assert not ops.equal(b_sharded, a_sharded)


class InterpolateTest(unittest.TestCase):
def testInterpolateSplitChannelDim(self):
batches = 2
channels = 6
height = 5
width = 4
scale_factor = 2.0
mode = "bilinear"
align_corners = True
recompute_scale_factor = True
antialias = True
input = torch.rand(batches, channels, height, width, dtype=torch.float32)
expected_result = torch.nn.functional.interpolate(
input=input,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
shard_count = 3
sharded_input = ops.reshard_split(input, dim=1, count=shard_count)
sharded_result = ops.interpolate(
input=sharded_input,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
assert isinstance(sharded_result, SplitPrimitiveTensor)
assert sharded_result.shard_count == shard_count
assert sharded_result.shard_dim == 1
actual_result = ops.unbox_tensor(ops.unshard(sharded_result))
torch.testing.assert_close(actual_result, expected_result)

def testInterpolateReplicated(self):
batches = 2
channels = 6
height = 5
width = 4
scale_factor = 2.0
mode = "bilinear"
align_corners = True
recompute_scale_factor = True
antialias = True
input = torch.rand(batches, channels, height, width, dtype=torch.float32)
expected_result = torch.nn.functional.interpolate(
input=input,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
shard_count = 3
sharded_input = ops.replicate(input, count=shard_count)
sharded_result = ops.interpolate(
input=sharded_input,
scale_factor=scale_factor,
mode=mode,
align_corners=align_corners,
recompute_scale_factor=recompute_scale_factor,
antialias=antialias,
)
assert isinstance(sharded_result, ReplicatedTensor)
assert sharded_result.shard_count == shard_count
actual_result = ops.unbox_tensor(ops.unshard(sharded_result))
torch.testing.assert_close(actual_result, expected_result)


class NormalizationTest(unittest.TestCase):
def testGroupNormShardedGroups(self):
"""Shard the channel dimension such that the group count is multiple of the
Expand Down

0 comments on commit f577d8b

Please sign in to comment.