Skip to content

Commit

Permalink
SQLServer Aggregate Support (#11811)
Browse files Browse the repository at this point in the history
* 40 red

* 18 Red

* 31 Red

* 20 red

* 18 Red

* 15 red

* 9 Red

* 7

* Comment out broken test for now

* Green

* Cleanup

* Changelog

* Update check_aggregate_support

* Cleanup

* Reenable test

* Fix tests

* Doc comment
  • Loading branch information
AdRiley authored Dec 21, 2024
1 parent e8f781a commit 31772e3
Show file tree
Hide file tree
Showing 10 changed files with 101 additions and 78 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@
- [Enhance Managed_Resource to allow implementation of in-memory caches][11577]
- [Added `add_group_number` to the in-memory database.[11818]
- [The reload button clears the HTTP cache.][11673]
- [SQL Server Support for Aggregate][11811]

[11235]: https://github.com/enso-org/enso/pull/11235
[11255]: https://github.com/enso-org/enso/pull/11255
Expand All @@ -146,6 +147,7 @@
[11577]: https://github.com/enso-org/enso/pull/11577
[11818]: https://github.com/enso-org/enso/pull/11818
[11673]: https://github.com/enso-org/enso/pull/11673
[11811]: https://github.com/enso-org/enso/pull/11811

#### Enso Language & Runtime

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,12 @@ type Redshift_Dialect
_ = [op_kind, args]
expression

## PRIVATE
Add an extra cast to adjust the output type of aggregate operations.
Some DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
self.cast_op_type op_kind columns (SQL_Expression.Operation op_kind (columns.map c->c.expression))

## PRIVATE
prepare_fetch_types_query : SQL_Expression -> Context -> SQL_Statement
prepare_fetch_types_query self expression context =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ from project.Errors import Aggregagtion_Requires_Order
make_aggregate_column : DB_Table -> Aggregate_Column -> Text -> Dialect -> (Text -> Vector -> SQL_Expression -> SQL_Type_Reference) -> Problem_Builder -> Internal_Column
make_aggregate_column table aggregate as dialect infer_return_type problem_builder -> Internal_Column =
simple_aggregate op_kind columns =
expression = dialect.cast_op_type op_kind columns (SQL_Expression.Operation op_kind (columns.map c->c.expression))
expression = dialect.cast_aggregate_columns op_kind columns
sql_type_ref = infer_return_type op_kind columns expression
Internal_Column.Value as sql_type_ref expression

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ type SQL_Generator
generate_select_query_sql : Dialect -> Vector (Pair Text SQL_Expression) -> Context -> SQL_Builder
generate_select_query_sql self dialect columns ctx =
gen_exprs exprs = exprs.map (expr-> dialect.generate_expression self expr for_select=False)
gen_group_exprs exprs = exprs.map (expr-> dialect.generate_expression self expr for_select=True)
gen_column pair = (dialect.generate_expression self expr=pair.second for_select=True) ++ alias dialect pair.first

generated_columns = case columns of
Expand All @@ -187,7 +188,7 @@ type SQL_Generator

from_part = self.generate_from_part dialect ctx.from_spec
where_part = (SQL_Builder.join " AND " (gen_exprs ctx.where_filters)) . prefix_if_present " WHERE "
group_part = (SQL_Builder.join ", " (gen_exprs ctx.groups)) . prefix_if_present " GROUP BY "
group_part = (SQL_Builder.join ", " (gen_group_exprs ctx.groups)) . prefix_if_present " GROUP BY "

orders = ctx.orders.map (self.generate_order dialect)
order_part = (SQL_Builder.join ", " orders) . prefix_if_present " ORDER BY "
Expand Down Expand Up @@ -663,14 +664,14 @@ preprocess_query (query : Query) -> Query =
column expression; it should be provided only if `has_quote` is `True` and
must not be empty then. If the quote character occurs in the expression, it
is escaped by doubling each occurrence.
make_concat make_raw_concat_expr make_contains_expr has_quote args =
make_concat make_raw_concat_expr make_contains_expr has_quote args append_char="||" =
expected_args = if has_quote then 5 else 4
if args.length != expected_args then Error.throw (Illegal_State.Error "Unexpected number of arguments for the concat operation.") else
expr = args.at 0
separator = args.at 1
prefix = args.at 2
suffix = args.at 3
append = " || "
append = " " + append_char + " "
possibly_quoted = case has_quote of
True ->
quote = args.at 4
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ type Postgres_Dialect
if cast_to.is_nothing then expression else
SQL_Expression.Operation "CAST" [expression, SQL_Expression.Literal cast_to]

## PRIVATE
Add an extra cast to adjust the output type of aggregate operations.
Some DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
self.cast_op_type op_kind columns (SQL_Expression.Operation op_kind (columns.map c->c.expression))

## PRIVATE
prepare_fetch_types_query : SQL_Expression -> Context -> SQL_Statement
prepare_fetch_types_query self expression context =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,12 @@ type SQLite_Dialect
_ = [op_kind, args]
expression

## PRIVATE
Add an extra cast to adjust the output type of aggregate operations.
Some DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
self.cast_op_type op_kind columns (SQL_Expression.Operation op_kind (columns.map c->c.expression))

## PRIVATE
prepare_fetch_types_query : SQL_Expression -> Context -> SQL_Statement
prepare_fetch_types_query self expression context =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,27 @@ type SQLServer_Dialect
is used only to override the type in cases where the default one that the
database uses is not what we want.
cast_op_type self (op_kind:Text) (args:(Vector Internal_Column)) (expression:SQL_Expression) =
_ = [op_kind, args]
expression
is_int ic =
typeid = ic.sql_type_reference.get.typeid
typeid == Java_Types.SMALLINT || typeid == Java_Types.INTEGER || typeid == Java_Types.BIGINT

cast_to = case op_kind of
"AVG" ->
if is_int (args.at 0) then "FLOAT" else Nothing
"STDDEV_POP" ->
if is_int (args.at 0) then "FLOAT" else Nothing
"STDDEV_SAMP" ->
if is_int (args.at 0) then "FLOAT" else Nothing
_ -> Nothing

if cast_to.is_nothing then expression else
SQL_Expression.Operation "CAST" [expression, SQL_Expression.Literal cast_to]

## PRIVATE
Add an extra cast to adjust the output type of aggregate operations.
Some DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
SQL_Expression.Operation op_kind (columns.map c->(self.cast_op_type op_kind columns (Internals_Access.column_expression c)))

## PRIVATE
prepare_fetch_types_query : SQL_Expression -> Context -> SQL_Statement
Expand All @@ -224,10 +243,32 @@ type SQLServer_Dialect
generate_collate self collation_name:Text -> Text = Base_Generator.default_generate_collate collation_name quote_char=""

## PRIVATE
check_aggregate_support : Aggregate_Column -> Boolean ! Unsupported_Database_Operation
check_aggregate_support self aggregate =
_ = aggregate
True
check_aggregate_support self aggregate:Aggregate_Column -> Boolean ! Unsupported_Database_Operation =
unsupported name =
Error.throw (Unsupported_Database_Operation.Error name)
case aggregate of
Group_By _ _ -> True
Count _ -> True
Count_Distinct columns _ _ ->
if columns.length == 1 then True else
unsupported "Count_Distinct on multiple columns"
Count_Not_Nothing _ _ -> True
Count_Nothing _ _ -> True
Count_Not_Empty _ _ -> True
Count_Empty _ _ -> True
Percentile _ _ _ -> unsupported "Percentile"
Mode _ _ -> unsupported "Mode"
First _ _ _ _ -> unsupported "First"
Last _ _ _ _ -> unsupported "Last"
Maximum _ _ -> True
Minimum _ _ -> True
Shortest _ _ -> unsupported "Shortest"
Longest _ _ -> unsupported "Longest"
Standard_Deviation _ _ _ -> True
Concatenate _ _ _ _ _ _ -> True
Sum _ _ -> True
Average _ _ -> True
Median _ _ -> unsupported "Median"

## PRIVATE
Checks if an operation is supported by the dialect.
Expand All @@ -243,6 +284,7 @@ type SQLServer_Dialect
Feature.Filter -> True
Feature.Join -> True
Feature.Union -> True
Feature.Aggregate -> True
_ -> False

## PRIVATE
Expand Down Expand Up @@ -401,6 +443,7 @@ private _generate_expression dialect base_gen expr expression_kind:Expression_Ki

pair final_expr null_checks_result
query : Query -> pair (base_gen.generate_sub_query dialect query) []
descriptor : Order_Descriptor -> pair (base_gen.generate_order dialect descriptor) []

## PRIVATE
type Expression_Kind
Expand Down Expand Up @@ -437,7 +480,7 @@ private _op_return_kind op -> Expression_Kind =
if return_bool_ops.contains op then Expression_Kind.Boolean_Condition else Expression_Kind.Value

private _op_needs_to_materialize_null_checks op -> Boolean =
["FILL_NULL", "COALESCE"].contains op
["FILL_NULL", "COALESCE", "COUNT_IS_NULL", "COUNT_EMPTY", "COUNT_NOT_EMPTY", "COUNT", "SUM", "AVG", "LONGEST", "SHORTEST", "COUNT_DISTINCT", "COUNT_DISTINCT_INCLUDE_NULL", "STDDEV_POP", "STDDEV_SAMP", "CONCAT", "CONCAT_QUOTE_IF_NEEDED", "MIN", "MAX"].contains op

## PRIVATE
make_dialect_operations =
Expand All @@ -447,13 +490,13 @@ make_dialect_operations =
arith_extensions = [floating_point_div, mod_op, decimal_div, decimal_mod, ["ROW_MIN", Base_Generator.make_function "LEAST"], ["ROW_MAX", Base_Generator.make_function "GREATEST"]]
bool = [bool_or]

stddev_pop = ["STDDEV_POP", Base_Generator.make_function "stddev_pop"]
stddev_samp = ["STDDEV_SAMP", Base_Generator.make_function "stddev_samp"]
stats = [agg_median, agg_mode, agg_percentile, stddev_pop, stddev_samp]
stddev_pop = ["STDDEV_POP", Base_Generator.make_function "STDEVP"]
stddev_samp = ["STDDEV_SAMP", Base_Generator.make_function "STDEV"]
stats = [stddev_pop, stddev_samp]
date_ops = [["year", Base_Generator.make_function "year"], make_datepart "quarter", ["month", Base_Generator.make_function "month"], make_datepart "week" "iso_week", ["day", Base_Generator.make_function "day"], make_datepart "hour", make_datepart "minute", make_datepart "day_of_year" "dayofyear", make_day_of_week, make_datepart "second", make_datepart "millisecond", make_extract_microsecond, ["date_add", make_date_add], ["date_diff", make_date_diff], ["date_trunc_to_day", make_date_trunc_to_day]]
special_overrides = [is_empty, ["IIF", _make_iif]]
other = [["RUNTIME_ERROR", make_runtime_error_op]]
my_mappings = text + counts + stats + first_last_aggregators + arith_extensions + bool + date_ops + special_overrides + other
my_mappings = text + counts + arith_extensions + bool + stats + date_ops + special_overrides + other
base = Base_Generator.base_dialect_operations . extend_with my_mappings
Base_Generator.Dialect_Operations.Value (base.operations_dict.remove "IS_IN")

Expand All @@ -469,68 +512,29 @@ private _make_iif arguments:Vector -> SQL_Builder =

## PRIVATE
agg_count_is_null = Base_Generator.lift_unary_op "COUNT_IS_NULL" arg->
SQL_Builder.code "SUM(CASE WHEN " ++ arg.paren ++ " IS NULL THEN 1 ELSE 0 END)"
SQL_Builder.code "COALESCE(SUM(CASE WHEN " ++ arg.paren ++ " IS NULL THEN 1 ELSE 0 END), 0)"

## PRIVATE
agg_count_empty = Base_Generator.lift_unary_op "COUNT_EMPTY" arg->
SQL_Builder.code "SUM(CASE WHEN (" ++ arg.paren ++ " IS NULL) OR (" ++ arg.paren ++ " = '') THEN 1 ELSE 0 END)"
SQL_Builder.code "COALESCE(SUM(CASE WHEN (" ++ arg.paren ++ " IS NULL) OR (" ++ arg.paren ++ " = '') THEN 1 ELSE 0 END), 0)"

## PRIVATE
agg_count_not_empty = Base_Generator.lift_unary_op "COUNT_NOT_EMPTY" arg->
SQL_Builder.code "SUM(CASE WHEN (" ++ arg.paren ++ " IS NOT NULL) AND (" ++ arg.paren ++ " != '') THEN 1 ELSE 0 END)"


## PRIVATE
agg_median = Base_Generator.lift_unary_op "MEDIAN" arg->
median = SQL_Builder.code "MEDIAN(" ++ arg ++ ")"
has_nan = SQL_Builder.code "BOOLOR_AGG(" ++ arg ++ " = 'NaN'::Double)"
SQL_Builder.code "CASE WHEN " ++ has_nan ++ " THEN 'NaN'::Double ELSE " ++ median ++ " END"

## PRIVATE
agg_mode = Base_Generator.lift_unary_op "MODE" arg->
SQL_Builder.code "MODE(" ++ arg ++ ")"

## PRIVATE
agg_percentile = Base_Generator.lift_binary_op "PERCENTILE" p-> expr->
percentile = SQL_Builder.code "percentile_cont(" ++ p ++ ") WITHIN GROUP (ORDER BY " ++ expr ++ ")"
has_nan = SQL_Builder.code "BOOLOR_AGG(" ++ expr ++ " = 'NaN'::Double)"
SQL_Builder.code "CASE WHEN " ++ has_nan ++ " THEN 'NaN' ELSE " ++ percentile ++ " END"

## PRIVATE
These are written in a not most-efficient way, but a way that makes them
compatible with other group-by aggregations out-of-the-box. In the future, we
may want to consider some alternative solutions.
first_last_aggregators =
first = make_first_aggregator reverse=False ignore_null=False
first_not_null = make_first_aggregator reverse=False ignore_null=True
last = make_first_aggregator reverse=True ignore_null=False
last_not_null = make_first_aggregator reverse=True ignore_null=True
[["FIRST", first], ["FIRST_NOT_NULL", first_not_null], ["LAST", last], ["LAST_NOT_NULL", last_not_null]]

## PRIVATE
make_first_aggregator reverse ignore_null args =
if args.length < 2 then Error.throw (Illegal_State.Error "Insufficient number of arguments for the operation.") else
result_expr = args.first
order_bys = args.drop 1

method_name = if reverse then "LAST_VALUE" else "FIRST_VALUE"
filter_clause = if ignore_null then ") IGNORE NULLS OVER" else ") OVER"
order_clause = SQL_Builder.code " ORDER BY " ++ SQL_Builder.join "," order_bys
SQL_Builder.code (method_name + "(") ++ result_expr ++ filter_clause ++ order_clause
SQL_Builder.code "COALESCE(SUM(CASE WHEN (" ++ arg.paren ++ " IS NOT NULL) AND (" ++ arg.paren ++ " != '') THEN 1 ELSE 0 END), 0)"

## PRIVATE
agg_shortest = Base_Generator.lift_unary_op "SHORTEST" arg->
SQL_Builder.code "FIRST_VALUE(" ++ arg ++ ") IGNORE NULLS OVER (ORDER BY LENGTH(" ++ arg ++ "))"
SQL_Builder.code "FIRST_VALUE(" ++ arg ++ ") IGNORE NULLS OVER (ORDER BY LEN(" ++ arg ++ "))"

## PRIVATE
agg_longest = Base_Generator.lift_unary_op "LONGEST" arg->
SQL_Builder.code "FIRST_VALUE(" ++ arg ++ ") IGNORE NULLS OVER (ORDER BY LENGTH(" ++ arg ++ ") DESC)"
SQL_Builder.code "FIRST_VALUE(" ++ arg ++ ") IGNORE NULLS OVER (ORDER BY LEN(" ++ arg ++ ") DESC)"

## PRIVATE
concat_ops =
make_raw_concat_expr expr separator =
SQL_Builder.code "string_agg(" ++ expr ++ ", " ++ separator ++ ")"
concat = Base_Generator.make_concat make_raw_concat_expr make_contains_expr
concat = Base_Generator.make_concat make_raw_concat_expr make_contains_expr append_char="+"
[["CONCAT", concat (has_quote=False)], ["CONCAT_QUOTE_IF_NEEDED", concat (has_quote=True)]]

## PRIVATE
Expand All @@ -554,14 +558,7 @@ agg_count_distinct args = if args.is_empty then (Error.throw (Illegal_Argument.E
True ->
## A single null value will be skipped.
SQL_Builder.code "COUNT(DISTINCT " ++ args.first ++ ")"
False ->
## A tuple of nulls is not a null, so it will not be skipped - but
we want to ignore all-null columns. So we manually filter them
out.
count = SQL_Builder.code "COUNT(DISTINCT (" ++ SQL_Builder.join ", " args ++ "))"
are_nulls = args.map arg-> arg.paren ++ " IS NULL"
all_nulls_filter = SQL_Builder.code " FILTER (WHERE NOT (" ++ SQL_Builder.join " AND " are_nulls ++ "))"
(count ++ all_nulls_filter).paren
False -> Error.throw (Illegal_Argument.Error "COUNT_DISTINCT supports only single arguments in SQLServer.")

## PRIVATE
agg_count_distinct_include_null args = case args.length == 1 of
Expand Down Expand Up @@ -595,12 +592,11 @@ ends_with = Base_Generator.lift_binary_op "ENDS_WITH" str-> sub->
res.paren

## PRIVATE
contains = Base_Generator.lift_binary_op "CONTAINS" str-> sub->
res = SQL_Builder.code "CHARINDEX(" ++ sub ++ ", " ++ str ++ ") > 0"
res.paren
make_contains_expr expr substring =
SQL_Builder.code "CHARINDEX(" ++ substring ++ ", " ++ expr ++ ") > 0"

## PRIVATE
make_contains_expr expr substring = contains [expr, substring]
contains = Base_Generator.lift_binary_op "CONTAINS" make_contains_expr

## PRIVATE
make_case_sensitive = Base_Generator.lift_unary_op "MAKE_CASE_SENSITIVE" arg->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ type Snowflake_Dialect
_ = [op_kind, args]
expression

## PRIVATE
Add an extra cast to adjust the output type of aggregate operations.
Some DBs do CAST(SUM(x) AS FLOAT) others do SUM(CAST(x AS FLOAT)).
cast_aggregate_columns self op_kind:Text columns:(Vector Internal_Column) =
self.cast_op_type op_kind columns (SQL_Expression.Operation op_kind (columns.map c->c.expression))

## PRIVATE
prepare_fetch_types_query : SQL_Expression -> Context -> SQL_Statement
prepare_fetch_types_query self expression context =
Expand Down
4 changes: 2 additions & 2 deletions test/Microsoft_Tests/src/SQLServer_Spec.enso
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,8 @@ add_sqlserver_specs suite_builder create_connection_fn =
materialize = .read

common_selection = Common_Table_Operations.Main.Test_Selection.Config supported_replace_params=supported_replace_params run_advanced_edge_case_tests_by_default=True
aggregate_selection = Common_Table_Operations.Aggregate_Spec.Test_Selection.Config first_last_row_order=False aggregation_problems=False
agg_in_memory_table = (enso_project.data / "data.csv") . read
aggregate_selection = Common_Table_Operations.Aggregate_Spec.Test_Selection.Config advanced_stats=False text_shortest_longest=False first_last=False first_last_row_order=False aggregation_problems=False multi_distinct=False first_last_multi_order=False first_last_ignore_nothing=False text_concat=False
agg_in_memory_table = ((Project_Description.new enso_dev.Table_Tests).data / "data.csv") . read

agg_table_fn = _->
agg_in_memory_table.select_into_database_table default_connection.get (Name_Generator.random_name "Agg1") primary_key=Nothing temporary=True
Expand Down
Loading

0 comments on commit 31772e3

Please sign in to comment.