13
13
14
14
from __future__ import print_function
15
15
import argparse
16
- import sys
17
- import time
18
16
19
17
import torch
20
18
import torch .optim as optim
@@ -57,8 +55,7 @@ def train(model, data_loader, optimizer, epoch):
57
55
optimizer .step ()
58
56
59
57
if batch_idx % args .log_interval == 0 :
60
- mesg = '{}\t Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
61
- time .ctime (),
58
+ mesg = 'Epoch: {} [{}/{} ({:.0f}%)]\t Loss: {:.6f}' .format (
62
59
epoch ,
63
60
batch_idx * len (data ),
64
61
len (data_loader .dataset ),
@@ -87,7 +84,7 @@ def test(model, data_loader):
87
84
for data , target in data_loader :
88
85
target_indices = target
89
86
target_one_hot = utils .one_hot_encode (
90
- target_indices , length = model .digits .num_units )
87
+ target_indices , length = model .digits .num_unit )
91
88
92
89
data , target = Variable (data , volatile = True ), Variable (target_one_hot )
93
90
@@ -133,12 +130,12 @@ def main():
133
130
default = 128 , help = 'testing batch size. default=128' )
134
131
parser .add_argument ('--loss-threshold' , type = float , default = 0.0001 ,
135
132
help = 'stop training if loss goes below this threshold. default=0.0001' )
136
- parser .add_argument (" --log-interval" , type = int , default = 1 ,
137
- help = 'number of images after which the training loss is logged , default is 1 ' )
138
- parser .add_argument ('--cuda' , action = 'store_true' ,
139
- help = 'set it to 1 for running on GPU, 0 for CPU ' )
133
+ parser .add_argument (' --log-interval' , type = int , default = 10 ,
134
+ help = 'how many batches to wait before logging training status , default=10 ' )
135
+ parser .add_argument ('--no- cuda' , action = 'store_true' , default = False ,
136
+ help = 'disables CUDA training, default=false ' )
140
137
parser .add_argument ('--threads' , type = int , default = 4 ,
141
- help = 'number of threads for data loader to use' )
138
+ help = 'number of threads for data loader to use, default=4 ' )
142
139
parser .add_argument ('--seed' , type = int , default = 42 ,
143
140
help = 'random seed for training. default=42' )
144
141
parser .add_argument ('--num-conv-channel' , type = int , default = 256 ,
@@ -149,20 +146,18 @@ def main():
149
146
default = 1152 , help = 'primary unit size. default=1152' )
150
147
parser .add_argument ('--output-unit-size' , type = int ,
151
148
default = 16 , help = 'output unit size. default=16' )
149
+ parser .add_argument ('--num-routing' , type = int ,
150
+ default = 3 , help = 'number of routing iteration. default=3' )
152
151
153
152
args = parser .parse_args ()
154
153
155
154
print (args )
156
155
157
156
# Check GPU or CUDA is available
158
- cuda = args .cuda
159
- if cuda and not torch .cuda .is_available ():
160
- print (
161
- "ERROR: No GPU/cuda is not available. Try running on CPU or run without --cuda" )
162
- sys .exit (1 )
157
+ args .cuda = not args .no_cuda and torch .cuda .is_available ()
163
158
164
159
torch .manual_seed (args .seed )
165
- if cuda :
160
+ if args . cuda :
166
161
torch .cuda .manual_seed (args .seed )
167
162
168
163
# Load data
@@ -174,10 +169,11 @@ def main():
174
169
num_primary_unit = args .num_primary_unit ,
175
170
primary_unit_size = args .primary_unit_size ,
176
171
output_unit_size = args .output_unit_size ,
177
- cuda = args .cuda )
172
+ num_routing = args .num_routing ,
173
+ cuda_enabled = args .cuda )
178
174
179
- if cuda :
180
- model = model .cuda ()
175
+ if args . cuda :
176
+ model .cuda ()
181
177
182
178
optimizer = optim .Adam (model .parameters (), lr = args .lr )
183
179
0 commit comments