Skip to content

Commit

Permalink
Fix #1668 (#1670)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr authored Dec 22, 2023
1 parent 4572a99 commit 6c0d3f2
Showing 1 changed file with 44 additions and 14 deletions.
58 changes: 44 additions & 14 deletions bertopic/backend/_multimodal.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ class MultiModalBackend(BaseEmbedder):
generating word, document, and image embeddings.
Arguments:
embedding_model: A sentence-transformers embedding model
embedding_model: A sentence-transformers embedding model that
can either embed both images and text or only text.
If it only embeds text, then `image_model` needs
to be used to embed the images.
image_model: A sentence-transformers embedding model that is used
to embed only images.
batch_size: The sizes of image batches to pass
Examples:
Expand All @@ -40,10 +45,12 @@ class MultiModalBackend(BaseEmbedder):
"""
def __init__(self,
embedding_model: Union[str, SentenceTransformer],
image_model: Union[str, SentenceTransformer] = None,
batch_size: int = 32):
super().__init__()
self.batch_size = batch_size


# Text or Text+Image model
if isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
Expand All @@ -52,8 +59,25 @@ def __init__(self,
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`")

# Image Model
self.image_model = None
if image_model is not None:
if isinstance(image_model, SentenceTransformer):
self.image_model = image_model
elif isinstance(image_model, str):
self.image_model = SentenceTransformer(image_model)
else:
raise ValueError("Please select a correct SentenceTransformers model: \n"
"`from sentence_transformers import SentenceTransformer` \n"
"`model = SentenceTransformer('clip-ViT-B-32')`")

self.tokenizer = self.embedding_model._first_module().processor.tokenizer
try:
self.tokenizer = self.embedding_model._first_module().processor.tokenizer
except AttributeError:
self.tokenizer = self.embedding_model.tokenizer
except:
self.tokenizer = None

def embed(self,
documents: List[str],
Expand Down Expand Up @@ -136,7 +160,10 @@ def embed_images(self, images, verbose):
end_index = (i * self.batch_size) + self.batch_size

images_to_embed = [Image.open(image) if isinstance(image, str) else image for image in images[start_index:end_index]]
img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
if self.image_model is not None:
img_emb = self.image_model.encode(images_to_embed)
else:
img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
embeddings.extend(img_emb.tolist())

# Close images
Expand All @@ -146,19 +173,22 @@ def embed_images(self, images, verbose):
embeddings = np.array(embeddings)
else:
images_to_embed = [Image.open(filepath) for filepath in images]
embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
if self.image_model is not None:
embeddings = self.image_model.encode(images_to_embed)
else:
embeddings = self.embedding_model.encode(images_to_embed, show_progress_bar=False)
return embeddings

def _truncate_document(self, document):
tokens = self.tokenizer.encode(document)
if self.tokenizer:
tokens = self.tokenizer.encode(document)

if len(tokens) > 77:
# Skip the starting token, only include 75 tokens
truncated_tokens = tokens[1:76]
document = self.tokenizer.decode(truncated_tokens)
if len(tokens) > 77:
# Skip the starting token, only include 75 tokens
truncated_tokens = tokens[1:76]
document = self.tokenizer.decode(truncated_tokens)

# Recursive call here, because the encode(decode()) can have different result
return self._truncate_document(document)
# Recursive call here, because the encode(decode()) can have different result
return self._truncate_document(document)

else:
return document
return document

0 comments on commit 6c0d3f2

Please sign in to comment.