From 3dbbb441d6d41cbf1949e7b4f80ca23ea4ad5838 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Mon, 16 May 2022 17:09:08 +0530 Subject: [PATCH 1/2] remove manual no_grad and train mode --- .../models/mmdet/lightning/model_adapter.py | 13 ++++-------- .../efficientdet/lightning/model_adapter.py | 17 +++++++-------- .../torchvision/lightning_model_adapter.py | 21 +++++++++---------- 3 files changed, 22 insertions(+), 29 deletions(-) diff --git a/icevision/models/mmdet/lightning/model_adapter.py b/icevision/models/mmdet/lightning/model_adapter.py index ef28b64fe..ca1699d0c 100644 --- a/icevision/models/mmdet/lightning/model_adapter.py +++ b/icevision/models/mmdet/lightning/model_adapter.py @@ -52,12 +52,10 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): data, records = batch - self.model.eval() - with torch.no_grad(): - outputs = self.model.train_step(data=data, optimizer=None) - raw_preds = self.model.forward_test( - imgs=[data["img"]], img_metas=[data["img_metas"]] - ) + outputs = self.model.train_step(data=data, optimizer=None) + raw_preds = self.model.forward_test( + imgs=[data["img"]], img_metas=[data["img_metas"]] + ) preds = self.convert_raw_predictions( batch=data, raw_preds=raw_preds, records=records @@ -67,8 +65,5 @@ def validation_step(self, batch, batch_idx): for k, v in outputs["log_vars"].items(): self.log(f"valid/{k}", v) - # TODO: is train and eval model automatically set by lighnting? - self.model.train() - def validation_epoch_end(self, outs): self.finalize_metrics() diff --git a/icevision/models/ross/efficientdet/lightning/model_adapter.py b/icevision/models/ross/efficientdet/lightning/model_adapter.py index 0b1d5037d..99c802271 100644 --- a/icevision/models/ross/efficientdet/lightning/model_adapter.py +++ b/icevision/models/ross/efficientdet/lightning/model_adapter.py @@ -41,15 +41,14 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): (xb, yb), records = batch - with torch.no_grad(): - raw_preds = self(xb, yb) - preds = efficientdet.convert_raw_predictions( - batch=(xb, yb), - raw_preds=raw_preds["detections"], - records=records, - detection_threshold=0.0, - ) - loss = efficientdet.loss_fn(raw_preds, yb) + raw_preds = self(xb, yb) + preds = efficientdet.convert_raw_predictions( + batch=(xb, yb), + raw_preds=raw_preds["detections"], + records=records, + detection_threshold=0.0, + ) + loss = efficientdet.loss_fn(raw_preds, yb) self.accumulate_metrics(preds) diff --git a/icevision/models/torchvision/lightning_model_adapter.py b/icevision/models/torchvision/lightning_model_adapter.py index 1e14b98db..7d2427952 100644 --- a/icevision/models/torchvision/lightning_model_adapter.py +++ b/icevision/models/torchvision/lightning_model_adapter.py @@ -36,17 +36,16 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): (xb, yb), records = batch - with torch.no_grad(): - self.train() - train_preds = self(xb, yb) - loss = loss_fn(train_preds, yb) - - self.eval() - raw_preds = self(xb) - preds = self.convert_raw_predictions( - batch=batch, raw_preds=raw_preds, records=records - ) - self.accumulate_metrics(preds=preds) + self.train() + train_preds = self(xb, yb) + loss = loss_fn(train_preds, yb) + + self.eval() + raw_preds = self(xb) + preds = self.convert_raw_predictions( + batch=batch, raw_preds=raw_preds, records=records + ) + self.accumulate_metrics(preds=preds) self.log("val_loss", loss) From 13bb3e72279cdc5aa22d047c7e33d2c05ec31861 Mon Sep 17 00:00:00 2001 From: Aniket Maurya Date: Mon, 16 May 2022 17:13:57 +0530 Subject: [PATCH 2/2] remove manual no_grad and train mode --- .../yolov5/lightning/model_adapter.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/icevision/models/ultralytics/yolov5/lightning/model_adapter.py b/icevision/models/ultralytics/yolov5/lightning/model_adapter.py index f967f097e..5a9ec93d6 100644 --- a/icevision/models/ultralytics/yolov5/lightning/model_adapter.py +++ b/icevision/models/ultralytics/yolov5/lightning/model_adapter.py @@ -42,16 +42,15 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx): (xb, yb), records = batch - with torch.no_grad(): - inference_out, training_out = self(xb) - preds = yolov5.convert_raw_predictions( - batch=xb, - raw_preds=inference_out, - records=records, - detection_threshold=0.001, - nms_iou_threshold=0.6, - ) - loss = self.compute_loss(training_out, yb)[0] + inference_out, training_out = self(xb) + preds = yolov5.convert_raw_predictions( + batch=xb, + raw_preds=inference_out, + records=records, + detection_threshold=0.001, + nms_iou_threshold=0.6, + ) + loss = self.compute_loss(training_out, yb)[0] self.accumulate_metrics(preds)