Skip to content

Commit

Permalink
lets try this
Browse files Browse the repository at this point in the history
  • Loading branch information
wesdottoday committed Oct 17, 2023
1 parent 55a139f commit b734844
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 0 deletions.
1 change: 1 addition & 0 deletions services/api/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.env
21 changes: 21 additions & 0 deletions services/api/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# a dockerfile for a fastapi server

# Use the official Python image.
# https://hub.docker.com/_/python
FROM python:3.8-slim

# Copy local code to the container image.
ENV APP_HOME /app
WORKDIR $APP_HOME
ARG COMMITHASH
ENV COMMITHASH=$COMMITHASH
COPY requirements.txt ./

# Install production dependencies.
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

COPY . ./

# Run the fastapi service
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "4000"]
2 changes: 2 additions & 0 deletions services/api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Private LLM API

7 changes: 7 additions & 0 deletions services/api/build_locally.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/bin/sh
source ./.env
REPOSITORY=private-llm-fastapi-server
IMAGE_TAG=latest
docker build --build-arg="COMMITHASH=localtest" -t $REPOSITORY:$IMAGE_TAG .

docker run --rm -p 4001:4000 --env-file ./.env $REPOSITORY:$IMAGE_TAG
41 changes: 41 additions & 0 deletions services/api/contextualize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
contextualize.py
Maintainer: Wes Kennedy
Description: The contextualize module allows us to write app specific queries to help build our database.
"""

import db
from sqlalchemy import *

class Contextualizer():
def __init__():
pass

def customer_lookup_byid(customer_id):
"""
Takes a customer id and returns the customer's name
"""
pass

def customer_lookup_byname(customer_name):
"""
Takes a customer name and returns the customer's id
"""
pass

def customer_lookup_byemail(customer_email):
"""
Takes a customer email and returns the customer's id
"""
pass

def customer_previous_orders(db_conn, customer_id):
"""
Takes a customer id and returns a list of previous orders
"""
orders = []
query = text(f"SELECT * FROM orders WHERE customer_id = {customer_id}")
response = db.query_wrapper(db_conn, query)

return orders
54 changes: 54 additions & 0 deletions services/api/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
# add the ability to query a mysql server using sql_alchemy
from sqlalchemy import *
from sqlalchemy_utils import database_exists

MYSQL_HOST = os.getenv("MYSQL_HOST")
MYSQL_PORT = os.getenv("MYSQL_PORT")
MYSQL_USER = os.getenv("MYSQL_USER")
MYSQL_PASSWORD = os.getenv("MYSQL_PASSWORD")
MYSQL_DATABASE = os.getenv("MYSQL_DATABASE")



def init_connect():
# Create the connection string
connection_string = f"mysql+pymysql://{MYSQL_USER}:{MYSQL_PASSWORD}@{MYSQL_HOST}:{MYSQL_PORT}"
full_engine = f"{connection_string}/{MYSQL_DATABASE}"
engine = create_engine(connection_string)
if database_exists(full_engine):
return engine
else:
with engine.connect() as conn:
init_db(conn)
return engine

# Create the engine
def simple_connect():
pass

def init():
db_conn = init_connect()
use_db(db_conn)
init_chat_table(db_conn)
return db_conn

def init_db(db_conn):
query_create_db = text(f"CREATE DATABASE {MYSQL_DATABASE}") #create db
with db_conn.connect() as conn:
conn.execute(query_create_db)

def use_db(db_conn):
use_db = text(f"USE {MYSQL_DATABASE}")
with db_conn.connect() as conn:
conn.execute(use_db)

def init_chat_table(db_conn):
create_chat_table = text(f"CREATE TABLE IF NOT EXISTS messages (_id INT AUTO_INCREMENT,conversation_id CHAR(255),message TEXT,sender VARCHAR(50),timestamp TIMESTAMP,chat_context JSON,user_context TEXT,embedding BLOB NOT NULL, PRIMARY KEY (_id));")
with db_conn.connect() as conn:
conn.execute(create_chat_table)

def query_wrapper(db_conn, query):
with db_conn.connect() as conn:
result = conn.execute(query)
return result
181 changes: 181 additions & 0 deletions services/api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
# Create a FastAPI server and define the endpoints: /embedding, /chat as POST requests
# /embedding: takes a text input and returns the embedding from a remote api call using requests
# /chat: takes a text input and context and returns a response from a remote api call using requests
# Note: the remote api calls are defined in the config file

import os
import sys
import json
import pandas as pd
from typing import List, Optional, Dict
from fastapi import FastAPI, HTTPException, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from langchain.prompts import (
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.embeddings import SagemakerEndpointEmbeddings
from langchain.embeddings.sagemaker_endpoint import EmbeddingsContentHandler
from langchain.llms import SagemakerEndpoint
from langchain.chains import RetrievalQA
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from contextualize import Contextualizer
from langchain.chains import LLMChain
from langchain.memory import ConversationSummaryMemory
import db
from langchain.memory import ConversationBufferMemory

memory = ConversationBufferMemory()

db_conn = db.init()

### System Prompt
system_prompt = """
You are a helpful customer service agent working for Kai Shoes. \n
You will be chatting with a customer. \n
Use context from their previous orders to help them make decisions.
"""

# Add the parent directory to the path to import the config file
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

SAGEMAKER_ENDPOINT = os.getenv("SAGEMAKER_ENDPOINT")
SAGEMAKER_ROLE = os.getenv("SAGEMAKER_ROLE")
SAGEMAKER_REGION = os.getenv("SAGEMAKER_REGION")
AWS_ACCESS_KEY_ID = os.getenv("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY")
COMMITHASH = os.getenv("COMMITHASH")

print(COMMITHASH)

print(AWS_ACCESS_KEY_ID)
print(AWS_SECRET_ACCESS_KEY)


# Initialize the FastAPI server
app = FastAPI(
title="LLM API",
description="API for the LLM project",
version="0.1.0",
docs_url="/",
)

## FastAPI Routes
### /chat route

"""
Expected JSON:
{
"text": "this is my message",
"cust_id": "1234"
}
"""
@app.post("/chat")
async def chat(request: Request):
"""
Takes a text input and context and returns a response from a remote api call using requests
"""
# Get the request body
body = await request.json()
# Get the text input
print(body)
question = body.get("text")

# Get the context
context = body.get("cust_id")


# Get the response from the LLMChain
response = llm_prompt_run(context, question)
# Return the response
return {"response": response}

@app.get("/test")
async def root():
return {"message": "Hello World, I'm runnin on commit {}".format(COMMITHASH)}



# SageMaker Endpoint Handler
class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
# payload = {
# "inputs": [
# {
# "role": "system",
# "content": system_prompt,
# },
# {"role": "user", "content": prompt},

# ],
# "parameters": {"max_new_tokens": 1000, "top_p": 0.9, "temperature": 0.6},
# }
input_str = ''.join(prompt)
input_str = json.dumps({"inputs": input_str, "parameters": model_kwargs})
print(input_str)
# input_str = json.dumps(
# payload,
# )
input_utf = input_str
print(input_utf)
return input_utf

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
content = response_json
return content

content_handler = ContentHandler()

# # SageMaker Embeddings
# sagemaker_embeddings = SagemakerEndpointEmbeddings(
# endpoint_name=SAGEMAKER_ENDPOINT,
# region_name=SAGEMAKER_REGION,
# content_handler=content_handler,
# )

# query_result = sagemaker_embeddings.embed_query("foo")


def llm_prompt_run(user_context, question):

prompt = ChatPromptTemplate(
messages=[
SystemMessagePromptTemplate.from_template(
"You are a friendly support rep at Kai Shoes. Use the following pieces of information to answer the user's question. If you don't know the answer, just say that you don't know, don't try to make up an answer."
),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{context}"),
HumanMessagePromptTemplate.from_template("{question}")
]
)

# SageMaker LLMChain
llm = SagemakerEndpoint(
endpoint_name=SAGEMAKER_ENDPOINT,
region_name="us-west-2",
model_kwargs={"max_new_tokens": 700, "top_p": 0.9, "temperature": 0.6},
endpoint_kwargs={"CustomAttributes": 'accept_eula=true'},
content_handler=content_handler,
)

chat_history = []
memory = ConversationBufferMemory(memory_key="chat_history",return_messages=True)

chain = LLMChain(llm=llm,
prompt=prompt,
memory=memory,
)


llm_resp = chain.run({'context': user_context, 'question': question, 'chat_history': chat_history})

print(llm_resp)
return llm_resp
13 changes: 13 additions & 0 deletions services/api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
config==0.5.1
fastapi==0.103.2
loguru==0.7.2
pandas==1.5.3

Requests==2.31.0
langchain==0.0.313
PyMySQL==1.1.0
sqlalchemy
sqlalchemy_utils

uvicorn
boto3

0 comments on commit b734844

Please sign in to comment.