-
Notifications
You must be signed in to change notification settings - Fork 1
/
demev3.m
154 lines (136 loc) · 5.58 KB
/
demev3.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
%DEMEV3 Demonstrate Bayesian regression for the RBF.
%
% Description
% The problem consists an input variable X which sampled from a
% Gaussian distribution, and a target variable T generated by computing
% SIN(2*PI*X) and adding Gaussian noise. An RBF network with linear
% outputs is trained by minimizing a sum-of-squares error function with
% isotropic Gaussian regularizer, using the scaled conjugate gradient
% optimizer. The hyperparameters ALPHA and BETA are re-estimated using
% the function EVIDENCE. A graph is plotted of the original function,
% the training data, the trained network function, and the error bars.
%
% See also
% DEMEV1, EVIDENCE, RBF, SCG, NETEVFWD
%
% Copyright (c) Ian T Nabney (1996-2001)
clc;
disp('This demonstration illustrates the application of Bayesian')
disp('re-estimation to determine the hyperparameters in a simple regression')
disp('problem using an RBF netowk. It is based on a the fact that the')
disp('posterior distribution for the output weights of an RBF is Gaussian')
disp('and uses the evidence maximization framework of MacKay.')
disp(' ')
disp('First, we generate a synthetic data set consisting of a single input')
disp('variable x sampled from a Gaussian distribution, and a target variable')
disp('t obtained by evaluating sin(2*pi*x) and adding Gaussian noise.')
disp(' ')
disp('Press any key to see a plot of the data together with the sine function.')
pause;
% Generate the matrix of inputs x and targets t.
ndata = 16; % Number of data points.
noise = 0.1; % Standard deviation of noise distribution.
randn('state', 0);
rand('state', 0);
x = 0.25 + 0.07*randn(ndata, 1);
t = sin(2*pi*x) + noise*randn(size(x));
% Plot the data and the original sine function.
h = figure;
nplot = 200;
plotvals = linspace(0, 1, nplot)';
plot(x, t, 'ok')
xlabel('Input')
ylabel('Target')
hold on
axis([0 1 -1.5 1.5])
fplot('sin(2*pi*x)', [0 1], '-g')
legend('data', 'function');
disp(' ')
disp('Press any key to continue')
pause; clc;
disp('Next we create a two-layer MLP network having 3 hidden units and one')
disp('linear output. The model assumes Gaussian target noise governed by an')
disp('inverse variance hyperparmeter beta, and uses a simple Gaussian prior')
disp('distribution governed by an inverse variance hyperparameter alpha.')
disp(' ');
disp('The network weights and the hyperparameters are initialised and then')
disp('the output layer weights are optimized with the scaled conjugate gradient')
disp('algorithm using the SCG function, with the hyperparameters kept')
disp('fixed. After a maximum of 50 iterations, the hyperparameters are')
disp('re-estimated using the EVIDENCE function. The process of optimizing')
disp('the weights with fixed hyperparameters and then re-estimating the')
disp('hyperparameters is repeated for a total of 3 cycles.')
disp(' ')
disp('Press any key to train the network and determine the hyperparameters.')
pause;
% Set up network parameters.
nin = 1; % Number of inputs.
nhidden = 3; % Number of hidden units.
nout = 1; % Number of outputs.
alpha = 0.01; % Initial prior hyperparameter.
beta_init = 50.0; % Initial noise hyperparameter.
% Create and initialize network weight vector.
net = rbf(nin, nhidden, nout, 'tps', 'linear', alpha, beta_init);
[net.mask, prior] = rbfprior('tps', nin, nhidden, nout, alpha, alpha);
net = netinit(net, prior);
options = foptions;
options(14) = 5; % At most 5 EM iterations for basis functions
options(1) = -1; % Turn off all messages
net = rbfsetbf(net, options, x); % Initialise the basis functions
% Now train the network
nouter = 5;
ninner = 2;
options = foptions;
options(1) = 1;
options(2) = 1.0e-5; % Absolute precision for weights.
options(3) = 1.0e-5; % Precision for objective function.
options(14) = 50; % Number of training cycles in inner loop.
% Train using scaled conjugate gradients, re-estimating alpha and beta.
for k = 1:nouter
net = netopt(net, options, x, t, 'scg');
[net, gamma] = evidence(net, x, t, ninner);
fprintf(1, '\nRe-estimation cycle %d:\n', k);
fprintf(1, ' alpha = %8.5f\n', net.alpha);
fprintf(1, ' beta = %8.5f\n', net.beta);
fprintf(1, ' gamma = %8.5f\n\n', gamma);
disp(' ')
disp('Press any key to continue.')
pause;
end
fprintf(1, 'true beta: %f\n', 1/(noise*noise));
disp(' ')
disp('Network training and hyperparameter re-estimation are now complete.')
disp('Compare the final value for the hyperparameter beta with the true')
disp('value.')
disp(' ')
disp('Notice that the final error value is close to the number of data')
disp(['points (', num2str(ndata),') divided by two.'])
disp(' ')
disp('Press any key to continue.')
pause; clc;
disp('We can now plot the function represented by the trained network. This')
disp('corresponds to the mean of the predictive distribution. We can also')
disp('plot ''error bars'' representing one standard deviation of the')
disp('predictive distribution around the mean.')
disp(' ')
disp('Press any key to add the network function and error bars to the plot.')
pause;
% Evaluate error bars.
[y, sig2] = netevfwd(netpak(net), net, x, t, plotvals);
sig = sqrt(sig2);
% Plot the data, the original function, and the trained network function.
[y, z] = rbffwd(net, plotvals);
figure(h); hold on;
plot(plotvals, y, '-r')
xlabel('Input')
ylabel('Target')
plot(plotvals, y + sig, '-b');
plot(plotvals, y - sig, '-b');
legend('data', 'function', 'network', 'error bars');
disp(' ')
disp('Notice how the confidence interval spanned by the ''error bars'' is')
disp('smaller in the region of input space where the data density is high,')
disp('and becomes larger in regions away from the data.')
disp(' ')
disp('Press any key to end.')
pause; clc; close(h);