From 23df712935ace481c4dfc8bec9b55ce691fbab43 Mon Sep 17 00:00:00 2001 From: Romain Hugonnet Date: Wed, 22 May 2024 11:33:18 -0800 Subject: [PATCH] Add Dask-delayed raster `subsample()`, `reproject()` and `interp_points()` (#537) Co-authored-by: Amelie --- dev-environment.yml | 3 + environment.yml | 1 + geoutils/raster/delayed.py | 791 ++++++++++++++++++++ requirements.txt | 1 + tests/{ => test_raster}/test_array.py | 0 tests/test_raster/test_delayed.py | 689 +++++++++++++++++ tests/{ => test_raster}/test_multiraster.py | 0 tests/{ => test_raster}/test_raster.py | 0 tests/{ => test_raster}/test_sampling.py | 0 tests/{ => test_raster}/test_satimg.py | 0 10 files changed, 1485 insertions(+) create mode 100644 geoutils/raster/delayed.py rename tests/{ => test_raster}/test_array.py (100%) create mode 100644 tests/test_raster/test_delayed.py rename tests/{ => test_raster}/test_multiraster.py (100%) rename tests/{ => test_raster}/test_raster.py (100%) rename tests/{ => test_raster}/test_sampling.py (100%) rename tests/{ => test_raster}/test_satimg.py (100%) diff --git a/dev-environment.yml b/dev-environment.yml index c0720f47..04256a02 100644 --- a/dev-environment.yml +++ b/dev-environment.yml @@ -12,6 +12,7 @@ dependencies: - scipy=1.* - tqdm - xarray + - dask - rioxarray=0.* # Development-specific, to mirror manually in setup.cfg [options.extras_require]. @@ -27,6 +28,8 @@ dependencies: - pyyaml - flake8 - pylint + - netcdf4 # To write synthetic data with chunksizes + - dask-memusage # Doc dependencies - sphinx diff --git a/environment.yml b/environment.yml index 5f2c5d50..92daeaba 100644 --- a/environment.yml +++ b/environment.yml @@ -12,4 +12,5 @@ dependencies: - scipy=1.* - tqdm - xarray + - dask - rioxarray=0.* diff --git a/geoutils/raster/delayed.py b/geoutils/raster/delayed.py new file mode 100644 index 00000000..7435a222 --- /dev/null +++ b/geoutils/raster/delayed.py @@ -0,0 +1,791 @@ +""" +Module for dask-delayed functions for out-of-memory raster operations. +""" + +from __future__ import annotations + +import warnings +from typing import Any, Literal, TypeVar + +import dask.array as da +import dask.delayed +import geopandas as gpd +import numpy as np +import pandas as pd +import rasterio as rio +from dask.utils import cached_cumsum +from scipy.interpolate import interpn + +from geoutils._typing import NDArrayBool, NDArrayNum +from geoutils.projtools import _get_bounds_projected, _get_footprint_projected + +# 1/ SUBSAMPLING +# At the date of April 2024: +# Getting an exact subsample size out-of-memory only for valid values is not supported directly by Dask/Xarray + +# It is not trivial because we don't know where valid values will be in advance, and because of ragged output (varying +# output length considerations), which prevents from using high-level functions with good efficiency +# We thus follow https://blog.dask.org/2021/07/02/ragged-output (the dask.array.map_blocks solution has a larger RAM +# usage by having to drop an axis and re-chunk along 1D of the 2D array, so we use the dask.delayed solution instead) + + +def _get_subsample_size_from_user_input( + subsample: int | float, total_nb_valids: int, silence_max_subsample: bool +) -> int: + """Get subsample size based on a user input of either integer size or fraction of the number of valid points.""" + + # If value is between 0 and 1, use a fraction + if (subsample <= 1) & (subsample > 0): + npoints = int(subsample * total_nb_valids) + # Otherwise use the value directly + elif subsample > 1: + # Use the number of valid points if larger than subsample asked by user + npoints = min(int(subsample), total_nb_valids) + if subsample > total_nb_valids: + if not silence_max_subsample: + warnings.warn( + f"Subsample value of {subsample} is larger than the number of valid pixels of {total_nb_valids}," + f"using all valid pixels as a subsample.", + category=UserWarning, + ) + else: + raise ValueError("Subsample must be > 0.") + + return npoints + + +def _get_indices_block_per_subsample( + indices_1d: NDArrayNum, num_chunks: tuple[int, int], nb_valids_per_block: list[int] +) -> list[list[int]]: + """ + Get list of 1D valid subsample indices relative to the block for each block. + + The 1D valid subsample indices correspond to the subsample index to apply for a flattened array of valid values. + Relative to the block means converted so that the block indexes for valid values starts at 0 up to the number of + valid values in that block (while the input indices go from zero to the total number of valid values in the full + array). + + :param indices_1d: Subsample 1D indexes among a total number of valid values. + :param num_chunks: Number of chunks in X and Y. + :param nb_valids_per_block: Number of valid pixels per block. + + :returns: Relative 1D valid subsample index per block. + """ + + # Apply a cumulative sum to get the first 1D total index of each block + valids_cumsum = np.cumsum(nb_valids_per_block) + + # We can write a faster algorithm by sorting + indices_1d = np.sort(indices_1d) + + # TODO: Write nested lists into array format to further save RAM? + # We define a list of indices per block + relative_index_per_block = [[] for _ in range(num_chunks[0] * num_chunks[1])] + k = 0 # K is the block number + for i in indices_1d: + + # Move to the next block K where current 1D subsample index is, if not in this one + while i >= valids_cumsum[k]: + k += 1 + + # Add 1D subsample index relative to first subsample index of this block + first_index_block = valids_cumsum[k - 1] if k >= 1 else 0 # The first 1D valid subsample index of the block + relative_index = i - first_index_block + relative_index_per_block[k].append(relative_index) + + return relative_index_per_block + + +@dask.delayed # type: ignore +def _delayed_nb_valids(arr_chunk: NDArrayNum | NDArrayBool) -> NDArrayNum: + """Count number of valid values per block.""" + if arr_chunk.dtype == "bool": + return np.array([np.count_nonzero(arr_chunk)]).reshape((1, 1)) + return np.array([np.count_nonzero(np.isfinite(arr_chunk))]).reshape((1, 1)) + + +@dask.delayed # type: ignore +def _delayed_subsample_block( + arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum +) -> NDArrayNum | NDArrayBool: + """Subsample the valid values at the corresponding 1D valid indices per block.""" + + if arr_chunk.dtype == "bool": + return arr_chunk[arr_chunk][subsample_indices] + return arr_chunk[np.isfinite(arr_chunk)][subsample_indices] + + +@dask.delayed # type: ignore +def _delayed_subsample_indices_block( + arr_chunk: NDArrayNum | NDArrayBool, subsample_indices: NDArrayNum, block_id: dict[str, Any] +) -> NDArrayNum: + """Return 2D indices from the subsampled 1D valid indices per block.""" + + if arr_chunk.dtype == "bool": + ix, iy = np.unravel_index(np.argwhere(arr_chunk.flatten())[subsample_indices], shape=arr_chunk.shape) + else: + # Unravel indices of valid data to the shape of the block + ix, iy = np.unravel_index( + np.argwhere(np.isfinite(arr_chunk.flatten()))[subsample_indices], shape=arr_chunk.shape + ) + + # Convert to full-array indexes by adding the row and column starting indexes for this block + ix += block_id["xstart"] + iy += block_id["ystart"] + + return np.hstack((ix, iy)) + + +def delayed_subsample( + darr: da.Array, + subsample: int | float = 1, + return_indices: bool = False, + random_state: int | np.random.Generator | None = None, + silence_max_subsample: bool = False, +) -> NDArrayNum | tuple[NDArrayNum, NDArrayNum]: + """ + Subsample a raster at valid values on out-of-memory chunks. + + Optionally, this function can return the 2D indices of the subsample of valid values instead. + + The random subsample is distributed evenly across valid values, no matter which chunk they belong to. + First, the number of valid values in each chunk are computed out-of-memory. Then, a subsample is defined among + the total number of valid values, which are then indexed sequentially along the chunk valid values out-of-memory. + + A random state will give a fixed subsample for a delayed array with a fixed chunksize. However, the subsample + will vary with changing chunksize because the 1D delayed indexing depends on it (indexing per valid value per + flattened chunk). For this reason, a loaded array will also have a different subsample due to its direct 1D + indexing (per valid value for the entire flattened array). + + To ensure you re-use a similar subsample of valid values for several arrays, call this function with + return_indices=True, then sample your arrays out-of-memory with .vindex[indices[0], indices[1]] + (this assumes that these arrays have valid values at the same locations). + + Only valid values are sampled. If passing a numerical array, then only finite values are considered valid values. + If passing a boolean array, then only True values are considered valid values. + + :param darr: Input dask array. This can be a boolean or a numerical array. + :param subsample: Subsample size. If <= 1, will be considered a fraction of valid pixels to extract. + If > 1 will be considered the number of valid pixels to extract. + :param return_indices: If set to True, will return the extracted indices only. + :param random_state: Random state, or seed number to use for random calculations. + :param silence_max_subsample: Whether to silence the warning for the subsample size being larger than the total + number of valid points (warns by default). + + :return: Subsample of values from the array (optionally, their indexes). + """ + + # Get random state + rng = np.random.default_rng(random_state) + + # Compute number of valid points for each block out-of-memory + blocks = darr.to_delayed().ravel() + list_delayed_valids = [ + da.from_delayed(_delayed_nb_valids(b), shape=(1, 1), dtype=np.dtype("int32")) for b in blocks + ] + nb_valids_per_block = np.concatenate([dask.compute(*list_delayed_valids)]) + + # Sum to get total number of valid points + total_nb_valids = np.sum(nb_valids_per_block) + + # Get subsample size (depending on user input) + subsample_size = _get_subsample_size_from_user_input( + subsample=subsample, total_nb_valids=total_nb_valids, silence_max_subsample=silence_max_subsample + ) + + # Get random 1D indexes for the subsample size + indices_1d = rng.choice(total_nb_valids, subsample_size, replace=False) + + # Sort which indexes belong to which chunk + ind_per_block = _get_indices_block_per_subsample( + indices_1d, num_chunks=darr.numblocks, nb_valids_per_block=nb_valids_per_block + ) + + # To just get the subsample without indices + if not return_indices: + # Task a delayed subsample to be computed for each block, skipping blocks with no values to sample + list_subsamples = [ + _delayed_subsample_block(b, ind) + for i, (b, ind) in enumerate(zip(blocks, ind_per_block)) + if len(ind_per_block[i]) > 0 + ] + # Cast output to the right expected dtype and length, then compute and concatenate + list_subsamples_delayed = [ + da.from_delayed(s, shape=(nb_valids_per_block[i]), dtype=darr.dtype) for i, s in enumerate(list_subsamples) + ] + subsamples = np.concatenate(dask.compute(*list_subsamples_delayed), axis=0) + + return subsamples + + # To return indices + else: + # Get starting 2D index for each chunk of the full array + # (mirroring what is done in block_id of dask.array.map_blocks) + # https://github.com/dask/dask/blob/24493f58660cb933855ba7629848881a6e2458c1/dask/array/core.py#L908 + # This list also includes the last index as well (not used here) + starts = [cached_cumsum(c, initial_zero=True) for c in darr.chunks] + num_chunks = darr.numblocks + # Get the starts per 1D block ID by unravelling starting indexes for each block + indexes_xi, indexes_yi = np.unravel_index(np.arange(len(blocks)), shape=(num_chunks[0], num_chunks[1])) + block_ids = [ + {"xstart": starts[0][indexes_xi[i]], "ystart": starts[1][indexes_yi[i]]} for i in range(len(blocks)) + ] + + # Task delayed subsample indices to be computed for each block, skipping blocks with no values to sample + list_subsample_indices = [ + _delayed_subsample_indices_block(b, ind, block_id=block_ids[i]) + for i, (b, ind) in enumerate(zip(blocks, ind_per_block)) + if len(ind_per_block[i]) > 0 + ] + # Cast output to the right expected dtype and length, then compute and concatenate + list_subsamples_indices_delayed = [ + da.from_delayed(s, shape=(2, len(ind_per_block[i])), dtype=np.dtype("int32")) + for i, s in enumerate(list_subsample_indices) + ] + indices = np.concatenate(dask.compute(*list_subsamples_indices_delayed), axis=0) + + return indices[:, 0], indices[:, 1] + + +# 2/ POINT INTERPOLATION ON REGULAR OR EQUAL GRID +# At the date of April 2024: +# This functionality is not covered efficiently by Dask/Xarray, because they need to support rectilinear grids, which +# is difficult when interpolating in the chunked dimensions, and loads nearly all array memory when using .interp(). + +# Here we harness the fact that rasters are always on regular (or sometimes equal) grids to efficiently map +# the location of the blocks required for interpolation, which requires little memory usage. + +# Code structure inspired by https://blog.dask.org/2021/07/02/ragged-output and the "block_id" in map_blocks + + +def _get_interp_indices_per_block( + interp_x: NDArrayNum, + interp_y: NDArrayNum, + starts: list[tuple[int, ...]], + num_chunks: tuple[int, int], + chunksize: tuple[int, int], + xres: float, + yres: float, +) -> list[list[int]]: + """Map blocks where each pair of interpolation coordinates will have to be computed.""" + + # TODO 1: Check the robustness for chunksize different and X and Y + + # TODO 2: Check if computing block_i_id matricially + using an == comparison (possibly delayed) to get index + # per block is not more computationally efficient? + # (as it uses array instead of nested lists, and nested lists grow in RAM very fast) + + # The argument "starts" contains the list of chunk first X/Y index for the full array, plus the last index + + # We use one bucket per block, assuming a flattened blocks shape + ind_per_block = [[] for _ in range(num_chunks[0] * num_chunks[1])] + for i, (x, y) in enumerate(zip(interp_x, interp_y)): + # Because it is a regular grid, we know exactly in which block ID the coordinate will fall + block_i_1d = int((x - starts[0][0]) / (xres * chunksize[0])) * num_chunks[1] + int( + (y - starts[1][0]) / (yres * chunksize[1]) + ) + ind_per_block[block_i_1d].append(i) + + return ind_per_block + + +@dask.delayed # type: ignore +def _delayed_interp_points_block( + arr_chunk: NDArrayNum, block_id: dict[str, Any], interp_coords: NDArrayNum +) -> NDArrayNum: + """ + Interpolate block in 2D out-of-memory for a regular or equal grid. + """ + + # Extract information out of block_id dictionary + xs, ys, xres, yres = (block_id["xstart"], block_id["ystart"], block_id["xres"], block_id["yres"]) + + # Reconstruct the coordinates from xi/yi/xres/yres (as it has to be a regular grid) + x_coords = np.arange(xs, xs + xres * arr_chunk.shape[0], xres) + y_coords = np.arange(ys, ys + yres * arr_chunk.shape[1], yres) + + # TODO: Use scipy.map_coordinates for an equal grid as in Raster.interp_points? + + # Interpolate to points + interp_chunk = interpn(points=(x_coords, y_coords), values=arr_chunk, xi=(interp_coords[0, :], interp_coords[1, :])) + + # And return the interpolated array + return interp_chunk + + +def delayed_interp_points( + darr: da.Array, + points: tuple[list[float], list[float]], + resolution: tuple[float, float], + method: Literal["nearest", "linear", "cubic", "quintic"] = "linear", +) -> NDArrayNum: + """ + Interpolate raster at point coordinates on out-of-memory chunks. + + This function harnesses the fact that a raster is defined on a regular (or equal) grid, and it is therefore + faster than Xarray.interpn (especially for small sample sizes) and uses only a fraction of the memory usage. + + :param darr: Input dask array. + :param points: Point(s) at which to interpolate raster value. If points fall outside of image, value + returned is nan. Shape should be (N,2). + :param resolution: Resolution of the raster (xres, yres). + :param method: Interpolation method, one of 'nearest', 'linear', 'cubic', or 'quintic'. For more information, + see scipy.ndimage.map_coordinates and scipy.interpolate.interpn. Default is linear. + + :return: Array of raster value(s) for the given points. + """ + + # TODO: Replace by a generic 2D point casting function accepting multiple inputs (living outside this function) + # Convert input to 2D array + points_arr = np.vstack((points[0], points[1])) + + # Map depth of overlap required for each interpolation method + # TODO: Double-check this window somewhere in SciPy's documentation + map_depth = {"nearest": 1, "linear": 2, "cubic": 3, "quintic": 5} + + # Expand dask array for overlapping computations + chunksize = darr.chunksize + expanded = da.overlap.overlap(darr, depth=map_depth[method], boundary=np.nan) + + # Get starting 2D index for each chunk of the full array + # (mirroring what is done in block_id of dask.array.map_blocks) + starts = [cached_cumsum(c, initial_zero=True) for c in darr.chunks] + num_chunks = expanded.numblocks + + # Get samples indices per blocks + ind_per_block = _get_interp_indices_per_block( + points_arr[0, :], points_arr[1, :], starts, num_chunks, chunksize, resolution[0], resolution[1] + ) + + # Create a delayed object for each block, and flatten the blocks into a 1d shape + blocks = expanded.to_delayed().ravel() + + # Build the block IDs by unravelling starting indexes for each block + indexes_xi, indexes_yi = np.unravel_index(np.arange(len(blocks)), shape=(num_chunks[0], num_chunks[1])) + block_ids = [ + { + "xstart": (starts[0][indexes_xi[i]] - map_depth[method]) * resolution[0], + "ystart": (starts[1][indexes_yi[i]] - map_depth[method]) * resolution[1], + "xres": resolution[0], + "yres": resolution[1], + } + for i in range(len(blocks)) + ] + + # Compute values delayed + list_interp = [ + _delayed_interp_points_block(data_chunk, block_ids[i], points_arr[:, ind_per_block[i]]) + for i, data_chunk in enumerate(blocks) + if len(ind_per_block[i]) > 0 + ] + + # We define the expected output shape and dtype to simplify things for Dask + list_interp_delayed = [ + da.from_delayed(p, shape=(1, len(ind_per_block[i])), dtype=darr.dtype) for i, p in enumerate(list_interp) + ] + interp_points = np.concatenate(dask.compute(*list_interp_delayed), axis=0) + + # Re-order per-block output points to match their original indices + indices = np.concatenate(ind_per_block).astype(int) + argsort = np.argsort(indices) + interp_points = np.array(interp_points)[argsort] + + return interp_points + + +# 3/ REPROJECT +# At the date of April 2024: not supported by Rioxarray +# Part of the code (defining a GeoGrid and GeoTiling classes) is inspired by +# https://github.com/opendatacube/odc-geo/pull/88, modified to be concise, stand-alone and rely only on +# Rasterio/GeoPandas +# Could be submitted as a PR to Rioxarray? (but not sure the dependency to GeoPandas would work there) + +# We define a GeoGrid and GeoTiling class (which composes GeoGrid) to consistently deal with georeferenced footprints +# of chunked grids +GeoGridType = TypeVar("GeoGridType", bound="GeoGrid") + + +class GeoGrid: + """ + Georeferenced grid class. + + Describes a georeferenced grid through a geotransform (one-sided bounds and resolution), shape and CRS. + """ + + def __init__(self, transform: rio.transform.Affine, shape: tuple[int, int], crs: rio.crs.CRS | None): + + self._transform = transform + self._shape = shape + self._crs = crs + + @property + def transform(self) -> rio.transform.Affine: + return self._transform + + @property + def crs(self) -> rio.crs.CRS: + return self._crs + + @property + def shape(self) -> tuple[int, int]: + return self._shape + + @property + def height(self) -> int: + return self.shape[0] + + @property + def width(self) -> int: + return self.shape[1] + + @property + def res(self) -> tuple[int, int]: + return self.transform[0], abs(self.transform[4]) + + def bounds_projected(self, crs: rio.crs.CRS = None) -> rio.coords.BoundingBox: + if crs is None: + crs = self.crs + bounds = rio.coords.BoundingBox(*rio.transform.array_bounds(self.height, self.width, self.transform)) + return _get_bounds_projected(bounds=bounds, in_crs=self.crs, out_crs=crs) + + @property + def bounds(self) -> rio.coords.BoundingBox: + return self.bounds_projected() + + def footprint_projected(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame: + if crs is None: + crs = self.crs + return _get_footprint_projected(self.bounds, in_crs=self.crs, out_crs=crs, densify_points=100) + + @property + def footprint(self) -> gpd.GeoDataFrame: + return self.footprint_projected() + + @classmethod + def from_dict(cls: type[GeoGridType], dict_meta: dict[str, Any]) -> GeoGridType: + """Create a GeoGrid from a dictionary containing transform, shape and CRS.""" + return cls(**dict_meta) + + def shift( + self: GeoGridType, + xoff: float, + yoff: float, + distance_unit: Literal["georeferenced"] | Literal["pixel"] = "pixel", + ) -> GeoGridType: + """Shift into a new geogrid (not inplace).""" + + if distance_unit not in ["georeferenced", "pixel"]: + raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.") + + # Get transform + dx, b, xmin, d, dy, ymax = list(self.transform)[:6] + + # Convert pixel offsets to georeferenced units + if distance_unit == "pixel": + # Can either multiply the offset by the resolution + # xoff *= self.res[0] + # yoff *= self.res[1] + + # Or use the boundaries instead! (maybe less floating point issues? doesn't seem to matter in tests) + xoff = xoff / self.shape[1] * (self.bounds.right - self.bounds.left) + yoff = yoff / self.shape[0] * (self.bounds.top - self.bounds.bottom) + + shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff) + + return self.from_dict({"transform": shifted_transform, "crs": self.crs, "shape": self.shape}) + + +def _get_block_ids_per_chunk(chunks: tuple[tuple[int, ...], tuple[int, ...]]) -> list[dict[str, int]]: + """Get location of chunks based on array shape and list of chunk sizes.""" + + # Get number of chunks + num_chunks = (len(chunks[0]), len(chunks[1])) + + # Get robust list of chunk locations (using what is done in block_id of dask.array.map_blocks) + # https://github.com/dask/dask/blob/24493f58660cb933855ba7629848881a6e2458c1/dask/array/core.py#L908 + from dask.utils import cached_cumsum + + starts = [cached_cumsum(c, initial_zero=True) for c in chunks] + nb_blocks = num_chunks[0] * num_chunks[1] + ixi, iyi = np.unravel_index(np.arange(nb_blocks), shape=(num_chunks[0], num_chunks[1])) + # Starting and ending indexes "s" and "e" for both X/Y, to place the chunk in the full array + block_ids = [ + { + "num_block": i, + "ys": starts[0][ixi[i]], + "xs": starts[1][iyi[i]], + "ye": starts[0][ixi[i] + 1], + "xe": starts[1][iyi[i] + 1], + } + for i in range(nb_blocks) + ] + + return block_ids + + +class ChunkedGeoGrid: + """ + Chunked georeferenced grid class. + + Associates a georeferenced grid to chunks (possibly of varying sizes). + """ + + def __init__(self, grid: GeoGrid, chunks: tuple[tuple[int, ...], tuple[int, ...]]): + + self._grid = grid + self._chunks = chunks + + @property + def grid(self) -> GeoGrid: + return self._grid + + @property + def chunks(self) -> tuple[tuple[int, ...], tuple[int, ...]]: + return self._chunks + + def get_block_locations(self) -> list[dict[str, int]]: + """Get block locations in 2D: xstart, xend, ystart, yend.""" + return _get_block_ids_per_chunk(self._chunks) + + def get_blocks_as_geogrids(self) -> list[GeoGrid]: + """Get blocks as geogrids with updated transform/shape.""" + + block_ids = self.get_block_locations() + + list_geogrids = [] + for bid in block_ids: + # We get the block size + block_shape = (bid["ye"] - bid["ys"], bid["xe"] - bid["xs"]) + # Build a temporary geogrid with the same transform as the full grid, but with the chunk shape + geogrid_tmp = GeoGrid(transform=self.grid.transform, crs=self.grid.crs, shape=block_shape) + # And shift it to the right location (X is positive in index direction, Y is negative) + geogrid_block = geogrid_tmp.shift(xoff=bid["xs"], yoff=-bid["ys"]) + list_geogrids.append(geogrid_block) + + return list_geogrids + + def get_block_footprints(self, crs: rio.crs.CRS = None) -> gpd.GeoDataFrame: + """Get block projected footprints as a single geodataframe.""" + + geogrids = self.get_blocks_as_geogrids() + footprints = [gg.footprint_projected(crs=crs) if crs is not None else gg.footprint for gg in geogrids] + + return pd.concat(footprints) + + +def _chunks2d_from_chunksizes_shape( + chunksizes: tuple[int, int], shape: tuple[int, int] +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """Get tuples of chunk sizes for X/Y dimensions based on chunksizes and array shape.""" + + # Chunksize is fixed, except for the last chunk depending on the shape + chunks_y = tuple( + min( + chunksizes[0], + shape[0] - i * chunksizes[0], + ) + for i in range(int(np.ceil(shape[0] / chunksizes[0]))) + ) + chunks_x = tuple( + min( + chunksizes[1], + shape[1] - i * chunksizes[1], + ) + for i in range(int(np.ceil(shape[1] / chunksizes[1]))) + ) + + return chunks_y, chunks_x + + +def _combined_blocks_shape_transform( + sub_block_ids: list[dict[str, int]], src_geogrid: GeoGrid +) -> tuple[dict[str, Any], list[dict[str, int]]]: + """Derive combined shape and transform from a subset of several blocks (for source input during reprojection).""" + + # Get combined shape by taking min of X/Y starting indices, max of X/Y ending indices + all_xs, all_ys, all_xe, all_ye = ([b[s] for b in sub_block_ids] for s in ["xs", "ys", "xe", "ye"]) + minmaxs = {"min_xs": np.min(all_xs), "max_xe": np.max(all_xe), "min_ys": np.min(all_ys), "max_ye": np.max(all_ye)} + combined_shape = (minmaxs["max_ye"] - minmaxs["min_ys"], minmaxs["max_xe"] - minmaxs["min_xs"]) + + # Shift source transform with start indexes to get the one for combined block location + combined_transform = src_geogrid.shift(xoff=minmaxs["min_xs"], yoff=-minmaxs["min_ys"]).transform + + # Compute relative block indexes that will be needed to reconstruct a square array in the delayed function, + # by subtracting the minimum starting indices in X/Y + relative_block_indexes = [ + {"r" + s1 + s2: b[s1 + s2] - minmaxs["min_" + s1 + "s"] for s1 in ["x", "y"] for s2 in ["s", "e"]} + for b in sub_block_ids + ] + + combined_meta = {"src_shape": combined_shape, "src_transform": tuple(combined_transform)} + + return combined_meta, relative_block_indexes + + +@dask.delayed # type: ignore +def _delayed_reproject_per_block( + *src_arrs: tuple[NDArrayNum], block_ids: list[dict[str, int]], combined_meta: dict[str, Any], **kwargs: Any +) -> NDArrayNum: + """ + Delayed reprojection per destination block (also rebuilds a square array combined from intersecting source blocks). + """ + + # If no source chunk intersects, we return a chunk of destination nodata values + if len(src_arrs) == 0: + # We can use float32 to return NaN, will be cast to other floating type later if that's not source array dtype + dst_arr = np.zeros(combined_meta["dst_shape"], dtype=np.dtype("float32")) + dst_arr[:] = kwargs["dst_nodata"] + return dst_arr + + # First, we build an empty array with the combined shape, only with nodata values + comb_src_arr = np.ones((combined_meta["src_shape"]), dtype=src_arrs[0].dtype) + comb_src_arr[:] = kwargs["src_nodata"] + + # Then fill it with the source chunks values + for i, arr in enumerate(src_arrs): + bid = block_ids[i] + comb_src_arr[bid["rys"] : bid["rye"], bid["rxs"] : bid["rxe"]] = arr + + # Now, we can simply call Rasterio! + + # We build the combined transform from tuple + src_transform = rio.transform.Affine(*combined_meta["src_transform"]) + dst_transform = rio.transform.Affine(*combined_meta["dst_transform"]) + + # Reproject + dst_arr = np.zeros(combined_meta["dst_shape"], dtype=comb_src_arr.dtype) + + _ = rio.warp.reproject( + comb_src_arr, + dst_arr, + src_transform=src_transform, + src_crs=kwargs["src_crs"], + dst_transform=dst_transform, + dst_crs=kwargs["dst_crs"], + resampling=kwargs["resampling"], + src_nodata=kwargs["src_nodata"], + dst_nodata=kwargs["dst_nodata"], + num_threads=1, # Force the number of threads to 1 to avoid Dask/Rasterio conflicting on multi-threading + ) + + return dst_arr + + +def delayed_reproject( + darr: da.Array, + src_transform: rio.transform.Affine, + src_crs: rio.crs.CRS, + dst_transform: rio.transform.Affine, + dst_shape: tuple[int, int], + dst_crs: rio.crs.CRS, + resampling: rio.enums.Resampling, + src_nodata: int | float | None = None, + dst_nodata: int | float | None = None, + dst_chunksizes: tuple[int, int] | None = None, + **kwargs: Any, +) -> da.Array: + """ + Reproject georeferenced raster on out-of-memory chunks. + + Each chunk of the destination array is mapped to one or several intersecting chunks of the source array, and + reprojection is performed using rio.warp.reproject for each mapping. + + Part of the code is inspired by https://github.com/opendatacube/odc-geo/pull/88. + + :param darr: Input dask array for source raster. + :param src_transform: Geotransform of source raster. + :param src_crs: Coordinate reference system of source raster. + :param dst_transform: Geotransform of destination raster. + :param dst_shape: Shape of destination raster. + :param dst_crs: Coordinate reference system of destination raster. + :param resampling: Resampling method. + :param src_nodata: Nodata value of source raster. + :param dst_nodata: Nodata value of destination raster. + :param dst_chunksizes: Chunksizes for destination raster. + :param kwargs: Other arguments to pass to rio.warp.reproject(). + + :return: Dask array of reprojected raster. + """ + + # 1/ Define source and destination chunked georeferenced grid through simple classes storing CRS/transform/shape, + # which allow to consistently derive shape/transform for each block and their CRS-projected footprints + + # Define georeferenced grids for source/destination array + src_geogrid = GeoGrid(transform=src_transform, shape=darr.shape, crs=src_crs) + dst_geogrid = GeoGrid(transform=dst_transform, shape=dst_shape, crs=dst_crs) + + # Add the chunking + # For source, we can use the .chunks attribute + src_chunks = darr.chunks + src_geotiling = ChunkedGeoGrid(grid=src_geogrid, chunks=src_chunks) + + # For destination, we need to create the chunks based on destination chunksizes + if dst_chunksizes is None: + dst_chunksizes = darr.chunksize + dst_chunks = _chunks2d_from_chunksizes_shape(chunksizes=dst_chunksizes, shape=dst_shape) + dst_geotiling = ChunkedGeoGrid(grid=dst_geogrid, chunks=dst_chunks) + + # 2/ Get footprints of tiles in CRS of destination array, with a buffer of 2 pixels for destination ones to ensure + # overlap, then map indexes of source blocks that intersect a given destination block + src_footprints = src_geotiling.get_block_footprints(crs=dst_crs) + dst_footprints = dst_geotiling.get_block_footprints().buffer(2 * max(dst_geogrid.res)) + dest2source = [np.where(dst.intersects(src_footprints).values)[0] for dst in dst_footprints] + + # 3/ To reconstruct a square source array during chunked reprojection, we need to derive the combined shape and + # transform of each tuples of source blocks + src_block_ids = np.array(src_geotiling.get_block_locations()) + meta_params = [ + ( + _combined_blocks_shape_transform(sub_block_ids=src_block_ids[sbid], src_geogrid=src_geogrid) + if len(sbid) > 0 + else ({}, []) + ) + for sbid in dest2source + ] + # We also add the output transform/shape for this destination chunk in the combined meta + # (those are the only two that are chunk-specific) + dst_block_geogrids = dst_geotiling.get_blocks_as_geogrids() + for i, (c, _) in enumerate(meta_params): + c.update({"dst_shape": dst_block_geogrids[i].shape, "dst_transform": tuple(dst_block_geogrids[i].transform)}) + + # 4/ Call a delayed function that uses rio.warp to reproject the combined source block(s) to each destination block + + # Add fixed arguments to keywords + kwargs.update( + { + "src_nodata": src_nodata, + "dst_nodata": dst_nodata, + "resampling": resampling, + "src_crs": src_crs, + "dst_crs": dst_crs, + } + ) + + # Create a delayed object for each block, and flatten the blocks into a 1d shape + blocks = darr.to_delayed().ravel() + # Run the delayed reprojection, looping for each destination block + list_reproj = [ + _delayed_reproject_per_block( + *blocks[dest2source[i]], block_ids=meta_params[i][1], combined_meta=meta_params[i][0], **kwargs + ) + for i in range(len(dest2source)) + ] + + # We define the expected output shape and dtype to simplify things for Dask + list_reproj_delayed = [ + da.from_delayed(r, shape=dst_block_geogrids[i].shape, dtype=darr.dtype) for i, r in enumerate(list_reproj) + ] + + # Array comes out as flat blocks x chunksize0 (varying) x chunksize1 (varying), so we can't reshape directly + # We need to unravel the flattened blocks indices to align X/Y, then concatenate all columns, then rows + indexes_xi, indexes_yi = np.unravel_index( + np.arange(len(dest2source)), shape=(len(dst_chunks[0]), len(dst_chunks[1])) + ) + + lists_columns = [ + [l for i, l in enumerate(list_reproj_delayed) if j == indexes_xi[i]] for j in range(len(dst_chunks[0])) + ] + concat_columns = [da.concatenate(c, axis=1) for c in lists_columns] + concat_all = da.concatenate(concat_columns, axis=0) + + return concat_all diff --git a/requirements.txt b/requirements.txt index 68bac6e0..149d6fe3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ numpy==1.* scipy==1.* tqdm xarray +dask rioxarray==0.* diff --git a/tests/test_array.py b/tests/test_raster/test_array.py similarity index 100% rename from tests/test_array.py rename to tests/test_raster/test_array.py diff --git a/tests/test_raster/test_delayed.py b/tests/test_raster/test_delayed.py new file mode 100644 index 00000000..e87ec91a --- /dev/null +++ b/tests/test_raster/test_delayed.py @@ -0,0 +1,689 @@ +"""Tests for dask-delayed functions.""" + +from __future__ import annotations + +import os +import sys +from tempfile import NamedTemporaryFile +from typing import Any, Callable + +import dask.array as da +import numpy as np +import pandas as pd +import pytest +import rasterio as rio +import xarray as xr +from dask.distributed import Client, LocalCluster +from dask_memusage import install +from pluggy import PluggyTeardownRaisedWarning +from pyproj import CRS + +from geoutils.examples import _EXAMPLES_DIRECTORY +from geoutils.raster.delayed import ( + delayed_interp_points, + delayed_reproject, + delayed_subsample, +) + +# Ignore teardown warning given by Dask when closing the local cluster (due to dask-memusage plugin) +pytestmark = pytest.mark.filterwarnings("ignore", category=PluggyTeardownRaisedWarning) + + +@pytest.fixture(scope="module") # type: ignore +def cluster(): + """Fixture to use a single cluster for the entire module (otherwise raise runtime errors).""" + # Need cluster to be single-threaded to use dask-memusage confidently + dask_cluster = LocalCluster(n_workers=1, threads_per_worker=1, dashboard_address=None) + yield dask_cluster + dask_cluster.close() + + +def _run_dask_measuring_memusage( + cluster: Any, dask_func: Callable[..., Any], *args_dask_func: Any, **kwargs_dask_func: Any +) -> tuple[Any, float]: + """Run a dask function monitoring its memory usage.""" + + # Create a name temporary file that won't delete immediately + fn_tmp_csv = NamedTemporaryFile(suffix=".csv", delete=False).name + + # Setup cluster and client within context managers for a clean shutdown + install(cluster.scheduler, fn_tmp_csv) + with Client(cluster) as _: + outputs = dask_func(*args_dask_func, **kwargs_dask_func) + + # Read memusage file and cleanup + df = pd.read_csv(fn_tmp_csv) + os.remove(fn_tmp_csv) + + # Keep only non-zero memory usage + ind_nonzero = df.max_memory_mb != 0 + + # Compute peak additional memory usage from min baseline + memusage_mb = np.max(df.max_memory_mb[ind_nonzero]) - np.min(df.max_memory_mb[ind_nonzero]) + + return outputs, memusage_mb + + +def _estimate_subsample_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, int], subsample_size: int) -> float: + """ + Estimate the theoretical memory usage of the delayed subsampling method. + (we don't need to be super precise, just within a factor of ~2 to check memory usage performs as expected) + """ + + # TOTAL SIZE = Single chunk operations + Subsample indexes + Metadata passed to dask + Outputs + + # On top of the rest is added the Dask graph, we will multiply by a factor of 2.5 to get a good safety margin + fac_dask_margin = 2.5 + num_chunks = np.prod(darr.numblocks) + + # Single chunk operation = (data type bytes + boolean from np.isfinite) * chunksize **2 + chunk_memusage = (darr.dtype.itemsize + np.dtype("bool").itemsize) * np.prod(chunksizes_in_mem) + + # 1D index subsample size: integer type * subsample_size + sample_memusage = np.dtype("int32").itemsize * subsample_size + + # Outputs: number of valid pixels + subsample + valids_memusage = np.dtype("int32").itemsize * num_chunks + subout_memusage = np.dtype(darr.dtype).itemsize * subsample_size + out_memusage = valids_memusage + subout_memusage + + # Size of metadata passed to dask: number of blocks times its content + # Content of a metadata block = list (block size) of list (subsample size) of integer indexes + size_index_int = 28 # Python size for int + list_all_blocks = 64 + 8 * num_chunks # A list is 64 + 8 bits per element, without memory of contained elements + list_per_block = 64 * num_chunks + 8 * subsample_size + size_index_int * subsample_size # Rough max estimate + meta_memusage = list_per_block + list_all_blocks + + # Final estimate of memory usage of operation in MB + max_op_memusage = fac_dask_margin * (chunk_memusage + sample_memusage + out_memusage + meta_memusage) / (2**20) + # We add a base memory usage of ~80 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) + max_op_memusage += 80 + 10 * (num_chunks / 1000) + + return max_op_memusage + + +def _estimate_interp_points_memusage(darr: da.Array, chunksizes_in_mem: tuple[int, int], ninterp: int) -> float: + """ + Estimate the theoretical memory usage of the delayed interpolation method. + (we don't need to be super precise, just within a factor of ~2 to check memory usage performs as expected) + """ + + # TOTAL SIZE = Single chunk operations + Chunk overlap + Metadata passed to dask + Outputs + + # On top of the rest is added the Dask graph, we will multiply by a factor of 2.5 to get a good safety margin + fac_dask_margin = 2.5 + num_chunks = np.prod(darr.numblocks) + + # Single chunk operation = (data type bytes + boolean from np.isfinite + its subset) * overlapping chunksize **2 + chunk_memusage = (darr.dtype.itemsize + 2 * np.dtype("bool").itemsize) * np.prod(chunksizes_in_mem) + # For interpolation, chunks have to overlap and temporarily load each neighbouring chunk, + # we add 8 neighbouring chunks, and double the size due to the memory used during interpolation + chunk_memusage *= 9 + + # Outputs: pair of interpolated coordinates + out_memusage = np.dtype(darr.dtype).itemsize * ninterp * 2 + + # Size of metadata passed to dask: number of blocks times its content + # Content of a metadata block = list (block size) of list (subsample size) of integer + size_index_int = 28 # Python size for int + size_index_float = 24 # Python size for float + list_all_blocks = 64 + 8 * num_chunks # A list is 64 + 8 bits per element, without memory of contained elements + list_per_block = 64 * num_chunks + 8 * ninterp + size_index_int * ninterp # Rough max estimate + # And a list for each block of dict with 4 floats (xres, yres, xstart, ystart) + dict_all_blocks = 64 * num_chunks + 4 * size_index_float * 64 * num_chunks + meta_memusage = list_per_block + list_all_blocks + dict_all_blocks + + # Final estimate of memory usage of operation in MB + max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2**20) + # We add a base memory usage of ~80 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) + max_op_memusage += 80 + 10 * (num_chunks / 1000) + + return max_op_memusage + + +def _estimate_reproject_memusage( + darr: da.Array, + chunksizes_in_mem: tuple[int, int], + dst_chunksizes: tuple[int, int], + rel_res_fac: tuple[float, float], +) -> float: + """ + Estimate the theoretical memory usage of the delayed reprojection method. + (we don't need to be super precise, just within a factor of ~2 to check memory usage performs as expected) + """ + + # TOTAL SIZE = Combined source chunk operations + Building geopandas mapping + Metadata passed to dask + Outputs + + # On top of the rest is added the Dask graph, we will multiply by a factor of 2.5 to get a good safety margin + fac_dask_margin = 2.5 + num_chunks = np.prod(darr.numblocks) + + # THE BIG QUESTION: how many maximum source chunks might be loaded for a single destination chunk? + # It depends on the relative ratio of input chunksizes to destination chunksizes, accounting for resolution change + x_rel_source_chunks = dst_chunksizes[0] / chunksizes_in_mem[0] * rel_res_fac[0] + y_rel_source_chunks = dst_chunksizes[1] / chunksizes_in_mem[1] * rel_res_fac[1] + # There is also some overlap needed for resampling and due to warping in different CRS, so let's multiply this by 8 + # (all neighbouring tiles) + nb_source_chunks_per_dest = 8 * x_rel_source_chunks * y_rel_source_chunks + + # Combined memory usage of one chunk operation = squared array made from combined chunksize + original chunks + total_nb = np.ceil(np.sqrt(nb_source_chunks_per_dest)) ** 2 + nb_source_chunks_per_dest + # We multiply the memory usage of a single chunk to the number of loaded/combined chunks + chunk_memusage = darr.dtype.itemsize * np.prod(chunksizes_in_mem) * total_nb + + # Outputs: reprojected raster + out_memusage = np.dtype(darr.dtype).itemsize * np.prod(dst_chunksizes) + + # Size of metadata passed to dask: number of blocks times its content + # For each block, we pass a dict with 4 integers per source chunk (rxs, rxe, rys, rye) + size_index_float = 24 # Python size for float + size_index_int = 28 # Python size for float + dict_all_blocks = (64 + 4 * size_index_int * nb_source_chunks_per_dest) * num_chunks + # Passing the 2 CRS, 2 transforms, resampling methods and 2 nodatas + combined_meta = (112 + 112 + 56 + 56 + 44 + 28 + 28) * size_index_float * num_chunks + meta_memusage = combined_meta + dict_all_blocks + + # Final estimate of memory usage of operation in MB + max_op_memusage = fac_dask_margin * (chunk_memusage + out_memusage + meta_memusage) / (2**20) + # We add a base memory usage of ~80 MB + 10MB per 1000 chunks (loaded in background by Dask even on tiny data) + max_op_memusage += 80 + 10 * (num_chunks / 1000) + + return max_op_memusage + + +def _build_dst_transform_shifted_newres( + src_transform: rio.transform.Affine, + src_shape: tuple[int, int], + src_crs: CRS, + dst_crs: CRS, + bounds_rel_shift: tuple[float, float], + res_rel_fac: tuple[float, float], +) -> rio.transform.Affine: + """ + Build a destination transform intersecting the source transform given source/destination shapes, + and possibly introducing a relative shift in upper-left bound and multiplicative change in resolution. + """ + + # Get bounding box in source CRS + bounds = rio.coords.BoundingBox(*rio.transform.array_bounds(src_shape[0], src_shape[1], src_transform)) + + # Construct an aligned transform in the destination CRS assuming the same resolution + tmp_transform = rio.warp.calculate_default_transform( + src_crs, + dst_crs, + src_shape[1], + src_shape[0], + left=bounds.left, + right=bounds.right, + top=bounds.top, + bottom=bounds.bottom, + dst_width=src_shape[1], + dst_height=src_shape[0], + )[0] + # This allows us to get bounds and resolution in the units of the new CRS + tmp_res = (tmp_transform[0], abs(tmp_transform[4])) + tmp_bounds = rio.coords.BoundingBox(*rio.transform.array_bounds(src_shape[0], src_shape[1], tmp_transform)) + # Now we can create a shifted/different-res destination grid + dst_transform = rio.transform.from_origin( + west=tmp_bounds.left + bounds_rel_shift[0] * tmp_res[0] * src_shape[1], + north=tmp_bounds.top + 150 * bounds_rel_shift[0] * tmp_res[1] * src_shape[0], + xsize=tmp_res[0] / res_rel_fac[0], + ysize=tmp_res[1] / res_rel_fac[1], + ) + + return dst_transform + + +class TestDelayed: + """ + Testing class for delayed functions. + + We test on a first set of rasters big enough to clearly monitor the memory usage, and a second set small enough + to run fast to check a wide range of input parameters. + + In details: + Set 1. We capture memory usage during the .compute() calls and check that only the expected amount of memory that + we estimate independently (bytes used by one or several chunk combinations + metadata) is indeed used. + Set 2. We compare outputs with the in-memory function specifically for input variables that influence the delayed + algorithm and might lead to new errors (for example: array shape to get subsample/points locations for + subsample and interp_points, or destination chunksizes to map output of reproject). + + We start with set 2: output checks which run faster when ordered before + (maybe due to the cluster memory monitoring after). + """ + + # Define random seed for generating test data + rng = da.random.default_rng(seed=42) + + # 1/ Set 1: Memory usage checks + + # Big test files written on disk in an out-of-memory fashion, + # with different input shapes not necessarily aligned between themselves + large_shape = (10000, 10000) + # We can use a constant value for storage chunks, as it doesn't have any influence on the accuracy of delayed + # methods (can change slightly RAM usage, but pretty stable as long as chunksizes in memory are larger and + # significantly bigger) + chunksizes_on_disk = (500, 500) + fn_large = os.path.join(_EXAMPLES_DIRECTORY, "test_large.nc") + if not os.path.exists(fn_large): + # Create random array in the right shape + data = rng.normal(size=large_shape[0] * large_shape[1]).reshape(large_shape[0], large_shape[1]) + data_arr = xr.DataArray(data=data, dims=["x", "y"]) + ds = xr.Dataset(data_vars={"test": data_arr}) + encoding_kwargs = {"test": {"chunksizes": chunksizes_on_disk}} + # Write to disk out-of-memory + writer = ds.to_netcdf(fn_large, encoding=encoding_kwargs, compute=False) + writer.compute() + + # 2. Set 2 + # Smaller test files for fast checks, with various shapes and with/without nodata + list_small_shapes = [(50, 50), (51, 47)] + with_nodata = [False, True] + list_small_darr = [] + for small_shape in list_small_shapes: + for w in with_nodata: + small_darr = rng.normal(size=small_shape[0] * small_shape[1]) + # Add about half nodata values + if w: + ind_nodata = rng.choice(small_darr.size, size=int(small_darr.size / 2), replace=False) + small_darr[list(ind_nodata)] = np.nan + small_darr = small_darr.reshape(small_shape[0], small_shape[1]) + list_small_darr.append(small_darr) + + # List of in-memory chunksize for small tests + list_small_chunksizes_in_mem = [(10, 10), (7, 19)] + + # Create a corresponding boolean array for each numerical dask array + # Every finite numerical value (valid numerical value) corresponds to True (valid boolean value). + darr_bool = [] + for small_darr in list_small_darr: + darr_bool.append(da.where(da.isfinite(small_darr), True, False)) + + @pytest.mark.parametrize("darr, darr_bool", list(zip(list_small_darr, darr_bool))) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", list_small_chunksizes_in_mem) # type: ignore + @pytest.mark.parametrize("subsample_size", [2, 100, 100000]) # type: ignore + def test_delayed_subsample__output( + self, darr: da.Array, darr_bool: da.Array, chunksizes_in_mem: tuple[int, int], subsample_size: int + ): + """ + Checks for delayed subsampling function for output accuracy. + Variables that influence specifically the delayed function and might lead to new errors are: + - Input chunksizes, + - Input array shape, + - Number of subsampled points. + """ + + # 1/ We run the delayed function after re-chunking + darr = darr.rechunk(chunksizes_in_mem) + sub = delayed_subsample(darr, subsample=subsample_size, random_state=42) + # 2/ Output checks + + # # The subsample should have exactly the prescribed length, with only valid values + assert len(sub) == min(subsample_size, np.count_nonzero(np.isfinite(darr))) + assert all(np.isfinite(sub)) + + # To verify the sampling works correctly, we can get its subsample indices with the argument return_indices + # And compare to the same subsample with vindex (now that we know the coordinates of valid values sampled) + indices = delayed_subsample(darr, subsample=subsample_size, random_state=42, return_indices=True) + sub2 = np.array(darr.vindex[indices[0], indices[1]]) + assert np.array_equal(sub, sub2) + + # Finally, to verify that a boolean array, with valid values at the same locations as the numerical array, + # leads to the same results, we compare the samples values and the samples indices. + darr_bool = darr_bool.rechunk(chunksizes_in_mem) + indices_bool = delayed_subsample(darr_bool, subsample=subsample_size, random_state=42, return_indices=True) + sub_bool = np.array(darr.vindex[indices_bool]) + assert np.array_equal(sub, sub_bool) + assert np.array_equal(indices, indices_bool) + + @pytest.mark.parametrize("darr", list_small_darr) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", list_small_chunksizes_in_mem) # type: ignore + @pytest.mark.parametrize("ninterp", [2, 100]) # type: ignore + @pytest.mark.parametrize("res", [(0.5, 2), (1, 1)]) # type: ignore + def test_delayed_interp_points__output( + self, darr: da.Array, chunksizes_in_mem: tuple[int, int], ninterp: int, res: tuple[float, float] + ): + """ + Checks for delayed interpolate points function. + Variables that influence specifically the delayed function are: + - Input chunksizes, + - Input array shape, + - Number of interpolated points, + - The resolution of the regular grid. + """ + + # 1/ Define points to interpolate given the size and resolution + darr = darr.rechunk(chunksizes_in_mem) + rng = np.random.default_rng(seed=42) + interp_x = (rng.choice(darr.shape[0], ninterp) + rng.random(ninterp)) * res[0] + interp_y = (rng.choice(darr.shape[1], ninterp) + rng.random(ninterp)) * res[1] + + interp1 = delayed_interp_points(darr, points=(interp_x, interp_y), resolution=res) + + # 2/ Output checks + + # Interpolate directly with Xarray (loads a lot in memory) and check results are exactly the same + xx = xr.DataArray(interp_x, dims="z", name="x") + yy = xr.DataArray(interp_y, dims="z", name="y") + ds = xr.DataArray( + data=darr, + dims=["x", "y"], + coords={ + "x": np.arange(0, darr.shape[0] * res[0], res[0]), + "y": np.arange(0, darr.shape[1] * res[1], res[1]), + }, + ) + interp2 = ds.interp(x=xx, y=yy) + interp2.compute() + interp2 = np.array(interp2.values) + + assert np.array_equal(interp1, interp2, equal_nan=True) + + @pytest.mark.parametrize("darr", list_small_darr) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", list_small_chunksizes_in_mem) # type: ignore + @pytest.mark.parametrize("dst_chunksizes", list_small_chunksizes_in_mem) # type: ignore + # Shift upper left corner of output bounds (relative to projected input bounds) by fractions of the raster size + @pytest.mark.parametrize("dst_bounds_rel_shift", [(0, 0), (-0.2, 0.5)]) # type: ignore + # Modify output resolution (relative to projected input resolution) by a factor + @pytest.mark.parametrize("dst_res_rel_fac", [(1, 1), (2.1, 0.54)]) # type: ignore + # Same for shape + @pytest.mark.parametrize("dst_shape_diff", [(0, 0), (-28, 117)]) # type: ignore + def test_delayed_reproject__output( + self, + darr: da.Array, + chunksizes_in_mem: tuple[int, int], + dst_chunksizes: tuple[int, int], + dst_bounds_rel_shift: tuple[float, float], + dst_res_rel_fac: tuple[float, float], + dst_shape_diff: tuple[int, int], + ): + """ + Checks for the delayed reproject function. + Variables that influence specifically the delayed function are: + - Input/output chunksizes, + - Input array shape, + - Output geotransform relative to projected input geotransform, + - Output array shape relative to input. + """ + + # Keeping this commented here if we need to redo local tests due to Rasterio errors + # darr = list_small_darr[0] + # chunksizes_in_mem = list_small_chunksizes_in_mem[0] + # dst_chunksizes = list_small_chunksizes_in_mem[0] # (2000, 2000) + # dst_bounds_rel_shift = (0, 0) + # dst_res_rel_fac = (0.45, 0.45) # (1, 1) + # dst_shape_diff = (0, 0) + # cluster = LocalCluster(n_workers=1, threads_per_worker=1, dashboard_address=None) + + # 0/ Define input parameters + + # Get input and output shape + darr = darr.rechunk(chunksizes_in_mem) + src_shape = darr.shape + dst_shape = (src_shape[0] + dst_shape_diff[0], src_shape[1] + dst_shape_diff[1]) + + # Define arbitrary input transform, as we only care about the relative difference with the output transform + src_transform = rio.transform.from_bounds(10, 10, 15, 15, src_shape[0], src_shape[1]) + + # Other arguments having (normally) no influence + src_crs = CRS(4326) + dst_crs = CRS(32630) + src_nodata = -9999 + dst_nodata = 99999 + resampling = rio.enums.Resampling.bilinear + + # Get shifted dst_transform with new resolution + dst_transform = _build_dst_transform_shifted_newres( + src_transform=src_transform, + src_crs=src_crs, + dst_crs=dst_crs, + src_shape=src_shape, + bounds_rel_shift=dst_bounds_rel_shift, + res_rel_fac=dst_res_rel_fac, + ) + + # 2/ Run delayed reproject with memory monitoring + + reproj_arr = delayed_reproject( + darr, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + dst_shape=dst_shape, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=resampling, + dst_chunksizes=dst_chunksizes, + ) + + # 3/ Outputs check: load in memory and compare with a direct Rasterio reproject + reproj_arr = np.array(reproj_arr) + + dst_arr = np.zeros(dst_shape) + _ = rio.warp.reproject( + np.array(darr), + dst_arr, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + resampling=resampling, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + ) + + # Keeping this to visualize Rasterio resampling issue + # if PLOT: + # import matplotlib.pyplot as plt + # plt.figure() + # plt.imshow((reproj_arr - dst_arr), cmap="RdYlBu", vmin=-0.2, vmax=0.2, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/diff_close_zero.png", dpi=500) + # plt.figure() + # plt.imshow(np.abs(reproj_arr - dst_arr), cmap="RdYlBu", vmin=99997, vmax=100001, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/diff_nodata.png", dpi=500) + # plt.figure() + # plt.imshow(dst_arr, cmap="RdYlBu", vmin=-1, vmax=1, interpolation="None") + # plt.colorbar() + # plt.savefig("/home/atom/ongoing/dst.png", dpi=500) + + # Due to (what appears to be) Rasterio errors, we have to remain imprecise for the checks here: + # even though some reprojections are pretty good, some can get a bit nasty + + # Check that little data (less than 10% of pixels) are significantly different + ind_signif_diff = np.abs(reproj_arr - dst_arr) > 0.5 + assert np.count_nonzero(ind_signif_diff) < 0.1 * reproj_arr.size + + # The median difference should be negligible compared to the amplitude of the signal (+/- 1 std) + assert np.nanmedian(np.abs(reproj_arr - dst_arr)) < 0.1 + + # # Replace with allclose once Rasterio issue fixed? For some cases we get a good match + # (less than 0.01 for all pixels) + # assert np.allclose(reproj_arr[~ind_both_nodata], dst_arr[~ind_both_nodata], atol=0.02) + + @pytest.mark.parametrize("fn", [fn_large]) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(1000, 1000), (2500, 2500)]) # type: ignore + @pytest.mark.parametrize("subsample_size", [100, 100000]) # type: ignore + def test_delayed_subsample__memusage( + self, fn: str, chunksizes_in_mem: tuple[int, int], subsample_size: int, cluster: Any + ): + """ + Checks for delayed subsampling function for memory usage on big file. + (and also runs output checks as not long or too memory intensive in this case) + Variables that influence memory usage are: + - Subsample sizes, + - Chunksizes in memory. + """ + + # Only check on linux + if sys.platform == "linux": + + # 0/ Open dataset with chunks + ds = xr.open_dataset(fn, chunks={"x": chunksizes_in_mem[0], "y": chunksizes_in_mem[1]}) + darr = ds["test"].data + + # 1/ Estimation of theoretical memory usage of the subsampling script + + max_op_memusage = _estimate_subsample_memusage( + darr=darr, chunksizes_in_mem=chunksizes_in_mem, subsample_size=subsample_size + ) + + # 2/ Run delayed subsample with dask memory usage monitoring + + # Derive subsample from delayed function + # (passed to wrapper function to measure memory usage during execution) + sub, measured_op_memusage = _run_dask_measuring_memusage( + cluster, delayed_subsample, darr, subsample=subsample_size, random_state=42 + ) + + # Check the measured memory usage is smaller than the maximum estimated one + assert measured_op_memusage < max_op_memusage + + # 3/ Output checks + # The subsample should have exactly the prescribed length, with only valid values + assert len(sub) == subsample_size + assert all(np.isfinite(sub)) + + # To verify the sampling works correctly, we can get its subsample indices with the argument return_indices + # And compare to the same subsample with vindex (now that we know the coordinates of valid values sampled) + indices = delayed_subsample(darr, subsample=subsample_size, random_state=42, return_indices=True) + sub2 = np.array(darr.vindex[indices[0], indices[1]]) + assert np.array_equal(sub, sub2) + + @pytest.mark.parametrize("fn", [fn_large]) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(2000, 2000)]) # type: ignore + @pytest.mark.parametrize("ninterp", [100, 100000]) # type: ignore + def test_delayed_interp_points__memusage( + self, fn: str, chunksizes_in_mem: tuple[int, int], ninterp: int, cluster: Any + ): + """ + Checks for delayed interpolate points function for memory usage on a big file. + Variables that influence memory usage are: + - Number of interpolated points, + - Chunksizes in memory. + """ + + # Only check on linux + if sys.platform == "linux": + + # 0/ Open dataset with chunks and create random point locations to interpolate + ds = xr.open_dataset(fn, chunks={"x": chunksizes_in_mem[0], "y": chunksizes_in_mem[1]}) + darr = ds["test"].data + + rng = np.random.default_rng(seed=42) + interp_x = rng.choice(ds.x.size, ninterp) + rng.random(ninterp) + interp_y = rng.choice(ds.y.size, ninterp) + rng.random(ninterp) + + # 1/ Estimation of theoretical memory usage of the subsampling script + max_op_memusage = _estimate_interp_points_memusage( + darr=darr, chunksizes_in_mem=chunksizes_in_mem, ninterp=ninterp + ) + + # 2/ Run interpolation of random point coordinates with memory monitoring + interp1, measured_op_memusage = _run_dask_measuring_memusage( + cluster, delayed_interp_points, darr, points=(interp_x, interp_y), resolution=(1, 1) + ) + # Check the measured memory usage is smaller than the maximum estimated one + assert measured_op_memusage < max_op_memusage + + @pytest.mark.parametrize("fn", [fn_large]) # type: ignore + @pytest.mark.parametrize("chunksizes_in_mem", [(1000, 1000), (2500, 2500)]) # type: ignore + @pytest.mark.parametrize("dst_chunksizes", [(1000, 1000), (2500, 2500)]) # type: ignore + @pytest.mark.parametrize("dst_bounds_rel_shift", [(1, 1), (2, 2)]) # type: ignore + def test_delayed_reproject__memusage( + self, + fn: str, + chunksizes_in_mem: tuple[int, int], + dst_chunksizes: tuple[int, int], + dst_bounds_rel_shift: tuple[float, float], + cluster: Any, + ): + """ + Checks for the delayed reproject function for memory usage on a big file. + Variables that influence memory usage are: + - Source chunksizes in memory, + - Destination chunksizes in memory, + - Relative difference in resolution (potentially more/less source chunks to load for a destination chunk). + """ + + # Only check on linux + if sys.platform == "linux": + + # We fix arbitrary changes to the destination shape/resolution/bounds + # (already checked in details in the output tests) + dst_shape_diff = (25, -25) + dst_res_rel_fac = (1.5, 0.5) + + # 0/ Open dataset with chunks and define variables + ds = xr.open_dataset(fn, chunks={"x": chunksizes_in_mem[0], "y": chunksizes_in_mem[1]}) + darr = ds["test"].data + + # Get input and output shape + src_shape = darr.shape + dst_shape = (src_shape[0], src_shape[1] + dst_shape_diff[1]) + + # Define arbitrary input/output CRS, they don't have a direct influence on the delayed method + # (as long as the input/output transforms intersect if projected in the same CRS) + src_crs = CRS(4326) + dst_crs = CRS(32630) + + # Define arbitrary input transform, as we only care about the relative difference with the output transform + src_transform = rio.transform.from_bounds(10, 10, 15, 15, src_shape[0], src_shape[1]) + + # Other arguments having no influence + src_nodata = -9999 + dst_nodata = 99999 + resampling = rio.enums.Resampling.bilinear + + # Get shifted dst_transform with new resolution + dst_transform = _build_dst_transform_shifted_newres( + src_transform=src_transform, + src_crs=src_crs, + dst_crs=dst_crs, + src_shape=src_shape, + bounds_rel_shift=dst_bounds_rel_shift, + res_rel_fac=dst_res_rel_fac, + ) + + # 1/ Estimation of theoretical memory usage of the subsampling script + + max_op_memusage = _estimate_reproject_memusage( + darr, chunksizes_in_mem=chunksizes_in_mem, dst_chunksizes=dst_chunksizes, rel_res_fac=dst_res_rel_fac + ) + + # 2/ Run delayed reproject with memory monitoring + + # We define a function where computes happens during writing to be able to measure memory usage + # (delayed_reproject returns a delayed array that might not fit in memory, unlike subsampling/interpolation) + fn_tmp_out = os.path.join(_EXAMPLES_DIRECTORY, os.path.splitext(os.path.basename(fn))[0] + "_reproj.nc") + + def reproject_and_write(*args: Any, **kwargs: Any) -> None: + # Run delayed reprojection + reproj_arr_tmp = delayed_reproject(*args, **kwargs) + + # Save file out-of-memory and compute + data_arr = xr.DataArray(data=reproj_arr_tmp, dims=["x", "y"]) + ds_out = xr.Dataset(data_vars={"test_reproj": data_arr}) + write_delayed = ds_out.to_netcdf(fn_tmp_out, compute=False) + write_delayed.compute() + + # And call this function with memory usage monitoring + _, measured_op_memusage = _run_dask_measuring_memusage( + cluster, + reproject_and_write, + darr, + src_transform=src_transform, + src_crs=src_crs, + dst_transform=dst_transform, + dst_crs=dst_crs, + dst_shape=dst_shape, + src_nodata=src_nodata, + dst_nodata=dst_nodata, + resampling=resampling, + dst_chunksizes=dst_chunksizes, + ) + + # Check the measured memory usage is smaller than the maximum estimated one + assert measured_op_memusage < max_op_memusage diff --git a/tests/test_multiraster.py b/tests/test_raster/test_multiraster.py similarity index 100% rename from tests/test_multiraster.py rename to tests/test_raster/test_multiraster.py diff --git a/tests/test_raster.py b/tests/test_raster/test_raster.py similarity index 100% rename from tests/test_raster.py rename to tests/test_raster/test_raster.py diff --git a/tests/test_sampling.py b/tests/test_raster/test_sampling.py similarity index 100% rename from tests/test_sampling.py rename to tests/test_raster/test_sampling.py diff --git a/tests/test_satimg.py b/tests/test_raster/test_satimg.py similarity index 100% rename from tests/test_satimg.py rename to tests/test_raster/test_satimg.py