Skip to content

Commit

Permalink
Fix #1797 and #1798
Browse files Browse the repository at this point in the history
  • Loading branch information
MaartenGr committed Feb 12, 2024
1 parent 4cd1d9e commit 1dc3f96
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3237,7 +3237,7 @@ def merge_models(cls, models, min_similarity: float = .7, embedding_model=None):
merged_topics["topic_aspects"][str(new_topic_val)] = selected_topics["topic_aspects"][str(new_topic)]

# Add new embeddings
new_tensors = tensors[new_topic - selected_topics["_outliers"]]
new_tensors = tensors[new_topic + selected_topics["_outliers"]]
merged_tensors = np.vstack([merged_tensors, new_tensors])

# Topic Mapper
Expand Down Expand Up @@ -3663,6 +3663,32 @@ def _combine_zeroshot_topics(self,
self.__dict__.clear()
self.__dict__.update(merged_model.__dict__)
logger.info("Zeroshot Step 3 - Completed \u2713")

# Move -1 topic back to position 0 if it exists
if self._outliers:
nr_zeroshot_topics = len(set(y))

# Re-map the topics such that the -1 topic is at position 0
new_mappings = {}
for topic in self.topics_:
if topic < nr_zeroshot_topics:
new_mappings[topic] = topic
elif topic == nr_zeroshot_topics:
new_mappings[topic] = -1
else:
new_mappings[topic] = topic - 1

# Re-map the topics including all representations (labels, sizes, embeddings, etc.)
self.topics_ = [new_mappings[topic] for topic in self.topics_]
self.topic_representations_ = {new_mappings[topic]: repr for topic, repr in self.topic_representations_.items()}
self.topic_labels_ = {new_mappings[topic]: label for topic, label in self.topic_labels_.items()}
self.topic_sizes_ = collections.Counter(self.topics_)
self.topic_embeddings_ = np.vstack([
self.topic_embeddings_[nr_zeroshot_topics],
self.topic_embeddings_[:nr_zeroshot_topics],
self.topic_embeddings_[nr_zeroshot_topics+1:]
])

return self.topics_

def _guided_topic_modeling(self, embeddings: np.ndarray) -> Tuple[List[int], np.array]:
Expand Down

0 comments on commit 1dc3f96

Please sign in to comment.