Skip to content

Commit 94335e3

Browse files
Yushi HommaYushi Homma
authored andcommitted
Added documentation around Network Volumes for speeding up cold starts and updated the predict.py script
1 parent b765dce commit 94335e3

File tree

6 files changed

+62
-41
lines changed

6 files changed

+62
-41
lines changed

.dockerignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
.github
33
predict.py
44
README.md
5+
artifacts

README.md

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ This Docker image runs a Llama model on a serverless RunPod instance using the o
77

88
## Set Up
99
1. Create a RunPod account and navigate to the [RunPod Serverless Console](https://www.runpod.io/console/serverless).
10-
2. Navigate to `My Templates` and click on the `New Template` button.
11-
3. Enter in the following fields and click on the `Save Template` button:
10+
2. (Optional) Create a Network Volume to cache your model to speed up cold starts (but will incur some cost per hour for storage).
11+
- *Note: Only certain Network Volume regions are compatible with certain instance types on RunPod, so try out if your Network Volume makes your desired instance type Unavailable, try other regions for your Network Volume.*
12+
![70B Network Volume Configuration Example](artifacts/yh_runpod_network_volume_screenshot.png)
13+
3. Navigate to `My Templates` and click on the `New Template` button.
14+
4. Enter in the following fields and click on the `Save Template` button:
1215

1316
| Template Field | Value |
1417
| --- | --- |
@@ -34,7 +37,9 @@ This Docker image runs a Llama model on a serverless RunPod instance using the o
3437
| (Optional) `PROMPT_SUFFIX` | `"ASSISTANT: "` |
3538
| (Optional) `MAX_SEQ_LEN` | `4096` |
3639
| (Optional) `ALPHA_VALUE` | `1` |
37-
40+
| (If using Network Volumes) `HUGGINGFACE_HUB_CACHE` | `/runpod-volume/hub` |
41+
| (If using Network Volumes) `TRANSFORMERS_CACHE` | `/runpod-volume/hub` |
42+
![Airoboros 70B Template Configuration Example](artifacts/yh_airoboros_70b_template_screenshot)
3843
4. Now click on `My Endpoints` and click on the `New Endpoint` button.
3944
5. Fill in the following fields and click on the `Create` button:
4045
| Endpoint Field | Value |
@@ -45,8 +50,9 @@ This Docker image runs a Llama model on a serverless RunPod instance using the o
4550
| Max Workers | `1` |
4651
| Idle Timeout | `5` seconds |
4752
| FlashBoot | Checked/Enabled |
48-
| GPU Type(s) | Use the `Container Disk` section of step 3 to determine the smallest GPU that can load the entire 4 bit model. In our example's case, use 16 GB GPU. |
49-
53+
| GPU Type(s) | Use the `Container Disk` section of step 3 to determine the smallest GPU that can load the entire 4 bit model. In our example's case, use 16 GB GPU. Make smaller if using Network Volume instead. |
54+
| (Optional) Network Volume | `airoboros-7b` |
55+
![Airoboros 70B Template Configuration Example](artifacts/yh_airoboros_70b_template_screenshot)
5056
## Inference Usage
5157
See the `predict.py` file for an example. For convenience we also copy the code below.
5258

@@ -57,12 +63,13 @@ from time import sleep
5763
import logging
5864
import argparse
5965
import sys
66+
import json
6067

6168
endpoint_id = os.environ["RUNPOD_ENDPOINT_ID"]
6269
URI = f"https://api.runpod.ai/v2/{endpoint_id}/run"
6370

6471

65-
def run(prompt, stream=False):
72+
def run(prompt, params={}, stream=False):
6673
request = {
6774
'prompt': prompt,
6875
'max_new_tokens': 1800,
@@ -74,6 +81,8 @@ def run(prompt, stream=False):
7481
'stream': stream
7582
}
7683

84+
request.update(params)
85+
7786
response = requests.post(URI, json=dict(input=request), headers = {
7887
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
7988
})
@@ -85,35 +94,35 @@ def run(prompt, stream=False):
8594

8695

8796
def stream_output(task_id, stream=False):
88-
try:
89-
url = f"https://api.runpod.ai/v2/{endpoint_id}/stream/{task_id}"
90-
headers = {
91-
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
92-
}
97+
# try:
98+
url = f"https://api.runpod.ai/v2/{endpoint_id}/stream/{task_id}"
99+
headers = {
100+
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
101+
}
93102

94-
previous_output = ''
103+
previous_output = ''
95104

105+
try:
96106
while True:
97107
response = requests.get(url, headers=headers)
98108
if response.status_code == 200:
99109
data = response.json()
100-
if stream:
101-
if len(data['stream']) > 0:
102-
new_output = data['stream'][0]['output']
103-
104-
sys.stdout.write(new_output[len(previous_output):])
105-
sys.stdout.flush()
106-
previous_output = new_output
107-
elif len(data['stream']) > 0:
108-
return data['stream'][0]['output']
110+
if len(data['stream']) > 0:
111+
new_output = data['stream'][0]['output']
112+
113+
sys.stdout.write(new_output[len(previous_output):])
114+
sys.stdout.flush()
115+
previous_output = new_output
109116

110117
if data.get('status') == 'COMPLETED':
118+
if not stream:
119+
return previous_output
111120
break
112121

113122
elif response.status_code >= 400:
114-
logging.error(response.json())
123+
print(response)
115124
# Sleep for 0.1 seconds between each request
116-
sleep(0.1 if stream else 0.5)
125+
sleep(0.1 if stream else 1)
117126
except Exception as e:
118127
print(e)
119128
cancel_task(task_id)
@@ -131,6 +140,7 @@ def cancel_task(task_id):
131140
if __name__ == '__main__':
132141
parser = argparse.ArgumentParser(description='Runpod AI CLI')
133142
parser.add_argument('-s', '--stream', action='store_true', help='Stream output')
143+
parser.add_argument('-p', '--params_json', type=str, help='JSON string of generation params')
134144

135145
prompt = """Given the following clinical notes, what tests, diagnoses, and recommendations should the I give? Provide your answer as a detailed report with labeled sections "Diagnostic Tests", "Possible Diagnoses", and "Patient Recommendations".
136146
@@ -143,7 +153,13 @@ if __name__ == '__main__':
143153
-fh:father had MI recently,mother has thyroid dz
144154
-sh:non-smoker,mariguana 5-6 months ago,3 beers on the weekend, basketball at school
145155
-sh:no std,no other significant medical conditions."""
146-
print(run(prompt, stream=parser.parse_args().stream))
156+
args = parser.parse_args()
157+
params = json.loads(args.params_json) if args.params_json else "{}"
158+
import time
159+
start = time.time()
160+
print(run(prompt, params=params, stream=args.stream))
161+
print("Time taken: ", time.time() - start, " seconds")
162+
147163
```
148164

149165
Run the above code using the following command in terminal with the runpoint endpoint id assigned to your endpoint in step 5.
@@ -152,6 +168,6 @@ RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.p
152168
```
153169
To run with streaming enabled, use the `--stream` option. To set generation parameters, use the `--params_json` option to pass a JSON string of parameters:
154170
```bash
155-
RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.py --stream --params_json '{"temperature": 0.9, "max_new_tokens": 2048}'
171+
RUNPOD_AI_API_KEY='**************' RUNPOD_ENDPOINT_ID='*******' python predict.py --params_json '{"temperature": 0.3, "max_tokens": 1000, "prompt_prefix": "USER: ", "prompt_suffix": "ASSISTANT: "}'
156172
```
157173
You can generate the API key [here](https://www.runpod.io/console/serverless/user/settings) under API Keys.
241 KB
Loading
220 KB
Loading
99.7 KB
Loading

predict.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,34 +35,35 @@ def run(prompt, params={}, stream=False):
3535

3636

3737
def stream_output(task_id, stream=False):
38-
try:
39-
url = f"https://api.runpod.ai/v2/{endpoint_id}/stream/{task_id}"
40-
headers = {
41-
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
42-
}
38+
# try:
39+
url = f"https://api.runpod.ai/v2/{endpoint_id}/stream/{task_id}"
40+
headers = {
41+
"Authorization": f"Bearer {os.environ['RUNPOD_AI_API_KEY']}"
42+
}
4343

44-
previous_output = ''
44+
previous_output = ''
4545

46+
try:
4647
while True:
4748
response = requests.get(url, headers=headers)
4849
if response.status_code == 200:
4950
data = response.json()
50-
if stream:
51-
if len(data['stream']) > 0:
52-
new_output = data['stream'][0]['output']
53-
54-
sys.stdout.write(new_output[len(previous_output):])
55-
sys.stdout.flush()
56-
previous_output = new_output
57-
return data['stream'][0]['output']
51+
if len(data['stream']) > 0:
52+
new_output = data['stream'][0]['output']
53+
54+
sys.stdout.write(new_output[len(previous_output):])
55+
sys.stdout.flush()
56+
previous_output = new_output
5857

5958
if data.get('status') == 'COMPLETED':
59+
if not stream:
60+
return previous_output
6061
break
6162

6263
elif response.status_code >= 400:
63-
logging.error(response.json())
64+
print(response)
6465
# Sleep for 0.1 seconds between each request
65-
sleep(0.1 if stream else 0.5)
66+
sleep(0.1 if stream else 1)
6667
except Exception as e:
6768
print(e)
6869
cancel_task(task_id)
@@ -95,4 +96,7 @@ def cancel_task(task_id):
9596
-sh:no std,no other significant medical conditions."""
9697
args = parser.parse_args()
9798
params = json.loads(args.params_json) if args.params_json else "{}"
99+
import time
100+
start = time.time()
98101
print(run(prompt, params=params, stream=args.stream))
102+
print("Time taken: ", time.time() - start, " seconds")

0 commit comments

Comments
 (0)