Skip to content

Commit

Permalink
Support skip/limit options for pandas scan (#4662)
Browse files Browse the repository at this point in the history
  • Loading branch information
royi-luo authored Jan 7, 2025
1 parent 59ed5a8 commit 6caaad7
Show file tree
Hide file tree
Showing 9 changed files with 177 additions and 59 deletions.
1 change: 1 addition & 0 deletions tools/python_api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pybind11_add_module(_kuzu
src_cpp/py_prepared_statement.cpp
src_cpp/py_query_result.cpp
src_cpp/py_query_result_converter.cpp
src_cpp/py_scan_config.cpp
src_cpp/py_udf.cpp
src_cpp/py_conversion.cpp
src_cpp/pyarrow/pyarrow_bind.cpp
Expand Down
21 changes: 9 additions & 12 deletions tools/python_api/src_cpp/include/pandas/pandas_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "function/table/scan_functions.h"
#include "function/table_functions.h"
#include "pandas_bind.h"
#include "py_scan_config.h"

namespace kuzu {

Expand All @@ -15,9 +16,10 @@ struct PandasScanLocalState final : public function::TableFuncLocalState {
};

struct PandasScanSharedState final : public function::BaseScanSharedStateWithNumRows {
explicit PandasScanSharedState(uint64_t numRows)
: BaseScanSharedStateWithNumRows{numRows}, numRowsRead{0} {}
PandasScanSharedState(uint64_t startRow, uint64_t numRows)
: BaseScanSharedStateWithNumRows{numRows}, startRow(startRow), numRowsRead{0} {}

uint64_t startRow;
uint64_t numRowsRead;
};

Expand All @@ -31,23 +33,19 @@ struct PandasScanFunction {
struct PandasScanFunctionData : public function::TableFuncBindData {
py::handle df;
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData;
common::FileScanInfo fileScanInfo;
PyScanConfig scanConfig;

PandasScanFunctionData(binder::expression_vector columns, py::handle df, uint64_t numRows,
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData,
common::FileScanInfo fileScanInfo)
std::vector<std::unique_ptr<PandasColumnBindData>> columnBindData, PyScanConfig scanConfig)
: TableFuncBindData{std::move(columns), 0 /* numWarningDataColumns */, numRows}, df{df},
columnBindData{std::move(columnBindData)}, fileScanInfo(std::move(fileScanInfo)) {}
columnBindData{std::move(columnBindData)}, scanConfig(scanConfig) {}

~PandasScanFunctionData() override {
py::gil_scoped_acquire acquire;
columnBindData.clear();
}

bool getIgnoreErrorsOption() const override {
return fileScanInfo.getOption(common::CopyConstants::IGNORE_ERRORS_OPTION_NAME,
common::CopyConstants::DEFAULT_IGNORE_ERRORS);
}
bool getIgnoreErrorsOption() const override { return scanConfig.ignoreErrors; }

std::vector<std::unique_ptr<PandasColumnBindData>> copyColumnBindData() const;

Expand All @@ -57,11 +55,10 @@ struct PandasScanFunctionData : public function::TableFuncBindData {

private:
PandasScanFunctionData(const PandasScanFunctionData& other)
: TableFuncBindData{other}, df{other.df} {
: TableFuncBindData{other}, df{other.df}, scanConfig(other.scanConfig) {
for (const auto& i : other.columnBindData) {
columnBindData.push_back(i->copy());
}
fileScanInfo = other.fileScanInfo.copy();
}
};

Expand Down
16 changes: 16 additions & 0 deletions tools/python_api/src_cpp/include/py_scan_config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#pragma once

#include "common/case_insensitive_map.h"
#include "common/types/value/value.h"

namespace kuzu {

struct PyScanConfig {
uint64_t skipNum;
uint64_t limitNum;
bool ignoreErrors;
explicit PyScanConfig(const common::case_insensitive_map_t<common::Value>& options,
uint64_t numRows);
};

} // namespace kuzu
7 changes: 0 additions & 7 deletions tools/python_api/src_cpp/include/pyarrow/pyarrow_scan.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,6 @@

namespace kuzu {

struct PyArrowScanConfig {
uint64_t skipNum;
uint64_t limitNum;
bool ignoreErrors;
explicit PyArrowScanConfig(const common::case_insensitive_map_t<common::Value>& options);
};

struct PyArrowTableScanLocalState final : public function::TableFuncLocalState {
ArrowArrayWrapper* arrowArray;

Expand Down
19 changes: 12 additions & 7 deletions tools/python_api/src_cpp/pandas/pandas_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "numpy/numpy_scan.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "py_scan_config.h"
#include "pyarrow/pyarrow_scan.h"
#include "pybind11/pytypes.h"

Expand All @@ -32,10 +33,13 @@ std::unique_ptr<TableFuncBindData> bindFunc(ClientContext* /*context*/,
auto getFunc = df.attr("__getitem__");
auto numRows = py::len(getFunc(columns[0]));
auto returnColumns = input->binder->createVariables(names, returnTypes);
auto scanConfig =
input->extraInput->constPtrCast<ExtraScanTableFuncBindInput>()->fileScanInfo.copy();
return std::make_unique<PandasScanFunctionData>(std::move(returnColumns), df, numRows,
std::move(columnBindData), std::move(scanConfig));
auto scanConfig = PyScanConfig{
input->extraInput->constPtrCast<ExtraScanTableFuncBindInput>()->fileScanInfo.options,
numRows};
KU_ASSERT(numRows >= scanConfig.skipNum);
return std::make_unique<PandasScanFunctionData>(std::move(returnColumns), df,
std::min(numRows - scanConfig.skipNum, scanConfig.limitNum), std::move(columnBindData),
scanConfig);
}

bool sharedStateNext(const TableFuncBindData* /*bindData*/, PandasScanLocalState* localState,
Expand All @@ -45,11 +49,11 @@ bool sharedStateNext(const TableFuncBindData* /*bindData*/, PandasScanLocalState
if (pandasSharedState->numRowsRead >= pandasSharedState->numRows) {
return false;
}
localState->start = pandasSharedState->numRowsRead;
localState->start = pandasSharedState->startRow + pandasSharedState->numRowsRead;
pandasSharedState->numRowsRead +=
std::min(pandasSharedState->numRows - pandasSharedState->numRowsRead,
CopyConstants::PANDAS_PARTITION_COUNT);
localState->end = pandasSharedState->numRowsRead;
localState->end = pandasSharedState->startRow + pandasSharedState->numRowsRead;
return true;
}

Expand All @@ -67,7 +71,8 @@ std::unique_ptr<TableFuncSharedState> initSharedState(const TableFunctionInitInp
}
// LCOV_EXCL_STOP
auto scanBindData = ku_dynamic_cast<PandasScanFunctionData*>(input.bindData);
return std::make_unique<PandasScanSharedState>(scanBindData->cardinality);
return std::make_unique<PandasScanSharedState>(scanBindData->scanConfig.skipNum,
scanBindData->cardinality);
}

void pandasBackendScanSwitch(PandasColumnBindData* bindData, uint64_t count, uint64_t offset,
Expand Down
39 changes: 39 additions & 0 deletions tools/python_api/src_cpp/py_scan_config.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include "py_scan_config.h"

#include "common/constants.h"
#include "common/exception/binder.h"
#include "function/cast/functions/numeric_limits.h"

namespace kuzu {

PyScanConfig::PyScanConfig(const common::case_insensitive_map_t<common::Value>& options,
uint64_t numRows) {
skipNum = 0;
limitNum = function::NumericLimits<uint64_t>::maximum();
ignoreErrors = common::CopyConstants::DEFAULT_IGNORE_ERRORS;
for (const auto& i : options) {
if (i.first == "SKIP") {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw common::BinderException("SKIP Option must be a positive integer literal.");
}
skipNum = std::min(numRows, static_cast<uint64_t>(i.second.val.int64Val));
} else if (i.first == "LIMIT") {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw common::BinderException("LIMIT Option must be a positive integer literal.");
}
limitNum = i.second.val.int64Val;
} else if (i.first == common::CopyConstants::IGNORE_ERRORS_OPTION_NAME) {
if (i.second.getDataType().getLogicalTypeID() != common::LogicalTypeID::BOOL) {
throw common::BinderException("IGNORE_ERRORS Option must be a boolean.");
}
ignoreErrors = i.second.val.booleanVal;
} else {
throw common::BinderException(
common::stringFormat("{} Option not recognized by pyArrow scanner.", i.first));
}
}
}

} // namespace kuzu
32 changes: 2 additions & 30 deletions tools/python_api/src_cpp/pyarrow/pyarrow_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "function/table/bind_input.h"
#include "processor/execution_context.h"
#include "py_connection.h"
#include "py_scan_config.h"
#include "pyarrow/pyarrow_bind.h"
#include "pybind11/pytypes.h"

Expand All @@ -16,35 +17,6 @@ using namespace kuzu::catalog;

namespace kuzu {

PyArrowScanConfig::PyArrowScanConfig(const case_insensitive_map_t<Value>& options) {
skipNum = 0;
limitNum = NumericLimits<uint64_t>::maximum();
ignoreErrors = CopyConstants::DEFAULT_IGNORE_ERRORS;
for (const auto& i : options) {
if (i.first == "SKIP") {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw BinderException("SKIP Option must be a positive integer literal.");
}
skipNum = i.second.val.int64Val;
} else if (i.first == "LIMIT") {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::INT64 ||
i.second.val.int64Val < 0) {
throw BinderException("LIMIT Option must be a positive integer literal.");
}
limitNum = i.second.val.int64Val;
} else if (i.first == CopyConstants::IGNORE_ERRORS_OPTION_NAME) {
if (i.second.getDataType().getLogicalTypeID() != LogicalTypeID::BOOL) {
throw BinderException("IGNORE_ERRORS Option must be a boolean.");
}
ignoreErrors = i.second.val.booleanVal;
} else {
throw BinderException(
stringFormat("{} Option not recognized by pyArrow scanner.", i.first));
}
}
}

template<typename T>
static bool moduleIsLoaded() {
auto dict = pybind11::module_::import("sys").attr("modules");
Expand Down Expand Up @@ -73,7 +45,7 @@ static std::unique_ptr<TableFuncBindData> bindFunc(ClientContext*,
}
auto numRows = py::len(table);
auto schema = Pyarrow::bind(table, returnTypes, names);
auto config = PyArrowScanConfig(scanInput->fileScanInfo.options);
auto config = PyScanConfig(scanInput->fileScanInfo.options, numRows);
// The following python operations are zero copy as defined in pyarrow docs.
if (config.skipNum != 0) {
table = table.attr("slice")(config.skipNum);
Expand Down
89 changes: 86 additions & 3 deletions tools/python_api/test/test_scan_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ def test_scan_pandas(tmp_path: Path) -> None:
"INT32": np.array([-100, -200, -300, -400], dtype=np.int32),
"INT64": np.array([-1000, -2000, -3000, -4000], dtype=np.int64),
"FLOAT_32": np.array(
[-0.5199999809265137, float("nan"), -3.299999952316284, 4.400000095367432], dtype=np.float32
[-0.5199999809265137, float("nan"), -3.299999952316284, 4.400000095367432],
dtype=np.float32,
),
"FLOAT_64": np.array([5132.12321, 24.222, float("nan"), 4.444], dtype=np.float64),
"datetime_microseconds": np.array([
Expand Down Expand Up @@ -312,8 +313,18 @@ def test_pandas_scan_demo(tmp_path: Path) -> None:
"height_in_inch RETURN s"
).get_as_df()
assert len(result) == 2
assert result["s"][0] == {"ID": 0, "_id": {"offset": 0, "table": 0}, "_label": "student", "height": 70}
assert result["s"][1] == {"ID": 4, "_id": {"offset": 2, "table": 0}, "_label": "student", "height": 67}
assert result["s"][0] == {
"ID": 0,
"_id": {"offset": 0, "table": 0},
"_label": "student",
"height": 70,
}
assert result["s"][1] == {
"ID": 4,
"_id": {"offset": 2, "table": 0},
"_label": "student",
"height": 67,
}

conn.execute("CREATE NODE TABLE person(ID INT64, age UINT16, height UINT32, is_student BOOLean, PRIMARY KEY(ID))")
conn.execute("LOAD FROM person CREATE (p:person {ID: id, age: age, height: height, is_student: is_student})")
Expand Down Expand Up @@ -402,6 +413,78 @@ def test_copy_from_pandas_object(tmp_path: Path) -> None:
assert result.has_next() is False


def test_copy_from_pandas_object_skip(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Zhang", "50"]
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False
df = pd.DataFrame({"f": ["Adam", "Noura"], "t": ["Zhang", "Zhang"]})
conn.execute("CREATE REL TABLE Knows(FROM Person TO Person);")
conn.execute("COPY Knows FROM df(SKIP=1)")
result = conn.execute("match (p:Person)-[]->(:Person {name: 'Zhang'}) return p.*")
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False


def test_copy_from_pandas_object_limit(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(LIMIT=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.get_next() == ["Karissa", "40"]
assert result.has_next() is False
df = pd.DataFrame({"f": ["Adam", "Zhang"], "t": ["Karissa", "Karissa"]})
conn.execute("CREATE REL TABLE Knows(FROM Person TO Person);")
conn.execute("COPY Knows FROM df(LIMIT=1)")
result = conn.execute("match (p:Person)-[]->(:Person {name: 'Karissa'}) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.has_next() is False


def test_copy_from_pandas_object_skip_and_limit(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=1, LIMIT=2);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Karissa", "40"]
assert result.get_next() == ["Zhang", "50"]
assert result.has_next() is False


def test_copy_from_pandas_object_skip_bounds_check(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(SKIP=10);")
result = conn.execute("match (p:Person) return p.*")
assert result.has_next() is False


def test_copy_from_pandas_object_limit_bounds_check(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
df = pd.DataFrame({"name": ["Adam", "Karissa", "Zhang", "Noura"], "age": [30, 40, 50, 25]})
conn.execute("CREATE NODE TABLE Person(name STRING, age STRING, PRIMARY KEY (name));")
conn.execute("COPY Person FROM df(LIMIT=10);")
result = conn.execute("match (p:Person) return p.*")
assert result.get_next() == ["Adam", "30"]
assert result.get_next() == ["Karissa", "40"]
assert result.get_next() == ["Zhang", "50"]
assert result.get_next() == ["Noura", "25"]
assert result.has_next() is False


def test_copy_from_pandas_date(tmp_path: Path) -> None:
db = kuzu.Database(tmp_path)
conn = kuzu.Connection(db)
Expand Down
12 changes: 12 additions & 0 deletions tools/python_api/test/test_scan_pandas_pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,18 @@ def test_pyarrow_skip_limit(conn_db_readonly: ConnDB) -> None:
assert result["col1"].to_pylist() == expected["col1"].to_pylist()
assert result["col2"].to_pylist() == expected["col2"].to_pylist()

# skip bounds check
result = conn.execute("LOAD FROM df (SKIP=500000, LIMIT=5000) RETURN * ORDER BY index").get_as_arrow()
assert len(result) == 0

# limit bounds check
result = conn.execute("LOAD FROM df (SKIP=0, LIMIT=500000) RETURN * ORDER BY index").get_as_arrow()
expected = pa.Table.from_pandas(df)
assert result["index"].to_pylist() == expected["index"].to_pylist()
assert result["col0"].to_pylist() == expected["col0"].to_pylist()
assert result["col1"].to_pylist() == expected["col1"].to_pylist()
assert result["col2"].to_pylist() == expected["col2"].to_pylist()


def test_pyarrow_invalid_skip_limit(conn_db_readonly: ConnDB) -> None:
conn, db = conn_db_readonly
Expand Down

0 comments on commit 6caaad7

Please sign in to comment.