Skip to content

Commit e6853fd

Browse files
committed
refactor requested changes
1 parent 754461d commit e6853fd

File tree

5 files changed

+57
-80
lines changed

5 files changed

+57
-80
lines changed

crates/burn-jit/src/kernel/binary.rs

-33
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@ pub(crate) struct SubOp;
2424
pub(crate) struct MulOp;
2525
pub(crate) struct DivOp;
2626
pub(crate) struct RemainderOp;
27-
// pub(crate) struct BitwiseAndOp;
28-
// pub(crate) struct BitwiseOrOp;
29-
// pub(crate) struct BitwiseXorOp;
30-
// pub(crate) struct BitwiseNotOp;
3127

3228
/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
3329
/// set another precision in the family type to cast, when necessary, the input value to a valid
@@ -113,35 +109,6 @@ impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
113109
}
114110
}
115111

116-
// #[cube]
117-
// impl<N: Int> BinaryOp<N> for BitwiseAndOp {
118-
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
119-
// //lhs + rhs
120-
// lhs & rhs
121-
// }
122-
// }
123-
124-
// #[cube]
125-
// impl<N: Int> BinaryOp<N> for BitwiseOrOp {
126-
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
127-
// lhs | rhs
128-
// }
129-
// }
130-
131-
// #[cube]
132-
// impl<N: Int> BinaryOp<N> for BitwiseXorOp {
133-
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
134-
// lhs ^ rhs
135-
// }
136-
// }
137-
138-
// #[cube]
139-
// impl<N: Int> BinaryOp<N> for BitwiseNotOp {
140-
// fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
141-
// lhs + rhs
142-
// }
143-
// }
144-
145112
#[cube(launch_unchecked)]
146113
pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
147114
input: &Tensor<Line<C>>,

crates/burn-jit/src/kernel/unary_int.rs

+44
Original file line numberDiff line numberDiff line change
@@ -102,3 +102,47 @@ where
102102
}
103103
}
104104
}
105+
106+
pub(crate) mod unary_basic_int {
107+
108+
use super::*;
109+
110+
pub(crate) fn launch<R, Args, I>(tensor: JitTensor<R>, args: Args) -> JitTensor<R>
111+
where
112+
R: JitRuntime,
113+
for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind,
114+
I: IntElement,
115+
{
116+
launch_unary_int::<R, I, BasicIntUnary, _>(tensor, |input| {
117+
BasicIntUnaryOptionsLaunch::new(args(input))
118+
})
119+
}
120+
121+
#[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)]
122+
pub enum BasicIntUnaryKind {
123+
BitwiseNot,
124+
}
125+
126+
#[derive(CubeLaunch)]
127+
struct BasicIntUnaryOptions {
128+
#[cube(comptime)]
129+
kind: BasicIntUnaryKind,
130+
}
131+
struct BasicIntUnary;
132+
133+
#[cube]
134+
impl<I: Int> IntUnaryOp<I> for BasicIntUnary {
135+
type Options = BasicIntUnaryOptions;
136+
137+
fn execute(input: Line<I>, options: &Self::Options) -> Line<I> {
138+
match comptime![options.kind] {
139+
BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input),
140+
}
141+
}
142+
}
143+
144+
impl IntUnaryOpFamily for BasicIntUnary {
145+
type Options<I: Int> = BasicIntUnaryOptions;
146+
type Unary<I: Int> = Self;
147+
}
148+
}

crates/burn-jit/src/ops/int_ops.rs

+6-20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
use self::unary_basic_int::BasicIntUnaryKind;
2+
13
use super::{expand, numeric, permute};
24
use crate::kernel::{
3-
launch_binop_int, launch_scalar_binop_int, launch_unary_int, launch_unary_numeric, reduce,
4-
BitwiseShlOp, BitwiseShrOp, IntUnaryOp, IntUnaryOpFamily, NumericUnaryOp, NumericUnaryOpFamily,
5+
launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int,
6+
BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily,
57
};
68
use crate::{
79
element::BoolElement,
@@ -11,7 +13,7 @@ use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
1113
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
1214
use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
1315
use cubecl::frontend::Numeric;
14-
use cubecl::prelude::{BitwiseNot, *};
16+
use cubecl::prelude::*;
1517
use std::ops::Range;
1618

1719
impl<R, F, I, BT> IntTensorOps<Self> for JitBackend<R, F, I, BT>
@@ -322,23 +324,7 @@ where
322324
}
323325

324326
fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
325-
struct BitwiseNot;
326-
327-
#[cube]
328-
impl<I: Int> IntUnaryOp<I> for BitwiseNot {
329-
type Options = ();
330-
331-
fn execute(input: Line<I>, _options: &Self::Options) -> Line<I> {
332-
Line::bitwise_not(input)
333-
}
334-
}
335-
336-
impl IntUnaryOpFamily for BitwiseNot {
337-
type Options<I: Int> = ();
338-
type Unary<I: Int> = Self;
339-
}
340-
341-
launch_unary_int::<R, I, BitwiseNot, _>(tensor, |_| ())
327+
unary_basic_int::launch::<R, _, I>(tensor, |_| &BasicIntUnaryKind::BitwiseNot)
342328
}
343329

344330
fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {

crates/burn-router/src/runner.rs

+5-25
Original file line numberDiff line numberDiff line change
@@ -793,31 +793,19 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
793793
handles.register_float_tensor::<B>(&desc.out.id, output);
794794
}
795795
IntOperationDescription::BitwiseAnd(desc) => {
796-
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
797-
let rhs = handles.get_int_tensor::<B>(&desc.rhs);
798-
799-
let output = B::bitwise_and(lhs, rhs);
800-
handles.register_int_tensor::<B>(&desc.out.id, output);
796+
binary_int_ops!(handles, desc, B::bitwise_and)
801797
}
802798
IntOperationDescription::BitwiseAndScalar(desc) => {
803799
scalar_int_ops!(handles, desc, B::bitwise_and_scalar)
804800
}
805801
IntOperationDescription::BitwiseOr(desc) => {
806-
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
807-
let rhs = handles.get_int_tensor::<B>(&desc.rhs);
808-
809-
let output = B::bitwise_or(lhs, rhs);
810-
handles.register_int_tensor::<B>(&desc.out.id, output);
802+
binary_int_ops!(handles, desc, B::bitwise_or)
811803
}
812804
IntOperationDescription::BitwiseOrScalar(desc) => {
813805
scalar_int_ops!(handles, desc, B::bitwise_or_scalar)
814806
}
815807
IntOperationDescription::BitwiseXor(desc) => {
816-
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
817-
let rhs = handles.get_int_tensor::<B>(&desc.rhs);
818-
819-
let output = B::bitwise_xor(lhs, rhs);
820-
handles.register_int_tensor::<B>(&desc.out.id, output);
808+
binary_int_ops!(handles, desc, B::bitwise_xor)
821809
}
822810
IntOperationDescription::BitwiseXorScalar(desc) => {
823811
scalar_int_ops!(handles, desc, B::bitwise_xor_scalar)
@@ -826,18 +814,10 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
826814
unary_int_ops!(handles, desc, B::bitwise_not)
827815
}
828816
IntOperationDescription::BitwiseLeftShift(desc) => {
829-
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
830-
let rhs = handles.get_int_tensor::<B>(&desc.rhs);
831-
832-
let output = B::bitwise_left_shift(lhs, rhs);
833-
handles.register_int_tensor::<B>(&desc.out.id, output);
817+
binary_int_ops!(handles, desc, B::bitwise_left_shift)
834818
}
835819
IntOperationDescription::BitwiseRightShift(desc) => {
836-
let lhs = handles.get_int_tensor::<B>(&desc.lhs);
837-
let rhs = handles.get_int_tensor::<B>(&desc.rhs);
838-
839-
let output = B::bitwise_right_shift(lhs, rhs);
840-
handles.register_int_tensor::<B>(&desc.out.id, output);
820+
binary_int_ops!(handles, desc, B::bitwise_right_shift)
841821
}
842822
IntOperationDescription::BitwiseLeftShiftScalar(desc) => {
843823
scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar)

crates/burn-tch/src/ops/base.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -541,7 +541,7 @@ impl TchOps {
541541
lhs,
542542
rhs,
543543
|lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(),
544-
|lhs, rhs| rhs.f_bitwise_left_shift_(lhs).unwrap(),
544+
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
545545
|lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(),
546546
)
547547
}
@@ -569,7 +569,7 @@ impl TchOps {
569569
lhs,
570570
rhs,
571571
|lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(),
572-
|lhs, rhs| rhs.f_bitwise_right_shift_(lhs).unwrap(),
572+
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
573573
|lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(),
574574
)
575575
}

0 commit comments

Comments
 (0)