Skip to content

Commit

Permalink
use tid instead of uid in get_response
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Nov 9, 2023
1 parent 7e7301f commit 26db3a4
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 59 deletions.
117 changes: 72 additions & 45 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import deque, defaultdict
from dataclasses import dataclass, asdict, field
from functools import cached_property
from typing import Dict, Tuple, List, Any, Iterator, Union, DefaultDict, Set
from typing import Dict, Tuple, List, Any, Iterator, Union, DefaultDict
from typing_extensions import Self

import torch
Expand Down Expand Up @@ -131,6 +131,7 @@ def from_msg(msg: Dict[str, int]) -> Self:

@dataclass
class RaggedRequest:
tid: int
uid: int
input_tokens: torch.Tensor
prompt_length: int
Expand Down Expand Up @@ -457,6 +458,7 @@ def _generate_output(self, r: RaggedRequest) -> bool:
outputs = []
if r.stream:
outputs.append((
r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
Expand All @@ -469,13 +471,14 @@ def _generate_output(self, r: RaggedRequest) -> bool:
output_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens],
dim=0)
outputs.append((
r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
))
for output in outputs:
self.result_queues[r.uid].put_nowait(output)
self.result_queues[r.tid].put_nowait(output)

def _do_schedule_requests(self, requests: List[RaggedRequest]) -> None:

Expand Down Expand Up @@ -558,6 +561,7 @@ def schedule_requests(self) -> None:
[r for r in self.buffer if id(r) not in scheduled_requests_ids])

def make_request(self,
tid: int,
uid: int,
input_tokens: torch.Tensor,
kwargs: Dict) -> RaggedRequest:
Expand Down Expand Up @@ -603,6 +607,7 @@ def make_request(self,
assert kwargs == {}, f"Unknown keyword arguments {kwargs}"

return RaggedRequest(
tid=tid,
uid=uid,
input_tokens=input_tokens,
prompt_length=prompt_length,
Expand Down Expand Up @@ -634,60 +639,62 @@ def flush(self, uids: List[int]) -> None:


class MIIPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tid = threading.get_ident()

def __call__(self, inputs: Union[str, List[str]], **kwargs) -> ResponseBatch:
if isinstance(inputs, str):
inputs = [inputs]
outputs: ResponseBatch = ResponseBatch([])
uids: List[int] = list(range(len(inputs)))
flushed_uids: Set[int] = set()
uids_running: List[int] = list(range(len(inputs)))
uids_complete_order: List[int] = []

for uid, input in zip(uids, inputs):
for uid, input in zip(uids_running, inputs):
request_kwargs = kwargs.copy()
self._enqueue_request(uid, input, request_kwargs)
self.schedule_requests()
self._put_request(uid, input, request_kwargs)

while self.scheduled_requests:
self.generate()
# Make sure we flush uids as they are done generating
for uid, result_queue in self.result_queues.items():
if (not result_queue.empty()) and uid not in flushed_uids:
flushed_uids.add(uid)
self.request_queue.put_nowait(
RaggedRequest(
uid=uid,
input_tokens=None,
prompt_length=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
))
self.schedule_requests()

if self.is_rank_0:
while uids_running:
self.generate()
while not self.result_queues[self.tid].empty():
uid, response = self._get_response()
outputs.append(response)
self._flush_uid(uid)
uids_complete_order.append(uids_running.index(uid))
uids_running.remove(uid)
# To kick ranks 1 -> n out of the while loop
self._bcast_requests(force=True)
else:
while self.scheduled_requests:
self.generate()

for uid in range(len(inputs)):
outputs.append(self._dequeue_response(uid))
outputs = ResponseBatch([
r for idx,
r in sorted(zip(uids_complete_order,
outputs),
key=lambda pair: pair[0])
])

if self.model_config.all_rank_output:
outputs = self._bcast_responses(outputs)

return outputs

def _enqueue_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None:
self.result_queues[uid] = queue.Queue()
def _put_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None:
self.result_queues[self.tid] = queue.Queue()
input_tokens = self.tokenizer.encode(input)
request = self.make_request(uid, input_tokens, kwargs)
request = self.make_request(self.tid, uid, input_tokens, kwargs)
self.request_queue.put(request)

def _dequeue_response(self, uid: int) -> Response:
result = self.result_queues[uid].get()
generated_tokens = self.tokenizer.decode(result[0])
response = self.make_response(generated_tokens, result[1], result[2], result[3])
return response
def _get_response(self) -> Tuple[int, Response]:
result = self.result_queues[self.tid].get()
uid = result[0]
generated_tokens = self.tokenizer.decode(result[1])
response = self.make_response(generated_tokens, result[2], result[3], result[4])
return uid, response

def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch:
if self.is_rank_0:
Expand All @@ -700,6 +707,22 @@ def _bcast_responses(self, responses: ResponseBatch) -> ResponseBatch:
responses = ResponseBatch([Response.from_msg(msg) for msg in data_dicts])
return responses

def _flush_uid(self, uid: int) -> None:
if self.is_rank_0:
self.request_queue.put_nowait(
RaggedRequest(
tid=None,
uid=uid,
input_tokens=None,
prompt_length=None,
seq_length=None,
max_length=None,
max_new_tokens=None,
last_in_prompt=None,
post_processing=None,
stream=None,
))


class MIIAsyncPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -742,17 +765,18 @@ def put_request(self, prompt: str, kwargs: Dict) -> int:

uid = self._get_uid()

with self.lock:
if uid not in self.result_queues:
self.result_queues[uid] = queue.Queue()

# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid

tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()

input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(uid, input_tokens, kwargs)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)

return uid
Expand All @@ -762,7 +786,7 @@ def is_response_ready(self, uid: int) -> bool:
return True
return not self.result_queues[uid].empty()

def get_response(self, uid: int) -> List[Response]:
def get_response(self) -> Tuple[int, Response]:
# TODO: We should avoid any request/response work with non-rank 0, but
# this requires some refactoring how we do the put and request in
# `ModelResponse`
Expand All @@ -771,14 +795,16 @@ def get_response(self, uid: int) -> List[Response]:
prompt_length=None,
generated_length=None,
finish_reason=None)
result = self.result_queues[uid].get()
generated_token_ids = result[0]
tid = threading.get_ident()
result = self.result_queues[tid].get()
uid = result[0]
generated_token_ids = result[1]
if len(generated_token_ids) == 0:
generated_text = ""
else:
generated_text = self.tokenizer.decode(generated_token_ids)
response = self.make_response(generated_text, result[1], result[2], result[3])
return response
response = self.make_response(generated_text, result[2], result[3], result[4])
return uid, response

def start(self) -> None:
self.thread = threading.Thread(target=self, daemon=True)
Expand All @@ -797,6 +823,7 @@ def flush_uid(self, uid: int) -> None:
if self.is_rank_0:
self.request_queue.put_nowait(
RaggedRequest(
tid=None,
uid=uid,
input_tokens=None,
prompt_length=None,
Expand Down
19 changes: 5 additions & 14 deletions mii/grpc_related/modelresponse_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,20 +77,11 @@ def _run_inference(self, method_name, request_proto):
# Get responses from the pipeline as they are ready, flush finished uids
# so new requests can be processed
while uids_running:
for uid in uids_running:
# If the response is not ready, move to next uid
if not self.inference_pipeline.is_response_ready(uid):
continue
if self.inference_pipeline.is_rank_0:
print("RESPONSE READY")

# If a response is ready, get it and flush the uid so any queued requests can be processed
response = self.inference_pipeline.get_response(uid)
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_running.index(uid))
uids_running.remove(uid)
time.sleep(0.001) # So we don't spin too much
uid, response = self.inference_pipeline.get_response()
responses.append(response)
self.inference_pipeline.flush_uid(uid)
uids_complete_order.append(uids_running.index(uid))
uids_running.remove(uid)
end = time.time()

# Sort responses in the order of prompts
Expand Down

0 comments on commit 26db3a4

Please sign in to comment.