Skip to content

Truncation of sequences that are beyond the model's maximum length #359

@MootezSaaD

Description

@MootezSaaD

Hi,
First, I would like to thank you for this library :-) I'm really enjoying it.

I tried to tokenize a sequence with around 4K tokens and then fed it to a RoBERTa-based model (CodeBERT). This led to the following issue,

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
[<ipython-input-12-a373b5333f39>](https://localhost:8080/#) in <cell line: 1>()
      2    ids = input_sentence.padded_tensor(padding_id=0, pad_left=True)
      3    mask = input.attention_mask(pad_left=True)
----> 4    model_output = encoder(piece_ids=ids, attention_mask=mask)

10 frames
[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, attention_mask, positions, type_ids)
    122         type_ids: Optional[Tensor] = None,
    123     ) -> ModelOutput:
--> 124         embeddings = self.embeddings(piece_ids, positions=positions, type_ids=type_ids)
    125         layer_output = embeddings
    126 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/curated_transformers/models/roberta/embeddings.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
     96         if positions is None:
     97             positions = self._get_positions(piece_ids)
---> 98         return super().forward(
     99             piece_ids,
    100             positions=positions,

[/usr/local/lib/python3.10/dist-packages/curated_transformers/layers/transformer.py](https://localhost:8080/#) in forward(self, piece_ids, positions, type_ids)
    180             if positions is None:
    181                 positions = self._get_positions(piece_ids)
--> 182             position_embeddings = self.position_embeddings(positions)
    183             embeddings += position_embeddings
    184 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _wrapped_call_impl(self, *args, **kwargs)
   1516             return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517         else:
-> 1518             return self._call_impl(*args, **kwargs)
   1519 
   1520     def _call_impl(self, *args, **kwargs):

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *args, **kwargs)
   1525                 or _global_backward_pre_hooks or _global_backward_hooks
   1526                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527             return forward_call(*args, **kwargs)
   1528 
   1529         try:

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py](https://localhost:8080/#) in forward(self, input)
    160 
    161     def forward(self, input: Tensor) -> Tensor:
--> 162         return F.embedding(
    163             input, self.weight, self.padding_idx, self.max_norm,
    164             self.norm_type, self.scale_grad_by_freq, self.sparse)

[/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
   2231         # remove once script supports set_grad_enabled
   2232         _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 2233     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
   2234 
   2235 

IndexError: index out of range in self

For reference, here was the code that I was using,

MODEL_TAG = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_hf_hub(name=MODEL_TAG, revision="main")
model = RoBERTaEncoder.from_hf_hub(
    name=MODEL_TAG,
    revision="main",
)
code = [
   'void avcodec_string(char *buf, int buf_size, AVCodecContext *enc, int encode)\n\n{\n\n    const char *codec_type;\n\n    const char *codec_name;\n\n    const char *profile = NULL;\n\n    const AVCodec *p;\n\n    int64_t bitrate;\n\n    int new_line = 0;\n\n    AVRational display_aspect_ratio;\n\n    const char *separator = enc->dump_separator ? (const char *)enc->dump_separator : ", ";\n\n\n\n    if (!buf || buf_size <= 0)\n\n        return;\n\n    codec_type = av_get_media_type_string(enc->codec_type);\n\n    codec_name = avcodec_get_name(enc->codec_id);\n\n    if (enc->profile != FF_PROFILE_UNKNOWN) {\n\n        if (enc->codec)\n\n            p = enc->codec;\n\n        else\n\n            p = encode ? avcodec_find_encoder(enc->codec_id) :\n\n                        avcodec_find_decoder(enc->codec_id);\n\n        if (p)\n\n            profile = av_get_profile_name(p, enc->profile);\n\n    }\n\n\n\n    snprintf(buf, buf_size, "%s: %s", codec_type ? codec_type : "unknown",\n\n             codec_name);\n\n    buf[0] ^= \'a\' ^ \'A\'; /* first letter in uppercase */\n\n\n\n    if (enc->codec && strcmp(enc->codec->name, codec_name))\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", enc->codec->name);\n\n\n\n    if (profile)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf), " (%s)", profile);\n\n    if (   enc->codec_type == AVMEDIA_TYPE_VIDEO\n\n        && av_log_get_level() >= AV_LOG_VERBOSE\n\n        && enc->refs)\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %d reference frame%s",\n\n                 enc->refs, enc->refs > 1 ? "s" : "");\n\n\n\n    if (enc->codec_tag) {\n\n        char tag_buf[32];\n\n        av_get_codec_tag_string(tag_buf, sizeof(tag_buf), enc->codec_tag);\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 " (%s / 0x%04X)", tag_buf, enc->codec_tag);\n\n    }\n\n\n\n    switch (enc->codec_type) {\n\n    case AVMEDIA_TYPE_VIDEO:\n\n        {\n\n            char detail[256] = "(";\n\n\n\n            av_strlcat(buf, separator, buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 "%s", enc->pix_fmt == AV_PIX_FMT_NONE ? "none" :\n\n                     av_get_pix_fmt_name(enc->pix_fmt));\n\n            if (enc->bits_per_raw_sample && enc->pix_fmt != AV_PIX_FMT_NONE &&\n\n                enc->bits_per_raw_sample < av_pix_fmt_desc_get(enc->pix_fmt)->comp[0].depth)\n\n                av_strlcatf(detail, sizeof(detail), "%d bpc, ", enc->bits_per_raw_sample);\n\n            if (enc->color_range != AVCOL_RANGE_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_color_range_name(enc->color_range));\n\n\n\n            if (enc->colorspace != AVCOL_SPC_UNSPECIFIED ||\n\n                enc->color_primaries != AVCOL_PRI_UNSPECIFIED ||\n\n                enc->color_trc != AVCOL_TRC_UNSPECIFIED) {\n\n                if (enc->colorspace != (int)enc->color_primaries ||\n\n                    enc->colorspace != (int)enc->color_trc) {\n\n                    new_line = 1;\n\n                    av_strlcatf(detail, sizeof(detail), "%s/%s/%s, ",\n\n                                av_color_space_name(enc->colorspace),\n\n                                av_color_primaries_name(enc->color_primaries),\n\n                                av_color_transfer_name(enc->color_trc));\n\n                } else\n\n                    av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                                av_get_colorspace_name(enc->colorspace));\n\n            }\n\n\n\n            if (av_log_get_level() >= AV_LOG_DEBUG &&\n\n                enc->chroma_sample_location != AVCHROMA_LOC_UNSPECIFIED)\n\n                av_strlcatf(detail, sizeof(detail), "%s, ",\n\n                            av_chroma_location_name(enc->chroma_sample_location));\n\n\n\n            if (strlen(detail) > 1) {\n\n                detail[strlen(detail) - 2] = 0;\n\n                av_strlcatf(buf, buf_size, "%s)", detail);\n\n            }\n\n        }\n\n\n\n        if (enc->width) {\n\n            av_strlcat(buf, new_line ? separator : ", ", buf_size);\n\n\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%dx%d",\n\n                     enc->width, enc->height);\n\n\n\n            if (av_log_get_level() >= AV_LOG_VERBOSE &&\n\n                (enc->width != enc->coded_width ||\n\n                 enc->height != enc->coded_height))\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " (%dx%d)", enc->coded_width, enc->coded_height);\n\n\n\n            if (enc->sample_aspect_ratio.num) {\n\n                av_reduce(&display_aspect_ratio.num, &display_aspect_ratio.den,\n\n                          enc->width * enc->sample_aspect_ratio.num,\n\n                          enc->height * enc->sample_aspect_ratio.den,\n\n                          1024 * 1024);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         " [SAR %d:%d DAR %d:%d]",\n\n                         enc->sample_aspect_ratio.num, enc->sample_aspect_ratio.den,\n\n                         display_aspect_ratio.num, display_aspect_ratio.den);\n\n            }\n\n            if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n                int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n            }\n\n        }\n\n        if (encode) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", q=%d-%d", enc->qmin, enc->qmax);\n\n        } else {\n\n            if (enc->properties & FF_CODEC_PROPERTY_CLOSED_CAPTIONS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", Closed Captions");\n\n            if (enc->properties & FF_CODEC_PROPERTY_LOSSLESS)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", lossless");\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_AUDIO:\n\n        av_strlcat(buf, separator, buf_size);\n\n\n\n        if (enc->sample_rate) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     "%d Hz, ", enc->sample_rate);\n\n        }\n\n        av_get_channel_layout_string(buf + strlen(buf), buf_size - strlen(buf), enc->channels, enc->channel_layout);\n\n        if (enc->sample_fmt != AV_SAMPLE_FMT_NONE) {\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %s", av_get_sample_fmt_name(enc->sample_fmt));\n\n        }\n\n        if (   enc->bits_per_raw_sample > 0\n\n            && enc->bits_per_raw_sample != av_get_bytes_per_sample(enc->sample_fmt) * 8)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     " (%d bit)", enc->bits_per_raw_sample);\n\n        break;\n\n    case AVMEDIA_TYPE_DATA:\n\n        if (av_log_get_level() >= AV_LOG_DEBUG) {\n\n            int g = av_gcd(enc->time_base.num, enc->time_base.den);\n\n            if (g)\n\n                snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                         ", %d/%d",\n\n                         enc->time_base.num / g, enc->time_base.den / g);\n\n        }\n\n        break;\n\n    case AVMEDIA_TYPE_SUBTITLE:\n\n        if (enc->width)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", %dx%d", enc->width, enc->height);\n\n        break;\n\n    default:\n\n        return;\n\n    }\n\n    if (encode) {\n\n        if (enc->flags & AV_CODEC_FLAG_PASS1)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 1");\n\n        if (enc->flags & AV_CODEC_FLAG_PASS2)\n\n            snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                     ", pass 2");\n\n    }\n\n    bitrate = get_bit_rate(enc);\n\n    if (bitrate != 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", %"PRId64" kb/s", bitrate / 1000);\n\n    } else if (enc->rc_max_rate > 0) {\n\n        snprintf(buf + strlen(buf), buf_size - strlen(buf),\n\n                 ", max. %"PRId64" kb/s", (int64_t)enc->rc_max_rate / 1000);\n\n    }\n\n}\n',
]
with torch.no_grad():
    input_sentence = tokenizer(code)
    ids = input_sentence.padded_tensor(padding_id=0, pad_left=False)
    mask = input_sentence.attention_mask(pad_left=False)
    model_output = model(piece_ids=ids, attention_mask=mask)

I went through the API docs and skimmed through source code and it appears that truncation is not supported. Note that when I manually truncated the sequence, I was able to feed it to the RoBERTa encoder.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions