Skip to content

Commit

Permalink
Rename TransformerTextualHead and USE_LOOKAHEAD config.
Browse files Browse the repository at this point in the history
  • Loading branch information
kdexd committed Apr 4, 2021
1 parent 11c5793 commit ff7fe24
Show file tree
Hide file tree
Showing 22 changed files with 72 additions and 71 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model = torch.hub.load("kdexd/virtex", "resnet50", pretrained=True)

### Note (For returning users before January 2021):

The pretrained models in our model zoo have changed in [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0).
The pretrained models in our model zoo have changed from [`v1.0`](https://github.com/kdexd/virtex/releases/tag/v1.0) onwards.
They are slightly better tuned than older models, and reproduce the results in our
CVPR 2021 accepted paper ([arXiv v2](https://arxiv.org/abs/2006.06666v2)).
Some training and evaluation hyperparams are changed since [`v0.9`](https://github.com/kdexd/virtex/releases/tag/v0.9).
Expand Down
27 changes: 15 additions & 12 deletions configs/_base_bicaptioning_R_50_L1_H1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
# ResNet-50 + (L = 1, H = 1024) transformer trained for 500K iterations.
# -----------------------------------------------------------------------------
RANDOM_SEED: 0
AMP: True
CUDNN_BENCHMARK: True
CUDNN_DETERMINISTIC: False
AMP: true
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
ROOT: "datasets/coco"
Expand All @@ -31,33 +31,36 @@ DATA:
- "normalize"

USE_PERCENTAGE: 100.0
USE_SINGLE_CAPTION: False
USE_SINGLE_CAPTION: false

MODEL:
NAME: "bicaptioning"
NAME: "virtex"
VISUAL:
NAME: "torchvision::resnet50"
PRETRAINED: false
FROZEN: false
TEXTUAL:
NAME: "transformer_postnorm::L1_H1024_A16_F4096"
NAME: "transdec_postnorm::L1_H1024_A16_F4096"
DROPOUT: 0.1

OPTIM:
OPTIMIZER_NAME: "sgd"
SGD_MOMENTUM: 0.9
WEIGHT_DECAY: 0.0001
NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)"
CLIP_GRAD_NORM: 10

USE_LOOKAHEAD: True
LOOKAHEAD_ALPHA: 0.5
LOOKAHEAD_STEPS: 5
LOOKAHEAD:
USE: true
ALPHA: 0.5
STEPS: 5

BATCH_SIZE: 256
CNN_LR: 0.2
LR: 0.001
NUM_ITERATIONS: 500000

WARMUP_STEPS: 10000
LR_DECAY_NAME: cosine
LR_DECAY_NAME: "cosine"

NO_DECAY: ".*textual.(embedding|transformer).*(norm.*|bias)"
CLIP_GRAD_NORM: 10.0

2 changes: 1 addition & 1 deletion configs/depth_ablations/bicaptioning_R_50_L2_H1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L2_H1024_A16_F4096"
NAME: "transdec_postnorm::L2_H1024_A16_F4096"
2 changes: 1 addition & 1 deletion configs/depth_ablations/bicaptioning_R_50_L3_H1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L3_H1024_A16_F4096"
NAME: "transdec_postnorm::L3_H1024_A16_F4096"
2 changes: 1 addition & 1 deletion configs/depth_ablations/bicaptioning_R_50_L4_H1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L4_H1024_A16_F4096"
NAME: "transdec_postnorm::L4_H1024_A16_F4096"
9 changes: 5 additions & 4 deletions configs/downstream/imagenet_clf.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
RANDOM_SEED: 0
# Don't need AMP to train a tiny linear layer.
AMP: False
CUDNN_BENCHMARK: True
CUDNN_DETERMINISTIC: False
AMP: false
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
ROOT: "datasets/imagenet"
Expand All @@ -24,7 +24,8 @@ OPTIM:
SGD_MOMENTUM: 0.9
WEIGHT_DECAY: 0.0
NO_DECAY: "none"
USE_LOOKAHEAD: False
LOOKAHEAD:
USE: false

LR: 0.3
WARMUP_STEPS: 0
Expand Down
9 changes: 5 additions & 4 deletions configs/downstream/inaturalist_clf.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
RANDOM_SEED: 0
AMP: True
CUDNN_BENCHMARK: True
CUDNN_DETERMINISTIC: False
AMP: true
CUDNN_BENCHMARK: true
CUDNN_DETERMINISTIC: false

DATA:
ROOT: "datasets/inaturalist"
Expand All @@ -23,7 +23,8 @@ OPTIM:
SGD_MOMENTUM: 0.9
WEIGHT_DECAY: 0.0001
NO_DECAY: "none"
USE_LOOKAHEAD: False
LOOKAHEAD:
USE: false

LR: 0.025
WARMUP_STEPS: 0
Expand Down
2 changes: 1 addition & 1 deletion configs/task_ablations/bicaptioning_R_50_L1_H2048.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L1_H2048_A32_F8192"
NAME: "transdec_postnorm::L1_H2048_A32_F8192"
2 changes: 1 addition & 1 deletion configs/task_ablations/captioning_R_50_L1_H2048.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
MODEL:
NAME: "captioning"
TEXTUAL:
NAME: "transformer_postnorm::L1_H2048_A32_F8192"
NAME: "transdec_postnorm::L1_H2048_A32_F8192"
2 changes: 1 addition & 1 deletion configs/task_ablations/masked_lm_R_50_L1_H2048.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"
MODEL:
NAME: "masked_lm"
TEXTUAL:
NAME: "transformer_postnorm::L1_H2048_A32_F8192"
NAME: "transdec_postnorm::L1_H2048_A32_F8192"
2 changes: 1 addition & 1 deletion configs/width_ablations/bicaptioning_R_50_L1_H2048.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L1_H2048_A32_F8192"
NAME: "transdec_postnorm::L1_H2048_A32_F8192"
2 changes: 1 addition & 1 deletion configs/width_ablations/bicaptioning_R_50_L1_H512.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L1_H512_A8_F2048"
NAME: "transdec_postnorm::L1_H512_A8_F2048"
2 changes: 1 addition & 1 deletion configs/width_ablations/bicaptioning_R_50_L1_H768.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ _BASE_: "../_base_bicaptioning_R_50_L1_H1024.yaml"

MODEL:
TEXTUAL:
NAME: "transformer_postnorm::L1_H768_A12_F3072"
NAME: "transdec_postnorm::L1_H768_A12_F3072"
6 changes: 3 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
author = "Karan Desai"

# The full version, including alpha/beta/rc tags
release = "1.0"
release = "1.1"


# -- General configuration ---------------------------------------------------
Expand Down Expand Up @@ -62,9 +62,9 @@
# built documents.
#
# This version is used underneath the title on the index page.
version = "1.0"
version = "1.1"
# The following is used if you need to also include a more detailed version.
release = "1.0"
release = "1.1"

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
2 changes: 1 addition & 1 deletion docs/virtex/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@ Config References
.. literalinclude:: ../../virtex/config.py
:language: python
:linenos:
:lines: 46-205
:lines: 46-206
:dedent: 8
10 changes: 0 additions & 10 deletions docs/virtex/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ virtex.models

<hr>

Pretraining Models
------------------

.. automodule:: virtex.models.classification

-------------------------------------------------------------------------------
Expand All @@ -17,10 +14,3 @@ Pretraining Models
-------------------------------------------------------------------------------

.. automodule:: virtex.models.masked_lm

-------------------------------------------------------------------------------

Downstream Models
-----------------

.. automodule:: virtex.models.downstream
3 changes: 2 additions & 1 deletion scripts/eval_detectron2.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""
Finetune a pre-trained model on a downstream task, one of those available in
Detectron2. Optionally use gradient checkpointing for saving memory.
Detectron2.
Supported downstream:
- LVIS Instance Segmentation
- COCO Instance Segmentation
- Pascal VOC 2007+12 Object Detection
Reference: https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_model_zoo_configs() -> List[str]:

setup(
name="virtex",
version="1.0.0",
version="1.1.0",
author="Karan Desai and Justin Johnson",
description="VirTex: Learning Visual Representations with Textual Annotations",
package_data={"virtex.model_zoo": get_model_zoo_configs()},
Expand Down
13 changes: 7 additions & 6 deletions virtex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def __init__(

_C.MODEL.TEXTUAL = CN()
# Name of textual head. Set to "none" for MODEL.NAME = "*_classification".
# Possible choices: {"transformer_postnorm", "transformer_prenorm"}.
# Possible choices: {"transdec_postnorm", "transdec_prenorm"}.
# Architectural hyper-parameters are specified as shown above.
_C.MODEL.TEXTUAL.NAME = "transformer_postnorm::L1_H2048_A32_F8192"
_C.MODEL.TEXTUAL.NAME = "transdec_postnorm::L1_H2048_A32_F8192"
# L = Number of layers in the transformer.
# H = Hidden size of the transformer (embeddings, attention features).
# A = Number of attention heads in the transformer.
Expand All @@ -174,12 +174,13 @@ def __init__(
# Regex pattern of params for which there will be no weight decay.
_C.OPTIM.NO_DECAY = ".*textual.(embedding|transformer).*(norm.*|bias)"
# Max gradient norm for clipping to avoid exploding gradients.
_C.OPTIM.CLIP_GRAD_NORM = 10
_C.OPTIM.CLIP_GRAD_NORM = 10.0

# Wrap our optimizer with Lookahead (https://arxiv.org/abs/1907.08610).
_C.OPTIM.USE_LOOKAHEAD = False
_C.OPTIM.LOOKAHEAD_ALPHA = 0.5
_C.OPTIM.LOOKAHEAD_STEPS = 5
_C.OPTIM.LOOKAHEAD = CN()
_C.OPTIM.LOOKAHEAD.USE = True
_C.OPTIM.LOOKAHEAD.ALPHA = 0.5
_C.OPTIM.LOOKAHEAD.STEPS = 5

# We set different learning rates for CNN (visual backbone) and rest of
# the model. CNN LR is typically much higher for training from scratch.
Expand Down
27 changes: 14 additions & 13 deletions virtex/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,10 @@ class VisualBackboneFactory(Factory):
Use the method name for model as in torchvision, for example,
``torchvision::resnet50``, ``torchvision::wide_resnet50_2`` etc.
Possible choices: ``{"blind", "torchvision"}``.
Possible choices: ``{"torchvision"}``.
"""

PRODUCTS: Dict[str, Callable] = {
"blind": visual_backbones.BlindVisualBackbone,
"torchvision": visual_backbones.TorchvisionVisualBackbone,
}

Expand Down Expand Up @@ -358,23 +357,25 @@ class TextualHeadFactory(Factory):
r"""
Factory to create :mod:`~virtex.modules.textual_heads`. Architectural
hyperparameters for transformers can be specified as ``name::*``.
For example, ``transformer_postnorm::L1_H1024_A16_F4096`` would create a
For example, ``transdec_postnorm::L1_H1024_A16_F4096`` would create a
transformer textual head with ``L = 1`` layers, ``H = 1024`` hidden size,
``A = 16`` attention heads, and ``F = 4096`` size of feedforward layers.
Textual head should be ``"none"`` for pretraining tasks which do not
involve language modeling, such as ``"token_classification"``.
Possible choices: ``{"transformer_postnorm", "transformer_prenorm", "none"}``.
Possible choices: ``{"transdec_postnorm", "transdec_prenorm", "none"}``.
"""

# fmt: off
PRODUCTS: Dict[str, Callable] = {
"transformer_prenorm": partial(textual_heads.TransformerTextualHead, norm_type="pre"),
"transformer_postnorm": partial(textual_heads.TransformerTextualHead, norm_type="post"),
"transdec_prenorm": partial(
textual_heads.TransformerDecoderTextualHead, norm_type="pre"
),
"transdec_postnorm": partial(
textual_heads.TransformerDecoderTextualHead, norm_type="post"
),
"none": textual_heads.LinearTextualHead,
}
# fmt: on

@classmethod
def from_config(cls, config: Config) -> nn.Module:
Expand All @@ -394,7 +395,7 @@ def from_config(cls, config: Config) -> nn.Module:
"vocab_size": _C.DATA.VOCAB_SIZE,
}

if "transformer" in _C.MODEL.TEXTUAL.NAME:
if "trans" in _C.MODEL.TEXTUAL.NAME:
# Get architectural hyper-params as per name by matching regex.
name, architecture = name.split("::")
architecture = re.match(r"L(\d+)_H(\d+)_A(\d+)_F(\d+)", architecture)
Expand All @@ -405,15 +406,15 @@ def from_config(cls, config: Config) -> nn.Module:
feedforward_size = int(architecture.group(4))

# Mask the future tokens for autoregressive captioning.
mask_future_positions = _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}
mask_future = _C.MODEL.NAME in {"virtex", "captioning", "bicaptioning"}

kwargs.update(
hidden_size=hidden_size,
num_layers=num_layers,
attention_heads=attention_heads,
feedforward_size=feedforward_size,
dropout=_C.MODEL.TEXTUAL.DROPOUT,
mask_future_positions=mask_future_positions,
mask_future_positions=mask_future,
max_caption_length=_C.DATA.MAX_CAPTION_LENGTH,
padding_idx=_C.DATA.UNK_INDEX,
)
Expand Down Expand Up @@ -521,9 +522,9 @@ def from_config(
kwargs = {}

optimizer = cls.create(_C.OPTIM.OPTIMIZER_NAME, param_groups, **kwargs)
if _C.OPTIM.USE_LOOKAHEAD:
if _C.OPTIM.LOOKAHEAD.USE:
optimizer = Lookahead(
optimizer, k=_C.OPTIM.LOOKAHEAD_STEPS, alpha=_C.OPTIM.LOOKAHEAD_ALPHA
optimizer, k=_C.OPTIM.LOOKAHEAD.STEPS, alpha=_C.OPTIM.LOOKAHEAD.ALPHA
)
return optimizer

Expand Down
2 changes: 1 addition & 1 deletion virtex/modules/textual_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def forward(
return output_logits


class TransformerTextualHead(TextualHead):
class TransformerDecoderTextualHead(TextualHead):
r"""
A textual head composed of four main modules: (1) input projection (linear
layer) for visual features to match size with textual features, (2) word
Expand Down
13 changes: 8 additions & 5 deletions virtex/modules/visual_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ def __init__(
self.cnn = getattr(torchvision.models, name)(
pretrained, zero_init_residual=True
)
# Reove global average pooling and fc layer.
self.cnn.avgpool = nn.Identity()
# Do nothing after the final residual stage.
self.cnn.fc = nn.Identity()

# Freeze all weights if specified.
Expand All @@ -74,9 +73,13 @@ def forward(self, image: torch.Tensor) -> torch.Tensor:
example it will be ``(batch_size, 2048, 7, 7)`` for ResNet-50.
"""

# shape: (batch_size, channels, height, width)
# [ResNet-50: (b, 2048, 7, 7)]
return self.cnn(image)
for idx, (name, layer) in enumerate(self.cnn.named_children()):
out = layer(image) if idx == 0 else layer(out)

# These are the spatial features we need.
if name == "layer4":
# shape: (batch_size, channels, height, width)
return out

def detectron2_backbone_state_dict(self) -> Dict[str, Any]:
r"""
Expand Down

0 comments on commit ff7fe24

Please sign in to comment.