Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: aws auth typing and generic types #1486

Merged
merged 9 commits into from
Jan 17, 2025
Prev Previous commit
Next Next commit
style: generic list type for hinting
sbrunato committed Jan 17, 2025
commit 1b761703c65df6cf6ec6e356853ba6a1585a7758
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@
import re
from datetime import datetime
from importlib.metadata import metadata
from typing import Any, List
from typing import Any

# -- General configuration ------------------------------------------------

@@ -168,7 +168,7 @@
"custom.css",
]

html_js_files: List[Any] = []
html_js_files: list[Any] = []

# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
32 changes: 16 additions & 16 deletions eodag/api/core.py
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@
import shutil
import tempfile
from operator import itemgetter
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
from typing import TYPE_CHECKING, Any, Iterator, Optional, Union

import geojson
import pkg_resources
@@ -565,7 +565,7 @@ def set_locations_conf(self, locations_conf_path: str) -> None:
main_locations_config = locations_config[main_key]

logger.info("Locations configuration loaded from %s" % locations_conf_path)
self.locations_config: List[dict[str, Any]] = main_locations_config
self.locations_config: list[dict[str, Any]] = main_locations_config
else:
logger.info(
"Could not load locations configuration from %s" % locations_conf_path
@@ -574,7 +574,7 @@ def set_locations_conf(self, locations_conf_path: str) -> None:

def list_product_types(
self, provider: Optional[str] = None, fetch_providers: bool = True
) -> List[dict[str, Any]]:
) -> list[dict[str, Any]]:
"""Lists supported product types.

:param provider: (optional) The name of a provider that must support the product
@@ -588,7 +588,7 @@ def list_product_types(
# First, update product types list if possible
self.fetch_product_types_list(provider=provider)

product_types: List[dict[str, Any]] = []
product_types: list[dict[str, Any]] = []

providers_configs = (
list(self.providers_config.values())
@@ -869,7 +869,7 @@ def update_product_types_list(
provider,
)
continue
new_product_types: List[str] = []
new_product_types: list[str] = []
for (
new_product_type,
new_product_type_conf,
@@ -932,7 +932,7 @@ def update_product_types_list(

def available_providers(
self, product_type: Optional[str] = None, by_group: bool = False
) -> List[str]:
) -> list[str]:
"""Gives the sorted list of the available providers or groups

The providers or groups are sorted first by their priority level in descending order,
@@ -1026,7 +1026,7 @@ def guess_product_type(
missionStartDate: Optional[str] = None,
missionEndDate: Optional[str] = None,
**kwargs: Any,
) -> List[str]:
) -> list[str]:
"""
Find EODAG product type IDs that best match a set of search parameters.

@@ -1084,7 +1084,7 @@ def guess_product_type(
query = p.parse(text)
results = searcher.search(query, limit=None)

guesses: List[dict[str, str]] = [dict(r) for r in results or []]
guesses: list[dict[str, str]] = [dict(r) for r in results or []]

# datetime filtering
if missionStartDate or missionEndDate:
@@ -1205,7 +1205,7 @@ def search(
items_per_page=items_per_page,
)

errors: List[tuple[str, Exception]] = []
errors: list[tuple[str, Exception]] = []
# Loop over available providers and return the first non-empty results
for i, search_plugin in enumerate(search_plugins):
search_plugin.clear()
@@ -1655,7 +1655,7 @@ def _prepare_search(
locations: Optional[dict[str, str]] = None,
provider: Optional[str] = None,
**kwargs: Any,
) -> tuple[List[Union[Search, Api]], dict[str, Any]]:
) -> tuple[list[Union[Search, Api]], dict[str, Any]]:
"""Internal method to prepare the search kwargs and get the search plugins.

Product query:
@@ -1763,7 +1763,7 @@ def _prepare_search(

preferred_provider = self.get_preferred_provider()[0]

search_plugins: List[Union[Search, Api]] = []
search_plugins: list[Union[Search, Api]] = []
for plugin in self._plugins_manager.get_search_plugins(
product_type=product_type, provider=provider
):
@@ -1833,10 +1833,10 @@ def _do_search(
max_items_per_page,
)

results: List[EOProduct] = []
results: list[EOProduct] = []
total_results: Optional[int] = 0 if count else None

errors: List[tuple[str, Exception]] = []
errors: list[tuple[str, Exception]] = []

try:
prep = PreparedSearch(count=count)
@@ -1984,7 +1984,7 @@ def crunch(self, results: SearchResult, **kwargs: Any) -> SearchResult:
return results

@staticmethod
def group_by_extent(searches: List[SearchResult]) -> List[SearchResult]:
def group_by_extent(searches: list[SearchResult]) -> list[SearchResult]:
"""Combines multiple SearchResults and return a list of SearchResults grouped
by extent (i.e. bounding box).

@@ -2015,7 +2015,7 @@ def download_all(
wait: float = DEFAULT_DOWNLOAD_WAIT,
timeout: float = DEFAULT_DOWNLOAD_TIMEOUT,
**kwargs: Unpack[DownloadConf],
) -> List[str]:
) -> list[str]:
"""Download all products resulting from a search.

:param search_result: A collection of EO products resulting from a search
@@ -2273,7 +2273,7 @@ def list_queryables(
properties, associating parameters to their annotated type, and a additional_properties attribute
"""
# only fetch providers if product type is not found
available_product_types: List[str] = [
available_product_types: list[str] = [
pt["ID"]
for pt in self.list_product_types(provider=provider, fetch_providers=False)
]
4 changes: 2 additions & 2 deletions eodag/api/product/_assets.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@

import re
from collections import UserDict
from typing import TYPE_CHECKING, Any, List, Optional
from typing import TYPE_CHECKING, Any, Optional

from eodag.utils.exceptions import NotAvailableError
from eodag.utils.repr import dict_to_html_table
@@ -56,7 +56,7 @@ def as_dict(self) -> dict[str, Any]:
"""
return {k: v.as_dict() for k, v in self.data.items()}

def get_values(self, asset_filter: str = "") -> List[Asset]:
def get_values(self, asset_filter: str = "") -> list[Asset]:
"""
retrieves the assets matching the given filter

58 changes: 24 additions & 34 deletions eodag/api/product/metadata_mapping.py
Original file line number Diff line number Diff line change
@@ -23,17 +23,7 @@
import re
from datetime import datetime, timedelta
from string import Formatter
from typing import (
TYPE_CHECKING,
Any,
AnyStr,
Callable,
Iterator,
List,
Optional,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, AnyStr, Callable, Iterator, Optional, Union, cast

import geojson
import orjson
@@ -86,8 +76,8 @@


def get_metadata_path(
map_value: Union[str, List[str]],
) -> tuple[Union[List[str], None], str]:
map_value: Union[str, list[str]],
) -> tuple[Union[list[str], None], str]:
"""Return the jsonpath or xpath to the value of a EO product metadata in a provider
search result.

@@ -135,12 +125,12 @@ def get_metadata_path(
return None, path


def get_metadata_path_value(map_value: Union[str, List[str]]) -> str:
def get_metadata_path_value(map_value: Union[str, list[str]]) -> str:
"""Get raw metadata path without converter"""
return map_value[1] if isinstance(map_value, list) else map_value


def get_search_param(map_value: List[str]) -> str:
def get_search_param(map_value: list[str]) -> str:
"""See :func:`~eodag.api.product.metadata_mapping.get_metadata_path`

:param map_value: The value originating from the definition of `metadata_mapping`
@@ -333,7 +323,7 @@ def convert_to_rounded_wkt(value: BaseGeometry) -> str:
return wkt_value

@staticmethod
def convert_to_bounds_lists(input_geom: BaseGeometry) -> List[List[float]]:
def convert_to_bounds_lists(input_geom: BaseGeometry) -> list[list[float]]:
if isinstance(input_geom, MultiPolygon):
geoms = [geom for geom in input_geom.geoms]
# sort with larger one at first (stac-browser only plots first one)
@@ -343,7 +333,7 @@ def convert_to_bounds_lists(input_geom: BaseGeometry) -> List[List[float]]:
return [list(input_geom.bounds[0:4])]

@staticmethod
def convert_to_bounds(input_geom_unformatted: Any) -> List[float]:
def convert_to_bounds(input_geom_unformatted: Any) -> list[float]:
input_geom = get_geometry_from_various(geometry=input_geom_unformatted)
if isinstance(input_geom, MultiPolygon):
geoms = [geom for geom in input_geom.geoms]
@@ -363,7 +353,7 @@ def convert_to_bounds(input_geom_unformatted: Any) -> List[float]:
return list(input_geom.bounds[0:4])

@staticmethod
def convert_to_nwse_bounds(input_geom: BaseGeometry) -> List[float]:
def convert_to_nwse_bounds(input_geom: BaseGeometry) -> list[float]:
if isinstance(input_geom, str):
input_geom = shapely.wkt.loads(input_geom)
return list(input_geom.bounds[-1:] + input_geom.bounds[:-1])
@@ -447,7 +437,7 @@ def flatten_elements(nested) -> Iterator[Any]:
else:
yield e

polygons_list: List[Polygon] = []
polygons_list: list[Polygon] = []
for elem in flatten_elements(georss[0]):
coords_list = elem.text.split()
polygon_args = [
@@ -512,8 +502,8 @@ def convert_replace_str(string: str, args: str) -> str:

@staticmethod
def convert_recursive_sub_str(
input_obj: Union[dict[Any, Any], List[Any]], args: str
) -> Union[dict[Any, Any], List[Any]]:
input_obj: Union[dict[Any, Any], list[Any]], args: str
) -> Union[dict[Any, Any], list[Any]]:
old, new = ast.literal_eval(args)
return items_recursive_apply(
input_obj,
@@ -615,7 +605,7 @@ def convert_s2msil2a_title_to_aws_productinfo(string: str) -> str:

@staticmethod
def convert_split_id_into_s1_params(product_id: str) -> dict[str, str]:
parts: List[str] = re.split(r"_(?!_)", product_id)
parts: list[str] = re.split(r"_(?!_)", product_id)
if len(parts) < 9:
logger.error(
"id %s does not match expected Sentinel-1 id format", product_id
@@ -650,7 +640,7 @@ def convert_split_id_into_s1_params(product_id: str) -> dict[str, str]:

@staticmethod
def convert_split_id_into_s3_params(product_id: str) -> dict[str, str]:
parts: List[str] = re.split(r"_(?!_)", product_id)
parts: list[str] = re.split(r"_(?!_)", product_id)
params = {"productType": product_id[4:15]}
dates = re.findall("[0-9]{8}T[0-9]{6}", product_id)
start_date = datetime.strptime(dates[0], "%Y%m%dT%H%M%S") - timedelta(
@@ -667,7 +657,7 @@ def convert_split_id_into_s3_params(product_id: str) -> dict[str, str]:

@staticmethod
def convert_split_id_into_s5p_params(product_id: str) -> dict[str, str]:
parts: List[str] = re.split(r"_(?!_)", product_id)
parts: list[str] = re.split(r"_(?!_)", product_id)
params = {
"productType": product_id[9:19],
"processingMode": parts[1],
@@ -684,7 +674,7 @@ def convert_split_id_into_s5p_params(product_id: str) -> dict[str, str]:
return params

@staticmethod
def convert_split_cop_dem_id(product_id: str) -> List[int]:
def convert_split_cop_dem_id(product_id: str) -> list[int]:
parts = product_id.split("_")
lattitude = parts[3]
longitude = parts[5]
@@ -723,7 +713,7 @@ def convert_dates_from_cmems_id(product_id: str):
@staticmethod
def convert_to_datetime_dict(
date: str, format: str
) -> dict[str, Union[List[str], str]]:
) -> dict[str, Union[list[str], str]]:
"""Convert a date (str) to a dictionary where values are in the format given in argument

date == "2021-04-21T18:27:19.123Z" and format == "list" => {
@@ -775,7 +765,7 @@ def convert_to_datetime_dict(
@staticmethod
def convert_interval_to_datetime_dict(
date: str, separator: str = "/"
) -> dict[str, List[str]]:
) -> dict[str, list[str]]:
"""Convert a date interval ('/' separated str) to a dictionary where values are lists

date == "2021-04-21/2021-04-22" => {
@@ -815,7 +805,7 @@ def convert_interval_to_datetime_dict(
}

@staticmethod
def convert_get_ecmwf_time(date: str) -> List[str]:
def convert_get_ecmwf_time(date: str) -> list[str]:
"""Get the time of a date (str) in the ECMWF format (["HH:00"])

"2021-04-21T18:27:19.123Z" => ["18:00"]
@@ -859,7 +849,7 @@ def convert_get_variables_from_path(path: str):

@staticmethod
def convert_assets_list_to_dict(
assets_list: List[dict[str, str]], asset_name_key: str = "title"
assets_list: list[dict[str, str]], asset_name_key: str = "title"
) -> dict[str, dict[str, str]]:
"""Convert a list of assets to a dictionary where keys represent
name of assets and are found among values of asset dictionaries.
@@ -887,7 +877,7 @@ def convert_assets_list_to_dict(
"asset3": {"href": "qux", "title": "qux-title", "name": "asset3"},
}
"""
asset_names: List[str] = []
asset_names: list[str] = []
assets_dict: dict[str, dict[str, str]] = {}

for asset in assets_list:
@@ -897,7 +887,7 @@ def convert_assets_list_to_dict(

# we only keep the equivalent of the path basename in the case where the
# asset name has a path pattern and this basename is only found once
immutable_asset_indexes: List[int] = []
immutable_asset_indexes: list[int] = []
for i, asset_name in enumerate(asset_names):
if i in immutable_asset_indexes:
continue
@@ -1479,7 +1469,7 @@ def _get_queryables(


def get_queryable_from_provider(
provider_queryable: str, metadata_mapping: dict[str, Union[str, List[str]]]
provider_queryable: str, metadata_mapping: dict[str, Union[str, list[str]]]
) -> Optional[str]:
"""Get EODAG configured queryable parameter from provider queryable parameter

@@ -1503,7 +1493,7 @@ def get_queryable_from_provider(


def get_provider_queryable_path(
queryable: str, metadata_mapping: dict[str, Union[str, List[str]]]
queryable: str, metadata_mapping: dict[str, Union[str, list[str]]]
) -> Optional[str]:
"""Get EODAG configured queryable path from its parameter

@@ -1521,7 +1511,7 @@ def get_provider_queryable_path(
def get_provider_queryable_key(
eodag_key: str,
provider_queryables: dict[str, Any],
metadata_mapping: dict[str, Union[List[Any], str]],
metadata_mapping: dict[str, Union[list[Any], str]],
) -> str:
"""Finds the provider queryable corresponding to the given eodag key based on the metadata mapping

Loading