Skip to content

Commit 6eee172

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Fix convert_bounding_box_format when passing strings (#8265)
Reviewed By: vmoens Differential Revision: D55062797 fbshipit-source-id: 776e1cd156ad5e4a857e7ea1ecbf0b7933a35f87
1 parent eb68391 commit 6eee172

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

test/test_transforms_v2.py

+17
Original file line numberDiff line numberDiff line change
@@ -3398,6 +3398,23 @@ def test_transform(self, old_format, new_format, format_type):
33983398
make_bounding_boxes(format=old_format),
33993399
)
34003400

3401+
@pytest.mark.parametrize(("old_format", "new_format"), old_new_formats)
3402+
def test_strings(self, old_format, new_format):
3403+
# Non-regression test for https://github.com/pytorch/vision/issues/8258
3404+
input = tv_tensors.BoundingBoxes(torch.tensor([[10, 10, 20, 20]]), format=old_format, canvas_size=(50, 50))
3405+
expected = self._reference_convert_bounding_box_format(input, new_format)
3406+
3407+
old_format = old_format.name
3408+
new_format = new_format.name
3409+
3410+
out_functional = F.convert_bounding_box_format(input, new_format=new_format)
3411+
out_functional_tensor = F.convert_bounding_box_format(
3412+
input.as_subclass(torch.Tensor), old_format=old_format, new_format=new_format
3413+
)
3414+
out_transform = transforms.ConvertBoundingBoxFormat(new_format)(input)
3415+
for out in (out_functional, out_functional_tensor, out_transform):
3416+
assert_equal(out, expected)
3417+
34013418
def _reference_convert_bounding_box_format(self, bounding_boxes, new_format):
34023419
return tv_tensors.wrap(
34033420
torchvision.ops.box_convert(

torchvision/transforms/v2/_meta.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,10 @@ class ConvertBoundingBoxFormat(Transform):
1717

1818
def __init__(self, format: Union[str, tv_tensors.BoundingBoxFormat]) -> None:
1919
super().__init__()
20-
if isinstance(format, str):
21-
format = tv_tensors.BoundingBoxFormat[format]
2220
self.format = format
2321

2422
def _transform(self, inpt: tv_tensors.BoundingBoxes, params: Dict[str, Any]) -> tv_tensors.BoundingBoxes:
25-
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value]
23+
return F.convert_bounding_box_format(inpt, new_format=self.format) # type: ignore[return-value, arg-type]
2624

2725

2826
class ClampBoundingBoxes(Transform):

torchvision/transforms/v2/functional/_meta.py

+5
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,11 @@ def convert_bounding_box_format(
214214
if not torch.jit.is_scripting():
215215
_log_api_usage_once(convert_bounding_box_format)
216216

217+
if isinstance(old_format, str):
218+
old_format = BoundingBoxFormat[old_format.upper()]
219+
if isinstance(new_format, str):
220+
new_format = BoundingBoxFormat[new_format.upper()]
221+
217222
if torch.jit.is_scripting() or is_pure_tensor(inpt):
218223
if old_format is None:
219224
raise ValueError("For pure tensor inputs, `old_format` has to be passed.")

0 commit comments

Comments
 (0)