Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Xarray: add indexes options and better define band names #764

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,73 @@
# 7.3.0 (TBD)

* add `indexes` parameter for `XarrayReader` methods. As for Rasterio, the indexes values start at `1`.

```python
data = ... # DataArray of shape (2, x, y)

# before
with XarrayReader(data) as dst:
img = dst.tile(0, 0, 0)
assert img.count == 2

# now
with XarrayReader(data) as dst:
# Select the first `band` within the data array
img = dst.tile(0, 0, 0, indexes=1)
assert img.count == 1
```

* better define `band names` for `XarrayReader` objects

* band_name for `2D` dataset is extracted form the first `non-geo` coordinates value

```python
data = xarray.DataArray(
numpy.arange(0.0, 33 * 35 * 2).reshape(2, 33, 35),
dims=("time", "y", "x"),
coords={
"x": numpy.arange(-170, 180, 10),
"y": numpy.arange(-80, 85, 5),
"time": [datetime(2022, 1, 1), datetime(2022, 1, 2)],
},
)
da = data[0]

print(da.coords["time"].data)
>> array('2022-01-01T00:00:00.000000000', dtype='datetime64[ns]'))

# before
with XarrayReader(data) as dst:
img = dst.info()
print(img.band_descriptions)[0]
>> ("b1", "value")

# now
with XarrayReader(data) as dst:
img = dst.info()
print(img.band_descriptions)[0]
>> ("b1", "2022-01-01T00:00:00.000000000")

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why have the band name returned as b1 rather than time or time1 (corresponding to <dimension-name><1-based index>)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was mostly for compatibility but I think having <dim_name><idx> would be 👍

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I will have the same issue when there are only 2 Dimensions what should the band name be? For now I would use the same code I'm using to get the band name from the first non-geo coordinate 🤷

```

* default `band_names` is changed to DataArray's name or `array` (when no available coordinates value)

```python
data = ... # DataArray of shape (x, y)

# before
with XarrayReader(data) as dst:
img = dst.info()
print(img.band_descriptions)[0]
>> ("b1", "value")

# now
with XarrayReader(data) as dst:
img = dst.info()
print(img.band_descriptions)[0]
>> ("b1", "array")
```


# 7.2.0 (2024-11-05)

* Ensure compatibility between XarrayReader and other Readers by adding `**kwargs` on class methods (https://github.com/cogeotiff/rio-tiler/pull/762)
Expand Down
89 changes: 66 additions & 23 deletions rio_tiler/io/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import warnings
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

import attr
import numpy
Expand All @@ -28,8 +28,13 @@
from rio_tiler.io.base import BaseReader
from rio_tiler.models import BandStatistics, ImageData, Info, PointData
from rio_tiler.reader import _get_width_height
from rio_tiler.types import BBox, NoData, RIOResampling, WarpResampling
from rio_tiler.utils import CRS_to_uri, _validate_shape_input, get_array_statistics
from rio_tiler.types import BBox, Indexes, NoData, RIOResampling, WarpResampling
from rio_tiler.utils import (
CRS_to_uri,
_validate_shape_input,
cast_to_sequence,
get_array_statistics,
)

try:
import xarray
Expand Down Expand Up @@ -105,6 +110,7 @@ def __attrs_post_init__(self):
for d in self.input.dims
if d not in [self.input.rio.x_dim, self.input.rio.y_dim]
]
assert len(self._dims) in [0, 1], "Can't handle >=4D DataArray"
Copy link
Member Author

@vincentsarago vincentsarago Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a check to make sure we don't have 4D arrays


@property
def minzoom(self):
Expand All @@ -118,29 +124,34 @@ def maxzoom(self):

@property
def band_names(self) -> List[str]:
"""Return list of `band names` in DataArray."""
return [str(band) for d in self._dims for band in self.input[d].values] or [
"value"
]
"""Return list of `band descriptions` in DataArray."""
vincentsarago marked this conversation as resolved.
Show resolved Hide resolved
if not self._dims:
coords_name = list(self.input.coords)
if len(coords_name) > 3 and (coord := coords_name[2]):
return [str(self.input.coords[coord].data)]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm hesitant to put something like {coord_name}={coord_value}} 🤷‍♂️

but we don't do this for the other band names

Copy link

@maxrjones maxrjones Nov 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not super easy to understand what this code is trying to accomplish, but I think it's problematic that band names are based on dimensions if _dims is set as an attribute and the names are based on coordinates if not. I think it should always be based on non-spatial (as defined by rioxarray) dimensions. Since all dimensions have names, this should also make dealing with defaults simpler. Some documentation about how to map Xarray's data model into rio-tiler's assumptions would really help in general.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but I think it's problematic that band names are based on dimensions if _dims is set as an attribute and the names are based on coordinates if not. I think it should always be based on non-spatial (as defined by rioxarray) dimensions. Since all dimensions have names, this should also make dealing with defaults simpler. Some documentation about how to map Xarray's data model into rio-tiler's assumptions would really help in general.

@maxrjones I'm not sure to get this

The main issue is that when we pass a 2D dataarray (thinks about if you select the first time dim) then dims is empty but the coordinates has the time value which correspond to the name of the array slice (maybe I'm mistaken) so I felt we needed to have a way to surface this.

from datetime import datetime
import numpy
import xarray
import rioxarray

arr = numpy.arange(0.0, 33 * 35 * 2).reshape(2, 33, 35)
data = xarray.DataArray(
    arr,
    dims=("time", "y", "x"),
    coords={
        "x": numpy.arange(-170, 180, 10),
        "y": numpy.arange(-80, 85, 5),
        "time": [datetime(2022, 1, 1), datetime(2022, 1, 2)],
    },
)
data.attrs.update({"valid_min": arr.min(), "valid_max": arr.max()})

ds = data.to_dataset(name="dataset")
da = ds["dataset"][0]
da.rio.write_crs("epsg:4326", inplace=True)

da
# >>> <xarray.DataArray 'dataset' (y: 33, x: 35)> Size: 9kB
# array([[0.000e+00, 1.000e+00, 2.000e+00, ..., 3.200e+01, 3.300e+01,
#         3.400e+01],
#        [3.500e+01, 3.600e+01, 3.700e+01, ..., 6.700e+01, 6.800e+01,
#         6.900e+01],
#        [7.000e+01, 7.100e+01, 7.200e+01, ..., 1.020e+02, 1.030e+02,
#         1.040e+02],
#        ...,
#        [1.050e+03, 1.051e+03, 1.052e+03, ..., 1.082e+03, 1.083e+03,
#         1.084e+03],
#        [1.085e+03, 1.086e+03, 1.087e+03, ..., 1.117e+03, 1.118e+03,
#         1.119e+03],
#        [1.120e+03, 1.121e+03, 1.122e+03, ..., 1.152e+03, 1.153e+03,
#         1.154e+03]])
# Coordinates:
#   * x            (x) int64 280B -170 -160 -150 -140 -130 ... 130 140 150 160 170
#   * y            (y) int64 264B -80 -75 -70 -65 -60 -55 ... 55 60 65 70 75 80
#     time         datetime64[ns] 8B 2022-01-01
#     spatial_ref  int64 8B 0
# Attributes:
#     valid_min:  0.0
#     valid_max:  2309.0

_dims = [
    d
    for d in da.dims
    if d not in [da.rio.x_dim, da.rio.y_dim]
]
_dims
# >>> []


coords_name = list(da.coords)
coords_name
# >>> ['x', 'y', 'time', 'spatial_ref']

if len(coords_name) > 3 and (coord := coords_name[2]):
    print(str(da.coords[coord].data))
# >>> 2022-01-01T00:00:00.000000000

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@maxrjones could you double check the logic here?


return [self.input.name or "array"]

return [str(band) for d in self._dims for band in self.input[d].values]

def info(self) -> Info:
"""Return xarray.DataArray info."""
bands = [str(band) for d in self._dims for band in self.input[d].values] or [
"value"
]
metadata = [band.attrs for d in self._dims for band in self.input[d]] or [{}]

meta = {
"bounds": self.bounds,
"crs": CRS_to_uri(self.crs) or self.crs.to_wkt(),
"band_metadata": [(f"b{ix}", v) for ix, v in enumerate(metadata, 1)],
"band_descriptions": [(f"b{ix}", v) for ix, v in enumerate(bands, 1)],
"band_descriptions": [
(f"b{ix}", v) for ix, v in enumerate(self.band_names, 1)
],
"dtype": str(self.input.dtype),
"nodata_type": "Nodata" if self.input.rio.nodata is not None else "None",
"name": self.input.name,
"count": self.input.rio.count,
"width": self.input.rio.width,
"height": self.input.rio.height,
"dimensions": self.input.dims,
"attrs": {
k: (v.tolist() if isinstance(v, (numpy.ndarray, numpy.generic)) else v)
for k, v in self.input.attrs.items()
Expand All @@ -149,19 +160,43 @@ def info(self) -> Info:

return Info(**meta)

def _sel_indexes(
self, indexes: Optional[Indexes] = None
) -> Tuple[xarray.DataArray, List[str]]:
"""Select `band` indexes in DataArray."""
ds = self.input
band_names = self.band_names
if indexes := cast_to_sequence(indexes):
assert all(v > 0 for v in indexes), "Indexes value must be >= 1"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xarray won't complain when we pass data[-1] so we need this tests

if ds.ndim == 2:
if indexes != (1,):
raise ValueError(
f"Invalid indexes {indexes} for array of shape {ds.shape}"
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for 2D array we still allow indexes=1


return ds, band_names

indexes = [idx - 1 for idx in indexes]
ds = ds[indexes]
band_names = [self.band_names[idx] for idx in indexes]

return ds, band_names

def statistics(
self,
categorical: bool = False,
categories: Optional[List[float]] = None,
percentiles: Optional[List[int]] = None,
hist_options: Optional[Dict] = None,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
**kwargs: Any,
) -> Dict[str, BandStatistics]:
"""Return statistics from a dataset."""
hist_options = hist_options or {}

ds = self.input
ds, band_names = self._sel_indexes(indexes)

if nodata is not None:
ds = ds.rio.write_nodata(nodata)

Expand All @@ -176,9 +211,7 @@ def statistics(
**hist_options,
)

return {
self.band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats)
}
return {band_names[ix]: BandStatistics(**val) for ix, val in enumerate(stats)}

def tile(
self,
Expand All @@ -189,6 +222,7 @@ def tile(
reproject_method: WarpResampling = "nearest",
auto_expand: bool = True,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
**kwargs: Any,
) -> ImageData:
"""Read a Web Map tile from a dataset.
Expand All @@ -211,7 +245,8 @@ def tile(
f"Tile(x={tile_x}, y={tile_y}, z={tile_z}) is outside bounds"
)

ds = self.input
ds, band_names = self._sel_indexes(indexes)

if nodata is not None:
ds = ds.rio.write_nodata(nodata)

Expand Down Expand Up @@ -251,7 +286,7 @@ def tile(
bounds=tile_bounds,
crs=dst_crs,
dataset_statistics=stats,
band_names=self.band_names,
band_names=band_names,
)

def part(
Expand All @@ -262,6 +297,7 @@ def part(
reproject_method: WarpResampling = "nearest",
auto_expand: bool = True,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
max_size: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down Expand Up @@ -294,7 +330,8 @@ def part(

dst_crs = dst_crs or bounds_crs

ds = self.input
ds, band_names = self._sel_indexes(indexes)

if nodata is not None:
ds = ds.rio.write_nodata(nodata)

Expand Down Expand Up @@ -339,7 +376,7 @@ def part(
bounds=ds.rio.bounds(),
crs=ds.rio.crs,
dataset_statistics=stats,
band_names=self.band_names,
band_names=band_names,
)

output_height = height or img.height
Expand All @@ -362,6 +399,7 @@ def preview(
height: Optional[int] = None,
width: Optional[int] = None,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
dst_crs: Optional[CRS] = None,
reproject_method: WarpResampling = "nearest",
resampling_method: RIOResampling = "nearest",
Expand All @@ -388,7 +426,8 @@ def preview(
UserWarning,
)

ds = self.input
ds, band_names = self._sel_indexes(indexes)

if nodata is not None:
ds = ds.rio.write_nodata(nodata)

Expand Down Expand Up @@ -427,7 +466,7 @@ def preview(
bounds=ds.rio.bounds(),
crs=ds.rio.crs,
dataset_statistics=stats,
band_names=self.band_names,
band_names=band_names,
)

output_height = height or img.height
Expand All @@ -450,6 +489,7 @@ def point(
lat: float,
coord_crs: CRS = WGS84_CRS,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
**kwargs: Any,
) -> PointData:
"""Read a pixel value from a dataset.
Expand All @@ -472,7 +512,8 @@ def point(
):
raise PointOutsideBounds("Point is outside dataset bounds")

ds = self.input
ds, band_names = self._sel_indexes(indexes)

if nodata is not None:
ds = ds.rio.write_nodata(nodata)

Expand All @@ -489,7 +530,7 @@ def point(
arr,
coordinates=(lon, lat),
crs=coord_crs,
band_names=self.band_names,
band_names=band_names,
)

def feature(
Expand All @@ -500,6 +541,7 @@ def feature(
reproject_method: WarpResampling = "nearest",
auto_expand: bool = True,
nodata: Optional[NoData] = None,
indexes: Optional[Indexes] = None,
max_size: Optional[int] = None,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down Expand Up @@ -537,6 +579,7 @@ def feature(
dst_crs=dst_crs,
bounds_crs=shape_crs,
nodata=nodata,
indexes=indexes,
max_size=max_size,
width=width,
height=height,
Expand Down
Loading
Loading