Skip to content

Commit

Permalink
Various updates, trying to remove rasterio in favor of rioxarray and …
Browse files Browse the repository at this point in the history
…numpy in favor of xarray
  • Loading branch information
remi-braun committed Dec 11, 2024
1 parent 5fc7889 commit 66614c1
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 105 deletions.
72 changes: 39 additions & 33 deletions eoreader/products/optical/dimap_v2_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,41 +1028,47 @@ def open_mask(self, mask_str: str, **kwargs) -> gpd.GeoDataFrame:

mask.crs = WGS84
LOGGER.info(f"Orthorectifying {mask_str}")
with rasterio.open(str(self._get_tile_path())) as dim_dst:
# Rasterize mask (no transform as we have the vector in image geometry)
LOGGER.debug(f"\tRasterizing {mask_str}")
mask_raster = features.rasterize(
mask.geometry,
out_shape=(dim_dst.height, dim_dst.width),
fill=self._mask_false, # Outside vector
default_value=self._mask_true, # Inside vector
dtype=np.uint8,
)
# Check mask validity (to avoid reprojecting)
# All null
if mask_raster.max() == 0:
mask = gpd.GeoDataFrame(geometry=[], crs=crs)
# All valid
elif mask_raster.min() == 1:
pass
else:
# Reproject mask raster
LOGGER.debug(f"\tReprojecting {mask_str}")
dem_path = self._get_dem_path(**kwargs)
reproj_data = self._reproject(
mask_raster, dim_dst.meta, dim_dst.rpcs, dem_path, **kwargs
)

# Vectorize mask raster
LOGGER.debug(f"\tRevectorizing {mask_str}")
mask = rasters.vectorize(
reproj_data,
values=self._mask_true,
default_nodata=self._mask_false,
)
# Rasterize mask (no transform as we have the vector in image geometry)
LOGGER.debug(f"\tRasterizing {mask_str}")
tile = utils.read(self._get_tile_path())[0:1, ...]

mask_raster = rasters.rasterize(
tile,
mask,
default_nodata=self._mask_false, # Outside vector
default_value=self._mask_true, # Inside vector
dtype=np.uint8,
)
# Check mask validity (to avoid reprojecting)
# All null
if mask_raster.max() == 0:
mask = gpd.GeoDataFrame(geometry=[], crs=crs)
# All valid
elif mask_raster.min() == 1:
pass
else:
# Reproject mask raster
LOGGER.debug(f"\tReprojecting {mask_str}")
dem_path = self._get_dem_path(**kwargs)

# TODO: change this when available in rioxarray
# See https://github.com/corteva/rioxarray/issues/837
with rasterio.open(self._get_tile_path()) as ds:
rpcs = ds.rpcs

reproj_data = self._reproject(mask_raster, rpcs, dem_path, **kwargs)

# Vectorize mask raster
LOGGER.debug(f"\tRevectorizing {mask_str}")
mask = rasters.vectorize(
reproj_data,
values=self._mask_true,
default_nodata=self._mask_false,
)

# Do not keep pixelized mask
mask = geometry.simplify_footprint(mask, self.pixel_size)
# Do not keep pixelized mask
mask = geometry.simplify_footprint(mask, self.pixel_size)

# Sometimes the GML mask lacks crs (why ?)
elif (
Expand Down
7 changes: 4 additions & 3 deletions eoreader/products/optical/optical_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def _load_clouds(
return band_dict

def _create_mask(
self, xda: xr.DataArray, cond: np.ndarray, nodata: np.ndarray
self, xda: xr.DataArray, cond: np.ndarray, nodata: np.ndarray = None
) -> xr.DataArray:
"""
Create a mask from a conditional array and a nodata mask.
Expand All @@ -618,10 +618,11 @@ def _create_mask(
xr.DataArray: Mask as xarray
"""
# Create mask
mask = xda.copy(data=np.where(cond, self._mask_true, self._mask_false))
mask = xda.copy(data=xr.where(cond, self._mask_true, self._mask_false))

# Set nodata to mask
mask = mask.where(nodata == 0)
if nodata is not None:
mask = mask.where(nodata == 0)

return mask

Expand Down
15 changes: 3 additions & 12 deletions eoreader/products/optical/s2_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1155,23 +1155,14 @@ def _manage_nodata_lt_4_0(

if len(nodata_det) > 0:
# Rasterize nodata
mask = features.rasterize(
nodata_det.geometry,
out_shape=(band_arr.rio.height, band_arr.rio.width),
fill=self._mask_true, # Outside detector = nodata (inverted compared to the usual)
default_value=self._mask_false, # Inside detector = not nodata
transform=transform.from_bounds(
*band_arr.rio.bounds(), band_arr.rio.width, band_arr.rio.height
),
dtype=np.uint8,
)
mask = self._rasterize(band_arr, nodata_det)
else:
# Manage empty geometry: nodata is 0
LOGGER.warning(
"Empty detector footprint (DETFOO) vector. Nodata will be set where the pixels are null."
)
s2_nodata = 0
mask = np.where(band_arr == s2_nodata, 1, 0).astype(np.uint8)
mask = xr.where(band_arr == s2_nodata, 1, 0).astype(np.uint8)

return self._set_nodata_mask(band_arr, mask)

Expand Down Expand Up @@ -1500,7 +1491,7 @@ def _open_clouds(
return self._open_clouds_gt_4_0(bands, pixel_size, size, **kwargs)

def _rasterize(
self, xds: xr.DataArray, geometry: gpd.GeoDataFrame, nodata: np.ndarray
self, xds: xr.DataArray, geometry: gpd.GeoDataFrame, nodata: np.ndarray = None
) -> xr.DataArray:
"""
Rasterize a vector on a memory dataset
Expand Down
18 changes: 11 additions & 7 deletions eoreader/products/optical/sv1_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@
import xarray as xr
from lxml import etree
from rasterio import crs as riocrs
from sertit import path, rasters_rio, vectors
from sertit import path, vectors
from sertit.misc import ListEnum
from sertit.types import AnyPathType
from shapely.geometry import box

from eoreader import DATETIME_FMT, EOREADER_NAME, cache
from eoreader import DATETIME_FMT, EOREADER_NAME, cache, utils
from eoreader.bands import (
BLUE,
GREEN,
Expand Down Expand Up @@ -648,11 +648,15 @@ def _get_ortho_path(self, **kwargs) -> AnyPathType:

# Reproject and write on disk data
dem_path = self._get_dem_path(**kwargs)
with rasterio.open(str(self._get_tile_path(**kwargs))) as src:
out_arr, meta = self._reproject(
src.read(), src.meta, src.rpcs, dem_path, **kwargs
)
rasters_rio.write(out_arr, meta, ortho_path)
tile = utils.read(self._get_tile_path(**kwargs))

# TODO: change this when available in rioxarray
# See https://github.com/corteva/rioxarray/issues/837
with rasterio.open(self._get_tile_path(**kwargs)) as ds:
rpcs = ds.rpcs

out = self._reproject(tile, rpcs, dem_path, **kwargs)
utils.write(out, ortho_path)

else:
ortho_path = self._get_tile_path(**kwargs)
Expand Down
140 changes: 90 additions & 50 deletions eoreader/products/optical/vhr_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from rasterio.crs import CRS
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from sertit import AnyPath, path, rasters, rasters_rio
from sertit import AnyPath, path, rasters
from sertit.types import AnyPathStrType, AnyPathType

from eoreader import EOREADER_NAME, utils
Expand Down Expand Up @@ -162,27 +162,37 @@ def _get_ortho_path(self, **kwargs) -> AnyPathType:
LOGGER.info(
"Manually orthorectified stack not given by the user. "
"Reprojecting whole stack, this may take a while. "
"(May be inaccurate on steep terrain, depending on the DEM pixel size)"
"(Might be inaccurate on steep terrain, depending on the DEM pixel size)."
)

# Reproject and write on disk data
dem_path = self._get_dem_path(**kwargs)
with rasterio.open(str(self._get_tile_path())) as src:
if "rpcs" in kwargs:
rpcs = kwargs.pop("rpcs")
else:
rpcs = src.rpcs

if not rpcs:
raise InvalidProductError(
"Your projected VHR data doesn't have any RPC. "
"EOReader cannot orthorectify it!"
)

out_arr, meta = self._reproject(
src.read(), src.meta, rpcs, dem_path, **kwargs
tile = utils.read(self._get_tile_path())

if "rpcs" in kwargs:
rpcs = kwargs.pop("rpcs")
else:
# TODO: change this when available in rioxarray
# See https://github.com/corteva/rioxarray/issues/837
with rasterio.open(self._get_tile_path()) as ds:
rpcs = ds.rpcs
tags = ds.tags()

if not rpcs:
raise InvalidProductError(
"Your projected VHR data doesn't have any RPC. "
"EOReader cannot orthorectify it!"
)
rasters_rio.write(out_arr, meta, ortho_path, tags=src.tags())

out = self._reproject(tile, rpcs, dem_path, **kwargs)
utils.write(
out,
ortho_path,
dtype=np.float32,
nodata=self._raw_nodata,
tags=tags,
)

else:
ortho_path = self._get_tile_path()
Expand Down Expand Up @@ -270,7 +280,7 @@ def _get_dem_path(self, **kwargs) -> str:
return dem_path

def _reproject(
self, src_arr: np.ndarray, src_meta: dict, rpcs: rpc.RPC, dem_path, **kwargs
self, src_xda: xr.DataArray, rpcs: rpc.RPC, dem_path, **kwargs
) -> (np.ndarray, dict):
"""
Reproject using RPCs (cannot use another pixel size than src to ensure RPCs are valid)
Expand All @@ -284,45 +294,75 @@ def _reproject(
Returns:
(np.ndarray, dict): Reprojected array and its metadata
"""

# Set RPC keywords
LOGGER.debug(f"Orthorectifying data with {dem_path}")
kwargs = {
"RPC_DEM": dem_path,
"RPC_DEM_MISSING_VALUE": 0,
"OSR_USE_ETMERC": "YES",
"BIGTIFF": "IF_NEEDED",
}

# Reproject
# WARNING: may not give correct output pixel size
out_arr, dst_transform = warp.reproject(
src_arr,
rpcs=rpcs,
src_crs=self._get_raw_crs(),
src_nodata=self._raw_nodata,
kwargs.update(
{
"RPC_DEM": dem_path,
"RPC_DEM_MISSING_VALUE": 0,
"OSR_USE_ETMERC": "YES",
"BIGTIFF": "IF_NEEDED",
}
)

# Reproject with rioxarray
# Seems to handle the resolution well on the contrary to rasterio's reproject...
out_xda = src_xda.rio.reproject(
dst_crs=self.crs(),
dst_resolution=self.pixel_size,
dst_nodata=self._raw_nodata, # input data should be in integer
num_threads=utils.get_max_cores(),
resolution=self.pixel_size,
resampling=Resampling.bilinear,
nodata=self._raw_nodata,
num_threads=utils.get_max_cores(),
rpcs=rpcs,
dtype=src_xda.dtype,
**kwargs,
)
# Get dims
count, height, width = out_arr.shape

# Update metadata
meta = src_meta.copy()
meta["transform"] = dst_transform
meta["driver"] = "GTiff"
meta["compress"] = "lzw"
meta["nodata"] = self._raw_nodata
meta["crs"] = self.crs()
meta["width"] = width
meta["height"] = height
meta["count"] = count

return out_arr, meta
out_xda.rename(f"Reprojected stack of {self.name}")
out_xda.attrs["long_name"] = self.get_bands_names()

# Daskified reproject doesn't seem to work with RPC
# See https://github.com/opendatacube/odc-geo/issues/193
# from odc.geo import xr # noqa
# out_xda = src_xda.odc.reproject(
# how=self.crs(),
# resolution=self.pixel_size,
# resampling=Resampling.bilinear,
# dst_nodata=self._raw_nodata,
# num_threads=utils.get_max_cores(),
# rpcs=rpcs,
# dtype=src_xda.dtype,
# **kwargs
# )

# Legacy with rasterio directly
# WARNING: may not give correct output pixel size
# out_arr, dst_transform = warp.reproject(
# src_arr,
# rpcs=rpcs,
# src_crs=self._get_raw_crs(),
# src_nodata=self._raw_nodata,
# dst_crs=self.crs(),
# dst_resolution=self.pixel_size,
# dst_nodata=self._raw_nodata, # input data should be in integer
# num_threads=utils.get_max_cores(),
# resampling=Resampling.bilinear,
# **kwargs,
# )
# # Get dims
# count, height, width = out_arr.shape
#
# # Update metadata
# meta = src_meta.copy()
# meta["transform"] = dst_transform
# meta["driver"] = "GTiff"
# meta["compress"] = "lzw"
# meta["nodata"] = self._raw_nodata
# meta["crs"] = self.crs()
# meta["width"] = width
# meta["height"] = height
# meta["count"] = count

return out_xda

def _read_band(
self,
Expand Down
13 changes: 13 additions & 0 deletions eoreader/products/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -2187,3 +2187,16 @@ def _read_archived_vector(
file_list=self._get_archived_file_list(archive_path),
**kwargs,
)

def get_bands_names(self) -> list:
"""
Get the name of the bands composing the product, ordered by ID.
For example, for SPOT7: ['RED', 'GREEN', 'BLUE', 'NIR']
Returns:
list: Ordered bands names
"""
stack_bands = {
band.id: band.name for band in self.bands.values() if band is not None
}
return [id_name[1] for id_name in sorted(stack_bands.items())]

0 comments on commit 66614c1

Please sign in to comment.