Skip to content

Commit

Permalink
refactor requested changes
Browse files Browse the repository at this point in the history
  • Loading branch information
quinton11 committed Jan 23, 2025
1 parent 754461d commit e6853fd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 80 deletions.
33 changes: 0 additions & 33 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@ pub(crate) struct SubOp;
pub(crate) struct MulOp;
pub(crate) struct DivOp;
pub(crate) struct RemainderOp;
// pub(crate) struct BitwiseAndOp;
// pub(crate) struct BitwiseOrOp;
// pub(crate) struct BitwiseXorOp;
// pub(crate) struct BitwiseNotOp;

/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
/// set another precision in the family type to cast, when necessary, the input value to a valid
Expand Down Expand Up @@ -113,35 +109,6 @@ impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
}
}

// #[cube]
// impl<N: Int> BinaryOp<N> for BitwiseAndOp {
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
// //lhs + rhs
// lhs & rhs
// }
// }

// #[cube]
// impl<N: Int> BinaryOp<N> for BitwiseOrOp {
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
// lhs | rhs
// }
// }

// #[cube]
// impl<N: Int> BinaryOp<N> for BitwiseXorOp {
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
// lhs ^ rhs
// }
// }

// #[cube]
// impl<N: Int> BinaryOp<N> for BitwiseNotOp {
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
// lhs + rhs
// }
// }

#[cube(launch_unchecked)]
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
input: &Tensor<Line<C>>,
Expand Down
44 changes: 44 additions & 0 deletions crates/burn-jit/src/kernel/unary_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,47 @@ where
}
}
}

pub(crate) mod unary_basic_int {

use super::*;

pub(crate) fn launch<R, Args, I>(tensor: JitTensor<R>, args: Args) -> JitTensor<R>
where
R: JitRuntime,
for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind,
I: IntElement,
{
launch_unary_int::<R, I, BasicIntUnary, _>(tensor, |input| {
BasicIntUnaryOptionsLaunch::new(args(input))
})
}

#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
pub enum BasicIntUnaryKind {
BitwiseNot,
}

#[derive(CubeLaunch)]
struct BasicIntUnaryOptions {
#[cube(comptime)]
kind: BasicIntUnaryKind,
}
struct BasicIntUnary;

#[cube]
impl<I: Int> IntUnaryOp<I> for BasicIntUnary {
type Options = BasicIntUnaryOptions;

fn execute(input: Line<I>, options: &Self::Options) -> Line<I> {
match comptime![options.kind] {
BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input),
}
}
}

impl IntUnaryOpFamily for BasicIntUnary {
type Options<I: Int> = BasicIntUnaryOptions;
type Unary<I: Int> = Self;
}
}
26 changes: 6 additions & 20 deletions crates/burn-jit/src/ops/int_ops.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use self::unary_basic_int::BasicIntUnaryKind;

use super::{expand, numeric, permute};
use crate::kernel::{
launch_binop_int, launch_scalar_binop_int, launch_unary_int, launch_unary_numeric, reduce,
BitwiseShlOp, BitwiseShrOp, IntUnaryOp, IntUnaryOpFamily, NumericUnaryOp, NumericUnaryOpFamily,
launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily,
};
use crate::{
element::BoolElement,
Expand All @@ -11,7 +13,7 @@ use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
use cubecl::frontend::Numeric;
use cubecl::prelude::{BitwiseNot, *};
use cubecl::prelude::*;
use std::ops::Range;

impl<R, F, I, BT> IntTensorOps<Self> for JitBackend<R, F, I, BT>
Expand Down Expand Up @@ -322,23 +324,7 @@ where
}

fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
struct BitwiseNot;

#[cube]
impl<I: Int> IntUnaryOp<I> for BitwiseNot {
type Options = ();

fn execute(input: Line<I>, _options: &Self::Options) -> Line<I> {
Line::bitwise_not(input)
}
}

impl IntUnaryOpFamily for BitwiseNot {
type Options<I: Int> = ();
type Unary<I: Int> = Self;
}

launch_unary_int::<R, I, BitwiseNot, _>(tensor, |_| ())
unary_basic_int::launch::<R, _, I>(tensor, |_| &BasicIntUnaryKind::BitwiseNot)
}

fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
Expand Down
30 changes: 5 additions & 25 deletions crates/burn-router/src/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,31 +793,19 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
handles.register_float_tensor::<B>(&desc.out.id, output);
}
IntOperationDescription::BitwiseAnd(desc) => {
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
let rhs = handles.get_int_tensor::<B>(&desc.rhs);

let output = B::bitwise_and(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
binary_int_ops!(handles, desc, B::bitwise_and)
}
IntOperationDescription::BitwiseAndScalar(desc) => {
scalar_int_ops!(handles, desc, B::bitwise_and_scalar)
}
IntOperationDescription::BitwiseOr(desc) => {
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
let rhs = handles.get_int_tensor::<B>(&desc.rhs);

let output = B::bitwise_or(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
binary_int_ops!(handles, desc, B::bitwise_or)
}
IntOperationDescription::BitwiseOrScalar(desc) => {
scalar_int_ops!(handles, desc, B::bitwise_or_scalar)
}
IntOperationDescription::BitwiseXor(desc) => {
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
let rhs = handles.get_int_tensor::<B>(&desc.rhs);

let output = B::bitwise_xor(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
binary_int_ops!(handles, desc, B::bitwise_xor)
}
IntOperationDescription::BitwiseXorScalar(desc) => {
scalar_int_ops!(handles, desc, B::bitwise_xor_scalar)
Expand All @@ -826,18 +814,10 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
unary_int_ops!(handles, desc, B::bitwise_not)
}
IntOperationDescription::BitwiseLeftShift(desc) => {
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
let rhs = handles.get_int_tensor::<B>(&desc.rhs);

let output = B::bitwise_left_shift(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
binary_int_ops!(handles, desc, B::bitwise_left_shift)
}
IntOperationDescription::BitwiseRightShift(desc) => {
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
let rhs = handles.get_int_tensor::<B>(&desc.rhs);

let output = B::bitwise_right_shift(lhs, rhs);
handles.register_int_tensor::<B>(&desc.out.id, output);
binary_int_ops!(handles, desc, B::bitwise_right_shift)
}
IntOperationDescription::BitwiseLeftShiftScalar(desc) => {
scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar)
Expand Down
4 changes: 2 additions & 2 deletions crates/burn-tch/src/ops/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ impl TchOps {
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(),
|lhs, rhs| rhs.f_bitwise_left_shift_(lhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
)
}
Expand Down Expand Up @@ -569,7 +569,7 @@ impl TchOps {
lhs,
rhs,
|lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(),
|lhs, rhs| rhs.f_bitwise_right_shift_(lhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
)
}
Expand Down

0 comments on commit e6853fd

Please sign in to comment.