88"""
99
1010
11- import numpy , os
12- #import data_provider
11+ import numpy , os , cPickle
1312from PIL import Image
1413
14+ def load_mnist ():
15+ path = '.'
16+ data = cPickle .load (open (os .path .join (path ,'mnist.pkl' ), 'r' ))
17+ return data
18+
1519def scale_to_unit_interval (ndar , eps = 1e-8 ):
1620 """ Scales all values in the ndarray ndar to be between 0 and 1 """
1721 ndar = ndar .copy ()
@@ -137,8 +141,8 @@ def tile_raster_images(X, img_shape, tile_shape, tile_spacing=(0, 0),
137141 return out_array
138142
139143def visualize_mnist ():
140- train , _ , _ , _ , _ , _ = data_provider . load_mnist ()
141- design_matrix = train
144+ ( train_X , train_Y ), ( valid_X , valid_Y ), ( test_X , test_Y ) = load_mnist ()
145+ design_matrix = train_X
142146 images = design_matrix [0 :2500 , :]
143147 channel_length = 28 * 28
144148 to_visualize = images
@@ -150,195 +154,9 @@ def visualize_mnist():
150154 im_new = Image .fromarray (numpy .uint8 (image_data ))
151155 im_new .save ('samples_mnist.png' )
152156 os .system ('eog samples_mnist.png' )
153-
154- def visualize_cifar10 ():
155- import cifar10_wrapper
156- train , test = cifar10_wrapper .get_cifar10_raw ()
157-
158-
159- images = train .X [0 :2500 , :]
160- channel_length = 32 * 32
161- to_visualize = (images [:, 0 :channel_length ],
162- images [:, channel_length :channel_length * 2 ],
163- images [:,channel_length * 2 :channel_length * 3 ],
164- None )
165-
166- image_data = tile_raster_images (to_visualize ,
167- img_shape = [32 ,32 ],
168- tile_shape = [50 ,50 ],
169- tile_spacing = (2 ,2 ))
170- im_new = Image .fromarray (numpy .uint8 (image_data ))
171- im_new .save ('samples_cifar10.png' )
172- os .system ('eog samples_cifar10.png' )
173-
174- def visualize_weight_matrix_single_channel (img_shape , tile_shape , to_visualize ):
175- # the weights learned from the black-white image, e.g., MNIST
176- # W: inputs by hidden
177- # img_shape = [28,28]
178- # tile_shape = [10,10]
179-
180- image_data = tile_raster_images (to_visualize ,
181- img_shape = img_shape ,
182- tile_shape = tile_shape ,
183- tile_spacing = (2 ,2 ))
184-
185- im_new = Image .fromarray (numpy .uint8 (image_data ))
186- im_new .save ('l0_weights.png' )
187- #im_new.save('ica_weights.png')
188- #os.system('eog ica_weights.png')
189-
190- def visualize_convNet_weights (params ):
191- W = params [- 4 ]
192- a ,b ,c ,d = W .shape
193- to_visualize = W .reshape ((a * b , c * d ))
194- image_data = tile_raster_images (to_visualize ,
195- img_shape = [c ,d ],
196- tile_shape = [10 ,5 ],
197- tile_spacing = (2 ,2 ))
198- im_new = Image .fromarray (numpy .uint8 (image_data ))
199- im_new .save ('convNet_weights.png' )
200-
201- def visualize_first_layer_weights (W , dataset_name = None ):
202- imgs = W .T
203- if dataset_name == 'MNIST' :
204- img_shape = [28 ,28 ]
205- to_visualize = imgs
206- elif dataset_name == 'TFD_unsupervised' :
207- img_shape = [48 ,48 ]
208- to_visualize = imgs
209-
210- elif dataset_name == 'CIFAR10' :
211- img_shape = [32 ,32 ]
212- channel_length = 32 * 32
213- to_visualize = (imgs [:, 0 :channel_length ],
214- imgs [:, channel_length :channel_length * 2 ],
215- imgs [:,channel_length * 2 :channel_length * 3 ],
216- None )
217- else :
218- raise NotImplementedError ('%s does not support visulization of W' % self .dataset_name )
219-
220- t = int (numpy .ceil (numpy .sqrt (W .shape [1 ])))
221-
222- tile_shape = [t ,t ]
223-
224- visualize_weight_matrix_single_channel (img_shape ,
225- tile_shape , to_visualize )
226-
227- def visualize_reconstruction_quality_ae (x , x_tilde , x_reconstructed , image_shape ):
228- # to visualize the reconstruction quality of MNIST on DAE
229- assert x .shape == x_tilde .shape
230- assert x_tilde .shape == x_reconstructed .shape
231-
232- n_show = 400
233- tile_shape = [20 , 20 ]
234- tile_spacing = (2 ,2 )
235- image_shape = image_shape
236-
237- idx = range (x .shape [0 ])
238- numpy .random .shuffle (idx )
239-
240- use = idx [:n_show ]
241- channel_length = image_shape
242- to_visualize = x [use ]
243- image_data = tile_raster_images (to_visualize ,
244- img_shape = image_shape ,
245- tile_shape = tile_shape ,
246- tile_spacing = tile_spacing )
247-
248- to_visualize = x_tilde [use ]
249- image_corrupted = tile_raster_images (to_visualize ,
250- img_shape = image_shape ,
251- tile_shape = tile_shape ,
252- tile_spacing = tile_spacing )
253-
254- to_visualize = x_reconstructed [use ]
255- image_reconstructed = tile_raster_images (to_visualize ,
256- img_shape = image_shape ,
257- tile_shape = tile_shape ,
258- tile_spacing = tile_spacing )
259-
260- vertical_bar = numpy .zeros ((image_data .shape [0 ], 5 ))
261- vertical_bar [:,2 ] += 255
262-
263- image = numpy .concatenate ((image_data , vertical_bar , image_corrupted ,
264- vertical_bar , image_reconstructed ), axis = 1 )
265-
266- im_new = Image .fromarray (numpy .uint8 (image ))
267- #im_new.save('reconstruction_mnist.png')
268- #os.system('eog reconstruction_mnist.png')
269- return im_new
270-
271- def visualize_gibbs_chain (data , samples , x_noisy , x_reconstruct , jumps , image_shape ):
272- # jumps is a binary matrix
273- # randomly pick to visualize
274- #assert data.shape == samples.shape
275-
276- n_show = 400
277- tile_shape = [20 , 20 ]
278- tile_spacing = (2 ,2 )
279- image_shape = image_shape
280-
281- idx = range (data .shape [0 ])
282- numpy .random .shuffle (idx )
283-
284- use = idx [:n_show ]
285-
286- to_visualize = data [use ]
287- image_data = tile_raster_images (to_visualize ,
288- img_shape = image_shape ,
289- tile_shape = tile_shape ,
290- tile_spacing = tile_spacing )
291-
292- use = range (n_show )
293-
294- to_visualize = samples [use ]
295- image_1 = tile_raster_images (to_visualize ,
296- img_shape = image_shape ,
297- tile_shape = tile_shape ,
298- tile_spacing = tile_spacing )
299- to_visualize = x_noisy [use ]
300- image_2 = tile_raster_images (to_visualize ,
301- img_shape = image_shape ,
302- tile_shape = tile_shape ,
303- tile_spacing = tile_spacing )
304- to_visualize = x_reconstruct [use ]
305- image_3 = tile_raster_images (to_visualize ,
306- img_shape = image_shape ,
307- tile_shape = tile_shape ,
308- tile_spacing = tile_spacing )
309- # now masking those intermediate steps in the chain
310- jumps = jumps .flatten ()
311- jumps [0 ] = 0
312- mask = numpy .zeros ((n_show )) != 0
313- for idx , jump in enumerate (jumps ):
314- if jumps [idx ] == 1 and jumps [idx - 1 ]== 0 :
315- mask [idx - 1 ] = True
316-
317- to_visualize = numpy .zeros (x_reconstruct .shape )
318- for i , m in enumerate (mask ):
319- if m :
320- to_visualize [i ] = x_reconstruct [i ]
321- image_4 = tile_raster_images (to_visualize ,
322- img_shape = image_shape ,
323- tile_shape = tile_shape ,
324- tile_spacing = tile_spacing )
325-
326- vertical_bar = numpy .zeros ((image_data .shape [0 ], 5 ))
327- vertical_bar [:,2 ] += 255
328-
329- image = numpy .concatenate ((image_data , vertical_bar , image_1 ,
330- vertical_bar , image_2 , vertical_bar ,
331- image_3 , vertical_bar , image_4 ), axis = 1 )
332- im_new = Image .fromarray (numpy .uint8 (image ))
333- #im_new.save('samples_mnist%.png')
334- #os.system('eog samples_mnist.png')
335- return im_new
336157
337158if __name__ == '__main__' :
338- #visualize_mnist()
339- #visualize_cifar10()
340- #W = RAB_tools.load_pkl('convNet_saved_params.pkl')
341- #visualize_convNet_weights(W)
342- visualize_weight_matrix_single_channel (W )
159+ visualize_mnist ()
160+
343161
344162
0 commit comments