Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement FastDivmod #433

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions crates/cubecl-core/src/frontend/element/int.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
use crate::frontend::{
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit, ExpandElementTyped,
Numeric,
};
use crate::ir::{Elem, IntKind};
use crate::prelude::Not;
use crate::Runtime;
use crate::{
compute::{KernelBuilder, KernelLauncher},
prelude::{CountOnes, ReverseBits},
};
use crate::{
frontend::{
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementBaseInit,
ExpandElementTyped, Numeric,
},
prelude::MulHi,
};

use super::{
init_expand_element, Init, IntoRuntime, LaunchArgExpand, ScalarArgSettings, __expand_new,
Expand All @@ -20,6 +23,7 @@ pub trait Int:
+ CountOnes
+ ReverseBits
+ Not
+ MulHi
+ std::ops::Rem<Output = Self>
+ core::ops::Add<Output = Self>
+ core::ops::Sub<Output = Self>
Expand Down
15 changes: 15 additions & 0 deletions crates/cubecl-core/src/frontend/operation/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,21 @@ impl_binary_func!(
u32,
u64
);
impl_binary_func!(
MulHi,
mul_hi,
__expand_mul_hi,
__expand_mul_hi_method,
Operator::MulHi,
i8,
i16,
i32,
i64,
u8,
u16,
u32,
u64
);
impl_binary_func_fixed_output_vectorization!(
Dot,
dot,
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-core/src/ir/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@ pub enum Operator {
Magnitude(UnaryOperator),
Normalize(UnaryOperator),
Dot(BinaryOperator),
/// Wide multiplication returning the high bits
MulHi(BinaryOperator),
// A select statement/ternary
Select(Select),
}
Expand Down Expand Up @@ -294,6 +296,7 @@ impl Display for Operator {
}
Operator::Cast(op) => write!(f, "cast({})", op.input),
Operator::Bitcast(op) => write!(f, "bitcast({})", op.input),
Operator::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs),
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions crates/cubecl-core/src/ir/processing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ impl ScopeProcessing {
sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
}
Operator::MulHi(op) => {
sanitize_constant_scalar_ref_var(&mut op.lhs, &inst.out.unwrap());
sanitize_constant_scalar_ref_var(&mut op.rhs, &inst.out.unwrap());
}
Operator::Abs(op) => {
sanitize_constant_scalar_ref_var(&mut op.input, &inst.out.unwrap());
}
Expand Down
3 changes: 3 additions & 0 deletions crates/cubecl-cpp/src/shared/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,9 @@ impl<D: Dialect> CppCompiler<D> {
gpu::Operator::Sub(op) => {
instructions.push(Instruction::Sub(self.compile_binary(op, out)))
}
gpu::Operator::MulHi(op) => {
instructions.push(Instruction::HiMul(self.compile_binary(op, out)))
}
gpu::Operator::Slice(op) => {
if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
let input = op.input;
Expand Down
44 changes: 44 additions & 0 deletions crates/cubecl-cpp/src/shared/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,50 @@ operator!(BitwiseXor, "^");
operator!(Or, "||");
operator!(And, "&&");

pub struct HiMul;

impl<D: Dialect> Binary<D> for HiMul {
// Powf doesn't support half and no half equivalent exists
fn format_scalar<Lhs: Display, Rhs: Display>(
f: &mut std::fmt::Formatter<'_>,
lhs: Lhs,
rhs: Rhs,
item: Item<D>,
) -> std::fmt::Result {
let elem = item.elem;
match elem {
Elem::I32 => write!(f, "__mulhi({lhs}, {rhs})"),
Elem::U32 => write!(f, "__umulhi({lhs}, {rhs})"),
Elem::I64 => write!(f, "__mul64hi({lhs}, {rhs})"),
Elem::U64 => write!(f, "__umul64hi({lhs}, {rhs})"),
_ => unimplemented!("HiMul only supports 32 and 64 bit ints"),
}
}

// Powf doesn't support half and no half equivalent exists
fn unroll_vec(
f: &mut Formatter<'_>,
lhs: &Variable<D>,
rhs: &Variable<D>,
out: &Variable<D>,
) -> core::fmt::Result {
let item_out = out.item();
let index = out.item().vectorization;

let out = out.fmt_left();
writeln!(f, "{out} = {item_out}{{")?;
for i in 0..index {
let lhsi = lhs.index(i);
let rhsi = rhs.index(i);

Self::format_scalar(f, lhsi, rhsi, item_out)?;
f.write_str(", ")?;
}

f.write_str("};\n")
}
}

pub struct Powf;

impl<D: Dialect> Binary<D> for Powf {
Expand Down
2 changes: 2 additions & 0 deletions crates/cubecl-cpp/src/shared/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub enum Instruction<D: Dialect> {
Div(BinaryInstruction<D>),
Mul(BinaryInstruction<D>),
Sub(BinaryInstruction<D>),
HiMul(BinaryInstruction<D>),
Index(BinaryInstruction<D>),
IndexAssign(BinaryInstruction<D>),
CheckedIndex {
Expand Down Expand Up @@ -230,6 +231,7 @@ impl<D: Dialect> Display for Instruction<D> {
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::HiMul(it) => HiMul::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Modulo(inst) => Modulo::format(f, &inst.lhs, &inst.rhs, &inst.out),
Instruction::BitwiseOr(it) => BitwiseOr::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::BitwiseAnd(it) => BitwiseAnd::format(f, &it.lhs, &it.rhs, &it.out),
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ impl Display for Instruction {
OpId::Sub => write!(f, "{} - {}", args[0], args[1]),
OpId::Mul => write!(f, "{} * {}", args[0], args[1]),
OpId::Div => write!(f, "{} / {}", args[0], args[1]),
OpId::MulHi => write!(f, "{}.mul_hi({})", args[0], args[1]),
OpId::Abs => write!(f, "{}.abs()", args[0]),
OpId::Exp => write!(f, "{}.exp()", args[0]),
OpId::Log => write!(f, "{}.log()", args[0]),
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/gvn/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,7 @@ pub enum OpId {
Sub,
Mul,
Div,
MulHi,
Abs,
Exp,
Log,
Expand Down
6 changes: 6 additions & 0 deletions crates/cubecl-opt/src/gvn/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ impl Expression {
rhs: args[1],
})
.into(),
OpId::MulHi => Operator::MulHi(BinaryOperator {
lhs: args[0],
rhs: args[1],
})
.into(),
OpId::Abs => Operator::Abs(UnaryOperator { input: args[0] }).into(),
OpId::Exp => Operator::Exp(UnaryOperator { input: args[0] }).into(),
OpId::Log => Operator::Log(UnaryOperator { input: args[0] }).into(),
Expand Down Expand Up @@ -278,6 +283,7 @@ pub fn id_of_op(op: &Operator) -> OpId {
Operator::Sub(_) => OpId::Sub,
Operator::Mul(_) => OpId::Mul,
Operator::Div(_) => OpId::Div,
Operator::MulHi(_) => OpId::MulHi,
Operator::Abs(_) => OpId::Abs,
Operator::Exp(_) => OpId::Exp,
Operator::Log(_) => OpId::Log,
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/gvn/numbering.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ impl ValueTable {
// Commutative binop
Operator::Add(op)
| Operator::Mul(op)
| Operator::MulHi(op)
| Operator::And(op)
| Operator::Or(op)
| Operator::Equal(op)
Expand Down
1 change: 1 addition & 0 deletions crates/cubecl-opt/src/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ impl Optimizer {
Operator::Add(binary_operator)
| Operator::Sub(binary_operator)
| Operator::Mul(binary_operator)
| Operator::MulHi(binary_operator)
| Operator::Div(binary_operator)
| Operator::Powf(binary_operator)
| Operator::Equal(binary_operator)
Expand Down
11 changes: 11 additions & 0 deletions crates/cubecl-spirv/src/instruction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,17 @@ impl<T: SpirvTarget> SpirvCompiler<T> {
};
});
}
Operator::MulHi(op) => {
self.compile_binary_op(op, out, |b, out_ty, ty, lhs, rhs, out| {
let struct_ty = b.type_struct([ty, ty]);
let result = match out_ty.elem() {
Elem::Int(_, false) => b.u_mul_extended(struct_ty, None, lhs, rhs).unwrap(),
Elem::Int(_, true) => b.s_mul_extended(struct_ty, None, lhs, rhs).unwrap(),
_ => unreachable!(),
};
b.composite_extract(ty, Some(out), result, [1]).unwrap();
});
}
Operator::Remainder(op) => {
self.compile_binary_op(op, out, |b, out_ty, ty, lhs, rhs, out| {
match out_ty.elem() {
Expand Down
3 changes: 2 additions & 1 deletion crates/cubecl-std/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
authors = [
"nathanielsimard <[email protected]>",
"louisfd <[email protected]>",
"maxtremblay <[email protected]>"
"maxtremblay <[email protected]>",
]
categories = ["science", "mathematics", "algorithms"]
description = "CubeCL Standard Library."
Expand All @@ -17,3 +17,4 @@ version.workspace = true
[dependencies]
cubecl-core = { path = "../cubecl-core", version = "0.5.0", default-features = false }
cubecl-runtime = { path = "../cubecl-runtime", version = "0.5.0", default-features = false }
serde = { workspace = true }
120 changes: 120 additions & 0 deletions crates/cubecl-std/src/fast_math.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
use cubecl_core as cubecl;
use cubecl_core::{prelude::*, CubeLaunch};

#[derive(CubeLaunch)]
pub struct FastDivmod<I: Int> {
divisor: I,
multiplier: u32,
shift_right: u32,
}

impl<I: Int> FastDivmod<I> {
#[allow(clippy::new_ret_no_self)]
pub fn new<'a, R: Runtime>(divisor: I) -> FastDivmodLaunch<'a, I, R> {
let div_int = divisor.to_i64().unwrap();
assert!(div_int != 0);

let mut multiplier = 0;
let mut shift_right = 0;

if div_int != 1 {
let p = 31 + find_log2(div_int);
multiplier = (1u64 << p).div_ceil(div_int as u64) as u64;
shift_right = p - 32;
}

FastDivmodLaunch::new(
ScalarArg::new(divisor),
ScalarArg::new(multiplier as u32),
ScalarArg::new(shift_right as u32),
)
}
}

impl<I: Int> FastDivmod<I> {
pub fn div(&self, dividend: I) -> I {
self.div_mod(dividend).0
}

pub fn modulo(&self, dividend: I) -> I {
self.div_mod(dividend).1
}

pub fn div_mod(&self, dividend: I) -> (I, I) {
(dividend / self.divisor, dividend % self.divisor)
}

pub fn __expand_div(
context: &mut CubeContext,
this: FastDivmodExpand<I>,
dividend: ExpandElementTyped<I>,
) -> ExpandElementTyped<I> {
this.__expand_div_method(context, dividend)
}

pub fn __expand_modulo(
context: &mut CubeContext,
this: FastDivmodExpand<I>,
dividend: ExpandElementTyped<I>,
) -> ExpandElementTyped<I> {
this.__expand_modulo_method(context, dividend)
}

pub fn __expand_div_mod(
context: &mut CubeContext,
this: FastDivmodExpand<I>,
dividend: ExpandElementTyped<I>,
) -> (ExpandElementTyped<I>, ExpandElementTyped<I>) {
this.__expand_div_mod_method(context, dividend)
}
}

impl<I: Int> FastDivmodExpand<I> {
pub fn __expand_div_method(
self,
context: &mut CubeContext,
dividend: ExpandElementTyped<I>,
) -> ExpandElementTyped<I> {
self.__expand_div_mod_method(context, dividend).0
}

pub fn __expand_modulo_method(
self,
context: &mut CubeContext,
dividend: ExpandElementTyped<I>,
) -> ExpandElementTyped<I> {
self.__expand_div_mod_method(context, dividend).1
}

pub fn __expand_div_mod_method(
self,
context: &mut CubeContext,
dividend: ExpandElementTyped<I>,
) -> (ExpandElementTyped<I>, ExpandElementTyped<I>) {
fast_divmod::expand::<I>(
context,
dividend,
self.divisor,
self.multiplier,
self.shift_right,
)
}
}

#[cube]
pub fn fast_divmod<I: Int>(dividend: I, divisor: I, multiplier: u32, shift_right: u32) -> (I, I) {
let quotient = if divisor != I::new(1) {
I::cast_from(u32::mul_hi(u32::cast_from(dividend), multiplier) >> shift_right)
} else {
dividend
};

let remainder = dividend - (quotient * divisor);
(quotient, remainder)
}

fn find_log2(x: i64) -> i64 {
let mut a = (31 - x.leading_zeros()) as i64;
a += ((x & (x - 1)) != 0) as i64;
a
}
4 changes: 4 additions & 0 deletions crates/cubecl-std/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
//! Cubecl standard library.

mod fast_math;

pub use fast_math::*;
7 changes: 6 additions & 1 deletion crates/cubecl-wgpu/src/compiler/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,10 @@ pub trait WgpuCompiler: Compiler {

#[allow(async_fn_in_trait)]
async fn request_device(adapter: &Adapter) -> (Device, Queue);
fn register_features(adapter: &Adapter, device: &Device, props: &mut DeviceProperties<Feature>);
fn register_features(
adapter: &Adapter,
device: &Device,
props: &mut DeviceProperties<Feature>,
comp_options: &mut Self::CompilationOptions,
);
}
1 change: 1 addition & 0 deletions crates/cubecl-wgpu/src/compiler/spirv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ impl WgpuCompiler for SpirvCompiler<GLCompute> {
adapter: &wgpu::Adapter,
_device: &wgpu::Device,
props: &mut cubecl_runtime::DeviceProperties<cubecl_core::Feature>,
_comp_options: &mut Self::CompilationOptions,
) {
let features = adapter.features();
unsafe {
Expand Down
Loading
Loading