Skip to content

Commit 31c7379

Browse files
committed
small tweaks to try to retrieve each step of inference
1 parent 84835fb commit 31c7379

File tree

1 file changed

+69
-1
lines changed

1 file changed

+69
-1
lines changed

mmdet/models/backbones/swin.py

+69-1
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,9 @@ class SwinTransformer(BaseModule):
485485
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
486486
Default: 4.
487487
depths (tuple[int]): Depths of each Swin Transformer stage.
488-
Default: (2, 2, 6, 2).
488+
Default: (2, 2, 6, 2).
489+
This means that the model has 4 "stages. <-- George Comment. The number of stages is retrieved by running stages = len(depths)
490+
489491
num_heads (tuple[int]): Parallel attention heads of each Swin
490492
Transformer stage. Default: (3, 6, 12, 24).
491493
strides (tuple[int]): The patch merging or patch embedding stride of
@@ -571,6 +573,10 @@ def __init__(self,
571573

572574
super(SwinTransformer, self).__init__(init_cfg=init_cfg)
573575

576+
# George comment
577+
# default value of depths --> (2, 2, 6, 2)
578+
# therefore num_layers = 4, which then ends up being
579+
# the number of stages
574580
num_layers = len(depths)
575581
self.out_indices = out_indices
576582
self.use_abs_pos_embed = use_abs_pos_embed
@@ -603,6 +609,7 @@ def __init__(self,
603609

604610
self.stages = ModuleList()
605611
in_channels = embed_dims
612+
606613
for i in range(num_layers):
607614
if i < num_layers - 1:
608615
downsample = PatchMerging(
@@ -614,6 +621,11 @@ def __init__(self,
614621
else:
615622
downsample = None
616623

624+
# George comment:
625+
# one stage for every layer
626+
# very annoying terminology switch
627+
# don't see why it wouldn't just be num_stages
628+
# instead of num_layers
617629
stage = SwinBlockSequence(
618630
embed_dims=in_channels,
619631
num_heads=num_heads[i],
@@ -817,3 +829,59 @@ def correct_unfold_norm_order(x):
817829
new_ckpt['backbone.' + new_k] = new_v
818830

819831
return new_ckpt
832+
833+
834+
@MODELS.register_module()
835+
class SwinTransformerFirst3Stages(SwinTransformer):
836+
837+
def forward(self, x):
838+
x, hw_shape = self.patch_embed(x)
839+
840+
if self.use_abs_pos_embed:
841+
x = x + self.absolute_pos_embed
842+
x = self.drop_after_pos(x)
843+
844+
outs = []
845+
846+
# Switch this so that it only runs for the number of
847+
# frozen stages.
848+
for i, stage in enumerate(self.stages):
849+
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
850+
851+
if i >= self.frozen_stages:
852+
break
853+
854+
if i in self.out_indices:
855+
norm_layer = getattr(self, f'norm{i}')
856+
out = norm_layer(out)
857+
out = out.view(-1, *out_hw_shape,
858+
self.num_features[i]).permute(0, 3, 1,
859+
2).contiguous()
860+
outs.append(out)
861+
862+
return outs
863+
864+
865+
866+
@MODELS.register_module()
867+
class SwinTransformerLastStage(SwinTransformer):
868+
869+
def forward(self, x):
870+
stage = self.stages[:-1]
871+
872+
# Not yet sure if this is right
873+
i = len(self.stages)
874+
x, hw_shape, out, out_hw_shape = stage(x, hw_shape)
875+
876+
# Switch this so that it only runs for the number of
877+
# frozen stages.
878+
norm_layer = getattr(self, f'norm{i}')
879+
out = norm_layer(out)
880+
out = out.view(-1, *out_hw_shape,
881+
self.num_features[i]).permute(0, 3, 1,
882+
2).contiguous()
883+
884+
# Return as list to keep the output format consistent
885+
# You'd then feed this through to the rpn and roi_heads
886+
return [out]
887+

0 commit comments

Comments
 (0)