-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
85 lines (75 loc) · 2.65 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import random
import numpy as np
import torch
import sys
from loguru import logger
from pathlib import Path
from viewer import CameraState
import json
from typing import Optional, List
def set_global_state(seed: int, device: str):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.set_device(device)
fmt = "<green>{time:MMDD-HH:mm:ss.SSSSSS}</green> | <level>{level:5}</level> | <level>{message}</level>"
level = "DEBUG"
log_config = {
"handlers": [
{"sink": sys.stdout, "format": fmt, "level": level, "enqueue": True}
]
}
logger.configure(**log_config) # type: ignore
def load_camera_states(path: Path) -> List[CameraState]:
camera_states = []
with open(path / "cameras.json", "r") as f:
for cam in json.load(f):
c2w = np.eye(4)
c2w[:3, :3] = np.array(cam["rotation"])
c2w[:3, 3] = np.array(cam["position"])
w2c = np.linalg.inv(c2w)
K = np.array(
[
[cam["fx"], 0, cam["width"] / 2],
[0, cam["fy"], cam["height"] / 2],
[0, 0, 1],
],
dtype=np.float32,
)
camera_states.append(CameraState(w2c, K, cam["width"], cam["height"]))
return camera_states
def load_gaussian_model(
path: Path, iterations: Optional[int] = None
) -> torch.nn.Module:
cpt_lst = [cpt for cpt in (path / "checkpoints").glob("*.pth")]
if iterations is not None:
target_cpt = None
for cpt in cpt_lst:
if cpt.stem == f"iterations_{iterations}":
target_cpt = cpt
break
if target_cpt is None:
raise ValueError(f"cannot find checkpoint for iteration {iterations}")
else:
max_iterations = 0
target_cpt = None
for cpt in cpt_lst:
iterations = int(cpt.stem.split("_")[1])
if iterations > max_iterations:
max_iterations = iterations
target_cpt = cpt
if target_cpt is None:
raise ValueError("no checkpoint found")
logger.info(f"load checkpoint from {target_cpt}")
gaussian_model = torch.load(target_cpt, map_location="cpu").cuda()
return gaussian_model
def save_gaussian_model(
path: Path, gaussian_model: torch.nn.Module, save_optimizer: bool = False
):
tmp_optimizer = None
if not save_optimizer:
tmp_optimizer = gaussian_model.optimizer
gaussian_model.optimizer = None # type: ignore
torch.save(gaussian_model, path)
if tmp_optimizer is not None:
gaussian_model.optimizer = tmp_optimizer