Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement] Add Trino HLL Function Compatibility Mapping and last_day_of_month Support(StarRocks#40894) (backport #47529) #47538

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package com.starrocks.connector.parser.trino;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.starrocks.analysis.CaseExpr;
Expand Down Expand Up @@ -81,234 +80,250 @@ private static void registerAllFunctionTransformer() {
registerUnicodeFunctionTransformer();
registerMapFunctionTransformer();
registerBinaryFunctionTransformer();
registerHLLFunctionTransformer();
// todo: support more function transform
}

private static void registerAggregateFunctionTransformer() {
// 1.approx_distinct
registerFunctionTransformer("approx_distinct", 1,
"approx_count_distinct", ImmutableList.of(Expr.class));
"approx_count_distinct", List.of(Expr.class));

// 2. arbitrary
registerFunctionTransformer("arbitrary", 1,
"any_value", ImmutableList.of(Expr.class));
"any_value", List.of(Expr.class));

// 3. approx_percentile
registerFunctionTransformer("approx_percentile", 2,
"percentile_approx", ImmutableList.of(Expr.class, Expr.class));
"percentile_approx", List.of(Expr.class, Expr.class));

// 4. stddev
registerFunctionTransformer("stddev", 1,
"stddev_samp", ImmutableList.of(Expr.class));
"stddev_samp", List.of(Expr.class));

// 5. stddev_pop
registerFunctionTransformer("stddev_pop", 1,
"stddev", ImmutableList.of(Expr.class));
"stddev", List.of(Expr.class));

// 6. variance
registerFunctionTransformer("variance", 1,
"var_samp", ImmutableList.of(Expr.class));
"var_samp", List.of(Expr.class));

// 7. var_pop
registerFunctionTransformer("var_pop", 1,
"variance", ImmutableList.of(Expr.class));
"variance", List.of(Expr.class));

// 8. count_if(x) -> count(case when x then 1 end)
registerFunctionTransformer("count_if", 1, new FunctionCallExpr("count",
ImmutableList.of(new CaseExpr(null, ImmutableList.of(new CaseWhenClause(
List.of(new CaseExpr(null, List.of(new CaseWhenClause(
new PlaceholderExpr(1, Expr.class), new IntLiteral(1L))), null))));
}

private static void registerArrayFunctionTransformer() {
// array_union -> array_distinct(array_concat(x, y))
registerFunctionTransformer("array_union", 2, new FunctionCallExpr("array_distinct",
ImmutableList.of(new FunctionCallExpr("array_concat", ImmutableList.of(
List.of(new FunctionCallExpr("array_concat", List.of(
new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class)
)))));
// contains -> array_contains
registerFunctionTransformer("contains", 2, "array_contains",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
// slice -> array_slice
registerFunctionTransformer("slice", 3, "array_slice",
ImmutableList.of(Expr.class, Expr.class, Expr.class));
List.of(Expr.class, Expr.class, Expr.class));
// filter(array, lambda) -> array_filter(array, lambda)
registerFunctionTransformer("filter", 2, "array_filter",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
// contains_sequence -> array_contains_seq
registerFunctionTransformer("contains_sequence", 2, "array_contains_seq",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
}

private static void registerDateFunctionTransformer() {
// to_unixtime -> unix_timestamp
registerFunctionTransformer("to_unixtime", 1, "unix_timestamp",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// date_parse -> str_to_date
registerFunctionTransformer("date_parse", 2, "str_to_date",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// day_of_week -> dayofweek
registerFunctionTransformer("day_of_week", 1, "dayofweek_iso",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// dow -> dayofweek
registerFunctionTransformer("dow", 1, "dayofweek_iso",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// day_of_month -> dayofmonth
registerFunctionTransformer("day_of_month", 1, "dayofmonth",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// day_of_year -> dayofyear
registerFunctionTransformer("day_of_year", 1, "dayofyear",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// doy -> dayofyear
registerFunctionTransformer("doy", 1, "dayofyear",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// week_of_year -> week_iso
registerFunctionTransformer("week_of_year", 1, "week_iso",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// week -> week_iso
registerFunctionTransformer("week", 1, "week_iso",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// format_datetime -> jodatime_format
registerFunctionTransformer("format_datetime", 2, "jodatime_format",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// to_char -> jodatime_format
registerFunctionTransformer("to_char", 2, "jodatime_format",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// parse_datetime -> str_to_jodatime
registerFunctionTransformer("parse_datetime", 2, "str_to_jodatime",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// to_date -> to_tera_date
registerFunctionTransformer("to_date", 2, "to_tera_date",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// to_timestamp -> to_tera_timestamp
registerFunctionTransformer("to_timestamp", 2, "to_tera_timestamp",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// last_day_of_month(x) -> last_day(x,'month')
registerFunctionTransformer("last_day_of_month", 1, new FunctionCallExpr("last_day",
List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("month"))));
}

private static void registerStringFunctionTransformer() {
// chr -> char
registerFunctionTransformer("chr", 1, "char", ImmutableList.of(Expr.class));
registerFunctionTransformer("chr", 1, "char", List.of(Expr.class));

// codepoint -> ascii
registerFunctionTransformer("codepoint", 1, "ascii", ImmutableList.of(Expr.class));
registerFunctionTransformer("codepoint", 1, "ascii", List.of(Expr.class));

// strpos -> locate
registerFunctionTransformer("strpos", 2, new FunctionCallExpr("locate",
ImmutableList.of(new PlaceholderExpr(2, Expr.class), new PlaceholderExpr(1, Expr.class))));
List.of(new PlaceholderExpr(2, Expr.class), new PlaceholderExpr(1, Expr.class))));

// length -> char_length
registerFunctionTransformer("length", 1, "char_length", ImmutableList.of(Expr.class));
registerFunctionTransformer("length", 1, "char_length", List.of(Expr.class));

// str_to_map(str, del1, del2) -> str_to_map(split(str, del1), del2)
registerFunctionTransformer("str_to_map", 3, new FunctionCallExpr("str_to_map",
ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of(
List.of(new FunctionCallExpr("split", List.of(
new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class)
)), new PlaceholderExpr(3, Expr.class))));

// str_to_map(str) -> str_to_map(split(str, del1), del2)
registerFunctionTransformer("str_to_map", 2, new FunctionCallExpr("str_to_map",
ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of(
List.of(new FunctionCallExpr("split", List.of(
new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class)
)), new StringLiteral(":"))));

// str_to_map(str) -> str_to_map(split(str, del1), del2)
registerFunctionTransformer("str_to_map", 1, new FunctionCallExpr("str_to_map",
ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of(
List.of(new FunctionCallExpr("split", List.of(
new PlaceholderExpr(1, Expr.class), new StringLiteral(","))), new StringLiteral(":"))));

// replace(string, search) -> replace(string, search, '')
registerFunctionTransformer("replace", 2, new FunctionCallExpr("replace",
ImmutableList.of(new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class),
List.of(new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class),
new StringLiteral(""))));

registerFunctionTransformer("index", 2, "instr",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
}

private static void registerRegexpFunctionTransformer() {
// regexp_like -> regexp
registerFunctionTransformer("regexp_like", 2, "regexp",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
}

private static void registerURLFunctionTransformer() {
// url_extract_path('https://www.starrocks.io/showcase') -> parse_url('https://www.starrocks.io/showcase', 'PATH')
registerFunctionTransformer("url_extract_path", 1, new FunctionCallExpr("parse_url",
ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("PATH"))));
List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("PATH"))));
}

private static void registerJsonFunctionTransformer() {
// json_array_length -> json_length
registerFunctionTransformer("json_array_length", 1, "json_length",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// json_parse -> parse_json
registerFunctionTransformer("json_parse", 1, "parse_json",
ImmutableList.of(Expr.class));
List.of(Expr.class));

// json_extract -> get_json_string
registerFunctionTransformer("json_extract", 2, "get_json_string",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));

// json_size -> json_length
registerFunctionTransformer("json_size", 2, "json_length",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
}

private static void registerBitwiseFunctionTransformer() {
// bitwise_and -> bitand
registerFunctionTransformer("bitwise_and", 2, "bitand", ImmutableList.of(Expr.class, Expr.class));
registerFunctionTransformer("bitwise_and", 2, "bitand", List.of(Expr.class, Expr.class));

// bitwise_not -> bitnot
registerFunctionTransformer("bitwise_not", 1, "bitnot", ImmutableList.of(Expr.class));
registerFunctionTransformer("bitwise_not", 1, "bitnot", List.of(Expr.class));

// bitwise_or -> bitor
registerFunctionTransformer("bitwise_or", 2, "bitor", ImmutableList.of(Expr.class, Expr.class));
registerFunctionTransformer("bitwise_or", 2, "bitor", List.of(Expr.class, Expr.class));

// bitwise_xor -> bitxor
registerFunctionTransformer("bitwise_xor", 2, "bitxor", ImmutableList.of(Expr.class, Expr.class));
registerFunctionTransformer("bitwise_xor", 2, "bitxor", List.of(Expr.class, Expr.class));

// bitwise_left_shift -> bit_shift_left
registerFunctionTransformer("bitwise_left_shift", 2, "bit_shift_left", ImmutableList.of(Expr.class, Expr.class));
registerFunctionTransformer("bitwise_left_shift", 2, "bit_shift_left", List.of(Expr.class, Expr.class));

// bitwise_right_shift -> bit_shift_right
registerFunctionTransformer("bitwise_right_shift", 2, "bit_shift_right", ImmutableList.of(Expr.class, Expr.class));
registerFunctionTransformer("bitwise_right_shift", 2, "bit_shift_right", List.of(Expr.class, Expr.class));
}

private static void registerUnicodeFunctionTransformer() {
// to_utf8 -> to_binary
registerFunctionTransformer("to_utf8", 1, new FunctionCallExpr("to_binary",
ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8"))));
List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8"))));

// from_utf8 -> from_binary
registerFunctionTransformer("from_utf8", 1, new FunctionCallExpr("from_binary",
ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8"))));
List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8"))));
}

private static void registerMapFunctionTransformer() {
// map(array, array) -> map_from_arrays
registerFunctionTransformer("map", 2, "map_from_arrays",
ImmutableList.of(Expr.class, Expr.class));
List.of(Expr.class, Expr.class));
}

private static void registerBinaryFunctionTransformer() {
// to_hex -> hex
registerFunctionTransformer("to_hex", 1, "hex", ImmutableList.of(Expr.class));
registerFunctionTransformer("to_hex", 1, "hex", List.of(Expr.class));

// from_hex -> hex_decode_binary
registerFunctionTransformer("from_hex", 1, "hex_decode_binary", ImmutableList.of(Expr.class));
registerFunctionTransformer("from_hex", 1, "hex_decode_binary", List.of(Expr.class));
}

private static void registerHLLFunctionTransformer() {
// approx_set -> HLL_HASH
registerFunctionTransformer("approx_set", 1, "hll_hash", List.of(Expr.class));

// empty_approx_set -> HLL_EMPTY
registerFunctionTransformer("empty_approx_set", "hll_empty");

// merge -> HLL_RAW_AGG
registerFunctionTransformer("merge", 1, "hll_raw_agg", List.of(Expr.class));
}

private static void registerFunctionTransformer(String trinoFnName, int trinoFnArgNums, String starRocksFnName,
Expand All @@ -317,6 +332,11 @@ private static void registerFunctionTransformer(String trinoFnName, int trinoFnA
registerFunctionTransformer(trinoFnName, trinoFnArgNums, starRocksFunctionCall);
}

private static void registerFunctionTransformer(String trinoFnName, String starRocksFnName) {
FunctionCallExpr starRocksFunctionCall = buildStarRocksFunctionCall(starRocksFnName, Lists.newArrayList());
registerFunctionTransformer(trinoFnName, 0, starRocksFunctionCall);
}

private static void registerFunctionTransformerWithVarArgs(String trinoFnName, String starRocksFnName,
List<Class<? extends Expr>> starRocksArgumentsClass) {
Preconditions.checkState(starRocksArgumentsClass.size() == 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,17 @@ public void testUtilityFunction() throws Exception {
sql = "select current_schema";
assertPlanContains(sql, "<slot 2> : 'test'");
}


@Test
public void testHllFunction() throws Exception {
String sql = "select empty_approx_set()";
assertPlanContains(sql, "<slot 2> : HLL_EMPTY()");

sql = "select approx_set(\"tc\") from tall";
assertPlanContains(sql, "<slot 12> : hll_hash(CAST(3: tc AS VARCHAR))");

sql = "select merge(approx_set(\"tc\")) from tall";
assertPlanContains(sql, "hll_raw_agg(hll_hash(CAST(3: tc AS VARCHAR)))");
}
}
Loading