From 66614c16c20b5ae26664bc598e7b9d4ccf5c2395 Mon Sep 17 00:00:00 2001 From: BRAUN REMI Date: Wed, 11 Dec 2024 19:04:40 +0100 Subject: [PATCH] Various updates, trying to remove rasterio in favor of rioxarray and numpy in favor of xarray --- eoreader/products/optical/dimap_v2_product.py | 72 ++++----- eoreader/products/optical/optical_product.py | 7 +- eoreader/products/optical/s2_product.py | 15 +- eoreader/products/optical/sv1_product.py | 18 ++- eoreader/products/optical/vhr_product.py | 140 +++++++++++------- eoreader/products/product.py | 13 ++ 6 files changed, 160 insertions(+), 105 deletions(-) diff --git a/eoreader/products/optical/dimap_v2_product.py b/eoreader/products/optical/dimap_v2_product.py index 1d9231f4..bb8fdcc3 100644 --- a/eoreader/products/optical/dimap_v2_product.py +++ b/eoreader/products/optical/dimap_v2_product.py @@ -1028,41 +1028,47 @@ def open_mask(self, mask_str: str, **kwargs) -> gpd.GeoDataFrame: mask.crs = WGS84 LOGGER.info(f"Orthorectifying {mask_str}") - with rasterio.open(str(self._get_tile_path())) as dim_dst: - # Rasterize mask (no transform as we have the vector in image geometry) - LOGGER.debug(f"\tRasterizing {mask_str}") - mask_raster = features.rasterize( - mask.geometry, - out_shape=(dim_dst.height, dim_dst.width), - fill=self._mask_false, # Outside vector - default_value=self._mask_true, # Inside vector - dtype=np.uint8, - ) - # Check mask validity (to avoid reprojecting) - # All null - if mask_raster.max() == 0: - mask = gpd.GeoDataFrame(geometry=[], crs=crs) - # All valid - elif mask_raster.min() == 1: - pass - else: - # Reproject mask raster - LOGGER.debug(f"\tReprojecting {mask_str}") - dem_path = self._get_dem_path(**kwargs) - reproj_data = self._reproject( - mask_raster, dim_dst.meta, dim_dst.rpcs, dem_path, **kwargs - ) - # Vectorize mask raster - LOGGER.debug(f"\tRevectorizing {mask_str}") - mask = rasters.vectorize( - reproj_data, - values=self._mask_true, - default_nodata=self._mask_false, - ) + # Rasterize mask (no transform as we have the vector in image geometry) + LOGGER.debug(f"\tRasterizing {mask_str}") + tile = utils.read(self._get_tile_path())[0:1, ...] + + mask_raster = rasters.rasterize( + tile, + mask, + default_nodata=self._mask_false, # Outside vector + default_value=self._mask_true, # Inside vector + dtype=np.uint8, + ) + # Check mask validity (to avoid reprojecting) + # All null + if mask_raster.max() == 0: + mask = gpd.GeoDataFrame(geometry=[], crs=crs) + # All valid + elif mask_raster.min() == 1: + pass + else: + # Reproject mask raster + LOGGER.debug(f"\tReprojecting {mask_str}") + dem_path = self._get_dem_path(**kwargs) + + # TODO: change this when available in rioxarray + # See https://github.com/corteva/rioxarray/issues/837 + with rasterio.open(self._get_tile_path()) as ds: + rpcs = ds.rpcs + + reproj_data = self._reproject(mask_raster, rpcs, dem_path, **kwargs) + + # Vectorize mask raster + LOGGER.debug(f"\tRevectorizing {mask_str}") + mask = rasters.vectorize( + reproj_data, + values=self._mask_true, + default_nodata=self._mask_false, + ) - # Do not keep pixelized mask - mask = geometry.simplify_footprint(mask, self.pixel_size) + # Do not keep pixelized mask + mask = geometry.simplify_footprint(mask, self.pixel_size) # Sometimes the GML mask lacks crs (why ?) elif ( diff --git a/eoreader/products/optical/optical_product.py b/eoreader/products/optical/optical_product.py index 51b19d0d..a9ed44ed 100644 --- a/eoreader/products/optical/optical_product.py +++ b/eoreader/products/optical/optical_product.py @@ -604,7 +604,7 @@ def _load_clouds( return band_dict def _create_mask( - self, xda: xr.DataArray, cond: np.ndarray, nodata: np.ndarray + self, xda: xr.DataArray, cond: np.ndarray, nodata: np.ndarray = None ) -> xr.DataArray: """ Create a mask from a conditional array and a nodata mask. @@ -618,10 +618,11 @@ def _create_mask( xr.DataArray: Mask as xarray """ # Create mask - mask = xda.copy(data=np.where(cond, self._mask_true, self._mask_false)) + mask = xda.copy(data=xr.where(cond, self._mask_true, self._mask_false)) # Set nodata to mask - mask = mask.where(nodata == 0) + if nodata is not None: + mask = mask.where(nodata == 0) return mask diff --git a/eoreader/products/optical/s2_product.py b/eoreader/products/optical/s2_product.py index 94290180..3943f751 100644 --- a/eoreader/products/optical/s2_product.py +++ b/eoreader/products/optical/s2_product.py @@ -1155,23 +1155,14 @@ def _manage_nodata_lt_4_0( if len(nodata_det) > 0: # Rasterize nodata - mask = features.rasterize( - nodata_det.geometry, - out_shape=(band_arr.rio.height, band_arr.rio.width), - fill=self._mask_true, # Outside detector = nodata (inverted compared to the usual) - default_value=self._mask_false, # Inside detector = not nodata - transform=transform.from_bounds( - *band_arr.rio.bounds(), band_arr.rio.width, band_arr.rio.height - ), - dtype=np.uint8, - ) + mask = self._rasterize(band_arr, nodata_det) else: # Manage empty geometry: nodata is 0 LOGGER.warning( "Empty detector footprint (DETFOO) vector. Nodata will be set where the pixels are null." ) s2_nodata = 0 - mask = np.where(band_arr == s2_nodata, 1, 0).astype(np.uint8) + mask = xr.where(band_arr == s2_nodata, 1, 0).astype(np.uint8) return self._set_nodata_mask(band_arr, mask) @@ -1500,7 +1491,7 @@ def _open_clouds( return self._open_clouds_gt_4_0(bands, pixel_size, size, **kwargs) def _rasterize( - self, xds: xr.DataArray, geometry: gpd.GeoDataFrame, nodata: np.ndarray + self, xds: xr.DataArray, geometry: gpd.GeoDataFrame, nodata: np.ndarray = None ) -> xr.DataArray: """ Rasterize a vector on a memory dataset diff --git a/eoreader/products/optical/sv1_product.py b/eoreader/products/optical/sv1_product.py index 4940b893..6f878f14 100644 --- a/eoreader/products/optical/sv1_product.py +++ b/eoreader/products/optical/sv1_product.py @@ -31,12 +31,12 @@ import xarray as xr from lxml import etree from rasterio import crs as riocrs -from sertit import path, rasters_rio, vectors +from sertit import path, vectors from sertit.misc import ListEnum from sertit.types import AnyPathType from shapely.geometry import box -from eoreader import DATETIME_FMT, EOREADER_NAME, cache +from eoreader import DATETIME_FMT, EOREADER_NAME, cache, utils from eoreader.bands import ( BLUE, GREEN, @@ -648,11 +648,15 @@ def _get_ortho_path(self, **kwargs) -> AnyPathType: # Reproject and write on disk data dem_path = self._get_dem_path(**kwargs) - with rasterio.open(str(self._get_tile_path(**kwargs))) as src: - out_arr, meta = self._reproject( - src.read(), src.meta, src.rpcs, dem_path, **kwargs - ) - rasters_rio.write(out_arr, meta, ortho_path) + tile = utils.read(self._get_tile_path(**kwargs)) + + # TODO: change this when available in rioxarray + # See https://github.com/corteva/rioxarray/issues/837 + with rasterio.open(self._get_tile_path(**kwargs)) as ds: + rpcs = ds.rpcs + + out = self._reproject(tile, rpcs, dem_path, **kwargs) + utils.write(out, ortho_path) else: ortho_path = self._get_tile_path(**kwargs) diff --git a/eoreader/products/optical/vhr_product.py b/eoreader/products/optical/vhr_product.py index a2391598..3ce32cd9 100644 --- a/eoreader/products/optical/vhr_product.py +++ b/eoreader/products/optical/vhr_product.py @@ -34,7 +34,7 @@ from rasterio.crs import CRS from rasterio.enums import Resampling from rasterio.vrt import WarpedVRT -from sertit import AnyPath, path, rasters, rasters_rio +from sertit import AnyPath, path, rasters from sertit.types import AnyPathStrType, AnyPathType from eoreader import EOREADER_NAME, utils @@ -162,27 +162,37 @@ def _get_ortho_path(self, **kwargs) -> AnyPathType: LOGGER.info( "Manually orthorectified stack not given by the user. " "Reprojecting whole stack, this may take a while. " - "(May be inaccurate on steep terrain, depending on the DEM pixel size)" + "(Might be inaccurate on steep terrain, depending on the DEM pixel size)." ) # Reproject and write on disk data dem_path = self._get_dem_path(**kwargs) - with rasterio.open(str(self._get_tile_path())) as src: - if "rpcs" in kwargs: - rpcs = kwargs.pop("rpcs") - else: - rpcs = src.rpcs - - if not rpcs: - raise InvalidProductError( - "Your projected VHR data doesn't have any RPC. " - "EOReader cannot orthorectify it!" - ) - out_arr, meta = self._reproject( - src.read(), src.meta, rpcs, dem_path, **kwargs + tile = utils.read(self._get_tile_path()) + + if "rpcs" in kwargs: + rpcs = kwargs.pop("rpcs") + else: + # TODO: change this when available in rioxarray + # See https://github.com/corteva/rioxarray/issues/837 + with rasterio.open(self._get_tile_path()) as ds: + rpcs = ds.rpcs + tags = ds.tags() + + if not rpcs: + raise InvalidProductError( + "Your projected VHR data doesn't have any RPC. " + "EOReader cannot orthorectify it!" ) - rasters_rio.write(out_arr, meta, ortho_path, tags=src.tags()) + + out = self._reproject(tile, rpcs, dem_path, **kwargs) + utils.write( + out, + ortho_path, + dtype=np.float32, + nodata=self._raw_nodata, + tags=tags, + ) else: ortho_path = self._get_tile_path() @@ -270,7 +280,7 @@ def _get_dem_path(self, **kwargs) -> str: return dem_path def _reproject( - self, src_arr: np.ndarray, src_meta: dict, rpcs: rpc.RPC, dem_path, **kwargs + self, src_xda: xr.DataArray, rpcs: rpc.RPC, dem_path, **kwargs ) -> (np.ndarray, dict): """ Reproject using RPCs (cannot use another pixel size than src to ensure RPCs are valid) @@ -284,45 +294,75 @@ def _reproject( Returns: (np.ndarray, dict): Reprojected array and its metadata """ - # Set RPC keywords LOGGER.debug(f"Orthorectifying data with {dem_path}") - kwargs = { - "RPC_DEM": dem_path, - "RPC_DEM_MISSING_VALUE": 0, - "OSR_USE_ETMERC": "YES", - "BIGTIFF": "IF_NEEDED", - } - - # Reproject - # WARNING: may not give correct output pixel size - out_arr, dst_transform = warp.reproject( - src_arr, - rpcs=rpcs, - src_crs=self._get_raw_crs(), - src_nodata=self._raw_nodata, + kwargs.update( + { + "RPC_DEM": dem_path, + "RPC_DEM_MISSING_VALUE": 0, + "OSR_USE_ETMERC": "YES", + "BIGTIFF": "IF_NEEDED", + } + ) + + # Reproject with rioxarray + # Seems to handle the resolution well on the contrary to rasterio's reproject... + out_xda = src_xda.rio.reproject( dst_crs=self.crs(), - dst_resolution=self.pixel_size, - dst_nodata=self._raw_nodata, # input data should be in integer - num_threads=utils.get_max_cores(), + resolution=self.pixel_size, resampling=Resampling.bilinear, + nodata=self._raw_nodata, + num_threads=utils.get_max_cores(), + rpcs=rpcs, + dtype=src_xda.dtype, **kwargs, ) - # Get dims - count, height, width = out_arr.shape - - # Update metadata - meta = src_meta.copy() - meta["transform"] = dst_transform - meta["driver"] = "GTiff" - meta["compress"] = "lzw" - meta["nodata"] = self._raw_nodata - meta["crs"] = self.crs() - meta["width"] = width - meta["height"] = height - meta["count"] = count - - return out_arr, meta + out_xda.rename(f"Reprojected stack of {self.name}") + out_xda.attrs["long_name"] = self.get_bands_names() + + # Daskified reproject doesn't seem to work with RPC + # See https://github.com/opendatacube/odc-geo/issues/193 + # from odc.geo import xr # noqa + # out_xda = src_xda.odc.reproject( + # how=self.crs(), + # resolution=self.pixel_size, + # resampling=Resampling.bilinear, + # dst_nodata=self._raw_nodata, + # num_threads=utils.get_max_cores(), + # rpcs=rpcs, + # dtype=src_xda.dtype, + # **kwargs + # ) + + # Legacy with rasterio directly + # WARNING: may not give correct output pixel size + # out_arr, dst_transform = warp.reproject( + # src_arr, + # rpcs=rpcs, + # src_crs=self._get_raw_crs(), + # src_nodata=self._raw_nodata, + # dst_crs=self.crs(), + # dst_resolution=self.pixel_size, + # dst_nodata=self._raw_nodata, # input data should be in integer + # num_threads=utils.get_max_cores(), + # resampling=Resampling.bilinear, + # **kwargs, + # ) + # # Get dims + # count, height, width = out_arr.shape + # + # # Update metadata + # meta = src_meta.copy() + # meta["transform"] = dst_transform + # meta["driver"] = "GTiff" + # meta["compress"] = "lzw" + # meta["nodata"] = self._raw_nodata + # meta["crs"] = self.crs() + # meta["width"] = width + # meta["height"] = height + # meta["count"] = count + + return out_xda def _read_band( self, diff --git a/eoreader/products/product.py b/eoreader/products/product.py index 6d553b66..4e35eea6 100644 --- a/eoreader/products/product.py +++ b/eoreader/products/product.py @@ -2187,3 +2187,16 @@ def _read_archived_vector( file_list=self._get_archived_file_list(archive_path), **kwargs, ) + + def get_bands_names(self) -> list: + """ + Get the name of the bands composing the product, ordered by ID. + For example, for SPOT7: ['RED', 'GREEN', 'BLUE', 'NIR'] + + Returns: + list: Ordered bands names + """ + stack_bands = { + band.id: band.name for band in self.bands.values() if band is not None + } + return [id_name[1] for id_name in sorted(stack_bands.items())]