Skip to content

Commit

Permalink
Fix warmup shapes for corner cases (#136)
Browse files Browse the repository at this point in the history
Co-authored-by: Karol Damaszke <[email protected]>
  • Loading branch information
kdamaszk and kdamaszk authored May 6, 2024
1 parent 4169ff8 commit bad7fe7
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 11 deletions.
36 changes: 25 additions & 11 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ impl Client {
};

// get all possible prefill batch sizes
let max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
let mut max_prefill_batch_size: u32 = max_prefill_tokens / max_input_length;
let max_decode_batch_size: u32 = match max_batch_size {
Some(max_batch_size) => max_batch_size as u32,
None => read_env_var("PREFILL_BATCH_BUCKET_SIZE", 8)
};
max_prefill_batch_size = cmp::min(max_prefill_batch_size, max_decode_batch_size);
let prefill_bucket_size: u32 = read_env_var("PREFILL_BATCH_BUCKET_SIZE", 4);
let batch_sizes: Vec<u32> = (prefill_bucket_size..max_prefill_batch_size+1).step_by(prefill_bucket_size as usize).collect();

Expand All @@ -142,23 +147,33 @@ impl Client {
}
}

// if max_batch_size is None, create two batches
let num_batches = max_batch_size.unwrap_or(2).min(2);
let mut id_counter: u64 = 0;
for shape in shapes.iter() {
// create two batches in order to trigger concatenate operation
// in case decode bs=1 create one batch
let batches: Vec<Batch> = vec![
let (batch_size, seq_length) = shape;
let mut batches: Vec<Batch> = vec![
self.create_warmup_batch(
*shape,
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
);
num_batches
)
];
// if possible, create second batch in order to trigger concatenate operation
if *batch_size < max_decode_batch_size {
batches.push(
self.create_warmup_batch(
(1, *seq_length),
&mut id_counter,
max_input_length,
max_total_tokens,
seq_bucket_size,
false,
)
);
}

let request = tonic::Request::new(WarmupRequest {
batches,
max_input_length,
Expand All @@ -168,7 +183,7 @@ impl Client {
let _response = self.stub.warmup(request).await?.into_inner();
}

//Send batches with deafult params to warm up Greedy search
// send batches with default params to warm up Greedy search
let mut greedy_shapes: Vec<(u32, u32)> = Vec::with_capacity(batch_sizes.len());
for batch_size in &batch_sizes {
greedy_shapes.push((*batch_size, seq_bucket_size.clone()));
Expand All @@ -182,8 +197,7 @@ impl Client {
max_total_tokens,
seq_bucket_size,
true,
);
num_batches
)
];
let request = tonic::Request::new(WarmupRequest {
batches,
Expand Down
2 changes: 2 additions & 0 deletions server/text_generation_server/models/causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,8 @@ def warmup(self, batches: List[CausalLMBatch]) -> None:

# if decode bs is 1 warmup ends here
if len(batches) == 0:
while decode_batch is not None:
_, decode_batch, _ = self.generate_token([decode_batch])
return

# prefill
Expand Down

0 comments on commit bad7fe7

Please sign in to comment.