Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Alternative Optimizers #1164

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion candle-core/src/backprop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,15 @@ impl Tensor {
Op::Unary(_, UnaryOp::Round) => {
Err(Error::BackwardNotSupported { op: "round" })?
}
Op::Unary(_, UnaryOp::Gelu) => Err(Error::BackwardNotSupported { op: "gelu" })?,
Op::Unary(arg, UnaryOp::Gelu) => {
let sum_grad = grads.or_insert(arg)?;
let cube = arg.powf(3.)?;
let tanh = (0.0356774 * &cube + (0.797885 * arg)?)?.tanh()?;
let gelu_grad = (((0.5 * &tanh)?
+ (0.0535161 * cube + (0.398942 * arg)?)? * (1. - tanh.powf(2.)?))?
+ 0.5)?;
*sum_grad = sum_grad.add(&(&grad * gelu_grad)?)?
}
Op::Unary(_, UnaryOp::Erf) => Err(Error::BackwardNotSupported { op: "erf" })?,
Op::Unary(_, UnaryOp::GeluErf) => {
Err(Error::BackwardNotSupported { op: "gelu-erf" })?
Expand Down
13 changes: 13 additions & 0 deletions candle-core/tests/grad_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ fn unary_grad(device: &Device) -> Result<()> {
test_utils::to_vec1_round(grad_x, 2)?,
[0.01, 0.42, 0.0, 0.98],
);

// testing compared to pytorch nn.GELU(approximate = 'tanh')
let y = x.gelu()?;
let grads = y.backward()?;
let grad_x = grads.get(&x).context("no grad for x")?;
assert_eq!(
test_utils::to_vec1_round(&y, 4)?,
[2.9964, 0.8412, 3.9999, 0.0839]
);
assert_eq!(
test_utils::to_vec1_round(grad_x, 4)?,
[1.0116, 1.0830, 1.0003, 0.6188],
);
Ok(())
}

Expand Down
2 changes: 1 addition & 1 deletion candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pub use init::Init;
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
pub use linear::{linear, linear_no_bias, Linear};
pub use ops::Dropout;
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
pub use optim::{AdamW, NesterovSGD, Optimizer, ParamsAdamW, ParamsNesterovSGD, SGD};
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
pub use sequential::{seq, Sequential};
pub use var_builder::VarBuilder;
Expand Down
22 changes: 22 additions & 0 deletions candle-nn/src/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,25 @@ pub fn cross_entropy(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
pub fn mse(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
(inp - target)?.sqr()?.mean_all()
}

/// The binary cross-entropy with logit loss.
///
/// Arguments
///
/// * [inp]: The input tensor of dimensions `N, C` where `N` is the batch size and `C` the number
/// of categories. This is expected to raw logits.
/// * [target]: The ground truth labels as a tensor of u32 of dimension `N, C` where `N` is the batch size and `C` the number
/// of categories.
///
/// The resulting tensor is a scalar containing the average value over the batch.
pub fn binary_cross_entropy_with_logit(inp: &Tensor, target: &Tensor) -> Result<Tensor> {
let inp = crate::ops::sigmoid(inp)?;

let left_side = target * inp.log()?;
let right_side = (target.affine(-1., 1.))? * inp.affine(-1., 1.)?.log()?;

let loss = left_side? + right_side?;
let loss = loss?.neg()?.mean_all()?;

Ok(loss)
}
78 changes: 77 additions & 1 deletion candle-nn/src/optim.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
//! Various optimization algorithms.
use candle::{Result, Tensor, Var};
use std::collections::HashMap;

use candle::{Result, Tensor, TensorId, Var};

/// The interface optimizers should implement.
pub trait Optimizer: Sized {
Expand Down Expand Up @@ -79,6 +81,80 @@ impl SGD {
}
}

/// Optimizer for Stochastic Gradient Descent with Nesterov momentum.
///
/// Similar to PyTorch SGD but without weight decay.

pub struct NesterovSGD {
vars: Vec<Var>,
params: ParamsNesterovSGD,
prev_step: HashMap<TensorId, Tensor>,
}

#[derive(Clone, Debug)]
pub struct ParamsNesterovSGD {
pub learning_rate: f64,
pub momentum: f64,
}

impl Optimizer for NesterovSGD {
type Config = ParamsNesterovSGD;

fn new(vars: Vec<Var>, params: ParamsNesterovSGD) -> Result<Self> {
let vars = vars
.into_iter()
.filter(|var| var.dtype().is_float())
.collect();
Ok(Self {
vars,
params,
prev_step: HashMap::new(),
})
}

fn learning_rate(&self) -> f64 {
self.params.learning_rate
}

fn step(&mut self, grads: &candle::backprop::GradStore) -> Result<()> {
for var in self.vars.iter() {
if let Some(grad) = grads.get(var) {
let gt;
let bt;
if let Some(prev_step) = self.prev_step.get(&var.id()) {
// println!("Exists");
// bt​←μbt−1​+(1−τ)gt
bt = ((prev_step * self.params.momentum)? + grad)?;
gt = (grad + (self.params.momentum * &bt)?)?;
} else {
// println!("Doesn't Exist");
// bt​←μbt−1​+(1−τ)gt
bt = (1. * grad)?;
gt = (grad + (self.params.momentum * &bt)?)?;
}
// println!("Momentum {}", bt);
self.prev_step.insert(var.id(), bt);
var.set(&var.sub(&(gt * self.params.learning_rate)?)?)?;
}
}
Ok(())
}

fn set_learning_rate(&mut self, lr: f64) {
self.params.learning_rate = lr
}
}

impl NesterovSGD {
pub fn into_inner(self) -> Vec<Var> {
self.vars
}

pub fn push(&mut self, var: &Var) {
self.vars.push(var.clone())
}
}

#[derive(Clone, Debug)]
pub struct ParamsAdamW {
pub lr: f64,
Expand Down
47 changes: 47 additions & 0 deletions candle-nn/tests/loss.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,50 @@ fn nll_and_cross_entropy() -> Result<()> {
assert_eq!(to_vec0_round(&loss, 4)?, 1.1312);
Ok(())
}

/* Equivalent python code:
import torch
import torch.nn.functional as F

inp = torch.Tensor([[ 2.3611, -0.8813, -0.5006, -0.2178],
[ 0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[ 1.3081, 0.6641, 1.1802, -0.2547],
[ 0.5292, 0.7636, 0.3692, -0.8318]])

target = torch.Tensor([[0., 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.]])

print(F.binary_cross_entropy_with_logits(inp, target))
*/
#[test]
fn binary_cross_entropy_with_logit() -> Result<()> {
let cpu = Device::Cpu;

let inp = [
[2.3611f32, -0.8813, -0.5006, -0.2178],
[0.0419, 0.0763, -1.0457, -1.6692],
[-1.0494, 0.8111, 1.5723, 1.2315],
[1.3081, 0.6641, 1.1802, -0.2547],
[0.5292, 0.7636, 0.3692, -0.8318],
];

let target = [
[0.0f32, 1., 0., 0.],
[0., 1., 0., 0.],
[0., 0., 0., 1.],
[1., 0., 0., 0.],
[0., 0., 1., 0.],
];

let inp = Tensor::new(&inp, &cpu)?;
let target = Tensor::new(&target, &cpu)?;

let loss = candle_nn::loss::binary_cross_entropy_with_logit(&inp, &target)?;

assert_eq!(to_vec0_round(&loss, 4)?, 0.8224);
Ok(())
}
56 changes: 55 additions & 1 deletion candle-nn/tests/optim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use candle::test_utils::{to_vec0_round, to_vec2_round};

use anyhow::Result;
use candle::{Device, Tensor, Var};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
use candle_nn::{
AdamW, Linear, Module, NesterovSGD, Optimizer, ParamsAdamW, ParamsNesterovSGD, SGD,
};

#[test]
fn sgd_optim() -> Result<()> {
Expand Down Expand Up @@ -71,6 +73,58 @@ fn sgd_linear_regression() -> Result<()> {
Ok(())
}

/* The results of this test have been checked against the following PyTorch code.
import torch
from torch import optim

w_gen = torch.tensor([[3., 1.]])
b_gen = torch.tensor([-2.])

sample_xs = torch.tensor([[2., 1.], [7., 4.], [-4., 12.], [5., 8.]])
sample_ys = sample_xs.matmul(w_gen.t()) + b_gen

m = torch.nn.Linear(2, 1)
with torch.no_grad():
m.weight.zero_()
m.bias.zero_()
optimizer = optim.SGD(m.parameters(), lr=0.004, momentum=0.1, nesterov=True)
for _step in range(100):
optimizer.zero_grad()
ys = m(sample_xs)
loss = ((ys - sample_ys)**2).sum()
loss.backward()
optimizer.step()
print(m.weight)
print(m.bias)
*/
#[test]
fn nesterov_sgd_linear_regression() -> Result<()> {
// Generate some linear data, y = 3.x1 + x2 - 2.
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
let gen = Linear::new(w_gen, Some(b_gen));
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
let sample_ys = gen.forward(&sample_xs)?;

let params = ParamsNesterovSGD {
learning_rate: 0.004,
momentum: 0.1,
};
// Now use backprop to run a linear regression between samples and get the coefficients back.
let w = Var::new(&[[0f32, 0.]], &Device::Cpu)?;
let b = Var::new(0f32, &Device::Cpu)?;
let mut n_sgd = NesterovSGD::new(vec![w.clone(), b.clone()], params)?;
let lin = Linear::new(w.as_tensor().clone(), Some(b.as_tensor().clone()));
for _step in 0..100 {
let ys = lin.forward(&sample_xs)?;
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
n_sgd.backward_step(&loss)?;
}
assert_eq!(w.to_vec2::<f32>()?, &[[1.07495, -9.90416]]);
assert_eq!(b.to_scalar::<f32>()?, -1.8961483);
Ok(())
}

/* The following test returns the same values as the PyTorch code below.
import torch
from torch import optim
Expand Down
Loading