Skip to content

Commit fd772b5

Browse files
Merge pull request #11 from RadarML/dev/misc
Miscellaneous improvements & initial public release
2 parents 823db82 + 8b367ab commit fd772b5

File tree

12 files changed

+139
-43
lines changed

12 files changed

+139
-43
lines changed

README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# NRDK: Neural Radar Development Kit
22

3+
[![GitHub](https://img.shields.io/github/v/release/RadarML/nrdk)](https://github.com/RadarML/nrdk)
4+
![MIT License](https://img.shields.io/badge/license-MIT-green)
5+
![Supports Python 3.12+](https://img.shields.io/badge/python-3.12%20%7C%203.13-blue)
6+
![Typed](https://img.shields.io/badge/types-typed-limegreen)
7+
[![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.readthedocs.io)
8+
[![CI](https://github.com/RadarML/nrdk/actions/workflows/ci.yml/badge.svg)](https://github.com/RadarML/nrdk/actions/workflows/ci.yml)
9+
![GitHub issues](https://img.shields.io/github/issues/RadarML/nrdk)
10+
311
> [!IMPORTANT]
412
> See our [documentation site](https://radarml.github.io/nrdk) for more details!
513

docs/grt/index.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ example-project/
3030

3131
Don't forget to change the `name`, `authors`, and `description`!
3232

33-
2. Set up the `nrdk` dependency.
33+
2. Set up the `nrdk` dependency (`nrdk[roverd] >= 0.1.5`).
3434

3535
!!! warning "Required Extras"
3636

@@ -71,3 +71,9 @@ The GRT template includes reference training scripts which can be used for high
7171
```python title="grt/train_minimal.py"
7272
--8<-- "grt/train_minimal.py"
7373
```
74+
75+
## Evaluation Script
76+
77+
::: evaluate.evaluate
78+
options:
79+
show_root_heading: false

docs/index.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
# NRDK: Neural Radar Development Kit
22

3+
[![GitHub](https://img.shields.io/github/v/release/RadarML/nrdk)](https://github.com/RadarML/nrdk)
4+
![MIT License](https://img.shields.io/badge/license-MIT-green)
5+
![Supports Python 3.12+](https://img.shields.io/badge/python-3.12%20%7C%203.13-blue)
6+
![Typed](https://img.shields.io/badge/types-typed-limegreen)
7+
[![bear-ified](https://raw.githubusercontent.com/beartype/beartype-assets/main/badge/bear-ified.svg)](https://beartype.readthedocs.io)
8+
[![CI](https://github.com/RadarML/nrdk/actions/workflows/ci.yml/badge.svg)](https://github.com/RadarML/nrdk/actions/workflows/ci.yml)
9+
![GitHub issues](https://img.shields.io/github/issues/RadarML/nrdk)
10+
311
The **Neural Radar Development Kit** (NRDK) is an open-source and MIT-licensed Python library and framework for developing, training, and evaluating machine learning models on radar spectrum and multimodal sensor data.
412

513
Built around typed, high modular interfaces, the NRDK is designed to reduce the barrier of entry to learning on spectrum via out-of-the-box reference implementations for [red-rover](https://radarml.github.io/red-rover/) data and the [I/Q-1M Dataset](https://radarml.github.io/red-rover/iq1m/), while also providing an easy path towards customization and extensions for other radar and data collection systems.

grt/evaluate.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import numpy as np
1111
import torch
1212
import tyro
13+
import wadler_lindig as wl
14+
from abstract_dataloader import spec
1315
from omegaconf import DictConfig
1416
from roverd.channels.utils import Prefetch
1517
from roverd.sensors import DynamicSensor
@@ -18,6 +20,20 @@
1820
from nrdk.framework import Result
1921

2022

23+
class _DatasetMeta(spec.Dataset):
24+
def __init__(
25+
self, dataset: spec.Dataset[dict[str, Any]], meta: Any
26+
) -> None:
27+
self.dataset = dataset
28+
self.meta = meta
29+
30+
def __getitem__(self, index: int | np.integer) -> dict[str, Any]:
31+
return {"meta": self.meta, **self.dataset[index]}
32+
33+
def __len__(self) -> int:
34+
return len(self.dataset)
35+
36+
2137
def _get_dataloaders(
2238
cfg: DictConfig, data_root: str, transforms: Any,
2339
traces: list[str] | None = None, filter: str | None = None,
@@ -36,19 +52,28 @@ def _get_dataloaders(
3652
os.path.relpath(t, cfg["meta"]["dataset"])
3753
for t in hydra.utils.instantiate(
3854
cfg["datamodule"]["traces"]["test"])]
55+
56+
_unfiltered = traces
3957
if filter is not None:
4058
traces = [t for t in traces if re.match(filter, t)]
59+
if len(traces) == 0:
60+
raise ValueError(
61+
f"No traces match the filter {filter}:\n"
62+
f"{wl.pprint(_unfiltered)}")
4163

4264
def construct(t: str) -> torch.utils.data.DataLoader:
43-
dataset = dataset_constructor(paths=[t])
65+
dataset = _DatasetMeta(
66+
dataset_constructor(paths=[t]),
67+
meta={"train": False, "split": "test"})
68+
4469
return datamodule.dataloader(dataset, mode="test")
4570

4671
return {
4772
t: partial(construct, os.path.join(data_root, t)) for t in traces}
4873

4974

5075
def evaluate(
51-
path: str, /, sample: int | None = None,
76+
path: str, /, output: str | None = None, sample: int | None = None,
5277
traces: list[str] | None = None, filter: str | None = None,
5378
data_root: str | None = None,
5479
device: str = "cuda:0",
@@ -80,6 +105,7 @@ def evaluate(
80105
81106
Args:
82107
path: path to results directory.
108+
output: if specified, write results to this directory instead.
83109
sample: number of samples to evaluate.
84110
traces: explicit list of traces to evaluate.
85111
filter: evaluate all traces matching this regex.
@@ -90,11 +116,16 @@ def evaluate(
90116
workers: number of workers for data loading.
91117
prefetch: number of batches to prefetch per worker.
92118
"""
119+
torch.set_float32_matmul_precision('high')
120+
93121
result = Result(path)
94122
cfg = result.config()
95123
if sample is not None:
96124
cfg["datamodule"]["subsample"]["test"] = sample
97125

126+
if output is None:
127+
output = os.path.join(path, "eval")
128+
98129
if data_root is None:
99130
data_root = cfg["meta"]["dataset"]
100131
if data_root is None:
@@ -107,6 +138,7 @@ def evaluate(
107138
cfg["datamodule"]["batch_size"] = batch
108139
cfg["datamodule"]["num_workers"] = workers
109140
cfg["datamodule"]["prefetch_factor"] = prefetch
141+
cfg["lightningmodule"]["compile"] = False
110142

111143
transforms = hydra.utils.instantiate(cfg["transforms"])
112144
lightningmodule = hydra.utils.instantiate(
@@ -120,7 +152,7 @@ def evaluate(
120152
def collect_metadata(y_true):
121153
return {
122154
f"meta/{k}/ts": getattr(v, "timestamps")
123-
for k, v in y_true.items()
155+
for k, v in y_true.items() if hasattr(v, "timestamps")
124156
}
125157

126158
for trace, dl_constructor in dataloaders.items():
@@ -131,8 +163,7 @@ def collect_metadata(y_true):
131163
total=len(dataloader), desc=trace)
132164

133165
output_container = DynamicSensor(
134-
os.path.join(result.path, "eval", trace),
135-
create=True, exist_ok=True)
166+
os.path.join(output, trace), create=True, exist_ok=True)
136167
metrics = []
137168
outputs = {}
138169
for batch_metrics, vis in eval_stream:
@@ -160,7 +191,7 @@ def collect_metadata(y_true):
160191
k: np.concatenate([m[k] for m in metrics], axis=0)
161192
for k in metrics[0]}
162193
np.savez_compressed(
163-
os.path.join(result.path, "eval", trace, "metrics.npz"),
194+
os.path.join(output, trace, "metrics.npz"),
164195
**metrics, allow_pickle=False)
165196

166197
output_container.create("ts", meta={

grt/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ description = "GRT/NRDK Reference Implementation"
88
requires-python = ">=3.12"
99

1010
dependencies = [
11-
"nrdk[roverd] >= 0.1.1",
11+
"nrdk[roverd] >= 0.1.5",
1212
"pyyaml >= 5.0.0",
1313
"hydra-core >= 1.3.0",
1414
"lightning >= 2.5.5",

grt/train.py

Lines changed: 52 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,88 @@
1-
"""GRT reference implementation training script."""
1+
"""GRT Reference implementation training script."""
22

33
import logging
44
import os
5+
from collections.abc import Mapping
56
from time import perf_counter
67
from typing import Any
78

89
import hydra
910
import torch
1011
import yaml
1112
from lightning.pytorch import callbacks
13+
from omegaconf import DictConfig
14+
from rich.logging import RichHandler
15+
16+
from nrdk.framework import Result
17+
18+
logger = logging.getLogger("train")
19+
20+
def _configure_logging(cfg: DictConfig) -> None:
21+
log_level = cfg.meta.get("verbose", logging.INFO)
22+
root = logging.getLogger()
23+
root.setLevel(log_level)
24+
root.handlers.clear()
25+
26+
rich_handler = RichHandler(markup=True)
27+
rich_handler.setFormatter(logging.Formatter(
28+
"[orange1]%(name)s:[/orange1] %(message)s"))
29+
root.addHandler(rich_handler)
30+
31+
logger.debug(f"Configured with log level: {log_level}")
32+
33+
34+
def _load_weights(lightningmodule, path: str, rename: Mapping = {}) -> None:
35+
weights = Result(path).best if os.path.isdir(path) else path
36+
lightningmodule.load_weights(weights, rename=rename)
37+
38+
39+
def _get_best(trainer) -> dict[str, Any]:
40+
for callback in trainer.callbacks:
41+
if isinstance(callback, callbacks.ModelCheckpoint):
42+
return {
43+
"best_k": {
44+
os.path.basename(k): v.item()
45+
for k, v in callback.best_k_models.items()},
46+
"best": os.path.basename(callback.best_model_path)
47+
}
48+
return {}
1249

1350

1451
@hydra.main(version_base=None, config_path="./config", config_name="default")
15-
def train(cfg):
52+
def train(cfg: DictConfig) -> None:
1653
"""Train a model using the GRT reference implementation."""
1754
torch.set_float32_matmul_precision('high')
55+
_configure_logging(cfg)
56+
57+
if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
58+
logger.error("Must set `meta.name` and `meta.version` in the config.")
59+
return
1860

1961
def _inst(path, *args, **kwargs):
2062
return hydra.utils.instantiate(
2163
cfg[path], _convert_="all", *args, **kwargs)
2264

23-
if cfg["meta"]["name"] is None or cfg["meta"]["version"] is None:
24-
logging.error("Must set `meta.name` and `meta.version` in the config.")
25-
return
26-
2765
transforms = _inst("transforms")
2866
datamodule = _inst("datamodule", transforms=transforms)
2967
lightningmodule = _inst("lightningmodule", transforms=transforms)
3068
trainer = _inst("trainer")
31-
3269
if "base" in cfg:
33-
lightningmodule.load_weights(
34-
cfg['base']['path'], rename=cfg['base'].get('rename', {}))
70+
_load_weights(lightningmodule, **cfg['base'])
3571

3672
start = perf_counter()
73+
logger.info(
74+
f"Start training @ {cfg["meta"]["results"]}/{cfg["meta"]["name"]}/"
75+
f"{cfg["meta"]["version"]} [t={start:.3f}]")
3776
trainer.fit(
3877
model=lightningmodule, datamodule=datamodule,
3978
ckpt_path=cfg['meta']['resume'])
4079
duration = perf_counter() - start
80+
logger.info(
81+
f"Training completed in {duration / 60 / 60:.2f}h (={duration:.3f}s).")
4182

4283
meta: dict[str, Any] = {"duration": duration}
43-
for callback in trainer.callbacks:
44-
if isinstance(callback, callbacks.ModelCheckpoint):
45-
meta["best_k"] = {
46-
os.path.basename(k): v.item()
47-
for k, v in callback.best_k_models.items()}
48-
meta["best"] = os.path.basename(callback.best_model_path)
49-
break
50-
84+
meta.update(_get_best(trainer))
85+
logger.info(f"Best checkpoint: {meta.get('best')}")
5186
meta_path = os.path.join(trainer.logger.log_dir, "checkpoints.yaml")
5287
with open(meta_path, 'w') as f:
5388
yaml.dump(meta, f, sort_keys=False)

grt/uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mkdocs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repo_name: RadarML/nrdk
44
nav:
55
- nrdk:
66
- Overview: index.md
7-
- Architecture: design.md
7+
- Software Architecture: design.md
88
- Core Modules:
99
- nrdk.config: nrdk/config.md
1010
- nrdk.framework: nrdk/framework.md
@@ -14,7 +14,7 @@ nav:
1414
- nrdk.objectives: nrdk/objectives.md
1515
- nrdk.roverd: nrdk/roverd.md
1616
- nrdk.vis: nrdk/vis.md
17-
- CLI Tools: cli.md
17+
- cli: cli.md
1818
- nrdk.tss:
1919
- Overview: tss/index.md
2020
- High Level API: tss/api.md

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "nrdk"
7-
version = "0.1.4"
7+
version = "0.1.5"
88
authors = [
99
{ name="Tianshu Huang", email="tianshu2@andrew.cmu.edu" },
1010
]

src/nrdk/framework/lightningmodule.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import os
55
import re
66
import threading
7-
import warnings
87
from collections.abc import Iterable, Iterator, Mapping, Sequence
98
from functools import cache
109
from typing import Any, Callable, Generic, TypeVar, cast
@@ -91,10 +90,12 @@ def __init__(
9190
) -> None:
9291
super().__init__()
9392

93+
self._log = logging.getLogger("NRDKLightningModule")
94+
9495
if compile:
9596
jt_disable = os.environ.get("JAXTYPING_DISABLE", "0").lower()
9697
if jt_disable not in ("1", "true"):
97-
warnings.warn(
98+
self._log.warning(
9899
"torch.compile is currently incompatible with jaxtyping; "
99100
"if you see type errors, set the environment variable "
100101
"`JAXTYPING_DISABLE=1` to disable jaxtyping checks.")
@@ -108,8 +109,6 @@ def __init__(
108109
self.vis_interval = vis_interval
109110
self.vis_samples = vis_samples
110111

111-
self._log = logging.getLogger(self.__class__.__name__)
112-
113112
@torch.compiler.disable
114113
def load_weights(
115114
self, path: str, rename: Sequence[Mapping[str, str | None]] = []
@@ -153,9 +152,14 @@ def load_weights(
153152
if "model" in weights:
154153
weights = weights["model"]
155154

156-
weights = {
157-
k[6:] if k.startswith("model.")
158-
else k: v for k, v in weights.items()}
155+
def _strip_prefix(k):
156+
if k.startswith("model."):
157+
k = k[6:]
158+
if k.startswith("_orig_mod."):
159+
k = k[10:]
160+
return k
161+
162+
weights = {_strip_prefix(k): v for k, v in weights.items()}
159163
for pattern in rename:
160164
pat, sub = next(iter(pattern.items()))
161165
if sub is None:
@@ -191,7 +195,7 @@ def _make_log(
191195

192196
if len(images) > 0:
193197
if not isinstance(self.logger, LoggerWithImages):
194-
warnings.warn(
198+
self._log.warning(
195199
"Tried to log visualizations, but the logger does not "
196200
"implement the `LoggerWithImages` interface.")
197201
else:
@@ -222,6 +226,7 @@ def log_visualizations(
222226
y_pred: model output values.
223227
split: train/val split to put in the output path.
224228
"""
229+
self._log.debug(f"Logging visualizations @ i={self.global_step}")
225230
y_true, y_pred = optree.tree_map(
226231
lambda x: x[:self.vis_samples].cpu().detach(),
227232
(y_true, y_pred)) # type: ignore

0 commit comments

Comments
 (0)