Skip to content

Commit 1c7780a

Browse files
authored
Feat/dynamic small pool (#1931)
1 parent f9ec2e1 commit 1c7780a

File tree

3 files changed

+251
-5
lines changed

3 files changed

+251
-5
lines changed

crates/burn-compute/src/memory_management/dynamic.rs

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
use super::memory_pool::{
22
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
3+
SmallMemoryPool,
34
};
45
use crate::storage::ComputeStorage;
56

67
use super::MemoryManagement;
78

89
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
910
pub struct DynamicMemoryManagement<Storage> {
10-
small_memory_pool: MemoryPool,
11+
small_memory_pool: SmallMemoryPool,
12+
medium_memory_pool: MemoryPool,
1113
main_memory_pool: MemoryPool,
1214
storage: Storage,
1315
}
@@ -20,14 +22,16 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
2022
RoundingStrategy::RoundUp,
2123
1024 * 1024 * 1024 * 2,
2224
);
23-
let small_memory_pool = MemoryPool::new(
25+
let medium_memory_pool = MemoryPool::new(
2426
MemoryExtensionStrategy::Never,
2527
RoundingStrategy::None,
2628
1024 * 1024 * 512,
2729
);
30+
let small_memory_pool = SmallMemoryPool::new();
2831
Self {
29-
main_memory_pool,
3032
small_memory_pool,
33+
main_memory_pool,
34+
medium_memory_pool,
3135
storage,
3236
}
3337
}
@@ -54,6 +58,10 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
5458
return handle;
5559
}
5660

61+
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
62+
return handle;
63+
}
64+
5765
if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
5866
return handle;
5967
}
@@ -62,17 +70,22 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
6270
}
6371

6472
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
65-
if size < 512 {
73+
if size <= 32 {
6674
self.small_memory_pool
6775
.reserve(&mut self.storage, size, sync)
76+
} else if size < 512 {
77+
self.medium_memory_pool
78+
.reserve(&mut self.storage, size, sync)
6879
} else {
6980
self.main_memory_pool.reserve(&mut self.storage, size, sync)
7081
}
7182
}
7283

7384
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
74-
if size < 512 {
85+
if size <= 32 {
7586
self.small_memory_pool.alloc(&mut self.storage, size, sync)
87+
} else if size < 512 {
88+
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
7689
} else {
7790
self.main_memory_pool.alloc(&mut self.storage, size, sync)
7891
}

crates/burn-compute/src/memory_management/memory_pool/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@ mod ring;
33

44
mod base;
55
mod handle;
6+
mod small;
67

78
pub use base::*;
89
pub use handle::*;
910
pub use ring::*;
11+
pub use small::*;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
use super::{ChunkHandle, ChunkId, MemoryPoolBinding, MemoryPoolHandle, SliceHandle, SliceId};
2+
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
3+
use alloc::vec::Vec;
4+
use hashbrown::HashMap;
5+
6+
/// A memory pool that allocates fixed-size chunks (32 bytes each) and reuses them to minimize allocations.
7+
///
8+
/// - Only one slice is supported per chunk due to the limitations in WGPU where small allocations cannot be offset.
9+
/// - The pool uses a ring buffer to efficiently manage and reuse chunks.
10+
///
11+
/// Fields:
12+
/// - `chunks`: A hashmap storing the allocated chunks by their IDs.
13+
/// - `slices`: A hashmap storing the slices by their IDs.
14+
/// - `ring_buffer`: A vector used as a ring buffer to manage chunk reuse.
15+
/// - `index`: The current position in the ring buffer.
16+
pub struct SmallMemoryPool {
17+
chunks: HashMap<ChunkId, SmallChunk>,
18+
slices: HashMap<SliceId, SmallSlice>,
19+
ring_buffer: Vec<ChunkId>,
20+
index: usize,
21+
}
22+
23+
#[derive(new, Debug)]
24+
pub struct SmallChunk {
25+
pub storage: StorageHandle,
26+
#[allow(dead_code)]
27+
pub handle: ChunkHandle,
28+
pub slice: Option<SliceId>,
29+
}
30+
31+
#[derive(new, Debug)]
32+
pub struct SmallSlice {
33+
pub storage: StorageHandle,
34+
pub handle: SliceHandle,
35+
#[allow(dead_code)]
36+
pub chunk: ChunkHandle,
37+
pub padding: usize,
38+
}
39+
40+
impl SmallSlice {
41+
pub fn effective_size(&self) -> usize {
42+
self.storage.size() + self.padding
43+
}
44+
}
45+
46+
const BUFFER_ALIGNMENT: usize = 32;
47+
48+
impl SmallMemoryPool {
49+
pub fn new() -> Self {
50+
Self {
51+
chunks: HashMap::new(),
52+
slices: HashMap::new(),
53+
ring_buffer: Vec::new(),
54+
index: 0,
55+
}
56+
}
57+
58+
/// Returns the resource from the storage, for the specified handle.
59+
pub fn get<Storage: ComputeStorage>(
60+
&mut self,
61+
storage: &mut Storage,
62+
binding: &MemoryPoolBinding,
63+
) -> Option<Storage::Resource> {
64+
self.slices
65+
.get(binding.slice.id())
66+
.map(|s| &s.storage)
67+
.map(|h| storage.get(h))
68+
}
69+
70+
/// Reserves memory of specified size using the reserve algorithm, and return
71+
/// a handle to the reserved memory.
72+
///
73+
/// Also clean ups, merging free slices together if permitted by the merging strategy
74+
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
75+
&mut self,
76+
storage: &mut Storage,
77+
size: usize,
78+
sync: Sync,
79+
) -> MemoryPoolHandle {
80+
assert!(size <= BUFFER_ALIGNMENT);
81+
let slice = self.get_free_slice(size);
82+
83+
match slice {
84+
Some(slice) => MemoryPoolHandle {
85+
slice: slice.clone(),
86+
},
87+
None => self.alloc(storage, size, sync),
88+
}
89+
}
90+
91+
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
92+
&mut self,
93+
storage: &mut Storage,
94+
size: usize,
95+
_sync: Sync,
96+
) -> MemoryPoolHandle {
97+
assert!(size <= BUFFER_ALIGNMENT);
98+
99+
self.alloc_slice(storage, size)
100+
}
101+
102+
fn alloc_slice<Storage: ComputeStorage>(
103+
&mut self,
104+
storage: &mut Storage,
105+
slice_size: usize,
106+
) -> MemoryPoolHandle {
107+
let handle_chunk = self.create_chunk(storage, BUFFER_ALIGNMENT);
108+
let chunk_id = *handle_chunk.id();
109+
let slice = self.allocate_slice(handle_chunk.clone(), slice_size);
110+
111+
let handle_slice = slice.handle.clone();
112+
self.update_chunk_metadata(chunk_id, slice);
113+
114+
MemoryPoolHandle {
115+
slice: handle_slice,
116+
}
117+
}
118+
119+
fn allocate_slice(&self, handle_chunk: ChunkHandle, slice_size: usize) -> SmallSlice {
120+
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
121+
122+
let effective_size = slice.effective_size();
123+
assert_eq!(effective_size, BUFFER_ALIGNMENT);
124+
125+
slice
126+
}
127+
128+
fn update_chunk_metadata(&mut self, chunk_id: ChunkId, slice: SmallSlice) {
129+
let slice_id = *slice.handle.id();
130+
131+
self.slices.insert(slice_id, slice);
132+
self.chunks.get_mut(&chunk_id).unwrap().slice = Some(slice_id);
133+
}
134+
135+
fn find_free_slice(&mut self) -> Option<SliceId> {
136+
if self.ring_buffer.is_empty() {
137+
return None;
138+
}
139+
for _ in 0..self.ring_buffer.len() {
140+
let chunk_id = self.ring_buffer.get(self.index).unwrap();
141+
let chunk = self.chunks.get(chunk_id).unwrap();
142+
let slice = self.slices.get(&chunk.slice.unwrap()).unwrap();
143+
self.index = (self.index + 1) % self.ring_buffer.len();
144+
if slice.handle.is_free() {
145+
return Some(*slice.handle.id());
146+
}
147+
}
148+
None
149+
}
150+
151+
/// Finds a free slice that can contain the given size
152+
/// Returns the chunk's id and size.
153+
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
154+
let slice_id = self.find_free_slice();
155+
156+
let slice_id = match slice_id {
157+
Some(val) => val,
158+
None => return None,
159+
};
160+
161+
let slice = self.slices.get_mut(&slice_id).unwrap();
162+
let old_slice_size = slice.effective_size();
163+
164+
let offset = match slice.storage.utilization {
165+
StorageUtilization::Full(_) => 0,
166+
StorageUtilization::Slice { offset, size: _ } => offset,
167+
};
168+
assert_eq!(offset, 0);
169+
slice.storage.utilization = StorageUtilization::Slice { offset, size };
170+
let new_padding = old_slice_size - size;
171+
slice.padding = new_padding;
172+
assert_eq!(
173+
slice.effective_size(),
174+
old_slice_size,
175+
"new and old slice should have the same size"
176+
);
177+
178+
Some(slice.handle.clone())
179+
}
180+
181+
/// Creates a slice of size `size` upon the given chunk with the given offset.
182+
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> SmallSlice {
183+
assert_eq!(offset, 0);
184+
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
185+
let handle = SliceHandle::new();
186+
187+
let storage = StorageHandle {
188+
id: chunk.storage.id.clone(),
189+
utilization: StorageUtilization::Slice { offset, size },
190+
};
191+
192+
let padding = calculate_padding(size);
193+
194+
SmallSlice::new(storage, handle, chunk.handle.clone(), padding)
195+
}
196+
197+
/// Creates a chunk of given size by allocating on the storage.
198+
fn create_chunk<Storage: ComputeStorage>(
199+
&mut self,
200+
storage: &mut Storage,
201+
size: usize,
202+
) -> ChunkHandle {
203+
let padding = calculate_padding(size);
204+
let effective_size = size + padding;
205+
206+
let storage = storage.alloc(effective_size);
207+
let handle = ChunkHandle::new();
208+
let id = *handle.id();
209+
210+
self.ring_buffer.push(id);
211+
212+
self.chunks
213+
.insert(id, SmallChunk::new(storage, handle.clone(), None));
214+
215+
handle
216+
}
217+
218+
#[allow(unused)]
219+
fn deallocate<Storage: ComputeStorage>(&mut self, _storage: &mut Storage) {
220+
todo!()
221+
}
222+
}
223+
224+
fn calculate_padding(size: usize) -> usize {
225+
let remainder = size % BUFFER_ALIGNMENT;
226+
if remainder != 0 {
227+
BUFFER_ALIGNMENT - remainder
228+
} else {
229+
0
230+
}
231+
}

0 commit comments

Comments
 (0)