Skip to content

Commit

Permalink
Updates:
Browse files Browse the repository at this point in the history
- chains.py: Supporting "tool_choice".
- Update cookbook examples.
- Improve promopts, for "force" mode.
  • Loading branch information
unclecode committed Mar 17, 2024
1 parent d828608 commit 18f35c9
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 11 deletions.
35 changes: 27 additions & 8 deletions app/libs/chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from importlib import import_module
import json
import uuid
import traceback
from fastapi import Request
from fastapi.responses import JSONResponse
from providers import BaseProvider
from prompts import SYSTEM_MESSAGE, SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide
from prompts import SYSTEM_MESSAGE, ENFORCED_SYSTAME_MESSAE, SUFFIX, FORCE_CALL_SUFFIX, CLEAN_UP_MESSAGE, get_func_result_guide, get_forced_tool_suffix
from providers import GroqProvider
import importlib
from utils import get_tool_call_response, create_logger
Expand All @@ -19,8 +20,10 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]):
self.provider = provider
self.body = body
self.response = None

# extract all keys from body except messages and tools and set in params
self.params = {k: v for k, v in body.items() if k not in ["messages", "tools"]}

# self.no_tool_behaviour = self.params.get("no_tool_behaviour", "return")
self.no_tool_behaviour = self.params.get("no_tool_behaviour", "forward")
self.params.pop("no_tool_behaviour", None)
Expand Down Expand Up @@ -50,8 +53,6 @@ def __init__(self, request: Request, provider: str, body: Dict[str, Any]):
bt['extra'] = self.params.get("extra", {})
self.params.pop("extra", None)



self.client : BaseProvider = None

@property
Expand All @@ -60,7 +61,7 @@ def last_message(self):

@property
def is_tool_call(self):
return bool(self.last_message["role"] == "user" and self.tools)
return bool(self.last_message["role"] == "user" and self.tools and self.params.get("tool_choice", "none") != "none")

@property
def is_tool_response(self):
Expand Down Expand Up @@ -88,6 +89,7 @@ async def handle(self, context: Context):
return await self._next_handler.handle(context)
except Exception as e:
_exception_handler: "Handler" = ExceptionHandler()
# Extract the stack trace and log the exception
return await _exception_handler.handle(context, e)


Expand Down Expand Up @@ -130,19 +132,35 @@ class ToolExtractionHandler(Handler):
async def handle(self, context: Context):
body = context.body
if context.is_tool_call:

# Prepare the messages and tools for the tool extraction
messages = [
f"{m['role'].title()}: {m['content']}"
for m in context.messages
if m["role"] != "system"
]

tools_json = json.dumps([t["function"] for t in context.tools], indent=4)

# Process the tool_choice
tool_choice = context.params.get("tool_choice", "auto")
forced_mode = False
if type(tool_choice) == dict and tool_choice.get("type", None) == "function":
tool_choice = tool_choice["function"].get("name", None)
if not tool_choice:
raise ValueError("Invalid tool choice. 'tool_choice' is set to a dictionary with 'type' as 'function', but 'function' does not have a 'name' key.")
forced_mode = True

# Regenerate the string tool_json and keep only the forced tool
tools_json = json.dumps([t["function"] for t in context.tools if t["function"]["name"] == tool_choice], indent=4)

system_message = SYSTEM_MESSAGE if not forced_mode else ENFORCED_SYSTAME_MESSAE
suffix = SUFFIX if not forced_mode else get_forced_tool_suffix(tool_choice)

new_messages = [
{"role": "system", "content": SYSTEM_MESSAGE},
{"role": "system", "content": system_message},
{
"role": "system",
"content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{SUFFIX}",
"content": f"Conversation History:\n{''.join(messages)}\n\nTools: \n{tools_json}\n\n{suffix}",
},
]

Expand Down Expand Up @@ -309,4 +327,5 @@ async def handle(self, context: Context):
class ExceptionHandler(Handler):
async def handle(self, context: Context, exception: Exception):
print(f"Error processing the request: {exception}")
return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500)
print(traceback.format_exc())
return JSONResponse(content={"error": "An unexpected error occurred. " + str(exception)}, status_code=500)
33 changes: 32 additions & 1 deletion app/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,44 @@
** If no tools are required, then return an empty list for "tool_calls". **
**Wrap the JSON response between ```json and ```**.
**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**.
**Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response, do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user.**"""


ENFORCED_SYSTAME_MESSAE = """A history of conversations between an AI assistant and the user, plus the last user's message, is given to you.
You have access to a specific tool that the AI assistant must use to provide a proper answer. The tool is a function that requires a set of parameters, which are provided in a JSON schema to explain what parameters the tool needs. Your task is to extract the values for these parameters from the user's last message and the conversation history.
Your job is to closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response.
You should think step by step, provide your reasoning for your response, then add the JSON response at the end following the below schema:
{
"tool_calls": [{
"name": "function_name",
"arguments": {
"arg1": "value1",
"arg2": "value2",
...
}]
}
}
**Wrap the JSON response between ```json and ```, and rememebr "tool_calls" is a list.**.
Whenever a message starts with 'SYSTEM MESSAGE', that is a guide and help information for you to generate your next response. Do not consider them a message from the user, and do not reply to them at all. Just use the information and continue your conversation with the user."""

CLEAN_UP_MESSAGE = "When I tried to extract the content between ```json and ``` and parse the content to valid JSON object, I faced with the abovr error. Remember, you are supposed to wrap the schema between ```json and ```, and do this only one time. First find out what went wrong, that I couldn't extract the JSON between ```json and ```, and also faced error when trying to parse it, then regenerate the your last message and fix the issue."

SUFFIX = """Think step by step and justify your response. Make sure to not miss in case to answer user query we need multiple tools, in that case detect all that we need, then generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME."""

FORCE_CALL_SUFFIX = """For this task, you HAVE to choose the tool (function) {tool_name}, and ignore other rools. Therefore think step by step and justify your response, then closely examine the user's last message and the history of the conversation, then extract the necessary parameter values for the given tool based on the provided JSON schema. Remember that you must use the specified tool to generate the response. Finally generate a JSON response wrapped between "```json" and "```". Remember to USE THIS JSON WRAPPER ONLY ONE TIME."""

def get_forced_tool_suffix(tool_name : str) -> str:
return FORCE_CALL_SUFFIX.format(tool_name=tool_name)

def get_func_result_guide(function_call_result : str) -> str:
return f"SYSTEM MESSAGE: \n```json\n{function_call_result}\n```\n\nThe above is the result after functions are called. Use the result to answer the user's last question.\n\n"
137 changes: 137 additions & 0 deletions cookbook/function_call_force_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@

from duckduckgo_search import DDGS
import requests, os
import json

api_key=os.environ["GROQ_API_KEY"]
header = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
proxy_url = "https://groqcall.ai/proxy/groq/v1/chat/completions"

# or "http://localhost:8000/proxy/groq/v1/chat/completions" if running locally
# proxy_url = "http://localhost:8000/proxy/groq/v1/chat/completions"


def duckduckgo_search(query, max_results=None):
"""
Use this function to search DuckDuckGo for a query.
"""
with DDGS() as ddgs:
return [r for r in ddgs.text(query, safesearch='off', max_results=max_results)]

def duckduckgo_news(query, max_results=None):
"""
Use this function to get the latest news from DuckDuckGo.
"""
with DDGS() as ddgs:
return [r for r in ddgs.news(query, safesearch='off', max_results=max_results)]

function_map = {
"duckduckgo_search": duckduckgo_search,
"duckduckgo_news": duckduckgo_news,
}

request = {
"messages": [
{
"role": "system",
"content": "YOU MUST FOLLOW THESE INSTRUCTIONS CAREFULLY.\n<instructions>\n1. Use markdown to format your answers.\n</instructions>"
},
{
"role": "user",
"content": "Whats happening in France? Summarize top stories with sources, very short and concise."
}
],
"model": "mixtral-8x7b-32768",
# "tool_choice": "auto",
# "tool_choice": "none",
"tool_choice": {"type": "function", "function": {"name": "duckduckgo_search"}},
"tools": [
{
"type": "function",
"function": {
"name": "duckduckgo_search",
"description": "Use this function to search DuckDuckGo for a query.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The result from DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string"
},
"max_results": {
"type": [
"number",
"null"
]
}
}
}
}
},
{
"type": "function",
"function": {
"name": "duckduckgo_news",
"description": "Use this function to get the latest news from DuckDuckGo.\n\nArgs:\n query(str): The query to search for.\n max_results (optional, default=5): The maximum number of results to return.\n\nReturns:\n The latest news from DuckDuckGo.",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string"
},
"max_results": {
"type": [
"number",
"null"
]
}
}
}
}
}
]
}

response = requests.post(
proxy_url,
headers=header,
json=request
)
# Check if the request was successful
if response.status_code == 200:
# Process the response data (if needed)
res = response.json()
message = res['choices'][0]['message']
tools_response_messages = []
if not message['content'] and 'tool_calls' in message:
for tool_call in message['tool_calls']:
tool_name = tool_call['function']['name']
tool_args = tool_call['function']['arguments']
tool_args = json.loads(tool_args)
if tool_name not in function_map:
print(f"Error: {tool_name} is not a valid function name.")
continue
tool_func = function_map[tool_name]
tool_response = tool_func(**tool_args)
tools_response_messages.append({
"role": "tool", "content": json.dumps(tool_response)
})

if tools_response_messages:
request['messages'] += tools_response_messages
response = requests.post(
proxy_url,
headers=header,
json=request
)
if response.status_code == 200:
res = response.json()
print(res['choices'][0]['message']['content'])
else:
print("Error:", response.status_code, response.text)
else:
print(message['content'])
else:
print("Error:", response.status_code, response.text)
3 changes: 2 additions & 1 deletion cookbook/function_call_with_schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@

from duckduckgo_search import DDGS
import requests, os
api_key=os.environ["GROQ_API_KEY"]
import json

api_key=os.environ["GROQ_API_KEY"]
header = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
Expand Down
4 changes: 3 additions & 1 deletion cookbook/function_call_without_schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import requests
import json
import os

api_key = "YOUR_GROQ_API_KEY"
api_key=os.environ["GROQ_API_KEY"],
header = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
Expand Down

0 comments on commit 18f35c9

Please sign in to comment.