Skip to content

Commit

Permalink
Merge branch 'main' into issue-8203
Browse files Browse the repository at this point in the history
  • Loading branch information
bmmtstb authored Feb 5, 2024
2 parents e31964d + 806dba6 commit c88511f
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 20 deletions.
3 changes: 2 additions & 1 deletion .github/scripts/unittest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ set -euo pipefail
eval "$($(which conda) shell.bash hook)" && conda deactivate && conda activate ci

echo '::group::Install testing utilities'
pip install --progress-bar=off pytest pytest-mock pytest-cov expecttest!=0.2.0
# TODO: remove the <8 constraint on pytest when https://github.com/pytorch/vision/issues/8238 is closed
pip install --progress-bar=off "pytest<8" pytest-mock pytest-cov expecttest!=0.2.0
echo '::endgroup::'

python test/smoke_test.py
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ jobs:
echo '::endgroup::'
echo '::group::Install testing utilities'
pip install --progress-bar=off pytest
# TODO: remove the <8 constraint on pytest when https://github.com/pytorch/vision/issues/8238 is closed
pip install --progress-bar=off "pytest<8"
echo '::endgroup::'
echo '::group::Run extended unittests'
Expand Down
6 changes: 6 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ ENDFOREACH()
add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE ${TORCH_LIBRARIES})

if(WITH_MPS)
find_library(metal NAMES Metal)
find_library(foundation NAMES Foundation)
target_link_libraries(${PROJECT_NAME} PRIVATE ${metal} ${foundation})
endif()

if (WITH_PNG)
target_link_libraries(${PROJECT_NAME} PRIVATE ${PNG_LIBRARY})
endif()
Expand Down
13 changes: 11 additions & 2 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4935,15 +4935,24 @@ def test_transform(self, transform, make_input):
check_transform(transform, make_input())

@pytest.mark.parametrize("num_output_channels", [1, 3])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
@pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)])
def test_image_correctness(self, num_output_channels, fn):
image = make_image(dtype=torch.uint8, device="cpu")
def test_image_correctness(self, num_output_channels, color_space, fn):
image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space)

actual = fn(image, num_output_channels=num_output_channels)
expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels))

assert_equal(actual, expected, rtol=0, atol=1)

def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self):
image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY")

output_image = F.rgb_to_grayscale(image, num_output_channels=3)
assert_equal(output_image[0][0][0], output_image[1][0][0])
output_image[0][0][0] = output_image[0][0][0] + 1
assert output_image[0][0][0] != output_image[1][0][0]

@pytest.mark.parametrize("num_input_channels", [1, 3])
def test_random_transform_correctness(self, num_input_channels):
image = make_image(
Expand Down
16 changes: 8 additions & 8 deletions torchvision/datasets/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class VideoClips:
video_paths (List[str]): paths to the video files
clip_length_in_frames (int): size of a clip in number of frames
frames_between_clips (int): step (in frames) between each clip
frame_rate (int, optional): if specified, it will resample the video
frame_rate (float, optional): if specified, it will resample the video
so that it has `frame_rate`, and then the clips will be defined
on the resampled video
num_workers (int): how many subprocesses to use for data loading.
Expand All @@ -102,7 +102,7 @@ def __init__(
video_paths: List[str],
clip_length_in_frames: int = 16,
frames_between_clips: int = 1,
frame_rate: Optional[int] = None,
frame_rate: Optional[float] = None,
_precomputed_metadata: Optional[Dict[str, Any]] = None,
num_workers: int = 0,
_video_width: int = 0,
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(

def _compute_frame_pts(self) -> None:
self.video_pts = [] # len = num_videos. Each entry is a tensor of shape (num_frames_in_video,)
self.video_fps: List[int] = [] # len = num_videos
self.video_fps: List[float] = [] # len = num_videos

# strategy: use a DataLoader to parallelize read_video_timestamps
# so need to create a dummy dataset first
Expand Down Expand Up @@ -203,15 +203,15 @@ def subset(self, indices: List[int]) -> "VideoClips":

@staticmethod
def compute_clips_for_video(
video_pts: torch.Tensor, num_frames: int, step: int, fps: int, frame_rate: Optional[int] = None
video_pts: torch.Tensor, num_frames: int, step: int, fps: Optional[float], frame_rate: Optional[float] = None
) -> Tuple[torch.Tensor, Union[List[slice], torch.Tensor]]:
if fps is None:
# if for some reason the video doesn't have fps (because doesn't have a video stream)
# set the fps to 1. The value doesn't matter, because video_pts is empty anyway
fps = 1
if frame_rate is None:
frame_rate = fps
total_frames = len(video_pts) * (float(frame_rate) / fps)
total_frames = len(video_pts) * frame_rate / fps
_idxs = VideoClips._resample_video_idx(int(math.floor(total_frames)), fps, frame_rate)
video_pts = video_pts[_idxs]
clips = unfold(video_pts, num_frames, step)
Expand All @@ -227,7 +227,7 @@ def compute_clips_for_video(
idxs = unfold(_idxs, num_frames, step)
return clips, idxs

def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[int] = None) -> None:
def compute_clips(self, num_frames: int, step: int, frame_rate: Optional[float] = None) -> None:
"""
Compute all consecutive sequences of clips from video_pts.
Always returns clips of size `num_frames`, meaning that the
Expand Down Expand Up @@ -275,8 +275,8 @@ def get_clip_location(self, idx: int) -> Tuple[int, int]:
return video_idx, clip_idx

@staticmethod
def _resample_video_idx(num_frames: int, original_fps: int, new_fps: int) -> Union[slice, torch.Tensor]:
step = float(original_fps) / new_fps
def _resample_video_idx(num_frames: int, original_fps: float, new_fps: float) -> Union[slice, torch.Tensor]:
step = original_fps / new_fps
if step.is_integer():
# optimization: if step is integer, don't need to perform
# advanced indexing
Expand Down
3 changes: 1 addition & 2 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ class FasterRCNN(GeneralizedRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down
3 changes: 1 addition & 2 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class KeypointRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down
3 changes: 1 addition & 2 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ class MaskRCNN(FasterRCNN):
for computing the loss
rpn_positive_fraction (float): proportion of positive anchors in a mini-batch during training
of the RPN
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
rpn_score_thresh (float): only return proposals with an objectness score greater than rpn_score_thresh
box_roi_pool (MultiScaleRoIAlign): the module which crops and resizes the feature maps in
the locations indicated by the bounding boxes
box_head (nn.Module): module that takes the cropped feature maps as input
Expand Down
1 change: 1 addition & 0 deletions torchvision/models/detection/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ class RegionProposalNetwork(torch.nn.Module):
contain two fields: training and testing, to allow for different values depending
on training or evaluation
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
score_thresh (float): only return proposals with an objectness score greater than score_thresh
"""

Expand Down
8 changes: 6 additions & 2 deletions torchvision/transforms/v2/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch.
def _rgb_to_grayscale_image(
image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True
) -> torch.Tensor:
if image.shape[-3] == 1:
# TODO: Maybe move the validation that num_output_channels is 1 or 3 to this function instead of callers.
if image.shape[-3] == 1 and num_output_channels == 1:
return image.clone()

if image.shape[-3] == 1 and num_output_channels == 3:
s = [1] * len(image.shape)
s[-3] = 3
return image.repeat(s)
r, g, b = image.unbind(dim=-3)
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
l_img = l_img.unsqueeze(dim=-3)
Expand Down

0 comments on commit c88511f

Please sign in to comment.