Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for reading and writing tiff gcps #460

Merged
merged 15 commits into from
Feb 9, 2022
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ History
Latest
------
- DEP: Drop Python 3.7 support (issue #451)
- ENH: Add GCPs reading and writing (issue #376)

0.9.1
------
Expand Down
6 changes: 6 additions & 0 deletions rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,6 +880,10 @@ def open_rasterio(
coord_name = "band"
coords[coord_name] = np.asarray(riods.indexes)

has_gcps = hasattr(riods, "gcps") and riods.gcps != ([], None)
if has_gcps:
parse_coordinates = False

# Get geospatial coordinates
if parse_coordinates:
coords.update(
Expand Down Expand Up @@ -937,6 +941,8 @@ def open_rasterio(
result.rio.write_transform(_rio_transform(riods), inplace=True)
if hasattr(riods, "crs") and riods.crs:
result.rio.write_crs(riods.crs, inplace=True)
if has_gcps:
result.rio.write_gcps(*riods.gcps, inplace=True)

if chunks is not None:
result = _prepare_dask(result, riods, filename, chunks)
Expand Down
41 changes: 27 additions & 14 deletions rioxarray/raster_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,10 @@ def reproject(
"CRS not found. Please set the CRS with 'rio.write_crs()'."
f"{_get_data_var_message(self._obj)}"
)
gcps = self.get_gcps()
if gcps:
kwargs.setdefault("gcps", gcps)

src_affine = None if "gcps" in kwargs else self.transform(recalc=True)
if transform is None:
dst_affine, dst_width, dst_height = _make_dst_affine(
Expand All @@ -404,21 +408,9 @@ def reproject(
else:
dst_height, dst_width = self.shape

extra_dim = self._check_dimensions()
if extra_dim:
dst_data = np.zeros(
(self._obj[extra_dim].size, dst_height, dst_width),
dtype=self._obj.dtype.type,
)
else:
dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type)
dst_data = self._create_dst_data(dst_height, dst_width)

default_nodata = (
_NODATA_DTYPE_MAP[dtype_rev[self._obj.dtype.name]]
if self.nodata is None
else self.nodata
)
dst_nodata = default_nodata if nodata is None else nodata
dst_nodata = self._get_dst_nodata(nodata)

rasterio.warp.reproject(
source=self._obj.values,
Expand Down Expand Up @@ -456,6 +448,26 @@ def reproject(
xda.rio.write_coordinate_system(inplace=True)
return xda

def _get_dst_nodata(self, nodata):
default_nodata = (
_NODATA_DTYPE_MAP[dtype_rev[self._obj.dtype.name]]
if self.nodata is None
else self.nodata
)
dst_nodata = default_nodata if nodata is None else nodata
return dst_nodata

def _create_dst_data(self, dst_height, dst_width):
extra_dim = self._check_dimensions()
if extra_dim:
dst_data = np.zeros(
(self._obj[extra_dim].size, dst_height, dst_width),
dtype=self._obj.dtype.type,
)
else:
dst_data = np.zeros((dst_height, dst_width), dtype=self._obj.dtype.type)
return dst_data

def reproject_match(
self, match_data_array, resampling=Resampling.nearest, **reproject_kwargs
):
Expand Down Expand Up @@ -1011,6 +1023,7 @@ def to_raster(
dtype=dtype,
crs=self.crs,
transform=self.transform(recalc=recalc_transform),
gcps=self.get_gcps(),
nodata=rio_nodata,
windowed=windowed,
lock=lock,
Expand Down
89 changes: 89 additions & 0 deletions rioxarray/rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import rasterio.windows
import xarray
from affine import Affine
from rasterio.control import GroundControlPoint
from rasterio.crs import CRS

from rioxarray._options import EXPORT_GRID_MAPPING, get_option
Expand Down Expand Up @@ -1078,3 +1079,91 @@ def transform_bounds(self, dst_crs, densify_pts=21, recalc=False):
return rasterio.warp.transform_bounds(
self.crs, dst_crs, *self.bounds(recalc=recalc), densify_pts=densify_pts
)

def write_gcps(self, gcps, gcp_crs, grid_mapping_name=None, inplace=False):
"""
Write the GroundControlPoints to the dataset.

https://rasterio.readthedocs.io/en/latest/topics/georeferencing.html#ground-control-points

Parameters
----------
gcp: list of :obj:`rasterio.control.GroundControlPoints`
The Ground Control Points to integrate to the dataset.
gcp_crs: str, :obj:`rasterio.crs.CRS`, or dict
Coordinate reference system for the GCPs.
grid_mapping_name: str, optional
Name of the grid_mapping coordinate to store the GCPs information in.
Default is the grid_mapping name of the dataset.
inplace: bool, optional
If True, it will write to the existing dataset. Default is False.

Returns
-------
:obj:`xarray.Dataset` | :obj:`xarray.DataArray`:
Modified dataset with Ground Control Points written.
"""
grid_mapping_name = (
self.grid_mapping if grid_mapping_name is None else grid_mapping_name
)
data_obj = self._get_obj(inplace=True)

data_obj = data_obj.rio.write_crs(
gcp_crs, grid_mapping_name=grid_mapping_name, inplace=inplace
)
geojson_gcps = _convert_gcps_to_geojson(gcps)
data_obj.coords[grid_mapping_name].attrs["gcps"] = geojson_gcps
return data_obj

def get_gcps(self):
"""
Get the GroundControlPoints from the dataset.

https://rasterio.readthedocs.io/en/latest/topics/georeferencing.html#ground-control-points

Returns
-------
list of :obj:`rasterio.control.GroundControlPoints` or None
The Ground Control Points from the dataset or None if not applicable
"""
try:
geojson_gcps = self._obj.coords[self.grid_mapping].attrs["gcps"]
except (KeyError, AttributeError):
return None

gcps = [
GroundControlPoint(
x=gcp["geometry"]["coordinates"][0],
y=gcp["geometry"]["coordinates"][1],
z=gcp["geometry"]["coordinates"][2],
row=gcp["properties"]["row"],
col=gcp["properties"]["col"],
id=gcp["properties"]["id"],
info=gcp["properties"]["info"],
)
for gcp in geojson_gcps["features"]
]
return gcps


def _convert_gcps_to_geojson(gcps):
snowman2 marked this conversation as resolved.
Show resolved Hide resolved
"""
Convert GCPs to geojson.

Parameters
----------
gcps: The list of GroundControlPoint instances.

Returns
-------
A FeatureCollection dict.
"""
features = [
dict(
type="Feature",
properties=dict(id=gcp.id, info=gcp.info, row=gcp.row, col=gcp.col),
geometry=dict(type="Point", coordinates=[gcp.x, gcp.y, gcp.z]),
)
for gcp in gcps
]
return dict(type="FeatureCollection", features=features)
57 changes: 56 additions & 1 deletion test/integration/test_integration__io.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
TEST_INPUT_DATA_DIR,
_assert_xarrays_equal,
)
from test.integration.test_integration_rioxarray import (
_check_rio_gcps,
_create_gdal_gcps,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -760,7 +764,7 @@ def test_ENVI_tags():


def test_no_mftime():
# rasterio can accept "filename" urguments that are actually urls,
# rasterio can accept "filename" arguments that are actually urls,
# including paths to remote files.
# In issue #1816, we found that these caused dask to break, because
# the modification time was used to determine the dask token. This
Expand Down Expand Up @@ -1236,3 +1240,54 @@ def test_cint16_promote_dtype(tmp_path):
assert "complex128" in riofh.dtypes
assert riofh.nodata == 0
assert data.dtype == "complex128"


def test_reading_gcps(tmp_path):
"""
Test reading gcps from a tiff file.
"""
tiffname = tmp_path / "test.tif"

gdal_gcps = _create_gdal_gcps()

with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=np.uint8,
driver="GTiff",
) as source:
source.gcps = gdal_gcps

with rioxarray.open_rasterio(tiffname) as darr:
_check_rio_gcps(darr, *gdal_gcps)


def test_writing_gcps(tmp_path):
"""
Test writing gcps to a tiff file.
"""
tiffname = tmp_path / "test.tif"
tiffname2 = tmp_path / "test_written.tif"

gdal_gcps = _create_gdal_gcps()

with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=np.uint8,
driver="GTiff",
) as source:
source.gcps = gdal_gcps

with rioxarray.open_rasterio(tiffname) as darr:
darr.rio.to_raster(tiffname2, driver="GTIFF")

with rioxarray.open_rasterio(tiffname2) as darr:
assert "gcps" in darr.coords["spatial_ref"].attrs
_check_rio_gcps(darr, *gdal_gcps)
110 changes: 110 additions & 0 deletions test/integration/test_integration_rioxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -2772,3 +2772,113 @@ def test_interpolate_na_missing_nodata():
test_da.rio.interpolate_na()
with pytest.raises(RioXarrayError, match=match):
test_da.to_dataset().rio.interpolate_na()


def test_rio_write_gcps():
"""
Test setting gcps in dataarray.
"""
gdal_gcps, gcp_crs = _create_gdal_gcps()

darr = xarray.DataArray(1)
darr.rio.write_gcps(gdal_gcps, gcp_crs, inplace=True)

_check_rio_gcps(darr, gdal_gcps, gcp_crs)


def _create_gdal_gcps():
src_gcps = [
GroundControlPoint(
row=0, col=0, x=0.0, y=0.0, z=12.0, id="1", info="the first gcp"
),
GroundControlPoint(
row=0, col=800, x=10.0, y=0.0, z=1.0, id="2", info="the second gcp"
),
GroundControlPoint(
row=800, col=800, x=10.0, y=10.0, z=3.5, id="3", info="the third gcp"
),
GroundControlPoint(
row=800, col=0, x=0.0, y=10.0, z=5.5, id="4", info="the fourth gcp"
),
]
crs = CRS.from_epsg(4326)
gdal_gcps = (src_gcps, crs)
return gdal_gcps


def _check_rio_gcps(darr, src_gcps, crs):
assert "x" not in darr.coords
assert "y" not in darr.coords
assert darr.rio.crs == crs
assert "gcps" in darr.spatial_ref.attrs
gcps = darr.spatial_ref.attrs["gcps"]
assert gcps["type"] == "FeatureCollection"
assert len(gcps["features"]) == len(src_gcps)
for feature, gcp in zip(gcps["features"], src_gcps):
assert feature["type"] == "Feature"
assert feature["properties"]["id"] == gcp.id
# info seems to be lost when rasterio writes?
# assert feature["properties"]["info"] == gcp.info
assert feature["properties"]["row"] == gcp.row
assert feature["properties"]["col"] == gcp.col
assert feature["geometry"]["type"] == "Point"
assert feature["geometry"]["coordinates"] == [gcp.x, gcp.y, gcp.z]


def test_rio_get_gcps():
"""
Test setting gcps in dataarray.
"""
gdal_gcps, gdal_crs = _create_gdal_gcps()

darr = xarray.DataArray(1)
darr.rio.write_gcps(gdal_gcps, gdal_crs, inplace=True)

gcps = darr.rio.get_gcps()
for gcp, gdal_gcp in zip(gcps, gdal_gcps):
assert gcp.row == gdal_gcp.row
assert gcp.col == gdal_gcp.col
assert gcp.x == gdal_gcp.x
assert gcp.y == gdal_gcp.y
assert gcp.z == gdal_gcp.z
assert gcp.id == gdal_gcp.id
assert gcp.info == gdal_gcp.info


def test_reproject__gcps_file(tmp_path):
tiffname = tmp_path / "test.tif"
src_gcps = [
GroundControlPoint(row=0, col=0, x=156113, y=2818720, z=0),
GroundControlPoint(row=0, col=800, x=338353, y=2785790, z=0),
GroundControlPoint(row=800, col=800, x=297939, y=2618518, z=0),
GroundControlPoint(row=800, col=0, x=115698, y=2651448, z=0),
]
crs = CRS.from_epsg(32618)
with rasterio.open(
tiffname,
mode="w",
height=800,
width=800,
count=3,
dtype=numpy.uint8,
driver="GTiff",
) as source:
source.gcps = (src_gcps, crs)

rds = rioxarray.open_rasterio(tiffname)
rds = rds.rio.reproject(
crs,
)
assert rds.rio.height == 923
assert rds.rio.width == 1027
assert rds.rio.crs == crs
assert rds.rio.transform().almost_equals(
Affine(
216.8587081056465,
0.0,
115698.25,
0.0,
-216.8587081056465,
2818720.0,
)
)