Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Managed the import of torch.amp to be compatible with all pytorch versions #13487

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@
import torch.nn as nn
import yaml
from torch.optim import lr_scheduler

try:
import torch.amp as amp
except ImportError:
import torch.cuda.amp as amp

from tqdm import tqdm

FILE = Path(__file__).resolve()
Expand Down Expand Up @@ -221,7 +227,7 @@ def train(hyp, opt, device, callbacks):
LOGGER.info(f"Transferred {len(csd)}/{len(model.state_dict())} items from {weights}") # report
else:
model = Model(cfg, ch=3, nc=nc, anchors=hyp.get("anchors")).to(device) # create
amp = check_amp(model) # check AMP
use_amp = check_amp(model) # check AMP

# Freeze
freeze = [f"model.{x}." for x in (freeze if len(freeze) > 1 else range(freeze[0]))] # layers to freeze
Expand All @@ -238,7 +244,7 @@ def train(hyp, opt, device, callbacks):

# Batch size
if RANK == -1 and batch_size == -1: # single-GPU only, estimate best batch size
batch_size = check_train_batch_size(model, imgsz, amp)
batch_size = check_train_batch_size(model, imgsz, use_amp)
loggers.on_params_update({"batch_size": batch_size})

# Optimizer
Expand Down Expand Up @@ -352,7 +358,8 @@ def lf(x):
maps = np.zeros(nc) # mAP per class
results = (0, 0, 0, 0, 0, 0, 0) # P, R, [email protected], [email protected], val_loss(box, obj, cls)
scheduler.last_epoch = start_epoch - 1 # do not move
scaler = torch.cuda.amp.GradScaler(enabled=amp)
# scaler = torch.cuda.amp.GradScaler(enabled=amp)
scaler = amp.GradScaler(enabled=use_amp)
stopper, stop = EarlyStopping(patience=opt.patience), False
compute_loss = ComputeLoss(model) # init loss class
callbacks.run("on_train_start")
Expand Down Expand Up @@ -409,7 +416,8 @@ def lf(x):
imgs = nn.functional.interpolate(imgs, size=ns, mode="bilinear", align_corners=False)

# Forward
with torch.cuda.amp.autocast(amp):
# with torch.cuda.amp.autocast(amp):
with amp.autocast(enabled=use_amp, device_type=device.type):
pred = model(imgs) # forward
loss, loss_items = compute_loss(pred, targets.to(device)) # loss scaled by batch_size
if RANK != -1:
Expand Down Expand Up @@ -458,7 +466,7 @@ def lf(x):
data_dict,
batch_size=batch_size // WORLD_SIZE * 2,
imgsz=imgsz,
half=amp,
half=use_amp,
model=ema.ema,
single_cls=single_cls,
dataloader=val_loader,
Expand Down
6 changes: 3 additions & 3 deletions utils/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,9 +355,9 @@ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vi
self._new_video(videos[0]) # new video
else:
self.cap = None
assert self.nf > 0, (
f"No images or videos found in {p}. Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"
)
assert (
self.nf > 0
), f"No images or videos found in {p}. Supported formats are:\nimages: {IMG_FORMATS}\nvideos: {VID_FORMATS}"

def __iter__(self):
"""Initializes iterator by resetting count and returns the iterator object itself."""
Expand Down
6 changes: 3 additions & 3 deletions utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,9 @@ def check_file(file, suffix=""):
assert Path(file).exists() and Path(file).stat().st_size > 0, f"File download failed: {url}" # check
return file
elif file.startswith("clearml://"): # ClearML Dataset ID
assert "clearml" in sys.modules, (
"ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
)
assert (
"clearml" in sys.modules
), "ClearML is not installed, so cannot use ClearML dataset. Try running 'pip install clearml'."
return file
else: # search
files = []
Expand Down
8 changes: 5 additions & 3 deletions utils/loggers/clearml/clearml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,11 @@ def construct_dataset(clearml_info_string):
with open(yaml_filenames[0]) as f:
dataset_definition = yaml.safe_load(f)

assert set(dataset_definition.keys()).issuperset({"train", "test", "val", "nc", "names"}), (
"The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"
)
assert set(
dataset_definition.keys()
).issuperset(
{"train", "test", "val", "nc", "names"}
), "The right keys were not found in the yaml file, make sure it at least has the following keys: ('train', 'test', 'val', 'nc', 'names')"

data_dict = {
"train": (
Expand Down
6 changes: 3 additions & 3 deletions utils/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def select_device(device="", batch_size=0, newline=True):
os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
elif device: # non-cpu device requested
os.environ["CUDA_VISIBLE_DEVICES"] = device # set environment variable - must be before assert is_available()
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(",", "")), (
f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
)
assert torch.cuda.is_available() and torch.cuda.device_count() >= len(
device.replace(",", "")
), f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"

if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
devices = device.split(",") if device else "0" # range(torch.cuda.device_count()) # i.e. 0,1,6,7
Expand Down
Loading