Skip to content

Commit 04b1126

Browse files
authored
[None][feat] Hang detection for executor loop and worker. (#10480)
Signed-off-by: Yuxian Qiu <[email protected]>
1 parent 50c22b8 commit 04b1126

File tree

6 files changed

+182
-12
lines changed

6 files changed

+182
-12
lines changed

tensorrt_llm/_common.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import ctypes
1717
import os
1818
import platform
19+
import threading
1920
import time
2021
from functools import wraps
2122
from pathlib import Path
@@ -34,7 +35,7 @@
3435
else:
3536
Network = None
3637

37-
from ._utils import str_dtype_to_trt
38+
from ._utils import print_all_stacks, str_dtype_to_trt
3839
from .bindings import MpiComm
3940
from .logger import logger
4041
from .plugin import _load_plugin_lib
@@ -82,6 +83,19 @@ def _init(log_level: object = None) -> None:
8283

8384
MpiComm.local_init()
8485

86+
def _print_stacks():
87+
counter = 0
88+
while True:
89+
time.sleep(print_stacks_period)
90+
counter += 1
91+
logger.error(f"Printing stacks {counter} times")
92+
print_all_stacks()
93+
94+
print_stacks_period = int(os.getenv("TRTLLM_PRINT_STACKS_PERIOD", "-1"))
95+
if print_stacks_period > 0:
96+
print_stacks_thread = threading.Thread(target=_print_stacks, daemon=True)
97+
print_stacks_thread.start()
98+
8599
logger.info("TensorRT LLM inited.")
86100

87101

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tensorrt_llm.mapping import CpType
1515

1616
from ..distributed import Distributed
17+
from .hang_detector import HangDetector
1718
from .llm_request import (ExecutorRequest, LlmRequest,
1819
executor_request_to_llm_request)
1920

@@ -47,10 +48,17 @@ def is_control_request(self):
4748
class ExecutorRequestQueue:
4849
"""Handles fetching and processing of new requests from the request queue."""
4950

50-
def __init__(self, dist: Distributed, enable_attention_dp: bool,
51-
max_batch_size: int, max_beam_width: int,
52-
max_num_active_requests: int, enable_iter_perf_stats: bool,
53-
batch_wait_timeout_ms: float):
51+
def __init__(
52+
self,
53+
dist: Distributed,
54+
enable_attention_dp: bool,
55+
max_batch_size: int,
56+
max_beam_width: int,
57+
max_num_active_requests: int,
58+
enable_iter_perf_stats: bool,
59+
batch_wait_timeout_ms: float,
60+
hang_detector: Optional[HangDetector] = None,
61+
):
5462
self.dist = dist
5563
self.request_queue: queue.Queue[RequestQueueItem] = queue.Queue()
5664
self.waiting_queue: deque[RequestQueueItem] = deque()
@@ -66,6 +74,7 @@ def __init__(self, dist: Distributed, enable_attention_dp: bool,
6674
self.active = True
6775
self.batch_wait_timeout_ms = batch_wait_timeout_ms
6876
self.send_requests_handler = None
77+
self.hang_detector = hang_detector or HangDetector()
6978

7079
# State tracking
7180
self.num_fetch_requests = 0
@@ -303,7 +312,8 @@ def _fetch_and_process_requests(
303312
self.request_accumulated.clear()
304313
# Reset timeout to 0 to avoid hanging when no new requests are available
305314
timeout = datetime.timedelta(0)
306-
new_requests.extend(self._get_from_request_queue(timeout))
315+
with self.hang_detector.pause():
316+
new_requests.extend(self._get_from_request_queue(timeout))
307317

308318
# Broadcast requests and handle Python objects
309319
new_requests, py_request_objects = self._handle_request_broadcasting(
@@ -477,8 +487,9 @@ def _handle_request_broadcasting(self,
477487
# Preserve original `new_requests` on rank 0
478488
_ = self._broadcast_new_requests(new_requests, py_request_objects)
479489
else:
480-
new_requests, py_request_objects = self._broadcast_new_requests(
481-
new_requests, py_request_objects)
490+
with self.hang_detector.pause():
491+
new_requests, py_request_objects = self._broadcast_new_requests(
492+
new_requests, py_request_objects)
482493

483494
return new_requests, py_request_objects
484495

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
import asyncio
2+
import threading
3+
from contextlib import contextmanager
4+
from typing import Callable, Optional
5+
6+
from tensorrt_llm._utils import print_all_stacks
7+
from tensorrt_llm.logger import logger
8+
9+
10+
class HangDetector:
11+
def __init__(
12+
self, timeout: Optional[int] = None, on_detected: Optional[Callable[[], None]] = None
13+
):
14+
self.timeout = timeout if timeout is not None else 300
15+
assert self.timeout > 0, "timeout must be greater than 0"
16+
self.on_detected = on_detected or (lambda: None)
17+
self.task = None
18+
self.loop = None
19+
self.loop_thread = None
20+
self.lock = threading.Lock()
21+
self.active = False
22+
self._detected = False
23+
24+
def start(self):
25+
"""Enable hang detection."""
26+
27+
def run_loop():
28+
asyncio.set_event_loop(self.loop)
29+
self.loop.run_forever()
30+
31+
self.active = True
32+
self.loop = asyncio.new_event_loop()
33+
self.loop_thread = threading.Thread(target=run_loop, daemon=True, name="hang_detector_loop")
34+
self.loop_thread.start()
35+
36+
async def _detect_hang(self):
37+
await asyncio.sleep(self.timeout)
38+
with self.lock:
39+
self._detected = True
40+
logger.error(f"Hang detected after {self.timeout} seconds.")
41+
print_all_stacks()
42+
self.on_detected()
43+
44+
def detected(self):
45+
"""Return True if hang is detected."""
46+
with self.lock:
47+
return self._detected
48+
49+
def checkpoint(self):
50+
"""Reset hang detection timer."""
51+
self.cancel_task()
52+
if self.active:
53+
self.task = asyncio.run_coroutine_threadsafe(self._detect_hang(), self.loop)
54+
55+
def cancel_task(self):
56+
"""Cancel the hang detection task."""
57+
if self.task is not None and not self.task.done():
58+
self.task.cancel()
59+
self.task = None
60+
61+
@contextmanager
62+
def pause(self):
63+
"""Pause hang detection in scope."""
64+
try:
65+
self.cancel_task()
66+
yield
67+
finally:
68+
self.checkpoint()
69+
70+
def stop(self):
71+
"""Stop hang detection."""
72+
self.active = False
73+
self.cancel_task()
74+
if self.loop is not None:
75+
# Cancel all pending tasks before stopping the loop
76+
def cancel_all_tasks():
77+
for task in asyncio.all_tasks(self.loop):
78+
if not task.done():
79+
task.cancel()
80+
self.loop.call_soon(self.loop.stop)
81+
82+
self.loop.call_soon_threadsafe(cancel_all_tasks)
83+
84+
if self.loop_thread is not None and self.loop_thread.is_alive():
85+
self.loop_thread.join()
86+
87+
self.loop = None
88+
self.loop_thread = None
89+
90+
def __enter__(self):
91+
self.start()
92+
return self
93+
94+
def __exit__(self, exc_type, exc_value, traceback):
95+
self.stop()
96+
return False

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from .guided_decoder import GuidedDecoder
4747
from .handle_additional_outputs import HandleAdditionalOutputs
4848
from .handle_logits import HandleLogits
49+
from .hang_detector import HangDetector
4950
from .kv_cache_connector import KvCacheConnectorManager
5051
from .kv_cache_transceiver import KvCacheTransceiver
5152
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
@@ -137,6 +138,7 @@ def __init__(self,
137138
max_seq_len: Optional[int] = None,
138139
peft_cache_config: Optional[PeftCacheConfig] = None,
139140
virtual_memory_pools: Optional[dict] = None,
141+
hang_detection_timeout: Optional[int] = None,
140142
execution_stream: Optional[torch.cuda.Stream] = None):
141143
super(PyExecutor, self).__init__()
142144
self.device_id = torch.cuda.current_device()
@@ -280,6 +282,15 @@ def __init__(self,
280282
self.adp_ctx_batching_wait_iters_count = 0
281283
self.batch_wait_iters_count = 0
282284

285+
def on_detected():
286+
self._handle_errors(
287+
f"Hang detected on rank {self.global_rank} in PyExecutor.")
288+
self.shutdown_event.set()
289+
self.is_shutdown = True
290+
291+
self.hang_detector = HangDetector(timeout=hang_detection_timeout,
292+
on_detected=on_detected)
293+
283294
# request fetcher initialization
284295
self._set_global_steady_clock_offset()
285296
self.executor_request_queue = ExecutorRequestQueue(
@@ -290,6 +301,7 @@ def __init__(self,
290301
max_num_active_requests=self.max_num_active_requests,
291302
enable_iter_perf_stats=self.enable_iter_perf_stats,
292303
batch_wait_timeout_ms=self.batch_wait_timeout_ms,
304+
hang_detector=self.hang_detector,
293305
)
294306
self.executor_request_queue.set_exclude_last_generation_logits(
295307
self.disable_overlap_scheduler, self.dist.pp_size)
@@ -476,6 +488,14 @@ def shutdown(self):
476488
"""
477489
self.executor_request_queue.enqueue_shutdown_request()
478490
self.shutdown_event.wait()
491+
if self.hang_detector.detected():
492+
# Early return here to avoid waiting for hanging threads.
493+
# Since `on_detected` has sent the error message as response,
494+
# this worker will be asked to shutdown immediately.
495+
# Since the whole process will shutdown after this `shutdown` call,
496+
# All threads and memory pools will be freed properly.
497+
logger.error("Hang detected, shutting down immediately.")
498+
return
479499
self.worker_thread.join()
480500
self.worker_started = False
481501
for manager in self.resource_manager.resource_managers.values():
@@ -960,10 +980,11 @@ def _executor_loop_pp(self):
960980
# ensure the context is created, otherwise, some MPI calls will fail.
961981
CUASSERT(cudart.cudaSetDevice(self.device_id))
962982
microbatch_id = 0
963-
with self._profiler() as profile_step:
983+
with self._profiler() as profile_step, self.hang_detector:
964984
iter_start_time = time.time()
965985
iter_stats = None
966986
while True:
987+
self.hang_detector.checkpoint()
967988
profile_step()
968989
if self.enable_iter_perf_stats:
969990
iter_start_time = time.time()
@@ -1349,11 +1370,12 @@ def _executor_loop(self):
13491370
torch.cuda.set_device(self.device_id)
13501371
# ensure the context is created, otherwise, some MPI calls will fail.
13511372
CUASSERT(cudart.cudaSetDevice(self.device_id))
1352-
with self._profiler() as profile_step:
1373+
with self._profiler() as profile_step, self.hang_detector:
13531374
sample_state = None
13541375
iter_start_time = time.time()
13551376
iter_stats = None
13561377
while True:
1378+
self.hang_detector.checkpoint()
13571379
profile_step()
13581380
if self.enable_iter_perf_stats:
13591381
iter_start_time = time.time()
@@ -1551,13 +1573,14 @@ def _executor_loop_overlap(self):
15511573
torch.cuda.set_device(self.device_id)
15521574
# ensure the context is created, otherwise, some MPI calls will fail.
15531575
CUASSERT(cudart.cudaSetDevice(self.device_id))
1554-
with self._profiler() as profile_step:
1576+
with self._profiler() as profile_step, self.hang_detector:
15551577
iter_start_time = time.time()
15561578
iter_stats = None
15571579
target_inputs = None
15581580
previous_tensors_device = None
15591581
can_forward = False if self.benchmark_req_queues_size > 0 and self.kv_cache_transceiver else True
15601582
while True:
1583+
self.hang_detector.checkpoint()
15611584
profile_step()
15621585
if self.enable_iter_perf_stats:
15631586
iter_start_time = time.time()

tensorrt_llm/_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import os
2222
import socket
2323
import struct
24+
import sys
2425
import tempfile
2526
import trace
27+
import traceback
2628
import weakref
2729
from contextlib import contextmanager
2830
from enum import EnumMeta
@@ -761,6 +763,13 @@ def is_sm_100f(sm_version=None):
761763
return sm_version == 100 or sm_version == 103
762764

763765

766+
def print_all_stacks():
767+
"""Print stack traces for all threads"""
768+
for thread_id, frame in sys._current_frames().items():
769+
logger.error(f"Thread {thread_id} stack trace:\n" +
770+
"".join(traceback.format_stack(frame)))
771+
772+
764773
def is_trace_enabled(env_var: str):
765774
value = os.environ.get(env_var, "-1")
766775
if value == "ALL":

tensorrt_llm/executor/worker.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import gc
22
import os
3+
import threading
4+
import time
35
import traceback
46
from concurrent.futures import ProcessPoolExecutor
57
from pathlib import Path
@@ -9,7 +11,7 @@
911

1012
from tensorrt_llm.logger import logger
1113

12-
from .._utils import mpi_comm, mpi_rank
14+
from .._utils import mpi_comm, mpi_rank, print_all_stacks
1315
from ..bindings import executor as tllm
1416
from ..builder import Engine
1517
from ..llmapi.llm_args import BaseLlmArgs
@@ -153,6 +155,21 @@ def worker_main(
153155
hmac_key: Optional[bytes] = None,
154156
) -> None:
155157

158+
def _print_stacks():
159+
counter = 0
160+
while True:
161+
time.sleep(print_stacks_period)
162+
counter += 1
163+
logger.error(f"Printing stacks {counter} times")
164+
print_all_stacks()
165+
166+
print_stacks_period = int(
167+
os.getenv("TRTLLM_WORKER_PRINT_STACKS_PERIOD", "-1"))
168+
if print_stacks_period > 0:
169+
print_stacks_thread = threading.Thread(target=_print_stacks,
170+
daemon=True)
171+
print_stacks_thread.start()
172+
156173
mpi_comm().barrier()
157174

158175
if llm_args is not None and llm_args.env_overrides:

0 commit comments

Comments
 (0)