-
Notifications
You must be signed in to change notification settings - Fork 108
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Xarray: add indexes options and better define band names #764
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from __future__ import annotations | ||
|
||
import warnings | ||
from typing import Any, Dict, List, Optional | ||
from typing import Any, Dict, List, Optional, Tuple | ||
|
||
import attr | ||
import numpy | ||
|
@@ -28,8 +28,13 @@ | |
from rio_tiler.io.base import BaseReader | ||
from rio_tiler.models import BandStatistics, ImageData, Info, PointData | ||
from rio_tiler.reader import _get_width_height | ||
from rio_tiler.types import BBox, NoData, RIOResampling, WarpResampling | ||
from rio_tiler.utils import CRS_to_uri, _validate_shape_input, get_array_statistics | ||
from rio_tiler.types import BBox, Indexes, NoData, RIOResampling, WarpResampling | ||
from rio_tiler.utils import ( | ||
CRS_to_uri, | ||
_validate_shape_input, | ||
cast_to_sequence, | ||
get_array_statistics, | ||
) | ||
|
||
try: | ||
import xarray | ||
|
@@ -105,6 +110,7 @@ def __attrs_post_init__(self): | |
for d in self.input.dims | ||
if d not in [self.input.rio.x_dim, self.input.rio.y_dim] | ||
] | ||
assert len(self._dims) in [0, 1], "Can't handle >=4D DataArray" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a check to make sure we don't have 4D arrays |
||
|
||
@property | ||
def minzoom(self): | ||
|
@@ -118,29 +124,34 @@ def maxzoom(self): | |
|
||
@property | ||
def band_names(self) -> List[str]: | ||
"""Return list of `band names` in DataArray.""" | ||
return [str(band) for d in self._dims for band in self.input[d].values] or [ | ||
"value" | ||
] | ||
"""Return list of `band descriptions` in DataArray.""" | ||
vincentsarago marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not self._dims: | ||
coords_name = list(self.input.coords) | ||
if len(coords_name) > 3 and (coord := coords_name[2]): | ||
return [str(self.input.coords[coord].data)] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm hesitant to put something like but we don't do this for the other band names There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's not super easy to understand what this code is trying to accomplish, but I think it's problematic that band names are based on dimensions if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
@maxrjones I'm not sure to get this The main issue is that when we pass a 2D dataarray (thinks about if you select the first time dim) then from datetime import datetime
import numpy
import xarray
import rioxarray
arr = numpy.arange(0.0, 33 * 35 * 2).reshape(2, 33, 35)
data = xarray.DataArray(
arr,
dims=("time", "y", "x"),
coords={
"x": numpy.arange(-170, 180, 10),
"y": numpy.arange(-80, 85, 5),
"time": [datetime(2022, 1, 1), datetime(2022, 1, 2)],
},
)
data.attrs.update({"valid_min": arr.min(), "valid_max": arr.max()})
ds = data.to_dataset(name="dataset")
da = ds["dataset"][0]
da.rio.write_crs("epsg:4326", inplace=True)
da
# >>> <xarray.DataArray 'dataset' (y: 33, x: 35)> Size: 9kB
# array([[0.000e+00, 1.000e+00, 2.000e+00, ..., 3.200e+01, 3.300e+01,
# 3.400e+01],
# [3.500e+01, 3.600e+01, 3.700e+01, ..., 6.700e+01, 6.800e+01,
# 6.900e+01],
# [7.000e+01, 7.100e+01, 7.200e+01, ..., 1.020e+02, 1.030e+02,
# 1.040e+02],
# ...,
# [1.050e+03, 1.051e+03, 1.052e+03, ..., 1.082e+03, 1.083e+03,
# 1.084e+03],
# [1.085e+03, 1.086e+03, 1.087e+03, ..., 1.117e+03, 1.118e+03,
# 1.119e+03],
# [1.120e+03, 1.121e+03, 1.122e+03, ..., 1.152e+03, 1.153e+03,
# 1.154e+03]])
# Coordinates:
# * x (x) int64 280B -170 -160 -150 -140 -130 ... 130 140 150 160 170
# * y (y) int64 264B -80 -75 -70 -65 -60 -55 ... 55 60 65 70 75 80
# time datetime64[ns] 8B 2022-01-01
# spatial_ref int64 8B 0
# Attributes:
# valid_min: 0.0
# valid_max: 2309.0
_dims = [
d
for d in da.dims
if d not in [da.rio.x_dim, da.rio.y_dim]
]
_dims
# >>> []
coords_name = list(da.coords)
coords_name
# >>> ['x', 'y', 'time', 'spatial_ref']
if len(coords_name) > 3 and (coord := coords_name[2]):
print(str(da.coords[coord].data))
# >>> 2022-01-01T00:00:00.000000000 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @maxrjones could you double check the logic here? |
||
|
||
return [self.input.name or "array"] | ||
|
||
return [str(band) for d in self._dims for band in self.input[d].values] | ||
|
||
def info(self) -> Info: | ||
"""Return xarray.DataArray info.""" | ||
bands = [str(band) for d in self._dims for band in self.input[d].values] or [ | ||
"value" | ||
] | ||
metadata = [band.attrs for d in self._dims for band in self.input[d]] or [{}] | ||
|
||
meta = { | ||
"bounds": self.bounds, | ||
"crs": CRS_to_uri(self.crs) or self.crs.to_wkt(), | ||
"band_metadata": [(f"b{ix}", v) for ix, v in enumerate(metadata, 1)], | ||
"band_descriptions": [(f"b{ix}", v) for ix, v in enumerate(bands, 1)], | ||
"band_descriptions": [ | ||
(f"b{ix}", v) for ix, v in enumerate(self.band_names, 1) | ||
], | ||
"dtype": str(self.input.dtype), | ||
"nodata_type": "Nodata" if self.input.rio.nodata is not None else "None", | ||
"name": self.input.name, | ||
"count": self.input.rio.count, | ||
"width": self.input.rio.width, | ||
"height": self.input.rio.height, | ||
"dimensions": self.input.dims, | ||
"attrs": { | ||
k: (v.tolist() if isinstance(v, (numpy.ndarray, numpy.generic)) else v) | ||
for k, v in self.input.attrs.items() | ||
|
@@ -149,19 +160,43 @@ def info(self) -> Info: | |
|
||
return Info(**meta) | ||
|
||
def _sel_indexes( | ||
self, indexes: Optional[Indexes] = None | ||
) -> Tuple[xarray.DataArray, List[str]]: | ||
"""Select `band` indexes in DataArray.""" | ||
ds = self.input | ||
band_names = self.band_names | ||
if indexes := cast_to_sequence(indexes): | ||
assert all(v > 0 for v in indexes), "Indexes value must be >= 1" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. xarray won't complain when we pass |
||
if ds.ndim == 2: | ||
if indexes != (1,): | ||
raise ValueError( | ||
f"Invalid indexes {indexes} for array of shape {ds.shape}" | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for 2D array we still allow |
||
|
||
return ds, band_names | ||
|
||
indexes = [idx - 1 for idx in indexes] | ||
ds = ds[indexes] | ||
band_names = [self.band_names[idx] for idx in indexes] | ||
|
||
return ds, band_names | ||
|
||
def statistics( | ||
self, | ||
categorical: bool = False, | ||
categories: Optional[List[float]] = None, | ||
percentiles: Optional[List[int]] = None, | ||
hist_options: Optional[Dict] = None, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
**kwargs: Any, | ||
) -> Dict[str, BandStatistics]: | ||
"""Return statistics from a dataset.""" | ||
hist_options = hist_options or {} | ||
|
||
ds = self.input | ||
ds, band_names = self._sel_indexes(indexes) | ||
|
||
if nodata is not None: | ||
ds = ds.rio.write_nodata(nodata) | ||
|
||
|
@@ -176,9 +211,7 @@ def statistics( | |
**hist_options, | ||
) | ||
|
||
return { | ||
self.band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats) | ||
} | ||
return {band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats)} | ||
|
||
def tile( | ||
self, | ||
|
@@ -189,6 +222,7 @@ def tile( | |
reproject_method: WarpResampling = "nearest", | ||
auto_expand: bool = True, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
**kwargs: Any, | ||
) -> ImageData: | ||
"""Read a Web Map tile from a dataset. | ||
|
@@ -211,7 +245,8 @@ def tile( | |
f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds" | ||
) | ||
|
||
ds = self.input | ||
ds, band_names = self._sel_indexes(indexes) | ||
|
||
if nodata is not None: | ||
ds = ds.rio.write_nodata(nodata) | ||
|
||
|
@@ -251,7 +286,7 @@ def tile( | |
bounds=tile_bounds, | ||
crs=dst_crs, | ||
dataset_statistics=stats, | ||
band_names=self.band_names, | ||
band_names=band_names, | ||
) | ||
|
||
def part( | ||
|
@@ -262,6 +297,7 @@ def part( | |
reproject_method: WarpResampling = "nearest", | ||
auto_expand: bool = True, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
max_size: Optional[int] = None, | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
|
@@ -294,7 +330,8 @@ def part( | |
|
||
dst_crs = dst_crs or bounds_crs | ||
|
||
ds = self.input | ||
ds, band_names = self._sel_indexes(indexes) | ||
|
||
if nodata is not None: | ||
ds = ds.rio.write_nodata(nodata) | ||
|
||
|
@@ -339,7 +376,7 @@ def part( | |
bounds=ds.rio.bounds(), | ||
crs=ds.rio.crs, | ||
dataset_statistics=stats, | ||
band_names=self.band_names, | ||
band_names=band_names, | ||
) | ||
|
||
output_height = height or img.height | ||
|
@@ -362,6 +399,7 @@ def preview( | |
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
dst_crs: Optional[CRS] = None, | ||
reproject_method: WarpResampling = "nearest", | ||
resampling_method: RIOResampling = "nearest", | ||
|
@@ -388,7 +426,8 @@ def preview( | |
UserWarning, | ||
) | ||
|
||
ds = self.input | ||
ds, band_names = self._sel_indexes(indexes) | ||
|
||
if nodata is not None: | ||
ds = ds.rio.write_nodata(nodata) | ||
|
||
|
@@ -427,7 +466,7 @@ def preview( | |
bounds=ds.rio.bounds(), | ||
crs=ds.rio.crs, | ||
dataset_statistics=stats, | ||
band_names=self.band_names, | ||
band_names=band_names, | ||
) | ||
|
||
output_height = height or img.height | ||
|
@@ -450,6 +489,7 @@ def point( | |
lat: float, | ||
coord_crs: CRS = WGS84_CRS, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
**kwargs: Any, | ||
) -> PointData: | ||
"""Read a pixel value from a dataset. | ||
|
@@ -472,7 +512,8 @@ def point( | |
): | ||
raise PointOutsideBounds("Point is outside dataset bounds") | ||
|
||
ds = self.input | ||
ds, band_names = self._sel_indexes(indexes) | ||
|
||
if nodata is not None: | ||
ds = ds.rio.write_nodata(nodata) | ||
|
||
|
@@ -489,7 +530,7 @@ def point( | |
arr, | ||
coordinates=(lon, lat), | ||
crs=coord_crs, | ||
band_names=self.band_names, | ||
band_names=band_names, | ||
) | ||
|
||
def feature( | ||
|
@@ -500,6 +541,7 @@ def feature( | |
reproject_method: WarpResampling = "nearest", | ||
auto_expand: bool = True, | ||
nodata: Optional[NoData] = None, | ||
indexes: Optional[Indexes] = None, | ||
max_size: Optional[int] = None, | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
|
@@ -537,6 +579,7 @@ def feature( | |
dst_crs=dst_crs, | ||
bounds_crs=shape_crs, | ||
nodata=nodata, | ||
indexes=indexes, | ||
max_size=max_size, | ||
width=width, | ||
height=height, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1060,7 +1060,7 @@ def _get_reader(self, asset_info: AssetInfo) -> Tuple[Type[BaseReader], Dict]: | |
assert info["netcdf"].crs | ||
|
||
img = stac.preview(assets=["netcdf"]) | ||
assert img.band_names == ["netcdf_value"] | ||
assert img.band_names == ["netcdf_dataset"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. band value default to the DataArray's name ✨ |
||
|
||
|
||
@patch("rio_tiler.io.stac.STAC_ALTERNATE_KEY", "s3") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why have the band name returned as
b1
rather thantime
ortime1
(corresponding to<dimension-name><1-based index>
)?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was mostly for compatibility but I think having
<dim_name><idx>
would be 👍There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I will have the same issue when there are only 2 Dimensions what should the band name be? For now I would use the same code I'm using to get the band name from the first non-geo coordinate 🤷