1010import numpy as np
1111import torch
1212import tyro
13+ import wadler_lindig as wl
14+ from abstract_dataloader import spec
1315from omegaconf import DictConfig
1416from roverd .channels .utils import Prefetch
1517from roverd .sensors import DynamicSensor
1820from nrdk .framework import Result
1921
2022
23+ class _DatasetMeta (spec .Dataset ):
24+ def __init__ (
25+ self , dataset : spec .Dataset [dict [str , Any ]], meta : Any
26+ ) -> None :
27+ self .dataset = dataset
28+ self .meta = meta
29+
30+ def __getitem__ (self , index : int | np .integer ) -> dict [str , Any ]:
31+ return {"meta" : self .meta , ** self .dataset [index ]}
32+
33+ def __len__ (self ) -> int :
34+ return len (self .dataset )
35+
36+
2137def _get_dataloaders (
2238 cfg : DictConfig , data_root : str , transforms : Any ,
2339 traces : list [str ] | None = None , filter : str | None = None ,
@@ -36,19 +52,28 @@ def _get_dataloaders(
3652 os .path .relpath (t , cfg ["meta" ]["dataset" ])
3753 for t in hydra .utils .instantiate (
3854 cfg ["datamodule" ]["traces" ]["test" ])]
55+
56+ _unfiltered = traces
3957 if filter is not None :
4058 traces = [t for t in traces if re .match (filter , t )]
59+ if len (traces ) == 0 :
60+ raise ValueError (
61+ f"No traces match the filter { filter } :\n "
62+ f"{ wl .pprint (_unfiltered )} " )
4163
4264 def construct (t : str ) -> torch .utils .data .DataLoader :
43- dataset = dataset_constructor (paths = [t ])
65+ dataset = _DatasetMeta (
66+ dataset_constructor (paths = [t ]),
67+ meta = {"train" : False , "split" : "test" })
68+
4469 return datamodule .dataloader (dataset , mode = "test" )
4570
4671 return {
4772 t : partial (construct , os .path .join (data_root , t )) for t in traces }
4873
4974
5075def evaluate (
51- path : str , / , sample : int | None = None ,
76+ path : str , / , output : str | None = None , sample : int | None = None ,
5277 traces : list [str ] | None = None , filter : str | None = None ,
5378 data_root : str | None = None ,
5479 device : str = "cuda:0" ,
@@ -80,6 +105,7 @@ def evaluate(
80105
81106 Args:
82107 path: path to results directory.
108+ output: if specified, write results to this directory instead.
83109 sample: number of samples to evaluate.
84110 traces: explicit list of traces to evaluate.
85111 filter: evaluate all traces matching this regex.
@@ -90,11 +116,16 @@ def evaluate(
90116 workers: number of workers for data loading.
91117 prefetch: number of batches to prefetch per worker.
92118 """
119+ torch .set_float32_matmul_precision ('high' )
120+
93121 result = Result (path )
94122 cfg = result .config ()
95123 if sample is not None :
96124 cfg ["datamodule" ]["subsample" ]["test" ] = sample
97125
126+ if output is None :
127+ output = os .path .join (path , "eval" )
128+
98129 if data_root is None :
99130 data_root = cfg ["meta" ]["dataset" ]
100131 if data_root is None :
@@ -107,6 +138,7 @@ def evaluate(
107138 cfg ["datamodule" ]["batch_size" ] = batch
108139 cfg ["datamodule" ]["num_workers" ] = workers
109140 cfg ["datamodule" ]["prefetch_factor" ] = prefetch
141+ cfg ["lightningmodule" ]["compile" ] = False
110142
111143 transforms = hydra .utils .instantiate (cfg ["transforms" ])
112144 lightningmodule = hydra .utils .instantiate (
@@ -120,7 +152,7 @@ def evaluate(
120152 def collect_metadata (y_true ):
121153 return {
122154 f"meta/{ k } /ts" : getattr (v , "timestamps" )
123- for k , v in y_true .items ()
155+ for k , v in y_true .items () if hasattr ( v , "timestamps" )
124156 }
125157
126158 for trace , dl_constructor in dataloaders .items ():
@@ -131,8 +163,7 @@ def collect_metadata(y_true):
131163 total = len (dataloader ), desc = trace )
132164
133165 output_container = DynamicSensor (
134- os .path .join (result .path , "eval" , trace ),
135- create = True , exist_ok = True )
166+ os .path .join (output , trace ), create = True , exist_ok = True )
136167 metrics = []
137168 outputs = {}
138169 for batch_metrics , vis in eval_stream :
@@ -160,7 +191,7 @@ def collect_metadata(y_true):
160191 k : np .concatenate ([m [k ] for m in metrics ], axis = 0 )
161192 for k in metrics [0 ]}
162193 np .savez_compressed (
163- os .path .join (result . path , "eval" , trace , "metrics.npz" ),
194+ os .path .join (output , trace , "metrics.npz" ),
164195 ** metrics , allow_pickle = False )
165196
166197 output_container .create ("ts" , meta = {
0 commit comments