@@ -485,7 +485,9 @@ class SwinTransformer(BaseModule):
485
485
mlp_ratio (int): Ratio of mlp hidden dim to embedding dim.
486
486
Default: 4.
487
487
depths (tuple[int]): Depths of each Swin Transformer stage.
488
- Default: (2, 2, 6, 2).
488
+ Default: (2, 2, 6, 2).
489
+ This means that the model has 4 "stages. <-- George Comment. The number of stages is retrieved by running stages = len(depths)
490
+
489
491
num_heads (tuple[int]): Parallel attention heads of each Swin
490
492
Transformer stage. Default: (3, 6, 12, 24).
491
493
strides (tuple[int]): The patch merging or patch embedding stride of
@@ -571,6 +573,10 @@ def __init__(self,
571
573
572
574
super (SwinTransformer , self ).__init__ (init_cfg = init_cfg )
573
575
576
+ # George comment
577
+ # default value of depths --> (2, 2, 6, 2)
578
+ # therefore num_layers = 4, which then ends up being
579
+ # the number of stages
574
580
num_layers = len (depths )
575
581
self .out_indices = out_indices
576
582
self .use_abs_pos_embed = use_abs_pos_embed
@@ -603,6 +609,7 @@ def __init__(self,
603
609
604
610
self .stages = ModuleList ()
605
611
in_channels = embed_dims
612
+
606
613
for i in range (num_layers ):
607
614
if i < num_layers - 1 :
608
615
downsample = PatchMerging (
@@ -614,6 +621,11 @@ def __init__(self,
614
621
else :
615
622
downsample = None
616
623
624
+ # George comment:
625
+ # one stage for every layer
626
+ # very annoying terminology switch
627
+ # don't see why it wouldn't just be num_stages
628
+ # instead of num_layers
617
629
stage = SwinBlockSequence (
618
630
embed_dims = in_channels ,
619
631
num_heads = num_heads [i ],
@@ -817,3 +829,59 @@ def correct_unfold_norm_order(x):
817
829
new_ckpt ['backbone.' + new_k ] = new_v
818
830
819
831
return new_ckpt
832
+
833
+
834
+ @MODELS .register_module ()
835
+ class SwinTransformerFirst3Stages (SwinTransformer ):
836
+
837
+ def forward (self , x ):
838
+ x , hw_shape = self .patch_embed (x )
839
+
840
+ if self .use_abs_pos_embed :
841
+ x = x + self .absolute_pos_embed
842
+ x = self .drop_after_pos (x )
843
+
844
+ outs = []
845
+
846
+ # Switch this so that it only runs for the number of
847
+ # frozen stages.
848
+ for i , stage in enumerate (self .stages ):
849
+ x , hw_shape , out , out_hw_shape = stage (x , hw_shape )
850
+
851
+ if i >= self .frozen_stages :
852
+ break
853
+
854
+ if i in self .out_indices :
855
+ norm_layer = getattr (self , f'norm{ i } ' )
856
+ out = norm_layer (out )
857
+ out = out .view (- 1 , * out_hw_shape ,
858
+ self .num_features [i ]).permute (0 , 3 , 1 ,
859
+ 2 ).contiguous ()
860
+ outs .append (out )
861
+
862
+ return outs
863
+
864
+
865
+
866
+ @MODELS .register_module ()
867
+ class SwinTransformerLastStage (SwinTransformer ):
868
+
869
+ def forward (self , x ):
870
+ stage = self .stages [:- 1 ]
871
+
872
+ # Not yet sure if this is right
873
+ i = len (self .stages )
874
+ x , hw_shape , out , out_hw_shape = stage (x , hw_shape )
875
+
876
+ # Switch this so that it only runs for the number of
877
+ # frozen stages.
878
+ norm_layer = getattr (self , f'norm{ i } ' )
879
+ out = norm_layer (out )
880
+ out = out .view (- 1 , * out_hw_shape ,
881
+ self .num_features [i ]).permute (0 , 3 , 1 ,
882
+ 2 ).contiguous ()
883
+
884
+ # Return as list to keep the output format consistent
885
+ # You'd then feed this through to the rpn and roi_heads
886
+ return [out ]
887
+
0 commit comments