diff --git a/candle-nn/tests/kv_cache.rs b/candle-nn/tests/kv_cache.rs index c692a5baf..cc016ff1c 100644 --- a/candle-nn/tests/kv_cache.rs +++ b/candle-nn/tests/kv_cache.rs @@ -9,19 +9,24 @@ use candle::{Device, Result, Tensor}; #[test] fn kv_cache() -> Result<()> { let mut cache = candle_nn::kv_cache::Cache::new(0, 16); - let data = cache.current_data()?; - assert!(data.is_none()); - let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [1., 2., 3.]); - let t = Tensor::new(&[4f32], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); - let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; - cache.append(&t)?; - let data = cache.current_data()?.unwrap(); - assert_eq!(data.to_vec1::()?, [1., 2., 3., 4., 0., 5., 6., 7.]); + for _ in [0, 1] { + assert_eq!(cache.current_seq_len(), 0); + let data = cache.current_data()?; + assert!(data.is_none()); + let t = Tensor::new(&[1f32, 2., 3.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3.]); + let t = Tensor::new(&[4f32], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4.]); + let t = Tensor::new(&[0f32, 5., 6., 7.], &Device::Cpu)?; + cache.append(&t)?; + let data = cache.current_data()?.unwrap(); + assert_eq!(data.to_vec1::()?, [1., 2., 3., 4., 0., 5., 6., 7.]); + assert_eq!(cache.current_seq_len(), 8); + cache.reset(); + } Ok(()) }