From 36d0e3e64c16103d9ba6e75af935c9a884d99cc5 Mon Sep 17 00:00:00 2001 From: Mantas <56790921+mantasu@users.noreply.github.com> Date: Tue, 6 Feb 2024 11:10:19 +0000 Subject: [PATCH] Allow 2D numpy arrays as inputs for `to_image` (#8256) Co-authored-by: Nicolas Hug --- test/test_transforms_v2.py | 5 +++++ torchvision/transforms/v2/functional/_type_conversion.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index b40d04fffdd..458f83f01c3 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -5182,6 +5182,11 @@ def test_functional_and_transform(self, make_input, fn): if isinstance(input, torch.Tensor): assert output.data_ptr() == input.data_ptr() + def test_2d_np_array(self): + # Non-regression test for https://github.com/pytorch/vision/issues/8255 + input = np.random.rand(10, 10) + assert F.to_image(input).shape == (1, 10, 10) + def test_functional_error(self): with pytest.raises(TypeError, match="Input can either be a pure Tensor, a numpy array, or a PIL image"): F.to_image(object()) diff --git a/torchvision/transforms/v2/functional/_type_conversion.py b/torchvision/transforms/v2/functional/_type_conversion.py index 9ac357315b2..c5a731fe143 100644 --- a/torchvision/transforms/v2/functional/_type_conversion.py +++ b/torchvision/transforms/v2/functional/_type_conversion.py @@ -11,7 +11,7 @@ def to_image(inpt: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> tv_tensors.Image: """See :class:`~torchvision.transforms.v2.ToImage` for details.""" if isinstance(inpt, np.ndarray): - output = torch.from_numpy(inpt).permute((2, 0, 1)).contiguous() + output = torch.from_numpy(np.atleast_3d(inpt)).permute((2, 0, 1)).contiguous() elif isinstance(inpt, PIL.Image.Image): output = pil_to_tensor(inpt) elif isinstance(inpt, torch.Tensor):