From 0bb7e64c498e5c9809f7097650194f8a36c12424 Mon Sep 17 00:00:00 2001 From: Yang Zhang Date: Tue, 3 Dec 2024 01:22:49 -0800 Subject: [PATCH] refactor(sparksql): Speed up sparksql compilation by splitting function registrations (#11565) Summary: This PR aims to speed up sparksql compilation by splitting function registrations to multiple source files arranged according to function type. Adds 'velox_functions_spark' for registrations and renames previous 'velox_functions_spark' as 'velox_functions_spark_impl'. Tested the compilation time using `velox_functions_spark_test` target to mock the general development process: build -> modify cpp file -> build. The compilation time speeds up 1.5x(165s to 104s) in release mode and more in debug mode. Fixes https://github.com/facebookincubator/velox/issues/11564. Pull Request resolved: https://github.com/facebookincubator/velox/pull/11565 Reviewed By: miaoever, kagamiori Differential Revision: D66688101 Pulled By: xiaoxmeng fbshipit-source-id: 54ba372f08c4ec91062b3d07e8e2b81aabbdef59 --- pyvelox/signatures.cpp | 2 +- .../fuzzer/SparkExpressionFuzzerTest.cpp | 2 +- .../expression/tests/ExpressionRunnerTest.cpp | 2 +- velox/functions/sparksql/Bitwise.cpp | 160 ------ velox/functions/sparksql/Bitwise.h | 109 +++- velox/functions/sparksql/CMakeLists.txt | 23 +- velox/functions/sparksql/JsonObjectKeys.h | 1 + velox/functions/sparksql/Register.cpp | 543 ------------------ .../tests/CollectSetAggregateTest.cpp | 2 +- .../sparksql/benchmarks/CompareBenchmark.cpp | 2 +- .../sparksql/benchmarks/HashBenchmark.cpp | 2 +- .../benchmarks/SIMDCompareBenchmark.cpp | 2 +- .../functions/sparksql/coverage/Coverage.cpp | 2 +- .../fuzzer/tests/SparkQueryRunnerTest.cpp | 2 +- .../sparksql/registration/CMakeLists.txt | 46 ++ .../sparksql/registration/Register.cpp | 69 +++ .../sparksql/{ => registration}/Register.h | 2 +- .../sparksql/registration/RegisterArray.cpp | 131 +++++ .../sparksql/registration/RegisterBinary.cpp | 49 ++ .../sparksql/registration/RegisterBitwise.cpp | 54 ++ .../RegisterComparison.cpp} | 27 +- .../registration/RegisterDatetime.cpp | 91 +++ .../RegisterJson.cpp} | 12 +- .../sparksql/registration/RegisterMap.cpp | 45 ++ .../RegisterMath.cpp} | 49 +- .../sparksql/registration/RegisterMisc.cpp | 40 ++ .../sparksql/registration/RegisterRegexp.cpp | 36 ++ .../registration/RegisterSpecialForm.cpp | 49 ++ .../sparksql/registration/RegisterString.cpp | 147 +++++ .../RegisterUrl.cpp} | 14 +- .../functions/sparksql/tests/RegisterTest.cpp | 2 +- velox/functions/sparksql/tests/SliceTest.cpp | 2 +- .../sparksql/tests/SortArrayTest.cpp | 2 +- .../sparksql/tests/SparkCastExprTest.cpp | 2 +- .../sparksql/tests/SparkFunctionBaseTest.h | 2 +- 35 files changed, 947 insertions(+), 778 deletions(-) delete mode 100644 velox/functions/sparksql/Bitwise.cpp delete mode 100644 velox/functions/sparksql/Register.cpp create mode 100644 velox/functions/sparksql/registration/CMakeLists.txt create mode 100644 velox/functions/sparksql/registration/Register.cpp rename velox/functions/sparksql/{ => registration}/Register.h (94%) create mode 100644 velox/functions/sparksql/registration/RegisterArray.cpp create mode 100644 velox/functions/sparksql/registration/RegisterBinary.cpp create mode 100644 velox/functions/sparksql/registration/RegisterBitwise.cpp rename velox/functions/sparksql/{RegisterCompare.cpp => registration/RegisterComparison.cpp} (77%) create mode 100644 velox/functions/sparksql/registration/RegisterDatetime.cpp rename velox/functions/sparksql/{RegisterCompare.h => registration/RegisterJson.cpp} (72%) create mode 100644 velox/functions/sparksql/registration/RegisterMap.cpp rename velox/functions/sparksql/{RegisterArithmetic.cpp => registration/RegisterMath.cpp} (96%) create mode 100644 velox/functions/sparksql/registration/RegisterMisc.cpp create mode 100644 velox/functions/sparksql/registration/RegisterRegexp.cpp create mode 100644 velox/functions/sparksql/registration/RegisterSpecialForm.cpp create mode 100644 velox/functions/sparksql/registration/RegisterString.cpp rename velox/functions/sparksql/{RegisterArithmetic.h => registration/RegisterUrl.cpp} (67%) diff --git a/pyvelox/signatures.cpp b/pyvelox/signatures.cpp index 27b5674f6557..5e8f7e9a501a 100644 --- a/pyvelox/signatures.cpp +++ b/pyvelox/signatures.cpp @@ -19,8 +19,8 @@ #include "velox/functions/FunctionRegistry.h" #include "velox/functions/prestosql/aggregates/RegisterAggregateFunctions.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/functions/sparksql/Register.h" #include "velox/functions/sparksql/aggregates/Register.h" +#include "velox/functions/sparksql/registration/Register.h" namespace facebook::velox::py { diff --git a/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp b/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp index 3ddf9bce9d46..60e1af2281ed 100644 --- a/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp +++ b/velox/expression/fuzzer/SparkExpressionFuzzerTest.cpp @@ -24,12 +24,12 @@ #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/expression/fuzzer/FuzzerRunner.h" -#include "velox/functions/sparksql/Register.h" #include "velox/functions/sparksql/fuzzer/AddSubtractArgGenerator.h" #include "velox/functions/sparksql/fuzzer/DivideArgGenerator.h" #include "velox/functions/sparksql/fuzzer/MakeTimestampArgGenerator.h" #include "velox/functions/sparksql/fuzzer/MultiplyArgGenerator.h" #include "velox/functions/sparksql/fuzzer/UnscaledValueArgGenerator.h" +#include "velox/functions/sparksql/registration/Register.h" using namespace facebook::velox::functions::sparksql::fuzzer; using facebook::velox::fuzzer::ArgGenerator; diff --git a/velox/expression/tests/ExpressionRunnerTest.cpp b/velox/expression/tests/ExpressionRunnerTest.cpp index b942574d92e0..c65c7a95daa0 100644 --- a/velox/expression/tests/ExpressionRunnerTest.cpp +++ b/velox/expression/tests/ExpressionRunnerTest.cpp @@ -27,7 +27,7 @@ #include "velox/exec/fuzzer/ReferenceQueryRunner.h" #include "velox/expression/tests/ExpressionVerifier.h" #include "velox/functions/prestosql/registration/RegistrationFunctions.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/vector/VectorSaver.h" using namespace facebook::velox; diff --git a/velox/functions/sparksql/Bitwise.cpp b/velox/functions/sparksql/Bitwise.cpp deleted file mode 100644 index 056f3b83c704..000000000000 --- a/velox/functions/sparksql/Bitwise.cpp +++ /dev/null @@ -1,160 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "velox/functions/sparksql/Bitwise.h" - -namespace facebook::velox::functions::sparksql { - -template -struct BitwiseAndFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { - result = a & b; - } -}; - -template -struct BitwiseOrFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { - result = a | b; - } -}; - -template -struct BitwiseXorFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { - result = a ^ b; - } -}; - -template -struct BitwiseNotFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a) { - result = ~a; - } -}; - -template -struct ShiftLeftFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput1& result, TInput1 a, TInput2 b) { - if constexpr (std::is_same_v) { - if (b < 0) { - b = b % 32 + 32; - } - if (b >= 32) { - b = b % 32; - } - } - if constexpr (std::is_same_v) { - if (b < 0) { - b = b % 64 + 64; - } - if (b >= 64) { - b = b % 64; - } - } - result = a << b; - } -}; - -template -struct ShiftRightFunction { - template - FOLLY_ALWAYS_INLINE void call(TInput1& result, TInput1 a, TInput2 b) { - if constexpr (std::is_same_v) { - if (b < 0) { - b = b % 32 + 32; - } - if (b >= 32) { - b = b % 32; - } - } - if constexpr (std::is_same_v) { - if (b < 0) { - b = b % 64 + 64; - } - if (b >= 64) { - b = b % 64; - } - } - result = a >> b; - } -}; - -template -struct BitCountFunction { - template - FOLLY_ALWAYS_INLINE void call(int32_t& result, TInput num) { - constexpr int kMaxBits = sizeof(TInput) * CHAR_BIT; - auto value = static_cast(num); - result = bits::countBits(&value, 0, kMaxBits); - } -}; - -template -struct BitGetFunction { - template - FOLLY_ALWAYS_INLINE void call(int8_t& result, TInput num, int32_t pos) { - constexpr int kMaxBits = sizeof(TInput) * CHAR_BIT; - VELOX_USER_CHECK_GE( - pos, - 0, - "The value of 'pos' argument must be greater than or equal to zero."); - VELOX_USER_CHECK_LT( - pos, - kMaxBits, - "The value of 'pos' argument must not exceed the number of bits in 'x' - 1."); - result = (num >> pos) & 1; - } -}; - -void registerBitwiseFunctions(const std::string& prefix) { - registerBinaryIntegral({prefix + "bitwise_and"}); - registerBinaryIntegral({prefix + "bitwise_or"}); - registerBinaryIntegral({prefix + "bitwise_xor"}); - - registerUnaryIntegral({prefix + "bitwise_not"}); - - registerFunction({prefix + "bit_count"}); - registerFunction({prefix + "bit_count"}); - registerFunction({prefix + "bit_count"}); - registerFunction({prefix + "bit_count"}); - registerFunction({prefix + "bit_count"}); - - registerFunction( - {prefix + "bit_get"}); - registerFunction( - {prefix + "bit_get"}); - registerFunction( - {prefix + "bit_get"}); - registerFunction( - {prefix + "bit_get"}); - - registerFunction( - {prefix + "shiftleft"}); - registerFunction( - {prefix + "shiftleft"}); - - registerFunction( - {prefix + "shiftright"}); - registerFunction( - {prefix + "shiftright"}); -} - -} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Bitwise.h b/velox/functions/sparksql/Bitwise.h index 47f962707030..d5f8c71a9fff 100644 --- a/velox/functions/sparksql/Bitwise.h +++ b/velox/functions/sparksql/Bitwise.h @@ -15,12 +15,115 @@ */ #pragma once -#include #include "velox/functions/Macros.h" -#include "velox/functions/lib/RegistrationHelpers.h" namespace facebook::velox::functions::sparksql { -void registerBitwiseFunctions(const std::string& prefix); +template +struct BitwiseAndFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { + result = a & b; + } +}; + +template +struct BitwiseOrFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { + result = a | b; + } +}; + +template +struct BitwiseXorFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a, TInput b) { + result = a ^ b; + } +}; + +template +struct BitwiseNotFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput& result, TInput a) { + result = ~a; + } +}; + +template +struct ShiftLeftFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput1& result, TInput1 a, TInput2 b) { + if constexpr (std::is_same_v) { + if (b < 0) { + b = b % 32 + 32; + } + if (b >= 32) { + b = b % 32; + } + } + if constexpr (std::is_same_v) { + if (b < 0) { + b = b % 64 + 64; + } + if (b >= 64) { + b = b % 64; + } + } + result = a << b; + } +}; + +template +struct ShiftRightFunction { + template + FOLLY_ALWAYS_INLINE void call(TInput1& result, TInput1 a, TInput2 b) { + if constexpr (std::is_same_v) { + if (b < 0) { + b = b % 32 + 32; + } + if (b >= 32) { + b = b % 32; + } + } + if constexpr (std::is_same_v) { + if (b < 0) { + b = b % 64 + 64; + } + if (b >= 64) { + b = b % 64; + } + } + result = a >> b; + } +}; + +template +struct BitCountFunction { + template + FOLLY_ALWAYS_INLINE void call(int32_t& result, TInput num) { + constexpr int kMaxBits = sizeof(TInput) * CHAR_BIT; + auto value = static_cast(num); + result = bits::countBits(&value, 0, kMaxBits); + } +}; + +template +struct BitGetFunction { + template + FOLLY_ALWAYS_INLINE void call(int8_t& result, TInput num, int32_t pos) { + constexpr int kMaxBits = sizeof(TInput) * CHAR_BIT; + VELOX_USER_CHECK_GE( + pos, + 0, + "The value of 'pos' argument must be greater than or equal to zero."); + VELOX_USER_CHECK_LT( + pos, + kMaxBits, + "The value of 'pos' argument must not exceed the number of bits in 'x' - 1."); + result = (num >> pos) & 1; + } +}; } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/CMakeLists.txt b/velox/functions/sparksql/CMakeLists.txt index 1b940591e2a4..5e2f5ad58271 100644 --- a/velox/functions/sparksql/CMakeLists.txt +++ b/velox/functions/sparksql/CMakeLists.txt @@ -14,10 +14,9 @@ add_subdirectory(specialforms) velox_add_library( - velox_functions_spark + velox_functions_spark_impl ArrayGetFunction.cpp ArraySort.cpp - Bitwise.cpp Comparisons.cpp DecimalArithmetic.cpp DecimalCompare.cpp @@ -27,34 +26,22 @@ velox_add_library( MakeTimestamp.cpp Map.cpp RegexFunctions.cpp - Register.cpp - RegisterArithmetic.cpp - RegisterCompare.cpp Size.cpp String.cpp UnscaledValueFunction.cpp) -# GCC 12 has a bug where it does not respect "pragma ignore" directives and ends -# up failing compilation in an openssl header included by a hash-related -# function. -if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND NOT VELOX_MONO_LIBRARY) - target_compile_options(velox_functions_spark - PRIVATE -Wno-deprecated-declarations) -endif() - velox_link_libraries( - velox_functions_spark + velox_functions_spark_impl velox_functions_lib velox_functions_prestosql_impl velox_functions_spark_specialforms - velox_is_null_functions velox_functions_util Folly::folly simdjson::simdjson) if(NOT VELOX_MONO_LIBRARY) - set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE - high_memory_pool) + set_property(TARGET velox_functions_spark_impl PROPERTY JOB_POOL_COMPILE + high_memory_pool) endif() add_subdirectory(window) @@ -72,3 +59,5 @@ endif() if(${VELOX_ENABLE_BENCHMARKS}) add_subdirectory(benchmarks) endif() + +add_subdirectory(registration) diff --git a/velox/functions/sparksql/JsonObjectKeys.h b/velox/functions/sparksql/JsonObjectKeys.h index a320cbb08413..e9fa922e10e1 100644 --- a/velox/functions/sparksql/JsonObjectKeys.h +++ b/velox/functions/sparksql/JsonObjectKeys.h @@ -15,6 +15,7 @@ */ #pragma once +#include "velox/functions/Macros.h" #include "velox/functions/prestosql/json/SIMDJsonUtil.h" namespace facebook::velox::functions::sparksql { diff --git a/velox/functions/sparksql/Register.cpp b/velox/functions/sparksql/Register.cpp deleted file mode 100644 index c753eca82baa..000000000000 --- a/velox/functions/sparksql/Register.cpp +++ /dev/null @@ -1,543 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "velox/functions/sparksql/Register.h" - -#include "velox/expression/RegisterSpecialForm.h" -#include "velox/expression/RowConstructor.h" -#include "velox/expression/SpecialFormRegistry.h" -#include "velox/functions/lib/ArrayShuffle.h" -#include "velox/functions/lib/IsNull.h" -#include "velox/functions/lib/Re2Functions.h" -#include "velox/functions/lib/RegistrationHelpers.h" -#include "velox/functions/lib/Repeat.h" -#include "velox/functions/lib/Slice.h" -#include "velox/functions/prestosql/ArrayFunctions.h" -#include "velox/functions/prestosql/BinaryFunctions.h" -#include "velox/functions/prestosql/DateTimeFunctions.h" -#include "velox/functions/prestosql/StringFunctions.h" -#include "velox/functions/prestosql/URLFunctions.h" -#include "velox/functions/sparksql/ArrayFlattenFunction.h" -#include "velox/functions/sparksql/ArrayInsert.h" -#include "velox/functions/sparksql/ArrayMinMaxFunction.h" -#include "velox/functions/sparksql/ArraySort.h" -#include "velox/functions/sparksql/Bitwise.h" -#include "velox/functions/sparksql/DateTimeFunctions.h" -#include "velox/functions/sparksql/Hash.h" -#include "velox/functions/sparksql/In.h" -#include "velox/functions/sparksql/JsonObjectKeys.h" -#include "velox/functions/sparksql/LeastGreatest.h" -#include "velox/functions/sparksql/MaskFunction.h" -#include "velox/functions/sparksql/MightContain.h" -#include "velox/functions/sparksql/MonotonicallyIncreasingId.h" -#include "velox/functions/sparksql/RaiseError.h" -#include "velox/functions/sparksql/RegexFunctions.h" -#include "velox/functions/sparksql/RegisterArithmetic.h" -#include "velox/functions/sparksql/RegisterCompare.h" -#include "velox/functions/sparksql/Size.h" -#include "velox/functions/sparksql/SparkPartitionId.h" -#include "velox/functions/sparksql/Split.h" -#include "velox/functions/sparksql/String.h" -#include "velox/functions/sparksql/StringToMap.h" -#include "velox/functions/sparksql/UnscaledValueFunction.h" -#include "velox/functions/sparksql/Uuid.h" -#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" -#include "velox/functions/sparksql/specialforms/DecimalRound.h" -#include "velox/functions/sparksql/specialforms/MakeDecimal.h" -#include "velox/functions/sparksql/specialforms/SparkCastExpr.h" - -namespace facebook::velox::functions { -extern void registerElementAtFunction( - const std::string& name, - bool enableCaching); - -template -inline void registerArrayRemoveFunctions(const std::string& prefix) { - registerFunction, Array, T>( - {prefix + "array_remove"}); -} - -inline void registerArrayRemoveFunctions(const std::string& prefix) { - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions(prefix); - registerArrayRemoveFunctions>(prefix); - registerFunction< - ArrayRemoveFunctionString, - Array, - Array, - Varchar>({prefix + "array_remove"}); -} - -static void workAroundRegistrationMacro(const std::string& prefix) { - // VELOX_REGISTER_VECTOR_FUNCTION must be invoked in the same namespace as the - // vector function definition. - // Higher order functions. - VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "aggregate"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter"); - // Spark and Presto map_filter function has the same definition: - // function expression corresponds to body, arguments to signature - VELOX_REGISTER_VECTOR_FUNCTION(udf_map_filter, prefix + "map_filter"); - // Complex types. - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, prefix + "array"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, prefix + "array_contains"); - VELOX_REGISTER_VECTOR_FUNCTION( - udf_array_intersect, prefix + "array_intersect"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_except, prefix + "array_except"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_position, prefix + "array_position"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_zip_with, prefix + "zip_with"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "forall"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "exists"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_zip, prefix + "arrays_zip"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with"); - - // This is the semantics of spark.sql.ansi.enabled = false. - registerElementAtFunction(prefix + "element_at", true); - - VELOX_REGISTER_VECTOR_FUNCTION( - udf_map_allow_duplicates, prefix + "map_from_arrays"); - VELOX_REGISTER_VECTOR_FUNCTION( - udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor); - // String functions. - VELOX_REGISTER_VECTOR_FUNCTION(udf_concat, prefix + "concat"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_lower, prefix + "lower"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, prefix + "upper"); - VELOX_REGISTER_VECTOR_FUNCTION(udf_reverse, prefix + "reverse"); - // Logical. - VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not"); - registerIsNullFunction(prefix + "isnull"); - registerIsNotNullFunction(prefix + "isnotnull"); - registerArrayRemoveFunctions(prefix); -} - -namespace sparksql { - -void registerAllSpecialFormGeneralFunctions() { - exec::registerFunctionCallToSpecialForms(); - exec::registerFunctionCallToSpecialForm( - MakeDecimalCallToSpecialForm::kMakeDecimal, - std::make_unique()); - exec::registerFunctionCallToSpecialForm( - DecimalRoundCallToSpecialForm::kRoundDecimal, - std::make_unique()); - registerFunctionCallToSpecialForm( - "cast", std::make_unique()); - registerFunctionCallToSpecialForm( - "try_cast", std::make_unique()); - exec::registerFunctionCallToSpecialForm( - AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls, - std::make_unique()); -} - -namespace { -template -inline void registerArrayMinMaxFunctions(const std::string& prefix) { - registerFunction>({prefix + "array_min"}); - registerFunction>({prefix + "array_max"}); -} - -inline void registerArrayMinMaxFunctions(const std::string& prefix) { - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); - registerArrayMinMaxFunctions(prefix); -} -} // namespace - -void registerFunctions(const std::string& prefix) { - registerAllSpecialFormGeneralFunctions(); - - // Register size functions - registerSize(prefix + "size"); - - registerRegexpReplace(prefix); - - registerFunction, Varchar>( - {prefix + "json_object_keys"}); - - // Register string functions. - registerFunction({prefix + "chr"}); - registerFunction({prefix + "ascii"}); - registerFunction( - {prefix + "lpad"}); - registerFunction( - {prefix + "rpad"}); - registerFunction( - {prefix + "lpad"}); - registerFunction( - {prefix + "rpad"}); - registerFunction( - {prefix + "substring"}); - registerFunction< - sparksql::SubstrFunction, - Varchar, - Varchar, - int32_t, - int32_t>({prefix + "substring"}); - registerFunction< - sparksql::OverlayVarcharFunction, - Varchar, - Varchar, - Varchar, - int32_t, - int32_t>({prefix + "overlay"}); - registerFunction< - sparksql::OverlayVarbinaryFunction, - Varbinary, - Varbinary, - Varbinary, - int32_t, - int32_t>({prefix + "overlay"}); - - registerFunction< - sparksql::StringToMapFunction, - Map, - Varchar, - Varchar, - Varchar>({prefix + "str_to_map"}); - - registerFunction( - {prefix + "left"}); - - registerFunction( - {prefix + "bit_length"}); - registerFunction( - {prefix + "bit_length"}); - - exec::registerStatefulVectorFunction( - prefix + "instr", instrSignatures(), makeInstr); - exec::registerStatefulVectorFunction( - prefix + "length", lengthSignatures(), makeLength); - registerFunction( - {prefix + "substring_index"}); - - registerFunction({prefix + "md5"}); - registerFunction( - {prefix + "sha1"}); - registerFunction( - {prefix + "sha2"}); - registerFunction({prefix + "crc32"}); - registerFunction( - {prefix + "empty2null"}); - - exec::registerStatefulVectorFunction( - prefix + "regexp_extract", re2ExtractSignatures(), makeRegexExtract); - exec::registerStatefulVectorFunction( - prefix + "regexp_extract_all", - re2ExtractAllSignatures(), - makeRe2ExtractAll); - exec::registerStatefulVectorFunction( - prefix + "rlike", re2SearchSignatures(), makeRLike); - exec::registerStatefulVectorFunction( - prefix + "like", likeSignatures(), makeLike); - - exec::registerStatefulVectorFunction( - prefix + "least", - leastSignatures(), - makeLeast, - exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); - exec::registerStatefulVectorFunction( - prefix + "greatest", - greatestSignatures(), - makeGreatest, - exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); - exec::registerStatefulVectorFunction( - prefix + "hash", hashSignatures(), makeHash, hashMetadata()); - exec::registerStatefulVectorFunction( - prefix + "hash_with_seed", - hashWithSeedSignatures(), - makeHashWithSeed, - hashMetadata()); - exec::registerStatefulVectorFunction( - prefix + "xxhash64", xxhash64Signatures(), makeXxHash64, hashMetadata()); - exec::registerStatefulVectorFunction( - prefix + "xxhash64_with_seed", - xxhash64WithSeedSignatures(), - makeXxHash64WithSeed, - hashMetadata()); - VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map"); - - // Register 'in' functions. - registerIn(prefix); - - // These vector functions are only accessible via the - // VELOX_REGISTER_VECTOR_FUNCTION macro, which must be invoked in the same - // namespace as the function definition. - workAroundRegistrationMacro(prefix); - - // These groups of functions involve instantiating many templates. They're - // broken out into a separate compilation unit to improve build latency. - registerArithmeticFunctions(prefix); - registerCompareFunctions(prefix); - registerBitwiseFunctions(prefix); - - // String search function - registerFunction( - {prefix + "startswith"}); - registerFunction( - {prefix + "endswith"}); - registerFunction( - {prefix + "contains"}); - registerFunction( - {prefix + "locate"}); - - registerFunction({prefix + "trim"}); - registerFunction({prefix + "trim"}); - registerFunction({prefix + "ltrim"}); - registerFunction( - {prefix + "ltrim"}); - registerFunction({prefix + "rtrim"}); - registerFunction( - {prefix + "rtrim"}); - - registerFunction( - {prefix + "translate"}); - - registerFunction( - {prefix + "conv"}); - - registerFunction( - {prefix + "replace"}); - registerFunction( - {prefix + "replace"}); - - registerFunction( - {prefix + "find_in_set"}); - - registerFunction( - {prefix + "url_encode"}); - registerFunction( - {prefix + "url_decode"}); - - // Register array sort functions. - exec::registerStatefulVectorFunction( - prefix + "array_sort", arraySortSignatures(), makeArraySort); - exec::registerStatefulVectorFunction( - prefix + "sort_array", sortArraySignatures(), makeSortArray); - - exec::registerStatefulVectorFunction( - prefix + "array_repeat", - repeatSignatures(), - makeRepeatAllowNegativeCount, - repeatMetadata()); - - registerIntegerSliceFunction(prefix); - - exec::registerStatefulVectorFunction( - prefix + "shuffle", - arrayShuffleWithCustomSeedSignatures(), - makeArrayShuffleWithCustomSeed, - getMetadataForArrayShuffle()); - - VELOX_REGISTER_VECTOR_FUNCTION(udf_array_get, prefix + "get"); - - // Register date functions. - registerFunction({prefix + "year"}); - registerFunction({prefix + "year"}); - registerFunction({prefix + "week_of_year"}); - registerFunction( - {prefix + "year_of_week"}); - - registerFunction( - {prefix + "to_utc_timestamp"}); - registerFunction( - {prefix + "from_utc_timestamp"}); - - registerFunction({prefix + "unix_date"}); - - registerFunction( - {prefix + "unix_seconds"}); - - registerFunction({prefix + "unix_timestamp"}); - - registerFunction( - {prefix + "unix_timestamp", prefix + "to_unix_timestamp"}); - registerFunction< - UnixTimestampParseWithFormatFunction, - int64_t, - Varchar, - Varchar>({prefix + "unix_timestamp", prefix + "to_unix_timestamp"}); - registerFunction( - {prefix + "from_unixtime"}); - registerFunction( - {prefix + "make_date"}); - registerFunction( - {prefix + "datediff"}); - registerFunction({prefix + "last_day"}); - registerFunction( - {prefix + "add_months"}); - - registerFunction({prefix + "date_add"}); - registerFunction({prefix + "date_add"}); - registerFunction({prefix + "date_add"}); - - registerFunction( - {prefix + "date_from_unix_date"}); - - registerFunction({prefix + "date_sub"}); - registerFunction({prefix + "date_sub"}); - registerFunction({prefix + "date_sub"}); - - registerFunction( - {prefix + "day", prefix + "dayofmonth"}); - registerFunction( - {prefix + "doy", prefix + "dayofyear"}); - - registerFunction({prefix + "dayofweek"}); - - registerFunction({prefix + "weekday"}); - - registerFunction({prefix + "quarter"}); - - registerFunction({prefix + "month"}); - - registerFunction({prefix + "next_day"}); - - registerFunction( - {prefix + "get_timestamp"}); - - registerFunction({prefix + "hour"}); - - registerFunction({prefix + "minute"}); - - registerFunction({prefix + "second"}); - - registerFunction( - {prefix + "make_ym_interval"}); - registerFunction( - {prefix + "make_ym_interval"}); - registerFunction( - {prefix + "make_ym_interval"}); - - VELOX_REGISTER_VECTOR_FUNCTION(udf_make_timestamp, prefix + "make_timestamp"); - - registerFunction( - {prefix + "unix_micros"}); - registerUnaryIntegralWithTReturn( - {prefix + "timestamp_micros"}); - registerFunction( - {prefix + "unix_millis"}); - registerUnaryIntegralWithTReturn( - {prefix + "timestamp_millis"}); - - // Register bloom filter function - registerFunction( - {prefix + "might_contain"}); - - registerArrayMinMaxFunctions(prefix); - - // Register decimal vector functions. - exec::registerVectorFunction( - prefix + "unscaled_value", - unscaledValueSignatures(), - makeUnscaledValue()); - - registerFunction( - {prefix + "spark_partition_id"}); - - registerFunction( - {prefix + "monotonically_increasing_id"}); - - registerFunction>({prefix + "uuid"}); - - registerFunction< - ArrayFlattenFunction, - Array>, - Array>>>({prefix + "flatten"}); - - registerFunction( - {prefix + "repeat"}); - - registerFunction({prefix + "soundex"}); - - registerFunction( - {prefix + "raise_error"}); - - registerFunction< - LevenshteinDistanceFunction, - int32_t, - Varchar, - Varchar, - int32_t>({prefix + "levenshtein"}); - registerFunction( - {prefix + "levenshtein"}); - - registerFunction< - ArrayInsert, - Array>, - Array>, - int32_t, - Generic, - bool>({prefix + "array_insert"}); - - registerFunction, Varchar, Varchar>({prefix + "split"}); - registerFunction, Varchar, Varchar, int32_t>( - {prefix + "split"}); - - registerFunction({prefix + "mask"}); - registerFunction({prefix + "mask"}); - registerFunction( - {prefix + "mask"}); - registerFunction( - {prefix + "mask"}); - registerFunction< - MaskFunction, - Varchar, - Varchar, - Varchar, - Varchar, - Varchar, - Varchar>({prefix + "mask"}); -} - -std::vector listFunctionNames() { - std::vector names = - exec::specialFormRegistry().getSpecialFormNames(); - - const auto& simpleFunctions = exec::simpleFunctions().getFunctionNames(); - names.insert(names.end(), simpleFunctions.begin(), simpleFunctions.end()); - - exec::vectorFunctionFactories().withRLock([&](const auto& map) { - names.reserve(names.size() + map.size()); - for (const auto& [name, _] : map) { - names.push_back(name); - } - }); - - return names; -} - -} // namespace sparksql -} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp index 4a06bb10ac8e..95dc45207fd5 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp @@ -15,8 +15,8 @@ */ #include "velox/functions/lib/aggregates/tests/utils/AggregationTestBase.h" -#include "velox/functions/sparksql/Register.h" #include "velox/functions/sparksql/aggregates/Register.h" +#include "velox/functions/sparksql/registration/Register.h" using namespace facebook::velox::functions::aggregate::test; diff --git a/velox/functions/sparksql/benchmarks/CompareBenchmark.cpp b/velox/functions/sparksql/benchmarks/CompareBenchmark.cpp index b0b7efdfc3cb..a7207d193fb1 100644 --- a/velox/functions/sparksql/benchmarks/CompareBenchmark.cpp +++ b/velox/functions/sparksql/benchmarks/CompareBenchmark.cpp @@ -18,7 +18,7 @@ #include #include "velox/benchmarks/ExpressionBenchmarkBuilder.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" using namespace facebook; diff --git a/velox/functions/sparksql/benchmarks/HashBenchmark.cpp b/velox/functions/sparksql/benchmarks/HashBenchmark.cpp index b197cadf5965..74b32e1f0d19 100644 --- a/velox/functions/sparksql/benchmarks/HashBenchmark.cpp +++ b/velox/functions/sparksql/benchmarks/HashBenchmark.cpp @@ -18,7 +18,7 @@ #include #include "velox/benchmarks/ExpressionBenchmarkBuilder.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" using namespace facebook; diff --git a/velox/functions/sparksql/benchmarks/SIMDCompareBenchmark.cpp b/velox/functions/sparksql/benchmarks/SIMDCompareBenchmark.cpp index 8dda0d515596..a9c02bf0ff01 100644 --- a/velox/functions/sparksql/benchmarks/SIMDCompareBenchmark.cpp +++ b/velox/functions/sparksql/benchmarks/SIMDCompareBenchmark.cpp @@ -18,7 +18,7 @@ #include #include "velox/benchmarks/ExpressionBenchmarkBuilder.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/vector/fuzzer/VectorFuzzer.h" using namespace facebook; diff --git a/velox/functions/sparksql/coverage/Coverage.cpp b/velox/functions/sparksql/coverage/Coverage.cpp index 91ba43e21396..5a01b9b7d0b6 100644 --- a/velox/functions/sparksql/coverage/Coverage.cpp +++ b/velox/functions/sparksql/coverage/Coverage.cpp @@ -17,8 +17,8 @@ #include #include "velox/exec/Aggregate.h" #include "velox/functions/CoverageUtil.h" -#include "velox/functions/sparksql/Register.h" #include "velox/functions/sparksql/aggregates/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/functions/sparksql/window/WindowFunctionsRegistration.h" DEFINE_bool(all, false, "Generate coverage map for all Spark functions"); diff --git a/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp b/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp index 5feaa79beebf..7c53376e0ad8 100644 --- a/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp +++ b/velox/functions/sparksql/fuzzer/tests/SparkQueryRunnerTest.cpp @@ -21,9 +21,9 @@ #include "velox/dwio/parquet/RegisterParquetWriter.h" #include "velox/exec/tests/utils/AssertQueryBuilder.h" #include "velox/exec/tests/utils/PlanBuilder.h" -#include "velox/functions/sparksql/Register.h" #include "velox/functions/sparksql/aggregates/Register.h" #include "velox/functions/sparksql/fuzzer/SparkQueryRunner.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/parse/TypeResolver.h" #include "velox/vector/tests/utils/VectorTestBase.h" diff --git a/velox/functions/sparksql/registration/CMakeLists.txt b/velox/functions/sparksql/registration/CMakeLists.txt new file mode 100644 index 000000000000..c7ca8ed7c18f --- /dev/null +++ b/velox/functions/sparksql/registration/CMakeLists.txt @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +velox_add_library( + velox_functions_spark + Register.cpp + RegisterArray.cpp + RegisterBinary.cpp + RegisterBitwise.cpp + RegisterComparison.cpp + RegisterDatetime.cpp + RegisterJson.cpp + RegisterMap.cpp + RegisterMath.cpp + RegisterMisc.cpp + RegisterRegexp.cpp + RegisterSpecialForm.cpp + RegisterString.cpp + RegisterUrl.cpp) + +# GCC 12 has a bug where it does not respect "pragma ignore" directives and ends +# up failing compilation in an openssl header included by a hash-related +# function. +if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND NOT VELOX_MONO_LIBRARY) + target_compile_options(velox_functions_spark + PRIVATE -Wno-deprecated-declarations) +endif() + +velox_link_libraries(velox_functions_spark velox_functions_spark_impl + velox_is_null_functions simdjson::simdjson) + +if(NOT VELOX_MONO_LIBRARY) + set_property(TARGET velox_functions_spark PROPERTY JOB_POOL_COMPILE + high_memory_pool) +endif() diff --git a/velox/functions/sparksql/registration/Register.cpp b/velox/functions/sparksql/registration/Register.cpp new file mode 100644 index 000000000000..aacbfacc4828 --- /dev/null +++ b/velox/functions/sparksql/registration/Register.cpp @@ -0,0 +1,69 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/sparksql/registration/Register.h" +#include "velox/expression/SimpleFunctionRegistry.h" +#include "velox/expression/SpecialFormRegistry.h" + +namespace facebook::velox::functions::sparksql { + +extern void registerArrayFunctions(const std::string& prefix); +extern void registerBinaryFunctions(const std::string& prefix); +extern void registerBitwiseFunctions(const std::string& prefix); +extern void registerCompareFunctions(const std::string& prefix); +extern void registerDatetimeFunctions(const std::string& prefix); +extern void registerJsonFunctions(const std::string& prefix); +extern void registerMapFunctions(const std::string& prefix); +extern void registerMathFunctions(const std::string& prefix); +extern void registerMiscFunctions(const std::string& prefix); +extern void registerRegexpFunctions(const std::string& prefix); +extern void registerSpecialFormGeneralFunctions(const std::string& prefix); +extern void registerStringFunctions(const std::string& prefix); +extern void registerUrlFunctions(const std::string& prefix); + +void registerFunctions(const std::string& prefix) { + registerArrayFunctions(prefix); + registerBinaryFunctions(prefix); + registerBitwiseFunctions(prefix); + registerCompareFunctions(prefix); + registerDatetimeFunctions(prefix); + registerJsonFunctions(prefix); + registerMapFunctions(prefix); + registerMathFunctions(prefix); + registerMiscFunctions(prefix); + registerRegexpFunctions(prefix); + registerSpecialFormGeneralFunctions(prefix); + registerStringFunctions(prefix); + registerUrlFunctions(prefix); +} + +std::vector listFunctionNames() { + std::vector names = + exec::specialFormRegistry().getSpecialFormNames(); + + const auto& simpleFunctions = exec::simpleFunctions().getFunctionNames(); + names.insert(names.end(), simpleFunctions.begin(), simpleFunctions.end()); + + exec::vectorFunctionFactories().withRLock([&](const auto& map) { + names.reserve(names.size() + map.size()); + for (const auto& [name, _] : map) { + names.push_back(name); + } + }); + + return names; +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/Register.h b/velox/functions/sparksql/registration/Register.h similarity index 94% rename from velox/functions/sparksql/Register.h rename to velox/functions/sparksql/registration/Register.h index 1333a25bf452..2f3da2d4624a 100644 --- a/velox/functions/sparksql/Register.h +++ b/velox/functions/sparksql/registration/Register.h @@ -20,7 +20,7 @@ namespace facebook::velox::functions::sparksql { -void registerFunctions(const std::string& prefix); +void registerFunctions(const std::string& prefix = ""); /// Return all the registered scalar function names include simple functions, /// vector functions and special forms. diff --git a/velox/functions/sparksql/registration/RegisterArray.cpp b/velox/functions/sparksql/registration/RegisterArray.cpp new file mode 100644 index 000000000000..38cc02c452ee --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterArray.cpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/ArrayShuffle.h" +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/lib/Repeat.h" +#include "velox/functions/lib/Slice.h" +#include "velox/functions/prestosql/ArrayFunctions.h" +#include "velox/functions/sparksql/ArrayFlattenFunction.h" +#include "velox/functions/sparksql/ArrayInsert.h" +#include "velox/functions/sparksql/ArrayMinMaxFunction.h" +#include "velox/functions/sparksql/ArraySort.h" + +namespace facebook::velox::functions { + +// VELOX_REGISTER_VECTOR_FUNCTION must be invoked in the same namespace as the +// vector function definition. +// Higher order functions. +void registerSparkArrayFunctions(const std::string& prefix) { + VELOX_REGISTER_VECTOR_FUNCTION(udf_transform, prefix + "transform"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_reduce, prefix + "aggregate"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_constructor, prefix + "array"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_contains, prefix + "array_contains"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_distinct, prefix + "array_distinct"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_except, prefix + "array_except"); + VELOX_REGISTER_VECTOR_FUNCTION( + udf_array_intersect, prefix + "array_intersect"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_position, prefix + "array_position"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_zip, prefix + "arrays_zip"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_any_match, prefix + "exists"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_filter, prefix + "filter"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_all_match, prefix + "forall"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_zip_with, prefix + "zip_with"); +} + +namespace sparksql { + +template +inline void registerArrayMinMaxFunctions(const std::string& prefix) { + registerFunction>({prefix + "array_min"}); + registerFunction>({prefix + "array_max"}); +} + +inline void registerArrayMinMaxFunctions(const std::string& prefix) { + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); + registerArrayMinMaxFunctions(prefix); +} + +template +inline void registerArrayRemoveFunctions(const std::string& prefix) { + registerFunction, Array, T>( + {prefix + "array_remove"}); +} + +inline void registerArrayRemoveFunctions(const std::string& prefix) { + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerArrayRemoveFunctions>(prefix); + registerFunction< + ArrayRemoveFunctionString, + Array, + Array, + Varchar>({prefix + "array_remove"}); +} + +void registerArrayFunctions(const std::string& prefix) { + registerArrayMinMaxFunctions(prefix); + registerArrayRemoveFunctions(prefix); + registerSparkArrayFunctions(prefix); + // Register array sort functions. + exec::registerStatefulVectorFunction( + prefix + "array_sort", arraySortSignatures(), makeArraySort); + exec::registerStatefulVectorFunction( + prefix + "sort_array", sortArraySignatures(), makeSortArray); + exec::registerStatefulVectorFunction( + prefix + "array_repeat", + repeatSignatures(), + makeRepeatAllowNegativeCount, + repeatMetadata()); + registerFunction< + ArrayFlattenFunction, + Array>, + Array>>>({prefix + "flatten"}); + registerFunction< + ArrayInsert, + Array>, + Array>, + int32_t, + Generic, + bool>({prefix + "array_insert"}); + VELOX_REGISTER_VECTOR_FUNCTION(udf_array_get, prefix + "get"); + exec::registerStatefulVectorFunction( + prefix + "shuffle", + arrayShuffleWithCustomSeedSignatures(), + makeArrayShuffleWithCustomSeed, + getMetadataForArrayShuffle()); + registerIntegerSliceFunction(prefix); +} + +} // namespace sparksql +} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/registration/RegisterBinary.cpp b/velox/functions/sparksql/registration/RegisterBinary.cpp new file mode 100644 index 000000000000..92b5bbaf8c77 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterBinary.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/prestosql/BinaryFunctions.h" +#include "velox/functions/sparksql/Hash.h" +#include "velox/functions/sparksql/MightContain.h" +#include "velox/functions/sparksql/String.h" + +namespace facebook::velox::functions::sparksql { + +void registerBinaryFunctions(const std::string& prefix) { + registerFunction({prefix + "crc32"}); + exec::registerStatefulVectorFunction( + prefix + "hash", hashSignatures(), makeHash, hashMetadata()); + exec::registerStatefulVectorFunction( + prefix + "hash_with_seed", + hashWithSeedSignatures(), + makeHashWithSeed, + hashMetadata()); + exec::registerStatefulVectorFunction( + prefix + "xxhash64", xxhash64Signatures(), makeXxHash64, hashMetadata()); + exec::registerStatefulVectorFunction( + prefix + "xxhash64_with_seed", + xxhash64WithSeedSignatures(), + makeXxHash64WithSeed, + hashMetadata()); + registerFunction({prefix + "md5"}); + registerFunction( + {prefix + "might_contain"}); + registerFunction( + {prefix + "sha1"}); + registerFunction( + {prefix + "sha2"}); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterBitwise.cpp b/velox/functions/sparksql/registration/RegisterBitwise.cpp new file mode 100644 index 000000000000..cdefc5b16853 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterBitwise.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/Bitwise.h" + +namespace facebook::velox::functions::sparksql { + +void registerBitwiseFunctions(const std::string& prefix) { + registerBinaryIntegral({prefix + "bitwise_and"}); + registerBinaryIntegral({prefix + "bitwise_or"}); + registerBinaryIntegral({prefix + "bitwise_xor"}); + + registerUnaryIntegral({prefix + "bitwise_not"}); + + registerFunction({prefix + "bit_count"}); + registerFunction({prefix + "bit_count"}); + registerFunction({prefix + "bit_count"}); + registerFunction({prefix + "bit_count"}); + registerFunction({prefix + "bit_count"}); + + registerFunction( + {prefix + "bit_get"}); + registerFunction( + {prefix + "bit_get"}); + registerFunction( + {prefix + "bit_get"}); + registerFunction( + {prefix + "bit_get"}); + + registerFunction( + {prefix + "shiftleft"}); + registerFunction( + {prefix + "shiftleft"}); + + registerFunction( + {prefix + "shiftright"}); + registerFunction( + {prefix + "shiftright"}); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterCompare.cpp b/velox/functions/sparksql/registration/RegisterComparison.cpp similarity index 77% rename from velox/functions/sparksql/RegisterCompare.cpp rename to velox/functions/sparksql/registration/RegisterComparison.cpp index f9d3f7e2af13..fdb3223c1139 100644 --- a/velox/functions/sparksql/RegisterCompare.cpp +++ b/velox/functions/sparksql/registration/RegisterComparison.cpp @@ -13,14 +13,21 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/functions/sparksql/RegisterCompare.h" - +#include "velox/functions/lib/IsNull.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/sparksql/Comparisons.h" +#include "velox/functions/sparksql/LeastGreatest.h" -namespace facebook::velox::functions::sparksql { +namespace facebook::velox::functions { +void registerSparkCompareFunctions(const std::string& prefix) { + VELOX_REGISTER_VECTOR_FUNCTION(udf_not, prefix + "not"); + registerIsNullFunction(prefix + "isnull"); + registerIsNotNullFunction(prefix + "isnotnull"); +} +namespace sparksql { void registerCompareFunctions(const std::string& prefix) { + registerSparkCompareFunctions(prefix); exec::registerStatefulVectorFunction( prefix + "equalto", comparisonSignatures(), makeEqualTo); registerFunction, Generic>( @@ -31,6 +38,11 @@ void registerCompareFunctions(const std::string& prefix) { prefix + "greaterthan", comparisonSignatures(), makeGreaterThan); exec::registerStatefulVectorFunction( prefix + "lessthanorequal", comparisonSignatures(), makeLessThanOrEqual); + exec::registerStatefulVectorFunction( + prefix + "least", + leastSignatures(), + makeLeast, + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); exec::registerStatefulVectorFunction( prefix + "greaterthanorequal", comparisonSignatures(), @@ -43,6 +55,11 @@ void registerCompareFunctions(const std::string& prefix) { exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); registerFunction, Generic>( {prefix + "equalnullsafe"}); + exec::registerStatefulVectorFunction( + prefix + "greatest", + greatestSignatures(), + makeGreatest, + exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build()); registerFunction( {prefix + "between"}); registerFunction( @@ -69,5 +86,5 @@ void registerCompareFunctions(const std::string& prefix) { VELOX_REGISTER_VECTOR_FUNCTION( udf_decimal_neq, prefix + "decimal_notequalto"); } - -} // namespace facebook::velox::functions::sparksql +} // namespace sparksql +} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/registration/RegisterDatetime.cpp b/velox/functions/sparksql/registration/RegisterDatetime.cpp new file mode 100644 index 000000000000..334a021457b0 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterDatetime.cpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/prestosql/DateTimeFunctions.h" +#include "velox/functions/sparksql/DateTimeFunctions.h" + +namespace facebook::velox::functions::sparksql { + +void registerDatetimeFunctions(const std::string& prefix) { + registerFunction({prefix + "year"}); + registerFunction({prefix + "year"}); + registerFunction({prefix + "week_of_year"}); + registerFunction( + {prefix + "year_of_week"}); + registerFunction( + {prefix + "to_utc_timestamp"}); + registerFunction( + {prefix + "from_utc_timestamp"}); + registerFunction({prefix + "unix_date"}); + registerFunction( + {prefix + "unix_seconds"}); + registerFunction({prefix + "unix_timestamp"}); + registerFunction( + {prefix + "unix_timestamp", prefix + "to_unix_timestamp"}); + registerFunction< + UnixTimestampParseWithFormatFunction, + int64_t, + Varchar, + Varchar>({prefix + "unix_timestamp", prefix + "to_unix_timestamp"}); + registerFunction( + {prefix + "from_unixtime"}); + registerFunction( + {prefix + "make_date"}); + registerFunction( + {prefix + "datediff"}); + registerFunction({prefix + "last_day"}); + registerFunction( + {prefix + "add_months"}); + registerFunction({prefix + "date_add"}); + registerFunction({prefix + "date_add"}); + registerFunction({prefix + "date_add"}); + registerFunction( + {prefix + "date_from_unix_date"}); + registerFunction({prefix + "date_sub"}); + registerFunction({prefix + "date_sub"}); + registerFunction({prefix + "date_sub"}); + registerFunction( + {prefix + "day", prefix + "dayofmonth"}); + registerFunction( + {prefix + "doy", prefix + "dayofyear"}); + registerFunction({prefix + "dayofweek"}); + registerFunction({prefix + "weekday"}); + registerFunction({prefix + "quarter"}); + registerFunction({prefix + "month"}); + registerFunction({prefix + "next_day"}); + registerFunction( + {prefix + "get_timestamp"}); + registerFunction({prefix + "hour"}); + registerFunction({prefix + "minute"}); + registerFunction({prefix + "second"}); + registerFunction( + {prefix + "make_ym_interval"}); + registerFunction( + {prefix + "make_ym_interval"}); + registerFunction( + {prefix + "make_ym_interval"}); + VELOX_REGISTER_VECTOR_FUNCTION(udf_make_timestamp, prefix + "make_timestamp"); + registerFunction( + {prefix + "unix_micros"}); + registerUnaryIntegralWithTReturn( + {prefix + "timestamp_micros"}); + registerFunction( + {prefix + "unix_millis"}); + registerUnaryIntegralWithTReturn( + {prefix + "timestamp_millis"}); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/RegisterCompare.h b/velox/functions/sparksql/registration/RegisterJson.cpp similarity index 72% rename from velox/functions/sparksql/RegisterCompare.h rename to velox/functions/sparksql/registration/RegisterJson.cpp index 72efff91188b..e98052563f8e 100644 --- a/velox/functions/sparksql/RegisterCompare.h +++ b/velox/functions/sparksql/registration/RegisterJson.cpp @@ -13,10 +13,14 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once - -#include +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/JsonObjectKeys.h" namespace facebook::velox::functions::sparksql { -void registerCompareFunctions(const std::string& prefix); + +void registerJsonFunctions(const std::string& prefix) { + registerFunction, Varchar>( + {prefix + "json_object_keys"}); +} + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterMap.cpp b/velox/functions/sparksql/registration/RegisterMap.cpp new file mode 100644 index 000000000000..663d7de459a4 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterMap.cpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/Size.h" + +namespace facebook::velox::functions { +extern void registerElementAtFunction( + const std::string& name, + bool enableCaching); + +void registerSparkMapFunctions(const std::string& prefix) { + VELOX_REGISTER_VECTOR_FUNCTION( + udf_map_allow_duplicates, prefix + "map_from_arrays"); + // Spark and Presto map_filter function has the same definition: + // function expression corresponds to body, arguments to signature + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_filter, prefix + "map_filter"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_entries, prefix + "map_entries"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with"); +} + +namespace sparksql { +void registerMapFunctions(const std::string& prefix) { + registerSparkMapFunctions(prefix); + VELOX_REGISTER_VECTOR_FUNCTION(udf_map, prefix + "map"); + // This is the semantics of spark.sql.ansi.enabled = false. + registerElementAtFunction(prefix + "element_at", true); + registerSize(prefix + "size"); +} +} // namespace sparksql +} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/RegisterArithmetic.cpp b/velox/functions/sparksql/registration/RegisterMath.cpp similarity index 96% rename from velox/functions/sparksql/RegisterArithmetic.cpp rename to velox/functions/sparksql/registration/RegisterMath.cpp index 684b4e110b11..0532c5a82055 100644 --- a/velox/functions/sparksql/RegisterArithmetic.cpp +++ b/velox/functions/sparksql/registration/RegisterMath.cpp @@ -13,7 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include "velox/functions/sparksql/RegisterArithmetic.h" #include "velox/functions/lib/RegistrationHelpers.h" #include "velox/functions/prestosql/Arithmetic.h" #include "velox/functions/sparksql/Arithmetic.h" @@ -22,31 +21,7 @@ namespace facebook::velox::functions::sparksql { -void registerRandFunctions(const std::string& prefix) { - registerFunction({prefix + "rand", prefix + "random"}); - registerFunction>( - {prefix + "rand", prefix + "random"}); - registerFunction>( - {prefix + "rand", prefix + "random"}); -} - -void registerArithmeticFunctions(const std::string& prefix) { - // Operators. - registerBinaryNumeric({prefix + "add"}); - registerBinaryNumeric({prefix + "subtract"}); - registerBinaryNumeric({prefix + "multiply"}); - registerFunction({prefix + "divide"}); - registerBinaryNumeric({prefix + "remainder"}); - registerUnaryNumeric({prefix + "unaryminus"}); - registerFunction< - UnaryMinusFunction, - LongDecimal, - LongDecimal>({prefix + "unaryminus"}); - registerFunction< - UnaryMinusFunction, - ShortDecimal, - ShortDecimal>({prefix + "unaryminus"}); - // Math functions. +void registerMathFunctions(const std::string& prefix) { registerUnaryNumeric({prefix + "abs"}); registerFunction< DecimalAbsFunction, @@ -112,7 +87,27 @@ void registerArithmeticFunctions(const std::string& prefix) { double, double, int64_t>({prefix + "width_bucket"}); - registerRandFunctions(prefix); + registerFunction({prefix + "rand", prefix + "random"}); + registerFunction>( + {prefix + "rand", prefix + "random"}); + registerFunction>( + {prefix + "rand", prefix + "random"}); + + // Operators. + registerBinaryNumeric({prefix + "add"}); + registerBinaryNumeric({prefix + "subtract"}); + registerBinaryNumeric({prefix + "multiply"}); + registerFunction({prefix + "divide"}); + registerBinaryNumeric({prefix + "remainder"}); + registerUnaryNumeric({prefix + "unaryminus"}); + registerFunction< + UnaryMinusFunction, + LongDecimal, + LongDecimal>({prefix + "unaryminus"}); + registerFunction< + UnaryMinusFunction, + ShortDecimal, + ShortDecimal>({prefix + "unaryminus"}); registerDecimalAdd(prefix); registerDecimalSubtract(prefix); diff --git a/velox/functions/sparksql/registration/RegisterMisc.cpp b/velox/functions/sparksql/registration/RegisterMisc.cpp new file mode 100644 index 000000000000..9c08846fc1e8 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterMisc.cpp @@ -0,0 +1,40 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/SpecialFormRegistry.h" +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/In.h" +#include "velox/functions/sparksql/MonotonicallyIncreasingId.h" +#include "velox/functions/sparksql/RaiseError.h" +#include "velox/functions/sparksql/SparkPartitionId.h" +#include "velox/functions/sparksql/UnscaledValueFunction.h" +#include "velox/functions/sparksql/Uuid.h" + +namespace facebook::velox::functions::sparksql { +void registerMiscFunctions(const std::string& prefix) { + registerFunction( + {prefix + "monotonically_increasing_id"}); + registerFunction( + {prefix + "raise_error"}); + registerFunction( + {prefix + "spark_partition_id"}); + registerIn(prefix); + exec::registerVectorFunction( + prefix + "unscaled_value", + unscaledValueSignatures(), + makeUnscaledValue()); + registerFunction>({prefix + "uuid"}); +} +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterRegexp.cpp b/velox/functions/sparksql/registration/RegisterRegexp.cpp new file mode 100644 index 000000000000..debc2b7b31f0 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterRegexp.cpp @@ -0,0 +1,36 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/Re2Functions.h" +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/sparksql/RegexFunctions.h" + +namespace facebook::velox::functions::sparksql { + +void registerRegexpFunctions(const std::string& prefix) { + exec::registerStatefulVectorFunction( + prefix + "regexp_extract", re2ExtractSignatures(), makeRegexExtract); + exec::registerStatefulVectorFunction( + prefix + "regexp_extract_all", + re2ExtractAllSignatures(), + makeRe2ExtractAll); + exec::registerStatefulVectorFunction( + prefix + "rlike", re2SearchSignatures(), makeRLike); + exec::registerStatefulVectorFunction( + prefix + "like", likeSignatures(), makeLike); + registerRegexpReplace(prefix); +} + +} // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/registration/RegisterSpecialForm.cpp b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp new file mode 100644 index 000000000000..d9f12abe4f80 --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterSpecialForm.cpp @@ -0,0 +1,49 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/expression/RegisterSpecialForm.h" +#include "velox/expression/RowConstructor.h" +#include "velox/expression/SpecialFormRegistry.h" +#include "velox/functions/sparksql/specialforms/AtLeastNNonNulls.h" +#include "velox/functions/sparksql/specialforms/DecimalRound.h" +#include "velox/functions/sparksql/specialforms/MakeDecimal.h" +#include "velox/functions/sparksql/specialforms/SparkCastExpr.h" + +namespace facebook::velox::functions { +void registerSparkSpecialFormFunctions() { + VELOX_REGISTER_VECTOR_FUNCTION( + udf_concat_row, exec::RowConstructorCallToSpecialForm::kRowConstructor); +} + +namespace sparksql { +void registerSpecialFormGeneralFunctions(const std::string& prefix) { + exec::registerFunctionCallToSpecialForms(); + exec::registerFunctionCallToSpecialForm( + MakeDecimalCallToSpecialForm::kMakeDecimal, + std::make_unique()); + exec::registerFunctionCallToSpecialForm( + DecimalRoundCallToSpecialForm::kRoundDecimal, + std::make_unique()); + exec::registerFunctionCallToSpecialForm( + AtLeastNNonNullsCallToSpecialForm::kAtLeastNNonNulls, + std::make_unique()); + registerSparkSpecialFormFunctions(); + registerFunctionCallToSpecialForm( + "cast", std::make_unique()); + registerFunctionCallToSpecialForm( + "try_cast", std::make_unique()); +} +} // namespace sparksql +} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/registration/RegisterString.cpp b/velox/functions/sparksql/registration/RegisterString.cpp new file mode 100644 index 000000000000..38a01bc4941d --- /dev/null +++ b/velox/functions/sparksql/registration/RegisterString.cpp @@ -0,0 +1,147 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/lib/Re2Functions.h" +#include "velox/functions/prestosql/StringFunctions.h" +#include "velox/functions/prestosql/URLFunctions.h" +#include "velox/functions/sparksql/MaskFunction.h" +#include "velox/functions/sparksql/Split.h" +#include "velox/functions/sparksql/String.h" +#include "velox/functions/sparksql/StringToMap.h" + +namespace facebook::velox::functions { +void registerSparkStringFunctions(const std::string& prefix) { + VELOX_REGISTER_VECTOR_FUNCTION(udf_concat, prefix + "concat"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_lower, prefix + "lower"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_upper, prefix + "upper"); + VELOX_REGISTER_VECTOR_FUNCTION(udf_reverse, prefix + "reverse"); +} + +namespace sparksql { +void registerStringFunctions(const std::string& prefix) { + registerSparkStringFunctions(prefix); + registerFunction( + {prefix + "startswith"}); + registerFunction( + {prefix + "endswith"}); + registerFunction( + {prefix + "contains"}); + registerFunction( + {prefix + "locate"}); + registerFunction({prefix + "trim"}); + registerFunction({prefix + "trim"}); + registerFunction({prefix + "ltrim"}); + registerFunction( + {prefix + "ltrim"}); + registerFunction({prefix + "rtrim"}); + registerFunction( + {prefix + "rtrim"}); + registerFunction( + {prefix + "translate"}); + registerFunction( + {prefix + "conv"}); + registerFunction( + {prefix + "replace"}); + registerFunction( + {prefix + "replace"}); + registerFunction( + {prefix + "find_in_set"}); + registerFunction( + {prefix + "url_encode"}); + registerFunction( + {prefix + "url_decode"}); + registerFunction({prefix + "chr"}); + registerFunction({prefix + "ascii"}); + registerFunction( + {prefix + "lpad"}); + registerFunction( + {prefix + "rpad"}); + registerFunction( + {prefix + "lpad"}); + registerFunction( + {prefix + "rpad"}); + registerFunction( + {prefix + "substring"}); + registerFunction< + sparksql::SubstrFunction, + Varchar, + Varchar, + int32_t, + int32_t>({prefix + "substring"}); + registerFunction< + sparksql::OverlayVarcharFunction, + Varchar, + Varchar, + Varchar, + int32_t, + int32_t>({prefix + "overlay"}); + registerFunction< + sparksql::OverlayVarbinaryFunction, + Varbinary, + Varbinary, + Varbinary, + int32_t, + int32_t>({prefix + "overlay"}); + registerFunction< + sparksql::StringToMapFunction, + Map, + Varchar, + Varchar, + Varchar>({prefix + "str_to_map"}); + registerFunction( + {prefix + "left"}); + registerFunction( + {prefix + "bit_length"}); + registerFunction( + {prefix + "bit_length"}); + exec::registerStatefulVectorFunction( + prefix + "instr", instrSignatures(), makeInstr); + exec::registerStatefulVectorFunction( + prefix + "length", lengthSignatures(), makeLength); + registerFunction( + {prefix + "substring_index"}); + registerFunction( + {prefix + "empty2null"}); + registerFunction< + LevenshteinDistanceFunction, + int32_t, + Varchar, + Varchar, + int32_t>({prefix + "levenshtein"}); + registerFunction( + {prefix + "levenshtein"}); + registerFunction( + {prefix + "repeat"}); + registerFunction({prefix + "soundex"}); + registerFunction, Varchar, Varchar>({prefix + "split"}); + registerFunction, Varchar, Varchar, int32_t>( + {prefix + "split"}); + registerFunction({prefix + "mask"}); + registerFunction({prefix + "mask"}); + registerFunction( + {prefix + "mask"}); + registerFunction( + {prefix + "mask"}); + registerFunction< + MaskFunction, + Varchar, + Varchar, + Varchar, + Varchar, + Varchar, + Varchar>({prefix + "mask"}); +} +} // namespace sparksql +} // namespace facebook::velox::functions diff --git a/velox/functions/sparksql/RegisterArithmetic.h b/velox/functions/sparksql/registration/RegisterUrl.cpp similarity index 67% rename from velox/functions/sparksql/RegisterArithmetic.h rename to velox/functions/sparksql/registration/RegisterUrl.cpp index fcd6bca397f3..cce789c134ee 100644 --- a/velox/functions/sparksql/RegisterArithmetic.h +++ b/velox/functions/sparksql/registration/RegisterUrl.cpp @@ -13,10 +13,16 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#pragma once - -#include +#include "velox/functions/lib/RegistrationHelpers.h" +#include "velox/functions/prestosql/URLFunctions.h" namespace facebook::velox::functions::sparksql { -void registerArithmeticFunctions(const std::string& prefix); + +void registerUrlFunctions(const std::string& prefix) { + registerFunction( + {prefix + "url_encode"}); + registerFunction( + {prefix + "url_decode"}); +} + } // namespace facebook::velox::functions::sparksql diff --git a/velox/functions/sparksql/tests/RegisterTest.cpp b/velox/functions/sparksql/tests/RegisterTest.cpp index fa046cb9cbfc..436a2e9c5104 100644 --- a/velox/functions/sparksql/tests/RegisterTest.cpp +++ b/velox/functions/sparksql/tests/RegisterTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include diff --git a/velox/functions/sparksql/tests/SliceTest.cpp b/velox/functions/sparksql/tests/SliceTest.cpp index 1d4feff3c636..cfd480bb72dd 100644 --- a/velox/functions/sparksql/tests/SliceTest.cpp +++ b/velox/functions/sparksql/tests/SliceTest.cpp @@ -14,7 +14,7 @@ * limitations under the License. */ #include "velox/functions/lib/tests/SliceTestBase.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" namespace facebook::velox::functions::sparksql::test { diff --git a/velox/functions/sparksql/tests/SortArrayTest.cpp b/velox/functions/sparksql/tests/SortArrayTest.cpp index 83a70dc72e7a..fa0449a37f9e 100644 --- a/velox/functions/sparksql/tests/SortArrayTest.cpp +++ b/velox/functions/sparksql/tests/SortArrayTest.cpp @@ -19,7 +19,7 @@ #include #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/functions/sparksql/tests/ArraySortTestData.h" #include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h" #include "velox/vector/ComplexVector.h" diff --git a/velox/functions/sparksql/tests/SparkCastExprTest.cpp b/velox/functions/sparksql/tests/SparkCastExprTest.cpp index 800710d053b8..4e820a2af15a 100644 --- a/velox/functions/sparksql/tests/SparkCastExprTest.cpp +++ b/velox/functions/sparksql/tests/SparkCastExprTest.cpp @@ -15,7 +15,7 @@ */ #include "velox/functions/prestosql/tests/CastBaseTest.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/parse/TypeResolver.h" using namespace facebook::velox; diff --git a/velox/functions/sparksql/tests/SparkFunctionBaseTest.h b/velox/functions/sparksql/tests/SparkFunctionBaseTest.h index 1b527708807b..b141be0d86fc 100644 --- a/velox/functions/sparksql/tests/SparkFunctionBaseTest.h +++ b/velox/functions/sparksql/tests/SparkFunctionBaseTest.h @@ -16,7 +16,7 @@ #pragma once #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" -#include "velox/functions/sparksql/Register.h" +#include "velox/functions/sparksql/registration/Register.h" #include "velox/parse/TypeResolver.h" namespace facebook::velox::functions::sparksql::test {