Skip to content

Commit

Permalink
Add new FromData operation description (#2735)
Browse files Browse the repository at this point in the history
* Add new FromData operation description

* Only hash tensor desc
  • Loading branch information
laggui authored Jan 24, 2025
1 parent 7ddb5af commit f978e7b
Show file tree
Hide file tree
Showing 10 changed files with 286 additions and 85 deletions.
69 changes: 53 additions & 16 deletions crates/burn-fusion/src/ops/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use burn_tensor::{
ops::{binary_ops_shape, FloatTensor, IntTensor},
repr::{FromDataOperationDescription, TensorDescription},
DType, Element, TensorData,
};
use std::marker::PhantomData;
Expand All @@ -24,32 +25,68 @@ use burn_tensor::{

impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new)]
struct EmptyOps<B: FusionBackend> {
desc: TensorDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device);
handles.register_bool_tensor::<B>(&self.desc.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let tensor = B::bool_empty(shape.clone(), device);
let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool);

client.register_tensor(
B::bool_tensor_handle(tensor),
shape.dims,
StreamId::current(),
DType::Bool,
)
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())),
EmptyOps::<B>::new(desc, device.clone()),
);

out
}

async fn bool_into_data(tensor: BoolTensor<Self>) -> TensorData {
tensor.bool_into_data::<B>().await
}

fn bool_from_data(data: burn_tensor::TensorData, device: &Device<Self>) -> BoolTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::bool_from_data(self.desc.data, &self.device);
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let tensor = B::bool_from_data(data, device);
let shape = burn_tensor::TensorMetadata::shape(&tensor);

client.register_tensor(
B::bool_tensor_handle(tensor),
shape.dims,
StreamId::current(),
DType::Bool,
)
let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool);

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};

client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
);

out
}

fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
Expand Down
69 changes: 52 additions & 17 deletions crates/burn-fusion/src/ops/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range};

impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::float_from_data(self.desc.data, &self.device);
handles.register_float_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let tensor = B::float_from_data(data, device);
let shape = burn_tensor::TensorMetadata::shape(&tensor);

client.register_tensor(
B::float_tensor_handle(tensor),
shape.dims,
StreamId::current(),
B::FloatElem::dtype(),
)
let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype());

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};

client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
);

out
}

fn float_random(
Expand Down Expand Up @@ -233,16 +252,32 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}

fn float_empty(shape: Shape, device: &Device<Self>) -> FloatTensor<Self> {
let client = get_client::<B>(&device.clone());
#[derive(new)]
struct EmptyOps<B: FusionBackend> {
desc: TensorDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::float_empty(Shape::from(&self.desc.shape), &self.device);
handles.register_float_tensor::<B>(&self.desc.id, output);
}
}

let stream = StreamId::current();
let tensor = B::float_empty(shape.clone(), device);
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype());

client.register_tensor(
B::float_tensor_handle(tensor),
shape.dims,
stream,
B::FloatElem::dtype(),
)
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Empty(desc.clone())),
EmptyOps::<B>::new(desc, device.clone()),
);

out
}

fn float_add(lhs: FloatTensor<Self>, rhs: FloatTensor<Self>) -> FloatTensor<Self> {
Expand Down
68 changes: 51 additions & 17 deletions crates/burn-fusion/src/ops/int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,68 @@ use std::marker::PhantomData;

impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
let client = get_client::<B>(&device.clone());
let tensor = B::int_empty(shape.clone(), device);
#[derive(new)]
struct EmptyOps<B: FusionBackend> {
desc: TensorDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for EmptyOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::int_empty(Shape::from(&self.desc.shape), &self.device);
handles.register_int_tensor::<B>(&self.desc.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(shape.dims.clone(), B::IntElem::dtype());

client.register_tensor(
B::int_tensor_handle(tensor),
shape.dims,
stream,
B::IntElem::dtype(),
)
let desc = out.to_description_out();

client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Empty(desc.clone())),
EmptyOps::<B>::new(desc, device.clone()),
);

out
}

async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
tensor.int_into_data::<B>().await
}

fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
let client = get_client::<B>(&device.clone());
let tensor = B::int_from_data(data, device);
let shape = burn_tensor::TensorMetadata::shape(&tensor);
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::int_from_data(self.desc.data, &self.device);
handles.register_int_tensor::<B>(&self.desc.out.id, output);
}
}

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype());

client.register_tensor(
B::int_tensor_handle(tensor),
shape.dims,
stream,
B::IntElem::dtype(),
)
let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};

client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())),
FromDataOps::<B>::new(desc, device.clone()),
);

out
}

fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
Expand Down
47 changes: 35 additions & 12 deletions crates/burn-fusion/src/ops/qtensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use burn_tensor::{
ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor},
quantization::{QuantizationParametersPrimitive, QuantizationScheme},
repr::{
DequantizeOperationDescription, FloatOperationDescription, HandleContainer,
OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription,
BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription,
FromDataOperationDescription, HandleContainer, OperationDescription,
QuantizationParametersDescription, QuantizeOperationDescription,
},
DType, Device, Element, Shape, TensorData,
};
Expand All @@ -19,19 +20,41 @@ use crate::{

impl<B: FusionBackend> QTensorOps<Self> for Fusion<B> {
fn q_from_data(data: TensorData, device: &Device<Self>) -> QuantizedTensor<Self> {
#[derive(new)]
struct FromDataOps<B: FusionBackend> {
desc: FromDataOperationDescription,
device: Device<B>,
}

impl<B: FusionBackend> Operation<B::FusionRuntime> for FromDataOps<B> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let output = B::q_from_data(self.desc.data, &self.device);
handles.register_quantized_tensor::<B>(&self.desc.out.id, output);
}
}

match data.dtype {
DType::QFloat(_scheme) => {
let dtype = data.dtype;
let client = get_client::<B>(device);
let tensor = B::q_from_data(data, device);
let shape = burn_tensor::TensorMetadata::shape(&tensor);

client.register_tensor(
B::quantized_tensor_handle(tensor),
shape.dims,
StreamId::current(),
dtype,
)

let stream = StreamId::current();
let client = get_client::<B>(&device.clone());
let out = client.tensor_uninitialized(data.shape.clone(), dtype);

let desc = FromDataOperationDescription {
out: out.to_description_out(),
data,
};

client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::FromData(
desc.clone(),
)),
FromDataOps::<B>::new(desc, device.clone()),
);

out
}
_ => panic!(
"Invalid dtype (expected DType::QFloat, got {:?})",
Expand Down
6 changes: 6 additions & 0 deletions crates/burn-fusion/src/stream/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1210,6 +1210,12 @@ impl RelativeOps for BaseOperationDescription {
BaseOperationDescription::Empty(desc) => {
BaseOperationDescription::Empty(desc.to_relative(converter))
}
BaseOperationDescription::FromData(desc) => {
BaseOperationDescription::FromData(FromDataOperationDescription {
data: desc.data.clone(),
out: desc.out.to_relative(converter),
})
}
}
}
}
Expand Down
19 changes: 15 additions & 4 deletions crates/burn-router/src/ops/op_bool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle
use burn_tensor::repr::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
FromDataOperationDescription, OperationDescription, PermuteOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
};
use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata};

Expand All @@ -31,7 +31,18 @@ impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {

fn bool_from_data(data: TensorData, device: &Device<Self>) -> BoolTensor<Self> {
let client = get_client::<R>(device);
client.register_tensor_data(data.convert::<bool>())
let out = client.register_empty_tensor(data.shape.clone(), DType::Bool);

let desc = FromDataOperationDescription {
data,
out: out.to_description_out(),
};

client.register(OperationDescription::BaseBool(
BaseOperationDescription::FromData(desc),
));

out
}

fn bool_into_int(tensor: BoolTensor<Self>) -> IntTensor<Self> {
Expand Down
Loading

0 comments on commit f978e7b

Please sign in to comment.