diff --git a/test_latency.py b/test_latency.py index f67c20a..e5c8d96 100644 --- a/test_latency.py +++ b/test_latency.py @@ -9,7 +9,7 @@ import torchvision import time -import models.vanillanet +from models.vanillanet import * if __name__ == "__main__": @@ -22,6 +22,7 @@ net = vanillanet_5().cuda() net.eval() + net.switch_to_deploy() print(net) for img, target in data_loader_val: img = img.cuda()