diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index 8a7c01bbc9..410d531d74 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. -| Burn | PyTorch Equivalent | -| ------------------------------------- | ------------------------------------------------------------------------- | -| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | -| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | -| `Tensor::from_primitive(primitive)` | N/A | -| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | -| `tensor.all()` | `tensor.all()` | -| `tensor.all_dim(dim)` | `tensor.all(dim)` | -| `tensor.any()` | `tensor.any()` | -| `tensor.any_dim(dim)` | `tensor.any(dim)` | -| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | -| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | -| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | -| `tensor.device()` | `tensor.device` | -| `tensor.dtype()` | `tensor.dtype` | -| `tensor.dims()` | `tensor.size()` | -| `tensor.equal(other)` | `x == y` | -| `tensor.expand(shape)` | `tensor.expand(shape)` | -| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | -| `tensor.flip(axes)` | `tensor.flip(axes)` | -| `tensor.into_data()` | N/A | -| `tensor.into_primitive()` | N/A | -| `tensor.into_scalar()` | `tensor.item()` | -| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | -| `tensor.not_equal(other)` | `x != y` | -| `tensor.permute(axes)` | `tensor.permute(axes)` | -| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | -| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | -| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | -| `tensor.reshape(shape)` | `tensor.view(shape)` | -| `tensor.shape()` | `tensor.shape` | -| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | -| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | -| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | -| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | -| `tensor.to_data()` | N/A | -| `tensor.to_device(device)` | `tensor.to(device)` | -| `tensor.transpose()` | `tensor.T` | -| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | -| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | -| `tensor.unsqueeze_dims(dims)` | N/A | +| Burn | PyTorch Equivalent | +| ------------------------------------------- | ------------------------------------------------------------------------- | +| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | +| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | +| `Tensor::from_primitive(primitive)` | N/A | +| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | +| `tensor.all()` | `tensor.all()` | +| `tensor.all_dim(dim)` | `tensor.all(dim)` | +| `tensor.any()` | `tensor.any()` | +| `tensor.any_dim(dim)` | `tensor.any(dim)` | +| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | +| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` | +| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` | +| `tensor.device()` | `tensor.device` | +| `tensor.dtype()` | `tensor.dtype` | +| `tensor.dims()` | `tensor.size()` | +| `tensor.equal(other)` | `x == y` | +| `tensor.expand(shape)` | `tensor.expand(shape)` | +| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | +| `tensor.flip(axes)` | `tensor.flip(axes)` | +| `tensor.into_data()` | N/A | +| `tensor.into_primitive()` | N/A | +| `tensor.into_scalar()` | `tensor.item()` | +| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | +| `tensor.not_equal(other)` | `x != y` | +| `tensor.permute(axes)` | `tensor.permute(axes)` | +| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | +| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` | +| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` | +| `tensor.reshape(shape)` | `tensor.view(shape)` | +| `tensor.shape()` | `tensor.shape` | +| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | +| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | +| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | +| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | +| `tensor.to_data()` | N/A | +| `tensor.to_device(device)` | `tensor.to(device)` | +| `tensor.transpose()` | `tensor.T` | +| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | +| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | +| `tensor.unsqueeze_dims(dims)` | N/A | ### Numeric Operations @@ -258,32 +258,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. Those operations are only available for `Float` tensors. -| Burn API | PyTorch Equivalent | -| --------------------------------------------- | ---------------------------------- | -| `tensor.cast(dtype)` | `tensor.to(dtype)` | -| `tensor.ceil()` | `tensor.ceil()` | -| `tensor.cos()` | `tensor.cos()` | -| `tensor.erf()` | `tensor.erf()` | -| `tensor.exp()` | `tensor.exp()` | -| `tensor.floor()` | `tensor.floor()` | -| `tensor.from_floats(floats, device)` | N/A | -| `tensor.from_full_precision(tensor)` | N/A | -| `tensor.int()` | Similar to `tensor.to(torch.long)` | -| `tensor.log()` | `tensor.log()` | -| `tensor.log1p()` | `tensor.log1p()` | -| `tensor.matmul(other)` | `tensor.matmul(other)` | -| `tensor.random(shape, distribution, device)` | N/A | -| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | -| `tensor.recip()` | `tensor.reciprocal()` | -| `tensor.round()` | `tensor.round()` | -| `tensor.sin()` | `tensor.sin()` | -| `tensor.sqrt()` | `tensor.sqrt()` | -| `tensor.tanh()` | `tensor.tanh()` | -| `tensor.to_full_precision()` | `tensor.to(torch.float)` | -| `tensor.var(dim)` | `tensor.var(dim)` | -| `tensor.var_bias(dim)` | N/A | -| `tensor.var_mean(dim)` | N/A | -| `tensor.var_mean_bias(dim)` | N/A | +| Burn API | PyTorch Equivalent | +| -------------------------------------------- | ---------------------------------- | +| `tensor.cast(dtype)` | `tensor.to(dtype)` | +| `tensor.ceil()` | `tensor.ceil()` | +| `tensor.cos()` | `tensor.cos()` | +| `tensor.erf()` | `tensor.erf()` | +| `tensor.exp()` | `tensor.exp()` | +| `tensor.floor()` | `tensor.floor()` | +| `tensor.from_floats(floats, device)` | N/A | +| `tensor.from_full_precision(tensor)` | N/A | +| `tensor.int()` | Similar to `tensor.to(torch.long)` | +| `tensor.log()` | `tensor.log()` | +| `tensor.log1p()` | `tensor.log1p()` | +| `tensor.matmul(other)` | `tensor.matmul(other)` | +| `tensor.random(shape, distribution, device)` | N/A | +| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | +| `tensor.recip()` | `tensor.reciprocal()` | +| `tensor.round()` | `tensor.round()` | +| `tensor.sin()` | `tensor.sin()` | +| `tensor.sqrt()` | `tensor.sqrt()` | +| `tensor.tanh()` | `tensor.tanh()` | +| `tensor.to_full_precision()` | `tensor.to(torch.float)` | +| `tensor.var(dim)` | `tensor.var(dim)` | +| `tensor.var_bias(dim)` | N/A | +| `tensor.var_mean(dim)` | N/A | +| `tensor.var_mean_bias(dim)` | N/A | ### Int Operations @@ -293,6 +293,17 @@ Those operations are only available for `Int` tensors. | ------------------------------------------------ | ------------------------------------------------------- | | `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` | | `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` | +| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` | +| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` | +| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` | +| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` | +| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` | +| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` | +| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` | +| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` | +| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` | +| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` | +| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` | | `tensor.float()` | `tensor.to(torch.float)` | | `tensor.from_ints(ints)` | N/A | | `tensor.int_random(shape, distribution, device)` | N/A | diff --git a/crates/burn-autodiff/src/ops/int_tensor.rs b/crates/burn-autodiff/src/ops/int_tensor.rs index 4aad98bb46..f3439d1cad 100644 --- a/crates/burn-autodiff/src/ops/int_tensor.rs +++ b/crates/burn-autodiff/src/ops/int_tensor.rs @@ -348,4 +348,48 @@ impl IntTensorOps for Autodiff { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { B::int_argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_and(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_or(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_xor(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + B::bitwise_not(tensor) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + B::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: B::IntElem) -> IntTensor { + B::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 4ae0c53de7..08b84251fa 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -372,4 +372,47 @@ impl IntTensorOps for Candle) -> IntTensor { sign(tensor) } + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_and is not implemented for Candle IntTensor"); + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_and_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_or is not implemented for Candle IntTensor"); + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_or_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_xor is not implemented for Candle IntTensor"); + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_xor_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unimplemented!("bitwise_not is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_left_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + unimplemented!("bitwise_right_shift is not implemented for Candle IntTensor"); + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_left_shift_scalar is not implemented for Candle IntTensor"); + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + unimplemented!("bitwise_right_shift_scalar is not implemented for Candle IntTensor"); + } } diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index bdb47df02c..e2115cbf6a 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -1819,4 +1819,267 @@ impl IntTensorOps for Fusion { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseAndOps, B::bitwise_and); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAnd(desc.clone())), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseAndOps, B::bitwise_and_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseAndScalar( + desc.clone(), + )), + BitwiseAndOps::::new(desc), + ); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseOrOps, B::bitwise_or); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOr(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseOrOps, B::bitwise_or_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseOrScalar(desc.clone())), + BitwiseOrOps::::new(desc), + ); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseXorOps, B::bitwise_xor); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXor(desc.clone())), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseXorOps, B::bitwise_xor_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseXorScalar( + desc.clone(), + )), + BitwiseXorOps::::new(desc), + ); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_int_ops!(BitwiseNotOps, B::bitwise_not); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::IntElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseNot(desc.clone())), + BitwiseNotOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShift( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseLeftShiftOps, B::bitwise_left_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseLeftShiftScalar( + desc.clone(), + )), + BitwiseLeftShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + binary_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift); + + let stream_1 = lhs.stream; + let stream_2 = rhs.stream; + let out = lhs.client.tensor_uninitialized( + binary_ops_shape(&lhs.shape, &rhs.shape), + B::IntElem::dtype(), + ); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream_1, stream_2], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShift( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + scalar_int_ops!(BitwiseRightShiftOps, B::bitwise_right_shift_scalar); + + let stream = lhs.stream; + let out = lhs + .client + .tensor_uninitialized(lhs.shape.clone(), B::IntElem::dtype()); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + repr::OperationDescription::Int(IntOperationDescription::BitwiseRightShiftScalar( + desc.clone(), + )), + BitwiseRightShiftOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 0f9c75fc94..d85e06cc09 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -694,6 +694,82 @@ impl RelativeOps for IntOperationDescription { out: desc.out.to_relative(converter), }) } + IntOperationDescription::BitwiseAnd(desc) => { + IntOperationDescription::BitwiseAnd(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + IntOperationDescription::BitwiseAndScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOr(desc) => { + IntOperationDescription::BitwiseOr(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + IntOperationDescription::BitwiseOrScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXor(desc) => { + IntOperationDescription::BitwiseXor(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + IntOperationDescription::BitwiseXorScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseNot(desc) => { + IntOperationDescription::BitwiseNot(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + IntOperationDescription::BitwiseLeftShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + IntOperationDescription::BitwiseLeftShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShift(desc) => { + IntOperationDescription::BitwiseRightShift(BinaryOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + IntOperationDescription::BitwiseRightShiftScalar(ScalarOperationDescription { + lhs: desc.lhs.to_relative(converter), + rhs: desc.rhs, + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index f0e15352cf..a1bbab7f5f 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -57,6 +57,7 @@ impl IntElement for i64 {} impl IntElement for i32 {} impl IntElement for i16 {} impl IntElement for i8 {} +impl IntElement for u32 {} impl BoolElement for u8 {} impl BoolElement for u32 {} diff --git a/crates/burn-jit/src/kernel/binary_int.rs b/crates/burn-jit/src/kernel/binary_int.rs new file mode 100644 index 0000000000..06706a7d28 --- /dev/null +++ b/crates/burn-jit/src/kernel/binary_int.rs @@ -0,0 +1,276 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use burn_tensor::Shape; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +use super::into_contiguous; + +pub(crate) trait BinaryOpIntFamily: Send + Sync + 'static { + type BinaryOp: BinaryOpInt; +} + +#[cube] +pub(crate) trait BinaryOpInt: 'static + Send + Sync { + /// Execute a binary operation. + fn execute(lhs: Line, rhs: Line) -> Line; +} + +pub(crate) struct BitwiseAndOp; +pub(crate) struct BitwiseOrOp; +pub(crate) struct BitwiseXorOp; +pub(crate) struct BitwiseShrOp; +pub(crate) struct BitwiseShlOp; + +impl BinaryOpIntFamily for BitwiseAndOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseOrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseXorOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShrOp { + type BinaryOp = Self; +} + +impl BinaryOpIntFamily for BitwiseShlOp { + type BinaryOp = Self; +} + +#[cube] +impl BinaryOpInt for BitwiseAndOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs & rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseOrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs | rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseXorOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs ^ rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShrOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs >> rhs + } +} + +#[cube] +impl BinaryOpInt for BitwiseShlOp { + fn execute(lhs: Line, rhs: Line) -> Line { + lhs << rhs + } +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_scalar_binop_int( + input: &Tensor>, + scalar: C, + output: &mut Tensor>, +) { + if ABSOLUTE_POS >= output.len() { + return; + } + + output[ABSOLUTE_POS] = O::BinaryOp::::execute(input[ABSOLUTE_POS], Line::new(scalar)); +} + +#[cube(launch_unchecked)] +pub(crate) fn kernel_binop_int( + lhs: &Tensor>, + rhs: &Tensor>, + out: &mut Tensor>, + #[comptime] rank: Option, + #[comptime] to_contiguous_lhs: bool, + #[comptime] to_contiguous_rhs: bool, +) { + let offset_out = ABSOLUTE_POS; + let mut offset_lhs = ABSOLUTE_POS; + let mut offset_rhs = ABSOLUTE_POS; + + if offset_out >= out.len() { + return; + } + + if to_contiguous_lhs { + offset_lhs = index_offset_with_layout::( + lhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + if to_contiguous_rhs { + offset_rhs = index_offset_with_layout::( + rhs, + out, + offset_out, + 0, + rank.unwrap_or_else(|| out.rank()), + rank.is_some(), + ); + } + + out[offset_out] = O::BinaryOp::::execute(lhs[offset_lhs], rhs[offset_rhs]); +} + +pub(crate) fn launch_binop_int( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + let ndims = lhs.shape.num_dims(); + let line_size_lhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &lhs.shape.dims, + &lhs.strides, + ndims - 1, + ); + let line_size_rhs = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &rhs.shape.dims, + &rhs.strides, + ndims - 1, + ); + let line_size = Ord::min(line_size_lhs, line_size_rhs); + + let mut shape_out = vec![0; ndims]; + lhs.shape + .dims + .iter() + .zip(rhs.shape.dims.iter()) + .enumerate() + .for_each(|(index, (dim_lhs, dim_rhs))| { + shape_out[index] = usize::max(*dim_lhs, *dim_rhs); + }); + + let shape_out = Shape::from(shape_out); + let client = lhs.client.clone(); + let num_elems = shape_out.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if lhs.can_mut_broadcast(&rhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(0), + None, + false, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + ); + + lhs + } else if rhs.can_mut_broadcast(&lhs) { + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + TensorArg::alias(1), + None, + rhs.strides != lhs.strides || rhs.shape != lhs.shape, + false, + ); + + rhs + } else { + let output = empty_device::(lhs.client.clone(), lhs.device.clone(), shape_out); + let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape; + let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape; + + kernel_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + lhs.as_tensor_arg::(line_size), + rhs.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + None, + to_contiguous_lhs, + to_contiguous_rhs, + ); + + output + } + } +} + +pub(crate) fn launch_scalar_binop_int( + mut tensor: JitTensor, + scalar: E, +) -> JitTensor { + if !tensor.is_contiguous_buffer() { + tensor = into_contiguous(tensor); + } + + // Vectorization is only enabled when the last dimension is contiguous. + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + + unsafe { + if tensor.can_mut() { + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + TensorArg::alias(0), + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + kernel_scalar_binop_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + ScalarArg::new(scalar), + output.as_tensor_arg::(line_size), + ); + + output + } + } +} diff --git a/crates/burn-jit/src/kernel/mod.rs b/crates/burn-jit/src/kernel/mod.rs index fd23cd2e2d..93d2833976 100644 --- a/crates/burn-jit/src/kernel/mod.rs +++ b/crates/burn-jit/src/kernel/mod.rs @@ -1,4 +1,5 @@ mod binary; +mod binary_int; mod cast; mod clamp; mod comparison; @@ -6,13 +7,16 @@ mod contiguous; mod index; mod mask; mod unary_float; +mod unary_int; mod unary_numeric; pub(crate) use binary::*; +pub(crate) use binary_int::*; pub use cast::*; pub use contiguous::*; pub use mask::*; pub(crate) use unary_float::*; +pub(crate) use unary_int::*; pub(crate) use unary_numeric::*; pub use burn_common::PLANE_DIM_APPROX; diff --git a/crates/burn-jit/src/kernel/unary_int.rs b/crates/burn-jit/src/kernel/unary_int.rs new file mode 100644 index 0000000000..5e60898699 --- /dev/null +++ b/crates/burn-jit/src/kernel/unary_int.rs @@ -0,0 +1,148 @@ +use crate::{ops::numeric::empty_device, tensor::JitTensor, IntElement, JitRuntime}; +use cubecl::{ + calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*, + tensor_line_size_parallel, +}; + +pub(crate) trait IntUnaryOpFamily: 'static + Send + Sync { + type Options: LaunchArg; + type Unary: IntUnaryOp>; +} + +#[cube] +pub(crate) trait IntUnaryOp: 'static + Send + Sync { + type Options: LaunchArg; + + fn execute(input: Line, options: &Self::Options) -> Line; +} + +#[cube(launch_unchecked)] +pub(crate) fn unary_int( + input: &Tensor>, + output: &mut Tensor>, + options: &O::Options, + #[comptime] rank: Option, + #[comptime] to_contiguous: bool, +) { + let offset_output = ABSOLUTE_POS; + + if offset_output >= output.len() { + return; + } + + if comptime![to_contiguous] { + let offset_input = index_offset_with_layout::( + input, + output, + offset_output, + 0, + rank.unwrap_or_else(|| output.rank()), + rank.is_some(), + ); + + output[offset_output] = O::Unary::::execute(input[offset_input], options); + } else { + output[offset_output] = O::Unary::::execute(input[offset_output], options); + } +} + +pub(crate) fn launch_unary_int(tensor: JitTensor, args: Args) -> JitTensor +where + for<'a> Args: FnOnce(&'a ()) -> RuntimeArg<'a, O::Options, R>, + R: JitRuntime, + E: IntElement + Int, + O: IntUnaryOpFamily, +{ + let ndims = tensor.shape.num_dims(); + let line_size = tensor_line_size_parallel( + R::line_size_elem(&E::as_elem_native_unchecked()), + &tensor.shape.dims, + &tensor.strides, + ndims - 1, + ); + let client = tensor.client.clone(); + let num_elems = tensor.shape.num_elements(); + + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim); + let is_contiguous = tensor.is_contiguous(); + + unsafe { + if tensor.can_mut() && tensor.is_contiguous_buffer() { + unary_int::launch_unchecked::( + &client, + cube_count, + cube_dim, + tensor.as_tensor_arg::(line_size), + TensorArg::alias(0), + args(&()), + None, + false, + ); + + tensor + } else { + let output = empty_device::( + tensor.client.clone(), + tensor.device.clone(), + tensor.shape.clone(), + ); + + unary_int::launch_unchecked::( + &client, + cube_count, + CubeDim::default(), + tensor.as_tensor_arg::(line_size), + output.as_tensor_arg::(line_size), + args(&()), + Some(ndims as u32), + !is_contiguous, + ); + output + } + } +} + +pub(crate) mod unary_basic_int { + + use super::*; + + pub(crate) fn launch(tensor: JitTensor, args: Args) -> JitTensor + where + R: JitRuntime, + for<'a> Args: FnOnce(&'a ()) -> &'a BasicIntUnaryKind, + I: IntElement, + { + launch_unary_int::(tensor, |input| { + BasicIntUnaryOptionsLaunch::new(args(input)) + }) + } + + #[derive(Clone, Copy, Debug, Hash, Eq, PartialEq, serde::Serialize, serde::Deserialize)] + pub enum BasicIntUnaryKind { + BitwiseNot, + } + + #[derive(CubeLaunch)] + struct BasicIntUnaryOptions { + #[cube(comptime)] + kind: BasicIntUnaryKind, + } + struct BasicIntUnary; + + #[cube] + impl IntUnaryOp for BasicIntUnary { + type Options = BasicIntUnaryOptions; + + fn execute(input: Line, options: &Self::Options) -> Line { + match comptime![options.kind] { + BasicIntUnaryKind::BitwiseNot => Line::bitwise_not(input), + } + } + } + + impl IntUnaryOpFamily for BasicIntUnary { + type Options = BasicIntUnaryOptions; + type Unary = Self; + } +} diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 5702a90849..d7b84e5e64 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,5 +1,10 @@ +use self::unary_basic_int::BasicIntUnaryKind; + use super::{expand, numeric, permute}; -use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily}; +use crate::kernel::{ + launch_binop_int, launch_scalar_binop_int, launch_unary_numeric, reduce, unary_basic_int, + BitwiseShlOp, BitwiseShrOp, NumericUnaryOp, NumericUnaryOpFamily, +}; use crate::{ element::BoolElement, kernel::prng::{random_bernoulli, random_normal, random_uniform}, @@ -293,4 +298,56 @@ where fn int_flip(tensor: IntTensor, axes: &[usize]) -> IntTensor { kernel::flip::(tensor, axes) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_and::(lhs, rhs) + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_and_scalar::(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_or::(lhs, rhs) + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + numeric::bitwise_xor::(lhs, rhs) + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + numeric::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + unary_basic_int::launch::(tensor, |_| &BasicIntUnaryKind::BitwiseNot) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = kernel::cast::(rhs); + launch_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = rhs.elem::(); + launch_scalar_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = kernel::cast::(rhs); + launch_binop_int::(lhs_cast, rhs_cast) + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let lhs_cast = kernel::cast::(lhs); + let rhs_cast = rhs.elem::(); + launch_scalar_binop_int::(lhs_cast, rhs_cast) + } } diff --git a/crates/burn-jit/src/ops/numeric.rs b/crates/burn-jit/src/ops/numeric.rs index d0d5be8468..2c2c7987ab 100644 --- a/crates/burn-jit/src/ops/numeric.rs +++ b/crates/burn-jit/src/ops/numeric.rs @@ -1,8 +1,9 @@ use crate::kernel::{ - launch_binop, launch_scalar_binop, AddOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, + launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int, AddOp, + BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp, }; use crate::{element::JitElement, tensor::JitTensor}; -use crate::{FloatElement, JitRuntime}; +use crate::{FloatElement, IntElement, JitRuntime}; use burn_tensor::{ElementConversion, Shape}; use cubecl::client::ComputeClient; use cubecl::tensor_vectorization_factor; @@ -139,3 +140,36 @@ pub fn remainder_scalar(lhs: JitTensor, rhs: E) pub fn pow(lhs: JitTensor, rhs: JitTensor) -> JitTensor { launch_binop::>(lhs, rhs) } + +pub fn bitwise_and( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_and_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_or( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_or_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor( + lhs: JitTensor, + rhs: JitTensor, +) -> JitTensor { + launch_binop_int::(lhs, rhs) +} + +pub fn bitwise_xor_scalar(lhs: JitTensor, rhs: E) -> JitTensor { + launch_scalar_binop_int::(lhs, rhs) +} diff --git a/crates/burn-ndarray/src/ops/int_tensor.rs b/crates/burn-ndarray/src/ops/int_tensor.rs index 9009b5c4a8..43c7cdb100 100644 --- a/crates/burn-ndarray/src/ops/int_tensor.rs +++ b/crates/burn-ndarray/src/ops/int_tensor.rs @@ -351,4 +351,71 @@ impl IntTensorOps fn int_expand(tensor: NdArrayTensor, shape: Shape) -> NdArrayTensor { NdArrayOps::expand(tensor, shape) } + + fn bitwise_and(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() & (b.elem::())).elem() + }) + } + + fn bitwise_and_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() & rhs.elem::()).elem() + }) + } + + fn bitwise_or(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() | (b.elem::())).elem() + }) + } + + fn bitwise_or_scalar( + lhs: burn_tensor::ops::IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> burn_tensor::ops::IntTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() | rhs.elem::()).elem() + }) + } + + fn bitwise_xor(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() ^ (b.elem::())).elem() + }) + } + + fn bitwise_xor_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() ^ rhs.elem::()).elem() + }) + } + + fn bitwise_not(tensor: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::()).elem()) + } + + fn bitwise_left_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() << (b.elem::())).elem() + }) + } + + fn bitwise_left_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() << rhs.elem::()).elem() + }) + } + + fn bitwise_right_shift(lhs: NdArrayTensor, rhs: NdArrayTensor) -> NdArrayTensor { + NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| { + (a.elem::() >> (b.elem::())).elem() + }) + } + + fn bitwise_right_shift_scalar(lhs: NdArrayTensor, rhs: I) -> NdArrayTensor { + NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| { + (a.elem::() >> rhs.elem::()).elem() + }) + } } diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index db81602d4f..5d84131e32 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -1173,4 +1173,201 @@ impl IntTensorOps for BackendRouter { out } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAnd(desc), + )); + + out + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOr(desc), + )); + + out + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXor(desc), + )); + + out + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseNot(desc), + )); + + out + } + + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseAndScalar(desc), + )); + + out + } + + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseOrScalar(desc), + )); + + out + } + + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseXorScalar(desc), + )); + + out + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShift(desc), + )); + + out + } + + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseLeftShiftScalar(desc), + )); + + out + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(binary_ops_shape(&lhs.shape, &rhs.shape), dtype); + + let desc = BinaryOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShift(desc), + )); + + out + } + + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { + let client = lhs.client.clone(); + let dtype = lhs.dtype; + let out = client.register_empty_tensor(lhs.shape.clone(), dtype); + + let desc = ScalarOperationDescription { + lhs: lhs.into_description(), + rhs: rhs.elem(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Int( + IntOperationDescription::BitwiseRightShiftScalar(desc), + )); + + out + } } diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 04f93a4769..9521cf66ae 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -792,6 +792,39 @@ impl RunnerClient for Runner { let output = B::int_into_float(tensor); handles.register_float_tensor::(&desc.out.id, output); } + IntOperationDescription::BitwiseAnd(desc) => { + binary_int_ops!(handles, desc, B::bitwise_and) + } + IntOperationDescription::BitwiseAndScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_and_scalar) + } + IntOperationDescription::BitwiseOr(desc) => { + binary_int_ops!(handles, desc, B::bitwise_or) + } + IntOperationDescription::BitwiseOrScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_or_scalar) + } + IntOperationDescription::BitwiseXor(desc) => { + binary_int_ops!(handles, desc, B::bitwise_xor) + } + IntOperationDescription::BitwiseXorScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_xor_scalar) + } + IntOperationDescription::BitwiseNot(desc) => { + unary_int_ops!(handles, desc, B::bitwise_not) + } + IntOperationDescription::BitwiseLeftShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_left_shift) + } + IntOperationDescription::BitwiseRightShift(desc) => { + binary_int_ops!(handles, desc, B::bitwise_right_shift) + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_left_shift_scalar) + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + scalar_int_ops!(handles, desc, B::bitwise_right_shift_scalar) + } }, OperationDescription::Float(_dtype, op) => match op { FloatOperationDescription::Exp(desc) => { diff --git a/crates/burn-tch/src/ops/base.rs b/crates/burn-tch/src/ops/base.rs index 7b04207871..704c6176cc 100644 --- a/crates/burn-tch/src/ops/base.rs +++ b/crates/burn-tch/src/ops/base.rs @@ -477,4 +477,118 @@ impl TchOps { pub fn argsort(tensor: TchTensor, dim: usize, descending: bool) -> TchTensor { TchTensor::new(tensor.tensor.argsort(dim as i64, descending)) } + + pub fn bitwise_and(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_and_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_and_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_and_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_and_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_and_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_and(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_or(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_or_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_or_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_or_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_or_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_or_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_or(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_xor(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_xor_tensor_(rhs).unwrap(), + |lhs, rhs| rhs.f_bitwise_xor_tensor_(lhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_xor_tensor(rhs).unwrap(), + ) + } + + pub fn bitwise_xor_scalar + Clone>(tensor: TchTensor, scalar: S) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_xor_(scalar.clone().into()).unwrap(), + |tensor| tensor.f_bitwise_xor(scalar.clone().into()).unwrap(), + ) + } + + pub fn bitwise_not(tensor: TchTensor) -> TchTensor { + tensor.unary_ops( + |mut tensor| tensor.f_bitwise_not_().unwrap(), + |tensor| tensor.f_bitwise_not().unwrap(), + ) + } + + pub fn bitwise_left_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_left_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_left_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_left_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_left_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } + + pub fn bitwise_right_shift(lhs: TchTensor, rhs: TchTensor) -> TchTensor { + TchTensor::binary_ops_tensor( + lhs, + rhs, + |lhs, rhs| lhs.f_bitwise_right_shift_(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + |lhs, rhs| lhs.f_bitwise_right_shift(rhs).unwrap(), + ) + } + + pub fn bitwise_right_shift_scalar + Clone>( + tensor: TchTensor, + scalar: S, + ) -> TchTensor { + tensor.unary_ops( + |mut tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar_(scalar.clone().into()) + .unwrap() + }, + |tensor| { + tensor + .f_bitwise_right_shift_tensor_scalar(scalar.clone().into()) + .unwrap() + }, + ) + } } diff --git a/crates/burn-tch/src/ops/int_tensor.rs b/crates/burn-tch/src/ops/int_tensor.rs index 0da31fe430..0ac829abaf 100644 --- a/crates/burn-tch/src/ops/int_tensor.rs +++ b/crates/burn-tch/src/ops/int_tensor.rs @@ -416,4 +416,63 @@ impl IntTensorOps for LibTorch { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { TchOps::argsort(tensor, dim, descending) } + + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_and(lhs, rhs) + } + + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_or(lhs, rhs) + } + + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_xor(lhs, rhs) + } + + fn bitwise_not(tensor: IntTensor) -> IntTensor { + TchOps::bitwise_not(tensor) + } + + fn bitwise_and_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_and_scalar(lhs, rhs) + } + + fn bitwise_or_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_or_scalar(lhs, rhs) + } + + fn bitwise_xor_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_xor_scalar(lhs, rhs) + } + + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_left_shift(lhs, rhs) + } + + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor { + TchOps::bitwise_right_shift(lhs, rhs) + } + + fn bitwise_left_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_left_shift_scalar(lhs, rhs) + } + + fn bitwise_right_shift_scalar( + lhs: IntTensor, + rhs: burn_tensor::ops::IntElem, + ) -> IntTensor { + TchOps::bitwise_right_shift_scalar(lhs, rhs) + } } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 001b9d6e83..0d7fe2493b 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -520,6 +520,50 @@ pub enum NumericOperationDescription { pub enum IntOperationDescription { /// Operation corresponding to [into float](crate::ops::IntTensorOps::int_into_float). IntoFloat(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and](crate::ops::IntTensorOps::bitwise_and). + BitwiseAnd(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise and scalar](crate::ops::IntTensorOps::bitwise_and_scalar). + BitwiseAndScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or](crate::ops::IntTensorOps::bitwise_or). + BitwiseOr(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise or scalar](crate::ops::IntTensorOps::bitwise_or_scalar). + BitwiseOrScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor](crate::ops::IntTensorOps::bitwise_xor). + BitwiseXor(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise xor scalar](crate::ops::IntTensorOps::bitwise_xor_scalar). + BitwiseXorScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise not](crate::ops::IntTensorOps::bitwise_not). + BitwiseNot(UnaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift](crate::ops::IntTensorOps::bitwise_left_shift). + BitwiseLeftShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise left shift scalar](crate::ops::IntTensorOps::bitwise_left_shift_scalar). + BitwiseLeftShiftScalar(ScalarOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift](crate::ops::IntTensorOps::bitwise_right_shift). + BitwiseRightShift(BinaryOperationDescription), + /// Operation corresponding to: + /// + /// Int => [bitwise right shift scalar](crate::ops::IntTensorOps::bitwise_right_shift_scalar). + BitwiseRightShiftScalar(ScalarOperationDescription), } /// Operation description specific to a bool tensor. @@ -1544,6 +1588,39 @@ impl IntOperationDescription { fn nodes(&self) -> Vec<&TensorDescription> { match self { IntOperationDescription::IntoFloat(desc) => vec![&desc.input, &desc.out], + IntOperationDescription::BitwiseAnd(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseAndScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseOr(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseOrScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseXor(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseXorScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseNot(desc) => { + vec![&desc.input, &desc.out] + } + IntOperationDescription::BitwiseLeftShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseLeftShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } + IntOperationDescription::BitwiseRightShift(desc) => { + vec![&desc.lhs, &desc.rhs, &desc.out] + } + IntOperationDescription::BitwiseRightShiftScalar(desc) => { + vec![&desc.lhs, &desc.out] + } } } } diff --git a/crates/burn-tensor/src/tensor/api/int.rs b/crates/burn-tensor/src/tensor/api/int.rs index e882a107c7..5d65b68ceb 100644 --- a/crates/burn-tensor/src/tensor/api/int.rs +++ b/crates/burn-tensor/src/tensor/api/int.rs @@ -99,4 +99,59 @@ where ) -> Tensor { cartesian_grid::(shape, device) } + + /// Applies the bitwise logical and operation with each bit representing the integer. + pub fn bitwise_and(self, other: Self) -> Self { + Self::new(B::bitwise_and(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical or operation with another tensor. + pub fn bitwise_or(self, other: Self) -> Self { + Self::new(B::bitwise_or(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical xor operation with another tensor. + pub fn bitwise_xor(self, other: Self) -> Self { + Self::new(B::bitwise_xor(self.primitive, other.primitive)) + } + + /// Applies the bitwise logical not operation. + pub fn bitwise_not(self) -> Self { + Self::new(B::bitwise_not(self.primitive)) + } + + /// Applies the bitwise logical and operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_and_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_and_scalar(self.primitive, other)) + } + + /// Applies the bitwise logical or operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_or_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_or_scalar(self.primitive, other)) + } + + /// Applies bitwise logical xor operation with each bit in the scalar and the integers in the tensor. + pub fn bitwise_xor_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_xor_scalar(self.primitive, other)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift(self, other: Self) -> Self { + Self::new(B::bitwise_left_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift(self, other: Self) -> Self { + Self::new(B::bitwise_right_shift(self.primitive, other.primitive)) + } + + /// Applies the bitwise left shift operation with the integers in the tensor. + pub fn bitwise_left_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_left_shift_scalar(self.primitive, other)) + } + + /// Applies the bitwise right shift operation with the integers in the tensor. + pub fn bitwise_right_shift_scalar(self, other: B::IntElem) -> Self { + Self::new(B::bitwise_right_shift_scalar(self.primitive, other)) + } } diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index abdd2e54ba..81b73eb2dd 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1185,4 +1185,37 @@ pub trait IntTensorOps { fn int_argsort(tensor: IntTensor, dim: usize, descending: bool) -> IntTensor { argsort::(tensor, dim, descending) } + + /// Bitwise AND operation for Int Tensors + fn bitwise_and(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise AND operation for Int Tensors with a scalar + fn bitwise_and_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise OR operation for Int Tensors + fn bitwise_or(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise OR operation for Int Tensors with a scalar + fn bitwise_or_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors + fn bitwise_xor(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise XOR operation for Int Tensors with a scalar + fn bitwise_xor_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise NOT operation for Int Tensors + fn bitwise_not(tensor: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors + fn bitwise_left_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise left shift operation for Int Tensors with a scalar + fn bitwise_left_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors + fn bitwise_right_shift(lhs: IntTensor, rhs: IntTensor) -> IntTensor; + + /// Bitwise right shift operation for Int Tensors with a scalar + fn bitwise_right_shift_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor; } diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 8aa41ee24d..ee9aec9fe8 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -310,6 +310,7 @@ macro_rules! testgen_with_int_param { burn_tensor::testgen_sub!(); burn_tensor::testgen_transpose!(); burn_tensor::testgen_gather_scatter!(); + burn_tensor::testgen_bitwise!(); // test stats burn_tensor::testgen_eye!(); diff --git a/crates/burn-tensor/src/tests/ops/bitwise.rs b/crates/burn-tensor/src/tests/ops/bitwise.rs new file mode 100644 index 0000000000..73702a716e --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/bitwise.rs @@ -0,0 +1,172 @@ +#[burn_tensor_testgen::testgen(bitwise)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_apply_bitwise_and_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[2, 4, 0], [9, 2, 8]]), false); + } + + #[test] + fn should_apply_bitwise_and_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_and(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([9, 3]), false); + } + + #[test] + fn should_apply_bitwise_and_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_and_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 4, 5], [1, 1, 0]]), false); + } + + #[test] + fn should_apply_bitwise_not_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + + let output = tensor_1.bitwise_not(); + + output + .into_data() + .assert_eq(&TensorData::from([[-4, -5, -6], [-10, -4, -9]]), false); + } + + #[test] + fn should_apply_bitwise_or_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_or_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 5, 5], [13, 7, 13]]), false); + } + + #[test] + fn should_apply_bitwise_or_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[7, 7, 13], [9, 11, 15]]), false); + } + + #[test] + fn should_apply_bitwise_or_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_or(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([15, 7]), false); + } + + #[test] + fn should_apply_bitwise_xor_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 5; + + let output = tensor_1.bitwise_xor_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 1, 0], [12, 6, 13]]), false); + } + + #[test] + fn should_apply_bitwise_xor_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[6, 7, 8], [9, 10, 15]]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[5, 3, 13], [0, 9, 7]]), false); + } + + #[test] + fn should_apply_bitwise_xor_1d() { + let tensor_1 = TestTensorInt::<1>::from([13, 7]); + let tensor_2 = TestTensorInt::from([11, 3]); + + let output = tensor_1.bitwise_xor(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([6, 4]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_left_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[6, 16, 40], [144, 96, 512]]), false); + } + + #[test] + fn should_apply_bitwise_left_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_left_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[12, 16, 20], [36, 12, 32]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let tensor_2 = TestTensorInt::from([[1, 2, 3], [4, 5, 6]]); + + let output = tensor_1.bitwise_right_shift(tensor_2); + + output + .into_data() + .assert_eq(&TensorData::from([[1, 1, 0], [0, 0, 0]]), false); + } + + #[test] + fn should_apply_bitwise_right_shift_scalar_2d() { + let tensor_1 = TestTensorInt::<2>::from([[3, 4, 5], [9, 3, 8]]); + let scalar = 2; + + let output = tensor_1.bitwise_right_shift_scalar(scalar); + + output + .into_data() + .assert_eq(&TensorData::from([[0, 1, 1], [2, 0, 2]]), false); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index b1096e0216..32bdd0f4ba 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -7,6 +7,7 @@ mod arange; mod arange_step; mod arg; mod argwhere_nonzero; +mod bitwise; mod bool; mod cartesian_grid; mod cast;