From 9e2d0bb99dfd844681d829f29286c8df8a9b6460 Mon Sep 17 00:00:00 2001 From: Kristjan Eimre Date: Mon, 22 Sep 2025 15:51:19 +0200 Subject: [PATCH 1/9] Move global singletons to app.state to enable multiple APIs in one process - fix tests - fix tests for real mongodb - fix tests for elasticsearch --- docs/api_reference/server/create_app.md | 3 + optimade/client/client.py | 19 +- .../filtertransformers/base_transformer.py | 4 +- optimade/filtertransformers/elasticsearch.py | 2 +- optimade/server/config.py | 32 +- optimade/server/create_app.py | 270 +++++++++++ optimade/server/entry_collections/__init__.py | 8 +- .../server/entry_collections/elasticsearch.py | 40 +- .../entry_collections/entry_collections.py | 59 ++- optimade/server/entry_collections/mongo.py | 79 +++- optimade/server/exception_handlers.py | 9 +- optimade/server/index_links.json | 2 +- optimade/server/logger.py | 11 +- optimade/server/main.py | 177 +------- optimade/server/main_index.py | 145 +----- optimade/server/mappers/entries.py | 422 +++++------------- optimade/server/mappers/links.py | 3 +- optimade/server/middleware.py | 32 +- optimade/server/query_params.py | 9 +- optimade/server/routers/__init__.py | 9 - optimade/server/routers/index_info.py | 10 +- optimade/server/routers/info.py | 29 +- optimade/server/routers/landing.py | 81 ++-- optimade/server/routers/links.py | 20 +- optimade/server/routers/references.py | 24 +- optimade/server/routers/structures.py | 24 +- optimade/server/routers/utils.py | 50 ++- optimade/server/schemas.py | 10 +- optimade/utils.py | 13 +- tests/conftest.py | 25 -- tests/filtertransformers/test_base.py | 14 +- .../filtertransformers/test_elasticsearch.py | 2 +- tests/filtertransformers/test_mongo.py | 44 +- tests/models/test_links.py | 12 +- tests/models/test_references.py | 11 +- tests/models/test_structures.py | 13 +- tests/server/conftest.py | 5 +- .../test_entry_collections.py | 11 +- .../server/entry_collections/test_indexes.py | 17 +- tests/server/middleware/test_api_hint.py | 14 +- tests/server/middleware/test_query_param.py | 10 +- tests/server/middleware/test_versioned_url.py | 19 +- tests/server/middleware/test_warnings.py | 39 +- tests/server/query_params/conftest.py | 2 +- tests/server/query_params/test_filter.py | 4 +- tests/server/query_params/test_include.py | 8 +- tests/server/routers/test_structures.py | 4 +- tests/server/test_client.py | 115 ++--- tests/server/test_config.py | 81 ++-- tests/server/test_mappers.py | 53 +-- tests/server/test_schemas.py | 5 +- tests/server/test_server_validation.py | 10 +- tests/server/utils.py | 2 - 53 files changed, 978 insertions(+), 1138 deletions(-) create mode 100644 docs/api_reference/server/create_app.md create mode 100644 optimade/server/create_app.py diff --git a/docs/api_reference/server/create_app.md b/docs/api_reference/server/create_app.md new file mode 100644 index 000000000..a2333d7c0 --- /dev/null +++ b/docs/api_reference/server/create_app.md @@ -0,0 +1,3 @@ +# create_app + +::: optimade.server.create_app diff --git a/optimade/client/client.py b/optimade/client/client.py index 7ba38bc4b..60573d109 100644 --- a/optimade/client/client.py +++ b/optimade/client/client.py @@ -511,11 +511,12 @@ def _binary_search_count_async( ) ) + # if we got any data, we are below the target value + below = bool(result[base_url].data) + self._progress.disable = self.silent - window, probe = self._update_probe_and_window( - window, probe, bool(result[base_url].data) - ) + window, probe = self._update_probe_and_window(window, probe, below) if window[0] == window[1] and window[0] == probe: return probe @@ -557,16 +558,15 @@ def _update_probe_and_window( raise RuntimeError( "Invalid arguments: must provide all or none of window, last_probe and below parameters" ) - probe: int = last_probe # Exit condition: find a range of (count, count+1) values # and determine whether the probe was above or below in the last guess if window[1] is not None and window[1] - window[0] == 1: if below: - return (window[0], window[0]), window[0] - else: return (window[1], window[1]), window[1] + else: + return (window[0], window[0]), window[0] # Enclose the real value in the window, with `None` indicating an open boundary if below: @@ -578,12 +578,13 @@ def _update_probe_and_window( if window[1] is None: probe *= 10 - # Otherwise, if we're in the window and the ends of the window now have the same power of 10, take the average (102 => 108) => 105 - elif round(math.log10(window[0])) == round(math.log10(window[0])): + # Otherwise, if we're in the window and the ends of the window now have the same power of 10 (or within +-1), + # take the average (102 => 108) => 105 + elif abs(math.log10(window[1]) - math.log10(window[0])) <= 1: probe = (window[1] + window[0]) // 2 # otherwise use logarithmic average (10, 1000) => 100 else: - probe = int(10 ** (math.log10(window[1]) + math.log10(window[0]) / 2)) + probe = int(10 ** ((math.log10(window[1]) + math.log10(window[0])) / 2)) return window, probe diff --git a/optimade/filtertransformers/base_transformer.py b/optimade/filtertransformers/base_transformer.py index 3a984c29c..0acf46d02 100644 --- a/optimade/filtertransformers/base_transformer.py +++ b/optimade/filtertransformers/base_transformer.py @@ -82,7 +82,7 @@ class BaseTransformer(Transformer, abc.ABC): """ - mapper: type[BaseResourceMapper] | None = None + mapper: BaseResourceMapper | None = None operator_map: dict[str, str | None] = { "<": None, "<=": None, @@ -106,7 +106,7 @@ class BaseTransformer(Transformer, abc.ABC): _quantity_type: type[Quantity] = Quantity _quantities = None - def __init__(self, mapper: type[BaseResourceMapper] | None = None): + def __init__(self, mapper: BaseResourceMapper | None = None): """Initialise the transformer object, optionally loading in a resource mapper for use when post-processing. diff --git a/optimade/filtertransformers/elasticsearch.py b/optimade/filtertransformers/elasticsearch.py index 41c919d2b..d30272d49 100644 --- a/optimade/filtertransformers/elasticsearch.py +++ b/optimade/filtertransformers/elasticsearch.py @@ -101,7 +101,7 @@ class ElasticTransformer(BaseTransformer): def __init__( self, - mapper: type[BaseResourceMapper], + mapper: BaseResourceMapper, quantities: dict[str, Quantity] | None = None, ): if quantities is not None: diff --git a/optimade/server/config.py b/optimade/server/config.py index 38a091bf8..90d17414d 100644 --- a/optimade/server/config.py +++ b/optimade/server/config.py @@ -175,6 +175,7 @@ class ServerConfig(BaseSettings): extra="allow", env_file_encoding="utf-8", case_sensitive=False, + validate_assignment=True, ) debug: Annotated[ @@ -370,12 +371,13 @@ class ServerConfig(BaseSettings): list[str | dict[Literal["name", "type", "unit", "description"], str]], ], Field( + default_factory=dict, description=( "A list of additional fields to be served with the provider's prefix " "attached, broken down by endpoint." ), ), - ] = {} + ] aliases: Annotated[ dict[Literal["links", "references", "structures"], dict[str, str]], Field( @@ -540,32 +542,6 @@ def check_license_info(cls, value: Any) -> AnyHttpUrl | None: return value - @model_validator(mode="after") - def use_real_mongo_override(self) -> "ServerConfig": - """Overrides the `database_backend` setting with MongoDB and - raises a deprecation warning. - """ - use_real_mongo = self.use_real_mongo - - # Remove from model - del self.use_real_mongo - - # Remove from set of user-defined fields - if "use_real_mongo" in self.model_fields_set: - self.model_fields_set.remove("use_real_mongo") - - if use_real_mongo is not None: - warnings.warn( - "'use_real_mongo' is deprecated, please set the appropriate 'database_backend' " - "instead.", - DeprecationWarning, - ) - - if use_real_mongo: - self.database_backend = SupportedBackend.MONGODB - - return self - @model_validator(mode="after") def align_mongo_uri_and_mongo_database(self) -> "ServerConfig": """Prefer the value of database name if set from `mongo_uri` rather than @@ -623,7 +599,7 @@ def settings_customise_sources( ) -CONFIG: ServerConfig = ServerConfig() +# CONFIG: ServerConfig = ServerConfig() """This singleton loads the config from a hierarchy of sources (see [`customise_sources`][optimade.server.config.ServerConfig.settings_customise_sources]) and makes it importable in the server code. diff --git a/optimade/server/create_app.py b/optimade/server/create_app.py new file mode 100644 index 000000000..0c329fced --- /dev/null +++ b/optimade/server/create_app.py @@ -0,0 +1,270 @@ +"""The OPTIMADE server + +The server is based on MongoDB, using either `pymongo` or `mongomock`. + +This is an example implementation with example data. +To implement your own server see the documentation at https://optimade.org/optimade-python-tools. +""" + +import json +import os +import warnings +from pathlib import Path + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from fastapi.middleware.gzip import GZipMiddleware + +with warnings.catch_warnings(record=True) as w: + from optimade.server.config import DEFAULT_CONFIG_FILE_PATH, ServerConfig + + config_warnings = w + +from optimade import __api_version__, __version__ +from optimade.server.entry_collections import EntryCollection, create_entry_collections +from optimade.server.exception_handlers import OPTIMADE_EXCEPTIONS +from optimade.server.logger import LOGGER +from optimade.server.middleware import OPTIMADE_MIDDLEWARE +from optimade.server.routers import ( + index_info, + info, + landing, + links, + references, + structures, + versions, +) +from optimade.server.routers.utils import BASE_URL_PREFIXES, JSONAPIResponse + +MAIN_ENDPOINTS = [info, links, references, structures, landing] +INDEX_ENDPOINTS = [index_info, links] + + +def add_major_version_base_url(app: FastAPI, index: bool = False): + """Add mandatory vMajor endpoints, i.e. all except /versions.""" + for endpoint in INDEX_ENDPOINTS if index else MAIN_ENDPOINTS: + app.include_router( + endpoint.router, prefix=BASE_URL_PREFIXES["major"], include_in_schema=False + ) + + +def add_optional_versioned_base_urls(app: FastAPI, index: bool = False): + """Add the following OPTIONAL prefixes/base URLs to server: + ``` + /vMajor.Minor + /vMajor.Minor.Patch + ``` + """ + for version in ("minor", "patch"): + for endpoint in INDEX_ENDPOINTS if index else MAIN_ENDPOINTS: + app.include_router( + endpoint.router, + prefix=BASE_URL_PREFIXES[version], + include_in_schema=False, + ) + + +def insert_main_data( + config: ServerConfig, entry_collections: dict[str, EntryCollection] +): + from optimade.utils import insert_from_jsonl + + def _insert_test_data(endpoint: str | None = None): + import bson.json_util + from bson.objectid import ObjectId + + import optimade.server.data as data + from optimade.server.routers.utils import get_providers + + def load_entries(endpoint_name: str, endpoint_collection: EntryCollection): + LOGGER.debug("Loading test %s...", endpoint_name) + + endpoint_collection.insert(getattr(data, endpoint_name, [])) + if ( + config.database_backend.value in ("mongomock", "mongodb") + and endpoint_name == "links" + ): + LOGGER.debug( + "Adding Materials-Consortia providers to links from optimade.org" + ) + providers = get_providers(add_mongo_id=True) + for doc in providers: + endpoint_collection.collection.replace_one( # type: ignore[attr-defined] + filter={"_id": ObjectId(doc["_id"]["$oid"])}, + replacement=bson.json_util.loads(bson.json_util.dumps(doc)), + upsert=True, + ) + LOGGER.debug("Done inserting test %s!", endpoint_name) + + if endpoint: + load_entries(endpoint, entry_collections[endpoint]) + else: + for name, collection in entry_collections.items(): + load_entries(name, collection) + + if config.insert_from_jsonl: + jsonl_path = Path(config.insert_from_jsonl) + LOGGER.debug("Inserting data from JSONL file: %s", jsonl_path) + if not jsonl_path.exists(): + raise RuntimeError( + f"Requested JSONL file does not exist: {jsonl_path}. Please specify an absolute group." + ) + + insert_from_jsonl( + jsonl_path, + entry_collections, + create_default_index=config.create_default_index, + ) + + LOGGER.debug("Inserted data from JSONL file: %s", jsonl_path) + if config.insert_test_data: + _insert_test_data("links") + elif config.insert_test_data: + _insert_test_data() + + if config.exit_after_insert: + LOGGER.info("Exiting after inserting test data.") + import sys + + sys.exit(0) + + +def insert_index_data( + config: ServerConfig, entry_collections: dict[str, EntryCollection] +): + import bson.json_util + from bson.objectid import ObjectId + + from optimade.server.routers.utils import get_providers, mongo_id_for_database + + links_coll = entry_collections["links"] + + if len(links_coll) > 0: + LOGGER.info("Skipping index links inserct: links collection already populated.") + return + + LOGGER.debug("Loading index links...") + with open(config.index_links_path) as f: + data = json.load(f) + + processed = [] + for db in data: + db["_id"] = {"$oid": mongo_id_for_database(db["id"], db["type"])} + processed.append(db) + + LOGGER.debug( + "Inserting index links into collection from %s...", config.index_links_path + ) + + links_coll.insert(bson.json_util.loads(bson.json_util.dumps(processed))) + + if config.database_backend.value in ("mongodb", "mongomock"): + LOGGER.debug( + "Adding Materials-Consortia providers to links from optimade.org..." + ) + providers = get_providers(add_mongo_id=True) + for doc in providers: + links_coll.collection.replace_one( # type: ignore[attr-defined] + filter={"_id": ObjectId(doc["_id"]["$oid"])}, + replacement=bson.json_util.loads(bson.json_util.dumps(doc)), + upsert=True, + ) + + LOGGER.debug("Done inserting index links!") + + else: + LOGGER.warning( + "Not inserting test data for index meta-database for backend %s", + config.database_backend.value, + ) + + +DESCRIPTION_TEMPLATE = """ +The [Open Databases Integration for Materials Design (OPTIMADE) consortium](https://www.optimade.org/) aims to make materials databases interoperational by developing a common REST API. +{index_meta_text} +This specification is generated using [`optimade-python-tools`](https://github.com/Materials-Consortia/optimade-python-tools/tree/v{version}) v{version}. +""" + + +def create_app(config: ServerConfig | None = None, index: bool = False) -> FastAPI: + if config_warnings: + LOGGER.warning( + f"Invalid config file or no config file provided, running server with default settings. Errors: " + f"{[warnings.formatwarning(w.message, w.category, w.filename, w.lineno, '') for w in config_warnings]}" + ) + else: + LOGGER.info( + f"Loaded settings from {os.getenv('OPTIMADE_CONFIG_FILE', DEFAULT_CONFIG_FILE_PATH)}." + ) + + if config is None: + config = ServerConfig() + + if config.debug: # pragma: no cover + LOGGER.info("DEBUG MODE") + + title = "OPTIMADE API" if not index else "OPTIMADE API - Index meta-database" + description = """The [Open Databases Integration for Materials Design (OPTIMADE) consortium](https://www.optimade.org/) aims to make materials databases interoperational by developing a common REST API.\n""" + if index: + description += 'This is the "special" index meta-database.\n' + description += f"\nThis specification is generated using [`optimade-python-tools`](https://github.com/Materials-Consortia/optimade-python-tools/tree/v{__version__}) v{__version__}." + + if index: + config.is_index = True + + app = FastAPI( + root_path=config.root_path, + title=title, + description=description, + version=__api_version__, + docs_url=f"{BASE_URL_PREFIXES['major']}/extensions/docs", + redoc_url=f"{BASE_URL_PREFIXES['major']}/extensions/redoc", + openapi_url=f"{BASE_URL_PREFIXES['major']}/extensions/openapi.json", + default_response_class=JSONAPIResponse, + separate_input_output_schemas=False, + ) + + # Save the config in the app state for access in endpoints + app.state.config = config + + # create entry collections and save in app state for access in endpoints + entry_collections = create_entry_collections(config) + app.state.entry_collections = entry_collections + + if not index: + if config.insert_test_data or config.insert_from_jsonl: + insert_main_data(config, entry_collections) + else: + if config.insert_test_data and config.index_links_path.exists(): + insert_index_data(config, entry_collections) + + # Add CORS middleware first + app.add_middleware(CORSMiddleware, allow_origins=["*"]) + + # Then add required OPTIMADE middleware + for middleware in OPTIMADE_MIDDLEWARE: + app.add_middleware(middleware) + + # Enable GZIP after other middleware. + if config.gzip.enabled: + app.add_middleware( + GZipMiddleware, + minimum_size=config.gzip.minimum_size, + compresslevel=config.gzip.compresslevel, + ) + + # Add exception handlers + for exception, handler in OPTIMADE_EXCEPTIONS: + app.add_exception_handler(exception, handler) + + # Add various endpoints to unversioned URL + endpoints = INDEX_ENDPOINTS if index else MAIN_ENDPOINTS + endpoints += [versions] + for endpoint in endpoints: + app.include_router(endpoint.router) + + # add the versioned endpoints + add_major_version_base_url(app, index=index) + add_optional_versioned_base_urls(app, index=index) + + return app diff --git a/optimade/server/entry_collections/__init__.py b/optimade/server/entry_collections/__init__.py index 2f980147b..5e749cd8a 100644 --- a/optimade/server/entry_collections/__init__.py +++ b/optimade/server/entry_collections/__init__.py @@ -1,3 +1,7 @@ -from .entry_collections import EntryCollection, PaginationMechanism, create_collection +from .entry_collections import ( + EntryCollection, + PaginationMechanism, + create_entry_collections, +) -__all__ = ("EntryCollection", "create_collection", "PaginationMechanism") +__all__ = ("EntryCollection", "create_entry_collections", "PaginationMechanism") diff --git a/optimade/server/entry_collections/elasticsearch.py b/optimade/server/entry_collections/elasticsearch.py index 5bc52016a..b6434b358 100644 --- a/optimade/server/entry_collections/elasticsearch.py +++ b/optimade/server/entry_collections/elasticsearch.py @@ -3,20 +3,21 @@ from pathlib import Path from typing import Any, Optional +from elasticsearch import Elasticsearch + from optimade.filtertransformers.elasticsearch import ElasticTransformer from optimade.models import EntryResource -from optimade.server.config import CONFIG +from optimade.server.config import ServerConfig from optimade.server.entry_collections import EntryCollection, PaginationMechanism from optimade.server.logger import LOGGER from optimade.server.mappers import BaseResourceMapper -if CONFIG.database_backend.value == "elastic": - from elasticsearch import Elasticsearch - from elasticsearch.helpers import bulk - from elasticsearch_dsl import Search - CLIENT = Elasticsearch(hosts=CONFIG.elastic_hosts) - LOGGER.info("Using: Elasticsearch backend at %s", CONFIG.elastic_hosts) +def get_elastic_client(config: ServerConfig) -> Optional["Elasticsearch"]: + if config.database_backend.value == "elastic": + LOGGER.info("Using: Elasticsearch backend at %s", config.elastic_hosts) + return Elasticsearch(hosts=config.elastic_hosts) + return None class ElasticCollection(EntryCollection): @@ -26,7 +27,8 @@ def __init__( self, name: str, resource_cls: type[EntryResource], - resource_mapper: type[BaseResourceMapper], + resource_mapper: BaseResourceMapper, + config: ServerConfig, client: Optional["Elasticsearch"] = None, ): """Initialize the ElasticCollection for the given parameters. @@ -43,9 +45,20 @@ def __init__( resource_cls=resource_cls, resource_mapper=resource_mapper, transformer=ElasticTransformer(mapper=resource_mapper), + config=config, ) - self.client = client if client else CLIENT + self.config = config + + # Normalize: always end with a concrete Elasticsearch client + tmp_client = client if client is not None else get_elastic_client(config) + if tmp_client is None: + raise RuntimeError( + "Tried to create ElasticCollection without an Elasticsearch backend" + ) + + self.client: Elasticsearch = tmp_client + self.name = name def count(self, *args, **kwargs) -> int: @@ -92,7 +105,7 @@ def predefined_index(self) -> dict[str, Any]: @staticmethod def create_elastic_index_from_mapper( - resource_mapper: type[BaseResourceMapper], fields: Iterable[str] + resource_mapper: BaseResourceMapper, fields: Iterable[str] ) -> dict[str, Any]: """Create a fallback elastic index based on a resource mapper. @@ -114,6 +127,8 @@ def create_elastic_index_from_mapper( def __len__(self): """Returns the total number of entries in the collection.""" + from elasticsearch_dsl import Search + return Search(using=self.client, index=self.name).execute().hits.total.value def insert(self, data: list[EntryResource | dict]) -> None: @@ -141,6 +156,8 @@ def get_id(item): item.pop("_id", None) return id_ + from elasticsearch.helpers import bulk + bulk( self.client, ( @@ -167,6 +184,7 @@ def _run_db_query( entries matching the query and a boolean for whether or not there is more data available. """ + from elasticsearch_dsl import Search search = Search(using=self.client, index=self.name) @@ -176,7 +194,7 @@ def _run_db_query( page_offset = criteria.get("skip", None) page_above = criteria.get("page_above", None) - limit = criteria.get("limit", CONFIG.page_limit) + limit = criteria.get("limit", self.config.page_limit) all_aliased_fields = [ self.resource_mapper.get_backend_field(field) for field in self.all_fields diff --git a/optimade/server/entry_collections/entry_collections.py b/optimade/server/entry_collections/entry_collections.py index caa4593fe..cbcaa23a7 100644 --- a/optimade/server/entry_collections/entry_collections.py +++ b/optimade/server/entry_collections/entry_collections.py @@ -11,7 +11,7 @@ from optimade.filterparser import LarkParser from optimade.models import Attributes, EntryResource from optimade.models.types import NoneType, _get_origin_type -from optimade.server.config import CONFIG, SupportedBackend +from optimade.server.config import ServerConfig, SupportedBackend from optimade.server.mappers import BaseResourceMapper from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams from optimade.warnings import ( @@ -24,10 +24,11 @@ def create_collection( name: str, resource_cls: type[EntryResource], - resource_mapper: type[BaseResourceMapper], + resource_mapper: BaseResourceMapper, + config: ServerConfig, ) -> "EntryCollection": """Create an `EntryCollection` of the configured type, depending on the value of - `CONFIG.database_backend`. + `config.database_backend`. Arguments: name: The collection name. @@ -38,7 +39,7 @@ def create_collection( The created `EntryCollection`. """ - if CONFIG.database_backend in ( + if config.database_backend in ( SupportedBackend.MONGODB, SupportedBackend.MONGOMOCK, ): @@ -48,19 +49,21 @@ def create_collection( name=name, resource_cls=resource_cls, resource_mapper=resource_mapper, + config=config, ) - if CONFIG.database_backend is SupportedBackend.ELASTIC: + if config.database_backend is SupportedBackend.ELASTIC: from optimade.server.entry_collections.elasticsearch import ElasticCollection return ElasticCollection( name=name, resource_cls=resource_cls, resource_mapper=resource_mapper, + config=config, ) raise NotImplementedError( - f"The database backend {CONFIG.database_backend!r} is not implemented" + f"The database backend {config.database_backend!r} is not implemented" ) @@ -86,8 +89,9 @@ class EntryCollection(ABC): def __init__( self, resource_cls: type[EntryResource], - resource_mapper: type[BaseResourceMapper], + resource_mapper: BaseResourceMapper, transformer: Transformer, + config: ServerConfig, ): """Initialize the collection for the given parameters. @@ -105,11 +109,12 @@ def __init__( self.resource_cls = resource_cls self.resource_mapper = resource_mapper self.transformer = transformer + self.config = config - self.provider_prefix = CONFIG.provider.prefix + self.provider_prefix = config.provider.prefix self.provider_fields = [ field if isinstance(field, str) else field["name"] - for field in CONFIG.provider_fields.get(resource_mapper.ENDPOINT, []) + for field in config.provider_fields.get(resource_mapper.ENDPOINT, []) ] self._all_fields: set[str] = set() @@ -217,7 +222,7 @@ def find( results = results[0] if ( - CONFIG.validate_api_response + self.config.validate_api_response and data_returned is not None and data_returned > 1 ): @@ -373,13 +378,13 @@ def handle_query_params( # page_limit if getattr(params, "page_limit", False): limit = params.page_limit # type: ignore[union-attr] - if limit > CONFIG.page_limit_max: + if limit > self.config.page_limit_max: raise Forbidden( - detail=f"Max allowed page_limit is {CONFIG.page_limit_max}, you requested {limit}", + detail=f"Max allowed page_limit is {self.config.page_limit_max}, you requested {limit}", ) cursor_kwargs["limit"] = limit else: - cursor_kwargs["limit"] = CONFIG.page_limit + cursor_kwargs["limit"] = self.config.page_limit # response_fields cursor_kwargs["projection"] = { @@ -530,3 +535,31 @@ def get_next_query_params( ] return query + + +from optimade.models import LinksResource, ReferenceResource, StructureResource +from optimade.server.config import ServerConfig +from optimade.server.mappers import LinksMapper, ReferenceMapper, StructureMapper + + +def create_entry_collections(config: ServerConfig): + return { + "links": create_collection( + name=config.links_collection, + resource_cls=LinksResource, + resource_mapper=LinksMapper(config), + config=config, + ), + "references": create_collection( + name=config.references_collection, + resource_cls=ReferenceResource, + resource_mapper=ReferenceMapper(config), + config=config, + ), + "structures": create_collection( + name=config.structures_collection, + resource_cls=StructureResource, + resource_mapper=StructureMapper(config), + config=config, + ), + } diff --git a/optimade/server/entry_collections/mongo.py b/optimade/server/entry_collections/mongo.py index e4e2c1e8f..e205e5826 100644 --- a/optimade/server/entry_collections/mongo.py +++ b/optimade/server/entry_collections/mongo.py @@ -1,32 +1,59 @@ +import atexit from typing import Any +from pymongo.errors import ExecutionTimeout + from optimade.filtertransformers.mongo import MongoTransformer from optimade.models import EntryResource -from optimade.server.config import CONFIG, SupportedBackend +from optimade.server.config import ServerConfig, SupportedBackend from optimade.server.entry_collections import EntryCollection from optimade.server.logger import LOGGER from optimade.server.mappers import BaseResourceMapper from optimade.server.query_params import EntryListingQueryParams, SingleEntryQueryParams -if CONFIG.database_backend.value == "mongodb": - from pymongo import MongoClient, version_tuple - from pymongo.errors import ExecutionTimeout +_CLIENTS: dict[tuple[str, str], Any] = {} + + +def _close_all_clients(log: bool = True): + for (backend, uri), client in list(_CLIENTS.items()): + try: + client.close() + if log: + LOGGER.debug(f"Closed MongoClient for {backend} {uri}") + except Exception as exc: + if log: + LOGGER.warning(f"Failed closing MongoClient {backend} {uri}: {exc}") + finally: + _CLIENTS.pop((backend, uri), None) + + +atexit.register(lambda: _close_all_clients(log=False)) - if version_tuple[0] < 4: - LOGGER.warning( - "Support for pymongo<=3 (and thus MongoDB v3) is deprecated and will be " - "removed in the next minor release." - ) - LOGGER.info("Using: Real MongoDB (pymongo)") +def get_mongo_client(config: ServerConfig): + """Return a cached MongoClient for (backend, uri), creating it if necessary.""" + backend = config.database_backend.value + uri = config.mongo_uri + key = (backend, uri) -elif CONFIG.database_backend.value == "mongomock": - from mongomock import MongoClient + if key in _CLIENTS: + return _CLIENTS[key] - LOGGER.info("Using: Mock MongoDB (mongomock)") + if backend == "mongodb": + from pymongo import MongoClient -if CONFIG.database_backend.value in ("mongomock", "mongodb"): - CLIENT = MongoClient(CONFIG.mongo_uri) + LOGGER.info(f"Using: Real MongoDB (pymongo) @ {uri}") + client = MongoClient(uri) + elif backend == "mongomock": + from mongomock import MongoClient + + LOGGER.info(f"Using: Mock MongoDB (mongomock) @ {uri}") + client = MongoClient(uri) + else: + raise ValueError(f"Unsupported backend {backend}") + + _CLIENTS[key] = client + return client class MongoCollection(EntryCollection): @@ -39,8 +66,8 @@ def __init__( self, name: str, resource_cls: type[EntryResource], - resource_mapper: type[BaseResourceMapper], - database: str = CONFIG.mongo_database, + resource_mapper: BaseResourceMapper, + config: ServerConfig, ): """Initialize the MongoCollection for the given parameters. @@ -49,20 +76,24 @@ def __init__( resource_cls: The type of entry resource that is stored by the collection. resource_mapper: A resource mapper object that handles aliases and format changes between deserialization and response. - database: The name of the underlying MongoDB database to connect to. - """ super().__init__( resource_cls, resource_mapper, MongoTransformer(mapper=resource_mapper), + config, ) - self.collection = CLIENT[database][name] + self.config = config + database = config.mongo_database + + client = get_mongo_client(config) + + self.collection = client[database][name] # check aliases do not clash with mongo operators - self._check_aliases(self.resource_mapper.all_aliases()) - self._check_aliases(self.resource_mapper.all_length_aliases()) + self._check_aliases(self.resource_mapper.all_aliases) + self._check_aliases(self.resource_mapper.all_length_aliases) def __len__(self) -> int: """Returns the total number of entries in the collection.""" @@ -85,7 +116,7 @@ def count(self, **kwargs: Any) -> int | None: return self.collection.estimated_document_count() else: if "maxTimeMS" not in kwargs: - kwargs["maxTimeMS"] = int(1000 * CONFIG.mongo_count_timeout) + kwargs["maxTimeMS"] = int(1000 * self.config.mongo_count_timeout) try: return self.collection.count_documents(**kwargs) except ExecutionTimeout: @@ -183,7 +214,7 @@ def _run_db_query( """ results = list(self.collection.find(**criteria)) - if CONFIG.database_backend == SupportedBackend.MONGOMOCK and criteria.get( + if self.config.database_backend == SupportedBackend.MONGOMOCK and criteria.get( "projection", {} ).get("_id"): # mongomock does not support `$toString`` in projection, so we have to do it manually diff --git a/optimade/server/exception_handlers.py b/optimade/server/exception_handlers.py index 68c6d29a4..127e53b4f 100644 --- a/optimade/server/exception_handlers.py +++ b/optimade/server/exception_handlers.py @@ -9,7 +9,6 @@ from optimade.exceptions import BadRequest, OptimadeHTTPException from optimade.models import ErrorResponse, ErrorSource, OptimadeError -from optimade.server.config import CONFIG from optimade.server.logger import LOGGER from optimade.server.routers.utils import JSONAPIResponse, meta_values @@ -34,12 +33,13 @@ def general_exception( """ debug_info = {} - if CONFIG.debug: + config = request.app.state.config + if config.debug: tb = "".join( traceback.format_exception(type(exc), value=exc, tb=exc.__traceback__) ) LOGGER.error("Traceback:\n%s", tb) - debug_info[f"_{CONFIG.provider.prefix}_traceback"] = tb + debug_info[f"_{config.provider.prefix}_traceback"] = tb try: http_response_code = int(exc.status_code) # type: ignore[attr-defined] @@ -61,11 +61,12 @@ def general_exception( response = ErrorResponse( meta=meta_values( + config, url=request.url, data_returned=0, data_available=0, more_data_available=False, - schema=CONFIG.schema_url, + schema=config.schema_url, **debug_info, ), errors=errors, diff --git a/optimade/server/index_links.json b/optimade/server/index_links.json index 334a07b74..bcc613068 100644 --- a/optimade/server/index_links.json +++ b/optimade/server/index_links.json @@ -1,6 +1,6 @@ [ { - "id": "test_server", + "id": "index", "type": "links", "name": "OPTIMADE API", "description": "The [Open Databases Integration for Materials Design (OPTIMADE) consortium](https://www.optimade.org/) aims to make materials databases interoperational by developing a common REST API.", diff --git a/optimade/server/logger.py b/optimade/server/logger.py index 0f574aab4..5e7171b88 100644 --- a/optimade/server/logger.py +++ b/optimade/server/logger.py @@ -6,24 +6,27 @@ import sys from pathlib import Path +from optimade.server.config import ServerConfig + # Instantiate LOGGER LOGGER = logging.getLogger("optimade") LOGGER.setLevel(logging.DEBUG) +CONFIG = ServerConfig() + # Handler CONSOLE_HANDLER = logging.StreamHandler(sys.stdout) try: - from optimade.server.config import CONFIG - CONSOLE_HANDLER.setLevel(CONFIG.log_level.value.upper()) if CONFIG.debug: CONSOLE_HANDLER.setLevel(logging.DEBUG) - except ImportError: CONSOLE_HANDLER.setLevel(os.getenv("OPTIMADE_LOG_LEVEL", "INFO").upper()) +CONSOLE_HANDLER.setLevel(os.getenv("OPTIMADE_LOG_LEVEL", "DEBUG").upper()) + # Formatter; try to use uvicorn default, otherwise just use built-in default try: from uvicorn.logging import DefaultFormatter @@ -39,8 +42,6 @@ # Save a file with all messages (DEBUG level) try: - from optimade.server.config import CONFIG - LOGS_DIR = CONFIG.log_dir except ImportError: LOGS_DIR = Path(os.getenv("OPTIMADE_LOG_DIR", "/var/log/optimade/")).resolve() diff --git a/optimade/server/main.py b/optimade/server/main.py index 22a4337ad..87f5c3a5d 100644 --- a/optimade/server/main.py +++ b/optimade/server/main.py @@ -1,181 +1,10 @@ """The OPTIMADE server -The server is based on MongoDB, using either `pymongo` or `mongomock`. - This is an example implementation with example data. To implement your own server see the documentation at https://optimade.org/optimade-python-tools. """ -import os -import warnings -from contextlib import asynccontextmanager -from pathlib import Path - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware -from fastapi.middleware.gzip import GZipMiddleware - -with warnings.catch_warnings(record=True) as w: - from optimade.server.config import CONFIG, DEFAULT_CONFIG_FILE_PATH - - config_warnings = w - -from optimade import __api_version__, __version__ -from optimade.server.entry_collections import EntryCollection -from optimade.server.exception_handlers import OPTIMADE_EXCEPTIONS -from optimade.server.logger import LOGGER -from optimade.server.middleware import OPTIMADE_MIDDLEWARE -from optimade.server.routers import ( - info, - landing, - links, - references, - structures, - versions, -) -from optimade.server.routers.utils import BASE_URL_PREFIXES, JSONAPIResponse - -if config_warnings: - LOGGER.warn( - f"Invalid config file or no config file provided, running server with default settings. Errors: " - f"{[warnings.formatwarning(w.message, w.category, w.filename, w.lineno, '') for w in config_warnings]}" - ) -else: - LOGGER.info( - f"Loaded settings from {os.getenv('OPTIMADE_CONFIG_FILE', DEFAULT_CONFIG_FILE_PATH)}." - ) - -if CONFIG.debug: # pragma: no cover - LOGGER.info("DEBUG MODE") - - -@asynccontextmanager # type: ignore[arg-type] -async def lifespan(app: FastAPI): - """Add dynamic endpoints on startup.""" - # Add API endpoints for MANDATORY base URL `/vMAJOR` - add_major_version_base_url(app) - # Add API endpoints for OPTIONAL base URLs `/vMAJOR.MINOR` and `/vMAJOR.MINOR.PATCH` - add_optional_versioned_base_urls(app) - - # Yield so that the app can start - yield - - -app = FastAPI( - root_path=CONFIG.root_path, - title="OPTIMADE API", - description=( - f"""The [Open Databases Integration for Materials Design (OPTIMADE) consortium](https://www.optimade.org/) aims to make materials databases interoperational by developing a common REST API. - -This specification is generated using [`optimade-python-tools`](https://github.com/Materials-Consortia/optimade-python-tools/tree/v{__version__}) v{__version__}.""" - ), - version=__api_version__, - docs_url=f"{BASE_URL_PREFIXES['major']}/extensions/docs", - redoc_url=f"{BASE_URL_PREFIXES['major']}/extensions/redoc", - openapi_url=f"{BASE_URL_PREFIXES['major']}/extensions/openapi.json", - default_response_class=JSONAPIResponse, - separate_input_output_schemas=False, - lifespan=lifespan, -) - - -if CONFIG.insert_test_data or CONFIG.insert_from_jsonl: - from optimade.utils import insert_from_jsonl - - def _insert_test_data(endpoint: str | None = None): - import bson.json_util - from bson.objectid import ObjectId - - import optimade.server.data as data - from optimade.server.routers import ENTRY_COLLECTIONS - from optimade.server.routers.utils import get_providers - - def load_entries(endpoint_name: str, endpoint_collection: EntryCollection): - LOGGER.debug("Loading test %s...", endpoint_name) - - endpoint_collection.insert(getattr(data, endpoint_name, [])) - if ( - CONFIG.database_backend.value in ("mongomock", "mongodb") - and endpoint_name == "links" - ): - LOGGER.debug( - "Adding Materials-Consortia providers to links from optimade.org" - ) - providers = get_providers(add_mongo_id=True) - for doc in providers: - endpoint_collection.collection.replace_one( # type: ignore[attr-defined] - filter={"_id": ObjectId(doc["_id"]["$oid"])}, - replacement=bson.json_util.loads(bson.json_util.dumps(doc)), - upsert=True, - ) - LOGGER.debug("Done inserting test %s!", endpoint_name) - - if endpoint: - load_entries(endpoint, ENTRY_COLLECTIONS[endpoint]) - else: - for name, collection in ENTRY_COLLECTIONS.items(): - load_entries(name, collection) - - if CONFIG.insert_from_jsonl: - jsonl_path = Path(CONFIG.insert_from_jsonl) - LOGGER.debug("Inserting data from JSONL file: %s", jsonl_path) - if not jsonl_path.exists(): - raise RuntimeError( - f"Requested JSONL file does not exist: {jsonl_path}. Please specify an absolute group." - ) - - insert_from_jsonl(jsonl_path, create_default_index=CONFIG.create_default_index) - - LOGGER.debug("Inserted data from JSONL file: %s", jsonl_path) - if CONFIG.insert_test_data: - _insert_test_data("links") - elif CONFIG.insert_test_data: - _insert_test_data() - - if CONFIG.exit_after_insert: - LOGGER.info("Exiting after inserting test data.") - import sys - - sys.exit(0) - -# Add CORS middleware first -app.add_middleware(CORSMiddleware, allow_origins=["*"]) - -# Then add required OPTIMADE middleware -for middleware in OPTIMADE_MIDDLEWARE: - app.add_middleware(middleware) - -# Enable GZIP after other middleware. -if CONFIG.gzip.enabled: - app.add_middleware( - GZipMiddleware, - minimum_size=CONFIG.gzip.minimum_size, - compresslevel=CONFIG.gzip.compresslevel, - ) - - -# Add exception handlers -for exception, handler in OPTIMADE_EXCEPTIONS: - app.add_exception_handler(exception, handler) - -# Add various endpoints to unversioned URL -for endpoint in (info, links, references, structures, landing, versions): - app.include_router(endpoint.router) - - -def add_major_version_base_url(app: FastAPI): - """Add mandatory vMajor endpoints, i.e. all except versions.""" - for endpoint in (info, links, references, structures, landing): - app.include_router(endpoint.router, prefix=BASE_URL_PREFIXES["major"]) - +from optimade.server.config import ServerConfig +from optimade.server.create_app import create_app -def add_optional_versioned_base_urls(app: FastAPI): - """Add the following OPTIONAL prefixes/base URLs to server: - ``` - /vMajor.Minor - /vMajor.Minor.Patch - ``` - """ - for version in ("minor", "patch"): - for endpoint in (info, links, references, structures, landing): - app.include_router(endpoint.router, prefix=BASE_URL_PREFIXES[version]) +app = create_app(ServerConfig()) diff --git a/optimade/server/main_index.py b/optimade/server/main_index.py index eccbcfb1d..a34357ba5 100644 --- a/optimade/server/main_index.py +++ b/optimade/server/main_index.py @@ -1,149 +1,10 @@ """The OPTIMADE Index Meta-Database server -The server is based on MongoDB, using either `pymongo` or `mongomock`. - This is an example implementation with example data. To implement your own index meta-database server see the documentation at https://optimade.org/optimade-python-tools. """ -import json -import os -import warnings -from contextlib import asynccontextmanager - -from fastapi import FastAPI -from fastapi.middleware.cors import CORSMiddleware - -with warnings.catch_warnings(record=True) as w: - from optimade.server.config import CONFIG, DEFAULT_CONFIG_FILE_PATH - - config_warnings = w - -from optimade import __api_version__, __version__ -from optimade.server.exception_handlers import OPTIMADE_EXCEPTIONS -from optimade.server.logger import LOGGER -from optimade.server.middleware import OPTIMADE_MIDDLEWARE -from optimade.server.routers import index_info, links, versions -from optimade.server.routers.utils import BASE_URL_PREFIXES, JSONAPIResponse - -if config_warnings: - LOGGER.warn( - f"Invalid config file or no config file provided, running server with default settings. Errors: " - f"{[warnings.formatwarning(w.message, w.category, w.filename, w.lineno, '') for w in config_warnings]}" - ) -else: - LOGGER.info( - f"Loaded settings from {os.getenv('OPTIMADE_CONFIG_FILE', DEFAULT_CONFIG_FILE_PATH)}." - ) - - -if CONFIG.debug: # pragma: no cover - LOGGER.info("DEBUG MODE") - - -@asynccontextmanager # type: ignore[arg-type] -async def lifespan(app: FastAPI): - """Add dynamic endpoints and adjust config on startup.""" - CONFIG.is_index = True - # Add API endpoints for MANDATORY base URL `/vMAJOR` - add_major_version_base_url(app) - # Add API endpoints for OPTIONAL base URLs `/vMAJOR.MINOR` and `/vMAJOR.MINOR.PATCH` - add_optional_versioned_base_urls(app) - - # Yield so that the app can start - yield - - -app = FastAPI( - root_path=CONFIG.root_path, - title="OPTIMADE API - Index meta-database", - description=( - f"""The [Open Databases Integration for Materials Design (OPTIMADE) consortium](https://www.optimade.org/) aims to make materials databases interoperational by developing a common REST API. -This is the "special" index meta-database. - -This specification is generated using [`optimade-python-tools`](https://github.com/Materials-Consortia/optimade-python-tools/tree/v{__version__}) v{__version__}.""" - ), - version=__api_version__, - docs_url=f"{BASE_URL_PREFIXES['major']}/extensions/docs", - redoc_url=f"{BASE_URL_PREFIXES['major']}/extensions/redoc", - openapi_url=f"{BASE_URL_PREFIXES['major']}/extensions/openapi.json", - default_response_class=JSONAPIResponse, - separate_input_output_schemas=False, - lifespan=lifespan, -) - - -if CONFIG.insert_test_data and CONFIG.index_links_path.exists(): - import bson.json_util - from bson.objectid import ObjectId - - from optimade.server.routers.links import links_coll - from optimade.server.routers.utils import get_providers, mongo_id_for_database - - LOGGER.debug("Loading index links...") - with open(CONFIG.index_links_path) as f: - data = json.load(f) - - processed = [] - for db in data: - db["_id"] = {"$oid": mongo_id_for_database(db["id"], db["type"])} - processed.append(db) - - LOGGER.debug( - "Inserting index links into collection from %s...", CONFIG.index_links_path - ) - - links_coll.insert(bson.json_util.loads(bson.json_util.dumps(processed))) - - if CONFIG.database_backend.value in ("mongodb", "mongomock"): - LOGGER.debug( - "Adding Materials-Consortia providers to links from optimade.org..." - ) - providers = get_providers(add_mongo_id=True) - for doc in providers: - links_coll.collection.replace_one( # type: ignore[attr-defined] - filter={"_id": ObjectId(doc["_id"]["$oid"])}, - replacement=bson.json_util.loads(bson.json_util.dumps(doc)), - upsert=True, - ) - - LOGGER.debug("Done inserting index links!") - - else: - LOGGER.warning( - "Not inserting test data for index meta-database for backend %s", - CONFIG.database_backend.value, - ) - -# Add CORS middleware first -app.add_middleware(CORSMiddleware, allow_origins=["*"]) - -# Then add required OPTIMADE middleware -for middleware in OPTIMADE_MIDDLEWARE: - app.add_middleware(middleware) - -# Add exception handlers -for exception, handler in OPTIMADE_EXCEPTIONS: - app.add_exception_handler(exception, handler) - -# Add all endpoints to unversioned URL -for endpoint in (index_info, links, versions): - app.include_router(endpoint.router) - - -def add_major_version_base_url(app: FastAPI): - """Add mandatory endpoints to `/vMAJOR` base URL.""" - for endpoint in (index_info, links): - app.include_router(endpoint.router, prefix=BASE_URL_PREFIXES["major"]) - +from optimade.server.config import ServerConfig +from optimade.server.create_app import create_app -def add_optional_versioned_base_urls(app: FastAPI): - """Add the following OPTIONAL prefixes/base URLs to server: - ``` - /vMajor.Minor - /vMajor.Minor.Patch - ``` - """ - for version in ("minor", "patch"): - app.include_router(index_info.router, prefix=BASE_URL_PREFIXES[version]) - app.include_router(links.router, prefix=BASE_URL_PREFIXES[version]) +app = create_app(ServerConfig(), index=True) diff --git a/optimade/server/mappers/entries.py b/optimade/server/mappers/entries.py index 8f499f9b5..22806a780 100644 --- a/optimade/server/mappers/entries.py +++ b/optimade/server/mappers/entries.py @@ -1,9 +1,10 @@ import warnings from collections.abc import Iterable -from functools import lru_cache -from typing import Any +from functools import cached_property +from typing import Any, Literal from optimade.models.entries import EntryResource +from optimade.server.config import ServerConfig # A number that approximately tracks the number of types with mappers # so that the global caches can be set to the correct size. @@ -14,332 +15,156 @@ __all__ = ("BaseResourceMapper",) -class classproperty(property): - """A simple extension of the property decorator that binds to types - rather than instances. - - Modelled on this [StackOverflow answer](https://stackoverflow.com/a/5192374) - with some tweaks to allow mkdocstrings to do its thing. - - """ - - def __init__(self, func): - self.__name__ = func.__name__ - self.__module__ = func.__module__ - self.__doc__ = func.__doc__ - self.__wrapped__ = func - - def __get__(self, _: Any, owner: type | None = None) -> Any: - return self.__wrapped__(owner) - - class BaseResourceMapper: - """Generic Resource Mapper that defines and performs the mapping - between objects in the database and the resource objects defined by - the specification. - - Attributes: - ALIASES: a tuple of aliases between - OPTIMADE field names and the field names in the database , - e.g. `(("elements", "custom_elements_field"))`. - LENGTH_ALIASES: a tuple of aliases between - a field name and another field that defines its length, to be used - when querying, e.g. `(("elements", "nelements"))`. - e.g. `(("elements", "custom_elements_field"))`. - ENTRY_RESOURCE_CLASS: The entry type that this mapper corresponds to. - PROVIDER_FIELDS: a tuple of extra field names that this - mapper should support when querying with the database prefix. - TOP_LEVEL_NON_ATTRIBUTES_FIELDS: the set of top-level - field names common to all endpoints. - SUPPORTED_PREFIXES: The set of prefixes registered by this mapper. - ALL_ATTRIBUTES: The set of attributes defined across the entry - resource class and the server configuration. - ENTRY_RESOURCE_ATTRIBUTES: A dictionary of attributes and their definitions - defined by the schema of the entry resource class. - ENDPOINT: The expected endpoint name for this resource, as defined by - the `type` in the schema of the entry resource class. + """Instance-based Resource Mapper. + Create one instance per CONFIG (and optionally per providers set). + Subclasses still set class-level constants like ENTRY_RESOURCE_CLASS. """ - try: - from optimade.server.data import providers as PROVIDERS # type: ignore - except (ImportError, ModuleNotFoundError): - PROVIDERS = {} - - KNOWN_PROVIDER_PREFIXES: set[str] = { - prov["id"] for prov in PROVIDERS.get("data", []) - } + # class-level knobs remain ALIASES: tuple[tuple[str, str], ...] = () LENGTH_ALIASES: tuple[tuple[str, str], ...] = () PROVIDER_FIELDS: tuple[str, ...] = () - ENTRY_RESOURCE_CLASS: type[EntryResource] = EntryResource + ENTRY_RESOURCE_CLASS: type["EntryResource"] = EntryResource RELATIONSHIP_ENTRY_TYPES: set[str] = {"references", "structures"} TOP_LEVEL_NON_ATTRIBUTES_FIELDS: set[str] = {"id", "type", "relationships", "links"} - @classmethod - @lru_cache(maxsize=NUM_ENTRY_TYPES) - def all_aliases(cls) -> Iterable[tuple[str, str]]: - """Returns all of the associated aliases for this entry type, - including those defined by the server config. The first member - of each tuple is the OPTIMADE-compliant field name, the second - is the backend-specific field name. - - Returns: - A tuple of alias tuples. - + def __init__(self, config: ServerConfig | None = None): """ - from optimade.server.config import CONFIG - - return ( - tuple( - (f"_{CONFIG.provider.prefix}_{field}", field) - if not field.startswith("_") - else (field, field) - for field in CONFIG.provider_fields.get(cls.ENDPOINT, []) - if isinstance(field, str) - ) - + tuple( - (f"_{CONFIG.provider.prefix}_{field['name']}", field["name"]) - if not field["name"].startswith("_") - else (field["name"], field["name"]) - for field in CONFIG.provider_fields.get(cls.ENDPOINT, []) - if isinstance(field, dict) - ) - + tuple( - (f"_{CONFIG.provider.prefix}_{field}", field) - if not field.startswith("_") - else (field, field) - for field in cls.PROVIDER_FIELDS - ) - + tuple(CONFIG.aliases.get(cls.ENDPOINT, {}).items()) - + cls.ALIASES - ) - - @classproperty - @lru_cache(maxsize=1) - def SUPPORTED_PREFIXES(cls) -> set[str]: - """A set of prefixes handled by this entry type. - - !!! note - This implementation only includes the provider prefix, - but in the future this property may be extended to include other - namespaces (for serving fields from, e.g., other providers or - domain-specific terms). - + Args: + config: Server CONFIG-like object (must expose: + .provider.prefix, .provider_fields, .aliases, .length_aliases) """ - from optimade.server.config import CONFIG - - return {CONFIG.provider.prefix} - - @classproperty - def ALL_ATTRIBUTES(cls) -> set[str]: - """Returns all attributes served by this entry.""" - from optimade.server.config import CONFIG - - return ( - set(cls.ENTRY_RESOURCE_ATTRIBUTES) - .union( - cls.get_optimade_field(field) - for field in CONFIG.provider_fields.get(cls.ENDPOINT, ()) - if isinstance(field, str) - ) - .union( - cls.get_optimade_field(field["name"]) - for field in CONFIG.provider_fields.get(cls.ENDPOINT, ()) - if isinstance(field, dict) - ) - .union({cls.get_optimade_field(field) for field in cls.PROVIDER_FIELDS}) - ) - - @classproperty - def ENTRY_RESOURCE_ATTRIBUTES(cls) -> dict[str, Any]: - """Returns the dictionary of attributes defined by the underlying entry resource class.""" - from optimade.server.schemas import retrieve_queryable_properties - - return retrieve_queryable_properties(cls.ENTRY_RESOURCE_CLASS) - - @classproperty - @lru_cache(maxsize=NUM_ENTRY_TYPES) - def ENDPOINT(cls) -> str: - """Returns the expected endpoint for this mapper, corresponding - to the `type` property of the resource class. - - """ - endpoint = cls.ENTRY_RESOURCE_CLASS.model_fields["type"].default - if not endpoint and not isinstance(endpoint, str): + if config is None: + config = ServerConfig() + self.config = config + try: + from optimade.server.data import providers as PROVIDERS # type: ignore + except (ImportError, ModuleNotFoundError): + PROVIDERS = {} + self.providers = PROVIDERS + + self.KNOWN_PROVIDER_PREFIXES: set[str] = { + prov["id"] for prov in self.providers.get("data", []) + } + + # ---- Computed, cached once per instance ---- + @cached_property + def ENDPOINT(self) -> Literal["links", "references", "structures"]: + endpoint = self.ENTRY_RESOURCE_CLASS.model_fields["type"].default + if not endpoint or not isinstance(endpoint, str): raise ValueError("Type not set for this entry type!") return endpoint - @classmethod - @lru_cache(maxsize=NUM_ENTRY_TYPES) - def all_length_aliases(cls) -> tuple[tuple[str, str], ...]: - """Returns all of the associated length aliases for this class, - including those defined by the server config. - - Returns: - A tuple of length alias tuples. - - """ - from optimade.server.config import CONFIG - - return cls.LENGTH_ALIASES + tuple( - CONFIG.length_aliases.get(cls.ENDPOINT, {}).items() + @cached_property + def SUPPORTED_PREFIXES(self) -> set[str]: + return {self.config.provider.prefix} + + @cached_property + def all_aliases(self) -> tuple[tuple[str, str], ...]: + cfg = self.config + ep = self.ENDPOINT + provider_fields = cfg.provider_fields.get(ep) or [] + + provider_field_aliases_str = tuple( + (f"_{cfg.provider.prefix}_{field}", field) + if not field.startswith("_") + else (field, field) + for field in provider_fields + if isinstance(field, str) ) + provider_field_aliases_dict = tuple( + (f"_{cfg.provider.prefix}_{fd['name']}", fd["name"]) + if not fd["name"].startswith("_") + else (fd["name"], fd["name"]) + for fd in provider_fields + if isinstance(fd, dict) + ) + explicit_provider_fields = tuple( + (f"_{cfg.provider.prefix}_{field}", field) + if not field.startswith("_") + else (field, field) + for field in self.PROVIDER_FIELDS + ) + config_aliases = tuple(cfg.aliases.get(ep, {}).items()) - @classmethod - @lru_cache(maxsize=128) - def length_alias_for(cls, field: str) -> str | None: - """Returns the length alias for the particular field, - or `None` if no such alias is found. - - Parameters: - field: OPTIMADE field name. - - Returns: - Aliased field as found in [`all_length_aliases()`][optimade.server.mappers.entries.BaseResourceMapper.all_length_aliases]. - - """ - return dict(cls.all_length_aliases()).get(field, None) + return ( + provider_field_aliases_str + + provider_field_aliases_dict + + explicit_provider_fields + + config_aliases + + self.ALIASES + ) - @classmethod - @lru_cache(maxsize=128) - def get_backend_field(cls, optimade_field: str) -> str: - """Return the field name configured for the particular - underlying database for the passed OPTIMADE field name, that would - be used in an API filter. + @cached_property + def all_length_aliases(self) -> tuple[tuple[str, str], ...]: + return self.LENGTH_ALIASES + tuple( + self.config.length_aliases.get(self.ENDPOINT, {}).items() + ) - Aliases are read from - [`all_aliases()`][optimade.server.mappers.entries.BaseResourceMapper.all_aliases]. + @cached_property + def ENTRY_RESOURCE_ATTRIBUTES_MAP(self) -> dict[str, Any]: + from optimade.server.schemas import retrieve_queryable_properties - If a dot-separated OPTIMADE field is provided, e.g., `species.mass`, only the first part will be mapped. - This means for an (OPTIMADE, DB) alias of (`species`, `kinds`), `get_backend_fields("species.mass")` - will return `kinds.mass`. + return retrieve_queryable_properties(self.ENTRY_RESOURCE_CLASS) - Arguments: - optimade_field: The OPTIMADE field to attempt to map to the backend-specific field. + @cached_property + def ALL_ATTRIBUTES(self) -> set[str]: + cfg = self.config + ep = self.ENDPOINT + pf = cfg.provider_fields.get(ep, ()) - Examples: - >>> get_backend_field("chemical_formula_anonymous") - 'formula_anon' - >>> get_backend_field("formula_anon") - 'formula_anon' - >>> get_backend_field("_exmpl_custom_provider_field") - 'custom_provider_field' + attrs = set(self.ENTRY_RESOURCE_ATTRIBUTES_MAP) + attrs.update( + self.get_optimade_field(field) for field in pf if isinstance(field, str) + ) + attrs.update( + self.get_optimade_field(field["name"]) + for field in pf + if isinstance(field, dict) + ) + attrs.update(self.get_optimade_field(field) for field in self.PROVIDER_FIELDS) + return attrs - Returns: - The mapped field name to be used in the query to the backend. + # ---- Instance methods that use the cached properties ---- + def length_alias_for(self, field: str) -> str | None: + return dict(self.all_length_aliases).get(field) - """ + def get_backend_field(self, optimade_field: str) -> str: split = optimade_field.split(".") - alias = dict(cls.all_aliases()).get(split[0], None) + alias = dict(self.all_aliases).get(split[0]) if alias is not None: return alias + ("." + ".".join(split[1:]) if len(split) > 1 else "") return optimade_field - @classmethod - @lru_cache(maxsize=128) - def alias_for(cls, field: str) -> str: - """Return aliased field name. - - !!! warning "Deprecated" - This method is deprecated could be removed without further warning. Please use - [`get_backend_field()`][optimade.server.mappers.entries.BaseResourceMapper.get_backend_field]. - - Parameters: - field: OPTIMADE field name. - - Returns: - Aliased field as found in [`all_aliases()`][optimade.server.mappers.entries.BaseResourceMapper.all_aliases]. + def get_optimade_field(self, backend_field: str) -> str: + return {alias: real for real, alias in self.all_aliases}.get( + backend_field, backend_field + ) - """ + def alias_for(self, field: str) -> str: warnings.warn( - "The `.alias_for(...)` method is deprecated, please use `.get_backend_field(...)`.", + "`.alias_for(...)` is deprecated; use `.get_backend_field(...)`.", DeprecationWarning, ) - return cls.get_backend_field(field) - - @classmethod - @lru_cache(maxsize=128) - def get_optimade_field(cls, backend_field: str) -> str: - """Return the corresponding OPTIMADE field name for the underlying database field, - ready to be used to construct the OPTIMADE-compliant JSON response. - - Aliases are read from - [`all_aliases()`][optimade.server.mappers.entries.BaseResourceMapper.all_aliases]. - - Arguments: - backend_field: The backend field to attempt to map to an OPTIMADE field. + return self.get_backend_field(field) - Examples: - >>> get_optimade_field("chemical_formula_anonymous") - 'chemical_formula_anonymous' - >>> get_optimade_field("formula_anon") - 'chemical_formula_anonymous' - >>> get_optimade_field("custom_provider_field") - '_exmpl_custom_provider_field' - - Returns: - The mapped field name to be used in an OPTIMADE-compliant response. - - """ - return {alias: real for real, alias in cls.all_aliases()}.get( - backend_field, backend_field - ) - - @classmethod - @lru_cache(maxsize=128) - def alias_of(cls, field: str) -> str: - """Return de-aliased field name, if it exists, - otherwise return the input field name. - - !!! warning "Deprecated" - This method is deprecated could be removed without further warning. Please use - [`get_optimade_field()`][optimade.server.mappers.entries.BaseResourceMapper.get_optimade_field]. - - Parameters: - field: Field name to be de-aliased. - - Returns: - De-aliased field name, falling back to returning `field`. - - """ + def alias_of(self, field: str) -> str: warnings.warn( - "The `.alias_of(...)` method is deprecated, please use `.get_optimade_field(...)`.", + "`.alias_of(...)` is deprecated; use `.get_optimade_field(...)`.", DeprecationWarning, ) - return cls.get_optimade_field(field) - - @classmethod - @lru_cache(maxsize=NUM_ENTRY_TYPES) - def get_required_fields(cls) -> set: - """Get REQUIRED response fields. - - Returns: - REQUIRED response fields. - - """ - return cls.TOP_LEVEL_NON_ATTRIBUTES_FIELDS + return self.get_optimade_field(field) - @classmethod - def map_back(cls, doc: dict) -> dict: - """Map properties from MongoDB to OPTIMADE. + def get_required_fields(self) -> set[str]: + return self.TOP_LEVEL_NON_ATTRIBUTES_FIELDS - Starting from a MongoDB document `doc`, map the DB fields to the corresponding OPTIMADE fields. - Then, the fields are all added to the top-level field "attributes", - with the exception of other top-level fields, defined in `cls.TOP_LEVEL_NON_ATTRIBUTES_FIELDS`. - All fields not in `cls.TOP_LEVEL_NON_ATTRIBUTES_FIELDS` + "attributes" will be removed. - Finally, the `type` is given the value of the specified `cls.ENDPOINT`. - - Parameters: - doc: A resource object in MongoDB format. - - Returns: - A resource object in OPTIMADE format. - - """ - mapping = ((real, alias) for alias, real in cls.all_aliases()) + def map_back(self, doc: dict) -> dict: + mapping = ((real, alias) for alias, real in self.all_aliases) newdoc = {} - reals = {real for _, real in cls.all_aliases()} + reals = {real for _, real in self.all_aliases} + for key in doc: if key not in reals: newdoc[key] = doc[key] @@ -351,28 +176,19 @@ def map_back(cls, doc: dict) -> dict: raise Exception("Will overwrite doc field!") attributes = newdoc.copy() - for field in cls.TOP_LEVEL_NON_ATTRIBUTES_FIELDS: + for field in self.TOP_LEVEL_NON_ATTRIBUTES_FIELDS: value = attributes.pop(field, None) if value is not None: newdoc[field] = value for field in list(newdoc.keys()): - if field not in cls.TOP_LEVEL_NON_ATTRIBUTES_FIELDS: + if field not in self.TOP_LEVEL_NON_ATTRIBUTES_FIELDS: del newdoc[field] - newdoc["type"] = cls.ENDPOINT + newdoc["type"] = self.ENDPOINT newdoc["attributes"] = attributes - return newdoc - @classmethod - def deserialize( - cls, results: dict | Iterable[dict] - ) -> list[EntryResource] | EntryResource: - """Converts the raw database entries for this class into serialized models, - mapping the data along the way. - - """ + def deserialize(self, results: dict | Iterable[dict]): if isinstance(results, dict): - return cls.ENTRY_RESOURCE_CLASS(**cls.map_back(results)) - - return [cls.ENTRY_RESOURCE_CLASS(**cls.map_back(doc)) for doc in results] + return self.ENTRY_RESOURCE_CLASS(**self.map_back(results)) + return [self.ENTRY_RESOURCE_CLASS(**self.map_back(doc)) for doc in results] diff --git a/optimade/server/mappers/links.py b/optimade/server/mappers/links.py index f0ddbddcf..9457a457a 100644 --- a/optimade/server/mappers/links.py +++ b/optimade/server/mappers/links.py @@ -7,8 +7,7 @@ class LinksMapper(BaseResourceMapper): ENTRY_RESOURCE_CLASS = LinksResource - @classmethod - def map_back(cls, doc: dict) -> dict: + def map_back(self, doc: dict) -> dict: """Map properties from MongoDB to OPTIMADE :param doc: A resource object in MongoDB format diff --git a/optimade/server/middleware.py b/optimade/server/middleware.py index 741a1957c..71c1ba470 100644 --- a/optimade/server/middleware.py +++ b/optimade/server/middleware.py @@ -20,7 +20,7 @@ from optimade.exceptions import BadRequest, VersionNotSupported from optimade.models import Warnings -from optimade.server.config import CONFIG +from optimade.server.config import ServerConfig from optimade.server.routers.utils import BASE_URL_PREFIXES, get_base_url from optimade.warnings import ( FieldValueNotRecognized, @@ -78,7 +78,7 @@ class CheckWronglyVersionedBaseUrls(BaseHTTPMiddleware): """If a non-supported versioned base URL is supplied return `553 Version Not Supported`.""" @staticmethod - def check_url(url: StarletteURL): + def check_url(config: ServerConfig, url: StarletteURL): """Check URL path for versioned part. Parameters: @@ -89,7 +89,7 @@ def check_url(url: StarletteURL): and the version part is not supported by the implementation. """ - base_url = get_base_url(url) + base_url = get_base_url(config, url) optimade_path = f"{url.scheme}://{url.netloc}{url.path}"[len(base_url) :] match = re.match(r"^(?P/v[0-9]+(\.[0-9]+){0,2}).*", optimade_path) if match is not None: @@ -103,8 +103,9 @@ def check_url(url: StarletteURL): ) async def dispatch(self, request: Request, call_next): + config = request.app.state.config if request.url.path: - self.check_url(request.url) + self.check_url(config, request.url) response = await call_next(request) return response @@ -196,7 +197,7 @@ def handle_api_hint(api_hint: list[str]) -> None | str: ) @staticmethod - def is_versioned_base_url(url: str) -> bool: + def is_versioned_base_url(config: ServerConfig, url: str) -> bool: """Determine whether a request is for a versioned base URL. First, simply check whether a `/vMAJOR(.MINOR.PATCH)` part exists in the URL. @@ -213,14 +214,15 @@ def is_versioned_base_url(url: str) -> bool: if not re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url): return False - base_url = get_base_url(url) + base_url = get_base_url(config, url) return bool(re.findall(r"(/v[0-9]+(\.[0-9]+){0,2})", url[len(base_url) :])) async def dispatch(self, request: Request, call_next): + config = request.app.state.config parsed_query = urllib.parse.parse_qs(request.url.query, keep_blank_values=True) if "api_hint" in parsed_query: - if self.is_versioned_base_url(str(request.url)): + if self.is_versioned_base_url(config, str(request.url)): warnings.warn( QueryParamNotUsed( detail=( @@ -239,7 +241,7 @@ async def dispatch(self, request: Request, call_next): version_path = self.handle_api_hint(parsed_query["api_hint"]) if version_path: - base_url = get_base_url(request.url) + base_url = get_base_url(config, request.url) new_request = ( f"{base_url}{version_path}{str(request.url)[len(base_url) :]}" @@ -312,6 +314,13 @@ class AddWarnings(BaseHTTPMiddleware): _warnings: list[Warnings] + def __init__(self, app, config: ServerConfig | None = None): + super().__init__(app) + self._warnings = [] + if config is None: + config = ServerConfig() + self._config = config + def showwarning( self, message: Warning | str, @@ -378,7 +387,7 @@ def showwarning( except AttributeError: detail = str(message) - if CONFIG.debug: + if self._config.debug: if line is None: # All this is taken directly from the warnings library. # See 'warnings._formatwarnmsg_impl()' for the original code. @@ -397,7 +406,7 @@ def showwarning( if line: meta["line"] = line.strip() - if CONFIG.debug: + if self._config.debug: new_warning = Warnings(title=title, detail=detail, meta=meta) else: new_warning = Warnings(title=title, detail=detail) @@ -429,6 +438,9 @@ def chunk_it_up(content: str | bytes, chunk_size: int) -> Generator: async def dispatch(self, request: Request, call_next): self._warnings = [] + # Stash config so self.showwarning() can reach it + self._config = request.app.state.config + warnings.simplefilter(action="default", category=OptimadeWarning) warnings.showwarning = self.showwarning diff --git a/optimade/server/query_params.py b/optimade/server/query_params.py index e37bcd8c4..8e684746f 100644 --- a/optimade/server/query_params.py +++ b/optimade/server/query_params.py @@ -7,7 +7,6 @@ from pydantic import EmailStr from optimade.exceptions import BadRequest -from optimade.server.config import CONFIG from optimade.server.mappers import BaseResourceMapper from optimade.warnings import QueryParamNotUsed, UnknownProviderQueryParameter @@ -48,8 +47,6 @@ def check_params(self, query_params: Iterable[str]) -> None: does not have a valid prefix. """ - if not getattr(CONFIG, "validate_query_parameters", False): - return errors = [] warnings = [] unsupported_warnings = [] @@ -60,9 +57,9 @@ def check_params(self, query_params: Iterable[str]) -> None: split_param = param.split("_") if param.startswith("_") and len(split_param) > 2: prefix = split_param[1] - if prefix in BaseResourceMapper.SUPPORTED_PREFIXES: + if prefix in BaseResourceMapper().SUPPORTED_PREFIXES: errors.append(param) - elif prefix not in BaseResourceMapper.KNOWN_PROVIDER_PREFIXES: + elif prefix not in BaseResourceMapper().KNOWN_PROVIDER_PREFIXES: warnings.append(param) else: errors.append(param) @@ -220,7 +217,7 @@ def __init__( description="Sets a numerical limit on the number of entries returned.\nSee [JSON API 1.0](https://jsonapi.org/format/1.0/#fetching-pagination).\nThe API implementation MUST return no more than the number specified.\nIt MAY return fewer.\nThe database MAY have a maximum limit and not accept larger numbers (in which case an error code -- 403 Forbidden -- MUST be returned).\nThe default limit value is up to the API implementation to decide.\nExample: `http://example.com/optimade/v1/structures?page_limit=100`", ge=0, ), - ] = CONFIG.page_limit, + ] = 20, page_offset: Annotated[ int, Query( diff --git a/optimade/server/routers/__init__.py b/optimade/server/routers/__init__.py index 266518dbe..e69de29bb 100644 --- a/optimade/server/routers/__init__.py +++ b/optimade/server/routers/__init__.py @@ -1,9 +0,0 @@ -from .links import links_coll -from .references import references_coll -from .structures import structures_coll - -ENTRY_COLLECTIONS = { - "links": links_coll, - "references": references_coll, - "structures": structures_coll, -} diff --git a/optimade/server/routers/index_info.py b/optimade/server/routers/index_info.py index 310aa2ea1..59f235a8e 100644 --- a/optimade/server/routers/index_info.py +++ b/optimade/server/routers/index_info.py @@ -8,7 +8,6 @@ IndexRelationship, RelatedLinksResource, ) -from optimade.server.config import CONFIG from optimade.server.routers.utils import get_base_url, meta_values from optimade.server.schemas import ERROR_RESPONSES @@ -23,13 +22,16 @@ responses=ERROR_RESPONSES, ) def get_info(request: Request) -> IndexInfoResponse: + config = request.app.state.config + return IndexInfoResponse( meta=meta_values( + config, request.url, 1, 1, more_data_available=False, - schema=CONFIG.index_schema_url, + schema=config.index_schema_url, ), data=IndexInfoResource( id=IndexInfoResource.model_fields["id"].default, @@ -38,7 +40,7 @@ def get_info(request: Request) -> IndexInfoResponse: api_version=f"{__api_version__}", available_api_versions=[ { - "url": f"{get_base_url(request.url)}/v{__api_version__.split('.')[0]}/", + "url": f"{get_base_url(config, request.url)}/v{__api_version__.split('.')[0]}/", "version": f"{__api_version__}", } ], @@ -51,7 +53,7 @@ def get_info(request: Request) -> IndexInfoResponse: "default": IndexRelationship( data={ "type": RelatedLinksResource.model_fields["type"].default, - "id": CONFIG.default_db, + "id": config.default_db, } ) }, diff --git a/optimade/server/routers/info.py b/optimade/server/routers/info.py index 28b8cda22..b3246ddac 100644 --- a/optimade/server/routers/info.py +++ b/optimade/server/routers/info.py @@ -6,7 +6,6 @@ from optimade import __api_version__ from optimade.models import EntryInfoResource, EntryInfoResponse, InfoResponse from optimade.models.baseinfo import BaseInfoAttributes, BaseInfoResource, Link -from optimade.server.config import CONFIG from optimade.server.routers.utils import get_base_url, meta_values from optimade.server.schemas import ( ENTRY_INFO_SCHEMAS, @@ -25,6 +24,8 @@ responses=ERROR_RESPONSES, ) def get_info(request: Request) -> InfoResponse: + config = request.app.state.config + @functools.lru_cache(maxsize=1) def _generate_info_response() -> BaseInfoResource: """Cached closure that generates the info response for the implementation.""" @@ -36,7 +37,7 @@ def _generate_info_response() -> BaseInfoResource: api_version=__api_version__, available_api_versions=[ { - "url": f"{get_base_url(request.url)}/v{__api_version__.split('.')[0]}", + "url": f"{get_base_url(config, request.url)}/v{__api_version__.split('.')[0]}", "version": __api_version__, } ], @@ -44,16 +45,21 @@ def _generate_info_response() -> BaseInfoResource: available_endpoints=["info", "links"] + list(ENTRY_INFO_SCHEMAS.keys()), entry_types_by_format={"json": list(ENTRY_INFO_SCHEMAS.keys())}, is_index=False, - license=Link(href=CONFIG.license) if CONFIG.license else None, - available_licenses=[str(CONFIG.license).split("/")[-1]] - if "https://spdx.org" in str(CONFIG.license) + license=Link(href=config.license) if config.license else None, + available_licenses=[str(config.license).split("/")[-1]] + if "https://spdx.org" in str(config.license) else None, ), ) return InfoResponse( meta=meta_values( - request.url, 1, 1, more_data_available=False, schema=CONFIG.schema_url + config, + request.url, + 1, + 1, + more_data_available=False, + schema=config.schema_url, ), data=_generate_info_response(), ) @@ -67,6 +73,8 @@ def _generate_info_response() -> BaseInfoResource: responses=ERROR_RESPONSES, ) def get_entry_info(request: Request, entry: str) -> EntryInfoResponse: + config = request.app.state.config + @functools.lru_cache(maxsize=len(ENTRY_INFO_SCHEMAS)) def _generate_entry_info_response(entry: str) -> EntryInfoResource: """Cached closure that generates the entry info response for the given type. @@ -89,7 +97,7 @@ def _generate_entry_info_response(entry: str) -> EntryInfoResource: schema = ENTRY_INFO_SCHEMAS[entry] queryable_properties = {"id", "type", "attributes"} properties = retrieve_queryable_properties( - schema, queryable_properties, entry_type=entry + schema, queryable_properties, entry_type=entry, config=config ) output_fields_by_format = {"json": list(properties)} @@ -104,7 +112,12 @@ def _generate_entry_info_response(entry: str) -> EntryInfoResource: return EntryInfoResponse( meta=meta_values( - request.url, 1, 1, more_data_available=False, schema=CONFIG.schema_url + config, + request.url, + 1, + 1, + more_data_available=False, + schema=config.schema_url, ), data=_generate_entry_info_response(entry), ) diff --git a/optimade/server/routers/landing.py b/optimade/server/routers/landing.py index e71ddf4e7..6875e24b2 100644 --- a/optimade/server/routers/landing.py +++ b/optimade/server/routers/landing.py @@ -1,6 +1,5 @@ """OPTIMADE landing page router.""" -from functools import lru_cache from pathlib import Path from fastapi import Request @@ -8,30 +7,41 @@ from starlette.routing import Route, Router from optimade import __api_version__ -from optimade.server.config import CONFIG -from optimade.server.routers import ENTRY_COLLECTIONS +from optimade.server.config import ServerConfig from optimade.server.routers.utils import get_base_url, meta_values - -@lru_cache -def render_landing_page(url: str) -> HTMLResponse: - """Render and cache the landing page. - - This function uses the template file `./static/landing_page.html`, adapted - from the original Jinja template. Instead of Jinja, some basic string - replacement is used to fill out the fields from the server configuration. - - !!! warning "Careful" - The removal of Jinja means that the fields are no longer validated as - web safe before inclusion in the template. - - """ - meta = meta_values(url, 1, 1, more_data_available=False, schema=CONFIG.schema_url) +# In-process cache: {(config_id, url, custom_mtime): HTMLResponse} +_PAGE_CACHE: dict[tuple[int, str, float | None], HTMLResponse] = {} + + +def _custom_file_mtime(config: ServerConfig) -> float | None: + custom = getattr(config, "custom_landing_page", None) + if not custom: + return None + p = Path(custom) + try: + return p.resolve().stat().st_mtime + except FileNotFoundError: + return None + + +def render_landing_page( + config: ServerConfig, entry_collections, url: str +) -> HTMLResponse: + """Render and cache the landing page with a manual, hashable key.""" + cache_key = (id(config), url, _custom_file_mtime(config)) + cached = _PAGE_CACHE.get(cache_key) + if cached is not None: + return cached + + meta = meta_values( + config, url, 1, 1, more_data_available=False, schema=config.schema_url + ) major_version = __api_version__.split(".")[0] - versioned_url = f"{get_base_url(url)}/v{major_version}/" + versioned_url = f"{get_base_url(config, url)}/v{major_version}/" - if CONFIG.custom_landing_page: - html = Path(CONFIG.custom_landing_page).resolve().read_text() + if config.custom_landing_page: + html = Path(config.custom_landing_page).resolve().read_text() else: template_dir = Path(__file__).parent.joinpath("static").resolve() html = (template_dir / "landing_page.html").read_text() @@ -47,7 +57,10 @@ def render_landing_page(url: str) -> HTMLResponse: "provider.name": meta.provider.name, "provider.prefix": meta.provider.prefix, "provider.description": meta.provider.description, - "provider.homepage": str(meta.provider.homepage) or "", + # avoid "None" string leaking into HTML + "provider.homepage": str(meta.provider.homepage) + if meta.provider.homepage + else "", } ) @@ -56,34 +69,40 @@ def render_landing_page(url: str) -> HTMLResponse: { "implementation.name": meta.implementation.name or "", "implementation.version": meta.implementation.version or "", - "implementation.source_url": str(meta.implementation.source_url or ""), + "implementation.source_url": str(meta.implementation.source_url) + if meta.implementation.source_url + else "", } ) - for replacement in replacements: - html = html.replace(f"{{{{ {replacement} }}}}", replacements[replacement]) + for k, v in replacements.items(): + html = html.replace(f"{{{{ {k} }}}}", v) - # Build the list of endpoints. The template already opens and closes the `