-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
defaults.py
1295 lines (951 loc) · 40 KB
/
defaults.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Configs."""
import math
from fvcore.common.config import CfgNode
from . import custom_config
# -----------------------------------------------------------------------------
# Config definition
# -----------------------------------------------------------------------------
_C = CfgNode()
# -----------------------------------------------------------------------------
# Contrastive Model (for MoCo, SimCLR, SwAV, BYOL)
# -----------------------------------------------------------------------------
_C.CONTRASTIVE = CfgNode()
# temperature used for contrastive losses
_C.CONTRASTIVE.T = 0.07
# output dimension for the loss
_C.CONTRASTIVE.DIM = 128
# number of training samples (for kNN bank)
_C.CONTRASTIVE.LENGTH = 239975
# the length of MoCo's and MemBanks' queues
_C.CONTRASTIVE.QUEUE_LEN = 65536
# momentum for momentum encoder updates
_C.CONTRASTIVE.MOMENTUM = 0.5
# wether to anneal momentum to value above with cosine schedule
_C.CONTRASTIVE.MOMENTUM_ANNEALING = False
# either memorybank, moco, simclr, byol, swav
_C.CONTRASTIVE.TYPE = "mem"
# wether to interpolate memorybank in time
_C.CONTRASTIVE.INTERP_MEMORY = False
# 1d or 2d (+temporal) memory
_C.CONTRASTIVE.MEM_TYPE = "1d"
# number of classes for online kNN evaluation
_C.CONTRASTIVE.NUM_CLASSES_DOWNSTREAM = 400
# use an MLP projection with these num layers
_C.CONTRASTIVE.NUM_MLP_LAYERS = 1
# dimension of projection and predictor MLPs
_C.CONTRASTIVE.MLP_DIM = 2048
# use BN in projection/prediction MLP
_C.CONTRASTIVE.BN_MLP = False
# use synchronized BN in projection/prediction MLP
_C.CONTRASTIVE.BN_SYNC_MLP = False
# shuffle BN only locally vs. across machines
_C.CONTRASTIVE.LOCAL_SHUFFLE_BN = True
# Wether to fill multiple clips (or just the first) into queue
_C.CONTRASTIVE.MOCO_MULTI_VIEW_QUEUE = False
# if sampling multiple clips per vid they need to be at least min frames apart
_C.CONTRASTIVE.DELTA_CLIPS_MIN = -math.inf
# if sampling multiple clips per vid they can be max frames apart
_C.CONTRASTIVE.DELTA_CLIPS_MAX = math.inf
# if non empty, use predictors with depth specified
_C.CONTRASTIVE.PREDICTOR_DEPTHS = []
# Wether to sequentially process multiple clips (=lower mem usage) or batch them
_C.CONTRASTIVE.SEQUENTIAL = False
# Wether to perform SimCLR loss across machines (or only locally)
_C.CONTRASTIVE.SIMCLR_DIST_ON = True
# Length of queue used in SwAV
_C.CONTRASTIVE.SWAV_QEUE_LEN = 0
# Wether to run online kNN evaluation during training
_C.CONTRASTIVE.KNN_ON = True
# ---------------------------------------------------------------------------- #
# Batch norm options
# ---------------------------------------------------------------------------- #
_C.BN = CfgNode()
# Precise BN stats.
_C.BN.USE_PRECISE_STATS = False
# Number of samples use to compute precise bn.
_C.BN.NUM_BATCHES_PRECISE = 200
# Weight decay value that applies on BN.
_C.BN.WEIGHT_DECAY = 0.0
# Norm type, options include `batchnorm`, `sub_batchnorm`, `sync_batchnorm`
_C.BN.NORM_TYPE = "batchnorm"
# Parameter for SubBatchNorm, where it splits the batch dimension into
# NUM_SPLITS splits, and run BN on each of them separately independently.
_C.BN.NUM_SPLITS = 1
# Parameter for NaiveSyncBatchNorm, where the stats across `NUM_SYNC_DEVICES`
# devices will be synchronized. `NUM_SYNC_DEVICES` cannot be larger than number of
# devices per machine; if global sync is desired, set `GLOBAL_SYNC`.
# By default ONLY applies to NaiveSyncBatchNorm3d; consider also setting
# CONTRASTIVE.BN_SYNC_MLP if appropriate.
_C.BN.NUM_SYNC_DEVICES = 1
# Parameter for NaiveSyncBatchNorm. Setting `GLOBAL_SYNC` to True synchronizes
# stats across all devices, across all machines; in this case, `NUM_SYNC_DEVICES`
# must be set to None.
# By default ONLY applies to NaiveSyncBatchNorm3d; consider also setting
# CONTRASTIVE.BN_SYNC_MLP if appropriate.
_C.BN.GLOBAL_SYNC = False
# ---------------------------------------------------------------------------- #
# Training options.
# ---------------------------------------------------------------------------- #
_C.TRAIN = CfgNode()
# If True Train the model, else skip training.
_C.TRAIN.ENABLE = True
# Kill training if loss explodes over this ratio from the previous 5 measurements.
# Only enforced if > 0.0
_C.TRAIN.KILL_LOSS_EXPLOSION_FACTOR = 0.0
# Dataset.
_C.TRAIN.DATASET = "kinetics"
# Total mini-batch size.
_C.TRAIN.BATCH_SIZE = 64
# Evaluate model on test data every eval period epochs.
_C.TRAIN.EVAL_PERIOD = 10
# Save model checkpoint every checkpoint period epochs.
_C.TRAIN.CHECKPOINT_PERIOD = 10
# Resume training from the latest checkpoint in the output directory.
_C.TRAIN.AUTO_RESUME = True
# Path to the checkpoint to load the initial weight.
_C.TRAIN.CHECKPOINT_FILE_PATH = ""
# Checkpoint types include `caffe2` or `pytorch`.
_C.TRAIN.CHECKPOINT_TYPE = "pytorch"
# If True, perform inflation when loading checkpoint.
_C.TRAIN.CHECKPOINT_INFLATE = False
# If True, reset epochs when loading checkpoint.
_C.TRAIN.CHECKPOINT_EPOCH_RESET = False
# If set, clear all layer names according to the pattern provided.
_C.TRAIN.CHECKPOINT_CLEAR_NAME_PATTERN = () # ("backbone.",)
# If True, use FP16 for activations
_C.TRAIN.MIXED_PRECISION = False
# if True, inflate some params from imagenet model.
_C.TRAIN.CHECKPOINT_IN_INIT = False
# ---------------------------------------------------------------------------- #
# Augmentation options.
# ---------------------------------------------------------------------------- #
_C.AUG = CfgNode()
# Whether to enable randaug.
_C.AUG.ENABLE = False
# Number of repeated augmentations to used during training.
# If this is greater than 1, then the actual batch size is
# TRAIN.BATCH_SIZE * AUG.NUM_SAMPLE.
_C.AUG.NUM_SAMPLE = 1
# Not used if using randaug.
_C.AUG.COLOR_JITTER = 0.4
# RandAug parameters.
_C.AUG.AA_TYPE = "rand-m9-mstd0.5-inc1"
# Interpolation method.
_C.AUG.INTERPOLATION = "bicubic"
# Probability of random erasing.
_C.AUG.RE_PROB = 0.25
# Random erasing mode.
_C.AUG.RE_MODE = "pixel"
# Random erase count.
_C.AUG.RE_COUNT = 1
# Do not random erase first (clean) augmentation split.
_C.AUG.RE_SPLIT = False
# Whether to generate input mask during image processing.
_C.AUG.GEN_MASK_LOADER = False
# If True, masking mode is "tube". Default is "cube".
_C.AUG.MASK_TUBE = False
# If True, masking mode is "frame". Default is "cube".
_C.AUG.MASK_FRAMES = False
# The size of generated masks.
_C.AUG.MASK_WINDOW_SIZE = [8, 7, 7]
# The ratio of masked tokens out of all tokens. Also applies to MViT supervised training
_C.AUG.MASK_RATIO = 0.0
# The maximum number of a masked block. None means no maximum limit. (Used only in image MaskFeat.)
_C.AUG.MAX_MASK_PATCHES_PER_BLOCK = None
# ---------------------------------------------------------------------------- #
# Masked pretraining visualization options.
# ---------------------------------------------------------------------------- #
_C.VIS_MASK = CfgNode()
# Whether to do visualization.
_C.VIS_MASK.ENABLE = False
# ---------------------------------------------------------------------------- #
# MipUp options.
# ---------------------------------------------------------------------------- #
_C.MIXUP = CfgNode()
# Whether to use mixup.
_C.MIXUP.ENABLE = False
# Mixup alpha.
_C.MIXUP.ALPHA = 0.8
# Cutmix alpha.
_C.MIXUP.CUTMIX_ALPHA = 1.0
# Probability of performing mixup or cutmix when either/both is enabled.
_C.MIXUP.PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled.
_C.MIXUP.SWITCH_PROB = 0.5
# Label smoothing.
_C.MIXUP.LABEL_SMOOTH_VALUE = 0.1
# ---------------------------------------------------------------------------- #
# Testing options
# ---------------------------------------------------------------------------- #
_C.TEST = CfgNode()
# If True test the model, else skip the testing.
_C.TEST.ENABLE = True
# Dataset for testing.
_C.TEST.DATASET = "kinetics"
# Total mini-batch size
_C.TEST.BATCH_SIZE = 8
# Path to the checkpoint to load the initial weight.
_C.TEST.CHECKPOINT_FILE_PATH = ""
# Number of clips to sample from a video uniformly for aggregating the
# prediction results.
_C.TEST.NUM_ENSEMBLE_VIEWS = 10
# Number of crops to sample from a frame spatially for aggregating the
# prediction results.
_C.TEST.NUM_SPATIAL_CROPS = 3
# Checkpoint types include `caffe2` or `pytorch`.
_C.TEST.CHECKPOINT_TYPE = "pytorch"
# Path to saving prediction results file.
_C.TEST.SAVE_RESULTS_PATH = ""
_C.TEST.NUM_TEMPORAL_CLIPS = []
# -----------------------------------------------------------------------------
# ResNet options
# -----------------------------------------------------------------------------
_C.RESNET = CfgNode()
# Transformation function.
_C.RESNET.TRANS_FUNC = "bottleneck_transform"
# Number of groups. 1 for ResNet, and larger than 1 for ResNeXt).
_C.RESNET.NUM_GROUPS = 1
# Width of each group (64 -> ResNet; 4 -> ResNeXt).
_C.RESNET.WIDTH_PER_GROUP = 64
# Apply relu in a inplace manner.
_C.RESNET.INPLACE_RELU = True
# Apply stride to 1x1 conv.
_C.RESNET.STRIDE_1X1 = False
# If true, initialize the gamma of the final BN of each block to zero.
_C.RESNET.ZERO_INIT_FINAL_BN = False
# If true, initialize the final conv layer of each block to zero.
_C.RESNET.ZERO_INIT_FINAL_CONV = False
# Number of weight layers.
_C.RESNET.DEPTH = 50
# If the current block has more than NUM_BLOCK_TEMP_KERNEL blocks, use temporal
# kernel of 1 for the rest of the blocks.
_C.RESNET.NUM_BLOCK_TEMP_KERNEL = [[3], [4], [6], [3]]
# Size of stride on different res stages.
_C.RESNET.SPATIAL_STRIDES = [[1], [2], [2], [2]]
# Size of dilation on different res stages.
_C.RESNET.SPATIAL_DILATIONS = [[1], [1], [1], [1]]
# ---------------------------------------------------------------------------- #
# X3D options
# See https://arxiv.org/abs/2004.04730 for details about X3D Networks.
# ---------------------------------------------------------------------------- #
_C.X3D = CfgNode()
# Width expansion factor.
_C.X3D.WIDTH_FACTOR = 1.0
# Depth expansion factor.
_C.X3D.DEPTH_FACTOR = 1.0
# Bottleneck expansion factor for the 3x3x3 conv.
_C.X3D.BOTTLENECK_FACTOR = 1.0 #
# Dimensions of the last linear layer before classificaiton.
_C.X3D.DIM_C5 = 2048
# Dimensions of the first 3x3 conv layer.
_C.X3D.DIM_C1 = 12
# Whether to scale the width of Res2, default is false.
_C.X3D.SCALE_RES2 = False
# Whether to use a BatchNorm (BN) layer before the classifier, default is false.
_C.X3D.BN_LIN5 = False
# Whether to use channelwise (=depthwise) convolution in the center (3x3x3)
# convolution operation of the residual blocks.
_C.X3D.CHANNELWISE_3x3x3 = True
# -----------------------------------------------------------------------------
# Nonlocal options
# -----------------------------------------------------------------------------
_C.NONLOCAL = CfgNode()
# Index of each stage and block to add nonlocal layers.
_C.NONLOCAL.LOCATION = [[[]], [[]], [[]], [[]]]
# Number of group for nonlocal for each stage.
_C.NONLOCAL.GROUP = [[1], [1], [1], [1]]
# Instatiation to use for non-local layer.
_C.NONLOCAL.INSTANTIATION = "dot_product"
# Size of pooling layers used in Non-Local.
_C.NONLOCAL.POOL = [
# Res2
[[1, 2, 2], [1, 2, 2]],
# Res3
[[1, 2, 2], [1, 2, 2]],
# Res4
[[1, 2, 2], [1, 2, 2]],
# Res5
[[1, 2, 2], [1, 2, 2]],
]
# -----------------------------------------------------------------------------
# Model options
# -----------------------------------------------------------------------------
_C.MODEL = CfgNode()
# Model architecture.
_C.MODEL.ARCH = "slowfast"
# Model name
_C.MODEL.MODEL_NAME = "SlowFast"
# The number of classes to predict for the model.
_C.MODEL.NUM_CLASSES = 400
# Loss function.
_C.MODEL.LOSS_FUNC = "cross_entropy"
# Model architectures that has one single pathway.
_C.MODEL.SINGLE_PATHWAY_ARCH = [
"2d",
"c2d",
"i3d",
"slow",
"x3d",
"mvit",
"maskmvit",
]
# Model architectures that has multiple pathways.
_C.MODEL.MULTI_PATHWAY_ARCH = ["slowfast"]
# Dropout rate before final projection in the backbone.
_C.MODEL.DROPOUT_RATE = 0.5
# Randomly drop rate for Res-blocks, linearly increase from res2 to res5
_C.MODEL.DROPCONNECT_RATE = 0.0
# The std to initialize the fc layer(s).
_C.MODEL.FC_INIT_STD = 0.01
# Activation layer for the output head.
_C.MODEL.HEAD_ACT = "softmax"
# Activation checkpointing enabled or not to save GPU memory.
_C.MODEL.ACT_CHECKPOINT = False
# If True, detach the final fc layer from the network, by doing so, only the
# final fc layer will be trained.
_C.MODEL.DETACH_FINAL_FC = False
# If True, frozen batch norm stats during training.
_C.MODEL.FROZEN_BN = False
# If True, AllReduce gradients are compressed to fp16
_C.MODEL.FP16_ALLREDUCE = False
# -----------------------------------------------------------------------------
# MViT options
# -----------------------------------------------------------------------------
_C.MVIT = CfgNode()
# Options include `conv`, `max`.
_C.MVIT.MODE = "conv"
# If True, perform pool before projection in attention.
_C.MVIT.POOL_FIRST = False
# If True, use cls embed in the network, otherwise don't use cls_embed in transformer.
_C.MVIT.CLS_EMBED_ON = True
# Kernel size for patchtification.
_C.MVIT.PATCH_KERNEL = [3, 7, 7]
# Stride size for patchtification.
_C.MVIT.PATCH_STRIDE = [2, 4, 4]
# Padding size for patchtification.
_C.MVIT.PATCH_PADDING = [2, 4, 4]
# If True, use 2d patch, otherwise use 3d patch.
_C.MVIT.PATCH_2D = False
# Base embedding dimension for the transformer.
_C.MVIT.EMBED_DIM = 96
# Base num of heads for the transformer.
_C.MVIT.NUM_HEADS = 1
# Dimension reduction ratio for the MLP layers.
_C.MVIT.MLP_RATIO = 4.0
# If use, use bias term in attention fc layers.
_C.MVIT.QKV_BIAS = True
# Drop path rate for the tranfomer.
_C.MVIT.DROPPATH_RATE = 0.1
# The initial value of layer scale gamma. Set 0.0 to disable layer scale.
_C.MVIT.LAYER_SCALE_INIT_VALUE = 0.0
# Depth of the transformer.
_C.MVIT.DEPTH = 16
# Normalization layer for the transformer. Only layernorm is supported now.
_C.MVIT.NORM = "layernorm"
# Dimension multiplication at layer i. If 2.0 is used, then the next block will increase
# the dimension by 2 times. Format: [depth_i: mul_dim_ratio]
_C.MVIT.DIM_MUL = []
# Head number multiplication at layer i. If 2.0 is used, then the next block will
# increase the number of heads by 2 times. Format: [depth_i: head_mul_ratio]
_C.MVIT.HEAD_MUL = []
# Stride size for the Pool KV at layer i.
# Format: [[i, stride_t_i, stride_h_i, stride_w_i], ...,]
_C.MVIT.POOL_KV_STRIDE = []
# Initial stride size for KV at layer 1. The stride size will be further reduced with
# the raio of MVIT.DIM_MUL. If will overwrite MVIT.POOL_KV_STRIDE if not None.
_C.MVIT.POOL_KV_STRIDE_ADAPTIVE = None
# Stride size for the Pool Q at layer i.
# Format: [[i, stride_t_i, stride_h_i, stride_w_i], ...,]
_C.MVIT.POOL_Q_STRIDE = []
# If not None, overwrite the KV_KERNEL and Q_KERNEL size with POOL_KVQ_CONV_SIZ.
# Otherwise the kernel_size is [s + 1 if s > 1 else s for s in stride_size].
_C.MVIT.POOL_KVQ_KERNEL = None
# If True, perform no decay on positional embedding and cls embedding.
_C.MVIT.ZERO_DECAY_POS_CLS = True
# If True, use norm after stem.
_C.MVIT.NORM_STEM = False
# If True, perform separate positional embedding.
_C.MVIT.SEP_POS_EMBED = False
# Dropout rate for the MViT backbone.
_C.MVIT.DROPOUT_RATE = 0.0
# If True, use absolute positional embedding.
_C.MVIT.USE_ABS_POS = True
# If True, use relative positional embedding for spatial dimentions
_C.MVIT.REL_POS_SPATIAL = False
# If True, use relative positional embedding for temporal dimentions
_C.MVIT.REL_POS_TEMPORAL = False
# If True, init rel with zero
_C.MVIT.REL_POS_ZERO_INIT = False
# If True, using Residual Pooling connection
_C.MVIT.RESIDUAL_POOLING = False
# Dim mul in qkv linear layers of attention block instead of MLP
_C.MVIT.DIM_MUL_IN_ATT = False
# If True, using separate linear layers for Q, K, V in attention blocks.
_C.MVIT.SEPARATE_QKV = False
# The initialization scale factor for the head parameters.
_C.MVIT.HEAD_INIT_SCALE = 1.0
# Whether to use the mean pooling of all patch tokens as the output.
_C.MVIT.USE_MEAN_POOLING = False
# If True, use frozen sin cos positional embedding.
_C.MVIT.USE_FIXED_SINCOS_POS = False
# -----------------------------------------------------------------------------
# Masked pretraining options
# -----------------------------------------------------------------------------
_C.MASK = CfgNode()
# Whether to enable Masked style pretraining.
_C.MASK.ENABLE = False
# Whether to enable MAE (discard encoder tokens).
_C.MASK.MAE_ON = False
# Whether to enable random masking in mae
_C.MASK.MAE_RND_MASK = False
# Whether to do random masking per-frame in mae
_C.MASK.PER_FRAME_MASKING = False
# only predict loss on temporal strided patches, or predict full time extent
_C.MASK.TIME_STRIDE_LOSS = True
# Whether to normalize the pred pixel loss
_C.MASK.NORM_PRED_PIXEL = True
# Whether to fix initialization with inverse depth of layer for pretraining.
_C.MASK.SCALE_INIT_BY_DEPTH = False
# Base embedding dimension for the decoder transformer.
_C.MASK.DECODER_EMBED_DIM = 512
# Base embedding dimension for the decoder transformer.
_C.MASK.DECODER_SEP_POS_EMBED = False
# Use a KV kernel in decoder?
_C.MASK.DEC_KV_KERNEL = []
# Use a KV stride in decoder?
_C.MASK.DEC_KV_STRIDE = []
# The depths of features which are inputs of the prediction head.
_C.MASK.PRETRAIN_DEPTH = [15]
# The type of Masked pretraining prediction head.
# Can be "separate", "separate_xformer".
_C.MASK.HEAD_TYPE = "separate"
# The depth of MAE's decoder
_C.MASK.DECODER_DEPTH = 0
# The weight of HOG target loss.
_C.MASK.PRED_HOG = False
# Reversible Configs
_C.MVIT.REV = CfgNode()
# Enable Reversible Model
_C.MVIT.REV.ENABLE = False
# Method to fuse the reversible paths
# see :class: `TwoStreamFusion` for all the options
_C.MVIT.REV.RESPATH_FUSE = "concat"
# Layers to buffer activations at
# (at least Q-pooling layers needed)
_C.MVIT.REV.BUFFER_LAYERS = []
# 'conv' or 'max' operator for the respath in Qpooling
_C.MVIT.REV.RES_PATH = "conv"
# Method to merge hidden states before Qpoolinglayers
_C.MVIT.REV.PRE_Q_FUSION = "avg"
# -----------------------------------------------------------------------------
# SlowFast options
# -----------------------------------------------------------------------------
_C.SLOWFAST = CfgNode()
# Corresponds to the inverse of the channel reduction ratio, $\beta$ between
# the Slow and Fast pathways.
_C.SLOWFAST.BETA_INV = 8
# Corresponds to the frame rate reduction ratio, $\alpha$ between the Slow and
# Fast pathways.
_C.SLOWFAST.ALPHA = 8
# Ratio of channel dimensions between the Slow and Fast pathways.
_C.SLOWFAST.FUSION_CONV_CHANNEL_RATIO = 2
# Kernel dimension used for fusing information from Fast pathway to Slow
# pathway.
_C.SLOWFAST.FUSION_KERNEL_SZ = 5
# -----------------------------------------------------------------------------
# Data options
# -----------------------------------------------------------------------------
_C.DATA = CfgNode()
# The path to the data directory.
_C.DATA.PATH_TO_DATA_DIR = ""
# The separator used between path and label.
_C.DATA.PATH_LABEL_SEPARATOR = " "
# Video path prefix if any.
_C.DATA.PATH_PREFIX = ""
# The number of frames of the input clip.
_C.DATA.NUM_FRAMES = 8
# The video sampling rate of the input clip.
_C.DATA.SAMPLING_RATE = 8
# Eigenvalues for PCA jittering. Note PCA is RGB based.
_C.DATA.TRAIN_PCA_EIGVAL = [0.225, 0.224, 0.229]
# Eigenvectors for PCA jittering.
_C.DATA.TRAIN_PCA_EIGVEC = [
[-0.5675, 0.7192, 0.4009],
[-0.5808, -0.0045, -0.8140],
[-0.5836, -0.6948, 0.4203],
]
# If a imdb have been dumpped to a local file with the following format:
# `{"im_path": im_path, "class": cont_id}`
# then we can skip the construction of imdb and load it from the local file.
_C.DATA.PATH_TO_PRELOAD_IMDB = ""
# The mean value of the video raw pixels across the R G B channels.
_C.DATA.MEAN = [0.45, 0.45, 0.45]
# List of input frame channel dimensions.
_C.DATA.INPUT_CHANNEL_NUM = [3, 3]
# The std value of the video raw pixels across the R G B channels.
_C.DATA.STD = [0.225, 0.225, 0.225]
# The spatial augmentation jitter scales for training.
_C.DATA.TRAIN_JITTER_SCALES = [256, 320]
# The relative scale range of Inception-style area based random resizing augmentation.
# If this is provided, DATA.TRAIN_JITTER_SCALES above is ignored.
_C.DATA.TRAIN_JITTER_SCALES_RELATIVE = []
# The relative aspect ratio range of Inception-style area based random resizing
# augmentation.
_C.DATA.TRAIN_JITTER_ASPECT_RELATIVE = []
# If True, perform stride length uniform temporal sampling.
_C.DATA.USE_OFFSET_SAMPLING = False
# Whether to apply motion shift for augmentation.
_C.DATA.TRAIN_JITTER_MOTION_SHIFT = False
# The spatial crop size for training.
_C.DATA.TRAIN_CROP_SIZE = 224
# The spatial crop size for testing.
_C.DATA.TEST_CROP_SIZE = 256
# Input videos may has different fps, convert it to the target video fps before
# frame sampling.
_C.DATA.TARGET_FPS = 30
# JITTER TARGET_FPS by +- this number randomly
_C.DATA.TRAIN_JITTER_FPS = 0.0
# Decoding backend, options include `pyav` or `torchvision`
_C.DATA.DECODING_BACKEND = "torchvision"
# Decoding resize to short size (set to native size for best speed)
_C.DATA.DECODING_SHORT_SIZE = 256
# if True, sample uniformly in [1 / max_scale, 1 / min_scale] and take a
# reciprocal to get the scale. If False, take a uniform sample from
# [min_scale, max_scale].
_C.DATA.INV_UNIFORM_SAMPLE = False
# If True, perform random horizontal flip on the video frames during training.
_C.DATA.RANDOM_FLIP = True
# If True, calculdate the map as metric.
_C.DATA.MULTI_LABEL = False
# Method to perform the ensemble, options include "sum" and "max".
_C.DATA.ENSEMBLE_METHOD = "sum"
# If True, revert the default input channel (RBG <-> BGR).
_C.DATA.REVERSE_INPUT_CHANNEL = False
# how many samples (=clips) to decode from a single video
_C.DATA.TRAIN_CROP_NUM_TEMPORAL = 1
# how many spatial samples to crop from a single clip
_C.DATA.TRAIN_CROP_NUM_SPATIAL = 1
# color random percentage for grayscale conversion
_C.DATA.COLOR_RND_GRAYSCALE = 0.0
# loader can read .csv file in chunks of this chunk size
_C.DATA.LOADER_CHUNK_SIZE = 0
# if LOADER_CHUNK_SIZE > 0, define overall length of .csv file
_C.DATA.LOADER_CHUNK_OVERALL_SIZE = 0
# for chunked reading, dataloader can skip rows in (large)
# training csv file
_C.DATA.SKIP_ROWS = 0
# The separator used between path and label.
_C.DATA.PATH_LABEL_SEPARATOR = " "
# augmentation probability to convert raw decoded video to
# grayscale temporal difference
_C.DATA.TIME_DIFF_PROB = 0.0
# Apply SSL-based SimCLR / MoCo v1/v2 color augmentations,
# with params below
_C.DATA.SSL_COLOR_JITTER = False
# color jitter percentage for brightness, contrast, saturation
_C.DATA.SSL_COLOR_BRI_CON_SAT = [0.4, 0.4, 0.4]
# color jitter percentage for hue
_C.DATA.SSL_COLOR_HUE = 0.1
# SimCLR / MoCo v2 augmentations on/off
_C.DATA.SSL_MOCOV2_AUG = False
# SimCLR / MoCo v2 blur augmentation minimum gaussian sigma
_C.DATA.SSL_BLUR_SIGMA_MIN = [0.0, 0.1]
# SimCLR / MoCo v2 blur augmentation maximum gaussian sigma
_C.DATA.SSL_BLUR_SIGMA_MAX = [0.0, 2.0]
# If combine train/val split as training for in21k
_C.DATA.IN22K_TRAINVAL = False
# If not None, use IN1k as val split when training in21k
_C.DATA.IN22k_VAL_IN1K = ""
# Large resolution models may use different crop ratios
_C.DATA.IN_VAL_CROP_RATIO = 0.875 # 224/256 = 0.875
# don't use real video for kinetics.py
_C.DATA.DUMMY_LOAD = False
# ---------------------------------------------------------------------------- #
# Optimizer options
# ---------------------------------------------------------------------------- #
_C.SOLVER = CfgNode()
# Base learning rate.
_C.SOLVER.BASE_LR = 0.1
# Learning rate policy (see utils/lr_policy.py for options and examples).
_C.SOLVER.LR_POLICY = "cosine"
# Final learning rates for 'cosine' policy.
_C.SOLVER.COSINE_END_LR = 0.0
# Exponential decay factor.
_C.SOLVER.GAMMA = 0.1
# Step size for 'exp' and 'cos' policies (in epochs).
_C.SOLVER.STEP_SIZE = 1
# Steps for 'steps_' policies (in epochs).
_C.SOLVER.STEPS = []
# Learning rates for 'steps_' policies.
_C.SOLVER.LRS = []
# Maximal number of epochs.
_C.SOLVER.MAX_EPOCH = 300
# Momentum.
_C.SOLVER.MOMENTUM = 0.9
# Momentum dampening.
_C.SOLVER.DAMPENING = 0.0
# Nesterov momentum.
_C.SOLVER.NESTEROV = True
# L2 regularization.
_C.SOLVER.WEIGHT_DECAY = 1e-4
# Start the warm up from SOLVER.BASE_LR * SOLVER.WARMUP_FACTOR.
_C.SOLVER.WARMUP_FACTOR = 0.1
# Gradually warm up the SOLVER.BASE_LR over this number of epochs.
_C.SOLVER.WARMUP_EPOCHS = 0.0
# The start learning rate of the warm up.
_C.SOLVER.WARMUP_START_LR = 0.01
# Optimization method.
_C.SOLVER.OPTIMIZING_METHOD = "sgd"
# Base learning rate is linearly scaled with NUM_SHARDS.
_C.SOLVER.BASE_LR_SCALE_NUM_SHARDS = False
# If True, start from the peak cosine learning rate after warm up.
_C.SOLVER.COSINE_AFTER_WARMUP = False
# If True, perform no weight decay on parameter with one dimension (bias term, etc).
_C.SOLVER.ZERO_WD_1D_PARAM = False
# Clip gradient at this value before optimizer update
_C.SOLVER.CLIP_GRAD_VAL = None
# Clip gradient at this norm before optimizer update
_C.SOLVER.CLIP_GRAD_L2NORM = None
# LARS optimizer
_C.SOLVER.LARS_ON = False
# The layer-wise decay of learning rate. Set to 1. to disable.
_C.SOLVER.LAYER_DECAY = 1.0
# Adam's beta
_C.SOLVER.BETAS = (0.9, 0.999)
# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
# The name of the current task; e.g. "ssl"/"sl" for (self)supervised learning
_C.TASK = ""
# Number of GPUs to use (applies to both training and testing).
_C.NUM_GPUS = 1
# Number of machine to use for the job.
_C.NUM_SHARDS = 1
# The index of the current machine.
_C.SHARD_ID = 0
# Output basedir.
_C.OUTPUT_DIR = "."
# Note that non-determinism may still be present due to non-deterministic
# operator implementations in GPU operator libraries.
_C.RNG_SEED = 1
# Log period in iters.
_C.LOG_PERIOD = 10
# If True, log the model info.
_C.LOG_MODEL_INFO = True
# Distributed backend.
_C.DIST_BACKEND = "nccl"
# ---------------------------------------------------------------------------- #
# Benchmark options
# ---------------------------------------------------------------------------- #
_C.BENCHMARK = CfgNode()
# Number of epochs for data loading benchmark.
_C.BENCHMARK.NUM_EPOCHS = 5
# Log period in iters for data loading benchmark.
_C.BENCHMARK.LOG_PERIOD = 100
# If True, shuffle dataloader for epoch during benchmark.
_C.BENCHMARK.SHUFFLE = True
# ---------------------------------------------------------------------------- #
# Common train/test data loader options
# ---------------------------------------------------------------------------- #
_C.DATA_LOADER = CfgNode()
# Number of data loader workers per training process.
_C.DATA_LOADER.NUM_WORKERS = 8
# Load data to pinned host memory.
_C.DATA_LOADER.PIN_MEMORY = True
# Enable multi thread decoding.
_C.DATA_LOADER.ENABLE_MULTI_THREAD_DECODE = False
# ---------------------------------------------------------------------------- #
# Detection options.
# ---------------------------------------------------------------------------- #
_C.DETECTION = CfgNode()
# Whether enable video detection.
_C.DETECTION.ENABLE = False
# Aligned version of RoI. More details can be found at slowfast/models/head_helper.py
_C.DETECTION.ALIGNED = True
# Spatial scale factor.
_C.DETECTION.SPATIAL_SCALE_FACTOR = 16
# RoI tranformation resolution.
_C.DETECTION.ROI_XFORM_RESOLUTION = 7
# -----------------------------------------------------------------------------
# AVA Dataset options
# -----------------------------------------------------------------------------
_C.AVA = CfgNode()
# Directory path of frames.
_C.AVA.FRAME_DIR = "/mnt/fair-flash3-east/ava_trainval_frames.img/"
# Directory path for files of frame lists.
_C.AVA.FRAME_LIST_DIR = (
"/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
)
# Directory path for annotation files.
_C.AVA.ANNOTATION_DIR = (
"/mnt/vol/gfsai-flash3-east/ai-group/users/haoqifan/ava/frame_list/"
)
# Filenames of training samples list files.
_C.AVA.TRAIN_LISTS = ["train.csv"]
# Filenames of test samples list files.
_C.AVA.TEST_LISTS = ["val.csv"]
# Filenames of box list files for training. Note that we assume files which
# contains predicted boxes will have a suffix "predicted_boxes" in the
# filename.
_C.AVA.TRAIN_GT_BOX_LISTS = ["ava_train_v2.2.csv"]
_C.AVA.TRAIN_PREDICT_BOX_LISTS = []
# Filenames of box list files for test.
_C.AVA.TEST_PREDICT_BOX_LISTS = ["ava_val_predicted_boxes.csv"]
# This option controls the score threshold for the predicted boxes to use.
_C.AVA.DETECTION_SCORE_THRESH = 0.9
# If use BGR as the format of input frames.
_C.AVA.BGR = False