diff --git a/src/mistral_inference/transformer.py b/src/mistral_inference/transformer.py index 9c9aebe..484b02e 100644 --- a/src/mistral_inference/transformer.py +++ b/src/mistral_inference/transformer.py @@ -159,7 +159,7 @@ def forward_partial( if self.pipeline_rank == 0: assert self.tok_embeddings is not None - if self.vision_encoder is not None and images: + if self.vision_encoder is not None and images is not None: h = self.embed_vision_language_features(input_ids, images) else: h = self.tok_embeddings(input_ids)