Skip to content

Commit ff844b1

Browse files
Fix candle backend sync (#1579)
* Fix candle backend sync * tch mps sync * clippy --------- Co-authored-by: louisfd <[email protected]>
1 parent fb1da53 commit ff844b1

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

crates/burn-candle/src/backend.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::marker::PhantomData;
22

3-
use burn_tensor::backend::Backend;
3+
use burn_tensor::{backend::Backend, Device};
44
use candle_core::DeviceLocation;
55

66
use crate::{
@@ -91,4 +91,21 @@ impl<F: FloatCandleElement, I: IntCandleElement> Backend for Candle<F, I> {
9191
// TODO submit an issue at Candle
9292
panic!("Manual seed not supported by Candle. ")
9393
}
94+
95+
fn sync(device: &Device<Self>) {
96+
let device: candle_core::Device = (*device).into();
97+
98+
match device {
99+
candle_core::Device::Cpu => (),
100+
candle_core::Device::Cuda(device) => {
101+
#[cfg(feature = "cuda")]
102+
device.synchronize().unwrap();
103+
}
104+
candle_core::Device::Metal(device) => {
105+
// For some reason, device.wait_until_completed() does not seem to work,
106+
// and neither does writing and reading a value with into_data
107+
panic!("Device synchronization unavailable with Metal device on Candle backend")
108+
}
109+
}
110+
}
94111
}

crates/burn-tch/src/backend.rs

+15-4
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use crate::PrecisionBridge;
33
use super::element::TchElement;
44
use super::TchTensor;
55
use burn_tensor::backend::Backend;
6+
use burn_tensor::ops::IntTensorOps;
7+
use burn_tensor::{Int, Tensor};
68

79
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
810
/// The device struct when using the `tch` backend.
@@ -102,10 +104,19 @@ impl<E: TchElement> Backend for LibTorch<E> {
102104
}
103105

104106
fn sync(device: &Self::Device) {
105-
if let LibTorchDevice::Cuda(index) = device {
106-
tch::Cuda::synchronize(*index as i64);
107-
} else if let LibTorchDevice::Mps = device {
108-
panic!("Can't sync MPS device")
107+
match device {
108+
LibTorchDevice::Cpu => (),
109+
LibTorchDevice::Cuda(index) => {
110+
tch::Cuda::synchronize(*index as i64);
111+
}
112+
_ => {
113+
// When there is no explicit way to synchronize, we write and read one value to sync
114+
Tensor::<Self, 1, Int>::from_primitive(<Self as IntTensorOps<Self>>::int_zeros(
115+
[1].into(),
116+
device,
117+
))
118+
.into_data();
119+
}
109120
}
110121
}
111122
}

0 commit comments

Comments
 (0)