Skip to content

Commit b5d031d

Browse files
authored
Merge pull request #154 from Maxxen/dev
fix broken join optimization
2 parents bacf0b2 + 0037249 commit b5d031d

File tree

4 files changed

+117
-30
lines changed

4 files changed

+117
-30
lines changed

spatial/src/spatial/core/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ set(EXTENSION_SOURCES
44
${EXTENSION_SOURCES}
55
${CMAKE_CURRENT_SOURCE_DIR}/module.cpp
66
${CMAKE_CURRENT_SOURCE_DIR}/types.cpp
7-
${CMAKE_CURRENT_SOURCE_DIR}/optimizers.cpp
7+
${CMAKE_CURRENT_SOURCE_DIR}/optimizer_rules.cpp
88
PARENT_SCOPE
99
)

spatial/src/spatial/core/module.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
#include "spatial/core/module.hpp"
33

44
#include "spatial/common.hpp"
5-
#include "spatial/core/functions/scalar.hpp"
6-
#include "spatial/core/functions/cast.hpp"
75
#include "spatial/core/functions/aggregate.hpp"
6+
#include "spatial/core/functions/cast.hpp"
7+
#include "spatial/core/functions/scalar.hpp"
88
#include "spatial/core/functions/table.hpp"
9-
#include "spatial/core/optimizers.hpp"
9+
#include "spatial/core/optimizer_rules.hpp"
1010
#include "spatial/core/types.hpp"
1111

1212
namespace spatial {

spatial/src/spatial/core/optimizers.cpp spatial/src/spatial/core/optimizer_rules.cpp

+113-26
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1+
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
2+
#include "duckdb/execution/expression_executor.hpp"
13
#include "duckdb/optimizer/optimizer_extension.hpp"
4+
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
5+
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
6+
#include "duckdb/planner/expression/bound_function_expression.hpp"
27
#include "duckdb/planner/logical_operator.hpp"
3-
#include "duckdb/planner/operator/logical_join.hpp"
48
#include "duckdb/planner/operator/logical_any_join.hpp"
59
#include "duckdb/planner/operator/logical_comparison_join.hpp"
610
#include "duckdb/planner/operator/logical_filter.hpp"
711
#include "duckdb/planner/operator/logical_get.hpp"
8-
#include "duckdb/execution/expression_executor.hpp"
9-
#include "duckdb/planner/expression/bound_function_expression.hpp"
10-
#include "duckdb/planner/expression/bound_comparison_expression.hpp"
11-
#include "duckdb/planner/expression/bound_conjunction_expression.hpp"
12-
#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
12+
#include "duckdb/planner/operator/logical_join.hpp"
1313
#include "spatial/common.hpp"
14-
#include "spatial/core/optimizers.hpp"
14+
#include "spatial/core/optimizer_rules.hpp"
1515

1616
namespace spatial {
1717

@@ -23,7 +23,8 @@ namespace core {
2323
//
2424
// Rewrites joins on spatial predicates to range joins on their bounding boxes
2525
// combined with a spatial predicate filter. This turns the joins from a
26-
// blockwise-nested loop join into a inequality join, which is much faster.
26+
// blockwise-nested loop join into a inequality join + filter, which is much
27+
// faster.
2728
//
2829
// All spatial predicates (except st_disjoint) imply an intersection of the
2930
// bounding boxes of the two geometries.
@@ -43,7 +44,44 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
4344
join->conditions.push_back(std::move(cmp));
4445
}
4546

46-
static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
47+
48+
static bool IsTableRefsDisjoint(unordered_set<idx_t> &left_table_indexes,
49+
unordered_set<idx_t> &right_table_indexes,
50+
unordered_set<idx_t> &left_bindings,
51+
unordered_set<idx_t> &right_bindings) {
52+
53+
// Check that all the left-side bindings reference the left-side tables of the join,
54+
// as well as that all the right-side bindings reference the right-side tables of the join.
55+
// and that the left and right side bindings are disjoint.
56+
57+
for(auto &left_binding : left_bindings) {
58+
if(right_bindings.find(left_binding) != right_bindings.end()) {
59+
// The left side bindings reference the right side tables of the join.
60+
return false;
61+
}
62+
// Also check that the left side bindings are on the left side of the join
63+
if(left_table_indexes.find(left_binding) == left_table_indexes.end()) {
64+
// The left side bindings are not on the left side of the join.
65+
return false;
66+
}
67+
}
68+
69+
for(auto &right_binding : right_bindings) {
70+
if(left_bindings.find(right_binding) != left_bindings.end()) {
71+
// The right side bindings reference the left side tables of the join.
72+
return false;
73+
}
74+
// Also check that the right side bindings are on the right side of the join
75+
if(right_table_indexes.find(right_binding) == right_table_indexes.end()) {
76+
// The right side bindings are not on the right side of the join.
77+
return false;
78+
}
79+
}
80+
81+
return true;
82+
}
83+
84+
static void TryOptimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
4785

4886
auto &op = *plan;
4987

@@ -53,7 +91,8 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
5391

5492
// Check if the join condition is a spatial predicate and the join type is INNER
5593
if (any_join.condition->type == ExpressionType::BOUND_FUNCTION && any_join.join_type == JoinType::INNER) {
56-
auto &bound_function = any_join.condition->Cast<BoundFunctionExpression>();
94+
auto bound_func_expr = any_join.condition->Copy();
95+
auto &bound_function = bound_func_expr->Cast<BoundFunctionExpression>();
5796

5897
// Note that we cant perform this optimization for st_disjoint as all comparisons have to be AND'd
5998
case_insensitive_set_t predicates = {"st_equals", "st_intersects", "st_touches", "st_crosses",
@@ -65,37 +104,76 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
65104

66105
// Convert this into a comparison join on st_xmin, st_xmax, st_ymin, st_ymax of the two input
67106
// geometries
68-
auto &left_pred_expr = bound_function.children[0];
69-
auto &right_pred_expr = bound_function.children[1];
107+
auto left_pred_expr = std::move(bound_function.children[0]);
108+
auto right_pred_expr = std::move(bound_function.children[1]);
109+
110+
// We need to place the left side of the predicate on the left side of the join
111+
// and the right side of the predicate on the right side of the join
112+
// So look at the table indexes of the left and right side of the predicate
113+
unordered_set<idx_t> left_table_indexes;
114+
LogicalJoin::GetTableReferences(*any_join.children[0], left_table_indexes);
115+
116+
unordered_set<idx_t> right_table_indexes;
117+
LogicalJoin::GetTableReferences(*any_join.children[1], right_table_indexes);
118+
119+
unordered_set<idx_t> left_pred_bindings;
120+
LogicalJoin::GetExpressionBindings(*left_pred_expr, left_pred_bindings);
121+
122+
unordered_set<idx_t> right_pred_bindings;
123+
LogicalJoin::GetExpressionBindings(*right_pred_expr, right_pred_bindings);
124+
125+
// Check if we can optimize this join
126+
// We need to make sure that the left and right side of the predicate are disjoint
127+
// e.g.
128+
// a JOIN b ON st_intersects(a.geom, b.geom) => OK
129+
// a JOIN b ON st_intersects(b.geom, a.geom) => OK
130+
// a JOIN b ON st_intersects(a.geom, st_union(a.geom, b.geom)) => NOT OK
131+
auto can_split = IsTableRefsDisjoint(left_table_indexes, right_table_indexes, left_pred_bindings, right_pred_bindings);
132+
if(!can_split) {
133+
// Try again with the left and right side of the predicate swapped
134+
// We can safely swap because the intersection operation we encode with the comparison join
135+
// is symmetric, so the order of the arguments wont matter in the "new" join condition we're
136+
// about to create.
137+
can_split = IsTableRefsDisjoint(left_table_indexes, right_table_indexes, right_pred_bindings, left_pred_bindings);
138+
if(!can_split) {
139+
// We cant optimize this join
140+
return;
141+
}
142+
// Swap the left and right side of the predicate
143+
std::swap(left_pred_expr, right_pred_expr);
144+
}
70145

71146
// Lookup the st_xmin, st_xmax, st_ymin, st_ymax functions in the catalog
72147
auto &catalog = Catalog::GetSystemCatalog(context);
73-
auto &xmin_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, "", "st_xmin")
148+
auto &xmin_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_xmin")
74149
.Cast<ScalarFunctionCatalogEntry>();
75-
auto &xmax_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, "", "st_xmax")
150+
auto &xmax_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_xmax")
76151
.Cast<ScalarFunctionCatalogEntry>();
77-
auto &ymin_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, "", "st_ymin")
152+
auto &ymin_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_ymin")
78153
.Cast<ScalarFunctionCatalogEntry>();
79-
auto &ymax_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, "", "st_ymax")
154+
auto &ymax_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_ymax")
80155
.Cast<ScalarFunctionCatalogEntry>();
81156

157+
auto &left_arg_type = left_pred_expr->return_type;
158+
auto &right_arg_type = right_pred_expr->return_type;
159+
82160
auto xmin_func_left =
83-
xmin_func_set.functions.GetFunctionByArguments(context, {left_pred_expr->return_type});
161+
xmin_func_set.functions.GetFunctionByArguments(context, {left_arg_type});
84162
auto xmax_func_left =
85-
xmax_func_set.functions.GetFunctionByArguments(context, {left_pred_expr->return_type});
163+
xmax_func_set.functions.GetFunctionByArguments(context, {left_arg_type});
86164
auto ymin_func_left =
87-
ymin_func_set.functions.GetFunctionByArguments(context, {left_pred_expr->return_type});
165+
ymin_func_set.functions.GetFunctionByArguments(context, {left_arg_type});
88166
auto ymax_func_left =
89-
ymax_func_set.functions.GetFunctionByArguments(context, {left_pred_expr->return_type});
167+
ymax_func_set.functions.GetFunctionByArguments(context, {left_arg_type});
90168

91169
auto xmin_func_right =
92-
xmin_func_set.functions.GetFunctionByArguments(context, {right_pred_expr->return_type});
170+
xmin_func_set.functions.GetFunctionByArguments(context, {right_arg_type});
93171
auto xmax_func_right =
94-
xmax_func_set.functions.GetFunctionByArguments(context, {right_pred_expr->return_type});
172+
xmax_func_set.functions.GetFunctionByArguments(context, {right_arg_type});
95173
auto ymin_func_right =
96-
ymin_func_set.functions.GetFunctionByArguments(context, {right_pred_expr->return_type});
174+
ymin_func_set.functions.GetFunctionByArguments(context, {right_arg_type});
97175
auto ymax_func_right =
98-
ymax_func_set.functions.GetFunctionByArguments(context, {right_pred_expr->return_type});
176+
ymax_func_set.functions.GetFunctionByArguments(context, {right_arg_type});
99177

100178
// Create the new join condition
101179

@@ -151,16 +229,25 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
151229
ExpressionType::COMPARE_LESSTHANOREQUALTO);
152230
AddComparison(new_join, std::move(a_y_max), std::move(b_y_min),
153231
ExpressionType::COMPARE_GREATERTHANOREQUALTO);
232+
154233
new_join->children = std::move(any_join.children);
234+
if(any_join.has_estimated_cardinality) {
235+
new_join->estimated_cardinality = any_join.estimated_cardinality;
236+
new_join->has_estimated_cardinality = true;
237+
}
155238

156-
// Also, we need to create a filter with the original predicate
157-
auto filter = make_uniq<LogicalFilter>(any_join.condition->Copy());
239+
auto filter = make_uniq<LogicalFilter>(std::move(any_join.condition));
158240
filter->children.push_back(std::move(new_join));
159241

160242
plan = std::move(filter);
161243
}
162244
}
163245
}
246+
}
247+
248+
static void Optimize(ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
249+
250+
TryOptimize(context, info, plan);
164251

165252
// Recursively optimize the children
166253
for (auto &child : plan->children) {

0 commit comments

Comments
 (0)