From 6a29c363f77375597b6adc88b04b3bae03ac2e73 Mon Sep 17 00:00:00 2001 From: "Alan D. Snow" Date: Mon, 4 Nov 2024 08:38:44 -0600 Subject: [PATCH] BUG: Properly handle encoding/decoding scales and offsets (#821) --- rioxarray/_io.py | 4 ++ rioxarray/merge.py | 19 ++++---- rioxarray/raster_dataset.py | 43 +++++++++++++---- rioxarray/raster_writer.py | 19 ++++---- .../integration/test_integration_rioxarray.py | 47 +++++++++++++++++-- 5 files changed, 102 insertions(+), 30 deletions(-) diff --git a/rioxarray/_io.py b/rioxarray/_io.py index 26efbc3b..96e32501 100644 --- a/rioxarray/_io.py +++ b/rioxarray/_io.py @@ -957,8 +957,12 @@ def _handle_encoding( variables.pop_to( result.attrs, result.encoding, "scale_factor", name=da_name ) + if "scales" in result.attrs: + variables.pop_to(result.attrs, result.encoding, "scales", name=da_name) if "add_offset" in result.attrs: variables.pop_to(result.attrs, result.encoding, "add_offset", name=da_name) + if "offsets" in result.attrs: + variables.pop_to(result.attrs, result.encoding, "offsets", name=da_name) if masked: if "_FillValue" in result.attrs: variables.pop_to(result.attrs, result.encoding, "_FillValue", name=da_name) diff --git a/rioxarray/merge.py b/rioxarray/merge.py index 58d483b7..25cfbf4e 100644 --- a/rioxarray/merge.py +++ b/rioxarray/merge.py @@ -42,12 +42,16 @@ def __init__(self, xds: DataArray): "crs": self.crs, "nodata": self.nodatavals[0], } - self._scale_factor = self._xds.encoding.get("scale_factor", 1.0) - self._add_offset = self._xds.encoding.get("add_offset", 0.0) + valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any( + scale != 1 for scale in self._xds.encoding.get("scales", (1,)) + ) + valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any( + offset != 0 for offset in self._xds.encoding.get("offsets", (0,)) + ) self._mask_and_scale = ( self._xds.rio.encoded_nodata is not None - or self._scale_factor != 1 - or self._add_offset != 0 + or valid_scale_factor + or valid_offset or self._xds.encoding.get("_Unsigned") is not None ) @@ -70,10 +74,9 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray: kwargs["masked"] = True out = dataset.read(*args, **kwargs) if self._mask_and_scale: - if self._scale_factor != 1: - out = out * self._scale_factor - if self._add_offset != 0: - out = out + self._add_offset + out = out.astype(self._xds.dtype) + for iii in range(self.count): + out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii] return out diff --git a/rioxarray/raster_dataset.py b/rioxarray/raster_dataset.py index 6f3aa287..52dc593d 100644 --- a/rioxarray/raster_dataset.py +++ b/rioxarray/raster_dataset.py @@ -509,35 +509,58 @@ def to_raster( is True. Otherwise None is returned. """ + # pylint: disable=too-many-locals variable_dim = f"band_{uuid4()}" data_array = self._obj.to_array(dim=variable_dim) # ensure raster metadata preserved - scales = [] - offsets = [] - nodatavals = [] + attr_scales = [] + attr_offsets = [] + attr_nodatavals = [] + encoded_scales = [] + encoded_offsets = [] + encoded_nodatavals = [] band_tags = [] long_name = [] for data_var in data_array[variable_dim].values: - scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0)) - offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0)) + try: + encoded_scales.append(self._obj[data_var].encoding["scale_factor"]) + except KeyError: + attr_scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0)) + try: + encoded_offsets.append(self._obj[data_var].encoding["add_offset"]) + except KeyError: + attr_offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0)) long_name.append(self._obj[data_var].attrs.get("long_name", data_var)) - nodatavals.append(self._obj[data_var].rio.nodata) + if self._obj[data_var].rio.encoded_nodata is not None: + encoded_nodatavals.append(self._obj[data_var].rio.encoded_nodata) + else: + attr_nodatavals.append(self._obj[data_var].rio.nodata) band_tags.append(self._obj[data_var].attrs.copy()) - data_array.attrs["scales"] = scales - data_array.attrs["offsets"] = offsets + if encoded_scales: + data_array.encoding["scales"] = encoded_scales + else: + data_array.attrs["scales"] = attr_scales + if encoded_offsets: + data_array.encoding["offsets"] = encoded_offsets + else: + data_array.attrs["offsets"] = attr_offsets data_array.attrs["band_tags"] = band_tags data_array.attrs["long_name"] = long_name + use_encoded_nodatavals = bool(encoded_nodatavals) + nodatavals = encoded_nodatavals if use_encoded_nodatavals else attr_nodatavals nodata = nodatavals[0] if ( all(nodataval == nodata for nodataval in nodatavals) or numpy.isnan(nodatavals).all() ): - data_array.rio.write_nodata(nodata, inplace=True) + data_array.rio.write_nodata( + nodata, inplace=True, encoded=use_encoded_nodatavals + ) else: raise RioXarrayError( "All nodata values must be the same when exporting to raster. " - f"Current values: {nodatavals}" + f"Current values: {attr_nodatavals}" ) if self.crs is not None: data_array.rio.write_crs(self.crs, inplace=True) diff --git a/rioxarray/raster_writer.py b/rioxarray/raster_writer.py index 408a2672..2cf87f9a 100644 --- a/rioxarray/raster_writer.py +++ b/rioxarray/raster_writer.py @@ -94,20 +94,23 @@ def _write_metatata_to_raster(*, raster_handle, xarray_dataset, tags): ) # write scales and offsets - try: - raster_handle.scales = tags["scales"] - except KeyError: + scales = tags.get("scales", xarray_dataset.encoding.get("scales")) + if scales is None: scale_factor = tags.get( "scale_factor", xarray_dataset.encoding.get("scale_factor") ) if scale_factor is not None: - raster_handle.scales = (scale_factor,) * raster_handle.count - try: - raster_handle.offsets = tags["offsets"] - except KeyError: + scales = (scale_factor,) * raster_handle.count + if scales is not None: + raster_handle.scales = scales + + offsets = tags.get("offsets", xarray_dataset.encoding.get("offsets")) + if offsets is None: add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset")) if add_offset is not None: - raster_handle.offsets = (add_offset,) * raster_handle.count + offsets = (add_offset,) * raster_handle.count + if offsets is not None: + raster_handle.offsets = offsets _write_tags(raster_handle=raster_handle, tags=tags) _write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset) diff --git a/test/integration/test_integration_rioxarray.py b/test/integration/test_integration_rioxarray.py index 113089a6..ad6cbbe0 100644 --- a/test/integration/test_integration_rioxarray.py +++ b/test/integration/test_integration_rioxarray.py @@ -1771,7 +1771,8 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir): tmp_raster = tmpdir.join("air_temp_offset.tif") with rioxarray.open_rasterio( - os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks + os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), + chunks=chunks, ) as rds: rds = _ensure_dataset(rds) attrs = dict(rds.air_temperature.attrs) @@ -1795,6 +1796,38 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir): assert rds.rio.nodata == 32767.0 +@pytest.mark.parametrize("mask_and_scale", [True, False]) +def test_to_raster__scales__offsets(mask_and_scale, tmpdir): + tmp_raster = tmpdir.join("air_temp_offset.tif") + + with rioxarray.open_rasterio( + os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), + mask_and_scale=mask_and_scale, + ) as rds: + rds = _ensure_dataset(rds) + rds["air_temperature_2"] = rds.air_temperature.copy() + if mask_and_scale: + rds.air_temperature_2.encoding["scale_factor"] = 0.2 + rds.air_temperature_2.encoding["add_offset"] = 110.0 + else: + rds.air_temperature_2.attrs["scale_factor"] = 0.2 + rds.air_temperature_2.attrs["add_offset"] = 110.0 + rds.squeeze(dim="band", drop=True).rio.to_raster(str(tmp_raster)) + + with rasterio.open(str(tmp_raster)) as rds: + assert rds.scales == (0.1, 0.2) + assert rds.offsets == (220.0, 110.0) + + # test roundtrip + with rioxarray.open_rasterio(str(tmp_raster), mask_and_scale=mask_and_scale) as rds: + if mask_and_scale: + assert rds.encoding["scales"] == (0.1, 0.2) + assert rds.encoding["offsets"] == (220.0, 110.0) + else: + assert rds.attrs["scales"] == (0.1, 0.2) + assert rds.attrs["offsets"] == (220.0, 110.0) + + def test_to_raster__custom_description__wrong(tmpdir): tmp_raster = tmpdir.join("planet_3d_raster.tif") with xarray.open_dataset( @@ -1857,11 +1890,14 @@ def test_to_raster__dataset(tmpdir): assert numpy.isnan(rdscompare.rio.nodata) +@pytest.mark.parametrize("mask_and_scale", [True, False]) @pytest.mark.parametrize("chunks", [True, None]) -def test_to_raster__dataset__mask_and_scale(chunks, tmpdir): +def test_to_raster__dataset__mask_and_scale(chunks, mask_and_scale, tmpdir): output_raster = tmpdir.join("tmmx_20190121.tif") with rioxarray.open_rasterio( - os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks + os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), + chunks=chunks, + mask_and_scale=mask_and_scale, ) as rds: rds = _ensure_dataset(rds) rds.isel(band=0).rio.to_raster(str(output_raster)) @@ -1871,7 +1907,10 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir): assert rdscompare.add_offset == 220.0 assert rdscompare.long_name == "tmmx" assert rdscompare.rio.crs == rds.rio.crs - assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata + if mask_and_scale: + assert rdscompare.rio.nodata == rds.air_temperature.rio.encoded_nodata + else: + assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata def test_to_raster__dataset__different_crs(tmpdir):