diff --git a/.example.env b/.example.env index 520a1085..c3fb4b62 100644 --- a/.example.env +++ b/.example.env @@ -25,6 +25,7 @@ VEDA_RASTER_EXPORT_ASSUME_ROLE_CREDS_AS_ENVS=False VEDA_RASTER_ROOT_PATH= VEDA_STAC_ROOT_PATH= +VEDA_STAC_ENABLE_TRANSACTIONS=FALSE VEDA_USERPOOL_ID= VEDA_CLIENT_ID= diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index ccad6ebc..03b7fbad 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -77,9 +77,15 @@ jobs: - name: Install reqs for ingest api run: python -m pip install -r ingest_api/runtime/requirements_dev.txt + - name: Install veda auth for ingest api + run: python -m pip install common/auth + - name: Ingest unit tests run: NO_PYDANTIC_SSM_SETTINGS=1 python -m pytest ingest_api/runtime/tests/ -vv -s + # - name: Stac-api transactions unit tests + # run: python -m pytest stac_api/runtime/tests/ -vv -s + - name: Stop services run: docker compose stop diff --git a/common/__init__.py b/common/__init__.py new file mode 100644 index 00000000..bde69941 --- /dev/null +++ b/common/__init__.py @@ -0,0 +1 @@ +"""common utils shared by veda stacks""" diff --git a/common/auth/setup.py b/common/auth/setup.py new file mode 100644 index 00000000..a076c143 --- /dev/null +++ b/common/auth/setup.py @@ -0,0 +1,17 @@ +"""Setup veda_auth +""" + +from setuptools import find_packages, setup + +inst_reqs = ["cryptography>=42.0.5", "pyjwt>=2.8.0", "fastapi", "pydantic<2"] + +setup( + name="veda_auth", + version="0.0.1", + description="", + python_requires=">=3.7", + packages=find_packages(), + zip_safe=False, + install_requires=inst_reqs, + include_package_data=True, +) diff --git a/common/auth/veda_auth/__init__.py b/common/auth/veda_auth/__init__.py new file mode 100644 index 00000000..1948ba93 --- /dev/null +++ b/common/auth/veda_auth/__init__.py @@ -0,0 +1,5 @@ +""" + VEDA cognito auth +""" + +from veda_auth.main import VedaAuth # noqa: F401 diff --git a/common/auth/veda_auth/main.py b/common/auth/veda_auth/main.py new file mode 100644 index 00000000..ce19e65f --- /dev/null +++ b/common/auth/veda_auth/main.py @@ -0,0 +1,128 @@ +"""Authentication handler for veda.stac and veda.ingest""" + +import base64 +import hashlib +import hmac +import logging +from typing import Annotated, Any, Dict + +import boto3 +import jwt + +from fastapi import Depends, HTTPException, Security, security, status + +logger = logging.getLogger(__name__) + + +class VedaAuth: + """Class for handling authentication""" + + def __init__(self, settings) -> None: + """ + Args: + settings: pydantic settings object containing cognito details + Returns: + None + + """ + self.oauth2_scheme = security.OAuth2AuthorizationCodeBearer( + authorizationUrl=settings.cognito_authorization_url, + tokenUrl=settings.cognito_token_url, + refreshUrl=settings.cognito_token_url, + ) + + self.jwks_client = jwt.PyJWKClient(settings.jwks_url) # Caches JWKS + + def validated_token( + token_str: Annotated[str, Security(self.oauth2_scheme)], + required_scopes: security.SecurityScopes, + ) -> Dict: + # Parse & validate token + logger.info(f"\nToken String {token_str}") + try: + token = jwt.decode( + token_str, + self.jwks_client.get_signing_key_from_jwt(token_str).key, + algorithms=["RS256"], + ) + except jwt.exceptions.InvalidTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) from e + + # Validate scopes (if required) + for scope in required_scopes.scopes: + if scope not in token["scope"]: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Not enough permissions", + headers={ + "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' + }, + ) + + return token + + self.validated_token = validated_token + + def get_username( + token: Annotated[Dict[Any, Any], Depends(self.validated_token)] + ) -> str: + result = token["username"] if "username" in token else str(token.get("sub")) + return result + + self.get_username = get_username + + def _get_secret_hash( + self, username: str, client_id: str, client_secret: str + ) -> str: + # A keyed-hash message authentication code (HMAC) calculated using + # the secret key of a user pool client and username plus the client + # ID in the message. + message = username + client_id + dig = hmac.new( + bytearray(client_secret, "utf-8"), + msg=message.encode("UTF-8"), + digestmod=hashlib.sha256, + ).digest() + return base64.b64encode(dig).decode() + + def authenticate_and_get_token( + self, + username: str, + password: str, + user_pool_id: str, + app_client_id: str, + app_client_secret: str, + ) -> Dict: + """Authenticates the credentials and returns token""" + client = boto3.client("cognito-idp") + if app_client_secret: + auth_params = { + "USERNAME": username, + "PASSWORD": password, + "SECRET_HASH": self._get_secret_hash( + username, app_client_id, app_client_secret + ), + } + else: + auth_params = { + "USERNAME": username, + "PASSWORD": password, + } + try: + resp = client.admin_initiate_auth( + UserPoolId=user_pool_id, + ClientId=app_client_id, + AuthFlow="ADMIN_USER_PASSWORD_AUTH", + AuthParameters=auth_params, + ) + except client.exceptions.NotAuthorizedException: + return { + "message": "Login failed, please make sure the credentials are correct." + } + except Exception as e: + return {"message": f"Login failed with exception {e}"} + return resp["AuthenticationResult"] diff --git a/ingest_api/infrastructure/construct.py b/ingest_api/infrastructure/construct.py index 1e763d09..69b3265c 100644 --- a/ingest_api/infrastructure/construct.py +++ b/ingest_api/infrastructure/construct.py @@ -107,12 +107,6 @@ def __init__( value=self.api.url, ) - register_ssm_parameter( - self, - name="jwks_url", - value=self.jwks_url, - description="JWKS URL for Cognito user pool", - ) register_ssm_parameter( self, name="dynamodb_table", diff --git a/ingest_api/runtime/Dockerfile b/ingest_api/runtime/Dockerfile index c2955171..201c48ce 100644 --- a/ingest_api/runtime/Dockerfile +++ b/ingest_api/runtime/Dockerfile @@ -2,6 +2,10 @@ FROM public.ecr.aws/sam/build-python3.9:latest WORKDIR /tmp +COPY common/auth /tmp/common/auth +RUN pip install /tmp/common/auth -t /asset +RUN rm -rf /tmp/common + COPY ingest_api/runtime/requirements.txt /tmp/ingestor/requirements.txt RUN pip install -r /tmp/ingestor/requirements.txt -t /asset --no-binary pydantic uvicorn RUN rm -rf /tmp/ingestor diff --git a/ingest_api/runtime/requirements.txt b/ingest_api/runtime/requirements.txt index de984f61..66537327 100644 --- a/ingest_api/runtime/requirements.txt +++ b/ingest_api/runtime/requirements.txt @@ -8,7 +8,6 @@ orjson>=3.6.8 psycopg[binary,pool]>=3.0.15 pydantic_ssm_settings>=0.2.0 pydantic>=1.10.12 -pyjwt>=2.8.0 pypgstac==0.7.4 python-multipart==0.0.7 requests>=2.27.1 diff --git a/ingest_api/runtime/src/auth.py b/ingest_api/runtime/src/auth.py deleted file mode 100644 index 294bec64..00000000 --- a/ingest_api/runtime/src/auth.py +++ /dev/null @@ -1,106 +0,0 @@ -import base64 -import hashlib -import hmac -import logging -from typing import Annotated, Any, Dict - -import boto3 -import jwt -from src.config import settings - -from fastapi import Depends, HTTPException, Security, security, status - -logger = logging.getLogger(__name__) - -oauth2_scheme = security.OAuth2AuthorizationCodeBearer( - authorizationUrl=settings.cognito_authorization_url, - tokenUrl=settings.cognito_token_url, - refreshUrl=settings.cognito_token_url, -) - -jwks_client = jwt.PyJWKClient(settings.jwks_url) # Caches JWKS - - -def validated_token( - token_str: Annotated[str, Security(oauth2_scheme)], - required_scopes: security.SecurityScopes, -) -> Dict: - # Parse & validate token - try: - token = jwt.decode( - token_str, - jwks_client.get_signing_key_from_jwt(token_str).key, - algorithms=["RS256"], - ) - except jwt.exceptions.InvalidTokenError as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) from e - - # Validate scopes (if required) - for scope in required_scopes.scopes: - if scope not in token["scope"]: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Not enough permissions", - headers={ - "WWW-Authenticate": f'Bearer scope="{required_scopes.scope_str}"' - }, - ) - - return token - - -def get_username(token: Annotated[Dict[Any, Any], Depends(validated_token)]) -> str: - result = token["username"] if "username" in token else str(token.get("sub")) - return result - - -def _get_secret_hash(username: str, client_id: str, client_secret: str) -> str: - # A keyed-hash message authentication code (HMAC) calculated using - # the secret key of a user pool client and username plus the client - # ID in the message. - message = username + client_id - dig = hmac.new( - bytearray(client_secret, "utf-8"), - msg=message.encode("UTF-8"), - digestmod=hashlib.sha256, - ).digest() - return base64.b64encode(dig).decode() - - -def authenticate_and_get_token( - username: str, - password: str, - user_pool_id: str, - app_client_id: str, - app_client_secret: str, -) -> Dict: - client = boto3.client("cognito-idp") - if app_client_secret: - auth_params = { - "USERNAME": username, - "PASSWORD": password, - "SECRET_HASH": _get_secret_hash(username, app_client_id, app_client_secret), - } - else: - auth_params = { - "USERNAME": username, - "PASSWORD": password, - } - try: - resp = client.admin_initiate_auth( - UserPoolId=user_pool_id, - ClientId=app_client_id, - AuthFlow="ADMIN_USER_PASSWORD_AUTH", - AuthParameters=auth_params, - ) - except client.exceptions.NotAuthorizedException: - return { - "message": "Login failed, please make sure the credentials are correct." - } - except Exception as e: - return {"message": f"Login failed with exception {e}"} - return resp["AuthenticationResult"] diff --git a/ingest_api/runtime/src/config.py b/ingest_api/runtime/src/config.py index ff51f680..0b9e1c0b 100644 --- a/ingest_api/runtime/src/config.py +++ b/ingest_api/runtime/src/config.py @@ -4,6 +4,7 @@ from pydantic import AnyHttpUrl, BaseSettings, Field, constr from pydantic_ssm_settings import AwsSsmSourceConfig +from veda_auth import VedaAuth AwsArn = constr(regex=r"^arn:aws:iam::\d{12}:role/.+") @@ -63,3 +64,5 @@ def from_ssm(cls, stack: str): ), ) ) + +auth = VedaAuth(settings) diff --git a/ingest_api/runtime/src/dependencies.py b/ingest_api/runtime/src/dependencies.py index 901f338c..5c772c1f 100644 --- a/ingest_api/runtime/src/dependencies.py +++ b/ingest_api/runtime/src/dependencies.py @@ -1,9 +1,8 @@ import logging import boto3 -import src.auth as auth -import src.config as config import src.services as services +from src.config import auth, settings from fastapi import Depends, HTTPException, security @@ -14,7 +13,7 @@ def get_table(): client = boto3.resource("dynamodb") - return client.Table(config.settings.dynamodb_table) + return client.Table(settings.dynamodb_table) def get_db(table=Depends(get_table)) -> services.Database: diff --git a/ingest_api/runtime/src/main.py b/ingest_api/runtime/src/main.py index 9b5fa6d2..8074c860 100644 --- a/ingest_api/runtime/src/main.py +++ b/ingest_api/runtime/src/main.py @@ -1,12 +1,11 @@ from typing import Dict -import src.auth as auth import src.dependencies as dependencies import src.schemas as schemas import src.services as services from aws_lambda_powertools.metrics import MetricUnit from src.collection_publisher import CollectionPublisher, ItemPublisher -from src.config import settings +from src.config import auth, settings from src.doc import DESCRIPTION from src.monitoring import LoggerRouteHandler, logger, metrics, tracer diff --git a/local/Dockerfile.ingest b/local/Dockerfile.ingest index bcce1482..ccc5741e 100644 --- a/local/Dockerfile.ingest +++ b/local/Dockerfile.ingest @@ -7,6 +7,8 @@ RUN pip install -r /tmp/ingestor/requirements.txt --no-binary pydantic uvicorn RUN rm -rf /tmp/ingestor # TODO this is temporary until we use a real packaging system like setup.py or poetry COPY ingest_api/runtime/src /asset/src +COPY common/auth /tmp/common/auth +RUN pip install /tmp/common/auth # # Reduce package size and remove useless files RUN cd /asset && find . -type f -name '*.pyc' | while read f; do n=$(echo $f | sed 's/__pycache__\///' | sed 's/.cpython-[2-3][0-9]//'); cp $f $n; done; diff --git a/local/Dockerfile.stac b/local/Dockerfile.stac index 4a3b1af1..71463001 100644 --- a/local/Dockerfile.stac +++ b/local/Dockerfile.stac @@ -4,12 +4,14 @@ FROM ghcr.io/vincentsarago/uvicorn-gunicorn:${PYTHON_VERSION} ENV CURL_CA_BUNDLE /etc/ssl/certs/ca-certificates.crt -RUN pip install boto3 - -COPY stac_api/runtime /tmp/stac # Installing boto3, which isn't needed in the lambda container instance # since lambda execution environment includes boto3 by default RUN pip install boto3 + +COPY stac_api/runtime /tmp/stac + +COPY common/auth /tmp/stac/common/auth +RUN pip install /tmp/stac/common/auth RUN pip install /tmp/stac RUN rm -rf /tmp/stac diff --git a/stac_api/infrastructure/config.py b/stac_api/infrastructure/config.py index 65c04eda..2196a430 100644 --- a/stac_api/infrastructure/config.py +++ b/stac_api/infrastructure/config.py @@ -2,7 +2,7 @@ from typing import Dict, Optional -from pydantic import BaseSettings, Field +from pydantic import AnyHttpUrl, BaseSettings, Field, root_validator class vedaSTACSettings(BaseSettings): @@ -44,6 +44,37 @@ class vedaSTACSettings(BaseSettings): description="Description of the STAC Catalog", ) + userpool_id: Optional[str] = Field( + description="The Cognito Userpool used for authentication" + ) + cognito_domain: Optional[AnyHttpUrl] = Field( + description="The base url of the Cognito domain for authorization and token urls" + ) + client_id: Optional[str] = Field(description="The Cognito APP client ID") + client_secret: Optional[str] = Field( + "", description="The Cognito APP client secret" + ) + stac_enable_transactions: bool = Field( + False, description="Whether to enable transactions endpoints" + ) + + @root_validator + def check_transaction_fields(cls, values): + """ + Validates the existence of auth env vars in case stac_enable_transactions is True + """ + if values.get("stac_enable_transactions"): + missing_fields = [ + field + for field in ["userpool_id", "cognito_domain", "client_id"] + if not values.get(field) + ] + if missing_fields: + raise ValueError( + f"When 'stac_enable_transactions' is True, the following fields must be provided: {', '.join(missing_fields)}" + ) + return values + class Config: """model config""" diff --git a/stac_api/infrastructure/construct.py b/stac_api/infrastructure/construct.py index 272cddca..95be8f23 100644 --- a/stac_api/infrastructure/construct.py +++ b/stac_api/infrastructure/construct.py @@ -43,6 +43,22 @@ def __init__( # TODO config stack_name = Stack.of(self).stack_name + lambda_env = { + "VEDA_STAC_PROJECT_NAME": veda_stac_settings.project_name, + "VEDA_STAC_PROJECT_DESCRIPTION": veda_stac_settings.project_description, + "VEDA_STAC_ROOT_PATH": veda_stac_settings.stac_root_path, + "VEDA_STAC_STAGE": stage, + "VEDA_STAC_USERPOOL_ID": veda_stac_settings.userpool_id, + "VEDA_STAC_CLIENT_ID": veda_stac_settings.client_id, + "VEDA_STAC_COGNITO_DOMAIN": veda_stac_settings.cognito_domain, + "VEDA_STAC_ENABLE_TRANSACTIONS": str( + veda_stac_settings.stac_enable_transactions + ), + "DB_MIN_CONN_SIZE": "0", + "DB_MAX_CONN_SIZE": "1", + **{k.upper(): v for k, v in veda_stac_settings.env.items()}, + } + lambda_function = aws_lambda.Function( self, "lambda", @@ -56,15 +72,7 @@ def __init__( allow_public_subnet=True, memory_size=veda_stac_settings.memory, timeout=Duration.seconds(veda_stac_settings.timeout), - environment={ - **{k.upper(): v for k, v in veda_stac_settings.env.items()}, - "DB_MIN_CONN_SIZE": "0", - "DB_MAX_CONN_SIZE": "1", - "VEDA_STAC_ROOT_PATH": veda_stac_settings.stac_root_path, - "VEDA_STAC_STAGE": stage, - "VEDA_STAC_PROJECT_NAME": veda_stac_settings.project_name, - "VEDA_STAC_PROJECT_DESCRIPTION": veda_stac_settings.project_description, - }, + environment=lambda_env, log_retention=aws_logs.RetentionDays.ONE_WEEK, tracing=aws_lambda.Tracing.ACTIVE, ) diff --git a/stac_api/runtime/Dockerfile b/stac_api/runtime/Dockerfile index 5dcae19c..c8a8c852 100644 --- a/stac_api/runtime/Dockerfile +++ b/stac_api/runtime/Dockerfile @@ -3,7 +3,10 @@ FROM --platform=linux/amd64 public.ecr.aws/sam/build-python3.9:latest WORKDIR /tmp COPY stac_api/runtime /tmp/stac + RUN pip install "mangum>=0.14,<0.15" "plpygis>=0.2.1" /tmp/stac -t /asset --no-binary pydantic +COPY common/auth /tmp/stac/common/auth +RUN pip install /tmp/stac/common/auth -t /asset RUN rm -rf /tmp/stac # Reduce package size and remove useless files diff --git a/stac_api/runtime/handler.py b/stac_api/runtime/handler.py index b04844cf..176be13b 100644 --- a/stac_api/runtime/handler.py +++ b/stac_api/runtime/handler.py @@ -4,11 +4,8 @@ from mangum import Mangum from src.app import app -from src.config import ApiSettings from src.monitoring import logger, metrics, tracer -settings = ApiSettings() - logging.getLogger("mangum.lifespan").setLevel(logging.ERROR) logging.getLogger("mangum.http").setLevel(logging.ERROR) diff --git a/stac_api/runtime/setup.py b/stac_api/runtime/setup.py index 6dfa0ff0..604e25dd 100644 --- a/stac_api/runtime/setup.py +++ b/stac_api/runtime/setup.py @@ -18,6 +18,7 @@ "pygeoif<=0.8", # newest release (1.0+ / 09-22-2022) breaks a number of other geo libs "aws-lambda-powertools>=1.18.0", "aws_xray_sdk>=2.6.0,<3", + "pystac[validation]==1.10.1", ] extra_reqs = { diff --git a/stac_api/runtime/src/app.py b/stac_api/runtime/src/app.py index 0e30fa82..61dac75c 100644 --- a/stac_api/runtime/src/app.py +++ b/stac_api/runtime/src/app.py @@ -1,14 +1,16 @@ """FastAPI application using PGStac. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac """ + from aws_lambda_powertools.metrics import MetricUnit -from src.config import ApiSettings, TilesApiSettings +from src.config import TilesApiSettings, api_settings from src.config import extensions as PgStacExtensions from src.config import get_request_model as GETModel from src.config import post_request_model as POSTModel from src.extension import TiTilerExtension from fastapi import APIRouter, FastAPI +from fastapi.params import Depends from fastapi.responses import ORJSONResponse from stac_fastapi.pgstac.db import close_db_connection, connect_to_db from starlette.middleware.cors import CORSMiddleware @@ -20,6 +22,8 @@ from .api import VedaStacApi from .core import VedaCrudClient from .monitoring import LoggerRouteHandler, logger, metrics, tracer +from .routes import add_route_dependencies +from .validation import ValidationMiddleware try: from importlib.resources import files as resources_files # type: ignore @@ -30,7 +34,6 @@ templates = Jinja2Templates(directory=str(resources_files(__package__) / "templates")) # type: ignore -api_settings = ApiSettings() tiles_settings = TilesApiSettings() api = VedaStacApi( @@ -39,6 +42,15 @@ openapi_url="/openapi.json", docs_url="/docs", root_path=api_settings.root_path, + swagger_ui_init_oauth=( + { + "appName": "Cognito", + "clientId": api_settings.client_id, + "usePkceWithAuthorizationCodeGrant": True, + } + if api_settings.client_id + else {} + ), ), title=f"{api_settings.project_name} STAC API", description=api_settings.project_description, @@ -48,7 +60,7 @@ search_get_request_model=GETModel, search_post_request_model=POSTModel, response_class=ORJSONResponse, - middlewares=[CompressionMiddleware], + middlewares=[CompressionMiddleware, ValidationMiddleware], router=APIRouter(route_class=LoggerRouteHandler), ) app = api.app @@ -62,10 +74,45 @@ CORSMiddleware, allow_origins=api_settings.cors_origins, allow_credentials=True, - allow_methods=["GET", "POST", "OPTIONS"], + allow_methods=["GET", "POST", "PUT", "OPTIONS"], allow_headers=["*"], ) +if api_settings.enable_transactions: + from veda_auth import VedaAuth + + auth = VedaAuth(api_settings) + # Require auth for all endpoints that create, modify or delete data. + add_route_dependencies( + app.router.routes, + [ + {"path": "/collections", "method": "POST", "type": "http"}, + {"path": "/collections/{collectionId}", "method": "PUT", "type": "http"}, + {"path": "/collections/{collectionId}", "method": "DELETE", "type": "http"}, + { + "path": "/collections/{collectionId}/items", + "method": "POST", + "type": "http", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "PUT", + "type": "http", + }, + { + "path": "/collections/{collectionId}/items/{itemId}", + "method": "DELETE", + "type": "http", + }, + { + "path": "/collections/{collectionId}/bulk_items", + "method": "POST", + "type": "http", + }, + ], + [Depends(auth.validated_token)], + ) + if tiles_settings.titiler_endpoint: # Register to the TiTiler extension to the api extension = TiTilerExtension() diff --git a/stac_api/runtime/src/config.py b/stac_api/runtime/src/config.py index 0a70e13f..78eccf94 100644 --- a/stac_api/runtime/src/config.py +++ b/stac_api/runtime/src/config.py @@ -1,13 +1,15 @@ """API settings. Based on https://github.com/developmentseed/eoAPI/tree/master/src/eoapi/stac""" + import base64 import json from functools import lru_cache from typing import Optional import boto3 -import pydantic +from pydantic import AnyHttpUrl, BaseSettings, Field, root_validator, validator +from fastapi.responses import ORJSONResponse from stac_fastapi.api.models import create_get_request_model, create_post_request_model # from stac_fastapi.pgstac.extensions import QueryExtension @@ -18,8 +20,11 @@ QueryExtension, SortExtension, TokenPaginationExtension, + TransactionExtension, ) +from stac_fastapi.extensions.third_party import BulkTransactionExtension from stac_fastapi.pgstac.config import Settings +from stac_fastapi.pgstac.transactions import BulkTransactionsClient, TransactionsClient from stac_fastapi.pgstac.types.search import PgstacSearch @@ -47,7 +52,7 @@ def get_secret_dict(secret_name: str): return json.loads(base64.b64decode(get_secret_value_response["SecretBinary"])) -class _ApiSettings(pydantic.BaseSettings): +class _ApiSettings(BaseSettings): """API settings""" project_name: Optional[str] = "veda" @@ -59,7 +64,54 @@ class _ApiSettings(pydantic.BaseSettings): pgstac_secret_arn: Optional[str] stage: Optional[str] = None - @pydantic.validator("cors_origins") + userpool_id: Optional[str] = Field( + "", description="The Cognito Userpool used for authentication" + ) + cognito_domain: Optional[AnyHttpUrl] = Field( + description="The base url of the Cognito domain for authorization and token urls" + ) + client_id: Optional[str] = Field(description="The Cognito APP client ID") + client_secret: Optional[str] = Field( + "", description="The Cognito APP client secret" + ) + enable_transactions: bool = Field( + False, description="Whether to enable transactions" + ) + + @root_validator + def check_transaction_fields(cls, values): + enable_transactions = values.get("enable_transactions") + + if enable_transactions: + missing_fields = [ + field + for field in ["userpool_id", "cognito_domain", "client_id"] + if not values.get(field) + ] + if missing_fields: + raise ValueError( + f"When 'enable_transactions' is True, the following fields must be provided: {', '.join(missing_fields)}" + ) + return values + + @property + def jwks_url(self) -> AnyHttpUrl: + """JWKS url""" + if self.userpool_id: + region = self.userpool_id.split("_")[0] + return f"https://cognito-idp.{region}.amazonaws.com/{self.userpool_id}/.well-known/jwks.json" + + @property + def cognito_authorization_url(self) -> AnyHttpUrl: + """Cognito user pool authorization url""" + return f"{self.cognito_domain}/oauth2/authorize" + + @property + def cognito_token_url(self) -> AnyHttpUrl: + """Cognito user pool token and refresh url""" + return f"{self.cognito_domain}/oauth2/token" + + @validator("cors_origins") def parse_cors_origin(cls, v): """Parse CORS origins.""" return [origin.strip() for origin in v.split(",")] @@ -101,7 +153,10 @@ def ApiSettings() -> _ApiSettings: return _ApiSettings() -class _TilesApiSettings(pydantic.BaseSettings): +api_settings = ApiSettings() + + +class _TilesApiSettings(BaseSettings): """Tile API settings""" titiler_endpoint: Optional[str] @@ -123,12 +178,24 @@ def TilesApiSettings() -> _TilesApiSettings: extensions = [ + ContextExtension(), + FieldsExtension(), FilterExtension(), QueryExtension(), SortExtension(), - FieldsExtension(), TokenPaginationExtension(), - ContextExtension(), ] + +if api_settings.enable_transactions: + extensions.extend( + [ + BulkTransactionExtension(client=BulkTransactionsClient()), + TransactionExtension( + client=TransactionsClient(), + settings=ApiSettings().load_postgres_settings(), + response_class=ORJSONResponse, + ), + ] + ) post_request_model = create_post_request_model(extensions, base_model=PgstacSearch) get_request_model = create_get_request_model(extensions) diff --git a/stac_api/runtime/src/routes.py b/stac_api/runtime/src/routes.py new file mode 100644 index 00000000..ae0bae08 --- /dev/null +++ b/stac_api/runtime/src/routes.py @@ -0,0 +1,25 @@ +"""Dependency injection in to fastapi routes""" + +from typing import List + +from fastapi.dependencies.utils import get_parameterless_sub_dependant +from fastapi.params import Depends +from fastapi.routing import APIRoute +from starlette.routing import Match +from starlette.types import Scope + + +def add_route_dependencies( + routes: List[APIRoute], scopes: List[Scope], dependencies: List[Depends] +): + """Inject dependencies to routes""" + for route in routes: + if not any(route.matches(scope)[0] == Match.FULL for scope in scopes): + continue + + route.dependant.dependencies = [ + # Mimicking how APIRoute handles dependencies: + # https://github.com/tiangolo/fastapi/blob/1760da0efa55585c19835d81afa8ca386036c325/fastapi/routing.py#L408-L412 + get_parameterless_sub_dependant(depends=depends, path=route.path_format) + for depends in dependencies + ] + route.dependant.dependencies diff --git a/stac_api/runtime/src/validation.py b/stac_api/runtime/src/validation.py new file mode 100644 index 00000000..9f429e3c --- /dev/null +++ b/stac_api/runtime/src/validation.py @@ -0,0 +1,60 @@ +"""Middleware for validating transaction endpoints""" + +import json +import re +from typing import Dict + +from pydantic import BaseModel, Field +from pystac import STACObjectType +from pystac.errors import STACValidationError +from pystac.validation import validate_dict +from src.config import api_settings + +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +path_prefix = api_settings.root_path or "" + + +class BulkItems(BaseModel): + """Validation model for bulk-items endpoint request""" + + items: Dict[str, dict] + method: str = Field(default="insert") + + +class ValidationMiddleware(BaseHTTPMiddleware): + """Middleware that handles STAC collection and item validation in transaction endpoints""" + + async def dispatch(self, request: Request, call_next): + """Middleware dispatch""" + if request.method in ("POST", "PUT"): + try: + body = await request.body() + request_data = json.loads(body) + if re.match( + f"^{path_prefix}/collections(?:/[^/]+)?$", + request.url.path, + ): + validate_dict(request_data, STACObjectType.COLLECTION) + elif re.match( + f"^{path_prefix}/collections/[^/]+/items(?:/[^/]+)?$", + request.url.path, + ): + validate_dict(request_data, STACObjectType.ITEM) + elif re.match( + f"^{path_prefix}/collections/[^/]+/bulk-items$", + request.url.path, + ): + bulk_items = BulkItems(**request_data) + for item_data in bulk_items.items.values(): + validate_dict(item_data, STACObjectType.ITEM) + except STACValidationError as e: + return JSONResponse( + status_code=422, + content={"detail": "Validation Error", "errors": str(e)}, + ) + + response = await call_next(request) + return response diff --git a/stac_api/runtime/tests/__init__.py b/stac_api/runtime/tests/__init__.py new file mode 100644 index 00000000..2d9078d5 --- /dev/null +++ b/stac_api/runtime/tests/__init__.py @@ -0,0 +1 @@ +"""STAC API tests""" diff --git a/stac_api/runtime/tests/conftest.py b/stac_api/runtime/tests/conftest.py new file mode 100644 index 00000000..e584bde6 --- /dev/null +++ b/stac_api/runtime/tests/conftest.py @@ -0,0 +1,338 @@ +""" +Test fixtures and data for STAC Transactions API testing. + +This module contains fixtures and mock data used for testing the STAC API. +It includes valid and invalid STAC collections and items, as well as environment +setup for testing with mock AWS and PostgreSQL configurations. +""" + +import os + +import pytest + +from fastapi.testclient import TestClient + +VALID_COLLECTION = { + "id": "CMIP245-winter-median-pr", + "type": "Collection", + "title": "Projected changes to winter (January, February, and March) cumulative daily precipitation", + "links": [], + "description": "Differences in winter (January, February, and March) cumulative daily precipitation between a historical period (1995 - 2014) and multiple 20-year periods from an ensemble of CMIP6 climate projections (SSP2-4.5) downscaled by NASA Earth Exchange (NEX-GDDP-CMIP6)", + "extent": { + "spatial": {"bbox": [[-126, 30, -104, 51]]}, + "temporal": {"interval": [["2025-01-01T00:00:00Z", "2085-03-31T12:00:00Z"]]}, + }, + "license": "MIT", + "stac_extensions": [ + "https://stac-extensions.github.io/render/v1.0.0/schema.json", + "https://stac-extensions.github.io/item-assets/v1.0.0/schema.json", + ], + "item_assets": { + "cog_default": { + "type": "image/tiff; application=geotiff; profile=cloud-optimized", + "roles": ["data", "layer"], + "title": "Default COG Layer", + "description": "Cloud optimized default layer to display on map", + } + }, + "dashboard:is_periodic": False, + "dashboard:time_density": "year", + "stac_version": "1.0.0", + "renders": { + "dashboard": { + "resampling": "bilinear", + "bidx": [1], + "nodata": "nan", + "colormap_name": "rdbu", + "rescale": [[-60, 60]], + "assets": ["cog_default"], + "title": "VEDA Dashboard Render Parameters", + } + }, + "providers": [ + { + "name": "NASA Center for Climate Simulation (NCCS)", + "url": "https://www.nccs.nasa.gov/services/data-collections/land-based-products/nex-gddp-cmip6", + "roles": ["producer", "processor", "licensor"], + }, + { + "name": "NASA VEDA", + "url": "https://www.earthdata.nasa.gov/dashboard/", + "roles": ["host"], + }, + ], + "assets": { + "thumbnail": { + "title": "Thumbnail", + "description": "Photo by Justin Pflug (Photo of Nisqually glacier)", + "href": "https://thumbnails.openveda.cloud/CMIP-winter-median.jpeg", + "type": "image/jpeg", + "roles": ["thumbnail"], + } + }, +} + +VALID_ITEM = { + "id": "OMI_trno2_0.10x0.10_2023_Col3_V4", + "bbox": [-180.0, -90.0, 180.0, 90.0], + "type": "Feature", + "links": [ + { + "rel": "collection", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr", + }, + { + "rel": "parent", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr", + }, + { + "rel": "root", + "type": "application/json", + "href": "https://dev.openveda.cloud/api/stac/", + }, + { + "rel": "self", + "type": "application/geo+json", + "href": "https://dev.openveda.cloud/api/stac/collections/CMIP245-winter-median-pr/items/OMI_trno2_0.10x0.10_2023_Col3_V4", + }, + { + "title": "Map of Item", + "href": "https://dev.openveda.cloud/api/raster/stac/map?collection=CMIP245-winter-median-pr&item=OMI_trno2_0.10x0.10_2023_Col3_V4&assets=cog_default&rescale=0%2C3000000000000000&colormap_name=reds", + "rel": "preview", + "type": "text/html", + }, + ], + "assets": { + "no2": { + "href": "s3://veda-data-store-staging/OMI_trno2-COG/OMI_trno2_0.10x0.10_2023_Col3_V4.tif", + "type": "image/tiff; application=geotiff", + "roles": ["data", "layer"], + "title": "NO2 values", + "proj:bbox": [-180.0, -90.0, 180.0, 90.0], + "proj:epsg": 4326, + "proj:wkt2": 'GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AXIS["Latitude",NORTH],AXIS["Longitude",EAST],AUTHORITY["EPSG","4326"]]', + "proj:shape": [1800, 3600], + "description": "description", + "raster:bands": [ + { + "scale": 1.0, + "nodata": -1.2676506002282294e30, + "offset": 0.0, + "sampling": "area", + "data_type": "float32", + "histogram": { + "max": 14863169193246720, + "min": -2293753591103488.0, + "count": 11, + "buckets": [57, 484234, 23295, 2552, 694, 318, 230, 79, 42, 12], + }, + "statistics": { + "mean": 365095923477877.9, + "stddev": 569167954388057.0, + "maximum": 14863169193246720, + "minimum": -2293753591103488.0, + "valid_percent": 97.56336212158203, + }, + } + ], + "proj:geometry": { + "type": "Polygon", + "coordinates": [ + [ + [-180.0, -90.0], + [180.0, -90.0], + [180.0, 90.0], + [-180.0, 90.0], + [-180.0, -90.0], + ] + ], + }, + "proj:projjson": { + "id": {"code": 4326, "authority": "EPSG"}, + "name": "WGS 84", + "type": "GeographicCRS", + "datum": { + "name": "World Geodetic System 1984", + "type": "GeodeticReferenceFrame", + "ellipsoid": { + "name": "WGS 84", + "semi_major_axis": 6378137, + "inverse_flattening": 298.257223563, + }, + }, + "$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", + "coordinate_system": { + "axis": [ + { + "name": "Geodetic latitude", + "unit": "degree", + "direction": "north", + "abbreviation": "Lat", + }, + { + "name": "Geodetic longitude", + "unit": "degree", + "direction": "east", + "abbreviation": "Lon", + }, + ], + "subtype": "ellipsoidal", + }, + }, + "proj:transform": [0.1, 0.0, -180.0, 0.0, -0.1, 90.0, 0.0, 0.0, 1.0], + }, + "rendered_preview": { + "title": "Rendered preview", + "href": "https://dev.openveda.cloud/api/raster/stac/preview.png?collection=CMIP245-winter-median-pr&item=OMI_trno2_0.10x0.10_2023_Col3_V4&assets=cog_default&rescale=0%2C3000000000000000&colormap_name=reds", + "rel": "preview", + "roles": ["overview"], + "type": "image/png", + }, + }, + "geometry": { + "type": "Polygon", + "coordinates": [[[-180, -90], [180, -90], [180, 90], [-180, 90], [-180, -90]]], + }, + "collection": "CMIP245-winter-median-pr", + "properties": { + "end_datetime": "2023-12-31T00:00:00+00:00", + "start_datetime": "2023-01-01T00:00:00+00:00", + "datetime": None, + }, + "stac_version": "1.0.0", + "stac_extensions": [ + "https://stac-extensions.github.io/raster/v1.1.0/schema.json", + "https://stac-extensions.github.io/projection/v1.1.0/schema.json", + ], +} + + +@pytest.fixture +def test_environ(): + """ + Set up the test environment with mocked AWS and PostgreSQL credentials. + + This fixture sets environment variables to mock AWS credentials and + PostgreSQL database configuration for testing purposes. + """ + # Mocked AWS Credentials for moto (best practice recommendation from moto) + os.environ["AWS_ACCESS_KEY_ID"] = "testing" + os.environ["AWS_SECRET_ACCESS_KEY"] = "testing" + os.environ["AWS_SECURITY_TOKEN"] = "testing" + os.environ["AWS_SESSION_TOKEN"] = "testing" + os.environ["AWS_REGION"] = "us-west-2" + os.environ["VEDA_STAC_USERPOOL_ID"] = "us-west-2_FAKEUSERPOOL" + os.environ["VEDA_STAC_CLIENT_ID"] = "Xdjkfghadsfkdsadfjas" + os.environ["VEDA_STAC_CLIENT_SECRET"] = "dsakfjdsalfkjadslfjalksfj" + os.environ[ + "VEDA_STAC_COGNITO_DOMAIN" + ] = "https://fake.auth.us-west-2.amazoncognito.com" + os.environ["VEDA_STAC_ENABLE_TRANSACTIONS"] = "TRUE" + + # Config mocks + os.environ["POSTGRES_USER"] = "username" + os.environ["POSTGRES_PASS"] = "password" + os.environ["POSTGRES_DBNAME"] = "postgis" + os.environ["POSTGRES_HOST_READER"] = "database" + os.environ["POSTGRES_HOST_WRITER"] = "database" + os.environ["POSTGRES_PORT"] = "5432" + + +def override_validated_token(): + """ + Mock function to override validated token dependency. + + Returns: + str: A fake token to bypass authorization in tests. + """ + return "fake_token" + + +@pytest.fixture +def app(test_environ): + """ + Fixture to initialize the FastAPI application. + + This fixture imports and returns the FastAPI application instance + for testing purposes. + + Args: + test_environ: A fixture setting up the test environment. + + Returns: + FastAPI: The FastAPI application instance. + """ + from src.app import app + + return app + + +@pytest.fixture +def api_client(app): + """ + Fixture to initialize the API client for making requests. + + This fixture creates a TestClient instance for interacting with the + FastAPI application, and sets up dependency overrides for testing. + + Args: + app: A fixture providing the FastAPI application instance. + + Yields: + TestClient: The TestClient instance for API testing. + """ + from src.app import auth + + app.dependency_overrides[auth.validated_token] = override_validated_token + yield TestClient(app) + app.dependency_overrides.clear() + + +@pytest.fixture +def valid_stac_collection(): + """ + Fixture providing a valid STAC collection for testing. + + Returns: + dict: A valid STAC collection. + """ + return VALID_COLLECTION + + +@pytest.fixture +def invalid_stac_collection(): + """ + Fixture providing an invalid STAC collection for testing. + + Returns: + dict: An invalid STAC collection with the 'extent' field removed. + """ + invalid = VALID_COLLECTION.copy() + invalid.pop("extent") + return invalid + + +@pytest.fixture +def valid_stac_item(): + """ + Fixture providing a valid STAC item for testing. + + Returns: + dict: A valid STAC item. + """ + return VALID_ITEM + + +@pytest.fixture +def invalid_stac_item(): + """ + Fixture providing an invalid STAC item for testing. + + Returns: + dict: An invalid STAC item with the 'properties' field removed. + """ + invalid_item = VALID_ITEM.copy() + invalid_item.pop("properties") + return invalid_item diff --git a/stac_api/runtime/tests/test_transactions.py b/stac_api/runtime/tests/test_transactions.py new file mode 100644 index 00000000..6a5cd8e3 --- /dev/null +++ b/stac_api/runtime/tests/test_transactions.py @@ -0,0 +1,137 @@ +""" +Test suite for STAC (SpatioTemporal Asset Catalog) Transactions API endpoints. + +This module contains tests for the collection and item endpoints of the STAC API. +It verifies the behavior of the API when posting valid and invalid STAC collections and items, +as well as bulk items. + +Endpoints tested: +- /collections +- /collections/{}/items +- /collections/{}/bulk_items +""" + +import pytest + +collections_endpoint = "/collections" +items_endpoint = "/collections/{}/items" +bulk_endpoint = "/collections/{}/bulk_items" + + +class TestList: + """ + Test cases for STAC API's collection and item endpoints. + + This class contains tests to ensure that the STAC API correctly handles + posting valid and invalid STAC collections and items, both individually + and in bulk. It uses pytest fixtures to set up the test environment with + necessary data. + """ + + @pytest.fixture(autouse=True) + def setup( + self, + api_client, + valid_stac_collection, + valid_stac_item, + invalid_stac_collection, + invalid_stac_item, + ): + """ + Set up the test environment with the required fixtures. + + Args: + api_client: The API client for making requests. + valid_stac_collection: A valid STAC collection for testing. + valid_stac_item: A valid STAC item for testing. + invalid_stac_collection: An invalid STAC collection for testing. + invalid_stac_item: An invalid STAC item for testing. + """ + self.api_client = api_client + self.valid_stac_collection = valid_stac_collection + self.valid_stac_item = valid_stac_item + self.invalid_stac_collection = invalid_stac_collection + self.invalid_stac_item = invalid_stac_item + + def test_post_invalid_collection(self): + """ + Test the API's response to posting an invalid STAC collection. + + Asserts that the response status code is 422 and the detail + is "Validation Error". + """ + response = self.api_client.post( + collections_endpoint, json=self.invalid_stac_collection + ) + assert response.json()["detail"] == "Validation Error" + assert response.status_code == 422 + + def test_post_valid_collection(self): + """ + Test the API's response to posting a valid STAC collection. + + Asserts that the response status code is 200. + """ + response = self.api_client.post( + collections_endpoint, json=self.valid_stac_collection + ) + # assert response.json() == {} + assert response.status_code == 200 + + def test_post_invalid_item(self): + """ + Test the API's response to posting an invalid STAC item. + + Asserts that the response status code is 422 and the detail + is "Validation Error". + """ + response = self.api_client.post( + items_endpoint.format(self.invalid_stac_item["collection"]), + json=self.invalid_stac_item, + ) + assert response.json()["detail"] == "Validation Error" + assert response.status_code == 422 + + def test_post_valid_item(self): + """ + Test the API's response to posting a valid STAC item. + + Asserts that the response status code is 200. + """ + response = self.api_client.post( + items_endpoint.format(self.valid_stac_item["collection"]), + json=self.valid_stac_item, + ) + # assert response.json() == {} + assert response.status_code == 200 + + def test_post_invalid_bulk_items(self): + """ + Test the API's response to posting invalid bulk STAC items. + + Asserts that the response status code is 422. + """ + item_id = self.invalid_stac_item["id"] + collection_id = self.invalid_stac_item["collection"] + invalid_request = { + "items": {item_id: self.invalid_stac_item}, + "method": "upsert", + } + response = self.api_client.post( + bulk_endpoint.format(collection_id), json=invalid_request + ) + assert response.status_code == 422 + + def test_post_valid_bulk_items(self): + """ + Test the API's response to posting valid bulk STAC items. + + Asserts that the response status code is 200. + """ + item_id = self.valid_stac_item["id"] + collection_id = self.valid_stac_item["collection"] + valid_request = {"items": {item_id: self.valid_stac_item}, "method": "upsert"} + response = self.api_client.post( + bulk_endpoint.format(collection_id), json=valid_request + ) + assert response.status_code == 200