-
Notifications
You must be signed in to change notification settings - Fork 1
/
demgtm2.m
194 lines (170 loc) · 5.45 KB
/
demgtm2.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
%DEMGTM2 Demonstrate GTM for visualisation.
%
% Description
% This script demonstrates the use of a GTM with a two-dimensional
% latent space to visualise data in a higher dimensional space. This is
% done through the use of the mean responsibility and magnification
% factors.
%
% See also
% DEMGTM1, GTM, GTMEM, GTMPOST
%
% Copyright (c) Ian T Nabney (1996-2001)
% Fix seeds for reproducible results
rand('state', 420);
randn('state', 420);
ndata = 300
clc;
disp('This demonstration shows how a Generative Topographic Mapping')
disp('can be used to model and visualise high dimensional data. The')
disp('data is generated from a mixture of two spherical Gaussians in')
dstring = ['four dimensional space. ', num2str(ndata), ...
' data points are generated.'];
disp(dstring);
disp(' ');
disp('Press any key to continue.')
pause
% Create data
data_dim = 4;
latent_dim = 2;
mix = gmm(data_dim, 2, 'spherical');
mix.centres = [1 1 1 1; 0 0 0 0];
mix.priors = [0.5 0.5];
mix.covars = [0.1 0.1];
[data, labels] = gmmsamp(mix, ndata);
latent_shape = [15 15]; % Number of latent points in each dimension
nlatent = prod(latent_shape); % Number of latent points
num_rbf_centres = 16;
clc;
dstring = ['Next we generate and initialise the GTM. There are ',...
num2str(nlatent), ' latent points'];
disp(dstring);
dstring = ['arranged in a square of ', num2str(latent_shape(1)), ...
' points on a side. There are ', num2str(num_rbf_centres), ...
' centres in the'];
disp(dstring);
disp('RBF model, which has Gaussian activation functions.')
disp(' ')
disp('Once the model is created, the latent data sample')
disp('and RBF centres are placed uniformly in the square [-1 1 -1 1].')
disp('The output weights of the RBF are computed to map the latent');
disp('space to the two dimensional PCA subspace of the data.');
disp(' ')
disp('Press any key to continue.');
pause;
% Create and initialise GTM model
net = gtm(latent_dim, nlatent, data_dim, num_rbf_centres, ...
'gaussian', 0.1);
options = foptions;
options(1) = -1;
options(7) = 1; % Set width factor of RBF
net = gtminit(net, options, data, 'regular', latent_shape, [4 4]);
options = foptions;
options(14) = 30;
options(1) = 1;
clc;
dstring = ['We now train the model with ', num2str(options(14)), ...
' iterations of'];
disp(dstring)
disp('the EM algorithm for the GTM.')
disp(' ')
disp('Press any key to continue.')
pause;
[net, options] = gtmem(net, data, options);
disp(' ')
disp('Press any key to continue.')
pause;
clc;
disp('We now visualise the data by plotting, for each data point,');
disp('the posterior mean and mode (in latent space). These give');
disp('a summary of the entire posterior distribution in latent space.')
disp('The corresponding values are joined by a line to aid the')
disp('interpretation.')
disp(' ')
disp('Press any key to continue.');
pause;
% Plot posterior means
means = gtmlmean(net, data);
modes = gtmlmode(net, data);
PointSize = 12;
ClassSymbol1 = 'r.';
ClassSymbol2 = 'b.';
fh1 = figure;
hold on;
title('Visualisation in latent space')
plot(means((labels==1),1), means(labels==1,2), ...
ClassSymbol1, 'MarkerSize', PointSize)
plot(means((labels>1),1),means(labels>1,2),...
ClassSymbol2, 'MarkerSize', PointSize)
ClassSymbol1 = 'ro';
ClassSymbol2 = 'bo';
plot(modes(labels==1,1), modes(labels==1,2), ...
ClassSymbol1)
plot(modes(labels>1,1),modes(labels>1,2),...
ClassSymbol2)
% Join up means and modes
for n = 1:ndata
plot([means(n,1); modes(n,1)], [means(n,2); modes(n,2)], 'g-')
end
% Place legend outside data plot
legend('Mean (class 1)', 'Mean (class 2)', 'Mode (class 1)',...
'Mode (class 2)', -1);
% Display posterior for a data point
% Choose an interesting one with a large distance between mean and
% mode
[distance, point] = max(sum((means-modes).^2, 2));
resp = gtmpost(net, data(point, :));
disp(' ')
disp('For more detailed information, the full posterior distribution')
disp('(or responsibility) can be plotted in latent space for a')
disp('single data point. This point has been chosen as the one')
disp('with the largest distance between mean and mode.')
disp(' ')
disp('Press any key to continue.');
pause;
R = reshape(resp, fliplr(latent_shape));
XL = reshape(net.X(:,1), fliplr(latent_shape));
YL = reshape(net.X(:,2), fliplr(latent_shape));
fh2 = figure;
imagesc(net.X(:, 1), net.X(:,2), R);
hold on;
tstr = ['Responsibility for point ', num2str(point)];
title(tstr);
set(gca,'YDir','normal')
colormap(hot);
colorbar
disp(' ');
disp('Press any key to continue.')
pause
clc
disp('Finally, we visualise the data with the posterior means in')
disp('latent space as before, but superimpose the magnification')
disp('factors to highlight the separation between clusters.')
disp(' ')
disp('Note the large magnitude factors down the centre of the')
disp('graph, showing that the manifold is stretched more in')
disp('this region than within each of the two clusters.')
ClassSymbol1 = 'g.';
ClassSymbol2 = 'b.';
fh3 = figure;
mags = gtmmag(net, net.X);
% Reshape into grid form
Mags = reshape(mags, fliplr(latent_shape));
imagesc(net.X(:, 1), net.X(:,2), Mags);
hold on
title('Dataset visualisation with magnification factors')
set(gca,'YDir','normal')
colormap(hot);
colorbar
hold on; % Else the magnification plot disappears
plot(means(labels==1,1), means(labels==1,2), ...
ClassSymbol1, 'MarkerSize', PointSize)
plot(means(labels>1,1), means(labels>1,2), ...
ClassSymbol2, 'MarkerSize', PointSize)
disp(' ')
disp('Press any key to exit.')
pause
close(fh1);
close(fh2);
close(fh3);
clear all;