Skip to content

Commit 30a8e98

Browse files
committed
BUG: Properly handle encoding/decoding scales and offsets
1 parent 3ca29fb commit 30a8e98

File tree

5 files changed

+101
-30
lines changed

5 files changed

+101
-30
lines changed

rioxarray/_io.py

+4
Original file line numberDiff line numberDiff line change
@@ -957,8 +957,12 @@ def _handle_encoding(
957957
variables.pop_to(
958958
result.attrs, result.encoding, "scale_factor", name=da_name
959959
)
960+
if "scales" in result.attrs:
961+
variables.pop_to(result.attrs, result.encoding, "scales", name=da_name)
960962
if "add_offset" in result.attrs:
961963
variables.pop_to(result.attrs, result.encoding, "add_offset", name=da_name)
964+
if "offsets" in result.attrs:
965+
variables.pop_to(result.attrs, result.encoding, "offsets", name=da_name)
962966
if masked:
963967
if "_FillValue" in result.attrs:
964968
variables.pop_to(result.attrs, result.encoding, "_FillValue", name=da_name)

rioxarray/merge.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,16 @@ def __init__(self, xds: DataArray):
4242
"crs": self.crs,
4343
"nodata": self.nodatavals[0],
4444
}
45-
self._scale_factor = self._xds.encoding.get("scale_factor", 1.0)
46-
self._add_offset = self._xds.encoding.get("add_offset", 0.0)
45+
valid_scale_factor = self._xds.encoding.get("scale_factor", 1) != 1 or any(
46+
scale != 1 for scale in self._xds.encoding.get("scales", (1,))
47+
)
48+
valid_offset = self._xds.encoding.get("add_offset", 0.0) != 0 or any(
49+
offset != 0 for offset in self._xds.encoding.get("offsets", (0,))
50+
)
4751
self._mask_and_scale = (
4852
self._xds.rio.encoded_nodata is not None
49-
or self._scale_factor != 1
50-
or self._add_offset != 0
53+
or valid_scale_factor
54+
or valid_offset
5155
or self._xds.encoding.get("_Unsigned") is not None
5256
)
5357

@@ -70,10 +74,9 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
7074
kwargs["masked"] = True
7175
out = dataset.read(*args, **kwargs)
7276
if self._mask_and_scale:
73-
if self._scale_factor != 1:
74-
out = out * self._scale_factor
75-
if self._add_offset != 0:
76-
out = out + self._add_offset
77+
out = out.astype(self._xds.dtype)
78+
for iii in range(self.count):
79+
out[iii] = out[iii] * dataset.scales[iii] + dataset.offsets[iii]
7780
return out
7881

7982

rioxarray/raster_dataset.py

+32-10
Original file line numberDiff line numberDiff line change
@@ -512,32 +512,54 @@ def to_raster(
512512
variable_dim = f"band_{uuid4()}"
513513
data_array = self._obj.to_array(dim=variable_dim)
514514
# ensure raster metadata preserved
515-
scales = []
516-
offsets = []
517-
nodatavals = []
515+
attr_scales = []
516+
attr_offsets = []
517+
attr_nodatavals = []
518+
encoded_scales = []
519+
encoded_offsets = []
520+
encoded_nodatavals = []
518521
band_tags = []
519522
long_name = []
520523
for data_var in data_array[variable_dim].values:
521-
scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
522-
offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
524+
try:
525+
encoded_scales.append(self._obj[data_var].encoding["scale_factor"])
526+
except KeyError:
527+
attr_scales.append(self._obj[data_var].attrs.get("scale_factor", 1.0))
528+
try:
529+
encoded_offsets.append(self._obj[data_var].encoding["add_offset"])
530+
except KeyError:
531+
attr_offsets.append(self._obj[data_var].attrs.get("add_offset", 0.0))
523532
long_name.append(self._obj[data_var].attrs.get("long_name", data_var))
524-
nodatavals.append(self._obj[data_var].rio.nodata)
533+
if self._obj[data_var].rio.encoded_nodata is not None:
534+
encoded_nodatavals.append(self._obj[data_var].rio.encoded_nodata)
535+
else:
536+
attr_nodatavals.append(self._obj[data_var].rio.nodata)
525537
band_tags.append(self._obj[data_var].attrs.copy())
526-
data_array.attrs["scales"] = scales
527-
data_array.attrs["offsets"] = offsets
538+
if encoded_scales:
539+
data_array.encoding["scales"] = encoded_scales
540+
else:
541+
data_array.attrs["scales"] = attr_scales
542+
if encoded_offsets:
543+
data_array.encoding["offsets"] = encoded_offsets
544+
else:
545+
data_array.attrs["offsets"] = attr_offsets
528546
data_array.attrs["band_tags"] = band_tags
529547
data_array.attrs["long_name"] = long_name
530548

549+
use_encoded_nodatavals = bool(encoded_nodatavals)
550+
nodatavals = encoded_nodatavals if use_encoded_nodatavals else attr_nodatavals
531551
nodata = nodatavals[0]
532552
if (
533553
all(nodataval == nodata for nodataval in nodatavals)
534554
or numpy.isnan(nodatavals).all()
535555
):
536-
data_array.rio.write_nodata(nodata, inplace=True)
556+
data_array.rio.write_nodata(
557+
nodata, inplace=True, encoded=use_encoded_nodatavals
558+
)
537559
else:
538560
raise RioXarrayError(
539561
"All nodata values must be the same when exporting to raster. "
540-
f"Current values: {nodatavals}"
562+
f"Current values: {attr_nodatavals}"
541563
)
542564
if self.crs is not None:
543565
data_array.rio.write_crs(self.crs, inplace=True)

rioxarray/raster_writer.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -94,20 +94,23 @@ def _write_metatata_to_raster(*, raster_handle, xarray_dataset, tags):
9494
)
9595

9696
# write scales and offsets
97-
try:
98-
raster_handle.scales = tags["scales"]
99-
except KeyError:
97+
scales = tags.get("scales", xarray_dataset.encoding.get("scales"))
98+
if scales is None:
10099
scale_factor = tags.get(
101100
"scale_factor", xarray_dataset.encoding.get("scale_factor")
102101
)
103102
if scale_factor is not None:
104-
raster_handle.scales = (scale_factor,) * raster_handle.count
105-
try:
106-
raster_handle.offsets = tags["offsets"]
107-
except KeyError:
103+
scales = (scale_factor,) * raster_handle.count
104+
if scales is not None:
105+
raster_handle.scales = scales
106+
107+
offsets = tags.get("offsets", xarray_dataset.encoding.get("offsets"))
108+
if offsets is None:
108109
add_offset = tags.get("add_offset", xarray_dataset.encoding.get("add_offset"))
109110
if add_offset is not None:
110-
raster_handle.offsets = (add_offset,) * raster_handle.count
111+
offsets = (add_offset,) * raster_handle.count
112+
if offsets is not None:
113+
raster_handle.offsets = offsets
111114

112115
_write_tags(raster_handle=raster_handle, tags=tags)
113116
_write_band_description(raster_handle=raster_handle, xarray_dataset=xarray_dataset)

test/integration/test_integration_rioxarray.py

+43-4
Original file line numberDiff line numberDiff line change
@@ -1770,7 +1770,8 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir):
17701770
tmp_raster = tmpdir.join("air_temp_offset.tif")
17711771

17721772
with rioxarray.open_rasterio(
1773-
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks
1773+
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
1774+
chunks=chunks,
17741775
) as rds:
17751776
rds = _ensure_dataset(rds)
17761777
attrs = dict(rds.air_temperature.attrs)
@@ -1794,6 +1795,38 @@ def test_to_raster__offsets_and_scales(chunks, tmpdir):
17941795
assert rds.rio.nodata == 32767.0
17951796

17961797

1798+
@pytest.mark.parametrize("mask_and_scale", [True, False])
1799+
def test_to_raster__scales__offsets(mask_and_scale, tmpdir):
1800+
tmp_raster = tmpdir.join("air_temp_offset.tif")
1801+
1802+
with rioxarray.open_rasterio(
1803+
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
1804+
mask_and_scale=mask_and_scale,
1805+
) as rds:
1806+
rds = _ensure_dataset(rds)
1807+
rds["air_temperature_2"] = rds.air_temperature.copy()
1808+
if mask_and_scale:
1809+
rds.air_temperature_2.encoding["scale_factor"] = 0.2
1810+
rds.air_temperature_2.encoding["add_offset"] = 110.0
1811+
else:
1812+
rds.air_temperature_2.attrs["scale_factor"] = 0.2
1813+
rds.air_temperature_2.attrs["add_offset"] = 110.0
1814+
rds.squeeze(dim="band", drop=True).rio.to_raster(str(tmp_raster))
1815+
1816+
with rasterio.open(str(tmp_raster)) as rds:
1817+
assert rds.scales == (0.1, 0.2)
1818+
assert rds.offsets == (220.0, 110.0)
1819+
1820+
# test roundtrip
1821+
with rioxarray.open_rasterio(str(tmp_raster), mask_and_scale=mask_and_scale) as rds:
1822+
if mask_and_scale:
1823+
assert rds.encoding["scales"] == (0.1, 0.2)
1824+
assert rds.encoding["offsets"] == (220.0, 110.0)
1825+
else:
1826+
assert rds.attrs["scales"] == (0.1, 0.2)
1827+
assert rds.attrs["offsets"] == (220.0, 110.0)
1828+
1829+
17971830
def test_to_raster__custom_description__wrong(tmpdir):
17981831
tmp_raster = tmpdir.join("planet_3d_raster.tif")
17991832
with xarray.open_dataset(
@@ -1856,11 +1889,14 @@ def test_to_raster__dataset(tmpdir):
18561889
assert numpy.isnan(rdscompare.rio.nodata)
18571890

18581891

1892+
@pytest.mark.parametrize("mask_and_scale", [True, False])
18591893
@pytest.mark.parametrize("chunks", [True, None])
1860-
def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
1894+
def test_to_raster__dataset__mask_and_scale(chunks, mask_and_scale, tmpdir):
18611895
output_raster = tmpdir.join("tmmx_20190121.tif")
18621896
with rioxarray.open_rasterio(
1863-
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"), chunks=chunks
1897+
os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc"),
1898+
chunks=chunks,
1899+
mask_and_scale=mask_and_scale,
18641900
) as rds:
18651901
rds = _ensure_dataset(rds)
18661902
rds.isel(band=0).rio.to_raster(str(output_raster))
@@ -1870,7 +1906,10 @@ def test_to_raster__dataset__mask_and_scale(chunks, tmpdir):
18701906
assert rdscompare.add_offset == 220.0
18711907
assert rdscompare.long_name == "tmmx"
18721908
assert rdscompare.rio.crs == rds.rio.crs
1873-
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata
1909+
if mask_and_scale:
1910+
assert rdscompare.rio.nodata == rds.air_temperature.rio.encoded_nodata
1911+
else:
1912+
assert rdscompare.rio.nodata == rds.air_temperature.rio.nodata
18741913

18751914

18761915
def test_to_raster__dataset__different_crs(tmpdir):

0 commit comments

Comments
 (0)