From d90eadcaf5885147472a97f481b3261bc3bf0815 Mon Sep 17 00:00:00 2001 From: Max Gabrielsson Date: Mon, 11 Sep 2023 18:08:23 +0200 Subject: [PATCH] update duckdb, add geos aggregates --- duckdb | 2 +- .../spatial/core/functions/aggregate.hpp | 4 +- .../include/spatial/geos/geos_wrappers.hpp | 3 + spatial/src/spatial/core/module.cpp | 2 +- .../src/spatial/geos/functions/aggregate.cpp | 185 ++++++++++++++++++ spatial/src/spatial/geos/geos_wrappers.cpp | 22 ++- 6 files changed, 208 insertions(+), 10 deletions(-) diff --git a/duckdb b/duckdb index a8ce02cc..9db510bd 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit a8ce02cc2e740d8973d26ccdb77d0068c69c9124 +Subproject commit 9db510bd1105f883747baca317a9b63adedf0d8e diff --git a/spatial/include/spatial/core/functions/aggregate.hpp b/spatial/include/spatial/core/functions/aggregate.hpp index bafc3573..3e90763a 100644 --- a/spatial/include/spatial/core/functions/aggregate.hpp +++ b/spatial/include/spatial/core/functions/aggregate.hpp @@ -6,7 +6,9 @@ namespace spatial { namespace core { struct CoreAggregateFunctions { - static void Register(ClientContext &context); +public: + static void Register(ClientContext &context) { + } }; } // namespace core diff --git a/spatial/include/spatial/geos/geos_wrappers.hpp b/spatial/include/spatial/geos/geos_wrappers.hpp index 9f651d9d..80bd1954 100644 --- a/spatial/include/spatial/geos/geos_wrappers.hpp +++ b/spatial/include/spatial/geos/geos_wrappers.hpp @@ -206,6 +206,9 @@ struct GeosContextWrapper { string_t Serialize(Vector &result, const unique_ptr> &geom); }; +GEOSGeometry *DeserializeGEOSGeometry(const string_t &blob, GEOSContextHandle_t ctx); +string_t SerializeGEOSGeometry(Vector &result, const GEOSGeometry *geom, GEOSContextHandle_t ctx); + } // namespace geos } // namespace spatial \ No newline at end of file diff --git a/spatial/src/spatial/core/module.cpp b/spatial/src/spatial/core/module.cpp index 35ba203d..db0014ef 100644 --- a/spatial/src/spatial/core/module.cpp +++ b/spatial/src/spatial/core/module.cpp @@ -18,7 +18,7 @@ void CoreModule::Register(ClientContext &context) { CoreScalarFunctions::Register(context); CoreCastFunctions::Register(context); CoreTableFunctions::Register(context); - // CoreAggregateFunctions::Register(context); + CoreAggregateFunctions::Register(context); } } // namespace core diff --git a/spatial/src/spatial/geos/functions/aggregate.cpp b/spatial/src/spatial/geos/functions/aggregate.cpp index 7c83f92f..2ab0797f 100644 --- a/spatial/src/spatial/geos/functions/aggregate.cpp +++ b/spatial/src/spatial/geos/functions/aggregate.cpp @@ -1,12 +1,197 @@ +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" #include "spatial/common.hpp" #include "spatial/geos/functions/aggregate.hpp" +#include "spatial/geos/geos_wrappers.hpp" + +#include "geos_c.h" namespace spatial { namespace geos { +struct GEOSAggState { + GEOSGeometry *geom = nullptr; + GEOSContextHandle_t context = nullptr; + + ~GEOSAggState() { + if (geom) { + GEOSGeom_destroy_r(context, geom); + geom = nullptr; + } + if (context) { + GEOS_finish_r(context); + context = nullptr; + } + } +}; + +//------------------------------------------------------------------------ +// INTERSECTION +//------------------------------------------------------------------------ +struct IntersectionAggFunction { + template + static void Initialize(STATE &state) { + state.geom = nullptr; + state.context = GEOS_init_r(); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &data) { + if (!source.geom) { + return; + } + if (!target.geom) { + target.geom = GEOSGeom_clone_r(target.context, source.geom); + return; + } + auto curr = target.geom; + target.geom = GEOSIntersection_r(target.context, curr, source.geom); + GEOSGeom_destroy_r(target.context, curr); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + if (!state.geom) { + state.geom = DeserializeGEOSGeometry(input, state.context); + } else { + auto next = DeserializeGEOSGeometry(input, state.context); + auto curr = state.geom; + state.geom = GEOSIntersection_r(state.context, curr, next); + GEOSGeom_destroy_r(state.context, next); + GEOSGeom_destroy_r(state.context, curr); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + // There is no point in doing anything else, intersection is idempotent + if (!state.geom) { + state.geom = DeserializeGEOSGeometry(input, state.context); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.geom) { + finalize_data.ReturnNull(); + } else { + target = SerializeGEOSGeometry(finalize_data.result, state.geom, state.context); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &) { + if (state.geom) { + GEOSGeom_destroy_r(state.context, state.geom); + state.geom = nullptr; + } + if (state.context) { + GEOS_finish_r(state.context); + state.context = nullptr; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +//------------------------------------------------------------------------ +// UNION +//------------------------------------------------------------------------ + +struct UnionAggFunction { + template + static void Initialize(STATE &state) { + state.geom = nullptr; + state.context = GEOS_init_r(); + } + + template + static void Combine(const STATE &source, STATE &target, AggregateInputData &data) { + if (!source.geom) { + return; + } + if (!target.geom) { + target.geom = GEOSGeom_clone_r(target.context, source.geom); + return; + } + auto curr = target.geom; + target.geom = GEOSUnion_r(target.context, curr, source.geom); + GEOSGeom_destroy_r(target.context, curr); + } + + template + static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { + if (!state.geom) { + state.geom = DeserializeGEOSGeometry(input, state.context); + } else { + auto next = DeserializeGEOSGeometry(input, state.context); + auto curr = state.geom; + state.geom = GEOSUnion_r(state.context, curr, next); + GEOSGeom_destroy_r(state.context, next); + GEOSGeom_destroy_r(state.context, curr); + } + } + + template + static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { + // There is no point in doing anything else, union is idempotent + if (!state.geom) { + state.geom = DeserializeGEOSGeometry(input, state.context); + } + } + + template + static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { + if (!state.geom) { + finalize_data.ReturnNull(); + } else { + target = SerializeGEOSGeometry(finalize_data.result, state.geom, state.context); + } + } + + template + static void Destroy(STATE &state, AggregateInputData &) { + if (state.geom) { + GEOSGeom_destroy_r(state.context, state.geom); + state.geom = nullptr; + } + if (state.context) { + GEOS_finish_r(state.context); + state.context = nullptr; + } + } + + static bool IgnoreNull() { + return true; + } +}; + +//------------------------------------------------------------------------ +// Register +//------------------------------------------------------------------------ void GeosAggregateFunctions::Register(ClientContext &context) { + + auto &catalog = Catalog::GetSystemCatalog(context); + + AggregateFunctionSet st_intersection_agg("st_intersection_agg"); + st_intersection_agg.AddFunction( + AggregateFunction::UnaryAggregateDestructor( + core::GeoTypes::GEOMETRY(), core::GeoTypes::GEOMETRY())); + CreateAggregateFunctionInfo intersection_info(std::move(st_intersection_agg)); + intersection_info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; + catalog.CreateFunction(context, intersection_info); + + AggregateFunctionSet st_union_agg("st_union_agg"); + st_union_agg.AddFunction( + AggregateFunction::UnaryAggregateDestructor( + core::GeoTypes::GEOMETRY(), core::GeoTypes::GEOMETRY())); + CreateAggregateFunctionInfo union_info(std::move(st_union_agg)); + union_info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; + catalog.CreateFunction(context, union_info); } } // namespace geos diff --git a/spatial/src/spatial/geos/geos_wrappers.cpp b/spatial/src/spatial/geos/geos_wrappers.cpp index cc5375bf..9f09273d 100644 --- a/spatial/src/spatial/geos/geos_wrappers.cpp +++ b/spatial/src/spatial/geos/geos_wrappers.cpp @@ -143,7 +143,7 @@ static GEOSGeometry *DeserializeGeometryCollection(Cursor &reader, GEOSContextHa } } -static GEOSGeometry *DeserializeGeometry(Cursor &reader, GEOSContextHandle_t ctx) { +GEOSGeometry *DeserializeGeometry(Cursor &reader, GEOSContextHandle_t ctx) { auto type = reader.Peek(); switch (type) { case GeometryType::POINT: { @@ -174,11 +174,15 @@ static GEOSGeometry *DeserializeGeometry(Cursor &reader, GEOSContextHandle_t ctx } } -GeometryPtr GeosContextWrapper::Deserialize(const string_t &blob) { +GEOSGeometry *DeserializeGEOSGeometry(const string_t &blob, GEOSContextHandle_t ctx) { Cursor reader(blob); reader.Skip(4); // Skip type, flags and hash reader.Skip(4); // Skip padding - return GeometryPtr(DeserializeGeometry(reader, ctx)); + return DeserializeGeometry(reader, ctx); +} + +GeometryPtr GeosContextWrapper::Deserialize(const string_t &blob) { + return GeometryPtr(DeserializeGEOSGeometry(blob, ctx)); } //------------------------------------------------------------------- @@ -451,8 +455,8 @@ static void SerializeGeometry(Cursor &writer, const GEOSGeometry *geom, const GE } } -string_t GeosContextWrapper::Serialize(Vector &result, const GeometryPtr &geom) { - auto size = GetSerializedSize(geom.get(), ctx); +string_t SerializeGEOSGeometry(Vector &result, const GEOSGeometry *geom, GEOSContextHandle_t ctx) { + auto size = GetSerializedSize(geom, ctx); size += sizeof(GeometryHeader); // Header size += sizeof(uint32_t); // Padding @@ -465,7 +469,7 @@ string_t GeosContextWrapper::Serialize(Vector &result, const GeometryPtr &geom) } GeometryType type; - auto geos_type = GEOSGeomTypeId_r(ctx, geom.get()); + auto geos_type = GEOSGeomTypeId_r(ctx, geom); switch (geos_type) { case GEOS_POINT: type = GeometryType::POINT; @@ -501,11 +505,15 @@ string_t GeosContextWrapper::Serialize(Vector &result, const GeometryPtr &geom) writer.Write(header); // Header writer.Write(0); // Padding - SerializeGeometry(writer, geom.get(), ctx); + SerializeGeometry(writer, geom, ctx); return blob; } +string_t GeosContextWrapper::Serialize(Vector &result, const GeometryPtr &geom) { + return SerializeGEOSGeometry(result, geom.get(), ctx); +} + } // namespace geos } // namespace spatial \ No newline at end of file