Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for target_feature atomics on wasm32 #239

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async-channel = "2.3"
dirs = "5.0.1"
md5 = "0.7.0"
sanitize-filename = "0.5"
wasm-bindgen = "0.2"
wasm-bindgen-futures = "0.4.45"
weak-table = "0.3"
web-time = "1.1.0"
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-runtime/src/channel/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ use alloc::vec::Vec;

/// The ComputeChannel trait links the ComputeClient to the ComputeServer
/// while ensuring thread-safety
pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send + Sync {
pub trait ComputeChannel<Server: ComputeServer>: Sync + Send + Clone + core::fmt::Debug {
/// Given a binding, returns owned resource as bytes
fn read(&self, binding: Binding) -> impl Future<Output = Vec<u8>> + Send;
fn read(&self, binding: Binding) -> impl Future<Output = Vec<u8>>;

/// Given a resource handle, return the storage resource.
fn get_resource(&self, binding: Binding) -> BindingResource<Server>;
Expand Down Expand Up @@ -40,12 +40,12 @@ pub trait ComputeChannel<Server: ComputeServer>: Clone + core::fmt::Debug + Send
fn flush(&self);

/// Wait for the completion of every task in the server.
fn sync(&self) -> impl Future<Output = ()> + Send;
fn sync(&self) -> impl Future<Output = ()>;

/// Wait for the completion of every task in the server.
///
/// Returns the (approximate) total amount of GPU work done since the last sync.
fn sync_elapsed(&self) -> impl Future<Output = TimestampsResult> + Send;
fn sync_elapsed(&self) -> impl Future<Output = TimestampsResult>;

/// Get the current memory usage of the server.
fn memory_usage(&self) -> crate::memory_management::MemoryUsage;
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/src/channel/cell.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ where

impl<Server> ComputeChannel<Server> for RefCellComputeChannel<Server>
where
Server: ComputeServer + Send,
Server: ComputeServer,
{
async fn read(&self, binding: Binding) -> Vec<u8> {
let future = {
Expand Down
6 changes: 4 additions & 2 deletions crates/cubecl-runtime/src/channel/mpsc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use super::ComputeChannel;
use crate::{
memory_management::MemoryUsage,
server::{Binding, ComputeServer, CubeCount, Handle},
storage::BindingResource,
storage::{BindingResource, ComputeStorage},
ExecutionMode,
};

Expand Down Expand Up @@ -50,7 +50,8 @@ where

impl<Server> MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
Server: ComputeServer + Send + 'static,
<Server::Storage as ComputeStorage>::Resource: Send,
{
/// Create a new mpsc compute channel.
pub fn new(mut server: Server) -> Self {
Expand Down Expand Up @@ -123,6 +124,7 @@ impl<Server: ComputeServer> Clone for MpscComputeChannel<Server> {
impl<Server> ComputeChannel<Server> for MpscComputeChannel<Server>
where
Server: ComputeServer + 'static,
<Server::Storage as ComputeStorage>::Resource: Send,
{
async fn read(&self, binding: Binding) -> Vec<u8> {
let sender = self.state.sender.clone();
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/src/channel/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ where

impl<Server> ComputeChannel<Server> for MutexComputeChannel<Server>
where
Server: ComputeServer,
Server: ComputeServer + Send,
{
async fn read(&self, handle: Binding) -> Vec<u8> {
// Nb: The order here is really important - the mutex guard has to be dropped before
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-runtime/src/memory_management/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::{format, string::String};
/// Amount of memory in use by this allocator
/// and statistics on how much memory is reserved and
/// wasted in total.
#[derive(Debug)]
pub struct MemoryUsage {
/// The number of allocations currently active.
pub number_allocs: u64,
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-runtime/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use cubecl_common::benchmark::TimestampsResult;
///
/// Everything in the server is mutable, therefore it should be solely accessed through the
/// [compute channel](crate::channel::ComputeChannel) for thread safety.
pub trait ComputeServer: Send + core::fmt::Debug
pub trait ComputeServer: core::fmt::Debug
where
Self: Sized,
{
Expand All @@ -26,7 +26,7 @@ where
type Feature: Ord + Copy + Debug + Send + Sync;

/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, binding: Binding) -> impl Future<Output = Vec<u8>> + Send + 'static;
fn read(&mut self, binding: Binding) -> impl Future<Output = Vec<u8>> + 'static;

/// Given a resource handle, returns the storage resource.
fn get_resource(&mut self, binding: Binding) -> BindingResource<Self>;
Expand Down Expand Up @@ -57,12 +57,12 @@ where
fn flush(&mut self);

/// Wait for the completion of every task in the server.
fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
fn sync(&mut self) -> impl Future<Output = ()> + 'static;

/// Wait for the completion of every task in the server.
///
/// Returns the (approximate) total amount of GPU work done since the last sync.
fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + Send + 'static;
fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + 'static;

/// The current memory usage of the server.
fn memory_usage(&self) -> MemoryUsage;
Expand Down
4 changes: 2 additions & 2 deletions crates/cubecl-runtime/src/storage/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ impl StorageHandle {
}

/// Storage types are responsible for allocating and deallocating memory.
pub trait ComputeStorage: Send {
pub trait ComputeStorage {
/// The resource associated type determines the way data is implemented and how
/// it can be accessed by kernels.
type Resource: Send;
type Resource;

/// The alignment memory is allocated with in this storage.
const ALIGNMENT: u64;
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-runtime/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ impl<K: AutotuneKey> TuneCache<K> {
} => {
if cfg!(autotune_persistent_cache) {
match checksum_matches {
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(false) => TuneCacheResult::Miss, // Can't use this.
#[cfg(autotune_persistent_cache)]
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(true) => TuneCacheResult::Hit {
fastest_index: *fastest_index,
},
_ => TuneCacheResult::Miss, // Some(false) or None so we can't use this.
}
} else {
let _ = checksum_matches;
Expand Down
2 changes: 1 addition & 1 deletion crates/cubecl-runtime/src/tune/tuner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ impl<K: AutotuneKey> Tuner<K> {
}
}

fn spawn_benchmark_task(future: impl Future<Output = ()> + Send + 'static) {
fn spawn_benchmark_task(future: impl Future<Output = ()> + 'static) {
// On wasm, spawn the tuning as a detached task.
#[cfg(target_family = "wasm")]
wasm_bindgen_futures::spawn_local(future);
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ wgpu = { version = "22.0.0", features = ["fragile-send-sync-non-atomic-wasm"] }

async-channel = { workspace = true }
derive-new = { workspace = true }
futures-lite = { workspace = true }
hashbrown = { workspace = true }
log = { workspace = true }
web-time = { workspace = true }
Expand Down
6 changes: 2 additions & 4 deletions crates/cubecl-wgpu/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
use std::sync::Arc;

use cubecl_core::{
prelude::CompiledKernel, server::ComputeServer, Compiler, ExecutionMode, Feature,
};
use cubecl_runtime::DeviceProperties;
use wgpu::{Adapter, ComputePipeline, Device, Queue};

use crate::WgpuServer;
use crate::{Pdrc, WgpuServer};

pub trait WgpuCompiler: Compiler {
fn compile(
Expand All @@ -19,7 +17,7 @@ pub trait WgpuCompiler: Compiler {
server: &mut WgpuServer<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline>;
) -> Pdrc<ComputePipeline>;

#[allow(async_fn_in_trait)]
async fn request_device(adapter: &Adapter) -> (Device, Queue);
Expand Down
8 changes: 4 additions & 4 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::{borrow::Cow, sync::Arc};
use std::borrow::Cow;

use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory};
use super::{LocalArray, Subgroup};
use crate::{
compiler::{base::WgpuCompiler, wgsl},
WgpuServer,
Pdrc, WgpuServer,
};
use cubecl_core::{
ir::{self as cube, HybridAllocator, UIntKind},
Expand Down Expand Up @@ -73,7 +73,7 @@ impl WgpuCompiler for WgslCompiler {
server: &mut WgpuServer<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline> {
) -> Pdrc<ComputePipeline> {
let source = &kernel.source;
let repr = kernel.repr.unwrap();
let module = match mode {
Expand Down Expand Up @@ -118,7 +118,7 @@ impl WgpuCompiler for WgslCompiler {
push_constant_ranges: &[],
});

Arc::new(
Pdrc::new(
server
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-wgpu/src/compute/poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ mod _impl {
// On Wasm, the browser handles the polling loop, so we don't need anything.
#[cfg(target_family = "wasm")]
mod _impl {
use crate::Pdrc;

#[derive(Debug)]
pub struct WgpuPoll {}
impl WgpuPoll {
pub fn new(_device: alloc::sync::Arc<wgpu::Device>) -> Self {
pub fn new(_device: Pdrc<wgpu::Device>) -> Self {
Self {}
}
pub fn start_polling(&self) -> alloc::sync::Arc<()> {
Expand Down
52 changes: 33 additions & 19 deletions crates/cubecl-wgpu/src/compute/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ use std::{future::Future, marker::PhantomData, num::NonZero, pin::Pin, time::Dur

use super::poll::WgpuPoll;
use super::WgpuStorage;
use crate::compiler::base::WgpuCompiler;
use alloc::sync::Arc;
use crate::{compiler::base::WgpuCompiler, Pdrc};
use cubecl_common::future;
use cubecl_core::{compute::DebugInformation, prelude::*, server::Handle, Feature, KernelId};
use cubecl_runtime::{
Expand All @@ -15,18 +14,23 @@ use cubecl_runtime::{
};
use hashbrown::HashMap;
use web_time::Instant;
use wgpu::{CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType};
use wgpu::{
CommandEncoder, ComputePass, ComputePipeline, QuerySet, QuerySetDescriptor, QueryType,
WasmNotSend,
};

/// Wgpu compute server.
#[derive(Debug)]
pub struct WgpuServer<C: WgpuCompiler> {
memory_management: MemoryManagement<WgpuStorage>,
pub(crate) device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
pub(crate) device: Pdrc<wgpu::Device>,
queue: Pdrc<wgpu::Queue>,
#[allow(unused)]
pub(crate) adapter: Pdrc<wgpu::Adapter>,
encoder: CommandEncoder,
current_pass: Option<ComputePass<'static>>,
tasks_count: usize,
pipelines: HashMap<KernelId, Arc<ComputePipeline>>,
pipelines: HashMap<KernelId, Pdrc<ComputePipeline>>,
tasks_max: usize,
logger: DebugLogger,
poll: WgpuPoll,
Expand All @@ -36,6 +40,10 @@ pub struct WgpuServer<C: WgpuCompiler> {
_compiler: PhantomData<C>,
}

trait FutureWasmNotSend<O>: Future<Output = O> + WasmNotSend {}

impl<O, T: Future<Output = O> + WasmNotSend> FutureWasmNotSend<O> for T {}

#[derive(Debug)]
enum KernelTimestamps {
Native { query_set: QuerySet, init: bool },
Expand Down Expand Up @@ -82,8 +90,9 @@ impl<C: WgpuCompiler> WgpuServer<C> {
/// Create a new server.
pub fn new(
memory_management: MemoryManagement<WgpuStorage>,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
device: Pdrc<wgpu::Device>,
queue: Pdrc<wgpu::Queue>,
adapter: Pdrc<wgpu::Adapter>,
tasks_max: usize,
) -> Self {
let logger = DebugLogger::default();
Expand All @@ -93,18 +102,22 @@ impl<C: WgpuCompiler> WgpuServer<C> {
timestamps.enable(&device);
}

let encoder = create_encoder(&device);
let poll = WgpuPoll::new(device.clone());

Self {
memory_management,
device: device.clone(),
queue: queue.clone(),
encoder: create_encoder(&device),
device,
queue,
adapter,
encoder,
current_pass: None,
tasks_count: 0,
storage_locked: MemoryLock::default(),
pipelines: HashMap::new(),
tasks_max,
logger,
poll: WgpuPoll::new(device.clone()),
poll,
duration_profiled: None,
timestamps,
_compiler: PhantomData,
Expand All @@ -113,9 +126,9 @@ impl<C: WgpuCompiler> WgpuServer<C> {

fn pipeline(
&mut self,
kernel: <Self as ComputeServer>::Kernel,
kernel: <WgpuServer<C> as ComputeServer>::Kernel,
mode: ExecutionMode,
) -> Arc<ComputePipeline> {
) -> Pdrc<ComputePipeline> {
let mut kernel_id = kernel.id();
kernel_id.mode(mode);

Expand Down Expand Up @@ -192,7 +205,7 @@ impl<C: WgpuCompiler> WgpuServer<C> {
}
}

fn sync_queue(&mut self) -> Pin<Box<dyn Future<Output = ()> + Send + 'static>> {
fn sync_queue(&mut self) -> Pin<Box<dyn FutureWasmNotSend<()> + 'static>> {
self.flush();

#[cfg(target_family = "wasm")]
Expand Down Expand Up @@ -220,7 +233,7 @@ impl<C: WgpuCompiler> WgpuServer<C> {

fn sync_queue_elapsed(
&mut self,
) -> Pin<Box<dyn Future<Output = TimestampsResult> + Send + 'static>> {
) -> Pin<Box<dyn FutureWasmNotSend<TimestampsResult> + 'static>> {
self.clear_compute_pass();

enum TimestampMethod {
Expand Down Expand Up @@ -301,14 +314,15 @@ impl<C: WgpuCompiler> ComputeServer for WgpuServer<C> {
type Storage = WgpuStorage;
type Feature = Feature;

fn read(&mut self, binding: server::Binding) -> impl Future<Output = Vec<u8>> + Send + 'static {
fn read(&mut self, binding: server::Binding) -> impl Future<Output = Vec<u8>> + 'static {
let rb = self.get_resource(binding);
let resource = rb.resource();
self.clear_compute_pass();

self.read_wgpu_buffer(&resource.buffer, resource.offset(), resource.size())
}

fn get_resource(&mut self, binding: server::Binding) -> BindingResource<Self> {
fn get_resource(&mut self, binding: server::Binding) -> BindingResource<WgpuServer<C>> {
// Keep track of any buffer that might be used in the wgpu queue, as we cannot copy into them
// after they have any outstanding compute work. Calling get_resource repeatedly
// will add duplicates to this, but that is ok.
Expand Down Expand Up @@ -376,7 +390,7 @@ impl<C: WgpuCompiler> ComputeServer for WgpuServer<C> {

unsafe fn execute(
&mut self,
kernel: Self::Kernel,
kernel: <WgpuServer<C> as ComputeServer>::Kernel,
count: CubeCount,
bindings: Vec<server::Binding>,
mode: ExecutionMode,
Expand Down
Loading