Skip to content

Commit 6dfb215

Browse files
authored
Memory estimate fixes (#720)
1 parent ddd1c2d commit 6dfb215

File tree

5 files changed

+43
-25
lines changed

5 files changed

+43
-25
lines changed

.pre-commit-config.yaml

+15-19
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,16 @@
11
repos:
2-
- repo: https://github.com/pre-commit/pre-commit-hooks
3-
rev: v3.4.0 # Use the latest revision
4-
hooks:
5-
- id: trailing-whitespace
6-
- id: end-of-file-fixer
7-
- id: check-yaml
8-
- repo: https://github.com/psf/black
9-
rev: 24.2.0
10-
hooks:
11-
- id: black
12-
name: Format code
13-
args:
14-
- --line-length=120
15-
- repo: https://github.com/pycqa/flake8
16-
rev: 7.0.0
17-
hooks:
18-
- id: flake8
19-
name: flake8
20-
args: ['--max-line-length=120']
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v3.4.0 # Use the latest revision
4+
hooks:
5+
- id: trailing-whitespace
6+
- id: end-of-file-fixer
7+
- id: check-yaml
8+
- repo: https://github.com/astral-sh/ruff-pre-commit
9+
# Ruff version.
10+
rev: v0.8.3
11+
hooks:
12+
# Run the linter.
13+
- id: ruff
14+
args: [--fix]
15+
# Run the formatter.
16+
- id: ruff-format

launcher/src/main.rs

+14-1
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,10 @@ struct Args {
353353
#[clap(default_value = "64", long, env)]
354354
compile_max_rank: usize,
355355

356+
// The initial batch size for model CUDA compilations
357+
#[clap(default_value = "32", long, env)]
358+
compile_batch_size: usize,
359+
356360
/// The number of speculative tokens to generate in the model per step.
357361
/// Defaults to 0, meaning no speculative decoding.
358362
#[clap(long, env)]
@@ -633,7 +637,7 @@ struct Args {
633637
#[clap(long, env)]
634638
disable_sgmv: bool,
635639

636-
#[clap(default_value = "0.8", long, env)]
640+
#[clap(default_value = "0.9", long, env)]
637641
memory_wiggle_room: f32,
638642
}
639643

@@ -654,6 +658,7 @@ fn shard_manager(
654658
compile: bool,
655659
compile_max_batch_size: usize,
656660
compile_max_rank: usize,
661+
compile_batch_size: usize,
657662
speculative_tokens: Option<usize>,
658663
speculation_max_batch_size: usize,
659664
preloaded_adapter_ids: Vec<String>,
@@ -832,6 +837,12 @@ fn shard_manager(
832837
compile_max_rank.to_string().into(),
833838
));
834839

840+
// Compile initial batch size
841+
envs.push((
842+
"LORAX_COMPILE_BATCH_SIZE".into(),
843+
compile_batch_size.to_string().into(),
844+
));
845+
835846
// Speculative decoding max batch size
836847
envs.push((
837848
"LORAX_SPECULATION_MAX_BATCH_SIZE".into(),
@@ -1294,6 +1305,7 @@ fn spawn_shards(
12941305
let compile = args.compile;
12951306
let compile_max_batch_size = args.compile_max_batch_size;
12961307
let compile_max_rank = args.compile_max_rank;
1308+
let compile_batch_size = args.compile_batch_size;
12971309
let speculative_tokens = args.speculative_tokens;
12981310
let speculation_max_batch_size = args.speculation_max_batch_size;
12991311
let preloaded_adapter_ids = args.preloaded_adapter_ids.clone();
@@ -1325,6 +1337,7 @@ fn spawn_shards(
13251337
compile,
13261338
compile_max_batch_size,
13271339
compile_max_rank,
1340+
compile_batch_size,
13281341
speculative_tokens,
13291342
speculation_max_batch_size,
13301343
preloaded_adapter_ids,

server/lorax_server/models/flash_causal_lm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1350,7 +1350,7 @@ def warmup(self, batch: FlashCausalLMBatch, max_new_tokens: int, embedding_model
13501350
total_cache_size = self.num_layers * cache_block_size * 2 * dtype_size
13511351

13521352
free_memory = get_cuda_free_memory(self.device, MEMORY_FRACTION - ADAPTER_MEMORY_FRACTION)
1353-
free_memory -= graph_cache_memory
1353+
free_memory = max(0, free_memory - graph_cache_memory)
13541354
logger.info("Memory remaining for kv cache: {} MB", free_memory / 1024 / 1024)
13551355

13561356
batch_num_blocks = batch.num_blocks if batch is not None else 0

server/lorax_server/utils/dist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# CUDA memory fraction
1212
MEMORY_FRACTION = float(os.getenv("CUDA_MEMORY_FRACTION", "1.0"))
1313

14-
MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.8"))
14+
MEMORY_WIGGLE_ROOM = float(os.getenv("MEMORY_WIGGLE_ROOM", "0.9"))
1515

1616

1717
class FakeBarrier:

server/lorax_server/utils/graph.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from lorax_server.models.model import Model
2727

2828

29-
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 128))
29+
MAX_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_MAX_BATCH_SIZE", 256))
30+
COMPILE_BATCH_SIZE = int(os.environ.get("LORAX_COMPILE_BATCH_SIZE", 32))
3031
MAX_RANK = int(os.environ.get("LORAX_COMPILE_MAX_RANK", 64))
3132

3233
SLOT_PAD_VALUE = -1
@@ -40,7 +41,7 @@
4041
CACHED_BATCH_SIZES = [1, 2, 3, 4, 8, 16] + [
4142
BATCH_SIZE_INCREMENT * (i + 1) for i in range(MAX_BATCH_SIZE // BATCH_SIZE_INCREMENT)
4243
]
43-
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= MAX_BATCH_SIZE]
44+
CACHED_BATCH_SIZES = [b for b in CACHED_BATCH_SIZES if b <= COMPILE_BATCH_SIZE]
4445

4546
# Include 0 to ensure we can use cuda graphs without adapters
4647
# TODO(travis): use padding to allow for more ranks without increasing memory usage
@@ -472,6 +473,7 @@ def __init__(
472473
self.sliding_window_blocks = sliding_window_blocks
473474
self.layer_to_lora_weights = layer_to_lora_weights
474475
self.punica_wrapper = punica_wrapper
476+
self.batch_size = COMPILE_BATCH_SIZE
475477

476478
def can_use_graph(
477479
self,
@@ -603,7 +605,13 @@ def forward(
603605

604606
key = (batch_size, max_rank)
605607
graph = self.cache.get(key)
606-
if graph is None or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names()):
608+
if (
609+
graph is None
610+
or not graph.input_state.traced_adapter_layer_names.issuperset(adapter_data.layer_names())
611+
# This is the case where COMPILE_BATCH_SIZE < batch_size <= MAX_BATCH_SIZE so
612+
# we just retrace the graph for that new size
613+
or batch_size > self.batch_size
614+
):
607615
current_traced_adapter_layer_names = (
608616
graph.input_state.traced_adapter_layer_names if graph is not None else set()
609617
)
@@ -631,6 +639,7 @@ def forward(
631639
self.punica_wrapper,
632640
)
633641
self.cache[key] = graph
642+
self.batch_size = batch_size
634643

635644
output_states = graph.forward(
636645
input_ids=input_ids,

0 commit comments

Comments
 (0)