@@ -73,31 +73,37 @@ def build_model(self):
7373 if self .y_dim :
7474 self .y = tf .placeholder (tf .float32 , [self .batch_size , self .y_dim ], name = 'y' )
7575
76- image_dims = [self . output_height , self . output_width , self .c_dim ]
76+ image_dims = [None , None , self .c_dim ]
7777
7878 self .inputs = tf .placeholder (
79- tf .float32 , [self .batch_size ] + image_dims name = 'real_images' )
79+ tf .float32 , [self .batch_size ] + image_dims , name = 'real_images' )
8080 self .sample_inputs = tf .placeholder (
8181 tf .float32 , [self .sample_num ] + image_dims , name = 'sample_inputs' )
8282
83- inputs = tf .image .resize_images (
84- self .inputs , [self .output_height , self .output_width ])
85- sample_inputs = tf .image .resize_images (
86- self .sample_inputs , [self .output_height , self .output_width ])
83+ if not self .is_crop :
84+ inputs = tf .image .resize_images (
85+ self .inputs , [self .output_height , self .output_width ])
86+ sample_inputs = tf .image .resize_images (
87+ self .sample_inputs , [self .output_height , self .output_width ])
88+ else :
89+ inputs = self .inputs
90+ sample_inputs = self .sample_inputs
8791
8892 self .z = tf .placeholder (
8993 tf .float32 , [None , self .z_dim ], name = 'z' )
9094 self .z_sum = histogram_summary ("z" , self .z )
9195
9296 if self .y_dim :
9397 self .G = self .generator (self .z , self .y )
94- self .D , self .D_logits = self .discriminator (self .inputs , self .y , reuse = False )
98+ self .D , self .D_logits = \
99+ self .discriminator (inputs , self .y , reuse = False )
95100
96101 self .sampler = self .sampler (self .z , self .y )
97- self .D_ , self .D_logits_ = self .discriminator (self .G , self .y , reuse = True )
102+ self .D_ , self .D_logits_ = \
103+ self .discriminator (self .G , self .y , reuse = True )
98104 else :
99105 self .G = self .generator (self .z )
100- self .D , self .D_logits = self .discriminator (self . inputs )
106+ self .D , self .D_logits = self .discriminator (inputs )
101107
102108 self .sampler = self .sampler (self .z )
103109 self .D_ , self .D_logits_ = self .discriminator (self .G , reuse = True )
@@ -144,10 +150,9 @@ def train(self, config):
144150 g_optim = tf .train .AdamOptimizer (config .learning_rate , beta1 = config .beta1 ) \
145151 .minimize (self .g_loss , var_list = self .g_vars )
146152 try :
147- tf .initialize_all_variables ().run ()
153+ tf .global_variables_initializer ().run ()
148154 except :
149- init_op = tf .global_variables_initializer ()
150- self .sess .run (init_op )
155+ tf .initialize_all_variables ().run ()
151156
152157 self .g_sum = merge_summary ([self .z_sum , self .d__sum ,
153158 self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
@@ -198,8 +203,8 @@ def train(self, config):
198203 batch_files = data [idx * config .batch_size :(idx + 1 )* config .batch_size ]
199204 batch = [
200205 get_image (batch_file ,
201- self .image_height ,
202- self .image_width ,
206+ image_height = self .image_height ,
207+ image_width = self .image_width ,
203208 resize_height = self .output_height ,
204209 resize_width = self .output_width ,
205210 is_crop = self .is_crop ,
@@ -263,8 +268,8 @@ def train(self, config):
263268 feed_dict = { self .z : batch_z })
264269 self .writer .add_summary (summary_str , counter )
265270
266- errD_fake = self .d_loss_fake .eval ({self .z : batch_z })
267- errD_real = self .d_loss_real .eval ({self .inputs : batch_images })
271+ errD_fake = self .d_loss_fake .eval ({ self .z : batch_z })
272+ errD_real = self .d_loss_real .eval ({ self .inputs : batch_images })
268273 errG = self .g_loss .eval ({self .z : batch_z })
269274
270275 counter += 1
@@ -325,10 +330,10 @@ def discriminator(self, image, y=None, reuse=False):
325330
326331 h1 = lrelu (self .d_bn1 (conv2d (h0 , self .df_dim + self .y_dim , name = 'd_h1_conv' )))
327332 h1 = tf .reshape (h1 , [self .batch_size , - 1 ])
328- h1 = tf .concat ( 1 , [h1 , y ])
333+ h1 = tf .concat_v2 ( [h1 , y ], 1 )
329334
330335 h2 = lrelu (self .d_bn2 (linear (h1 , self .dfc_dim , 'd_h2_lin' )))
331- h2 = tf .concat ( 1 , [h2 , y ])
336+ h2 = tf .concat_v2 ( [h2 , y ], 1 )
332337
333338 h3 = linear (h2 , 1 , 'd_h3_lin' )
334339
@@ -337,100 +342,114 @@ def discriminator(self, image, y=None, reuse=False):
337342 def generator (self , z , y = None ):
338343 with tf .variable_scope ("generator" ) as scope :
339344 if not self .y_dim :
340- s = self .output_size
341- s2 , s4 , s8 , s16 = int (s / 2 ), int (s / 4 ), int (s / 8 ), int (s / 16 )
345+ s_h , s_w = self .output_height , self .output_width
346+ s_h2 , s_h4 , s_h8 , s_h16 = \
347+ int (s_h / 2 ), int (s_h / 4 ), int (s_h / 8 ), int (s_h / 16 )
348+ s_w2 , s_w4 , s_w8 , s_w16 = \
349+ int (s_w / 2 ), int (s_w / 4 ), int (s_w / 8 ), int (s_w / 16 )
342350
343351 # project `z` and reshape
344- self .z_ , self .h0_w , self .h0_b = linear (z , self .gf_dim * 8 * s16 * s16 , 'g_h0_lin' , with_w = True )
352+ self .z_ , self .h0_w , self .h0_b = linear (
353+ z , self .gf_dim * 8 * s_h16 * s_w16 , 'g_h0_lin' , with_w = True )
345354
346- self .h0 = tf .reshape (self .z_ , [- 1 , s16 , s16 , self .gf_dim * 8 ])
355+ self .h0 = tf .reshape (
356+ self .z_ , [- 1 , s_h16 , s_w16 , self .gf_dim * 8 ])
347357 h0 = tf .nn .relu (self .g_bn0 (self .h0 ))
348358
349- self .h1 , self .h1_w , self .h1_b = deconv2d (h0 ,
350- [self .batch_size , s8 , s8 , self .gf_dim * 4 ], name = 'g_h1' , with_w = True )
359+ self .h1 , self .h1_w , self .h1_b = deconv2d (
360+ h0 , [self .batch_size , s_h8 , s_w8 , self .gf_dim * 4 ], name = 'g_h1' , with_w = True )
351361 h1 = tf .nn .relu (self .g_bn1 (self .h1 ))
352362
353- h2 , self .h2_w , self .h2_b = deconv2d (h1 ,
354- [self .batch_size , s4 , s4 , self .gf_dim * 2 ], name = 'g_h2' , with_w = True )
363+ h2 , self .h2_w , self .h2_b = deconv2d (
364+ h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ], name = 'g_h2' , with_w = True )
355365 h2 = tf .nn .relu (self .g_bn2 (h2 ))
356366
357- h3 , self .h3_w , self .h3_b = deconv2d (h2 ,
358- [self .batch_size , s2 , s2 , self .gf_dim * 1 ], name = 'g_h3' , with_w = True )
367+ h3 , self .h3_w , self .h3_b = deconv2d (
368+ h2 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 1 ], name = 'g_h3' , with_w = True )
359369 h3 = tf .nn .relu (self .g_bn3 (h3 ))
360370
361- h4 , self .h4_w , self .h4_b = deconv2d (h3 ,
362- [self .batch_size , s , s , self .c_dim ], name = 'g_h4' , with_w = True )
371+ h4 , self .h4_w , self .h4_b = deconv2d (
372+ h3 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h4' , with_w = True )
363373
364374 return tf .nn .tanh (h4 )
365375 else :
366- s = self .output_size
367- s2 , s4 = int (s / 2 ), int (s / 4 )
376+ s_h , s_w = self .output_height , self .output_width
377+ s_h2 , s_h4 = int (s_h / 2 ), int (s_h / 4 )
378+ s_w2 , s_w4 = int (s_w / 2 ), int (s_w / 4 )
368379
369380 # yb = tf.expand_dims(tf.expand_dims(y, 1),2)
370381 yb = tf .reshape (y , [self .batch_size , 1 , 1 , self .y_dim ])
371- z = tf .concat ( 1 , [z , y ])
382+ z = tf .concat_v2 ( [z , y ], 1 )
372383
373- h0 = tf .nn .relu (self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
374- h0 = tf .concat (1 , [h0 , y ])
384+ h0 = tf .nn .relu (
385+ self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
386+ h0 = tf .concat_v2 ([h0 , y ], 1 )
375387
376- h1 = tf .nn .relu (self .g_bn1 (linear (h0 , self .gf_dim * 2 * s4 * s4 , 'g_h1_lin' )))
377- h1 = tf .reshape (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ])
388+ h1 = tf .nn .relu (self .g_bn1 (
389+ linear (h0 , self .gf_dim * 2 * s_h4 * s_w4 , 'g_h1_lin' )))
390+ h1 = tf .reshape (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ])
378391
379392 h1 = conv_cond_concat (h1 , yb )
380393
381394 h2 = tf .nn .relu (self .g_bn2 (deconv2d (h1 ,
382- [self .batch_size , s2 , s2 , self .gf_dim * 2 ], name = 'g_h2' )))
395+ [self .batch_size , s_h2 , s_w2 , self .gf_dim * 2 ], name = 'g_h2' )))
383396 h2 = conv_cond_concat (h2 , yb )
384397
385398 return tf .nn .sigmoid (
386- deconv2d (h2 , [self .batch_size , s , s , self .c_dim ], name = 'g_h3' ))
399+ deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
387400
388401 def sampler (self , z , y = None ):
389402 with tf .variable_scope ("generator" ) as scope :
390403 scope .reuse_variables ()
391404
392405 if not self .y_dim :
393406
394- s = self .output_size
395- s2 , s4 , s8 , s16 = int (s / 2 ), int (s / 4 ), int (s / 8 ), int (s / 16 )
407+ s_h , s_w = self .output_height , self .output_width
408+ s_h2 , s_h4 , s_h8 , s_h16 = \
409+ int (s_h / 2 ), int (s_h / 4 ), int (s_h / 8 ), int (s_h / 16 )
410+ s_w2 , s_w4 , s_w8 , s_w16 = \
411+ int (s_w / 2 ), int (s_w / 4 ), int (s_w / 8 ), int (s_w / 16 )
396412
397413 # project `z` and reshape
398- h0 = tf .reshape (linear (z , self .gf_dim * 8 * s16 * s16 , 'g_h0_lin' ),
399- [- 1 , s16 , s16 , self .gf_dim * 8 ])
414+ h0 = tf .reshape (
415+ linear (z , self .gf_dim * 8 * s_h16 * s_w16 , 'g_h0_lin' ),
416+ [- 1 , s_h16 , s_w16 , self .gf_dim * 8 ])
400417 h0 = tf .nn .relu (self .g_bn0 (h0 , train = False ))
401418
402- h1 = deconv2d (h0 , [self .batch_size , s8 , s8 , self .gf_dim * 4 ], name = 'g_h1' )
419+ h1 = deconv2d (h0 , [self .batch_size , s_h8 , s_w8 , self .gf_dim * 4 ], name = 'g_h1' )
403420 h1 = tf .nn .relu (self .g_bn1 (h1 , train = False ))
404421
405- h2 = deconv2d (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ], name = 'g_h2' )
422+ h2 = deconv2d (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ], name = 'g_h2' )
406423 h2 = tf .nn .relu (self .g_bn2 (h2 , train = False ))
407424
408- h3 = deconv2d (h2 , [self .batch_size , s2 , s2 , self .gf_dim * 1 ], name = 'g_h3' )
425+ h3 = deconv2d (h2 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 1 ], name = 'g_h3' )
409426 h3 = tf .nn .relu (self .g_bn3 (h3 , train = False ))
410427
411- h4 = deconv2d (h3 , [self .batch_size , s , s , self .c_dim ], name = 'g_h4' )
428+ h4 = deconv2d (h3 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h4' )
412429
413430 return tf .nn .tanh (h4 )
414431 else :
415- s = self .output_size
416- s2 , s4 = int (s / 2 ), int (s / 4 )
432+ s_h , s_w = self .output_height , self .output_width
433+ s_h2 , s_h4 = int (s_h / 2 ), int (s_h / 4 )
434+ s_w2 , s_w4 = int (s_w / 2 ), int (s_w / 4 )
417435
418436 # yb = tf.reshape(y, [-1, 1, 1, self.y_dim])
419437 yb = tf .reshape (y , [self .batch_size , 1 , 1 , self .y_dim ])
420- z = tf .concat ( 1 , [z , y ])
438+ z = tf .concat_v2 ( [z , y ], 1 )
421439
422440 h0 = tf .nn .relu (self .g_bn0 (linear (z , self .gfc_dim , 'g_h0_lin' )))
423- h0 = tf .concat ( 1 , [h0 , y ])
441+ h0 = tf .concat_v2 ( [h0 , y ], 1 )
424442
425- h1 = tf .nn .relu (self .g_bn1 (linear (h0 , self .gf_dim * 2 * s4 * s4 , 'g_h1_lin' ), train = False ))
426- h1 = tf .reshape (h1 , [self .batch_size , s4 , s4 , self .gf_dim * 2 ])
443+ h1 = tf .nn .relu (self .g_bn1 (
444+ linear (h0 , self .gf_dim * 2 * s_h4 * s_w4 , 'g_h1_lin' ), train = False ))
445+ h1 = tf .reshape (h1 , [self .batch_size , s_h4 , s_w4 , self .gf_dim * 2 ])
427446 h1 = conv_cond_concat (h1 , yb )
428447
429448 h2 = tf .nn .relu (self .g_bn2 (
430- deconv2d (h1 , [self .batch_size , s2 , s2 , self .gf_dim * 2 ], name = 'g_h2' ), train = False ))
449+ deconv2d (h1 , [self .batch_size , s_h2 , s_w2 , self .gf_dim * 2 ], name = 'g_h2' ), train = False ))
431450 h2 = conv_cond_concat (h2 , yb )
432451
433- return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s , s , self .c_dim ], name = 'g_h3' ))
452+ return tf .nn .sigmoid (deconv2d (h2 , [self .batch_size , s_h , s_w , self .c_dim ], name = 'g_h3' ))
434453
435454 def load_mnist (self ):
436455 data_dir = os .path .join ("./data" , self .dataset_name )
@@ -468,11 +487,16 @@ def load_mnist(self):
468487 y_vec [i ,y [i ]] = 1.0
469488
470489 return X / 255. ,y_vec
490+
491+ @property
492+ def model_dir (self ):
493+ return "{}_{}_{}_{}" .format (
494+ self .dataset_name , self .batch_size ,
495+ self .output_height , self .output_width )
471496
472497 def save (self , checkpoint_dir , step ):
473498 model_name = "DCGAN.model"
474- model_dir = "%s_%s_%s" % (self .dataset_name , self .batch_size , self .output_size )
475- checkpoint_dir = os .path .join (checkpoint_dir , model_dir )
499+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
476500
477501 if not os .path .exists (checkpoint_dir ):
478502 os .makedirs (checkpoint_dir )
@@ -483,9 +507,7 @@ def save(self, checkpoint_dir, step):
483507
484508 def load (self , checkpoint_dir ):
485509 print (" [*] Reading checkpoints..." )
486-
487- model_dir = "%s_%s_%s" % (self .dataset_name , self .batch_size , self .output_size )
488- checkpoint_dir = os .path .join (checkpoint_dir , model_dir )
510+ checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
489511
490512 ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
491513 if ckpt and ckpt .model_checkpoint_path :
0 commit comments