Skip to content

Commit

Permalink
Update to tools parameter (#1175)
Browse files Browse the repository at this point in the history
* Update to tools parameter

* Use tool_choice argument

* Update response to iterate through new tool_calls property

* Allow tool_calls to be None

* Revert func signature as the local is implicitly optional and adding more typing was not the intention of this change

* Fix tests as the response structure looks different when using tool_calls
  • Loading branch information
tonybaloney authored Jan 25, 2024
1 parent bec59be commit 62c5ae8
Show file tree
Hide file tree
Showing 4 changed files with 147 additions and 22 deletions.
17 changes: 11 additions & 6 deletions app/backend/approaches/chatapproach.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,17 @@ def get_system_prompt(self, override_prompt: Optional[str], follow_up_questions_

def get_search_query(self, chat_completion: ChatCompletion, user_query: str):
response_message = chat_completion.choices[0].message
if function_call := response_message.function_call:
if function_call.name == "search_sources":
arg = json.loads(function_call.arguments)
search_query = arg.get("search_query", self.NO_RESPONSE)
if search_query != self.NO_RESPONSE:
return search_query

if response_message.tool_calls:
for tool in response_message.tool_calls:
if tool.type != "function":
continue
function = tool.function
if function.name == "search_sources":
arg = json.loads(function.arguments)
search_query = arg.get("search_query", self.NO_RESPONSE)
if search_query != self.NO_RESPONSE:
return search_query
elif query_text := response_message.content:
if query_text.strip() != self.NO_RESPONSE:
return query_text
Expand Down
32 changes: 18 additions & 14 deletions app/backend/approaches/chatreadretrieveread.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Any, Coroutine, Literal, Optional, Union, overload
from typing import Any, Coroutine, List, Literal, Optional, Union, overload

from azure.search.documents.aio import SearchClient
from azure.search.documents.models import VectorQuery
from openai import AsyncOpenAI, AsyncStream
from openai.types.chat import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionToolParam,
)

from approaches.approach import ThoughtStep
Expand Down Expand Up @@ -97,19 +98,22 @@ async def run_until_final_call(
original_user_query = history[-1]["content"]
user_query_request = "Generate search query for: " + original_user_query

functions = [
tools: List[ChatCompletionToolParam] = [
{
"name": "search_sources",
"description": "Retrieve sources from the Azure AI Search index",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
}
"type": "function",
"function": {
"name": "search_sources",
"description": "Retrieve sources from the Azure AI Search index",
"parameters": {
"type": "object",
"properties": {
"search_query": {
"type": "string",
"description": "Query string to retrieve documents from azure search eg: 'Health care plan'",
}
},
"required": ["search_query"],
},
"required": ["search_query"],
},
}
]
Expand All @@ -131,8 +135,8 @@ async def run_until_final_call(
temperature=0.0,
max_tokens=100, # Setting too low risks malformed JSON, setting too high may affect performance
n=1,
functions=functions,
function_call="auto",
tools=tools,
tool_choice="auto",
)

query_text = self.get_search_query(chat_completion, original_user_query)
Expand Down
60 changes: 59 additions & 1 deletion tests/test_chatapproach.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,65 @@ def chat_approach():


def test_get_search_query(chat_approach):
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-35-turbo","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"content":"this is the query","role":"assistant","function_call":{"name":"search_sources","arguments":"{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"}},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
payload = """
{
"id": "chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM",
"object": "chat.completion",
"created": 1695324963,
"model": "gpt-35-turbo",
"prompt_filter_results": [
{
"prompt_index": 0,
"content_filter_results": {
"hate": {
"filtered": false,
"severity": "safe"
},
"self_harm": {
"filtered": false,
"severity": "safe"
},
"sexual": {
"filtered": false,
"severity": "safe"
},
"violence": {
"filtered": false,
"severity": "safe"
}
}
}
],
"choices": [
{
"index": 0,
"finish_reason": "function_call",
"message": {
"content": "this is the query",
"role": "assistant",
"tool_calls": [
{
"id": "search_sources1235",
"type": "function",
"function": {
"name": "search_sources",
"arguments": "{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"
}
}
]
},
"content_filter_results": {
}
}
],
"usage": {
"completion_tokens": 19,
"prompt_tokens": 425,
"total_tokens": 444
}
}
"""
default_query = "hello"
chatcompletions = ChatCompletion.model_validate(json.loads(payload), strict=False)
query = chat_approach.get_search_query(chatcompletions, default_query)
Expand Down
60 changes: 59 additions & 1 deletion tests/test_chatvisionapproach.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,65 @@ def test_build_filter(chat_approach):


def test_get_search_query(chat_approach):
payload = '{"id":"chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM","object":"chat.completion","created":1695324963,"model":"gpt-4v","prompt_filter_results":[{"prompt_index":0,"content_filter_results":{"hate":{"filtered":false,"severity":"safe"},"self_harm":{"filtered":false,"severity":"safe"},"sexual":{"filtered":false,"severity":"safe"},"violence":{"filtered":false,"severity":"safe"}}}],"choices":[{"index":0,"finish_reason":"function_call","message":{"content":"this is the query","role":"assistant","function_call":{"name":"search_sources","arguments":"{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"}},"content_filter_results":{}}],"usage":{"completion_tokens":19,"prompt_tokens":425,"total_tokens":444}}'
payload = """
{
"id": "chatcmpl-81JkxYqYppUkPtOAia40gki2vJ9QM",
"object": "chat.completion",
"created": 1695324963,
"model": "gpt-35-turbo",
"prompt_filter_results": [
{
"prompt_index": 0,
"content_filter_results": {
"hate": {
"filtered": false,
"severity": "safe"
},
"self_harm": {
"filtered": false,
"severity": "safe"
},
"sexual": {
"filtered": false,
"severity": "safe"
},
"violence": {
"filtered": false,
"severity": "safe"
}
}
}
],
"choices": [
{
"index": 0,
"finish_reason": "function_call",
"message": {
"content": "this is the query",
"role": "assistant",
"tool_calls": [
{
"id": "search_sources1235",
"type": "function",
"function": {
"name": "search_sources",
"arguments": "{\\n\\"search_query\\":\\"accesstelemedicineservices\\"\\n}"
}
}
]
},
"content_filter_results": {
}
}
],
"usage": {
"completion_tokens": 19,
"prompt_tokens": 425,
"total_tokens": 444
}
}
"""
default_query = "hello"
chatcompletions = ChatCompletion.model_validate(json.loads(payload), strict=False)
query = chat_approach.get_search_query(chatcompletions, default_query)
Expand Down

0 comments on commit 62c5ae8

Please sign in to comment.