Skip to content

Commit

Permalink
fix KeyError bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mrwyattii committed Nov 8, 2023
1 parent 38029b4 commit 352c9be
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions mii/batching/ragged_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,8 +717,11 @@ def put_request(self,
kwargs: Dict,
session_id: Union[str,
None] = None) -> int:
if not self.is_rank_0:
return
# TODO: We should avoid any request/response work with non-rank 0, but
# this requires some refactoring how we do the put and request in
# `ModelResponse`
#if not self.is_rank_0:
# return
if self.stop_thread:
raise RuntimeError("The request queue was shutdown.")

Expand All @@ -733,9 +736,18 @@ def put_request(self,
for r in self.make_request(uid, input_tokens, kwargs):
self.request_queue.put(r)

# Temporary hack to avoid non-rank 0 processes not shutting down. See related TODO above.
if self.is_rank_0:
self.request_queue.empty()

return uid

def get_response(self, uid: int) -> List[Response]:
# TODO: We should avoid any request/response work with non-rank 0, but
# this requires some refactoring how we do the put and request in
# `ModelResponse`
#if not self.is_rank_0:
# return
result = self.result_queues[uid].get()
generated_token_ids = result[0]
if len(generated_token_ids) == 0:
Expand Down Expand Up @@ -783,4 +795,4 @@ def destroy_session(self,
stop_criterion=None,
stream=None,
))
self.uids.remove(uid)
self.uids.remove(uid)

0 comments on commit 352c9be

Please sign in to comment.