diff --git a/src/lightly_train/_commands/train_task.py b/src/lightly_train/_commands/train_task.py index 8473a62df..5a7f7a64c 100644 --- a/src/lightly_train/_commands/train_task.py +++ b/src/lightly_train/_commands/train_task.py @@ -721,6 +721,8 @@ def _train_task_from_config(config: TrainTaskConfig) -> None: ) model_init_args = {} if model_init_args is None else model_init_args + if "model_name" not in model_init_args: + model_init_args["model_name"] = config.model train_transform_args, val_transform_args = helpers.get_transform_args( train_model_cls=train_model_cls, diff --git a/src/lightly_train/_data/yolo_object_detection_dataset.py b/src/lightly_train/_data/yolo_object_detection_dataset.py index 9cd8a0bb2..a2f632257 100644 --- a/src/lightly_train/_data/yolo_object_detection_dataset.py +++ b/src/lightly_train/_data/yolo_object_detection_dataset.py @@ -75,7 +75,8 @@ def __getitem__(self, index: int) -> ObjectDetectionDatasetItem: [ int(class_id) in self.class_id_to_internal_class_id for class_id in class_labels_np - ] + ], + dtype=bool, ) bboxes_np = bboxes_np[keep] class_labels_np = class_labels_np[keep] diff --git a/src/lightly_train/_task_models/picodet_object_detection/csp_pan.py b/src/lightly_train/_task_models/picodet_object_detection/csp_pan.py index 03f613302..dfd6416e8 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/csp_pan.py +++ b/src/lightly_train/_task_models/picodet_object_detection/csp_pan.py @@ -50,7 +50,7 @@ def __init__( ) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False) self.bn = nn.BatchNorm2d(out_channels) - self.act = nn.Hardswish(inplace=True) + self.act = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: x = self.depthwise(x) @@ -83,7 +83,7 @@ def __init__( bias=False, ) self.bn = nn.BatchNorm2d(out_channels) - self.act = nn.Hardswish(inplace=True) + self.act = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: out: Tensor = self.act(self.bn(self.conv(x))) diff --git a/src/lightly_train/_task_models/picodet_object_detection/esnet.py b/src/lightly_train/_task_models/picodet_object_detection/esnet.py index b3c7bc120..291b8ab09 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/esnet.py +++ b/src/lightly_train/_task_models/picodet_object_detection/esnet.py @@ -77,7 +77,7 @@ def __init__( padding: int | None = None, groups: int = 1, bias: bool = False, - act: Literal["hardswish", "relu", "none"] = "hardswish", + act: Literal["relu", "none"] = "relu", ) -> None: super().__init__() if padding is None: @@ -93,10 +93,8 @@ def __init__( ) self.bn = nn.BatchNorm2d(out_channels) - if act == "hardswish": - self.act: nn.Module = nn.Hardswish(inplace=True) - elif act == "relu": - self.act = nn.ReLU(inplace=True) + if act == "relu": + self.act: nn.Module = nn.ReLU(inplace=True) else: self.act = nn.Identity() @@ -122,7 +120,7 @@ def __init__(self, channels: int, reduction: int = 4) -> None: def forward(self, x: Tensor) -> Tensor: scale = F.adaptive_avg_pool2d(x, 1) scale = F.relu(self.fc1(scale), inplace=True) - scale = F.hardsigmoid(self.fc2(scale), inplace=True) + scale = torch.sigmoid(self.fc2(scale)) return x * scale @@ -149,7 +147,7 @@ def __init__( super().__init__() self.conv_pw = ConvBNAct( - in_channels // 2, mid_channels // 2, kernel_size=1, act="hardswish" + in_channels // 2, mid_channels // 2, kernel_size=1, act="relu" ) self.conv_dw = ConvBNAct( mid_channels // 2, @@ -160,7 +158,7 @@ def __init__( ) self.se = SEModule(se_channels, reduction=4) self.conv_linear = ConvBNAct( - mid_channels, out_channels // 2, kernel_size=1, act="hardswish" + mid_channels, out_channels // 2, kernel_size=1, act="relu" ) def forward(self, x: Tensor) -> Tensor: @@ -208,11 +206,11 @@ def __init__( act="none", ) self.conv_linear_1 = ConvBNAct( - in_channels, out_channels // 2, kernel_size=1, act="hardswish" + in_channels, out_channels // 2, kernel_size=1, act="relu" ) self.conv_pw_2 = ConvBNAct( - in_channels, mid_channels // 2, kernel_size=1, act="hardswish" + in_channels, mid_channels // 2, kernel_size=1, act="relu" ) self.conv_dw_2 = ConvBNAct( mid_channels // 2, @@ -224,7 +222,7 @@ def __init__( ) self.se = SEModule(se_channels, reduction=4) self.conv_linear_2 = ConvBNAct( - mid_channels // 2, out_channels // 2, kernel_size=1, act="hardswish" + mid_channels // 2, out_channels // 2, kernel_size=1, act="relu" ) self.conv_dw_mv1 = ConvBNAct( @@ -232,10 +230,10 @@ def __init__( out_channels, kernel_size=3, groups=out_channels, - act="hardswish", + act="relu", ) self.conv_pw_mv1 = ConvBNAct( - out_channels, out_channels, kernel_size=1, act="hardswish" + out_channels, out_channels, kernel_size=1, act="relu" ) def forward(self, x: Tensor) -> Tensor: @@ -374,7 +372,7 @@ def __init__( ] self.conv1 = ConvBNAct( - in_channels, stage_out_channels[0], kernel_size=3, stride=2, act="hardswish" + in_channels, stage_out_channels[0], kernel_size=3, stride=2, act="relu" ) self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/src/lightly_train/_task_models/picodet_object_detection/pico_head.py b/src/lightly_train/_task_models/picodet_object_detection/pico_head.py index e40b7007d..35b9735da 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/pico_head.py +++ b/src/lightly_train/_task_models/picodet_object_detection/pico_head.py @@ -49,7 +49,7 @@ def __init__( ) self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False) self.bn = nn.BatchNorm2d(out_channels) - self.act = nn.Hardswish(inplace=True) + self.act = nn.ReLU(inplace=True) def forward(self, x: Tensor) -> Tensor: x = self.depthwise(x) @@ -95,7 +95,7 @@ def forward(self, x: Tensor) -> Tensor: x = F.softmax(x, dim=-1) # Compute expectation - project: Tensor = self.project # type: ignore[assignment] + project: Tensor = self.project.to(dtype=x.dtype, device=x.device) # type: ignore[assignment] x = F.linear(x, project.view(1, -1)).squeeze(-1) # Reshape back to (..., 4) if input was 4*(reg_max+1) @@ -242,7 +242,7 @@ def __init__( bias=False, ), nn.BatchNorm2d(feat_channels), - nn.Hardswish(inplace=True), + nn.ReLU(inplace=True), ) ) self.cls_convs.append(cls_convs) @@ -266,7 +266,7 @@ def __init__( bias=False, ), nn.BatchNorm2d(feat_channels), - nn.Hardswish(inplace=True), + nn.ReLU(inplace=True), ) ) self.reg_convs.append(reg_convs) @@ -426,3 +426,75 @@ def decode_predictions( all_decoded_bboxes = torch.cat(decoded_bboxes_list, dim=1) return all_points, all_cls_scores, all_decoded_bboxes + + +class PicoHeadO2O(nn.Module): + """One-to-one query head for PicoDet. + + This head produces a fixed number of predictions (Q) without NMS or TopK. + It uses a lightweight query pooling mechanism over a single feature map. + + Args: + in_channels: Number of input channels for the feature map. + num_classes: Number of object classes. + num_queries: Number of fixed predictions to emit. + hidden_dim: Hidden dimension for query features. + """ + + def __init__( + self, + in_channels: int, + num_classes: int, + num_queries: int = 100, + hidden_dim: int | None = None, + ) -> None: + super().__init__() + hidden_dim = in_channels if hidden_dim is None else hidden_dim + self.num_queries = num_queries + self.num_classes = num_classes + + self.attn_conv = nn.Conv2d(in_channels, num_queries, kernel_size=1) + self.query_proj = nn.Linear(in_channels, hidden_dim) + self.act = nn.ReLU(inplace=True) + self.obj_head = nn.Linear(hidden_dim, 1) + self.cls_head = nn.Linear(hidden_dim, num_classes) + self.box_head = nn.Linear(hidden_dim, 4) + + self._init_weights() + + def _init_weights(self) -> None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, std=0.01) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, feat: Tensor) -> tuple[Tensor, Tensor, Tensor]: + """Forward pass. + + Args: + feat: Feature map of shape (B, C, H, W). + + Returns: + Tuple of: + - obj_logits: (B, Q) + - cls_logits: (B, Q, C) + - box_preds: (B, Q, 4) in normalized cxcywh format. + """ + batch_size, channels, height, width = feat.shape + attn = self.attn_conv(feat).reshape(batch_size, self.num_queries, -1) + attn = attn.clamp(min=-10.0, max=10.0) + attn = torch.softmax(attn, dim=-1) + + feat_flat = feat.reshape(batch_size, channels, height * width).transpose(1, 2) + pooled = torch.bmm(attn, feat_flat) + pooled = self.act(self.query_proj(pooled)) + + obj_logits = self.obj_head(pooled).squeeze(-1) + cls_logits = self.cls_head(pooled) + box_preds = self.box_head(pooled).sigmoid() + return obj_logits, cls_logits, box_preds diff --git a/src/lightly_train/_task_models/picodet_object_detection/postprocessor.py b/src/lightly_train/_task_models/picodet_object_detection/postprocessor.py index b26c6c6a5..8248e0c05 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/postprocessor.py +++ b/src/lightly_train/_task_models/picodet_object_detection/postprocessor.py @@ -63,7 +63,12 @@ def deploy(self) -> None: self.deploy_mode = True def _generate_grid_points( - self, height: int, width: int, stride: int, device: torch.device + self, + height: int, + width: int, + stride: int, + device: torch.device, + dtype: torch.dtype, ) -> Tensor: """Generate grid center points for a feature map. @@ -76,8 +81,8 @@ def _generate_grid_points( Returns: Grid points of shape (H*W, 2) as [x, y] in pixel coordinates. """ - y = (torch.arange(height, device=device, dtype=torch.float32) + 0.5) * stride - x = (torch.arange(width, device=device, dtype=torch.float32) + 0.5) * stride + y = (torch.arange(height, device=device, dtype=dtype) + 0.5) * stride + x = (torch.arange(width, device=device, dtype=dtype) + 0.5) * stride yy, xx = torch.meshgrid(y, x, indexing="ij") return torch.stack([xx.flatten(), yy.flatten()], dim=-1) @@ -128,7 +133,9 @@ def forward( bbox_pred[0].permute(1, 2, 0).reshape(-1, 4 * (self.reg_max + 1)) ) - points = self._generate_grid_points(height, width, stride, device) + points = self._generate_grid_points( + height, width, stride, device, dtype=cls_score.dtype + ) scores = cls_score.sigmoid().reshape(-1, num_classes) valid_mask = scores > score_thr diff --git a/src/lightly_train/_task_models/picodet_object_detection/sim_ota_assigner.py b/src/lightly_train/_task_models/picodet_object_detection/sim_ota_assigner.py index d43eafdb4..524ce76e7 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/sim_ota_assigner.py +++ b/src/lightly_train/_task_models/picodet_object_detection/sim_ota_assigner.py @@ -16,6 +16,123 @@ ) +class TaskAlignedTop1Assigner: + """Task-aligned top-1 assigner for one-to-one matching. + + This assigns each ground-truth to at most one prediction using a + GT-centric top-1 selection based on a task-aligned metric with + collision resolution. + + Args: + alpha: Power for classification score in the metric. + beta: Power for IoU in the metric. + """ + + def __init__(self, alpha: float = 0.5, beta: float = 6.0) -> None: + self.alpha = float(alpha) + self.beta = float(beta) + + @torch.no_grad() + def assign( + self, + *, + pred_boxes_xyxy: Tensor, + pred_cls_logits: Tensor, + gt_boxes_xyxy: Tensor, + gt_labels: Tensor, + prior_centers: Tensor | None = None, + ) -> tuple[Tensor, Tensor, Tensor]: + """Assign predictions to ground truth. + + Args: + pred_boxes_xyxy: Predicted boxes (N, 4) in xyxy pixel coords. + pred_cls_logits: Predicted class logits (N, C). + gt_boxes_xyxy: Ground-truth boxes (M, 4) in xyxy pixel coords. + gt_labels: Ground-truth labels (M,). + prior_centers: Optional prior centers (N, 2) in pixel coords for + spatial prior masking (center inside gt). + + Returns: + Tuple of: + - assigned_gt_index: (N,) index of matched GT or -1. + - assigned_labels: (N,) matched labels or -1. + - assigned_ious: (N,) IoU for matched predictions. + """ + device = pred_boxes_xyxy.device + num_preds = pred_boxes_xyxy.shape[0] + if gt_boxes_xyxy.numel() == 0: + assigned_gt = torch.full((num_preds,), -1, device=device, dtype=torch.long) + assigned_labels = torch.full( + (num_preds,), -1, device=device, dtype=torch.long + ) + assigned_ious = torch.zeros((num_preds,), device=device) + return assigned_gt, assigned_labels, assigned_ious + + ious = box_iou(pred_boxes_xyxy, gt_boxes_xyxy) + cls_prob = pred_cls_logits.sigmoid() + gt_onehot = torch.nn.functional.one_hot( + gt_labels, num_classes=pred_cls_logits.shape[1] + ).to(dtype=cls_prob.dtype) + pair_cls = cls_prob @ gt_onehot.t() + + metric = (pair_cls.clamp(min=1e-5) ** self.alpha) * ( + ious.clamp(min=1e-5) ** self.beta + ) + if prior_centers is not None: + cx = prior_centers[:, 0].unsqueeze(1) + cy = prior_centers[:, 1].unsqueeze(1) + in_gt = (cx >= gt_boxes_xyxy[:, 0]) & (cx <= gt_boxes_xyxy[:, 2]) + in_gt = in_gt & (cy >= gt_boxes_xyxy[:, 1]) & (cy <= gt_boxes_xyxy[:, 3]) + has_center = in_gt.any(dim=0) + if not bool(has_center.all()): + gt_centers = (gt_boxes_xyxy[:, 0:2] + gt_boxes_xyxy[:, 2:4]) / 2 + diff = prior_centers[:, None, :] - gt_centers[None, :, :] + dist2 = (diff**2).sum(dim=-1) + missing = ~has_center + if bool(missing.any()): + nearest_idx = dist2[:, missing].argmin(dim=0) + in_gt[nearest_idx, missing] = True + metric = torch.where(in_gt, metric, torch.zeros_like(metric)) + + if metric.max() <= 0: + assigned_gt = torch.full((num_preds,), -1, device=device, dtype=torch.long) + assigned_labels = torch.full( + (num_preds,), -1, device=device, dtype=torch.long + ) + assigned_ious = torch.zeros((num_preds,), device=device) + return assigned_gt, assigned_labels, assigned_ious + + num_gt = gt_boxes_xyxy.shape[0] + assigned_gt = torch.full((num_preds,), -1, device=device, dtype=torch.long) + assigned_ious = torch.zeros((num_preds,), device=device) + gt_assigned = torch.zeros((num_gt,), device=device, dtype=torch.bool) + metric_work = metric.clone() + + while True: + remaining = (~gt_assigned) & (metric_work.max(dim=0).values > 0) + if not bool(remaining.any()): + break + best_scores, best_preds = metric_work[:, remaining].max(dim=0) + gt_indices = torch.nonzero(remaining, as_tuple=False).squeeze(1) + + unique_preds, inv = torch.unique(best_preds, return_inverse=True) + for idx, pred_idx in enumerate(unique_preds): + gt_mask = inv == idx + if not bool(gt_mask.any()): + continue + best_local = torch.argmax(best_scores[gt_mask]) + gt_global = gt_indices[gt_mask][best_local] + assigned_gt[pred_idx] = gt_global + assigned_ious[pred_idx] = ious[pred_idx, gt_global] + gt_assigned[gt_global] = True + metric_work[pred_idx, :] = 0 + + assigned_labels = torch.full((num_preds,), -1, device=device, dtype=torch.long) + pos_mask = assigned_gt >= 0 + assigned_labels[pos_mask] = gt_labels[assigned_gt[pos_mask]] + return assigned_gt, assigned_labels, assigned_ious + + class SimOTAAssigner: """Simplified Optimal Transport Assignment for anchor-free detection. diff --git a/src/lightly_train/_task_models/picodet_object_detection/task_model.py b/src/lightly_train/_task_models/picodet_object_detection/task_model.py index d2fb3e183..f87ca86a8 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/task_model.py +++ b/src/lightly_train/_task_models/picodet_object_detection/task_model.py @@ -8,10 +8,12 @@ from __future__ import annotations import logging +import os from copy import deepcopy from typing import Any, Literal import torch +import torch.nn.functional as F from packaging import version from PIL.Image import Image as PILImage from torch import Tensor @@ -23,7 +25,10 @@ from lightly_train._export import tensorrt_helpers from lightly_train._task_models.picodet_object_detection.csp_pan import CSPPAN from lightly_train._task_models.picodet_object_detection.esnet import ESNet -from lightly_train._task_models.picodet_object_detection.pico_head import PicoHead +from lightly_train._task_models.picodet_object_detection.pico_head import ( + PicoHead, + distance2bbox, +) from lightly_train._task_models.picodet_object_detection.postprocessor import ( PicoDetPostProcessor, ) @@ -41,9 +46,9 @@ "neck_out_channels": 96, "head_feat_channels": 96, }, - "picodet/l-416": { + "picodet/l-640": { "model_size": "l", - "image_size": (416, 416), + "image_size": (640, 640), "stacked_convs": 4, "neck_out_channels": 160, "head_feat_channels": 160, @@ -73,9 +78,12 @@ def __init__( score_threshold: float = 0.025, iou_threshold: float = 0.6, max_detections: int = 100, + backbone_weights: PathLike | None = None, load_weights: bool = True, ) -> None: - super().__init__(init_args=locals(), ignore_args={"load_weights"}) + super().__init__( + init_args=locals(), ignore_args={"backbone_weights", "load_weights"} + ) self.model_name = model_name self.image_size = image_size @@ -83,6 +91,7 @@ def __init__( self.num_classes = num_classes self.reg_max = reg_max self.classes = classes + self._export_decode_fp32 = False if classes is not None and len(classes) != num_classes: raise ValueError( @@ -132,6 +141,11 @@ def __init__( ) backbone_out_channels = self.backbone.out_channels + print("Attempting to load backbone weights: ", load_weights, backbone_weights) + + if load_weights and backbone_weights is not None: + self.load_backbone_weights(backbone_weights) + self.neck = CSPPAN( in_channels=backbone_out_channels, out_channels=neck_out_channels_typed, @@ -153,6 +167,17 @@ def __init__( share_cls_reg=True, use_depthwise=True, ) + self.o2o_head = PicoHead( + in_channels=neck_out_channels_typed, + num_classes=num_classes, + feat_channels=head_feat_channels_typed, + stacked_convs=stacked_convs_typed, + kernel_size=5, + reg_max=reg_max, + strides=(8, 16, 32, 64), + share_cls_reg=True, + use_depthwise=True, + ) self.postprocessor = PicoDetPostProcessor( num_classes=num_classes, @@ -162,6 +187,85 @@ def __init__( iou_threshold=iou_threshold, max_detections=max_detections, ) + self._o2o_peak_score_thresholds = (0.02, 0.04, 0.06, 0.08) + self._o2o_peak_kernels = (3, 5, 5, 5) + self._o2o_suppress_logit = -1e6 + + def _apply_o2o_peak_filter(self, cls_score: Tensor, level_idx: int) -> Tensor: + """Suppress non-peak logits to sparsify dense o2o predictions.""" + scores = cls_score.sigmoid().amax(dim=1, keepdim=True) + threshold = self._o2o_peak_score_thresholds[level_idx] + kernel = self._o2o_peak_kernels[level_idx] + pooled = F.max_pool2d( + scores, kernel_size=kernel, stride=1, padding=kernel // 2 + ) + keep = (scores >= threshold) & (scores == pooled) + suppressed = cls_score.new_full((), self._o2o_suppress_logit) + return torch.where(keep, cls_score, suppressed) + + def _count_o2o_peaks(self, cls_scores_list: list[Tensor]) -> Tensor: + """Return mean number of peaks per level per image for debug logging.""" + device = cls_scores_list[0].device + total_peaks = torch.zeros((len(cls_scores_list),), device=device) + for level_idx, cls_score in enumerate(cls_scores_list): + scores = cls_score.sigmoid().amax(dim=1, keepdim=True) + threshold = self._o2o_peak_score_thresholds[level_idx] + kernel = self._o2o_peak_kernels[level_idx] + pooled = F.max_pool2d( + scores, kernel_size=kernel, stride=1, padding=kernel // 2 + ) + keep = (scores >= threshold) & (scores == pooled) + total_peaks[level_idx] = keep.sum() + batch_size = cls_scores_list[0].shape[0] + return total_peaks / float(batch_size) + + def load_backbone_weights(self, path: PathLike) -> None: + """Load backbone weights from a checkpoint file. + + Args: + path: Path to a .pt file (e.g., exported_last.pt). + """ + if not os.path.exists(path): + logger.error("Checkpoint file not found: %s", path) + return + + state_dict = torch.load(path, map_location="cpu", weights_only=False) + if isinstance(state_dict, dict): + for key in ("state_dict", "model", "model_state_dict", "student"): + if key in state_dict and isinstance(state_dict[key], dict): + state_dict = state_dict[key] + break + + if isinstance(state_dict, dict): + if all(key.startswith("module.") for key in state_dict): + state_dict = { + key[len("module.") :]: value for key, value in state_dict.items() + } + + prefixes = ("_model.", "model.", "backbone.") + if all(key.startswith(prefixes) for key in state_dict): + state_dict = { + key.split(".", 1)[1]: value for key, value in state_dict.items() + } + elif any(key.startswith(prefixes) for key in state_dict): + state_dict = { + key.split(".", 1)[1]: value + for key, value in state_dict.items() + if key.startswith(prefixes) + } + + missing, unexpected = self.backbone.load_state_dict(state_dict, strict=False) + total_backbone_keys = len(self.backbone.state_dict()) + loaded_keys = total_backbone_keys - len(missing) + logger.info( + "Backbone weights loaded: %d/%d keys matched.", + loaded_keys, + total_backbone_keys, + ) + if missing: + logger.warning("Missing keys when loading backbone: %s", missing) + if unexpected: + logger.warning("Unexpected keys when loading backbone: %s", unexpected) @classmethod def list_model_names(cls) -> list[str]: @@ -212,7 +316,7 @@ def load_train_state_dict( return self.load_state_dict(new_state_dict, strict=strict, assign=assign) - def _forward_train(self, images: Tensor) -> dict[str, list[Tensor]]: + def _forward_train(self, images: Tensor) -> dict[str, Tensor | list[Tensor]]: """Forward pass returning raw per-level predictions. Args: @@ -226,12 +330,79 @@ def _forward_train(self, images: Tensor) -> dict[str, list[Tensor]]: feats = self.backbone(images) feats = self.neck(feats) cls_scores, bbox_preds = self.head(feats) - return {"cls_scores": cls_scores, "bbox_preds": bbox_preds} + o2o_cls_scores, o2o_bbox_preds = self.o2o_head(feats) + return { + "cls_scores": cls_scores, + "bbox_preds": bbox_preds, + "o2o_cls_scores": o2o_cls_scores, + "o2o_bbox_preds": o2o_bbox_preds, + } + + def _decode_o2o_predictions( + self, + *, + cls_scores_list: list[Tensor], + bbox_preds_list: list[Tensor], + image_size: tuple[int, int], + input_size: tuple[int, int], + ) -> tuple[Tensor, Tensor]: + batch_size = cls_scores_list[0].shape[0] + device = cls_scores_list[0].device + decode_bbox_preds_pixel: list[Tensor] = [] + flatten_cls_preds: list[Tensor] = [] + decode_dtype = ( + torch.float32 if self._export_decode_fp32 else cls_scores_list[0].dtype + ) + + for level_idx, (cls_score, bbox_pred) in enumerate( + zip(cls_scores_list, bbox_preds_list) + ): + stride = self.o2o_head.strides[level_idx] + _, _, h, w = cls_score.shape + num_points = h * w + + cls_score = self._apply_o2o_peak_filter(cls_score, level_idx) + y = (torch.arange(h, device=device, dtype=decode_dtype) + 0.5) * stride + x = (torch.arange(w, device=device, dtype=decode_dtype) + 0.5) * stride + yy, xx = torch.meshgrid(y, x, indexing="ij") + points = torch.stack([xx.flatten(), yy.flatten()], dim=-1) + + center_in_feature = points / stride + bbox_pred_flat = bbox_pred.permute(0, 2, 3, 1).reshape( + batch_size, num_points, 4 * (self.reg_max + 1) + ) + if self._export_decode_fp32: + bbox_pred_flat = bbox_pred_flat.to(dtype=decode_dtype) + pred_corners = self.o2o_head.integral(bbox_pred_flat) + decode_bbox_pred = distance2bbox( + center_in_feature.unsqueeze(0).expand(batch_size, -1, -1), pred_corners + ) + decode_bbox_preds_pixel.append(decode_bbox_pred * stride) + + cls_pred_flat = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, num_points, self.num_classes + ) + flatten_cls_preds.append(cls_pred_flat) + + boxes_xyxy = torch.cat(decode_bbox_preds_pixel, dim=1) + cls_logits = torch.cat(flatten_cls_preds, dim=1) + + input_h, input_w = input_size + orig_h, orig_w = image_size + if (orig_h, orig_w) != (input_h, input_w): + scale = boxes_xyxy.new_tensor( + [orig_w / input_w, orig_h / input_h, orig_w / input_w, orig_h / input_h] + ) + boxes_xyxy = boxes_xyxy * scale + + scale_limit = boxes_xyxy.new_tensor([orig_w, orig_h, orig_w, orig_h]) + boxes_xyxy = torch.min(boxes_xyxy, scale_limit).clamp(min=0) + return boxes_xyxy, cls_logits def forward( self, images: Tensor, orig_target_size: Tensor | None = None ) -> tuple[Tensor, Tensor, Tensor]: - """Forward pass returning final predictions for inference/ONNX. + """Forward pass returning o2o predictions for inference/ONNX. Args: images: Input tensor of shape (B, C, H, W). @@ -239,9 +410,9 @@ def forward( Returns: Tuple of: - - labels: Tensor of shape (B, N) with class indices. - - boxes: Tensor of shape (B, N, 4) in xyxy format. - - scores: Tensor of shape (B, N) with confidence scores. + - boxes_xyxy: Tensor of shape (B, N, 4) in xyxy pixel format. + - obj_logits: Tensor of shape (B, N) with objectness logits. + - cls_logits: Tensor of shape (B, N, C) with class logits. """ if orig_target_size is None: orig_h, orig_w = images.shape[-2:] @@ -253,41 +424,18 @@ def forward( orig_target_size_ = orig_target_size_[0] orig_h, orig_w = int(orig_target_size_[0]), int(orig_target_size_[1]) - outputs = self._forward_train(images) - result = self.postprocessor( - cls_scores=[cs[:1] for cs in outputs["cls_scores"]], - bbox_preds=[bp[:1] for bp in outputs["bbox_preds"]], - original_size=(orig_h, orig_w), - score_threshold=0.0, - ) - - max_detections = self.postprocessor.max_detections - labels_out = torch.full( - (1, max_detections), - -1, - device=images.device, - dtype=torch.long, - ) - boxes_out = torch.zeros( - (1, max_detections, 4), - device=images.device, - dtype=result["bboxes"].dtype, - ) - scores_out = torch.zeros( - (1, max_detections), - device=images.device, - dtype=result["scores"].dtype, + feats = self.backbone(images) + feats = self.neck(feats) + cls_scores_list, bbox_preds_list = self.o2o_head(feats) + input_size = (int(images.shape[-2]), int(images.shape[-1])) + boxes_xyxy, cls_logits = self._decode_o2o_predictions( + cls_scores_list=cls_scores_list, + bbox_preds_list=bbox_preds_list, + image_size=(orig_h, orig_w), + input_size=input_size, ) - - # PicoDet postprocessing returns variable-length outputs, so we pad to - # fixed shapes for ONNX; LTDETR already returns fixed-size tensors. - labels = self.internal_class_to_class[result["labels"]] - num_detections = labels.shape[0] - labels_out[0, :num_detections] = labels - boxes_out[0, :num_detections] = result["bboxes"] - scores_out[0, :num_detections] = result["scores"] - - return labels_out, boxes_out, scores_out + obj_logits = cls_logits.max(dim=-1).values + return boxes_xyxy, obj_logits, cls_logits @torch.no_grad() def predict( @@ -323,16 +471,26 @@ def predict( x = transforms_functional.resize(x, list(self.image_size)) x = x.unsqueeze(0) - outputs = self._forward_train(x) - results = self.postprocessor( - cls_scores=outputs["cls_scores"], - bbox_preds=outputs["bbox_preds"], - original_size=(orig_h, orig_w), - score_threshold=threshold, + feats = self.backbone(x) + feats = self.neck(feats) + cls_scores_list, bbox_preds_list = self.o2o_head(feats) + boxes_xyxy, cls_logits = self._decode_o2o_predictions( + cls_scores_list=cls_scores_list, + bbox_preds_list=bbox_preds_list, + image_size=(orig_h, orig_w), + input_size=tuple(self.image_size), ) - labels = self.internal_class_to_class[results["labels"]] - boxes = results["bboxes"] - scores = results["scores"] + boxes = boxes_xyxy[0] + internal_labels = cls_logits[0].argmax(dim=-1) + cls_for_label = cls_logits[0].gather(1, internal_labels.unsqueeze(1)).squeeze(1) + scores = torch.sigmoid(cls_for_label) + labels = self.internal_class_to_class[internal_labels] + if threshold > 0: + keep = scores >= threshold + labels = labels[keep] + boxes = boxes[keep] + scores = scores[keep] + return { "labels": labels, "bboxes": boxes, @@ -355,8 +513,7 @@ def export_onnx( The export uses a dummy input of shape (1, C, H, W) where C is inferred from the first model parameter and (H, W) come from `self.image_size`. - The ONNX graph outputs labels, boxes, and scores in the resized input - image space. + The ONNX graph outputs labels, boxes, and scores. Optionally simplifies the exported model in-place using onnxslim and verifies numerical closeness against a float32 CPU reference via @@ -437,6 +594,20 @@ def export_onnx( input_names = ["images"] output_names = ["labels", "boxes", "scores"] + class _PicoDetExportWrapper(torch.nn.Module): + def __init__(self, model: PicoDetObjectDetection) -> None: + super().__init__() + self.model = model + + def forward(self, images: Tensor) -> tuple[Tensor, Tensor, Tensor]: + boxes_xyxy, obj_logit, cls_logits = self.model(images) + scores = torch.sigmoid(obj_logit) + labels = cls_logits.argmax(dim=-1).to(torch.int64) + labels = self.model.internal_class_to_class[labels] + return labels, boxes_xyxy, scores + + export_model = _PicoDetExportWrapper(self) + # Older torch.onnx.export versions don't accept the "dynamo" kwarg. export_kwargs: dict[str, Any] = { "input_names": input_names, @@ -449,12 +620,17 @@ def export_onnx( if torch_version >= version.parse("2.2.0"): export_kwargs["dynamo"] = False - torch.onnx.export( - self, - (dummy_input,), - str(out), - **export_kwargs, - ) + prev_export_decode_fp32 = self._export_decode_fp32 + self._export_decode_fp32 = dtype == torch.float16 + try: + torch.onnx.export( + export_model, + (dummy_input,), + str(out), + **export_kwargs, + ) + finally: + self._export_decode_fp32 = prev_export_decode_fp32 if simplify: import onnxslim # type: ignore [import-not-found,import-untyped] @@ -473,7 +649,8 @@ def export_onnx( onnx.checker.check_model(out, full_check=True) reference_model = deepcopy(self).cpu().to(torch.float32).eval() - reference_outputs = reference_model( + reference_export_model = _PicoDetExportWrapper(reference_model) + reference_outputs = reference_export_model( dummy_input.cpu().to(torch.float32), ) @@ -504,7 +681,7 @@ def msg(s: str) -> str: check_device=False, check_dtype=False, check_layout=False, - atol=5e-3, + atol=2e-2, rtol=1e-1, ) else: diff --git a/src/lightly_train/_task_models/picodet_object_detection/train_model.py b/src/lightly_train/_task_models/picodet_object_detection/train_model.py index a442ce9db..347e53497 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/train_model.py +++ b/src/lightly_train/_task_models/picodet_object_detection/train_model.py @@ -8,10 +8,11 @@ from __future__ import annotations import copy -from typing import Any, ClassVar, Literal +from typing import Any, ClassVar, Literal, cast import torch import torch.distributed as dist +import torch.nn.functional as F from lightly.utils.scheduler import CosineWarmupScheduler from lightning_fabric import Fabric from torch import Tensor @@ -43,11 +44,9 @@ bbox2distance, distance2bbox, ) -from lightly_train._task_models.picodet_object_detection.postprocessor import ( - PicoDetPostProcessor, -) from lightly_train._task_models.picodet_object_detection.sim_ota_assigner import ( SimOTAAssigner, + TaskAlignedTop1Assigner, ) from lightly_train._task_models.picodet_object_detection.task_model import ( PicoDetObjectDetection, @@ -63,7 +62,7 @@ TrainModel, TrainModelArgs, ) -from lightly_train.types import ObjectDetectionBatch +from lightly_train.types import ObjectDetectionBatch, PathLike class PicoDetObjectDetectionTaskSaveCheckpointArgs(TaskSaveCheckpointArgs): @@ -77,6 +76,7 @@ class PicoDetObjectDetectionTrainArgs(TrainModelArgs): """Training arguments for PicoDet-S. Args: + backbone_weights: Optional path to backbone checkpoint to load. lr: Learning rate for SGD optimizer. momentum: Momentum for SGD optimizer. weight_decay: Weight decay for SGD optimizer. @@ -85,6 +85,7 @@ class PicoDetObjectDetectionTrainArgs(TrainModelArgs): loss_vfl_weight: Weight for varifocal loss. loss_giou_weight: Weight for GIoU loss. loss_dfl_weight: Weight for distribution focal loss. + loss_o2o_dfl_weight: Weight for the o2o DFL loss term. simota_center_radius: Center radius for SimOTA assignment. simota_candidate_topk: Top-k candidates for dynamic k in SimOTA. simota_iou_weight: IoU weight in SimOTA cost matrix. @@ -98,6 +99,8 @@ class PicoDetObjectDetectionTrainArgs(TrainModelArgs): PicoDetObjectDetectionTaskSaveCheckpointArgs ) + backbone_weights: PathLike | None = None + lr: float = 0.1 momentum: float = 0.9 weight_decay: float = 4e-5 @@ -108,6 +111,7 @@ class PicoDetObjectDetectionTrainArgs(TrainModelArgs): loss_vfl_weight: float = 1.0 loss_giou_weight: float = 2.0 loss_dfl_weight: float = 0.25 + loss_o2o_dfl_weight: float = 0.25 simota_center_radius: float = 2.5 simota_candidate_topk: int = 10 @@ -158,6 +162,7 @@ def __init__( num_classes=num_classes, classes=data_args.included_classes, image_normalize=image_normalize, + backbone_weights=model_args.backbone_weights, load_weights=load_weights, ) @@ -178,6 +183,7 @@ def __init__( cls_weight=1.0, num_classes=num_classes, ) + self.o2o_assigner = TaskAlignedTop1Assigner(alpha=0.5, beta=6.0) # EMA model setup (following LTDETR pattern for consistency) # EMA is always enabled @@ -214,28 +220,70 @@ def training_step( img_h, img_w = images.shape[-2:] outputs = self.model._forward_train(images) - cls_scores = outputs["cls_scores"] - bbox_preds = outputs["bbox_preds"] + cls_scores_list = cast(list[Tensor], outputs["cls_scores"]) + bbox_preds_list = cast(list[Tensor], outputs["bbox_preds"]) + o2o_cls_scores = cast(list[Tensor], outputs["o2o_cls_scores"]) + o2o_bbox_preds = cast(list[Tensor], outputs["o2o_bbox_preds"]) # Convert GT from YOLO format to pixel xyxy gt_boxes_xyxy_norm = _yolo_to_xyxy(gt_bboxes_yolo) sizes = [(img_w, img_h)] * batch_size gt_boxes_xyxy_list = _denormalize_xyxy_boxes(gt_boxes_xyxy_norm, sizes) - total_loss, loss_vfl, loss_giou, loss_dfl = self._compute_losses( + dense_loss, loss_vfl, loss_giou, loss_dfl = self._compute_losses( fabric=fabric, - cls_scores=cls_scores, - bbox_preds=bbox_preds, + cls_scores=cls_scores_list, + bbox_preds=bbox_preds_list, gt_boxes_xyxy_list=gt_boxes_xyxy_list, gt_labels_list=gt_labels_list, ) + ( + o2o_loss, + loss_o2o_obj, + loss_o2o_cls, + loss_o2o_box, + loss_o2o_dfl, + o2o_stats, + ) = self._compute_o2o_losses( + cls_scores=o2o_cls_scores, + bbox_preds=o2o_bbox_preds, + gt_boxes_xyxy_list=gt_boxes_xyxy_list, + gt_labels_list=gt_labels_list, + image_size=(img_h, img_w), + ) + o2o_peak_kept = self.model._count_o2o_peaks(o2o_cls_scores) + + total_loss = o2o_loss + 0.5 * dense_loss loss_dict = reduce_dict( { "train_loss": total_loss, + "train_loss/loss_o2o": o2o_loss, + "train_loss/loss_o2o_obj": loss_o2o_obj, + "train_loss/loss_o2o_cls": loss_o2o_cls, + "train_loss/loss_o2o_box": loss_o2o_box, + "train_loss/loss_o2o_dfl": loss_o2o_dfl, + "train_loss/loss_dense": dense_loss, "train_loss/loss_vfl": loss_vfl, "train_loss/loss_giou": loss_giou, "train_loss/loss_dfl": loss_dfl, + # TODO(igorsusmelj): remove o2o debug stats after investigation. + "debug/o2o_num_pos": o2o_stats["o2o_num_pos"], + "debug/o2o_mean_iou": o2o_stats["o2o_mean_iou"], + "debug/o2o_cls_target_sum": o2o_stats["o2o_cls_target_sum"], + "debug/o2o_peak_kept_p3": o2o_peak_kept[0], + "debug/o2o_peak_kept_p4": o2o_peak_kept[1], + "debug/o2o_peak_kept_p5": o2o_peak_kept[2], + "debug/o2o_peak_kept_p6": o2o_peak_kept[3], + "debug/o2o_gt_small": o2o_stats["o2o_gt_small"], + "debug/o2o_gt_medium": o2o_stats["o2o_gt_medium"], + "debug/o2o_gt_large": o2o_stats["o2o_gt_large"], + "debug/o2o_gt_center_small": o2o_stats["o2o_gt_center_small"], + "debug/o2o_gt_center_medium": o2o_stats["o2o_gt_center_medium"], + "debug/o2o_gt_center_large": o2o_stats["o2o_gt_center_large"], + "debug/o2o_gt_matched_small": o2o_stats["o2o_gt_matched_small"], + "debug/o2o_gt_matched_medium": o2o_stats["o2o_gt_matched_medium"], + "debug/o2o_gt_matched_large": o2o_stats["o2o_gt_matched_large"], } ) @@ -267,41 +315,70 @@ def validation_step( with torch.no_grad(): outputs = model_to_use._forward_train(images) # type: ignore[operator] - cls_scores = outputs["cls_scores"] - bbox_preds = outputs["bbox_preds"] + cls_scores_list = cast(list[Tensor], outputs["cls_scores"]) + bbox_preds_list = cast(list[Tensor], outputs["bbox_preds"]) + o2o_cls_scores = cast(list[Tensor], outputs["o2o_cls_scores"]) + o2o_bbox_preds = cast(list[Tensor], outputs["o2o_bbox_preds"]) gt_boxes_xyxy_norm = _yolo_to_xyxy(gt_bboxes_yolo) img_h, img_w = images.shape[-2:] sizes = [(img_w, img_h)] * batch_size gt_boxes_xyxy_list = _denormalize_xyxy_boxes(gt_boxes_xyxy_norm, sizes) - total_loss, loss_vfl, loss_giou, loss_dfl = self._compute_losses( + dense_loss, loss_vfl, loss_giou, loss_dfl = self._compute_losses( fabric=fabric, - cls_scores=cls_scores, - bbox_preds=bbox_preds, + cls_scores=cls_scores_list, + bbox_preds=bbox_preds_list, gt_boxes_xyxy_list=gt_boxes_xyxy_list, gt_labels_list=gt_labels_list, ) - - postprocessor = self.model.postprocessor - assert isinstance(postprocessor, PicoDetPostProcessor) - predictions = postprocessor.forward_batch( - cls_scores=cls_scores, - bbox_preds=bbox_preds, - original_sizes=torch.tensor([[img_h, img_w]] * batch_size, device=device), - score_threshold=0.001, + ( + o2o_loss, + loss_o2o_obj, + loss_o2o_cls, + loss_o2o_box, + loss_o2o_dfl, + o2o_stats, + ) = self._compute_o2o_losses( + cls_scores=o2o_cls_scores, + bbox_preds=o2o_bbox_preds, + gt_boxes_xyxy_list=gt_boxes_xyxy_list, + gt_labels_list=gt_labels_list, + image_size=(img_h, img_w), + ) + o2o_peak_kept = model_to_use._count_o2o_peaks(o2o_cls_scores) + total_loss = o2o_loss + 0.5 * dense_loss + + boxes_xyxy, cls_logits = model_to_use._decode_o2o_predictions( + cls_scores_list=o2o_cls_scores, + bbox_preds_list=o2o_bbox_preds, + image_size=(img_h, img_w), + input_size=(int(img_h), int(img_w)), ) + cls_labels = cls_logits.argmax(dim=-1) + cls_scores = cls_logits.gather(2, cls_labels.unsqueeze(-1)).squeeze(-1) + scores = torch.sigmoid(cls_scores) + cls_labels = self.model.internal_class_to_class[cls_labels] preds = [] targets = [] + max_detections = model_to_use.postprocessor.max_detections for i in range(batch_size): - pred_boxes = predictions[i]["bboxes"].detach() - pred_scores = predictions[i]["scores"].detach() - pred_labels = predictions[i]["labels"].detach() + pred_boxes = boxes_xyxy[i].detach() + pred_scores = scores[i].detach() + pred_labels = cls_labels[i].detach() gt_boxes = gt_boxes_xyxy_list[i].to(device).detach() gt_labels_i = gt_labels_list[i].to(device).long().detach() + if pred_scores.numel() > max_detections: + topk_scores, topk_idx = torch.topk( + pred_scores, k=max_detections, largest=True + ) + pred_boxes = pred_boxes[topk_idx] + pred_labels = pred_labels[topk_idx] + pred_scores = topk_scores + preds.append( { "boxes": pred_boxes, @@ -322,9 +399,34 @@ def validation_step( loss=total_loss, log_dict={ "val_loss": total_loss.item(), + "val_loss/loss_o2o": o2o_loss.item(), + "val_loss/loss_o2o_obj": loss_o2o_obj.item(), + "val_loss/loss_o2o_cls": loss_o2o_cls.item(), + "val_loss/loss_o2o_box": loss_o2o_box.item(), + "val_loss/loss_o2o_dfl": loss_o2o_dfl.item(), + "val_loss/loss_dense": dense_loss.item(), "val_loss/loss_vfl": loss_vfl.item(), "val_loss/loss_giou": loss_giou.item(), "val_loss/loss_dfl": loss_dfl.item(), + # TODO(igorsusmelj): remove o2o debug stats after investigation. + "debug/o2o_num_pos": o2o_stats["o2o_num_pos"].item(), + "debug/o2o_mean_iou": o2o_stats["o2o_mean_iou"].item(), + "debug/o2o_cls_target_sum": o2o_stats["o2o_cls_target_sum"].item(), + "debug/o2o_peak_kept_p3": o2o_peak_kept[0].item(), + "debug/o2o_peak_kept_p4": o2o_peak_kept[1].item(), + "debug/o2o_peak_kept_p5": o2o_peak_kept[2].item(), + "debug/o2o_peak_kept_p6": o2o_peak_kept[3].item(), + "debug/o2o_gt_small": o2o_stats["o2o_gt_small"].item(), + "debug/o2o_gt_medium": o2o_stats["o2o_gt_medium"].item(), + "debug/o2o_gt_large": o2o_stats["o2o_gt_large"].item(), + "debug/o2o_gt_center_small": o2o_stats["o2o_gt_center_small"].item(), + "debug/o2o_gt_center_medium": o2o_stats["o2o_gt_center_medium"].item(), + "debug/o2o_gt_center_large": o2o_stats["o2o_gt_center_large"].item(), + "debug/o2o_gt_matched_small": o2o_stats["o2o_gt_matched_small"].item(), + "debug/o2o_gt_matched_medium": o2o_stats[ + "o2o_gt_matched_medium" + ].item(), + "debug/o2o_gt_matched_large": o2o_stats["o2o_gt_matched_large"].item(), "val_metric/map": self.map_metric, }, ) @@ -381,6 +483,8 @@ def _compute_losses( all_decoded_bboxes_pixel = torch.cat(decode_bbox_preds_pixel, dim=1) all_cls_preds = torch.cat(flatten_cls_preds, dim=1) all_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + assert all_cls_preds.shape[1] == all_bbox_preds.shape[1] + assert all_cls_preds.shape[1] == all_center_and_strides.shape[1] all_vfl_losses: list[Tensor] = [] all_giou_losses: list[Tensor] = [] @@ -504,6 +608,217 @@ def _compute_losses( return total_loss, loss_vfl, loss_giou, loss_dfl + def _compute_o2o_losses( + self, + *, + cls_scores: list[Tensor], + bbox_preds: list[Tensor], + gt_boxes_xyxy_list: list[Tensor], + gt_labels_list: list[Tensor], + image_size: tuple[int, int], + ) -> tuple[Tensor, Tensor, Tensor, Tensor, dict[str, Tensor]]: + batch_size = cls_scores[0].shape[0] + device = cls_scores[0].device + + total_obj = torch.zeros((), device=device) + total_cls = torch.zeros((), device=device) + total_box = torch.zeros((), device=device) + total_dfl = torch.zeros((), device=device) + total_pos = torch.zeros((), device=device) + total_iou = torch.zeros((), device=device) + total_cls_target = torch.zeros((), device=device) + total_gt_small = torch.zeros((), device=device) + total_gt_medium = torch.zeros((), device=device) + total_gt_large = torch.zeros((), device=device) + total_gt_center_small = torch.zeros((), device=device) + total_gt_center_medium = torch.zeros((), device=device) + total_gt_center_large = torch.zeros((), device=device) + total_gt_matched_small = torch.zeros((), device=device) + total_gt_matched_medium = torch.zeros((), device=device) + total_gt_matched_large = torch.zeros((), device=device) + + decode_bbox_preds_pixel: list[Tensor] = [] + center_and_strides: list[Tensor] = [] + flatten_cls_preds: list[Tensor] = [] + flatten_bbox_preds: list[Tensor] = [] + + for level_idx, (cls_score, bbox_pred) in enumerate(zip(cls_scores, bbox_preds)): + stride = self.strides[level_idx] + _, _, h, w = cls_score.shape + num_points = h * w + + y = (torch.arange(h, device=device, dtype=torch.float32) + 0.5) * stride + x = (torch.arange(w, device=device, dtype=torch.float32) + 0.5) * stride + yy, xx = torch.meshgrid(y, x, indexing="ij") + points = torch.stack([xx.flatten(), yy.flatten()], dim=-1) + priors = torch.cat( + [points, torch.full((num_points, 2), stride, device=device)], dim=-1 + ) + center_and_stride = priors.unsqueeze(0).expand(batch_size, -1, -1) + center_and_strides.append(center_and_stride) + + center_in_feature = points / stride + bbox_pred_flat = bbox_pred.permute(0, 2, 3, 1).reshape( + batch_size, num_points, 4 * (self.reg_max + 1) + ) + pred_corners = self.integral(bbox_pred_flat) + decode_bbox_pred = distance2bbox( + center_in_feature.unsqueeze(0).expand(batch_size, -1, -1), pred_corners + ) + decode_bbox_preds_pixel.append(decode_bbox_pred * stride) + + cls_pred_flat = cls_score.permute(0, 2, 3, 1).reshape( + batch_size, num_points, self.num_classes + ) + flatten_cls_preds.append(cls_pred_flat) + flatten_bbox_preds.append(bbox_pred_flat) + + all_center_and_strides = torch.cat(center_and_strides, dim=1) + all_decoded_bboxes_pixel = torch.cat(decode_bbox_preds_pixel, dim=1) + all_cls_preds = torch.cat(flatten_cls_preds, dim=1) + all_bbox_preds = torch.cat(flatten_bbox_preds, dim=1) + assert all_cls_preds.shape[1] == all_bbox_preds.shape[1] + assert all_cls_preds.shape[1] == all_center_and_strides.shape[1] + + img_h, img_w = image_size + scale_limit = all_decoded_bboxes_pixel.new_tensor([img_w, img_h, img_w, img_h]) + all_decoded_bboxes_pixel = torch.min( + all_decoded_bboxes_pixel, scale_limit + ).clamp(min=0) + + for img_idx in range(batch_size): + pred_boxes = all_decoded_bboxes_pixel[img_idx] + pred_cls_logits = all_cls_preds[img_idx] + pred_bbox = all_bbox_preds[img_idx] + priors = all_center_and_strides[img_idx] + gt_boxes = gt_boxes_xyxy_list[img_idx].to(device) + gt_labels = gt_labels_list[img_idx].to(device).long() + + if gt_boxes.numel() > 0: + gt_wh = (gt_boxes[:, 2:] - gt_boxes[:, :2]).clamp(min=0) + gt_area = gt_wh[:, 0] * gt_wh[:, 1] + small = gt_area < 32**2 + medium = (gt_area >= 32**2) & (gt_area < 96**2) + large = gt_area >= 96**2 + + total_gt_small = total_gt_small + small.sum() + total_gt_medium = total_gt_medium + medium.sum() + total_gt_large = total_gt_large + large.sum() + + centers = priors[:, :2] + cx = centers[:, 0].unsqueeze(1) + cy = centers[:, 1].unsqueeze(1) + in_gt = (cx >= gt_boxes[:, 0]) & (cx <= gt_boxes[:, 2]) + in_gt = in_gt & (cy >= gt_boxes[:, 1]) & (cy <= gt_boxes[:, 3]) + has_center = in_gt.any(dim=0) + total_gt_center_small = total_gt_center_small + has_center[small].sum() + total_gt_center_medium = ( + total_gt_center_medium + has_center[medium].sum() + ) + total_gt_center_large = total_gt_center_large + has_center[large].sum() + + assigned_gt, assigned_labels, assigned_ious = self.o2o_assigner.assign( + pred_boxes_xyxy=pred_boxes, + pred_cls_logits=pred_cls_logits, + gt_boxes_xyxy=gt_boxes, + gt_labels=gt_labels, + prior_centers=priors[:, :2], + ) + + pos_mask = assigned_gt >= 0 + if gt_boxes.numel() > 0 and pos_mask.any(): + matched = torch.zeros( + (gt_boxes.shape[0],), device=device, dtype=torch.bool + ) + matched[assigned_gt[pos_mask]] = True + gt_wh = (gt_boxes[:, 2:] - gt_boxes[:, :2]).clamp(min=0) + gt_area = gt_wh[:, 0] * gt_wh[:, 1] + small = gt_area < 32**2 + medium = (gt_area >= 32**2) & (gt_area < 96**2) + large = gt_area >= 96**2 + total_gt_matched_small = total_gt_matched_small + matched[small].sum() + total_gt_matched_medium = ( + total_gt_matched_medium + matched[medium].sum() + ) + total_gt_matched_large = total_gt_matched_large + matched[large].sum() + cls_target = pred_cls_logits.new_zeros(pred_cls_logits.shape) + num_pos = pos_mask.sum() + total_pos = total_pos + num_pos + if pos_mask.any(): + cls_target[pos_mask] = F.one_hot( + assigned_labels[pos_mask], num_classes=self.num_classes + ).to(dtype=pred_cls_logits.dtype) * assigned_ious[pos_mask].unsqueeze( + -1 + ) + total_iou = total_iou + assigned_ious[pos_mask].sum() + total_cls_target = total_cls_target + cls_target[pos_mask].sum() + + cls_loss_raw = self.vfl_loss(pred_cls_logits, cls_target) + cls_norm = cls_target.sum().clamp(min=1.0) + cls_loss = cls_loss_raw / cls_norm + + if pos_mask.any(): + target_boxes = gt_boxes[assigned_gt[pos_mask]] + weight_targets = assigned_ious[pos_mask] + weight_sum = weight_targets.sum().clamp(min=1.0) + giou_loss = self.giou_loss( + pred_boxes[pos_mask], + target_boxes, + weight=weight_targets, + ) + pos_priors = priors[pos_mask] + pos_strides = pos_priors[:, 2:3] + pos_centers = pos_priors[:, :2] + pos_centers_feature = pos_centers / pos_strides + pos_gt_bboxes_feature = target_boxes / pos_strides + pos_bbox_pred = pred_bbox[pos_mask] + pos_gt_distances = bbox2distance( + pos_centers_feature, + pos_gt_bboxes_feature, + reg_max=float(self.reg_max), + ) + dfl_weight = weight_targets.unsqueeze(-1).expand(-1, 4).reshape(-1) + dfl_loss = self.dfl_loss( + pos_bbox_pred.reshape(-1, self.reg_max + 1), + pos_gt_distances.reshape(-1), + weight=dfl_weight, + ) + dfl_loss = dfl_loss / 4.0 + giou_loss = giou_loss / weight_sum + dfl_loss = dfl_loss / weight_sum + box_loss = giou_loss + self.model_args.loss_o2o_dfl_weight * dfl_loss + else: + box_loss = torch.zeros((), device=device) + dfl_loss = torch.zeros((), device=device) + + obj_loss = torch.zeros((), device=device) + + total_obj = total_obj + obj_loss + total_cls = total_cls + cls_loss + total_box = total_box + box_loss + total_dfl = total_dfl + dfl_loss + + total_obj = total_obj / batch_size + total_cls = total_cls / batch_size + total_box = total_box / batch_size + total_loss = total_obj + total_cls + total_box + stats = { + "o2o_num_pos": total_pos, + "o2o_mean_iou": total_iou / total_pos.clamp(min=1), + "o2o_cls_target_sum": total_cls_target, + "o2o_gt_small": total_gt_small, + "o2o_gt_medium": total_gt_medium, + "o2o_gt_large": total_gt_large, + "o2o_gt_center_small": total_gt_center_small, + "o2o_gt_center_medium": total_gt_center_medium, + "o2o_gt_center_large": total_gt_center_large, + "o2o_gt_matched_small": total_gt_matched_small, + "o2o_gt_matched_medium": total_gt_matched_medium, + "o2o_gt_matched_large": total_gt_matched_large, + } + total_dfl = total_dfl / batch_size + return total_loss, total_obj, total_cls, total_box, total_dfl, stats + def get_optimizer(self, total_steps: int) -> tuple[Optimizer, LRScheduler]: """Create optimizer and learning rate scheduler. @@ -524,7 +839,7 @@ def get_optimizer(self, total_steps: int) -> tuple[Optimizer, LRScheduler]: weight_decay=self.model_args.weight_decay, ) - warmup_steps = self.model_args.lr_warmup_steps + warmup_steps = min(self.model_args.lr_warmup_steps, total_steps) max_steps = total_steps scheduler = CosineWarmupScheduler( optimizer=optimizer, diff --git a/src/lightly_train/_task_models/picodet_object_detection/transforms.py b/src/lightly_train/_task_models/picodet_object_detection/transforms.py index bd9458e6b..ffd4182f8 100644 --- a/src/lightly_train/_task_models/picodet_object_detection/transforms.py +++ b/src/lightly_train/_task_models/picodet_object_detection/transforms.py @@ -119,8 +119,13 @@ def resolve_auto(self, model_init_args: dict[str, Any]) -> None: super().resolve_auto(model_init_args=model_init_args) if self.image_size == "auto": - # Default to 416x416 for PicoDet - self.image_size = tuple(model_init_args.get("image_size", (416, 416))) + model_name = model_init_args.get("model_name") + default_image_size = ( + (640, 640) if model_name == "picodet/l-640" else (416, 416) + ) + self.image_size = tuple( + model_init_args.get("image_size", default_image_size) + ) height, width = self.image_size for field_name in self.__class__.model_fields: @@ -217,8 +222,13 @@ def resolve_auto(self, model_init_args: dict[str, Any]) -> None: super().resolve_auto(model_init_args=model_init_args) if self.image_size == "auto": - # Default to 416x416 for PicoDet - self.image_size = tuple(model_init_args.get("image_size", (416, 416))) + model_name = model_init_args.get("model_name") + default_image_size = ( + (640, 640) if model_name == "picodet/l-640" else (416, 416) + ) + self.image_size = tuple( + model_init_args.get("image_size", default_image_size) + ) height, width = self.image_size for field_name in self.__class__.model_fields: diff --git a/tests/_task_models/picodet_object_detection/test_task_model.py b/tests/_task_models/picodet_object_detection/test_task_model.py index 415033e7e..7bb9d0d7a 100644 --- a/tests/_task_models/picodet_object_detection/test_task_model.py +++ b/tests/_task_models/picodet_object_detection/test_task_model.py @@ -7,6 +7,7 @@ # from __future__ import annotations +import math from pathlib import Path import pytest @@ -59,12 +60,13 @@ def test_task_model_forward_shapes() -> None: ) x = torch.randn(1, 3, 416, 416) - labels, boxes, scores = model(x) + boxes, obj_logits, cls_logits = model(x) - max_detections = model.postprocessor.max_detections - assert labels.shape == (1, max_detections) - assert boxes.shape == (1, max_detections, 4) - assert scores.shape == (1, max_detections) + strides = model.o2o_head.strides + num_preds = sum(math.ceil(416 / s) ** 2 for s in strides) + assert boxes.shape == (1, num_preds, 4) + assert obj_logits.shape == (1, num_preds) + assert cls_logits.shape == (1, num_preds, 80) @pytest.mark.skipif(not RequirementCache("onnx"), reason="onnx not installed")