diff --git a/.travis.yml b/.travis.yml index b97655c..d782653 100644 --- a/.travis.yml +++ b/.travis.yml @@ -17,8 +17,8 @@ language: python python: - - "3.5" - "3.6" + - "3.7" install: - sudo apt-get update - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh -O miniconda.sh; @@ -31,7 +31,7 @@ install: # # Useful for debugging any issues with conda # - conda info -a # Replace dep1 dep2 ... with your dependencies - - conda create -n hessianlearn2 python=$TRAVIS_PYTHON_VERSION tensorflow scipy + - conda create -n hessianlearn2 python=$TRAVIS_PYTHON_VERSION tensorflow=2.0.0 scipy - conda activate hessianlearn2 # # - python setup.py install script: diff --git a/README.md b/README.md index b8607e2..df8a755 100644 --- a/README.md +++ b/README.md @@ -63,12 +63,15 @@ Set `HESSIANLEARN_PATH` environmental variable Train a keras model ```python +import os,sys import tensorflow as tf sys.path.append( os.environ.get('HESSIANLEARN_PATH')) from hessianlearn import * # Define keras neural network model neural_network = tf.keras.models.Model(...) +# Define loss function and compile model +neural_network.compile(loss = ...) ``` @@ -77,7 +80,9 @@ hessianlearn implements various training [`problem`](https://github.com/tomolear ```python # Instantiate the problem (this handles the loss function, # construction of hessian and gradient etc.) -problem = RegressionProblem(neural_network,dtype = tf.float32) +# KerasModelProblem extracts loss function and metrics from +# a compiled keras model +problem = KerasModelProblem(neural_network) # Instantiate the data object, this handles the train / validation split # as well as iterating during training data = Data({problem.x:x_data,problem.y_true:y_data},train_batch_size,\ @@ -94,6 +99,40 @@ HLModel = HessianlearnModel(problem,regularization,data) HLModel.fit() ``` +### Alternative Usage (More like Keras Interface) +The example above was the original way the optimizer interface was implemented in hessianlearn, however to better mimic the keras interface and allow for more end-user rapid prototyping of the optimizer that is used to fit data, as of December 2021, the following way has been created + +```python +import os,sys +import tensorflow as tf +sys.path.append( os.environ.get('HESSIANLEARN_PATH')) +from hessianlearn import * + +# Define keras neural network model +neural_network = tf.keras.models.Model(...) +# Define loss function and compile model +neural_network.compile(loss = ...) +# Instance keras model wrapper which deals with the +# construction of the `problem` which handles the construction +# of Hessian computational graph and variables +HLModel = KerasModelWrapper(neural_network) +# Then the end user can pass in an optimizer +# (e.g. custom end-user optimizer) +optimizer = LowRankSaddleFreeNewton # The class constructor, not an instance +optparameters = LowRankSaddleFreeNewtonParameters() +optimizer_parameters['hessian_low_rank'] = 40 +HLModel.set_optimizer(optimizer,optimizer_parameters = optparameters) +# The data object still needs to key on to the specific computational +# graph variables that data will be passed in for. +# Note that data can naturally handle multiple input and output data, +# in which case problem.x, problem.y_true are lists corresponding to +# neural_network.inputs, neural_network.outputs +problem = HLModel.problem +data = Data({problem.x:x_data,problem.y_true:y_data},train_batch_size,\ + validation_data_size = validation_data_size) +# And finally one can call fit! +HLModel.fit(data) +``` ## Examples @@ -108,7 +147,7 @@ These publications motivate and use the hessianlearn library for stochastic nonc [**Inexact Newton Methods for Stochastic Nonconvex Optimization with Applications to Neural Network Training**](https://arxiv.org/abs/1905.06738). arXiv:1905.06738. ([Download](https://arxiv.org/pdf/1905.06738.pdf))
BibTeX
-@article{o2019inexact,
+@article{OLearyRoseberryAlgerGhattas2019,
   title={Inexact Newton methods for stochastic nonconvex optimization with applications to neural network training},
   author={O'Leary-Roseberry, Thomas and Alger, Nick and Ghattas, Omar},
   journal={arXiv preprint arXiv:1905.06738},
@@ -117,10 +156,10 @@ arXiv:1905.06738.
 }
- \[2\] O'Leary-Roseberry, T., Alger, N., Ghattas O., -[**Low Rank Saddle Free Newton**](https://arxiv.org/abs/2002.02881). +[**Low Rank Saddle Free Newton: A Scalable Method for Stochastic Nonconvex Optimization**](https://arxiv.org/abs/2002.02881). arXiv:2002.02881. ([Download](https://arxiv.org/pdf/2002.02881.pdf))
BibTeX
-@article{o2020low,
+@article{OLearyRoseberryAlgerGhattas2020,
   title={Low Rank Saddle Free Newton: Algorithm and Analysis},
   author={O'Leary-Roseberry, Thomas and Alger, Nick and Ghattas, Omar},
   journal={arXiv preprint arXiv:2002.02881},
@@ -133,11 +172,14 @@ arXiv:2002.02881.
 [**Derivative-Informed Projected Neural Networks for High-Dimensional Parametric Maps Governed by PDEs**](https://arxiv.org/abs/2011.15110).
 arXiv:2011.15110.
 ([Download](https://arxiv.org/pdf/2011.15110.pdf))
BibTeX
-@article{o2020derivative,
-  title={Derivative-Informed Projected Neural Networks for High-Dimensional Parametric Maps Governed by PDEs},
-  author={O'Leary-Roseberry, Thomas and Villa, Umberto and Chen, Peng and Ghattas, Omar},
-  journal={arXiv preprint arXiv:2011.15110},
-  year={2020}
+@article{OLearyRoseberryVillaChenEtAl2022,
+  title={Derivative-informed projected neural networks for high-dimensional parametric maps governed by {PDE}s},
+  author={O’Leary-Roseberry, Thomas and Villa, Umberto and Chen, Peng and Ghattas, Omar},
+  journal={Computer Methods in Applied Mechanics and Engineering},
+  volume={388},
+  pages={114199},
+  year={2022},
+  publisher={Elsevier}
 }
 }
diff --git a/applications/transfer_learning/imagenet_cifar100_classification_evaluate_test.py b/applications/transfer_learning/imagenet_cifar100_classification_evaluate_test.py index 8721722..64746ec 100644 --- a/applications/transfer_learning/imagenet_cifar100_classification_evaluate_test.py +++ b/applications/transfer_learning/imagenet_cifar100_classification_evaluate_test.py @@ -114,7 +114,6 @@ pretrained_resnet50 = tf.keras.applications.resnet50.ResNet50(weights = 'imagenet',include_top=False,input_tensor=input_tensor) - for layer in pretrained_resnet50.layers[:143]: layer.trainable = False diff --git a/applications/transfer_learning/imagenet_cifar10_classification_evaluate_test.py b/applications/transfer_learning/imagenet_cifar10_classification_evaluate_test.py new file mode 100644 index 0000000..e2160e9 --- /dev/null +++ b/applications/transfer_learning/imagenet_cifar10_classification_evaluate_test.py @@ -0,0 +1,180 @@ +# This file is part of the hessianlearn package +# +# hessianlearn is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or any later version. +# +# hessianlearn is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# If not, see . +# +# Author: Tom O'Leary-Roseberry +# Contact: tom.olearyroseberry@utexas.edu + + +import numpy as np +import os +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ['KMP_DUPLICATE_LIB_OK']='True' +os.environ["KMP_WARNINGS"] = "FALSE" +# os.environ['CUDA_VISIBLE_DEVICES'] = '1' +import pickle +import tensorflow as tf +import time, datetime +# if int(tf.__version__[0]) > 1: +# import tensorflow.compat.v1 as tf +# tf.disable_v2_behavior() + + +# Memory issue with GPUs +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +for device in gpu_devices: + tf.config.experimental.set_memory_growth(device, True) +# Load hessianlearn library +import sys +sys.path.append( os.environ.get('HESSIANLEARN_PATH', "../../")) +from hessianlearn import * + +# Parse run specifications +from argparse import ArgumentParser + +parser = ArgumentParser(add_help=True) +parser.add_argument("-optimizer", dest='optimizer',required=False, default = 'lrsfn', help="optimizer type",type=str) +parser.add_argument('-fixed_step',dest = 'fixed_step',\ + required= False,default = 1,help='boolean for fixed step vs globalization',type = int) +parser.add_argument('-alpha',dest = 'alpha',required = False,default = 1e-4,help= 'learning rate alpha',type=float) +parser.add_argument('-hessian_low_rank',dest = 'hessian_low_rank',required= False,default = 40,help='low rank for sfn',type = int) +parser.add_argument('-record_spectrum',dest = 'record_spectrum',\ + required= False,default = 0,help='boolean for recording spectrum',type = int) +# parser.add_argument('-weight_burn_in',dest = 'weight_burn_in',\ +# required= False,default = 0,help='',type = int) + +# parser.add_argument('-data_seed',dest = 'data_seed',\ +# required= False,default = 0,help='',type = int) + +parser.add_argument('-batch_size',dest = 'batch_size',required= False,default = 32,help='batch size',type = int) +parser.add_argument('-hess_batch_size',dest = 'hess_batch_size',required= False,default = 8,help='hess batch size',type = int) +parser.add_argument('-keras_epochs',dest = 'keras_epochs',required= False,default = 50,help='keras_epochs',type = int) +parser.add_argument("-keras_opt", dest='keras_opt',required=False, default = 'adam', help="optimizer type for keras",type=str) +parser.add_argument('-keras_alpha',dest = 'keras_alpha',required= False,default = 1e-3,help='keras learning rate',type = float) +parser.add_argument('-max_sweeps',dest = 'max_sweeps',required= False,default = 1,help='max sweeps',type = float) +parser.add_argument('-weights_file',dest = 'weights_file',required= False,default = 'None',help='weight file pickle',type = str) + +args = parser.parse_args() + +try: + tf.set_random_seed(0) +except: + tf.random.set_seed(0) + +# GPU Environment Details +gpu_availabe = tf.test.is_gpu_available() +built_with_cuda = tf.test.is_built_with_cuda() +print(80*'#') +print(('IS GPU AVAILABLE: '+str(gpu_availabe)).center(80)) +print(('IS BUILT WITH CUDA: '+str(built_with_cuda)).center(80)) +print(80*'#') + +settings = {} +# Set run specifications +# Data specs +settings['batch_size'] = args.batch_size +settings['hess_batch_size'] = args.hess_batch_size + + +################################################################################ +# Instantiate data +(x_train, y_train), (_x_test, _y_test) = tf.keras.datasets.cifar10.load_data() + +# # Normalize the data +# x_train = x_train.astype('float32') / 255. +# x_test = x_test.astype('float32') / 255. + +x_train = tf.keras.applications.resnet50.preprocess_input(x_train) +x_test_full = tf.keras.applications.resnet50.preprocess_input(_x_test) +x_val = x_test_full[:2000] +x_test = x_test_full[2000:] + +y_train = tf.keras.utils.to_categorical(y_train) +y_test_full = tf.keras.utils.to_categorical(_y_test) +y_val = y_test_full[:2000] +y_test = y_test_full[2000:] + +################################################################################ +# Create the neural network in keras + +# tf.keras.backend.set_floatx('float64') + +resnet_input_shape = (200,200,3) +input_tensor = tf.keras.Input(shape = resnet_input_shape) + +pretrained_resnet50 = tf.keras.applications.resnet50.ResNet50(weights = 'imagenet',include_top=False,input_tensor=input_tensor) + +for layer in pretrained_resnet50.layers[:143]: + layer.trainable = False + +classifier = tf.keras.models.Sequential() +classifier.add(tf.keras.layers.Input(shape=(32,32,3))) +classifier.add(tf.keras.layers.Lambda(lambda image: tf.image.resize(image, resnet_input_shape[:2]))) +classifier.add(pretrained_resnet50) +classifier.add(tf.keras.layers.Flatten()) +classifier.add(tf.keras.layers.BatchNormalization()) +classifier.add(tf.keras.layers.Dense(64, activation='relu')) +classifier.add(tf.keras.layers.Dropout(0.5)) +classifier.add(tf.keras.layers.BatchNormalization()) +classifier.add(tf.keras.layers.Dense(10, activation='softmax')) + + +if args.keras_opt == 'adam': + optimizer = tf.keras.optimizers.Adam(learning_rate = args.keras_alpha,epsilon = 1e-8) +elif args.keras_opt == 'sgd': + optimizer = tf.keras.optimizers.SGD(learning_rate=args.keras_alpha) +else: + raise + +classifier.compile(optimizer=optimizer, + loss=tf.keras.losses.CategoricalCrossentropy(from_logits = True), + metrics=['accuracy']) + +loss_test_0, acc_test_0 = classifier.evaluate(x_test,y_test,verbose=2) +print('acc_test = ',acc_test_0) +loss_val_0, acc_val_0 = classifier.evaluate(x_val,y_val,verbose=2) +print('acc_val = ',acc_val_0) + + +if args.weights_file is not 'None': + try: + logger = open(args.weights_file, 'rb') + best_weights = pickle.load(logger)['best_weights'] + for layer_name,weight in best_weights.items(): + classifier.get_layer(layer_name).set_weights(weight) + except: + print('Issue loading best weights') + +loss_test_final, acc_test_final = classifier.evaluate(x_test,y_test,verbose=2) +print('acc_test final = ',acc_test_final) +loss_val_final, acc_val_final = classifier.evaluate(x_val,y_val,verbose=2) +print('acc_val final = ',acc_val_final) + +################################################################################ +# Evaluate again on all the data. +(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() + +# # Normalize the data +# x_train = x_train.astype('float32') / 255. +# x_test = x_test.astype('float32') / 255. + +x_train = tf.keras.applications.resnet50.preprocess_input(x_train) +x_test = tf.keras.applications.resnet50.preprocess_input(x_test) + +y_train = tf.keras.utils.to_categorical(y_train) +y_test = tf.keras.utils.to_categorical(y_test) + +loss_test_total, acc_test_total = classifier.evaluate(x_test,y_test,verbose=2) +print(80*'#') +print('After hessianlearn training'.center(80)) +print('acc_test_total = ',acc_test_total) diff --git a/hessianlearn/algorithms/adam.py b/hessianlearn/algorithms/adam.py index 54058f5..aaae4f6 100644 --- a/hessianlearn/algorithms/adam.py +++ b/hessianlearn/algorithms/adam.py @@ -89,14 +89,15 @@ def minimize(self,feed_dict = None): gradient = self.sess.run(self.grad,feed_dict = feed_dict) self.m = self.parameters['beta_1']*self.m + (1-self.parameters['beta_1'])*gradient - # m_hat = [m/(1 - self.parameters['beta_1']**self.iter) for m in self.m] + m_hat = self.m / (1.0 - self.parameters['beta_1']**self._iter) g_sq_vec = np.square(gradient) self.v = self.parameters['beta_2']*self.v + (1-self.parameters['beta_2'])*g_sq_vec - v_root = np.sqrt(self.v) + v_hat = self.v / (1.0 - self.parameters['beta_2']**self._iter) + v_root = np.sqrt(v_hat) - update = -alpha*self.m/(v_root +self.parameters['epsilon']) + update = -alpha*m_hat/(v_root +self.parameters['epsilon']) self.p = update self._sweeps += [1,0] self.sess.run(self.problem._update_ops,feed_dict = {self.problem._update_placeholder:update}) diff --git a/hessianlearn/algorithms/inexactNewtonCG.py b/hessianlearn/algorithms/inexactNewtonCG.py index 5c009e1..f03d95d 100644 --- a/hessianlearn/algorithms/inexactNewtonCG.py +++ b/hessianlearn/algorithms/inexactNewtonCG.py @@ -137,13 +137,12 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None): if not self.trust_region_initialized: self.initialize_trust_region() # Set trust region radius - self.cg_solver.set_trust_region_radius(self.trust_region.radius) - p,on_boundary = self.cg_solver.solve(-gradient,feed_dict) - self._sweeps += [1,2*self.cg_solver.iter] - self.p = p + self.cg_solver.set_trust_region_radius(self.trust_region.radius) # Solve for candidate step p, on_boundary = self.cg_solver.solve(-gradient,hessian_feed_dict) pg = np.dot(p,gradient) + self._sweeps += [1,2*self.cg_solver.iter] + self.p = p # Calculate predicted reduction feed_dict[self.cg_solver.problem.dw] = p Hp = self.sess.run(self.cg_solver.Aop,feed_dict) diff --git a/hessianlearn/algorithms/inexactNewtonMINRES.py b/hessianlearn/algorithms/inexactNewtonMINRES.py index fb1e518..c6d53bc 100644 --- a/hessianlearn/algorithms/inexactNewtonMINRES.py +++ b/hessianlearn/algorithms/inexactNewtonMINRES.py @@ -118,6 +118,5 @@ def minimize(self,feed_dict = None,hessian_feed_dict = None): - \ No newline at end of file diff --git a/hessianlearn/algorithms/lowRankSaddleFreeNewton.py b/hessianlearn/algorithms/lowRankSaddleFreeNewton.py index 791a935..8f03b8c 100644 --- a/hessianlearn/algorithms/lowRankSaddleFreeNewton.py +++ b/hessianlearn/algorithms/lowRankSaddleFreeNewton.py @@ -35,7 +35,7 @@ def ParametersLowRankSaddleFreeNewton(parameters = {}): - parameters['alpha'] = [1e0, "Initial steplength, or learning rate"] + parameters['alpha'] = [1e-3, "Initial steplength, or learning rate"] parameters['rel_tolerance'] = [1e-3, "Relative convergence when sqrt(g,g)/sqrt(g_0,g_0) <= rel_tolerance"] parameters['abs_tolerance'] = [1e-4,"Absolute converge when sqrt(g,g) <= abs_tolerance"] parameters['default_damping'] = [1e-3, "Levenberg-Marquardt damping when no regularization is used"] @@ -95,6 +95,8 @@ def __init__(self,problem,regularization = None,sess = None,parameters = Paramet self._rq_std = 0.0 + self.eigenvalues = None + @property def rank(self): return self._rank diff --git a/hessianlearn/algorithms/optimizer.py b/hessianlearn/algorithms/optimizer.py index 29fd4c2..e5a5e6e 100644 --- a/hessianlearn/algorithms/optimizer.py +++ b/hessianlearn/algorithms/optimizer.py @@ -88,6 +88,18 @@ def iter(self): def regularization(self): return self._regularization + @property + def set_sess(self): + return self._set_sess + + + def _set_sess(self,sess): + r""" + Sets the tf.Session() + """ + self._sess = sess + if 'H' in dir(self): + self.H._sess = sess def minimize(self): r""" diff --git a/hessianlearn/model/__init__.py b/hessianlearn/model/__init__.py index 5c4504a..d518603 100644 --- a/hessianlearn/model/__init__.py +++ b/hessianlearn/model/__init__.py @@ -16,3 +16,5 @@ # Contact: tom.olearyroseberry@utexas.edu from .model import HessianlearnModel, HessianlearnModelSettings + +from .kerasModelWrapper import KerasModelWrapper, KerasModelWrapperSettings diff --git a/hessianlearn/model/kerasModelWrapper.py b/hessianlearn/model/kerasModelWrapper.py new file mode 100644 index 0000000..0e125fe --- /dev/null +++ b/hessianlearn/model/kerasModelWrapper.py @@ -0,0 +1,592 @@ +# This file is part of the hessianlearn package +# +# hessianlearn is free software: you can redistribute it and/or modify +# it under the terms of the GNU Lesser General Public License as published by +# the Free Software Foundation, either version 3 of the License, or any later version. +# +# hessianlearn is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public License +# If not, see . +# +# Author: Tom O'Leary-Roseberry +# Contact: tom.olearyroseberry@utexas.edu + + +from __future__ import absolute_import, division, print_function +import numpy as np +import tensorflow as tf +# tf.compat.v1.enable_eager_execution() +if int(tf.__version__[0]) > 1: + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() + # tf.enable_eager_execution() + +from abc import ABC, abstractmethod +import warnings + +import sys, os, pickle, time, datetime + + +from ..utilities.parameterList import ParameterList + +from ..problem.problem import KerasModelProblem + +# from ..algorithms import * + +from ..algorithms.adam import Adam +from ..algorithms.gradientDescent import GradientDescent +# from ..algorithms.cgSolver import CGSolver +from ..algorithms.inexactNewtonCG import InexactNewtonCG +# from ..algorithms.gmresSolver import GMRESSolver +# from ..algorithms.inexactNewtonGMRES import InexactNewtonGMRES +# from ..algorithms.minresSolver import MINRESSolver +# from ..algorithms.inexactNewtonMINRES import InexactNewtonMINRES +from ..algorithms.randomizedEigensolver import * +from ..problem.regularization import L2Regularization +from ..algorithms.lowRankSaddleFreeNewton import LowRankSaddleFreeNewton + +from ..problem.hessian import Hessian, HessianWrapper +from ..algorithms.varianceBasedNystrom import variance_based_nystrom + + + +def KerasModelWrapperSettings(settings = {}): + settings['problem_name'] = ['', "string for name used in file naming"] + settings['title'] = [None, "string for name used in plotting"] + settings['logger_outname'] = [None, "string for name used in logger file naming"] + settings['printing_items'] = [{'sweeps':'sweeps','Loss':'train_loss','acc ':'train_acc',\ + '||g||':'||g||','Lossval':'val_loss','accval':'val_acc',\ + 'maxacc':'max_val_acc','alpha':'alpha'},\ + "Dictionary of items for printing"] + settings['printing_sweep_frequency'] = [1, "Print only every this many sweeps"] + settings['validate_frequency'] = [1, "Only compute validation quantities every X sweeps"] + settings['save_weights'] = [True, "Whether or not to save the best weights"] + settings['max_sweeps'] = [10,"Maximum number of times through the data (measured in epoch equivalents"] + + + settings['verbose'] = [True, "Boolean for printing"] + + settings['intra_threads'] = [2, "Setting for intra op parallelism"] + settings['inter_threads'] = [2, "Setting for inter op parallelism"] + + + + # Initial weights for specific layers + settings['layer_weights'] = [{},"Dictionary of layer name key and weight \ + values for weights set after global variable initialization "] + + # Settings for recording spectral information during training + settings['record_spectrum'] = [False, "Boolean for recording spectrum during training"] + settings['target_rank'] = [100,"Target rank for randomized eigenvalue solver"] + settings['oversample'] = [10,"Oversampling for randomized eigenvalue solver"] + + return ParameterList(settings) + + +class KerasModelWrapper(ABC): + def __init__(self,kerasModel,regularization= None,optimizer = None,\ + optimizer_parameters = None,hessian_block_size = None, settings = KerasModelWrapperSettings({})): + warnings.warn('Experimental Class! Be Wary') + # Check hessian blocking condition here? + if optimizer_parameters is not None: + if ('hessian_low_rank' in optimizer_parameters.data.keys()) and (hessian_block_size is not None): + hessian_block_size = max(optimizer_parameters['hessian_low_rank'],hessian_block_size) + + + self._problem = KerasModelProblem(kerasModel,hessian_block_size = hessian_block_size) + if regularization is None: + # If regularization is not passed in, default to zero Tikhonov + self._regularization = L2Regularization(self._problem, 0.0) + else: + self._regularization = regularization + + self.settings = settings + + if optimizer is not None: + if optimizer_parameters is None: + self.set_optimizer(optimizer,regularization = self.regularization) + else: + self.set_optimizer(optimizer,regularization = self.regularization,parameters = optimizer_parameters) + + + + @property + def sess(self): + return self._sess + + @property + def optimizer(self): + return self._optimizer + + @property + def fit(self): + return self._fit + + @property + def problem(self): + return self._problem + + @property + def regularization(self): + return self._regularization + + @property + def set_optimizer(self): + return self._set_optimizer + + + + @property + def logger(self): + return self._logger + + def _set_optimizer(self,optimizer,parameters = None): + if parameters is None: + self._optimizer = optimizer(self.problem, regularization = self.regularization,sess = None) + else: + self._optimizer = optimizer(self.problem, regularization = self.regularization,sess = None,parameters = parameters) + # If larger Hessian spectrum is requested, reinitialize blocking for faster Hessian evaluations + if 'hessian_low_rank' in self._optimizer.parameters.data.keys(): + if self.problem._hessian_block_size is None: + self.problem._initialize_hessian_blocking(self.optimizer.parameters['hessian_low_rank']) + elif self.problem._hessian_block_size < self.optimizer.parameters['hessian_low_rank']: + self.problem._initialize_hessian_blocking(self.optimizer.parameters['hessian_low_rank']) + + + + + def _initialize_logging(self): + # Initialize Logging + logger = {} + logger['dimension'] = self.problem.dimension + logger['problem_name'] = self.settings['problem_name'] + logger['title'] = self.settings['title'] + logger['batch_size'] = self.data._batch_size + logger['hessian_batch_size'] = self.data._hessian_batch_size + logger['train_loss'] = {} + logger['val_loss'] = {} + logger['||g||'] ={} + logger['sweeps'] = {} + logger['total_time'] = {} + logger['time'] = {} + logger['best_weights'] = None + logger['optimizer'] = None + logger['alpha'] = None + logger['globalization'] = None + logger['hessian_low_rank'] = {} + + logger['val_acc'] = {} + logger['train_acc'] = {} + + logger['max_val_acc'] = {} + logger['alpha'] = {} + + if hasattr(self.problem, 'metric_dict'): + for metric_name in self.problem.metric_dict.keys(): + logger[metric_name] = {} + + if self.settings['record_spectrum']: + logger['full_train_eigenvalues'] = {} + logger['train_eigenvalues'] = {} + logger['val_eigenvalues'] = {} + + elif 'eigenvalues' in dir(self._optimizer): + logger['train_eigenvalues'] = {} + + + self._logger = logger + + os.makedirs(self.settings['problem_name']+'_logging/',exist_ok = True) + os.makedirs(self.settings['problem_name']+'_best_weights/',exist_ok = True) + + + # Set outname for logging file + if self.settings['logger_outname'] is None: + logger_outname = str(datetime.date.today())+'-dW='+str(self.problem.dimension) + else: + logger_outname = self.settings['logger_outname'] + self.logger_outname = logger_outname + + + + def _fit(self,data,options = None, w_0 = None): + self.data = data + if self.settings['verbose']: + print(80*'#') + print(('Size of configuration space: '+str(self.problem.dimension)).center(80)) + print(('Size of training data: '+str(self.data.train_data_size)).center(80)) + print(('Approximate data cardinality needed: '\ + +str(int(float(self.problem.dimension)/self.problem.output_dimension ))).center(80)) + print(80*'#') + + with tf.Session(config=tf.ConfigProto(intra_op_parallelism_threads=self.settings['intra_threads'],\ + inter_op_parallelism_threads=self.settings['inter_threads'])) as sess: + # Re initialize data + self.data.reset() + # Initialize logging: + self._initialize_logging() + # Initialize the optimizer + self._optimizer.set_sess(sess) + # After optimizer is instantiated, we call the global variables initializer + sess.run(tf.global_variables_initializer()) + ################################################################################ + # Load initial guess if requested: + if w_0 is not None: + if type(w_0) is list: + self._problem._NN.set_weights(w_0) + else: + try: + sess.run(self.problem._assignment_ops,feed_dict = {self.problem._assignment_placeholder:w_0}) + except: + print(80*'#') + print('Issue setting weights manually'.center(80)) + print('tf.global_variables_initializer() used to initial instead'.center(80)) + # This handles a corner case for weights that are not trainable, + # but still get set by the tf.global_variables_initializer() + for layer_name,weight in self.settings['layer_weights'].items(): + self.problem._NN.get_layer(layer_name).set_weights(weight) + ################################################################################ + # First print + if self.settings['verbose']: + self.print(first_print = True) + ################################################################################ + # Load validation data + val_dict = next(iter(self.data.validation)) + if self.problem.is_autoencoder: + assert not hasattr(self.problem,'y_true') + elif self.problem.is_gan: + random_state_gan = np.random.RandomState(seed = 0) + # Should the first dimension here agree with the size of the validation data? + noise = random_state_gan.normal(size = (self.data.batch_size, self.problem.noise_dimension)) + val_dict[self.problem.noise] = noise + + ################################################################################ + # Prepare for iteration + max_sweeps = self.settings['max_sweeps'] + train_data = iter(self.data.train) + sweeps = 0 + min_val_loss = np.inf + max_val_acc = -np.inf + validation_duration = 0.0 + t0 = time.time() + for iteration, (data_g,data_H) in enumerate(zip(self.data.train,self.data.hess_train)): + ################################################################################ + # Unpack data pairs and update dictionary as needed + assert type(data_g) is dict and type(data_H) is dict, 'Old hessianlearn data object has been deprecated, use dictionary iterator now' + train_dict = data_g + hess_dict = data_H + if self.problem.is_autoencoder: + assert not hasattr(self.problem,'y_true') + elif self.problem.is_gan: + assert not hasattr(self.problem,'y_true') + noise = random_state_gan.normal(size = (self.data.batch_size, self.problem.noise_dimension)) + train_dict[self.problem.noise] = noise + noise_hess = random_state_gan.normal(size = (self.data.hessian_batch_size, self.problem.noise_dimension)) + hess_dict[self.problem.noise] = noise_hess + + try: + self.problem.NN.reset_metrics() + except: + pass + + # metric_names = [metric.name for metric in self.problem.NN.metrics] + # metric_evals = sess.run(self.problem.metrics_list,train_dict) + + # for name,evalu in zip(metric_names,metric_evals): + # print('For metric',name,' we have: ',evalu) + + metric_names = list(self.problem.metric_dict.keys()) + metric_evals = sess.run(list(self.problem.metric_dict.values()),train_dict) + + ################################################################################ + # Log time / sweep number + # Every element of dictionary is + # keyed by the optimization iteration + self._logger['total_time'][iteration] = time.time() - t0 - validation_duration + self._logger['sweeps'][iteration] = sweeps + if iteration-1 not in self._logger['time'].keys(): + self._logger['time'][iteration] = self._logger['total_time'][iteration] + else: + self._logger['time'][iteration] = self._logger['total_time'][iteration] - self._logger['total_time'][iteration-1] + self._logger['sweeps'][iteration] = sweeps + # Log information for training data + # Much more efficient to have the actual optimizer / minimize() function + # return this information since it has to query the graph + # This is a place to cut down on computational graph queries + try: + self.problem.NN.reset_metrics() + except: + pass + if hasattr(self.problem,'accuracy'): + norm_g, train_loss, train_acc = sess.run([self.problem.norm_g,self.problem.loss,self.problem.accuracy],train_dict) + self._logger['train_acc'][iteration] = train_acc + else: + norm_g, train_loss = sess.run([self.problem.norm_g,self.problem.loss],train_dict) + self._logger['||g||'][iteration] = norm_g + self._logger['train_loss'][iteration] = train_loss + # Logging of optimization hyperparameters + # These can change at each iteration when using adaptive range finding + # or globalization like line search + self._logger['alpha'][iteration] = self.optimizer.alpha + if hasattr(self.optimizer,'rank'): + self._logger['hessian_low_rank'][iteration] = self.optimizer.rank + + # Update the sweeps + sweeps = np.dot(self.data.batch_factor,self.optimizer.sweeps) + ################################################################################ + # Log for validation data + validate_this_iteration = False + validate_frequency = self.settings['validate_frequency'] + if self.settings['validate_frequency'] is None or iteration == 0: + validate_this_iteration = True + else: + validate_this_iteration = self._check_sweep_remainder_condition(iteration,self.settings['validate_frequency']) + try: + self.problem.NN.reset_metrics() + except: + pass + if hasattr(self.problem,'accuracy'): + if validate_this_iteration: + validation_start = time.time() + if hasattr(self.problem,'metric_dict'): + metric_names = list(self.problem.metric_dict.keys()) + metric_values = sess.run(list(self.problem.metric_dict.values()),train_dict) + for metric_name,metric_value in zip(metric_names,metric_values): + self.logger[metric_name][iteration] = metric_value + if hasattr(self.problem,'_variance_reduction'): + if self.problem.has_derivative_loss: + val_loss, val_acc, val_h1_acc, val_var_red =\ + sess.run([self.problem.loss,self.problem.accuracy,\ + self.problem.h1_accuracy,self.problem.variance_reduction],val_dict) + self._logger['val_h1_acc'][iteration] = val_h1_acc + else: + val_loss, val_acc, val_var_red =\ + sess.run([self.problem.loss,self.problem.accuracy,self.problem.variance_reduction],val_dict) + self._logger['val_variance_reduction'][iteration] = val_var_red + else: + if self.problem.has_derivative_loss: + val_loss, val_acc, val_h1_acc = sess.run([self.problem.loss,self.problem.accuracy,\ + self.problem.h1_accuracy],val_dict) + self._logger['val_h1_acc'][iteration] = val_h1_acc + else: + val_loss, val_acc = sess.run([self.problem.loss,self.problem.accuracy],val_dict) + self._logger['val_acc'][iteration] = val_acc + self._logger['val_loss'][iteration] = val_loss + max_val_acc = max(max_val_acc,val_acc) + min_val_loss = min(min_val_loss,val_loss) + validation_duration += time.time() - validation_start + self._logger['max_val_acc'][iteration] = max_val_acc + else: + if validate_this_iteration: + validation_start = time.time() + val_loss = sess.run(self.problem.loss,val_dict) + validation_duration += time.time() - validation_start + min_val_loss = min(min_val_loss,val_loss) + self._logger['val_loss'][iteration] = val_loss + + ################################################################################ + # Save the best weights based on validation accuracy or loss + if hasattr(self.problem,'accuracy') and val_acc == max_val_acc: + weight_dictionary = {} + for layer in self.problem._NN.layers: + weight_dictionary[layer.name] = self.problem._NN.get_layer(layer.name).get_weights() + self._best_weights = weight_dictionary + if self.settings['save_weights']: + # Save the weights individually, not in the logger + + self._logger['best_weights'] = weight_dictionary + elif val_loss == min_val_loss: + weight_dictionary = {} + if self.problem.is_gan: + weight_dictionary['generator'] = {} + for layer in self.problem._generator.layers: + weight_dictionary['generator'][layer.name] = self.problem._generator.get_layer(layer.name).get_weights() + weight_dictionary['discriminator'] = {} + for layer in self.problem._discriminator.layers: + weight_dictionary['discriminator'][layer.name] = self.problem._discriminator.get_layer(layer.name).get_weights() + else: + for layer in self.problem._NN.layers: + weight_dictionary[layer.name] = self.problem._NN.get_layer(layer.name).get_weights() + self._best_weights = weight_dictionary + if self.settings['save_weights']: + self._logger['best_weights'] = weight_dictionary + ################################################################################ + # Printing + if self.settings['verbose']: + # Print once each epoch + self.print(iteration = iteration) + ################################################################################ + # Checking for nans! + if np.isnan(train_loss) or np.isnan(norm_g): + print(80*'#') + print('Encountered nan, exiting'.center(80)) + print(80*'#') + break + ################################################################################ + # Actual optimization takes place here + try: + self.optimizer.minimize(train_dict,hessian_feed_dict=hess_dict) + except: + self.optimizer.minimize(train_dict) + ################################################################################ + # Recording the spectrum + if not self.settings['record_spectrum'] and 'eigenvalues' in dir(self._optimizer): + try: + self._logger['train_eigenvalues'][iteration] = self.optimizer.eigenvalues + except: + pass + elif self.settings['record_spectrum'] and iteration%self.settings['spec_frequency'] ==0: + self._record_spectrum(iteration) + with open(self.settings['problem_name']+'_logging/'+ self.logger_outname +'.pkl', 'wb+') as f: + pickle.dump(self.logger, f, pickle.HIGHEST_PROTOCOL) + with open(self.settings['problem_name']+'_best_weights/'+ self.logger_outname +'.pkl', 'wb+') as f: + pickle.dump(self._best_weights, f, pickle.HIGHEST_PROTOCOL) + ################################################################################ + # Check if max_sweeps condition has been met + if sweeps > max_sweeps: + # One last print + self.print(iteration = iteration,force_print = True) + break + ################################################################################ + # Post optimization + # The weights need to be manually set once the session scope is closed. + try: + if self.problem.is_gan: + for layer_name in self._best_weights['generator']: + self.problem._generator.get_layer(layer_name).set_weights(self._best_weights['generator'][layer_name]) + for layer_name in self._best_weights['discriminator']: + self.problem._discriminator.get_layer(layer_name).set_weights(self._best_weights['discriminator'][layer_name]) + else: + for layer_name in self._best_weights: + self._problem._NN.get_layer(layer_name).set_weights(self._best_weights[layer_name]) + except: + print('Error setting the weights after training') + + + def _record_spectrum(self,iteration): + k_rank = self.settings['target_rank'] + p_oversample = self.settings['oversample'] + + if self.settings['rayleigh_quotients']: + print('It is working') + my_t0 = time.time() + + train_data = self.data.train._data + val_data = self.data.val._data + + if self.problem.is_autoencoder: + if not (type(self.problem.x) is list): + full_train_dict = {self.problem.x:train_data[self.problem.x]} + full_val_dict = {self.problem.x:val_data[self.problem.x]} + else: + full_train_dict,full_val_dict = {},{} + for input_key in self.problem.x: + full_train_dict[input_key] = train_data[input_key] + full_val_dict[input_key] = val_data[input_key] + else: + if not (type(self.problem.x) is list) and not (type(self.problem.y_true) is list): + full_train_dict = {self.problem.x:train_data[self.problem.x],self.problem.y_true:train_data[self.problem.y_true]} + full_val_dict = {self.problem.x:val_data[self.problem.x],self.problem.y_true:val_data[self.problem.y_true]} + else: + full_train_dict, full_val_dict = {}, {} + if type(self.problem.x) is list: + for input_key in self.problem.x: + full_train_dict[input_key] = train_data[input_key] + full_val_dict[input_key] = val_data[input_key] + else: + full_train_dict[self.problem.x] = train_data[self.problem.x] + full_val_dict[self.problem.x] = val_data[self.problem.x] + if type(self.problem.y_true) is list: + for output_key in self.problem.y_true: + full_train_dict[output_key] = train_data[output_key] + full_val_dict[output_key] = val_data[output_key] + else: + full_train_dict[self.problem.y_true] = train_data[self.problem.y_true] + full_val_dict[self.problem.y_true] = val_data[self.problem.y_true] + + + + + d_full_train, U_full_train = low_rank_hessian(self.optimizer,full_train_dict,k_rank,p_oversample,verbose=True) + self._logger['full_train_eigenvalues'][iteration] = d_full_train + + else: + d_full,_ = low_rank_hessian(self.optimizer,train_dict,k_rank,p_oversample) + self._logger['train_eigenvalues'][iteration] = d_full + d_val,_ = low_rank_hessian(self.optimizer,val_dict,k) + self._logger['val_eigenvalues'][iteration] = d_val + + + + def print(self,first_print = False,iteration = None,force_print = False): + ################################################################################ + # Check to make sure everything requested to print exists + for key in self.settings['printing_items'].keys(): + assert self.settings['printing_items'][key] in self._logger.keys(), 'item '+str(self.settings['printing_items'][key])+' not in logger' + ################################################################################ + # First print : column names + if first_print: + print(80*'#') + format_string = '' + for i in range(len(self.settings['printing_items'].keys())): + if i == 0: + format_string += '{0:7} ' + else: + format_string += '{'+str(i)+':7} ' + string_tuples = (print_string.center(8) for print_string in self.settings['printing_items'].keys()) + print(format_string.format(*string_tuples)) + ################################################################################ + # Iteration prints + else: + format_string = '' + for i,key in enumerate(self.settings['printing_items'].keys()): + if iteration not in self._logger[self.settings['printing_items'][key]].keys(): + format_string += '{'+str(i)+':7} ' + elif 'sweeps' in key: + format_string += '{'+str(i)+':^8.2f} ' + elif 'acc' in key: + value = self._logger[self.settings['printing_items'][key]][iteration] + if value < 0.0: + format_string += '{'+str(i)+':.2%} ' + else: + format_string += '{'+str(i)+':.3%} ' + elif 'rank' in key: + format_string += '{'+str(i)+':5} ' + else: + format_string += '{'+str(i)+':1.2e} ' + ################################################################################ + # Check sweep remainder condition here + every_sweep = self.settings['printing_sweep_frequency'] + if every_sweep is None or not ('sweeps' in self.settings['printing_items'].keys()): + print_this_time = True + elif iteration == 0 or force_print: + print_this_time = True + elif every_sweep is not None: + assert iteration is not None + print_this_time = self._check_sweep_remainder_condition(iteration,every_sweep) + ################################################################################ + # Actual printing + if print_this_time: + value_list = [] + for item in self.settings['printing_items']: + if iteration in self._logger[self.settings['printing_items'][item]]: + value_list.append(self._logger[self.settings['printing_items'][item]][iteration]) + else: + value_list.append(8*' ') + # value_tuples = (self._logger[self.settings['printing_items'][item]][iteration] for item in self.settings['printing_items']) + print(format_string.format(*value_list)) + + + def _check_sweep_remainder_condition(self,iteration, sweeps_divisor): + + last_sweep_floor_div,last_sweep_rem = np.divmod(self._logger['sweeps'][iteration-1], sweeps_divisor) + this_sweep_floor_div,this_sweep_rem = np.divmod(self._logger['sweeps'][iteration], sweeps_divisor) + + if this_sweep_floor_div > last_sweep_floor_div: + return True + else: + return False diff --git a/hessianlearn/model/model.py b/hessianlearn/model/model.py index 9c6d83f..28f1c68 100644 --- a/hessianlearn/model/model.py +++ b/hessianlearn/model/model.py @@ -72,7 +72,7 @@ def HessianlearnModelSettings(settings = {}): # Optimizer settings settings['optimizer'] = ['lrsfn', "String to denote choice of optimizer"] - settings['alpha'] = [5e-2, "Initial steplength, or learning rate"] + settings['alpha'] = [1e-3, "Initial steplength, or learning rate"] settings['hessian_low_rank'] = [20, "Low rank to be used for LRSFN / SFN"] settings['globalization'] = [None, "None means steps of length alpha will be taken at each iteration"] settings['max_backtrack'] = [10, "Maximum number of backtracking iterations for each line search"] @@ -135,9 +135,6 @@ def __init__(self,problem,regularization,data,derivative_data = None,settings = self._optimizer = None - - - @property def sess(self): return self._sess @@ -321,6 +318,10 @@ def _initialize_logging(self): if hasattr(self.problem,'_variance_reduction'): logger['val_variance_reduction'] = {} + if hasattr(self.problem, 'metric_dict'): + for metric_name in self.problem.metric_dict.keys(): + logger[metric_name] = {} + if self.problem.has_derivative_loss: logger['train_h1_loss'] = {} logger['val_h1_loss'] = {} @@ -404,7 +405,6 @@ def _fit(self,options = None, w_0 = None): val_dict = next(iter(self.data.validation)) if self.problem.is_autoencoder: assert not hasattr(self.problem,'y_true') - # val_dict = {self.problem.x: val_data[self.problem.x]} elif self.problem.is_gan: random_state_gan = np.random.RandomState(seed = 0) # Should the first dimension here agree with the size of the validation data? @@ -449,6 +449,10 @@ def _fit(self,options = None, w_0 = None): # Much more efficient to have the actual optimizer / minimize() function # return this information since it has to query the graph # This is a place to cut down on computational graph queries + try: + self.problem.NN.reset_metrics() + except: + pass if hasattr(self.problem,'accuracy'): if self.problem.has_derivative_loss: norm_g, train_loss, train_acc, train_h1_acc = sess.run([self.problem.norm_g,self.problem.loss,\ @@ -478,10 +482,18 @@ def _fit(self,options = None, w_0 = None): validate_this_iteration = True else: validate_this_iteration = self._check_sweep_remainder_condition(iteration,self.settings['validate_frequency']) - + try: + self.problem.NN.reset_metrics() + except: + pass if hasattr(self.problem,'accuracy'): if validate_this_iteration: validation_start = time.time() + if hasattr(self.problem,'metric_dict'): + metric_names = list(self.problem.metric_dict.keys()) + metric_values = sess.run(list(self.problem.metric_dict.values()),train_dict) + for metric_name,metric_value in zip(metric_names,metric_values): + self.logger[metric_name][iteration] = metric_value if hasattr(self.problem,'_variance_reduction'): if self.problem.has_derivative_loss: val_loss, val_acc, val_h1_acc, val_var_red =\ @@ -592,7 +604,6 @@ def _record_spectrum(self,iteration): k_rank = self.settings['target_rank'] p_oversample = self.settings['oversample'] - if self.settings['rayleigh_quotients']: print('It is working') my_t0 = time.time() @@ -601,11 +612,37 @@ def _record_spectrum(self,iteration): val_data = self.data.val._data if self.problem.is_autoencoder: - full_train_dict = {self.problem.x:train_data[self.problem.x]} - full_val_dict = {self.problem.x:val_data[self.problem.x]} + if not (type(self.problem.x) is list): + full_train_dict = {self.problem.x:train_data[self.problem.x]} + full_val_dict = {self.problem.x:val_data[self.problem.x]} + else: + full_train_dict,full_val_dict = {},{} + for input_key in self.problem.x: + full_train_dict[input_key] = train_data[input_key] + full_val_dict[input_key] = val_data[input_key] else: - full_train_dict = {self.problem.x:train_data[self.problem.x],self.problem.y_true:train_data[self.problem.y_true]} - full_val_dict = {self.problem.x:val_data[self.problem.x],self.problem.y_true:val_data[self.problem.y_true]} + if not (type(self.problem.x) is list) and not (type(self.problem.y_true) is list): + full_train_dict = {self.problem.x:train_data[self.problem.x],self.problem.y_true:train_data[self.problem.y_true]} + full_val_dict = {self.problem.x:val_data[self.problem.x],self.problem.y_true:val_data[self.problem.y_true]} + else: + full_train_dict, full_val_dict = {}, {} + if type(self.problem.x) is list: + for input_key in self.problem.x: + full_train_dict[input_key] = train_data[input_key] + full_val_dict[input_key] = val_data[input_key] + else: + full_train_dict[self.problem.x] = train_data[self.problem.x] + full_val_dict[self.problem.x] = val_data[self.problem.x] + if type(self.problem.y_true) is list: + for output_key in self.problem.y_true: + full_train_dict[output_key] = train_data[output_key] + full_val_dict[output_key] = val_data[output_key] + else: + full_train_dict[self.problem.y_true] = train_data[self.problem.y_true] + full_val_dict[self.problem.y_true] = val_data[self.problem.y_true] + + + d_full_train, U_full_train = low_rank_hessian(self.optimizer,full_train_dict,k_rank,p_oversample,verbose=True) self._logger['full_train_eigenvalues'][iteration] = d_full_train diff --git a/hessianlearn/problem/__init__.py b/hessianlearn/problem/__init__.py index e3ded50..65cd88b 100644 --- a/hessianlearn/problem/__init__.py +++ b/hessianlearn/problem/__init__.py @@ -17,7 +17,7 @@ from __future__ import absolute_import, division, print_function -from .problem import Problem, ClassificationProblem, RegressionProblem, H1RegressionProblem,\ +from .problem import Problem, ClassificationProblem, KerasModelProblem, RegressionProblem, H1RegressionProblem,\ AutoencoderProblem,VariationalAutoencoderProblem, GenerativeAdversarialNetworkProblem from .hessian import Hessian, HessianWrapper diff --git a/hessianlearn/problem/preconditioner.py b/hessianlearn/problem/preconditioner.py index a94f186..9208a19 100644 --- a/hessianlearn/problem/preconditioner.py +++ b/hessianlearn/problem/preconditioner.py @@ -19,11 +19,11 @@ from __future__ import absolute_import, division, print_function import numpy as np import tensorflow as tf -tf.compat.v1.enable_eager_execution() +# tf.compat.v1.enable_eager_execution() if int(tf.__version__[0]) > 1: import tensorflow.compat.v1 as tf - # tf.disable_v2_behavior() - tf.enable_eager_execution() + tf.disable_v2_behavior() + # tf.enable_eager_execution() class Preconditioner(object): """ diff --git a/hessianlearn/problem/problem.py b/hessianlearn/problem/problem.py index 55a5d34..92bc731 100644 --- a/hessianlearn/problem/problem.py +++ b/hessianlearn/problem/problem.py @@ -59,7 +59,7 @@ class Problem(ABC): It takes a neural network model and defines loss function and derivatives Also defines update operations. """ - def __init__(self,NeuralNetwork,hessian_block_size = None,dtype = tf.float32): + def __init__(self,NeuralNetwork,hessian_block_size = None,dtype = None): """ The Problem parent class constructor takes a neural network model (typically from tf.keras.Model) Children class implement different loss functions which are implemented by the method _initialize_loss @@ -74,7 +74,10 @@ def __init__(self,NeuralNetwork,hessian_block_size = None,dtype = tf.float32): # Hessian block size self._hessian_block_size = hessian_block_size # Data type - self._dtype = dtype + if dtype is None: + self._dtype = NeuralNetwork.inputs[0].dtype + else: + self._dtype = dtype # Initialize the neural network(s) self._initialize_network(NeuralNetwork) @@ -181,6 +184,9 @@ def _initialize_network(self,NeuralNetwork): Must set member variable self._output_shape """ self._NN = NeuralNetwork + assert len(self.NN.inputs) == 1 and len(self.NN.outputs) == 1,\ + 'This class only supports single input / output networks. For multi input / output look\ + at hessianlearn.problem.KerasModelProblem' self.x = self.NN.inputs[0] self.y_prediction = self.NN(self.x) @@ -348,6 +354,109 @@ def _partition_dictionaries(self,data_dictionary,n_partitions): raise NotImplementedError("Child class should implement method _partition_dictionaries") +class KerasModelProblem(Problem): + """ + This class implements an hessianlearn Problem that inherits losses and metrics from tf.keras.. + + """ + def __init__(self,NeuralNetwork,hessian_block_size = None): + """ + The constructor for this class takes: + -NeuralNetwork: the neural network represented as a tf.keras Model + + """ + assert NeuralNetwork._is_compiled, 'Must first compile the network before passing it in.' + # Assertion about data type conformity: + dtype = NeuralNetwork.inputs[0].dtype + # This may be a redundant check since the tf.keras.Model.compile method may + # enforce type conformity anyways + for input_i in NeuralNetwork.inputs[1:]: + assert input_i.dtype == dtype + for output_i in NeuralNetwork.outputs: + assert output_i.dtype == dtype + + super(KerasModelProblem,self).__init__(NeuralNetwork,hessian_block_size = hessian_block_size,dtype = dtype) + + @property + def metric_dict(self): + return self._metric_dict + + + def _initialize_network(self,NeuralNetwork): + """ + This method defines the neural network model + -NeuralNetwork: the neural network as a tf.keras.model.Model + + Must set member variable self._output_shape + """ + self._NN = NeuralNetwork + + if len(self.NN.inputs) == 1: + self.x = self.NN.inputs[0] + else: + self.x = self.NN.inputs + + self.y_prediction = self.NN(self.x) + + if len(self.NN.outputs) == 1: + # Simply do things the old way + # Why does the following not work: + # self.y_true = self.NN.outputs[0] + # Instead I have to use a placeholder for true output + output_shape = self.NN.output_shape + + self.y_true = tf.placeholder(self.dtype, output_shape,name='output_placeholder') + + if len(self.y_prediction.shape) > 2: + self._output_dimension = 1. + for shape in self.y_prediction.shape[1:]: + self._output_dimension *= shape.value + else: + self._output_dimension = self.y_prediction.shape[-1].value + else: + # Why does the following not work: + # self.y_true = self.NN.outputs[0] + # Instead I have to use a placeholder for true output + + self.y_true = [] + for i,shape in enumerate(self.NN.output_shape): + self.y_true.append(tf.placeholder(self.dtype, shape,name='output_placeholder_'+str(i))) + + self._output_dimension = 0 + for prediction in self.y_prediction: + if len(prediction.shape) > 2: + _output_dimension = 1. + for shape in prediction.shape[1:]: + _output_dimension *= shape.value + else: + _output_dimension = prediction.shape[-1].value + self._output_dimension += _output_dimension + + def _initialize_loss(self): + """ + This method defines the least squares loss function as well as relative error and accuracy + """ + with tf.name_scope('loss'): + if len(self._NN.loss_functions) == 1: + self._loss = self._NN.loss_functions[0](self.y_true,self.y_prediction) + else: + weights_and_losses = zip(self.NN._loss_weights_list,self._NN.loss_functions) + self._loss = sum([weight_i*loss_i(self.y_true,self.y_prediction) for weight_i, loss_i in weights_and_losses]) + + with tf.name_scope('accuracy'): + # The current convention is to pull out the first metric to be used + # an an 'accuracy' in printing. All metrics will be logged, however, + # and the end-user can specify what gets printed at each iteration + # by specifying printing items in the settings for the model class + m_0 = self.NN.metrics[0] + self._accuracy = self.NN.metrics[0](self.y_true,self.y_prediction) + + + with tf.name_scope('metrics'): + self._metric_dict = {} + for metric in self.NN.metrics: + self._metric_dict[metric.name] = metric(self.y_true,self.y_prediction) + class ClassificationProblem(Problem): """ @@ -439,7 +548,6 @@ def _partition_dictionaries(self,data_dictionary,n_partitions): return dictionary_partitions - class RegressionProblem(Problem): """ This class implements the description of basic regression problems. diff --git a/hessianlearn/test/test_HessianlearnModel.py b/hessianlearn/test/test_HessianlearnModel.py index 10be0ee..e56689c 100644 --- a/hessianlearn/test/test_HessianlearnModel.py +++ b/hessianlearn/test/test_HessianlearnModel.py @@ -18,16 +18,15 @@ import unittest import numpy as np -import tensorflow as tf -if int(tf.__version__[0]) > 1: - import tensorflow.compat.v1 as tf - tf.disable_v2_behavior() - import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' os.environ['KMP_DUPLICATE_LIB_OK']='True' os.environ["KMP_WARNINGS"] = "FALSE" +import tensorflow as tf +if int(tf.__version__[0]) > 1: + import tensorflow.compat.v1 as tf + tf.disable_v2_behavior() import sys @@ -59,19 +58,21 @@ def one_hot_vectors(labels_temp): tf.keras.layers.Dense(10) ]) # Instantiate the problem, regularization. - problem = ClassificationProblem(classifier,loss_type = 'least_squares',dtype=tf.float32) - regularization = L2Regularization(problem,gamma =0.001) + problem = ClassificationProblem(classifier,loss_type = 'cross_entropy',dtype=tf.float32) + regularization = L2Regularization(problem,gamma =0.) # Instante the data object train_dict = {problem.x:x_train, problem.y_true:y_train} validation_dict = {problem.x:x_test, problem.y_true:y_test} - data = Data(train_dict,256,validation_data = validation_dict,hessian_batch_size = 32) + data = Data(train_dict,32,validation_data = validation_dict,hessian_batch_size = 8) # Instantiate the model object HLModelSettings = HessianlearnModelSettings() HLModelSettings['max_sweeps'] = 1. HLModel = HessianlearnModel(problem,regularization,data,settings = HLModelSettings) - for optimizer in ['lrsfn','adam','gd','incg','sgd']: + for optimizer in ['lrsfn','adam','gd','sgd','incg']: HLModel.settings['optimizer'] = optimizer + if optimizer == 'incg': + HLModel.settings['alpha'] = 1e-4 HLModel.fit() first_loss = HLModel.logger['train_loss'][0] last_iteration = max(HLModel.logger['train_loss'].keys())