diff --git a/code/keras_classification_train.py b/code/keras_classification_train.py index da985d3..5f80864 100755 --- a/code/keras_classification_train.py +++ b/code/keras_classification_train.py @@ -26,7 +26,7 @@ parser.add_argument('--batch-size', type=int, default=32) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--optimizer', type=str, choices=['sgd', 'adam'], default='adam') -parser.add_argument('--lr-adjust-freq' , type=int, default=10, help='How many epochs per LR adjustment (*=0.1)') +parser.add_argument('--lr-adjust-freq', type=int, default=10, help='How many epochs per LR adjustment (*=0.1)') # For Adam: # Should use lr = 0.001 # For SGD: @@ -46,19 +46,21 @@ num_classes = 2 if args.label == 'gender' else 4 gen_to_cls = {'male': 0, 'female': 1} if args.label == 'gender' else {'noble': 0, 'warrior': 1, 'incarnation': 2, 'commoner': 3} + def lr_scheduler(epoch): lr = args.lr * (0.1**(epoch // args.lr_adjust_freq)) return lr + class ModelSequence(k.utils.Sequence): def __init__(self, df, batch_size): self.x = df['image'].apply(lambda x: f'{image_dir}/{x}').values self.u = df[args.label].values self.batch_size = batch_size - + def __len__(self): return int(np.ceil(len(self.x) / float(self.batch_size))) - + def load(self, fn): img = lycon.load(fn) img = lycon.resize(img, args.image_size, args.image_size) @@ -80,6 +82,7 @@ def __getitem__(self, idx): seq_train = ModelSequence(df[df.set == 'train'], batch_size=args.batch_size) seq_valid = ModelSequence(df[df.set == 'dev'], batch_size=args.batch_size) + def build_model(): input_tensor = Input(shape=(args.image_size, args.image_size, 3)) @@ -89,7 +92,6 @@ def build_model(): # input_shape=(args.image_size, args.image_size, 3), classes=num_classes, pooling='avg') - output_tensor = Dense(num_classes, activation='softmax')(base_model.output) model = Model(inputs=input_tensor, outputs=output_tensor) @@ -110,7 +112,7 @@ def build_model(): model.fit_generator(seq_train, epochs=args.epochs, verbose=1, validation_data=seq_valid, callbacks=[LearningRateScheduler(lr_scheduler)]) -print('Dev set: ' , model.evaluate_generator(seq_valid, verbose=1)) +print('Dev set: ', model.evaluate_generator(seq_valid, verbose=1)) seq_test = ModelSequence(df[df.set == 'test'], batch_size=args.batch_size) -print('Test set: ' , model.evaluate_generator(seq_test, verbose=1)) +print('Test set: ', model.evaluate_generator(seq_test, verbose=1)) diff --git a/download.py b/download.py index 082ff84..b84f747 100755 --- a/download.py +++ b/download.py @@ -31,6 +31,7 @@ # ssl._create_default_https_context = ssl._create_unverified_context # dirty fix script_dir = dirname(realpath(__file__)) + def load_urls(urls_file): urls = list(open(urls_file).readlines()) urls = [_.strip('\r\n') for _ in urls] # strip linebreaks @@ -38,6 +39,7 @@ def load_urls(urls_file): iurls = [(index, url) for index, url in enumerate(urls) if url] return iurls + def download_and_check_image(iurl): index, url = iurl save_file = join(images_dir, '%08d.jpg' % index) @@ -65,7 +67,7 @@ def download_and_check_image(iurl): redownloading_warning = True except KeyboardInterrupt: print('KeyboardInterrupt: Exiting early!') - sys.exit(130) # Avoid humungous backtraces when ctrl+c is pressed + sys.exit(130) # Avoid humungous backtraces when ctrl+c is pressed except Exception as e: print('Download failed with {} for {}-th image from {}'.format(index, e, url)) @@ -113,11 +115,11 @@ def download_and_check_image(iurl): pool = multiprocessing.Pool(args.threads) - if tqdm: # Use tqdm progressbar + if tqdm: # Use tqdm progressbar bar = tqdm(total=len(iurls)) for i, _ in enumerate(pool.imap_unordered(download_and_check_image, iurls)): bar.update() - else: # Use a basic status print + else: # Use a basic status print for i, _ in enumerate(pool.imap_unordered(download_and_check_image, iurls)): print('Download images: %7d / %d Done' % (i + 1, len(iurls)), end='\r', flush=True) print()