Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for target_feature atomics on wasm32 #239

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ serde_json = { version = "1.0.119", default-features = false }
dashmap = "5.5.3"
hashbrown = "0.14.5"
spin = { version = "0.9.8", features = ["mutex", "spin_mutex"] }
rayon = "1"

getrandom = { version = "0.2.15", default-features = false }
rand = { version = "0.8.5", default-features = false, features = [
Expand Down Expand Up @@ -73,6 +74,7 @@ pretty_assertions = "1.4"
# Async
embassy-futures = { version = "0.1.1" } # for no-std
futures-lite = { version = "2.3.0", default-features = false }
futures = "0.3.31"

[profile.dev]
opt-level = 2
1 change: 1 addition & 0 deletions crates/cubecl-runtime/src/memory_management/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use alloc::{format, string::String};
/// Amount of memory in use by this allocator
/// and statistics on how much memory is reserved and
/// wasted in total.
#[derive(Debug)]
pub struct MemoryUsage {
/// The number of allocations currently active.
pub number_allocs: u64,
Expand Down
5 changes: 3 additions & 2 deletions crates/cubecl-runtime/src/tune/tune_cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ impl<K: AutotuneKey> TuneCache<K> {
} => {
if cfg!(autotune_persistent_cache) {
match checksum_matches {
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(false) => TuneCacheResult::Miss, // Can't use this.
#[cfg(autotune_persistent_cache)]
None => TuneCacheResult::Unchecked, // Don't know yet.
Some(true) => TuneCacheResult::Hit {
fastest_index: *fastest_index,
},
_ => TuneCacheResult::Miss, // Some(false) or None so we can't use this.
}
} else {
let _ = checksum_matches;
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-wgpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ web-time = { workspace = true }

cfg-if = { workspace = true }

[target.'cfg(all(target_arch = "wasm32", target_feature = "atomics"))'.dependencies]
futures = { workspace = true }
rayon = { workspace = true }

[dev-dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.4.0", features = [
"export_tests",
Expand Down
10 changes: 4 additions & 6 deletions crates/cubecl-wgpu/src/compiler/base.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
use std::sync::Arc;

use cubecl_core::{
prelude::CompiledKernel, server::ComputeServer, Compiler, ExecutionMode, Feature,
};
use cubecl_runtime::DeviceProperties;
use wgpu::{Adapter, ComputePipeline, Device, Queue};

use crate::WgpuServer;
use crate::{Pdrc, WgpuServer, WgpuServerInner};

pub trait WgpuCompiler: Compiler {
fn compile(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
mode: ExecutionMode,
) -> CompiledKernel<Self>;

fn create_pipeline(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline>;
) -> Pdrc<ComputePipeline>;

#[allow(async_fn_in_trait)]
async fn request_device(adapter: &Adapter) -> (Device, Queue);
Expand Down
11 changes: 6 additions & 5 deletions crates/cubecl-wgpu/src/compiler/wgsl/compiler.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
use std::{borrow::Cow, sync::Arc};
use std::borrow::Cow;

use super::{shader::ComputeShader, ConstantArray, Item, SharedMemory};
use super::{LocalArray, Subgroup};
use crate::{
compiler::{base::WgpuCompiler, wgsl},
WgpuServer,
};
use crate::{Pdrc, WgpuServerInner};
use cubecl_core::{
ir::{self as cube, HybridAllocator, UIntKind},
prelude::CompiledKernel,
Expand Down Expand Up @@ -70,10 +71,10 @@ impl cubecl_core::Compiler for WgslCompiler {

impl WgpuCompiler for WgslCompiler {
fn create_pipeline(
server: &mut WgpuServer<Self>,
server: &mut WgpuServerInner<Self>,
kernel: CompiledKernel<Self>,
mode: ExecutionMode,
) -> Arc<ComputePipeline> {
) -> Pdrc<ComputePipeline> {
let source = &kernel.source;
let repr = kernel.repr.unwrap();
let module = match mode {
Expand Down Expand Up @@ -118,7 +119,7 @@ impl WgpuCompiler for WgslCompiler {
push_constant_ranges: &[],
});

Arc::new(
Pdrc::new(
server
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
Expand All @@ -136,7 +137,7 @@ impl WgpuCompiler for WgslCompiler {
}

fn compile(
_server: &mut WgpuServer<Self>,
_server: &mut WgpuServerInner<Self>,
kernel: <WgpuServer<Self> as ComputeServer>::Kernel,
mode: ExecutionMode,
) -> CompiledKernel<Self> {
Expand Down
4 changes: 3 additions & 1 deletion crates/cubecl-wgpu/src/compute/poll.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,12 @@ mod _impl {
// On Wasm, the browser handles the polling loop, so we don't need anything.
#[cfg(target_family = "wasm")]
mod _impl {
use crate::Pdrc;

#[derive(Debug)]
pub struct WgpuPoll {}
impl WgpuPoll {
pub fn new(_device: alloc::sync::Arc<wgpu::Device>) -> Self {
pub fn new(_device: Pdrc<wgpu::Device>) -> Self {
Self {}
}
pub fn start_polling(&self) -> alloc::sync::Arc<()> {
Expand Down
Loading
Loading