From d2d2a712eb97a48ddf705382cef46e759d217132 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Fri, 8 Nov 2024 13:48:05 +0530 Subject: [PATCH 01/10] Added support for target_feature atomics on wasm32 The target_feature atomics enabled multithread on the wasm32. The devices in wgpu are not Send or Sync when the target_feature atmoics is enabled as they cannot be shared across threads. This PR enables cubecl to run wgpu on a dedicated thread and then communicate to this thread using channels. --- Cargo.toml | 2 + .../src/memory_management/base.rs | 1 + crates/cubecl-wgpu/Cargo.toml | 4 + crates/cubecl-wgpu/src/compiler/base.rs | 10 +- .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 11 +- crates/cubecl-wgpu/src/compute/poll.rs | 4 +- crates/cubecl-wgpu/src/compute/server.rs | 550 +++++++++++++++++- crates/cubecl-wgpu/src/compute/storage.rs | 292 +++++++--- crates/cubecl-wgpu/src/lib.rs | 12 + crates/cubecl-wgpu/src/runtime.rs | 112 +++- 10 files changed, 863 insertions(+), 135 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 205aac918..38b97d290 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ serde_json = { version = "1.0.119", default-features = false } dashmap = "5.5.3" hashbrown = "0.14.5" spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } +rayon = "1" getrandom = { version = "0.2.15", default-features = false } rand = { version = "0.8.5", default-features = false, features = [ @@ -73,6 +74,7 @@ pretty_assertions = "1.4" # Async embassy-futures = { version = "0.1.1" } # for no-std futures-lite = { version = "2.3.0", default-features = false } +futures = "0.3.31" [profile.dev] opt-level = 2 diff --git a/crates/cubecl-runtime/src/memory_management/base.rs b/crates/cubecl-runtime/src/memory_management/base.rs index 798e670be..feaa986b4 100644 --- a/crates/cubecl-runtime/src/memory_management/base.rs +++ b/crates/cubecl-runtime/src/memory_management/base.rs @@ -4,6 +4,7 @@ use alloc::{format, string::String}; /// Amount of memory in use by this allocator /// and statistics on how much memory is reserved and /// wasted in total. +#[derive(Debug)] pub struct MemoryUsage { /// The number of allocations currently active. pub number_allocs: u64, diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index 1f03dda59..b6cddd113 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -45,6 +45,10 @@ web-time = { workspace = true } cfg-if = { workspace = true } +[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies] +futures = { workspace = true } +rayon = { workspace = true } + [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ "export_tests", diff --git a/crates/cubecl-wgpu/src/compiler/base.rs b/crates/cubecl-wgpu/src/compiler/base.rs index e040190de..a924dcca3 100644 --- a/crates/cubecl-wgpu/src/compiler/base.rs +++ b/crates/cubecl-wgpu/src/compiler/base.rs @@ -1,25 +1,23 @@ -use std::sync::Arc; - use cubecl_core::{ prelude::CompiledKernel, server::ComputeServer, Compiler, ExecutionMode, Feature, }; use cubecl_runtime::DeviceProperties; use wgpu::{Adapter, ComputePipeline, Device, Queue}; -use crate::WgpuServer; +use crate::{Pdrc, WgpuServer, WgpuServerInner}; pub trait WgpuCompiler: Compiler { fn compile( - server: &mut WgpuServer, + server: &mut WgpuServerInner, kernel: as ComputeServer>::Kernel, mode: ExecutionMode, ) -> CompiledKernel; fn create_pipeline( - server: &mut WgpuServer, + server: &mut WgpuServerInner, kernel: CompiledKernel, mode: ExecutionMode, - ) -> Arc; + ) -> Pdrc; #[allow(async_fn_in_trait)] async fn request_device(adapter: &Adapter) -> (Device, Queue); diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 5060466e9..58ac40325 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -1,4 +1,4 @@ -use std::{borrow::Cow, sync::Arc}; +use std::borrow::Cow; use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory}; use super::{LocalArray, Subgroup}; @@ -6,6 +6,7 @@ use crate::{ compiler::{base::WgpuCompiler, wgsl}, WgpuServer, }; +use crate::{Pdrc, WgpuServerInner}; use cubecl_core::{ ir::{self as cube, HybridAllocator, UIntKind}, prelude::CompiledKernel, @@ -70,10 +71,10 @@ impl cubecl_core::Compiler for WgslCompiler { impl WgpuCompiler for WgslCompiler { fn create_pipeline( - server: &mut WgpuServer, + server: &mut WgpuServerInner, kernel: CompiledKernel, mode: ExecutionMode, - ) -> Arc { + ) -> Pdrc { let source = &kernel.source; let repr = kernel.repr.unwrap(); let module = match mode { @@ -118,7 +119,7 @@ impl WgpuCompiler for WgslCompiler { push_constant_ranges: &[], }); - Arc::new( + Pdrc::new( server .device .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { @@ -136,7 +137,7 @@ impl WgpuCompiler for WgslCompiler { } fn compile( - _server: &mut WgpuServer, + _server: &mut WgpuServerInner, kernel: as ComputeServer>::Kernel, mode: ExecutionMode, ) -> CompiledKernel { diff --git a/crates/cubecl-wgpu/src/compute/poll.rs b/crates/cubecl-wgpu/src/compute/poll.rs index 1bc253c29..cf8f5d19a 100644 --- a/crates/cubecl-wgpu/src/compute/poll.rs +++ b/crates/cubecl-wgpu/src/compute/poll.rs @@ -58,10 +58,12 @@ mod _impl { // On Wasm, the browser handles the polling loop, so we don't need anything. #[cfg(target_family = "wasm")] mod _impl { + use crate::Pdrc; + #[derive(Debug)] pub struct WgpuPoll {} impl WgpuPoll { - pub fn new(_device: alloc::sync::Arc) -> Self { + pub fn new(_device: Pdrc) -> Self { Self {} } pub fn start_polling(&self) -> alloc::sync::Arc<()> { diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 0efe93291..00aa7dde5 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -1,41 +1,125 @@ +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use std::rc::Rc; use std::{future::Future, marker::PhantomData, num::NonZero, pin::Pin, time::Duration}; use super::poll::WgpuPoll; use super::WgpuStorage; -use crate::compiler::base::WgpuCompiler; -use alloc::sync::Arc; +use crate::{compiler::base::WgpuCompiler, Pdrc}; use cubecl_common::future; use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId}; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use cubecl_runtime::storage::{StorageHandle, StorageId, StorageUtilization}; use cubecl_runtime::{ debug::{DebugLogger, ProfileLevel}, - memory_management::{MemoryHandle, MemoryLock, MemoryManagement}, + memory_management::{MemoryHandle, MemoryLock, MemoryManagement, MemoryUsage}, server::{self, ComputeServer}, storage::{BindingResource, ComputeStorage}, ExecutionMode, TimestampsError, TimestampsResult, }; use hashbrown::HashMap; use web_time::Instant; -use wgpu::{CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType}; +use wgpu::{ + CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType, + WasmNotSend, +}; /// Wgpu compute server. #[derive(Debug)] -pub struct WgpuServer { +pub struct WgpuServerInner { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] memory_management: MemoryManagement, - pub(crate) device: Arc, - queue: Arc, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + memory_management: MemoryManagement>, + pub(crate) device: Pdrc, + queue: Pdrc, encoder: CommandEncoder, current_pass: Option>, tasks_count: usize, - pipelines: HashMap>, + pipelines: HashMap>, tasks_max: usize, logger: DebugLogger, poll: WgpuPoll, storage_locked: MemoryLock, duration_profiled: Option, timestamps: KernelTimestamps, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + memory: HashMap>, _compiler: PhantomData, } +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +pub type WgpuServer = WgpuServerInner; + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +#[derive(Debug)] +pub struct WgpuServer { + tx: std::sync::mpsc::Sender>, +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +impl WgpuServer { + pub fn new(tx: std::sync::mpsc::Sender>) -> Self { + WgpuServer { tx } + } +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +pub enum ServerCommand { + Read { + tx: futures::channel::oneshot::Sender>, + binding: server::Binding, + }, + GetResource { + tx: futures::channel::oneshot::Sender>>, + binding: server::Binding, + }, + Create { + tx: futures::channel::oneshot::Sender, + data: &'static [u8], + }, + Empty { + tx: futures::channel::oneshot::Sender, + size: usize, + }, + Execute { + tx: futures::channel::oneshot::Sender<()>, + kernel: as ComputeServer>::Kernel, + count: CubeCount, + bindings: Vec, + mode: ExecutionMode, + }, + Flush { + tx: futures::channel::oneshot::Sender<()>, + }, + Sync { + tx: futures::channel::oneshot::Sender<()>, + }, + SyncElapsed { + tx: futures::channel::oneshot::Sender, + }, + MemoryUsage { + tx: futures::channel::oneshot::Sender, + }, + EnableTimestamps { + tx: futures::channel::oneshot::Sender<()>, + }, + DisableTimestamps { + tx: futures::channel::oneshot::Sender<()>, + }, + Alloc { + tx: futures::channel::oneshot::Sender, + size: u64, + }, + PerformDeallocations { + tx: futures::channel::oneshot::Sender<()>, + deallocations: Vec, + }, +} + +trait FutureWasmNotSend: Future + WasmNotSend {} + +impl + WasmNotSend> FutureWasmNotSend for T {} + #[derive(Debug)] enum KernelTimestamps { Native { query_set: QuerySet, init: bool }, @@ -78,12 +162,15 @@ fn create_encoder(device: &wgpu::Device) -> CommandEncoder { }) } -impl WgpuServer { +impl WgpuServerInner { /// Create a new server. pub fn new( + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] memory_management: MemoryManagement, - device: Arc, - queue: Arc, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + memory_management: MemoryManagement>, + device: Pdrc, + queue: Pdrc, tasks_max: usize, ) -> Self { let logger = DebugLogger::default(); @@ -107,15 +194,17 @@ impl WgpuServer { poll: WgpuPoll::new(device.clone()), duration_profiled: None, timestamps, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + memory: HashMap::default(), _compiler: PhantomData, } } fn pipeline( &mut self, - kernel: ::Kernel, + kernel: as ComputeServer>::Kernel, mode: ExecutionMode, - ) -> Arc { + ) -> Pdrc { let mut kernel_id = kernel.id(); kernel_id.mode(mode); @@ -192,7 +281,7 @@ impl WgpuServer { } } - fn sync_queue(&mut self) -> Pin + Send + 'static>> { + fn sync_queue(&mut self) -> Pin + 'static>> { self.flush(); #[cfg(target_family = "wasm")] @@ -220,7 +309,7 @@ impl WgpuServer { fn sync_queue_elapsed( &mut self, - ) -> Pin + Send + 'static>> { + ) -> Pin + 'static>> { self.clear_compute_pass(); enum TimestampMethod { @@ -294,21 +383,28 @@ impl WgpuServer { } } } -} - -impl ComputeServer for WgpuServer { - type Kernel = Box>; - type Storage = WgpuStorage; - type Feature = Feature; - fn read(&mut self, binding: server::Binding) -> impl Future> + Send + 'static { + fn read( + &mut self, + binding: server::Binding, + ) -> impl Future> + WasmNotSend + 'static { let rb = self.get_resource(binding); let resource = rb.resource(); self.clear_compute_pass(); - self.read_wgpu_buffer(&resource.buffer, resource.offset(), resource.size()) + + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + let buffer = &resource.buffer; + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + let buffer = &self + .memory + .get(&resource.buffer) + .expect("Buffer does not exist in the wgpu server memory") + .clone(); + + self.read_wgpu_buffer(buffer, resource.offset(), resource.size()) } - fn get_resource(&mut self, binding: server::Binding) -> BindingResource { + fn get_resource(&mut self, binding: server::Binding) -> BindingResource> { // Keep track of any buffer that might be used in the wgpu queue, as we cannot copy into them // after they have any outstanding compute work. Calling get_resource repeatedly // will add duplicates to this, but that is ok. @@ -357,7 +453,16 @@ impl ComputeServer for WgpuServer { // Write to the staging buffer. Next queue submission this will copy the data to the GPU. self.queue - .write_buffer_with(&resource.buffer, resource.offset(), len) + .write_buffer_with( + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + &resource.buffer, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + self.memory + .get(&resource.buffer) + .expect("Buffer does not exist in the wgpu server memory"), + resource.offset(), + len, + ) .expect("Failed to write to staging buffer.")[0..data.len()] .copy_from_slice(data); } @@ -376,7 +481,7 @@ impl ComputeServer for WgpuServer { unsafe fn execute( &mut self, - kernel: Self::Kernel, + kernel: as ComputeServer>::Kernel, count: CubeCount, bindings: Vec, mode: ExecutionMode, @@ -415,7 +520,10 @@ impl ComputeServer for WgpuServer { .enumerate() .map(|(i, r)| wgpu::BindGroupEntry { binding: i as u32, - resource: r.resource().as_wgpu_bind_resource(), + resource: r.resource().as_wgpu_bind_resource( + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + &self.memory, + ), }) .collect::>(); let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { @@ -470,7 +578,12 @@ impl ComputeServer for WgpuServer { CubeCount::Dynamic(_) => { let binding_resource = dispatch_br.as_ref().unwrap(); pass.dispatch_workgroups_indirect( + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] &binding_resource.resource().buffer, + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + self.memory + .get(&binding_resource.resource().buffer) + .expect("Buffer does not exist in the wgpu server memory"), binding_resource.resource().offset(), ); } @@ -571,4 +684,387 @@ impl ComputeServer for WgpuServer { self.timestamps.disable(); } } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + pub fn handle_commands(&mut self, rx: std::sync::mpsc::Receiver>) -> ! { + loop { + match rx.recv() { + Ok(command) => match command { + ServerCommand::Read { tx, binding } => tx + .send(futures::executor::block_on(self.read(binding))) + .expect("Failed to send response"), + ServerCommand::GetResource { tx, binding } => { + if tx.send(self.get_resource(binding)).is_err() { + panic!("Failed to send response") + } + } + ServerCommand::Create { tx, data } => { + tx.send(self.create(data)).expect("Failed to send response") + } + ServerCommand::Empty { tx, size } => { + tx.send(self.empty(size)).expect("Failed to send response") + } + ServerCommand::Execute { + tx, + kernel, + count, + bindings, + mode, + } => tx + .send(unsafe { self.execute(kernel, count, bindings, mode) }) + .expect("Failed to send response"), + ServerCommand::Flush { tx } => { + tx.send(self.flush()).expect("Failed to send response") + } + ServerCommand::Sync { tx } => tx + .send(futures::executor::block_on(self.sync())) + .expect("Failed to send response"), + ServerCommand::SyncElapsed { tx } => tx + .send(futures::executor::block_on(self.sync_elapsed())) + .expect("Failed to send response"), + ServerCommand::MemoryUsage { tx } => tx + .send(self.memory_usage()) + .expect("Failed to send response"), + ServerCommand::EnableTimestamps { tx } => tx + .send(self.enable_timestamps()) + .expect("Failed to send response"), + ServerCommand::DisableTimestamps { tx } => tx + .send(self.disable_timestamps()) + .expect("Failed to send response"), + ServerCommand::Alloc { tx, size } => { + let id = StorageId::new(); + let buffer = self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC + | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + }); + + self.memory.insert(id, Rc::new(buffer)); + + tx.send(StorageHandle::new( + id, + StorageUtilization { offset: 0, size }, + )) + .expect("Failed to send response"); + } + ServerCommand::PerformDeallocations { + tx, + mut deallocations, + } => { + for id in deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() + } + } + + tx.send(()).expect("Failed to send response"); + } + }, + Err(err) => log::error!("Failed to receive command: {err}"), + } + } + } +} + +impl ComputeServer for WgpuServer { + type Kernel = Box>; + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + type Storage = WgpuStorage; + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + type Storage = WgpuStorage; + type Feature = Feature; + + fn read(&mut self, binding: server::Binding) -> impl Future> + Send + 'static { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.read(binding) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Read { tx, binding }) + .expect("Failed to send the message to the WgpuServerInner"); + + async move { + rx.await + .expect("Failed to receive the response from the WgpuServerInner") + } + } + } + + fn get_resource(&mut self, binding: server::Binding) -> BindingResource { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.get_resource(binding) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::GetResource { tx, binding }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(binding) => { + if let Some(binding) = binding { + return binding; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + /// When we create a new handle from existing data, we use custom allocations so that we don't + /// have to execute the current pending tasks. + /// + /// This is important, otherwise the compute passes are going to be too small and we won't be able to + /// fully utilize the GPU. + fn create(&mut self, data: &[u8]) -> server::Handle { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.create(data) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Create { + tx, + // Safety: Since we wait for the execution of the command to finish below + // we can be sure that this data will not disappear + data: unsafe { std::mem::transmute(data) }, + }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(handle) => { + if let Some(handle) = handle { + return handle; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + fn empty(&mut self, size: usize) -> server::Handle { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.empty(size) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Empty { tx, size }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(handle) => { + if let Some(handle) = handle { + return handle; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + unsafe fn execute( + &mut self, + kernel: Self::Kernel, + count: CubeCount, + bindings: Vec, + mode: ExecutionMode, + ) { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.execute(kernel, count, bindings, mode) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Execute { + tx, + kernel, + count, + bindings, + mode, + }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(response) => { + if response.is_some() { + break; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + fn flush(&mut self) { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.flush() + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Flush { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(response) => { + if response.is_some() { + break; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + /// Returns the total time of GPU work this sync completes. + fn sync(&mut self) -> impl Future + 'static { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.sync() + } + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Sync { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + async move { + rx.await + .expect("Failed to receive the response from the WgpuServerInner") + } + } + } + + /// Returns the total time of GPU work this sync completes. + fn sync_elapsed(&mut self) -> impl Future + 'static { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.sync_elapsed() + } + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::SyncElapsed { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + async move { + rx.await + .expect("Failed to receive the response from the WgpuServerInner") + } + } + } + + fn memory_usage(&self) -> MemoryUsage { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.memory_usage() + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::MemoryUsage { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(memory_usage) => { + if let Some(memory_usage) = memory_usage { + return memory_usage; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + fn enable_timestamps(&mut self) { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.enable_timestamps() + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::EnableTimestamps { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(response) => { + if response.is_some() { + break; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } + + fn disable_timestamps(&mut self) { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + self.disable_timestamps() + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::DisableTimestamps { tx }) + .expect("Failed to send the message to the WgpuServerInner"); + + loop { + match rx.try_recv() { + Ok(response) => { + if response.is_some() { + break; + } + } + Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), + } + } + } + } } diff --git a/crates/cubecl-wgpu/src/compute/storage.rs b/crates/cubecl-wgpu/src/compute/storage.rs index 307d047f0..df80289fb 100644 --- a/crates/cubecl-wgpu/src/compute/storage.rs +++ b/crates/cubecl-wgpu/src/compute/storage.rs @@ -1,106 +1,242 @@ -use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; +use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId}; use hashbrown::HashMap; -use std::{num::NonZeroU64, sync::Arc}; +use std::num::NonZeroU64; -/// Buffer storage for wgpu. -pub struct WgpuStorage { - memory: HashMap>, - deallocations: Vec, - device: Arc, -} +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +mod _impl { + use std::sync::Arc; + + use cubecl_runtime::storage::StorageUtilization; -impl core::fmt::Debug for WgpuStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) + use super::*; + + /// Buffer storage for wgpu. + pub struct WgpuStorage { + memory: HashMap>, + deallocations: Vec, + device: Arc, } -} -/// The memory resource that can be allocated for wgpu. -#[derive(new)] -pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, + impl core::fmt::Debug for WgpuStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) + } + } - offset: u64, - size: u64, -} + /// Keeps actual wgpu buffer references in a hashmap with ids as key. + impl WgpuStorage { + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(device: Arc) -> Self { + Self { + memory: HashMap::new(), + deallocations: Vec::new(), + device, + } + } -impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { - let binding = wgpu::BufferBinding { - buffer: &self.buffer, - offset: self.offset, - size: Some( - NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), - ), - }; - wgpu::BindingResource::Buffer(binding) + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() + } + } + } } - /// Return the buffer size. - pub fn size(&self) -> u64 { - self.size + impl ComputeStorage for WgpuStorage { + type Resource = WgpuResource; + + // 32 bytes is enough to handle a double4 worth of alignment. + // See: https://github.com/gfx-rs/wgpu/issues/3508 + // NB: cudamalloc and co. actually align to _256_ bytes. Worth + // trying this in the future to see if it reduces memory coalescing. + const ALIGNMENT: u64 = 32; + + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let buffer = self.memory.get(&handle.id).unwrap(); + WgpuResource::new(buffer.clone(), handle.offset(), handle.size()) + } + + fn alloc(&mut self, size: u64) -> StorageHandle { + let id = StorageId::new(); + let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC + | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + })); + + self.memory.insert(id, buffer); + StorageHandle::new(id, StorageUtilization { offset: 0, size }) + } + + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); + } } - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - self.offset + /// The memory resource that can be allocated for wgpu. + #[derive(new)] + pub struct WgpuResource { + /// The wgpu buffer. + pub buffer: Arc, + + offset: u64, + size: u64, } -} -/// Keeps actual wgpu buffer references in a hashmap with ids as key. -impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - device, + impl WgpuResource { + /// Return the binding view of the buffer. + pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { + let binding = wgpu::BufferBinding { + buffer: &self.buffer, + offset: self.offset, + size: Some( + NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), + ), + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + self.size + } + + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + self.offset } } +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +mod _impl { + use std::rc::Rc; + + use crate::{compiler::base::WgpuCompiler, compute::server::ServerCommand}; + + use super::*; + + /// Buffer storage for wgpu. + pub struct WgpuStorage { + deallocations: Vec, + tx: std::sync::mpsc::Sender>, + } - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() + /// Keeps actual wgpu buffer references in a hashmap with ids as key. + impl WgpuStorage { + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(tx: std::sync::mpsc::Sender>) -> Self { + Self { + deallocations: Vec::new(), + tx, + } + } + + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::PerformDeallocations { + tx, + deallocations: self.deallocations.drain(..).collect(), + }) + .expect("Failed to send command to the wgpu server"); + + loop { + match rx.try_recv() { + Ok(response) => { + if response.is_some() { + break; + } + } + Err(err) => { + panic!("Failed to receive the response from the wgpu server: {err}") + } + } } } } -} -impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; + impl ComputeStorage for WgpuStorage { + type Resource = WgpuResource; + + // 32 bytes is enough to handle a double4 worth of alignment. + // See: https://github.com/gfx-rs/wgpu/issues/3508 + // NB: cudamalloc and co. actually align to _256_ bytes. Worth + // trying this in the future to see if it reduces memory coalescing. + const ALIGNMENT: u64 = 32; - // 32 bytes is enough to handle a double4 worth of alignment. - // See: https://github.com/gfx-rs/wgpu/issues/3508 - // NB: cudamalloc and co. actually align to _256_ bytes. Worth - // trying this in the future to see if it reduces memory coalescing. - const ALIGNMENT: u64 = 32; + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + WgpuResource::new(handle.id, handle.offset(), handle.size()) + } + + fn alloc(&mut self, size: u64) -> StorageHandle { + let (tx, mut rx) = futures::channel::oneshot::channel(); + self.tx + .send(ServerCommand::Alloc { tx, size }) + .expect("Failed to send command to the wgpu server"); + + loop { + match rx.try_recv() { + Ok(handle) => { + if let Some(handle) = handle { + return handle; + } + } + Err(err) => { + panic!("Failed to receive the response from the wgpu server: {err}") + } + } + } + } - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let buffer = self.memory.get(&handle.id).unwrap(); - WgpuResource::new(buffer.clone(), handle.offset(), handle.size()) + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); + } } - fn alloc(&mut self, size: u64) -> StorageHandle { - let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC - | wgpu::BufferUsages::INDIRECT, - mapped_at_creation: false, - })); - - self.memory.insert(id, buffer); - StorageHandle::new(id, StorageUtilization { offset: 0, size }) + /// The memory resource that can be allocated for wgpu. + #[derive(new)] + pub struct WgpuResource { + /// The storage id. + pub buffer: StorageId, + offset: u64, + size: u64, } - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); + impl WgpuResource { + /// Return the binding view of the buffer. + pub fn as_wgpu_bind_resource<'a>( + &self, + buffers: &'a HashMap>, + ) -> wgpu::BindingResource<'a> { + let buffer = buffers.get(&self.buffer).expect("Buffer does not exist"); + let binding = wgpu::BufferBinding { + buffer, + offset: self.offset, + size: Some( + NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), + ), + }; + wgpu::BindingResource::Buffer(binding) + } + + /// Return the buffer size. + pub fn size(&self) -> u64 { + self.size + } + + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + self.offset + } } } + +pub use _impl::*; diff --git a/crates/cubecl-wgpu/src/lib.rs b/crates/cubecl-wgpu/src/lib.rs index 4b9c444df..e1dcf676a 100644 --- a/crates/cubecl-wgpu/src/lib.rs +++ b/crates/cubecl-wgpu/src/lib.rs @@ -20,6 +20,18 @@ pub use runtime::*; #[cfg(feature = "spirv")] pub use compiler::spirv; +/// Platform dependent reference counting. Uses [`alloc::sync::Arc`] on all platforms except +/// `wasm32` when the feature `atomics` is enabled. Uses [`alloc::rc::Rc`] instead when on +/// `wasm32` and with the `atomics` feature enabled. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +type Pdrc = alloc::sync::Arc; + +/// Platform dependent reference counting. Uses [`alloc::sync::Arc`] on all platforms except +/// `wasm32` when the feature `atomics` is enabled. Uses [`alloc::rc::Rc`] instead when on +/// `wasm32` and with the `atomics` feature enabled. +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +type Pdrc = alloc::rc::Rc; + #[cfg(test)] mod tests { pub type TestRuntime = crate::WgpuRuntime; diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 224a9d108..eda0225c9 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -3,9 +3,8 @@ use std::marker::PhantomData; use crate::{ compiler::{base::WgpuCompiler, wgsl::WgslCompiler}, compute::{WgpuServer, WgpuStorage}, - AutoGraphicsApi, GraphicsApi, WgpuDevice, + AutoGraphicsApi, GraphicsApi, Pdrc, WgpuDevice, }; -use alloc::sync::Arc; use cubecl_common::future; use cubecl_core::{Feature, Runtime}; pub use cubecl_runtime::memory_management::MemoryConfiguration; @@ -31,18 +30,13 @@ static RUNTIME: ComputeRuntime> impl Runtime for WgpuRuntime { type Compiler = WgslCompiler; - type Server = WgpuServer; + type Server = Server; - type Channel = MutexComputeChannel>; + type Channel = MutexComputeChannel; type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || { - let setup = future::block_on(create_setup_for_device::( - device, - )); - create_client_on_setup(setup, RuntimeOptions::default()) - }) + RUNTIME.client(device, move || create_client(device)) } fn name() -> &'static str { @@ -89,17 +83,96 @@ impl Default for RuntimeOptions { #[derive(Clone, Debug)] pub struct WgpuSetup { /// The underlying wgpu instance. - pub instance: Arc, + pub instance: Pdrc, /// The selected 'adapter'. This corresponds to a physical device. - pub adapter: Arc, + pub adapter: Pdrc, /// The wgpu device Burn will use. Nb: There can only be one device per adapter. - pub device: Arc, + pub device: Pdrc, /// The queue Burn commands will be submitted to. - pub queue: Arc, + pub queue: Pdrc, +} + +pub fn create_client( + device: &WgpuDevice, +) -> ComputeClient, MutexComputeChannel>> { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + let setup = future::block_on(create_setup_for_device::( + device, + )); + create_client_on_setup(setup, RuntimeOptions::default()) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let (tx_once, mut rx_once) = futures::channel::oneshot::channel(); + + { + let device = device.clone(); + + rayon::spawn(move || { + let (tx, rx) = std::sync::mpsc::channel(); + let setup = future::block_on(create_setup_for_device::< + AutoGraphicsApi, + WgslCompiler, + >(&device)); + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let options = RuntimeOptions::default(); + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(tx.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let mut server = crate::compute::WgpuServerInner::new( + memory_management, + setup.device.clone(), + setup.queue, + options.tasks_max, + ); + let channel = MutexComputeChannel::new(WgpuServer::new(tx)); + + let features = setup.adapter.features(); + let mut device_props = DeviceProperties::new(&[], mem_props); + + if features.contains(wgpu::Features::SUBGROUP) + && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu + { + device_props.register_feature(Feature::Subcube); + } + C::register_features(&setup.adapter, &setup.device, &mut device_props); + + tx_once + .send(ComputeClient::new(channel, device_props)) + .expect("Failed to send back client to the calling thread"); + + server.handle_commands(rx) + }); + } + + loop { + match rx_once.try_recv() { + Ok(client) => { + if let Some(client) = client { + return client; + } + } + Err(_) => panic!("Failed to get the client from the wgpu thread"), + } + } + } } /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { let device_id = WgpuDevice::Existing(setup.device.as_ref().global_id()); let client = create_client_on_setup(setup, options); @@ -109,6 +182,7 @@ pub fn init_device(setup: WgpuSetup, options: RuntimeOptions) -> WgpuDevice { /// Like [`init_setup_async`], but synchronous. /// On wasm, it is necessary to use [`init_setup_async`] instead. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub fn init_setup(device: &WgpuDevice, options: RuntimeOptions) -> WgpuSetup { cfg_if::cfg_if! { if #[cfg(target_family = "wasm")] { @@ -123,6 +197,7 @@ pub fn init_setup(device: &WgpuDevice, options: RuntimeOptions) /// Initialize a client on the given device with the given options. /// This function is useful to configure the runtime options /// or to pick a different graphics API. +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub async fn init_setup_async( device: &WgpuDevice, options: RuntimeOptions, @@ -134,6 +209,7 @@ pub async fn init_setup_async( return_setup } +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] pub(crate) fn create_client_on_setup( setup: WgpuSetup, options: RuntimeOptions, @@ -185,10 +261,10 @@ pub(crate) async fn create_setup_for_device( ); WgpuSetup { - instance: Arc::new(instance), - adapter: Arc::new(adapter), - device: Arc::new(device), - queue: Arc::new(queue), + instance: Pdrc::new(instance), + adapter: Pdrc::new(adapter), + device: Pdrc::new(device), + queue: Pdrc::new(queue), } } From afbf438169153917959d1cb646d632682eb6af3d Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Fri, 8 Nov 2024 14:33:57 +0530 Subject: [PATCH 02/10] Fixed use of `TuneCacheResult::Unchecked` when `autotune_persistent_cache` is disabled --- crates/cubecl-runtime/src/tune/tune_cache.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/crates/cubecl-runtime/src/tune/tune_cache.rs b/crates/cubecl-runtime/src/tune/tune_cache.rs index a2c5e0d5f..223d46b63 100644 --- a/crates/cubecl-runtime/src/tune/tune_cache.rs +++ b/crates/cubecl-runtime/src/tune/tune_cache.rs @@ -117,11 +117,12 @@ impl TuneCache { } => { if cfg!(autotune_persistent_cache) { match checksum_matches { - None => TuneCacheResult::Unchecked, // Don't know yet. - Some(false) => TuneCacheResult::Miss, // Can't use this. + #[cfg(autotune_persistent_cache)] + None => TuneCacheResult::Unchecked, // Don't know yet. Some(true) => TuneCacheResult::Hit { fastest_index: *fastest_index, }, + _ => TuneCacheResult::Miss, // Some(false) or None so we can't use this. } } else { let _ = checksum_matches; From 388a9a3d3368b4644426524c510705b7c6d54a46 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Wed, 13 Nov 2024 13:33:48 +0530 Subject: [PATCH 03/10] Switched to futures block_on to wait for response from the server --- crates/cubecl-wgpu/src/compute/server.rs | 112 +++++++--------------- crates/cubecl-wgpu/src/compute/storage.rs | 32 ++----- crates/cubecl-wgpu/src/runtime.rs | 13 +-- 3 files changed, 44 insertions(+), 113 deletions(-) diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 00aa7dde5..059de10c5 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -806,20 +806,15 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::GetResource { tx, binding }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(binding) => { - if let Some(binding) = binding { - return binding; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if let Ok(binding) = futures::executor::block_on(rx) { + binding + } else { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -837,7 +832,7 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::Create { tx, @@ -847,15 +842,10 @@ impl ComputeServer for WgpuServer { }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(handle) => { - if let Some(handle) = handle { - return handle; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if let Ok(handle) = futures::executor::block_on(rx) { + handle + } else { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -868,20 +858,15 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::Empty { tx, size }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(handle) => { - if let Some(handle) = handle { - return handle; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if let Ok(handle) = futures::executor::block_on(rx) { + handle + } else { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -900,7 +885,7 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::Execute { tx, @@ -911,15 +896,8 @@ impl ComputeServer for WgpuServer { }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(response) => { - if response.is_some() { - break; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if futures::executor::block_on(rx).is_err() { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -932,20 +910,13 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::Flush { tx }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(response) => { - if response.is_some() { - break; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if futures::executor::block_on(rx).is_err() { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -998,20 +969,15 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::MemoryUsage { tx }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(memory_usage) => { - if let Some(memory_usage) = memory_usage { - return memory_usage; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if let Ok(memory_usage) = futures::executor::block_on(rx) { + memory_usage + } else { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -1024,20 +990,13 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::EnableTimestamps { tx }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(response) => { - if response.is_some() { - break; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if futures::executor::block_on(rx).is_err() { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -1050,20 +1009,13 @@ impl ComputeServer for WgpuServer { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::DisableTimestamps { tx }) .expect("Failed to send the message to the WgpuServerInner"); - loop { - match rx.try_recv() { - Ok(response) => { - if response.is_some() { - break; - } - } - Err(_) => panic!("Failed to receive the response from the WgpuServerInner"), - } + if futures::executor::block_on(rx).is_err() { + panic!("Failed to receive the response from the WgpuServerInner") } } } diff --git a/crates/cubecl-wgpu/src/compute/storage.rs b/crates/cubecl-wgpu/src/compute/storage.rs index df80289fb..a7cda6117 100644 --- a/crates/cubecl-wgpu/src/compute/storage.rs +++ b/crates/cubecl-wgpu/src/compute/storage.rs @@ -140,7 +140,7 @@ mod _impl { /// Actually deallocates buffers tagged to be deallocated. pub fn perform_deallocations(&mut self) { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::PerformDeallocations { tx, @@ -148,17 +148,8 @@ mod _impl { }) .expect("Failed to send command to the wgpu server"); - loop { - match rx.try_recv() { - Ok(response) => { - if response.is_some() { - break; - } - } - Err(err) => { - panic!("Failed to receive the response from the wgpu server: {err}") - } - } + if futures::executor::block_on(rx).is_err() { + panic!("Failed to receive the response from the WgpuServerInner") } } } @@ -177,22 +168,15 @@ mod _impl { } fn alloc(&mut self, size: u64) -> StorageHandle { - let (tx, mut rx) = futures::channel::oneshot::channel(); + let (tx, rx) = futures::channel::oneshot::channel(); self.tx .send(ServerCommand::Alloc { tx, size }) .expect("Failed to send command to the wgpu server"); - loop { - match rx.try_recv() { - Ok(handle) => { - if let Some(handle) = handle { - return handle; - } - } - Err(err) => { - panic!("Failed to receive the response from the wgpu server: {err}") - } - } + if let Ok(handle) = futures::executor::block_on(rx) { + handle + } else { + panic!("Failed to receive the response from the WgpuServerInner") } } diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index eda0225c9..54e9bfdcc 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -157,15 +157,10 @@ pub fn create_client( }); } - loop { - match rx_once.try_recv() { - Ok(client) => { - if let Some(client) = client { - return client; - } - } - Err(_) => panic!("Failed to get the client from the wgpu thread"), - } + if let Ok(client) = futures::executor::block_on(rx_once) { + client + } else { + panic!("Failed to get the client from the wgpu thread") } } } From feab17f95c08018c7ac96faa6d51af6865ac645e Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Thu, 14 Nov 2024 11:35:59 +0530 Subject: [PATCH 04/10] Switched to thread local runtime --- crates/cubecl-runtime/src/channel/base.rs | 8 +- crates/cubecl-runtime/src/channel/cell.rs | 7 +- crates/cubecl-runtime/src/channel/mpsc.rs | 6 +- crates/cubecl-runtime/src/channel/mutex.rs | 2 +- crates/cubecl-runtime/src/server.rs | 8 +- crates/cubecl-runtime/src/storage/base.rs | 4 +- crates/cubecl-runtime/src/tune/tuner.rs | 2 +- crates/cubecl-wgpu/src/compiler/base.rs | 6 +- .../cubecl-wgpu/src/compiler/wgsl/compiler.rs | 7 +- crates/cubecl-wgpu/src/compute/server.rs | 467 +----------------- crates/cubecl-wgpu/src/compute/storage.rs | 273 +++------- crates/cubecl-wgpu/src/runtime.rs | 244 ++++++--- 12 files changed, 281 insertions(+), 753 deletions(-) diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index 2ab8a9759..d7b069e81 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -10,9 +10,9 @@ use alloc::vec::Vec; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety -pub trait ComputeChannel: Clone + core::fmt::Debug + Send + Sync { +pub trait ComputeChannel: Clone + core::fmt::Debug { /// Given a binding, returns owned resource as bytes - fn read(&self, binding: Binding) -> impl Future> + Send; + fn read(&self, binding: Binding) -> impl Future>; /// Given a resource handle, return the storage resource. fn get_resource(&self, binding: Binding) -> BindingResource; @@ -40,12 +40,12 @@ pub trait ComputeChannel: Clone + core::fmt::Debug + Send fn flush(&self); /// Wait for the completion of every task in the server. - fn sync(&self) -> impl Future + Send; + fn sync(&self) -> impl Future; /// Wait for the completion of every task in the server. /// /// Returns the (approximate) total amount of GPU work done since the last sync. - fn sync_elapsed(&self) -> impl Future + Send; + fn sync_elapsed(&self) -> impl Future; /// Get the current memory usage of the server. fn memory_usage(&self) -> crate::memory_management::MemoryUsage; diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index 83e99a282..f0e2b5b9d 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -42,7 +42,7 @@ where impl ComputeChannel for RefCellComputeChannel where - Server: ComputeServer + Send, + Server: ComputeServer, { async fn read(&self, binding: Binding) -> Vec { let future = { @@ -108,8 +108,3 @@ where self.server.borrow_mut().disable_timestamps(); } } - -/// This is unsafe, since no concurrency is supported by the `RefCell` channel. -/// However using this channel should only be done in single threaded environments such as `no-std`. -unsafe impl Send for RefCellComputeChannel {} -unsafe impl Sync for RefCellComputeChannel {} diff --git a/crates/cubecl-runtime/src/channel/mpsc.rs b/crates/cubecl-runtime/src/channel/mpsc.rs index 1b1a9e546..66a3cda6a 100644 --- a/crates/cubecl-runtime/src/channel/mpsc.rs +++ b/crates/cubecl-runtime/src/channel/mpsc.rs @@ -6,7 +6,7 @@ use super::ComputeChannel; use crate::{ memory_management::MemoryUsage, server::{Binding, ComputeServer, CubeCount, Handle}, - storage::BindingResource, + storage::{BindingResource, ComputeStorage}, ExecutionMode, }; @@ -50,7 +50,8 @@ where impl MpscComputeChannel where - Server: ComputeServer + 'static, + Server: ComputeServer + Send + 'static, + ::Resource: Send, { /// Create a new mpsc compute channel. pub fn new(mut server: Server) -> Self { @@ -123,6 +124,7 @@ impl Clone for MpscComputeChannel { impl ComputeChannel for MpscComputeChannel where Server: ComputeServer + 'static, + ::Resource: Send, { async fn read(&self, binding: Binding) -> Vec { let sender = self.state.sender.clone(); diff --git a/crates/cubecl-runtime/src/channel/mutex.rs b/crates/cubecl-runtime/src/channel/mutex.rs index 8d4dff53f..adbc2a07d 100644 --- a/crates/cubecl-runtime/src/channel/mutex.rs +++ b/crates/cubecl-runtime/src/channel/mutex.rs @@ -35,7 +35,7 @@ where impl ComputeChannel for MutexComputeChannel where - Server: ComputeServer, + Server: ComputeServer + Send, { async fn read(&self, handle: Binding) -> Vec { // Nb: The order here is really important - the mutex guard has to be dropped before diff --git a/crates/cubecl-runtime/src/server.rs b/crates/cubecl-runtime/src/server.rs index 45a06a8c9..027cd69fd 100644 --- a/crates/cubecl-runtime/src/server.rs +++ b/crates/cubecl-runtime/src/server.rs @@ -14,7 +14,7 @@ use cubecl_common::benchmark::TimestampsResult; /// /// Everything in the server is mutable, therefore it should be solely accessed through the /// [compute channel](crate::channel::ComputeChannel) for thread safety. -pub trait ComputeServer: Send + core::fmt::Debug +pub trait ComputeServer: core::fmt::Debug where Self: Sized, { @@ -26,7 +26,7 @@ where type Feature: Ord + Copy + Debug + Send + Sync; /// Given a handle, returns the owned resource as bytes. - fn read(&mut self, binding: Binding) -> impl Future> + Send + 'static; + fn read(&mut self, binding: Binding) -> impl Future> + 'static; /// Given a resource handle, returns the storage resource. fn get_resource(&mut self, binding: Binding) -> BindingResource; @@ -57,12 +57,12 @@ where fn flush(&mut self); /// Wait for the completion of every task in the server. - fn sync(&mut self) -> impl Future + Send + 'static; + fn sync(&mut self) -> impl Future + 'static; /// Wait for the completion of every task in the server. /// /// Returns the (approximate) total amount of GPU work done since the last sync. - fn sync_elapsed(&mut self) -> impl Future + Send + 'static; + fn sync_elapsed(&mut self) -> impl Future + 'static; /// The current memory usage of the server. fn memory_usage(&self) -> MemoryUsage; diff --git a/crates/cubecl-runtime/src/storage/base.rs b/crates/cubecl-runtime/src/storage/base.rs index 0f50eb8da..9737404c2 100644 --- a/crates/cubecl-runtime/src/storage/base.rs +++ b/crates/cubecl-runtime/src/storage/base.rs @@ -63,10 +63,10 @@ impl StorageHandle { } /// Storage types are responsible for allocating and deallocating memory. -pub trait ComputeStorage: Send { +pub trait ComputeStorage { /// The resource associated type determines the way data is implemented and how /// it can be accessed by kernels. - type Resource: Send; + type Resource; /// The alignment memory is allocated with in this storage. const ALIGNMENT: u64; diff --git a/crates/cubecl-runtime/src/tune/tuner.rs b/crates/cubecl-runtime/src/tune/tuner.rs index b3aacc331..4854bc9d7 100644 --- a/crates/cubecl-runtime/src/tune/tuner.rs +++ b/crates/cubecl-runtime/src/tune/tuner.rs @@ -224,7 +224,7 @@ impl Tuner { } } -fn spawn_benchmark_task(future: impl Future + Send + 'static) { +fn spawn_benchmark_task(future: impl Future + 'static) { // On wasm, spawn the tuning as a detached task. #[cfg(target_family = "wasm")] wasm_bindgen_futures::spawn_local(future); diff --git a/crates/cubecl-wgpu/src/compiler/base.rs b/crates/cubecl-wgpu/src/compiler/base.rs index a924dcca3..05ca30eb9 100644 --- a/crates/cubecl-wgpu/src/compiler/base.rs +++ b/crates/cubecl-wgpu/src/compiler/base.rs @@ -4,17 +4,17 @@ use cubecl_core::{ use cubecl_runtime::DeviceProperties; use wgpu::{Adapter, ComputePipeline, Device, Queue}; -use crate::{Pdrc, WgpuServer, WgpuServerInner}; +use crate::{Pdrc, WgpuServer}; pub trait WgpuCompiler: Compiler { fn compile( - server: &mut WgpuServerInner, + server: &mut WgpuServer, kernel: as ComputeServer>::Kernel, mode: ExecutionMode, ) -> CompiledKernel; fn create_pipeline( - server: &mut WgpuServerInner, + server: &mut WgpuServer, kernel: CompiledKernel, mode: ExecutionMode, ) -> Pdrc; diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs index 58ac40325..a00395db9 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs @@ -4,9 +4,8 @@ use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory}; use super::{LocalArray, Subgroup}; use crate::{ compiler::{base::WgpuCompiler, wgsl}, - WgpuServer, + Pdrc, WgpuServer, }; -use crate::{Pdrc, WgpuServerInner}; use cubecl_core::{ ir::{self as cube, HybridAllocator, UIntKind}, prelude::CompiledKernel, @@ -71,7 +70,7 @@ impl cubecl_core::Compiler for WgslCompiler { impl WgpuCompiler for WgslCompiler { fn create_pipeline( - server: &mut WgpuServerInner, + server: &mut WgpuServer, kernel: CompiledKernel, mode: ExecutionMode, ) -> Pdrc { @@ -137,7 +136,7 @@ impl WgpuCompiler for WgslCompiler { } fn compile( - _server: &mut WgpuServerInner, + _server: &mut WgpuServer, kernel: as ComputeServer>::Kernel, mode: ExecutionMode, ) -> CompiledKernel { diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 059de10c5..7fc23f980 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -1,5 +1,3 @@ -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -use std::rc::Rc; use std::{future::Future, marker::PhantomData, num::NonZero, pin::Pin, time::Duration}; use super::poll::WgpuPoll; @@ -7,11 +5,9 @@ use super::WgpuStorage; use crate::{compiler::base::WgpuCompiler, Pdrc}; use cubecl_common::future; use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId}; -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -use cubecl_runtime::storage::{StorageHandle, StorageId, StorageUtilization}; use cubecl_runtime::{ debug::{DebugLogger, ProfileLevel}, - memory_management::{MemoryHandle, MemoryLock, MemoryManagement, MemoryUsage}, + memory_management::{MemoryHandle, MemoryLock, MemoryManagement}, server::{self, ComputeServer}, storage::{BindingResource, ComputeStorage}, ExecutionMode, TimestampsError, TimestampsResult, @@ -25,11 +21,8 @@ use wgpu::{ /// Wgpu compute server. #[derive(Debug)] -pub struct WgpuServerInner { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +pub struct WgpuServer { memory_management: MemoryManagement, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - memory_management: MemoryManagement>, pub(crate) device: Pdrc, queue: Pdrc, encoder: CommandEncoder, @@ -42,80 +35,9 @@ pub struct WgpuServerInner { storage_locked: MemoryLock, duration_profiled: Option, timestamps: KernelTimestamps, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - memory: HashMap>, _compiler: PhantomData, } -#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] -pub type WgpuServer = WgpuServerInner; - -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -#[derive(Debug)] -pub struct WgpuServer { - tx: std::sync::mpsc::Sender>, -} - -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -impl WgpuServer { - pub fn new(tx: std::sync::mpsc::Sender>) -> Self { - WgpuServer { tx } - } -} - -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -pub enum ServerCommand { - Read { - tx: futures::channel::oneshot::Sender>, - binding: server::Binding, - }, - GetResource { - tx: futures::channel::oneshot::Sender>>, - binding: server::Binding, - }, - Create { - tx: futures::channel::oneshot::Sender, - data: &'static [u8], - }, - Empty { - tx: futures::channel::oneshot::Sender, - size: usize, - }, - Execute { - tx: futures::channel::oneshot::Sender<()>, - kernel: as ComputeServer>::Kernel, - count: CubeCount, - bindings: Vec, - mode: ExecutionMode, - }, - Flush { - tx: futures::channel::oneshot::Sender<()>, - }, - Sync { - tx: futures::channel::oneshot::Sender<()>, - }, - SyncElapsed { - tx: futures::channel::oneshot::Sender, - }, - MemoryUsage { - tx: futures::channel::oneshot::Sender, - }, - EnableTimestamps { - tx: futures::channel::oneshot::Sender<()>, - }, - DisableTimestamps { - tx: futures::channel::oneshot::Sender<()>, - }, - Alloc { - tx: futures::channel::oneshot::Sender, - size: u64, - }, - PerformDeallocations { - tx: futures::channel::oneshot::Sender<()>, - deallocations: Vec, - }, -} - trait FutureWasmNotSend: Future + WasmNotSend {} impl + WasmNotSend> FutureWasmNotSend for T {} @@ -162,13 +84,10 @@ fn create_encoder(device: &wgpu::Device) -> CommandEncoder { }) } -impl WgpuServerInner { +impl WgpuServer { /// Create a new server. pub fn new( - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] memory_management: MemoryManagement, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - memory_management: MemoryManagement>, device: Pdrc, queue: Pdrc, tasks_max: usize, @@ -194,8 +113,6 @@ impl WgpuServerInner { poll: WgpuPoll::new(device.clone()), duration_profiled: None, timestamps, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - memory: HashMap::default(), _compiler: PhantomData, } } @@ -383,25 +300,19 @@ impl WgpuServerInner { } } } +} - fn read( - &mut self, - binding: server::Binding, - ) -> impl Future> + WasmNotSend + 'static { +impl ComputeServer for WgpuServer { + type Kernel = Box>; + type Storage = WgpuStorage; + type Feature = Feature; + + fn read(&mut self, binding: server::Binding) -> impl Future> + 'static { let rb = self.get_resource(binding); let resource = rb.resource(); self.clear_compute_pass(); - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - let buffer = &resource.buffer; - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - let buffer = &self - .memory - .get(&resource.buffer) - .expect("Buffer does not exist in the wgpu server memory") - .clone(); - - self.read_wgpu_buffer(buffer, resource.offset(), resource.size()) + self.read_wgpu_buffer(&resource.buffer, resource.offset(), resource.size()) } fn get_resource(&mut self, binding: server::Binding) -> BindingResource> { @@ -453,16 +364,7 @@ impl WgpuServerInner { // Write to the staging buffer. Next queue submission this will copy the data to the GPU. self.queue - .write_buffer_with( - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - &resource.buffer, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - self.memory - .get(&resource.buffer) - .expect("Buffer does not exist in the wgpu server memory"), - resource.offset(), - len, - ) + .write_buffer_with(&resource.buffer, resource.offset(), len) .expect("Failed to write to staging buffer.")[0..data.len()] .copy_from_slice(data); } @@ -520,10 +422,7 @@ impl WgpuServerInner { .enumerate() .map(|(i, r)| wgpu::BindGroupEntry { binding: i as u32, - resource: r.resource().as_wgpu_bind_resource( - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - &self.memory, - ), + resource: r.resource().as_wgpu_bind_resource(), }) .collect::>(); let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { @@ -578,12 +477,7 @@ impl WgpuServerInner { CubeCount::Dynamic(_) => { let binding_resource = dispatch_br.as_ref().unwrap(); pass.dispatch_workgroups_indirect( - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] &binding_resource.resource().buffer, - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - self.memory - .get(&binding_resource.resource().buffer) - .expect("Buffer does not exist in the wgpu server memory"), binding_resource.resource().offset(), ); } @@ -684,339 +578,4 @@ impl WgpuServerInner { self.timestamps.disable(); } } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - pub fn handle_commands(&mut self, rx: std::sync::mpsc::Receiver>) -> ! { - loop { - match rx.recv() { - Ok(command) => match command { - ServerCommand::Read { tx, binding } => tx - .send(futures::executor::block_on(self.read(binding))) - .expect("Failed to send response"), - ServerCommand::GetResource { tx, binding } => { - if tx.send(self.get_resource(binding)).is_err() { - panic!("Failed to send response") - } - } - ServerCommand::Create { tx, data } => { - tx.send(self.create(data)).expect("Failed to send response") - } - ServerCommand::Empty { tx, size } => { - tx.send(self.empty(size)).expect("Failed to send response") - } - ServerCommand::Execute { - tx, - kernel, - count, - bindings, - mode, - } => tx - .send(unsafe { self.execute(kernel, count, bindings, mode) }) - .expect("Failed to send response"), - ServerCommand::Flush { tx } => { - tx.send(self.flush()).expect("Failed to send response") - } - ServerCommand::Sync { tx } => tx - .send(futures::executor::block_on(self.sync())) - .expect("Failed to send response"), - ServerCommand::SyncElapsed { tx } => tx - .send(futures::executor::block_on(self.sync_elapsed())) - .expect("Failed to send response"), - ServerCommand::MemoryUsage { tx } => tx - .send(self.memory_usage()) - .expect("Failed to send response"), - ServerCommand::EnableTimestamps { tx } => tx - .send(self.enable_timestamps()) - .expect("Failed to send response"), - ServerCommand::DisableTimestamps { tx } => tx - .send(self.disable_timestamps()) - .expect("Failed to send response"), - ServerCommand::Alloc { tx, size } => { - let id = StorageId::new(); - let buffer = self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC - | wgpu::BufferUsages::INDIRECT, - mapped_at_creation: false, - }); - - self.memory.insert(id, Rc::new(buffer)); - - tx.send(StorageHandle::new( - id, - StorageUtilization { offset: 0, size }, - )) - .expect("Failed to send response"); - } - ServerCommand::PerformDeallocations { - tx, - mut deallocations, - } => { - for id in deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() - } - } - - tx.send(()).expect("Failed to send response"); - } - }, - Err(err) => log::error!("Failed to receive command: {err}"), - } - } - } -} - -impl ComputeServer for WgpuServer { - type Kernel = Box>; - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - type Storage = WgpuStorage; - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - type Storage = WgpuStorage; - type Feature = Feature; - - fn read(&mut self, binding: server::Binding) -> impl Future> + Send + 'static { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.read(binding) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Read { tx, binding }) - .expect("Failed to send the message to the WgpuServerInner"); - - async move { - rx.await - .expect("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn get_resource(&mut self, binding: server::Binding) -> BindingResource { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.get_resource(binding) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::GetResource { tx, binding }) - .expect("Failed to send the message to the WgpuServerInner"); - - if let Ok(binding) = futures::executor::block_on(rx) { - binding - } else { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - /// When we create a new handle from existing data, we use custom allocations so that we don't - /// have to execute the current pending tasks. - /// - /// This is important, otherwise the compute passes are going to be too small and we won't be able to - /// fully utilize the GPU. - fn create(&mut self, data: &[u8]) -> server::Handle { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.create(data) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Create { - tx, - // Safety: Since we wait for the execution of the command to finish below - // we can be sure that this data will not disappear - data: unsafe { std::mem::transmute(data) }, - }) - .expect("Failed to send the message to the WgpuServerInner"); - - if let Ok(handle) = futures::executor::block_on(rx) { - handle - } else { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn empty(&mut self, size: usize) -> server::Handle { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.empty(size) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Empty { tx, size }) - .expect("Failed to send the message to the WgpuServerInner"); - - if let Ok(handle) = futures::executor::block_on(rx) { - handle - } else { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - unsafe fn execute( - &mut self, - kernel: Self::Kernel, - count: CubeCount, - bindings: Vec, - mode: ExecutionMode, - ) { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.execute(kernel, count, bindings, mode) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Execute { - tx, - kernel, - count, - bindings, - mode, - }) - .expect("Failed to send the message to the WgpuServerInner"); - - if futures::executor::block_on(rx).is_err() { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn flush(&mut self) { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.flush() - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Flush { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - if futures::executor::block_on(rx).is_err() { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - /// Returns the total time of GPU work this sync completes. - fn sync(&mut self) -> impl Future + 'static { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.sync() - } - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Sync { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - async move { - rx.await - .expect("Failed to receive the response from the WgpuServerInner") - } - } - } - - /// Returns the total time of GPU work this sync completes. - fn sync_elapsed(&mut self) -> impl Future + 'static { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.sync_elapsed() - } - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::SyncElapsed { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - async move { - rx.await - .expect("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn memory_usage(&self) -> MemoryUsage { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.memory_usage() - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::MemoryUsage { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - if let Ok(memory_usage) = futures::executor::block_on(rx) { - memory_usage - } else { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn enable_timestamps(&mut self) { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.enable_timestamps() - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::EnableTimestamps { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - if futures::executor::block_on(rx).is_err() { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - fn disable_timestamps(&mut self) { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - self.disable_timestamps() - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::DisableTimestamps { tx }) - .expect("Failed to send the message to the WgpuServerInner"); - - if futures::executor::block_on(rx).is_err() { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } } diff --git a/crates/cubecl-wgpu/src/compute/storage.rs b/crates/cubecl-wgpu/src/compute/storage.rs index a7cda6117..b4359893c 100644 --- a/crates/cubecl-wgpu/src/compute/storage.rs +++ b/crates/cubecl-wgpu/src/compute/storage.rs @@ -1,226 +1,107 @@ -use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId}; +use cubecl_runtime::storage::{ComputeStorage, StorageHandle, StorageId, StorageUtilization}; use hashbrown::HashMap; use std::num::NonZeroU64; -#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] -mod _impl { - use std::sync::Arc; +use crate::Pdrc; - use cubecl_runtime::storage::StorageUtilization; - - use super::*; +/// Buffer storage for wgpu. +pub struct WgpuStorage { + memory: HashMap>, + deallocations: Vec, + device: Pdrc, +} - /// Buffer storage for wgpu. - pub struct WgpuStorage { - memory: HashMap>, - deallocations: Vec, - device: Arc, +impl core::fmt::Debug for WgpuStorage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) } +} - impl core::fmt::Debug for WgpuStorage { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(format!("WgpuStorage {{ device: {:?} }}", self.device).as_str()) +/// Keeps actual wgpu buffer references in a hashmap with ids as key. +impl WgpuStorage { + /// Create a new storage on the given [device](wgpu::Device). + pub fn new(device: Pdrc) -> Self { + Self { + memory: HashMap::new(), + deallocations: Vec::new(), + device, } } - /// Keeps actual wgpu buffer references in a hashmap with ids as key. - impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(device: Arc) -> Self { - Self { - memory: HashMap::new(), - deallocations: Vec::new(), - device, - } - } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - for id in self.deallocations.drain(..) { - if let Some(buffer) = self.memory.remove(&id) { - buffer.destroy() - } + /// Actually deallocates buffers tagged to be deallocated. + pub fn perform_deallocations(&mut self) { + for id in self.deallocations.drain(..) { + if let Some(buffer) = self.memory.remove(&id) { + buffer.destroy() } } } +} - impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; - - // 32 bytes is enough to handle a double4 worth of alignment. - // See: https://github.com/gfx-rs/wgpu/issues/3508 - // NB: cudamalloc and co. actually align to _256_ bytes. Worth - // trying this in the future to see if it reduces memory coalescing. - const ALIGNMENT: u64 = 32; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - let buffer = self.memory.get(&handle.id).unwrap(); - WgpuResource::new(buffer.clone(), handle.offset(), handle.size()) - } - - fn alloc(&mut self, size: u64) -> StorageHandle { - let id = StorageId::new(); - let buffer = Arc::new(self.device.create_buffer(&wgpu::BufferDescriptor { - label: None, - size, - usage: wgpu::BufferUsages::COPY_DST - | wgpu::BufferUsages::STORAGE - | wgpu::BufferUsages::COPY_SRC - | wgpu::BufferUsages::INDIRECT, - mapped_at_creation: false, - })); +impl ComputeStorage for WgpuStorage { + type Resource = WgpuResource; - self.memory.insert(id, buffer); - StorageHandle::new(id, StorageUtilization { offset: 0, size }) - } + // 32 bytes is enough to handle a double4 worth of alignment. + // See: https://github.com/gfx-rs/wgpu/issues/3508 + // NB: cudamalloc and co. actually align to _256_ bytes. Worth + // trying this in the future to see if it reduces memory coalescing. + const ALIGNMENT: u64 = 32; - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); - } + fn get(&mut self, handle: &StorageHandle) -> Self::Resource { + let buffer = self.memory.get(&handle.id).unwrap(); + WgpuResource::new(buffer.clone(), handle.offset(), handle.size()) } - /// The memory resource that can be allocated for wgpu. - #[derive(new)] - pub struct WgpuResource { - /// The wgpu buffer. - pub buffer: Arc, - - offset: u64, - size: u64, + fn alloc(&mut self, size: u64) -> StorageHandle { + let id = StorageId::new(); + let buffer = Pdrc::new(self.device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage: wgpu::BufferUsages::COPY_DST + | wgpu::BufferUsages::STORAGE + | wgpu::BufferUsages::COPY_SRC + | wgpu::BufferUsages::INDIRECT, + mapped_at_creation: false, + })); + + self.memory.insert(id, buffer); + StorageHandle::new(id, StorageUtilization { offset: 0, size }) } - impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { - let binding = wgpu::BufferBinding { - buffer: &self.buffer, - offset: self.offset, - size: Some( - NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), - ), - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - self.size - } - - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - self.offset - } + fn dealloc(&mut self, id: StorageId) { + self.deallocations.push(id); } } -#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] -mod _impl { - use std::rc::Rc; - - use crate::{compiler::base::WgpuCompiler, compute::server::ServerCommand}; - - use super::*; - - /// Buffer storage for wgpu. - pub struct WgpuStorage { - deallocations: Vec, - tx: std::sync::mpsc::Sender>, - } - - /// Keeps actual wgpu buffer references in a hashmap with ids as key. - impl WgpuStorage { - /// Create a new storage on the given [device](wgpu::Device). - pub fn new(tx: std::sync::mpsc::Sender>) -> Self { - Self { - deallocations: Vec::new(), - tx, - } - } - - /// Actually deallocates buffers tagged to be deallocated. - pub fn perform_deallocations(&mut self) { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::PerformDeallocations { - tx, - deallocations: self.deallocations.drain(..).collect(), - }) - .expect("Failed to send command to the wgpu server"); - - if futures::executor::block_on(rx).is_err() { - panic!("Failed to receive the response from the WgpuServerInner") - } - } - } - - impl ComputeStorage for WgpuStorage { - type Resource = WgpuResource; - - // 32 bytes is enough to handle a double4 worth of alignment. - // See: https://github.com/gfx-rs/wgpu/issues/3508 - // NB: cudamalloc and co. actually align to _256_ bytes. Worth - // trying this in the future to see if it reduces memory coalescing. - const ALIGNMENT: u64 = 32; - - fn get(&mut self, handle: &StorageHandle) -> Self::Resource { - WgpuResource::new(handle.id, handle.offset(), handle.size()) - } - - fn alloc(&mut self, size: u64) -> StorageHandle { - let (tx, rx) = futures::channel::oneshot::channel(); - self.tx - .send(ServerCommand::Alloc { tx, size }) - .expect("Failed to send command to the wgpu server"); - - if let Ok(handle) = futures::executor::block_on(rx) { - handle - } else { - panic!("Failed to receive the response from the WgpuServerInner") - } - } +/// The memory resource that can be allocated for wgpu. +#[derive(new)] +pub struct WgpuResource { + /// The wgpu buffer. + pub buffer: Pdrc, + offset: u64, + size: u64, +} - fn dealloc(&mut self, id: StorageId) { - self.deallocations.push(id); - } +impl WgpuResource { + /// Return the binding view of the buffer. + pub fn as_wgpu_bind_resource(&self) -> wgpu::BindingResource { + let binding = wgpu::BufferBinding { + buffer: &self.buffer, + offset: self.offset, + size: Some( + NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), + ), + }; + wgpu::BindingResource::Buffer(binding) } - /// The memory resource that can be allocated for wgpu. - #[derive(new)] - pub struct WgpuResource { - /// The storage id. - pub buffer: StorageId, - offset: u64, - size: u64, + /// Return the buffer size. + pub fn size(&self) -> u64 { + self.size } - impl WgpuResource { - /// Return the binding view of the buffer. - pub fn as_wgpu_bind_resource<'a>( - &self, - buffers: &'a HashMap>, - ) -> wgpu::BindingResource<'a> { - let buffer = buffers.get(&self.buffer).expect("Buffer does not exist"); - let binding = wgpu::BufferBinding { - buffer, - offset: self.offset, - size: Some( - NonZeroU64::new(self.size).expect("0 size resources are not yet supported."), - ), - }; - wgpu::BindingResource::Buffer(binding) - } - - /// Return the buffer size. - pub fn size(&self) -> u64 { - self.size - } - - /// Return the buffer offset. - pub fn offset(&self) -> u64 { - self.offset - } + /// Return the buffer offset. + pub fn offset(&self) -> u64 { + self.offset } } - -pub use _impl::*; diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 54e9bfdcc..8db55462b 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -1,4 +1,6 @@ use std::marker::PhantomData; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use std::{cell::RefCell, rc::Rc}; use crate::{ compiler::{base::WgpuCompiler, wgsl::WgslCompiler}, @@ -6,10 +8,12 @@ use crate::{ AutoGraphicsApi, GraphicsApi, Pdrc, WgpuDevice, }; use cubecl_common::future; -use cubecl_core::{Feature, Runtime}; +use cubecl_core::{channel::ComputeChannel, server::ComputeServer, Feature, Runtime}; +use cubecl_runtime::client::ComputeClient; pub use cubecl_runtime::memory_management::MemoryConfiguration; use cubecl_runtime::DeviceProperties; -use cubecl_runtime::{channel::MutexComputeChannel, client::ComputeClient, ComputeRuntime}; +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +use cubecl_runtime::{channel::MutexComputeChannel, ComputeRuntime}; use cubecl_runtime::{ memory_management::{MemoryDeviceProperties, MemoryManagement}, storage::ComputeStorage, @@ -24,7 +28,13 @@ pub struct WgpuRuntime(PhantomData); type Server = WgpuServer; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +thread_local! { + static LOCAL_RUNTIME: RefCell>>> = RefCell::new(hashbrown::HashMap::default()); +} + /// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] static RUNTIME: ComputeRuntime> = ComputeRuntime::new(); @@ -32,11 +42,78 @@ impl Runtime for WgpuRuntime { type Compiler = WgslCompiler; type Server = Server; + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] type Channel = MutexComputeChannel; + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + type Channel = ThreadLocalChannel; type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { - RUNTIME.client(device, move || create_client(device)) + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { + RUNTIME.client(device, move || { + let setup = future::block_on(create_setup_for_device::< + AutoGraphicsApi, + WgslCompiler, + >(device)); + create_client_on_setup(setup, RuntimeOptions::default()) + }) + } + + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] + { + let setup = future::block_on(create_setup_for_device::( + device, + )); + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let options = RuntimeOptions::default(); + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device.clone(), + setup.queue, + options.tasks_max, + ); + + LOCAL_RUNTIME.with(|runtime| { + runtime + .borrow_mut() + .insert(device.clone(), Rc::new(RefCell::new(server))); + }); + + let features = setup.adapter.features(); + let mut device_props = DeviceProperties::new(&[], mem_props); + + if features.contains(wgpu::Features::SUBGROUP) + && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu + { + device_props.register_feature(Feature::Subcube); + } + ::register_features( + &setup.adapter, + &setup.device, + &mut device_props, + ); + + ComputeClient::new( + ThreadLocalChannel { + device: device.clone(), + }, + device_props, + ) + } } fn name() -> &'static str { @@ -92,79 +169,6 @@ pub struct WgpuSetup { pub queue: Pdrc, } -pub fn create_client( - device: &WgpuDevice, -) -> ComputeClient, MutexComputeChannel>> { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - let setup = future::block_on(create_setup_for_device::( - device, - )); - create_client_on_setup(setup, RuntimeOptions::default()) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let (tx_once, mut rx_once) = futures::channel::oneshot::channel(); - - { - let device = device.clone(); - - rayon::spawn(move || { - let (tx, rx) = std::sync::mpsc::channel(); - let setup = future::block_on(create_setup_for_device::< - AutoGraphicsApi, - WgslCompiler, - >(&device)); - - let limits = setup.device.limits(); - let mem_props = MemoryDeviceProperties { - max_page_size: limits.max_storage_buffer_binding_size as u64, - alignment: WgpuStorage::::ALIGNMENT - .max(limits.min_storage_buffer_offset_alignment as u64), - }; - - let options = RuntimeOptions::default(); - let memory_management = { - let mem_props = mem_props.clone(); - let config = options.memory_config; - let storage = WgpuStorage::new(tx.clone()); - MemoryManagement::from_configuration(storage, mem_props, config) - }; - let mut server = crate::compute::WgpuServerInner::new( - memory_management, - setup.device.clone(), - setup.queue, - options.tasks_max, - ); - let channel = MutexComputeChannel::new(WgpuServer::new(tx)); - - let features = setup.adapter.features(); - let mut device_props = DeviceProperties::new(&[], mem_props); - - if features.contains(wgpu::Features::SUBGROUP) - && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu - { - device_props.register_feature(Feature::Subcube); - } - C::register_features(&setup.adapter, &setup.device, &mut device_props); - - tx_once - .send(ComputeClient::new(channel, device_props)) - .expect("Failed to send back client to the calling thread"); - - server.handle_commands(rx) - }); - } - - if let Ok(client) = futures::executor::block_on(rx_once) { - client - } else { - panic!("Failed to get the client from the wgpu thread") - } - } -} - /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] @@ -409,3 +413,91 @@ fn get_device_override() -> Option { override_device }) } + +#[derive(Debug, Clone)] +pub struct ThreadLocalChannel { + device: WgpuDevice, +} + +impl ComputeChannel for ThreadLocalChannel { + fn read( + &self, + binding: cubecl_core::server::Binding, + ) -> impl std::future::Future> { + LOCAL_RUNTIME.with(|runtime| { + let server = runtime.borrow()[&self.device].clone(); + async move { server.borrow_mut().read(binding).await } + }) + } + + fn get_resource( + &self, + binding: cubecl_core::server::Binding, + ) -> cubecl_runtime::storage::BindingResource { + LOCAL_RUNTIME.with(|runtime| { + runtime.borrow()[&self.device] + .borrow_mut() + .get_resource(binding) + }) + } + + fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { + LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().create(data)) + } + + fn empty(&self, size: usize) -> cubecl_core::server::Handle { + LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().empty(size)) + } + + unsafe fn execute( + &self, + kernel: ::Kernel, + count: cubecl_core::CubeCount, + bindings: Vec, + mode: cubecl_core::ExecutionMode, + ) { + LOCAL_RUNTIME.with(|runtime| { + let runtime = runtime.borrow(); + let mut server = runtime[&self.device].borrow_mut(); + unsafe { server.execute(kernel, count, bindings, mode) } + }) + } + + fn flush(&self) { + LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().flush()) + } + + fn sync(&self) -> impl std::future::Future { + LOCAL_RUNTIME.with(|runtime| { + let server = runtime.borrow()[&self.device].clone(); + async move { server.borrow_mut().sync().await } + }) + } + + fn sync_elapsed(&self) -> impl std::future::Future { + LOCAL_RUNTIME.with(|runtime| { + let server = runtime.borrow()[&self.device].clone(); + async move { server.borrow_mut().sync_elapsed().await } + }) + } + + fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { + LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().memory_usage()) + } + + fn enable_timestamps(&self) { + LOCAL_RUNTIME.with(|runtime| { + runtime.borrow()[&self.device] + .borrow_mut() + .enable_timestamps() + }) + } + + fn disable_timestamps(&self) { + LOCAL_RUNTIME.with(|runtime| { + runtime.borrow()[&self.device] + .borrow_mut() + .disable_timestamps() + }) + } +} From cfe243e2d14d1c2644831d2249c605d5376098f0 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Thu, 14 Nov 2024 12:14:06 +0530 Subject: [PATCH 05/10] Added back the send bound --- crates/cubecl-runtime/src/channel/base.rs | 2 +- crates/cubecl-runtime/src/channel/cell.rs | 5 + crates/cubecl-wgpu/src/runtime.rs | 110 +++++++++++++++++++--- 3 files changed, 104 insertions(+), 13 deletions(-) diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index d7b069e81..60a4aa2f2 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -10,7 +10,7 @@ use alloc::vec::Vec; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety -pub trait ComputeChannel: Clone + core::fmt::Debug { +pub trait ComputeChannel: Send + Clone + core::fmt::Debug { /// Given a binding, returns owned resource as bytes fn read(&self, binding: Binding) -> impl Future>; diff --git a/crates/cubecl-runtime/src/channel/cell.rs b/crates/cubecl-runtime/src/channel/cell.rs index f0e2b5b9d..98a8bb83a 100644 --- a/crates/cubecl-runtime/src/channel/cell.rs +++ b/crates/cubecl-runtime/src/channel/cell.rs @@ -108,3 +108,8 @@ where self.server.borrow_mut().disable_timestamps(); } } + +/// This is unsafe, since no concurrency is supported by the `RefCell` channel. +/// However using this channel should only be done in single threaded environments such as `no-std`. +unsafe impl Send for RefCellComputeChannel {} +unsafe impl Sync for RefCellComputeChannel {} diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 8db55462b..2d50b2efc 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -414,18 +414,56 @@ fn get_device_override() -> Option { }) } +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] #[derive(Debug, Clone)] pub struct ThreadLocalChannel { device: WgpuDevice, } +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +impl ThreadLocalChannel { + fn make_server(device: &WgpuDevice) -> Rc> { + let setup = future::block_on(create_setup_for_device::( + device, + )); + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let options = RuntimeOptions::default(); + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device.clone(), + setup.queue, + options.tasks_max, + ); + + Rc::new(RefCell::new(server)) + } +} + +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] impl ComputeChannel for ThreadLocalChannel { fn read( &self, binding: cubecl_core::server::Binding, ) -> impl std::future::Future> { LOCAL_RUNTIME.with(|runtime| { - let server = runtime.borrow()[&self.device].clone(); + let server = runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .clone(); async move { server.borrow_mut().read(binding).await } }) } @@ -435,18 +473,35 @@ impl ComputeChannel for ThreadLocalChannel { binding: cubecl_core::server::Binding, ) -> cubecl_runtime::storage::BindingResource { LOCAL_RUNTIME.with(|runtime| { - runtime.borrow()[&self.device] + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) .borrow_mut() .get_resource(binding) }) } fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { - LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().create(data)) + LOCAL_RUNTIME.with(|runtime| { + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .borrow_mut() + .create(data) + }) } fn empty(&self, size: usize) -> cubecl_core::server::Handle { - LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().empty(size)) + LOCAL_RUNTIME.with(|runtime| { + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .borrow_mut() + .empty(size) + }) } unsafe fn execute( @@ -457,37 +512,65 @@ impl ComputeChannel for ThreadLocalChannel { mode: cubecl_core::ExecutionMode, ) { LOCAL_RUNTIME.with(|runtime| { - let runtime = runtime.borrow(); - let mut server = runtime[&self.device].borrow_mut(); + let mut runtime = runtime.borrow_mut(); + let mut server = runtime + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .borrow_mut(); unsafe { server.execute(kernel, count, bindings, mode) } }) } fn flush(&self) { - LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().flush()) + LOCAL_RUNTIME.with(|runtime| { + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .borrow_mut() + .flush() + }) } fn sync(&self) -> impl std::future::Future { LOCAL_RUNTIME.with(|runtime| { - let server = runtime.borrow()[&self.device].clone(); + let server = runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .clone(); async move { server.borrow_mut().sync().await } }) } fn sync_elapsed(&self) -> impl std::future::Future { LOCAL_RUNTIME.with(|runtime| { - let server = runtime.borrow()[&self.device].clone(); + let server = runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .clone(); async move { server.borrow_mut().sync_elapsed().await } }) } fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { - LOCAL_RUNTIME.with(|runtime| runtime.borrow()[&self.device].borrow_mut().memory_usage()) + LOCAL_RUNTIME.with(|runtime| { + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) + .borrow_mut() + .memory_usage() + }) } fn enable_timestamps(&self) { LOCAL_RUNTIME.with(|runtime| { - runtime.borrow()[&self.device] + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) .borrow_mut() .enable_timestamps() }) @@ -495,7 +578,10 @@ impl ComputeChannel for ThreadLocalChannel { fn disable_timestamps(&self) { LOCAL_RUNTIME.with(|runtime| { - runtime.borrow()[&self.device] + runtime + .borrow_mut() + .entry(self.device.clone()) + .or_insert_with(|| Self::make_server(&self.device)) .borrow_mut() .disable_timestamps() }) From 33484dbc317d5d53210f5b1825ea57c110f5f153 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Thu, 14 Nov 2024 12:30:15 +0530 Subject: [PATCH 06/10] Added sync --- crates/cubecl-runtime/src/channel/base.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/cubecl-runtime/src/channel/base.rs b/crates/cubecl-runtime/src/channel/base.rs index 60a4aa2f2..5f5710776 100644 --- a/crates/cubecl-runtime/src/channel/base.rs +++ b/crates/cubecl-runtime/src/channel/base.rs @@ -10,7 +10,7 @@ use alloc::vec::Vec; /// The ComputeChannel trait links the ComputeClient to the ComputeServer /// while ensuring thread-safety -pub trait ComputeChannel: Send + Clone + core::fmt::Debug { +pub trait ComputeChannel: Sync + Send + Clone + core::fmt::Debug { /// Given a binding, returns owned resource as bytes fn read(&self, binding: Binding) -> impl Future>; From d1173f315cf1b65e3a33c10967cf19c3fe3da376 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Thu, 14 Nov 2024 17:52:06 +0530 Subject: [PATCH 07/10] Switched to futures_lite for blocking --- Cargo.toml | 3 +- crates/cubecl-wgpu/Cargo.toml | 4 +- crates/cubecl-wgpu/src/compute/server.rs | 14 +- crates/cubecl-wgpu/src/runtime.rs | 211 ++++++++++------------- 4 files changed, 102 insertions(+), 130 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 38b97d290..8c2353b1c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,7 +27,6 @@ serde_json = { version = "1.0.119", default-features = false } dashmap = "5.5.3" hashbrown = "0.14.5" spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] } -rayon = "1" getrandom = { version = "0.2.15", default-features = false } rand = { version = "0.8.5", default-features = false, features = [ @@ -38,6 +37,7 @@ async-channel = "2.3" dirs = "5.0.1" md5 = "0.7.0" sanitize-filename = "0.5" +wasm-bindgen = "0.2" wasm-bindgen-futures = "0.4.45" weak-table = "0.3" web-time = "1.1.0" @@ -74,7 +74,6 @@ pretty_assertions = "1.4" # Async embassy-futures = { version = "0.1.1" } # for no-std futures-lite = { version = "2.3.0", default-features = false } -futures = "0.3.31" [profile.dev] opt-level = 2 diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index b6cddd113..ef96295df 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -39,6 +39,7 @@ wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] } async-channel = { workspace = true } derive-new = { workspace = true } +futures-lite = { workspace = true } hashbrown = { workspace = true } log = { workspace = true } web-time = { workspace = true } @@ -46,8 +47,7 @@ web-time = { workspace = true } cfg-if = { workspace = true } [target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies] -futures = { workspace = true } -rayon = { workspace = true } +wasm-bindgen = { workspace = true } [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 7fc23f980..322072560 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -25,6 +25,7 @@ pub struct WgpuServer { memory_management: MemoryManagement, pub(crate) device: Pdrc, queue: Pdrc, + pub(crate) adapter: Pdrc, encoder: CommandEncoder, current_pass: Option>, tasks_count: usize, @@ -90,6 +91,7 @@ impl WgpuServer { memory_management: MemoryManagement, device: Pdrc, queue: Pdrc, + adapter: Pdrc, tasks_max: usize, ) -> Self { let logger = DebugLogger::default(); @@ -99,18 +101,22 @@ impl WgpuServer { timestamps.enable(&device); } + let encoder = create_encoder(&device); + let poll = WgpuPoll::new(device.clone()); + Self { memory_management, - device: device.clone(), - queue: queue.clone(), - encoder: create_encoder(&device), + device, + queue, + adapter, + encoder, current_pass: None, tasks_count: 0, storage_locked: MemoryLock::default(), pipelines: HashMap::new(), tasks_max, logger, - poll: WgpuPoll::new(device.clone()), + poll, duration_profiled: None, timestamps, _compiler: PhantomData, diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 2d50b2efc..3782b159b 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -7,16 +7,14 @@ use crate::{ compute::{WgpuServer, WgpuStorage}, AutoGraphicsApi, GraphicsApi, Pdrc, WgpuDevice, }; -use cubecl_common::future; use cubecl_core::{channel::ComputeChannel, server::ComputeServer, Feature, Runtime}; -use cubecl_runtime::client::ComputeClient; -pub use cubecl_runtime::memory_management::MemoryConfiguration; -use cubecl_runtime::DeviceProperties; #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] -use cubecl_runtime::{channel::MutexComputeChannel, ComputeRuntime}; +use cubecl_runtime::channel::MutexComputeChannel; use cubecl_runtime::{ - memory_management::{MemoryDeviceProperties, MemoryManagement}, + client::ComputeClient, + memory_management::{MemoryConfiguration, MemoryDeviceProperties, MemoryManagement}, storage::ComputeStorage, + ComputeRuntime, DeviceProperties, }; use wgpu::RequestAdapterOptions; @@ -30,9 +28,11 @@ type Server = WgpuServer; #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] thread_local! { - static LOCAL_RUNTIME: RefCell>>> = RefCell::new(hashbrown::HashMap::default()); + static LOCAL_DEVICE: RefCell>>> = RefCell::new(hashbrown::HashMap::default()); } +static RUNTIME: ComputeRuntime = ComputeRuntime::new(); + /// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] static RUNTIME: ComputeRuntime> = @@ -49,71 +49,55 @@ impl Runtime for WgpuRuntime { type Device = WgpuDevice; fn client(device: &Self::Device) -> ComputeClient { - #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] - { - RUNTIME.client(device, move || { + RUNTIME.client(device, move || { + #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] + { let setup = future::block_on(create_setup_for_device::< AutoGraphicsApi, WgslCompiler, >(device)); create_client_on_setup(setup, RuntimeOptions::default()) - }) - } - - #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] - { - let setup = future::block_on(create_setup_for_device::( - device, - )); - - let limits = setup.device.limits(); - let mem_props = MemoryDeviceProperties { - max_page_size: limits.max_storage_buffer_binding_size as u64, - alignment: WgpuStorage::ALIGNMENT - .max(limits.min_storage_buffer_offset_alignment as u64), - }; - - let options = RuntimeOptions::default(); - let memory_management = { - let mem_props = mem_props.clone(); - let config = options.memory_config; - let storage = WgpuStorage::new(setup.device.clone()); - MemoryManagement::from_configuration(storage, mem_props, config) - }; - let server = crate::compute::WgpuServer::new( - memory_management, - setup.device.clone(), - setup.queue, - options.tasks_max, - ); - - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() - .insert(device.clone(), Rc::new(RefCell::new(server))); - }); - - let features = setup.adapter.features(); - let mut device_props = DeviceProperties::new(&[], mem_props); + } - if features.contains(wgpu::Features::SUBGROUP) - && setup.adapter.get_info().device_type != wgpu::DeviceType::Cpu + #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] { - device_props.register_feature(Feature::Subcube); + let server = LOCAL_DEVICE.with_borrow_mut(|runtime| { + runtime + .entry(device.clone()) + .or_insert_with(|| ThreadLocalChannel::make_server(device)) + .clone() + }); + let server = server.borrow(); + + let limits = server.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + + let features = server.device.features(); + let mut device_props = DeviceProperties::new(&[], mem_props); + + if features.contains(wgpu::Features::SUBGROUP) + && server.adapter.get_info().device_type != wgpu::DeviceType::Cpu + { + device_props.register_feature(Feature::Subcube); + } + ::register_features( + &server.adapter, + &server.device, + &mut device_props, + ); + + ComputeClient::new( + ThreadLocalChannel { + device: device.clone(), + }, + device_props, + ) } - ::register_features( - &setup.adapter, - &setup.device, - &mut device_props, - ); - - ComputeClient::new( - ThreadLocalChannel { - device: device.clone(), - }, - device_props, - ) - } + }) } fn name() -> &'static str { @@ -423,9 +407,10 @@ pub struct ThreadLocalChannel { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] impl ThreadLocalChannel { fn make_server(device: &WgpuDevice) -> Rc> { - let setup = future::block_on(create_setup_for_device::( - device, - )); + let setup = futures_lite::future::block_on(create_setup_for_device::< + AutoGraphicsApi, + WgslCompiler, + >(device)); let limits = setup.device.limits(); let mem_props = MemoryDeviceProperties { @@ -443,8 +428,9 @@ impl ThreadLocalChannel { }; let server = crate::compute::WgpuServer::new( memory_management, - setup.device.clone(), + setup.device, setup.queue, + setup.adapter, options.tasks_max, ); @@ -458,9 +444,8 @@ impl ComputeChannel for ThreadLocalChannel { &self, binding: cubecl_core::server::Binding, ) -> impl std::future::Future> { - LOCAL_RUNTIME.with(|runtime| { + LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .borrow_mut() .entry(self.device.clone()) .or_insert_with(|| Self::make_server(&self.device)) .clone(); @@ -472,35 +457,29 @@ impl ComputeChannel for ThreadLocalChannel { &self, binding: cubecl_core::server::Binding, ) -> cubecl_runtime::storage::BindingResource { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .get_resource(binding) + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().get_resource(binding) }) } fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .create(data) + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().create(data) }) } fn empty(&self, size: usize) -> cubecl_core::server::Handle { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .empty(size) + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().empty(size) }) } @@ -511,31 +490,26 @@ impl ComputeChannel for ThreadLocalChannel { bindings: Vec, mode: cubecl_core::ExecutionMode, ) { - LOCAL_RUNTIME.with(|runtime| { - let mut runtime = runtime.borrow_mut(); - let mut server = runtime + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut(); - unsafe { server.execute(kernel, count, bindings, mode) } + .or_insert_with(|| Self::make_server(&self.device)); + unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) } }) } fn flush(&self) { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .flush() + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().flush() }) } fn sync(&self) -> impl std::future::Future { - LOCAL_RUNTIME.with(|runtime| { + LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .borrow_mut() .entry(self.device.clone()) .or_insert_with(|| Self::make_server(&self.device)) .clone(); @@ -544,9 +518,8 @@ impl ComputeChannel for ThreadLocalChannel { } fn sync_elapsed(&self) -> impl std::future::Future { - LOCAL_RUNTIME.with(|runtime| { + LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .borrow_mut() .entry(self.device.clone()) .or_insert_with(|| Self::make_server(&self.device)) .clone(); @@ -555,35 +528,29 @@ impl ComputeChannel for ThreadLocalChannel { } fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .memory_usage() + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().memory_usage() }) } fn enable_timestamps(&self) { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .enable_timestamps() + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().enable_timestamps() }) } fn disable_timestamps(&self) { - LOCAL_RUNTIME.with(|runtime| { - runtime - .borrow_mut() + LOCAL_DEVICE.with_borrow_mut(|runtime| { + let server = runtime .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) - .borrow_mut() - .disable_timestamps() + .or_insert_with(|| Self::make_server(&self.device)); + server.borrow_mut().disable_timestamps() }) } } From 5412ef05fc0568ef1daaa91e348edf85f4b02b0a Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Thu, 14 Nov 2024 19:29:09 +0530 Subject: [PATCH 08/10] Fixed wasm32 non atomic --- crates/cubecl-wgpu/src/compute/server.rs | 1 + crates/cubecl-wgpu/src/runtime.rs | 8 +++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/crates/cubecl-wgpu/src/compute/server.rs b/crates/cubecl-wgpu/src/compute/server.rs index 322072560..6a3dbec4e 100644 --- a/crates/cubecl-wgpu/src/compute/server.rs +++ b/crates/cubecl-wgpu/src/compute/server.rs @@ -25,6 +25,7 @@ pub struct WgpuServer { memory_management: MemoryManagement, pub(crate) device: Pdrc, queue: Pdrc, + #[allow(unused)] pub(crate) adapter: Pdrc, encoder: CommandEncoder, current_pass: Option>, diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 3782b159b..95fe52a6b 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -7,7 +7,11 @@ use crate::{ compute::{WgpuServer, WgpuStorage}, AutoGraphicsApi, GraphicsApi, Pdrc, WgpuDevice, }; -use cubecl_core::{channel::ComputeChannel, server::ComputeServer, Feature, Runtime}; +#[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] +use cubecl_core::future; +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +use cubecl_core::{channel::ComputeChannel, server::ComputeServer}; +use cubecl_core::{Feature, Runtime}; #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] use cubecl_runtime::channel::MutexComputeChannel; use cubecl_runtime::{ @@ -31,6 +35,7 @@ thread_local! { static LOCAL_DEVICE: RefCell>>> = RefCell::new(hashbrown::HashMap::default()); } +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] static RUNTIME: ComputeRuntime = ComputeRuntime::new(); /// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). @@ -214,6 +219,7 @@ pub(crate) fn create_client_on_setup( memory_management, setup.device.clone(), setup.queue, + setup.adapter.clone(), options.tasks_max, ); let channel = MutexComputeChannel::new(server); From 617227075870a4c25467b45a30a4eb348f289389 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Fri, 15 Nov 2024 17:35:17 +0530 Subject: [PATCH 09/10] Make a async init function required for each thread --- crates/cubecl-wgpu/Cargo.toml | 3 - crates/cubecl-wgpu/src/runtime.rs | 100 +++++++++++++++++++++++------- 2 files changed, 78 insertions(+), 25 deletions(-) diff --git a/crates/cubecl-wgpu/Cargo.toml b/crates/cubecl-wgpu/Cargo.toml index ef96295df..1169a727e 100644 --- a/crates/cubecl-wgpu/Cargo.toml +++ b/crates/cubecl-wgpu/Cargo.toml @@ -46,9 +46,6 @@ web-time = { workspace = true } cfg-if = { workspace = true } -[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies] -wasm-bindgen = { workspace = true } - [dev-dependencies] cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [ "export_tests", diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 95fe52a6b..898a3c229 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -68,8 +68,8 @@ impl Runtime for WgpuRuntime { { let server = LOCAL_DEVICE.with_borrow_mut(|runtime| { runtime - .entry(device.clone()) - .or_insert_with(|| ThreadLocalChannel::make_server(device)) + .get(device) + .expect(&format!("The wgpu server for {device:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread")) .clone() }); let server = server.borrow(); @@ -158,6 +158,32 @@ pub struct WgpuSetup { pub queue: Pdrc, } +#[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] +pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) { + let setup = create_setup_for_device::(&device).await; + + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64), + }; + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device, + setup.queue, + setup.adapter, + options.tasks_max, + ); + + LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server)))); +} + /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`]. /// Useful when you want to share a device between CubeCL and other wgpu-dependent libraries. #[cfg(not(all(target_arch = "wasm32", target_feature = "atomics")))] @@ -465,8 +491,11 @@ impl ComputeChannel for ThreadLocalChannel { ) -> cubecl_runtime::storage::BindingResource { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().get_resource(binding) }) } @@ -474,8 +503,11 @@ impl ComputeChannel for ThreadLocalChannel { fn create(&self, data: &[u8]) -> cubecl_core::server::Handle { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().create(data) }) } @@ -483,8 +515,11 @@ impl ComputeChannel for ThreadLocalChannel { fn empty(&self, size: usize) -> cubecl_core::server::Handle { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().empty(size) }) } @@ -498,8 +533,11 @@ impl ComputeChannel for ThreadLocalChannel { ) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); unsafe { server.borrow_mut().execute(kernel, count, bindings, mode) } }) } @@ -507,8 +545,11 @@ impl ComputeChannel for ThreadLocalChannel { fn flush(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().flush() }) } @@ -516,8 +557,11 @@ impl ComputeChannel for ThreadLocalChannel { fn sync(&self) -> impl std::future::Future { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) .clone(); async move { server.borrow_mut().sync().await } }) @@ -526,8 +570,11 @@ impl ComputeChannel for ThreadLocalChannel { fn sync_elapsed(&self) -> impl std::future::Future { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)) + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )) .clone(); async move { server.borrow_mut().sync_elapsed().await } }) @@ -536,8 +583,11 @@ impl ComputeChannel for ThreadLocalChannel { fn memory_usage(&self) -> cubecl_runtime::memory_management::MemoryUsage { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().memory_usage() }) } @@ -545,8 +595,11 @@ impl ComputeChannel for ThreadLocalChannel { fn enable_timestamps(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().enable_timestamps() }) } @@ -554,8 +607,11 @@ impl ComputeChannel for ThreadLocalChannel { fn disable_timestamps(&self) { LOCAL_DEVICE.with_borrow_mut(|runtime| { let server = runtime - .entry(self.device.clone()) - .or_insert_with(|| Self::make_server(&self.device)); + .get(&self.device) + .expect(&format!( + "The wgpu server for {:?} was not initialized with `init_thread_server`. `init_thread_server` needs to be called once on each thread before any computation is done on that thread", + self.device, + )); server.borrow_mut().disable_timestamps() }) } From 7a0c979b7e190b2dad4305ad60e8430f756e9dd4 Mon Sep 17 00:00:00 2001 From: Sarthak Singh Date: Fri, 15 Nov 2024 17:38:58 +0530 Subject: [PATCH 10/10] Added check so that we don't re init the server --- crates/cubecl-wgpu/src/runtime.rs | 43 +++++++++++++++++-------------- 1 file changed, 23 insertions(+), 20 deletions(-) diff --git a/crates/cubecl-wgpu/src/runtime.rs b/crates/cubecl-wgpu/src/runtime.rs index 898a3c229..584660588 100644 --- a/crates/cubecl-wgpu/src/runtime.rs +++ b/crates/cubecl-wgpu/src/runtime.rs @@ -160,28 +160,31 @@ pub struct WgpuSetup { #[cfg(all(target_arch = "wasm32", target_feature = "atomics"))] pub async fn init_thread_server(device: WgpuDevice, options: RuntimeOptions) { - let setup = create_setup_for_device::(&device).await; + if !LOCAL_DEVICE.with_borrow(|map| map.contains_key(&device)) { + let setup = create_setup_for_device::(&device).await; - let limits = setup.device.limits(); - let mem_props = MemoryDeviceProperties { - max_page_size: limits.max_storage_buffer_binding_size as u64, - alignment: WgpuStorage::ALIGNMENT.max(limits.min_storage_buffer_offset_alignment as u64), - }; - let memory_management = { - let mem_props = mem_props.clone(); - let config = options.memory_config; - let storage = WgpuStorage::new(setup.device.clone()); - MemoryManagement::from_configuration(storage, mem_props, config) - }; - let server = crate::compute::WgpuServer::new( - memory_management, - setup.device, - setup.queue, - setup.adapter, - options.tasks_max, - ); + let limits = setup.device.limits(); + let mem_props = MemoryDeviceProperties { + max_page_size: limits.max_storage_buffer_binding_size as u64, + alignment: WgpuStorage::ALIGNMENT + .max(limits.min_storage_buffer_offset_alignment as u64), + }; + let memory_management = { + let mem_props = mem_props.clone(); + let config = options.memory_config; + let storage = WgpuStorage::new(setup.device.clone()); + MemoryManagement::from_configuration(storage, mem_props, config) + }; + let server = crate::compute::WgpuServer::new( + memory_management, + setup.device, + setup.queue, + setup.adapter, + options.tasks_max, + ); - LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server)))); + LOCAL_DEVICE.with_borrow_mut(|map| map.insert(device, Rc::new(RefCell::new(server)))); + } } /// Create a [`WgpuDevice`] on an existing [`WgpuSetup`].