@@ -878,8 +878,13 @@ def from_acq_h5_fast(cls, fname, comm=None, freq_sel=None, start=None, stop=None
878878 # Datasets read by andata (should be small)
879879 DSET_CORE = ["flags/inputs" , "flags/frac_lost" , "flags/dataset_id" ]
880880 # Datasets read directly and then inserted after the fact
881- # (should have an input/product/stack axis, as axis=1)
882- DSETS_DIRECT = ["vis" , "gain" , "flags/vis_weight" ]
881+ # Dictionary entries are dataset name : axis to distribute over,
882+ # and are specified here based on performance tests
883+ DSETS_DIRECT = {
884+ "vis" : 1 ,
885+ "gain" : 0 if (freq_sel is None ) or (freq_sel == slice (None )) else 1 ,
886+ "flags/vis_weight" : 1 ,
887+ }
883888
884889 if comm is None :
885890 comm = MPI .COMM_WORLD
@@ -912,14 +917,13 @@ def from_acq_h5_fast(cls, fname, comm=None, freq_sel=None, start=None, stop=None
912917 sel = (freq_sel , slice (None ), time_sel )
913918
914919 with misc .open_h5py_mpi (fname , "r" , comm = comm ) as fh :
915- for ds_name in DSETS_DIRECT :
920+ for ds_name , ds_axis in DSETS_DIRECT . items () :
916921 if ds_name not in fh :
917922 continue
918923
919- # Read dataset directly (distributed over input/product/stack axis) and
920- # add to container
924+ # Read dataset directly and add to container
921925 arr = mpiarray .MPIArray .from_hdf5 (
922- fh , ds_name , comm = comm , axis = 1 , sel = sel
926+ fh , ds_name , comm = comm , axis = ds_axis , sel = sel
923927 )
924928 arr = arr .redistribute (axis = 0 )
925929 dset = ad .create_dataset (ds_name , data = arr , distributed = True )
0 commit comments