diff --git a/fe/fe-core/src/main/java/com/starrocks/connector/parser/trino/Trino2SRFunctionCallTransformer.java b/fe/fe-core/src/main/java/com/starrocks/connector/parser/trino/Trino2SRFunctionCallTransformer.java index f33b5f3c412d2..787e3a801a995 100644 --- a/fe/fe-core/src/main/java/com/starrocks/connector/parser/trino/Trino2SRFunctionCallTransformer.java +++ b/fe/fe-core/src/main/java/com/starrocks/connector/parser/trino/Trino2SRFunctionCallTransformer.java @@ -15,7 +15,6 @@ package com.starrocks.connector.parser.trino; import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.starrocks.analysis.CaseExpr; @@ -81,234 +80,250 @@ private static void registerAllFunctionTransformer() { registerUnicodeFunctionTransformer(); registerMapFunctionTransformer(); registerBinaryFunctionTransformer(); + registerHLLFunctionTransformer(); // todo: support more function transform } private static void registerAggregateFunctionTransformer() { // 1.approx_distinct registerFunctionTransformer("approx_distinct", 1, - "approx_count_distinct", ImmutableList.of(Expr.class)); + "approx_count_distinct", List.of(Expr.class)); // 2. arbitrary registerFunctionTransformer("arbitrary", 1, - "any_value", ImmutableList.of(Expr.class)); + "any_value", List.of(Expr.class)); // 3. approx_percentile registerFunctionTransformer("approx_percentile", 2, - "percentile_approx", ImmutableList.of(Expr.class, Expr.class)); + "percentile_approx", List.of(Expr.class, Expr.class)); // 4. stddev registerFunctionTransformer("stddev", 1, - "stddev_samp", ImmutableList.of(Expr.class)); + "stddev_samp", List.of(Expr.class)); // 5. stddev_pop registerFunctionTransformer("stddev_pop", 1, - "stddev", ImmutableList.of(Expr.class)); + "stddev", List.of(Expr.class)); // 6. variance registerFunctionTransformer("variance", 1, - "var_samp", ImmutableList.of(Expr.class)); + "var_samp", List.of(Expr.class)); // 7. var_pop registerFunctionTransformer("var_pop", 1, - "variance", ImmutableList.of(Expr.class)); + "variance", List.of(Expr.class)); // 8. count_if(x) -> count(case when x then 1 end) registerFunctionTransformer("count_if", 1, new FunctionCallExpr("count", - ImmutableList.of(new CaseExpr(null, ImmutableList.of(new CaseWhenClause( + List.of(new CaseExpr(null, List.of(new CaseWhenClause( new PlaceholderExpr(1, Expr.class), new IntLiteral(1L))), null)))); } private static void registerArrayFunctionTransformer() { // array_union -> array_distinct(array_concat(x, y)) registerFunctionTransformer("array_union", 2, new FunctionCallExpr("array_distinct", - ImmutableList.of(new FunctionCallExpr("array_concat", ImmutableList.of( + List.of(new FunctionCallExpr("array_concat", List.of( new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class) ))))); // contains -> array_contains registerFunctionTransformer("contains", 2, "array_contains", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // slice -> array_slice registerFunctionTransformer("slice", 3, "array_slice", - ImmutableList.of(Expr.class, Expr.class, Expr.class)); + List.of(Expr.class, Expr.class, Expr.class)); // filter(array, lambda) -> array_filter(array, lambda) registerFunctionTransformer("filter", 2, "array_filter", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // contains_sequence -> array_contains_seq registerFunctionTransformer("contains_sequence", 2, "array_contains_seq", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); } private static void registerDateFunctionTransformer() { // to_unixtime -> unix_timestamp registerFunctionTransformer("to_unixtime", 1, "unix_timestamp", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // date_parse -> str_to_date registerFunctionTransformer("date_parse", 2, "str_to_date", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // day_of_week -> dayofweek registerFunctionTransformer("day_of_week", 1, "dayofweek_iso", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // dow -> dayofweek registerFunctionTransformer("dow", 1, "dayofweek_iso", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // day_of_month -> dayofmonth registerFunctionTransformer("day_of_month", 1, "dayofmonth", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // day_of_year -> dayofyear registerFunctionTransformer("day_of_year", 1, "dayofyear", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // doy -> dayofyear registerFunctionTransformer("doy", 1, "dayofyear", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // week_of_year -> week_iso registerFunctionTransformer("week_of_year", 1, "week_iso", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // week -> week_iso registerFunctionTransformer("week", 1, "week_iso", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // format_datetime -> jodatime_format registerFunctionTransformer("format_datetime", 2, "jodatime_format", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // to_char -> jodatime_format registerFunctionTransformer("to_char", 2, "jodatime_format", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // parse_datetime -> str_to_jodatime registerFunctionTransformer("parse_datetime", 2, "str_to_jodatime", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // to_date -> to_tera_date registerFunctionTransformer("to_date", 2, "to_tera_date", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // to_timestamp -> to_tera_timestamp registerFunctionTransformer("to_timestamp", 2, "to_tera_timestamp", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); + + // last_day_of_month(x) -> last_day(x,'month') + registerFunctionTransformer("last_day_of_month", 1, new FunctionCallExpr("last_day", + List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("month")))); } private static void registerStringFunctionTransformer() { // chr -> char - registerFunctionTransformer("chr", 1, "char", ImmutableList.of(Expr.class)); + registerFunctionTransformer("chr", 1, "char", List.of(Expr.class)); // codepoint -> ascii - registerFunctionTransformer("codepoint", 1, "ascii", ImmutableList.of(Expr.class)); + registerFunctionTransformer("codepoint", 1, "ascii", List.of(Expr.class)); // strpos -> locate registerFunctionTransformer("strpos", 2, new FunctionCallExpr("locate", - ImmutableList.of(new PlaceholderExpr(2, Expr.class), new PlaceholderExpr(1, Expr.class)))); + List.of(new PlaceholderExpr(2, Expr.class), new PlaceholderExpr(1, Expr.class)))); // length -> char_length - registerFunctionTransformer("length", 1, "char_length", ImmutableList.of(Expr.class)); + registerFunctionTransformer("length", 1, "char_length", List.of(Expr.class)); // str_to_map(str, del1, del2) -> str_to_map(split(str, del1), del2) registerFunctionTransformer("str_to_map", 3, new FunctionCallExpr("str_to_map", - ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of( + List.of(new FunctionCallExpr("split", List.of( new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class) )), new PlaceholderExpr(3, Expr.class)))); // str_to_map(str) -> str_to_map(split(str, del1), del2) registerFunctionTransformer("str_to_map", 2, new FunctionCallExpr("str_to_map", - ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of( + List.of(new FunctionCallExpr("split", List.of( new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class) )), new StringLiteral(":")))); // str_to_map(str) -> str_to_map(split(str, del1), del2) registerFunctionTransformer("str_to_map", 1, new FunctionCallExpr("str_to_map", - ImmutableList.of(new FunctionCallExpr("split", ImmutableList.of( + List.of(new FunctionCallExpr("split", List.of( new PlaceholderExpr(1, Expr.class), new StringLiteral(","))), new StringLiteral(":")))); // replace(string, search) -> replace(string, search, '') registerFunctionTransformer("replace", 2, new FunctionCallExpr("replace", - ImmutableList.of(new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class), + List.of(new PlaceholderExpr(1, Expr.class), new PlaceholderExpr(2, Expr.class), new StringLiteral("")))); registerFunctionTransformer("index", 2, "instr", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); } private static void registerRegexpFunctionTransformer() { // regexp_like -> regexp registerFunctionTransformer("regexp_like", 2, "regexp", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); } private static void registerURLFunctionTransformer() { // url_extract_path('https://www.starrocks.io/showcase') -> parse_url('https://www.starrocks.io/showcase', 'PATH') registerFunctionTransformer("url_extract_path", 1, new FunctionCallExpr("parse_url", - ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("PATH")))); + List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("PATH")))); } private static void registerJsonFunctionTransformer() { // json_array_length -> json_length registerFunctionTransformer("json_array_length", 1, "json_length", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // json_parse -> parse_json registerFunctionTransformer("json_parse", 1, "parse_json", - ImmutableList.of(Expr.class)); + List.of(Expr.class)); // json_extract -> get_json_string registerFunctionTransformer("json_extract", 2, "get_json_string", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); // json_size -> json_length registerFunctionTransformer("json_size", 2, "json_length", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); } private static void registerBitwiseFunctionTransformer() { // bitwise_and -> bitand - registerFunctionTransformer("bitwise_and", 2, "bitand", ImmutableList.of(Expr.class, Expr.class)); + registerFunctionTransformer("bitwise_and", 2, "bitand", List.of(Expr.class, Expr.class)); // bitwise_not -> bitnot - registerFunctionTransformer("bitwise_not", 1, "bitnot", ImmutableList.of(Expr.class)); + registerFunctionTransformer("bitwise_not", 1, "bitnot", List.of(Expr.class)); // bitwise_or -> bitor - registerFunctionTransformer("bitwise_or", 2, "bitor", ImmutableList.of(Expr.class, Expr.class)); + registerFunctionTransformer("bitwise_or", 2, "bitor", List.of(Expr.class, Expr.class)); // bitwise_xor -> bitxor - registerFunctionTransformer("bitwise_xor", 2, "bitxor", ImmutableList.of(Expr.class, Expr.class)); + registerFunctionTransformer("bitwise_xor", 2, "bitxor", List.of(Expr.class, Expr.class)); // bitwise_left_shift -> bit_shift_left - registerFunctionTransformer("bitwise_left_shift", 2, "bit_shift_left", ImmutableList.of(Expr.class, Expr.class)); + registerFunctionTransformer("bitwise_left_shift", 2, "bit_shift_left", List.of(Expr.class, Expr.class)); // bitwise_right_shift -> bit_shift_right - registerFunctionTransformer("bitwise_right_shift", 2, "bit_shift_right", ImmutableList.of(Expr.class, Expr.class)); + registerFunctionTransformer("bitwise_right_shift", 2, "bit_shift_right", List.of(Expr.class, Expr.class)); } private static void registerUnicodeFunctionTransformer() { // to_utf8 -> to_binary registerFunctionTransformer("to_utf8", 1, new FunctionCallExpr("to_binary", - ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8")))); + List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8")))); // from_utf8 -> from_binary registerFunctionTransformer("from_utf8", 1, new FunctionCallExpr("from_binary", - ImmutableList.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8")))); + List.of(new PlaceholderExpr(1, Expr.class), new StringLiteral("utf8")))); } private static void registerMapFunctionTransformer() { // map(array, array) -> map_from_arrays registerFunctionTransformer("map", 2, "map_from_arrays", - ImmutableList.of(Expr.class, Expr.class)); + List.of(Expr.class, Expr.class)); } private static void registerBinaryFunctionTransformer() { // to_hex -> hex - registerFunctionTransformer("to_hex", 1, "hex", ImmutableList.of(Expr.class)); + registerFunctionTransformer("to_hex", 1, "hex", List.of(Expr.class)); // from_hex -> hex_decode_binary - registerFunctionTransformer("from_hex", 1, "hex_decode_binary", ImmutableList.of(Expr.class)); + registerFunctionTransformer("from_hex", 1, "hex_decode_binary", List.of(Expr.class)); + } + + private static void registerHLLFunctionTransformer() { + // approx_set -> HLL_HASH + registerFunctionTransformer("approx_set", 1, "hll_hash", List.of(Expr.class)); + + // empty_approx_set -> HLL_EMPTY + registerFunctionTransformer("empty_approx_set", "hll_empty"); + + // merge -> HLL_RAW_AGG + registerFunctionTransformer("merge", 1, "hll_raw_agg", List.of(Expr.class)); } private static void registerFunctionTransformer(String trinoFnName, int trinoFnArgNums, String starRocksFnName, @@ -317,6 +332,11 @@ private static void registerFunctionTransformer(String trinoFnName, int trinoFnA registerFunctionTransformer(trinoFnName, trinoFnArgNums, starRocksFunctionCall); } + private static void registerFunctionTransformer(String trinoFnName, String starRocksFnName) { + FunctionCallExpr starRocksFunctionCall = buildStarRocksFunctionCall(starRocksFnName, Lists.newArrayList()); + registerFunctionTransformer(trinoFnName, 0, starRocksFunctionCall); + } + private static void registerFunctionTransformerWithVarArgs(String trinoFnName, String starRocksFnName, List> starRocksArgumentsClass) { Preconditions.checkState(starRocksArgumentsClass.size() == 1); diff --git a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java index 519c2388d9317..166e537d48290 100644 --- a/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java +++ b/fe/fe-core/src/test/java/com/starrocks/connector/parser/trino/TrinoFunctionTransformTest.java @@ -441,4 +441,17 @@ public void testUtilityFunction() throws Exception { sql = "select current_schema"; assertPlanContains(sql, " : 'test'"); } + + + @Test + public void testHllFunction() throws Exception { + String sql = "select empty_approx_set()"; + assertPlanContains(sql, " : HLL_EMPTY()"); + + sql = "select approx_set(\"tc\") from tall"; + assertPlanContains(sql, " : hll_hash(CAST(3: tc AS VARCHAR))"); + + sql = "select merge(approx_set(\"tc\")) from tall"; + assertPlanContains(sql, "hll_raw_agg(hll_hash(CAST(3: tc AS VARCHAR)))"); + } }