-
Notifications
You must be signed in to change notification settings - Fork 45
/
Copy pathtest_model.py
82 lines (60 loc) · 2.46 KB
/
test_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from imread_from_url import imread_from_url
from nets import Model
device = 'cuda'
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
imgL = left.transpose(2, 0, 1)
imgR = right.transpose(2, 0, 1)
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
# print(imgR_dw2.shape)
with torch.inference_mode():
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp
if __name__ == '__main__':
left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png")
right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png")
in_h, in_w = left_img.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
assert eval_h%8 == 0, "input height should be divisible by 8"
assert eval_w%8 == 0, "input width should be divisible by 8"
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
model_path = "models/crestereo_eth3d.pth"
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.load_state_dict(torch.load(model_path), strict=True)
model.to(device)
model.eval()
pred = inference(imgL, imgR, model, n_iter=20)
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
disp_vis = disp_vis.astype("uint8")
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
combined_img = np.hstack((left_img, disp_vis))
cv2.namedWindow("output", cv2.WINDOW_NORMAL)
cv2.imshow("output", combined_img)
cv2.imwrite("output.jpg", disp_vis)
cv2.waitKey(0)