Skip to content

Fix beam search bug, add nucleus sampling support.

Compare
Choose a tag to compare
@kdexd kdexd released this 15 Jul 21:44
· 11 commits to master since this release

Bug Fix: Beam Search

The beam search implementation adapted from AllenNLP was more suited for LSTM/GRU (recurrent models), less for transformers (autoregressive models).
This version removes the "backpointer" trick from AllenNLP implementation and improves captioning results for all VirTex models. See below, "Old" metrics are v1.1 (ArXiv v2) and "New" metrics are v1.2 (ArXiv v3).

image

This bug does not affect pre-training or other downstream task results. Thanks to Nicolas Carion (@alcinos) and Aishwarya Kamath (@ashkamath) for spotting this issue and helping me to fix it!

Feature: Nucleus Sampling

This codebase now supports decoding through Nucleus Sampling, as introduced in The Curious Case of Neural Text Degeneration. Try running captioning evaluation script with --config-override MODEL.DECODER.NAME nucleus_sampling MODEL.DECODER.NUCLEUS_SIZE 0.9! To have consistent behavior with prior versions, the default decoding method is Beam Search with 5 beams.

Note: Nucleus sampling would give worse results specifically on COCO Captions, but will produce more interesting sounding language with larger transformers trained on much more data than COCO Captions.

New config arguments to support this:

MODEL:
  DECODER:
    # What algorithm to use for decoding. Supported values: {"beam_search",
    # "nucleus_sampling"}.
    NAME: "beam_search"

    # Number of beams to decode (1 = greedy decoding). Ignored when decoding
    # through nucleus sampling.
    BEAM_SIZE: 5

    # Size of nucleus for sampling predictions. Ignored when decoding through
    # beam search.
    NUCLEUS_SIZE: 0.9

    # Maximum length of decoded caption. Decoding may end earlier when [EOS]
    # token is sampled.
    MAX_DECODING_STEPS: 50  # Same as DATA.MAX_CAPTION_LENGTH