Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
264fb0e
Fix bug when loading images with no bounding boxes
IgorSusmelj Jan 20, 2026
cfd67c6
Use ReLU. Use NMS free head.
IgorSusmelj Jan 20, 2026
6e173fb
Update tests
IgorSusmelj Jan 20, 2026
88f2597
Stabilise for fp16 and use pos only for loss
IgorSusmelj Jan 21, 2026
10ea159
Stabilis for fp16 training
IgorSusmelj Jan 21, 2026
b4dcf65
Aligned with yolo10 loss
IgorSusmelj Jan 21, 2026
80e1736
Add debug statements
IgorSusmelj Jan 22, 2026
6da7a3d
Use p3 instead of p6
IgorSusmelj Jan 22, 2026
94f83a4
Change head to not spatial information
IgorSusmelj Jan 22, 2026
bc39c62
Use topk to reduce memory pressure during validation
IgorSusmelj Jan 22, 2026
2c25587
Make export more fp16 stable
IgorSusmelj Jan 23, 2026
03bd316
Add L1 loss for stabilising the box coords
IgorSusmelj Jan 23, 2026
2691005
Add spatial priors to help with training speed
IgorSusmelj Jan 23, 2026
7427926
Remove l1 loss
IgorSusmelj Jan 25, 2026
6c768e6
Rearrange onnx export
IgorSusmelj Jan 30, 2026
acb3012
Handle small boxes better
IgorSusmelj Jan 30, 2026
23e19c6
Add debug message for matching stats across box sizes
IgorSusmelj Jan 30, 2026
92a8a83
Add dfl loss to o2o head
IgorSusmelj Jan 30, 2026
08b4e7d
Switch to better matching algo
IgorSusmelj Jan 30, 2026
77098f6
Remove duplicate code
IgorSusmelj Jan 30, 2026
4de145c
Update losses
IgorSusmelj Jan 30, 2026
490b4ac
Updated code and add shape check
IgorSusmelj Jan 30, 2026
b08c561
Remove coord channels and support pretrained backbones
IgorSusmelj Jan 31, 2026
980117f
Add backbone weights to training args
IgorSusmelj Jan 31, 2026
bc60181
Format code and fix wrong pathlike import
IgorSusmelj Jan 31, 2026
b04afeb
Be more flexible for loading backbone weights
IgorSusmelj Jan 31, 2026
8ed1ed2
Add debugging info when loading backbone weights
IgorSusmelj Jan 31, 2026
2850bce
Make sure we pass backbone_weights to model
IgorSusmelj Jan 31, 2026
591f387
Changes for small objects
IgorSusmelj Feb 3, 2026
40330aa
Switch back to non nms for eval and fix normalisation issue in the loss.
IgorSusmelj Feb 3, 2026
a19f591
Suppress low conf boxes and pool them
IgorSusmelj Feb 4, 2026
bb08544
Add per level tresholding and kernel
IgorSusmelj Feb 5, 2026
e242b1d
Switch to 640x640 for picodet l model
IgorSusmelj Feb 7, 2026
9feb55b
Make sure we can export to fp16
IgorSusmelj Feb 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/lightly_train/_commands/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,8 @@ def _train_task_from_config(config: TrainTaskConfig) -> None:
)

model_init_args = {} if model_init_args is None else model_init_args
if "model_name" not in model_init_args:
model_init_args["model_name"] = config.model

train_transform_args, val_transform_args = helpers.get_transform_args(
train_model_cls=train_model_cls,
Expand Down
3 changes: 2 additions & 1 deletion src/lightly_train/_data/yolo_object_detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def __getitem__(self, index: int) -> ObjectDetectionDatasetItem:
[
int(class_id) in self.class_id_to_internal_class_id
for class_id in class_labels_np
]
],
dtype=bool,
)
bboxes_np = bboxes_np[keep]
class_labels_np = class_labels_np[keep]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.Hardswish(inplace=True)
self.act = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
x = self.depthwise(x)
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(
bias=False,
)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.Hardswish(inplace=True)
self.act = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
out: Tensor = self.act(self.bn(self.conv(x)))
Expand Down
26 changes: 12 additions & 14 deletions src/lightly_train/_task_models/picodet_object_detection/esnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(
padding: int | None = None,
groups: int = 1,
bias: bool = False,
act: Literal["hardswish", "relu", "none"] = "hardswish",
act: Literal["relu", "none"] = "relu",
) -> None:
super().__init__()
if padding is None:
Expand All @@ -93,10 +93,8 @@ def __init__(
)
self.bn = nn.BatchNorm2d(out_channels)

if act == "hardswish":
self.act: nn.Module = nn.Hardswish(inplace=True)
elif act == "relu":
self.act = nn.ReLU(inplace=True)
if act == "relu":
self.act: nn.Module = nn.ReLU(inplace=True)
else:
self.act = nn.Identity()

Expand All @@ -122,7 +120,7 @@ def __init__(self, channels: int, reduction: int = 4) -> None:
def forward(self, x: Tensor) -> Tensor:
scale = F.adaptive_avg_pool2d(x, 1)
scale = F.relu(self.fc1(scale), inplace=True)
scale = F.hardsigmoid(self.fc2(scale), inplace=True)
scale = torch.sigmoid(self.fc2(scale))
return x * scale


Expand All @@ -149,7 +147,7 @@ def __init__(
super().__init__()

self.conv_pw = ConvBNAct(
in_channels // 2, mid_channels // 2, kernel_size=1, act="hardswish"
in_channels // 2, mid_channels // 2, kernel_size=1, act="relu"
)
self.conv_dw = ConvBNAct(
mid_channels // 2,
Expand All @@ -160,7 +158,7 @@ def __init__(
)
self.se = SEModule(se_channels, reduction=4)
self.conv_linear = ConvBNAct(
mid_channels, out_channels // 2, kernel_size=1, act="hardswish"
mid_channels, out_channels // 2, kernel_size=1, act="relu"
)

def forward(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -208,11 +206,11 @@ def __init__(
act="none",
)
self.conv_linear_1 = ConvBNAct(
in_channels, out_channels // 2, kernel_size=1, act="hardswish"
in_channels, out_channels // 2, kernel_size=1, act="relu"
)

self.conv_pw_2 = ConvBNAct(
in_channels, mid_channels // 2, kernel_size=1, act="hardswish"
in_channels, mid_channels // 2, kernel_size=1, act="relu"
)
self.conv_dw_2 = ConvBNAct(
mid_channels // 2,
Expand All @@ -224,18 +222,18 @@ def __init__(
)
self.se = SEModule(se_channels, reduction=4)
self.conv_linear_2 = ConvBNAct(
mid_channels // 2, out_channels // 2, kernel_size=1, act="hardswish"
mid_channels // 2, out_channels // 2, kernel_size=1, act="relu"
)

self.conv_dw_mv1 = ConvBNAct(
out_channels,
out_channels,
kernel_size=3,
groups=out_channels,
act="hardswish",
act="relu",
)
self.conv_pw_mv1 = ConvBNAct(
out_channels, out_channels, kernel_size=1, act="hardswish"
out_channels, out_channels, kernel_size=1, act="relu"
)

def forward(self, x: Tensor) -> Tensor:
Expand Down Expand Up @@ -374,7 +372,7 @@ def __init__(
]

self.conv1 = ConvBNAct(
in_channels, stage_out_channels[0], kernel_size=3, stride=2, act="hardswish"
in_channels, stage_out_channels[0], kernel_size=3, stride=2, act="relu"
)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, bias=False)
self.bn = nn.BatchNorm2d(out_channels)
self.act = nn.Hardswish(inplace=True)
self.act = nn.ReLU(inplace=True)

def forward(self, x: Tensor) -> Tensor:
x = self.depthwise(x)
Expand Down Expand Up @@ -95,7 +95,7 @@ def forward(self, x: Tensor) -> Tensor:
x = F.softmax(x, dim=-1)

# Compute expectation
project: Tensor = self.project # type: ignore[assignment]
project: Tensor = self.project.to(dtype=x.dtype, device=x.device) # type: ignore[assignment]
x = F.linear(x, project.view(1, -1)).squeeze(-1)

# Reshape back to (..., 4) if input was 4*(reg_max+1)
Expand Down Expand Up @@ -242,7 +242,7 @@ def __init__(
bias=False,
),
nn.BatchNorm2d(feat_channels),
nn.Hardswish(inplace=True),
nn.ReLU(inplace=True),
)
)
self.cls_convs.append(cls_convs)
Expand All @@ -266,7 +266,7 @@ def __init__(
bias=False,
),
nn.BatchNorm2d(feat_channels),
nn.Hardswish(inplace=True),
nn.ReLU(inplace=True),
)
)
self.reg_convs.append(reg_convs)
Expand Down Expand Up @@ -426,3 +426,75 @@ def decode_predictions(
all_decoded_bboxes = torch.cat(decoded_bboxes_list, dim=1)

return all_points, all_cls_scores, all_decoded_bboxes


class PicoHeadO2O(nn.Module):
"""One-to-one query head for PicoDet.

This head produces a fixed number of predictions (Q) without NMS or TopK.
It uses a lightweight query pooling mechanism over a single feature map.

Args:
in_channels: Number of input channels for the feature map.
num_classes: Number of object classes.
num_queries: Number of fixed predictions to emit.
hidden_dim: Hidden dimension for query features.
"""

def __init__(
self,
in_channels: int,
num_classes: int,
num_queries: int = 100,
hidden_dim: int | None = None,
) -> None:
super().__init__()
hidden_dim = in_channels if hidden_dim is None else hidden_dim
self.num_queries = num_queries
self.num_classes = num_classes

self.attn_conv = nn.Conv2d(in_channels, num_queries, kernel_size=1)
self.query_proj = nn.Linear(in_channels, hidden_dim)
self.act = nn.ReLU(inplace=True)
self.obj_head = nn.Linear(hidden_dim, 1)
self.cls_head = nn.Linear(hidden_dim, num_classes)
self.box_head = nn.Linear(hidden_dim, 4)

self._init_weights()

def _init_weights(self) -> None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=0.01)
if m.bias is not None:
nn.init.zeros_(m.bias)

def forward(self, feat: Tensor) -> tuple[Tensor, Tensor, Tensor]:
"""Forward pass.

Args:
feat: Feature map of shape (B, C, H, W).

Returns:
Tuple of:
- obj_logits: (B, Q)
- cls_logits: (B, Q, C)
- box_preds: (B, Q, 4) in normalized cxcywh format.
"""
batch_size, channels, height, width = feat.shape
attn = self.attn_conv(feat).reshape(batch_size, self.num_queries, -1)
attn = attn.clamp(min=-10.0, max=10.0)
attn = torch.softmax(attn, dim=-1)

feat_flat = feat.reshape(batch_size, channels, height * width).transpose(1, 2)
pooled = torch.bmm(attn, feat_flat)
pooled = self.act(self.query_proj(pooled))

obj_logits = self.obj_head(pooled).squeeze(-1)
cls_logits = self.cls_head(pooled)
box_preds = self.box_head(pooled).sigmoid()
return obj_logits, cls_logits, box_preds
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,12 @@ def deploy(self) -> None:
self.deploy_mode = True

def _generate_grid_points(
self, height: int, width: int, stride: int, device: torch.device
self,
height: int,
width: int,
stride: int,
device: torch.device,
dtype: torch.dtype,
) -> Tensor:
"""Generate grid center points for a feature map.

Expand All @@ -76,8 +81,8 @@ def _generate_grid_points(
Returns:
Grid points of shape (H*W, 2) as [x, y] in pixel coordinates.
"""
y = (torch.arange(height, device=device, dtype=torch.float32) + 0.5) * stride
x = (torch.arange(width, device=device, dtype=torch.float32) + 0.5) * stride
y = (torch.arange(height, device=device, dtype=dtype) + 0.5) * stride
x = (torch.arange(width, device=device, dtype=dtype) + 0.5) * stride
yy, xx = torch.meshgrid(y, x, indexing="ij")
return torch.stack([xx.flatten(), yy.flatten()], dim=-1)

Expand Down Expand Up @@ -128,7 +133,9 @@ def forward(
bbox_pred[0].permute(1, 2, 0).reshape(-1, 4 * (self.reg_max + 1))
)

points = self._generate_grid_points(height, width, stride, device)
points = self._generate_grid_points(
height, width, stride, device, dtype=cls_score.dtype
)

scores = cls_score.sigmoid().reshape(-1, num_classes)
valid_mask = scores > score_thr
Expand Down
Loading