Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python client implementation #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 71 additions & 0 deletions contrib/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python
import json
import sys

# pip install websocket-client
import websocket

class ModelClient(object):
def __init__(self, endpoint_url):
self.endpoint_url = endpoint_url
self.ws = None
self.model = None

def open_session(self, model, max_length):
self.ws = websocket.create_connection(self.endpoint_url)
self.model = model
payload = {
"type": "open_inference_session",
"model": self.model,
"max_length": max_length,
}
self.ws.send(json.dumps(payload))
assert json.loads(self.ws.recv())['ok'] == True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert json.loads(self.ws.recv())['ok'] == True
assert json.loads(self.ws.recv())['ok'] is True


def close_session(self):
self.ws.close()

def generate(self, prompt, **kwargs):
payload = {
"type": "generate",
"inputs": prompt,
"max_new_tokens": 1,
"do_sample": 0,
"temperature": 0,
"stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
}
payload = {**payload, **kwargs}
self.ws.send(json.dumps(payload))

while True:
data = json.loads(self.ws.recv())
if not data['ok']:
raise Exception(data['traceback'])
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be nice to have a more specific exception here.

yield data['outputs']
if data['stop']:
break

def main():
client = ModelClient("ws://localhost:8000/api/v2/generate")
# client = ModelClient("ws://chat.petals.ml/api/v2/generate")
client.open_session("bigscience/bloom-petals", 128)

if len(sys.argv) > 1:
prompt = sys.argv[1]
# Bloomz variant uses </s> instead of \n\n as an eos token
if not prompt.endswith("\n\n"):
prompt += "\n\n"
else:
prompt = "The SQL command to extract all the users whose name starts with A is: \n\n"
print(f"Prompt: {prompt}")

for out in client.generate(prompt,
do_sample=True,
temperature=0.75,
top_p=0.9):
print(out, end="", flush=True)

client.close_session()

if __name__ == '__main__':
main()