diff --git a/mltb2/arangodb.py b/mltb2/arangodb.py index 7af8292..19e4d70 100644 --- a/mltb2/arangodb.py +++ b/mltb2/arangodb.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Philip May +# Copyright (c) 2023-2024 Philip May # This software is distributed under the terms of the MIT license # which is available at https://opensource.org/licenses/MIT @@ -14,12 +14,13 @@ from argparse import ArgumentParser from contextlib import closing from dataclasses import dataclass -from typing import Dict, Optional, Sequence, Union +from typing import Any, Dict, Optional, Sequence, Union import jsonlines from arango import ArangoClient from arango.database import StandardDatabase from dotenv import dotenv_values +from pandas import DataFrame from tqdm import tqdm from mltb2.db import AbstractBatchDataManager @@ -36,7 +37,40 @@ def _check_config_keys(config: Dict[str, Optional[str]], expected_config_keys: S @dataclass -class ArangoBatchDataManager(AbstractBatchDataManager): +class ArangoConnectionManager: + """ArangoDB connection manager. + + Base class to manage / create ArangoDB connections. + + Args: + hosts: ArangoDB host or hosts. + db_name: ArangoDB database name. + username: ArangoDB username. + password: ArangoDB password. + """ + + hosts: Union[str, Sequence[str]] + db_name: str + username: str + password: str + + def _arango_client_factory(self) -> ArangoClient: + """Create an ArangoDB client.""" + arango_client = ArangoClient(hosts=self.hosts) + return arango_client + + def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase: + """Create an ArangoDB connection. + + Args: + arango_client: ArangoDB client. + """ + connection = arango_client.db(self.db_name, username=self.username, password=self.password) + return connection + + +@dataclass +class ArangoBatchDataManager(AbstractBatchDataManager, ArangoConnectionManager): """ArangoDB implementation of the ``AbstractBatchDataManager``. Args: @@ -52,10 +86,6 @@ class ArangoBatchDataManager(AbstractBatchDataManager): aql_overwrite: AQL string to overwrite the default. """ - hosts: Union[str, Sequence[str]] - db_name: str - username: str - password: str collection_name: str attribute_name: str batch_size: int = 20 @@ -117,20 +147,6 @@ def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None) aql_overwrite=aql_overwrite, ) - def _arango_client_factory(self) -> ArangoClient: - """Create an ArangoDB client.""" - arango_client = ArangoClient(hosts=self.hosts) - return arango_client - - def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase: - """Create an ArangoDB connection. - - Args: - arango_client: ArangoDB client. - """ - connection = arango_client.db(self.db_name, username=self.username, password=self.password) - return connection - def load_batch(self) -> Sequence: """Load a batch of data from the ArangoDB database. @@ -216,3 +232,107 @@ def arango_collection_backup() -> None: jsonlines_writer.write(doc) finally: cursor.close(ignore_missing=True) # type: ignore[union-attr] + + +@dataclass +class ArangoImportDataManager(ArangoConnectionManager): + """ArangoDB import tool to fill data into a collection. + + Args: + hosts: ArangoDB host or hosts. + db_name: ArangoDB database name. + username: ArangoDB username. + password: ArangoDB password. + """ + + @classmethod + def from_config_file(cls, config_file_name): + """Construct this from config file. + + The config file must contain at least these values: + + - ``hosts`` + - ``db_name`` + - ``username`` + - ``password`` + + Config file example: + + .. code-block:: + + hosts="https://arangodb.com" + db_name="my_ml_database" + username="my_username" + password="secret" + + Args: + config_file_name: The config file name (path). + """ + # load config file data + arango_config = dotenv_values(config_file_name) + + # check if all necessary keys are in config file + expected_config_file_keys = [ + "hosts", + "db_name", + "username", + "password", + ] + _check_config_keys(arango_config, expected_config_file_keys) + + return cls( + hosts=arango_config["hosts"], + db_name=arango_config["db_name"], + username=arango_config["username"], + password=arango_config["password"], + ) + + def import_dicts( + self, dicts: Sequence[Dict[str, Any]], collection_name: str, create_collection: bool = False + ) -> None: + """Import data to ArangoDB. + + Args: + dicts: The data to import. + collection_name: The collection name to import to. + create_collection: If ``True`` the collection is created if it does not exist. + Raises: + arango.exceptions.DocumentInsertError: If import fails. + """ + with closing(self._arango_client_factory()) as arango_client: + connection = self._connection_factory(arango_client) + + # get (or create) collection + if not connection.has_collection(collection_name): + if create_collection: + collection = connection.create_collection(collection_name) + else: + raise ValueError( + f"Collection '{collection_name}' does not exist! " + "Create it or specify 'create_collection=True'." + ) + else: + collection = connection.collection(collection_name) + + collection.import_bulk( # type: ignore[union-attr] + dicts, + halt_on_error=True, + details=False, + overwrite=False, + on_duplicate="error", + sync=True, + batch_size=100, + ) + + def import_dataframe(self, dataframe: DataFrame, collection_name: str, create_collection: bool = False) -> None: + """Import Pandas data to ArangoDB. + + Args: + dataframe: The Pandas data to import. + collection_name: The collection name to import to. + create_collection: If ``True`` the collection is created if it does not exist. + Raises: + arango.exceptions.DocumentInsertError: If import fails. + """ + dicts = dataframe.to_dict(orient="records") + self.import_dicts(dicts, collection_name, create_collection)