Skip to content

Commit e823338

Browse files
authored
Add Clone trait to the OptimizerAdaptor and Clone implementations to the optimizers (#1770)
1 parent f8a1356 commit e823338

File tree

10 files changed

+14
-1
lines changed

10 files changed

+14
-1
lines changed

crates/burn-core/src/grad_clipping/base.rs

+1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ impl GradientClippingConfig {
3030
/// Gradient Clipping provides a way to mitigate exploding gradients
3131
/// by clipping every component of the gradient by value or by norm during
3232
/// backpropagation.
33+
#[derive(Clone)]
3334
pub enum GradientClipping {
3435
/// Clip the gradient by value.
3536
Value(f32),

crates/burn-core/src/optim/adagrad.rs

+2
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub struct AdaGradConfig {
2626
}
2727

2828
/// AdaGrad optimizer
29+
#[derive(Clone)]
2930
pub struct AdaGrad<B: Backend> {
3031
lr_decay: LrDecay,
3132
weight_decay: Option<WeightDecay<B>>,
@@ -105,6 +106,7 @@ pub struct LrDecayState<B: Backend, const D: usize> {
105106
sum: Tensor<B, D>,
106107
}
107108

109+
#[derive(Clone)]
108110
struct LrDecay {
109111
lr_decay: f64,
110112
epsilon: f32,

crates/burn-core/src/optim/adam.rs

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub struct AdamConfig {
3131
}
3232

3333
/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
34+
#[derive(Clone)]
3435
pub struct Adam<B: Backend> {
3536
momentum: AdaptiveMomentum,
3637
weight_decay: Option<WeightDecay<B>>,
@@ -113,6 +114,7 @@ pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
113114
moment_2: Tensor<B, D>,
114115
}
115116

117+
#[derive(Clone)]
116118
struct AdaptiveMomentum {
117119
beta_1: f32,
118120
beta_2: f32,

crates/burn-core/src/optim/adamw.rs

+2
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ pub struct AdamWConfig {
3030
}
3131

3232
/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
33+
#[derive(Clone)]
3334
pub struct AdamW<B: Backend> {
3435
momentum: AdaptiveMomentumW,
3536
weight_decay: f32,
@@ -112,6 +113,7 @@ pub struct AdaptiveMomentumWState<B: Backend, const D: usize> {
112113
moment_2: Tensor<B, D>,
113114
}
114115

116+
#[derive(Clone)]
115117
struct AdaptiveMomentumW {
116118
beta_1: f32,
117119
beta_2: f32,

crates/burn-core/src/optim/decay.rs

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub struct WeightDecayState<B: Backend, const D: usize> {
2020
}
2121

2222
/// Weight decay implementation that transforms gradients.
23+
#[derive(Clone)]
2324
pub struct WeightDecay<B: Backend> {
2425
penalty: B::FloatElem,
2526
}

crates/burn-core/src/optim/momentum.rs

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub struct MomentumState<B: Backend, const D: usize> {
2727
}
2828

2929
/// Momemtum implementation that transforms gradients.
30+
#[derive(Clone)]
3031
pub struct Momentum<B: Backend> {
3132
momentum: B::FloatElem,
3233
dampening: f64,

crates/burn-core/src/optim/rmsprop.rs

+2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ impl RmsPropConfig {
6464

6565
/// Optimizer that implements stochastic gradient descent with momentum.
6666
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
67+
#[derive(Clone)]
6768
pub struct RmsProp<B: Backend> {
6869
alpha: f32,
6970
// epsilon: f32,
@@ -251,6 +252,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
251252

252253
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
253254
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
255+
#[derive(Clone)]
254256
pub struct RmsPropMomentum {
255257
momentum: f32,
256258
epsilon: f32,

crates/burn-core/src/optim/sgd.rs

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub struct SgdConfig {
2525
/// Optimizer that implements stochastic gradient descent with momentum.
2626
///
2727
/// The optimizer can be configured with [SgdConfig](SgdConfig).
28+
#[derive(Clone)]
2829
pub struct Sgd<B: Backend> {
2930
momentum: Option<Momentum<B>>,
3031
weight_decay: Option<WeightDecay<B>>,

crates/burn-core/src/optim/simple/adaptor.rs

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use hashbrown::HashMap;
1111

1212
/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
1313
/// an [optimizer](Optimizer).
14+
#[derive(Clone)]
1415
pub struct OptimizerAdaptor<O, M, B>
1516
where
1617
O: SimpleOptimizer<B::InnerBackend>,

crates/burn-core/src/optim/simple/base.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, Tensor};
66
///
77
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
88
/// module parameter structure, handle tracked and untracked tensors, and the likes.
9-
pub trait SimpleOptimizer<B>: Send + Sync
9+
pub trait SimpleOptimizer<B>: Send + Sync + Clone
1010
where
1111
B: Backend,
1212
{

0 commit comments

Comments
 (0)