From 617227075870a4c25467b45a30a4eb348f289389 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Fri, 15 Nov 2024 17:35:17 +0530 Subject: [PATCH] Make a async init function required for each thread --- crates/cubecl-wgpu/Cargo.toml | 3 - crates/cubecl-wgpu/src/runtime.rs | 100 +++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 25 deletions(-) diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index ef96295df..1169a727e 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -46,9 +46,6 @@ web-time = { workspace = true } cfg-if = { workspace = true } -[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies] -wasm-bindgen = { workspace = true } - [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ "export_tests", diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 95fe52a6b..898a3c229 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -68,8 +68,8 @@ impl Runtime for WgpuRuntime { { let server = LOCAL_DEVICE.with_borrow_mut(|runtime| { runtime - .entry(device.clone()) - .or_insert_with(|| ThreadLocalChannel::make_server(device)) + .get(device) + .expect(&format!("The wgpu server for {device:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread")) .clone() }); let server = server.borrow(); @@ -158,6 +158,32 @@ pub struct WgpuSetup { pub queue: Pdrc, } +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) { + let setup = create_setup_for_device::(&device).await; + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64), + }; + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device, + setup.queue, + setup.adapter, + options.tasks_max, + ); + + LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server)))); +} + /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] @@ -465,8 +491,11 @@ impl ComputeChannel for ThreadLocalChannel { ) -> cubecl_runtime::storage::BindingResource { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().get_resource(binding) }) } @@ -474,8 +503,11 @@ impl ComputeChannel for ThreadLocalChannel { fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().create(data) }) } @@ -483,8 +515,11 @@ impl ComputeChannel for ThreadLocalChannel { fn empty(&self, size: usize) -> cubecl_core::server::Handle { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().empty(size) }) } @@ -498,8 +533,11 @@ impl ComputeChannel for ThreadLocalChannel { ) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) } }) } @@ -507,8 +545,11 @@ impl ComputeChannel for ThreadLocalChannel { fn flush(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().flush() }) } @@ -516,8 +557,11 @@ impl ComputeChannel for ThreadLocalChannel { fn sync(&self) -> impl std::future::Future { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) .clone(); async move { server.borrow_mut().sync().await } }) @@ -526,8 +570,11 @@ impl ComputeChannel for ThreadLocalChannel { fn sync_elapsed(&self) -> impl std::future::Future { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) .clone(); async move { server.borrow_mut().sync_elapsed().await } }) @@ -536,8 +583,11 @@ impl ComputeChannel for ThreadLocalChannel { fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().memory_usage() }) } @@ -545,8 +595,11 @@ impl ComputeChannel for ThreadLocalChannel { fn enable_timestamps(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().enable_timestamps() }) } @@ -554,8 +607,11 @@ impl ComputeChannel for ThreadLocalChannel { fn disable_timestamps(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().disable_timestamps() }) }