Skip to content

Commit

Permalink
remove multiscale option
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago committed Oct 30, 2024
1 parent d9ea7d2 commit 80f3350
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 237 deletions.
20 changes: 3 additions & 17 deletions src/titiler/xarray/tests/fixtures/generate_fixtures.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,9 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<xarray.backends.zarr.ZarrStore at 0x118525bc0>"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"outputs": [],
"source": [
"# 3D Zarr\n",
"arr = numpy.linspace(0, 1000, 1000 * 2000 * 2).reshape(2, 1000, 2000)\n",
Expand All @@ -122,10 +111,7 @@
"metadata": {},
"outputs": [],
"source": [
"import xarray\n",
"import numpy\n",
"from datetime import datetime\n",
"\n",
"# Zarr Pyramid\n",
"def create_dataset(decimation: int = 0):\n",
" dec = decimation or 1 # make sure we don't / by 0\n",
" width = 2000 // dec\n",
Expand Down
4 changes: 0 additions & 4 deletions src/titiler/xarray/tests/test_dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,3 @@ def tiles(
response = client.get("/tiles/1/2/3", params={"variable": "yo"})
params = response.json()
assert params == {"variable": "yo"}

response = client.get("/tiles/1/2/3", params={"multiscale": True})
params = response.json()
assert params == {"group": 1}
41 changes: 23 additions & 18 deletions src/titiler/xarray/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,24 +164,6 @@ def test_tiles(filename, app):
assert resp.headers["content-type"] == "application/json"


# Test Multiscale (group == zoom level)
@pytest.mark.parametrize(
"group",
[0, 1, 2],
)
def test_tiles_multiscale(group, app):
"""Test /tiles endpoints."""
resp = app.get(
f"/md/tiles/WebMercatorQuad/{group}/0/0.tif",
params={"url": zarr_pyramid, "variable": "dataset", "multiscale": True},
)
assert resp.status_code == 200
with MemoryFile(resp.content) as mem:
with mem.open() as dst:
arr = dst.read(1)
assert arr.max() == group * 2


@pytest.mark.parametrize(
"filename",
[dataset_2d_nc, dataset_3d_nc, dataset_3d_zarr],
Expand Down Expand Up @@ -316,3 +298,26 @@ def test_part(filename, app):
)
assert resp.status_code == 200
assert resp.headers["content-type"] == "image/png"


@pytest.mark.parametrize(
"group",
[0, 1, 2],
)
def test_zarr_group(group, app):
"""Test /tiles endpoints."""
resp = app.get(
f"/md/tiles/WebMercatorQuad/{group}/0/0.tif",
params={"url": zarr_pyramid, "variable": "dataset", "group": group},
)
assert resp.status_code == 200
with MemoryFile(resp.content) as mem:
with mem.open() as dst:
arr = dst.read(1)
assert arr.max() == group * 2

resp = app.get(
"/md/point/0,0",
params={"url": zarr_pyramid, "variable": "dataset", "group": group},
)
assert resp.json()["values"] == [group * 2]
95 changes: 12 additions & 83 deletions src/titiler/xarray/titiler/xarray/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import numpy
from fastapi import Query
from rio_tiler.types import RIOResampling, WarpResampling
from starlette.requests import Request
from typing_extensions import Annotated

from titiler.core.dependencies import DefaultDependency
Expand Down Expand Up @@ -73,82 +72,25 @@ class XarrayParams(XarrayIOParams, XarrayDsParams):
pass


@dataclass(init=False)
class CompatXarrayParams(DefaultDependency):
@dataclass
class CompatXarrayParams(XarrayIOParams):
"""Custom XarrayParams endpoints.
This Dependency aims to be used in a tiler where both GDAL/Xarray dataset would be supported.
By default `variable` won't be required but when using an Xarray dataset,
it would fail without the variable query-parameter set.
"""

# File IO Options
group: Optional[int] = None
reference: Optional[bool] = None
decode_times: Optional[bool] = None
consolidated: Optional[bool] = None

# Dataset Options
variable: Optional[str] = None
drop_dim: Optional[str] = None
datetime: Optional[str] = None

def __init__(
self,
request: Request,
variable: Annotated[
Optional[str], Query(description="Xarray Variable name")
] = None,
group: Annotated[
Optional[int],
Query(
description="Select a specific zarr group from a zarr hierarchy. Could be associated with a zoom level or dataset."
),
] = None,
reference: Annotated[
Optional[bool],
Query(
title="reference",
description="Whether the dataset is a kerchunk reference",
),
] = None,
decode_times: Annotated[
Optional[bool],
Query(
title="decode_times",
description="Whether to decode times",
),
] = None,
consolidated: Annotated[
Optional[bool],
Query(
title="consolidated",
description="Whether to expect and open zarr store with consolidated metadata",
),
] = None,
drop_dim: Annotated[
Optional[str],
Query(description="Dimension to drop"),
] = None,
datetime: Annotated[
Optional[str], Query(description="Slice of time to read (if available)")
] = None,
):
"""Initialize XarrayIOParamsTiles
Note: Because we don't want `z and multi-scale` to appear in the documentation we use a dataclass with a custom `__init__` method.
FastAPI will use the `__init__` method but will exclude Request in the documentation making `pool` an invisible dependency.
"""
self.variable = variable
self.group = group
self.reference = reference
self.decode_times = decode_times
self.consolidated = consolidated
self.drop_dim = drop_dim
self.datetime = datetime

if request.query_params.get("multiscale") and request.path_params.get("z"):
self.group = int(request.path_params.get("z"))
variable: Annotated[Optional[str], Query(description="Xarray Variable name")] = None

drop_dim: Annotated[
Optional[str],
Query(description="Dimension to drop"),
] = None

datetime: Annotated[
Optional[str], Query(description="Slice of time to read (if available)")
] = None


@dataclass
Expand Down Expand Up @@ -176,19 +118,6 @@ def __post_init__(self):
self.nodata = numpy.nan if self.nodata == "nan" else float(self.nodata)


@dataclass
class TileParams(DefaultDependency):
"""Custom TileParams for Xarray."""

multiscale: Annotated[
Optional[bool],
Query(
title="multiscale",
description="Whether the dataset has multiscale groups (Zoom levels)",
),
] = None


# Custom PartFeatureParams which add `resampling`
@dataclass
class PartFeatureParams(DefaultDependency):
Expand Down
119 changes: 4 additions & 115 deletions src/titiler/xarray/titiler/xarray/factory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
"""TiTiler.xarray factory."""

from typing import Callable, List, Literal, Optional, Type, Union
from typing import Callable, List, Optional, Type, Union

import rasterio
from attrs import define, field
from fastapi import Body, Depends, Path, Query
from fastapi import Body, Depends, Query
from geojson_pydantic.features import Feature, FeatureCollection
from geojson_pydantic.geometries import MultiPolygon, Polygon
from pydantic import Field
from rio_tiler.constants import WGS84_CRS
from rio_tiler.models import Info
from starlette.responses import Response
from typing_extensions import Annotated

from titiler.core.dependencies import (
Expand All @@ -23,15 +21,11 @@
StatisticsParams,
)
from titiler.core.factory import TilerFactory as BaseTilerFactory
from titiler.core.factory import img_endpoint_params
from titiler.core.models.responses import InfoGeoJSON, StatisticsGeoJSON
from titiler.core.resources.enums import ImageType
from titiler.core.resources.responses import GeoJSONResponse, JSONResponse
from titiler.core.utils import render_image
from titiler.xarray.dependencies import (
DatasetParams,
PartFeatureParams,
TileParams,
XarrayIOParams,
XarrayParams,
)
Expand All @@ -54,8 +48,8 @@ class TilerFactory(BaseTilerFactory):
# Dataset Options (nodata, reproject)
dataset_dependency: Type[DefaultDependency] = DatasetParams

# Tile/Tilejson/WMTS Dependencies (multiscale option)
tile_dependency: Type[TileParams] = TileParams
# Tile/Tilejson/WMTS Dependencies (Not used in titiler.xarray)
tile_dependency: Type[DefaultDependency] = DefaultDependency

# Statistics/Histogram Dependencies
stats_dependency: Type[DefaultDependency] = StatisticsParams
Expand Down Expand Up @@ -185,111 +179,6 @@ def info_geojson(
properties=info,
)

# custom /tiles endpoints (adds `multiscale` options)
def tile(self):
"""Register /tiles endpoint."""

@self.router.get(r"/tiles/{tileMatrixSetId}/{z}/{x}/{y}", **img_endpoint_params)
@self.router.get(
r"/tiles/{tileMatrixSetId}/{z}/{x}/{y}.{format}", **img_endpoint_params
)
@self.router.get(
r"/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x", **img_endpoint_params
)
@self.router.get(
r"/tiles/{tileMatrixSetId}/{z}/{x}/{y}@{scale}x.{format}",
**img_endpoint_params,
)
def tile(
z: Annotated[
int,
Path(
description="Identifier (Z) selecting one of the scales defined in the TileMatrixSet and representing the scaleDenominator the tile.",
),
],
x: Annotated[
int,
Path(
description="Column (X) index of the tile on the selected TileMatrix. It cannot exceed the MatrixHeight-1 for the selected TileMatrix.",
),
],
y: Annotated[
int,
Path(
description="Row (Y) index of the tile on the selected TileMatrix. It cannot exceed the MatrixWidth-1 for the selected TileMatrix.",
),
],
tileMatrixSetId: Annotated[
Literal[tuple(self.supported_tms.list())],
Path(
description="Identifier selecting one of the TileMatrixSetId supported."
),
],
scale: Annotated[
int,
Field(
gt=0, le=4, description="Tile size scale. 1=256x256, 2=512x512..."
),
] = 1,
format: Annotated[
ImageType,
"Default will be automatically defined if the output image needs a mask (png) or not (jpeg).",
] = None,
multiscale: Annotated[
Optional[bool],
Query(
title="multiscale",
description="Whether the dataset has multiscale groups (Zoom levels)",
),
] = None,
src_path=Depends(self.path_dependency),
reader_params=Depends(self.reader_dependency),
tile_params=Depends(self.tile_dependency),
layer_params=Depends(self.layer_dependency),
dataset_params=Depends(self.dataset_dependency),
post_process=Depends(self.process_dependency),
rescale=Depends(self.rescale_dependency),
color_formula=Depends(self.color_formula_dependency),
colormap=Depends(self.colormap_dependency),
render_params=Depends(self.render_dependency),
env=Depends(self.environment_dependency),
):
"""Create map tile from a dataset."""
tms = self.supported_tms.get(tileMatrixSetId)

reader_options = reader_params.as_dict()
if getattr(tile_params, "multiscale", False):
reader_options["group"] = z

with rasterio.Env(**env):
with self.reader(src_path, tms=tms, **reader_options) as src_dst:
image = src_dst.tile(
x,
y,
z,
tilesize=scale * 256,
**layer_params.as_dict(),
**dataset_params.as_dict(),
)

if post_process:
image = post_process(image)

if rescale:
image.rescale(rescale)

if color_formula:
image.apply_color_formula(color_formula)

content, media_type = render_image(
image,
output_format=format,
colormap=colormap,
**render_params.as_dict(),
)

return Response(content, media_type=media_type)

# custom /statistics endpoints (remove /statistics - GET)
def statistics(self):
"""add statistics endpoints."""
Expand Down

0 comments on commit 80f3350

Please sign in to comment.