Skip to content

Commit

Permalink
Use range-checked crate
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Dec 4, 2023
1 parent a186d85 commit 7c6f343
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 12 deletions.
6 changes: 6 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ env_logger = "0.10.1"
tracing = "0.1.40"
tch = { git = "https://github.com/EricLBuehler/tch-rs.git", branch = "copy_derive", version = "0.14.0" } # pending on LaurentMazare/tch-rs#823
torch-sys = { git = "https://github.com/EricLBuehler/tch-rs.git", branch = "copy_derive", version = "0.14.0" } # pending on LaurentMazare/tch-rs#823
range-checked = { git = "https://github.com/EricLBuehler/range-checked.git", version = "0.1.0" }

[features]
default = []
Expand All @@ -43,4 +44,4 @@ cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:c
cudnn = ["candle-core/cudnn"]
flash-attn = ["cuda", "candle-transformers/flash-attn"]
mkl = ["dep:intel-mkl-src", "candle-core/mkl", "candle-nn/mkl", "candle-transformers/mkl"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
nccl = ["cuda", "cudarc/nccl", "dep:half"]
18 changes: 7 additions & 11 deletions src/paged_attention/cache_engine.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use candle_core::{DType, Device, Tensor};

use range_checked::F64Bounded;

use crate::openai::{models::ConfigLike, responses::APIError};

const _GB: usize = 1 << 30;
Expand All @@ -16,24 +18,18 @@ pub(crate) struct CacheConfig {
impl CacheConfig {
pub(crate) fn new(
block_size: usize,
gpu_mem_utilization: f64,
gpu_mem_utilization: F64Bounded<0, 1, false>,
swap_space_bytes: usize,
sliding_window: Option<usize>,
) -> Result<Self, APIError> {
if gpu_mem_utilization > 1.0 {
return Err(APIError::new(format!(
"GPU memory utilization must be less that 1.0. Got {gpu_mem_utilization}"
)));
}

Ok(Self {
) -> Self {
Self {
block_size,
gpu_mem_utilization,
gpu_mem_utilization: *gpu_mem_utilization,
swap_space_bytes: swap_space_bytes * _GB,
sliding_window,
num_gpu_blocks: None,
num_cpu_blocks: None,
})
}
}
}

Expand Down

0 comments on commit 7c6f343

Please sign in to comment.