Skip to content

Commit 1bd63a1

Browse files
Merge pull request #9 from RadarML/dev/bugfix
Miscellaneous Bugfixes and Improvements
2 parents d68aae4 + 36e3a7b commit 1bd63a1

File tree

5 files changed

+78
-29
lines changed

5 files changed

+78
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "nrdk"
7-
version = "0.1.2"
7+
version = "0.1.3"
88
authors = [
99
{ name="Tianshu Huang", email="tianshu2@andrew.cmu.edu" },
1010
]

src/nrdk/framework/result.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,41 @@ class Result:
3333
└── events.out.tfevents...
3434
```
3535
36+
!!! danger
37+
38+
Since `.hydra` is technically a hidden folder, many file operations
39+
(e.g., `mv experiment_result/*`) will skip or hide this folder by
40+
default.
41+
3642
Args:
3743
path: path to results directory.
44+
validate: check that the path exists, and that it matches the expected
45+
structure.
3846
"""
3947

40-
def __init__(self, path: str) -> None:
48+
def __init__(self, path: str, validate: bool = True) -> None:
4149
self.path = path
50+
if validate:
51+
self.validate(path)
52+
53+
@staticmethod
54+
def validate(path: str) -> None:
55+
"""Validate that the given path is a valid results directory.
56+
57+
Raises:
58+
ValueError: if the path does not exist, or does not have the
59+
expected structure.
60+
"""
61+
if not os.path.exists(path):
62+
raise ValueError(f"Result {path} does not exist.")
63+
if not os.path.exists(os.path.join(path, ".hydra", "config.yaml")):
64+
raise ValueError(
65+
f"Result {path} exists, but does not have an associated "
66+
f"hydra configuration {path}/.hydra/config.yaml.")
67+
if not os.path.exists(os.path.join(path, "checkpoints.yaml")):
68+
raise ValueError(
69+
f"Result {path} exists, does not have a checkpoint index "
70+
f"file {path}/checkpoints.yaml.")
4271

4372
@overload
4473
def config(self, omegaconf: Literal[True] = True) -> DictConfig: ...

src/nrdk/roverd/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def datamodule(
2929
traces: Mapping[str, Sequence[str]],
3030
transforms: spec.Pipeline,
3131
batch_size: int = 32, samples: int | Sequence[int] = 0,
32-
num_workers: int = 32, prefetch_factor: int = 2,
32+
num_workers: int = 32, prefetch_factor: int | None = 2,
3333
subsample: Mapping[str, int | float | None] = {},
3434
ptrain: float = 0.8, pval: float = 0.2
3535
) -> ADLDataModule:

src/nrdk/roverd/lidar.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,17 @@
1212
from .transforms import SpectrumData
1313

1414

15+
def _n_range(
16+
data: types.XWRRadarIQ[np.ndarray] | SpectrumData[np.ndarray]
17+
) -> int:
18+
if isinstance(data, SpectrumData):
19+
_batch, _t, _doppler, _el, _az, n_rng, _ch = data.spectrum.shape
20+
return n_rng
21+
else: # XWRRadarIQ
22+
_batch, _t, _n_slow, _tx, _rx, n_fast = data.iq.shape
23+
return n_fast // 2
24+
25+
1526
@dataclass
1627
class Occupancy3DData(Generic[TArray]):
1728
"""3D occupancy data.
@@ -62,7 +73,7 @@ def __init__(
6273

6374
def __call__(
6475
self, lidar: types.OSDepth[np.ndarray],
65-
radar: SpectrumData[np.ndarray],
76+
radar: SpectrumData[np.ndarray] | types.XWRRadarIQ[np.ndarray],
6677
aug: Mapping[str, Any] = {}
6778
) -> Occupancy3DData[np.ndarray]:
6879
"""Create 3D occupancy map from Lidar depth data.
@@ -87,7 +98,7 @@ def __call__(
8798
if aug.get("azimuth_flip", False):
8899
rng = np.flip(rng, axis=3)
89100

90-
_batch, _t, _doppler, _el, _az, n_rng, _ch = radar.spectrum.shape
101+
n_rng = _n_range(radar)
91102
n_bins = n_rng // self.d_rng
92103
bin = (rng // (radar.range_resolution * self.d_rng)).astype(np.uint16)
93104
bin[bin >= n_bins] = 0
@@ -177,7 +188,7 @@ def __init__(
177188

178189
def __call__(
179190
self, lidar: types.OSDepth[np.ndarray],
180-
radar: SpectrumData[np.ndarray],
191+
radar: SpectrumData[np.ndarray] | types.XWRRadarIQ[np.ndarray],
181192
aug: Mapping[str, Any] = {}
182193
) -> Occupancy2DData[np.ndarray]:
183194
"""Create 2D occupancy map from Lidar depth data.
@@ -216,7 +227,7 @@ def __call__(
216227
r[(z > self.z_max) | (z < self.z_min)] = 0
217228

218229
# Create map
219-
_batch, _t, _doppler, _el, _az, n_rng, _ch = radar.spectrum.shape
230+
n_rng = _n_range(radar)
220231
bin: UInt16[np.ndarray, "T El Az"] = (
221232
r // radar.range_resolution).astype(np.uint16)
222233
bin[bin >= n_rng] = 0

src/nrdk/tss/_cli.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,23 @@
1010

1111
def _cli(
1212
path: str, /,
13-
pattern: str = r"^(?P<experiment>(.*)).npz$",
14-
key: str = "loss", timestamps: str | None = None,
15-
experiments: list[str] | None = None,
16-
baseline: str | None = None,
17-
follow_symlinks: bool = False,
18-
cut: float | None = None,
19-
t_max: int | None = None,
13+
pattern: str | None = None, key: str | None = None,
14+
timestamps: str | None = None, experiments: list[str] | None = None,
15+
baseline: str | None = None, follow_symlinks: bool = False,
16+
cut: float | None = None, t_max: int | None = None,
2017
config: str | None = None,
2118
) -> int:
2219
"""Calculate statistics for time series metrics.
2320
24-
- pipe `tss ... > results.csv` to save the results to a file
25-
- use `--config config.yaml` to avoid having to specify all these arguments
21+
- pipe `tss ... > results.csv` to save the results to a file.
22+
- use `--config config.yaml` to avoid having to specify all these
23+
arguments; any arguments which are explicitly provided will override
24+
the values in the config file.
25+
26+
!!! warning
27+
28+
`path` (and `--follow_symlinks`, if specified) are required to be
29+
passed via the command line, and cannot be specified via the config.
2630
2731
Args:
2832
path: directory to find evaluations in.
@@ -43,18 +47,23 @@ def _cli(
4347
if config is not None:
4448
with open(config) as f:
4549
cfg = yaml.safe_load(f)
46-
return _cli(
47-
path,
48-
pattern=cfg.get("pattern", r"^(?P<experiment>(.*)).npz$"),
49-
experiments=cfg.get("experiments", None),
50-
key=cfg.get("key", "loss"),
51-
timestamps=cfg.get("timestamps", None),
52-
baseline=cfg.get("baseline", None),
53-
cut=cfg.get("cut", None),
54-
t_max=cfg.get("t_max", None),
55-
follow_symlinks=follow_symlinks)
56-
57-
index = api.index(path, pattern=pattern, follow_symlinks=follow_symlinks)
50+
else:
51+
cfg = {}
52+
53+
def setdefault(value, param, default):
54+
if value is None:
55+
value = cfg.get(param, default)
56+
return value
57+
58+
pattern = setdefault(pattern, "pattern", r"^(?P<experiment>(.*)).npz$")
59+
key = setdefault(key, "key", "loss")
60+
timestamps = setdefault(timestamps, "timestamps", None)
61+
baseline = setdefault(baseline, "baseline", None)
62+
cut = setdefault(cut, "cut", None)
63+
t_max = setdefault(t_max, "t_max", None)
64+
65+
index = api.index(
66+
path, pattern=pattern, follow_symlinks=follow_symlinks) # type: ignore
5867

5968
if len(index) == 0:
6069
print("No result files found!")
@@ -64,7 +73,7 @@ def _cli(
6473
return -1
6574

6675
df = api.dataframe_from_index(
67-
index, key=key, baseline=baseline,
76+
index, key=key, baseline=baseline, # type: ignore
6877
experiments=experiments, cut=cut, t_max=t_max, timestamps=timestamps)
6978

7079
buf = StringIO()

0 commit comments

Comments
 (0)