|
9 | 9 | import os |
10 | 10 | import sys |
11 | 11 | from argparse import Namespace |
12 | | -from time import time |
13 | | -from typing import Iterable, Tuple |
| 12 | +from typing import Iterable |
14 | 13 | import logging |
15 | 14 | import torch |
16 | 15 | from tqdm import tqdm |
|
23 | 22 |
|
24 | 23 | from utils import disable_logging |
25 | 24 | from utils.checkpoints import mammoth_load_checkpoint, save_mammoth_checkpoint |
26 | | -from utils.loggers import log_extra_metrics, log_accs, Logger |
| 25 | +from utils.loggers import log_extra_metrics, Logger |
27 | 26 | from utils.schedulers import get_scheduler |
28 | 27 | from utils.stats import track_system_stats |
29 | 28 |
|
@@ -80,10 +79,7 @@ def train_single_epoch(model: ContinualModel, |
80 | 79 | the number of iterations performed in the current epoch |
81 | 80 | """ |
82 | 81 | train_iter = iter(train_loader) |
83 | | - epoch_len = len(train_loader) if hasattr(train_loader, "__len__") else None |
84 | | - |
85 | 82 | i = 0 |
86 | | - previous_time = time() |
87 | 83 |
|
88 | 84 | while True: |
89 | 85 | try: |
@@ -116,13 +112,7 @@ def train_single_epoch(model: ContinualModel, |
116 | 112 | system_tracker() |
117 | 113 | i += 1 |
118 | 114 |
|
119 | | - time_diff = time() - previous_time |
120 | | - previous_time = time() |
121 | | - bar_log = {'loss': loss, 'lr': model.opt.param_groups[0]['lr']} |
122 | | - if epoch_len: |
123 | | - ep_h = 3600 / (epoch_len * time_diff) |
124 | | - bar_log['ep/h'] = ep_h |
125 | | - pbar.set_postfix(bar_log, refresh=False) |
| 115 | + pbar.set_postfix({'loss': loss, 'lr': model.opt.param_groups[0]['lr']}, refresh=False) |
126 | 116 | pbar.update() |
127 | 117 |
|
128 | 118 | if scheduler is not None and args.scheduler_mode == 'epoch': |
@@ -207,8 +197,8 @@ def train(model: ContinualModel, dataset: ContinualDataset, |
207 | 197 | random_res_class, random_res_task = dataset.evaluate(model, dataset, last=True) # the ugliness of this line is for backward compatibility |
208 | 198 | random_results_class.append(random_res_class) |
209 | 199 | random_results_task.append(random_res_task) |
210 | | - except Exception as e: |
211 | | - logging.info(f"Could not evaluate before `begin_task`, will try after") |
| 200 | + except Exception: |
| 201 | + logging.info("Could not evaluate before `begin_task`, will try after") |
212 | 202 | # will try after the begin_task in case the model needs to setup something |
213 | 203 | can_compute_fwd_beforetask = False |
214 | 204 |
|
|
0 commit comments