Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Parquet sink support #483

Merged
merged 15 commits into from
Apr 11, 2024
3 changes: 2 additions & 1 deletion kgx/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from kgx.config import get_logger, get_config
from kgx.cli.cli_utils import (
get_input_file_types,
get_output_file_types,
parse_source,
apply_operations,
graph_summary,
Expand Down Expand Up @@ -229,7 +230,7 @@ def validate_wrapper(
"--output-format",
"-f",
required=True,
help=f"The output format. Can be one of {get_input_file_types()}",
help=f"The output format. Can be one of {get_output_file_types()}",
)
@click.option(
"--output-compression", "-d", required=False, help="The output compression type"
Expand Down
9 changes: 5 additions & 4 deletions kgx/sink/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from .sink import Sink
from .tsv_sink import TsvSink
from .graph_sink import GraphSink
from .json_sink import JsonSink
from .jsonl_sink import JsonlSink
from .neo_sink import NeoSink
from .rdf_sink import RdfSink
from .graph_sink import GraphSink
from .null_sink import NullSink
from .sql_sink import SqlSink
from .parquet_sink import ParquetSink
from .rdf_sink import RdfSink
from .sql_sink import SqlSink
from .tsv_sink import TsvSink
115 changes: 115 additions & 0 deletions kgx/sink/parquet_sink.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
'''Sink for Parquet format.'''

from pathlib import Path
from typing import Any

import pandas as pd
from pyarrow import Table
from pyarrow.parquet import write_table

from kgx.sink.sink import Sink


DEFAULT_NODE_COLUMNS = {
"id",
"name",
"category",
"description",
"provided_by"
}
DEFAULT_EDGE_COLUMNS = {
"id",
"subject",
"predicate",
"object",
"relation",
"category",
"knowledge_source",
}


class ParquetSink(Sink):
"""
A ParquetSink writes data to Parquet files.

Parameters
----------
owner: Transformer
Transformer to which the ParquetSink belongs
filename: str
Name of the Parquet file to write to
kwargs: Any
Any additional arguments
"""

def __init__(
self,
owner,
filename: str,
**kwargs: Any
):
super().__init__(owner)
self.filename = filename
self.file_path = Path(filename).resolve()
self.dirname = self.file_path.parent
self.basename = self.file_path.stem
self.nodes_file_basename = f"{self.basename}_nodes.parquet"
self.edges_file_basename = f"{self.basename}_edges.parquet"

self.dirname.mkdir(parents=True, exist_ok=True)

self.nodes_file_name = self.dirname / self.nodes_file_basename
self.edges_file_name = self.dirname / self.edges_file_basename

if "node_properties" in kwargs:
self.node_properties.update(set(kwargs["node_properties"]))
else:
self.node_properties.update(DEFAULT_NODE_COLUMNS)
if "edge_properties" in kwargs:
self.edge_properties.update(set(kwargs["edge_properties"]))
else:
self.edge_properties.update(DEFAULT_EDGE_COLUMNS)

self.nodes = []
self.edges = []

def write_node(self, record) -> None:
"""
Write a node record to the underlying store.

Parameters
----------
record: Any
A node record

"""
self.nodes.append(record)

def write_edge(self, record) -> None:
"""
Write an edge record to the underlying store.

Parameters
----------
record: Any
An edge record

"""
self.edges.append(record)

def finalize(self) -> None:
"""
Finalize writing the data to the underlying store.
"""

nodes_df = pd.DataFrame(self.nodes)
edges_df = pd.DataFrame(self.edges)

nodes_table = Table.from_pandas(nodes_df)
edges_table = Table.from_pandas(edges_df)

write_table(nodes_table, self.nodes_file_name)
write_table(edges_table, self.edges_file_name)

self.nodes = []
self.edges = []
10 changes: 6 additions & 4 deletions kgx/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@
from kgx.sink import (
Sink,
GraphSink,
TsvSink,
JsonSink,
JsonlSink,
NeoSink,
RdfSink,
NullSink,
RdfSink,
SqlSink,
TsvSink,
ParquetSink
)
from kgx.utils.kgx_utils import (
apply_graph_operations,
Expand All @@ -52,15 +53,16 @@
}

SINK_MAP = {
"tsv": TsvSink,
"csv": TsvSink,
"graph": GraphSink,
"json": JsonSink,
"jsonl": JsonlSink,
"neo4j": NeoSink,
"nt": RdfSink,
"null": NullSink,
"sql": SqlSink
"sql": SqlSink,
"tsv": TsvSink,
"parquet": ParquetSink,
}


Expand Down
Loading
Loading