Skip to content

Commit

Permalink
Detach the tensors on batch-norm eval. (#1702)
Browse files Browse the repository at this point in the history
* Detach the tensors on batch-norm eval.

* Fix pyo3 bindings.

* Black tweak.

* Formatting.

* Also update the pyo3-onnx formatting.

* Apply black.
  • Loading branch information
LaurentMazare authored Feb 13, 2024
1 parent 13c6722 commit ad73e93
Show file tree
Hide file tree
Showing 14 changed files with 117 additions and 27 deletions.
2 changes: 1 addition & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ impl Tensor {
// the backprop graph of the backprop itself. This would be an issue for second order
// derivatives but these are out of scope at the moment.
let do_not_detach = CANDLE_GRAD_DO_NOT_DETACH.with(|b| *b);
let grad = if do_not_detach { grad } else { grad.detach()? };
let grad = if do_not_detach { grad } else { grad.detach() };
if let Some(op) = node.op() {
match op {
Op::Binary(lhs, rhs, BinaryOp::Add) => {
Expand Down
6 changes: 3 additions & 3 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1882,9 +1882,9 @@ impl Tensor {
/// this new node. The storage of this tensor is shared with the initial tensor.
///
/// If the tensor is already detached from the computation graph, the same tensor is returned.
pub fn detach(&self) -> Result<Tensor> {
pub fn detach(&self) -> Tensor {
if self.op.is_none() && !self.is_variable {
Ok(self.clone())
self.clone()
} else {
let tensor_ = Tensor_ {
id: TensorId::new(),
Expand All @@ -1895,7 +1895,7 @@ impl Tensor {
dtype: self.dtype,
device: self.device.clone(),
};
Ok(Tensor(Arc::new(tensor_)))
Tensor(Arc::new(tensor_))
}
}

Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ impl Var {
Ok(Self(inner))
}

pub fn as_detached_tensor(&self) -> Tensor {
self.0.detach()
}

pub fn as_tensor(&self) -> &Tensor {
&self.0
}
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/reinforcement-learning/ddpg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl DDPG<'_> {
pub fn actions(&mut self, state: &Tensor) -> Result<f32> {
let actions = self
.actor
.forward(&state.detach()?.unsqueeze(0)?)?
.forward(&state.detach().unsqueeze(0)?)?
.squeeze(0)?;
let actions = if self.train {
(actions + self.ou_noise.sample()?)?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn run() -> Result<()> {
loop {
let action = {
let action_probs: Vec<f32> =
softmax(&model.forward(&state.detach()?.unsqueeze(0)?)?, 1)?
softmax(&model.forward(&state.detach().unsqueeze(0)?)?, 1)?
.squeeze(0)?
.to_vec1()?;
weighted_sample(action_probs, &mut rng)? as i64
Expand Down Expand Up @@ -109,7 +109,7 @@ pub fn run() -> Result<()> {

let rewards = Tensor::from_vec(accumulate_rewards(&steps), batch_size, &Device::Cpu)?
.to_dtype(DType::F32)?
.detach()?;
.detach();

let actions_mask = {
let actions: Vec<i64> = steps.iter().map(|s| s.action).collect();
Expand All @@ -126,12 +126,12 @@ pub fn run() -> Result<()> {
.unwrap()
})
.collect();
Tensor::stack(&actions_mask, 0)?.detach()?
Tensor::stack(&actions_mask, 0)?.detach()
};

let states = {
let states: Vec<Tensor> = steps.into_iter().map(|s| s.state).collect();
Tensor::stack(&states, 0)?.detach()?
Tensor::stack(&states, 0)?.detach()
};

let log_probs = actions_mask
Expand Down
14 changes: 12 additions & 2 deletions candle-nn/src/batch_norm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,19 @@ impl BatchNorm {
let target_shape = target_shape.as_slice();

let x = x
.broadcast_sub(&self.running_mean.as_tensor().reshape(target_shape)?)?
.broadcast_sub(
&self
.running_mean
.as_detached_tensor()
.reshape(target_shape)?,
)?
.broadcast_div(
&(self.running_var.as_tensor().reshape(target_shape)? + self.eps)?.sqrt()?,
&(self
.running_var
.as_detached_tensor()
.reshape(target_shape)?
+ self.eps)?
.sqrt()?,
)?;

match &self.weight_and_bias {
Expand Down
Loading

0 comments on commit ad73e93

Please sign in to comment.