diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index a3d30324c..0d7889a5f 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -50,7 +50,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 0 --range-end 5 + python3 run_suite.py --suite minimal --range-begin 0 --range-end 6 unit-test-backend-part-2: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -67,7 +67,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 5 --range-end 14 + python3 run_suite.py --suite minimal --range-begin 6 --range-end 15 unit-test-backend-part-3: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -84,7 +84,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 14 --range-end 23 + python3 run_suite.py --suite minimal --range-begin 15 --range-end 24 unit-test-backend-part-4: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' @@ -101,7 +101,7 @@ jobs: timeout-minutes: 25 run: | cd test/srt - python3 run_suite.py --suite minimal --range-begin 23 + python3 run_suite.py --suite minimal --range-begin 24 unit-test-backend-2-gpu-part-1: if: github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request' diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 971809124..f7d55ed9b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -231,6 +231,7 @@ def __init__( self.tokenizer = None self.finished_reason = None self.stream = False + self.to_abort = False # For incremental decoding # ----- | --------- read_ids -------| @@ -368,6 +369,10 @@ def check_finished(self): if self.finished(): return + if self.to_abort: + self.finished_reason = FINISH_ABORT() + return + if len(self.output_ids) >= self.sampling_params.max_new_tokens: self.finished_reason = FINISH_LENGTH( length=self.sampling_params.max_new_tokens diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a327f37a2..c7e831811 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -579,6 +579,8 @@ def handle_generate_request( "Image request length is longer than the KV cache pool size or " "the max context length aborting because you cannot truncate the image embeds" ) + req.image_inputs = None + req.origin_input_ids = [0] req.sampling_params.max_new_tokens = 0 self.waiting_queue.append(req) return @@ -1350,13 +1352,15 @@ def abort_request(self, recv_req: AbortReq): if to_del is not None: del self.waiting_queue[to_del] + logger.debug(f"Abort queued request. {req.rid=}") + return # Delete requests in the running batch if self.running_batch: for req in self.running_batch.reqs: if req.rid == recv_req.rid and not req.finished(): - req.finished_reason = FINISH_ABORT() - self.tree_cache.cache_finished_req(req) + logger.debug(f"Abort running request. {req.rid=}") + req.to_abort = True break def update_weights(self, recv_req: UpdateWeightReqInput): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 3f6cce23d..a1646fb5f 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -677,8 +677,14 @@ def run_and_check_memory_leak( enable_mixed_chunk, disable_overlap, chunked_prefill_size, + assert_has_abort, ): - other_args = ["--chunked-prefill-size", str(chunked_prefill_size)] + other_args = [ + "--chunked-prefill-size", + str(chunked_prefill_size), + "--log-level", + "debug", + ] if disable_radix_cache: other_args += ["--disable-radix-cache"] if enable_mixed_chunk: @@ -723,14 +729,19 @@ def run_and_check_memory_leak( # Assert success has_new_server = False has_leak = False + has_abort = False for line in output_lines: if "The server is fired" in line: has_new_server = True if "leak" in line: has_leak = True + if "Abort" in line: + has_abort = True assert has_new_server assert not has_leak + if assert_has_abort: + assert has_abort def run_mmlu_test( @@ -761,6 +772,7 @@ def workload_func(base_url, model): enable_mixed_chunk, disable_overlap, chunked_prefill_size, + assert_has_abort=False, ) @@ -800,4 +812,5 @@ def run_one(_): enable_mixed_chunk, enable_overlap, chunked_prefill_size, + assert_has_abort=False, ) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f55eb25f..c04a1671e 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -10,6 +10,7 @@ "models/test_lora.py", "models/test_reward_models.py", "sampling/penaltylib", + "test_abort.py", "test_chunked_prefill.py", "test_double_sparsity.py", "test_embedding_openai_server.py", diff --git a/test/srt/test_abort.py b/test/srt/test_abort.py new file mode 100644 index 000000000..ae27d83a8 --- /dev/null +++ b/test/srt/test_abort.py @@ -0,0 +1,54 @@ +import multiprocessing +import time +import unittest +from concurrent.futures import ThreadPoolExecutor + +import requests + +from sglang.test.test_utils import run_and_check_memory_leak + + +class TestAbort(unittest.TestCase): + def workload_func(self, base_url, model): + def process_func(): + def run_one(_): + prompt = """ + System: You are a helpful assistant. + User: What is the capital of France? + Assistant: The capital of France is + """ + + response = requests.post( + f"{base_url}/generate", + json={ + "text": prompt, + "sampling_params": { + "temperature": 0, + "max_new_tokens": 2048, + }, + }, + ) + ret = response.json() + + with ThreadPoolExecutor(16) as executor: + list(executor.map(run_one, list(range(16)))) + + p = multiprocessing.Process(target=process_func) + p.start() + time.sleep(0.5) + p.terminate() + time.sleep(10) + + def test_memory_leak(self): + run_and_check_memory_leak( + self.workload_func, + disable_radix_cache=False, + enable_mixed_chunk=False, + disable_overlap=False, + chunked_prefill_size=8192, + assert_has_abort=True, + ) + + +if __name__ == "__main__": + unittest.main()