Skip to content

Commit

Permalink
make chatmessage retriever and writer compatible with the new API (#163)
Browse files Browse the repository at this point in the history
  • Loading branch information
anakin87 authored Jan 8, 2025
1 parent 66cfc02 commit 41903c1
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 42 deletions.
30 changes: 15 additions & 15 deletions test/chat_message_stores/test_in_memory_chat_message_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,11 @@ def test_count_messages(self):
"""
store = InMemoryChatMessageStore()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.count_messages() == 1
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
assert store.count_messages() == 2
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.count_messages() == 3

def test_retrieve(self):
Expand All @@ -55,18 +55,18 @@ def test_retrieve(self):
"""
store = InMemoryChatMessageStore()
assert store.retrieve() == []
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
assert store.retrieve() == [ChatMessage.from_user(content="Hello, how can I help you?")]
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.retrieve() == [ChatMessage.from_user("Hello, how can I help you?")]
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
assert store.retrieve() == [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
]
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.retrieve() == [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?"),
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?"),
]

def test_delete_messages(self):
Expand All @@ -75,12 +75,12 @@ def test_delete_messages(self):
"""
store = InMemoryChatMessageStore()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
store.write_messages(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert store.count_messages() == 1
store.delete_messages()
assert store.count_messages() == 0
store.write_messages(messages=[ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user(content="Hola, ¿cómo puedo ayudarte?")])
store.write_messages(messages=[ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")])
store.write_messages(messages=[ChatMessage.from_user("Hola, ¿cómo puedo ayudarte?")])
assert store.count_messages() == 2
store.delete_messages()
assert store.count_messages() == 0
44 changes: 22 additions & 22 deletions test/components/retrievers/test_chat_message_retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def test_retrieve_messages(self):
Test that the ChatMessageRetriever component can retrieve messages from the message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -40,10 +40,10 @@ def test_retrieve_messages_last_k(self):
Test that the ChatMessageRetriever component can retrieve last_k messages from the message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -52,19 +52,19 @@ def test_retrieve_messages_last_k(self):

assert retriever.message_store == message_store
assert retriever.run(last_k=1) == {
"messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]}
"messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]}

assert retriever.run(last_k=2) == {
"messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

# outliers
assert retriever.run(last_k=10) == {
"messages": [ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

with pytest.raises(ValueError):
Expand All @@ -79,10 +79,10 @@ def test_retrieve_messages_last_k_init(self):
by testing the init last_k parameter and the run last_k parameter logic
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?"),
ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]

message_store = InMemoryChatMessageStore()
Expand All @@ -93,12 +93,12 @@ def test_retrieve_messages_last_k_init(self):

# last_k is 1 here from run parameter, overrides init of 2
assert retriever.run(last_k=1) == {
"messages": [ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")]}
"messages": [ChatMessage.from_user("Bonjour, comment puis-je vous aider?")]}

# last_k is 2 here from init
assert retriever.run() == {
"messages": [ChatMessage.from_user(content="Hola, como puedo ayudarte?"),
ChatMessage.from_user(content="Bonjour, comment puis-je vous aider?")
"messages": [ChatMessage.from_user("Hola, como puedo ayudarte?"),
ChatMessage.from_user("Bonjour, comment puis-je vous aider?")
]}

def test_to_dict(self):
Expand Down Expand Up @@ -157,7 +157,7 @@ def test_chat_message_retriever_pipeline(self):
Context:
{% for memory in memories %}
{{ memory.content }}
{{ memory.text }}
{% endfor %}
Question: {{ query }}
Expand All @@ -166,7 +166,7 @@ def test_chat_message_retriever_pipeline(self):
question = "What is the capital of France?"

res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}})
resulting_prompt = res["prompt_builder"]["prompt"][0].content
resulting_prompt = res["prompt_builder"]["prompt"][0].text
assert "France" in resulting_prompt
assert "how can I help you" in resulting_prompt

Expand Down
10 changes: 5 additions & 5 deletions test/components/writers/test_chat_message_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_init(self):
Test that the ChatMessageWriter component can be initialized with a valid message store.
"""
messages = [
ChatMessage.from_user(content="Hello, how can I help you?"),
ChatMessage.from_user(content="Hallo, wie kann ich Ihnen helfen?")
ChatMessage.from_user("Hello, how can I help you?"),
ChatMessage.from_user("Hallo, wie kann ich Ihnen helfen?")
]

message_store = InMemoryChatMessageStore()
Expand Down Expand Up @@ -42,7 +42,7 @@ def test_to_dict(self):
}

# write again and serialize
writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")])
data = writer.to_dict()
assert data == {
"type": "haystack_experimental.components.writers.chat_message_writer.ChatMessageWriter",
Expand Down Expand Up @@ -74,7 +74,7 @@ def test_from_dict(self):
}

# write to verify that everything is still working
results = writer.run(messages=[ChatMessage.from_user(content="Hello, how can I help you?")])
results = writer.run(messages=[ChatMessage.from_user("Hello, how can I help you?")])
assert results["messages_written"] == 1

def test_chat_message_writer_pipeline(self):
Expand All @@ -97,7 +97,7 @@ def test_chat_message_writer_pipeline(self):
res = pipe.run(data={"prompt_builder": {"template": [ChatMessage.from_user(user_prompt)], "query": question}})
assert res["writer"]["messages_written"] == 1 # only one message is written
assert len(store.retrieve()) == 1 # only one message is written
assert store.retrieve()[0].content == """
assert store.retrieve()[0].text == """
Given the following information, answer the question.
Question: What is the capital of France?
Answer:
Expand Down

0 comments on commit 41903c1

Please sign in to comment.