Skip to content

Commit 51aea94

Browse files
Dynamic memory management preset + updated wgpu buffer memory management (#1962)
--------- Co-authored-by: mepatrick73 <[email protected]>
1 parent 5236e12 commit 51aea94

File tree

6 files changed

+191
-130
lines changed

6 files changed

+191
-130
lines changed

crates/burn-compute/benches/dynamic.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,22 @@
11
use std::collections::LinkedList;
22

33
use burn_compute::{
4-
memory_management::{dynamic::DynamicMemoryManagement, MemoryManagement},
4+
memory_management::{
5+
dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
6+
MemoryManagement,
7+
},
58
storage::BytesStorage,
69
};
710

811
const MB: usize = 1024 * 1024;
12+
913
fn main() {
1014
let start = std::time::Instant::now();
1115
let storage = BytesStorage::default();
12-
let mut mm = DynamicMemoryManagement::new(storage);
16+
let mut mm = DynamicMemoryManagement::new(
17+
storage,
18+
DynamicMemoryManagementOptions::preset(2048 * MB, 32),
19+
);
1320
let mut handles = LinkedList::new();
1421
for _ in 0..100 * 2048 {
1522
if handles.len() >= 4000 {

crates/burn-compute/src/memory_management/dynamic.rs

+117-54
Original file line numberDiff line numberDiff line change
@@ -3,39 +3,107 @@ use super::memory_pool::{
33
SmallMemoryPool,
44
};
55
use crate::storage::ComputeStorage;
6+
use alloc::vec::Vec;
67

78
use super::MemoryManagement;
89

910
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
1011
pub struct DynamicMemoryManagement<Storage> {
12+
min_chunk_alignment_offset: usize,
1113
small_memory_pool: SmallMemoryPool,
12-
small_medium_memory_pool: MemoryPool,
13-
medium_memory_pool: MemoryPool,
14-
main_memory_pool: MemoryPool,
14+
pools: Vec<MemoryPool>,
15+
options: Vec<MemoryPoolOptions>,
1516
storage: Storage,
1617
}
1718

19+
/// Options to initialize a [dynamic memory management](DynamicMemoryManagement).
20+
#[derive(new, Debug)]
21+
pub struct DynamicMemoryManagementOptions {
22+
pools: Vec<MemoryPoolOptions>,
23+
min_chunk_alignment_offset: usize,
24+
}
25+
26+
/// Options to create a memory pool.
27+
#[derive(Debug)]
28+
pub struct MemoryPoolOptions {
29+
/// The amount of bytes used for each chunk in the memory pool.
30+
pub chunk_size: usize,
31+
/// The number of chunks allocated directly at creation.
32+
///
33+
/// Useful when you know in advance how much memory you'll need.
34+
pub chunk_num_prealloc: usize,
35+
/// The max size in bytes a slice can take in the pool.
36+
pub slice_max_size: usize,
37+
}
38+
39+
impl DynamicMemoryManagementOptions {
40+
/// Creates the options from device limits.
41+
pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self {
42+
// Rounding down to a factor of 8.
43+
let max_chunk_size = (max_chunk_size / 8) * 8;
44+
45+
const MB: usize = 1024 * 1024;
46+
47+
let mut pools = Vec::new();
48+
49+
pools.push(MemoryPoolOptions {
50+
chunk_size: max_chunk_size,
51+
chunk_num_prealloc: 0,
52+
slice_max_size: max_chunk_size,
53+
});
54+
55+
let mut current = max_chunk_size;
56+
57+
while current >= 32 * MB {
58+
current /= 4;
59+
60+
pools.push(MemoryPoolOptions {
61+
chunk_size: current,
62+
chunk_num_prealloc: 0,
63+
// Creating max slices lower than the chunk size reduces fragmentation.
64+
slice_max_size: current / 2usize.pow(pools.len() as u32),
65+
});
66+
}
67+
68+
Self {
69+
pools,
70+
min_chunk_alignment_offset,
71+
}
72+
}
73+
}
74+
1875
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
1976
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
20-
pub fn new(storage: Storage) -> Self {
21-
let main_memory_pool = MemoryPool::new(
22-
MemoryExtensionStrategy::new_period_tick(10),
23-
RoundingStrategy::FixedAmount(1024 * 1024 * 1024),
24-
);
25-
let medium_memory_pool = MemoryPool::new(
26-
MemoryExtensionStrategy::Never,
27-
RoundingStrategy::FixedAmount(1024 * 1024 * 200),
28-
);
29-
let small_medium_memory_pool = MemoryPool::new(
30-
MemoryExtensionStrategy::Never,
31-
RoundingStrategy::FixedAmount(1024 * 1024 * 2),
32-
);
33-
let small_memory_pool = SmallMemoryPool::new();
77+
pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self {
78+
options
79+
.pools
80+
.sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size));
81+
82+
let min_chunk_alignment_offset = options.min_chunk_alignment_offset;
83+
84+
let pools = options
85+
.pools
86+
.iter()
87+
.map(|option| {
88+
let mut pool = MemoryPool::new(
89+
MemoryExtensionStrategy::Never,
90+
RoundingStrategy::FixedAmount(option.chunk_size),
91+
min_chunk_alignment_offset,
92+
);
93+
94+
for _ in 0..option.chunk_num_prealloc {
95+
pool.alloc(&mut storage, option.chunk_size, || {});
96+
}
97+
98+
pool
99+
})
100+
.collect();
101+
34102
Self {
35-
small_memory_pool,
36-
small_medium_memory_pool,
37-
main_memory_pool,
38-
medium_memory_pool,
103+
min_chunk_alignment_offset,
104+
small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset),
105+
pools,
106+
options: options.pools,
39107
storage,
40108
}
41109
}
@@ -62,50 +130,45 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
62130
return handle;
63131
}
64132

65-
if let Some(handle) = self
66-
.small_medium_memory_pool
67-
.get(&mut self.storage, &binding)
68-
{
69-
return handle;
133+
for pool in &mut self.pools {
134+
if let Some(handle) = pool.get(&mut self.storage, &binding) {
135+
return handle;
136+
}
70137
}
71138

72-
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
73-
return handle;
139+
panic!("No handle found in memory pools");
140+
}
141+
142+
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
143+
if size <= self.min_chunk_alignment_offset {
144+
return self
145+
.small_memory_pool
146+
.reserve(&mut self.storage, size, sync);
74147
}
75148

76-
if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
77-
return handle;
149+
for (index, option) in self.options.iter().enumerate() {
150+
if size <= option.slice_max_size {
151+
let pool = &mut self.pools[index];
152+
return pool.reserve(&mut self.storage, size, sync);
153+
}
78154
}
79155

80-
panic!("No handle found in the small and main memory pool");
156+
panic!("No memory pool big enough to reserve {size} bytes.");
81157
}
82158

83-
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
84-
if size <= 32 {
85-
self.small_memory_pool
86-
.reserve(&mut self.storage, size, sync)
87-
} else if size <= 2 * 1024 * 1024 {
88-
self.small_medium_memory_pool
89-
.reserve(&mut self.storage, size, sync)
90-
} else if size < 200 * 1024 * 1024 {
91-
self.medium_memory_pool
92-
.reserve(&mut self.storage, size, sync)
93-
} else {
94-
self.main_memory_pool.reserve(&mut self.storage, size, sync)
159+
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
160+
if size <= self.min_chunk_alignment_offset {
161+
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
95162
}
96-
}
97163

98-
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
99-
if size <= 32 {
100-
self.small_memory_pool.alloc(&mut self.storage, size, sync)
101-
} else if size <= 2 * 1024 * 1024 {
102-
self.small_medium_memory_pool
103-
.alloc(&mut self.storage, size, sync)
104-
} else if size <= 200 * 1024 * 1024 {
105-
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
106-
} else {
107-
self.main_memory_pool.alloc(&mut self.storage, size, sync)
164+
for (index, option) in self.options.iter().enumerate() {
165+
if size <= option.slice_max_size {
166+
let pool = &mut self.pools[index];
167+
return pool.alloc(&mut self.storage, size, sync);
168+
}
108169
}
170+
171+
panic!("No memory pool big enough to alloc {size} bytes.");
109172
}
110173

111174
fn dealloc(&mut self, _binding: Self::Binding) {

0 commit comments

Comments
 (0)