Skip to content

Commit

Permalink
add log_f1 argument
Browse files Browse the repository at this point in the history
  • Loading branch information
ahmedbesbes committed Nov 1, 2019
1 parent b8080f6 commit ba37aa4
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,6 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
lr = optimizer.state_dict()["param_groups"][0]["lr"]

if (iter % print_every == 0) and (iter > 0):
intermediate_report = classification_report(
y_true, y_pred, output_dict=True)

f1_by_class = 'F1 Scores by class: '
for class_name in class_names:
f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |"

print("[Training - Epoch: {}], LR: {} , Iteration: {}/{} , Loss: {}, Accuracy: {}".format(
epoch + 1,
lr,
Expand All @@ -93,7 +86,16 @@ def train(model, training_generator, optimizer, criterion, epoch, writer, log_fi
losses.avg,
accuracies.avg
))
print(f1_by_class)

if bool(args.log_f1):
intermediate_report = classification_report(
y_true, y_pred, output_dict=True)

f1_by_class = 'F1 Scores by class: '
for class_name in class_names:
f1_by_class += f"{class_name} : {np.round(intermediate_report[class_name]['f1-score'], 4)} |"

print(f1_by_class)

f1_train = f1_score(y_true, y_pred, average='weighted')

Expand Down Expand Up @@ -403,6 +405,7 @@ def run(args, both_cases=False):
parser.add_argument('--workers', type=int, default=1)
parser.add_argument('--log_path', type=str, default='./logs/')
parser.add_argument('--log_every', type=int, default=100)
parser.add_argument('--log_f1', type=int, default=1, choices=[0, 1])
parser.add_argument('--flush_history', type=int,
default=1, choices=[0, 1])
parser.add_argument('--output', type=str, default='./models/')
Expand Down

0 comments on commit ba37aa4

Please sign in to comment.