Skip to content

Commit

Permalink
Update _langchain.py with [KEYWORDS] tag option (#1871)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcantimmy authored Apr 1, 2024
1 parent 4a522ab commit 424cefc
Showing 1 changed file with 26 additions and 6 deletions.
32 changes: 26 additions & 6 deletions bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ class LangChain(BaseRepresentation):
Output key must be `output_text`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
nr_docs: The number of documents to pass to LangChain if a prompt
with the `["DOCUMENTS"]` tag is used.
NOTE: Use `"[KEYWORDS]"` in the prompt
to decide where the keywords need to be
inserted. Keywords won't be included unless
indicated. Unlike other representation models,
Langchain does not use the `"[DOCUMENTS]"` tag
to insert documents into the prompt. The load_qa_chain function
formats the representative documents within the prompt.
nr_docs: The number of documents to pass to LangChain
diversity: The diversity of documents to pass to LangChain.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
Expand Down Expand Up @@ -185,10 +191,24 @@ def extract_topics(self,
]

# `self.chain` must take `input_documents` and `question` as input keys
inputs = [
{"input_documents": docs, "question": self.prompt}
for docs in chain_docs
]
# Use a custom prompt that leverages keywords, using the tag: [KEYWORDS]
if "[KEYWORDS]" in self.prompt:
prompts = []
for topic in topics:
keywords = list(zip(*topics[topic]))[0]
prompt = self.prompt.replace("[KEYWORDS]", ", ".join(keywords))
prompts.append(prompt)

inputs = [
{"input_documents": docs, "question": prompt}
for docs, prompt in zip(chain_docs, prompts)
]

else:
inputs = [
{"input_documents": docs, "question": self.prompt}
for docs in chain_docs
]

# `self.chain` must return a dict with an `output_text` key
# same output key as the `StuffDocumentsChain` returned by `load_qa_chain`
Expand Down

0 comments on commit 424cefc

Please sign in to comment.