Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some codes follow PEP8 #3

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions code/keras_classification_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
parser.add_argument('--batch-size', type=int, default=32)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--optimizer', type=str, choices=['sgd', 'adam'], default='adam')
parser.add_argument('--lr-adjust-freq' , type=int, default=10, help='How many epochs per LR adjustment (*=0.1)')
parser.add_argument('--lr-adjust-freq', type=int, default=10, help='How many epochs per LR adjustment (*=0.1)')
# For Adam:
# Should use lr = 0.001
# For SGD:
Expand All @@ -46,19 +46,21 @@
num_classes = 2 if args.label == 'gender' else 4
gen_to_cls = {'male': 0, 'female': 1} if args.label == 'gender' else {'noble': 0, 'warrior': 1, 'incarnation': 2, 'commoner': 3}


def lr_scheduler(epoch):
lr = args.lr * (0.1**(epoch // args.lr_adjust_freq))
return lr


class ModelSequence(k.utils.Sequence):
def __init__(self, df, batch_size):
self.x = df['image'].apply(lambda x: f'{image_dir}/{x}').values
self.u = df[args.label].values
self.batch_size = batch_size

def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))

def load(self, fn):
img = lycon.load(fn)
img = lycon.resize(img, args.image_size, args.image_size)
Expand All @@ -80,6 +82,7 @@ def __getitem__(self, idx):
seq_train = ModelSequence(df[df.set == 'train'], batch_size=args.batch_size)
seq_valid = ModelSequence(df[df.set == 'dev'], batch_size=args.batch_size)


def build_model():
input_tensor = Input(shape=(args.image_size, args.image_size, 3))

Expand All @@ -89,7 +92,6 @@ def build_model():
# input_shape=(args.image_size, args.image_size, 3),
classes=num_classes,
pooling='avg')


output_tensor = Dense(num_classes, activation='softmax')(base_model.output)
model = Model(inputs=input_tensor, outputs=output_tensor)
Expand All @@ -110,7 +112,7 @@ def build_model():
model.fit_generator(seq_train, epochs=args.epochs, verbose=1, validation_data=seq_valid,
callbacks=[LearningRateScheduler(lr_scheduler)])

print('Dev set: ' , model.evaluate_generator(seq_valid, verbose=1))
print('Dev set: ', model.evaluate_generator(seq_valid, verbose=1))

seq_test = ModelSequence(df[df.set == 'test'], batch_size=args.batch_size)
print('Test set: ' , model.evaluate_generator(seq_test, verbose=1))
print('Test set: ', model.evaluate_generator(seq_test, verbose=1))
8 changes: 5 additions & 3 deletions download.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,15 @@
# ssl._create_default_https_context = ssl._create_unverified_context # dirty fix
script_dir = dirname(realpath(__file__))


def load_urls(urls_file):
urls = list(open(urls_file).readlines())
urls = [_.strip('\r\n') for _ in urls] # strip linebreaks
# Removing empty lines from list of URLS (without shifting line indices, which determine filenames)
iurls = [(index, url) for index, url in enumerate(urls) if url]
return iurls


def download_and_check_image(iurl):
index, url = iurl
save_file = join(images_dir, '%08d.jpg' % index)
Expand Down Expand Up @@ -65,7 +67,7 @@ def download_and_check_image(iurl):
redownloading_warning = True
except KeyboardInterrupt:
print('KeyboardInterrupt: Exiting early!')
sys.exit(130) # Avoid humungous backtraces when ctrl+c is pressed
sys.exit(130) # Avoid humungous backtraces when ctrl+c is pressed
except Exception as e:
print('Download failed with {} for {}-th image from {}'.format(index, e, url))

Expand Down Expand Up @@ -113,11 +115,11 @@ def download_and_check_image(iurl):

pool = multiprocessing.Pool(args.threads)

if tqdm: # Use tqdm progressbar
if tqdm: # Use tqdm progressbar
bar = tqdm(total=len(iurls))
for i, _ in enumerate(pool.imap_unordered(download_and_check_image, iurls)):
bar.update()
else: # Use a basic status print
else: # Use a basic status print
for i, _ in enumerate(pool.imap_unordered(download_and_check_image, iurls)):
print('Download images: %7d / %d Done' % (i + 1, len(iurls)), end='\r', flush=True)
print()
Expand Down