Skip to content

Commit

Permalink
Add ArangoDB upload tool. (#141)
Browse files Browse the repository at this point in the history
* Refactor ArangoDB data manager classes and add import functionality

* Update copyright year in arangodb.py

* Refactor ArangoDataManager to ArangoConnectionManager
  • Loading branch information
PhilipMay authored Jan 6, 2024
1 parent a1e6fa0 commit 991c75f
Showing 1 changed file with 141 additions and 21 deletions.
162 changes: 141 additions & 21 deletions mltb2/arangodb.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 Philip May
# Copyright (c) 2023-2024 Philip May
# This software is distributed under the terms of the MIT license
# which is available at https://opensource.org/licenses/MIT

Expand All @@ -14,12 +14,13 @@
from argparse import ArgumentParser
from contextlib import closing
from dataclasses import dataclass
from typing import Dict, Optional, Sequence, Union
from typing import Any, Dict, Optional, Sequence, Union

import jsonlines
from arango import ArangoClient
from arango.database import StandardDatabase
from dotenv import dotenv_values
from pandas import DataFrame
from tqdm import tqdm

from mltb2.db import AbstractBatchDataManager
Expand All @@ -36,7 +37,40 @@ def _check_config_keys(config: Dict[str, Optional[str]], expected_config_keys: S


@dataclass
class ArangoBatchDataManager(AbstractBatchDataManager):
class ArangoConnectionManager:
"""ArangoDB connection manager.
Base class to manage / create ArangoDB connections.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
"""

hosts: Union[str, Sequence[str]]
db_name: str
username: str
password: str

def _arango_client_factory(self) -> ArangoClient:
"""Create an ArangoDB client."""
arango_client = ArangoClient(hosts=self.hosts)
return arango_client

def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase:
"""Create an ArangoDB connection.
Args:
arango_client: ArangoDB client.
"""
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection


@dataclass
class ArangoBatchDataManager(AbstractBatchDataManager, ArangoConnectionManager):
"""ArangoDB implementation of the ``AbstractBatchDataManager``.
Args:
Expand All @@ -52,10 +86,6 @@ class ArangoBatchDataManager(AbstractBatchDataManager):
aql_overwrite: AQL string to overwrite the default.
"""

hosts: Union[str, Sequence[str]]
db_name: str
username: str
password: str
collection_name: str
attribute_name: str
batch_size: int = 20
Expand Down Expand Up @@ -117,20 +147,6 @@ def from_config_file(cls, config_file_name, aql_overwrite: Optional[str] = None)
aql_overwrite=aql_overwrite,
)

def _arango_client_factory(self) -> ArangoClient:
"""Create an ArangoDB client."""
arango_client = ArangoClient(hosts=self.hosts)
return arango_client

def _connection_factory(self, arango_client: ArangoClient) -> StandardDatabase:
"""Create an ArangoDB connection.
Args:
arango_client: ArangoDB client.
"""
connection = arango_client.db(self.db_name, username=self.username, password=self.password)
return connection

def load_batch(self) -> Sequence:
"""Load a batch of data from the ArangoDB database.
Expand Down Expand Up @@ -216,3 +232,107 @@ def arango_collection_backup() -> None:
jsonlines_writer.write(doc)
finally:
cursor.close(ignore_missing=True) # type: ignore[union-attr]


@dataclass
class ArangoImportDataManager(ArangoConnectionManager):
"""ArangoDB import tool to fill data into a collection.
Args:
hosts: ArangoDB host or hosts.
db_name: ArangoDB database name.
username: ArangoDB username.
password: ArangoDB password.
"""

@classmethod
def from_config_file(cls, config_file_name):
"""Construct this from config file.
The config file must contain at least these values:
- ``hosts``
- ``db_name``
- ``username``
- ``password``
Config file example:
.. code-block::
hosts="https://arangodb.com"
db_name="my_ml_database"
username="my_username"
password="secret"
Args:
config_file_name: The config file name (path).
"""
# load config file data
arango_config = dotenv_values(config_file_name)

# check if all necessary keys are in config file
expected_config_file_keys = [
"hosts",
"db_name",
"username",
"password",
]
_check_config_keys(arango_config, expected_config_file_keys)

return cls(
hosts=arango_config["hosts"],
db_name=arango_config["db_name"],
username=arango_config["username"],
password=arango_config["password"],
)

def import_dicts(
self, dicts: Sequence[Dict[str, Any]], collection_name: str, create_collection: bool = False
) -> None:
"""Import data to ArangoDB.
Args:
dicts: The data to import.
collection_name: The collection name to import to.
create_collection: If ``True`` the collection is created if it does not exist.
Raises:
arango.exceptions.DocumentInsertError: If import fails.
"""
with closing(self._arango_client_factory()) as arango_client:
connection = self._connection_factory(arango_client)

# get (or create) collection
if not connection.has_collection(collection_name):
if create_collection:
collection = connection.create_collection(collection_name)
else:
raise ValueError(
f"Collection '{collection_name}' does not exist! "
"Create it or specify 'create_collection=True'."
)
else:
collection = connection.collection(collection_name)

collection.import_bulk( # type: ignore[union-attr]
dicts,
halt_on_error=True,
details=False,
overwrite=False,
on_duplicate="error",
sync=True,
batch_size=100,
)

def import_dataframe(self, dataframe: DataFrame, collection_name: str, create_collection: bool = False) -> None:
"""Import Pandas data to ArangoDB.
Args:
dataframe: The Pandas data to import.
collection_name: The collection name to import to.
create_collection: If ``True`` the collection is created if it does not exist.
Raises:
arango.exceptions.DocumentInsertError: If import fails.
"""
dicts = dataframe.to_dict(orient="records")
self.import_dicts(dicts, collection_name, create_collection)

0 comments on commit 991c75f

Please sign in to comment.