Skip to content

Commit

Permalink
Addressed comments and added another test.
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmadsharif1 committed Feb 12, 2024
1 parent 78db2f6 commit 990fc31
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
14 changes: 13 additions & 1 deletion test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5004,7 +5004,6 @@ def test_image_correctness(self, fn):
actual = fn(image)
expected = F.to_image(F.grayscale_to_rgb(F.to_pil_image(image)))

print(f"ahmad: {expected.shape=} {actual.shape=}")
assert_equal(actual, expected, rtol=0, atol=1)

def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
Expand All @@ -5015,6 +5014,19 @@ def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
output_image[0][0][0] = output_image[0][0][0] + 1
assert output_image[0][0][0] != output_image[1][0][0]

def test_rgb_image_is_unchanged(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="RGB")
assert_equal(image.shape[-3], 3)
image[0][0][0] = 0
image[1][0][0] = 100
image[2][0][0] = 200
output_image = F.grayscale_to_rgb(image)
assert output_image[0][0][0] == 0
assert output_image[1][0][0] == 100
assert output_image[2][0][0] == 200
print(image)
print(output_image)


class TestRandomZoomOut:
# Tests are light because this largely relies on the already tested `pad` kernels.
Expand Down
6 changes: 4 additions & 2 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,16 @@ def grayscale_to_rgb(inpt: torch.Tensor) -> torch.Tensor:
@_register_kernel_internal(grayscale_to_rgb, torch.Tensor)
@_register_kernel_internal(grayscale_to_rgb, tv_tensors.Image)
def grayscale_to_rgb_image(image: torch.Tensor) -> torch.Tensor:
if image.shape[-3] >= 3:
# Image already has RGB channels. We don't need to do anything.
return image
# rgb_to_grayscale can be used to add channels so we reuse that function.
return _rgb_to_grayscale_image(image, num_output_channels=3, preserve_dtype=True)


@_register_kernel_internal(grayscale_to_rgb, PIL.Image.Image)
def grayscale_to_rgb_image_pil(image: PIL.Image.Image) -> PIL.Image.Image:
# to_grayscale can expand channels from 1 to 3 so we reuse that function.
return _FP.to_grayscale(image, num_output_channels=3)
return image.convert(mode="RGB")


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
Expand Down

0 comments on commit 990fc31

Please sign in to comment.