Skip to content

Commit

Permalink
Implementing matrix matrix products for Hessians and Jacobians (#8)
Browse files Browse the repository at this point in the history
* 5x speedup observed for 100 rank Hessian mat mat vs for looping using unstack and stack

* updating the hessian blocking PR

* updating matmat pr

* clashing names with keras datasets

* extensive update to adding derivative loss, and refactoring the data object to work as a dictionary based iterator

* updating hessianlearn model unit test

* hessian blocking implemented for GANs

* in the middle of updating many things in model, changing test to validation and making the validation frequency a model settings

* test to validation now

* updating some logging aspects of model

* updating model

* updating lrsfn to log eigenvalues, no longer logging rq stds by default

* more adjustments to logging

* remove the print

* updating unit test

* updating readme

* checking markdown

* updating readme

* updating readme again

* updating readme

* updating readme

* updating readme

* updating readme

* updating readme

* updating tutorial and readme before merge
  • Loading branch information
tomoleary authored Mar 16, 2021
1 parent 2e78913 commit f0cdbaf
Show file tree
Hide file tree
Showing 23 changed files with 975 additions and 681 deletions.
115 changes: 111 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
<!-- # hessianlearn -->


Hessian based stochastic optimization in tensorflow and keras

___ ___ ___ ___ ___ ___
/__/\ / /\ / /\ / /\ ___ / /\ /__/\
Expand Down Expand Up @@ -31,8 +30,116 @@



https://arxiv.org/abs/1905.06738
https://arxiv.org/abs/2002.02881


[![Build Status](https://travis-ci.com/tomoleary/hessianlearn.svg?branch=master)](https://travis-ci.com/tomoleary/hessianlearn)
[![License](https://https://img.shields.io/github/license/tomoleary/hessianlearn)](./LICENSE.md)
[![Top language](https://img.shields.io/github/languages/top/tomoleary/hessianlearn)](https://www.python.org)
![Code size](https://img.shields.io/github/languages/code-size/tomoleary/hessianlearn)
[![Issues](https://img.shields.io/github/issues/tomoleary/hessianlearn)](https://github.com/tomoleary/hessianlearn/issues)
[![Latest commit](https://img.shields.io/github/last-commit/tomoleary/hessianlearn)](https://github.com/tomoleary/hessianlearn/commits/master)

# Hessian-based stochastic optimization in TensorFlow and keras

This code implements Hessian-based stochastic optimization in TensorFlow and keras by exposing the matrix-free Hessian to users. The code is meant to allow for rapid-prototyping of Hessian-based algorithms via the matrix-free Hessian action, which allows users to inspect Hessian based information for stochastic nonconvex (neural network training) optimization problems.

The Hessian action is exposed via matrix-vector products:
<p align="center">
<img src="https://latex.codecogs.com/gif.latex?H\widehat{w}=\frac{d}{dw}(g^T\widehat{w})" />
</p>

and matrix-matrix products:
<p align="center">
<img src="https://latex.codecogs.com/gif.latex?H\widehat{W}=\frac{d}{dw}(g^T\widehat{W})" />
</p>

## Compatibility

The code is compatible with Tensorflow v1 and v2, but certain features of v2 are disabled (like eager execution). This is because the Hessian matrix products in hessianlearn are implemented using `placeholders` which have been deprecated in v2. For this reason hessianlearn cannot work with data generators and things like this that require eager execution. If any compatibility issues are found, please open an [issue](https://github.com/tomoleary/hessianlearn/issues).

## Usage
Set `HESSIANLEARN_PATH` environmental variable

Train a keras model

```python
import tensorflow as tf
sys.path.append( os.environ.get('HESSIANLEARN_PATH'))
from hessianlearn import *

# Define keras neural network model
neural_network = tf.keras.models.Model(...)

```

hessianlearn implements various training [`problem`](https://github.com/tomoleary/hessianlearn/blob/master/hessianlearn/problem/problem.py) constructs (regression, classification, autoencoders, variational autoencoders, generative adversarial networks). Instantiate a `problem`, a `data` object (which takes a dictionary with keys that correspond to the corresponding `placeholders` in `problem`) and `regularization`

```python
# Instantiate the problem (this handles the loss function,
# construction of hessian and gradient etc.)
problem = RegressionProblem(neural_network,dtype = tf.float32)
# Instantiate the data object, this handles the train / validation split
# as well as iterating during training
data = Data({problem.x:x_data,problem.y_true},train_batch_size,\
validation_data_size = validation_data_size)
# Instantiate the regularization: L2Regularization is Tikhonov,
# gamma = 0 is no regularization
regularization = L2Regularization(problem,gamma = 0)
```

Pass these objects into the `HessianlearnModel` which handles the training

```python
HLModel = HessianlearnModel(problem,regularization,data)
HLModel.fit()
```


## Examples

[Tutorial 0: MNIST Autoencoder](https://github.com/tomoleary/hessianlearn/blob/mat_mats/tutorial/Tutorial%200%20MNIST%20Autoencoder.ipynb)


# References

These publications motivate and use the hessianlearn library for stochastic nonconvex optimization

- \[1\] O'Leary-Roseberry, T., Alger, N., Ghattas O.,
[**Inexact Newton Methods for Stochastic Nonconvex Optimization with Applications to Neural Network Training**](https://arxiv.org/abs/1905.06738).
arXiv:1905.06738.
([Download](https://arxiv.org/pdf/1905.06738.pdf))<details><summary>BibTeX</summary><pre>
@article{o2019inexact,
title={Inexact Newton methods for stochastic nonconvex optimization with applications to neural network training},
author={O'Leary-Roseberry, Thomas and Alger, Nick and Ghattas, Omar},
journal={arXiv preprint arXiv:1905.06738},
year={2019}
}
}</pre></details>

- \[2\] O'Leary-Roseberry, T., Alger, N., Ghattas O.,
[**Low Rank Saddle Free Newton**](https://arxiv.org/abs/2002.02881).
arXiv:2002.02881.
([Download](https://arxiv.org/pdf/2002.02881.pdf))<details><summary>BibTeX</summary><pre>
@article{o2020low,
title={Low Rank Saddle Free Newton: Algorithm and Analysis},
author={O'Leary-Roseberry, Thomas and Alger, Nick and Ghattas, Omar},
journal={arXiv preprint arXiv:2002.02881},
year={2020}
}
}</pre></details>


- \[3\] O'Leary-Roseberry, T., Villa, U., Chen P., Ghattas O.,
[**Derivative-Informed Projected Neural Networks for High-Dimensional Parametric Maps Governed by PDEs**](https://arxiv.org/abs/2011.15110).
arXiv:2011.15110.
([Download](https://arxiv.org/pdf/2011.15110.pdf))<details><summary>BibTeX</summary><pre>
@article{o2020derivative,
title={Derivative-Informed Projected Neural Networks for High-Dimensional Parametric Maps Governed by PDEs},
author={O'Leary-Roseberry, Thomas and Villa, Umberto and Chen, Peng and Ghattas, Omar},
journal={arXiv preprint arXiv:2011.15110},
year={2020}
}
}</pre></details>




2 changes: 1 addition & 1 deletion hessianlearn/algorithms/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def ParametersAdam(parameters = {}):
parameters['alpha'] = [1e-3, "Initial steplength, or learning rate"]
parameters['beta_1'] = [0.9, "Exponential decay rate for first moment"]
parameters['beta_2'] = [0.999, "Exponential decay rate for second moment"]
parameters['epsilon'] = [1e-8, "epsilon for denominator involving square root"]
parameters['epsilon'] = [1e-7, "epsilon for denominator involving square root"]

parameters['rel_tolerance'] = [1e-3, "Relative convergence when sqrt(g,g)/sqrt(g_0,g_0) <= rel_tolerance"]
parameters['abs_tolerance'] = [1e-4,"Absolute converge when sqrt(g,g) <= abs_tolerance"]
Expand Down
14 changes: 7 additions & 7 deletions hessianlearn/algorithms/cgSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self,problem,regularization,sess = None,Aop = None,preconditioner =
self.x = x
self.parameters = parameters
if Aop is None:
self.Aop = self.problem.H_action + self.regularization.H_action
self.Aop = self.problem.Hdw + self.regularization.Hdw
else:
# be careful to note what the operator requires be passed into feed_dict
self.Aop = Aop
Expand Down Expand Up @@ -178,7 +178,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
self.reason_id = 0
x = np.zeros_like(b)

feed_dict[self.problem.w_hat] = x
feed_dict[self.problem.dw] = x
Ax_0 = self.sess.run(self.Aop,feed_dict = feed_dict)
# Calculate initial residual r = Ax_0 -b
r = b - Ax_0
Expand Down Expand Up @@ -207,7 +207,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
print( "Converged in ", self.iter, " iterations with final norm ", self.final_norm)
return x, False
# Check if the direction is negative before taking a step.
feed_dict[self.problem.w_hat] = p
feed_dict[self.problem.dw] = p
Ap = self.sess.run(self.Aop,feed_dict = feed_dict)
pAp = np.dot(p,Ap)
negative_direction = (pAp <= 0.0)
Expand Down Expand Up @@ -265,7 +265,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
beta = rz / rz_0
p = z + beta*p
# Check if the direction is negative, and prepare for next iteration.
feed_dict[self.problem.w_hat] = p
feed_dict[self.problem.dw] = p
Ap = self.sess.run(self.Aop,feed_dict = feed_dict)
pAp = np.dot(p,Ap)
negative_direction = (pAp <= 0.0)
Expand Down Expand Up @@ -315,7 +315,7 @@ def __init__(self,problem,regularization,sess = None,Aop = None,preconditioner =
self.regularization = regularization
self.parameters = parameters
if Aop is None:
self.Aop = self.problem.H_action + self.regularization.H_action
self.Aop = self.problem.Hdw + self.regularization.Hdw
else:
# be careful to note what the operator requires be passed into feed_dict
self.Aop = Aop
Expand Down Expand Up @@ -346,7 +346,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
self.reason_id = 0
x = np.zeros_like(b)

feed_dict[self.problem.w_hat] = x
feed_dict[self.problem.dw] = x
Ax_0 = self.sess.run(self.Aop,feed_dict = feed_dict)
# Calculate initial residual r = Ax_0 -b
r = b - Ax_0
Expand All @@ -359,7 +359,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
from scipy.sparse.linalg import LinearOperator

def Ap(p):
feed_dict[self.problem.w_hat] = p
feed_dict[self.problem.dw] = p
return self.sess.run(self.Aop,feed_dict = feed_dict)

n = self.problem.dimension
Expand Down
6 changes: 3 additions & 3 deletions hessianlearn/algorithms/gmresSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __init__(self,problem,regularization,sess = None,preconditioner = None,\
self.parameters = parameters


self.Aop = self.problem.H_action + self.regularization.H_action
self.Aop = self.problem.Hdw + self.regularization.Hdw

# # Define preconditioner
# if preconditioner is None:
Expand All @@ -104,7 +104,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
self.reason_id = 0
x = np.zeros_like(b)

feed_dict[self.problem.w_hat] = x
feed_dict[self.problem.dw] = x
Ax_0 = self.sess.run(self.Aop,feed_dict = feed_dict)
# Calculate initial residual r = Ax_0 -b
r = b - Ax_0
Expand All @@ -117,7 +117,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
from scipy.sparse.linalg import LinearOperator

def Ap(p):
feed_dict[self.problem.w_hat] = p
feed_dict[self.problem.dw] = p
return self.sess.run(self.Aop,feed_dict = feed_dict)

n = self.problem.dimension
Expand Down
16 changes: 8 additions & 8 deletions hessianlearn/algorithms/inexactNewtonCG.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,21 +109,21 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None):
if hessian_feed_dict is None:
hessian_feed_dict = feed_dict

self.gradient = self.sess.run(self.grad,feed_dict = feed_dict)
gradient = self.sess.run(self.grad,feed_dict = feed_dict)



if self.parameters['globalization'] is None:
self.alpha = self.parameters['alpha']
p,on_boundary = self.cg_solver.solve(-self.gradient,hessian_feed_dict)
p,on_boundary = self.cg_solver.solve(-gradient,hessian_feed_dict)
self._sweeps += [1,2*self.cg_solver.iter]
self.p = p
update = self.alpha*p
self.sess.run(self.problem._update_ops,feed_dict = {self.problem._update_placeholder:update})

if self.parameters['globalization'] == 'line_search':
w_dir,on_boundary = self.cg_solver.solve(-self.gradient,hessian_feed_dict)
w_dir_inner_g = np.inner(w_dir,self.gradient)
w_dir,on_boundary = self.cg_solver.solve(-gradient,hessian_feed_dict)
w_dir_inner_g = np.inner(w_dir,gradient)
initial_cost = self.sess.run(self.problem.loss,feed_dict = feed_dict)
cost_at_candidate = lambda p : self._loss_at_candidate(p,feed_dict = feed_dict)
self.alpha, line_search, line_search_iter = ArmijoLineSearch(w_dir,w_dir_inner_g,\
Expand All @@ -138,14 +138,14 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None):
self.initialize_trust_region()
# Set trust region radius
self.cg_solver.set_trust_region_radius(self.trust_region.radius)
p,on_boundary = self.cg_solver.solve(-self.gradient,feed_dict)
p,on_boundary = self.cg_solver.solve(-gradient,feed_dict)
self._sweeps += [1,2*self.cg_solver.iter]
self.p = p
# Solve for candidate step
p, on_boundary = self.cg_solver.solve(-self.gradient,hessian_feed_dict)
pg = np.dot(p,self.gradient)
p, on_boundary = self.cg_solver.solve(-gradient,hessian_feed_dict)
pg = np.dot(p,gradient)
# Calculate predicted reduction
feed_dict[self.cg_solver.problem.w_hat] = p
feed_dict[self.cg_solver.problem.dw] = p
Hp = self.sess.run(self.cg_solver.Aop,feed_dict)
pHp = np.dot(p,Hp)
predicted_reduction = -pg-0.5*pHp
Expand Down
11 changes: 5 additions & 6 deletions hessianlearn/algorithms/lowRankSaddleFreeNewton.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def ParametersLowRankSaddleFreeNewton(parameters = {}):
# Hessian approximation parameters
parameters['range_finding'] = [None,"Range finding, if None then r = hessian_low_rank\
Choose from None, 'arf', 'naarf', 'vn'"]
parameters['range_rel_error_tolerance'] = [100, "Error tolerance for error estimator in adaptive range finding"]
parameters['range_rel_error_tolerance'] = [1000, "Error tolerance for error estimator in adaptive range finding"]
parameters['range_abs_error_tolerance'] = [100, "Error tolerance for error estimator in adaptive range finding"]
parameters['range_block_size'] = [10, "Block size used in range finder"]
parameters['range_block_size'] = [5, "Block size used in range finder"]
parameters['rq_samples_for_naarf'] = [100, "Number of partitions for RQ variance evaluation"]
parameters['hessian_low_rank'] = [20, "Fixed rank for randomized eigenvalue decomposition"]
# Variance Nystrom Parameters
Expand All @@ -59,7 +59,7 @@ def ParametersLowRankSaddleFreeNewton(parameters = {}):
parameters['max_backtracking_iter'] = [5, 'Max backtracking iterations for armijo line search']

parameters['verbose'] = [False, "Printing"]
parameters['record_last_rq_std'] = [True, "Record the last eigenvector RQ variance"]
parameters['record_last_rq_std'] = [False, "Record the last eigenvector RQ variance"]

return ParameterList(parameters)

Expand Down Expand Up @@ -144,7 +144,6 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None,rq_estimator_dict =
elif self.parameters['range_finding'] == 'naarf':
norm_g = np.linalg.norm(gradient)
tolerance = self.parameters['range_rel_error_tolerance']*norm_g
noise_tolerance = 0.01*tolerance
if rq_estimator_dict is None:
rq_estimator_dict_list = self.problem._partition_dictionaries(feed_dict,self.parameters['rq_samples_for_naarf'])
elif type(rq_estimator_dict) == list:
Expand All @@ -153,7 +152,7 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None,rq_estimator_dict =
rq_estimator_dict_list = self.problem._partition_dictionaries(rq_estimator_dict,self.parameters['rq_samples_for_naarf'])
else:
raise
Q = noise_aware_adaptive_range_finder(self.H,hessian_feed_dict,rq_estimator_dict_list,block_size = self.parameters['range_block_size'],noise_tolerance = noise_tolerance,epsilon = tolerance)
Q = noise_aware_adaptive_range_finder(self.H,hessian_feed_dict,rq_estimator_dict_list,block_size = self.parameters['range_block_size'],epsilon = tolerance)
self._rank = Q.shape[1]
H = lambda x: self.H(x,hessian_feed_dict,verbose = self.parameters['verbose'])
Lmbda,U = eigensolver_from_range(H,Q)
Expand Down Expand Up @@ -183,8 +182,8 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None,rq_estimator_dict =
n = self.problem.dimension
self._rank = self.parameters['hessian_low_rank']
Lmbda,U = randomized_eigensolver(H, n, self._rank,verbose=False)
self.lambdas = Lmbda

self.eigenvalues = Lmbda
# Log the variance of the last eigenvector
if self.parameters['record_last_rq_std'] :
try:
Expand Down
6 changes: 3 additions & 3 deletions hessianlearn/algorithms/minresSolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def __init__(self,problem,regularization,sess = None,preconditioner = None,\
self.parameters = parameters


self.Aop = self.problem.H_action + self.regularization.H_action
self.Aop = self.problem.Hdw + self.regularization.Hdw

# # Define preconditioner
# if preconditioner is None:
Expand All @@ -111,7 +111,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
self.reason_id = 0
x = np.zeros_like(b)

feed_dict[self.problem.w_hat] = x
feed_dict[self.problem.dw] = x
Ax_0 = self.sess.run(self.Aop,feed_dict = feed_dict)
# Calculate initial residual r = Ax_0 -b
r = b - Ax_0
Expand All @@ -124,7 +124,7 @@ def solve(self,b,feed_dict = None,x_0 = None):
from scipy.sparse.linalg import LinearOperator

def Ap(p):
feed_dict[self.problem.w_hat] = p
feed_dict[self.problem.dw] = p
return self.sess.run(self.Aop,feed_dict = feed_dict)

n = self.problem.dimension
Expand Down
10 changes: 5 additions & 5 deletions hessianlearn/algorithms/randomizedEigensolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def randomized_eigensolver(Aop, n, k, p = None,seed = 0,verbose = False):
-----------
Aop : {Callable} n x n
Hermitian matrix operator whose eigenvalues need to be estimated
y = Aop(w_hat) is the action of A in the direction w_hat
y = Aop(dw) is the action of A in the direction dw
n : int,
number of row/columns of the operator A
Expand All @@ -69,7 +69,7 @@ def randomized_eigensolver(Aop, n, k, p = None,seed = 0,verbose = False):
>>> import numpy as np
>>> n = 100
>>> A = np.diag(0.95**np.arange(n))
>>> Aop = lambda w_hat: np.dot(A,w_hat)
>>> Aop = lambda dw: np.dot(A,dw)
>>> k = 10
>>> p = 5
>>> lmbda, U = randomized_eigensolver(Aop, n, k, p)
Expand Down Expand Up @@ -125,7 +125,7 @@ def eigensolver_from_range(Aop, Q,verbose = False):
-----------
Aop : {Callable} n x n
Hermitian matrix operator whose eigenvalues need to be estimated
y = Aop(w_hat) is the action of A in the direction w_hat
y = Aop(dw) is the action of A in the direction dw
Q : Array n x r
Expand Down Expand Up @@ -169,7 +169,7 @@ def randomized_double_pass_eigensolver(Aop, Y, k):
-----------
Aop : {Callable} n x n
Hermitian matrix operator whose eigenvalues need to be estimated
y = Aop(w_hat) is the action of A in the direction w_hat
y = Aop(dw) is the action of A in the direction dw
Y = Aop(Omega) : precomputed action of Aop on Omega, a m x n Array of (presumably) sampled Gaussian or l-percent sparse random vectors (row)
k : int,
number of eigenvalues/vectors to be estimated, 0 < k < m
Expand All @@ -190,7 +190,7 @@ def randomized_double_pass_eigensolver(Aop, Y, k):
>>> import numpy as np
>>> n = 100
>>> A = np.diag(0.95**np.arange(n))
>>> Aop = lambda w_hat: np.dot(A,w_hat)
>>> Aop = lambda dw: np.dot(A,dw)
>>> k = 10
>>> p = 5
>>> Omega = np.random.randn(n, k+p)
Expand Down
Loading

0 comments on commit f0cdbaf

Please sign in to comment.