From f978e7ba47660c88d1fad5742fc8d360f568b79d Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 24 Jan 2025 15:51:00 -0500 Subject: [PATCH] Add new FromData operation description (#2735) * Add new FromData operation description * Only hash tensor desc --- crates/burn-fusion/src/ops/boolean.rs | 69 ++++++++++++++++++------ crates/burn-fusion/src/ops/float.rs | 69 ++++++++++++++++++------ crates/burn-fusion/src/ops/int.rs | 68 +++++++++++++++++------ crates/burn-fusion/src/ops/qtensor.rs | 47 +++++++++++----- crates/burn-fusion/src/stream/context.rs | 6 +++ crates/burn-router/src/ops/op_bool.rs | 19 +++++-- crates/burn-router/src/ops/op_float.rs | 27 +++++++--- crates/burn-router/src/ops/op_int.rs | 27 +++++++--- crates/burn-router/src/runner.rs | 12 +++++ crates/burn-tensor/src/repr/operation.rs | 27 ++++++++-- 10 files changed, 286 insertions(+), 85 deletions(-) diff --git a/crates/burn-fusion/src/ops/boolean.rs b/crates/burn-fusion/src/ops/boolean.rs index baa5169db3..658907bf3e 100644 --- a/crates/burn-fusion/src/ops/boolean.rs +++ b/crates/burn-fusion/src/ops/boolean.rs @@ -1,5 +1,6 @@ use burn_tensor::{ ops::{binary_ops_shape, FloatTensor, IntTensor}, + repr::{FromDataOperationDescription, TensorDescription}, DType, Element, TensorData, }; use std::marker::PhantomData; @@ -24,15 +25,32 @@ use burn_tensor::{ impl BoolTensorOps for Fusion { fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_bool_tensor::(&self.desc.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_empty(shape.clone(), device); + let out = client.tensor_uninitialized(shape.dims.clone(), DType::Bool); - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn bool_into_data(tensor: BoolTensor) -> TensorData { @@ -40,16 +58,35 @@ impl BoolTensorOps for Fusion { } fn bool_from_data(data: burn_tensor::TensorData, device: &Device) -> BoolTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::bool_from_data(self.desc.data, &self.device); + handles.register_bool_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::bool_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::bool_tensor_handle(tensor), - shape.dims, - StreamId::current(), - DType::Bool, - ) + let out = client.tensor_uninitialized(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseBool(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index b3e2a80432..1ba2717bfb 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -16,16 +16,35 @@ use std::{marker::PhantomData, ops::Range}; impl FloatTensorOps for Fusion { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_from_data(self.desc.data, &self.device); + handles.register_float_tensor::(&self.desc.out.id, output); + } + } + + let stream = StreamId::current(); let client = get_client::(&device.clone()); - let tensor = B::float_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - StreamId::current(), - B::FloatElem::dtype(), - ) + let out = client.tensor_uninitialized(data.shape.clone(), B::FloatElem::dtype()); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn float_random( @@ -233,16 +252,32 @@ impl FloatTensorOps for Fusion { } fn float_empty(shape: Shape, device: &Device) -> FloatTensor { - let client = get_client::(&device.clone()); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::float_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_float_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); - let tensor = B::float_empty(shape.clone(), device); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::FloatElem::dtype()); - client.register_tensor( - B::float_tensor_handle(tensor), - shape.dims, - stream, - B::FloatElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } fn float_add(lhs: FloatTensor, rhs: FloatTensor) -> FloatTensor { diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index e2115cbf6a..bf88bbd25b 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -15,16 +15,32 @@ use std::marker::PhantomData; impl IntTensorOps for Fusion { fn int_empty(shape: Shape, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_empty(shape.clone(), device); + #[derive(new)] + struct EmptyOps { + desc: TensorDescription, + device: Device, + } + + impl Operation for EmptyOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_empty(Shape::from(&self.desc.shape), &self.device); + handles.register_int_tensor::(&self.desc.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(shape.dims.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = out.to_description_out(); + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::Empty(desc.clone())), + EmptyOps::::new(desc, device.clone()), + ); + + out } async fn int_into_data(tensor: IntTensor) -> TensorData { @@ -32,17 +48,35 @@ impl IntTensorOps for Fusion { } fn int_from_data(data: TensorData, device: &Device) -> IntTensor { - let client = get_client::(&device.clone()); - let tensor = B::int_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::int_from_data(self.desc.data, &self.device); + handles.register_int_tensor::(&self.desc.out.id, output); + } + } + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), B::IntElem::dtype()); - client.register_tensor( - B::int_tensor_handle(tensor), - shape.dims, - stream, - B::IntElem::dtype(), - ) + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseInt(BaseOperationDescription::FromData(desc.clone())), + FromDataOps::::new(desc, device.clone()), + ); + + out } fn int_device(tensor: &IntTensor) -> Device { diff --git a/crates/burn-fusion/src/ops/qtensor.rs b/crates/burn-fusion/src/ops/qtensor.rs index 41bc7ccde6..1449a485af 100644 --- a/crates/burn-fusion/src/ops/qtensor.rs +++ b/crates/burn-fusion/src/ops/qtensor.rs @@ -4,8 +4,9 @@ use burn_tensor::{ ops::{FloatElem, FloatTensor, IntTensor, QTensorOps, QuantizedTensor}, quantization::{QuantizationParametersPrimitive, QuantizationScheme}, repr::{ - DequantizeOperationDescription, FloatOperationDescription, HandleContainer, - OperationDescription, QuantizationParametersDescription, QuantizeOperationDescription, + BaseOperationDescription, DequantizeOperationDescription, FloatOperationDescription, + FromDataOperationDescription, HandleContainer, OperationDescription, + QuantizationParametersDescription, QuantizeOperationDescription, }, DType, Device, Element, Shape, TensorData, }; @@ -19,19 +20,41 @@ use crate::{ impl QTensorOps for Fusion { fn q_from_data(data: TensorData, device: &Device) -> QuantizedTensor { + #[derive(new)] + struct FromDataOps { + desc: FromDataOperationDescription, + device: Device, + } + + impl Operation for FromDataOps { + fn execute(self: Box, handles: &mut HandleContainer) { + let output = B::q_from_data(self.desc.data, &self.device); + handles.register_quantized_tensor::(&self.desc.out.id, output); + } + } + match data.dtype { DType::QFloat(_scheme) => { let dtype = data.dtype; - let client = get_client::(device); - let tensor = B::q_from_data(data, device); - let shape = burn_tensor::TensorMetadata::shape(&tensor); - - client.register_tensor( - B::quantized_tensor_handle(tensor), - shape.dims, - StreamId::current(), - dtype, - ) + + let stream = StreamId::current(); + let client = get_client::(&device.clone()); + let out = client.tensor_uninitialized(data.shape.clone(), dtype); + + let desc = FromDataOperationDescription { + out: out.to_description_out(), + data, + }; + + client.register( + vec![stream], + OperationDescription::BaseFloat(BaseOperationDescription::FromData( + desc.clone(), + )), + FromDataOps::::new(desc, device.clone()), + ); + + out } _ => panic!( "Invalid dtype (expected DType::QFloat, got {:?})", diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index d85e06cc09..ed1a1902f8 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -1210,6 +1210,12 @@ impl RelativeOps for BaseOperationDescription { BaseOperationDescription::Empty(desc) => { BaseOperationDescription::Empty(desc.to_relative(converter)) } + BaseOperationDescription::FromData(desc) => { + BaseOperationDescription::FromData(FromDataOperationDescription { + data: desc.data.clone(), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-router/src/ops/op_bool.rs b/crates/burn-router/src/ops/op_bool.rs index 25c46ae854..5d01ddd7d3 100644 --- a/crates/burn-router/src/ops/op_bool.rs +++ b/crates/burn-router/src/ops/op_bool.rs @@ -4,9 +4,9 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntEle use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, - OperationDescription, PermuteOperationDescription, RepeatDimOperationDescription, - ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, - SwapDimsDescription, UnaryOperationDescription, + FromDataOperationDescription, OperationDescription, PermuteOperationDescription, + RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription, + SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{DType, Device, Element, Shape, TensorData, TensorMetadata}; @@ -31,7 +31,18 @@ impl BoolTensorOps for BackendRouter { fn bool_from_data(data: TensorData, device: &Device) -> BoolTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::()) + let out = client.register_empty_tensor(data.shape.clone(), DType::Bool); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseBool( + BaseOperationDescription::FromData(desc), + )); + + out } fn bool_into_int(tensor: BoolTensor) -> IntTensor { diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index dda01990e0..1cf211701c 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - FloatOperationDescription, GatherOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FloatOperationDescription, FromDataOperationDescription, GatherOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -25,7 +25,18 @@ use crate::{get_client, BackendRouter, RunnerChannel, RunnerClient}; impl FloatTensorOps for BackendRouter { fn float_from_data(data: TensorData, device: &Device) -> FloatTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::FloatElem>()) + let out = client.register_empty_tensor(data.shape.clone(), FloatElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseFloat( + BaseOperationDescription::FromData(desc), + )); + + out } fn float_random( diff --git a/crates/burn-router/src/ops/op_int.rs b/crates/burn-router/src/ops/op_int.rs index 5d84131e32..997bf5b9e6 100644 --- a/crates/burn-router/src/ops/op_int.rs +++ b/crates/burn-router/src/ops/op_int.rs @@ -8,13 +8,13 @@ use burn_tensor::ops::{ use burn_tensor::repr::{ BaseOperationDescription, BinaryOperationDescription, CatOperationDescription, ClampOperationDescription, ExpandOperationDescription, FlipOperationDescription, - GatherOperationDescription, IntOperationDescription, MaskFillOperationDescription, - MaskWhereOperationDescription, NumericOperationDescription, OperationDescription, - PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription, - RepeatDimOperationDescription, ReshapeDescription, ScalarOperationDescription, - ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription, - SliceAssignOperationDescription, SliceOperationDescription, SwapDimsDescription, - UnaryOperationDescription, + FromDataOperationDescription, GatherOperationDescription, IntOperationDescription, + MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription, + OperationDescription, PermuteOperationDescription, RandomOperationDescription, + ReduceDimWithIndicesDescription, RepeatDimOperationDescription, ReshapeDescription, + ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription, + SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription, + SwapDimsDescription, UnaryOperationDescription, }; use burn_tensor::{ DType, Device, Distribution, Element, ElementConversion, Shape, TensorData, TensorMetadata, @@ -45,7 +45,18 @@ impl IntTensorOps for BackendRouter { fn int_from_data(data: TensorData, device: &Device) -> IntTensor { let client = get_client::(device); - client.register_tensor_data(data.convert::<::IntElem>()) + let out = client.register_empty_tensor(data.shape.clone(), IntElem::::dtype()); + + let desc = FromDataOperationDescription { + data, + out: out.to_description_out(), + }; + + client.register(OperationDescription::BaseInt( + BaseOperationDescription::FromData(desc), + )); + + out } fn int_device(tensor: &IntTensor) -> Device { diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 9521cf66ae..7443be94f9 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -245,6 +245,10 @@ impl RunnerClient for Runner { let output = B::float_empty(shape, &self.device); handles.register_float_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::float_from_data(desc.data.clone(), &self.device); + handles.register_float_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseInt(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -316,6 +320,10 @@ impl RunnerClient for Runner { let output = B::int_empty(shape, &self.device); handles.register_int_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::int_from_data(desc.data.clone(), &self.device); + handles.register_int_tensor::(&desc.out.id, output); + } }, OperationDescription::BaseBool(op) => match op { BaseOperationDescription::ToDevice(_) => unreachable!(), @@ -391,6 +399,10 @@ impl RunnerClient for Runner { let output = B::bool_empty(shape, &self.device); handles.register_bool_tensor::(&desc.id, output); } + BaseOperationDescription::FromData(desc) => { + let output = B::bool_from_data(desc.data.clone(), &self.device); + handles.register_bool_tensor::(&desc.out.id, output); + } }, OperationDescription::NumericFloat(_dtype, op) => match op { NumericOperationDescription::Add(desc) => { diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 0d7fe2493b..e4b0f3ccaf 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -6,6 +6,7 @@ use alloc::borrow::ToOwned; use alloc::boxed::Box; use alloc::{string::String, vec, vec::Vec}; +use crate::TensorData; use crate::{ ops::{ ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, @@ -197,6 +198,12 @@ pub enum ModuleOperationDescription { /// Basic operations that can be done on any tensor type. #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] pub enum BaseOperationDescription { + /// Operation corresponding to: + /// + /// Float => [from_data](crate::ops::FloatTensorOps::float_from_data). + /// Int => [from_data](crate::ops::IntTensorOps::int_from_data). + /// Bool => [from_data](crate::ops::BoolTensorOps::bool_from_data). + FromData(FromDataOperationDescription), /// Operation corresponding to: /// /// Float => [to device](crate::ops::FloatTensorOps::float_to_device). @@ -272,9 +279,9 @@ pub enum BaseOperationDescription { /// Operation corresponding to: /// - /// Float => [equal](crate::ops::FloatTensorOps::float_empty). - /// Int => [equal](crate::ops::IntTensorOps::int_empty). - /// Bool => [equal](crate::ops::BoolTensorOps::bool_empty). + /// Float => [empty](crate::ops::FloatTensorOps::float_empty). + /// Int => [empty](crate::ops::IntTensorOps::int_empty). + /// Bool => [empty](crate::ops::BoolTensorOps::bool_empty). Empty(TensorDescription), } @@ -630,6 +637,13 @@ pub struct RandomOperationDescription { pub distribution: Distribution, } +#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct FromDataOperationDescription { + pub out: TensorDescription, + pub data: TensorData, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct ReshapeDescription { @@ -1408,6 +1422,7 @@ impl BaseOperationDescription { BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cast(desc) => vec![&desc.input, &desc.out], BaseOperationDescription::Empty(desc) => vec![desc], + BaseOperationDescription::FromData(desc) => vec![&desc.out], } } } @@ -1754,6 +1769,12 @@ impl ModuleOperationDescription { } } +impl core::hash::Hash for FromDataOperationDescription { + fn hash(&self, state: &mut H) { + self.out.hash(state); + } +} + impl core::hash::Hash for RandomOperationDescription { fn hash(&self, state: &mut H) { self.out.hash(state);