Skip to content

Commit

Permalink
Add bitwise int ops to book + remove dead code
Browse files Browse the repository at this point in the history
  • Loading branch information
laggui committed Jan 24, 2025
1 parent e6853fd commit d0a4658
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 92 deletions.
145 changes: 78 additions & 67 deletions burn-book/src/building-blocks/tensor.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,47 +131,47 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t

Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.

| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dtype()` | `tensor.dtype` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.unsqueeze_dims(dims)` | N/A |
| Burn | PyTorch Equivalent |
| ------------------------------------------- | ------------------------------------------------------------------------- |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.split(split_size, dim)` | `tensor.split(split_size, dim)` |
| `tensor.split_with_sizes(split_sizes, dim)` | `tensor.split([split_sizes], dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dtype()` | `tensor.dtype` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])` |
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.transpose()` | `tensor.T` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| `tensor.unsqueeze_dims(dims)` | N/A |

### Numeric Operations

Expand Down Expand Up @@ -258,32 +258,32 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.

Those operations are only available for `Float` tensors.

| Burn API | PyTorch Equivalent |
| --------------------------------------------- | ---------------------------------- |
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` |
| `tensor.floor()` | `tensor.floor()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |
| Burn API | PyTorch Equivalent |
| -------------------------------------------- | ---------------------------------- |
| `tensor.cast(dtype)` | `tensor.to(dtype)` |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` |
| `tensor.floor()` | `tensor.floor()` |
| `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
| `tensor.log()` | `tensor.log()` |
| `tensor.log1p()` | `tensor.log1p()` |
| `tensor.matmul(other)` | `tensor.matmul(other)` |
| `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.tanh()` | `tensor.tanh()` |
| `tensor.to_full_precision()` | `tensor.to(torch.float)` |
| `tensor.var(dim)` | `tensor.var(dim)` |
| `tensor.var_bias(dim)` | N/A |
| `tensor.var_mean(dim)` | N/A |
| `tensor.var_mean_bias(dim)` | N/A |

### Int Operations

Expand All @@ -293,6 +293,17 @@ Those operations are only available for `Int` tensors.
| ------------------------------------------------ | ------------------------------------------------------- |
| `Tensor::arange(5..10, device)` | `tensor.arange(start=5, end=10, device=device)` |
| `Tensor::arange_step(5..10, 2, device)` | `tensor.arange(start=5, end=10, step=2, device=device)` |
| `tensor.bitwise_and(other)` | `torch.bitwise_and(tensor, other)` |
| `tensor.bitwise_and_scalar(scalar)` | `torch.bitwise_and(tensor, scalar)` |
| `tensor.bitwise_not()` | `torch.bitwise_not(tensor)` |
| `tensor.bitwise_left_shift(other)` | `torch.bitwise_left_shift(tensor, other)` |
| `tensor.bitwise_left_shift_scalar(scalar)` | `torch.bitwise_left_shift(tensor, scalar)` |
| `tensor.bitwise_right_shift(other)` | `torch.bitwise_right_shift(tensor, other)` |
| `tensor.bitwise_right_shift_scalar(scalar)` | `torch.bitwise_right_shift(tensor, scalar)` |
| `tensor.bitwise_or(other)` | `torch.bitwise_or(tensor, other)` |
| `tensor.bitwise_or_scalar(scalar)` | `torch.bitwise_or(tensor, scalar)` |
| `tensor.bitwise_xor(other)` | `torch.bitwise_xor(tensor, other)` |
| `tensor.bitwise_xor_scalar(scalar)` | `torch.bitwise_xor(tensor, scalar)` |
| `tensor.float()` | `tensor.to(torch.float)` |
| `tensor.from_ints(ints)` | N/A |
| `tensor.int_random(shape, distribution, device)` | N/A |
Expand Down
4 changes: 0 additions & 4 deletions crates/burn-jit/src/kernel/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ impl<F: Float> BinaryOpFamily for PowOp<F> {
type BinaryOp<C: Numeric> = Self;
}

// impl BinaryOpFamily for BitwiseAndOp {
// type BinaryOp<C: Numeric> = Self;
// }

#[cube]
impl<N: Numeric> BinaryOp<N> for AddOp {
fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
Expand Down
21 changes: 0 additions & 21 deletions crates/burn-tensor/src/tensor/api/numeric.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3452,27 +3452,6 @@ where
dim: usize,
descending: bool,
) -> <Int as TensorKind<B>>::Primitive;

// /// Applies logical `and` operation element-wise between two tensors.
// fn bitwise_and(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `and` operation element-wise between a tensor and a scalar.
// fn bitwise_and_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `or` operation element-wise between two tensors.
// fn bitwise_or(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `or` operation element-wise between a tensor and a scalar.
// fn bitwise_or_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `xor` operation element-wise between two tensors.
// fn bitwise_xor(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive;

// /// Applies logical `xor` operation element-wise between a tensor and a scalar.
// fn bitwise_xor_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive;

// /// Applies logical `not` operation element-wise on a tensor.
// fn bitwise_not(tensor: Self::Primitive) -> Self::Primitive;
}

impl<B: Backend> Numeric<B> for Int {
Expand Down

0 comments on commit d0a4658

Please sign in to comment.