Skip to content

Commit

Permalink
Make a async init function required for each thread
Browse files Browse the repository at this point in the history
  • Loading branch information
SarthakSingh31 committed Nov 15, 2024
1 parent 5412ef0 commit 6172270
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 25 deletions.
3 changes: 0 additions & 3 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ web-time = { workspace = true }

cfg-if = { workspace = true }

[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies]
wasm-bindgen = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
"export_tests",
Expand Down
100 changes: 78 additions & 22 deletions crates/cubecl-wgpu/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ impl Runtime for WgpuRuntime<WgslCompiler> {
{
let server = LOCAL_DEVICE.with_borrow_mut(|runtime| {
runtime
.entry(device.clone())
.or_insert_with(|| ThreadLocalChannel::make_server(device))
.get(device)
.expect(&format!("The wgpu server for {device:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread"))
.clone()
});
let server = server.borrow();
Expand Down Expand Up @@ -158,6 +158,32 @@ pub struct WgpuSetup {
pub queue: Pdrc<wgpu::Queue>,
}

#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))]
pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) {
let setup = create_setup_for_device::<AutoGraphicsApi, WgslCompiler>(&device).await;

let limits = setup.device.limits();
let mem_props = MemoryDeviceProperties {
max_page_size: limits.max_storage_buffer_binding_size as u64,
alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64),
};
let memory_management = {
let mem_props = mem_props.clone();
let config = options.memory_config;
let storage = WgpuStorage::new(setup.device.clone());
MemoryManagement::from_configuration(storage, mem_props, config)
};
let server = crate::compute::WgpuServer::new(
memory_management,
setup.device,
setup.queue,
setup.adapter,
options.tasks_max,
);

LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server))));
}

/// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].
/// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries.
#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))]
Expand Down Expand Up @@ -465,26 +491,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
) -> cubecl_runtime::storage::BindingResource<Server> {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().get_resource(binding)
})
}

fn create(&self, data: &[u8]) -> cubecl_core::server::Handle {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().create(data)
})
}

fn empty(&self, size: usize) -> cubecl_core::server::Handle {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().empty(size)
})
}
Expand All @@ -498,26 +533,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
) {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) }
})
}

fn flush(&self) {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().flush()
})
}

fn sync(&self) -> impl std::future::Future<Output = ()> {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device))
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
))
.clone();
async move { server.borrow_mut().sync().await }
})
Expand All @@ -526,8 +570,11 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
fn sync_elapsed(&self) -> impl std::future::Future<Output = cubecl_runtime::TimestampsResult> {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device))
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
))
.clone();
async move { server.borrow_mut().sync_elapsed().await }
})
Expand All @@ -536,26 +583,35 @@ impl ComputeChannel<Server> for ThreadLocalChannel {
fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().memory_usage()
})
}

fn enable_timestamps(&self) {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().enable_timestamps()
})
}

fn disable_timestamps(&self) {
LOCAL_DEVICE.with_borrow_mut(|runtime| {
let server = runtime
.entry(self.device.clone())
.or_insert_with(|| Self::make_server(&self.device));
.get(&self.device)
.expect(&format!(
"The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread",
self.device,
));
server.borrow_mut().disable_timestamps()
})
}
Expand Down

0 comments on commit 6172270

Please sign in to comment.