Skip to content

Commit ddd9ee3

Browse files
committed
run with model without dp
1 parent 2abe930 commit ddd9ee3

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

predict.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,21 @@ def __init__(
6363
self.age_head.eval()
6464

6565
def create_model(self, depth, drop_ratio, net_mode, model_path, head):
66-
model = DataParallel(ResNet(depth, drop_ratio, net_mode)).to(self.device)
67-
head = DataParallel(head()).to(self.device)
66+
load_with_module = False
6867

69-
load_state(model=model, head=head, path_to_model=model_path, model_only=True)
68+
model = ResNet(depth, drop_ratio, net_mode)
69+
head = head()
70+
71+
try:
72+
load_state(model=model, head=head, path_to_model=model_path, model_only=True)
73+
except Exception:
74+
load_with_module = True
75+
76+
model = DataParallel(model).to(self.device)
77+
head = DataParallel(head).to(self.device)
78+
79+
if load_with_module:
80+
load_state(model=model, head=head, path_to_model=model_path, model_only=True)
7081

7182
model.eval()
7283
head.eval()

0 commit comments

Comments
 (0)