1
+ #include " duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
2
+ #include " duckdb/execution/expression_executor.hpp"
1
3
#include " duckdb/optimizer/optimizer_extension.hpp"
4
+ #include " duckdb/planner/expression/bound_comparison_expression.hpp"
5
+ #include " duckdb/planner/expression/bound_conjunction_expression.hpp"
6
+ #include " duckdb/planner/expression/bound_function_expression.hpp"
2
7
#include " duckdb/planner/logical_operator.hpp"
3
- #include " duckdb/planner/operator/logical_join.hpp"
4
8
#include " duckdb/planner/operator/logical_any_join.hpp"
5
9
#include " duckdb/planner/operator/logical_comparison_join.hpp"
6
10
#include " duckdb/planner/operator/logical_filter.hpp"
7
11
#include " duckdb/planner/operator/logical_get.hpp"
8
- #include " duckdb/execution/expression_executor.hpp"
9
- #include " duckdb/planner/expression/bound_function_expression.hpp"
10
- #include " duckdb/planner/expression/bound_comparison_expression.hpp"
11
- #include " duckdb/planner/expression/bound_conjunction_expression.hpp"
12
- #include " duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp"
12
+ #include " duckdb/planner/operator/logical_join.hpp"
13
13
#include " spatial/common.hpp"
14
- #include " spatial/core/optimizers .hpp"
14
+ #include " spatial/core/optimizer_rules .hpp"
15
15
16
16
namespace spatial {
17
17
@@ -23,7 +23,8 @@ namespace core {
23
23
//
24
24
// Rewrites joins on spatial predicates to range joins on their bounding boxes
25
25
// combined with a spatial predicate filter. This turns the joins from a
26
- // blockwise-nested loop join into a inequality join, which is much faster.
26
+ // blockwise-nested loop join into a inequality join + filter, which is much
27
+ // faster.
27
28
//
28
29
// All spatial predicates (except st_disjoint) imply an intersection of the
29
30
// bounding boxes of the two geometries.
@@ -43,7 +44,44 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
43
44
join->conditions .push_back (std::move (cmp));
44
45
}
45
46
46
- static void Optimize (ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
47
+
48
+ static bool IsTableRefsDisjoint (unordered_set<idx_t > &left_table_indexes,
49
+ unordered_set<idx_t > &right_table_indexes,
50
+ unordered_set<idx_t > &left_bindings,
51
+ unordered_set<idx_t > &right_bindings) {
52
+
53
+ // Check that all the left-side bindings reference the left-side tables of the join,
54
+ // as well as that all the right-side bindings reference the right-side tables of the join.
55
+ // and that the left and right side bindings are disjoint.
56
+
57
+ for (auto &left_binding : left_bindings) {
58
+ if (right_bindings.find (left_binding) != right_bindings.end ()) {
59
+ // The left side bindings reference the right side tables of the join.
60
+ return false ;
61
+ }
62
+ // Also check that the left side bindings are on the left side of the join
63
+ if (left_table_indexes.find (left_binding) == left_table_indexes.end ()) {
64
+ // The left side bindings are not on the left side of the join.
65
+ return false ;
66
+ }
67
+ }
68
+
69
+ for (auto &right_binding : right_bindings) {
70
+ if (left_bindings.find (right_binding) != left_bindings.end ()) {
71
+ // The right side bindings reference the left side tables of the join.
72
+ return false ;
73
+ }
74
+ // Also check that the right side bindings are on the right side of the join
75
+ if (right_table_indexes.find (right_binding) == right_table_indexes.end ()) {
76
+ // The right side bindings are not on the right side of the join.
77
+ return false ;
78
+ }
79
+ }
80
+
81
+ return true ;
82
+ }
83
+
84
+ static void TryOptimize (ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
47
85
48
86
auto &op = *plan;
49
87
@@ -53,7 +91,8 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
53
91
54
92
// Check if the join condition is a spatial predicate and the join type is INNER
55
93
if (any_join.condition ->type == ExpressionType::BOUND_FUNCTION && any_join.join_type == JoinType::INNER) {
56
- auto &bound_function = any_join.condition ->Cast <BoundFunctionExpression>();
94
+ auto bound_func_expr = any_join.condition ->Copy ();
95
+ auto &bound_function = bound_func_expr->Cast <BoundFunctionExpression>();
57
96
58
97
// Note that we cant perform this optimization for st_disjoint as all comparisons have to be AND'd
59
98
case_insensitive_set_t predicates = {" st_equals" , " st_intersects" , " st_touches" , " st_crosses" ,
@@ -65,37 +104,76 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
65
104
66
105
// Convert this into a comparison join on st_xmin, st_xmax, st_ymin, st_ymax of the two input
67
106
// geometries
68
- auto &left_pred_expr = bound_function.children [0 ];
69
- auto &right_pred_expr = bound_function.children [1 ];
107
+ auto left_pred_expr = std::move (bound_function.children [0 ]);
108
+ auto right_pred_expr = std::move (bound_function.children [1 ]);
109
+
110
+ // We need to place the left side of the predicate on the left side of the join
111
+ // and the right side of the predicate on the right side of the join
112
+ // So look at the table indexes of the left and right side of the predicate
113
+ unordered_set<idx_t > left_table_indexes;
114
+ LogicalJoin::GetTableReferences (*any_join.children [0 ], left_table_indexes);
115
+
116
+ unordered_set<idx_t > right_table_indexes;
117
+ LogicalJoin::GetTableReferences (*any_join.children [1 ], right_table_indexes);
118
+
119
+ unordered_set<idx_t > left_pred_bindings;
120
+ LogicalJoin::GetExpressionBindings (*left_pred_expr, left_pred_bindings);
121
+
122
+ unordered_set<idx_t > right_pred_bindings;
123
+ LogicalJoin::GetExpressionBindings (*right_pred_expr, right_pred_bindings);
124
+
125
+ // Check if we can optimize this join
126
+ // We need to make sure that the left and right side of the predicate are disjoint
127
+ // e.g.
128
+ // a JOIN b ON st_intersects(a.geom, b.geom) => OK
129
+ // a JOIN b ON st_intersects(b.geom, a.geom) => OK
130
+ // a JOIN b ON st_intersects(a.geom, st_union(a.geom, b.geom)) => NOT OK
131
+ auto can_split = IsTableRefsDisjoint (left_table_indexes, right_table_indexes, left_pred_bindings, right_pred_bindings);
132
+ if (!can_split) {
133
+ // Try again with the left and right side of the predicate swapped
134
+ // We can safely swap because the intersection operation we encode with the comparison join
135
+ // is symmetric, so the order of the arguments wont matter in the "new" join condition we're
136
+ // about to create.
137
+ can_split = IsTableRefsDisjoint (left_table_indexes, right_table_indexes, right_pred_bindings, left_pred_bindings);
138
+ if (!can_split) {
139
+ // We cant optimize this join
140
+ return ;
141
+ }
142
+ // Swap the left and right side of the predicate
143
+ std::swap (left_pred_expr, right_pred_expr);
144
+ }
70
145
71
146
// Lookup the st_xmin, st_xmax, st_ymin, st_ymax functions in the catalog
72
147
auto &catalog = Catalog::GetSystemCatalog (context);
73
- auto &xmin_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, " " , " st_xmin" )
148
+ auto &xmin_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA , " st_xmin" )
74
149
.Cast <ScalarFunctionCatalogEntry>();
75
- auto &xmax_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, " " , " st_xmax" )
150
+ auto &xmax_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA , " st_xmax" )
76
151
.Cast <ScalarFunctionCatalogEntry>();
77
- auto &ymin_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, " " , " st_ymin" )
152
+ auto &ymin_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA , " st_ymin" )
78
153
.Cast <ScalarFunctionCatalogEntry>();
79
- auto &ymax_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, " " , " st_ymax" )
154
+ auto &ymax_func_set = catalog.GetEntry (context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA , " st_ymax" )
80
155
.Cast <ScalarFunctionCatalogEntry>();
81
156
157
+ auto &left_arg_type = left_pred_expr->return_type ;
158
+ auto &right_arg_type = right_pred_expr->return_type ;
159
+
82
160
auto xmin_func_left =
83
- xmin_func_set.functions .GetFunctionByArguments (context, {left_pred_expr-> return_type });
161
+ xmin_func_set.functions .GetFunctionByArguments (context, {left_arg_type });
84
162
auto xmax_func_left =
85
- xmax_func_set.functions .GetFunctionByArguments (context, {left_pred_expr-> return_type });
163
+ xmax_func_set.functions .GetFunctionByArguments (context, {left_arg_type });
86
164
auto ymin_func_left =
87
- ymin_func_set.functions .GetFunctionByArguments (context, {left_pred_expr-> return_type });
165
+ ymin_func_set.functions .GetFunctionByArguments (context, {left_arg_type });
88
166
auto ymax_func_left =
89
- ymax_func_set.functions .GetFunctionByArguments (context, {left_pred_expr-> return_type });
167
+ ymax_func_set.functions .GetFunctionByArguments (context, {left_arg_type });
90
168
91
169
auto xmin_func_right =
92
- xmin_func_set.functions .GetFunctionByArguments (context, {right_pred_expr-> return_type });
170
+ xmin_func_set.functions .GetFunctionByArguments (context, {right_arg_type });
93
171
auto xmax_func_right =
94
- xmax_func_set.functions .GetFunctionByArguments (context, {right_pred_expr-> return_type });
172
+ xmax_func_set.functions .GetFunctionByArguments (context, {right_arg_type });
95
173
auto ymin_func_right =
96
- ymin_func_set.functions .GetFunctionByArguments (context, {right_pred_expr-> return_type });
174
+ ymin_func_set.functions .GetFunctionByArguments (context, {right_arg_type });
97
175
auto ymax_func_right =
98
- ymax_func_set.functions .GetFunctionByArguments (context, {right_pred_expr-> return_type });
176
+ ymax_func_set.functions .GetFunctionByArguments (context, {right_arg_type });
99
177
100
178
// Create the new join condition
101
179
@@ -151,16 +229,25 @@ class RangeJoinSpatialPredicateRewriter : public OptimizerExtension {
151
229
ExpressionType::COMPARE_LESSTHANOREQUALTO);
152
230
AddComparison (new_join, std::move (a_y_max), std::move (b_y_min),
153
231
ExpressionType::COMPARE_GREATERTHANOREQUALTO);
232
+
154
233
new_join->children = std::move (any_join.children );
234
+ if (any_join.has_estimated_cardinality ) {
235
+ new_join->estimated_cardinality = any_join.estimated_cardinality ;
236
+ new_join->has_estimated_cardinality = true ;
237
+ }
155
238
156
- // Also, we need to create a filter with the original predicate
157
- auto filter = make_uniq<LogicalFilter>(any_join.condition ->Copy ());
239
+ auto filter = make_uniq<LogicalFilter>(std::move (any_join.condition ));
158
240
filter->children .push_back (std::move (new_join));
159
241
160
242
plan = std::move (filter);
161
243
}
162
244
}
163
245
}
246
+ }
247
+
248
+ static void Optimize (ClientContext &context, OptimizerExtensionInfo *info, unique_ptr<LogicalOperator> &plan) {
249
+
250
+ TryOptimize (context, info, plan);
164
251
165
252
// Recursively optimize the children
166
253
for (auto &child : plan->children ) {
0 commit comments