Skip to content

Commit 40a26bd

Browse files
Feat/backend bridge (#1529)
1 parent 28233d9 commit 40a26bd

File tree

36 files changed

+483
-220
lines changed

36 files changed

+483
-220
lines changed

crates/burn-autodiff/src/backend.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::{
33
grads::Gradients,
44
graph::backward::backward,
55
tensor::AutodiffTensor,
6+
AutodiffBridge,
67
};
78
use burn_tensor::backend::{AutodiffBackend, Backend};
89
use core::marker::PhantomData;
@@ -20,8 +21,7 @@ pub struct Autodiff<B, C = NoCheckpointing> {
2021
impl<B: Backend, C: CheckpointStrategy> Backend for Autodiff<B, C> {
2122
type Device = B::Device;
2223

23-
type FullPrecisionElem = B::FullPrecisionElem;
24-
type FullPrecisionBackend = Autodiff<B::FullPrecisionBackend>;
24+
type FullPrecisionBridge = AutodiffBridge<B::FullPrecisionBridge>;
2525

2626
type FloatTensorPrimitive<const D: usize> = AutodiffTensor<B, D>;
2727
type FloatElem = B::FloatElem;

crates/burn-autodiff/src/bridge.rs

+155
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use std::marker::PhantomData;
2+
3+
use burn_tensor::{
4+
backend::{Backend, BackendBridge},
5+
ops::FloatTensor,
6+
};
7+
8+
use crate::{
9+
checkpoint::{
10+
base::Checkpointer, retro_forward::RetroForward, state::BackwardStates,
11+
strategy::CheckpointStrategy,
12+
},
13+
grads::Gradients,
14+
ops::{unary_different_backend, Backward, Ops},
15+
Autodiff, NodeID,
16+
};
17+
18+
/// Enable autodiff on a [backend bridge](BackendBridge).
19+
#[derive(Debug)]
20+
pub struct AutodiffBridge<Bridge> {
21+
_p: PhantomData<Bridge>,
22+
}
23+
24+
impl<B, C, Bridge> BackendBridge<Autodiff<B, C>> for AutodiffBridge<Bridge>
25+
where
26+
B: Backend,
27+
C: CheckpointStrategy,
28+
Bridge: BackendBridge<B> + 'static,
29+
{
30+
type Target = Autodiff<Bridge::Target, C>;
31+
32+
fn into_target<const D: usize>(
33+
tensor: burn_tensor::ops::FloatTensor<Autodiff<B, C>, D>,
34+
_device: Option<burn_tensor::Device<Self::Target>>,
35+
) -> burn_tensor::ops::FloatTensor<Self::Target, D> {
36+
#[derive(Debug)]
37+
struct IntoTarget<B: Backend, Bridge: BackendBridge<B>> {
38+
_backend: PhantomData<B>,
39+
_bridge: PhantomData<Bridge>,
40+
}
41+
42+
#[derive(new, Debug)]
43+
struct RetroIntoTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
44+
tensor_id: NodeID,
45+
_backend: PhantomData<B>,
46+
_bridge: PhantomData<Bridge>,
47+
}
48+
49+
impl<B, Bridge, const D: usize> RetroForward for RetroIntoTarget<B, Bridge, D>
50+
where
51+
B: Backend,
52+
Bridge: BackendBridge<B> + 'static,
53+
{
54+
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
55+
let tensor: FloatTensor<B, D> = states.get_state(&self.tensor_id);
56+
let out = Bridge::into_target(tensor, Default::default());
57+
states.save(out_node, out)
58+
}
59+
}
60+
61+
impl<B, Bridge, const D: usize> Backward<Bridge::Target, D, 1> for IntoTarget<B, Bridge>
62+
where
63+
B: Backend,
64+
Bridge: BackendBridge<B> + 'static,
65+
{
66+
type State = ();
67+
68+
fn backward(
69+
self,
70+
ops: Ops<Self::State, 1>,
71+
grads: &mut Gradients,
72+
_checkpointer: &mut Checkpointer,
73+
) {
74+
unary_different_backend::<B, Bridge::Target, D, D, _>(
75+
ops.parents,
76+
ops.node,
77+
grads,
78+
|grad| Bridge::from_target(grad, None),
79+
);
80+
}
81+
}
82+
83+
IntoTarget::<B, Bridge> {
84+
_backend: PhantomData,
85+
_bridge: PhantomData,
86+
}
87+
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
88+
.memory_bound()
89+
.retro_forward(RetroIntoTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
90+
.parents([&tensor])
91+
.stateless(Bridge::into_target(tensor.primitive, None))
92+
}
93+
94+
fn from_target<const D: usize>(
95+
tensor: burn_tensor::ops::FloatTensor<Self::Target, D>,
96+
_device: Option<burn_tensor::Device<Autodiff<B, C>>>,
97+
) -> burn_tensor::ops::FloatTensor<Autodiff<B, C>, D> {
98+
#[derive(Debug)]
99+
struct FromTarget<B: Backend, Bridge: BackendBridge<B>> {
100+
_backend: PhantomData<B>,
101+
_bridge: PhantomData<Bridge>,
102+
}
103+
104+
#[derive(new, Debug)]
105+
struct RetroFromTarget<B: Backend, Bridge: BackendBridge<B>, const D: usize> {
106+
tensor_id: NodeID,
107+
_backend: PhantomData<B>,
108+
_bridge: PhantomData<Bridge>,
109+
}
110+
111+
impl<B, Bridge, const D: usize> RetroForward for RetroFromTarget<B, Bridge, D>
112+
where
113+
B: Backend,
114+
Bridge: BackendBridge<B> + 'static,
115+
{
116+
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
117+
let tensor: FloatTensor<Bridge::Target, D> = states.get_state(&self.tensor_id);
118+
let out = Bridge::from_target(tensor, None);
119+
states.save(out_node, out)
120+
}
121+
}
122+
123+
impl<B, Bridge, const D: usize> Backward<B, D, 1> for FromTarget<B, Bridge>
124+
where
125+
B: Backend,
126+
Bridge: BackendBridge<B> + 'static,
127+
{
128+
type State = ();
129+
130+
fn backward(
131+
self,
132+
ops: Ops<Self::State, 1>,
133+
grads: &mut Gradients,
134+
_checkpointer: &mut Checkpointer,
135+
) {
136+
unary_different_backend::<Bridge::Target, B, D, D, _>(
137+
ops.parents,
138+
ops.node,
139+
grads,
140+
|grad| Bridge::into_target(grad, None),
141+
);
142+
}
143+
}
144+
145+
FromTarget::<B, Bridge> {
146+
_backend: PhantomData,
147+
_bridge: PhantomData,
148+
}
149+
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
150+
.memory_bound()
151+
.retro_forward(RetroFromTarget::<B, Bridge, D>::new(tensor.node.id.clone()))
152+
.parents([&tensor])
153+
.stateless(Bridge::from_target(tensor.primitive, None))
154+
}
155+
}

crates/burn-autodiff/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@ pub(crate) mod tensor;
2626
pub(crate) mod utils;
2727

2828
mod backend;
29+
mod bridge;
30+
2931
pub use backend::*;
32+
pub use bridge::*;
3033

3134
#[cfg(feature = "export_tests")]
3235
mod tests;

crates/burn-autodiff/src/ops/tensor.rs

+2-103
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use crate::{
77
},
88
grads::Gradients,
99
graph::{ComputingProperty, NodeID, NodeRef, Requirement, Step},
10-
ops::{binary, broadcast_shape, unary, unary_different_backend, Backward, Ops, OpsKind},
10+
ops::{binary, broadcast_shape, unary, Backward, Ops, OpsKind},
1111
retro_binary, retro_unary, retro_unary_scalar,
1212
tensor::AutodiffTensor,
1313
utils::duplicate,
@@ -16,7 +16,7 @@ use crate::{
1616

1717
use burn_tensor::{
1818
backend::Backend,
19-
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, FullPrecisionBackend, IntTensor},
19+
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor},
2020
Data, Device, ElementConversion, Reader, Shape, Tensor,
2121
};
2222

@@ -1621,107 +1621,6 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
16211621
}
16221622
}
16231623

1624-
fn float_to_full_precision<const D: usize>(
1625-
tensor: &FloatTensor<Self, D>,
1626-
) -> FloatTensor<FullPrecisionBackend<Self>, D> {
1627-
#[derive(Debug)]
1628-
struct ToFullPrecision<B: Backend> {
1629-
phantom: PhantomData<B>,
1630-
}
1631-
1632-
#[derive(new, Debug)]
1633-
struct RetroToFullPrecision<B: Backend, const D: usize> {
1634-
tensor_id: NodeID,
1635-
_backend: PhantomData<B>,
1636-
}
1637-
1638-
impl<B: Backend, const D: usize> RetroForward for RetroToFullPrecision<B, D> {
1639-
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
1640-
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
1641-
let out = B::float_to_full_precision(&tensor);
1642-
states.save(out_node, out)
1643-
}
1644-
}
1645-
1646-
impl<B: Backend, const D: usize> Backward<B::FullPrecisionBackend, D, 1> for ToFullPrecision<B> {
1647-
type State = ();
1648-
1649-
fn backward(
1650-
self,
1651-
ops: Ops<Self::State, 1>,
1652-
grads: &mut Gradients,
1653-
_checkpointer: &mut Checkpointer,
1654-
) {
1655-
unary_different_backend::<B, B::FullPrecisionBackend, D, D, _>(
1656-
ops.parents,
1657-
ops.node,
1658-
grads,
1659-
|grad| B::float_from_full_precision(grad),
1660-
);
1661-
}
1662-
}
1663-
1664-
let ops = ToFullPrecision::<B> {
1665-
phantom: PhantomData,
1666-
};
1667-
ops.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
1668-
.memory_bound()
1669-
.retro_forward(RetroToFullPrecision::<B, D>::new(tensor.node.id.clone()))
1670-
.parents([tensor])
1671-
.stateless(B::float_to_full_precision(&tensor.primitive))
1672-
}
1673-
1674-
fn float_from_full_precision<const D: usize>(
1675-
tensor: FloatTensor<FullPrecisionBackend<Self>, D>,
1676-
) -> FloatTensor<Self, D> {
1677-
#[derive(Debug)]
1678-
struct FromFullPrecision<B: Backend> {
1679-
phantom: PhantomData<B>,
1680-
}
1681-
1682-
#[derive(new, Debug)]
1683-
struct RetroFromFullPrecision<B: Backend, const D: usize> {
1684-
tensor_id: NodeID,
1685-
_backend: PhantomData<B>,
1686-
}
1687-
1688-
impl<B: Backend, const D: usize> RetroForward for RetroFromFullPrecision<B, D> {
1689-
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
1690-
let tensor = states.get_state::<<<B as Backend>::FullPrecisionBackend as Backend>::FloatTensorPrimitive<D>>(&self.tensor_id);
1691-
let out = B::float_from_full_precision(tensor);
1692-
states.save(out_node, out)
1693-
}
1694-
}
1695-
1696-
impl<B: Backend, const D: usize> Backward<B, D, 1> for FromFullPrecision<B::FullPrecisionBackend> {
1697-
type State = ();
1698-
1699-
fn backward(
1700-
self,
1701-
ops: Ops<Self::State, 1>,
1702-
grads: &mut Gradients,
1703-
_checkpointer: &mut Checkpointer,
1704-
) {
1705-
unary_different_backend::<B::FullPrecisionBackend, B, D, D, _>(
1706-
ops.parents,
1707-
ops.node,
1708-
grads,
1709-
|grad| B::float_to_full_precision(&grad),
1710-
);
1711-
}
1712-
}
1713-
1714-
let ops = FromFullPrecision::<B::FullPrecisionBackend> {
1715-
phantom: PhantomData,
1716-
};
1717-
1718-
ops.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
1719-
.memory_bound()
1720-
.retro_forward(RetroFromFullPrecision::<B, D>::new(tensor.node.id.clone()))
1721-
.parents([&tensor])
1722-
.stateless(B::float_from_full_precision(tensor.primitive))
1723-
}
1724-
17251624
fn float_argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<B, D> {
17261625
B::float_argmax(tensor.primitive, dim)
17271626
}
+29
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#[burn_tensor_testgen::testgen(bridge)]
2+
mod tests {
3+
use super::*;
4+
use burn_tensor::{backend::Backend, module::embedding, Data, Distribution, Int, Tensor};
5+
6+
#[test]
7+
fn test_full_precision() {
8+
let device = Default::default();
9+
let x1 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
10+
.require_grad();
11+
let x2 = Tensor::<TestAutodiffBackend, 2>::random([32, 32], Distribution::Default, &device)
12+
.require_grad();
13+
14+
let x3 = x1.clone().into_full_precision();
15+
let x4 = x2.clone().into_full_precision();
16+
17+
let x5 = x3.matmul(x4);
18+
let x6 = Tensor::<TestAutodiffBackend, 2>::from_full_precision(x5);
19+
let x7 = x6 * x1.clone() / x2.clone();
20+
21+
let mut grads = x7.backward();
22+
23+
let x1_grad = x1.grad(&mut grads);
24+
let x2_grad = x2.grad(&mut grads);
25+
26+
assert!(x1_grad.is_some());
27+
assert!(x2_grad.is_some());
28+
}
29+
}

crates/burn-autodiff/src/tests/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ mod aggregation;
88
mod avgpool1d;
99
mod avgpool2d;
1010
mod backward;
11+
mod bridge;
1112
mod broadcast;
1213
mod cat;
1314
mod checkpoint;
@@ -64,6 +65,7 @@ macro_rules! testgen_all {
6465
// Behavior
6566
burn_autodiff::testgen_ad_broadcast!();
6667
burn_autodiff::testgen_gradients!();
68+
burn_autodiff::testgen_bridge!();
6769
burn_autodiff::testgen_checkpoint!();
6870

6971
// Activation

crates/burn-candle/src/backend.rs

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use candle_core::DeviceLocation;
55

66
use crate::{
77
element::{CandleElement, FloatCandleElement, IntCandleElement},
8-
CandleTensor,
8+
CandleTensor, PrecisionBridge,
99
};
1010

1111
/// Tensor backend that uses the [candle](candle_core) crate for executing tensor operations.
@@ -69,8 +69,7 @@ impl Default for CandleDevice {
6969
impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
7070
type Device = CandleDevice;
7171

72-
type FullPrecisionBackend = Candle<Self::FullPrecisionElem, Self::IntElem>;
73-
type FullPrecisionElem = f32;
72+
type FullPrecisionBridge = PrecisionBridge<f32>;
7473

7574
type FloatTensorPrimitive<const D: usize> = CandleTensor<Self::FloatElem, D>;
7675
type FloatElem = F;

0 commit comments

Comments
 (0)