Skip to content
This repository has been archived by the owner on Dec 6, 2023. It is now read-only.

bug-fix: just-in-time update in SVRG; feature: averaging for SVRG #120

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions examples/plot_svrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ def __call__(self, clf):
print("eta =", eta)
cb = Callback(X, y)
clf = SVRGClassifier(loss="squared_hinge", alpha=1e-5, eta=eta,
n_inner=1.0, max_iter=20, random_state=0, callback=cb)
do_averaging=True, n_inner=1.0, max_iter=20,
random_state=0, callback=cb)
clf.fit(X, y)
plt.plot(cb.times, cb.obj, label="eta=" + str(eta))

Expand All @@ -85,8 +86,8 @@ def __call__(self, clf):
print("n_inner =", n_inner)
cb = Callback(X, y)
clf = SVRGClassifier(loss="squared_hinge", alpha=1e-5, eta=1e-4,
n_inner=n_inner, max_iter=20, random_state=0,
callback=cb)
do_averaging=True, n_inner=n_inner, max_iter=20,
random_state=0, callback=cb)
clf.fit(X, y)
plt.plot(cb.times, cb.obj, label="n_inner=" + str(n_inner))

Expand Down
21 changes: 11 additions & 10 deletions lightning/impl/svrg.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@
class _BaseSVRG(object):

def _finalize_coef(self):
self.coef_ *= self.coef_scale_
self.coef_scale_.fill(1.0)
pass

def _fit(self, X, Y):
def _fit(self, X, Y, do_averaging):
n_samples, n_features = X.shape
rng = self._get_random_state()
loss = self._get_loss()
Expand All @@ -32,16 +31,16 @@ def _fit(self, X, Y):
ds = get_dataset(X, order="c")

self.coef_ = np.zeros((n_vectors, n_features), dtype=np.float64)
self.average_coef_ = np.zeros((n_vectors, n_features), dtype=np.float64)
full_grad = np.zeros_like(self.coef_)
grad = np.zeros((n_vectors, n_samples), dtype=np.float64)
self.coef_scale_ = np.ones(n_vectors, dtype=np.float64)

for i in xrange(n_vectors):
y = Y[:, i]

_svrg_fit(self, ds, y, self.coef_[i], self.coef_scale_[i:],
_svrg_fit(self, ds, y, self.coef_[i], self.average_coef_[i],
full_grad[i], grad[i], self.eta, self.alpha, loss,
self.max_iter, n_inner, self.tol, self.verbose,
self.max_iter, n_inner, self.tol, self.verbose, self.do_averaging,
self.callback, rng)

return self
Expand All @@ -59,7 +58,7 @@ class SVRGClassifier(BaseClassifier, _BaseSVRG):

def __init__(self, eta=1.0, alpha=1.0, loss="smooth_hinge", gamma=1.0,
max_iter=10, n_inner=1.0, tol=1e-3, verbose=0,
callback=None, random_state=None):
callback=None, random_state=None, do_averaging=False):
self.eta = eta
self.alpha = alpha
self.loss = loss
Expand All @@ -70,6 +69,7 @@ def __init__(self, eta=1.0, alpha=1.0, loss="smooth_hinge", gamma=1.0,
self.verbose = verbose
self.callback = callback
self.random_state = random_state
self.do_averaging = do_averaging

def _get_loss(self):
losses = {
Expand All @@ -85,7 +85,7 @@ def fit(self, X, y):
self._set_label_transformers(y)
Y = np.asfortranarray(self.label_binarizer_.transform(y),
dtype=np.float64)
return self._fit(X, Y)
return self._fit(X, Y, self.do_averaging)


class SVRGRegressor(BaseRegressor, _BaseSVRG):
Expand All @@ -100,7 +100,7 @@ class SVRGRegressor(BaseRegressor, _BaseSVRG):

def __init__(self, eta=1.0, alpha=1.0, loss="squared", gamma=1.0,
max_iter=10, n_inner=1.0, tol=1e-3, verbose=0,
callback=None, random_state=None):
callback=None, random_state=None, do_averaging=False):
self.eta = eta
self.alpha = alpha
self.loss = loss
Expand All @@ -111,6 +111,7 @@ def __init__(self, eta=1.0, alpha=1.0, loss="squared", gamma=1.0,
self.verbose = verbose
self.callback = callback
self.random_state = random_state
self.do_averaging = do_averaging

def _get_loss(self):
losses = {
Expand All @@ -122,4 +123,4 @@ def fit(self, X, y):
self.outputs_2d_ = len(y.shape) > 1
Y = y.reshape(-1, 1) if not self.outputs_2d_ else y
Y = Y.astype(np.float64)
return self._fit(X, Y)
return self._fit(X, Y, self.do_averaging)
93 changes: 60 additions & 33 deletions lightning/impl/svrg_fast.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# cython: wraparound=False
#
# Author: Mathieu Blondel
# Krishna Pillutla (averaging support)
# License: BSD

import numpy as np
Expand All @@ -12,6 +13,7 @@ cimport numpy as np
ctypedef np.int64_t LONG

from libc.math cimport sqrt
from libc.math cimport pow as powc

from lightning.impl.randomkit.random_fast cimport RandomState
from lightning.impl.dataset_fast cimport RowDataset
Expand Down Expand Up @@ -49,7 +51,7 @@ def _svrg_fit(self,
RowDataset X,
np.ndarray[double, ndim=1]y,
np.ndarray[double, ndim=1]coef,
np.ndarray[double, ndim=1]coef_scale,
np.ndarray[double, ndim=1]avg_coef,
np.ndarray[double, ndim=1]full_grad,
np.ndarray[double, ndim=1]grad,
double eta,
Expand All @@ -59,6 +61,7 @@ def _svrg_fit(self,
int n_inner,
double tol,
int verbose,
int do_averaging,
callback,
RandomState rng):

Expand All @@ -71,23 +74,28 @@ def _svrg_fit(self,
cdef double violation, violation_init, violation_ratio
cdef double eta_avg = eta / n_samples
cdef double eta_alpha = eta * alpha
cdef double one_minus_eta_alpha = 1 - eta_alpha
cdef double one_over_eta_alpha = 1 / eta_alpha if eta_alpha > 0 else 0.0
cdef int has_callback = callback is not None
cdef double w_scale = 1.0
cdef double avg_a = 0.0, avg_b = 1.0
cdef double correction, correction_avg
cdef double mu

# Data pointers.
cdef double* data
cdef int* indices
cdef int n_nz

# Buffers and pointers.
cdef np.ndarray[int, ndim=1]last = np.zeros(n_features, dtype=np.int32)
cdef double* w = <double*>coef.data
cdef double* w_scale = <double*>coef_scale.data
cdef double* w_avg = <double*>avg_coef.data
cdef double* fg = <double*>full_grad.data
cdef double* g = <double*>grad.data

for it in xrange(max_iter):

# Reset full gradient
# Reset full gradient.
for j in xrange(n_features):
fg[j] = 0

Expand All @@ -98,7 +106,7 @@ def _svrg_fit(self,
X.get_row_ptr(i, &indices, &data, &n_nz)

# Make prediction.
y_pred = _pred(data, indices, n_nz, w) * w_scale[0]
y_pred = _pred(data, indices, n_nz, w) * w_scale

# A gradient is given by g[i] * X[i].
g[i] = -loss.get_update(y_pred, y[i])
Expand All @@ -107,7 +115,7 @@ def _svrg_fit(self,

# Compute optimality violation.
violation = 0
alpha_scaled = alpha * w_scale[0]
alpha_scaled = alpha * w_scale
for j in xrange(n_features):
tmp = fg[j] / n_samples + alpha_scaled * w[j]
violation += tmp * tmp
Expand All @@ -134,46 +142,65 @@ def _svrg_fit(self,
# Retrieve sample i.
X.get_row_ptr(i, &indices, &data, &n_nz)

# Add deterministic part, just in time.
if t > 0:
for jj in xrange(n_nz):
j = indices[jj]
w[j] -= eta_avg / w_scale[0] * (t - last[j]) * fg[j]
last[j] = t

# Make prediction.
y_pred = _pred(data, indices, n_nz, w) * w_scale[0]
# Make prediction, accounting for correction due to
# dense (deterministic) part of update.
y_pred = _pred(data, indices, n_nz, w) * w_scale
if eta_alpha > 0:
correction = (1 - powc(one_minus_eta_alpha, t)) / eta_alpha
else:
correction = t
y_pred -= _pred(data, indices, n_nz, fg) * eta_avg * correction

# A gradient is given by scale * X[i].
scale = -loss.get_update(y_pred, y[i])

w_scale[0] *= (1 - eta_alpha)

# Add deterministic part.
#for j in xrange(n_features):
#w[j] -= eta_avg / w_scale * fg[j]
w_scale *= (1 - eta_alpha)

# Add stochastic part.
_add(data, indices, n_nz, eta * (g[i] - scale) / w_scale[0], w)
_add(data, indices, n_nz, eta * (g[i] - scale) / w_scale, w)

# Take care of possible underflows.
if w_scale[0] < 1e-9:
# Update average (or reset, at t = 0) of stochastic part.
if t == 0:
for j in xrange(n_features):
w[j] *= w_scale[0]
w_scale[0] = 1.0
w_avg[j] = 0.0
avg_a = w_scale
avg_b = 1.0
else:
mu = 1.0 / (t + 1)
_add(data, indices, n_nz, eta * (scale - g[i]) * avg_a / w_scale, w_avg)
avg_b /= (1.0 - mu)
avg_a += mu * avg_b * w_scale

# Finalize.
# Take care of possible underflows.
if w_scale < 1e-9:
for j in xrange(n_features):
w[j] *= w_scale
avg_a /= w_scale
w_scale = 1.0

# Finalize. Reconstruct w and w_avg. Add deterministic update to w and w_avg.
if eta_alpha > 0:
correction = (1.0 - powc(one_minus_eta_alpha, n_inner)) / eta_alpha
correction_avg = one_over_eta_alpha - one_minus_eta_alpha * correction / (n_inner * eta_alpha)
else:
correction = n_inner
correction_avg = (n_inner - 1.0) / 2.0
for j in xrange(n_features):
w[j] -= eta_avg / w_scale[0] * (n_inner - last[j]) * fg[j]
last[j] = 0
w_avg[j] = (w_avg[j] + avg_a * w[j]) / avg_b
w_avg[j] -= eta_avg * fg[j] * correction_avg
w[j] *= w_scale
w[j] -= eta_avg * fg[j] * correction
w_scale = 1.0
avg_a = 0.0
avg_b = 1.0

# Update iterate, if averaging
if do_averaging:
for j in xrange(n_features):
w[j] = w_avg[j]

# Callback.
if has_callback:
ret = callback(self)
if ret is not None:
break

# Rescale coefficients.
for j in xrange(n_features):
w[j] *= w_scale[0]
w_scale[0] = 1.0