diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp deleted file mode 100644 index f242b2770..000000000 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/avg.cpp +++ /dev/null @@ -1,196 +0,0 @@ -#include "core_functions/aggregate/algebraic_functions.hpp" -#include "core_functions/aggregate/sum_helpers.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/planner/expression.hpp" - -namespace duckdb { - -template -struct AvgState { - uint64_t count; - T value; - - void Initialize() { - this->count = 0; - } - - void Combine(const AvgState &other) { - this->count += other.count; - this->value += other.value; - } -}; - -struct KahanAvgState { - uint64_t count; - double value; - double err; - - void Initialize() { - this->count = 0; - this->err = 0.0; - } - - void Combine(const KahanAvgState &other) { - this->count += other.count; - KahanAddInternal(other.value, this->value, this->err); - KahanAddInternal(other.err, this->value, this->err); - } -}; - -struct AverageDecimalBindData : public FunctionData { - explicit AverageDecimalBindData(double scale) : scale(scale) { - } - - double scale; - -public: - unique_ptr Copy() const override { - return make_uniq(scale); - }; - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return scale == other.scale; - } -}; - -struct AverageSetOperation { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.Combine(source); - } - template - static void AddValues(STATE &state, idx_t count) { - state.count += count; - } -}; - -template -static T GetAverageDivident(uint64_t count, optional_ptr bind_data) { - T divident = T(count); - if (bind_data) { - auto &avg_bind_data = bind_data->Cast(); - divident *= avg_bind_data.scale; - } - return divident; -} - -struct IntegerAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = double(state.value) / divident; - } - } -}; - -struct IntegerAverageOperationHugeint : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = Hugeint::Cast(state.value) / divident; - } - } -}; - -struct HugeintAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - long double divident = GetAverageDivident(state.count, finalize_data.input.bind_data); - target = Hugeint::Cast(state.value) / divident; - } - } -}; - -struct NumericAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.value / state.count; - } - } -}; - -struct KahanAverageOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = (state.value / state.count) + (state.err / state.count); - } - } -}; - -AggregateFunction GetAverageAggregate(PhysicalType type) { - switch (type) { - case PhysicalType::INT16: { - return AggregateFunction::UnaryAggregate, int16_t, double, IntegerAverageOperation>( - LogicalType::SMALLINT, LogicalType::DOUBLE); - } - case PhysicalType::INT32: { - return AggregateFunction::UnaryAggregate, int32_t, double, IntegerAverageOperationHugeint>( - LogicalType::INTEGER, LogicalType::DOUBLE); - } - case PhysicalType::INT64: { - return AggregateFunction::UnaryAggregate, int64_t, double, IntegerAverageOperationHugeint>( - LogicalType::BIGINT, LogicalType::DOUBLE); - } - case PhysicalType::INT128: { - return AggregateFunction::UnaryAggregate, hugeint_t, double, HugeintAverageOperation>( - LogicalType::HUGEINT, LogicalType::DOUBLE); - } - default: - throw InternalException("Unimplemented average aggregate"); - } -} - -unique_ptr BindDecimalAvg(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - function = GetAverageAggregate(decimal_type.InternalType()); - function.name = "avg"; - function.arguments[0] = decimal_type; - function.return_type = LogicalType::DOUBLE; - return make_uniq( - Hugeint::Cast(Hugeint::POWERS_OF_TEN[DecimalType::GetScale(decimal_type)])); -} - -AggregateFunctionSet AvgFun::GetFunctions() { - AggregateFunctionSet avg; - - avg.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, - BindDecimalAvg)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT16)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT32)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT64)); - avg.AddFunction(GetAverageAggregate(PhysicalType::INT128)); - avg.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericAverageOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE)); - return avg; -} - -AggregateFunction FAvgFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp deleted file mode 100644 index bf53a5ad3..000000000 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/corr.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "core_functions/aggregate/algebraic_functions.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" -#include "core_functions/aggregate/algebraic/stddev.hpp" -#include "core_functions/aggregate/algebraic/corr.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -AggregateFunction CorrFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp deleted file mode 100644 index fddb9ed28..000000000 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/covar.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "core_functions/aggregate/algebraic_functions.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" - -namespace duckdb { - -AggregateFunction CovarPopFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction CovarSampFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp b/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp deleted file mode 100644 index e9d14ee25..000000000 --- a/src/duckdb/extension/core_functions/aggregate/algebraic/stddev.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "core_functions/aggregate/algebraic_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/function_set.hpp" -#include "core_functions/aggregate/algebraic/stddev.hpp" -#include - -namespace duckdb { - -AggregateFunction StdDevSampFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction StdDevPopFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction VarPopFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction VarSampFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction StandardErrorOfTheMeanFun::GetFunction() { - return AggregateFunction::UnaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp deleted file mode 100644 index 37f05b208..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/approx_count.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/hash.hpp" -#include "duckdb/common/types/hyperloglog.hpp" -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "hyperloglog.hpp" - -namespace duckdb { - -// Algorithms from -// "New cardinality estimation algorithms for HyperLogLog sketches" -// Otmar Ertl, arXiv:1702.01284 -struct ApproxDistinctCountState { - HyperLogLog hll; -}; - -struct ApproxCountDistinctFunction { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.hll.Merge(source.hll); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = UnsafeNumericCast(state.hll.Count()); - } - - static bool IgnoreNull() { - return true; - } -}; - -static void ApproxCountDistinctSimpleUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, - data_ptr_t state, idx_t count) { - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - Vector hash_vec(LogicalType::HASH, count); - VectorOperations::Hash(input, hash_vec, count); - - auto agg_state = reinterpret_cast(state); - agg_state->hll.Update(input, hash_vec, count); -} - -static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputData &, idx_t input_count, - Vector &state_vector, idx_t count) { - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - if (count > STANDARD_VECTOR_SIZE) { - throw InternalException("ApproxCountDistinct - count must be at most vector size"); - } - Vector hash_vec(LogicalType::HASH, count); - VectorOperations::Hash(input, hash_vec, count); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - const auto states = UnifiedVectorFormat::GetDataNoConst(sdata); - - UnifiedVectorFormat hdata; - hash_vec.ToUnifiedFormat(count, hdata); - const auto *hashes = UnifiedVectorFormat::GetData(hdata); - for (idx_t i = 0; i < count; i++) { - if (idata.validity.RowIsValid(idata.sel->get_index(i))) { - auto agg_state = states[sdata.sel->get_index(i)]; - const auto hash = hashes[hdata.sel->get_index(i)]; - agg_state->hll.InsertElement(hash); - } - } -} - -AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) { - auto fun = AggregateFunction( - {input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - ApproxCountDistinctUpdateFunction, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - ApproxCountDistinctSimpleUpdateFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -AggregateFunction ApproxCountDistinctFun::GetFunction() { - return GetApproxCountDistinctFunction(LogicalType::ANY); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp deleted file mode 100644 index 63c112b3c..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/arg_min_max.cpp +++ /dev/null @@ -1,742 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/function/aggregate/minmax_n_helpers.hpp" - -namespace duckdb { - -struct ArgMinMaxStateBase { - ArgMinMaxStateBase() : is_initialized(false), arg_null(false) { - } - - template - static inline void CreateValue(T &value) { - } - - template - static inline void DestroyValue(T &value) { - } - - template - static inline void AssignValue(T &target, T new_value) { - target = new_value; - } - - template - static inline void ReadValue(Vector &result, T &arg, T &target) { - target = arg; - } - - bool is_initialized; - bool arg_null; -}; - -// Out-of-line specialisations -template <> -void ArgMinMaxStateBase::CreateValue(string_t &value) { - value = string_t(uint32_t(0)); -} - -template <> -void ArgMinMaxStateBase::DestroyValue(string_t &value) { - if (!value.IsInlined()) { - delete[] value.GetData(); - } -} - -template <> -void ArgMinMaxStateBase::AssignValue(string_t &target, string_t new_value) { - DestroyValue(target); - if (new_value.IsInlined()) { - target = new_value; - } else { - // non-inlined string, need to allocate space for it - auto len = new_value.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, new_value.GetData(), len); - - target = string_t(ptr, UnsafeNumericCast(len)); - } -} - -template <> -void ArgMinMaxStateBase::ReadValue(Vector &result, string_t &arg, string_t &target) { - target = StringVector::AddStringOrBlob(result, arg); -} - -template -struct ArgMinMaxState : public ArgMinMaxStateBase { - using ARG_TYPE = A; - using BY_TYPE = B; - - ARG_TYPE arg; - BY_TYPE value; - - ArgMinMaxState() { - CreateValue(arg); - CreateValue(value); - } - - ~ArgMinMaxState() { - if (is_initialized) { - DestroyValue(arg); - DestroyValue(value); - is_initialized = false; - } - } -}; - -template -struct ArgMinMaxBase { - template - static void Initialize(STATE &state) { - new (&state) STATE; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - template - static void Assign(STATE &state, const A_TYPE &x, const B_TYPE &y, const bool x_null) { - if (IGNORE_NULL) { - STATE::template AssignValue(state.arg, x); - STATE::template AssignValue(state.value, y); - } else { - state.arg_null = x_null; - if (!state.arg_null) { - STATE::template AssignValue(state.arg, x); - } - STATE::template AssignValue(state.value, y); - } - } - - template - static void Operation(STATE &state, const A_TYPE &x, const B_TYPE &y, AggregateBinaryInput &binary) { - if (!state.is_initialized) { - if (IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) { - Assign(state, x, y, !binary.left_mask.RowIsValid(binary.lidx)); - state.is_initialized = true; - } - } else { - OP::template Execute(state, x, y, binary); - } - } - - template - static void Execute(STATE &state, A_TYPE x_data, B_TYPE y_data, AggregateBinaryInput &binary) { - if ((IGNORE_NULL || binary.right_mask.RowIsValid(binary.ridx)) && COMPARATOR::Operation(y_data, state.value)) { - Assign(state, x_data, y_data, !binary.left_mask.RowIsValid(binary.lidx)); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_initialized) { - return; - } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - Assign(target, source.arg, source.value, source.arg_null); - target.is_initialized = true; - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_initialized || state.arg_null) { - finalize_data.ReturnNull(); - } else { - STATE::template ReadValue(finalize_data.result, state.arg, target); - } - } - - static bool IgnoreNull() { - return IGNORE_NULL; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { - ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); - } - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -struct SpecializedGenericArgMinMaxState { - static bool CreateExtraState(idx_t count) { - // nop extra state - return false; - } - - static void PrepareData(Vector &by, idx_t count, bool &, UnifiedVectorFormat &result) { - by.ToUnifiedFormat(count, result); - } -}; - -template -struct GenericArgMinMaxState { - static Vector CreateExtraState(idx_t count) { - return Vector(LogicalType::BLOB, count); - } - - static void PrepareData(Vector &by, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { - OrderModifiers modifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKeyWithValidity(by, extra_state, modifiers, count); - extra_state.ToUnifiedFormat(count, result); - } -}; - -template -struct VectorArgMinMaxBase : ArgMinMaxBase { - template - static void Update(Vector inputs[], AggregateInputData &, idx_t input_count, Vector &state_vector, idx_t count) { - auto &arg = inputs[0]; - UnifiedVectorFormat adata; - arg.ToUnifiedFormat(count, adata); - - using ARG_TYPE = typename STATE::ARG_TYPE; - using BY_TYPE = typename STATE::BY_TYPE; - auto &by = inputs[1]; - UnifiedVectorFormat bdata; - auto extra_state = UPDATE_TYPE::CreateExtraState(count); - UPDATE_TYPE::PrepareData(by, count, extra_state, bdata); - const auto bys = UnifiedVectorFormat::GetData(bdata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - STATE *last_state = nullptr; - sel_t assign_sel[STANDARD_VECTOR_SIZE]; - idx_t assign_count = 0; - - auto states = UnifiedVectorFormat::GetData(sdata); - for (idx_t i = 0; i < count; i++) { - const auto bidx = bdata.sel->get_index(i); - if (!bdata.validity.RowIsValid(bidx)) { - continue; - } - const auto bval = bys[bidx]; - - const auto aidx = adata.sel->get_index(i); - const auto arg_null = !adata.validity.RowIsValid(aidx); - if (IGNORE_NULL && arg_null) { - continue; - } - - const auto sidx = sdata.sel->get_index(i); - auto &state = *states[sidx]; - if (!state.is_initialized || COMPARATOR::template Operation(bval, state.value)) { - STATE::template AssignValue(state.value, bval); - state.arg_null = arg_null; - // micro-adaptivity: it is common we overwrite the same state repeatedly - // e.g. when running arg_max(val, ts) and ts is sorted in ascending order - // this check essentially says: - // "if we are overriding the same state as the last row, the last write was pointless" - // hence we skip the last write altogether - if (!arg_null) { - if (&state == last_state) { - assign_count--; - } - assign_sel[assign_count++] = UnsafeNumericCast(i); - last_state = &state; - } - state.is_initialized = true; - } - } - if (assign_count == 0) { - // no need to assign anything: nothing left to do - return; - } - Vector sort_key(LogicalType::BLOB); - auto modifiers = OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST); - // slice with a selection vector and generate sort keys - SelectionVector sel(assign_sel); - Vector sliced_input(arg, sel, assign_count); - CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); - auto sort_key_data = FlatVector::GetData(sort_key); - - // now assign sort keys - for (idx_t i = 0; i < assign_count; i++) { - const auto sidx = sdata.sel->get_index(sel.get_index(i)); - auto &state = *states[sidx]; - STATE::template AssignValue(state.arg, sort_key_data[i]); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_initialized) { - return; - } - if (!target.is_initialized || COMPARATOR::Operation(source.value, target.value)) { - STATE::template AssignValue(target.value, source.value); - target.arg_null = source.arg_null; - if (!target.arg_null) { - STATE::template AssignValue(target.arg, source.arg); - } - target.is_initialized = true; - } - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.is_initialized || state.arg_null) { - finalize_data.ReturnNull(); - } else { - CreateSortKeyHelpers::DecodeSortKey(state.arg, finalize_data.result, finalize_data.result_idx, - OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST)); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[1]->return_type.InternalType() == PhysicalType::VARCHAR) { - ExpressionBinder::PushCollation(context, arguments[1], arguments[1]->return_type); - } - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -template -AggregateFunction GetGenericArgMinMaxFunction() { - using STATE = ArgMinMaxState; - return AggregateFunction( - {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, OP::template Update, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, - AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetVectorArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { -#ifndef DUCKDB_SMALLER_BINARY - using STATE = ArgMinMaxState; - return AggregateFunction({type, by_type}, type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - OP::template Update, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, - AggregateFunction::StateDestroy); -#else - auto function = GetGenericArgMinMaxFunction(); - function.arguments = {type, by_type}; - function.return_type = type; - return function; -#endif -} - -#ifndef DUCKDB_SMALLER_BINARY -template -AggregateFunction GetVectorArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { - switch (by_type.InternalType()) { - case PhysicalType::INT32: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT64: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT128: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::DOUBLE: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::VARCHAR: - return GetVectorArgMinMaxFunctionInternal(by_type, type); - default: - throw InternalException("Unimplemented arg_min/arg_max aggregate"); - } -} -#endif - -static const vector ArgMaxByTypes() { - vector types = {LogicalType::INTEGER, LogicalType::BIGINT, LogicalType::HUGEINT, - LogicalType::DOUBLE, LogicalType::VARCHAR, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIMESTAMP_TZ, LogicalType::BLOB}; - return types; -} - -template -void AddVectorArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { - auto by_types = ArgMaxByTypes(); - for (const auto &by_type : by_types) { -#ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetVectorArgMinMaxFunctionBy(by_type, type)); -#else - fun.AddFunction(GetVectorArgMinMaxFunctionInternal(by_type, type)); -#endif - } -} - -template -AggregateFunction GetArgMinMaxFunctionInternal(const LogicalType &by_type, const LogicalType &type) { -#ifndef DUCKDB_SMALLER_BINARY - using STATE = ArgMinMaxState; - auto function = - AggregateFunction::BinaryAggregate( - type, by_type, type); - if (type.InternalType() == PhysicalType::VARCHAR || by_type.InternalType() == PhysicalType::VARCHAR) { - function.destructor = AggregateFunction::StateDestroy; - } - function.bind = OP::Bind; -#else - auto function = GetGenericArgMinMaxFunction(); - function.arguments = {type, by_type}; - function.return_type = type; -#endif - return function; -} - -#ifndef DUCKDB_SMALLER_BINARY -template -AggregateFunction GetArgMinMaxFunctionBy(const LogicalType &by_type, const LogicalType &type) { - switch (by_type.InternalType()) { - case PhysicalType::INT32: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT64: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::INT128: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::DOUBLE: - return GetArgMinMaxFunctionInternal(by_type, type); - case PhysicalType::VARCHAR: - return GetArgMinMaxFunctionInternal(by_type, type); - default: - throw InternalException("Unimplemented arg_min/arg_max by aggregate"); - } -} -#endif - -template -void AddArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &type) { - auto by_types = ArgMaxByTypes(); - for (const auto &by_type : by_types) { -#ifndef DUCKDB_SMALLER_BINARY - fun.AddFunction(GetArgMinMaxFunctionBy(by_type, type)); -#else - fun.AddFunction(GetArgMinMaxFunctionInternal(by_type, type)); -#endif - } -} - -template -static AggregateFunction GetDecimalArgMinMaxFunction(const LogicalType &by_type, const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); -#ifndef DUCKDB_SMALLER_BINARY - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetArgMinMaxFunctionBy(by_type, type); - case PhysicalType::INT32: - return GetArgMinMaxFunctionBy(by_type, type); - case PhysicalType::INT64: - return GetArgMinMaxFunctionBy(by_type, type); - default: - return GetArgMinMaxFunctionBy(by_type, type); - } -#else - return GetArgMinMaxFunctionInternal(by_type, type); -#endif -} - -template -static unique_ptr BindDecimalArgMinMax(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - auto by_type = arguments[1]->return_type; - - // To avoid a combinatorial explosion, cast the ordering argument to one from the list - auto by_types = ArgMaxByTypes(); - idx_t best_target = DConstants::INVALID_INDEX; - int64_t lowest_cost = NumericLimits::Maximum(); - for (idx_t i = 0; i < by_types.size(); ++i) { - // Before falling back to casting, check for a physical type match for the by_type - if (by_types[i].InternalType() == by_type.InternalType()) { - lowest_cost = 0; - best_target = DConstants::INVALID_INDEX; - break; - } - - auto cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(by_type, by_types[i]); - if (cast_cost < 0) { - continue; - } - if (cast_cost < lowest_cost) { - best_target = i; - } - } - - if (best_target != DConstants::INVALID_INDEX) { - by_type = by_types[best_target]; - } - - auto name = std::move(function.name); - function = GetDecimalArgMinMaxFunction(by_type, decimal_type); - function.name = std::move(name); - function.return_type = decimal_type; - return nullptr; -} - -template -void AddDecimalArgMinMaxFunctionBy(AggregateFunctionSet &fun, const LogicalType &by_type) { - fun.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL, by_type}, LogicalTypeId::DECIMAL, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, BindDecimalArgMinMax)); -} - -template -void AddGenericArgMinMaxFunction(AggregateFunctionSet &fun) { - fun.AddFunction(GetGenericArgMinMaxFunction()); -} - -template -static void AddArgMinMaxFunctions(AggregateFunctionSet &fun) { - using GENERIC_VECTOR_OP = VectorArgMinMaxBase>; -#ifndef DUCKDB_SMALLER_BINARY - using OP = ArgMinMaxBase; - using VECTOR_OP = VectorArgMinMaxBase; -#else - using OP = GENERIC_VECTOR_OP; - using VECTOR_OP = GENERIC_VECTOR_OP; -#endif - AddArgMinMaxFunctionBy(fun, LogicalType::INTEGER); - AddArgMinMaxFunctionBy(fun, LogicalType::BIGINT); - AddArgMinMaxFunctionBy(fun, LogicalType::DOUBLE); - AddArgMinMaxFunctionBy(fun, LogicalType::VARCHAR); - AddArgMinMaxFunctionBy(fun, LogicalType::DATE); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP); - AddArgMinMaxFunctionBy(fun, LogicalType::TIMESTAMP_TZ); - AddArgMinMaxFunctionBy(fun, LogicalType::BLOB); - - auto by_types = ArgMaxByTypes(); - for (const auto &by_type : by_types) { - AddDecimalArgMinMaxFunctionBy(fun, by_type); - } - - AddVectorArgMinMaxFunctionBy(fun, LogicalType::ANY); - - // we always use LessThan when using sort keys because the ORDER_TYPE takes care of selecting the lowest or highest - AddGenericArgMinMaxFunction(fun); -} - -//------------------------------------------------------------------------------ -// ArgMinMax(N) Function -//------------------------------------------------------------------------------ -//------------------------------------------------------------------------------ -// State -//------------------------------------------------------------------------------ - -template -class ArgMinMaxNState { -public: - using VAL_TYPE = A; - using ARG_TYPE = B; - - using V = typename VAL_TYPE::TYPE; - using K = typename ARG_TYPE::TYPE; - - BinaryAggregateHeap heap; - - bool is_initialized = false; - void Initialize(idx_t nval) { - heap.Initialize(nval); - is_initialized = true; - } -}; - -//------------------------------------------------------------------------------ -// Operation -//------------------------------------------------------------------------------ -template -static void ArgMinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, - idx_t count) { - - auto &val_vector = inputs[0]; - auto &arg_vector = inputs[1]; - auto &n_vector = inputs[2]; - - UnifiedVectorFormat val_format; - UnifiedVectorFormat arg_format; - UnifiedVectorFormat n_format; - UnifiedVectorFormat state_format; - - auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - auto arg_extra_state = STATE::ARG_TYPE::CreateExtraState(arg_vector, count); - - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); - STATE::ARG_TYPE::PrepareData(arg_vector, count, arg_extra_state, arg_format); - - n_vector.ToUnifiedFormat(count, n_format); - state_vector.ToUnifiedFormat(count, state_format); - - auto states = UnifiedVectorFormat::GetData(state_format); - - for (idx_t i = 0; i < count; i++) { - const auto arg_idx = arg_format.sel->get_index(i); - const auto val_idx = val_format.sel->get_index(i); - if (!arg_format.validity.RowIsValid(arg_idx) || !val_format.validity.RowIsValid(val_idx)) { - continue; - } - const auto state_idx = state_format.sel->get_index(i); - auto &state = *states[state_idx]; - - // Initialize the heap if necessary and add the input to the heap - if (!state.is_initialized) { - static constexpr int64_t MAX_N = 1000000; - const auto nidx = n_format.sel->get_index(i); - if (!n_format.validity.RowIsValid(nidx)) { - throw InvalidInputException("Invalid input for arg_min/arg_max: n value cannot be NULL"); - } - const auto nval = UnifiedVectorFormat::GetData(n_format)[nidx]; - if (nval <= 0) { - throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be > 0"); - } - if (nval >= MAX_N) { - throw InvalidInputException("Invalid input for arg_min/arg_max: n value must be < %d", MAX_N); - } - state.Initialize(UnsafeNumericCast(nval)); - } - - // Now add the input to the heap - auto arg_val = STATE::ARG_TYPE::Create(arg_format, arg_idx); - auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx); - - state.heap.Insert(aggr_input.allocator, arg_val, val_val); - } -} - -//------------------------------------------------------------------------------ -// Bind -//------------------------------------------------------------------------------ -template -static void SpecializeArgMinMaxNFunction(AggregateFunction &function) { - using STATE = ArgMinMaxNState; - using OP = MinMaxNOperation; - - function.state_size = AggregateFunction::StateSize; - function.initialize = AggregateFunction::StateInitialize; - function.combine = AggregateFunction::StateCombine; - function.destructor = AggregateFunction::StateDestroy; - - function.finalize = MinMaxNOperation::Finalize; - function.update = ArgMinMaxNUpdate; -} - -template -static void SpecializeArgMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) { - switch (arg_type) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::VARCHAR: - SpecializeArgMinMaxNFunction(function); - break; - case PhysicalType::INT32: - SpecializeArgMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::INT64: - SpecializeArgMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::FLOAT: - SpecializeArgMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::DOUBLE: - SpecializeArgMinMaxNFunction, COMPARATOR>(function); - break; -#endif - default: - SpecializeArgMinMaxNFunction(function); - break; - } -} - -template -static void SpecializeArgMinMaxNFunction(PhysicalType val_type, PhysicalType arg_type, AggregateFunction &function) { - switch (val_type) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::VARCHAR: - SpecializeArgMinMaxNFunction(arg_type, function); - break; - case PhysicalType::INT32: - SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); - break; - case PhysicalType::INT64: - SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); - break; - case PhysicalType::FLOAT: - SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); - break; - case PhysicalType::DOUBLE: - SpecializeArgMinMaxNFunction, COMPARATOR>(arg_type, function); - break; -#endif - default: - SpecializeArgMinMaxNFunction(arg_type, function); - break; - } -} - -template -unique_ptr ArgMinMaxNBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - } - - const auto val_type = arguments[0]->return_type.InternalType(); - const auto arg_type = arguments[1]->return_type.InternalType(); - - // Specialize the function based on the input types - SpecializeArgMinMaxNFunction(val_type, arg_type, function); - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; -} - -template -static void AddArgMinMaxNFunction(AggregateFunctionSet &set) { - AggregateFunction function({LogicalTypeId::ANY, LogicalTypeId::ANY, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::ANY), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, ArgMinMaxNBind); - - return set.AddFunction(function); -} - -//------------------------------------------------------------------------------ -// Function Registration -//------------------------------------------------------------------------------ - -AggregateFunctionSet ArgMinFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); - return fun; -} - -AggregateFunctionSet ArgMaxFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - AddArgMinMaxNFunction(fun); - return fun; -} - -AggregateFunctionSet ArgMinNullFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - return fun; -} - -AggregateFunctionSet ArgMaxNullFun::GetFunctions() { - AggregateFunctionSet fun; - AddArgMinMaxFunctions(fun); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp deleted file mode 100644 index 241d25692..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitagg.cpp +++ /dev/null @@ -1,231 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/aggregate_executor.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/cast_helpers.hpp" - -namespace duckdb { - -template -struct BitState { - using TYPE = T; - bool is_set; - T value; -}; - -template -static AggregateFunction GetBitfieldUnaryAggregate(LogicalType type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case LogicalTypeId::SMALLINT: - return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); - case LogicalTypeId::INTEGER: - return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); - case LogicalTypeId::BIGINT: - return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); - case LogicalTypeId::HUGEINT: - return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); - case LogicalTypeId::UTINYINT: - return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); - case LogicalTypeId::USMALLINT: - return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); - case LogicalTypeId::UINTEGER: - return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); - case LogicalTypeId::UBIGINT: - return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); - case LogicalTypeId::UHUGEINT: - return AggregateFunction::UnaryAggregate, uhugeint_t, uhugeint_t, OP>(type, type); - default: - throw InternalException("Unimplemented bitfield type for unary aggregate"); - } -} - -struct BitwiseOperation { - template - static void Initialize(STATE &state) { - // If there are no matching rows, returns a null value. - state.is_set = false; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - if (!state.is_set) { - OP::template Assign(state, input); - state.is_set = true; - } else { - OP::template Execute(state, input); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - OP::template Operation(state, input, unary_input); - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - state.value = typename STATE::TYPE(input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_set) { - // source is NULL, nothing to do. - return; - } - if (!target.is_set) { - // target is NULL, use source value directly. - OP::template Assign(target, source.value); - target.is_set = true; - } else { - OP::template Execute(target, source.value); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = T(state.value); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct BitAndOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value &= typename STATE::TYPE(input); - ; - } -}; - -struct BitOrOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value |= typename STATE::TYPE(input); - ; - } -}; - -struct BitXorOperation : public BitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - state.value ^= typename STATE::TYPE(input); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -struct BitStringBitwiseOperation : public BitwiseOperation { - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - D_ASSERT(state.is_set == false); - if (input.IsInlined()) { - state.value = input; - } else { // non-inlined string, need to allocate space for it - auto len = input.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, input.GetData(), len); - - state.value = string_t(ptr, UnsafeNumericCast(len)); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = finalize_data.ReturnString(state.value); - } - } -}; - -struct BitStringAndOperation : public BitStringBitwiseOperation { - - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseAnd(input, state.value, state.value); - } -}; - -struct BitStringOrOperation : public BitStringBitwiseOperation { - - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseOr(input, state.value, state.value); - } -}; - -struct BitStringXorOperation : public BitStringBitwiseOperation { - template - static void Execute(STATE &state, INPUT_TYPE input) { - Bit::BitwiseXor(input, state.value, state.value); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } -}; - -AggregateFunctionSet BitAndFun::GetFunctions() { - AggregateFunctionSet bit_and; - for (auto &type : LogicalType::Integral()) { - bit_and.AddFunction(GetBitfieldUnaryAggregate(type)); - } - - bit_and.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringAndOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_and; -} - -AggregateFunctionSet BitOrFun::GetFunctions() { - AggregateFunctionSet bit_or; - for (auto &type : LogicalType::Integral()) { - bit_or.AddFunction(GetBitfieldUnaryAggregate(type)); - } - bit_or.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringOrOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_or; -} - -AggregateFunctionSet BitXorFun::GetFunctions() { - AggregateFunctionSet bit_xor; - for (auto &type : LogicalType::Integral()) { - bit_xor.AddFunction(GetBitfieldUnaryAggregate(type)); - } - bit_xor.AddFunction( - AggregateFunction::UnaryAggregateDestructor, string_t, string_t, BitStringXorOperation>( - LogicalType::BIT, LogicalType::BIT)); - return bit_xor; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp deleted file mode 100644 index c9e399835..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bitstring_agg.cpp +++ /dev/null @@ -1,320 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/aggregate_executor.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/storage/statistics/base_statistics.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" - -namespace duckdb { - -template -struct BitAggState { - bool is_set; - string_t value; - INPUT_TYPE min; - INPUT_TYPE max; -}; - -struct BitstringAggBindData : public FunctionData { - Value min; - Value max; - - BitstringAggBindData() { - } - - BitstringAggBindData(Value min, Value max) : min(std::move(min)), max(std::move(max)) { - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - if (min.IsNull() && other.min.IsNull() && max.IsNull() && other.max.IsNull()) { - return true; - } - if (Value::NotDistinctFrom(min, other.min) && Value::NotDistinctFrom(max, other.max)) { - return true; - } - return false; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "min", bind_data.min); - serializer.WriteProperty(101, "max", bind_data.max); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &) { - Value min; - Value max; - deserializer.ReadProperty(100, "min", min); - deserializer.ReadProperty(101, "max", max); - return make_uniq(min, max); - } -}; - -struct BitStringAggOperation { - static constexpr const idx_t MAX_BIT_RANGE = 1000000000; // for now capped at 1 billion bits - - template - static void Initialize(STATE &state) { - state.is_set = false; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto &bind_agg_data = unary_input.input.bind_data->template Cast(); - if (!state.is_set) { - if (bind_agg_data.min.IsNull() || bind_agg_data.max.IsNull()) { - throw BinderException( - "Could not retrieve required statistics. Alternatively, try by providing the statistics " - "explicitly: BITSTRING_AGG(col, min, max) "); - } - state.min = bind_agg_data.min.GetValue(); - state.max = bind_agg_data.max.GetValue(); - if (state.min > state.max) { - throw InvalidInputException("Invalid explicit bitstring range: Minimum (%s) > maximum (%s)", - NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); - } - idx_t bit_range = - GetRange(bind_agg_data.min.GetValue(), bind_agg_data.max.GetValue()); - if (bit_range > MAX_BIT_RANGE) { - throw OutOfRangeException( - "The range between min and max value (%s <-> %s) is too large for bitstring aggregation", - NumericHelper::ToString(state.min), NumericHelper::ToString(state.max)); - } - idx_t len = Bit::ComputeBitstringLen(bit_range); - auto target = len > string_t::INLINE_LENGTH ? string_t(new char[len], UnsafeNumericCast(len)) - : string_t(UnsafeNumericCast(len)); - Bit::SetEmptyBitString(target, bit_range); - - state.value = target; - state.is_set = true; - } - if (input >= state.min && input <= state.max) { - Execute(state, input, bind_agg_data.min.GetValue()); - } else { - throw OutOfRangeException("Value %s is outside of provided min and max range (%s <-> %s)", - NumericHelper::ToString(input), NumericHelper::ToString(state.min), - NumericHelper::ToString(state.max)); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - OP::template Operation(state, input, unary_input); - } - - template - static idx_t GetRange(INPUT_TYPE min, INPUT_TYPE max) { - if (min > max) { - throw InvalidInputException("Invalid explicit bitstring range: Minimum (%d) > maximum (%d)", min, max); - } - INPUT_TYPE result; - if (!TrySubtractOperator::Operation(max, min, result)) { - return NumericLimits::Maximum(); - } - auto val = NumericCast(result); - if (val == NumericLimits::Maximum()) { - return val; - } - return val + 1; - } - - template - static void Execute(STATE &state, INPUT_TYPE input, INPUT_TYPE min) { - Bit::SetBit(state.value, UnsafeNumericCast(input - min), 1); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.is_set) { - return; - } - if (!target.is_set) { - Assign(target, source.value); - target.is_set = true; - target.min = source.min; - target.max = source.max; - } else { - Bit::BitwiseOr(source.value, target.value, target.value); - } - } - - template - static void Assign(STATE &state, INPUT_TYPE input) { - D_ASSERT(state.is_set == false); - if (input.IsInlined()) { - state.value = input; - } else { // non-inlined string, need to allocate space for it - auto len = input.GetSize(); - auto ptr = new char[len]; - memcpy(ptr, input.GetData(), len); - state.value = string_t(ptr, UnsafeNumericCast(len)); - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.is_set && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -template <> -void BitStringAggOperation::Execute(BitAggState &state, hugeint_t input, hugeint_t min) { - idx_t val; - if (Hugeint::TryCast(input - min, val)) { - Bit::SetBit(state.value, val, 1); - } else { - throw OutOfRangeException("Range too large for bitstring aggregation"); - } -} - -template <> -idx_t BitStringAggOperation::GetRange(hugeint_t min, hugeint_t max) { - hugeint_t result; - if (!TrySubtractOperator::Operation(max, min, result)) { - return NumericLimits::Maximum(); - } - idx_t range; - if (!Hugeint::TryCast(result + 1, range) || result == NumericLimits::Maximum()) { - return NumericLimits::Maximum(); - } - return range; -} - -template <> -void BitStringAggOperation::Execute(BitAggState &state, uhugeint_t input, uhugeint_t min) { - idx_t val; - if (Uhugeint::TryCast(input - min, val)) { - Bit::SetBit(state.value, val, 1); - } else { - throw OutOfRangeException("Range too large for bitstring aggregation"); - } -} - -template <> -idx_t BitStringAggOperation::GetRange(uhugeint_t min, uhugeint_t max) { - uhugeint_t result; - if (!TrySubtractOperator::Operation(max, min, result)) { - return NumericLimits::Maximum(); - } - idx_t range; - if (!Uhugeint::TryCast(result + 1, range) || result == NumericLimits::Maximum()) { - return NumericLimits::Maximum(); - } - return range; -} - -unique_ptr BitstringPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - - if (NumericStats::HasMinMax(input.child_stats[0])) { - auto &bind_agg_data = input.bind_data->Cast(); - bind_agg_data.min = NumericStats::Min(input.child_stats[0]); - bind_agg_data.max = NumericStats::Max(input.child_stats[0]); - } - return nullptr; -} - -unique_ptr BindBitstringAgg(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments.size() == 3) { - if (!arguments[1]->IsFoldable() || !arguments[2]->IsFoldable()) { - throw BinderException("bitstring_agg requires a constant min and max argument"); - } - auto min = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - auto max = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - Function::EraseArgument(function, arguments, 2); - Function::EraseArgument(function, arguments, 1); - return make_uniq(min, max); - } - return make_uniq(); -} - -template -static void BindBitString(AggregateFunctionSet &bitstring_agg, const LogicalTypeId &type) { - auto function = - AggregateFunction::UnaryAggregateDestructor, TYPE, string_t, BitStringAggOperation>( - type, LogicalType::BIT); - function.bind = BindBitstringAgg; // create new a 'BitstringAggBindData' - function.serialize = BitstringAggBindData::Serialize; - function.deserialize = BitstringAggBindData::Deserialize; - function.statistics = BitstringPropagateStats; // stores min and max from column stats in BitstringAggBindData - bitstring_agg.AddFunction(function); // uses the BitstringAggBindData to access statistics for creating bitstring - function.arguments = {type, type, type}; - function.statistics = nullptr; // min and max are provided as arguments - bitstring_agg.AddFunction(function); -} - -void GetBitStringAggregate(const LogicalType &type, AggregateFunctionSet &bitstring_agg) { - switch (type.id()) { - case LogicalType::TINYINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::SMALLINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::INTEGER: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::BIGINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::HUGEINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UTINYINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::USMALLINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UINTEGER: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UBIGINT: { - return BindBitString(bitstring_agg, type.id()); - } - case LogicalType::UHUGEINT: { - return BindBitString(bitstring_agg, type.id()); - } - default: - throw InternalException("Unimplemented bitstring aggregate"); - } -} - -AggregateFunctionSet BitstringAggFun::GetFunctions() { - AggregateFunctionSet bitstring_agg("bitstring_agg"); - for (auto &type : LogicalType::Integral()) { - GetBitStringAggregate(type, bitstring_agg); - } - return bitstring_agg; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp deleted file mode 100644 index 9b781f848..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/bool.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct BoolState { - bool empty; - bool val; -}; - -struct BoolAndFunFunction { - template - static void Initialize(STATE &state) { - state.val = true; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val && source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input && state.val; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - static bool IgnoreNull() { - return true; - } -}; - -struct BoolOrFunFunction { - template - static void Initialize(STATE &state) { - state.val = false; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val = target.val || source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.empty = false; - state.val = input || state.val; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction BoolOrFun::GetFunction() { - auto fun = AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - return fun; -} - -AggregateFunction BoolAndFun::GetFunction() { - auto fun = AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::BOOLEAN), LogicalType::BOOLEAN); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp deleted file mode 100644 index 4f9f6f30c..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/kurtosis.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -struct KurtosisState { - idx_t n; - double sum; - double sum_sqr; - double sum_cub; - double sum_four; -}; - -struct KurtosisFlagBiasCorrection {}; - -struct KurtosisFlagNoBiasCorrection {}; - -template -struct KurtosisOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = state.sum_four = 0.0; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.n++; - state.sum += input; - state.sum_sqr += pow(input, 2); - state.sum_cub += pow(input, 3); - state.sum_four += pow(input, 4); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.n == 0) { - return; - } - target.n += source.n; - target.sum += source.sum; - target.sum_sqr += source.sum_sqr; - target.sum_cub += source.sum_cub; - target.sum_four += source.sum_four; - } - - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - auto n = (double)state.n; - if (n <= 1) { - finalize_data.ReturnNull(); - return; - } - if (std::is_same::value && n <= 3) { - finalize_data.ReturnNull(); - return; - } - double temp = 1 / n; - //! This is necessary due to linux 32 bits - long double temp_aux = 1 / n; - if (state.sum_sqr - state.sum * state.sum * temp == 0 || - state.sum_sqr - state.sum * state.sum * temp_aux == 0) { - finalize_data.ReturnNull(); - return; - } - double m4 = - temp * (state.sum_four - 4 * state.sum_cub * state.sum * temp + - 6 * state.sum_sqr * state.sum * state.sum * temp * temp - 3 * pow(state.sum, 4) * pow(temp, 3)); - - double m2 = temp * (state.sum_sqr - state.sum * state.sum * temp); - if (m2 <= 0) { // m2 shouldn't be below 0 but floating points are weird - finalize_data.ReturnNull(); - return; - } - if (std::is_same::value) { - target = m4 / (m2 * m2) - 3; - } else { - target = (n - 1) * ((n + 1) * m4 / (m2 * m2) - 3 * (n - 1)) / ((n - 2) * (n - 3)); - } - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("Kurtosis is out of range!"); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction KurtosisFun::GetFunction() { - return AggregateFunction::UnaryAggregate>(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -AggregateFunction KurtosisPopFun::GetFunction() { - return AggregateFunction::UnaryAggregate>(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp deleted file mode 100644 index 324893f60..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/product.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct ProductState { - bool empty; - double val; -}; - -struct ProductFunction { - template - static void Initialize(STATE &state) { - state.val = 1; - state.empty = true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.val *= source.val; - target.empty = target.empty && source.empty; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.empty) { - finalize_data.ReturnNull(); - return; - } - target = state.val; - } - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (state.empty) { - state.empty = false; - } - state.val *= input; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction ProductFun::GetFunction() { - return AggregateFunction::UnaryAggregate( - LogicalType(LogicalTypeId::DOUBLE), LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp deleted file mode 100644 index 12f237610..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/skew.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -struct SkewState { - size_t n; - double sum; - double sum_sqr; - double sum_cub; -}; - -struct SkewnessOperation { - template - static void Initialize(STATE &state) { - state.n = 0; - state.sum = state.sum_sqr = state.sum_cub = 0; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - state.n++; - state.sum += input; - state.sum_sqr += pow(input, 2); - state.sum_cub += pow(input, 3); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.n == 0) { - return; - } - - target.n += source.n; - target.sum += source.sum; - target.sum_sqr += source.sum_sqr; - target.sum_cub += source.sum_cub; - } - - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.n <= 2) { - finalize_data.ReturnNull(); - return; - } - double n = state.n; - double temp = 1 / n; - auto p = std::pow(temp * (state.sum_sqr - state.sum * state.sum * temp), 3); - if (p < 0) { - p = 0; // Shouldn't be below 0 but floating points are weird - } - double div = std::sqrt(p); - if (div == 0) { - target = NAN; - return; - } - double temp1 = std::sqrt(n * (n - 1)) / (n - 2); - target = temp1 * temp * - (state.sum_cub - 3 * state.sum_sqr * state.sum * temp + 2 * pow(state.sum, 3) * temp * temp) / div; - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("SKEW is out of range!"); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction SkewnessFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp deleted file mode 100644 index b694a2365..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/string_agg.cpp +++ /dev/null @@ -1,175 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -namespace duckdb { - -struct StringAggState { - idx_t size; - idx_t alloc_size; - char *dataptr; -}; - -struct StringAggBindData : public FunctionData { - explicit StringAggBindData(string sep_p) : sep(std::move(sep_p)) { - } - - string sep; - - unique_ptr Copy() const override { - return make_uniq(sep); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return sep == other.sep; - } -}; - -struct StringAggFunction { - template - static void Initialize(STATE &state) { - state.dataptr = nullptr; - state.alloc_size = 0; - state.size = 0; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.dataptr) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddString(finalize_data.result, state.dataptr, state.size); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.dataptr) { - delete[] state.dataptr; - } - } - - static bool IgnoreNull() { - return true; - } - - static inline void PerformOperation(StringAggState &state, const char *str, const char *sep, idx_t str_size, - idx_t sep_size) { - if (!state.dataptr) { - // first iteration: allocate space for the string and copy it into the state - state.alloc_size = MaxValue(8, NextPowerOfTwo(str_size)); - state.dataptr = new char[state.alloc_size]; - state.size = str_size; - memcpy(state.dataptr, str, str_size); - } else { - // subsequent iteration: first check if we have space to place the string and separator - idx_t required_size = state.size + str_size + sep_size; - if (required_size > state.alloc_size) { - // no space! allocate extra space - while (state.alloc_size < required_size) { - state.alloc_size *= 2; - } - auto new_data = new char[state.alloc_size]; - memcpy(new_data, state.dataptr, state.size); - delete[] state.dataptr; - state.dataptr = new_data; - } - // copy the separator - memcpy(state.dataptr + state.size, sep, sep_size); - state.size += sep_size; - // copy the string - memcpy(state.dataptr + state.size, str, str_size); - state.size += str_size; - } - } - - static inline void PerformOperation(StringAggState &state, string_t str, optional_ptr data_p) { - auto &data = data_p->Cast(); - PerformOperation(state, str.GetData(), data.sep.c_str(), str.GetSize(), data.sep.size()); - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - PerformOperation(state, input, unary_input.input.bind_data); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - if (!source.dataptr) { - // source is not set: skip combining - return; - } - PerformOperation(target, string_t(source.dataptr, UnsafeNumericCast(source.size)), - aggr_input_data.bind_data); - } -}; - -unique_ptr StringAggBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments.size() == 1) { - // single argument: default to comma - return make_uniq(","); - } - D_ASSERT(arguments.size() == 2); - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("Separator argument to StringAgg must be a constant"); - } - auto separator_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - string separator_string = ","; - if (separator_val.IsNull()) { - arguments[0] = make_uniq(Value(LogicalType::VARCHAR)); - } else { - separator_string = separator_val.ToString(); - } - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(std::move(separator_string)); -} - -static void StringAggSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "separator", bind_data.sep); -} - -unique_ptr StringAggDeserialize(Deserializer &deserializer, AggregateFunction &bound_function) { - auto sep = deserializer.ReadProperty(100, "separator"); - return make_uniq(std::move(sep)); -} - -AggregateFunctionSet StringAggFun::GetFunctions() { - AggregateFunctionSet string_agg; - AggregateFunction string_agg_param( - {LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR, - AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - AggregateFunction::UnaryUpdate, StringAggBind, - AggregateFunction::StateDestroy); - string_agg_param.serialize = StringAggSerialize; - string_agg_param.deserialize = StringAggDeserialize; - string_agg.AddFunction(string_agg_param); - string_agg_param.arguments.emplace_back(LogicalType::VARCHAR); - string_agg.AddFunction(string_agg_param); - return string_agg; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp b/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp deleted file mode 100644 index be37d5df1..000000000 --- a/src/duckdb/extension/core_functions/aggregate/distributive/sum.cpp +++ /dev/null @@ -1,245 +0,0 @@ -#include "core_functions/aggregate/distributive_functions.hpp" -#include "core_functions/aggregate/sum_helpers.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -namespace duckdb { - -struct SumSetOperation { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.Combine(source); - } - template - static void AddValues(STATE &state, idx_t count) { - state.isset = true; - } -}; - -struct IntegerSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = Hugeint::Convert(state.value); - } - } -}; - -struct SumToHugeintOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -template -struct DoubleSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -using NumericSumOperation = DoubleSumOperation; -using KahanSumOperation = DoubleSumOperation; - -struct HugeintSumOperation : public BaseSumOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -unique_ptr SumNoOverflowBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - throw BinderException("sum_no_overflow is for internal use only!"); -} - -void SumNoOverflowSerialize(Serializer &serializer, const optional_ptr bind_data, - const AggregateFunction &function) { - return; -} - -unique_ptr SumNoOverflowDeserialize(Deserializer &deserializer, AggregateFunction &function) { - function.return_type = deserializer.Get(); - return nullptr; -} - -AggregateFunction GetSumAggregateNoOverflow(PhysicalType type) { - switch (type) { - case PhysicalType::INT32: { - auto function = AggregateFunction::UnaryAggregate, int32_t, hugeint_t, IntegerSumOperation>( - LogicalType::INTEGER, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.bind = SumNoOverflowBind; - function.serialize = SumNoOverflowSerialize; - function.deserialize = SumNoOverflowDeserialize; - return function; - } - case PhysicalType::INT64: { - auto function = AggregateFunction::UnaryAggregate, int64_t, hugeint_t, IntegerSumOperation>( - LogicalType::BIGINT, LogicalType::HUGEINT); - function.name = "sum_no_overflow"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.bind = SumNoOverflowBind; - function.serialize = SumNoOverflowSerialize; - function.deserialize = SumNoOverflowDeserialize; - return function; - } - default: - throw BinderException("Unsupported internal type for sum_no_overflow"); - } -} - -AggregateFunction GetSumAggregateNoOverflowDecimal() { - AggregateFunction aggr({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, nullptr, - nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, SumNoOverflowBind); - aggr.serialize = SumNoOverflowSerialize; - aggr.deserialize = SumNoOverflowDeserialize; - return aggr; -} - -unique_ptr SumPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - if (input.node_stats && input.node_stats->has_max_cardinality) { - auto &numeric_stats = input.child_stats[0]; - if (!NumericStats::HasMinMax(numeric_stats)) { - return nullptr; - } - auto internal_type = numeric_stats.GetType().InternalType(); - hugeint_t max_negative; - hugeint_t max_positive; - switch (internal_type) { - case PhysicalType::INT32: - max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); - max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); - break; - case PhysicalType::INT64: - max_negative = NumericStats::Min(numeric_stats).GetValueUnsafe(); - max_positive = NumericStats::Max(numeric_stats).GetValueUnsafe(); - break; - default: - throw InternalException("Unsupported type for propagate sum stats"); - } - auto max_sum_negative = max_negative * Hugeint::Convert(input.node_stats->max_cardinality); - auto max_sum_positive = max_positive * Hugeint::Convert(input.node_stats->max_cardinality); - if (max_sum_positive >= NumericLimits::Maximum() || - max_sum_negative <= NumericLimits::Minimum()) { - // sum can potentially exceed int64_t bounds: use hugeint sum - return nullptr; - } - // total sum is guaranteed to fit in a single int64: use int64 sum instead of hugeint sum - expr.function = GetSumAggregateNoOverflow(internal_type); - } - return nullptr; -} - -AggregateFunction GetSumAggregate(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: { - auto function = AggregateFunction::UnaryAggregate, bool, hugeint_t, IntegerSumOperation>( - LogicalType::BOOLEAN, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT16: { - auto function = AggregateFunction::UnaryAggregate, int16_t, hugeint_t, IntegerSumOperation>( - LogicalType::SMALLINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - - case PhysicalType::INT32: { - auto function = - AggregateFunction::UnaryAggregate, int32_t, hugeint_t, SumToHugeintOperation>( - LogicalType::INTEGER, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT64: { - auto function = - AggregateFunction::UnaryAggregate, int64_t, hugeint_t, SumToHugeintOperation>( - LogicalType::BIGINT, LogicalType::HUGEINT); - function.statistics = SumPropagateStats; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - case PhysicalType::INT128: { - auto function = - AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, HugeintSumOperation>( - LogicalType::HUGEINT, LogicalType::HUGEINT); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return function; - } - default: - throw InternalException("Unimplemented sum aggregate"); - } -} - -unique_ptr BindDecimalSum(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - function = GetSumAggregate(decimal_type.InternalType()); - function.name = "sum"; - function.arguments[0] = decimal_type; - function.return_type = LogicalType::DECIMAL(Decimal::MAX_WIDTH_DECIMAL, DecimalType::GetScale(decimal_type)); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return nullptr; -} - -AggregateFunctionSet SumFun::GetFunctions() { - AggregateFunctionSet sum; - // decimal - sum.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, FunctionNullHandling::DEFAULT_NULL_HANDLING, nullptr, - BindDecimalSum)); - sum.AddFunction(GetSumAggregate(PhysicalType::BOOL)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT16)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT32)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT64)); - sum.AddFunction(GetSumAggregate(PhysicalType::INT128)); - sum.AddFunction(AggregateFunction::UnaryAggregate, double, double, NumericSumOperation>( - LogicalType::DOUBLE, LogicalType::DOUBLE)); - return sum; -} - -AggregateFunction CountIfFun::GetFunction() { - return GetSumAggregate(PhysicalType::BOOL); -} - -AggregateFunctionSet SumNoOverflowFun::GetFunctions() { - AggregateFunctionSet sum_no_overflow; - sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT32)); - sum_no_overflow.AddFunction(GetSumAggregateNoOverflow(PhysicalType::INT64)); - sum_no_overflow.AddFunction(GetSumAggregateNoOverflowDecimal()); - return sum_no_overflow; -} - -AggregateFunction KahanSumFun::GetFunction() { - return AggregateFunction::UnaryAggregate(LogicalType::DOUBLE, - LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp deleted file mode 100644 index 4eb2b9d30..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approx_top_k.cpp +++ /dev/null @@ -1,413 +0,0 @@ -#include "core_functions/aggregate/histogram_helpers.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/function/aggregate/sort_key_helpers.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/common/string_map_set.hpp" -#include "duckdb/common/printer.hpp" - -namespace duckdb { - -struct ApproxTopKString { - ApproxTopKString() : str(UINT32_C(0)), hash(0) { - } - ApproxTopKString(string_t str_p, hash_t hash_p) : str(str_p), hash(hash_p) { - } - - string_t str; - hash_t hash; -}; - -struct ApproxTopKHash { - std::size_t operator()(const ApproxTopKString &k) const { - return k.hash; - } -}; - -struct ApproxTopKEquality { - bool operator()(const ApproxTopKString &a, const ApproxTopKString &b) const { - return Equals::Operation(a.str, b.str); - } -}; - -template -using approx_topk_map_t = unordered_map; - -// approx top k algorithm based on "A parallel space saving algorithm for frequent items and the Hurwitz zeta -// distribution" arxiv link - https://arxiv.org/pdf/1401.0702 -// together with the filter extension (Filtered Space-Saving) from "Estimating Top-k Destinations in Data Streams" -struct ApproxTopKValue { - //! The counter - idx_t count = 0; - //! Index in the values array - idx_t index = 0; - //! The string value - ApproxTopKString str_val; - //! Allocated data - char *dataptr = nullptr; - uint32_t size = 0; - uint32_t capacity = 0; -}; - -struct InternalApproxTopKState { - // the top-k data structure has two components - // a list of k values sorted on "count" (i.e. values[0] has the lowest count) - // a lookup map: string_t -> idx in "values" array - unsafe_unique_array stored_values; - unsafe_vector> values; - approx_topk_map_t> lookup_map; - unsafe_vector filter; - idx_t k = 0; - idx_t capacity = 0; - idx_t filter_mask; - - void Initialize(idx_t kval) { - static constexpr idx_t MONITORED_VALUES_RATIO = 3; - static constexpr idx_t FILTER_RATIO = 8; - - D_ASSERT(values.empty()); - D_ASSERT(lookup_map.empty()); - k = kval; - capacity = kval * MONITORED_VALUES_RATIO; - stored_values = make_unsafe_uniq_array_uninitialized(capacity); - values.reserve(capacity); - - // we scale the filter based on the amount of values we are monitoring - idx_t filter_size = NextPowerOfTwo(capacity * FILTER_RATIO); - filter_mask = filter_size - 1; - filter.resize(filter_size); - } - - static void CopyValue(ApproxTopKValue &value, const ApproxTopKString &input, AggregateInputData &input_data) { - value.str_val.hash = input.hash; - if (input.str.IsInlined()) { - // no need to copy - value.str_val = input; - return; - } - value.size = UnsafeNumericCast(input.str.GetSize()); - if (value.size > value.capacity) { - // need to re-allocate for this value - value.capacity = UnsafeNumericCast(NextPowerOfTwo(value.size)); - value.dataptr = char_ptr_cast(input_data.allocator.Allocate(value.capacity)); - } - // copy over the data - memcpy(value.dataptr, input.str.GetData(), value.size); - value.str_val.str = string_t(value.dataptr, value.size); - } - - void InsertOrReplaceEntry(const ApproxTopKString &input, AggregateInputData &aggr_input, idx_t increment = 1) { - if (values.size() < capacity) { - D_ASSERT(increment > 0); - // we can always add this entry - auto &val = stored_values[values.size()]; - val.index = values.size(); - values.push_back(val); - } - auto &value = values.back().get(); - if (value.count > 0) { - // the capacity is reached - we need to replace an entry - - // we use the filter as an early out - // based on the hash - we find a slot in the filter - // instead of monitoring the value immediately, we add to the slot in the filter - // ONLY when the value in the filter exceeds the current min value, we start monitoring the value - // this speeds up the algorithm as switching monitor values means we need to erase/insert in the hash table - auto &filter_value = filter[input.hash & filter_mask]; - if (filter_value + increment < value.count) { - // if the filter has a lower count than the current min count - // we can skip adding this entry (for now) - filter_value += increment; - return; - } - // the filter exceeds the min value - start monitoring this value - // erase the existing entry from the map - // and set the filter for the minimum value back to the current minimum value - filter[value.str_val.hash & filter_mask] = value.count; - lookup_map.erase(value.str_val); - } - CopyValue(value, input, aggr_input); - lookup_map.insert(make_pair(value.str_val, reference(value))); - IncrementCount(value, increment); - } - - void IncrementCount(ApproxTopKValue &value, idx_t increment = 1) { - value.count += increment; - // maintain sortedness of "values" - // swap while we have a higher count than the next entry - while (value.index > 0 && values[value.index].get().count > values[value.index - 1].get().count) { - // swap the elements around - auto &left = values[value.index]; - auto &right = values[value.index - 1]; - std::swap(left.get().index, right.get().index); - std::swap(left, right); - } - } - - void Verify() const { -#ifdef DEBUG - if (values.empty()) { - D_ASSERT(lookup_map.empty()); - return; - } - D_ASSERT(values.size() <= capacity); - for (idx_t k = 0; k < values.size(); k++) { - auto &val = values[k].get(); - D_ASSERT(val.count > 0); - // verify map exists - auto entry = lookup_map.find(val.str_val); - D_ASSERT(entry != lookup_map.end()); - // verify the index is correct - D_ASSERT(val.index == k); - if (k > 0) { - // sortedness - D_ASSERT(val.count <= values[k - 1].get().count); - } - } - // verify lookup map does not contain extra entries - D_ASSERT(lookup_map.size() == values.size()); -#endif - } -}; - -struct ApproxTopKState { - InternalApproxTopKState *state; - - InternalApproxTopKState &GetState() { - if (!state) { - state = new InternalApproxTopKState(); - } - return *state; - } - - const InternalApproxTopKState &GetState() const { - if (!state) { - throw InternalException("No state available"); - } - return *state; - } -}; - -struct ApproxTopKOperation { - template - static void Initialize(STATE &state) { - state.state = nullptr; - } - - template - static void Operation(STATE &aggr_state, const TYPE &input, AggregateInputData &aggr_input, Vector &top_k_vector, - idx_t offset, idx_t count) { - auto &state = aggr_state.GetState(); - if (state.values.empty()) { - static constexpr int64_t MAX_APPROX_K = 1000000; - // not initialized yet - initialize the K value and set all counters to 0 - UnifiedVectorFormat kdata; - top_k_vector.ToUnifiedFormat(count, kdata); - auto kidx = kdata.sel->get_index(offset); - if (!kdata.validity.RowIsValid(kidx)) { - throw InvalidInputException("Invalid input for approx_top_k: k value cannot be NULL"); - } - auto kval = UnifiedVectorFormat::GetData(kdata)[kidx]; - if (kval <= 0) { - throw InvalidInputException("Invalid input for approx_top_k: k value must be > 0"); - } - if (kval >= MAX_APPROX_K) { - throw InvalidInputException("Invalid input for approx_top_k: k value must be < %d", MAX_APPROX_K); - } - state.Initialize(UnsafeNumericCast(kval)); - } - ApproxTopKString topk_string(input, Hash(input)); - auto entry = state.lookup_map.find(topk_string); - if (entry != state.lookup_map.end()) { - // the input is monitored - increment the count - state.IncrementCount(entry->second.get()); - } else { - // the input is not monitored - replace the first entry with the current entry and increment - state.InsertOrReplaceEntry(topk_string, aggr_input); - } - } - - template - static void Combine(const STATE &aggr_source, STATE &aggr_target, AggregateInputData &aggr_input) { - if (!aggr_source.state) { - // source state is empty - return; - } - auto &source = aggr_source.GetState(); - auto &target = aggr_target.GetState(); - if (source.values.empty()) { - // source is empty - return; - } - source.Verify(); - auto min_source = source.values.back().get().count; - idx_t min_target; - if (target.values.empty()) { - min_target = 0; - target.Initialize(source.k); - } else { - if (source.k != target.k) { - throw NotImplementedException("Approx Top K - cannot combine approx_top_K with different k values. " - "K values must be the same for all entries within the same group"); - } - min_target = target.values.back().get().count; - } - // for all entries in target - // check if they are tracked in source - // if they do - add the tracked count - // if they do not - add the minimum count - for (idx_t target_idx = 0; target_idx < target.values.size(); target_idx++) { - auto &val = target.values[target_idx].get(); - auto source_entry = source.lookup_map.find(val.str_val); - idx_t increment = min_source; - if (source_entry != source.lookup_map.end()) { - increment = source_entry->second.get().count; - } - if (increment == 0) { - continue; - } - target.IncrementCount(val, increment); - } - // now for each entry in source, if it is not tracked by the target, at the target minimum - for (auto &source_entry : source.values) { - auto &source_val = source_entry.get(); - auto target_entry = target.lookup_map.find(source_val.str_val); - if (target_entry != target.lookup_map.end()) { - // already tracked - no need to add anything - continue; - } - auto new_count = source_val.count + min_target; - idx_t increment; - if (target.values.size() >= target.capacity) { - idx_t current_min = target.values.empty() ? 0 : target.values.back().get().count; - D_ASSERT(target.values.size() == target.capacity); - // target already has capacity values - // check if we should insert this entry - if (new_count <= current_min) { - // if we do not we can skip this entry - continue; - } - increment = new_count - current_min; - } else { - // target does not have capacity entries yet - // just add this entry with the full count - increment = new_count; - } - target.InsertOrReplaceEntry(source_val.str_val, aggr_input, increment); - } - // copy over the filter - D_ASSERT(source.filter.size() == target.filter.size()); - for (idx_t filter_idx = 0; filter_idx < source.filter.size(); filter_idx++) { - target.filter[filter_idx] += source.filter[filter_idx]; - } - target.Verify(); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - delete state.state; - } - - static bool IgnoreNull() { - return true; - } -}; - -template -static void ApproxTopKUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, - idx_t count) { - using STATE = ApproxTopKState; - auto &input = inputs[0]; - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto &top_k_vector = inputs[1]; - - auto extra_state = OP::CreateExtraState(count); - UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); - - auto states = UnifiedVectorFormat::GetData(sdata); - auto data = UnifiedVectorFormat::GetData(input_data); - for (idx_t i = 0; i < count; i++) { - auto idx = input_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(idx)) { - continue; - } - auto &state = *states[sdata.sel->get_index(i)]; - ApproxTopKOperation::Operation(state, data[idx], aggr_input, top_k_vector, i, count); - } -} - -template -static void ApproxTopKFinalize(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, idx_t offset) { - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData(sdata); - - auto &mask = FlatVector::Validity(result); - auto old_len = ListVector::GetListSize(result); - idx_t new_entries = 0; - // figure out how much space we need - for (idx_t i = 0; i < count; i++) { - auto &state = states[sdata.sel->get_index(i)]->GetState(); - if (state.values.empty()) { - continue; - } - // get up to k values for each state - // this can be less of fewer unique values were found - new_entries += MinValue(state.values.size(), state.k); - } - // reserve space in the list vector - ListVector::Reserve(result, old_len + new_entries); - auto list_entries = FlatVector::GetData(result); - auto &child_data = ListVector::GetEntry(result); - - idx_t current_offset = old_len; - for (idx_t i = 0; i < count; i++) { - const auto rid = i + offset; - auto &state = states[sdata.sel->get_index(i)]->GetState(); - if (state.values.empty()) { - mask.SetInvalid(rid); - continue; - } - auto &list_entry = list_entries[rid]; - list_entry.offset = current_offset; - for (idx_t val_idx = 0; val_idx < MinValue(state.values.size(), state.k); val_idx++) { - auto &val = state.values[val_idx].get(); - D_ASSERT(val.count > 0); - OP::template HistogramFinalize(val.str_val.str, child_data, current_offset); - current_offset++; - } - list_entry.length = current_offset - list_entry.offset; - } - D_ASSERT(current_offset == old_len + new_entries); - ListVector::SetListSize(result, current_offset); - result.Verify(count); -} - -unique_ptr ApproxTopKBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - } - if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { - function.update = ApproxTopKUpdate; - function.finalize = ApproxTopKFinalize; - } - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; -} - -AggregateFunction ApproxTopKFun::GetFunction() { - using STATE = ApproxTopKState; - using OP = ApproxTopKOperation; - return AggregateFunction("approx_top_k", {LogicalTypeId::ANY, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::ANY), AggregateFunction::StateSize, - AggregateFunction::StateInitialize, ApproxTopKUpdate, - AggregateFunction::StateCombine, ApproxTopKFinalize, nullptr, ApproxTopKBind, - AggregateFunction::StateDestroy); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp deleted file mode 100644 index 23d2cf479..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/approximate_quantile.cpp +++ /dev/null @@ -1,444 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "t_digest.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -#include -#include -#include - -namespace duckdb { - -struct ApproxQuantileState { - duckdb_tdigest::TDigest *h; - idx_t pos; -}; - -struct ApproximateQuantileBindData : public FunctionData { - ApproximateQuantileBindData() { - } - explicit ApproximateQuantileBindData(float quantile_p) : quantiles(1, quantile_p) { - } - - explicit ApproximateQuantileBindData(vector quantiles_p) : quantiles(std::move(quantiles_p)) { - } - - unique_ptr Copy() const override { - return make_uniq(quantiles); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - // return quantiles == other.quantiles; - if (quantiles != other.quantiles) { - return false; - } - return true; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "quantiles", bind_data.quantiles); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - deserializer.ReadProperty(100, "quantiles", result->quantiles); - return std::move(result); - } - - vector quantiles; -}; - -struct ApproxQuantileOperation { - using SAVE_TYPE = duckdb_tdigest::Value; - - template - static void Initialize(STATE &state) { - state.pos = 0; - state.h = nullptr; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto val = Cast::template Operation(input); - if (!Value::DoubleIsFinite(val)) { - return; - } - if (!state.h) { - state.h = new duckdb_tdigest::TDigest(100); - } - state.h->add(val); - state.pos++; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.pos == 0) { - return; - } - D_ASSERT(source.h); - if (!target.h) { - target.h = new duckdb_tdigest::TDigest(100); - } - target.h->merge(source.h); - target.pos += source.pos; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.h) { - delete state.h; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct ApproxQuantileScalarOperation : public ApproxQuantileOperation { - template - static void Finalize(STATE &state, TARGET_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(state.h); - D_ASSERT(finalize_data.input.bind_data); - state.h->compress(); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - // The result is approximate, so clamp instead of overflowing. - const auto source = state.h->quantile(bind_data.quantiles[0]); - if (TryCast::Operation(source, target, false)) { - return; - } else if (source < 0) { - target = NumericLimits::Minimum(); - } else { - target = NumericLimits::Maximum(); - } - } -}; - -static AggregateFunction GetApproximateQuantileAggregateFunction(const LogicalType &type) { - // Not binary comparable - if (type == LogicalType::TIME_TZ) { - return AggregateFunction::UnaryAggregateDestructor(type, type); - } - switch (type.InternalType()) { - case PhysicalType::INT8: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor(type, type); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor(type, type); - default: - throw InternalException("Unimplemented quantile aggregate"); - } -} - -static AggregateFunction GetApproximateQuantileDecimalAggregateFunction(const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::INT8: - return GetApproximateQuantileAggregateFunction(LogicalType::TINYINT); - case PhysicalType::INT16: - return GetApproximateQuantileAggregateFunction(LogicalType::SMALLINT); - case PhysicalType::INT32: - return GetApproximateQuantileAggregateFunction(LogicalType::INTEGER); - case PhysicalType::INT64: - return GetApproximateQuantileAggregateFunction(LogicalType::BIGINT); - case PhysicalType::INT128: - return GetApproximateQuantileAggregateFunction(LogicalType::HUGEINT); - default: - throw InternalException("Unimplemented quantile decimal aggregate"); - } -} - -static float CheckApproxQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("APPROXIMATE QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < 0 || quantile > 1) { - throw BinderException("APPROXIMATE QUANTILE can only take parameters in range [0, 1]"); - } - - return quantile; -} - -unique_ptr BindApproxQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("APPROXIMATE QUANTILE can only take constant quantile parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - if (quantile_val.IsNull()) { - throw BinderException("APPROXIMATE QUANTILE parameter list cannot be NULL"); - } - - vector quantiles; - switch (quantile_val.type().id()) { - case LogicalTypeId::LIST: - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckApproxQuantile(element_val)); - } - break; - case LogicalTypeId::ARRAY: - for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckApproxQuantile(element_val)); - } - break; - default: - quantiles.push_back(CheckApproxQuantile(quantile_val)); - break; - } - - // remove the quantile argument so we can use the unary aggregate - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles); -} - -AggregateFunction ApproxQuantileDecimalFunction(const LogicalType &type) { - auto function = GetApproximateQuantileDecimalAggregateFunction(type); - function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; - return function; -} - -unique_ptr BindApproxQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindApproxQuantile(context, function, arguments); - function = ApproxQuantileDecimalFunction(arguments[0]->return_type); - return bind_data; -} - -AggregateFunction GetApproximateQuantileAggregate(const LogicalType &type) { - auto fun = GetApproximateQuantileAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::FLOAT); - return fun; -} - -template -struct ApproxQuantileListOperation : public ApproxQuantileOperation { - - template - static void Finalize(STATE &state, RESULT_TYPE &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - D_ASSERT(state.h); - state.h->compress(); - - auto &entry = target; - entry.offset = ridx; - entry.length = bind_data.quantiles.size(); - for (size_t q = 0; q < entry.length; ++q) { - const auto &quantile = bind_data.quantiles[q]; - rdata[ridx + q] = Cast::template Operation(state.h->quantile(quantile)); - } - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } -}; - -template -static AggregateFunction ApproxQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetTypedApproxQuantileListAggregateFunction(const LogicalType &type) { - using STATE = ApproxQuantileState; - using OP = ApproxQuantileListOperation; - auto fun = ApproxQuantileListAggregate(type, type); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - return fun; -} - -AggregateFunction GetApproxQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::TIME_TZ: - // Not binary comparable - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedApproxQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedApproxQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedApproxQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented approximate quantile list decimal aggregate"); - } - default: - throw NotImplementedException("Unimplemented approximate quantile list aggregate"); - } -} - -AggregateFunction ApproxQuantileDecimalListFunction(const LogicalType &type) { - auto function = GetApproxQuantileListAggregateFunction(type); - function.name = "approx_quantile"; - function.serialize = ApproximateQuantileBindData::Serialize; - function.deserialize = ApproximateQuantileBindData::Deserialize; - return function; -} - -unique_ptr BindApproxQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto bind_data = BindApproxQuantile(context, function, arguments); - function = ApproxQuantileDecimalListFunction(arguments[0]->return_type); - return bind_data; -} - -AggregateFunction GetApproxQuantileListAggregate(const LogicalType &type) { - auto fun = GetApproxQuantileListAggregateFunction(type); - fun.bind = BindApproxQuantile; - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproximateQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_float = LogicalType::LIST(LogicalType::FLOAT); - fun.arguments.push_back(list_of_float); - return fun; -} - -unique_ptr ApproxQuantileDecimalDeserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = ApproximateQuantileBindData::Deserialize(deserializer, function); - auto &return_type = deserializer.Get(); - if (return_type.id() == LogicalTypeId::LIST) { - function = ApproxQuantileDecimalListFunction(function.arguments[0]); - } else { - function = ApproxQuantileDecimalFunction(function.arguments[0]); - } - return bind_data; -} - -AggregateFunction GetApproxQuantileDecimal() { - // stub function - the actual function is set during bind or deserialize - AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::FLOAT}, LogicalTypeId::DECIMAL, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, BindApproxQuantileDecimal); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproxQuantileDecimalDeserialize; - return fun; -} - -AggregateFunction GetApproxQuantileDecimalList() { - // stub function - the actual function is set during bind or deserialize - AggregateFunction fun({LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::FLOAT)}, - LogicalType::LIST(LogicalTypeId::DECIMAL), nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindApproxQuantileDecimalList); - fun.serialize = ApproximateQuantileBindData::Serialize; - fun.deserialize = ApproxQuantileDecimalDeserialize; - return fun; -} - -AggregateFunctionSet ApproxQuantileFun::GetFunctions() { - AggregateFunctionSet approx_quantile; - approx_quantile.AddFunction(GetApproxQuantileDecimal()); - - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::SMALLINT)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::INTEGER)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::BIGINT)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::HUGEINT)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DOUBLE)); - - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::DATE)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIME_TZ)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP)); - approx_quantile.AddFunction(GetApproximateQuantileAggregate(LogicalType::TIMESTAMP_TZ)); - - // List variants - approx_quantile.AddFunction(GetApproxQuantileDecimalList()); - - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::TINYINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::SMALLINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::INTEGER)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::BIGINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::HUGEINT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::FLOAT)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalTypeId::DOUBLE)); - - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::DATE)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIME_TZ)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP)); - approx_quantile.AddFunction(GetApproxQuantileListAggregate(LogicalType::TIMESTAMP_TZ)); - - return approx_quantile; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp deleted file mode 100644 index dedb74297..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mad.cpp +++ /dev/null @@ -1,345 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/abs.hpp" -#include "core_functions/aggregate/quantile_state.hpp" - -namespace duckdb { - -struct FrameSet { - inline explicit FrameSet(const SubFrames &frames_p) : frames(frames_p) { - } - - inline idx_t Size() const { - idx_t result = 0; - for (const auto &frame : frames) { - result += frame.end - frame.start; - } - - return result; - } - - inline bool Contains(idx_t i) const { - for (idx_t f = 0; f < frames.size(); ++f) { - const auto &frame = frames[f]; - if (frame.start <= i && i < frame.end) { - return true; - } - } - return false; - } - const SubFrames &frames; -}; - -struct QuantileReuseUpdater { - idx_t *index; - idx_t j; - - inline QuantileReuseUpdater(idx_t *index, idx_t j) : index(index), j(j) { - } - - inline void Neither(idx_t begin, idx_t end) { - } - - inline void Left(idx_t begin, idx_t end) { - } - - inline void Right(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - index[j++] = begin; - } - } - - inline void Both(idx_t begin, idx_t end) { - } -}; - -void ReuseIndexes(idx_t *index, const SubFrames &currs, const SubFrames &prevs) { - - // Copy overlapping indices by scanning the previous set and copying down into holes. - // We copy instead of leaving gaps in case there are fewer values in the current frame. - FrameSet prev_set(prevs); - FrameSet curr_set(currs); - const auto prev_count = prev_set.Size(); - idx_t j = 0; - for (idx_t p = 0; p < prev_count; ++p) { - auto idx = index[p]; - - // Shift down into any hole - if (j != p) { - index[j] = idx; - } - - // Skip overlapping values - if (curr_set.Contains(idx)) { - ++j; - } - } - - // Insert new indices - if (j > 0) { - QuantileReuseUpdater updater(index, j); - AggregateExecutor::IntersectFrames(prevs, currs, updater); - } else { - // No overlap: overwrite with new values - for (const auto &curr : currs) { - for (auto idx = curr.start; idx < curr.end; ++idx) { - index[j++] = idx; - } - } - } -} - -//===--------------------------------------------------------------------===// -// Median Absolute Deviation -//===--------------------------------------------------------------------===// -template -struct MadAccessor { - using INPUT_TYPE = T; - using RESULT_TYPE = R; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const RESULT_TYPE delta = input - UnsafeNumericCast(median); - return TryAbsOperator::Operation(delta); - } -}; - -// hugeint_t - double => undefined -template <> -struct MadAccessor { - using INPUT_TYPE = hugeint_t; - using RESULT_TYPE = double; - using MEDIAN_TYPE = double; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = Hugeint::Cast(input) - median; - return TryAbsOperator::Operation(delta); - } -}; - -// date_t - timestamp_t => interval_t -template <> -struct MadAccessor { - using INPUT_TYPE = date_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto dt = Cast::Operation(input); - const auto delta = dt - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -// timestamp_t - timestamp_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = timestamp_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = timestamp_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -// dtime_t - dtime_t => int64_t -template <> -struct MadAccessor { - using INPUT_TYPE = dtime_t; - using RESULT_TYPE = interval_t; - using MEDIAN_TYPE = dtime_t; - const MEDIAN_TYPE &median; - explicit MadAccessor(const MEDIAN_TYPE &median_p) : median(median_p) { - } - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - const auto delta = input - median; - return Interval::FromMicro(TryAbsOperator::Operation(delta)); - } -}; - -template -struct MedianAbsoluteDeviationOperation : QuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - using INPUT_TYPE = typename STATE::InputType; - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &q = bind_data.quantiles[0]; - Interpolator interp(q, state.v.size(), false); - const auto med = interp.template Operation(state.v.data(), finalize_data.result); - - MadAccessor accessor(med); - target = interp.template Operation(state.v.data(), finalize_data.result, accessor); - } - - template - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, - idx_t ridx) { - auto &state = *reinterpret_cast(l_state); - auto gstate = reinterpret_cast(g_state); - - auto &data = state.GetOrCreateWindowCursor(partition); - const auto &fmask = partition.filter_mask; - - auto rdata = FlatVector::GetData(result); - - QuantileIncluded included(fmask, data); - const auto n = FrameSize(included, frames); - - if (!n) { - auto &rmask = FlatVector::Validity(result); - rmask.Set(ridx, false); - return; - } - - // Compute the median - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - D_ASSERT(bind_data.quantiles.size() == 1); - const auto &quantile = bind_data.quantiles[0]; - auto &window_state = state.GetOrCreateWindowState(); - MEDIAN_TYPE med; - if (gstate && gstate->HasTree()) { - med = gstate->GetWindowState().template WindowScalar(data, frames, n, result, quantile); - } else { - window_state.UpdateSkip(data, frames, included); - med = window_state.template WindowScalar(data, frames, n, result, quantile); - } - - // Lazily initialise frame state - window_state.SetCount(frames.back().end - frames.front().start); - auto index2 = window_state.m.data(); - D_ASSERT(index2); - - // The replacement trick does not work on the second index because if - // the median has changed, the previous order is not correct. - // It is probably close, however, and so reuse is helpful. - auto &prevs = window_state.prevs; - ReuseIndexes(index2, frames, prevs); - std::partition(index2, index2 + window_state.count, included); - - Interpolator interp(quantile, n, false); - - // Compute mad from the second index - using ID = QuantileIndirect; - ID indirect(data); - - using MAD = MadAccessor; - MAD mad(med); - - using MadIndirect = QuantileComposed; - MadIndirect mad_indirect(mad, indirect); - rdata[ridx] = interp.template Operation(index2, result, mad_indirect); - - // Prev is used by both skip lists and increments - prevs = frames; - } -}; - -unique_ptr BindMAD(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); -} - -template -AggregateFunction GetTypedMedianAbsoluteDeviationAggregateFunction(const LogicalType &input_type, - const LogicalType &target_type) { - using STATE = QuantileState; - using OP = MedianAbsoluteDeviationOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.bind = BindMAD; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; -#ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; -#endif - return fun; -} - -AggregateFunction GetMedianAbsoluteDeviationAggregateFunctionInternal(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::FLOAT: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case LogicalTypeId::DOUBLE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT32: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT64: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - case PhysicalType::INT128: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, type); - default: - throw NotImplementedException("Unimplemented Median Absolute Deviation DECIMAL aggregate"); - } - break; - - case LogicalTypeId::DATE: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction( - type, LogicalType::INTERVAL); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return GetTypedMedianAbsoluteDeviationAggregateFunction(type, - LogicalType::INTERVAL); - - default: - throw NotImplementedException("Unimplemented Median Absolute Deviation aggregate"); - } -} - -AggregateFunction GetMedianAbsoluteDeviationAggregateFunction(const LogicalType &type) { - auto result = GetMedianAbsoluteDeviationAggregateFunctionInternal(type); - result.errors = FunctionErrors::CAN_THROW_RUNTIME_ERROR; - return result; -} - -unique_ptr BindMedianAbsoluteDeviationDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetMedianAbsoluteDeviationAggregateFunction(arguments[0]->return_type); - function.name = "mad"; - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return BindMAD(context, function, arguments); -} - -AggregateFunctionSet MadFun::GetFunctions() { - AggregateFunctionSet mad("mad"); - mad.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindMedianAbsoluteDeviationDecimal)); - - const vector MAD_TYPES = {LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, - LogicalType::TIME_TZ}; - for (const auto &type : MAD_TYPES) { - mad.AddFunction(GetMedianAbsoluteDeviationAggregateFunction(type)); - } - return mad; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp deleted file mode 100644 index 8c35fc8c9..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/mode.cpp +++ /dev/null @@ -1,573 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "core_functions/aggregate/distributive_functions.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/owning_string_map.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/function/aggregate/sort_key_helpers.hpp" -#include "duckdb/common/algorithm.hpp" -#include - -// MODE( ) -// Returns the most frequent value for the values within expr1. -// NULL values are ignored. If all the values are NULL, or there are 0 rows, then the function returns NULL. - -namespace std {} // namespace std - -namespace duckdb { - -struct ModeAttr { - ModeAttr() : count(0), first_row(std::numeric_limits::max()) { - } - size_t count; - idx_t first_row; -}; - -template -struct ModeStandard { - using MAP_TYPE = unordered_map; - - static MAP_TYPE *CreateEmpty(ArenaAllocator &) { - return new MAP_TYPE(); - } - static MAP_TYPE *CreateEmpty(Allocator &) { - return new MAP_TYPE(); - } - - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return RESULT_TYPE(input); - } -}; - -struct ModeString { - using MAP_TYPE = OwningStringMap; - - static MAP_TYPE *CreateEmpty(ArenaAllocator &allocator) { - return new MAP_TYPE(allocator); - } - static MAP_TYPE *CreateEmpty(Allocator &allocator) { - return new MAP_TYPE(allocator); - } - - template - static RESULT_TYPE Assign(Vector &result, INPUT_TYPE input) { - return StringVector::AddStringOrBlob(result, input); - } -}; - -template -struct ModeState { - using Counts = typename TYPE_OP::MAP_TYPE; - - ModeState() { - } - - SubFrames prevs; - Counts *frequency_map = nullptr; - KEY_TYPE *mode = nullptr; - size_t nonzero = 0; - bool valid = false; - size_t count = 0; - - //! The collection being read - const ColumnDataCollection *inputs; - //! The state used for reading the collection on this thread - ColumnDataScanState *scan = nullptr; - //! The data chunk paged into into - DataChunk page; - //! The data pointer - const KEY_TYPE *data = nullptr; - //! The validity mask - const ValidityMask *validity = nullptr; - - ~ModeState() { - if (frequency_map) { - delete frequency_map; - } - if (mode) { - delete mode; - } - if (scan) { - delete scan; - } - } - - void InitializePage(const WindowPartitionInput &partition) { - if (!scan) { - scan = new ColumnDataScanState(); - } - if (page.ColumnCount() == 0) { - D_ASSERT(partition.inputs); - inputs = partition.inputs; - D_ASSERT(partition.column_ids.size() == 1); - inputs->InitializeScan(*scan, partition.column_ids); - inputs->InitializeScanChunk(*scan, page); - } - } - - inline sel_t RowOffset(idx_t row_idx) const { - D_ASSERT(RowIsVisible(row_idx)); - return UnsafeNumericCast(row_idx - scan->current_row_index); - } - - inline bool RowIsVisible(idx_t row_idx) const { - return (row_idx < scan->next_row_index && scan->current_row_index <= row_idx); - } - - inline idx_t Seek(idx_t row_idx) { - if (!RowIsVisible(row_idx)) { - D_ASSERT(inputs); - inputs->Seek(row_idx, *scan, page); - data = FlatVector::GetData(page.data[0]); - validity = &FlatVector::Validity(page.data[0]); - } - return RowOffset(row_idx); - } - - inline const KEY_TYPE &GetCell(idx_t row_idx) { - const auto offset = Seek(row_idx); - return data[offset]; - } - - inline bool RowIsValid(idx_t row_idx) { - const auto offset = Seek(row_idx); - return validity->RowIsValid(offset); - } - - void Reset() { - if (frequency_map) { - frequency_map->clear(); - } - nonzero = 0; - count = 0; - valid = false; - } - - void ModeAdd(idx_t row) { - const auto &key = GetCell(row); - auto &attr = (*frequency_map)[key]; - auto new_count = (attr.count += 1); - if (new_count == 1) { - ++nonzero; - attr.first_row = row; - } else { - attr.first_row = MinValue(row, attr.first_row); - } - if (new_count > count) { - valid = true; - count = new_count; - if (mode) { - *mode = key; - } else { - mode = new KEY_TYPE(key); - } - } - } - - void ModeRm(idx_t frame) { - const auto &key = GetCell(frame); - auto &attr = (*frequency_map)[key]; - auto old_count = attr.count; - nonzero -= size_t(old_count == 1); - - attr.count -= 1; - if (count == old_count && key == *mode) { - valid = false; - } - } - - typename Counts::const_iterator Scan() const { - //! Initialize control variables to first variable of the frequency map - auto highest_frequency = frequency_map->begin(); - for (auto i = highest_frequency; i != frequency_map->end(); ++i) { - // Tie break with the lowest insert position - if (i->second.count > highest_frequency->second.count || - (i->second.count == highest_frequency->second.count && - i->second.first_row < highest_frequency->second.first_row)) { - highest_frequency = i; - } - } - return highest_frequency; - } -}; - -template -struct ModeIncluded { - inline explicit ModeIncluded(const ValidityMask &fmask_p, STATE &state) : fmask(fmask_p), state(state) { - } - - inline bool operator()(const idx_t &idx) const { - return fmask.RowIsValid(idx) && state.RowIsValid(idx); - } - const ValidityMask &fmask; - STATE &state; -}; - -template -struct BaseModeFunction { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { - if (!state.frequency_map) { - state.frequency_map = TYPE_OP::CreateEmpty(input_data.allocator); - } - auto &i = (*state.frequency_map)[key]; - ++i.count; - i.first_row = MinValue(i.first_row, state.count); - ++state.count; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input) { - Execute(state, key, aggr_input.input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.frequency_map) { - return; - } - if (!target.frequency_map) { - // Copy - don't destroy! Otherwise windowing will break. - target.frequency_map = new typename STATE::Counts(*source.frequency_map); - target.count = source.count; - return; - } - for (auto &val : *source.frequency_map) { - auto &i = (*target.frequency_map)[val.first]; - i.count += val.second.count; - i.first_row = MinValue(i.first_row, val.second.first_row); - } - target.count += source.count; - } - - static bool IgnoreNull() { - return true; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } -}; - -template -struct TypedModeFunction : BaseModeFunction { - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &key, AggregateUnaryInput &aggr_input, idx_t count) { - if (!state.frequency_map) { - state.frequency_map = TYPE_OP::CreateEmpty(aggr_input.input.allocator); - } - auto &i = (*state.frequency_map)[key]; - i.count += count; - i.first_row = MinValue(i.first_row, state.count); - state.count += count; - } -}; - -template -struct ModeFunction : TypedModeFunction { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.frequency_map) { - finalize_data.ReturnNull(); - return; - } - auto highest_frequency = state.Scan(); - if (highest_frequency != state.frequency_map->end()) { - target = TYPE_OP::template Assign(finalize_data.result, highest_frequency->first); - } else { - finalize_data.ReturnNull(); - } - } - - template - struct UpdateWindowState { - STATE &state; - ModeIncluded &included; - - inline UpdateWindowState(STATE &state, ModeIncluded &included) : state(state), included(included) { - } - - inline void Neither(idx_t begin, idx_t end) { - } - - inline void Left(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - state.ModeRm(begin); - } - } - } - - inline void Right(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - state.ModeAdd(begin); - } - } - } - - inline void Both(idx_t begin, idx_t end) { - } - }; - - template - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, - idx_t rid) { - auto &state = *reinterpret_cast(l_state); - - state.InitializePage(partition); - const auto &fmask = partition.filter_mask; - - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - auto &prevs = state.prevs; - if (prevs.empty()) { - prevs.resize(1); - } - - ModeIncluded included(fmask, state); - - if (!state.frequency_map) { - state.frequency_map = TYPE_OP::CreateEmpty(Allocator::DefaultAllocator()); - } - const size_t tau_inverse = 4; // tau==0.25 - if (state.nonzero <= (state.frequency_map->size() / tau_inverse) || prevs.back().end <= frames.front().start || - frames.back().end <= prevs.front().start) { - state.Reset(); - // for f ∈ F do - for (const auto &frame : frames) { - for (auto i = frame.start; i < frame.end; ++i) { - if (included(i)) { - state.ModeAdd(i); - } - } - } - } else { - using Updater = UpdateWindowState; - Updater updater(state, included); - AggregateExecutor::IntersectFrames(prevs, frames, updater); - } - - if (!state.valid) { - // Rescan - auto highest_frequency = state.Scan(); - if (highest_frequency != state.frequency_map->end()) { - *(state.mode) = highest_frequency->first; - state.count = highest_frequency->second.count; - state.valid = (state.count > 0); - } - } - - if (state.valid) { - rdata[rid] = TYPE_OP::template Assign(result, *state.mode); - } else { - rmask.Set(rid, false); - } - - prevs = frames; - } -}; - -template -struct ModeFallbackFunction : BaseModeFunction { - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.frequency_map) { - finalize_data.ReturnNull(); - return; - } - auto highest_frequency = state.Scan(); - if (highest_frequency != state.frequency_map->end()) { - CreateSortKeyHelpers::DecodeSortKey(highest_frequency->first, finalize_data.result, - finalize_data.result_idx, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - } else { - finalize_data.ReturnNull(); - } - } -}; - -AggregateFunction GetFallbackModeFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = ModeFallbackFunction; - AggregateFunction aggr({type}, type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr); - aggr.destructor = AggregateFunction::StateDestroy; - return aggr; -} - -template > -AggregateFunction GetTypedModeFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = ModeFunction; - auto func = - AggregateFunction::UnaryAggregateDestructor( - type, type); - func.window = OP::template Window; - return func; -} - -AggregateFunction GetModeAggregate(const LogicalType &type) { - switch (type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::INT8: - return GetTypedModeFunction(type); - case PhysicalType::UINT8: - return GetTypedModeFunction(type); - case PhysicalType::INT16: - return GetTypedModeFunction(type); - case PhysicalType::UINT16: - return GetTypedModeFunction(type); - case PhysicalType::INT32: - return GetTypedModeFunction(type); - case PhysicalType::UINT32: - return GetTypedModeFunction(type); - case PhysicalType::INT64: - return GetTypedModeFunction(type); - case PhysicalType::UINT64: - return GetTypedModeFunction(type); - case PhysicalType::INT128: - return GetTypedModeFunction(type); - case PhysicalType::UINT128: - return GetTypedModeFunction(type); - case PhysicalType::FLOAT: - return GetTypedModeFunction(type); - case PhysicalType::DOUBLE: - return GetTypedModeFunction(type); - case PhysicalType::INTERVAL: - return GetTypedModeFunction(type); - case PhysicalType::VARCHAR: - return GetTypedModeFunction(type); -#endif - default: - return GetFallbackModeFunction(type); - } -} - -unique_ptr BindModeAggregate(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetModeAggregate(arguments[0]->return_type); - function.name = "mode"; - return nullptr; -} - -AggregateFunctionSet ModeFun::GetFunctions() { - AggregateFunctionSet mode("mode"); - mode.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalTypeId::ANY, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, BindModeAggregate)); - return mode; -} - -//===--------------------------------------------------------------------===// -// Entropy -//===--------------------------------------------------------------------===// -template -static double FinalizeEntropy(STATE &state) { - if (!state.frequency_map) { - return 0; - } - double count = static_cast(state.count); - double entropy = 0; - for (auto &val : *state.frequency_map) { - double val_sec = static_cast(val.second.count); - entropy += (val_sec / count) * log2(count / val_sec); - } - return entropy; -} - -template -struct EntropyFunction : TypedModeFunction { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = FinalizeEntropy(state); - } -}; - -template -struct EntropyFallbackFunction : BaseModeFunction { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = FinalizeEntropy(state); - } -}; - -template > -AggregateFunction GetTypedEntropyFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = EntropyFunction; - auto func = - AggregateFunction::UnaryAggregateDestructor( - type, LogicalType::DOUBLE); - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return func; -} - -AggregateFunction GetFallbackEntropyFunction(const LogicalType &type) { - using STATE = ModeState; - using OP = EntropyFallbackFunction; - AggregateFunction func({type}, LogicalType::DOUBLE, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateSortKeyHelpers::UnaryUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, nullptr); - func.destructor = AggregateFunction::StateDestroy; - func.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return func; -} - -AggregateFunction GetEntropyFunction(const LogicalType &type) { - switch (type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::UINT16: - return GetTypedEntropyFunction(type); - case PhysicalType::UINT32: - return GetTypedEntropyFunction(type); - case PhysicalType::UINT64: - return GetTypedEntropyFunction(type); - case PhysicalType::INT16: - return GetTypedEntropyFunction(type); - case PhysicalType::INT32: - return GetTypedEntropyFunction(type); - case PhysicalType::INT64: - return GetTypedEntropyFunction(type); - case PhysicalType::FLOAT: - return GetTypedEntropyFunction(type); - case PhysicalType::DOUBLE: - return GetTypedEntropyFunction(type); - case PhysicalType::VARCHAR: - return GetTypedEntropyFunction(type); -#endif - default: - return GetFallbackEntropyFunction(type); - } -} - -unique_ptr BindEntropyAggregate(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetEntropyFunction(arguments[0]->return_type); - function.name = "entropy"; - return nullptr; -} - -AggregateFunctionSet EntropyFun::GetFunctions() { - AggregateFunctionSet entropy("entropy"); - entropy.AddFunction(AggregateFunction({LogicalTypeId::ANY}, LogicalType::DOUBLE, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, BindEntropyAggregate)); - return entropy; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp deleted file mode 100644 index 98ca4d5be..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/quantile.cpp +++ /dev/null @@ -1,873 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/common/enums/quantile_enum.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/abs.hpp" -#include "core_functions/aggregate/quantile_state.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/queue.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/function/aggregate/sort_key_helpers.hpp" - -namespace duckdb { - -template -struct IndirectLess { - inline explicit IndirectLess(const INPUT_TYPE *inputs_p) : inputs(inputs_p) { - } - - inline bool operator()(const idx_t &lhi, const idx_t &rhi) const { - return inputs[lhi] < inputs[rhi]; - } - - const INPUT_TYPE *inputs; -}; - -template -static inline T QuantileAbs(const T &t) { - return AbsOperator::Operation(t); -} - -template <> -inline Value QuantileAbs(const Value &v) { - const auto &type = v.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - const auto integral = IntegralValue::Get(v); - const auto width = DecimalType::GetWidth(type); - const auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(QuantileAbs(Cast::Operation(integral)), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(QuantileAbs(integral), width, scale); - default: - throw InternalException("Unknown DECIMAL type"); - } - } - default: - return Value::DOUBLE(QuantileAbs(v.GetValue())); - } -} - -//===--------------------------------------------------------------------===// -// Quantile Bind Data -//===--------------------------------------------------------------------===// -QuantileBindData::QuantileBindData() { -} - -QuantileBindData::QuantileBindData(const Value &quantile_p) - : quantiles(1, QuantileValue(QuantileAbs(quantile_p))), order(1, 0), desc(quantile_p < 0) { -} - -QuantileBindData::QuantileBindData(const vector &quantiles_p) { - vector normalised; - size_t pos = 0; - size_t neg = 0; - for (idx_t i = 0; i < quantiles_p.size(); ++i) { - const auto &q = quantiles_p[i]; - pos += (q > 0); - neg += (q < 0); - normalised.emplace_back(QuantileAbs(q)); - order.push_back(i); - } - if (pos && neg) { - throw BinderException("QUANTILE parameters must have consistent signs"); - } - desc = (neg > 0); - - IndirectLess lt(normalised.data()); - std::sort(order.begin(), order.end(), lt); - - for (const auto &q : normalised) { - quantiles.emplace_back(QuantileValue(q)); - } -} - -QuantileBindData::QuantileBindData(const QuantileBindData &other) : order(other.order), desc(other.desc) { - for (const auto &q : other.quantiles) { - quantiles.emplace_back(q); - } -} - -unique_ptr QuantileBindData::Copy() const { - return make_uniq(*this); -} - -bool QuantileBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return desc == other.desc && quantiles == other.quantiles && order == other.order; -} - -void QuantileBindData::Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - vector raw; - for (const auto &q : bind_data.quantiles) { - raw.emplace_back(q.val); - } - serializer.WriteProperty(100, "quantiles", raw); - serializer.WriteProperty(101, "order", bind_data.order); - serializer.WriteProperty(102, "desc", bind_data.desc); -} - -unique_ptr QuantileBindData::Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - vector raw; - deserializer.ReadProperty(100, "quantiles", raw); - deserializer.ReadProperty(101, "order", result->order); - deserializer.ReadProperty(102, "desc", result->desc); - QuantileSerializationType deserialization_type; - deserializer.ReadPropertyWithExplicitDefault(103, "quantile_type", deserialization_type, - QuantileSerializationType::NON_DECIMAL); - - if (deserialization_type != QuantileSerializationType::NON_DECIMAL) { - deserializer.ReadDeletedProperty(104, "logical_type"); - } - - for (const auto &r : raw) { - result->quantiles.emplace_back(QuantileValue(r)); - } - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Cast Interpolation -//===--------------------------------------------------------------------===// -template <> -interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result) { - return {0, 0, src.micros}; -} - -template <> -double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi) { - return lo * (1.0 - d) + hi * d; -} - -template <> -dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi) { - return dtime_t(std::llround(static_cast(lo.micros) * (1.0 - d) + static_cast(hi.micros) * d)); -} - -template <> -timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi) { - return timestamp_t(std::llround(static_cast(lo.value) * (1.0 - d) + static_cast(hi.value) * d)); -} - -template <> -hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi) { - return Hugeint::Convert(Interpolate(Hugeint::Cast(lo), d, Hugeint::Cast(hi))); -} - -static interval_t MultiplyByDouble(const interval_t &i, const double &d) { // NOLINT - D_ASSERT(d >= 0 && d <= 1); - return Interval::FromMicro(std::llround(static_cast(Interval::GetMicro(i)) * d)); -} - -inline interval_t operator+(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) + Interval::GetMicro(rhs)); -} - -inline interval_t operator-(const interval_t &lhs, const interval_t &rhs) { - return Interval::FromMicro(Interval::GetMicro(lhs) - Interval::GetMicro(rhs)); -} - -template <> -interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi) { - const interval_t delta = hi - lo; - return lo + MultiplyByDouble(delta, d); -} - -template <> -string_t CastInterpolation::Cast(const string_t &src, Vector &result) { - return StringVector::AddStringOrBlob(result, src); -} - -//===--------------------------------------------------------------------===// -// Scalar Quantile -//===--------------------------------------------------------------------===// -template -struct QuantileScalarOperation : public QuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); - target = interp.template Operation(state.v.data(), finalize_data.result); - } - - template - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &result, - idx_t ridx) { - auto &state = *reinterpret_cast(l_state); - auto gstate = reinterpret_cast(g_state); - - auto &data = state.GetOrCreateWindowCursor(partition); - const auto &fmask = partition.filter_mask; - - QuantileIncluded included(fmask, data); - const auto n = FrameSize(included, frames); - - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - auto rdata = FlatVector::GetData(result); - auto &rmask = FlatVector::Validity(result); - - if (!n) { - rmask.Set(ridx, false); - return; - } - - const auto &quantile = bind_data.quantiles[0]; - if (gstate && gstate->HasTree()) { - rdata[ridx] = gstate->GetWindowState().template WindowScalar(data, frames, n, result, - quantile); - } else { - auto &window_state = state.GetOrCreateWindowState(); - - // Update the skip list - window_state.UpdateSkip(data, frames, included); - - // Find the position(s) needed - rdata[ridx] = window_state.template WindowScalar(data, frames, n, result, quantile); - - // Save the previous state for next time - window_state.prevs = frames; - } - } -}; - -struct QuantileScalarFallback : QuantileOperation { - template - static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { - state.AddElement(key, input_data); - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - D_ASSERT(bind_data.quantiles.size() == 1); - Interpolator interp(bind_data.quantiles[0], state.v.size(), bind_data.desc); - auto interpolation_result = interp.InterpolateInternal(state.v.data()); - CreateSortKeyHelpers::DecodeSortKey(interpolation_result, finalize_data.result, finalize_data.result_idx, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - } -}; - -//===--------------------------------------------------------------------===// -// Quantile List -//===--------------------------------------------------------------------===// -template -struct QuantileListOperation : QuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - auto v_t = state.v.data(); - D_ASSERT(v_t); - - auto &entry = target; - entry.offset = ridx; - idx_t lower = 0; - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.v.size(), bind_data.desc); - interp.begin = lower; - rdata[ridx + q] = interp.template Operation(v_t, result); - lower = interp.FRN; - } - entry.length = bind_data.quantiles.size(); - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } - - template - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &frames, Vector &list, - idx_t lidx) { - auto &state = *reinterpret_cast(l_state); - auto gstate = reinterpret_cast(g_state); - - auto &data = state.GetOrCreateWindowCursor(partition); - const auto &fmask = partition.filter_mask; - - D_ASSERT(aggr_input_data.bind_data); - auto &bind_data = aggr_input_data.bind_data->Cast(); - - QuantileIncluded included(fmask, data); - const auto n = FrameSize(included, frames); - - // Result is a constant LIST with a fixed length - if (!n) { - auto &lmask = FlatVector::Validity(list); - lmask.Set(lidx, false); - return; - } - - if (gstate && gstate->HasTree()) { - gstate->GetWindowState().template WindowList(data, frames, n, list, lidx, bind_data); - } else { - auto &window_state = state.GetOrCreateWindowState(); - window_state.UpdateSkip(data, frames, included); - window_state.template WindowList(data, frames, n, list, lidx, bind_data); - window_state.prevs = frames; - } - } -}; - -struct QuantileListFallback : QuantileOperation { - template - static void Execute(STATE &state, const INPUT_TYPE &key, AggregateInputData &input_data) { - state.AddElement(key, input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.v.empty()) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - - D_ASSERT(state.v.data()); - - auto &entry = target; - entry.offset = ridx; - idx_t lower = 0; - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, state.v.size(), bind_data.desc); - interp.begin = lower; - auto interpolation_result = interp.InterpolateInternal(state.v.data()); - CreateSortKeyHelpers::DecodeSortKey(interpolation_result, result, ridx + q, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - lower = interp.FRN; - } - entry.length = bind_data.quantiles.size(); - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } -}; - -//===--------------------------------------------------------------------===// -// Discrete Quantiles -//===--------------------------------------------------------------------===// -template -AggregateFunction GetDiscreteQuantileTemplated(const LogicalType &type) { - switch (type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::INT8: - return OP::template GetFunction(type); - case PhysicalType::INT16: - return OP::template GetFunction(type); - case PhysicalType::INT32: - return OP::template GetFunction(type); - case PhysicalType::INT64: - return OP::template GetFunction(type); - case PhysicalType::INT128: - return OP::template GetFunction(type); - case PhysicalType::FLOAT: - return OP::template GetFunction(type); - case PhysicalType::DOUBLE: - return OP::template GetFunction(type); - case PhysicalType::INTERVAL: - return OP::template GetFunction(type); - case PhysicalType::VARCHAR: - return OP::template GetFunction(type); -#endif - default: - return OP::GetFallback(type); - } -} - -struct ScalarDiscreteQuantile { - template - static AggregateFunction GetFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto fun = AggregateFunction::UnaryAggregateDestructor(type, type); -#ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::Window; - fun.window_init = OP::WindowInit; -#endif - return fun; - } - - static AggregateFunction GetFallback(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileScalarFallback; - - AggregateFunction fun({type}, type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateSortKeyHelpers::UnaryUpdate, - AggregateFunction::StateCombine, - AggregateFunction::StateVoidFinalize, nullptr, nullptr, - AggregateFunction::StateDestroy); - return fun; - } -}; - -template -static AggregateFunction QuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { // NOLINT - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -struct ListDiscreteQuantile { - template - static AggregateFunction GetFunction(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(type, type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; -#ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; -#endif - return fun; - } - - static AggregateFunction GetFallback(const LogicalType &type) { - using STATE = QuantileState; - using OP = QuantileListFallback; - - AggregateFunction fun({type}, LogicalType::LIST(type), AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - AggregateSortKeyHelpers::UnaryUpdate, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, nullptr, nullptr, - AggregateFunction::StateDestroy); - return fun; - } -}; - -AggregateFunction GetDiscreteQuantile(const LogicalType &type) { - return GetDiscreteQuantileTemplated(type); -} - -AggregateFunction GetDiscreteQuantileList(const LogicalType &type) { - return GetDiscreteQuantileTemplated(type); -} - -//===--------------------------------------------------------------------===// -// Continuous Quantiles -//===--------------------------------------------------------------------===// -template -AggregateFunction GetContinuousQuantileTemplated(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return OP::template GetFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SMALLINT: - return OP::template GetFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::SQLNULL: - case LogicalTypeId::INTEGER: - return OP::template GetFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::BIGINT: - return OP::template GetFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::HUGEINT: - return OP::template GetFunction(type, LogicalType::DOUBLE); - case LogicalTypeId::FLOAT: - return OP::template GetFunction(type, type); - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::DOUBLE: - return OP::template GetFunction(LogicalType::DOUBLE, LogicalType::DOUBLE); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return OP::template GetFunction(type, type); - case PhysicalType::INT32: - return OP::template GetFunction(type, type); - case PhysicalType::INT64: - return OP::template GetFunction(type, type); - case PhysicalType::INT128: - return OP::template GetFunction(type, type); - default: - throw NotImplementedException("Unimplemented continuous quantile DECIMAL aggregate"); - } - case LogicalTypeId::DATE: - return OP::template GetFunction(type, LogicalType::TIMESTAMP); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - return OP::template GetFunction(type, type); - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return OP::template GetFunction(type, type); - default: - throw NotImplementedException("Unimplemented continuous quantile aggregate"); - } -} - -struct ScalarContinuousQuantile { - template - static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { - using STATE = QuantileState; - using OP = QuantileScalarOperation; - auto fun = - AggregateFunction::UnaryAggregateDestructor(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; -#ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; -#endif - return fun; - } -}; - -struct ListContinuousQuantile { - template - static AggregateFunction GetFunction(const LogicalType &input_type, const LogicalType &target_type) { - using STATE = QuantileState; - using OP = QuantileListOperation; - auto fun = QuantileListAggregate(input_type, target_type); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; -#ifndef DUCKDB_SMALLER_BINARY - fun.window = OP::template Window; - fun.window_init = OP::template WindowInit; -#endif - return fun; - } -}; - -AggregateFunction GetContinuousQuantile(const LogicalType &type) { - return GetContinuousQuantileTemplated(type); -} - -AggregateFunction GetContinuousQuantileList(const LogicalType &type) { - return GetContinuousQuantileTemplated(type); -} - -//===--------------------------------------------------------------------===// -// Quantile binding -//===--------------------------------------------------------------------===// -static const Value &CheckQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < -1 || quantile > 1) { - throw BinderException("QUANTILE can only take parameters in the range [-1, 1]"); - } - if (Value::IsNan(quantile)) { - throw BinderException("QUANTILE parameter cannot be NaN"); - } - - return quantile_val; -} - -unique_ptr BindQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments.size() < 2) { - throw BinderException("QUANTILE requires a range argument between [0, 1]"); - } - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("QUANTILE can only take constant parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - if (quantile_val.IsNull()) { - throw BinderException("QUANTILE argument must not be NULL"); - } - vector quantiles; - switch (quantile_val.type().id()) { - case LogicalTypeId::LIST: - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckQuantile(element_val)); - } - break; - case LogicalTypeId::ARRAY: - for (const auto &element_val : ArrayValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckQuantile(element_val)); - } - break; - default: - quantiles.push_back(CheckQuantile(quantile_val)); - break; - } - - Function::EraseArgument(function, arguments, arguments.size() - 1); - return make_uniq(quantiles); -} - -//===--------------------------------------------------------------------===// -// Function definitions -//===--------------------------------------------------------------------===// -static bool CanInterpolate(const LogicalType &type) { - if (type.HasAlias()) { - return false; - } - switch (type.id()) { - case LogicalTypeId::DECIMAL: - case LogicalTypeId::SQLNULL: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DATE: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return true; - default: - return false; - } -} - -struct MedianFunction { - static AggregateFunction GetAggregate(const LogicalType &type) { - auto fun = CanInterpolate(type) ? GetContinuousQuantile(type) : GetDiscreteQuantile(type); - fun.name = "median"; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; - return fun; - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = QuantileBindData::Deserialize(deserializer, function); - - auto &input_type = function.arguments[0]; - function = GetAggregate(input_type); - return bind_data; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetAggregate(arguments[0]->return_type); - return make_uniq(Value::DECIMAL(int16_t(5), 2, 1)); - } -}; - -struct DiscreteQuantileListFunction { - static AggregateFunction GetAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantileList(type); - fun.name = "quantile_disc"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::LIST(LogicalType::DOUBLE)); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = QuantileBindData::Deserialize(deserializer, function); - - auto &input_type = function.arguments[0]; - function = GetAggregate(input_type); - return bind_data; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetAggregate(arguments[0]->return_type); - return BindQuantile(context, function, arguments); - } -}; - -struct DiscreteQuantileFunction { - static AggregateFunction GetAggregate(const LogicalType &type) { - auto fun = GetDiscreteQuantile(type); - fun.name = "quantile_disc"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = QuantileBindData::Deserialize(deserializer, function); - auto &quantile_data = bind_data->Cast(); - - auto &input_type = function.arguments[0]; - if (quantile_data.quantiles.size() == 1) { - function = GetAggregate(input_type); - } else { - function = DiscreteQuantileListFunction::GetAggregate(input_type); - } - return bind_data; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetAggregate(arguments[0]->return_type); - return BindQuantile(context, function, arguments); - } -}; - -struct ContinuousQuantileFunction { - static AggregateFunction GetAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantile(type); - fun.name = "quantile_cont"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = QuantileBindData::Deserialize(deserializer, function); - - auto &input_type = function.arguments[0]; - function = GetAggregate(input_type); - return bind_data; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type - : function.arguments[0]); - return BindQuantile(context, function, arguments); - } -}; - -struct ContinuousQuantileListFunction { - static AggregateFunction GetAggregate(const LogicalType &type) { - auto fun = GetContinuousQuantileList(type); - fun.name = "quantile_cont"; - fun.bind = Bind; - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto bind_data = QuantileBindData::Deserialize(deserializer, function); - - auto &input_type = function.arguments[0]; - function = GetAggregate(input_type); - return bind_data; - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetAggregate(function.arguments[0].id() == LogicalTypeId::DECIMAL ? arguments[0]->return_type - : function.arguments[0]); - return BindQuantile(context, function, arguments); - } -}; - -template -AggregateFunction EmptyQuantileFunction(LogicalType input, LogicalType result, const LogicalType &extra_arg) { - AggregateFunction fun({std::move(input)}, std::move(result), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - OP::Bind); - if (extra_arg.id() != LogicalTypeId::INVALID) { - fun.arguments.push_back(extra_arg); - } - fun.serialize = QuantileBindData::Serialize; - fun.deserialize = OP::Deserialize; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunctionSet MedianFun::GetFunctions() { - AggregateFunctionSet set("median"); - set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalTypeId::INVALID)); - return set; -} - -AggregateFunctionSet QuantileDiscFun::GetFunctions() { - AggregateFunctionSet set("quantile_disc"); - set.AddFunction( - EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::DOUBLE)); - set.AddFunction(EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, - LogicalType::LIST(LogicalType::DOUBLE))); - // this function is here for deserialization - it cannot be called by users - set.AddFunction( - EmptyQuantileFunction(LogicalType::ANY, LogicalType::ANY, LogicalType::INVALID)); - return set; -} - -vector GetContinuousQuantileTypes() { - return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT, - LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE, - LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ}; -} - -AggregateFunctionSet QuantileContFun::GetFunctions() { - AggregateFunctionSet quantile_cont("quantile_cont"); - quantile_cont.AddFunction(EmptyQuantileFunction( - LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::DOUBLE)); - quantile_cont.AddFunction(EmptyQuantileFunction( - LogicalTypeId::DECIMAL, LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE))); - for (const auto &type : GetContinuousQuantileTypes()) { - quantile_cont.AddFunction(EmptyQuantileFunction(type, type, LogicalType::DOUBLE)); - quantile_cont.AddFunction( - EmptyQuantileFunction(type, type, LogicalType::LIST(LogicalType::DOUBLE))); - } - return quantile_cont; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp b/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp deleted file mode 100644 index 8c332500d..000000000 --- a/src/duckdb/extension/core_functions/aggregate/holistic/reservoir_quantile.cpp +++ /dev/null @@ -1,449 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/reservoir_sample.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/common/queue.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -#include -#include - -namespace duckdb { - -template -struct ReservoirQuantileState { - T *v; - idx_t len; - idx_t pos; - BaseReservoirSampling *r_samp; - - void Resize(idx_t new_len) { - if (new_len <= len) { - return; - } - T *old_v = v; - v = (T *)realloc(v, new_len * sizeof(T)); - if (!v) { - free(old_v); - throw InternalException("Memory allocation failure"); - } - len = new_len; - } - - void ReplaceElement(T &input) { - v[r_samp->min_weighted_entry_index] = input; - r_samp->ReplaceElement(); - } - - void FillReservoir(idx_t sample_size, T element) { - if (pos < sample_size) { - v[pos++] = element; - r_samp->InitializeReservoirWeights(pos, len); - } else { - D_ASSERT(r_samp->next_index_to_sample >= r_samp->num_entries_to_skip_b4_next_sample); - if (r_samp->next_index_to_sample == r_samp->num_entries_to_skip_b4_next_sample) { - ReplaceElement(element); - } - } - } -}; - -struct ReservoirQuantileBindData : public FunctionData { - ReservoirQuantileBindData() { - } - ReservoirQuantileBindData(double quantile_p, idx_t sample_size_p) - : quantiles(1, quantile_p), sample_size(sample_size_p) { - } - - ReservoirQuantileBindData(vector quantiles_p, idx_t sample_size_p) - : quantiles(std::move(quantiles_p)), sample_size(sample_size_p) { - } - - unique_ptr Copy() const override { - return make_uniq(quantiles, sample_size); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return quantiles == other.quantiles && sample_size == other.sample_size; - } - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "quantiles", bind_data.quantiles); - serializer.WriteProperty(101, "sample_size", bind_data.sample_size); - } - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function) { - auto result = make_uniq(); - deserializer.ReadProperty(100, "quantiles", result->quantiles); - deserializer.ReadProperty(101, "sample_size", result->sample_size); - return std::move(result); - } - - vector quantiles; - idx_t sample_size; -}; - -struct ReservoirQuantileOperation { - template - static void Initialize(STATE &state) { - state.v = nullptr; - state.len = 0; - state.pos = 0; - state.r_samp = nullptr; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - auto &bind_data = unary_input.input.bind_data->template Cast(); - if (state.pos == 0) { - state.Resize(bind_data.sample_size); - } - if (!state.r_samp) { - state.r_samp = new BaseReservoirSampling(); - } - D_ASSERT(state.v); - state.FillReservoir(bind_data.sample_size, input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.pos == 0) { - return; - } - if (target.pos == 0) { - target.Resize(source.len); - } - if (!target.r_samp) { - target.r_samp = new BaseReservoirSampling(); - } - for (idx_t src_idx = 0; src_idx < source.pos; src_idx++) { - target.FillReservoir(target.len, source.v[src_idx]); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - if (state.v) { - free(state.v); - state.v = nullptr; - } - if (state.r_samp) { - delete state.r_samp; - state.r_samp = nullptr; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct ReservoirQuantileScalarOperation : public ReservoirQuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - D_ASSERT(state.v); - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - auto v_t = state.v; - D_ASSERT(bind_data.quantiles.size() == 1); - auto offset = (idx_t)((double)(state.pos - 1) * bind_data.quantiles[0]); - std::nth_element(v_t, v_t + offset, v_t + state.pos); - target = v_t[offset]; - } -}; - -AggregateFunction GetReservoirQuantileAggregateFunction(PhysicalType type) { - switch (type) { - case PhysicalType::INT8: - return AggregateFunction::UnaryAggregateDestructor, int8_t, int8_t, - ReservoirQuantileScalarOperation>(LogicalType::TINYINT, - LogicalType::TINYINT); - - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregateDestructor, int16_t, int16_t, - ReservoirQuantileScalarOperation>(LogicalType::SMALLINT, - LogicalType::SMALLINT); - - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregateDestructor, int32_t, int32_t, - ReservoirQuantileScalarOperation>(LogicalType::INTEGER, - LogicalType::INTEGER); - - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregateDestructor, int64_t, int64_t, - ReservoirQuantileScalarOperation>(LogicalType::BIGINT, - LogicalType::BIGINT); - - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregateDestructor, hugeint_t, hugeint_t, - ReservoirQuantileScalarOperation>(LogicalType::HUGEINT, - LogicalType::HUGEINT); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregateDestructor, float, float, - ReservoirQuantileScalarOperation>(LogicalType::FLOAT, - LogicalType::FLOAT); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregateDestructor, double, double, - ReservoirQuantileScalarOperation>(LogicalType::DOUBLE, - LogicalType::DOUBLE); - default: - throw InternalException("Unimplemented reservoir quantile aggregate"); - } -} - -template -struct ReservoirQuantileListOperation : public ReservoirQuantileOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.pos == 0) { - finalize_data.ReturnNull(); - return; - } - - D_ASSERT(finalize_data.input.bind_data); - auto &bind_data = finalize_data.input.bind_data->template Cast(); - - auto &result = ListVector::GetEntry(finalize_data.result); - auto ridx = ListVector::GetListSize(finalize_data.result); - ListVector::Reserve(finalize_data.result, ridx + bind_data.quantiles.size()); - auto rdata = FlatVector::GetData(result); - - auto v_t = state.v; - D_ASSERT(v_t); - - auto &entry = target; - entry.offset = ridx; - entry.length = bind_data.quantiles.size(); - for (size_t q = 0; q < entry.length; ++q) { - const auto &quantile = bind_data.quantiles[q]; - auto offset = (idx_t)((double)(state.pos - 1) * quantile); - std::nth_element(v_t, v_t + offset, v_t + state.pos); - rdata[ridx + q] = v_t[offset]; - } - - ListVector::SetListSize(finalize_data.result, entry.offset + entry.length); - } -}; - -template -static AggregateFunction ReservoirQuantileListAggregate(const LogicalType &input_type, const LogicalType &child_type) { - LogicalType result_type = LogicalType::LIST(child_type); - return AggregateFunction( - {input_type}, result_type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateFunction::UnaryScatterUpdate, AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, AggregateFunction::UnaryUpdate, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetTypedReservoirQuantileListAggregateFunction(const LogicalType &type) { - using STATE = ReservoirQuantileState; - using OP = ReservoirQuantileListOperation; - auto fun = ReservoirQuantileListAggregate(type, type); - return fun; -} - -AggregateFunction GetReservoirQuantileListAggregateFunction(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::SMALLINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::INTEGER: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::BIGINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::HUGEINT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::FLOAT: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::DOUBLE: - return GetTypedReservoirQuantileListAggregateFunction(type); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT32: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT64: - return GetTypedReservoirQuantileListAggregateFunction(type); - case PhysicalType::INT128: - return GetTypedReservoirQuantileListAggregateFunction(type); - default: - throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); - } - default: - // TODO: Add quantitative temporal types - throw NotImplementedException("Unimplemented reservoir quantile list aggregate"); - } -} - -static double CheckReservoirQuantile(const Value &quantile_val) { - if (quantile_val.IsNull()) { - throw BinderException("RESERVOIR_QUANTILE QUANTILE parameter cannot be NULL"); - } - auto quantile = quantile_val.GetValue(); - if (quantile < 0 || quantile > 1) { - throw BinderException("RESERVOIR_QUANTILE can only take parameters in the range [0, 1]"); - } - return quantile; -} - -unique_ptr BindReservoirQuantile(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw BinderException("RESERVOIR_QUANTILE can only take constant quantile parameters"); - } - Value quantile_val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - vector quantiles; - if (quantile_val.type().id() != LogicalTypeId::LIST) { - quantiles.push_back(CheckReservoirQuantile(quantile_val)); - } else { - for (const auto &element_val : ListValue::GetChildren(quantile_val)) { - quantiles.push_back(CheckReservoirQuantile(element_val)); - } - } - - if (arguments.size() == 2) { - // remove the quantile argument so we can use the unary aggregate - if (function.arguments.size() == 2) { - Function::EraseArgument(function, arguments, arguments.size() - 1); - } else { - arguments.pop_back(); - } - return make_uniq(quantiles, 8192ULL); - } - if (!arguments[2]->IsFoldable()) { - throw BinderException("RESERVOIR_QUANTILE can only take constant sample size parameters"); - } - Value sample_size_val = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - if (sample_size_val.IsNull()) { - throw BinderException("Size of the RESERVOIR_QUANTILE sample cannot be NULL"); - } - auto sample_size = sample_size_val.GetValue(); - - if (sample_size_val.IsNull() || sample_size <= 0) { - throw BinderException("Size of the RESERVOIR_QUANTILE sample must be bigger than 0"); - } - - // remove the quantile arguments so we can use the unary aggregate - if (function.arguments.size() == arguments.size()) { - Function::EraseArgument(function, arguments, arguments.size() - 1); - Function::EraseArgument(function, arguments, arguments.size() - 1); - } else { - arguments.pop_back(); - arguments.pop_back(); - } - return make_uniq(quantiles, NumericCast(sample_size)); -} - -unique_ptr BindReservoirQuantileDecimal(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetReservoirQuantileAggregateFunction(arguments[0]->return_type.InternalType()); - auto bind_data = BindReservoirQuantile(context, function, arguments); - function.name = "reservoir_quantile"; - function.serialize = ReservoirQuantileBindData::Serialize; - function.deserialize = ReservoirQuantileBindData::Deserialize; - return bind_data; -} - -AggregateFunction GetReservoirQuantileAggregate(PhysicalType type) { - auto fun = GetReservoirQuantileAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - fun.arguments.emplace_back(LogicalType::DOUBLE); - return fun; -} - -unique_ptr BindReservoirQuantileDecimalList(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function = GetReservoirQuantileListAggregateFunction(arguments[0]->return_type); - auto bind_data = BindReservoirQuantile(context, function, arguments); - function.serialize = ReservoirQuantileBindData::Serialize; - function.deserialize = ReservoirQuantileBindData::Deserialize; - function.name = "reservoir_quantile"; - return bind_data; -} - -AggregateFunction GetReservoirQuantileListAggregate(const LogicalType &type) { - auto fun = GetReservoirQuantileListAggregateFunction(type); - fun.bind = BindReservoirQuantile; - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - // temporarily push an argument so we can bind the actual quantile - auto list_of_double = LogicalType::LIST(LogicalType::DOUBLE); - fun.arguments.push_back(list_of_double); - return fun; -} - -static void DefineReservoirQuantile(AggregateFunctionSet &set, const LogicalType &type) { - // Four versions: type, scalar/list[, count] - auto fun = GetReservoirQuantileAggregate(type.InternalType()); - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); - - // List variants - fun = GetReservoirQuantileListAggregate(type); - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); -} - -static void GetReservoirQuantileDecimalFunction(AggregateFunctionSet &set, const vector &arguments, - const LogicalType &return_value) { - AggregateFunction fun(arguments, return_value, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - BindReservoirQuantileDecimal); - fun.serialize = ReservoirQuantileBindData::Serialize; - fun.deserialize = ReservoirQuantileBindData::Deserialize; - set.AddFunction(fun); - - fun.arguments.emplace_back(LogicalType::INTEGER); - set.AddFunction(fun); -} - -AggregateFunctionSet ReservoirQuantileFun::GetFunctions() { - AggregateFunctionSet reservoir_quantile; - - // DECIMAL - GetReservoirQuantileDecimalFunction(reservoir_quantile, {LogicalTypeId::DECIMAL, LogicalType::DOUBLE}, - LogicalTypeId::DECIMAL); - GetReservoirQuantileDecimalFunction(reservoir_quantile, - {LogicalTypeId::DECIMAL, LogicalType::LIST(LogicalType::DOUBLE)}, - LogicalType::LIST(LogicalTypeId::DECIMAL)); - - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::TINYINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::SMALLINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::INTEGER); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::BIGINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::HUGEINT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::FLOAT); - DefineReservoirQuantile(reservoir_quantile, LogicalTypeId::DOUBLE); - return reservoir_quantile; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp deleted file mode 100644 index fc184b8bb..000000000 --- a/src/duckdb/extension/core_functions/aggregate/nested/binned_histogram.cpp +++ /dev/null @@ -1,410 +0,0 @@ -#include "duckdb/function/scalar/nested_functions.hpp" -#include "core_functions/aggregate/nested_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/common/types/vector.hpp" -#include "core_functions/aggregate/histogram_helpers.hpp" -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -template -struct HistogramBinState { - using TYPE = T; - - unsafe_vector *bin_boundaries; - unsafe_vector *counts; - - void Initialize() { - bin_boundaries = nullptr; - counts = nullptr; - } - - void Destroy() { - if (bin_boundaries) { - delete bin_boundaries; - bin_boundaries = nullptr; - } - if (counts) { - delete counts; - counts = nullptr; - } - } - - bool IsSet() { - return bin_boundaries; - } - - template - void InitializeBins(Vector &bin_vector, idx_t count, idx_t pos, AggregateInputData &aggr_input) { - bin_boundaries = new unsafe_vector(); - counts = new unsafe_vector(); - UnifiedVectorFormat bin_data; - bin_vector.ToUnifiedFormat(count, bin_data); - auto bin_counts = UnifiedVectorFormat::GetData(bin_data); - auto bin_index = bin_data.sel->get_index(pos); - auto bin_list = bin_counts[bin_index]; - if (!bin_data.validity.RowIsValid(bin_index)) { - throw BinderException("Histogram bin list cannot be NULL"); - } - - auto &bin_child = ListVector::GetEntry(bin_vector); - auto bin_count = ListVector::GetListSize(bin_vector); - UnifiedVectorFormat bin_child_data; - auto extra_state = OP::CreateExtraState(bin_count); - OP::PrepareData(bin_child, bin_count, extra_state, bin_child_data); - - bin_boundaries->reserve(bin_list.length); - for (idx_t i = 0; i < bin_list.length; i++) { - auto bin_child_idx = bin_child_data.sel->get_index(bin_list.offset + i); - if (!bin_child_data.validity.RowIsValid(bin_child_idx)) { - throw BinderException("Histogram bin entry cannot be NULL"); - } - bin_boundaries->push_back(OP::template ExtractValue(bin_child_data, bin_list.offset + i, aggr_input)); - } - // sort the bin boundaries - std::sort(bin_boundaries->begin(), bin_boundaries->end()); - // ensure there are no duplicate bin boundaries - for (idx_t i = 1; i < bin_boundaries->size(); i++) { - if (Equals::Operation((*bin_boundaries)[i - 1], (*bin_boundaries)[i])) { - bin_boundaries->erase_at(i); - i--; - } - } - - counts->resize(bin_list.length + 1); - } -}; - -struct HistogramBinFunction { - template - static void Initialize(STATE &state) { - state.Initialize(); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.Destroy(); - } - - static bool IgnoreNull() { - return true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.bin_boundaries) { - // nothing to combine - return; - } - if (!target.bin_boundaries) { - // target does not have bin boundaries - copy everything over - target.bin_boundaries = new unsafe_vector(); - target.counts = new unsafe_vector(); - *target.bin_boundaries = *source.bin_boundaries; - *target.counts = *source.counts; - } else { - // both source and target have bin boundaries - if (*target.bin_boundaries != *source.bin_boundaries) { - throw NotImplementedException( - "Histogram - cannot combine histograms with different bin boundaries. " - "Bin boundaries must be the same for all histograms within the same group"); - } - if (target.counts->size() != source.counts->size()) { - throw InternalException("Histogram combine - bin boundaries are the same but counts are different"); - } - D_ASSERT(target.counts->size() == source.counts->size()); - for (idx_t bin_idx = 0; bin_idx < target.counts->size(); bin_idx++) { - (*target.counts)[bin_idx] += (*source.counts)[bin_idx]; - } - } - } -}; - -struct HistogramRange { - static constexpr bool EXACT = false; - - template - static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { - auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); - return UnsafeNumericCast(entry - bin_boundaries.begin()); - } -}; - -struct HistogramExact { - static constexpr bool EXACT = true; - - template - static idx_t GetBin(T value, const unsafe_vector &bin_boundaries) { - auto entry = std::lower_bound(bin_boundaries.begin(), bin_boundaries.end(), value); - if (entry == bin_boundaries.end() || !(*entry == value)) { - // entry not found - return last bucket - return bin_boundaries.size(); - } - return UnsafeNumericCast(entry - bin_boundaries.begin()); - } -}; - -template -static void HistogramBinUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, - Vector &state_vector, idx_t count) { - auto &input = inputs[0]; - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto &bin_vector = inputs[1]; - - auto extra_state = OP::CreateExtraState(count); - UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); - - auto states = UnifiedVectorFormat::GetData *>(sdata); - auto data = UnifiedVectorFormat::GetData(input_data); - for (idx_t i = 0; i < count; i++) { - auto idx = input_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(idx)) { - continue; - } - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.IsSet()) { - state.template InitializeBins(bin_vector, count, i, aggr_input); - } - auto bin_entry = HIST::template GetBin(data[idx], *state.bin_boundaries); - ++(*state.counts)[bin_entry]; - } -} - -static bool SupportsOtherBucket(const LogicalType &type) { - if (type.HasAlias()) { - return false; - } - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::DATE: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - case LogicalTypeId::STRUCT: - case LogicalTypeId::LIST: - return true; - default: - return false; - } -} -static Value OtherBucketValue(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return Value::MaximumValue(type); - case LogicalTypeId::DATE: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return Value::Infinity(type); - case LogicalTypeId::VARCHAR: - return Value(""); - case LogicalTypeId::BLOB: - return Value::BLOB(""); - case LogicalTypeId::STRUCT: { - // for structs we can set all child members to NULL - auto &child_types = StructType::GetChildTypes(type); - child_list_t child_list; - for (auto &child_type : child_types) { - child_list.push_back(make_pair(child_type.first, Value(child_type.second))); - } - return Value::STRUCT(std::move(child_list)); - } - case LogicalTypeId::LIST: - return Value::LIST(ListType::GetChildType(type), vector()); - default: - throw InternalException("Unsupported type for other bucket"); - } -} - -static void IsHistogramOtherBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_type = args.data[0].GetType(); - if (!SupportsOtherBucket(input_type)) { - result.Reference(Value::BOOLEAN(false)); - return; - } - auto v = OtherBucketValue(input_type); - Vector ref(v); - VectorOperations::NotDistinctFrom(args.data[0], ref, result, args.size()); -} - -template -static void HistogramBinFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, - idx_t offset) { - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData *>(sdata); - - auto &mask = FlatVector::Validity(result); - auto old_len = ListVector::GetListSize(result); - idx_t new_entries = 0; - bool supports_other_bucket = SupportsOtherBucket(MapType::KeyType(result.GetType())); - // figure out how much space we need - for (idx_t i = 0; i < count; i++) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.bin_boundaries) { - continue; - } - new_entries += state.bin_boundaries->size(); - if (state.counts->back() > 0 && supports_other_bucket) { - // overflow bucket has entries - new_entries++; - } - } - // reserve space in the list vector - ListVector::Reserve(result, old_len + new_entries); - auto &keys = MapVector::GetKeys(result); - auto &values = MapVector::GetValues(result); - auto list_entries = FlatVector::GetData(result); - auto count_entries = FlatVector::GetData(values); - - idx_t current_offset = old_len; - for (idx_t i = 0; i < count; i++) { - const auto rid = i + offset; - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.bin_boundaries) { - mask.SetInvalid(rid); - continue; - } - - auto &list_entry = list_entries[rid]; - list_entry.offset = current_offset; - for (idx_t bin_idx = 0; bin_idx < state.bin_boundaries->size(); bin_idx++) { - OP::template HistogramFinalize((*state.bin_boundaries)[bin_idx], keys, current_offset); - count_entries[current_offset] = (*state.counts)[bin_idx]; - current_offset++; - } - if (state.counts->back() > 0 && supports_other_bucket) { - // add overflow bucket ("others") - // set bin boundary to NULL for overflow bucket - keys.SetValue(current_offset, OtherBucketValue(keys.GetType())); - count_entries[current_offset] = state.counts->back(); - current_offset++; - } - list_entry.length = current_offset - list_entry.offset; - } - D_ASSERT(current_offset == old_len + new_entries); - ListVector::SetListSize(result, current_offset); - result.Verify(count); -} - -template -static AggregateFunction GetHistogramBinFunction(const LogicalType &type) { - using STATE_TYPE = HistogramBinState; - - const char *function_name = HIST::EXACT ? "histogram_exact" : "histogram"; - - auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); - return AggregateFunction( - function_name, {type, LogicalType::LIST(type)}, struct_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, HistogramBinUpdateFunction, - AggregateFunction::StateCombine, HistogramBinFinalizeFunction, nullptr, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetHistogramBinFunction(const LogicalType &type) { - if (type.id() == LogicalTypeId::DECIMAL) { - return GetHistogramBinFunction(LogicalType::DOUBLE); - } - switch (type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::BOOL: - return GetHistogramBinFunction(type); - case PhysicalType::UINT8: - return GetHistogramBinFunction(type); - case PhysicalType::UINT16: - return GetHistogramBinFunction(type); - case PhysicalType::UINT32: - return GetHistogramBinFunction(type); - case PhysicalType::UINT64: - return GetHistogramBinFunction(type); - case PhysicalType::INT8: - return GetHistogramBinFunction(type); - case PhysicalType::INT16: - return GetHistogramBinFunction(type); - case PhysicalType::INT32: - return GetHistogramBinFunction(type); - case PhysicalType::INT64: - return GetHistogramBinFunction(type); - case PhysicalType::FLOAT: - return GetHistogramBinFunction(type); - case PhysicalType::DOUBLE: - return GetHistogramBinFunction(type); - case PhysicalType::VARCHAR: - return GetHistogramBinFunction(type); -#endif - default: - return GetHistogramBinFunction(type); - } -} - -template -unique_ptr HistogramBinBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - } - - function = GetHistogramBinFunction(arguments[0]->return_type); - return nullptr; -} - -AggregateFunction HistogramFun::BinnedHistogramFunction() { - return AggregateFunction("histogram", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, LogicalTypeId::MAP, - nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - HistogramBinBindFunction, nullptr); -} - -AggregateFunction HistogramExactFun::GetFunction() { - return AggregateFunction("histogram_exact", {LogicalType::ANY, LogicalType::LIST(LogicalType::ANY)}, - LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, - HistogramBinBindFunction, nullptr); -} - -ScalarFunction IsHistogramOtherBinFun::GetFunction() { - return ScalarFunction("is_histogram_other_bin", {LogicalType::ANY}, LogicalType::BOOLEAN, - IsHistogramOtherBinFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp b/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp deleted file mode 100644 index 8a736f235..000000000 --- a/src/duckdb/extension/core_functions/aggregate/nested/histogram.cpp +++ /dev/null @@ -1,236 +0,0 @@ -#include "duckdb/function/scalar/nested_functions.hpp" -#include "core_functions/aggregate/nested_functions.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/string_map_set.hpp" -#include "core_functions/aggregate/histogram_helpers.hpp" -#include "duckdb/common/owning_string_map.hpp" - -namespace duckdb { - -template -struct HistogramFunction { - template - static void Initialize(STATE &state) { - state.hist = nullptr; - } - - template - static void Destroy(STATE &state, AggregateInputData &) { - if (state.hist) { - delete state.hist; - } - } - - static bool IgnoreNull() { - return true; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.hist) { - return; - } - if (!target.hist) { - target.hist = MAP_TYPE::CreateEmpty(input_data.allocator); - } - for (auto &entry : *source.hist) { - (*target.hist)[entry.first] += entry.second; - } - } -}; - -template -struct DefaultMapType { - using MAP_TYPE = TYPE; - - static TYPE *CreateEmpty(ArenaAllocator &) { - return new TYPE(); - } -}; - -template -struct StringMapType { - using MAP_TYPE = TYPE; - - static TYPE *CreateEmpty(ArenaAllocator &allocator) { - return new TYPE(allocator); - } -}; - -template -static void HistogramUpdateFunction(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, - Vector &state_vector, idx_t count) { - - D_ASSERT(input_count == 1); - - auto &input = inputs[0]; - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - auto extra_state = OP::CreateExtraState(count); - UnifiedVectorFormat input_data; - OP::PrepareData(input, count, extra_state, input_data); - - auto states = UnifiedVectorFormat::GetData *>(sdata); - auto input_values = UnifiedVectorFormat::GetData(input_data); - for (idx_t i = 0; i < count; i++) { - auto idx = input_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(idx)) { - continue; - } - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - state.hist = MAP_TYPE::CreateEmpty(aggr_input.allocator); - } - auto &input_value = input_values[idx]; - ++(*state.hist)[input_value]; - } -} - -template -static void HistogramFinalizeFunction(Vector &state_vector, AggregateInputData &, Vector &result, idx_t count, - idx_t offset) { - using HIST_STATE = HistogramAggState; - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData(sdata); - - auto &mask = FlatVector::Validity(result); - auto old_len = ListVector::GetListSize(result); - idx_t new_entries = 0; - // figure out how much space we need - for (idx_t i = 0; i < count; i++) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - continue; - } - new_entries += state.hist->size(); - } - // reserve space in the list vector - ListVector::Reserve(result, old_len + new_entries); - auto &keys = MapVector::GetKeys(result); - auto &values = MapVector::GetValues(result); - auto list_entries = FlatVector::GetData(result); - auto count_entries = FlatVector::GetData(values); - - idx_t current_offset = old_len; - for (idx_t i = 0; i < count; i++) { - const auto rid = i + offset; - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - mask.SetInvalid(rid); - continue; - } - - auto &list_entry = list_entries[rid]; - list_entry.offset = current_offset; - for (auto &entry : *state.hist) { - OP::template HistogramFinalize(entry.first, keys, current_offset); - count_entries[current_offset] = entry.second; - current_offset++; - } - list_entry.length = current_offset - list_entry.offset; - } - D_ASSERT(current_offset == old_len + new_entries); - ListVector::SetListSize(result, current_offset); - result.Verify(count); -} - -template -static AggregateFunction GetHistogramFunction(const LogicalType &type) { - using STATE_TYPE = HistogramAggState; - using HIST_FUNC = HistogramFunction; - - auto struct_type = LogicalType::MAP(type, LogicalType::UBIGINT); - return AggregateFunction( - "histogram", {type}, struct_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, HistogramUpdateFunction, - AggregateFunction::StateCombine, HistogramFinalizeFunction, nullptr, - nullptr, AggregateFunction::StateDestroy); -} - -template -AggregateFunction GetMapTypeInternal(const LogicalType &type) { - return GetHistogramFunction(type); -} - -template -AggregateFunction GetMapType(const LogicalType &type) { - if (IS_ORDERED) { - return GetMapTypeInternal>>(type); - } - return GetMapTypeInternal>>(type); -} - -template -AggregateFunction GetStringMapType(const LogicalType &type) { - if (IS_ORDERED) { - return GetMapTypeInternal>>(type); - } else { - return GetMapTypeInternal>>(type); - } -} - -template -AggregateFunction GetHistogramFunction(const LogicalType &type) { - switch (type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::BOOL: - return GetMapType(type); - case PhysicalType::UINT8: - return GetMapType(type); - case PhysicalType::UINT16: - return GetMapType(type); - case PhysicalType::UINT32: - return GetMapType(type); - case PhysicalType::UINT64: - return GetMapType(type); - case PhysicalType::INT8: - return GetMapType(type); - case PhysicalType::INT16: - return GetMapType(type); - case PhysicalType::INT32: - return GetMapType(type); - case PhysicalType::INT64: - return GetMapType(type); - case PhysicalType::FLOAT: - return GetMapType(type); - case PhysicalType::DOUBLE: - return GetMapType(type); - case PhysicalType::VARCHAR: - return GetStringMapType(type); -#endif - default: - return GetStringMapType(type); - } -} - -template -unique_ptr HistogramBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - - D_ASSERT(arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - function = GetHistogramFunction(arguments[0]->return_type); - return make_uniq(function.return_type); -} - -AggregateFunctionSet HistogramFun::GetFunctions() { - AggregateFunctionSet fun; - AggregateFunction histogram_function("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, HistogramBindFunction, nullptr); - fun.AddFunction(HistogramFun::BinnedHistogramFunction()); - fun.AddFunction(histogram_function); - return fun; -} - -AggregateFunction HistogramFun::GetHistogramUnorderedMap(LogicalType &type) { - return AggregateFunction("histogram", {LogicalType::ANY}, LogicalTypeId::MAP, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, HistogramBindFunction, nullptr); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp b/src/duckdb/extension/core_functions/aggregate/nested/list.cpp deleted file mode 100644 index 7b23987d6..000000000 --- a/src/duckdb/extension/core_functions/aggregate/nested/list.cpp +++ /dev/null @@ -1,212 +0,0 @@ -#include "duckdb/common/pair.hpp" -#include "duckdb/common/types/list_segment.hpp" -#include "core_functions/aggregate/nested_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -struct ListBindData : public FunctionData { - explicit ListBindData(const LogicalType &stype_p); - ~ListBindData() override; - - LogicalType stype; - ListSegmentFunctions functions; - - unique_ptr Copy() const override { - return make_uniq(stype); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stype == other.stype; - } -}; - -ListBindData::ListBindData(const LogicalType &stype_p) : stype(stype_p) { - // always unnest once because the result vector is of type LIST - auto type = ListType::GetChildType(stype_p); - GetSegmentDataFunctions(functions, type); -} - -ListBindData::~ListBindData() { -} - -struct ListAggState { - LinkedList linked_list; -}; - -struct ListFunction { - template - static void Initialize(STATE &state) { - state.linked_list.total_capacity = 0; - state.linked_list.first_segment = nullptr; - state.linked_list.last_segment = nullptr; - } - static bool IgnoreNull() { - return false; - } -}; - -static void ListUpdateFunction(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, - Vector &state_vector, idx_t count) { - - D_ASSERT(input_count == 1); - auto &input = inputs[0]; - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input, count, input_data); - - UnifiedVectorFormat states_data; - state_vector.ToUnifiedFormat(count, states_data); - auto states = UnifiedVectorFormat::GetData(states_data); - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - - for (idx_t i = 0; i < count; i++) { - auto &state = *states[states_data.sel->get_index(i)]; - aggr_input_data.allocator.AlignNext(); - list_bind_data.functions.AppendRow(aggr_input_data.allocator, state.linked_list, input_data, i); - } -} - -static void ListAbsorbFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, - idx_t count) { - D_ASSERT(aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE); - - UnifiedVectorFormat states_data; - states_vector.ToUnifiedFormat(count, states_data); - auto states_ptr = UnifiedVectorFormat::GetData(states_data); - - auto combined_ptr = FlatVector::GetData(combined); - for (idx_t i = 0; i < count; i++) { - - auto &state = *states_ptr[states_data.sel->get_index(i)]; - if (state.linked_list.total_capacity == 0) { - // NULL, no need to append - // this can happen when adding a FILTER to the grouping, e.g., - // LIST(i) FILTER (WHERE i <> 3) - continue; - } - - if (combined_ptr[i]->linked_list.total_capacity == 0) { - combined_ptr[i]->linked_list = state.linked_list; - continue; - } - - // append the linked list - combined_ptr[i]->linked_list.last_segment->next = state.linked_list.first_segment; - combined_ptr[i]->linked_list.last_segment = state.linked_list.last_segment; - combined_ptr[i]->linked_list.total_capacity += state.linked_list.total_capacity; - } -} - -static void ListFinalize(Vector &states_vector, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - - UnifiedVectorFormat states_data; - states_vector.ToUnifiedFormat(count, states_data); - auto states = UnifiedVectorFormat::GetData(states_data); - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - auto &mask = FlatVector::Validity(result); - auto result_data = FlatVector::GetData(result); - size_t total_len = ListVector::GetListSize(result); - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - - // first iterate over all entries and set up the list entries, and get the newly required total length - for (idx_t i = 0; i < count; i++) { - - auto &state = *states[states_data.sel->get_index(i)]; - const auto rid = i + offset; - result_data[rid].offset = total_len; - if (state.linked_list.total_capacity == 0) { - mask.SetInvalid(rid); - result_data[rid].length = 0; - continue; - } - - // set the length and offset of this list in the result vector - auto total_capacity = state.linked_list.total_capacity; - result_data[rid].length = total_capacity; - total_len += total_capacity; - } - - // reserve capacity, then iterate over all entries again and copy over the data to the child vector - ListVector::Reserve(result, total_len); - auto &result_child = ListVector::GetEntry(result); - for (idx_t i = 0; i < count; i++) { - - auto &state = *states[states_data.sel->get_index(i)]; - const auto rid = i + offset; - if (state.linked_list.total_capacity == 0) { - continue; - } - - idx_t current_offset = result_data[rid].offset; - list_bind_data.functions.BuildListVector(state.linked_list, result_child, current_offset); - } - - ListVector::SetListSize(result, total_len); -} - -static void ListCombineFunction(Vector &states_vector, Vector &combined, AggregateInputData &aggr_input_data, - idx_t count) { - - // Can we use destructive combining? - if (aggr_input_data.combine_type == AggregateCombineType::ALLOW_DESTRUCTIVE) { - ListAbsorbFunction(states_vector, combined, aggr_input_data, count); - return; - } - - UnifiedVectorFormat states_data; - states_vector.ToUnifiedFormat(count, states_data); - auto states_ptr = UnifiedVectorFormat::GetData(states_data); - auto combined_ptr = FlatVector::GetData(combined); - - auto &list_bind_data = aggr_input_data.bind_data->Cast(); - auto result_type = ListType::GetChildType(list_bind_data.stype); - - for (idx_t i = 0; i < count; i++) { - auto &source = *states_ptr[states_data.sel->get_index(i)]; - auto &target = *combined_ptr[i]; - - const auto entry_count = source.linked_list.total_capacity; - Vector input(result_type, source.linked_list.total_capacity); - list_bind_data.functions.BuildListVector(source.linked_list, input, 0); - - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input, entry_count, input_data); - - for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { - aggr_input_data.allocator.AlignNext(); - list_bind_data.functions.AppendRow(aggr_input_data.allocator, target.linked_list, input_data, entry_idx); - } - } -} - -unique_ptr ListBindFunction(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - D_ASSERT(arguments.size() == 1); - D_ASSERT(function.arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - function.arguments[0] = LogicalTypeId::UNKNOWN; - function.return_type = LogicalType::SQLNULL; - return nullptr; - } - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return make_uniq(function.return_type); -} - -AggregateFunction ListFun::GetFunction() { - auto func = - AggregateFunction({LogicalType::ANY}, LogicalTypeId::LIST, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, ListUpdateFunction, - ListCombineFunction, ListFinalize, nullptr, ListBindFunction, nullptr, nullptr, nullptr); - - return func; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp deleted file mode 100644 index b4b43af21..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_avg.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "core_functions/aggregate/regression_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { -struct RegrState { - double sum; - size_t count; -}; - -struct RegrAvgFunction { - template - static void Initialize(STATE &state) { - state.sum = 0; - state.count = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target.sum += source.sum; - target.count += source.count; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.sum / (double)state.count; - } - } - static bool IgnoreNull() { - return true; - } -}; -struct RegrAvgXFunction : RegrAvgFunction { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.sum += x; - state.count++; - } -}; - -struct RegrAvgYFunction : RegrAvgFunction { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.sum += y; - state.count++; - } -}; - -AggregateFunction RegrAvgxFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction RegrAvgyFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp deleted file mode 100644 index 9215fcfb8..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_count.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "core_functions/aggregate/regression_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "core_functions/aggregate/regression/regr_count.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -AggregateFunction RegrCountFun::GetFunction() { - auto regr_count = AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::UINTEGER); - regr_count.name = "regr_count"; - regr_count.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return regr_count; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp deleted file mode 100644 index e727d2669..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_intercept.cpp +++ /dev/null @@ -1,67 +0,0 @@ -//! AVG(y)-REGR_SLOPE(y,x)*AVG(x) - -#include "core_functions/aggregate/regression_functions.hpp" -#include "core_functions/aggregate/regression/regr_slope.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct RegrInterceptState { - size_t count; - double sum_x; - double sum_y; - RegrSlopeState slope; -}; - -struct RegrInterceptOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.sum_x = 0; - state.sum_y = 0; - RegrSlopeOperation::Initialize(state.slope); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - state.count++; - state.sum_x += x; - state.sum_y += y; - RegrSlopeOperation::Operation(state.slope, y, x, idata); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - target.count += source.count; - target.sum_x += source.sum_x; - target.sum_y += source.sum_y; - RegrSlopeOperation::Combine(source.slope, target.slope, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - return; - } - RegrSlopeOperation::Finalize(state.slope, target, finalize_data); - if (Value::IsNan(target)) { - finalize_data.ReturnNull(); - return; - } - auto x_avg = state.sum_x / state.count; - auto y_avg = state.sum_y / state.count; - target = y_avg - target * x_avg; - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrInterceptFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp deleted file mode 100644 index ba89a8a6e..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_r2.cpp +++ /dev/null @@ -1,72 +0,0 @@ -// REGR_R2(y, x) -// Returns the coefficient of determination for non-null pairs in a group. -// It is computed for non-null pairs using the following formula: -// null if var_pop(x) = 0, else -// 1 if var_pop(y) = 0 and var_pop(x) <> 0, else -// power(corr(y,x), 2) - -#include "core_functions/aggregate/algebraic/corr.hpp" -#include "duckdb/function/function_set.hpp" -#include "core_functions/aggregate/regression_functions.hpp" - -namespace duckdb { -struct RegrR2State { - CorrState corr; - StddevState var_pop_x; - StddevState var_pop_y; -}; - -struct RegrR2Operation { - template - static void Initialize(STATE &state) { - CorrOperation::Initialize(state.corr); - STDDevBaseOperation::Initialize(state.var_pop_x); - STDDevBaseOperation::Initialize(state.var_pop_y); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - CorrOperation::Operation(state.corr, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop_x, x); - STDDevBaseOperation::Execute(state.var_pop_y, y); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CorrOperation::Combine(source.corr, target.corr, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop_x, target.var_pop_x, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop_y, target.var_pop_y, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - auto var_pop_x = state.var_pop_x.count > 1 ? (state.var_pop_x.dsquared / state.var_pop_x.count) : 0; - if (!Value::DoubleIsFinite(var_pop_x)) { - throw OutOfRangeException("VARPOP(X) is out of range!"); - } - if (var_pop_x == 0) { - finalize_data.ReturnNull(); - return; - } - auto var_pop_y = state.var_pop_y.count > 1 ? (state.var_pop_y.dsquared / state.var_pop_y.count) : 0; - if (!Value::DoubleIsFinite(var_pop_y)) { - throw OutOfRangeException("VARPOP(Y) is out of range!"); - } - if (var_pop_y == 0) { - target = 1; - return; - } - CorrOperation::Finalize(state.corr, target, finalize_data); - target = pow(target, 2); - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrR2Fun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp deleted file mode 100644 index c58593990..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_slope.cpp +++ /dev/null @@ -1,20 +0,0 @@ -// REGR_SLOPE(y, x) -// Returns the slope of the linear regression line for non-null pairs in a group. -// It is computed for non-null pairs using the following formula: -// COVAR_POP(x,y) / VAR_POP(x) - -//! Input : Any numeric type -//! Output : Double - -#include "core_functions/aggregate/regression/regr_slope.hpp" -#include "duckdb/function/function_set.hpp" -#include "core_functions/aggregate/regression_functions.hpp" - -namespace duckdb { - -AggregateFunction RegrSlopeFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp deleted file mode 100644 index 72202c2be..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxx_syy.cpp +++ /dev/null @@ -1,75 +0,0 @@ -// REGR_SXX(y, x) -// Returns REGR_COUNT(y, x) * VAR_POP(x) for non-null pairs. -// REGR_SYY(y, x) -// Returns REGR_COUNT(y, x) * VAR_POP(y) for non-null pairs. - -#include "core_functions/aggregate/regression/regr_count.hpp" -#include "duckdb/function/function_set.hpp" -#include "core_functions/aggregate/regression_functions.hpp" - -namespace duckdb { - -struct RegrSState { - size_t count; - StddevState var_pop; -}; - -struct RegrBaseOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - STDDevBaseOperation::Initialize(state.var_pop); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - RegrCountFunction::Combine(source.count, target.count, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.var_pop.count == 0) { - finalize_data.ReturnNull(); - return; - } - auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; - if (!Value::DoubleIsFinite(var_pop)) { - throw OutOfRangeException("VARPOP is out of range!"); - } - RegrCountFunction::Finalize(state.count, target, finalize_data); - target *= var_pop; - } - - static bool IgnoreNull() { - return true; - } -}; - -struct RegrSXXOperation : RegrBaseOperation { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop, x); - } -}; - -struct RegrSYYOperation : RegrBaseOperation { - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop, y); - } -}; - -AggregateFunction RegrSXXFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -AggregateFunction RegrSYYFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp b/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp deleted file mode 100644 index 1ab726e82..000000000 --- a/src/duckdb/extension/core_functions/aggregate/regression/regr_sxy.cpp +++ /dev/null @@ -1,53 +0,0 @@ -// REGR_SXY(y, x) -// Returns REGR_COUNT(expr1, expr2) * COVAR_POP(expr1, expr2) for non-null pairs. - -#include "core_functions/aggregate/regression/regr_count.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" -#include "core_functions/aggregate/regression_functions.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct RegrSXyState { - size_t count; - CovarState cov_pop; -}; - -struct RegrSXYOperation { - template - static void Initialize(STATE &state) { - RegrCountFunction::Initialize(state.count); - CovarOperation::Initialize(state.cov_pop); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - RegrCountFunction::Operation(state.count, y, x, idata); - CovarOperation::Operation(state.cov_pop, y, x, idata); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); - RegrCountFunction::Combine(source.count, target.count, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - CovarPopOperation::Finalize(state.cov_pop, target, finalize_data); - auto cov_pop = target; - RegrCountFunction::Finalize(state.count, target, finalize_data); - target *= cov_pop; - } - - static bool IgnoreNull() { - return true; - } -}; - -AggregateFunction RegrSXYFun::GetFunction() { - return AggregateFunction::BinaryAggregate( - LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/core_functions_extension.cpp b/src/duckdb/extension/core_functions/core_functions_extension.cpp deleted file mode 100644 index 8bf09b800..000000000 --- a/src/duckdb/extension/core_functions/core_functions_extension.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#define DUCKDB_EXTENSION_MAIN -#include "core_functions_extension.hpp" - -#include "core_functions/function_list.hpp" -#include "duckdb/main/extension_util.hpp" -#include "duckdb/function/register_function_list_helper.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" - -namespace duckdb { - -template -static void FillExtraInfo(const StaticFunctionDefinition &function, T &info) { - info.internal = true; - FillFunctionDescriptions(function, info); - info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; -} - -void LoadInternal(DuckDB &db) { - auto functions = StaticFunctionDefinition::GetFunctionList(); - for (idx_t i = 0; functions[i].name; i++) { - auto &function = functions[i]; - if (function.get_function || function.get_function_set) { - // scalar function - ScalarFunctionSet result; - if (function.get_function) { - result.AddFunction(function.get_function()); - } else { - result = function.get_function_set(); - } - result.name = function.name; - CreateScalarFunctionInfo info(result); - FillExtraInfo(function, info); - ExtensionUtil::RegisterFunction(*db.instance, std::move(info)); - } else if (function.get_aggregate_function || function.get_aggregate_function_set) { - // aggregate function - AggregateFunctionSet result; - if (function.get_aggregate_function) { - result.AddFunction(function.get_aggregate_function()); - } else { - result = function.get_aggregate_function_set(); - } - result.name = function.name; - CreateAggregateFunctionInfo info(result); - FillExtraInfo(function, info); - ExtensionUtil::RegisterFunction(*db.instance, std::move(info)); - } else { - throw InternalException("Do not know how to register function of this type"); - } - } -} - -void CoreFunctionsExtension::Load(DuckDB &db) { - LoadInternal(db); -} - -std::string CoreFunctionsExtension::Name() { - return "core_functions"; -} - -std::string CoreFunctionsExtension::Version() const { -#ifdef EXT_VERSION_CORE_FUNCTIONS - return EXT_VERSION_CORE_FUNCTIONS; -#else - return ""; -#endif -} - -} // namespace duckdb - -extern "C" { - -DUCKDB_EXTENSION_API void core_functions_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - duckdb::LoadInternal(db_wrapper); -} - -DUCKDB_EXTENSION_API const char *core_functions_version() { - return duckdb::DuckDB::LibraryVersion(); -} -} - -#ifndef DUCKDB_EXTENSION_MAIN -#error DUCKDB_EXTENSION_MAIN not defined -#endif diff --git a/src/duckdb/extension/core_functions/function_list.cpp b/src/duckdb/extension/core_functions/function_list.cpp deleted file mode 100644 index 53d96feb0..000000000 --- a/src/duckdb/extension/core_functions/function_list.cpp +++ /dev/null @@ -1,407 +0,0 @@ -#include "core_functions/function_list.hpp" -#include "core_functions/aggregate/algebraic_functions.hpp" -#include "core_functions/aggregate/distributive_functions.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" -#include "core_functions/aggregate/nested_functions.hpp" -#include "core_functions/aggregate/regression_functions.hpp" -#include "core_functions/scalar/bit_functions.hpp" -#include "core_functions/scalar/blob_functions.hpp" -#include "core_functions/scalar/date_functions.hpp" -#include "core_functions/scalar/enum_functions.hpp" -#include "core_functions/scalar/generic_functions.hpp" -#include "core_functions/scalar/list_functions.hpp" -#include "core_functions/scalar/map_functions.hpp" -#include "core_functions/scalar/math_functions.hpp" -#include "core_functions/scalar/operators_functions.hpp" -#include "core_functions/scalar/random_functions.hpp" -#include "core_functions/scalar/secret_functions.hpp" -#include "core_functions/scalar/string_functions.hpp" -#include "core_functions/scalar/struct_functions.hpp" -#include "core_functions/scalar/union_functions.hpp" -#include "core_functions/scalar/array_functions.hpp" -#include "core_functions/scalar/debug_functions.hpp" - -namespace duckdb { - -// Scalar Function -#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::GetFunction, nullptr, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Scalar Function Set -#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, _PARAM::GetFunctions, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function -#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, _PARAM::GetFunction, nullptr } -#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function Set -#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, nullptr, _PARAM::GetFunctions } -#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -#define FINAL_FUNCTION \ - { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } - -// this list is generated by scripts/generate_functions.py -static const StaticFunctionDefinition core_functions[] = { - DUCKDB_SCALAR_FUNCTION(FactorialOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseAndFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAnyFunAlias), - DUCKDB_SCALAR_FUNCTION(PowOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDistanceFunAlias), - DUCKDB_SCALAR_FUNCTION_SET(LeftShiftFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListCosineDistanceFunAlias), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias2), - DUCKDB_SCALAR_FUNCTION_SET(RightShiftFun), - DUCKDB_SCALAR_FUNCTION_SET(AbsOperatorFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasAllFunAlias), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowOperatorFunAlias), - DUCKDB_SCALAR_FUNCTION(StartsWithOperatorFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AbsFun), - DUCKDB_SCALAR_FUNCTION(AcosFun), - DUCKDB_SCALAR_FUNCTION(AcoshFun), - DUCKDB_SCALAR_FUNCTION_SET(AgeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(AggregateFun), - DUCKDB_SCALAR_FUNCTION(AliasFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ApplyFun), - DUCKDB_AGGREGATE_FUNCTION(ApproxCountDistinctFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ApproxQuantileFun), - DUCKDB_AGGREGATE_FUNCTION(ApproxTopKFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMaxNullFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ArgMinNullFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgmaxFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArgminFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(ArrayAggFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggrFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayAggregateFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayApplyFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineDistanceFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayCosineSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayCrossProductFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayDistanceFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayDistinctFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayDotProductFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayFilterFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayGradeUpFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAllFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasAnyFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayInnerProductFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayNegativeDotProductFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayNegativeInnerProductFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayReduceFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayReverseSortFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySliceFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArraySortFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayTransformFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayUniqueFun), - DUCKDB_SCALAR_FUNCTION(ArrayValueFun), - DUCKDB_SCALAR_FUNCTION(ASCIIFun), - DUCKDB_SCALAR_FUNCTION(AsinFun), - DUCKDB_SCALAR_FUNCTION(AsinhFun), - DUCKDB_SCALAR_FUNCTION(AtanFun), - DUCKDB_SCALAR_FUNCTION(Atan2Fun), - DUCKDB_SCALAR_FUNCTION(AtanhFun), - DUCKDB_AGGREGATE_FUNCTION_SET(AvgFun), - DUCKDB_SCALAR_FUNCTION_SET(BarFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(Base64Fun), - DUCKDB_SCALAR_FUNCTION_SET(BinFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitAndFun), - DUCKDB_SCALAR_FUNCTION_SET(BitCountFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitOrFun), - DUCKDB_SCALAR_FUNCTION(BitPositionFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun), - DUCKDB_SCALAR_FUNCTION_SET(BitStringFun), - DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun), - DUCKDB_AGGREGATE_FUNCTION(BoolAndFun), - DUCKDB_AGGREGATE_FUNCTION(BoolOrFun), - DUCKDB_SCALAR_FUNCTION(CanCastImplicitlyFun), - DUCKDB_SCALAR_FUNCTION(CardinalityFun), - DUCKDB_SCALAR_FUNCTION(CbrtFun), - DUCKDB_SCALAR_FUNCTION_SET(CeilFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(CeilingFun), - DUCKDB_SCALAR_FUNCTION_SET(CenturyFun), - DUCKDB_SCALAR_FUNCTION(ChrFun), - DUCKDB_AGGREGATE_FUNCTION(CorrFun), - DUCKDB_SCALAR_FUNCTION(CosFun), - DUCKDB_SCALAR_FUNCTION(CoshFun), - DUCKDB_SCALAR_FUNCTION(CotFun), - DUCKDB_AGGREGATE_FUNCTION(CountIfFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(CountifFun), - DUCKDB_AGGREGATE_FUNCTION(CovarPopFun), - DUCKDB_AGGREGATE_FUNCTION(CovarSampFun), - DUCKDB_SCALAR_FUNCTION(CurrentDatabaseFun), - DUCKDB_SCALAR_FUNCTION(CurrentQueryFun), - DUCKDB_SCALAR_FUNCTION(CurrentSchemaFun), - DUCKDB_SCALAR_FUNCTION(CurrentSchemasFun), - DUCKDB_SCALAR_FUNCTION(CurrentSettingFun), - DUCKDB_SCALAR_FUNCTION(DamerauLevenshteinFun), - DUCKDB_SCALAR_FUNCTION_SET(DateDiffFun), - DUCKDB_SCALAR_FUNCTION_SET(DatePartFun), - DUCKDB_SCALAR_FUNCTION_SET(DateSubFun), - DUCKDB_SCALAR_FUNCTION_SET(DateTruncFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatediffFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatepartFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatesubFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DatetruncFun), - DUCKDB_SCALAR_FUNCTION_SET(DayFun), - DUCKDB_SCALAR_FUNCTION_SET(DayNameFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfMonthFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(DayOfYearFun), - DUCKDB_SCALAR_FUNCTION_SET(DecadeFun), - DUCKDB_SCALAR_FUNCTION(DecodeFun), - DUCKDB_SCALAR_FUNCTION(DegreesFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(Editdist3Fun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ElementAtFun), - DUCKDB_SCALAR_FUNCTION(EncodeFun), - DUCKDB_AGGREGATE_FUNCTION_SET(EntropyFun), - DUCKDB_SCALAR_FUNCTION(EnumCodeFun), - DUCKDB_SCALAR_FUNCTION(EnumFirstFun), - DUCKDB_SCALAR_FUNCTION(EnumLastFun), - DUCKDB_SCALAR_FUNCTION(EnumRangeFun), - DUCKDB_SCALAR_FUNCTION(EnumRangeBoundaryFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochMsFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochNsFun), - DUCKDB_SCALAR_FUNCTION_SET(EpochUsFun), - DUCKDB_SCALAR_FUNCTION_SET(EquiWidthBinsFun), - DUCKDB_SCALAR_FUNCTION_SET(EraFun), - DUCKDB_SCALAR_FUNCTION(EvenFun), - DUCKDB_SCALAR_FUNCTION(ExpFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FactorialFun), - DUCKDB_AGGREGATE_FUNCTION(FAvgFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FilterFun), - DUCKDB_SCALAR_FUNCTION(ListFlattenFun), - DUCKDB_SCALAR_FUNCTION_SET(FloorFun), - DUCKDB_SCALAR_FUNCTION(FormatFun), - DUCKDB_SCALAR_FUNCTION(FormatreadabledecimalsizeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FormatreadablesizeFun), - DUCKDB_SCALAR_FUNCTION(FormatBytesFun), - DUCKDB_SCALAR_FUNCTION(FromBase64Fun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FromBinaryFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(FromHexFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(FsumFun), - DUCKDB_SCALAR_FUNCTION(GammaFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GcdFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(GenRandomUuidFun), - DUCKDB_SCALAR_FUNCTION_SET(GenerateSeriesFun), - DUCKDB_SCALAR_FUNCTION(GetBitFun), - DUCKDB_SCALAR_FUNCTION(GetCurrentTimestampFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(GradeUpFun), - DUCKDB_SCALAR_FUNCTION_SET(GreatestFun), - DUCKDB_SCALAR_FUNCTION_SET(GreatestCommonDivisorFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(GroupConcatFun), - DUCKDB_SCALAR_FUNCTION(HammingFun), - DUCKDB_SCALAR_FUNCTION(HashFun), - DUCKDB_SCALAR_FUNCTION_SET(HexFun), - DUCKDB_AGGREGATE_FUNCTION_SET(HistogramFun), - DUCKDB_AGGREGATE_FUNCTION(HistogramExactFun), - DUCKDB_SCALAR_FUNCTION_SET(HoursFun), - DUCKDB_SCALAR_FUNCTION(InSearchPathFun), - DUCKDB_SCALAR_FUNCTION(InstrFun), - DUCKDB_SCALAR_FUNCTION(IsHistogramOtherBinFun), - DUCKDB_SCALAR_FUNCTION_SET(IsFiniteFun), - DUCKDB_SCALAR_FUNCTION_SET(IsInfiniteFun), - DUCKDB_SCALAR_FUNCTION_SET(IsNanFun), - DUCKDB_SCALAR_FUNCTION_SET(ISODayOfWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(ISOYearFun), - DUCKDB_SCALAR_FUNCTION(JaccardFun), - DUCKDB_SCALAR_FUNCTION_SET(JaroSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(JaroWinklerSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(JulianDayFun), - DUCKDB_AGGREGATE_FUNCTION(KahanSumFun), - DUCKDB_AGGREGATE_FUNCTION(KurtosisFun), - DUCKDB_AGGREGATE_FUNCTION(KurtosisPopFun), - DUCKDB_SCALAR_FUNCTION_SET(LastDayFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LcmFun), - DUCKDB_SCALAR_FUNCTION_SET(LeastFun), - DUCKDB_SCALAR_FUNCTION_SET(LeastCommonMultipleFun), - DUCKDB_SCALAR_FUNCTION(LeftFun), - DUCKDB_SCALAR_FUNCTION(LeftGraphemeFun), - DUCKDB_SCALAR_FUNCTION(LevenshteinFun), - DUCKDB_SCALAR_FUNCTION(LogGammaFun), - DUCKDB_AGGREGATE_FUNCTION(ListFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListAggrFun), - DUCKDB_SCALAR_FUNCTION(ListAggregateFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListApplyFun), - DUCKDB_SCALAR_FUNCTION_SET(ListCosineDistanceFun), - DUCKDB_SCALAR_FUNCTION_SET(ListCosineSimilarityFun), - DUCKDB_SCALAR_FUNCTION_SET(ListDistanceFun), - DUCKDB_SCALAR_FUNCTION(ListDistinctFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListDotProductFun), - DUCKDB_SCALAR_FUNCTION(ListFilterFun), - DUCKDB_SCALAR_FUNCTION_SET(ListGradeUpFun), - DUCKDB_SCALAR_FUNCTION(ListHasAllFun), - DUCKDB_SCALAR_FUNCTION(ListHasAnyFun), - DUCKDB_SCALAR_FUNCTION_SET(ListInnerProductFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListNegativeDotProductFun), - DUCKDB_SCALAR_FUNCTION_SET(ListNegativeInnerProductFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListPackFun), - DUCKDB_SCALAR_FUNCTION(ListReduceFun), - DUCKDB_SCALAR_FUNCTION_SET(ListReverseSortFun), - DUCKDB_SCALAR_FUNCTION_SET(ListSliceFun), - DUCKDB_SCALAR_FUNCTION_SET(ListSortFun), - DUCKDB_SCALAR_FUNCTION(ListTransformFun), - DUCKDB_SCALAR_FUNCTION(ListUniqueFun), - DUCKDB_SCALAR_FUNCTION(ListValueFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ListaggFun), - DUCKDB_SCALAR_FUNCTION(LnFun), - DUCKDB_SCALAR_FUNCTION_SET(LogFun), - DUCKDB_SCALAR_FUNCTION(Log10Fun), - DUCKDB_SCALAR_FUNCTION(Log2Fun), - DUCKDB_SCALAR_FUNCTION(LpadFun), - DUCKDB_SCALAR_FUNCTION_SET(LtrimFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MadFun), - DUCKDB_SCALAR_FUNCTION_SET(MakeDateFun), - DUCKDB_SCALAR_FUNCTION(MakeTimeFun), - DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampFun), - DUCKDB_SCALAR_FUNCTION_SET(MakeTimestampNsFun), - DUCKDB_SCALAR_FUNCTION(MapFun), - DUCKDB_SCALAR_FUNCTION(MapConcatFun), - DUCKDB_SCALAR_FUNCTION(MapEntriesFun), - DUCKDB_SCALAR_FUNCTION(MapExtractFun), - DUCKDB_SCALAR_FUNCTION(MapFromEntriesFun), - DUCKDB_SCALAR_FUNCTION(MapKeysFun), - DUCKDB_SCALAR_FUNCTION(MapValuesFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MaxByFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MeanFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MedianFun), - DUCKDB_SCALAR_FUNCTION_SET(MicrosecondsFun), - DUCKDB_SCALAR_FUNCTION_SET(MillenniumFun), - DUCKDB_SCALAR_FUNCTION_SET(MillisecondsFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(MinByFun), - DUCKDB_SCALAR_FUNCTION_SET(MinutesFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(MismatchesFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ModeFun), - DUCKDB_SCALAR_FUNCTION_SET(MonthFun), - DUCKDB_SCALAR_FUNCTION_SET(MonthNameFun), - DUCKDB_SCALAR_FUNCTION_SET(NanosecondsFun), - DUCKDB_SCALAR_FUNCTION_SET(NextAfterFun), - DUCKDB_SCALAR_FUNCTION(NormalizedIntervalFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(NowFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(OrdFun), - DUCKDB_SCALAR_FUNCTION_SET(ParseDirnameFun), - DUCKDB_SCALAR_FUNCTION_SET(ParseDirpathFun), - DUCKDB_SCALAR_FUNCTION_SET(ParseFilenameFun), - DUCKDB_SCALAR_FUNCTION_SET(ParsePathFun), - DUCKDB_SCALAR_FUNCTION(PiFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PositionFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(PowerFun), - DUCKDB_SCALAR_FUNCTION(PrintfFun), - DUCKDB_AGGREGATE_FUNCTION(ProductFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(QuantileFun), - DUCKDB_AGGREGATE_FUNCTION_SET(QuantileContFun), - DUCKDB_AGGREGATE_FUNCTION_SET(QuantileDiscFun), - DUCKDB_SCALAR_FUNCTION_SET(QuarterFun), - DUCKDB_SCALAR_FUNCTION(RadiansFun), - DUCKDB_SCALAR_FUNCTION(RandomFun), - DUCKDB_SCALAR_FUNCTION_SET(ListRangeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ReduceFun), - DUCKDB_AGGREGATE_FUNCTION(RegrAvgxFun), - DUCKDB_AGGREGATE_FUNCTION(RegrAvgyFun), - DUCKDB_AGGREGATE_FUNCTION(RegrCountFun), - DUCKDB_AGGREGATE_FUNCTION(RegrInterceptFun), - DUCKDB_AGGREGATE_FUNCTION(RegrR2Fun), - DUCKDB_AGGREGATE_FUNCTION(RegrSlopeFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSXXFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSXYFun), - DUCKDB_AGGREGATE_FUNCTION(RegrSYYFun), - DUCKDB_SCALAR_FUNCTION_SET(RepeatFun), - DUCKDB_SCALAR_FUNCTION(ReplaceFun), - DUCKDB_AGGREGATE_FUNCTION_SET(ReservoirQuantileFun), - DUCKDB_SCALAR_FUNCTION(ReverseFun), - DUCKDB_SCALAR_FUNCTION(RightFun), - DUCKDB_SCALAR_FUNCTION(RightGraphemeFun), - DUCKDB_SCALAR_FUNCTION_SET(RoundFun), - DUCKDB_SCALAR_FUNCTION(RpadFun), - DUCKDB_SCALAR_FUNCTION_SET(RtrimFun), - DUCKDB_SCALAR_FUNCTION_SET(SecondsFun), - DUCKDB_AGGREGATE_FUNCTION(StandardErrorOfTheMeanFun), - DUCKDB_SCALAR_FUNCTION(SetBitFun), - DUCKDB_SCALAR_FUNCTION(SetseedFun), - DUCKDB_SCALAR_FUNCTION_SET(SignFun), - DUCKDB_SCALAR_FUNCTION_SET(SignBitFun), - DUCKDB_SCALAR_FUNCTION(SinFun), - DUCKDB_SCALAR_FUNCTION(SinhFun), - DUCKDB_AGGREGATE_FUNCTION(SkewnessFun), - DUCKDB_SCALAR_FUNCTION(SqrtFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StartsWithFun), - DUCKDB_SCALAR_FUNCTION(StatsFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(StddevFun), - DUCKDB_AGGREGATE_FUNCTION(StdDevPopFun), - DUCKDB_AGGREGATE_FUNCTION(StdDevSampFun), - DUCKDB_AGGREGATE_FUNCTION_SET(StringAggFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StrposFun), - DUCKDB_SCALAR_FUNCTION(StructInsertFun), - DUCKDB_AGGREGATE_FUNCTION_SET(SumFun), - DUCKDB_AGGREGATE_FUNCTION_SET(SumNoOverflowFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(SumkahanFun), - DUCKDB_SCALAR_FUNCTION(TanFun), - DUCKDB_SCALAR_FUNCTION(TanhFun), - DUCKDB_SCALAR_FUNCTION_SET(TimeBucketFun), - DUCKDB_SCALAR_FUNCTION(TimeTZSortKeyFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneHourFun), - DUCKDB_SCALAR_FUNCTION_SET(TimezoneMinuteFun), - DUCKDB_SCALAR_FUNCTION_SET(ToBaseFun), - DUCKDB_SCALAR_FUNCTION(ToBase64Fun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToBinaryFun), - DUCKDB_SCALAR_FUNCTION(ToCenturiesFun), - DUCKDB_SCALAR_FUNCTION(ToDaysFun), - DUCKDB_SCALAR_FUNCTION(ToDecadesFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ToHexFun), - DUCKDB_SCALAR_FUNCTION(ToHoursFun), - DUCKDB_SCALAR_FUNCTION(ToMicrosecondsFun), - DUCKDB_SCALAR_FUNCTION(ToMillenniaFun), - DUCKDB_SCALAR_FUNCTION(ToMillisecondsFun), - DUCKDB_SCALAR_FUNCTION(ToMinutesFun), - DUCKDB_SCALAR_FUNCTION(ToMonthsFun), - DUCKDB_SCALAR_FUNCTION(ToQuartersFun), - DUCKDB_SCALAR_FUNCTION(ToSecondsFun), - DUCKDB_SCALAR_FUNCTION(ToTimestampFun), - DUCKDB_SCALAR_FUNCTION(ToWeeksFun), - DUCKDB_SCALAR_FUNCTION(ToYearsFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(TransactionTimestampFun), - DUCKDB_SCALAR_FUNCTION(TranslateFun), - DUCKDB_SCALAR_FUNCTION_SET(TrimFun), - DUCKDB_SCALAR_FUNCTION_SET(TruncFun), - DUCKDB_SCALAR_FUNCTION(CurrentTransactionIdFun), - DUCKDB_SCALAR_FUNCTION(TypeOfFun), - DUCKDB_SCALAR_FUNCTION(UnbinFun), - DUCKDB_SCALAR_FUNCTION(UnhexFun), - DUCKDB_SCALAR_FUNCTION(UnicodeFun), - DUCKDB_SCALAR_FUNCTION(UnionExtractFun), - DUCKDB_SCALAR_FUNCTION(UnionTagFun), - DUCKDB_SCALAR_FUNCTION(UnionValueFun), - DUCKDB_SCALAR_FUNCTION(UnpivotListFun), - DUCKDB_SCALAR_FUNCTION(UrlDecodeFun), - DUCKDB_SCALAR_FUNCTION(UrlEncodeFun), - DUCKDB_SCALAR_FUNCTION(UUIDFun), - DUCKDB_AGGREGATE_FUNCTION(VarPopFun), - DUCKDB_AGGREGATE_FUNCTION(VarSampFun), - DUCKDB_AGGREGATE_FUNCTION_ALIAS(VarianceFun), - DUCKDB_SCALAR_FUNCTION(VectorTypeFun), - DUCKDB_SCALAR_FUNCTION(VersionFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekDayFun), - DUCKDB_SCALAR_FUNCTION_SET(WeekOfYearFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseXorFun), - DUCKDB_SCALAR_FUNCTION_SET(YearFun), - DUCKDB_SCALAR_FUNCTION_SET(YearWeekFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseOrFun), - DUCKDB_SCALAR_FUNCTION_SET(BitwiseNotFun), - FINAL_FUNCTION -}; - -const StaticFunctionDefinition *StaticFunctionDefinition::GetFunctionList() { - return core_functions; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp deleted file mode 100644 index 05cdfb145..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/corr.hpp +++ /dev/null @@ -1,70 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/algebraic/corr.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/aggregate_function.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" -#include "core_functions/aggregate/algebraic/stddev.hpp" - -namespace duckdb { - -struct CorrState { - CovarState cov_pop; - StddevState dev_pop_x; - StddevState dev_pop_y; -}; - -// Returns the correlation coefficient for non-null pairs in a group. -// CORR(y, x) = COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y)) -struct CorrOperation { - template - static void Initialize(STATE &state) { - CovarOperation::Initialize(state.cov_pop); - STDDevBaseOperation::Initialize(state.dev_pop_x); - STDDevBaseOperation::Initialize(state.dev_pop_y); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - CovarOperation::Operation(state.cov_pop, y, x, idata); - STDDevBaseOperation::Execute(state.dev_pop_x, x); - STDDevBaseOperation::Execute(state.dev_pop_y, y); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); - STDDevBaseOperation::Combine(source.dev_pop_x, target.dev_pop_x, aggr_input_data); - STDDevBaseOperation::Combine(source.dev_pop_y, target.dev_pop_y, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.cov_pop.count == 0 || state.dev_pop_x.count == 0 || state.dev_pop_y.count == 0) { - finalize_data.ReturnNull(); - } else { - auto cov = state.cov_pop.co_moment / state.cov_pop.count; - auto std_x = state.dev_pop_x.count > 1 ? sqrt(state.dev_pop_x.dsquared / state.dev_pop_x.count) : 0; - if (!Value::DoubleIsFinite(std_x)) { - throw OutOfRangeException("STDDEV_POP for X is out of range!"); - } - auto std_y = state.dev_pop_y.count > 1 ? sqrt(state.dev_pop_y.dsquared / state.dev_pop_y.count) : 0; - if (!Value::DoubleIsFinite(std_y)) { - throw OutOfRangeException("STDDEV_POP for Y is out of range!"); - } - target = std_x * std_y != 0 ? cov / (std_x * std_y) : NAN; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp deleted file mode 100644 index 1908dfad1..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/covar.hpp +++ /dev/null @@ -1,101 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/algebraic/covar.hpp -// -// -//===----------------------------------------------------------------------===// -// COVAR_POP(y,x) - -#pragma once - -#include "duckdb/function/aggregate_function.hpp" - -namespace duckdb { - -struct CovarState { - uint64_t count; - double meanx; - double meany; - double co_moment; -}; - -struct CovarOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.meanx = 0; - state.meany = 0; - state.co_moment = 0; - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - // update running mean and d^2 - const double n = static_cast(++(state.count)); - - const double dx = (x - state.meanx); - const double meanx = state.meanx + dx / n; - - const double dy = (y - state.meany); - const double meany = state.meany + dy / n; - - // Schubert and Gertz SSDBM 2018 (4.3) - const double C = state.co_moment + dx * (y - meany); - - state.meanx = meanx; - state.meany = meany; - state.co_moment = C; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (target.count == 0) { - target = source; - } else if (source.count > 0) { - const auto count = target.count + source.count; - D_ASSERT(count >= target.count); // This is a check that we are not overflowing - const auto target_count = static_cast(target.count); - const auto source_count = static_cast(source.count); - const auto total_count = static_cast(count); - const auto meanx = (source_count * source.meanx + target_count * target.meanx) / total_count; - const auto meany = (source_count * source.meany + target_count * target.meany) / total_count; - - // Schubert and Gertz SSDBM 2018, equation 21 - const auto deltax = target.meanx - source.meanx; - const auto deltay = target.meany - source.meany; - target.co_moment = - source.co_moment + target.co_moment + deltax * deltay * source_count * target_count / total_count; - target.meanx = meanx; - target.meany = meany; - target.count = count; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct CovarPopOperation : public CovarOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.co_moment / state.count; - } - } -}; - -struct CovarSampOperation : public CovarOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count < 2) { - finalize_data.ReturnNull(); - } else { - target = state.co_moment / (state.count - 1); - } - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp deleted file mode 100644 index bdcafae95..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic/stddev.hpp +++ /dev/null @@ -1,151 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/algebraic/stddev.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/aggregate_function.hpp" -#include - -namespace duckdb { - -struct StddevState { - uint64_t count; // n - double mean; // M1 - double dsquared; // M2 -}; - -// Streaming approximate standard deviation using Welford's -// method, DOI: 10.2307/1266577 -struct STDDevBaseOperation { - template - static void Initialize(STATE &state) { - state.count = 0; - state.mean = 0; - state.dsquared = 0; - } - - template - static void Execute(STATE &state, const INPUT_TYPE &input) { - // update running mean and d^2 - state.count++; - const double mean_differential = (input - state.mean) / state.count; - const double new_mean = state.mean + mean_differential; - const double dsquared_increment = (input - new_mean) * (input - state.mean); - const double new_dsquared = state.dsquared + dsquared_increment; - - state.mean = new_mean; - state.dsquared = new_dsquared; - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - Execute(state, input); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (target.count == 0) { - target = source; - } else if (source.count > 0) { - const auto count = target.count + source.count; - D_ASSERT(count >= target.count); // This is a check that we are not overflowing - const double target_count = static_cast(target.count); - const double source_count = static_cast(source.count); - const double total_count = static_cast(count); - const auto mean = (source_count * source.mean + target_count * target.mean) / total_count; - const auto delta = source.mean - target.mean; - target.dsquared = - source.dsquared + target.dsquared + delta * delta * source_count * target_count / total_count; - target.mean = mean; - target.count = count; - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct VarSampOperation : public STDDevBaseOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count <= 1) { - finalize_data.ReturnNull(); - } else { - target = state.dsquared / (state.count - 1); - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("VARSAMP is out of range!"); - } - } - } -}; - -struct VarPopOperation : public STDDevBaseOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.count > 1 ? (state.dsquared / state.count) : 0; - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("VARPOP is out of range!"); - } - } - } -}; - -struct STDDevSampOperation : public STDDevBaseOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count <= 1) { - finalize_data.ReturnNull(); - } else { - target = sqrt(state.dsquared / (state.count - 1)); - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("STDDEV_SAMP is out of range!"); - } - } - } -}; - -struct STDDevPopOperation : public STDDevBaseOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = state.count > 1 ? sqrt(state.dsquared / state.count) : 0; - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("STDDEV_POP is out of range!"); - } - } - } -}; - -struct StandardErrorOfTheMeanOperation : public STDDevBaseOperation { - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.count == 0) { - finalize_data.ReturnNull(); - } else { - target = sqrt(state.dsquared / state.count) / sqrt((state.count)); - if (!Value::DoubleIsFinite(target)) { - throw OutOfRangeException("SEM is out of range!"); - } - } - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp deleted file mode 100644 index da08c769a..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/algebraic_functions.hpp +++ /dev/null @@ -1,126 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/aggregate/algebraic_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct AvgFun { - static constexpr const char *Name = "avg"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Calculates the average value for all tuples in x."; - static constexpr const char *Example = "SUM(x) / COUNT(*)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct MeanFun { - using ALIAS = AvgFun; - - static constexpr const char *Name = "mean"; -}; - -struct CorrFun { - static constexpr const char *Name = "corr"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the correlation coefficient for non-null pairs in a group."; - static constexpr const char *Example = "COVAR_POP(y, x) / (STDDEV_POP(x) * STDDEV_POP(y))"; - - static AggregateFunction GetFunction(); -}; - -struct CovarPopFun { - static constexpr const char *Name = "covar_pop"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the population covariance of input values."; - static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; - - static AggregateFunction GetFunction(); -}; - -struct CovarSampFun { - static constexpr const char *Name = "covar_samp"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the sample covariance for non-null pairs in a group."; - static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / (COUNT(*) - 1)"; - - static AggregateFunction GetFunction(); -}; - -struct FAvgFun { - static constexpr const char *Name = "favg"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Calculates the average using a more accurate floating point summation (Kahan Sum)"; - static constexpr const char *Example = "favg(A)"; - - static AggregateFunction GetFunction(); -}; - -struct StandardErrorOfTheMeanFun { - static constexpr const char *Name = "sem"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the standard error of the mean"; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct StdDevPopFun { - static constexpr const char *Name = "stddev_pop"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the population standard deviation."; - static constexpr const char *Example = "sqrt(var_pop(x))"; - - static AggregateFunction GetFunction(); -}; - -struct StdDevSampFun { - static constexpr const char *Name = "stddev_samp"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the sample standard deviation"; - static constexpr const char *Example = "sqrt(var_samp(x))"; - - static AggregateFunction GetFunction(); -}; - -struct StddevFun { - using ALIAS = StdDevSampFun; - - static constexpr const char *Name = "stddev"; -}; - -struct VarPopFun { - static constexpr const char *Name = "var_pop"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the population variance."; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct VarSampFun { - static constexpr const char *Name = "var_samp"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the sample variance of all input values."; - static constexpr const char *Example = "(SUM(x^2) - SUM(x)^2 / COUNT(x)) / (COUNT(x) - 1)"; - - static AggregateFunction GetFunction(); -}; - -struct VarianceFun { - using ALIAS = VarSampFun; - - static constexpr const char *Name = "variance"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp deleted file mode 100644 index 50c0197a9..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/distributive_functions.hpp +++ /dev/null @@ -1,261 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/aggregate/distributive_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct ApproxCountDistinctFun { - static constexpr const char *Name = "approx_count_distinct"; - static constexpr const char *Parameters = "any"; - static constexpr const char *Description = "Computes the approximate count of distinct elements using HyperLogLog."; - static constexpr const char *Example = "approx_count_distinct(A)"; - - static AggregateFunction GetFunction(); -}; - -struct ArgMinFun { - static constexpr const char *Name = "arg_min"; - static constexpr const char *Parameters = "arg,val"; - static constexpr const char *Description = "Finds the row with the minimum val. Calculates the non-NULL arg expression at that row."; - static constexpr const char *Example = "arg_min(A,B)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ArgminFun { - using ALIAS = ArgMinFun; - - static constexpr const char *Name = "argmin"; -}; - -struct MinByFun { - using ALIAS = ArgMinFun; - - static constexpr const char *Name = "min_by"; -}; - -struct ArgMinNullFun { - static constexpr const char *Name = "arg_min_null"; - static constexpr const char *Parameters = "arg,val"; - static constexpr const char *Description = "Finds the row with the minimum val. Calculates the arg expression at that row."; - static constexpr const char *Example = "arg_min_null(A,B)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ArgMaxFun { - static constexpr const char *Name = "arg_max"; - static constexpr const char *Parameters = "arg,val"; - static constexpr const char *Description = "Finds the row with the maximum val. Calculates the non-NULL arg expression at that row."; - static constexpr const char *Example = "arg_max(A,B)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ArgmaxFun { - using ALIAS = ArgMaxFun; - - static constexpr const char *Name = "argmax"; -}; - -struct MaxByFun { - using ALIAS = ArgMaxFun; - - static constexpr const char *Name = "max_by"; -}; - -struct ArgMaxNullFun { - static constexpr const char *Name = "arg_max_null"; - static constexpr const char *Parameters = "arg,val"; - static constexpr const char *Description = "Finds the row with the maximum val. Calculates the arg expression at that row."; - static constexpr const char *Example = "arg_max_null(A,B)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct BitAndFun { - static constexpr const char *Name = "bit_and"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns the bitwise AND of all bits in a given expression."; - static constexpr const char *Example = "bit_and(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct BitOrFun { - static constexpr const char *Name = "bit_or"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns the bitwise OR of all bits in a given expression."; - static constexpr const char *Example = "bit_or(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct BitXorFun { - static constexpr const char *Name = "bit_xor"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns the bitwise XOR of all bits in a given expression."; - static constexpr const char *Example = "bit_xor(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct BitstringAggFun { - static constexpr const char *Name = "bitstring_agg"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns a bitstring with bits set for each distinct value."; - static constexpr const char *Example = "bitstring_agg(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct BoolAndFun { - static constexpr const char *Name = "bool_and"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns TRUE if every input value is TRUE, otherwise FALSE."; - static constexpr const char *Example = "bool_and(A)"; - - static AggregateFunction GetFunction(); -}; - -struct BoolOrFun { - static constexpr const char *Name = "bool_or"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns TRUE if any input value is TRUE, otherwise FALSE."; - static constexpr const char *Example = "bool_or(A)"; - - static AggregateFunction GetFunction(); -}; - -struct CountIfFun { - static constexpr const char *Name = "count_if"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Counts the total number of TRUE values for a boolean column"; - static constexpr const char *Example = "count_if(A)"; - - static AggregateFunction GetFunction(); -}; - -struct CountifFun { - using ALIAS = CountIfFun; - - static constexpr const char *Name = "countif"; -}; - -struct EntropyFun { - static constexpr const char *Name = "entropy"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the log-2 entropy of count input-values."; - static constexpr const char *Example = ""; - - static AggregateFunctionSet GetFunctions(); -}; - -struct KahanSumFun { - static constexpr const char *Name = "kahan_sum"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Calculates the sum using a more accurate floating point summation (Kahan Sum)."; - static constexpr const char *Example = "kahan_sum(A)"; - - static AggregateFunction GetFunction(); -}; - -struct FsumFun { - using ALIAS = KahanSumFun; - - static constexpr const char *Name = "fsum"; -}; - -struct SumkahanFun { - using ALIAS = KahanSumFun; - - static constexpr const char *Name = "sumkahan"; -}; - -struct KurtosisFun { - static constexpr const char *Name = "kurtosis"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the excess kurtosis (Fisher’s definition) of all input values, with a bias correction according to the sample size"; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct KurtosisPopFun { - static constexpr const char *Name = "kurtosis_pop"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the excess kurtosis (Fisher’s definition) of all input values, without bias correction"; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct ProductFun { - static constexpr const char *Name = "product"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Calculates the product of all tuples in arg."; - static constexpr const char *Example = "product(A)"; - - static AggregateFunction GetFunction(); -}; - -struct SkewnessFun { - static constexpr const char *Name = "skewness"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the skewness of all input values."; - static constexpr const char *Example = "skewness(A)"; - - static AggregateFunction GetFunction(); -}; - -struct StringAggFun { - static constexpr const char *Name = "string_agg"; - static constexpr const char *Parameters = "str,arg"; - static constexpr const char *Description = "Concatenates the column string values with an optional separator."; - static constexpr const char *Example = "string_agg(A, '-')"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct GroupConcatFun { - using ALIAS = StringAggFun; - - static constexpr const char *Name = "group_concat"; -}; - -struct ListaggFun { - using ALIAS = StringAggFun; - - static constexpr const char *Name = "listagg"; -}; - -struct SumFun { - static constexpr const char *Name = "sum"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Calculates the sum value for all tuples in arg."; - static constexpr const char *Example = "sum(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct SumNoOverflowFun { - static constexpr const char *Name = "sum_no_overflow"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Internal only. Calculates the sum value for all tuples in arg without overflow checks."; - static constexpr const char *Example = "sum_no_overflow(A)"; - - static AggregateFunctionSet GetFunctions(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp deleted file mode 100644 index 7d73a3caf..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/histogram_helpers.hpp +++ /dev/null @@ -1,99 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/histogram_helpers.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/function/create_sort_key.hpp" - -namespace duckdb { - -struct HistogramFunctor { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - FlatVector::GetData(result)[offset] = value; - } - - static bool CreateExtraState(idx_t count) { - return false; - } - - static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { - input.ToUnifiedFormat(count, result); - } - - template - static T ExtractValue(UnifiedVectorFormat &bin_data, idx_t offset, AggregateInputData &) { - return UnifiedVectorFormat::GetData(bin_data)[bin_data.sel->get_index(offset)]; - } - - static bool RequiresExtract() { - return false; - } -}; - -struct HistogramStringFunctorBase { - template - static T ExtractValue(UnifiedVectorFormat &bin_data, idx_t offset, AggregateInputData &aggr_input) { - auto &input_str = UnifiedVectorFormat::GetData(bin_data)[bin_data.sel->get_index(offset)]; - if (input_str.IsInlined()) { - // inlined strings can be inserted directly - return input_str; - } - // if the string is not inlined we need to allocate space for it - auto input_str_size = UnsafeNumericCast(input_str.GetSize()); - auto string_memory = aggr_input.allocator.Allocate(input_str_size); - // copy over the string - memcpy(string_memory, input_str.GetData(), input_str_size); - // now insert it into the histogram - string_t histogram_str(char_ptr_cast(string_memory), input_str_size); - return histogram_str; - } - - static bool RequiresExtract() { - return true; - } -}; - -struct HistogramStringFunctor : HistogramStringFunctorBase { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - FlatVector::GetData(result)[offset] = StringVector::AddStringOrBlob(result, value); - } - - static bool CreateExtraState(idx_t count) { - return false; - } - - static void PrepareData(Vector &input, idx_t count, bool &, UnifiedVectorFormat &result) { - input.ToUnifiedFormat(count, result); - } -}; - -struct HistogramGenericFunctor : HistogramStringFunctorBase { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - CreateSortKeyHelpers::DecodeSortKey(value, result, offset, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - } - - static Vector CreateExtraState(idx_t count) { - return Vector(LogicalType::BLOB, count); - } - - static void PrepareData(Vector &input, idx_t count, Vector &extra_state, UnifiedVectorFormat &result) { - OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, extra_state); - input.Flatten(count); - extra_state.Flatten(count); - FlatVector::Validity(extra_state).Initialize(FlatVector::Validity(input)); - extra_state.ToUnifiedFormat(count, result); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp deleted file mode 100644 index f8b96a166..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/holistic_functions.hpp +++ /dev/null @@ -1,96 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/aggregate/holistic_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct ApproxQuantileFun { - static constexpr const char *Name = "approx_quantile"; - static constexpr const char *Parameters = "x,pos"; - static constexpr const char *Description = "Computes the approximate quantile using T-Digest."; - static constexpr const char *Example = "approx_quantile(x, 0.5)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct MadFun { - static constexpr const char *Name = "mad"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the median absolute deviation for the values within x. NULL values are ignored. Temporal types return a positive INTERVAL. "; - static constexpr const char *Example = "mad(x)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct MedianFun { - static constexpr const char *Name = "median"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the middle value of the set. NULL values are ignored. For even value counts, quantitative values are averaged and ordinal values return the lower value."; - static constexpr const char *Example = "median(x)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ModeFun { - static constexpr const char *Name = "mode"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the most frequent value for the values within x. NULL values are ignored."; - static constexpr const char *Example = ""; - - static AggregateFunctionSet GetFunctions(); -}; - -struct QuantileDiscFun { - static constexpr const char *Name = "quantile_disc"; - static constexpr const char *Parameters = "x,pos"; - static constexpr const char *Description = "Returns the exact quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding exact quantiles."; - static constexpr const char *Example = "quantile_disc(x, 0.5)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct QuantileFun { - using ALIAS = QuantileDiscFun; - - static constexpr const char *Name = "quantile"; -}; - -struct QuantileContFun { - static constexpr const char *Name = "quantile_cont"; - static constexpr const char *Parameters = "x,pos"; - static constexpr const char *Description = "Returns the interpolated quantile number between 0 and 1 . If pos is a LIST of FLOATs, then the result is a LIST of the corresponding interpolated quantiles. "; - static constexpr const char *Example = "quantile_cont(x, 0.5)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ReservoirQuantileFun { - static constexpr const char *Name = "reservoir_quantile"; - static constexpr const char *Parameters = "x,quantile,sample_size"; - static constexpr const char *Description = "Gives the approximate quantile using reservoir sampling, the sample size is optional and uses 8192 as a default size."; - static constexpr const char *Example = "reservoir_quantile(A,0.5,1024)"; - - static AggregateFunctionSet GetFunctions(); -}; - -struct ApproxTopKFun { - static constexpr const char *Name = "approx_top_k"; - static constexpr const char *Parameters = "val,k"; - static constexpr const char *Description = "Finds the k approximately most occurring values in the data set"; - static constexpr const char *Example = "approx_top_k(x, 5)"; - - static AggregateFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp deleted file mode 100644 index eb83e5e15..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/nested_functions.hpp +++ /dev/null @@ -1,53 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/aggregate/nested_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct HistogramFun { - static constexpr const char *Name = "histogram"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns a LIST of STRUCTs with the fields bucket and count."; - static constexpr const char *Example = "histogram(A)"; - - static AggregateFunctionSet GetFunctions(); - static AggregateFunction GetHistogramUnorderedMap(LogicalType &type); - static AggregateFunction BinnedHistogramFunction(); -}; - -struct HistogramExactFun { - static constexpr const char *Name = "histogram_exact"; - static constexpr const char *Parameters = "arg,bins"; - static constexpr const char *Description = "Returns a LIST of STRUCTs with the fields bucket and count matching the buckets exactly."; - static constexpr const char *Example = "histogram_exact(A, [0, 1, 2])"; - - static AggregateFunction GetFunction(); -}; - -struct ListFun { - static constexpr const char *Name = "list"; - static constexpr const char *Parameters = "arg"; - static constexpr const char *Description = "Returns a LIST containing all the values of a column."; - static constexpr const char *Example = "list(A)"; - - static AggregateFunction GetFunction(); -}; - -struct ArrayAggFun { - using ALIAS = ListFun; - - static constexpr const char *Name = "array_agg"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp deleted file mode 100644 index 253657f5a..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_helpers.hpp +++ /dev/null @@ -1,65 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/quantile_helpers.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/enums/quantile_enum.hpp" -#include "core_functions/aggregate/holistic_functions.hpp" - -namespace duckdb { - -// Avoid using naked Values in inner loops... -struct QuantileValue { - explicit QuantileValue(const Value &v) : val(v), dbl(v.GetValue()) { - const auto &type = val.type(); - switch (type.id()) { - case LogicalTypeId::DECIMAL: { - integral = IntegralValue::Get(v); - scaling = Hugeint::POWERS_OF_TEN[DecimalType::GetScale(type)]; - break; - } - default: - break; - } - } - - Value val; - - // DOUBLE - double dbl; - - // DECIMAL - hugeint_t integral; - hugeint_t scaling; - - inline bool operator==(const QuantileValue &other) const { - return val == other.val; - } -}; - -struct QuantileBindData : public FunctionData { - QuantileBindData(); - explicit QuantileBindData(const Value &quantile_p); - explicit QuantileBindData(const vector &quantiles_p); - QuantileBindData(const QuantileBindData &other); - - unique_ptr Copy() const override; - bool Equals(const FunctionData &other_p) const override; - - static void Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function); - - static unique_ptr Deserialize(Deserializer &deserializer, AggregateFunction &function); - - vector quantiles; - vector order; - bool desc; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp deleted file mode 100644 index a330c0a4b..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_sort_tree.hpp +++ /dev/null @@ -1,431 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/quantile_sort_tree.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "core_functions/aggregate/quantile_helpers.hpp" -#include "duckdb/execution/merge_sort_tree.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/function/window/window_index_tree.hpp" -#include -#include -#include -#include - -namespace duckdb { - -// Paged access -template -struct QuantileCursor { - explicit QuantileCursor(const WindowPartitionInput &partition) : inputs(*partition.inputs) { - D_ASSERT(partition.column_ids.size() == 1); - inputs.InitializeScan(scan, partition.column_ids); - inputs.InitializeScanChunk(scan, page); - - D_ASSERT(partition.all_valid.size() == 1); - all_valid = partition.all_valid[0]; - } - - inline sel_t RowOffset(idx_t row_idx) const { - D_ASSERT(RowIsVisible(row_idx)); - return UnsafeNumericCast(row_idx - scan.current_row_index); - } - - inline bool RowIsVisible(idx_t row_idx) const { - return (row_idx < scan.next_row_index && scan.current_row_index <= row_idx); - } - - inline idx_t Seek(idx_t row_idx) { - if (!RowIsVisible(row_idx)) { - inputs.Seek(row_idx, scan, page); - data = FlatVector::GetData(page.data[0]); - validity = &FlatVector::Validity(page.data[0]); - } - return RowOffset(row_idx); - } - - inline const INPUT_TYPE &operator[](idx_t row_idx) { - const auto offset = Seek(row_idx); - return data[offset]; - } - - inline bool RowIsValid(idx_t row_idx) { - const auto offset = Seek(row_idx); - return validity->RowIsValid(offset); - } - - inline bool AllValid() { - return all_valid; - } - - //! Windowed paging - const ColumnDataCollection &inputs; - //! The state used for reading the collection on this thread - ColumnDataScanState scan; - //! The data chunk paged into into - DataChunk page; - //! The data pointer - const INPUT_TYPE *data = nullptr; - //! The validity mask - const ValidityMask *validity = nullptr; - //! Paged chunks do not track this but it is really necessary for performance - bool all_valid; -}; - -// Direct access -template -struct QuantileDirect { - using INPUT_TYPE = T; - using RESULT_TYPE = T; - - inline const INPUT_TYPE &operator()(const INPUT_TYPE &x) const { - return x; - } -}; - -// Indirect access -template -struct QuantileIndirect { - using INPUT_TYPE = idx_t; - using RESULT_TYPE = T; - using CURSOR = QuantileCursor; - CURSOR &data; - - explicit QuantileIndirect(CURSOR &data_p) : data(data_p) { - } - - inline RESULT_TYPE operator()(const INPUT_TYPE &input) const { - return data[input]; - } -}; - -// Composed access -template -struct QuantileComposed { - using INPUT_TYPE = typename INNER::INPUT_TYPE; - using RESULT_TYPE = typename OUTER::RESULT_TYPE; - - const OUTER &outer; - const INNER &inner; - - explicit QuantileComposed(const OUTER &outer_p, const INNER &inner_p) : outer(outer_p), inner(inner_p) { - } - - inline RESULT_TYPE operator()(const idx_t &input) const { - return outer(inner(input)); - } -}; - -// Accessed comparison -template -struct QuantileCompare { - using INPUT_TYPE = typename ACCESSOR::INPUT_TYPE; - const ACCESSOR &accessor_l; - const ACCESSOR &accessor_r; - const bool desc; - - // Single cursor for linear operations - explicit QuantileCompare(const ACCESSOR &accessor, bool desc_p) - : accessor_l(accessor), accessor_r(accessor), desc(desc_p) { - } - - // Independent cursors for sorting - explicit QuantileCompare(const ACCESSOR &accessor_l, const ACCESSOR &accessor_r, bool desc_p) - : accessor_l(accessor_l), accessor_r(accessor_r), desc(desc_p) { - } - - inline bool operator()(const INPUT_TYPE &lhs, const INPUT_TYPE &rhs) const { - const auto lval = accessor_l(lhs); - const auto rval = accessor_r(rhs); - - return desc ? (rval < lval) : (lval < rval); - } -}; - -struct CastInterpolation { - template - static inline TARGET_TYPE Cast(const INPUT_TYPE &src, Vector &result) { - return Cast::Operation(src); - } - template - static inline TARGET_TYPE Interpolate(const TARGET_TYPE &lo, const double d, const TARGET_TYPE &hi) { - const auto delta = hi - lo; - return LossyNumericCast(lo + delta * d); - } -}; - -template <> -interval_t CastInterpolation::Cast(const dtime_t &src, Vector &result); -template <> -double CastInterpolation::Interpolate(const double &lo, const double d, const double &hi); -template <> -dtime_t CastInterpolation::Interpolate(const dtime_t &lo, const double d, const dtime_t &hi); -template <> -timestamp_t CastInterpolation::Interpolate(const timestamp_t &lo, const double d, const timestamp_t &hi); -template <> -hugeint_t CastInterpolation::Interpolate(const hugeint_t &lo, const double d, const hugeint_t &hi); -template <> -interval_t CastInterpolation::Interpolate(const interval_t &lo, const double d, const interval_t &hi); -template <> -string_t CastInterpolation::Cast(const string_t &src, Vector &result); - -// Continuous interpolation -template -struct Interpolator { - Interpolator(const QuantileValue &q, const idx_t n_p, const bool desc_p) - : desc(desc_p), RN((double)(n_p - 1) * q.dbl), FRN(ExactNumericCast(floor(RN))), - CRN(ExactNumericCast(ceil(RN))), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - if (lidx == hidx) { - return CastInterpolation::Cast(accessor(lidx), result); - } else { - auto lo = CastInterpolation::Cast(accessor(lidx), result); - auto hi = CastInterpolation::Cast(accessor(hidx), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - QuantileCompare comp(accessor, desc); - if (CRN == FRN) { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return CastInterpolation::Cast(accessor(v_t[FRN]), result); - } else { - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - std::nth_element(v_t + FRN, v_t + CRN, v_t + end, comp); - auto lo = CastInterpolation::Cast(accessor(v_t[FRN]), result); - auto hi = CastInterpolation::Cast(accessor(v_t[CRN]), result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - template - inline TARGET_TYPE Extract(const INPUT_TYPE *dest, Vector &result) const { - if (CRN == FRN) { - return CastInterpolation::Cast(dest[0], result); - } else { - auto lo = CastInterpolation::Cast(dest[0], result); - auto hi = CastInterpolation::Cast(dest[1], result); - return CastInterpolation::Interpolate(lo, RN - FRN, hi); - } - } - - const bool desc; - const double RN; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - -// Discrete "interpolation" -template <> -struct Interpolator { - static inline idx_t Index(const QuantileValue &q, const idx_t n) { - idx_t floored; - switch (q.val.type().id()) { - case LogicalTypeId::DECIMAL: { - // Integer arithmetic for accuracy - const auto integral = q.integral; - const auto scaling = q.scaling; - const auto scaled_q = - DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), integral); - const auto scaled_n = - DecimalMultiplyOverflowCheck::Operation(Hugeint::Convert(n), scaling); - floored = Cast::Operation((scaled_n - scaled_q) / scaling); - break; - } - default: - const auto scaled_q = double(n) * q.dbl; - floored = LossyNumericCast(floor(double(n) - scaled_q)); - break; - } - - return MaxValue(1, n - floored) - 1; - } - - Interpolator(const QuantileValue &q, const idx_t n_p, bool desc_p) - : desc(desc_p), FRN(Index(q, n_p)), CRN(FRN), begin(0), end(n_p) { - } - - template > - TARGET_TYPE Interpolate(INPUT_TYPE lidx, INPUT_TYPE hidx, Vector &result, const ACCESSOR &accessor) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - return CastInterpolation::Cast(accessor(lidx), result); - } - - template > - typename ACCESSOR::RESULT_TYPE InterpolateInternal(INPUT_TYPE *v_t, const ACCESSOR &accessor = ACCESSOR()) const { - QuantileCompare comp(accessor, desc); - std::nth_element(v_t + begin, v_t + FRN, v_t + end, comp); - return accessor(v_t[FRN]); - } - - template > - TARGET_TYPE Operation(INPUT_TYPE *v_t, Vector &result, const ACCESSOR &accessor = ACCESSOR()) const { - using ACCESS_TYPE = typename ACCESSOR::RESULT_TYPE; - return CastInterpolation::Cast(InterpolateInternal(v_t, accessor), result); - } - - template - TARGET_TYPE Extract(const INPUT_TYPE *dest, Vector &result) const { - return CastInterpolation::Cast(dest[0], result); - } - - const bool desc; - const idx_t FRN; - const idx_t CRN; - - idx_t begin; - idx_t end; -}; - -template -struct QuantileIncluded { - using CURSOR_TYPE = QuantileCursor; - - inline explicit QuantileIncluded(const ValidityMask &fmask_p, CURSOR_TYPE &dmask_p) - : fmask(fmask_p), dmask(dmask_p) { - } - - inline bool operator()(const idx_t &idx) { - return fmask.RowIsValid(idx) && dmask.RowIsValid(idx); - } - - inline bool AllValid() { - return fmask.AllValid() && dmask.AllValid(); - } - - const ValidityMask &fmask; - CURSOR_TYPE &dmask; -}; - -struct QuantileSortTree { - - unique_ptr index_tree; - - QuantileSortTree(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition) { - // TODO: Two pass parallel sorting using Build - auto &inputs = *partition.inputs; - ColumnDataScanState scan; - DataChunk sort; - inputs.InitializeScan(scan, partition.column_ids); - inputs.InitializeScanChunk(scan, sort); - - // Sort on the single argument - auto &bind_data = aggr_input_data.bind_data->Cast(); - auto order_expr = make_uniq(Value(sort.GetTypes()[0])); - auto order_type = bind_data.desc ? OrderType::DESCENDING : OrderType::ASCENDING; - BoundOrderModifier order_bys; - order_bys.orders.emplace_back(BoundOrderByNode(order_type, OrderByNullType::NULLS_LAST, std::move(order_expr))); - vector sort_idx(1, 0); - const auto count = partition.count; - - index_tree = make_uniq(partition.context, order_bys, sort_idx, count); - auto index_state = index_tree->GetLocalState(); - auto &local_state = index_state->Cast(); - - // Build the indirection array by scanning the valid indices - const auto &filter_mask = partition.filter_mask; - SelectionVector filter_sel(STANDARD_VECTOR_SIZE); - while (inputs.Scan(scan, sort)) { - const auto row_idx = scan.current_row_index; - if (!filter_mask.AllValid() || !partition.all_valid[0]) { - auto &key = sort.data[0]; - auto &validity = FlatVector::Validity(key); - idx_t filtered = 0; - for (sel_t i = 0; i < sort.size(); ++i) { - if (filter_mask.RowIsValid(i + row_idx) && validity.RowIsValid(i)) { - filter_sel[filtered++] = i; - } - } - local_state.SinkChunk(sort, row_idx, filter_sel, filtered); - } else { - local_state.SinkChunk(sort, row_idx, nullptr, 0); - } - } - local_state.Sort(); - } - - inline idx_t SelectNth(const SubFrames &frames, size_t n) const { - return index_tree->SelectNth(frames, n); - } - - template - RESULT_TYPE WindowScalar(QuantileCursor &data, const SubFrames &frames, const idx_t n, Vector &result, - const QuantileValue &q) { - D_ASSERT(n > 0); - - // Thread safe and idempotent. - index_tree->Build(); - - // Find the interpolated indicies within the frame - Interpolator interp(q, n, false); - const auto lo_data = SelectNth(frames, interp.FRN); - auto hi_data = lo_data; - if (interp.CRN != interp.FRN) { - hi_data = SelectNth(frames, interp.CRN); - } - - // Interpolate indirectly - using ID = QuantileIndirect; - ID indirect(data); - return interp.template Interpolate(lo_data, hi_data, result, indirect); - } - - template - void WindowList(QuantileCursor &data, const SubFrames &frames, const idx_t n, Vector &list, - const idx_t lidx, const QuantileBindData &bind_data) { - D_ASSERT(n > 0); - - // Thread safe and idempotent. - index_tree->Build(); - - // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - using ID = QuantileIndirect; - ID indirect(data); - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - Interpolator interp(quantile, n, false); - - const auto lo_data = SelectNth(frames, interp.FRN); - auto hi_data = lo_data; - if (interp.CRN != interp.FRN) { - hi_data = SelectNth(frames, interp.CRN); - } - - // Interpolate indirectly - rdata[lentry.offset + q] = - interp.template Interpolate(lo_data, hi_data, result, indirect); - } - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp deleted file mode 100644 index 00f4baf77..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/quantile_state.hpp +++ /dev/null @@ -1,307 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/quantile_state.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "core_functions/aggregate/quantile_sort_tree.hpp" -#include "SkipList.h" - -namespace duckdb { - -struct QuantileOperation { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - for (idx_t i = 0; i < count; i++) { - Operation(state, input, unary_input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &aggr_input) { - state.AddElement(input, aggr_input.input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (source.v.empty()) { - return; - } - target.v.insert(target.v.end(), source.v.begin(), source.v.end()); - } - - template - static void Destroy(STATE &state, AggregateInputData &) { - state.~STATE(); - } - - static bool IgnoreNull() { - return true; - } - - template - static void WindowInit(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - data_ptr_t g_state) { - D_ASSERT(partition.inputs); - - const auto &stats = partition.stats; - - // If frames overlap significantly, then use local skip lists. - if (stats[0].end <= stats[1].begin) { - // Frames can overlap - const auto overlap = double(stats[1].begin - stats[0].end); - const auto cover = double(stats[1].end - stats[0].begin); - const auto ratio = overlap / cover; - if (ratio > .75) { - return; - } - } - - // Build the tree - auto &state = *reinterpret_cast(g_state); - auto &window_state = state.GetOrCreateWindowState(); - window_state.qst = make_uniq(aggr_input_data, partition); - } - - template - static idx_t FrameSize(QuantileIncluded &included, const SubFrames &frames) { - // Count the number of valid values - idx_t n = 0; - if (included.AllValid()) { - for (const auto &frame : frames) { - n += frame.end - frame.start; - } - } else { - // NULLs or FILTERed values, - for (const auto &frame : frames) { - for (auto i = frame.start; i < frame.end; ++i) { - n += included(i); - } - } - } - - return n; - } -}; - -template -struct SkipLess { - inline bool operator()(const T &lhi, const T &rhi) const { - return lhi.second < rhi.second; - } -}; - -template -struct WindowQuantileState { - // Windowed Quantile merge sort trees - unique_ptr qst; - - // Windowed Quantile skip lists - using SkipType = pair; - using SkipListType = duckdb_skiplistlib::skip_list::HeadNode>; - SubFrames prevs; - unique_ptr s; - mutable vector skips; - - // Windowed MAD indirection - idx_t count; - vector m; - - using IncludedType = QuantileIncluded; - using CursorType = QuantileCursor; - - WindowQuantileState() : count(0) { - } - - inline void SetCount(size_t count_p) { - count = count_p; - if (count >= m.size()) { - m.resize(count); - } - } - - inline SkipListType &GetSkipList(bool reset = false) { - if (reset || !s) { - s.reset(); - s = make_uniq(); - } - return *s; - } - - struct SkipListUpdater { - SkipListType &skip; - CursorType &data; - IncludedType &included; - - inline SkipListUpdater(SkipListType &skip, CursorType &data, IncludedType &included) - : skip(skip), data(data), included(included) { - } - - inline void Neither(idx_t begin, idx_t end) { - } - - inline void Left(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - skip.remove(SkipType(begin, data[begin])); - } - } - } - - inline void Right(idx_t begin, idx_t end) { - for (; begin < end; ++begin) { - if (included(begin)) { - skip.insert(SkipType(begin, data[begin])); - } - } - } - - inline void Both(idx_t begin, idx_t end) { - } - }; - - void UpdateSkip(CursorType &data, const SubFrames &frames, IncludedType &included) { - // No overlap, or no data - if (!s || prevs.back().end <= frames.front().start || frames.back().end <= prevs.front().start) { - auto &skip = GetSkipList(true); - for (const auto &frame : frames) { - for (auto i = frame.start; i < frame.end; ++i) { - if (included(i)) { - skip.insert(SkipType(i, data[i])); - } - } - } - } else { - auto &skip = GetSkipList(); - SkipListUpdater updater(skip, data, included); - AggregateExecutor::IntersectFrames(prevs, frames, updater); - } - } - - bool HasTree() const { - return qst.get(); - } - - template - RESULT_TYPE WindowScalar(CursorType &data, const SubFrames &frames, const idx_t n, Vector &result, - const QuantileValue &q) const { - D_ASSERT(n > 0); - if (qst) { - return qst->WindowScalar(data, frames, n, result, q); - } else if (s) { - // Find the position(s) needed - try { - Interpolator interp(q, s->size(), false); - s->at(interp.FRN, interp.CRN - interp.FRN + 1, skips); - array dest; - dest[0] = skips[0].second; - if (skips.size() > 1) { - dest[1] = skips[1].second; - } - return interp.template Extract(dest.data(), result); - } catch (const duckdb_skiplistlib::skip_list::IndexError &idx_err) { - throw InternalException(idx_err.message()); - } - } else { - throw InternalException("No accelerator for scalar QUANTILE"); - } - } - - template - void WindowList(CursorType &data, const SubFrames &frames, const idx_t n, Vector &list, const idx_t lidx, - const QuantileBindData &bind_data) const { - D_ASSERT(n > 0); - // Result is a constant LIST with a fixed length - auto ldata = FlatVector::GetData(list); - auto &lentry = ldata[lidx]; - lentry.offset = ListVector::GetListSize(list); - lentry.length = bind_data.quantiles.size(); - - ListVector::Reserve(list, lentry.offset + lentry.length); - ListVector::SetListSize(list, lentry.offset + lentry.length); - auto &result = ListVector::GetEntry(list); - auto rdata = FlatVector::GetData(result); - - for (const auto &q : bind_data.order) { - const auto &quantile = bind_data.quantiles[q]; - rdata[lentry.offset + q] = WindowScalar(data, frames, n, result, quantile); - } - } -}; - -struct QuantileStandardType { - template - static T Operation(T input, AggregateInputData &) { - return input; - } -}; - -struct QuantileStringType { - template - static T Operation(T input, AggregateInputData &input_data) { - if (input.IsInlined()) { - return input; - } - auto string_data = input_data.allocator.Allocate(input.GetSize()); - memcpy(string_data, input.GetData(), input.GetSize()); - return string_t(char_ptr_cast(string_data), UnsafeNumericCast(input.GetSize())); - } -}; - -template -struct QuantileState { - using InputType = INPUT_TYPE; - using CursorType = QuantileCursor; - - // Regular aggregation - vector v; - - // Window Quantile State - unique_ptr> window_state; - unique_ptr window_cursor; - - void AddElement(INPUT_TYPE element, AggregateInputData &aggr_input) { - v.emplace_back(TYPE_OP::Operation(element, aggr_input)); - } - - bool HasTree() const { - return window_state && window_state->HasTree(); - } - WindowQuantileState &GetOrCreateWindowState() { - if (!window_state) { - window_state = make_uniq>(); - } - return *window_state; - } - WindowQuantileState &GetWindowState() { - return *window_state; - } - const WindowQuantileState &GetWindowState() const { - return *window_state; - } - - CursorType &GetOrCreateWindowCursor(const WindowPartitionInput &partition) { - if (!window_cursor) { - window_cursor = make_uniq(partition); - } - return *window_cursor; - } - CursorType &GetWindowCursor() { - return *window_cursor; - } - const CursorType &GetWindowCursor() const { - return *window_cursor; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp deleted file mode 100644 index 40366ef6a..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_count.hpp +++ /dev/null @@ -1,42 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/regression/regr_count.hpp -// -// -//===----------------------------------------------------------------------===// -// REGR_COUNT(y, x) - -#pragma once - -#include "duckdb/function/aggregate_function.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" -#include "core_functions/aggregate/algebraic/stddev.hpp" - -namespace duckdb { - -struct RegrCountFunction { - template - static void Initialize(STATE &state) { - state = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target += source; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = static_cast(state); - } - static bool IgnoreNull() { - return true; - } - template - static void Operation(STATE &state, const A_TYPE &, const B_TYPE &, AggregateBinaryInput &) { - state += 1; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp deleted file mode 100644 index d89af040e..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression/regr_slope.hpp +++ /dev/null @@ -1,57 +0,0 @@ -// REGR_SLOPE(y, x) -// Returns the slope of the linear regression line for non-null pairs in a group. -// It is computed for non-null pairs using the following formula: -// COVAR_POP(x,y) / VAR_POP(x) - -//! Input : Any numeric type -//! Output : Double - -#pragma once -#include "core_functions/aggregate/algebraic/stddev.hpp" -#include "core_functions/aggregate/algebraic/covar.hpp" - -namespace duckdb { - -struct RegrSlopeState { - CovarState cov_pop; - StddevState var_pop; -}; - -struct RegrSlopeOperation { - template - static void Initialize(STATE &state) { - CovarOperation::Initialize(state.cov_pop); - STDDevBaseOperation::Initialize(state.var_pop); - } - - template - static void Operation(STATE &state, const A_TYPE &y, const B_TYPE &x, AggregateBinaryInput &idata) { - CovarOperation::Operation(state.cov_pop, y, x, idata); - STDDevBaseOperation::Execute(state.var_pop, x); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - CovarOperation::Combine(source.cov_pop, target.cov_pop, aggr_input_data); - STDDevBaseOperation::Combine(source.var_pop, target.var_pop, aggr_input_data); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (state.cov_pop.count == 0 || state.var_pop.count == 0) { - finalize_data.ReturnNull(); - } else { - auto cov = state.cov_pop.co_moment / state.cov_pop.count; - auto var_pop = state.var_pop.count > 1 ? (state.var_pop.dsquared / state.var_pop.count) : 0; - if (!Value::DoubleIsFinite(var_pop)) { - throw OutOfRangeException("VARPOP is out of range!"); - } - target = var_pop != 0 ? cov / var_pop : NAN; - } - } - - static bool IgnoreNull() { - return true; - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp deleted file mode 100644 index e82b9fdff..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/regression_functions.hpp +++ /dev/null @@ -1,99 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/aggregate/regression_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct RegrAvgxFun { - static constexpr const char *Name = "regr_avgx"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the average of the independent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct RegrAvgyFun { - static constexpr const char *Name = "regr_avgy"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the average of the dependent variable for non-null pairs in a group, where x is the independent variable and y is the dependent variable."; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct RegrCountFun { - static constexpr const char *Name = "regr_count"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the number of non-null number pairs in a group."; - static constexpr const char *Example = "(SUM(x*y) - SUM(x) * SUM(y) / COUNT(*)) / COUNT(*)"; - - static AggregateFunction GetFunction(); -}; - -struct RegrInterceptFun { - static constexpr const char *Name = "regr_intercept"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the intercept of the univariate linear regression line for non-null pairs in a group."; - static constexpr const char *Example = "AVG(y)-REGR_SLOPE(y,x)*AVG(x)"; - - static AggregateFunction GetFunction(); -}; - -struct RegrR2Fun { - static constexpr const char *Name = "regr_r2"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the coefficient of determination for non-null pairs in a group."; - static constexpr const char *Example = ""; - - static AggregateFunction GetFunction(); -}; - -struct RegrSlopeFun { - static constexpr const char *Name = "regr_slope"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the slope of the linear regression line for non-null pairs in a group."; - static constexpr const char *Example = "COVAR_POP(x,y) / VAR_POP(x)"; - - static AggregateFunction GetFunction(); -}; - -struct RegrSXXFun { - static constexpr const char *Name = "regr_sxx"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = ""; - static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(x)"; - - static AggregateFunction GetFunction(); -}; - -struct RegrSXYFun { - static constexpr const char *Name = "regr_sxy"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Returns the population covariance of input values"; - static constexpr const char *Example = "REGR_COUNT(y, x) * COVAR_POP(y, x)"; - - static AggregateFunction GetFunction(); -}; - -struct RegrSYYFun { - static constexpr const char *Name = "regr_syy"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = ""; - static constexpr const char *Example = "REGR_COUNT(y, x) * VAR_POP(y)"; - - static AggregateFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp b/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp deleted file mode 100644 index 562f61ade..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/aggregate/sum_helpers.hpp +++ /dev/null @@ -1,175 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/aggregate/sum_helpers.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/common.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -static inline void KahanAddInternal(double input, double &summed, double &err) { - double diff = input - err; - double newval = summed + diff; - err = (newval - summed) - diff; - summed = newval; -} - -template -struct SumState { - bool isset; - T value; - - void Initialize() { - this->isset = false; - } - - void Combine(const SumState &other) { - this->isset = other.isset || this->isset; - this->value += other.value; - } -}; - -struct KahanSumState { - bool isset; - double value; - double err; - - void Initialize() { - this->isset = false; - this->err = 0.0; - } - - void Combine(const KahanSumState &other) { - this->isset = other.isset || this->isset; - KahanAddInternal(other.value, this->value, this->err); - KahanAddInternal(other.err, this->value, this->err); - } -}; - -struct RegularAdd { - template - static void AddNumber(STATE &state, T input) { - state.value += input; - } - - template - static void AddConstant(STATE &state, T input, idx_t count) { - state.value += input * int64_t(count); - } -}; - -struct HugeintAdd { - template - static void AddNumber(STATE &state, T input) { - state.value = Hugeint::Add(state.value, input); - } - - template - static void AddConstant(STATE &state, T input, idx_t count) { - AddNumber(state, Hugeint::Multiply(input, UnsafeNumericCast(count))); - } -}; - -struct KahanAdd { - template - static void AddNumber(STATE &state, T input) { - KahanAddInternal(input, state.value, state.err); - } - - template - static void AddConstant(STATE &state, T input, idx_t count) { - KahanAddInternal(input * count, state.value, state.err); - } -}; - -struct AddToHugeint { - static void AddValue(hugeint_t &result, uint64_t value, int positive) { - // integer summation taken from Tim Gubner et al. - Efficient Query Processing - // with Optimistically Compressed Hash Tables & Strings in the USSR - - // add the value to the lower part of the hugeint - result.lower += value; - // now handle overflows - int overflow = result.lower < value; - // we consider two situations: - // (1) input[idx] is positive, and current value is lower than value: overflow - // (2) input[idx] is negative, and current value is higher than value: underflow - if (!(overflow ^ positive)) { - // in the case of an overflow or underflow we either increment or decrement the upper base - // positive: +1, negative: -1 - result.upper += -1 + 2 * positive; - } - } - - template - static void AddNumber(STATE &state, T input) { - AddValue(state.value, uint64_t(input), input >= 0); - } - - template - static void AddConstant(STATE &state, T input, idx_t count) { - // add a constant X number of times - // fast path: check if value * count fits into a uint64_t - // note that we check if value * VECTOR_SIZE fits in a uint64_t to avoid having to actually do a division - // this is still a pretty high number (18014398509481984) so most positive numbers will fit - if (input >= 0 && uint64_t(input) < (NumericLimits::Maximum() / STANDARD_VECTOR_SIZE)) { - // if it does just multiply it and add the value - uint64_t value = uint64_t(input) * count; - AddValue(state.value, value, 1); - } else { - // if it doesn't fit we have two choices - // either we loop over count and add the values individually - // or we convert to a hugeint and multiply the hugeint - // the problem is that hugeint multiplication is expensive - // hence we switch here: with a low count we do the loop - // with a high count we do the hugeint multiplication - if (count < 8) { - for (idx_t i = 0; i < count; i++) { - AddValue(state.value, uint64_t(input), input >= 0); - } - } else { - hugeint_t addition = hugeint_t(input) * Hugeint::Convert(count); - state.value += addition; - } - } - } -}; - -template -struct BaseSumOperation { - template - static void Initialize(STATE &state) { - state.value = 0; - STATEOP::template Initialize(state); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - STATEOP::template Combine(source, target, aggr_input_data); - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &) { - STATEOP::template AddValues(state, 1); - ADDOP::template AddNumber(state, input); - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &, idx_t count) { - STATEOP::template AddValues(state, count); - ADDOP::template AddConstant(state, input, count); - } - - static bool IgnoreNull() { - return true; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp b/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp deleted file mode 100644 index dd6e29153..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/array_kernels.hpp +++ /dev/null @@ -1,107 +0,0 @@ -#pragma once -#include "duckdb/common/typedefs.hpp" -#include "duckdb/common/algorithm.hpp" -#include - -namespace duckdb { - -//------------------------------------------------------------------------- -// Folding Operations -//------------------------------------------------------------------------- -struct InnerProductOp { - static constexpr bool ALLOW_EMPTY = true; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - - TYPE result = 0; - - auto lhs_ptr = lhs_data; - auto rhs_ptr = rhs_data; - - for (idx_t i = 0; i < count; i++) { - const auto x = *lhs_ptr++; - const auto y = *rhs_ptr++; - result += x * y; - } - - return result; - } -}; - -struct NegativeInnerProductOp { - static constexpr bool ALLOW_EMPTY = true; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - return -InnerProductOp::Operation(lhs_data, rhs_data, count); - } -}; - -struct CosineSimilarityOp { - static constexpr bool ALLOW_EMPTY = false; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - - TYPE distance = 0; - TYPE norm_l = 0; - TYPE norm_r = 0; - - auto l_ptr = lhs_data; - auto r_ptr = rhs_data; - - for (idx_t i = 0; i < count; i++) { - const auto x = *l_ptr++; - const auto y = *r_ptr++; - distance += x * y; - norm_l += x * x; - norm_r += y * y; - } - - auto similarity = distance / std::sqrt(norm_l * norm_r); - return std::max(static_cast(-1.0), std::min(similarity, static_cast(1.0))); - } -}; - -struct CosineDistanceOp { - static constexpr bool ALLOW_EMPTY = false; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - return static_cast(1.0) - CosineSimilarityOp::Operation(lhs_data, rhs_data, count); - } -}; - -struct DistanceSquaredOp { - static constexpr bool ALLOW_EMPTY = true; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - - TYPE distance = 0; - - auto l_ptr = lhs_data; - auto r_ptr = rhs_data; - - for (idx_t i = 0; i < count; i++) { - const auto x = *l_ptr++; - const auto y = *r_ptr++; - const auto diff = x - y; - distance += diff * diff; - } - - return distance; - } -}; - -struct DistanceOp { - static constexpr bool ALLOW_EMPTY = true; - - template - static TYPE Operation(const TYPE *lhs_data, const TYPE *rhs_data, const idx_t count) { - return std::sqrt(DistanceSquaredOp::Operation(lhs_data, rhs_data, count)); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp b/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp deleted file mode 100644 index 024ca49f8..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/function_list.hpp +++ /dev/null @@ -1,33 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/function_list.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" - -namespace duckdb { - -typedef ScalarFunction (*get_scalar_function_t)(); -typedef ScalarFunctionSet (*get_scalar_function_set_t)(); -typedef AggregateFunction (*get_aggregate_function_t)(); -typedef AggregateFunctionSet (*get_aggregate_function_set_t)(); - -struct StaticFunctionDefinition { - const char *name; - const char *parameters; - const char *description; - const char *example; - get_scalar_function_t get_function; - get_scalar_function_set_t get_function_set; - get_aggregate_function_t get_aggregate_function; - get_aggregate_function_set_t get_aggregate_function_set; - - static const StaticFunctionDefinition *GetFunctionList(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp deleted file mode 100644 index 561643be4..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/array_functions.hpp +++ /dev/null @@ -1,93 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/array_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct ArrayValueFun { - static constexpr const char *Name = "array_value"; - static constexpr const char *Parameters = "any,..."; - static constexpr const char *Description = "Create an ARRAY containing the argument values."; - static constexpr const char *Example = "array_value(4, 5, 6)"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayCrossProductFun { - static constexpr const char *Name = "array_cross_product"; - static constexpr const char *Parameters = "array, array"; - static constexpr const char *Description = "Compute the cross product of two arrays of size 3. The array elements can not be NULL."; - static constexpr const char *Example = "array_cross_product([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayCosineSimilarityFun { - static constexpr const char *Name = "array_cosine_similarity"; - static constexpr const char *Parameters = "array1,array2"; - static constexpr const char *Description = "Compute the cosine similarity between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; - static constexpr const char *Example = "array_cosine_similarity([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayCosineDistanceFun { - static constexpr const char *Name = "array_cosine_distance"; - static constexpr const char *Parameters = "array1,array2"; - static constexpr const char *Description = "Compute the cosine distance between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; - static constexpr const char *Example = "array_cosine_distance([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayDistanceFun { - static constexpr const char *Name = "array_distance"; - static constexpr const char *Parameters = "array1,array2"; - static constexpr const char *Description = "Compute the distance between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; - static constexpr const char *Example = "array_distance([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayInnerProductFun { - static constexpr const char *Name = "array_inner_product"; - static constexpr const char *Parameters = "array1,array2"; - static constexpr const char *Description = "Compute the inner product between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; - static constexpr const char *Example = "array_inner_product([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayDotProductFun { - using ALIAS = ArrayInnerProductFun; - - static constexpr const char *Name = "array_dot_product"; -}; - -struct ArrayNegativeInnerProductFun { - static constexpr const char *Name = "array_negative_inner_product"; - static constexpr const char *Parameters = "array1,array2"; - static constexpr const char *Description = "Compute the negative inner product between two arrays of the same size. The array elements can not be NULL. The arrays can have any size as long as the size is the same for both arguments."; - static constexpr const char *Example = "array_negative_inner_product([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayNegativeDotProductFun { - using ALIAS = ArrayNegativeInnerProductFun; - - static constexpr const char *Name = "array_negative_dot_product"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp deleted file mode 100644 index e01a2fc58..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/bit_functions.hpp +++ /dev/null @@ -1,54 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/bit_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct GetBitFun { - static constexpr const char *Name = "get_bit"; - static constexpr const char *Parameters = "bitstring,index"; - static constexpr const char *Description = "Extracts the nth bit from bitstring; the first (leftmost) bit is indexed 0"; - static constexpr const char *Example = "get_bit('0110010'::BIT, 2)"; - - static ScalarFunction GetFunction(); -}; - -struct SetBitFun { - static constexpr const char *Name = "set_bit"; - static constexpr const char *Parameters = "bitstring,index,new_value"; - static constexpr const char *Description = "Sets the nth bit in bitstring to newvalue; the first (leftmost) bit is indexed 0. Returns a new bitstring"; - static constexpr const char *Example = "set_bit('0110010'::BIT, 2, 0)"; - - static ScalarFunction GetFunction(); -}; - -struct BitPositionFun { - static constexpr const char *Name = "bit_position"; - static constexpr const char *Parameters = "substring,bitstring"; - static constexpr const char *Description = "Returns first starting index of the specified substring within bits, or zero if it is not present. The first (leftmost) bit is indexed 1"; - static constexpr const char *Example = "bit_position('010'::BIT, '1110101'::BIT)"; - - static ScalarFunction GetFunction(); -}; - -struct BitStringFun { - static constexpr const char *Name = "bitstring"; - static constexpr const char *Parameters = "bitstring,length"; - static constexpr const char *Description = "Pads the bitstring until the specified length"; - static constexpr const char *Example = "bitstring('1010'::BIT, 7)"; - - static ScalarFunctionSet GetFunctions(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp deleted file mode 100644 index 051e212c1..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/blob_functions.hpp +++ /dev/null @@ -1,60 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/blob_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct DecodeFun { - static constexpr const char *Name = "decode"; - static constexpr const char *Parameters = "blob"; - static constexpr const char *Description = "Convert blob to varchar. Fails if blob is not valid utf-8"; - static constexpr const char *Example = "decode('\\xC3\\xBC'::BLOB)"; - - static ScalarFunction GetFunction(); -}; - -struct EncodeFun { - static constexpr const char *Name = "encode"; - static constexpr const char *Parameters = "string"; - static constexpr const char *Description = "Convert varchar to blob. Converts utf-8 characters into literal encoding"; - static constexpr const char *Example = "encode('my_string_with_ü')"; - - static ScalarFunction GetFunction(); -}; - -struct FromBase64Fun { - static constexpr const char *Name = "from_base64"; - static constexpr const char *Parameters = "string"; - static constexpr const char *Description = "Convert a base64 encoded string to a character string"; - static constexpr const char *Example = "from_base64('QQ==')"; - - static ScalarFunction GetFunction(); -}; - -struct ToBase64Fun { - static constexpr const char *Name = "to_base64"; - static constexpr const char *Parameters = "blob"; - static constexpr const char *Description = "Convert a blob to a base64 encoded string"; - static constexpr const char *Example = "base64('A'::blob)"; - - static ScalarFunction GetFunction(); -}; - -struct Base64Fun { - using ALIAS = ToBase64Fun; - - static constexpr const char *Name = "base64"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp deleted file mode 100644 index 7256502a9..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/date_functions.hpp +++ /dev/null @@ -1,603 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/date_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct AgeFun { - static constexpr const char *Name = "age"; - static constexpr const char *Parameters = "timestamp,timestamp"; - static constexpr const char *Description = "Subtract arguments, resulting in the time difference between the two timestamps"; - static constexpr const char *Example = "age(TIMESTAMP '2001-04-10', TIMESTAMP '1992-09-20')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct CenturyFun { - static constexpr const char *Name = "century"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the century component from a date or timestamp"; - static constexpr const char *Example = "century(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DateDiffFun { - static constexpr const char *Name = "date_diff"; - static constexpr const char *Parameters = "part,startdate,enddate"; - static constexpr const char *Description = "The number of partition boundaries between the timestamps"; - static constexpr const char *Example = "date_diff('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DatediffFun { - using ALIAS = DateDiffFun; - - static constexpr const char *Name = "datediff"; -}; - -struct DatePartFun { - static constexpr const char *Name = "date_part"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Get subfield (equivalent to extract)"; - static constexpr const char *Example = "date_part('minute', TIMESTAMP '1992-09-20 20:38:40')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DatepartFun { - using ALIAS = DatePartFun; - - static constexpr const char *Name = "datepart"; -}; - -struct DateSubFun { - static constexpr const char *Name = "date_sub"; - static constexpr const char *Parameters = "part,startdate,enddate"; - static constexpr const char *Description = "The number of complete partitions between the timestamps"; - static constexpr const char *Example = "date_sub('hour', TIMESTAMPTZ '1992-09-30 23:59:59', TIMESTAMPTZ '1992-10-01 01:58:00')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DatesubFun { - using ALIAS = DateSubFun; - - static constexpr const char *Name = "datesub"; -}; - -struct DateTruncFun { - static constexpr const char *Name = "date_trunc"; - static constexpr const char *Parameters = "part,timestamp"; - static constexpr const char *Description = "Truncate to specified precision"; - static constexpr const char *Example = "date_trunc('hour', TIMESTAMPTZ '1992-09-20 20:38:40')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DatetruncFun { - using ALIAS = DateTruncFun; - - static constexpr const char *Name = "datetrunc"; -}; - -struct DayFun { - static constexpr const char *Name = "day"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the day component from a date or timestamp"; - static constexpr const char *Example = "day(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DayNameFun { - static constexpr const char *Name = "dayname"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "The (English) name of the weekday"; - static constexpr const char *Example = "dayname(TIMESTAMP '1992-03-22')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DayOfMonthFun { - static constexpr const char *Name = "dayofmonth"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the dayofmonth component from a date or timestamp"; - static constexpr const char *Example = "dayofmonth(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DayOfWeekFun { - static constexpr const char *Name = "dayofweek"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the dayofweek component from a date or timestamp"; - static constexpr const char *Example = "dayofweek(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DayOfYearFun { - static constexpr const char *Name = "dayofyear"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the dayofyear component from a date or timestamp"; - static constexpr const char *Example = "dayofyear(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct DecadeFun { - static constexpr const char *Name = "decade"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the decade component from a date or timestamp"; - static constexpr const char *Example = "decade(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct EpochFun { - static constexpr const char *Name = "epoch"; - static constexpr const char *Parameters = "temporal"; - static constexpr const char *Description = "Extract the epoch component from a temporal type"; - static constexpr const char *Example = "epoch(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct EpochMsFun { - static constexpr const char *Name = "epoch_ms"; - static constexpr const char *Parameters = "temporal"; - static constexpr const char *Description = "Extract the epoch component in milliseconds from a temporal type"; - static constexpr const char *Example = "epoch_ms(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct EpochUsFun { - static constexpr const char *Name = "epoch_us"; - static constexpr const char *Parameters = "temporal"; - static constexpr const char *Description = "Extract the epoch component in microseconds from a temporal type"; - static constexpr const char *Example = "epoch_us(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct EpochNsFun { - static constexpr const char *Name = "epoch_ns"; - static constexpr const char *Parameters = "temporal"; - static constexpr const char *Description = "Extract the epoch component in nanoseconds from a temporal type"; - static constexpr const char *Example = "epoch_ns(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct EraFun { - static constexpr const char *Name = "era"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the era component from a date or timestamp"; - static constexpr const char *Example = "era(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct GetCurrentTimestampFun { - static constexpr const char *Name = "get_current_timestamp"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the current timestamp"; - static constexpr const char *Example = "get_current_timestamp()"; - - static ScalarFunction GetFunction(); -}; - -struct NowFun { - using ALIAS = GetCurrentTimestampFun; - - static constexpr const char *Name = "now"; -}; - -struct TransactionTimestampFun { - using ALIAS = GetCurrentTimestampFun; - - static constexpr const char *Name = "transaction_timestamp"; -}; - -struct HoursFun { - static constexpr const char *Name = "hour"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the hour component from a date or timestamp"; - static constexpr const char *Example = "hour(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ISODayOfWeekFun { - static constexpr const char *Name = "isodow"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the isodow component from a date or timestamp"; - static constexpr const char *Example = "isodow(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ISOYearFun { - static constexpr const char *Name = "isoyear"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the isoyear component from a date or timestamp"; - static constexpr const char *Example = "isoyear(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct JulianDayFun { - static constexpr const char *Name = "julian"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the Julian Day number from a date or timestamp"; - static constexpr const char *Example = "julian(timestamp '2006-01-01 12:00')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct LastDayFun { - static constexpr const char *Name = "last_day"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Returns the last day of the month"; - static constexpr const char *Example = "last_day(TIMESTAMP '1992-03-22 01:02:03.1234')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MakeDateFun { - static constexpr const char *Name = "make_date"; - static constexpr const char *Parameters = "year,month,day\1date-struct::STRUCT(year BIGINT, month BIGINT, day BIGINT)"; - static constexpr const char *Description = "The date for the given parts\1The date for the given struct."; - static constexpr const char *Example = "make_date(1992, 9, 20)\1make_date({'year': 2024, 'month': 11, 'day': 14})"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MakeTimeFun { - static constexpr const char *Name = "make_time"; - static constexpr const char *Parameters = "hour,minute,seconds"; - static constexpr const char *Description = "The time for the given parts"; - static constexpr const char *Example = "make_time(13, 34, 27.123456)"; - - static ScalarFunction GetFunction(); -}; - -struct MakeTimestampFun { - static constexpr const char *Name = "make_timestamp"; - static constexpr const char *Parameters = "year,month,day,hour,minute,seconds"; - static constexpr const char *Description = "The timestamp for the given parts"; - static constexpr const char *Example = "make_timestamp(1992, 9, 20, 13, 34, 27.123456)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MakeTimestampNsFun { - static constexpr const char *Name = "make_timestamp_ns"; - static constexpr const char *Parameters = "nanos"; - static constexpr const char *Description = "The timestamp for the given nanoseconds since epoch"; - static constexpr const char *Example = "make_timestamp(1732117793000000000)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MicrosecondsFun { - static constexpr const char *Name = "microsecond"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the microsecond component from a date or timestamp"; - static constexpr const char *Example = "microsecond(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MillenniumFun { - static constexpr const char *Name = "millennium"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the millennium component from a date or timestamp"; - static constexpr const char *Example = "millennium(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MillisecondsFun { - static constexpr const char *Name = "millisecond"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the millisecond component from a date or timestamp"; - static constexpr const char *Example = "millisecond(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MinutesFun { - static constexpr const char *Name = "minute"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the minute component from a date or timestamp"; - static constexpr const char *Example = "minute(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MonthFun { - static constexpr const char *Name = "month"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the month component from a date or timestamp"; - static constexpr const char *Example = "month(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct MonthNameFun { - static constexpr const char *Name = "monthname"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "The (English) name of the month"; - static constexpr const char *Example = "monthname(TIMESTAMP '1992-09-20')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct NanosecondsFun { - static constexpr const char *Name = "nanosecond"; - static constexpr const char *Parameters = "tsns"; - static constexpr const char *Description = "Extract the nanosecond component from a date or timestamp"; - static constexpr const char *Example = "nanosecond(timestamp_ns '2021-08-03 11:59:44.123456789') => 44123456789"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct NormalizedIntervalFun { - static constexpr const char *Name = "normalized_interval"; - static constexpr const char *Parameters = "interval"; - static constexpr const char *Description = "Normalizes an INTERVAL to an equivalent interval"; - static constexpr const char *Example = "normalized_interval(INTERVAL '30 days')"; - - static ScalarFunction GetFunction(); -}; - -struct QuarterFun { - static constexpr const char *Name = "quarter"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the quarter component from a date or timestamp"; - static constexpr const char *Example = "quarter(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct SecondsFun { - static constexpr const char *Name = "second"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the second component from a date or timestamp"; - static constexpr const char *Example = "second(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TimeBucketFun { - static constexpr const char *Name = "time_bucket"; - static constexpr const char *Parameters = "bucket_width,timestamp,origin"; - static constexpr const char *Description = "Truncate TIMESTAMPTZ by the specified interval bucket_width. Buckets are aligned relative to origin TIMESTAMPTZ. The origin defaults to 2000-01-03 00:00:00+00 for buckets that do not include a month or year interval, and to 2000-01-01 00:00:00+00 for month and year buckets"; - static constexpr const char *Example = "time_bucket(INTERVAL '2 weeks', TIMESTAMP '1992-04-20 15:26:00-07', TIMESTAMP '1992-04-01 00:00:00-07')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TimezoneFun { - static constexpr const char *Name = "timezone"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the timezone component from a date or timestamp"; - static constexpr const char *Example = "timezone(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TimezoneHourFun { - static constexpr const char *Name = "timezone_hour"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the timezone_hour component from a date or timestamp"; - static constexpr const char *Example = "timezone_hour(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TimezoneMinuteFun { - static constexpr const char *Name = "timezone_minute"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the timezone_minute component from a date or timestamp"; - static constexpr const char *Example = "timezone_minute(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TimeTZSortKeyFun { - static constexpr const char *Name = "timetz_byte_comparable"; - static constexpr const char *Parameters = "time_tz"; - static constexpr const char *Description = "Converts a TIME WITH TIME ZONE to an integer sort key"; - static constexpr const char *Example = "timetz_byte_comparable('18:18:16.21-07:00'::TIME_TZ)"; - - static ScalarFunction GetFunction(); -}; - -struct ToCenturiesFun { - static constexpr const char *Name = "to_centuries"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a century interval"; - static constexpr const char *Example = "to_centuries(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToDaysFun { - static constexpr const char *Name = "to_days"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a day interval"; - static constexpr const char *Example = "to_days(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToDecadesFun { - static constexpr const char *Name = "to_decades"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a decade interval"; - static constexpr const char *Example = "to_decades(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToHoursFun { - static constexpr const char *Name = "to_hours"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a hour interval"; - static constexpr const char *Example = "to_hours(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToMicrosecondsFun { - static constexpr const char *Name = "to_microseconds"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a microsecond interval"; - static constexpr const char *Example = "to_microseconds(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToMillenniaFun { - static constexpr const char *Name = "to_millennia"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a millenium interval"; - static constexpr const char *Example = "to_millennia(1)"; - - static ScalarFunction GetFunction(); -}; - -struct ToMillisecondsFun { - static constexpr const char *Name = "to_milliseconds"; - static constexpr const char *Parameters = "double"; - static constexpr const char *Description = "Construct a millisecond interval"; - static constexpr const char *Example = "to_milliseconds(5.5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToMinutesFun { - static constexpr const char *Name = "to_minutes"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a minute interval"; - static constexpr const char *Example = "to_minutes(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToMonthsFun { - static constexpr const char *Name = "to_months"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a month interval"; - static constexpr const char *Example = "to_months(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToQuartersFun { - static constexpr const char *Name = "to_quarters"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a quarter interval"; - static constexpr const char *Example = "to_quarters(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToSecondsFun { - static constexpr const char *Name = "to_seconds"; - static constexpr const char *Parameters = "double"; - static constexpr const char *Description = "Construct a second interval"; - static constexpr const char *Example = "to_seconds(5.5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToTimestampFun { - static constexpr const char *Name = "to_timestamp"; - static constexpr const char *Parameters = "sec"; - static constexpr const char *Description = "Converts secs since epoch to a timestamp with time zone"; - static constexpr const char *Example = "to_timestamp(1284352323.5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToWeeksFun { - static constexpr const char *Name = "to_weeks"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a week interval"; - static constexpr const char *Example = "to_weeks(5)"; - - static ScalarFunction GetFunction(); -}; - -struct ToYearsFun { - static constexpr const char *Name = "to_years"; - static constexpr const char *Parameters = "integer"; - static constexpr const char *Description = "Construct a year interval"; - static constexpr const char *Example = "to_years(5)"; - - static ScalarFunction GetFunction(); -}; - -struct WeekFun { - static constexpr const char *Name = "week"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the week component from a date or timestamp"; - static constexpr const char *Example = "week(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct WeekDayFun { - static constexpr const char *Name = "weekday"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the weekday component from a date or timestamp"; - static constexpr const char *Example = "weekday(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct WeekOfYearFun { - static constexpr const char *Name = "weekofyear"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the weekofyear component from a date or timestamp"; - static constexpr const char *Example = "weekofyear(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct YearFun { - static constexpr const char *Name = "year"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the year component from a date or timestamp"; - static constexpr const char *Example = "year(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct YearWeekFun { - static constexpr const char *Name = "yearweek"; - static constexpr const char *Parameters = "ts"; - static constexpr const char *Description = "Extract the yearweek component from a date or timestamp"; - static constexpr const char *Example = "yearweek(timestamp '2021-08-03 11:59:44.123456')"; - - static ScalarFunctionSet GetFunctions(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp deleted file mode 100644 index ce4debc6d..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/debug_functions.hpp +++ /dev/null @@ -1,27 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/debug_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct VectorTypeFun { - static constexpr const char *Name = "vector_type"; - static constexpr const char *Parameters = "col"; - static constexpr const char *Description = "Returns the VectorType of a given column"; - static constexpr const char *Example = "vector_type(col)"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp deleted file mode 100644 index 73791f8a5..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/enum_functions.hpp +++ /dev/null @@ -1,63 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/enum_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct EnumFirstFun { - static constexpr const char *Name = "enum_first"; - static constexpr const char *Parameters = "enum"; - static constexpr const char *Description = "Returns the first value of the input enum type"; - static constexpr const char *Example = "enum_first(NULL::mood)"; - - static ScalarFunction GetFunction(); -}; - -struct EnumLastFun { - static constexpr const char *Name = "enum_last"; - static constexpr const char *Parameters = "enum"; - static constexpr const char *Description = "Returns the last value of the input enum type"; - static constexpr const char *Example = "enum_last(NULL::mood)"; - - static ScalarFunction GetFunction(); -}; - -struct EnumCodeFun { - static constexpr const char *Name = "enum_code"; - static constexpr const char *Parameters = "enum"; - static constexpr const char *Description = "Returns the numeric value backing the given enum value"; - static constexpr const char *Example = "enum_code('happy'::mood)"; - - static ScalarFunction GetFunction(); -}; - -struct EnumRangeFun { - static constexpr const char *Name = "enum_range"; - static constexpr const char *Parameters = "enum"; - static constexpr const char *Description = "Returns all values of the input enum type as an array"; - static constexpr const char *Example = "enum_range(NULL::mood)"; - - static ScalarFunction GetFunction(); -}; - -struct EnumRangeBoundaryFun { - static constexpr const char *Name = "enum_range_boundary"; - static constexpr const char *Parameters = "start,end"; - static constexpr const char *Description = "Returns the range between the two given enum values as an array. The values must be of the same enum type. When the first parameter is NULL, the result starts with the first value of the enum type. When the second parameter is NULL, the result ends with the last value of the enum type"; - static constexpr const char *Example = "enum_range_boundary(NULL, 'happy'::mood)"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp deleted file mode 100644 index d874e72a9..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/generic_functions.hpp +++ /dev/null @@ -1,171 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/generic_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct AliasFun { - static constexpr const char *Name = "alias"; - static constexpr const char *Parameters = "expr"; - static constexpr const char *Description = "Returns the name of a given expression"; - static constexpr const char *Example = "alias(42 + 1)"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentSettingFun { - static constexpr const char *Name = "current_setting"; - static constexpr const char *Parameters = "setting_name"; - static constexpr const char *Description = "Returns the current value of the configuration setting"; - static constexpr const char *Example = "current_setting('access_mode')"; - - static ScalarFunction GetFunction(); -}; - -struct HashFun { - static constexpr const char *Name = "hash"; - static constexpr const char *Parameters = "param"; - static constexpr const char *Description = "Returns an integer with the hash of the value. Note that this is not a cryptographic hash"; - static constexpr const char *Example = "hash('🦆')"; - - static ScalarFunction GetFunction(); -}; - -struct LeastFun { - static constexpr const char *Name = "least"; - static constexpr const char *Parameters = "arg1, arg2, ..."; - static constexpr const char *Description = "Returns the lowest value of the set of input parameters"; - static constexpr const char *Example = "least(42, 84)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct GreatestFun { - static constexpr const char *Name = "greatest"; - static constexpr const char *Parameters = "arg1, arg2, ..."; - static constexpr const char *Description = "Returns the highest value of the set of input parameters"; - static constexpr const char *Example = "greatest(42, 84)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct StatsFun { - static constexpr const char *Name = "stats"; - static constexpr const char *Parameters = "expression"; - static constexpr const char *Description = "Returns a string with statistics about the expression. Expression can be a column, constant, or SQL expression"; - static constexpr const char *Example = "stats(5)"; - - static ScalarFunction GetFunction(); -}; - -struct TypeOfFun { - static constexpr const char *Name = "typeof"; - static constexpr const char *Parameters = "expression"; - static constexpr const char *Description = "Returns the name of the data type of the result of the expression"; - static constexpr const char *Example = "typeof('abc')"; - - static ScalarFunction GetFunction(); -}; - -struct CanCastImplicitlyFun { - static constexpr const char *Name = "can_cast_implicitly"; - static constexpr const char *Parameters = "source_type,target_type"; - static constexpr const char *Description = "Whether or not we can implicitly cast from the source type to the other type"; - static constexpr const char *Example = "can_implicitly_cast(NULL::INTEGER, NULL::BIGINT)"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentQueryFun { - static constexpr const char *Name = "current_query"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the current query as a string"; - static constexpr const char *Example = "current_query()"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentSchemaFun { - static constexpr const char *Name = "current_schema"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the name of the currently active schema. Default is main"; - static constexpr const char *Example = "current_schema()"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentSchemasFun { - static constexpr const char *Name = "current_schemas"; - static constexpr const char *Parameters = "include_implicit"; - static constexpr const char *Description = "Returns list of schemas. Pass a parameter of True to include implicit schemas"; - static constexpr const char *Example = "current_schemas(true)"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentDatabaseFun { - static constexpr const char *Name = "current_database"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the name of the currently active database"; - static constexpr const char *Example = "current_database()"; - - static ScalarFunction GetFunction(); -}; - -struct InSearchPathFun { - static constexpr const char *Name = "in_search_path"; - static constexpr const char *Parameters = "database_name,schema_name"; - static constexpr const char *Description = "Returns whether or not the database/schema are in the search path"; - static constexpr const char *Example = "in_search_path('memory', 'main')"; - - static ScalarFunction GetFunction(); -}; - -struct CurrentTransactionIdFun { - static constexpr const char *Name = "txid_current"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the current transaction’s ID (a BIGINT). It will assign a new one if the current transaction does not have one already"; - static constexpr const char *Example = "txid_current()"; - - static ScalarFunction GetFunction(); -}; - -struct VersionFun { - static constexpr const char *Name = "version"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the currently active version of DuckDB in this format: v0.3.2 "; - static constexpr const char *Example = "version()"; - - static ScalarFunction GetFunction(); -}; - -struct EquiWidthBinsFun { - static constexpr const char *Name = "equi_width_bins"; - static constexpr const char *Parameters = "min,max,bin_count,nice_rounding"; - static constexpr const char *Description = "Generates bin_count equi-width bins between the min and max. If enabled nice_rounding makes the numbers more readable/less jagged"; - static constexpr const char *Example = "equi_width_bins(0, 10, 2, true)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct IsHistogramOtherBinFun { - static constexpr const char *Name = "is_histogram_other_bin"; - static constexpr const char *Parameters = "val"; - static constexpr const char *Description = "Whether or not the provided value is the histogram \"other\" bin (used for values not belonging to any provided bin)"; - static constexpr const char *Example = "is_histogram_other_bin(v)"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp deleted file mode 100644 index 2b9318b40..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/list_functions.hpp +++ /dev/null @@ -1,390 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/list_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct ListFlattenFun { - static constexpr const char *Name = "flatten"; - static constexpr const char *Parameters = "nested_list"; - static constexpr const char *Description = "Flatten a nested list by one level"; - static constexpr const char *Example = "flatten([[1, 2, 3], [4, 5]])"; - - static ScalarFunction GetFunction(); -}; - -struct ListAggregateFun { - static constexpr const char *Name = "list_aggregate"; - static constexpr const char *Parameters = "list,name"; - static constexpr const char *Description = "Executes the aggregate function name on the elements of list"; - static constexpr const char *Example = "list_aggregate([1, 2, NULL], 'min')"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayAggregateFun { - using ALIAS = ListAggregateFun; - - static constexpr const char *Name = "array_aggregate"; -}; - -struct ListAggrFun { - using ALIAS = ListAggregateFun; - - static constexpr const char *Name = "list_aggr"; -}; - -struct ArrayAggrFun { - using ALIAS = ListAggregateFun; - - static constexpr const char *Name = "array_aggr"; -}; - -struct AggregateFun { - using ALIAS = ListAggregateFun; - - static constexpr const char *Name = "aggregate"; -}; - -struct ListDistinctFun { - static constexpr const char *Name = "list_distinct"; - static constexpr const char *Parameters = "list"; - static constexpr const char *Description = "Removes all duplicates and NULLs from a list. Does not preserve the original order"; - static constexpr const char *Example = "list_distinct([1, 1, NULL, -3, 1, 5])"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayDistinctFun { - using ALIAS = ListDistinctFun; - - static constexpr const char *Name = "array_distinct"; -}; - -struct ListUniqueFun { - static constexpr const char *Name = "list_unique"; - static constexpr const char *Parameters = "list"; - static constexpr const char *Description = "Counts the unique elements of a list"; - static constexpr const char *Example = "list_unique([1, 1, NULL, -3, 1, 5])"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayUniqueFun { - using ALIAS = ListUniqueFun; - - static constexpr const char *Name = "array_unique"; -}; - -struct ListValueFun { - static constexpr const char *Name = "list_value"; - static constexpr const char *Parameters = "any,..."; - static constexpr const char *Description = "Create a LIST containing the argument values"; - static constexpr const char *Example = "list_value(4, 5, 6)"; - - static ScalarFunction GetFunction(); -}; - -struct ListPackFun { - using ALIAS = ListValueFun; - - static constexpr const char *Name = "list_pack"; -}; - -struct ListSliceFun { - static constexpr const char *Name = "list_slice"; - static constexpr const char *Parameters = "list,begin,end\1list,begin,end,step"; - static constexpr const char *Description = "Extract a sublist using slice conventions. Negative values are accepted.\1list_slice with added step feature."; - static constexpr const char *Example = "list_slice([4, 5, 6], 2, 3)\2array_slice('DuckDB', 3, 4)\2array_slice('DuckDB', 3, NULL)\2array_slice('DuckDB', 0, -3)\1list_slice([4, 5, 6], 1, 3, 2)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArraySliceFun { - using ALIAS = ListSliceFun; - - static constexpr const char *Name = "array_slice"; -}; - -struct ListSortFun { - static constexpr const char *Name = "list_sort"; - static constexpr const char *Parameters = "list"; - static constexpr const char *Description = "Sorts the elements of the list"; - static constexpr const char *Example = "list_sort([3, 6, 1, 2])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArraySortFun { - using ALIAS = ListSortFun; - - static constexpr const char *Name = "array_sort"; -}; - -struct ListGradeUpFun { - static constexpr const char *Name = "list_grade_up"; - static constexpr const char *Parameters = "list"; - static constexpr const char *Description = "Returns the index of their sorted position."; - static constexpr const char *Example = "list_grade_up([3, 6, 1, 2])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayGradeUpFun { - using ALIAS = ListGradeUpFun; - - static constexpr const char *Name = "array_grade_up"; -}; - -struct GradeUpFun { - using ALIAS = ListGradeUpFun; - - static constexpr const char *Name = "grade_up"; -}; - -struct ListReverseSortFun { - static constexpr const char *Name = "list_reverse_sort"; - static constexpr const char *Parameters = "list"; - static constexpr const char *Description = "Sorts the elements of the list in reverse order"; - static constexpr const char *Example = "list_reverse_sort([3, 6, 1, 2])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ArrayReverseSortFun { - using ALIAS = ListReverseSortFun; - - static constexpr const char *Name = "array_reverse_sort"; -}; - -struct ListTransformFun { - static constexpr const char *Name = "list_transform"; - static constexpr const char *Parameters = "list,lambda"; - static constexpr const char *Description = "Returns a list that is the result of applying the lambda function to each element of the input list. See the Lambda Functions section for more details"; - static constexpr const char *Example = "list_transform([1, 2, 3], x -> x + 1)"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayTransformFun { - using ALIAS = ListTransformFun; - - static constexpr const char *Name = "array_transform"; -}; - -struct ListApplyFun { - using ALIAS = ListTransformFun; - - static constexpr const char *Name = "list_apply"; -}; - -struct ArrayApplyFun { - using ALIAS = ListTransformFun; - - static constexpr const char *Name = "array_apply"; -}; - -struct ApplyFun { - using ALIAS = ListTransformFun; - - static constexpr const char *Name = "apply"; -}; - -struct ListFilterFun { - static constexpr const char *Name = "list_filter"; - static constexpr const char *Parameters = "list,lambda"; - static constexpr const char *Description = "Constructs a list from those elements of the input list for which the lambda function returns true"; - static constexpr const char *Example = "list_filter([3, 4, 5], x -> x > 4)"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayFilterFun { - using ALIAS = ListFilterFun; - - static constexpr const char *Name = "array_filter"; -}; - -struct FilterFun { - using ALIAS = ListFilterFun; - - static constexpr const char *Name = "filter"; -}; - -struct ListReduceFun { - static constexpr const char *Name = "list_reduce"; - static constexpr const char *Parameters = "list,lambda"; - static constexpr const char *Description = "Returns a single value that is the result of applying the lambda function to each element of the input list, starting with the first element and then repeatedly applying the lambda function to the result of the previous application and the next element of the list."; - static constexpr const char *Example = "list_reduce([1, 2, 3], (x, y) -> x + y)"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayReduceFun { - using ALIAS = ListReduceFun; - - static constexpr const char *Name = "array_reduce"; -}; - -struct ReduceFun { - using ALIAS = ListReduceFun; - - static constexpr const char *Name = "reduce"; -}; - -struct GenerateSeriesFun { - static constexpr const char *Name = "generate_series"; - static constexpr const char *Parameters = "start,stop,step"; - static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is inclusive"; - static constexpr const char *Example = "generate_series(2, 5, 3)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListRangeFun { - static constexpr const char *Name = "range"; - static constexpr const char *Parameters = "start,stop,step"; - static constexpr const char *Description = "Create a list of values between start and stop - the stop parameter is exclusive"; - static constexpr const char *Example = "range(2, 5, 3)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListCosineDistanceFun { - static constexpr const char *Name = "list_cosine_distance"; - static constexpr const char *Parameters = "list1,list2"; - static constexpr const char *Description = "Compute the cosine distance between two lists"; - static constexpr const char *Example = "list_cosine_distance([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListCosineDistanceFunAlias { - using ALIAS = ListCosineDistanceFun; - - static constexpr const char *Name = "<=>"; -}; - -struct ListCosineSimilarityFun { - static constexpr const char *Name = "list_cosine_similarity"; - static constexpr const char *Parameters = "list1,list2"; - static constexpr const char *Description = "Compute the cosine similarity between two lists"; - static constexpr const char *Example = "list_cosine_similarity([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListDistanceFun { - static constexpr const char *Name = "list_distance"; - static constexpr const char *Parameters = "list1,list2"; - static constexpr const char *Description = "Compute the distance between two lists"; - static constexpr const char *Example = "list_distance([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListDistanceFunAlias { - using ALIAS = ListDistanceFun; - - static constexpr const char *Name = "<->"; -}; - -struct ListInnerProductFun { - static constexpr const char *Name = "list_inner_product"; - static constexpr const char *Parameters = "list1,list2"; - static constexpr const char *Description = "Compute the inner product between two lists"; - static constexpr const char *Example = "list_inner_product([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListDotProductFun { - using ALIAS = ListInnerProductFun; - - static constexpr const char *Name = "list_dot_product"; -}; - -struct ListNegativeInnerProductFun { - static constexpr const char *Name = "list_negative_inner_product"; - static constexpr const char *Parameters = "list1,list2"; - static constexpr const char *Description = "Compute the negative inner product between two lists"; - static constexpr const char *Example = "list_negative_inner_product([1, 2, 3], [1, 2, 3])"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ListNegativeDotProductFun { - using ALIAS = ListNegativeInnerProductFun; - - static constexpr const char *Name = "list_negative_dot_product"; -}; - -struct UnpivotListFun { - static constexpr const char *Name = "unpivot_list"; - static constexpr const char *Parameters = "any,..."; - static constexpr const char *Description = "Identical to list_value, but generated as part of unpivot for better error messages"; - static constexpr const char *Example = "unpivot_list(4, 5, 6)"; - - static ScalarFunction GetFunction(); -}; - -struct ListHasAnyFun { - static constexpr const char *Name = "list_has_any"; - static constexpr const char *Parameters = "l1, l2"; - static constexpr const char *Description = "Returns true if the lists have any element in common. NULLs are ignored."; - static constexpr const char *Example = "list_has_any([1, 2, 3], [2, 3, 4])"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayHasAnyFun { - using ALIAS = ListHasAnyFun; - - static constexpr const char *Name = "array_has_any"; -}; - -struct ListHasAnyFunAlias { - using ALIAS = ListHasAnyFun; - - static constexpr const char *Name = "&&"; -}; - -struct ListHasAllFun { - static constexpr const char *Name = "list_has_all"; - static constexpr const char *Parameters = "l1, l2"; - static constexpr const char *Description = "Returns true if all elements of l2 are in l1. NULLs are ignored."; - static constexpr const char *Example = "list_has_all([1, 2, 3], [2, 3])"; - - static ScalarFunction GetFunction(); -}; - -struct ArrayHasAllFun { - using ALIAS = ListHasAllFun; - - static constexpr const char *Name = "array_has_all"; -}; - -struct ListHasAllFunAlias { - using ALIAS = ListHasAllFun; - - static constexpr const char *Name = "@>"; -}; - -struct ListHasAllFunAlias2 { - using ALIAS = ListHasAllFun; - - static constexpr const char *Name = "<@"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp deleted file mode 100644 index 0998a3156..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/map_functions.hpp +++ /dev/null @@ -1,96 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/map_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct CardinalityFun { - static constexpr const char *Name = "cardinality"; - static constexpr const char *Parameters = "map"; - static constexpr const char *Description = "Returns the size of the map (or the number of entries in the map)"; - static constexpr const char *Example = "cardinality( map([4, 2], ['a', 'b']) );"; - - static ScalarFunction GetFunction(); -}; - -struct MapFun { - static constexpr const char *Name = "map"; - static constexpr const char *Parameters = "keys,values"; - static constexpr const char *Description = "Creates a map from a set of keys and values"; - static constexpr const char *Example = "map(['key1', 'key2'], ['val1', 'val2'])"; - - static ScalarFunction GetFunction(); -}; - -struct MapEntriesFun { - static constexpr const char *Name = "map_entries"; - static constexpr const char *Parameters = "map"; - static constexpr const char *Description = "Returns the map entries as a list of keys/values"; - static constexpr const char *Example = "map_entries(map(['key'], ['val']))"; - - static ScalarFunction GetFunction(); -}; - -struct MapExtractFun { - static constexpr const char *Name = "map_extract"; - static constexpr const char *Parameters = "map,key"; - static constexpr const char *Description = "Returns a list containing the value for a given key or an empty list if the key is not contained in the map. The type of the key provided in the second parameter must match the type of the map’s keys else an error is returned"; - static constexpr const char *Example = "map_extract(map(['key'], ['val']), 'key')"; - - static ScalarFunction GetFunction(); -}; - -struct ElementAtFun { - using ALIAS = MapExtractFun; - - static constexpr const char *Name = "element_at"; -}; - -struct MapFromEntriesFun { - static constexpr const char *Name = "map_from_entries"; - static constexpr const char *Parameters = "map"; - static constexpr const char *Description = "Returns a map created from the entries of the array"; - static constexpr const char *Example = "map_from_entries([{k: 5, v: 'val1'}, {k: 3, v: 'val2'}]);"; - - static ScalarFunction GetFunction(); -}; - -struct MapConcatFun { - static constexpr const char *Name = "map_concat"; - static constexpr const char *Parameters = "any,..."; - static constexpr const char *Description = "Returns a map created from merging the input maps, on key collision the value is taken from the last map with that key"; - static constexpr const char *Example = "map_concat(map([1,2], ['a', 'b']), map([2,3], ['c', 'd']));"; - - static ScalarFunction GetFunction(); -}; - -struct MapKeysFun { - static constexpr const char *Name = "map_keys"; - static constexpr const char *Parameters = "map"; - static constexpr const char *Description = "Returns the keys of a map as a list"; - static constexpr const char *Example = "map_keys(map(['key'], ['val']))"; - - static ScalarFunction GetFunction(); -}; - -struct MapValuesFun { - static constexpr const char *Name = "map_values"; - static constexpr const char *Parameters = "map"; - static constexpr const char *Description = "Returns the values of a map as a list"; - static constexpr const char *Example = "map_values(map(['key'], ['val']))"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp deleted file mode 100644 index 7b8e2befd..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/math_functions.hpp +++ /dev/null @@ -1,453 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/math_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct AbsOperatorFun { - static constexpr const char *Name = "@"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Absolute value"; - static constexpr const char *Example = "abs(-17.4)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct AbsFun { - using ALIAS = AbsOperatorFun; - - static constexpr const char *Name = "abs"; -}; - -struct PowOperatorFun { - static constexpr const char *Name = "**"; - static constexpr const char *Parameters = "x,y"; - static constexpr const char *Description = "Computes x to the power of y"; - static constexpr const char *Example = "pow(2, 3)"; - - static ScalarFunction GetFunction(); -}; - -struct PowFun { - using ALIAS = PowOperatorFun; - - static constexpr const char *Name = "pow"; -}; - -struct PowerFun { - using ALIAS = PowOperatorFun; - - static constexpr const char *Name = "power"; -}; - -struct PowOperatorFunAlias { - using ALIAS = PowOperatorFun; - - static constexpr const char *Name = "^"; -}; - -struct FactorialOperatorFun { - static constexpr const char *Name = "!__postfix"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Factorial of x. Computes the product of the current integer and all integers below it"; - static constexpr const char *Example = "4!"; - - static ScalarFunction GetFunction(); -}; - -struct FactorialFun { - using ALIAS = FactorialOperatorFun; - - static constexpr const char *Name = "factorial"; -}; - -struct AcosFun { - static constexpr const char *Name = "acos"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the arccosine of x"; - static constexpr const char *Example = "acos(0.5)"; - - static ScalarFunction GetFunction(); -}; - -struct AsinFun { - static constexpr const char *Name = "asin"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the arcsine of x"; - static constexpr const char *Example = "asin(0.5)"; - - static ScalarFunction GetFunction(); -}; - -struct AtanFun { - static constexpr const char *Name = "atan"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the arctangent of x"; - static constexpr const char *Example = "atan(0.5)"; - - static ScalarFunction GetFunction(); -}; - -struct Atan2Fun { - static constexpr const char *Name = "atan2"; - static constexpr const char *Parameters = "y,x"; - static constexpr const char *Description = "Computes the arctangent (y, x)"; - static constexpr const char *Example = "atan2(1.0, 0.0)"; - - static ScalarFunction GetFunction(); -}; - -struct BitCountFun { - static constexpr const char *Name = "bit_count"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the number of bits that are set"; - static constexpr const char *Example = "bit_count(31)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct CbrtFun { - static constexpr const char *Name = "cbrt"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the cube root of x"; - static constexpr const char *Example = "cbrt(8)"; - - static ScalarFunction GetFunction(); -}; - -struct CeilFun { - static constexpr const char *Name = "ceil"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Rounds the number up"; - static constexpr const char *Example = "ceil(17.4)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct CeilingFun { - using ALIAS = CeilFun; - - static constexpr const char *Name = "ceiling"; -}; - -struct CosFun { - static constexpr const char *Name = "cos"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the cos of x"; - static constexpr const char *Example = "cos(90)"; - - static ScalarFunction GetFunction(); -}; - -struct CotFun { - static constexpr const char *Name = "cot"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the cotangent of x"; - static constexpr const char *Example = "cot(0.5)"; - - static ScalarFunction GetFunction(); -}; - -struct DegreesFun { - static constexpr const char *Name = "degrees"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Converts radians to degrees"; - static constexpr const char *Example = "degrees(pi())"; - - static ScalarFunction GetFunction(); -}; - -struct EvenFun { - static constexpr const char *Name = "even"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Rounds x to next even number by rounding away from zero"; - static constexpr const char *Example = "even(2.9)"; - - static ScalarFunction GetFunction(); -}; - -struct ExpFun { - static constexpr const char *Name = "exp"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes e to the power of x"; - static constexpr const char *Example = "exp(1)"; - - static ScalarFunction GetFunction(); -}; - -struct FloorFun { - static constexpr const char *Name = "floor"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Rounds the number down"; - static constexpr const char *Example = "floor(17.4)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct IsFiniteFun { - static constexpr const char *Name = "isfinite"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns true if the floating point value is finite, false otherwise"; - static constexpr const char *Example = "isfinite(5.5)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct IsInfiniteFun { - static constexpr const char *Name = "isinf"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns true if the floating point value is infinite, false otherwise"; - static constexpr const char *Example = "isinf('Infinity'::float)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct IsNanFun { - static constexpr const char *Name = "isnan"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns true if the floating point value is not a number, false otherwise"; - static constexpr const char *Example = "isnan('NaN'::FLOAT)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct GammaFun { - static constexpr const char *Name = "gamma"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Interpolation of (x-1) factorial (so decimal inputs are allowed)"; - static constexpr const char *Example = "gamma(5.5)"; - - static ScalarFunction GetFunction(); -}; - -struct GreatestCommonDivisorFun { - static constexpr const char *Name = "greatest_common_divisor"; - static constexpr const char *Parameters = "x,y"; - static constexpr const char *Description = "Computes the greatest common divisor of x and y"; - static constexpr const char *Example = "greatest_common_divisor(42, 57)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct GcdFun { - using ALIAS = GreatestCommonDivisorFun; - - static constexpr const char *Name = "gcd"; -}; - -struct LeastCommonMultipleFun { - static constexpr const char *Name = "least_common_multiple"; - static constexpr const char *Parameters = "x,y"; - static constexpr const char *Description = "Computes the least common multiple of x and y"; - static constexpr const char *Example = "least_common_multiple(42, 57)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct LcmFun { - using ALIAS = LeastCommonMultipleFun; - - static constexpr const char *Name = "lcm"; -}; - -struct LogGammaFun { - static constexpr const char *Name = "lgamma"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the log of the gamma function"; - static constexpr const char *Example = "lgamma(2)"; - - static ScalarFunction GetFunction(); -}; - -struct LnFun { - static constexpr const char *Name = "ln"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the natural logarithm of x"; - static constexpr const char *Example = "ln(2)"; - - static ScalarFunction GetFunction(); -}; - -struct Log2Fun { - static constexpr const char *Name = "log2"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the 2-log of x"; - static constexpr const char *Example = "log2(8)"; - - static ScalarFunction GetFunction(); -}; - -struct Log10Fun { - static constexpr const char *Name = "log10"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the 10-log of x"; - static constexpr const char *Example = "log10(1000)"; - - static ScalarFunction GetFunction(); -}; - -struct LogFun { - static constexpr const char *Name = "log"; - static constexpr const char *Parameters = "b, x"; - static constexpr const char *Description = "Computes the logarithm of x to base b. b may be omitted, in which case the default 10"; - static constexpr const char *Example = "log(2, 64)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct NextAfterFun { - static constexpr const char *Name = "nextafter"; - static constexpr const char *Parameters = "x, y"; - static constexpr const char *Description = "Returns the next floating point value after x in the direction of y"; - static constexpr const char *Example = "nextafter(1::float, 2::float)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct PiFun { - static constexpr const char *Name = "pi"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns the value of pi"; - static constexpr const char *Example = "pi()"; - - static ScalarFunction GetFunction(); -}; - -struct RadiansFun { - static constexpr const char *Name = "radians"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Converts degrees to radians"; - static constexpr const char *Example = "radians(90)"; - - static ScalarFunction GetFunction(); -}; - -struct RoundFun { - static constexpr const char *Name = "round"; - static constexpr const char *Parameters = "x,precision"; - static constexpr const char *Description = "Rounds x to s decimal places"; - static constexpr const char *Example = "round(42.4332, 2)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct SignFun { - static constexpr const char *Name = "sign"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the sign of x as -1, 0 or 1"; - static constexpr const char *Example = "sign(-349)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct SignBitFun { - static constexpr const char *Name = "signbit"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns whether the signbit is set or not"; - static constexpr const char *Example = "signbit(-0.0)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct SinFun { - static constexpr const char *Name = "sin"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the sin of x"; - static constexpr const char *Example = "sin(90)"; - - static ScalarFunction GetFunction(); -}; - -struct SqrtFun { - static constexpr const char *Name = "sqrt"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Returns the square root of x"; - static constexpr const char *Example = "sqrt(4)"; - - static ScalarFunction GetFunction(); -}; - -struct TanFun { - static constexpr const char *Name = "tan"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the tan of x"; - static constexpr const char *Example = "tan(90)"; - - static ScalarFunction GetFunction(); -}; - -struct TruncFun { - static constexpr const char *Name = "trunc"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Truncates the number"; - static constexpr const char *Example = "trunc(17.4)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct CoshFun { - static constexpr const char *Name = "cosh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the hyperbolic cos of x"; - static constexpr const char *Example = "cosh(1)"; - - static ScalarFunction GetFunction(); -}; - -struct SinhFun { - static constexpr const char *Name = "sinh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the hyperbolic sin of x"; - static constexpr const char *Example = "sinh(1)"; - - static ScalarFunction GetFunction(); -}; - -struct TanhFun { - static constexpr const char *Name = "tanh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the hyperbolic tan of x"; - static constexpr const char *Example = "tanh(1)"; - - static ScalarFunction GetFunction(); -}; - -struct AcoshFun { - static constexpr const char *Name = "acosh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the inverse hyperbolic cos of x"; - static constexpr const char *Example = "acosh(2.3)"; - - static ScalarFunction GetFunction(); -}; - -struct AsinhFun { - static constexpr const char *Name = "asinh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the inverse hyperbolic sin of x"; - static constexpr const char *Example = "asinh(0.5)"; - - static ScalarFunction GetFunction(); -}; - -struct AtanhFun { - static constexpr const char *Name = "atanh"; - static constexpr const char *Parameters = "x"; - static constexpr const char *Description = "Computes the inverse hyperbolic tan of x"; - static constexpr const char *Example = "atanh(0.5)"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp deleted file mode 100644 index 3bbfc565a..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/operators_functions.hpp +++ /dev/null @@ -1,72 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/operators_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct BitwiseAndFun { - static constexpr const char *Name = "&"; - static constexpr const char *Parameters = "left,right"; - static constexpr const char *Description = "Bitwise AND"; - static constexpr const char *Example = "91 & 15"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct BitwiseOrFun { - static constexpr const char *Name = "|"; - static constexpr const char *Parameters = "left,right"; - static constexpr const char *Description = "Bitwise OR"; - static constexpr const char *Example = "32 | 3"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct BitwiseNotFun { - static constexpr const char *Name = "~"; - static constexpr const char *Parameters = "input"; - static constexpr const char *Description = "Bitwise NOT"; - static constexpr const char *Example = "~15"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct LeftShiftFun { - static constexpr const char *Name = "<<"; - static constexpr const char *Parameters = "input"; - static constexpr const char *Description = "Bitwise shift left"; - static constexpr const char *Example = "1 << 4"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct RightShiftFun { - static constexpr const char *Name = ">>"; - static constexpr const char *Parameters = "input"; - static constexpr const char *Description = "Bitwise shift right"; - static constexpr const char *Example = "8 >> 2"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct BitwiseXorFun { - static constexpr const char *Name = "xor"; - static constexpr const char *Parameters = "left,right"; - static constexpr const char *Description = "Bitwise XOR"; - static constexpr const char *Example = "xor(17, 5)"; - - static ScalarFunctionSet GetFunctions(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp deleted file mode 100644 index 1002f0e44..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/random_functions.hpp +++ /dev/null @@ -1,51 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/random_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct RandomFun { - static constexpr const char *Name = "random"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns a random number between 0 and 1"; - static constexpr const char *Example = "random()"; - - static ScalarFunction GetFunction(); -}; - -struct SetseedFun { - static constexpr const char *Name = "setseed"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Sets the seed to be used for the random function"; - static constexpr const char *Example = "setseed(0.42)"; - - static ScalarFunction GetFunction(); -}; - -struct UUIDFun { - static constexpr const char *Name = "uuid"; - static constexpr const char *Parameters = ""; - static constexpr const char *Description = "Returns a random UUID similar to this: eeccb8c5-9943-b2bb-bb5e-222f4e14b687"; - static constexpr const char *Example = "uuid()"; - - static ScalarFunction GetFunction(); -}; - -struct GenRandomUuidFun { - using ALIAS = UUIDFun; - - static constexpr const char *Name = "gen_random_uuid"; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp deleted file mode 100644 index 17e5614e0..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/secret_functions.hpp +++ /dev/null @@ -1,27 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/core_functions/scalar/secret_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct WhichSecretFun { - static constexpr const char *Name = "which_secret"; - static constexpr const char *Parameters = "path,type"; - static constexpr const char *Description = "Print out the name of the secret that will be used for reading a path"; - static constexpr const char *Example = "which_secret('s3://some/authenticated/path.csv', 's3')"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp deleted file mode 100644 index 6a6db36da..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/string_functions.hpp +++ /dev/null @@ -1,444 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/string_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct StartsWithOperatorFun { - static constexpr const char *Name = "^@"; - static constexpr const char *Parameters = "string,search_string"; - static constexpr const char *Description = "Returns true if string begins with search_string"; - static constexpr const char *Example = "starts_with('abc','a')"; - - static ScalarFunction GetFunction(); -}; - -struct StartsWithFun { - using ALIAS = StartsWithOperatorFun; - - static constexpr const char *Name = "starts_with"; -}; - -struct ASCIIFun { - static constexpr const char *Name = "ascii"; - static constexpr const char *Parameters = "string"; - static constexpr const char *Description = "Returns an integer that represents the Unicode code point of the first character of the string"; - static constexpr const char *Example = "ascii('Ω')"; - - static ScalarFunction GetFunction(); -}; - -struct BarFun { - static constexpr const char *Name = "bar"; - static constexpr const char *Parameters = "x,min,max,width"; - static constexpr const char *Description = "Draws a band whose width is proportional to (x - min) and equal to width characters when x = max. width defaults to 80"; - static constexpr const char *Example = "bar(5, 0, 20, 10)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct BinFun { - static constexpr const char *Name = "bin"; - static constexpr const char *Parameters = "value"; - static constexpr const char *Description = "Converts the value to binary representation"; - static constexpr const char *Example = "bin(42)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ToBinaryFun { - using ALIAS = BinFun; - - static constexpr const char *Name = "to_binary"; -}; - -struct ChrFun { - static constexpr const char *Name = "chr"; - static constexpr const char *Parameters = "code_point"; - static constexpr const char *Description = "Returns a character which is corresponding the ASCII code value or Unicode code point"; - static constexpr const char *Example = "chr(65)"; - - static ScalarFunction GetFunction(); -}; - -struct DamerauLevenshteinFun { - static constexpr const char *Name = "damerau_levenshtein"; - static constexpr const char *Parameters = "str1,str2"; - static constexpr const char *Description = "Extension of Levenshtein distance to also include transposition of adjacent characters as an allowed edit operation. In other words, the minimum number of edit operations (insertions, deletions, substitutions or transpositions) required to change one string to another. Different case is considered different"; - static constexpr const char *Example = "damerau_levenshtein('hello', 'world')"; - - static ScalarFunction GetFunction(); -}; - -struct FormatFun { - static constexpr const char *Name = "format"; - static constexpr const char *Parameters = "format,parameters..."; - static constexpr const char *Description = "Formats a string using fmt syntax"; - static constexpr const char *Example = "format('Benchmark \"{}\" took {} seconds', 'CSV', 42)"; - - static ScalarFunction GetFunction(); -}; - -struct FormatBytesFun { - static constexpr const char *Name = "format_bytes"; - static constexpr const char *Parameters = "bytes"; - static constexpr const char *Description = "Converts bytes to a human-readable presentation (e.g. 16000 -> 15.6 KiB)"; - static constexpr const char *Example = "format_bytes(1000 * 16)"; - - static ScalarFunction GetFunction(); -}; - -struct FormatreadablesizeFun { - using ALIAS = FormatBytesFun; - - static constexpr const char *Name = "formatReadableSize"; -}; - -struct FormatreadabledecimalsizeFun { - static constexpr const char *Name = "formatReadableDecimalSize"; - static constexpr const char *Parameters = "bytes"; - static constexpr const char *Description = "Converts bytes to a human-readable presentation (e.g. 16000 -> 16.0 KB)"; - static constexpr const char *Example = "format_bytes(1000 * 16)"; - - static ScalarFunction GetFunction(); -}; - -struct HammingFun { - static constexpr const char *Name = "hamming"; - static constexpr const char *Parameters = "str1,str2"; - static constexpr const char *Description = "The number of positions with different characters for 2 strings of equal length. Different case is considered different"; - static constexpr const char *Example = "hamming('duck','luck')"; - - static ScalarFunction GetFunction(); -}; - -struct MismatchesFun { - using ALIAS = HammingFun; - - static constexpr const char *Name = "mismatches"; -}; - -struct HexFun { - static constexpr const char *Name = "hex"; - static constexpr const char *Parameters = "value"; - static constexpr const char *Description = "Converts the value to hexadecimal representation"; - static constexpr const char *Example = "hex(42)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ToHexFun { - using ALIAS = HexFun; - - static constexpr const char *Name = "to_hex"; -}; - -struct InstrFun { - static constexpr const char *Name = "instr"; - static constexpr const char *Parameters = "haystack,needle"; - static constexpr const char *Description = "Returns location of first occurrence of needle in haystack, counting from 1. Returns 0 if no match found"; - static constexpr const char *Example = "instr('test test','es')"; - - static ScalarFunction GetFunction(); -}; - -struct StrposFun { - using ALIAS = InstrFun; - - static constexpr const char *Name = "strpos"; -}; - -struct PositionFun { - using ALIAS = InstrFun; - - static constexpr const char *Name = "position"; -}; - -struct JaccardFun { - static constexpr const char *Name = "jaccard"; - static constexpr const char *Parameters = "str1,str2"; - static constexpr const char *Description = "The Jaccard similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; - static constexpr const char *Example = "jaccard('duck','luck')"; - - static ScalarFunction GetFunction(); -}; - -struct JaroSimilarityFun { - static constexpr const char *Name = "jaro_similarity"; - static constexpr const char *Parameters = "str1,str2,score_cutoff"; - static constexpr const char *Description = "The Jaro similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; - static constexpr const char *Example = "jaro_similarity('duck', 'duckdb', 0.5)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct JaroWinklerSimilarityFun { - static constexpr const char *Name = "jaro_winkler_similarity"; - static constexpr const char *Parameters = "str1,str2,score_cutoff"; - static constexpr const char *Description = "The Jaro-Winkler similarity between two strings. Different case is considered different. Returns a number between 0 and 1"; - static constexpr const char *Example = "jaro_winkler_similarity('duck', 'duckdb', 0.5)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct LeftFun { - static constexpr const char *Name = "left"; - static constexpr const char *Parameters = "string,count"; - static constexpr const char *Description = "Extract the left-most count characters"; - static constexpr const char *Example = "left('Hello🦆', 2)"; - - static ScalarFunction GetFunction(); -}; - -struct LeftGraphemeFun { - static constexpr const char *Name = "left_grapheme"; - static constexpr const char *Parameters = "string,count"; - static constexpr const char *Description = "Extract the left-most count grapheme clusters"; - static constexpr const char *Example = "left_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; - - static ScalarFunction GetFunction(); -}; - -struct LevenshteinFun { - static constexpr const char *Name = "levenshtein"; - static constexpr const char *Parameters = "str1,str2"; - static constexpr const char *Description = "The minimum number of single-character edits (insertions, deletions or substitutions) required to change one string to the other. Different case is considered different"; - static constexpr const char *Example = "levenshtein('duck','db')"; - - static ScalarFunction GetFunction(); -}; - -struct Editdist3Fun { - using ALIAS = LevenshteinFun; - - static constexpr const char *Name = "editdist3"; -}; - -struct LpadFun { - static constexpr const char *Name = "lpad"; - static constexpr const char *Parameters = "string,count,character"; - static constexpr const char *Description = "Pads the string with the character from the left until it has count characters"; - static constexpr const char *Example = "lpad('hello', 10, '>')"; - - static ScalarFunction GetFunction(); -}; - -struct LtrimFun { - static constexpr const char *Name = "ltrim"; - static constexpr const char *Parameters = "string,characters"; - static constexpr const char *Description = "Removes any occurrences of any of the characters from the left side of the string"; - static constexpr const char *Example = "ltrim('>>>>test<<', '><')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ParseDirnameFun { - static constexpr const char *Name = "parse_dirname"; - static constexpr const char *Parameters = "string,separator"; - static constexpr const char *Description = "Returns the top-level directory name. separator options: system, both_slash (default), forward_slash, backslash"; - static constexpr const char *Example = "parse_dirname('path/to/file.csv', 'system')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ParseDirpathFun { - static constexpr const char *Name = "parse_dirpath"; - static constexpr const char *Parameters = "string,separator"; - static constexpr const char *Description = "Returns the head of the path similarly to Python's os.path.dirname. separator options: system, both_slash (default), forward_slash, backslash"; - static constexpr const char *Example = "parse_dirpath('path/to/file.csv', 'system')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ParseFilenameFun { - static constexpr const char *Name = "parse_filename"; - static constexpr const char *Parameters = "string,trim_extension,separator"; - static constexpr const char *Description = "Returns the last component of the path similarly to Python's os.path.basename. If trim_extension is true, the file extension will be removed (it defaults to false). separator options: system, both_slash (default), forward_slash, backslash"; - static constexpr const char *Example = "parse_filename('path/to/file.csv', true, 'forward_slash')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ParsePathFun { - static constexpr const char *Name = "parse_path"; - static constexpr const char *Parameters = "string,separator"; - static constexpr const char *Description = "Returns a list of the components (directories and filename) in the path similarly to Python's pathlib.PurePath::parts. separator options: system, both_slash (default), forward_slash, backslash"; - static constexpr const char *Example = "parse_path('path/to/file.csv', 'system')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct PrintfFun { - static constexpr const char *Name = "printf"; - static constexpr const char *Parameters = "format,parameters..."; - static constexpr const char *Description = "Formats a string using printf syntax"; - static constexpr const char *Example = "printf('Benchmark \"%s\" took %d seconds', 'CSV', 42)"; - - static ScalarFunction GetFunction(); -}; - -struct RepeatFun { - static constexpr const char *Name = "repeat"; - static constexpr const char *Parameters = "string,count"; - static constexpr const char *Description = "Repeats the string count number of times"; - static constexpr const char *Example = "repeat('A', 5)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct ReplaceFun { - static constexpr const char *Name = "replace"; - static constexpr const char *Parameters = "string,source,target"; - static constexpr const char *Description = "Replaces any occurrences of the source with target in string"; - static constexpr const char *Example = "replace('hello', 'l', '-')"; - - static ScalarFunction GetFunction(); -}; - -struct ReverseFun { - static constexpr const char *Name = "reverse"; - static constexpr const char *Parameters = "string"; - static constexpr const char *Description = "Reverses the string"; - static constexpr const char *Example = "reverse('hello')"; - - static ScalarFunction GetFunction(); -}; - -struct RightFun { - static constexpr const char *Name = "right"; - static constexpr const char *Parameters = "string,count"; - static constexpr const char *Description = "Extract the right-most count characters"; - static constexpr const char *Example = "right('Hello🦆', 3)"; - - static ScalarFunction GetFunction(); -}; - -struct RightGraphemeFun { - static constexpr const char *Name = "right_grapheme"; - static constexpr const char *Parameters = "string,count"; - static constexpr const char *Description = "Extract the right-most count grapheme clusters"; - static constexpr const char *Example = "right_grapheme('🤦🏼‍♂️🤦🏽‍♀️', 1)"; - - static ScalarFunction GetFunction(); -}; - -struct RpadFun { - static constexpr const char *Name = "rpad"; - static constexpr const char *Parameters = "string,count,character"; - static constexpr const char *Description = "Pads the string with the character from the right until it has count characters"; - static constexpr const char *Example = "rpad('hello', 10, '<')"; - - static ScalarFunction GetFunction(); -}; - -struct RtrimFun { - static constexpr const char *Name = "rtrim"; - static constexpr const char *Parameters = "string,characters"; - static constexpr const char *Description = "Removes any occurrences of any of the characters from the right side of the string"; - static constexpr const char *Example = "rtrim('>>>>test<<', '><')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct TranslateFun { - static constexpr const char *Name = "translate"; - static constexpr const char *Parameters = "string,from,to"; - static constexpr const char *Description = "Replaces each character in string that matches a character in the from set with the corresponding character in the to set. If from is longer than to, occurrences of the extra characters in from are deleted"; - static constexpr const char *Example = "translate('12345', '143', 'ax')"; - - static ScalarFunction GetFunction(); -}; - -struct TrimFun { - static constexpr const char *Name = "trim"; - static constexpr const char *Parameters = "string::VARCHAR\1string::VARCHAR,characters::VARCHAR"; - static constexpr const char *Description = "Removes any spaces from either side of the string.\1Removes any occurrences of any of the characters from either side of the string"; - static constexpr const char *Example = "trim(' test ')\1trim('>>>>test<<', '><')"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct UnbinFun { - static constexpr const char *Name = "unbin"; - static constexpr const char *Parameters = "value"; - static constexpr const char *Description = "Converts a value from binary representation to a blob"; - static constexpr const char *Example = "unbin('0110')"; - - static ScalarFunction GetFunction(); -}; - -struct FromBinaryFun { - using ALIAS = UnbinFun; - - static constexpr const char *Name = "from_binary"; -}; - -struct UnhexFun { - static constexpr const char *Name = "unhex"; - static constexpr const char *Parameters = "value"; - static constexpr const char *Description = "Converts a value from hexadecimal representation to a blob"; - static constexpr const char *Example = "unhex('2A')"; - - static ScalarFunction GetFunction(); -}; - -struct FromHexFun { - using ALIAS = UnhexFun; - - static constexpr const char *Name = "from_hex"; -}; - -struct UnicodeFun { - static constexpr const char *Name = "unicode"; - static constexpr const char *Parameters = "str"; - static constexpr const char *Description = "Returns the unicode codepoint of the first character of the string"; - static constexpr const char *Example = "unicode('ü')"; - - static ScalarFunction GetFunction(); -}; - -struct OrdFun { - using ALIAS = UnicodeFun; - - static constexpr const char *Name = "ord"; -}; - -struct ToBaseFun { - static constexpr const char *Name = "to_base"; - static constexpr const char *Parameters = "number,radix,min_length"; - static constexpr const char *Description = "Converts a value to a string in the given base radix, optionally padding with leading zeros to the minimum length"; - static constexpr const char *Example = "to_base(42, 16)"; - - static ScalarFunctionSet GetFunctions(); -}; - -struct UrlEncodeFun { - static constexpr const char *Name = "url_encode"; - static constexpr const char *Parameters = "input"; - static constexpr const char *Description = "Escapes the input string by encoding it so that it can be included in a URL query parameter."; - static constexpr const char *Example = "url_encode('this string has/ special+ characters>')"; - - static ScalarFunction GetFunction(); -}; - -struct UrlDecodeFun { - static constexpr const char *Name = "url_decode"; - static constexpr const char *Parameters = "input"; - static constexpr const char *Description = "Unescapes the URL encoded input."; - static constexpr const char *Example = "url_decode('this%20string%20is%2BFencoded')"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp deleted file mode 100644 index f921bf434..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/struct_functions.hpp +++ /dev/null @@ -1,27 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/struct_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct StructInsertFun { - static constexpr const char *Name = "struct_insert"; - static constexpr const char *Parameters = "struct,any"; - static constexpr const char *Description = "Adds field(s)/value(s) to an existing STRUCT with the argument values. The entry name(s) will be the bound variable name(s)"; - static constexpr const char *Example = "struct_insert({'a': 1}, b := 2)"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp b/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp deleted file mode 100644 index 766c12e8f..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions/scalar/union_functions.hpp +++ /dev/null @@ -1,45 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions/scalar/union_functions.hpp -// -// -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_functions.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct UnionExtractFun { - static constexpr const char *Name = "union_extract"; - static constexpr const char *Parameters = "union,tag"; - static constexpr const char *Description = "Extract the value with the named tags from the union. NULL if the tag is not currently selected"; - static constexpr const char *Example = "union_extract(s, 'k')"; - - static ScalarFunction GetFunction(); -}; - -struct UnionTagFun { - static constexpr const char *Name = "union_tag"; - static constexpr const char *Parameters = "union"; - static constexpr const char *Description = "Retrieve the currently selected tag of the union as an ENUM"; - static constexpr const char *Example = "union_tag(union_value(k := 'foo'))"; - - static ScalarFunction GetFunction(); -}; - -struct UnionValueFun { - static constexpr const char *Name = "union_value"; - static constexpr const char *Parameters = "tag"; - static constexpr const char *Description = "Create a single member UNION containing the argument value. The tag of the value will be the bound variable name"; - static constexpr const char *Example = "union_value(k := 'hello')"; - - static ScalarFunction GetFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/include/core_functions_extension.hpp b/src/duckdb/extension/core_functions/include/core_functions_extension.hpp deleted file mode 100644 index e877860f0..000000000 --- a/src/duckdb/extension/core_functions/include/core_functions_extension.hpp +++ /dev/null @@ -1,22 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// core_functions_extension.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" - -namespace duckdb { - -class CoreFunctionsExtension : public Extension { -public: - void Load(DuckDB &db) override; - std::string Name() override; - std::string Version() const override; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/lambda_functions.cpp b/src/duckdb/extension/core_functions/lambda_functions.cpp deleted file mode 100644 index b5549914a..000000000 --- a/src/duckdb/extension/core_functions/lambda_functions.cpp +++ /dev/null @@ -1,414 +0,0 @@ -#include "duckdb/function/lambda_functions.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Helper functions -//===--------------------------------------------------------------------===// - -//! LambdaExecuteInfo holds information for executing the lambda expression on an input chunk and -//! a resulting lambda chunk. -struct LambdaExecuteInfo { - LambdaExecuteInfo(ClientContext &context, const Expression &lambda_expr, const DataChunk &args, - const bool has_index, const Vector &child_vector) - : has_index(has_index) { - - expr_executor = make_uniq(context, lambda_expr); - - // get the input types for the input chunk - vector input_types; - if (has_index) { - input_types.push_back(LogicalType::BIGINT); - } - input_types.push_back(child_vector.GetType()); - for (idx_t i = 1; i < args.ColumnCount(); i++) { - input_types.push_back(args.data[i].GetType()); - } - - // get the result types - vector result_types {lambda_expr.return_type}; - - // initialize the data chunks - input_chunk.InitializeEmpty(input_types); - lambda_chunk.Initialize(Allocator::DefaultAllocator(), result_types); - }; - - //! The expression executor that executes the lambda expression - unique_ptr expr_executor; - //! The input chunk on which we execute the lambda expression - DataChunk input_chunk; - //! The chunk holding the result of executing the lambda expression - DataChunk lambda_chunk; - //! True, if this lambda expression expects an index vector in the input chunk - bool has_index; -}; - -//! A helper struct with information that is specific to the list_filter function -struct ListFilterInfo { - //! The new list lengths after filtering out elements - vector entry_lengths; - //! The length of the current list - idx_t length = 0; - //! The offset of the current list - idx_t offset = 0; - //! The current row index - idx_t row_idx = 0; - //! The length of the source list - idx_t src_length = 0; -}; - -//! ListTransformFunctor contains list_transform specific functionality -struct ListTransformFunctor { - static void ReserveNewLengths(vector &, const idx_t) { - // NOP - } - static void PushEmptyList(vector &) { - // NOP - } - //! Sets the list entries of the result vector - static void SetResultEntry(list_entry_t *result_entries, idx_t &offset, const list_entry_t &entry, - const idx_t row_idx, vector &) { - result_entries[row_idx].offset = offset; - result_entries[row_idx].length = entry.length; - offset += entry.length; - } - //! Appends the lambda vector to the result's child vector - static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *, - ListFilterInfo &, LambdaExecuteInfo &) { - ListVector::Append(result, lambda_vector, elem_cnt, 0); - } -}; - -//! ListFilterFunctor contains list_filter specific functionality -struct ListFilterFunctor { - //! Initializes the entry_lengths vector - static void ReserveNewLengths(vector &entry_lengths, const idx_t row_count) { - entry_lengths.reserve(row_count); - } - //! Pushes an empty list to the entry_lengths vector - static void PushEmptyList(vector &entry_lengths) { - entry_lengths.emplace_back(0); - } - //! Pushes the length of the original list to the entry_lengths vector - static void SetResultEntry(list_entry_t *, idx_t &, const list_entry_t &entry, const idx_t, - vector &entry_lengths) { - entry_lengths.push_back(entry.length); - } - //! Uses the lambda vector to filter the incoming list and to append the filtered list to the result vector - static void AppendResult(Vector &result, Vector &lambda_vector, const idx_t elem_cnt, list_entry_t *result_entries, - ListFilterInfo &info, LambdaExecuteInfo &execute_info) { - - idx_t count = 0; - SelectionVector sel(elem_cnt); - UnifiedVectorFormat lambda_data; - lambda_vector.ToUnifiedFormat(elem_cnt, lambda_data); - - auto lambda_values = UnifiedVectorFormat::GetData(lambda_data); - auto &lambda_validity = lambda_data.validity; - - // compute the new lengths and offsets, and create a selection vector - for (idx_t i = 0; i < elem_cnt; i++) { - auto entry_idx = lambda_data.sel->get_index(i); - - // set length and offset of empty lists - while (info.row_idx < info.entry_lengths.size() && !info.entry_lengths[info.row_idx]) { - result_entries[info.row_idx].offset = info.offset; - result_entries[info.row_idx].length = 0; - info.row_idx++; - } - - // found a true value - if (lambda_validity.RowIsValid(entry_idx) && lambda_values[entry_idx]) { - sel.set_index(count++, i); - info.length++; - } - - info.src_length++; - - // we traversed the entire source list - if (info.entry_lengths[info.row_idx] == info.src_length) { - // set the offset and length of the result entry - result_entries[info.row_idx].offset = info.offset; - result_entries[info.row_idx].length = info.length; - - // reset all other fields - info.offset += info.length; - info.row_idx++; - info.length = 0; - info.src_length = 0; - } - } - - // set length and offset of all remaining empty lists - while (info.row_idx < info.entry_lengths.size() && !info.entry_lengths[info.row_idx]) { - result_entries[info.row_idx].offset = info.offset; - result_entries[info.row_idx].length = 0; - info.row_idx++; - } - - // slice the input chunk's corresponding vector to get the new lists - // and append them to the result - idx_t source_list_idx = execute_info.has_index ? 1 : 0; - Vector result_lists(execute_info.input_chunk.data[source_list_idx], sel, count); - ListVector::Append(result, result_lists, count, 0); - } -}; - -vector LambdaFunctions::GetColumnInfo(DataChunk &args, const idx_t row_count) { - vector data; - // skip the input list and then insert all remaining input vectors - for (idx_t i = 1; i < args.ColumnCount(); i++) { - data.emplace_back(args.data[i]); - args.data[i].ToUnifiedFormat(row_count, data.back().format); - } - return data; -} - -vector> -LambdaFunctions::GetMutableColumnInfo(vector &data) { - vector> inconstant_info; - for (auto &entry : data) { - if (entry.vector.get().GetVectorType() != VectorType::CONSTANT_VECTOR) { - inconstant_info.push_back(entry); - } - } - return inconstant_info; -} - -void ExecuteExpression(const idx_t elem_cnt, const LambdaFunctions::ColumnInfo &column_info, - const vector &column_infos, const Vector &index_vector, - LambdaExecuteInfo &info) { - - info.input_chunk.SetCardinality(elem_cnt); - info.lambda_chunk.SetCardinality(elem_cnt); - - // slice the child vector - Vector slice(column_info.vector, column_info.sel, elem_cnt); - - // reference the child vector (and the index vector) - if (info.has_index) { - info.input_chunk.data[0].Reference(index_vector); - info.input_chunk.data[1].Reference(slice); - } else { - info.input_chunk.data[0].Reference(slice); - } - idx_t slice_offset = info.has_index ? 2 : 1; - - // (slice and) reference the other columns - vector slices; - for (idx_t i = 0; i < column_infos.size(); i++) { - - if (column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { - // only reference constant vectorsl - info.input_chunk.data[i + slice_offset].Reference(column_infos[i].vector); - - } else { - // slice inconstant vectors - slices.emplace_back(column_infos[i].vector, column_infos[i].sel, elem_cnt); - info.input_chunk.data[i + slice_offset].Reference(slices.back()); - } - } - - // execute the lambda expression - info.expr_executor->Execute(info.input_chunk, info.lambda_chunk); -} - -//===--------------------------------------------------------------------===// -// ListLambdaBindData -//===--------------------------------------------------------------------===// - -unique_ptr ListLambdaBindData::Copy() const { - auto lambda_expr_copy = lambda_expr ? lambda_expr->Copy() : nullptr; - return make_uniq(return_type, std::move(lambda_expr_copy), has_index); -} - -bool ListLambdaBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return Expression::Equals(lambda_expr, other.lambda_expr) && return_type == other.return_type && - has_index == other.has_index; -} - -void ListLambdaBindData::Serialize(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "return_type", bind_data.return_type); - serializer.WritePropertyWithDefault(101, "lambda_expr", bind_data.lambda_expr, unique_ptr()); - serializer.WriteProperty(102, "has_index", bind_data.has_index); -} - -unique_ptr ListLambdaBindData::Deserialize(Deserializer &deserializer, ScalarFunction &) { - auto return_type = deserializer.ReadProperty(100, "return_type"); - auto lambda_expr = deserializer.ReadPropertyWithExplicitDefault>(101, "lambda_expr", - unique_ptr()); - auto has_index = deserializer.ReadProperty(102, "has_index"); - return make_uniq(return_type, std::move(lambda_expr), has_index); -} - -//===--------------------------------------------------------------------===// -// LambdaFunctions -//===--------------------------------------------------------------------===// - -LogicalType LambdaFunctions::BindBinaryLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { - switch (parameter_idx) { - case 0: - return list_child_type; - case 1: - return LogicalType::BIGINT; - default: - throw BinderException("This lambda function only supports up to two lambda parameters!"); - } -} - -LogicalType LambdaFunctions::BindTernaryLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { - switch (parameter_idx) { - case 0: - return list_child_type; - case 1: - return list_child_type; - case 2: - return LogicalType::BIGINT; - default: - throw BinderException("This lambda function only supports up to three lambda parameters!"); - } -} - -template -void ExecuteLambda(DataChunk &args, ExpressionState &state, Vector &result) { - - bool result_is_null = false; - LambdaFunctions::LambdaInfo info(args, state, result, result_is_null); - if (result_is_null) { - return; - } - - auto result_entries = FlatVector::GetData(result); - auto mutable_column_infos = LambdaFunctions::GetMutableColumnInfo(info.column_infos); - - // special-handling for the child_vector - auto child_vector_size = ListVector::GetListSize(args.data[0]); - LambdaFunctions::ColumnInfo child_info(*info.child_vector); - info.child_vector->ToUnifiedFormat(child_vector_size, child_info.format); - - // get the expression executor - LambdaExecuteInfo execute_info(state.GetContext(), *info.lambda_expr, args, info.has_index, *info.child_vector); - - // get list_filter specific info - ListFilterInfo list_filter_info; - FUNCTION_FUNCTOR::ReserveNewLengths(list_filter_info.entry_lengths, info.row_count); - - // additional index vector - Vector index_vector(LogicalType::BIGINT); - - // loop over the child entries and create chunks to be executed by the expression executor - idx_t elem_cnt = 0; - idx_t offset = 0; - for (idx_t row_idx = 0; row_idx < info.row_count; row_idx++) { - - auto list_idx = info.list_column_format.sel->get_index(row_idx); - const auto &list_entry = info.list_entries[list_idx]; - - // set the result to NULL for this row - if (!info.list_column_format.validity.RowIsValid(list_idx)) { - info.result_validity->SetInvalid(row_idx); - FUNCTION_FUNCTOR::PushEmptyList(list_filter_info.entry_lengths); - continue; - } - - FUNCTION_FUNCTOR::SetResultEntry(result_entries, offset, list_entry, row_idx, list_filter_info.entry_lengths); - - // empty list, nothing to execute - if (list_entry.length == 0) { - continue; - } - - // iterate the elements of the current list and create the corresponding selection vectors - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - - // reached STANDARD_VECTOR_SIZE elements - if (elem_cnt == STANDARD_VECTOR_SIZE) { - - execute_info.lambda_chunk.Reset(); - ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); - auto &lambda_vector = execute_info.lambda_chunk.data[0]; - - FUNCTION_FUNCTOR::AppendResult(result, lambda_vector, elem_cnt, result_entries, list_filter_info, - execute_info); - elem_cnt = 0; - } - - // FIXME: reuse same selection vector for inconstant rows - // adjust indexes for slicing - child_info.sel.set_index(elem_cnt, list_entry.offset + child_idx); - for (auto &entry : mutable_column_infos) { - entry.get().sel.set_index(elem_cnt, row_idx); - } - - // set the index vector - if (info.has_index) { - index_vector.SetValue(elem_cnt, Value::BIGINT(NumericCast(child_idx + 1))); - } - - elem_cnt++; - } - } - - execute_info.lambda_chunk.Reset(); - ExecuteExpression(elem_cnt, child_info, info.column_infos, index_vector, execute_info); - auto &lambda_vector = execute_info.lambda_chunk.data[0]; - - FUNCTION_FUNCTOR::AppendResult(result, lambda_vector, elem_cnt, result_entries, list_filter_info, execute_info); - - if (info.is_all_constant && !info.is_volatile) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -unique_ptr LambdaFunctions::ListLambdaPrepareBind(vector> &arguments, - ClientContext &context, - ScalarFunction &bound_function) { - // NULL list parameter - if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type, nullptr); - } - // prepared statements - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - D_ASSERT(arguments[0]->return_type.id() == LogicalTypeId::LIST); - return nullptr; -} - -unique_ptr LambdaFunctions::ListLambdaBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - const bool has_index) { - unique_ptr bind_data = ListLambdaPrepareBind(arguments, context, bound_function); - if (bind_data) { - return bind_data; - } - - // get the lambda expression and put it in the bind info - auto &bound_lambda_expr = arguments[1]->Cast(); - auto lambda_expr = std::move(bound_lambda_expr.lambda_expr); - - return make_uniq(bound_function.return_type, std::move(lambda_expr), has_index); -} - -void LambdaFunctions::ListTransformFunction(DataChunk &args, ExpressionState &state, Vector &result) { - ExecuteLambda(args, state, result); -} - -void LambdaFunctions::ListFilterFunction(DataChunk &args, ExpressionState &state, Vector &result) { - ExecuteLambda(args, state, result); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp b/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp deleted file mode 100644 index af7d0ee03..000000000 --- a/src/duckdb/extension/core_functions/scalar/array/array_functions.cpp +++ /dev/null @@ -1,280 +0,0 @@ -#include "core_functions/scalar/array_functions.hpp" -#include "core_functions/array_kernels.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static unique_ptr ArrayGenericBinaryBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - const auto lhs_is_param = arguments[0]->HasParameter(); - const auto rhs_is_param = arguments[1]->HasParameter(); - - if (lhs_is_param && rhs_is_param) { - throw ParameterNotResolvedException(); - } - - const auto &lhs_type = arguments[0]->return_type; - const auto &rhs_type = arguments[1]->return_type; - - bound_function.arguments[0] = lhs_is_param ? rhs_type : lhs_type; - bound_function.arguments[1] = rhs_is_param ? lhs_type : rhs_type; - - if (bound_function.arguments[0].id() != LogicalTypeId::ARRAY || - bound_function.arguments[1].id() != LogicalTypeId::ARRAY) { - throw InvalidInputException( - StringUtil::Format("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name)); - } - - const auto lhs_size = ArrayType::GetSize(bound_function.arguments[0]); - const auto rhs_size = ArrayType::GetSize(bound_function.arguments[1]); - - if (lhs_size != rhs_size) { - throw BinderException("%s: Array arguments must be of the same size", bound_function.name); - } - - const auto &lhs_element_type = ArrayType::GetChildType(bound_function.arguments[0]); - const auto &rhs_element_type = ArrayType::GetChildType(bound_function.arguments[1]); - - // Resolve common type - LogicalType common_type; - if (!LogicalType::TryGetMaxLogicalType(context, lhs_element_type, rhs_element_type, common_type)) { - throw BinderException("%s: Cannot infer common element type (left = '%s', right = '%s')", bound_function.name, - lhs_element_type.ToString(), rhs_element_type.ToString()); - } - - // Ensure it is float or double - if (common_type.id() != LogicalTypeId::FLOAT && common_type.id() != LogicalTypeId::DOUBLE) { - throw BinderException("%s: Arguments must be arrays of FLOAT or DOUBLE", bound_function.name); - } - - // The important part is just that we resolve the size of the input arrays - bound_function.arguments[0] = LogicalType::ARRAY(common_type, lhs_size); - bound_function.arguments[1] = LogicalType::ARRAY(common_type, rhs_size); - - return nullptr; -} - -//------------------------------------------------------------------------------ -// Element-wise combine functions -//------------------------------------------------------------------------------ -// Given two arrays of the same size, combine their elements into a single array -// of the same size as the input arrays. - -struct CrossProductOp { - template - static void Operation(const TYPE *lhs_data, const TYPE *rhs_data, TYPE *res_data, idx_t size) { - D_ASSERT(size == 3); - - auto lx = lhs_data[0]; - auto ly = lhs_data[1]; - auto lz = lhs_data[2]; - - auto rx = rhs_data[0]; - auto ry = rhs_data[1]; - auto rz = rhs_data[2]; - - res_data[0] = ly * rz - lz * ry; - res_data[1] = lz * rx - lx * rz; - res_data[2] = lx * ry - ly * rx; - } -}; - -template -static void ArrayFixedCombine(DataChunk &args, ExpressionState &state, Vector &result) { - const auto &lstate = state.Cast(); - const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.name; - - const auto count = args.size(); - auto &lhs_child = ArrayVector::GetEntry(args.data[0]); - auto &rhs_child = ArrayVector::GetEntry(args.data[1]); - auto &res_child = ArrayVector::GetEntry(result); - - const auto &lhs_child_validity = FlatVector::Validity(lhs_child); - const auto &rhs_child_validity = FlatVector::Validity(rhs_child); - - UnifiedVectorFormat lhs_format; - UnifiedVectorFormat rhs_format; - - args.data[0].ToUnifiedFormat(count, lhs_format); - args.data[1].ToUnifiedFormat(count, rhs_format); - - auto lhs_data = FlatVector::GetData(lhs_child); - auto rhs_data = FlatVector::GetData(rhs_child); - auto res_data = FlatVector::GetData(res_child); - - for (idx_t i = 0; i < count; i++) { - const auto lhs_idx = lhs_format.sel->get_index(i); - const auto rhs_idx = rhs_format.sel->get_index(i); - - if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { - FlatVector::SetNull(result, i, true); - continue; - } - - const auto left_offset = lhs_idx * N; - if (!lhs_child_validity.CheckAllValid(left_offset + N, left_offset)) { - throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); - } - - const auto right_offset = rhs_idx * N; - if (!rhs_child_validity.CheckAllValid(right_offset + N, right_offset)) { - throw InvalidInputException( - StringUtil::Format("%s: right argument can not contain NULL values", func_name)); - } - const auto result_offset = i * N; - - const auto lhs_data_ptr = lhs_data + left_offset; - const auto rhs_data_ptr = rhs_data + right_offset; - const auto res_data_ptr = res_data + result_offset; - - OP::Operation(lhs_data_ptr, rhs_data_ptr, res_data_ptr, N); - } - - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -//------------------------------------------------------------------------------ -// Generic "fold" function -//------------------------------------------------------------------------------ -// Given two arrays, combine and reduce their elements into a single scalar value. - -template -static void ArrayGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { - const auto &lstate = state.Cast(); - const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.name; - - const auto count = args.size(); - auto &lhs_child = ArrayVector::GetEntry(args.data[0]); - auto &rhs_child = ArrayVector::GetEntry(args.data[1]); - - const auto &lhs_child_validity = FlatVector::Validity(lhs_child); - const auto &rhs_child_validity = FlatVector::Validity(rhs_child); - - UnifiedVectorFormat lhs_format; - UnifiedVectorFormat rhs_format; - - args.data[0].ToUnifiedFormat(count, lhs_format); - args.data[1].ToUnifiedFormat(count, rhs_format); - - auto lhs_data = FlatVector::GetData(lhs_child); - auto rhs_data = FlatVector::GetData(rhs_child); - auto res_data = FlatVector::GetData(result); - - const auto array_size = ArrayType::GetSize(args.data[0].GetType()); - D_ASSERT(array_size == ArrayType::GetSize(args.data[1].GetType())); - - for (idx_t i = 0; i < count; i++) { - const auto lhs_idx = lhs_format.sel->get_index(i); - const auto rhs_idx = rhs_format.sel->get_index(i); - - if (!lhs_format.validity.RowIsValid(lhs_idx) || !rhs_format.validity.RowIsValid(rhs_idx)) { - FlatVector::SetNull(result, i, true); - continue; - } - - const auto left_offset = lhs_idx * array_size; - if (!lhs_child_validity.CheckAllValid(left_offset + array_size, left_offset)) { - throw InvalidInputException(StringUtil::Format("%s: left argument can not contain NULL values", func_name)); - } - - const auto right_offset = rhs_idx * array_size; - if (!rhs_child_validity.CheckAllValid(right_offset + array_size, right_offset)) { - throw InvalidInputException( - StringUtil::Format("%s: right argument can not contain NULL values", func_name)); - } - - const auto lhs_data_ptr = lhs_data + left_offset; - const auto rhs_data_ptr = rhs_data + right_offset; - - res_data[i] = OP::Operation(lhs_data_ptr, rhs_data_ptr, array_size); - } - - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -//------------------------------------------------------------------------------ -// Function Registration -//------------------------------------------------------------------------------ -// Note: In the future we could add a wrapper with a non-type template parameter to specialize for specific array sizes -// e.g. 256, 512, 1024, 2048 etc. which may allow the compiler to vectorize the loop better. Perhaps something for an -// extension. - -template -static void AddArrayFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { - const auto array = LogicalType::ARRAY(type, optional_idx()); - if (type.id() == LogicalTypeId::FLOAT) { - ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); - set.AddFunction(function); - } else if (type.id() == LogicalTypeId::DOUBLE) { - ScalarFunction function({array, array}, type, ArrayGenericFold, ArrayGenericBinaryBind); - BaseScalarFunction::SetReturnsError(function); - set.AddFunction(function); - } else { - throw NotImplementedException("Array function not implemented for type %s", type.ToString()); - } -} - -ScalarFunctionSet ArrayDistanceFun::GetFunctions() { - ScalarFunctionSet set("array_distance"); - for (auto &type : LogicalType::Real()) { - AddArrayFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ArrayInnerProductFun::GetFunctions() { - ScalarFunctionSet set("array_inner_product"); - for (auto &type : LogicalType::Real()) { - AddArrayFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ArrayNegativeInnerProductFun::GetFunctions() { - ScalarFunctionSet set("array_negative_inner_product"); - for (auto &type : LogicalType::Real()) { - AddArrayFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ArrayCosineSimilarityFun::GetFunctions() { - ScalarFunctionSet set("array_cosine_similarity"); - for (auto &type : LogicalType::Real()) { - AddArrayFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ArrayCosineDistanceFun::GetFunctions() { - ScalarFunctionSet set("array_cosine_distance"); - for (auto &type : LogicalType::Real()) { - AddArrayFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ArrayCrossProductFun::GetFunctions() { - ScalarFunctionSet set("array_cross_product"); - - auto float_array = LogicalType::ARRAY(LogicalType::FLOAT, 3); - auto double_array = LogicalType::ARRAY(LogicalType::DOUBLE, 3); - set.AddFunction( - ScalarFunction({float_array, float_array}, float_array, ArrayFixedCombine)); - set.AddFunction( - ScalarFunction({double_array, double_array}, double_array, ArrayFixedCombine)); - for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp b/src/duckdb/extension/core_functions/scalar/array/array_value.cpp deleted file mode 100644 index e7f715f75..000000000 --- a/src/duckdb/extension/core_functions/scalar/array/array_value.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include "core_functions/scalar/array_functions.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/storage/statistics/array_stats.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static void ArrayValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto array_type = result.GetType(); - - D_ASSERT(array_type.id() == LogicalTypeId::ARRAY); - D_ASSERT(args.ColumnCount() == ArrayType::GetSize(array_type)); - - auto &child_type = ArrayType::GetChildType(array_type); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - } - - auto num_rows = args.size(); - auto num_columns = args.ColumnCount(); - - auto &child = ArrayVector::GetEntry(result); - - if (num_columns > 1) { - // Ensure that the child has a validity mask of the correct size - // The SetValue call below expects the validity mask to be initialized - auto &child_validity = FlatVector::Validity(child); - child_validity.Resize(num_rows * num_columns); - } - - for (idx_t i = 0; i < num_rows; i++) { - for (idx_t j = 0; j < num_columns; j++) { - auto val = args.GetValue(j, i).DefaultCastAs(child_type); - child.SetValue((i * num_columns) + j, val); - } - } - - result.Verify(args.size()); -} - -static unique_ptr ArrayValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.empty()) { - throw InvalidInputException("array_value requires at least one argument"); - } - - // construct return type - LogicalType child_type = arguments[0]->return_type; - for (idx_t i = 1; i < arguments.size(); i++) { - child_type = LogicalType::MaxLogicalType(context, child_type, arguments[i]->return_type); - } - - if (arguments.size() > ArrayType::MAX_ARRAY_SIZE) { - throw OutOfRangeException("Array size exceeds maximum allowed size"); - } - - // this is more for completeness reasons - bound_function.varargs = child_type; - bound_function.return_type = LogicalType::ARRAY(child_type, arguments.size()); - return make_uniq(bound_function.return_type); -} - -unique_ptr ArrayValueStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto list_stats = ArrayStats::CreateEmpty(expr.return_type); - auto &list_child_stats = ArrayStats::GetChildStats(list_stats); - for (idx_t i = 0; i < child_stats.size(); i++) { - list_child_stats.Merge(child_stats[i]); - } - return list_stats.ToUnique(); -} - -ScalarFunction ArrayValueFun::GetFunction() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun("array_value", {}, LogicalTypeId::ARRAY, ArrayValueFunction, ArrayValueBind, nullptr, - ArrayValueStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp b/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp deleted file mode 100644 index 0dbcb8ebc..000000000 --- a/src/duckdb/extension/core_functions/scalar/bit/bitstring.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include "core_functions/scalar/bit_functions.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/cast_helpers.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// BitStringFunction -//===--------------------------------------------------------------------===// -template -static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) { - if (n < 0) { - throw InvalidInputException("The bitstring length cannot be negative"); - } - idx_t input_length; - if (FROM_STRING) { - input_length = input.GetSize(); - } else { - input_length = Bit::BitLength(input); - } - if (idx_t(n) < input_length) { - throw InvalidInputException("Length must be equal or larger than input string"); - } - idx_t len; - if (FROM_STRING) { - Bit::TryGetBitStringSize(input, len, nullptr); // string verification - } - - len = Bit::ComputeBitstringLen(UnsafeNumericCast(n)); - string_t target = StringVector::EmptyString(result, len); - if (FROM_STRING) { - Bit::BitString(input, UnsafeNumericCast(n), target); - } else { - Bit::ExtendBitString(input, UnsafeNumericCast(n), target); - } - target.Finalize(); - return target; - }); -} - -ScalarFunctionSet BitStringFun::GetFunctions() { - ScalarFunctionSet bitstring; - bitstring.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); - bitstring.AddFunction( - ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction)); - for (auto &func : bitstring.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return bitstring; -} - -//===--------------------------------------------------------------------===// -// get_bit -//===--------------------------------------------------------------------===// -struct GetBitOperator { - template - static inline TR Operation(TA input, TB n) { - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { - throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), - NumericHelper::ToString(Bit::BitLength(input) - 1)); - } - return UnsafeNumericCast(Bit::GetBit(input, UnsafeNumericCast(n))); - } -}; - -ScalarFunction GetBitFun::GetFunction() { - ScalarFunction func({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(func); - return func; -} - -//===--------------------------------------------------------------------===// -// set_bit -//===--------------------------------------------------------------------===// -static void SetBitOperation(DataChunk &args, ExpressionState &state, Vector &result) { - TernaryExecutor::Execute( - args.data[0], args.data[1], args.data[2], result, args.size(), - [&](string_t input, int32_t n, int32_t new_value) { - if (new_value != 0 && new_value != 1) { - throw InvalidInputException("The new bit must be 1 or 0"); - } - if (n < 0 || (idx_t)n > Bit::BitLength(input) - 1) { - throw OutOfRangeException("bit index %s out of valid range (0..%s)", NumericHelper::ToString(n), - NumericHelper::ToString(Bit::BitLength(input) - 1)); - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - memcpy(target.GetDataWriteable(), input.GetData(), input.GetSize()); - Bit::SetBit(target, UnsafeNumericCast(n), UnsafeNumericCast(new_value)); - return target; - }); -} - -ScalarFunction SetBitFun::GetFunction() { - ScalarFunction function({LogicalType::BIT, LogicalType::INTEGER, LogicalType::INTEGER}, LogicalType::BIT, - SetBitOperation); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// bit_position -//===--------------------------------------------------------------------===// -struct BitPositionOperator { - template - static inline TR Operation(TA substring, TB input) { - if (substring.GetSize() > input.GetSize()) { - return 0; - } - return UnsafeNumericCast(Bit::BitPosition(substring, input)); - } -}; - -ScalarFunction BitPositionFun::GetFunction() { - return ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::INTEGER, - ScalarFunction::BinaryFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp b/src/duckdb/extension/core_functions/scalar/blob/base64.cpp deleted file mode 100644 index fb903fa80..000000000 --- a/src/duckdb/extension/core_functions/scalar/blob/base64.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "core_functions/scalar/blob_functions.hpp" -#include "duckdb/common/types/blob.hpp" - -namespace duckdb { - -struct Base64EncodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto result_str = StringVector::EmptyString(result, Blob::ToBase64Size(input)); - Blob::ToBase64(input, result_str.GetDataWriteable()); - result_str.Finalize(); - return result_str; - } -}; - -struct Base64DecodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto result_size = Blob::FromBase64Size(input); - auto result_blob = StringVector::EmptyString(result, result_size); - Blob::FromBase64(input, data_ptr_cast(result_blob.GetDataWriteable()), result_size); - result_blob.Finalize(); - return result_blob; - } -}; - -static void Base64EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -static void Base64DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction ToBase64Fun::GetFunction() { - return ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, Base64EncodeFunction); -} - -ScalarFunction FromBase64Fun::GetFunction() { - ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, Base64DecodeFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp b/src/duckdb/extension/core_functions/scalar/blob/encode.cpp deleted file mode 100644 index 66cedb0b5..000000000 --- a/src/duckdb/extension/core_functions/scalar/blob/encode.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "core_functions/scalar/blob_functions.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" - -namespace duckdb { - -static void EncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // encode is essentially a nop cast from varchar to blob - // we only need to reinterpret the data using the blob type - result.Reinterpret(args.data[0]); -} - -struct BlobDecodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - if (Utf8Proc::Analyze(input_data, input_length) == UnicodeType::INVALID) { - throw ConversionException( - "Failure in decode: could not convert blob to UTF8 string, the blob contained invalid UTF8 characters"); - } - return input; - } -}; - -static void DecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // decode is also a nop cast, but requires verification if the provided string is actually - UnaryExecutor::Execute(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction EncodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::BLOB, EncodeFunction); -} - -ScalarFunction DecodeFun::GetFunction() { - ScalarFunction function({LogicalType::BLOB}, LogicalType::VARCHAR, DecodeFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/age.cpp b/src/duckdb/extension/core_functions/scalar/date/age.cpp deleted file mode 100644 index cf7281f08..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/age.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/transaction/meta_transaction.hpp" - -namespace duckdb { - -static void AgeFunctionStandard(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - // Subtract argument from current_date (at midnight) - // Theoretically, this should be TZ-sensitive, but since we have to be able to handle - // plain TZ when ICU is not loaded, we implement this in UTC (like everything else) - // To get the PG behaviour, we overload these functions in ICU for TSTZ arguments. - auto current_date = Timestamp::FromDatetime( - Timestamp::GetDate(MetaTransaction::Get(state.GetContext()).start_timestamp), dtime_t(0)); - - UnaryExecutor::ExecuteWithNulls(input.data[0], result, input.size(), - [&](timestamp_t input, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input)) { - return Interval::GetAge(current_date, input); - } else { - mask.SetInvalid(idx); - return interval_t(); - } - }); -} - -static void AgeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 2); - - BinaryExecutor::ExecuteWithNulls( - input.data[0], input.data[1], result, input.size(), - [&](timestamp_t input1, timestamp_t input2, ValidityMask &mask, idx_t idx) { - if (Timestamp::IsFinite(input1) && Timestamp::IsFinite(input2)) { - return Interval::GetAge(input1, input2); - } else { - mask.SetInvalid(idx); - return interval_t(); - } - }); -} - -ScalarFunctionSet AgeFun::GetFunctions() { - ScalarFunctionSet age("age"); - age.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunctionStandard)); - age.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, LogicalType::INTERVAL, AgeFunction)); - return age; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/current.cpp b/src/duckdb/extension/core_functions/scalar/date/current.cpp deleted file mode 100644 index 3d25ee80a..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/current.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/transaction/meta_transaction.hpp" - -namespace duckdb { - -static timestamp_t GetTransactionTimestamp(ExpressionState &state) { - return MetaTransaction::Get(state.GetContext()).start_timestamp; -} - -static void CurrentTimestampFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 0); - auto ts = GetTransactionTimestamp(state); - auto val = Value::TIMESTAMPTZ(timestamp_tz_t(ts)); - result.Reference(val); -} - -ScalarFunction GetCurrentTimestampFun::GetFunction() { - ScalarFunction current_timestamp({}, LogicalType::TIMESTAMP_TZ, CurrentTimestampFunction); - current_timestamp.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return current_timestamp; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp b/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp deleted file mode 100644 index c0e4ba1da..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/date_diff.cpp +++ /dev/null @@ -1,454 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -// This function is an implementation of the "period-crossing" date difference function from T-SQL -// https://docs.microsoft.com/en-us/sql/t-sql/functions/datediff-transact-sql?view=sql-server-ver15 -struct DateDiff { - template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::ExecuteWithNulls( - left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return OP::template Operation(startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - }); - } - - // We need to truncate down, not towards 0 - static inline int64_t Truncate(int64_t value, int64_t units) { - return (value + (value < 0)) / units - (value < 0); - } - static inline int64_t Diff(int64_t start, int64_t end, int64_t units) { - return Truncate(end, units) - Truncate(start, units); - } - - struct YearOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) - Date::ExtractYear(startdate); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - int32_t start_year, start_month, start_day; - Date::Convert(startdate, start_year, start_month, start_day); - int32_t end_year, end_month, end_day; - Date::Convert(enddate, end_year, end_month, end_day); - - return (end_year * 12 + end_month - 1) - (start_year * 12 + start_month - 1); - } - }; - - struct DayOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return TR(Date::EpochDays(enddate)) - TR(Date::EpochDays(startdate)); - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 10 - Date::ExtractYear(startdate) / 10; - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 100 - Date::ExtractYear(startdate) / 100; - } - }; - - struct MilleniumOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractYear(enddate) / 1000 - Date::ExtractYear(startdate) / 1000; - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - int32_t start_year, start_month, start_day; - Date::Convert(startdate, start_year, start_month, start_day); - int32_t end_year, end_month, end_day; - Date::Convert(enddate, end_year, end_month, end_day); - - return (end_year * 12 + end_month - 1) / Interval::MONTHS_PER_QUARTER - - (start_year * 12 + start_month - 1) / Interval::MONTHS_PER_QUARTER; - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - // Weeks do not count Monday crossings, just distance - return (enddate.days - startdate.days) / Interval::DAYS_PER_WEEK; - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::ExtractISOYearNumber(enddate) - Date::ExtractISOYearNumber(startdate); - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::EpochMicroseconds(enddate) - Date::EpochMicroseconds(startdate); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::EpochMicroseconds(enddate) / Interval::MICROS_PER_MSEC - - Date::EpochMicroseconds(startdate) / Interval::MICROS_PER_MSEC; - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) - Date::Epoch(startdate); - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) / Interval::SECS_PER_MINUTE - - Date::Epoch(startdate) / Interval::SECS_PER_MINUTE; - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return Date::Epoch(enddate) / Interval::SECS_PER_HOUR - Date::Epoch(startdate) / Interval::SECS_PER_HOUR; - } - }; -}; - -// TIMESTAMP specialisations -template <> -int64_t DateDiff::YearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return YearOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MonthOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return MonthOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::DayOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return DayOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::DecadeOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return DecadeOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::CenturyOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return CenturyOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MilleniumOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return MilleniumOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::QuarterOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return QuarterOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::WeekOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return WeekOperator::Operation(Timestamp::GetDate(startdate), Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::ISOYearOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - return ISOYearOperator::Operation(Timestamp::GetDate(startdate), - Timestamp::GetDate(enddate)); -} - -template <> -int64_t DateDiff::MicrosecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - const auto start = Timestamp::GetEpochMicroSeconds(startdate); - const auto end = Timestamp::GetEpochMicroSeconds(enddate); - return SubtractOperatorOverflowCheck::Operation(end, start); -} - -template <> -int64_t DateDiff::MillisecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - D_ASSERT(Timestamp::IsFinite(startdate)); - D_ASSERT(Timestamp::IsFinite(enddate)); - return Diff(startdate.value, enddate.value, Interval::MICROS_PER_MSEC); -} - -template <> -int64_t DateDiff::SecondsOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - D_ASSERT(Timestamp::IsFinite(startdate)); - D_ASSERT(Timestamp::IsFinite(enddate)); - return Diff(startdate.value, enddate.value, Interval::MICROS_PER_SEC); -} - -template <> -int64_t DateDiff::MinutesOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - D_ASSERT(Timestamp::IsFinite(startdate)); - D_ASSERT(Timestamp::IsFinite(enddate)); - return Diff(startdate.value, enddate.value, Interval::MICROS_PER_MINUTE); -} - -template <> -int64_t DateDiff::HoursOperator::Operation(timestamp_t startdate, timestamp_t enddate) { - D_ASSERT(Timestamp::IsFinite(startdate)); - D_ASSERT(Timestamp::IsFinite(enddate)); - return Diff(startdate.value, enddate.value, Interval::MICROS_PER_HOUR); -} - -// TIME specialisations -template <> -int64_t DateDiff::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DateDiff::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DateDiff::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DateDiff::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DateDiff::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DateDiff::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DateDiff::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DateDiff::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DateDiff::ISOYearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); -} - -template <> -int64_t DateDiff::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros - startdate.micros; -} - -template <> -int64_t DateDiff::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_MSEC - startdate.micros / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DateDiff::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_SEC - startdate.micros / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DateDiff::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_MINUTE - startdate.micros / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DateDiff::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros / Interval::MICROS_PER_HOUR - startdate.micros / Interval::MICROS_PER_HOUR; -} - -template -static int64_t DifferenceDates(DatePartSpecifier type, TA startdate, TB enddate) { - switch (type) { - case DatePartSpecifier::YEAR: - return DateDiff::YearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MONTH: - return DateDiff::MonthOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateDiff::DayOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DECADE: - return DateDiff::DecadeOperator::template Operation(startdate, enddate); - case DatePartSpecifier::CENTURY: - return DateDiff::CenturyOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLENNIUM: - return DateDiff::MilleniumOperator::template Operation(startdate, enddate); - case DatePartSpecifier::QUARTER: - return DateDiff::QuarterOperator::template Operation(startdate, enddate); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateDiff::WeekOperator::template Operation(startdate, enddate); - case DatePartSpecifier::ISOYEAR: - return DateDiff::ISOYearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MICROSECONDS: - return DateDiff::MicrosecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLISECONDS: - return DateDiff::MillisecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateDiff::SecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MINUTE: - return DateDiff::MinutesOperator::template Operation(startdate, enddate); - case DatePartSpecifier::HOUR: - return DateDiff::HoursOperator::template Operation(startdate, enddate); - default: - throw NotImplementedException("Specifier type not implemented for DATEDIFF"); - } -} - -struct DateDiffTernaryOperator { - template - static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return DifferenceDates(GetDatePartSpecifier(part.GetString()), startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - } -}; - -template -static void DateDiffBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::YEAR: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MONTH: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DECADE: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::CENTURY: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLENNIUM: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::QUARTER: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::ISOYEAR: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MINUTE: - DateDiff::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::HOUR: - DateDiff::BinaryExecute(left, right, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATEDIFF"); - } -} - -template -static void DateDiffFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateDiffBinaryExecutor(type, start_arg, end_arg, result, args.size()); - } - } else { - TernaryExecutor::ExecuteWithNulls( - part_arg, start_arg, end_arg, result, args.size(), - DateDiffTernaryOperator::Operation); - } -} - -ScalarFunctionSet DateDiffFun::GetFunctions() { - ScalarFunctionSet date_diff("date_diff"); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, - LogicalType::BIGINT, DateDiffFunction)); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::BIGINT, DateDiffFunction)); - date_diff.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, - LogicalType::BIGINT, DateDiffFunction)); - return date_diff; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp b/src/duckdb/extension/core_functions/scalar/date/date_part.cpp deleted file mode 100644 index 1aeb4550c..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/date_part.cpp +++ /dev/null @@ -1,2263 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/case_insensitive_map.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/date_lookup_cache.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -DatePartSpecifier GetDateTypePartSpecifier(const string &specifier, LogicalType &type) { - const auto part = GetDatePartSpecifier(specifier); - switch (type.id()) { - case LogicalType::TIMESTAMP: - case LogicalType::TIMESTAMP_TZ: - return part; - case LogicalType::DATE: - switch (part) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::ISOYEAR: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::DOY: - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::ERA: - case DatePartSpecifier::EPOCH: - case DatePartSpecifier::JULIAN_DAY: - return part; - default: - break; - } - break; - case LogicalType::TIME: - case LogicalType::TIME_TZ: - switch (part) { - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - case DatePartSpecifier::EPOCH: - case DatePartSpecifier::TIMEZONE: - case DatePartSpecifier::TIMEZONE_HOUR: - case DatePartSpecifier::TIMEZONE_MINUTE: - return part; - default: - break; - } - break; - case LogicalType::INTERVAL: - switch (part) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - case DatePartSpecifier::EPOCH: - return part; - default: - break; - } - break; - default: - break; - } - - throw NotImplementedException("\"%s\" units \"%s\" not recognized", EnumUtil::ToString(type.id()), specifier); -} - -template -static unique_ptr PropagateSimpleDatePartStatistics(vector &child_stats) { - // we can always propagate simple date part statistics - // since the min and max can never exceed these bounds - auto result = NumericStats::CreateEmpty(LogicalType::BIGINT); - result.CopyValidity(child_stats[0]); - NumericStats::SetMin(result, Value::BIGINT(MIN)); - NumericStats::SetMax(result, Value::BIGINT(MAX)); - return result.ToUnique(); -} - -template -struct DateCacheLocalState : public FunctionLocalState { - explicit DateCacheLocalState() { - } - - DateLookupCache cache; -}; - -template -unique_ptr InitDateCacheLocalState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - return make_uniq>(); -} - -struct DatePart { - template - static unique_ptr PropagateDatePartStatistics(vector &child_stats, - const LogicalType &stats_type = LogicalType::BIGINT) { - // we can only propagate complex date part stats if the child has stats - auto &nstats = child_stats[0]; - if (!NumericStats::HasMinMax(nstats)) { - return nullptr; - } - // run the operator on both the min and the max, this gives us the [min, max] bound - auto min = NumericStats::GetMin(nstats); - auto max = NumericStats::GetMax(nstats); - if (min > max) { - return nullptr; - } - // Infinities prevent us from computing generic ranges - if (!Value::IsFinite(min) || !Value::IsFinite(max)) { - return nullptr; - } - TR min_part = OP::template Operation(min); - TR max_part = OP::template Operation(max); - auto result = NumericStats::CreateEmpty(stats_type); - NumericStats::SetMin(result, Value(min_part)); - NumericStats::SetMax(result, Value(max_part)); - result.CopyValidity(child_stats[0]); - return result.ToUnique(); - } - - template - struct PartOperator { - template - static inline TR Operation(TA input, ValidityMask &mask, idx_t idx, void *dataptr) { - if (Value::IsFinite(input)) { - return OP::template Operation(input); - } else { - mask.SetInvalid(idx); - return TR(); - } - } - }; - - template - static void UnaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() >= 1); - using IOP = PartOperator; - UnaryExecutor::GenericExecute(input.data[0], result, input.size(), nullptr, true); - } - - struct YearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractYear(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractMonth(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of month operator is [1, 12] - return PropagateSimpleDatePartStatistics<1, 12>(input.child_stats); - } - }; - - struct DayOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractDay(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of day operator is [1, 31] - return PropagateSimpleDatePartStatistics<1, 31>(input.child_stats); - } - }; - - struct DecadeOperator { - // From the PG docs: "The year field divided by 10" - template - static inline TR DecadeFromYear(TR yyyy) { - return yyyy / 10; - } - - template - static inline TR Operation(TA input) { - return DecadeFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct CenturyOperator { - // From the PG docs: - // "The first century starts at 0001-01-01 00:00:00 AD, although they did not know it at the time. - // This definition applies to all Gregorian calendar countries. - // There is no century number 0, you go from -1 century to 1 century. - // If you disagree with this, please write your complaint to: Pope, Cathedral Saint-Peter of Roma, Vatican." - // (To be fair, His Holiness had nothing to do with this - - // it was the lack of zero in the counting systems of the time...) - template - static inline TR CenturyFromYear(TR yyyy) { - if (yyyy > 0) { - return ((yyyy - 1) / 100) + 1; - } else { - return (yyyy / 100) - 1; - } - } - - template - static inline TR Operation(TA input) { - return CenturyFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct MillenniumOperator { - // See the century comment - template - static inline TR MillenniumFromYear(TR yyyy) { - if (yyyy > 0) { - return ((yyyy - 1) / 1000) + 1; - } else { - return (yyyy / 1000) - 1; - } - } - - template - static inline TR Operation(TA input) { - return MillenniumFromYear(YearOperator::Operation(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct QuarterOperator { - template - static inline TR QuarterFromMonth(TR mm) { - return (mm - 1) / Interval::MONTHS_PER_QUARTER + 1; - } - - template - static inline TR Operation(TA input) { - return QuarterFromMonth(Date::ExtractMonth(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - // min/max of quarter operator is [1, 4] - return PropagateSimpleDatePartStatistics<1, 4>(input.child_stats); - } - }; - - struct DayOfWeekOperator { - template - static inline TR DayOfWeekFromISO(TR isodow) { - // day of the week (Sunday = 0, Saturday = 6) - // turn sunday into 0 by doing mod 7 - return isodow % 7; - } - - template - static inline TR Operation(TA input) { - return DayOfWeekFromISO(Date::ExtractISODayOfTheWeek(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 6>(input.child_stats); - } - }; - - struct ISODayOfWeekOperator { - template - static inline TR Operation(TA input) { - // isodow (Monday = 1, Sunday = 7) - return Date::ExtractISODayOfTheWeek(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 7>(input.child_stats); - } - }; - - struct DayOfYearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractDayOfTheYear(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 366>(input.child_stats); - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractISOWeekNumber(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<1, 54>(input.child_stats); - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA input) { - return Date::ExtractISOYearNumber(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct YearWeekOperator { - template - static inline TR YearWeekFromParts(TR yyyy, TR ww) { - return yyyy * 100 + ((yyyy > 0) ? ww : -ww); - } - - template - static inline TR Operation(TA input) { - int32_t yyyy, ww; - Date::ExtractISOYearWeek(input, yyyy, ww); - return YearWeekFromParts(yyyy, ww); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochNanosecondsOperator { - template - static inline TR Operation(TA input) { - return Timestamp::GetEpochNanoSeconds(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochMicrosecondsOperator { - template - static inline TR Operation(TA input) { - return Timestamp::GetEpochMicroSeconds(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - }; - - struct EpochMillisOperator { - template - static inline TR Operation(TA input) { - return Cast::Operation(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats); - } - - static void Inverse(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](int64_t input) { - // millisecond amounts provided to epoch_ms should never be considered infinite - // instead such values will just throw when converted to microseconds - return Timestamp::FromEpochMsPossiblyInfinite(input); - }); - } - }; - - struct NanosecondsOperator { - template - static inline TR Operation(TA input) { - return MicrosecondsOperator::Operation(input) * Interval::NANOS_PER_MICRO; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60000000000>(input.child_stats); - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60000000>(input.child_stats); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60000>(input.child_stats); - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 60>(input.child_stats); - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA input) { - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 24>(input.child_stats); - } - }; - - struct EpochOperator { - template - static inline TR Operation(TA input) { - return TR(Date::Epoch(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); - } - }; - - struct EraOperator { - template - static inline TR EraFromYear(TR yyyy) { - return yyyy > 0 ? 1 : 0; - } - - template - static inline TR Operation(TA input) { - return EraFromYear(Date::ExtractYear(input)); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 1>(input.child_stats); - } - }; - - struct TimezoneOperator { - template - static inline TR Operation(TA input) { - // Regular timestamps are UTC. - return 0; - } - - template - static TR Operation(TA interval, TB timetz) { - auto time = Time::NormalizeTimeTZ(timetz); - date_t date(0); - time = Interval::Add(time, interval, date); - auto offset = UnsafeNumericCast(interval.micros / Interval::MICROS_PER_SEC); - return TR(time, offset); - } - - template - static void BinaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 2); - auto &offset = input.data[0]; - auto &timetz = input.data[1]; - - auto func = DatePart::TimezoneOperator::Operation; - BinaryExecutor::Execute(offset, timetz, result, input.size(), func); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); - } - }; - - struct TimezoneHourOperator { - template - static inline TR Operation(TA input) { - // Regular timestamps are UTC. - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); - } - }; - - struct TimezoneMinuteOperator { - template - static inline TR Operation(TA input) { - // Regular timestamps are UTC. - return 0; - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateSimpleDatePartStatistics<0, 0>(input.child_stats); - } - }; - - struct JulianDayOperator { - template - static inline TR Operation(TA input) { - return Timestamp::GetJulianDay(input); - } - - template - static unique_ptr PropagateStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return PropagateDatePartStatistics(input.child_stats, LogicalType::DOUBLE); - } - }; - - struct StructOperator { - using part_codes_t = vector; - using part_mask_t = uint64_t; - - enum MaskBits : uint8_t { - YMD = 1 << 0, - DOW = 1 << 1, - DOY = 1 << 2, - EPOCH = 1 << 3, - TIME = 1 << 4, - ZONE = 1 << 5, - ISO = 1 << 6, - JD = 1 << 7 - }; - - static part_mask_t GetMask(const part_codes_t &part_codes) { - part_mask_t mask = 0; - for (const auto &part_code : part_codes) { - switch (part_code) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::ERA: - mask |= YMD; - break; - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::ISOYEAR: - mask |= ISO; - break; - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - mask |= DOW; - break; - case DatePartSpecifier::DOY: - mask |= DOY; - break; - case DatePartSpecifier::EPOCH: - mask |= EPOCH; - break; - case DatePartSpecifier::JULIAN_DAY: - mask |= JD; - break; - case DatePartSpecifier::MICROSECONDS: - case DatePartSpecifier::MILLISECONDS: - case DatePartSpecifier::SECOND: - case DatePartSpecifier::MINUTE: - case DatePartSpecifier::HOUR: - mask |= TIME; - break; - case DatePartSpecifier::TIMEZONE: - case DatePartSpecifier::TIMEZONE_HOUR: - case DatePartSpecifier::TIMEZONE_MINUTE: - mask |= ZONE; - break; - case DatePartSpecifier::INVALID: - throw InternalException("Invalid DatePartSpecifier for STRUCT mask!"); - } - } - return mask; - } - - template - static inline P HasPartValue(vector

part_values, DatePartSpecifier part) { - auto idx = size_t(part); - if (IsBigintDatepart(part)) { - return part_values[idx - size_t(DatePartSpecifier::BEGIN_BIGINT)]; - } else { - return part_values[idx - size_t(DatePartSpecifier::BEGIN_DOUBLE)]; - } - } - - using bigint_vec = vector; - using double_vec = vector; - - template - static inline void Operation(bigint_vec &bigint_values, double_vec &double_values, const TA &input, - const idx_t idx, const part_mask_t mask) { - int64_t *bigint_data; - // YMD calculations - int32_t yyyy = 1970; - int32_t mm = 0; - int32_t dd = 1; - if (mask & YMD) { - Date::Convert(input, yyyy, mm, dd); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); - if (bigint_data) { - bigint_data[idx] = yyyy; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); - if (bigint_data) { - bigint_data[idx] = mm; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); - if (bigint_data) { - bigint_data[idx] = dd; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); - if (bigint_data) { - bigint_data[idx] = DecadeOperator::DecadeFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); - if (bigint_data) { - bigint_data[idx] = CenturyOperator::CenturyFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); - if (bigint_data) { - bigint_data[idx] = MillenniumOperator::MillenniumFromYear(yyyy); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); - if (bigint_data) { - bigint_data[idx] = QuarterOperator::QuarterFromMonth(mm); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ERA); - if (bigint_data) { - bigint_data[idx] = EraOperator::EraFromYear(yyyy); - } - } - - // Week calculations - if (mask & DOW) { - auto isodow = Date::ExtractISODayOfTheWeek(input); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOW); - if (bigint_data) { - bigint_data[idx] = DayOfWeekOperator::DayOfWeekFromISO(isodow); - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISODOW); - if (bigint_data) { - bigint_data[idx] = isodow; - } - } - - // ISO calculations - if (mask & ISO) { - int32_t ww = 0; - int32_t iyyy = 0; - Date::ExtractISOYearWeek(input, iyyy, ww); - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::WEEK); - if (bigint_data) { - bigint_data[idx] = ww; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::ISOYEAR); - if (bigint_data) { - bigint_data[idx] = iyyy; - } - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::YEARWEEK); - if (bigint_data) { - bigint_data[idx] = YearWeekOperator::YearWeekFromParts(iyyy, ww); - } - } - - if (mask & EPOCH) { - auto double_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (double_data) { - double_data[idx] = double(Date::Epoch(input)); - } - } - if (mask & DOY) { - bigint_data = HasPartValue(bigint_values, DatePartSpecifier::DOY); - if (bigint_data) { - bigint_data[idx] = Date::ExtractDayOfTheYear(input); - } - } - if (mask & JD) { - auto double_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); - if (double_data) { - double_data[idx] = double(Date::ExtractJulianDay(input)); - } - } - } - }; -}; - -template -static void DatePartCachedFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast>(); - UnaryExecutor::ExecuteWithNulls( - args.data[0], result, args.size(), - [&](T input, ValidityMask &mask, idx_t idx) { return lstate.cache.ExtractElement(input, mask, idx); }); -} - -template <> -int64_t DatePart::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::YearOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_YEAR; -} - -template <> -int64_t DatePart::YearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DatePart::YearOperator::Operation(dtime_tz_t input) { - return YearOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::MonthOperator::Operation(interval_t input) { - return input.months % Interval::MONTHS_PER_YEAR; -} - -template <> -int64_t DatePart::MonthOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DatePart::MonthOperator::Operation(dtime_tz_t input) { - return MonthOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOperator::Operation(interval_t input) { - return input.days; -} - -template <> -int64_t DatePart::DayOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DatePart::DayOperator::Operation(dtime_tz_t input) { - return DayOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::DecadeOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_DECADE; -} - -template <> -int64_t DatePart::DecadeOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DatePart::DecadeOperator::Operation(dtime_tz_t input) { - return DecadeOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::CenturyOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_CENTURY; -} - -template <> -int64_t DatePart::CenturyOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DatePart::CenturyOperator::Operation(dtime_tz_t input) { - return CenturyOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::MillenniumOperator::Operation(interval_t input) { - return input.months / Interval::MONTHS_PER_MILLENIUM; -} - -template <> -int64_t DatePart::MillenniumOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DatePart::MillenniumOperator::Operation(dtime_tz_t input) { - return MillenniumOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::QuarterOperator::Operation(interval_t input) { - return MonthOperator::Operation(input) / Interval::MONTHS_PER_QUARTER + 1; -} - -template <> -int64_t DatePart::QuarterOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DatePart::QuarterOperator::Operation(dtime_tz_t input) { - return QuarterOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(timestamp_t input) { - return DayOfWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"dow\" not recognized"); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"dow\" not recognized"); -} - -template <> -int64_t DatePart::DayOfWeekOperator::Operation(dtime_tz_t input) { - return DayOfWeekOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(timestamp_t input) { - return ISODayOfWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"isodow\" not recognized"); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"isodow\" not recognized"); -} - -template <> -int64_t DatePart::ISODayOfWeekOperator::Operation(dtime_tz_t input) { - return ISODayOfWeekOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(timestamp_t input) { - return DayOfYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"doy\" not recognized"); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"doy\" not recognized"); -} - -template <> -int64_t DatePart::DayOfYearOperator::Operation(dtime_tz_t input) { - return DayOfYearOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::WeekOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"week\" not recognized"); -} - -template <> -int64_t DatePart::WeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DatePart::WeekOperator::Operation(dtime_tz_t input) { - return WeekOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"isoyear\" not recognized"); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"isoyear\" not recognized"); -} - -template <> -int64_t DatePart::ISOYearOperator::Operation(dtime_tz_t input) { - return ISOYearOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(timestamp_t input) { - return YearWeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(interval_t input) { - const auto yyyy = YearOperator::Operation(input); - const auto ww = WeekOperator::Operation(input); - return YearWeekOperator::YearWeekFromParts(yyyy, ww); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"yearweek\" not recognized"); -} - -template <> -int64_t DatePart::YearWeekOperator::Operation(dtime_tz_t input) { - return YearWeekOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return Timestamp::GetEpochNanoSeconds(input); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(date_t input) { - D_ASSERT(Date::IsFinite(input)); - return Date::EpochNanoseconds(input); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(interval_t input) { - return Interval::GetNanoseconds(input); -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(dtime_t input) { - return input.micros * Interval::NANOS_PER_MICRO; -} - -template <> -int64_t DatePart::EpochNanosecondsOperator::Operation(dtime_tz_t input) { - return DatePart::EpochNanosecondsOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(date_t input) { - return Date::EpochMicroseconds(input); -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(interval_t input) { - return Interval::GetMicro(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return Cast::Operation(input).value; -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(dtime_t input) { - return input.micros; -} - -template <> -int64_t DatePart::EpochMicrosecondsOperator::Operation(dtime_tz_t input) { - return DatePart::EpochMicrosecondsOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(date_t input) { - return Date::EpochMilliseconds(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(interval_t input) { - return Interval::GetMilli(input); -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(dtime_t input) { - return input.micros / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::EpochMillisOperator::Operation(dtime_tz_t input) { - return DatePart::EpochMillisOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::NanosecondsOperator::Operation(timestamp_ns_t input) { - if (!Timestamp::IsFinite(input)) { - throw ConversionException("Can't get nanoseconds of infinite TIMESTAMP"); - } - date_t date; - dtime_t time; - int32_t nanos; - Timestamp::Convert(input, date, time, nanos); - // remove everything but the second & nanosecond part - return (time.micros % Interval::MICROS_PER_MINUTE) * Interval::NANOS_PER_MICRO + nanos; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - auto time = Timestamp::GetTime(input); - // remove everything but the second & microsecond part - return time.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(interval_t input) { - // remove everything but the second & microsecond part - return input.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(dtime_t input) { - // remove everything but the second & microsecond part - return input.micros % Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MicrosecondsOperator::Operation(dtime_tz_t input) { - return DatePart::MicrosecondsOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(interval_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(dtime_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DatePart::MillisecondsOperator::Operation(dtime_tz_t input) { - return DatePart::MillisecondsOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::SecondsOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(interval_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(dtime_t input) { - return MicrosecondsOperator::Operation(input) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DatePart::SecondsOperator::Operation(dtime_tz_t input) { - return DatePart::SecondsOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::MinutesOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - auto time = Timestamp::GetTime(input); - // remove the hour part, and truncate to minutes - return (time.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(interval_t input) { - // remove the hour part, and truncate to minutes - return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(dtime_t input) { - // remove the hour part, and truncate to minutes - return (input.micros % Interval::MICROS_PER_HOUR) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DatePart::MinutesOperator::Operation(dtime_tz_t input) { - return DatePart::MinutesOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::HoursOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return Timestamp::GetTime(input).micros / Interval::MICROS_PER_HOUR; -} - -template <> -int64_t DatePart::HoursOperator::Operation(interval_t input) { - return input.micros / Interval::MICROS_PER_HOUR; -} - -template <> -int64_t DatePart::HoursOperator::Operation(dtime_t input) { - return input.micros / Interval::MICROS_PER_HOUR; -} - -template <> -int64_t DatePart::HoursOperator::Operation(dtime_tz_t input) { - return DatePart::HoursOperator::Operation(input.time()); -} - -template <> -double DatePart::EpochOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return double(Timestamp::GetEpochMicroSeconds(input)) / double(Interval::MICROS_PER_SEC); -} - -template <> -double DatePart::EpochOperator::Operation(interval_t input) { - int64_t interval_years = input.months / Interval::MONTHS_PER_YEAR; - int64_t interval_days; - interval_days = Interval::DAYS_PER_YEAR * interval_years; - interval_days += Interval::DAYS_PER_MONTH * (input.months % Interval::MONTHS_PER_YEAR); - interval_days += input.days; - int64_t interval_epoch; - interval_epoch = interval_days * Interval::SECS_PER_DAY; - // we add 0.25 days per year to sort of account for leap days - interval_epoch += interval_years * (Interval::SECS_PER_DAY / 4); - return double(interval_epoch) + double(input.micros) / double(Interval::MICROS_PER_SEC); -} - -// TODO: We can't propagate interval statistics because we can't easily compare interval_t for order. -template <> -unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, - FunctionStatisticsInput &input) { - return nullptr; -} - -template <> -double DatePart::EpochOperator::Operation(dtime_t input) { - return double(input.micros) / double(Interval::MICROS_PER_SEC); -} - -template <> -double DatePart::EpochOperator::Operation(dtime_tz_t input) { - return DatePart::EpochOperator::Operation(input.time()); -} - -template <> -unique_ptr DatePart::EpochOperator::PropagateStatistics(ClientContext &context, - FunctionStatisticsInput &input) { - auto result = NumericStats::CreateEmpty(LogicalType::DOUBLE); - result.CopyValidity(input.child_stats[0]); - NumericStats::SetMin(result, Value::DOUBLE(0)); - NumericStats::SetMax(result, Value::DOUBLE(Interval::SECS_PER_DAY)); - return result.ToUnique(); -} - -template <> -int64_t DatePart::EraOperator::Operation(timestamp_t input) { - D_ASSERT(Timestamp::IsFinite(input)); - return EraOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -int64_t DatePart::EraOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"era\" not recognized"); -} - -template <> -int64_t DatePart::EraOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"era\" not recognized"); -} - -template <> -int64_t DatePart::EraOperator::Operation(dtime_tz_t input) { - return EraOperator::Operation(input.time()); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(date_t input) { - throw NotImplementedException("\"date\" units \"timezone\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(interval_t input) { - throw NotImplementedException("\"interval\" units \"timezone\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneOperator::Operation(dtime_tz_t input) { - return input.offset(); -} - -template <> -int64_t DatePart::TimezoneHourOperator::Operation(date_t input) { - throw NotImplementedException("\"date\" units \"timezone_hour\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneHourOperator::Operation(interval_t input) { - throw NotImplementedException("\"interval\" units \"timezone_hour\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneHourOperator::Operation(dtime_tz_t input) { - return input.offset() / Interval::SECS_PER_HOUR; -} - -template <> -int64_t DatePart::TimezoneMinuteOperator::Operation(date_t input) { - throw NotImplementedException("\"date\" units \"timezone_minute\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneMinuteOperator::Operation(interval_t input) { - throw NotImplementedException("\"interval\" units \"timezone_minute\" not recognized"); -} - -template <> -int64_t DatePart::TimezoneMinuteOperator::Operation(dtime_tz_t input) { - return (input.offset() / Interval::SECS_PER_MINUTE) % Interval::MINS_PER_HOUR; -} - -template <> -double DatePart::JulianDayOperator::Operation(date_t input) { - return double(Date::ExtractJulianDay(input)); -} - -template <> -double DatePart::JulianDayOperator::Operation(interval_t input) { - throw NotImplementedException("interval units \"julian\" not recognized"); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_t input) { - throw NotImplementedException("\"time\" units \"julian\" not recognized"); -} - -template <> -double DatePart::JulianDayOperator::Operation(dtime_tz_t input) { - return JulianDayOperator::Operation(input.time()); -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_t &input, - const idx_t idx, const part_mask_t mask) { - int64_t *part_data; - if (mask & TIME) { - const auto micros = MicrosecondsOperator::Operation(input); - part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); - if (part_data) { - part_data[idx] = micros; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_MSEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_SEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); - if (part_data) { - part_data[idx] = MinutesOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); - if (part_data) { - part_data[idx] = HoursOperator::Operation(input); - } - } - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } - - if (mask & ZONE) { - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); - if (part_data) { - part_data[idx] = 0; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); - if (part_data) { - part_data[idx] = 0; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); - if (part_data) { - part_data[idx] = 0; - } - } -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const dtime_tz_t &input, - const idx_t idx, const part_mask_t mask) { - int64_t *part_data; - if (mask & TIME) { - const auto micros = MicrosecondsOperator::Operation(input); - part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); - if (part_data) { - part_data[idx] = micros; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_MSEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_SEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); - if (part_data) { - part_data[idx] = MinutesOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); - if (part_data) { - part_data[idx] = HoursOperator::Operation(input); - } - } - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } - - if (mask & ZONE) { - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE); - if (part_data) { - part_data[idx] = TimezoneOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_HOUR); - if (part_data) { - part_data[idx] = TimezoneHourOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::TIMEZONE_MINUTE); - if (part_data) { - part_data[idx] = TimezoneMinuteOperator::Operation(input); - } - return; - } -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const timestamp_t &input, - const idx_t idx, const part_mask_t mask) { - D_ASSERT(Timestamp::IsFinite(input)); - date_t d; - dtime_t t; - Timestamp::Convert(input, d, t); - - // Both define epoch, and the correct value is the sum. - // So mask it out and compute it separately. - Operation(bigint_values, double_values, d, idx, mask & ~UnsafeNumericCast(EPOCH)); - Operation(bigint_values, double_values, t, idx, mask & ~UnsafeNumericCast(EPOCH)); - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } - - if (mask & JD) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::JULIAN_DAY); - if (part_data) { - part_data[idx] = JulianDayOperator::Operation(input); - } - } -} - -template <> -void DatePart::StructOperator::Operation(bigint_vec &bigint_values, double_vec &double_values, const interval_t &input, - const idx_t idx, const part_mask_t mask) { - int64_t *part_data; - if (mask & YMD) { - const auto mm = input.months % Interval::MONTHS_PER_YEAR; - part_data = HasPartValue(bigint_values, DatePartSpecifier::YEAR); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_YEAR; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MONTH); - if (part_data) { - part_data[idx] = mm; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::DAY); - if (part_data) { - part_data[idx] = input.days; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::DECADE); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_DECADE; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::CENTURY); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_CENTURY; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLENNIUM); - if (part_data) { - part_data[idx] = input.months / Interval::MONTHS_PER_MILLENIUM; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::QUARTER); - if (part_data) { - part_data[idx] = mm / Interval::MONTHS_PER_QUARTER + 1; - } - } - - if (mask & TIME) { - const auto micros = MicrosecondsOperator::Operation(input); - part_data = HasPartValue(bigint_values, DatePartSpecifier::MICROSECONDS); - if (part_data) { - part_data[idx] = micros; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MILLISECONDS); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_MSEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::SECOND); - if (part_data) { - part_data[idx] = micros / Interval::MICROS_PER_SEC; - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::MINUTE); - if (part_data) { - part_data[idx] = MinutesOperator::Operation(input); - } - part_data = HasPartValue(bigint_values, DatePartSpecifier::HOUR); - if (part_data) { - part_data[idx] = HoursOperator::Operation(input); - } - } - - if (mask & EPOCH) { - auto part_data = HasPartValue(double_values, DatePartSpecifier::EPOCH); - if (part_data) { - part_data[idx] = EpochOperator::Operation(input); - } - } -} - -template -static int64_t ExtractElement(DatePartSpecifier type, T element) { - switch (type) { - case DatePartSpecifier::YEAR: - return DatePart::YearOperator::template Operation(element); - case DatePartSpecifier::MONTH: - return DatePart::MonthOperator::template Operation(element); - case DatePartSpecifier::DAY: - return DatePart::DayOperator::template Operation(element); - case DatePartSpecifier::DECADE: - return DatePart::DecadeOperator::template Operation(element); - case DatePartSpecifier::CENTURY: - return DatePart::CenturyOperator::template Operation(element); - case DatePartSpecifier::MILLENNIUM: - return DatePart::MillenniumOperator::template Operation(element); - case DatePartSpecifier::QUARTER: - return DatePart::QuarterOperator::template Operation(element); - case DatePartSpecifier::DOW: - return DatePart::DayOfWeekOperator::template Operation(element); - case DatePartSpecifier::ISODOW: - return DatePart::ISODayOfWeekOperator::template Operation(element); - case DatePartSpecifier::DOY: - return DatePart::DayOfYearOperator::template Operation(element); - case DatePartSpecifier::WEEK: - return DatePart::WeekOperator::template Operation(element); - case DatePartSpecifier::ISOYEAR: - return DatePart::ISOYearOperator::template Operation(element); - case DatePartSpecifier::YEARWEEK: - return DatePart::YearWeekOperator::template Operation(element); - case DatePartSpecifier::MICROSECONDS: - return DatePart::MicrosecondsOperator::template Operation(element); - case DatePartSpecifier::MILLISECONDS: - return DatePart::MillisecondsOperator::template Operation(element); - case DatePartSpecifier::SECOND: - return DatePart::SecondsOperator::template Operation(element); - case DatePartSpecifier::MINUTE: - return DatePart::MinutesOperator::template Operation(element); - case DatePartSpecifier::HOUR: - return DatePart::HoursOperator::template Operation(element); - case DatePartSpecifier::ERA: - return DatePart::EraOperator::template Operation(element); - case DatePartSpecifier::TIMEZONE: - return DatePart::TimezoneOperator::template Operation(element); - case DatePartSpecifier::TIMEZONE_HOUR: - return DatePart::TimezoneHourOperator::template Operation(element); - case DatePartSpecifier::TIMEZONE_MINUTE: - return DatePart::TimezoneMinuteOperator::template Operation(element); - default: - throw NotImplementedException("Specifier type not implemented for DATEPART"); - } -} - -template -static void DatePartFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto &spec_arg = args.data[0]; - auto &date_arg = args.data[1]; - - BinaryExecutor::ExecuteWithNulls( - spec_arg, date_arg, result, args.size(), [&](string_t specifier, T date, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(date)) { - return ExtractElement(GetDatePartSpecifier(specifier.GetString()), date); - } else { - mask.SetInvalid(idx); - return int64_t(0); - } - }); -} - -static unique_ptr DatePartBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // If we are only looking for Julian Days for timestamps, - // then return doubles. - if (arguments[0]->HasParameter() || !arguments[0]->IsFoldable()) { - return nullptr; - } - - Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - const auto part_name = part_value.ToString(); - switch (GetDatePartSpecifier(part_name)) { - case DatePartSpecifier::JULIAN_DAY: - arguments.erase(arguments.begin()); - bound_function.arguments.erase(bound_function.arguments.begin()); - bound_function.name = "julian"; - bound_function.return_type = LogicalType::DOUBLE; - switch (arguments[0]->return_type.id()) { - case LogicalType::TIMESTAMP: - case LogicalType::TIMESTAMP_S: - case LogicalType::TIMESTAMP_MS: - case LogicalType::TIMESTAMP_NS: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; - break; - case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::JulianDayOperator::template PropagateStatistics; - break; - default: - throw BinderException("%s can only take DATE or TIMESTAMP arguments", bound_function.name); - } - break; - case DatePartSpecifier::EPOCH: - arguments.erase(arguments.begin()); - bound_function.arguments.erase(bound_function.arguments.begin()); - bound_function.name = "epoch"; - bound_function.return_type = LogicalType::DOUBLE; - switch (arguments[0]->return_type.id()) { - case LogicalType::TIMESTAMP: - case LogicalType::TIMESTAMP_S: - case LogicalType::TIMESTAMP_MS: - case LogicalType::TIMESTAMP_NS: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::DATE: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::INTERVAL: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::TIME: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - case LogicalType::TIME_TZ: - bound_function.function = DatePart::UnaryFunction; - bound_function.statistics = DatePart::EpochOperator::template PropagateStatistics; - break; - default: - throw BinderException("%s can only take temporal arguments", bound_function.name); - } - break; - default: - break; - } - - return nullptr; -} - -template -ScalarFunctionSet GetGenericDatePartFunction(scalar_function_t date_func, scalar_function_t ts_func, - scalar_function_t interval_func, function_statistics_t date_stats, - function_statistics_t ts_stats) { - ScalarFunctionSet operator_set; - operator_set.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BIGINT, std::move(date_func), nullptr, - nullptr, date_stats, DATE_CACHE)); - operator_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BIGINT, std::move(ts_func), nullptr, - nullptr, ts_stats, DATE_CACHE)); - operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, LogicalType::BIGINT, std::move(interval_func))); - for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return operator_set; -} - -template -static ScalarFunctionSet GetDatePartFunction() { - return GetGenericDatePartFunction( - DatePart::UnaryFunction, DatePart::UnaryFunction, - ScalarFunction::UnaryFunction, OP::template PropagateStatistics, - OP::template PropagateStatistics); -} - -ScalarFunctionSet GetGenericTimePartFunction(const LogicalType &result_type, scalar_function_t date_func, - scalar_function_t ts_func, scalar_function_t interval_func, - scalar_function_t time_func, scalar_function_t timetz_func, - function_statistics_t date_stats, function_statistics_t ts_stats, - function_statistics_t time_stats, function_statistics_t timetz_stats) { - ScalarFunctionSet operator_set; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, result_type, std::move(date_func), nullptr, nullptr, date_stats)); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, result_type, std::move(ts_func), nullptr, nullptr, ts_stats)); - operator_set.AddFunction(ScalarFunction({LogicalType::INTERVAL}, result_type, std::move(interval_func))); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIME}, result_type, std::move(time_func), nullptr, nullptr, time_stats)); - operator_set.AddFunction( - ScalarFunction({LogicalType::TIME_TZ}, result_type, std::move(timetz_func), nullptr, nullptr, timetz_stats)); - return operator_set; -} - -template -static ScalarFunctionSet GetTimePartFunction(const LogicalType &result_type = LogicalType::BIGINT) { - return GetGenericTimePartFunction( - result_type, DatePart::UnaryFunction, DatePart::UnaryFunction, - ScalarFunction::UnaryFunction, ScalarFunction::UnaryFunction, - ScalarFunction::UnaryFunction, OP::template PropagateStatistics, - OP::template PropagateStatistics, OP::template PropagateStatistics, - OP::template PropagateStatistics); -} - -struct LastDayOperator { - template - static inline TR Operation(TA input) { - int32_t yyyy, mm, dd; - Date::Convert(input, yyyy, mm, dd); - yyyy += (mm / 12); - mm %= 12; - ++mm; - return Date::FromDate(yyyy, mm, 1) - 1; - } -}; - -template <> -date_t LastDayOperator::Operation(timestamp_t input) { - return LastDayOperator::Operation(Timestamp::GetDate(input)); -} - -struct MonthNameOperator { - template - static inline TR Operation(TA input) { - return Date::MONTH_NAMES[DatePart::MonthOperator::Operation(input) - 1]; - } -}; - -struct DayNameOperator { - template - static inline TR Operation(TA input) { - return Date::DAY_NAMES[DatePart::DayOfWeekOperator::Operation(input)]; - } -}; - -struct StructDatePart { - using part_codes_t = vector; - - struct BindData : public VariableReturnBindData { - part_codes_t part_codes; - - explicit BindData(const LogicalType &stype, const part_codes_t &part_codes_p) - : VariableReturnBindData(stype), part_codes(part_codes_p) { - } - - unique_ptr Copy() const override { - return make_uniq(stype, part_codes); - } - }; - - static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // collect names and deconflict, construct return type - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[0]->IsFoldable()) { - throw BinderException("%s can only take constant lists of part names", bound_function.name); - } - - case_insensitive_set_t name_collision_set; - child_list_t struct_children; - part_codes_t part_codes; - - Value parts_list = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (parts_list.type().id() == LogicalTypeId::LIST) { - auto &list_children = ListValue::GetChildren(parts_list); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of part names", bound_function.name); - } - for (const auto &part_value : list_children) { - if (part_value.IsNull()) { - throw BinderException("NULL struct entry name in %s", bound_function.name); - } - const auto part_name = part_value.ToString(); - const auto part_code = GetDateTypePartSpecifier(part_name, arguments[1]->return_type); - if (name_collision_set.find(part_name) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\" in %s", part_name, bound_function.name); - } - name_collision_set.insert(part_name); - part_codes.emplace_back(part_code); - const auto part_type = IsBigintDatepart(part_code) ? LogicalType::BIGINT : LogicalType::DOUBLE; - struct_children.emplace_back(make_pair(part_name, part_type)); - } - } else { - throw BinderException("%s can only take constant lists of part names", bound_function.name); - } - - Function::EraseArgument(bound_function, arguments, 0); - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type, part_codes); - } - - template - static void Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - D_ASSERT(args.ColumnCount() == 1); - - const auto count = args.size(); - Vector &input = args.data[0]; - - // Type counts - const auto BIGINT_COUNT = size_t(DatePartSpecifier::BEGIN_DOUBLE) - size_t(DatePartSpecifier::BEGIN_BIGINT); - const auto DOUBLE_COUNT = size_t(DatePartSpecifier::BEGIN_INVALID) - size_t(DatePartSpecifier::BEGIN_DOUBLE); - DatePart::StructOperator::bigint_vec bigint_values(BIGINT_COUNT, nullptr); - DatePart::StructOperator::double_vec double_values(DOUBLE_COUNT, nullptr); - const auto part_mask = DatePart::StructOperator::GetMask(info.part_codes); - - auto &child_entries = StructVector::GetEntries(result); - - // The first computer of a part "owns" it - // and other requestors just reference the owner - vector owners(int(DatePartSpecifier::JULIAN_DAY) + 1, child_entries.size()); - for (size_t col = 0; col < child_entries.size(); ++col) { - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == child_entries.size()) { - owners[part_index] = col; - } - } - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - if (ConstantVector::IsNull(input)) { - ConstantVector::SetNull(result, true); - } else { - ConstantVector::SetNull(result, false); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - ConstantVector::SetNull(*child_entry, false); - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == col) { - if (IsBigintDatepart(info.part_codes[col])) { - bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = - ConstantVector::GetData(*child_entry); - } else { - double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = - ConstantVector::GetData(*child_entry); - } - } - } - auto tdata = ConstantVector::GetData(input); - if (Value::IsFinite(tdata[0])) { - DatePart::StructOperator::Operation(bigint_values, double_values, tdata[0], 0, part_mask); - } else { - for (auto &child_entry : child_entries) { - ConstantVector::SetNull(*child_entry, true); - } - } - } - } else { - UnifiedVectorFormat rdata; - input.ToUnifiedFormat(count, rdata); - - const auto &arg_valid = rdata.validity; - auto tdata = UnifiedVectorFormat::GetData(rdata); - - // Start with a valid flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &res_valid = FlatVector::Validity(result); - if (res_valid.GetData()) { - res_valid.SetAllValid(count); - } - - // Start with valid children - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - child_entry->SetVectorType(VectorType::FLAT_VECTOR); - auto &child_validity = FlatVector::Validity(*child_entry); - if (child_validity.GetData()) { - child_validity.SetAllValid(count); - } - - // Pre-multiplex - const auto part_index = size_t(info.part_codes[col]); - if (owners[part_index] == col) { - if (IsBigintDatepart(info.part_codes[col])) { - bigint_values[part_index - size_t(DatePartSpecifier::BEGIN_BIGINT)] = - FlatVector::GetData(*child_entry); - } else { - double_values[part_index - size_t(DatePartSpecifier::BEGIN_DOUBLE)] = - FlatVector::GetData(*child_entry); - } - } - } - - for (idx_t i = 0; i < count; ++i) { - const auto idx = rdata.sel->get_index(i); - if (arg_valid.RowIsValid(idx)) { - if (Value::IsFinite(tdata[idx])) { - DatePart::StructOperator::Operation(bigint_values, double_values, tdata[idx], i, part_mask); - } else { - for (auto &child_entry : child_entries) { - FlatVector::Validity(*child_entry).SetInvalid(i); - } - } - } else { - res_valid.SetInvalid(i); - for (auto &child_entry : child_entries) { - FlatVector::Validity(*child_entry).SetInvalid(i); - } - } - } - } - - // Reference any duplicate parts - for (size_t col = 0; col < child_entries.size(); ++col) { - const auto part_index = size_t(info.part_codes[col]); - const auto owner = owners[part_index]; - if (owner != col) { - child_entries[col]->Reference(*child_entries[owner]); - } - } - - result.Verify(count); - } - - static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - D_ASSERT(bind_data_p); - auto &info = bind_data_p->Cast(); - serializer.WriteProperty(100, "stype", info.stype); - serializer.WriteProperty(101, "part_codes", info.part_codes); - } - - static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { - auto stype = deserializer.ReadProperty(100, "stype"); - auto part_codes = deserializer.ReadProperty>(101, "part_codes"); - return make_uniq(std::move(stype), std::move(part_codes)); - } - - template - static ScalarFunction GetFunction(const LogicalType &temporal_type) { - auto part_type = LogicalType::LIST(LogicalType::VARCHAR); - auto result_type = LogicalType::STRUCT({}); - ScalarFunction result({part_type, temporal_type}, result_type, Function, Bind); - result.serialize = SerializeFunction; - result.deserialize = DeserializeFunction; - return result; - } -}; -template -ScalarFunctionSet GetCachedDatepartFunction() { - return GetGenericDatePartFunction>( - DatePartCachedFunction, DatePartCachedFunction, - ScalarFunction::UnaryFunction, OP::template PropagateStatistics, - OP::template PropagateStatistics); -} - -ScalarFunctionSet YearFun::GetFunctions() { - return GetCachedDatepartFunction(); -} - -ScalarFunctionSet MonthFun::GetFunctions() { - return GetCachedDatepartFunction(); -} - -ScalarFunctionSet DayFun::GetFunctions() { - return GetCachedDatepartFunction(); -} - -ScalarFunctionSet DecadeFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet CenturyFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet MillenniumFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet QuarterFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfWeekFun::GetFunctions() { - auto set = GetDatePartFunction(); - for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return set; -} - -ScalarFunctionSet ISODayOfWeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet ISOYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet EraFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet TimezoneFun::GetFunctions() { - auto operator_set = GetDatePartFunction(); - - // PG also defines timezone(INTERVAL, TIME_TZ) => TIME_TZ - ScalarFunction function({LogicalType::INTERVAL, LogicalType::TIME_TZ}, LogicalType::TIME_TZ, - DatePart::TimezoneOperator::BinaryFunction); - - operator_set.AddFunction(function); - - for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - - return operator_set; -} - -ScalarFunctionSet TimezoneHourFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet TimezoneMinuteFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet EpochFun::GetFunctions() { - return GetTimePartFunction(LogicalType::DOUBLE); -} - -struct GetEpochNanosOperator { - static int64_t Operation(timestamp_ns_t timestamp) { - return Timestamp::GetEpochNanoSeconds(timestamp); - } -}; - -static void ExecuteGetNanosFromTimestampNs(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - auto func = GetEpochNanosOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); -} - -ScalarFunctionSet EpochNsFun::GetFunctions() { - using OP = DatePart::EpochNanosecondsOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_NS}, LogicalType::BIGINT, ExecuteGetNanosFromTimestampNs)); - return operator_set; -} - -ScalarFunctionSet EpochUsFun::GetFunctions() { - using OP = DatePart::EpochMicrosecondsOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - return operator_set; -} - -ScalarFunctionSet EpochMsFun::GetFunctions() { - using OP = DatePart::EpochMillisOperator; - auto operator_set = GetTimePartFunction(); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - - // Legacy inverse BIGINT => TIMESTAMP - operator_set.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, DatePart::EpochMillisOperator::Inverse)); - - return operator_set; -} - -ScalarFunctionSet NanosecondsFun::GetFunctions() { - using OP = DatePart::NanosecondsOperator; - using TR = int64_t; - const LogicalType &result_type = LogicalType::BIGINT; - auto operator_set = GetTimePartFunction(); - - auto ns_func = DatePart::UnaryFunction; - auto ns_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_NS}, result_type, ns_func, nullptr, nullptr, ns_stats)); - - // TIMESTAMP WITH TIME ZONE has the same representation as TIMESTAMP so no need to defer to ICU - auto tstz_func = DatePart::UnaryFunction; - auto tstz_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BIGINT, tstz_func, nullptr, nullptr, tstz_stats)); - - return operator_set; -} - -ScalarFunctionSet MicrosecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet MillisecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet SecondsFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet MinutesFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet HoursFun::GetFunctions() { - return GetTimePartFunction(); -} - -ScalarFunctionSet YearWeekFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet DayOfMonthFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekDayFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet WeekOfYearFun::GetFunctions() { - return GetDatePartFunction(); -} - -ScalarFunctionSet LastDayFun::GetFunctions() { - ScalarFunctionSet last_day; - last_day.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::DATE, - DatePart::UnaryFunction)); - last_day.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DATE, - DatePart::UnaryFunction)); - return last_day; -} - -ScalarFunctionSet MonthNameFun::GetFunctions() { - ScalarFunctionSet monthname; - monthname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - monthname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - return monthname; -} - -ScalarFunctionSet DayNameFun::GetFunctions() { - ScalarFunctionSet dayname; - dayname.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - dayname.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - DatePart::UnaryFunction)); - return dayname; -} - -ScalarFunctionSet JulianDayFun::GetFunctions() { - using OP = DatePart::JulianDayOperator; - - ScalarFunctionSet operator_set; - auto date_func = DatePart::UnaryFunction; - auto date_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::DATE}, LogicalType::DOUBLE, date_func, nullptr, nullptr, date_stats)); - auto ts_func = DatePart::UnaryFunction; - auto ts_stats = OP::template PropagateStatistics; - operator_set.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::DOUBLE, ts_func, nullptr, nullptr, ts_stats)); - - return operator_set; -} - -ScalarFunctionSet DatePartFun::GetFunctions() { - ScalarFunctionSet date_part; - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - date_part.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME_TZ}, LogicalType::BIGINT, - DatePartFunction, DatePartBind)); - - // struct variants - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::DATE)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIMESTAMP)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::INTERVAL)); - date_part.AddFunction(StructDatePart::GetFunction(LogicalType::TIME_TZ)); - - for (auto &func : date_part.functions) { - BaseScalarFunction::SetReturnsError(func); - } - - return date_part; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp b/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp deleted file mode 100644 index acfb2c796..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/date_sub.cpp +++ /dev/null @@ -1,454 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -struct DateSub { - static int64_t SubtractMicros(timestamp_t startdate, timestamp_t enddate) { - const auto start = Timestamp::GetEpochMicroSeconds(startdate); - const auto end = Timestamp::GetEpochMicroSeconds(enddate); - return SubtractOperatorOverflowCheck::Operation(end, start); - } - - template - static inline void BinaryExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::ExecuteWithNulls( - left, right, result, count, [&](TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return OP::template Operation(startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - }); - } - - struct MonthOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - - if (start_ts > end_ts) { - return -MonthOperator::Operation(end_ts, start_ts); - } - // The number of complete months depends on whether end_ts is on the last day of the month. - date_t end_date; - dtime_t end_time; - Timestamp::Convert(end_ts, end_date, end_time); - - int32_t yyyy, mm, dd; - Date::Convert(end_date, yyyy, mm, dd); - const auto end_days = Date::MonthDays(yyyy, mm); - if (end_days == dd) { - // Now check whether the start day is after the end day - date_t start_date; - dtime_t start_time; - Timestamp::Convert(start_ts, start_date, start_time); - Date::Convert(start_date, yyyy, mm, dd); - if (dd > end_days || (dd == end_days && start_time < end_time)) { - // Move back to the same time on the last day of the (shorter) end month - start_date = Date::FromDate(yyyy, mm, end_days); - start_ts = Timestamp::FromDatetime(start_date, start_time); - } - } - - // Our interval difference will now give the correct result. - // Note that PG gives different interval subtraction results, - // so if we change this we will have to reimplement. - return Interval::GetAge(end_ts, start_ts).months; - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_QUARTER; - } - }; - - struct YearOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_YEAR; - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_DECADE; - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_CENTURY; - } - }; - - struct MilleniumOperator { - template - static inline TR Operation(TA start_ts, TB end_ts) { - return MonthOperator::Operation(start_ts, end_ts) / Interval::MONTHS_PER_MILLENIUM; - } - }; - - struct DayOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_DAY; - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_WEEK; - } - }; - - struct MicrosecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate); - } - }; - - struct MillisecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MSEC; - } - }; - - struct SecondsOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_SEC; - } - }; - - struct MinutesOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_MINUTE; - } - }; - - struct HoursOperator { - template - static inline TR Operation(TA startdate, TB enddate) { - return SubtractMicros(startdate, enddate) / Interval::MICROS_PER_HOUR; - } - }; -}; - -// DATE specialisations -template <> -int64_t DateSub::YearOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return YearOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MonthOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MonthOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::DayOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return DayOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::DecadeOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return DecadeOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::CenturyOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return CenturyOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MilleniumOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MilleniumOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::QuarterOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return QuarterOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::WeekOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return WeekOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MicrosecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MicrosecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MillisecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MillisecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::SecondsOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return SecondsOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::MinutesOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return MinutesOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -template <> -int64_t DateSub::HoursOperator::Operation(date_t startdate, date_t enddate) { - dtime_t t0(0); - return HoursOperator::Operation(Timestamp::FromDatetime(startdate, t0), - Timestamp::FromDatetime(enddate, t0)); -} - -// TIME specialisations -template <> -int64_t DateSub::YearOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"year\" not recognized"); -} - -template <> -int64_t DateSub::MonthOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"month\" not recognized"); -} - -template <> -int64_t DateSub::DayOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"day\" not recognized"); -} - -template <> -int64_t DateSub::DecadeOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"decade\" not recognized"); -} - -template <> -int64_t DateSub::CenturyOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"century\" not recognized"); -} - -template <> -int64_t DateSub::MilleniumOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"millennium\" not recognized"); -} - -template <> -int64_t DateSub::QuarterOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"quarter\" not recognized"); -} - -template <> -int64_t DateSub::WeekOperator::Operation(dtime_t startdate, dtime_t enddate) { - throw NotImplementedException("\"time\" units \"week\" not recognized"); -} - -template <> -int64_t DateSub::MicrosecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return enddate.micros - startdate.micros; -} - -template <> -int64_t DateSub::MillisecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MSEC; -} - -template <> -int64_t DateSub::SecondsOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_SEC; -} - -template <> -int64_t DateSub::MinutesOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_MINUTE; -} - -template <> -int64_t DateSub::HoursOperator::Operation(dtime_t startdate, dtime_t enddate) { - return (enddate.micros - startdate.micros) / Interval::MICROS_PER_HOUR; -} - -template -static int64_t SubtractDateParts(DatePartSpecifier type, TA startdate, TB enddate) { - switch (type) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::ISOYEAR: - return DateSub::YearOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MONTH: - return DateSub::MonthOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateSub::DayOperator::template Operation(startdate, enddate); - case DatePartSpecifier::DECADE: - return DateSub::DecadeOperator::template Operation(startdate, enddate); - case DatePartSpecifier::CENTURY: - return DateSub::CenturyOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLENNIUM: - return DateSub::MilleniumOperator::template Operation(startdate, enddate); - case DatePartSpecifier::QUARTER: - return DateSub::QuarterOperator::template Operation(startdate, enddate); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateSub::WeekOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MICROSECONDS: - return DateSub::MicrosecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MILLISECONDS: - return DateSub::MillisecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateSub::SecondsOperator::template Operation(startdate, enddate); - case DatePartSpecifier::MINUTE: - return DateSub::MinutesOperator::template Operation(startdate, enddate); - case DatePartSpecifier::HOUR: - return DateSub::HoursOperator::template Operation(startdate, enddate); - default: - throw NotImplementedException("Specifier type not implemented for DATESUB"); - } -} - -struct DateSubTernaryOperator { - template - static inline TR Operation(TS part, TA startdate, TB enddate, ValidityMask &mask, idx_t idx) { - if (Value::IsFinite(startdate) && Value::IsFinite(enddate)) { - return SubtractDateParts(GetDatePartSpecifier(part.GetString()), startdate, enddate); - } else { - mask.SetInvalid(idx); - return TR(); - } - } -}; - -template -static void DateSubBinaryExecutor(DatePartSpecifier type, Vector &left, Vector &right, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::YEAR: - case DatePartSpecifier::ISOYEAR: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MONTH: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::DECADE: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::CENTURY: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLENNIUM: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::QUARTER: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::MINUTE: - DateSub::BinaryExecute(left, right, result, count); - break; - case DatePartSpecifier::HOUR: - DateSub::BinaryExecute(left, right, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATESUB"); - } -} - -template -static void DateSubFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - auto &part_arg = args.data[0]; - auto &start_arg = args.data[1]; - auto &end_arg = args.data[2]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateSubBinaryExecutor(type, start_arg, end_arg, result, args.size()); - } - } else { - TernaryExecutor::ExecuteWithNulls( - part_arg, start_arg, end_arg, result, args.size(), - DateSubTernaryOperator::Operation); - } -} - -ScalarFunctionSet DateSubFun::GetFunctions() { - ScalarFunctionSet date_sub("date_sub"); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE, LogicalType::DATE}, - LogicalType::BIGINT, DateSubFunction)); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::BIGINT, DateSubFunction)); - date_sub.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIME, LogicalType::TIME}, - LogicalType::BIGINT, DateSubFunction)); - return date_sub; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp b/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp deleted file mode 100644 index cb54e30de..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/date_trunc.cpp +++ /dev/null @@ -1,737 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -struct DateTrunc { - template - static inline TR UnaryFunction(TA input) { - if (Value::IsFinite(input)) { - return OP::template Operation(input); - } else { - return Cast::template Operation(input); - } - } - - template - static inline void UnaryExecute(Vector &left, Vector &result, idx_t count) { - UnaryExecutor::Execute(left, result, count, UnaryFunction); - } - - struct MillenniumOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 1000) * 1000, 1, 1); - } - }; - - struct CenturyOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 100) * 100, 1, 1); - } - }; - - struct DecadeOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate((Date::ExtractYear(input) / 10) * 10, 1, 1); - } - }; - - struct YearOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate(Date::ExtractYear(input), 1, 1); - } - }; - - struct QuarterOperator { - template - static inline TR Operation(TA input) { - int32_t yyyy, mm, dd; - Date::Convert(input, yyyy, mm, dd); - mm = 1 + (((mm - 1) / 3) * 3); - return Date::FromDate(yyyy, mm, 1); - } - }; - - struct MonthOperator { - template - static inline TR Operation(TA input) { - return Date::FromDate(Date::ExtractYear(input), Date::ExtractMonth(input), 1); - } - }; - - struct WeekOperator { - template - static inline TR Operation(TA input) { - return Date::GetMondayOfCurrentWeek(input); - } - }; - - struct ISOYearOperator { - template - static inline TR Operation(TA input) { - date_t date = Date::GetMondayOfCurrentWeek(input); - date.days -= (Date::ExtractISOWeekNumber(date) - 1) * Interval::DAYS_PER_WEEK; - - return date; - } - }; - - struct DayOperator { - template - static inline TR Operation(TA input) { - return input; - } - }; - - struct HourOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, 0, 0, 0)); - } - }; - - struct MinuteOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, 0, 0)); - } - }; - - struct SecondOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, 0)); - } - }; - - struct MillisecondOperator { - template - static inline TR Operation(TA input) { - int32_t hour, min, sec, micros; - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - Time::Convert(time, hour, min, sec, micros); - micros -= UnsafeNumericCast(micros % Interval::MICROS_PER_MSEC); - return Timestamp::FromDatetime(date, Time::FromTime(hour, min, sec, micros)); - } - }; - - struct MicrosecondOperator { - template - static inline TR Operation(TA input) { - return input; - } - }; -}; - -// DATE specialisations -template <> -date_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { - return MillenniumOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::MillenniumOperator::Operation(date_t input) { - return Timestamp::FromDatetime(MillenniumOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::MillenniumOperator::Operation(timestamp_t input) { - return MillenniumOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { - return CenturyOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::CenturyOperator::Operation(date_t input) { - return Timestamp::FromDatetime(CenturyOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::CenturyOperator::Operation(timestamp_t input) { - return CenturyOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { - return DecadeOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::DecadeOperator::Operation(date_t input) { - return Timestamp::FromDatetime(DecadeOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::DecadeOperator::Operation(timestamp_t input) { - return DecadeOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::YearOperator::Operation(date_t input) { - return Timestamp::FromDatetime(YearOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::YearOperator::Operation(timestamp_t input) { - return YearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::QuarterOperator::Operation(date_t input) { - return Timestamp::FromDatetime(QuarterOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::QuarterOperator::Operation(timestamp_t input) { - return QuarterOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::MonthOperator::Operation(date_t input) { - return Timestamp::FromDatetime(MonthOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::MonthOperator::Operation(timestamp_t input) { - return MonthOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::WeekOperator::Operation(date_t input) { - return Timestamp::FromDatetime(WeekOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::WeekOperator::Operation(timestamp_t input) { - return WeekOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::ISOYearOperator::Operation(date_t input) { - return Timestamp::FromDatetime(ISOYearOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::ISOYearOperator::Operation(timestamp_t input) { - return ISOYearOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -timestamp_t DateTrunc::DayOperator::Operation(date_t input) { - return Timestamp::FromDatetime(DayOperator::Operation(input), dtime_t(0)); -} - -template <> -timestamp_t DateTrunc::DayOperator::Operation(timestamp_t input) { - return DayOperator::Operation(Timestamp::GetDate(input)); -} - -template <> -date_t DateTrunc::HourOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::HourOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::HourOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(HourOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MinuteOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MinuteOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MinuteOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(HourOperator::Operation(input)); -} - -template <> -date_t DateTrunc::SecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::SecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::SecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(DayOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MillisecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MillisecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MillisecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(MillisecondOperator::Operation(input)); -} - -template <> -date_t DateTrunc::MicrosecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -timestamp_t DateTrunc::MicrosecondOperator::Operation(date_t input) { - return DayOperator::Operation(input); -} - -template <> -date_t DateTrunc::MicrosecondOperator::Operation(timestamp_t input) { - return Timestamp::GetDate(MicrosecondOperator::Operation(input)); -} - -// INTERVAL specialisations -template <> -interval_t DateTrunc::MillenniumOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_MILLENIUM) * Interval::MONTHS_PER_MILLENIUM; - return input; -} - -template <> -interval_t DateTrunc::CenturyOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_CENTURY) * Interval::MONTHS_PER_CENTURY; - return input; -} - -template <> -interval_t DateTrunc::DecadeOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_DECADE) * Interval::MONTHS_PER_DECADE; - return input; -} - -template <> -interval_t DateTrunc::YearOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_YEAR) * Interval::MONTHS_PER_YEAR; - return input; -} - -template <> -interval_t DateTrunc::QuarterOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - input.months = (input.months / Interval::MONTHS_PER_QUARTER) * Interval::MONTHS_PER_QUARTER; - return input; -} - -template <> -interval_t DateTrunc::MonthOperator::Operation(interval_t input) { - input.days = 0; - input.micros = 0; - return input; -} - -template <> -interval_t DateTrunc::WeekOperator::Operation(interval_t input) { - input.micros = 0; - input.days = (input.days / Interval::DAYS_PER_WEEK) * Interval::DAYS_PER_WEEK; - return input; -} - -template <> -interval_t DateTrunc::ISOYearOperator::Operation(interval_t input) { - return YearOperator::Operation(input); -} - -template <> -interval_t DateTrunc::DayOperator::Operation(interval_t input) { - input.micros = 0; - return input; -} - -template <> -interval_t DateTrunc::HourOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_HOUR) * Interval::MICROS_PER_HOUR; - return input; -} - -template <> -interval_t DateTrunc::MinuteOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_MINUTE) * Interval::MICROS_PER_MINUTE; - return input; -} - -template <> -interval_t DateTrunc::SecondOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_SEC) * Interval::MICROS_PER_SEC; - return input; -} - -template <> -interval_t DateTrunc::MillisecondOperator::Operation(interval_t input) { - input.micros = (input.micros / Interval::MICROS_PER_MSEC) * Interval::MICROS_PER_MSEC; - return input; -} - -template <> -interval_t DateTrunc::MicrosecondOperator::Operation(interval_t input) { - return input; -} - -template -static TR TruncateElement(DatePartSpecifier type, TA element) { - if (!Value::IsFinite(element)) { - return Cast::template Operation(element); - } - - switch (type) { - case DatePartSpecifier::MILLENNIUM: - return DateTrunc::MillenniumOperator::Operation(element); - case DatePartSpecifier::CENTURY: - return DateTrunc::CenturyOperator::Operation(element); - case DatePartSpecifier::DECADE: - return DateTrunc::DecadeOperator::Operation(element); - case DatePartSpecifier::YEAR: - return DateTrunc::YearOperator::Operation(element); - case DatePartSpecifier::QUARTER: - return DateTrunc::QuarterOperator::Operation(element); - case DatePartSpecifier::MONTH: - return DateTrunc::MonthOperator::Operation(element); - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return DateTrunc::WeekOperator::Operation(element); - case DatePartSpecifier::ISOYEAR: - return DateTrunc::ISOYearOperator::Operation(element); - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return DateTrunc::DayOperator::Operation(element); - case DatePartSpecifier::HOUR: - return DateTrunc::HourOperator::Operation(element); - case DatePartSpecifier::MINUTE: - return DateTrunc::MinuteOperator::Operation(element); - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return DateTrunc::SecondOperator::Operation(element); - case DatePartSpecifier::MILLISECONDS: - return DateTrunc::MillisecondOperator::Operation(element); - case DatePartSpecifier::MICROSECONDS: - return DateTrunc::MicrosecondOperator::Operation(element); - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC"); - } -} - -struct DateTruncBinaryOperator { - template - static inline TR Operation(TA specifier, TB date) { - return TruncateElement(GetDatePartSpecifier(specifier.GetString()), date); - } -}; - -template -static void DateTruncUnaryExecutor(DatePartSpecifier type, Vector &left, Vector &result, idx_t count) { - switch (type) { - case DatePartSpecifier::MILLENNIUM: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::CENTURY: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::DECADE: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::YEAR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::QUARTER: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MONTH: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::ISOYEAR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::HOUR: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MINUTE: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MILLISECONDS: - DateTrunc::UnaryExecute(left, result, count); - break; - case DatePartSpecifier::MICROSECONDS: - DateTrunc::UnaryExecute(left, result, count); - break; - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC"); - } -} - -template -static void DateTruncFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto &part_arg = args.data[0]; - auto &date_arg = args.data[1]; - - if (part_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Common case of constant part. - if (ConstantVector::IsNull(part_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - const auto type = GetDatePartSpecifier(ConstantVector::GetData(part_arg)->GetString()); - DateTruncUnaryExecutor(type, date_arg, result, args.size()); - } - } else { - BinaryExecutor::ExecuteStandard(part_arg, date_arg, result, - args.size()); - } -} - -template -static unique_ptr DateTruncStatistics(vector &child_stats) { - // we can only propagate date stats if the child has stats - auto &nstats = child_stats[1]; - if (!NumericStats::HasMinMax(nstats)) { - return nullptr; - } - // run the operator on both the min and the max, this gives us the [min, max] bound - auto min = NumericStats::GetMin(nstats); - auto max = NumericStats::GetMax(nstats); - if (min > max) { - return nullptr; - } - - // Infinite values are unmodified - auto min_part = DateTrunc::UnaryFunction(min); - auto max_part = DateTrunc::UnaryFunction(max); - - auto min_value = Value::CreateValue(min_part); - auto max_value = Value::CreateValue(max_part); - auto result = NumericStats::CreateEmpty(min_value.type()); - NumericStats::SetMin(result, min_value); - NumericStats::SetMax(result, max_value); - result.CopyValidity(child_stats[0]); - return result.ToUnique(); -} - -template -static unique_ptr PropagateDateTruncStatistics(ClientContext &context, FunctionStatisticsInput &input) { - return DateTruncStatistics(input.child_stats); -} - -template -static function_statistics_t DateTruncStats(DatePartSpecifier type) { - switch (type) { - case DatePartSpecifier::MILLENNIUM: - return PropagateDateTruncStatistics; - case DatePartSpecifier::CENTURY: - return PropagateDateTruncStatistics; - case DatePartSpecifier::DECADE: - return PropagateDateTruncStatistics; - case DatePartSpecifier::YEAR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::QUARTER: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MONTH: - return PropagateDateTruncStatistics; - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - return PropagateDateTruncStatistics; - case DatePartSpecifier::ISOYEAR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - return PropagateDateTruncStatistics; - case DatePartSpecifier::HOUR: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MINUTE: - return PropagateDateTruncStatistics; - case DatePartSpecifier::SECOND: - case DatePartSpecifier::EPOCH: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MILLISECONDS: - return PropagateDateTruncStatistics; - case DatePartSpecifier::MICROSECONDS: - return PropagateDateTruncStatistics; - default: - throw NotImplementedException("Specifier type not implemented for DATETRUNC statistics"); - } -} - -static unique_ptr DateTruncBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (!arguments[0]->IsFoldable()) { - return nullptr; - } - - // Rebind to return a date if we are truncating that far - Value part_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (part_value.IsNull()) { - return nullptr; - } - const auto part_name = part_value.ToString(); - const auto part_code = GetDatePartSpecifier(part_name); - switch (part_code) { - case DatePartSpecifier::MILLENNIUM: - case DatePartSpecifier::CENTURY: - case DatePartSpecifier::DECADE: - case DatePartSpecifier::YEAR: - case DatePartSpecifier::QUARTER: - case DatePartSpecifier::MONTH: - case DatePartSpecifier::WEEK: - case DatePartSpecifier::YEARWEEK: - case DatePartSpecifier::ISOYEAR: - case DatePartSpecifier::DAY: - case DatePartSpecifier::DOW: - case DatePartSpecifier::ISODOW: - case DatePartSpecifier::DOY: - case DatePartSpecifier::JULIAN_DAY: - switch (bound_function.arguments[1].id()) { - case LogicalType::TIMESTAMP: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); - break; - case LogicalType::DATE: - bound_function.function = DateTruncFunction; - bound_function.statistics = DateTruncStats(part_code); - break; - default: - throw NotImplementedException("Temporal argument type for DATETRUNC"); - } - bound_function.return_type = LogicalType::DATE; - break; - default: - switch (bound_function.arguments[1].id()) { - case LogicalType::TIMESTAMP: - bound_function.statistics = DateTruncStats(part_code); - break; - case LogicalType::DATE: - bound_function.statistics = DateTruncStats(part_code); - break; - default: - throw NotImplementedException("Temporal argument type for DATETRUNC"); - } - break; - } - - return nullptr; -} - -ScalarFunctionSet DateTruncFun::GetFunctions() { - ScalarFunctionSet date_trunc("date_trunc"); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, - DateTruncFunction, DateTruncBind)); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::TIMESTAMP, - DateTruncFunction, DateTruncBind)); - date_trunc.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::INTERVAL}, LogicalType::INTERVAL, - DateTruncFunction)); - for (auto &func : date_trunc.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return date_trunc; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/epoch.cpp b/src/duckdb/extension/core_functions/scalar/date/epoch.cpp deleted file mode 100644 index cda3232a4..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/epoch.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" - -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" - -namespace duckdb { - -struct EpochSecOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE sec) { - int64_t result; - if (!TryCast::Operation(sec * Interval::MICROS_PER_SEC, result)) { - throw ConversionException("Epoch seconds out of range for TIMESTAMP WITH TIME ZONE"); - } - return timestamp_t(result); - } -}; - -static void EpochSecFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size()); -} - -ScalarFunction ToTimestampFun::GetFunction() { - // to_timestamp is an alias from Postgres that converts the time in seconds to a timestamp - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::TIMESTAMP_TZ, EpochSecFunction); -} - -struct NormalizedIntervalOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - return input.Normalize(); - } -}; - -static void NormalizedIntervalFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size()); -} - -ScalarFunction NormalizedIntervalFun::GetFunction() { - return ScalarFunction({LogicalType::INTERVAL}, LogicalType::INTERVAL, NormalizedIntervalFunction); -} - -struct TimeTZSortKeyOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - return input.sort_key(); - } -}; - -static void TimeTZSortKeyFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - UnaryExecutor::Execute(input.data[0], result, input.size()); -} - -ScalarFunction TimeTZSortKeyFun::GetFunction() { - return ScalarFunction({LogicalType::TIME_TZ}, LogicalType::UBIGINT, TimeTZSortKeyFunction); -} -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp b/src/duckdb/extension/core_functions/scalar/date/make_date.cpp deleted file mode 100644 index 0fe00a920..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/make_date.cpp +++ /dev/null @@ -1,181 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/vector_operations/senary_executor.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" - -#include - -namespace duckdb { - -static void MakeDateFromEpoch(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - result.Reinterpret(input.data[0]); -} - -struct MakeDateOperator { - template - static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd) { - return Date::FromDate(Cast::Operation(yyyy), Cast::Operation(mm), - Cast::Operation(dd)); - } -}; - -template -static void ExecuteMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), - MakeDateOperator::Operation); -} - -template -static date_t FromDateCast(T year, T month, T day) { - date_t result; - if (!Date::TryFromDate(Cast::Operation(year), Cast::Operation(month), - Cast::Operation(day), result)) { - throw ConversionException("Date out of range: %d-%d-%d", year, month, day); - } - return result; -} - -template -static void ExecuteStructMakeDate(DataChunk &input, ExpressionState &state, Vector &result) { - // this should be guaranteed by the binder - D_ASSERT(input.ColumnCount() == 1); - auto &vec = input.data[0]; - - auto &children = StructVector::GetEntries(vec); - D_ASSERT(children.size() == 3); - auto &yyyy = *children[0]; - auto &mm = *children[1]; - auto &dd = *children[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), FromDateCast); -} - -struct MakeTimeOperator { - template - static RESULT_TYPE Operation(HH hh, MM mm, SS ss) { - - auto hh_32 = Cast::Operation(hh); - auto mm_32 = Cast::Operation(mm); - // Have to check this separately because safe casting of DOUBLE => INT32 can round. - int32_t ss_32 = 0; - if (ss < 0 || ss > Interval::SECS_PER_MINUTE) { - ss_32 = Cast::Operation(ss); - } else { - ss_32 = LossyNumericCast(ss); - } - auto micros = LossyNumericCast(std::round((ss - ss_32) * Interval::MICROS_PER_SEC)); - - if (!Time::IsValidTime(hh_32, mm_32, ss_32, micros)) { - throw ConversionException("Time out of range: %d:%d:%d.%d", hh_32, mm_32, ss_32, micros); - } - return Time::FromTime(hh_32, mm_32, ss_32, micros); - } -}; - -template -static void ExecuteMakeTime(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 3); - auto &yyyy = input.data[0]; - auto &mm = input.data[1]; - auto &dd = input.data[2]; - - TernaryExecutor::Execute(yyyy, mm, dd, result, input.size(), - MakeTimeOperator::Operation); -} - -struct MakeTimestampOperator { - template - static RESULT_TYPE Operation(YYYY yyyy, MM mm, DD dd, HR hr, MN mn, SS ss) { - const auto d = MakeDateOperator::Operation(yyyy, mm, dd); - const auto t = MakeTimeOperator::Operation(hr, mn, ss); - return Timestamp::FromDatetime(d, t); - } - - template - static RESULT_TYPE Operation(T value) { - const auto result = RESULT_TYPE(value); - if (!Timestamp::IsFinite(result)) { - throw ConversionException("Timestamp microseconds out of range: %ld", value); - } - return RESULT_TYPE(value); - } -}; - -template -static void ExecuteMakeTimestamp(DataChunk &input, ExpressionState &state, Vector &result) { - if (input.ColumnCount() == 1) { - auto func = MakeTimestampOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); - return; - } - - D_ASSERT(input.ColumnCount() == 6); - - auto func = MakeTimestampOperator::Operation; - SenaryExecutor::Execute(input, result, func); -} - -template -static void ExecuteMakeTimestampNs(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() == 1); - - auto func = MakeTimestampOperator::Operation; - UnaryExecutor::Execute(input.data[0], result, input.size(), func); - return; -} - -ScalarFunctionSet MakeDateFun::GetFunctions() { - ScalarFunctionSet make_date("make_date"); - make_date.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::DATE, MakeDateFromEpoch)); - make_date.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::DATE, ExecuteMakeDate)); - - child_list_t make_date_children { - {"year", LogicalType::BIGINT}, {"month", LogicalType::BIGINT}, {"day", LogicalType::BIGINT}}; - make_date.AddFunction( - ScalarFunction({LogicalType::STRUCT(make_date_children)}, LogicalType::DATE, ExecuteStructMakeDate)); - for (auto &func : make_date.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return make_date; -} - -ScalarFunction MakeTimeFun::GetFunction() { - ScalarFunction function({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, LogicalType::TIME, - ExecuteMakeTime); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunctionSet MakeTimestampFun::GetFunctions() { - ScalarFunctionSet operator_set("make_timestamp"); - operator_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, - LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::DOUBLE}, - LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); - operator_set.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP, ExecuteMakeTimestamp)); - - for (auto &func : operator_set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return operator_set; -} - -ScalarFunctionSet MakeTimestampNsFun::GetFunctions() { - ScalarFunctionSet operator_set("make_timestamp_ns"); - operator_set.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::TIMESTAMP_NS, ExecuteMakeTimestampNs)); - return operator_set; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp b/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp deleted file mode 100644 index 726d6b54f..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/time_bucket.cpp +++ /dev/null @@ -1,373 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "core_functions/scalar/date_functions.hpp" - -namespace duckdb { - -struct TimeBucket { - - // Use 2000-01-03 00:00:00 (Monday) as origin when bucket_width is days, hours, ... for TimescaleDB compatibility - // There are 10959 days between 1970-01-01 and 2000-01-03 - constexpr static const int64_t DEFAULT_ORIGIN_MICROS = 10959 * Interval::MICROS_PER_DAY; - // Use 2000-01-01 as origin when bucket_width is months, years, ... for TimescaleDB compatibility - // There are 360 months between 1970-01-01 and 2000-01-01 - constexpr static const int32_t DEFAULT_ORIGIN_MONTHS = 360; - - enum struct BucketWidthType : uint8_t { CONVERTIBLE_TO_MICROS, CONVERTIBLE_TO_MONTHS, UNCLASSIFIED }; - - static inline BucketWidthType ClassifyBucketWidth(const interval_t bucket_width) { - if (bucket_width.months == 0 && Interval::GetMicro(bucket_width) > 0) { - return BucketWidthType::CONVERTIBLE_TO_MICROS; - } else if (bucket_width.months > 0 && bucket_width.days == 0 && bucket_width.micros == 0) { - return BucketWidthType::CONVERTIBLE_TO_MONTHS; - } else { - return BucketWidthType::UNCLASSIFIED; - } - } - - static inline BucketWidthType ClassifyBucketWidthErrorThrow(const interval_t bucket_width) { - if (bucket_width.months == 0) { - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - if (bucket_width_micros <= 0) { - throw NotImplementedException("Period must be greater than 0"); - } - return BucketWidthType::CONVERTIBLE_TO_MICROS; - } else if (bucket_width.months != 0 && bucket_width.days == 0 && bucket_width.micros == 0) { - if (bucket_width.months < 0) { - throw NotImplementedException("Period must be greater than 0"); - } - return BucketWidthType::CONVERTIBLE_TO_MONTHS; - } else { - throw NotImplementedException("Month intervals cannot have day or time component"); - } - } - - template - static inline int32_t EpochMonths(T ts) { - date_t ts_date = Cast::template Operation(ts); - return (Date::ExtractYear(ts_date) - 1970) * 12 + Date::ExtractMonth(ts_date) - 1; - } - - static inline timestamp_t WidthConvertibleToMicrosCommon(int64_t bucket_width_micros, int64_t ts_micros, - int64_t origin_micros) { - origin_micros %= bucket_width_micros; - ts_micros = SubtractOperatorOverflowCheck::Operation(ts_micros, origin_micros); - - int64_t result_micros = (ts_micros / bucket_width_micros) * bucket_width_micros; - if (ts_micros < 0 && ts_micros % bucket_width_micros != 0) { - result_micros = - SubtractOperatorOverflowCheck::Operation(result_micros, bucket_width_micros); - } - result_micros += origin_micros; - - return Timestamp::FromEpochMicroSeconds(result_micros); - } - - static inline date_t WidthConvertibleToMonthsCommon(int32_t bucket_width_months, int32_t ts_months, - int32_t origin_months) { - origin_months %= bucket_width_months; - ts_months = SubtractOperatorOverflowCheck::Operation(ts_months, origin_months); - - int32_t result_months = (ts_months / bucket_width_months) * bucket_width_months; - if (ts_months < 0 && ts_months % bucket_width_months != 0) { - result_months = - SubtractOperatorOverflowCheck::Operation(result_months, bucket_width_months); - } - result_months += origin_months; - - int32_t year = - (result_months < 0 && result_months % 12 != 0) ? 1970 + result_months / 12 - 1 : 1970 + result_months / 12; - int32_t month = - (result_months < 0 && result_months % 12 != 0) ? result_months % 12 + 13 : result_months % 12 + 1; - - return Date::FromDate(year, month, 1); - } - - struct WidthConvertibleToMicrosBinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); - return Cast::template Operation( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS)); - } - }; - - struct WidthConvertibleToMonthsBinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(ts); - return Cast::template Operation( - WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)); - } - }; - - struct BinaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts) { - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return WidthConvertibleToMicrosBinaryOperator::Operation(bucket_width, ts); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return WidthConvertibleToMonthsBinaryOperator::Operation(bucket_width, ts); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; - - struct OffsetWidthConvertibleToMicrosTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds( - Interval::Add(Cast::template Operation(ts), Interval::Invert(offset))); - return Cast::template Operation(Interval::Add( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, DEFAULT_ORIGIN_MICROS), offset)); - } - }; - - struct OffsetWidthConvertibleToMonthsTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(Interval::Add(ts, Interval::Invert(offset))); - return Interval::Add(Cast::template Operation(WidthConvertibleToMonthsCommon( - bucket_width.months, ts_months, DEFAULT_ORIGIN_MONTHS)), - offset); - } - }; - - struct OffsetTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC offset) { - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return OffsetWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, - offset); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return OffsetWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, - offset); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; - - struct OriginWidthConvertibleToMicrosTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int64_t bucket_width_micros = Interval::GetMicro(bucket_width); - int64_t ts_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(ts)); - int64_t origin_micros = Timestamp::GetEpochMicroSeconds(Cast::template Operation(origin)); - return Cast::template Operation( - WidthConvertibleToMicrosCommon(bucket_width_micros, ts_micros, origin_micros)); - } - }; - - struct OriginWidthConvertibleToMonthsTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin) { - if (!Value::IsFinite(ts)) { - return Cast::template Operation(ts); - } - int32_t ts_months = EpochMonths(ts); - int32_t origin_months = EpochMonths(origin); - return Cast::template Operation( - WidthConvertibleToMonthsCommon(bucket_width.months, ts_months, origin_months)); - } - }; - - struct OriginTernaryOperator { - template - static inline TR Operation(TA bucket_width, TB ts, TC origin, ValidityMask &mask, idx_t idx) { - if (!Value::IsFinite(origin)) { - mask.SetInvalid(idx); - return TR(); - } - BucketWidthType bucket_width_type = ClassifyBucketWidthErrorThrow(bucket_width); - switch (bucket_width_type) { - case BucketWidthType::CONVERTIBLE_TO_MICROS: - return OriginWidthConvertibleToMicrosTernaryOperator::Operation(bucket_width, ts, - origin); - case BucketWidthType::CONVERTIBLE_TO_MONTHS: - return OriginWidthConvertibleToMonthsTernaryOperator::Operation(bucket_width, ts, - origin); - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - }; -}; - -template -static void TimeBucketFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::WidthConvertibleToMicrosBinaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - BinaryExecutor::Execute( - bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::WidthConvertibleToMonthsBinaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::BinaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - BinaryExecutor::Execute(bucket_width_arg, ts_arg, result, args.size(), - TimeBucket::BinaryOperator::Operation); - } -} - -template -static void TimeBucketOffsetFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &offset_arg = args.data[2]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetWidthConvertibleToMicrosTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetWidthConvertibleToMonthsTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetTernaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, offset_arg, result, args.size(), - TimeBucket::OffsetTernaryOperator::Operation); - } -} - -template -static void TimeBucketOriginFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3); - - auto &bucket_width_arg = args.data[0]; - auto &ts_arg = args.data[1]; - auto &origin_arg = args.data[2]; - - if (bucket_width_arg.GetVectorType() == VectorType::CONSTANT_VECTOR && - origin_arg.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(bucket_width_arg) || ConstantVector::IsNull(origin_arg) || - !Value::IsFinite(*ConstantVector::GetData(origin_arg))) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - } else { - interval_t bucket_width = *ConstantVector::GetData(bucket_width_arg); - TimeBucket::BucketWidthType bucket_width_type = TimeBucket::ClassifyBucketWidth(bucket_width); - switch (bucket_width_type) { - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MICROS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginWidthConvertibleToMicrosTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::CONVERTIBLE_TO_MONTHS: - TernaryExecutor::Execute( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginWidthConvertibleToMonthsTernaryOperator::Operation); - break; - case TimeBucket::BucketWidthType::UNCLASSIFIED: - TernaryExecutor::ExecuteWithNulls( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginTernaryOperator::Operation); - break; - default: - throw NotImplementedException("Bucket type not implemented for TIME_BUCKET"); - } - } - } else { - TernaryExecutor::ExecuteWithNulls( - bucket_width_arg, ts_arg, origin_arg, result, args.size(), - TimeBucket::OriginTernaryOperator::Operation); - } -} - -ScalarFunctionSet TimeBucketFun::GetFunctions() { - ScalarFunctionSet time_bucket; - time_bucket.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE}, LogicalType::DATE, TimeBucketFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP}, LogicalType::TIMESTAMP, - TimeBucketFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::INTERVAL}, - LogicalType::DATE, TimeBucketOffsetFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::TIMESTAMP, TimeBucketOffsetFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::DATE, LogicalType::DATE}, - LogicalType::DATE, TimeBucketOriginFunction)); - time_bucket.AddFunction(ScalarFunction({LogicalType::INTERVAL, LogicalType::TIMESTAMP, LogicalType::TIMESTAMP}, - LogicalType::TIMESTAMP, TimeBucketOriginFunction)); - for (auto &func : time_bucket.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return time_bucket; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp b/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp deleted file mode 100644 index c8d508883..000000000 --- a/src/duckdb/extension/core_functions/scalar/date/to_interval.cpp +++ /dev/null @@ -1,258 +0,0 @@ -#include "core_functions/scalar/date_functions.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/function/to_interval.hpp" - -namespace duckdb { - -template <> -bool TryMultiplyOperator::Operation(double left, int64_t right, int64_t &result) { - return TryCast::Operation(left * double(right), result); -} - -struct ToMillenniaOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.days = 0; - result.micros = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_MILLENIUM, - result.months)) { - throw OutOfRangeException("Interval value %s millennia out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToCenturiesOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.days = 0; - result.micros = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_CENTURY, result.months)) { - throw OutOfRangeException("Interval value %s centuries out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToDecadesOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.days = 0; - result.micros = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_DECADE, result.months)) { - throw OutOfRangeException("Interval value %s decades out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToYearsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.days = 0; - result.micros = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_YEAR, - result.months)) { - throw OutOfRangeException("Interval value %d years out of range", input); - } - return result; - } -}; - -struct ToQuartersOperator { - template - static inline TR Operation(TA input) { - interval_t result; - if (!TryMultiplyOperator::Operation(input, Interval::MONTHS_PER_QUARTER, - result.months)) { - throw OutOfRangeException("Interval value %d quarters out of range", input); - } - result.days = 0; - result.micros = 0; - return result; - } -}; - -struct ToMonthsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = input; - result.days = 0; - result.micros = 0; - return result; - } -}; - -struct ToWeeksOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - if (!TryMultiplyOperator::Operation(input, Interval::DAYS_PER_WEEK, result.days)) { - throw OutOfRangeException("Interval value %d weeks out of range", input); - } - result.micros = 0; - return result; - } -}; - -struct ToDaysOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = input; - result.micros = 0; - return result; - } -}; - -struct ToHoursOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_HOUR, result.micros)) { - throw OutOfRangeException("Interval value %s hours out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToMinutesOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MINUTE, result.micros)) { - throw OutOfRangeException("Interval value %s minutes out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToMilliSecondsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - if (!TryMultiplyOperator::Operation(input, Interval::MICROS_PER_MSEC, result.micros)) { - throw OutOfRangeException("Interval value %s milliseconds out of range", NumericHelper::ToString(input)); - } - return result; - } -}; - -struct ToMicroSecondsOperator { - template - static inline TR Operation(TA input) { - interval_t result; - result.months = 0; - result.days = 0; - result.micros = input; - return result; - } -}; - -ScalarFunction ToMillenniaFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToCenturiesFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToDecadesFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToYearsFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToQuartersFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToMonthsFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToWeeksFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToDaysFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToHoursFun::GetFunction() { - ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToMinutesFun::GetFunction() { - ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToSecondsFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToMillisecondsFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunction ToMicrosecondsFun::GetFunction() { - ScalarFunction function({LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp b/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp deleted file mode 100644 index 627d7ac28..000000000 --- a/src/duckdb/extension/core_functions/scalar/debug/vector_type.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "core_functions/scalar/debug_functions.hpp" - -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -static void VectorTypeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto data = ConstantVector::GetData(result); - data[0] = StringVector::AddString(result, EnumUtil::ToString(input.data[0].GetVectorType())); -} - -ScalarFunction VectorTypeFun::GetFunction() { - auto vector_type_fun = ScalarFunction("vector_type", // name of the function - {LogicalType::ANY}, // argument list - LogicalType::VARCHAR, // return type - VectorTypeFunction); - vector_type_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return vector_type_fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp b/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp deleted file mode 100644 index a10ec381c..000000000 --- a/src/duckdb/extension/core_functions/scalar/enum/enum_functions.cpp +++ /dev/null @@ -1,164 +0,0 @@ -#include "core_functions/scalar/enum_functions.hpp" - -namespace duckdb { - -static void EnumFirstFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - auto val = Value(enum_vector.GetValue(0)); - result.Reference(val); -} - -static void EnumLastFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto enum_size = EnumType::GetSize(types[0]); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - auto val = Value(enum_vector.GetValue(enum_size - 1)); - result.Reference(val); -} - -static void EnumRangeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 1); - auto enum_size = EnumType::GetSize(types[0]); - auto &enum_vector = EnumType::GetValuesInsertOrder(types[0]); - vector enum_values; - for (idx_t i = 0; i < enum_size; i++) { - enum_values.emplace_back(enum_vector.GetValue(i)); - } - auto val = Value::LIST(LogicalType::VARCHAR, enum_values); - result.Reference(val); -} - -static void EnumRangeBoundaryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto types = input.GetTypes(); - D_ASSERT(types.size() == 2); - idx_t start, end; - auto first_param = input.GetValue(0, 0); - auto second_param = input.GetValue(1, 0); - - auto &enum_vector = - first_param.IsNull() ? EnumType::GetValuesInsertOrder(types[1]) : EnumType::GetValuesInsertOrder(types[0]); - - if (first_param.IsNull()) { - start = 0; - } else { - start = first_param.GetValue(); - } - if (second_param.IsNull()) { - end = EnumType::GetSize(types[0]); - } else { - end = second_param.GetValue() + 1; - } - vector enum_values; - for (idx_t i = start; i < end; i++) { - enum_values.emplace_back(enum_vector.GetValue(i)); - } - auto val = Value::LIST(LogicalType::VARCHAR, enum_values); - result.Reference(val); -} - -static void EnumCodeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.GetTypes().size() == 1); - result.Reinterpret(input.data[0]); -} - -static void CheckEnumParameter(const Expression &expr) { - if (expr.HasParameter()) { - throw ParameterNotResolvedException(); - } -} - -unique_ptr BindEnumFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { - throw BinderException("This function needs an ENUM as an argument"); - } - return nullptr; -} - -unique_ptr BindEnumCodeFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM) { - throw BinderException("This function needs an ENUM as an argument"); - } - - auto phy_type = EnumType::GetPhysicalType(arguments[0]->return_type); - switch (phy_type) { - case PhysicalType::UINT8: - bound_function.return_type = LogicalType(LogicalTypeId::UTINYINT); - break; - case PhysicalType::UINT16: - bound_function.return_type = LogicalType(LogicalTypeId::USMALLINT); - break; - case PhysicalType::UINT32: - bound_function.return_type = LogicalType(LogicalTypeId::UINTEGER); - break; - case PhysicalType::UINT64: - bound_function.return_type = LogicalType(LogicalTypeId::UBIGINT); - break; - default: - throw InternalException("Unsupported Enum Internal Type"); - } - - return nullptr; -} - -unique_ptr BindEnumRangeBoundaryFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - CheckEnumParameter(*arguments[0]); - CheckEnumParameter(*arguments[1]); - if (arguments[0]->return_type.id() != LogicalTypeId::ENUM && arguments[0]->return_type != LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[1]->return_type.id() != LogicalTypeId::ENUM && arguments[1]->return_type != LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[0]->return_type == LogicalType::SQLNULL && arguments[1]->return_type == LogicalType::SQLNULL) { - throw BinderException("This function needs an ENUM as an argument"); - } - if (arguments[0]->return_type.id() == LogicalTypeId::ENUM && - arguments[1]->return_type.id() == LogicalTypeId::ENUM && - arguments[0]->return_type != arguments[1]->return_type) { - throw BinderException("The parameters need to link to ONLY one enum OR be NULL "); - } - return nullptr; -} - -ScalarFunction EnumFirstFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumFirstFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumLastFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, EnumLastFunction, BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumCodeFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::ANY, EnumCodeFunction, BindEnumCodeFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumRangeFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), EnumRangeFunction, - BindEnumFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction EnumRangeBoundaryFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::LIST(LogicalType::VARCHAR), - EnumRangeBoundaryFunction, BindEnumRangeBoundaryFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp b/src/duckdb/extension/core_functions/scalar/generic/alias.cpp deleted file mode 100644 index 4edadcaaf..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/alias.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static void AliasFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - Value v(state.expr.GetAlias().empty() ? func_expr.children[0]->GetName() : state.expr.GetAlias()); - result.Reference(v); -} - -ScalarFunction AliasFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, AliasFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp b/src/duckdb/extension/core_functions/scalar/generic/binning.cpp deleted file mode 100644 index 83f7c0700..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/binning.cpp +++ /dev/null @@ -1,508 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/generic_executor.hpp" -#include "core_functions/scalar/generic_functions.hpp" - -namespace duckdb { - -static hugeint_t GetPreviousPowerOfTen(hugeint_t input) { - hugeint_t power_of_ten = 1; - while (power_of_ten < input) { - power_of_ten *= 10; - } - return power_of_ten / 10; -} - -enum class NiceRounding { CEILING, ROUND }; - -hugeint_t RoundToNumber(hugeint_t input, hugeint_t num, NiceRounding rounding) { - if (rounding == NiceRounding::ROUND) { - return (input + (num / 2)) / num * num; - } else { - return (input + (num - 1)) / num * num; - } -} - -hugeint_t MakeNumberNice(hugeint_t input, hugeint_t step, NiceRounding rounding) { - // we consider numbers nice if they are divisible by 2 or 5 times the power-of-ten one lower than the current - // e.g. 120 is a nice number because it is divisible by 20 - // 122 is not a nice number -> we make it nice by turning it into 120 [/20] - // 153 is not a nice number -> we make it nice by turning it into 150 [/50] - // 1220 is not a nice number -> we turn it into 1200 [/200] - // first figure out the previous power of 10 (i.e. for 67 we return 10) - // now the power of ten is the power BELOW the current number - // i.e. for 67, it is not 10 - // now we can get the 2 or 5 divisors - hugeint_t power_of_ten = GetPreviousPowerOfTen(step); - hugeint_t two = power_of_ten * 2; - hugeint_t five = power_of_ten; - if (power_of_ten * 3 <= step) { - two *= 5; - } - if (power_of_ten * 2 <= step) { - five *= 5; - } - - // compute the closest round number by adding the divisor / 2 and truncating - // do this for both divisors - hugeint_t round_to_two = RoundToNumber(input, two, rounding); - hugeint_t round_to_five = RoundToNumber(input, five, rounding); - // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) - if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { - return round_to_two; - } else { - return round_to_five; - } -} - -static double GetPreviousPowerOfTen(double input) { - double power_of_ten = 1; - if (input < 1) { - while (power_of_ten > input) { - power_of_ten /= 10; - } - return power_of_ten; - } - while (power_of_ten < input) { - power_of_ten *= 10; - } - return power_of_ten / 10; -} - -double RoundToNumber(double input, double num, NiceRounding rounding) { - double result; - if (rounding == NiceRounding::ROUND) { - result = std::round(input / num) * num; - } else { - result = std::ceil(input / num) * num; - } - if (!Value::IsFinite(result)) { - return input; - } - return result; -} - -double MakeNumberNice(double input, const double step, NiceRounding rounding) { - if (input == 0) { - return 0; - } - // now the power of ten is the power BELOW the current number - // i.e. for 67, it is not 10 - // now we can get the 2 or 5 divisors - double power_of_ten = GetPreviousPowerOfTen(step); - double two = power_of_ten * 2; - double five = power_of_ten; - if (power_of_ten * 3 <= step) { - two *= 5; - } - if (power_of_ten * 2 <= step) { - five *= 5; - } - - double round_to_two = RoundToNumber(input, two, rounding); - double round_to_five = RoundToNumber(input, five, rounding); - // now pick the closest number of the two (i.e. for 147 we pick 150, not 140) - if (AbsValue(input - round_to_two) < AbsValue(input - round_to_five)) { - return round_to_two; - } else { - return round_to_five; - } -} - -struct EquiWidthBinsInteger { - static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::BIGINT; - - static vector> Operation(const Expression &expr, int64_t input_min, int64_t input_max, - idx_t bin_count, bool nice_rounding) { - vector> result; - // to prevent integer truncation from affecting the bin boundaries we calculate them with numbers multiplied by - // 1000 we then divide to get the actual boundaries - const auto FACTOR = hugeint_t(1000); - auto min = hugeint_t(input_min) * FACTOR; - auto max = hugeint_t(input_max) * FACTOR; - - const hugeint_t span = max - min; - hugeint_t step = span / Hugeint::Convert(bin_count); - if (nice_rounding) { - // when doing nice rounding we try to make the max/step values nicer - hugeint_t new_step = MakeNumberNice(step, step, NiceRounding::ROUND); - hugeint_t new_max = RoundToNumber(max, new_step, NiceRounding::CEILING); - if (new_max != min && new_step != 0) { - max = new_max; - step = new_step; - } - // we allow for more bins when doing nice rounding since the bin count is approximate - bin_count *= 2; - } - for (hugeint_t bin_boundary = max; bin_boundary > min; bin_boundary -= step) { - const hugeint_t target_boundary = bin_boundary / FACTOR; - int64_t real_boundary = Hugeint::Cast(target_boundary); - if (!result.empty()) { - if (real_boundary < input_min || result.size() >= bin_count) { - // we can never generate input_min - break; - } - if (real_boundary == result.back().val) { - // we cannot generate the same value multiple times in a row - skip this step - continue; - } - } - result.push_back(real_boundary); - } - return result; - } -}; - -struct EquiWidthBinsDouble { - static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::DOUBLE; - - static vector> Operation(const Expression &expr, double min, double input_max, - idx_t bin_count, bool nice_rounding) { - double max = input_max; - if (!Value::IsFinite(min) || !Value::IsFinite(max)) { - throw InvalidInputException("equi_width_bucket does not support infinite or nan as min/max value"); - } - vector> result; - const double span = max - min; - double step; - if (!Value::IsFinite(span)) { - // max - min does not fit - step = max / static_cast(bin_count) - min / static_cast(bin_count); - } else { - step = span / static_cast(bin_count); - } - const double step_power_of_ten = GetPreviousPowerOfTen(step); - if (nice_rounding) { - // when doing nice rounding we try to make the max/step values nicer - step = MakeNumberNice(step, step, NiceRounding::ROUND); - max = RoundToNumber(input_max, step, NiceRounding::CEILING); - // we allow for more bins when doing nice rounding since the bin count is approximate - bin_count *= 2; - } - if (step == 0) { - throw InternalException("step is 0!?"); - } - - const double round_multiplication = 10 / step_power_of_ten; - for (double bin_boundary = max; bin_boundary > min; bin_boundary -= step) { - // because floating point addition adds inaccuracies, we add rounding at every step - double real_boundary = bin_boundary; - if (nice_rounding) { - real_boundary = std::round(bin_boundary * round_multiplication) / round_multiplication; - } - if (!result.empty() && result.back().val == real_boundary) { - // skip this step - continue; - } - if (real_boundary <= min || result.size() >= bin_count) { - // we can never generate below input_min - break; - } - result.push_back(real_boundary); - } - return result; - } -}; - -void NextMonth(int32_t &year, int32_t &month) { - month++; - if (month == 13) { - year++; - month = 1; - } -} - -void NextDay(int32_t &year, int32_t &month, int32_t &day) { - day++; - if (!Date::IsValid(year, month, day)) { - // day is out of range for month, move to next month - NextMonth(year, month); - day = 1; - } -} - -void NextHour(int32_t &year, int32_t &month, int32_t &day, int32_t &hour) { - hour++; - if (hour >= 24) { - NextDay(year, month, day); - hour = 0; - } -} - -void NextMinute(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute) { - minute++; - if (minute >= 60) { - NextHour(year, month, day, hour); - minute = 0; - } -} - -void NextSecond(int32_t &year, int32_t &month, int32_t &day, int32_t &hour, int32_t &minute, int32_t &sec) { - sec++; - if (sec >= 60) { - NextMinute(year, month, day, hour, minute); - sec = 0; - } -} - -timestamp_t MakeTimestampNice(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t minute, int32_t sec, - int32_t micros, interval_t step) { - // how to make a timestamp nice depends on the step - if (step.months >= 12) { - // if the step involves one year or more, ceil to months - // set time component to 00:00:00.00 - if (day > 1 || hour > 0 || minute > 0 || sec > 0 || micros > 0) { - // move to next month - NextMonth(year, month); - hour = minute = sec = micros = 0; - day = 1; - } - } else if (step.months > 0 || step.days >= 1) { - // if the step involves more than one day, ceil to days - if (hour > 0 || minute > 0 || sec > 0 || micros > 0) { - NextDay(year, month, day); - hour = minute = sec = micros = 0; - } - } else if (step.days > 0 || step.micros >= Interval::MICROS_PER_HOUR) { - // if the step involves more than one hour, ceil to hours - if (minute > 0 || sec > 0 || micros > 0) { - NextHour(year, month, day, hour); - minute = sec = micros = 0; - } - } else if (step.micros >= Interval::MICROS_PER_MINUTE) { - // if the step involves more than one minute, ceil to minutes - if (sec > 0 || micros > 0) { - NextMinute(year, month, day, hour, minute); - sec = micros = 0; - } - } else if (step.micros >= Interval::MICROS_PER_SEC) { - // if the step involves more than one second, ceil to seconds - if (micros > 0) { - NextSecond(year, month, day, hour, minute, sec); - micros = 0; - } - } - return Timestamp::FromDatetime(Date::FromDate(year, month, day), Time::FromTime(hour, minute, sec, micros)); -} - -int64_t RoundNumberToDivisor(int64_t number, int64_t divisor) { - return (number + (divisor / 2)) / divisor * divisor; -} - -interval_t MakeIntervalNice(interval_t interval) { - if (interval.months >= 6) { - // if we have more than 6 months, we don't care about days - interval.days = 0; - interval.micros = 0; - } else if (interval.months > 0 || interval.days >= 5) { - // if we have any months or more than 5 days, we don't care about micros - interval.micros = 0; - } else if (interval.days > 0 || interval.micros >= 6 * Interval::MICROS_PER_HOUR) { - // if we any days or more than 6 hours, we want micros to be roundable by hours at least - interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_HOUR); - } else if (interval.micros >= Interval::MICROS_PER_HOUR) { - // if we have more than an hour, we want micros to be divisible by quarter hours - interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE * 15); - } else if (interval.micros >= Interval::MICROS_PER_MINUTE * 10) { - // if we have more than 10 minutes, we want micros to be divisible by minutes - interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_MINUTE); - } else if (interval.micros >= Interval::MICROS_PER_MINUTE) { - // if we have more than a minute, we want micros to be divisible by quarter minutes - interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC * 15); - } else if (interval.micros >= Interval::MICROS_PER_SEC * 10) { - // if we have more than 10 seconds, we want micros to be divisible by seconds - interval.micros = RoundNumberToDivisor(interval.micros, Interval::MICROS_PER_SEC); - } - return interval; -} - -void GetTimestampComponents(timestamp_t input, int32_t &year, int32_t &month, int32_t &day, int32_t &hour, - int32_t &minute, int32_t &sec, int32_t µs) { - date_t date; - dtime_t time; - - Timestamp::Convert(input, date, time); - Date::Convert(date, year, month, day); - Time::Convert(time, hour, minute, sec, micros); -} - -struct EquiWidthBinsTimestamp { - static constexpr LogicalTypeId LOGICAL_TYPE = LogicalTypeId::TIMESTAMP; - - static vector> Operation(const Expression &expr, timestamp_t input_min, - timestamp_t input_max, idx_t bin_count, bool nice_rounding) { - if (!Value::IsFinite(input_min) || !Value::IsFinite(input_max)) { - throw InvalidInputException(expr, "equi_width_bucket does not support infinite or nan as min/max value"); - } - - if (!nice_rounding) { - // if we are not doing nice rounding it is pretty simple - just interpolate between the timestamp values - auto interpolated_values = - EquiWidthBinsInteger::Operation(expr, input_min.value, input_max.value, bin_count, false); - - vector> result; - for (auto &val : interpolated_values) { - result.push_back(timestamp_t(val.val)); - } - return result; - } - // fetch the components of the timestamps - int32_t min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros; - int32_t max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros; - GetTimestampComponents(input_min, min_year, min_month, min_day, min_hour, min_minute, min_sec, min_micros); - GetTimestampComponents(input_max, max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros); - - // get the interval differences per component - // note: these can be negative (except for the largest non-zero difference) - interval_t interval_diff; - interval_diff.months = (max_year - min_year) * Interval::MONTHS_PER_YEAR + (max_month - min_month); - interval_diff.days = max_day - min_day; - interval_diff.micros = (max_hour - min_hour) * Interval::MICROS_PER_HOUR + - (max_minute - min_minute) * Interval::MICROS_PER_MINUTE + - (max_sec - min_sec) * Interval::MICROS_PER_SEC + (max_micros - min_micros); - - double step_months = static_cast(interval_diff.months) / static_cast(bin_count); - double step_days = static_cast(interval_diff.days) / static_cast(bin_count); - double step_micros = static_cast(interval_diff.micros) / static_cast(bin_count); - // since we truncate the months/days, propagate any fractional component to the unit below (i.e. 0.2 months - // becomes 6 days) - if (step_months > 0) { - double overflow_months = step_months - std::floor(step_months); - step_days += overflow_months * Interval::DAYS_PER_MONTH; - } - if (step_days > 0) { - double overflow_days = step_days - std::floor(step_days); - step_micros += overflow_days * Interval::MICROS_PER_DAY; - } - interval_t step; - step.months = static_cast(step_months); - step.days = static_cast(step_days); - step.micros = static_cast(step_micros); - - // now we make the max, and the step nice - step = MakeIntervalNice(step); - timestamp_t timestamp_val = - MakeTimestampNice(max_year, max_month, max_day, max_hour, max_minute, max_sec, max_micros, step); - if (step.months <= 0 && step.days <= 0 && step.micros <= 0) { - // interval must be at least one microsecond - step.months = step.days = 0; - step.micros = 1; - } - - vector> result; - while (timestamp_val.value >= input_min.value && result.size() < bin_count) { - result.push_back(timestamp_val); - timestamp_val = SubtractOperator::Operation(timestamp_val, step); - } - return result; - } -}; - -unique_ptr BindEquiWidthFunction(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - // while internally the bins are computed over a unified type - // the equi_width_bins function returns the same type as the input MAX - LogicalType child_type; - switch (arguments[1]->return_type.id()) { - case LogicalTypeId::UNKNOWN: - case LogicalTypeId::SQLNULL: - return nullptr; - case LogicalTypeId::DECIMAL: - // for decimals we promote to double because - child_type = LogicalType::DOUBLE; - break; - default: - child_type = arguments[1]->return_type; - break; - } - bound_function.return_type = LogicalType::LIST(child_type); - return nullptr; -} - -template -static void EquiWidthBinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - static constexpr int64_t MAX_BIN_COUNT = 1000000; - auto &min_arg = args.data[0]; - auto &max_arg = args.data[1]; - auto &bin_count = args.data[2]; - auto &nice_rounding = args.data[3]; - - Vector intermediate_result(LogicalType::LIST(OP::LOGICAL_TYPE)); - GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, PrimitiveType, - GenericListType>>( - min_arg, max_arg, bin_count, nice_rounding, intermediate_result, args.size(), - [&](PrimitiveType min_p, PrimitiveType max_p, PrimitiveType bins_p, - PrimitiveType nice_rounding_p) { - if (max_p.val < min_p.val) { - throw InvalidInputException(state.expr, - "Invalid input for bin function - max value is smaller than min value"); - } - if (bins_p.val <= 0) { - throw InvalidInputException(state.expr, "Invalid input for bin function - there must be > 0 bins"); - } - if (bins_p.val > MAX_BIN_COUNT) { - throw InvalidInputException(state.expr, "Invalid input for bin function - max bin count of %d exceeded", - MAX_BIN_COUNT); - } - GenericListType> result_bins; - if (max_p.val == min_p.val) { - // if max = min return a single bucket - result_bins.values.push_back(max_p.val); - } else { - result_bins.values = OP::Operation(state.expr, min_p.val, max_p.val, static_cast(bins_p.val), - nice_rounding_p.val); - // last bin should always be the input max - if (result_bins.values[0].val < max_p.val) { - result_bins.values[0].val = max_p.val; - } - std::reverse(result_bins.values.begin(), result_bins.values.end()); - } - return result_bins; - }); - VectorOperations::DefaultCast(intermediate_result, result, args.size()); -} - -static void UnsupportedEquiWidth(DataChunk &args, ExpressionState &state, Vector &) { - throw BinderException(state.expr, "Unsupported type \"%s\" for equi_width_bins", args.data[0].GetType()); -} - -void EquiWidthBinSerialize(Serializer &, const optional_ptr, const ScalarFunction &) { - return; -} - -unique_ptr EquiWidthBinDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.return_type = deserializer.Get(); - return nullptr; -} - -ScalarFunctionSet EquiWidthBinsFun::GetFunctions() { - ScalarFunctionSet functions("equi_width_bins"); - functions.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BOOLEAN}, - LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, - BindEquiWidthFunction)); - functions.AddFunction(ScalarFunction( - {LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::BIGINT, LogicalType::BOOLEAN}, - LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, BindEquiWidthFunction)); - functions.AddFunction( - ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::BIGINT, LogicalType::BOOLEAN}, - LogicalType::LIST(LogicalType::ANY), EquiWidthBinFunction, - BindEquiWidthFunction)); - functions.AddFunction( - ScalarFunction({LogicalType::ANY_PARAMS(LogicalType::ANY, 150), LogicalType::ANY_PARAMS(LogicalType::ANY, 150), - LogicalType::BIGINT, LogicalType::BOOLEAN}, - LogicalType::LIST(LogicalType::ANY), UnsupportedEquiWidth, BindEquiWidthFunction)); - for (auto &function : functions.functions) { - function.serialize = EquiWidthBinSerialize; - function.deserialize = EquiWidthBinDeserialize; - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp b/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp deleted file mode 100644 index 5db38d601..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/can_implicitly_cast.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast_rules.hpp" - -namespace duckdb { - -bool CanCastImplicitly(ClientContext &context, const LogicalType &source, const LogicalType &target) { - return CastFunctionSet::Get(context).ImplicitCastCost(source, target) >= 0; -} - -static void CanCastImplicitlyFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &context = state.GetContext(); - bool can_cast_implicitly = CanCastImplicitly(context, args.data[0].GetType(), args.data[1].GetType()); - auto v = Value::BOOLEAN(can_cast_implicitly); - result.Reference(v); -} - -unique_ptr BindCanCastImplicitlyExpression(FunctionBindExpressionInput &input) { - auto &source_type = input.function.children[0]->return_type; - auto &target_type = input.function.children[1]->return_type; - if (source_type.id() == LogicalTypeId::UNKNOWN || source_type.id() == LogicalTypeId::SQLNULL || - target_type.id() == LogicalTypeId::UNKNOWN || target_type.id() == LogicalTypeId::SQLNULL) { - // parameter - unknown return type - return nullptr; - } - // emit a constant expression - return make_uniq( - Value::BOOLEAN(CanCastImplicitly(input.context, source_type, target_type))); -} - -ScalarFunction CanCastImplicitlyFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY, LogicalType::ANY}, LogicalType::BOOLEAN, CanCastImplicitlyFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindCanCastImplicitlyExpression; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp b/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp deleted file mode 100644 index b983b27cd..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/current_setting.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" - -#include "duckdb/main/database.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/catalog/catalog.hpp" -namespace duckdb { - -struct CurrentSettingBindData : public FunctionData { - explicit CurrentSettingBindData(Value value_p) : value(std::move(value_p)) { - } - - Value value; - -public: - unique_ptr Copy() const override { - return make_uniq(value); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return Value::NotDistinctFrom(value, other.value); - } -}; - -static void CurrentSettingFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - result.Reference(info.value); -} - -unique_ptr CurrentSettingBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto &key_child = arguments[0]; - if (key_child->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || - key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw ParserException("Key name for current_setting needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - if (key_val.IsNull() || StringValue::Get(key_val).empty()) { - throw ParserException("Key name for current_setting needs to be neither NULL nor empty"); - } - - auto key = StringUtil::Lower(StringValue::Get(key_val)); - Value val; - if (!context.TryGetCurrentSetting(key, val)) { - Catalog::AutoloadExtensionByConfigName(context, key); - // If autoloader didn't throw, the config is now available - context.TryGetCurrentSetting(key, val); - } - - bound_function.return_type = val.type(); - return make_uniq(val); -} - -ScalarFunction CurrentSettingFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::VARCHAR}, LogicalType::ANY, CurrentSettingFunction, CurrentSettingBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp b/src/duckdb/extension/core_functions/scalar/generic/hash.cpp deleted file mode 100644 index 184919447..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/hash.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" - -namespace duckdb { - -static void HashFunction(DataChunk &args, ExpressionState &state, Vector &result) { - args.Hash(result); - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunction HashFun::GetFunction() { - auto hash_fun = ScalarFunction({LogicalType::ANY}, LogicalType::HASH, HashFunction); - hash_fun.varargs = LogicalType::ANY; - hash_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return hash_fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/least.cpp b/src/duckdb/extension/core_functions/scalar/generic/least.cpp deleted file mode 100644 index 40a943101..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/least.cpp +++ /dev/null @@ -1,259 +0,0 @@ -#include "duckdb/common/operator/comparison_operators.hpp" -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -struct LeastOp { - using OP = LessThan; - - static OrderByNullType NullOrdering() { - return OrderByNullType::NULLS_LAST; - } -}; - -struct GreaterOp { - using OP = GreaterThan; - - static OrderByNullType NullOrdering() { - return OrderByNullType::NULLS_FIRST; - } -}; - -template -struct LeastOperator { - template - static T Operation(T left, T right) { - return OP::Operation(left, right) ? left : right; - } -}; - -struct LeastGreatestSortKeyState : public FunctionLocalState { - explicit LeastGreatestSortKeyState(idx_t column_count, OrderByNullType null_ordering) - : intermediate(LogicalType::BLOB), modifiers(OrderType::ASCENDING, null_ordering) { - vector types; - // initialize sort key chunk - for (idx_t i = 0; i < column_count; i++) { - types.push_back(LogicalType::BLOB); - } - sort_keys.Initialize(Allocator::DefaultAllocator(), types); - } - - DataChunk sort_keys; - Vector intermediate; - OrderModifiers modifiers; -}; - -template -unique_ptr LeastGreatestSortKeyInit(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - return make_uniq(expr.children.size(), OP::NullOrdering()); -} - -template -struct StandardLeastGreatest { - static constexpr bool IS_STRING = STRING; - - static DataChunk &Prepare(DataChunk &args, ExpressionState &) { - return args; - } - - static Vector &TargetVector(Vector &result, ExpressionState &) { - return result; - } - - static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &) { - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < rows; i++) { - if (!result_has_value[i]) { - result_mask.SetInvalid(i); - } - } - } -}; - -struct SortKeyLeastGreatest { - static constexpr bool IS_STRING = false; - - static DataChunk &Prepare(DataChunk &args, ExpressionState &state) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - lstate.sort_keys.Reset(); - for (idx_t c_idx = 0; c_idx < args.ColumnCount(); c_idx++) { - CreateSortKeyHelpers::CreateSortKey(args.data[c_idx], args.size(), lstate.modifiers, - lstate.sort_keys.data[c_idx]); - } - lstate.sort_keys.SetCardinality(args.size()); - return lstate.sort_keys; - } - - static Vector &TargetVector(Vector &result, ExpressionState &state) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - return lstate.intermediate; - } - - static void FinalizeResult(idx_t rows, bool result_has_value[], Vector &result, ExpressionState &state) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - auto result_keys = FlatVector::GetData(lstate.intermediate); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < rows; i++) { - if (!result_has_value[i]) { - result_mask.SetInvalid(i); - } else { - CreateSortKeyHelpers::DecodeSortKey(result_keys[i], result, i, lstate.modifiers); - } - } - } -}; - -template > -static void LeastGreatestFunction(DataChunk &args, ExpressionState &state, Vector &result) { - if (args.ColumnCount() == 1) { - // single input: nop - result.Reference(args.data[0]); - return; - } - auto &input = BASE_OP::Prepare(args, state); - auto &result_vector = BASE_OP::TargetVector(result, state); - - auto result_type = VectorType::CONSTANT_VECTOR; - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { - // non-constant input: result is not a constant vector - result_type = VectorType::FLAT_VECTOR; - } - if (BASE_OP::IS_STRING) { - // for string vectors we add a reference to the heap of the children - StringVector::AddHeapReference(result_vector, input.data[col_idx]); - } - } - - auto result_data = FlatVector::GetData(result_vector); - bool result_has_value[STANDARD_VECTOR_SIZE] {false}; - // perform the operation column-by-column - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - if (input.data[col_idx].GetVectorType() == VectorType::CONSTANT_VECTOR && - ConstantVector::IsNull(input.data[col_idx])) { - // ignore null vector - continue; - } - - UnifiedVectorFormat vdata; - input.data[col_idx].ToUnifiedFormat(input.size(), vdata); - - auto input_data = UnifiedVectorFormat::GetData(vdata); - if (!vdata.validity.AllValid()) { - // potential new null entries: have to check the null mask - for (idx_t i = 0; i < input.size(); i++) { - auto vindex = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(vindex)) { - // not a null entry: perform the operation and add to new set - auto ivalue = input_data[vindex]; - if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { - result_has_value[i] = true; - result_data[i] = ivalue; - } - } - } - } else { - // no new null entries: only need to perform the operation - for (idx_t i = 0; i < input.size(); i++) { - auto vindex = vdata.sel->get_index(i); - - auto ivalue = input_data[vindex]; - if (!result_has_value[i] || OP::template Operation(ivalue, result_data[i])) { - result_has_value[i] = true; - result_data[i] = ivalue; - } - } - } - } - BASE_OP::FinalizeResult(input.size(), result_has_value, result, state); - result.SetVectorType(result_type); -} - -template -unique_ptr BindLeastGreatest(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - LogicalType child_type = ExpressionBinder::GetExpressionReturnType(*arguments[0]); - for (idx_t i = 1; i < arguments.size(); i++) { - auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); - if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { - throw BinderException(arguments[i]->GetQueryLocation(), - "Cannot combine types of %s and %s - an explicit cast is required", - child_type.ToString(), arg_type.ToString()); - } - } - switch (child_type.id()) { - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - case LogicalTypeId::INTEGER_LITERAL: - child_type = IntegerLiteral::GetType(child_type); - break; - case LogicalTypeId::STRING_LITERAL: - child_type = LogicalType::VARCHAR; - break; - default: - break; - } - using OP = typename LEAST_GREATER_OP::OP; - switch (child_type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::BOOL: - case PhysicalType::INT8: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::INT16: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::INT32: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::INT64: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::INT128: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::DOUBLE: - bound_function.function = LeastGreatestFunction; - break; - case PhysicalType::VARCHAR: - bound_function.function = LeastGreatestFunction>; - break; -#endif - default: - // fallback with sort keys - bound_function.function = LeastGreatestFunction; - bound_function.init_local_state = LeastGreatestSortKeyInit; - break; - } - bound_function.arguments[0] = child_type; - bound_function.varargs = child_type; - bound_function.return_type = child_type; - return nullptr; -} - -template -ScalarFunction GetLeastGreatestFunction() { - return ScalarFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, BindLeastGreatest, nullptr, nullptr, - nullptr, LogicalType::ANY, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING); -} - -template -static ScalarFunctionSet GetLeastGreatestFunctions() { - ScalarFunctionSet fun_set; - fun_set.AddFunction(GetLeastGreatestFunction()); - return fun_set; -} - -ScalarFunctionSet LeastFun::GetFunctions() { - return GetLeastGreatestFunctions(); -} - -ScalarFunctionSet GreatestFun::GetFunctions() { - return GetLeastGreatestFunctions(); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp b/src/duckdb/extension/core_functions/scalar/generic/stats.cpp deleted file mode 100644 index ad3f4cd09..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/stats.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -struct StatsBindData : public FunctionData { - explicit StatsBindData(string stats_p = string()) : stats(std::move(stats_p)) { - } - - string stats; - -public: - unique_ptr Copy() const override { - return make_uniq(stats); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stats == other.stats; - } -}; - -static void StatsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - if (info.stats.empty()) { - info.stats = "No statistics"; - } - Value v(info.stats); - result.Reference(v); -} - -unique_ptr StatsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return make_uniq(); -} - -static unique_ptr StatsPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &bind_data = input.bind_data; - auto &info = bind_data->Cast(); - info.stats = child_stats[0].ToString(); - return nullptr; -} - -ScalarFunction StatsFun::GetFunction() { - ScalarFunction stats({LogicalType::ANY}, LogicalType::VARCHAR, StatsFunction, StatsBind, nullptr, - StatsPropagateStats); - stats.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - stats.stability = FunctionStability::VOLATILE; - return stats; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp b/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp deleted file mode 100644 index 5e4251c0b..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/system_functions.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include "duckdb/catalog/catalog_search_path.hpp" -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -// current_query -static void CurrentQueryFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(state.GetContext().GetCurrentQuery()); - result.Reference(val); -} - -// current_schema -static void CurrentSchemaFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(ClientData::Get(state.GetContext()).catalog_search_path->GetDefault().schema); - result.Reference(val); -} - -// current_database -static void CurrentDatabaseFunction(DataChunk &input, ExpressionState &state, Vector &result) { - Value val(DatabaseManager::GetDefaultDatabase(state.GetContext())); - result.Reference(val); -} - -struct CurrentSchemasBindData : public FunctionData { - explicit CurrentSchemasBindData(Value result_value) : result(std::move(result_value)) { - } - - Value result; - -public: - unique_ptr Copy() const override { - return make_uniq(result); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return Value::NotDistinctFrom(result, other.result); - } -}; - -static unique_ptr CurrentSchemasBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->return_type.id() != LogicalTypeId::BOOLEAN) { - throw BinderException("current_schemas requires a boolean input"); - } - if (!arguments[0]->IsFoldable()) { - throw NotImplementedException("current_schemas requires a constant input"); - } - Value schema_value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - Value result_val; - if (schema_value.IsNull()) { - // null - result_val = Value(LogicalType::LIST(LogicalType::VARCHAR)); - } else { - auto implicit_schemas = BooleanValue::Get(schema_value); - vector schema_list; - auto &catalog_search_path = ClientData::Get(context).catalog_search_path; - auto &search_path = implicit_schemas ? catalog_search_path->Get() : catalog_search_path->GetSetPaths(); - std::transform(search_path.begin(), search_path.end(), std::back_inserter(schema_list), - [](const CatalogSearchEntry &s) -> Value { return Value(s.schema); }); - result_val = Value::LIST(LogicalType::VARCHAR, schema_list); - } - return make_uniq(std::move(result_val)); -} - -// current_schemas -static void CurrentSchemasFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - result.Reference(info.result); -} - -// in_search_path -static void InSearchPathFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &context = state.GetContext(); - auto &search_path = ClientData::Get(context).catalog_search_path; - BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t db_name, string_t schema_name) { - return search_path->SchemaInSearchPath(context, db_name.GetString(), schema_name.GetString()); - }); -} - -// txid_current -static void TransactionIdCurrent(DataChunk &input, ExpressionState &state, Vector &result) { - auto &context = state.GetContext(); - auto &catalog = Catalog::GetCatalog(context, DatabaseManager::GetDefaultDatabase(context)); - auto &transaction = DuckTransaction::Get(context, catalog); - auto val = Value::UBIGINT(transaction.start_time); - result.Reference(val); -} - -// version -static void VersionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto val = Value(DuckDB::LibraryVersion()); - result.Reference(val); -} - -ScalarFunction CurrentQueryFun::GetFunction() { - ScalarFunction current_query({}, LogicalType::VARCHAR, CurrentQueryFunction); - current_query.stability = FunctionStability::VOLATILE; - return current_query; -} - -ScalarFunction CurrentSchemaFun::GetFunction() { - ScalarFunction current_schema({}, LogicalType::VARCHAR, CurrentSchemaFunction); - current_schema.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return current_schema; -} - -ScalarFunction CurrentDatabaseFun::GetFunction() { - ScalarFunction current_database({}, LogicalType::VARCHAR, CurrentDatabaseFunction); - current_database.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return current_database; -} - -ScalarFunction CurrentSchemasFun::GetFunction() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - ScalarFunction current_schemas({LogicalType::BOOLEAN}, varchar_list_type, CurrentSchemasFunction, - CurrentSchemasBind); - current_schemas.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return current_schemas; -} - -ScalarFunction InSearchPathFun::GetFunction() { - ScalarFunction in_search_path({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - InSearchPathFunction); - in_search_path.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return in_search_path; -} - -ScalarFunction CurrentTransactionIdFun::GetFunction() { - ScalarFunction txid_current({}, LogicalType::UBIGINT, TransactionIdCurrent); - txid_current.stability = FunctionStability::CONSISTENT_WITHIN_QUERY; - return txid_current; -} - -ScalarFunction VersionFun::GetFunction() { - return ScalarFunction({}, LogicalType::VARCHAR, VersionFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp b/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp deleted file mode 100644 index 1f7caef84..000000000 --- a/src/duckdb/extension/core_functions/scalar/generic/typeof.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "core_functions/scalar/generic_functions.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static void TypeOfFunction(DataChunk &args, ExpressionState &state, Vector &result) { - Value v(args.data[0].GetType().ToString()); - result.Reference(v); -} - -unique_ptr BindTypeOfFunctionExpression(FunctionBindExpressionInput &input) { - auto &return_type = input.function.children[0]->return_type; - if (return_type.id() == LogicalTypeId::UNKNOWN || return_type.id() == LogicalTypeId::SQLNULL) { - // parameter - unknown return type - return nullptr; - } - // emit a constant expression - return make_uniq(Value(return_type.ToString())); -} - -ScalarFunction TypeOfFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::ANY}, LogicalType::VARCHAR, TypeOfFunction); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.bind_expression = BindTypeOfFunctionExpression; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp b/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp deleted file mode 100644 index 0962b3c2e..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/array_slice.cpp +++ /dev/null @@ -1,460 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/swap.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -struct ListSliceBindData : public FunctionData { - ListSliceBindData(const LogicalType &return_type_p, bool begin_is_empty_p, bool end_is_empty_p) - : return_type(return_type_p), begin_is_empty(begin_is_empty_p), end_is_empty(end_is_empty_p) { - } - ~ListSliceBindData() override; - - LogicalType return_type; - - bool begin_is_empty; - bool end_is_empty; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; -}; - -ListSliceBindData::~ListSliceBindData() { -} - -bool ListSliceBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return return_type == other.return_type && begin_is_empty == other.begin_is_empty && - end_is_empty == other.end_is_empty; -} - -unique_ptr ListSliceBindData::Copy() const { - return make_uniq(return_type, begin_is_empty, end_is_empty); -} - -template -static idx_t CalculateSliceLength(idx_t begin, idx_t end, INDEX_TYPE step, bool svalid) { - if (step < 0) { - step = AbsValue(step); - } - if (step == 0 && svalid) { - throw InvalidInputException("Slice step cannot be zero"); - } - if (step == 1) { - return NumericCast(end - begin); - } else if (static_cast(step) >= (end - begin)) { - return 1; - } - if ((end - begin) % UnsafeNumericCast(step) != 0) { - return (end - begin) / UnsafeNumericCast(step) + 1; - } - return (end - begin) / UnsafeNumericCast(step); -} - -struct BlobSliceOperations { - static int64_t ValueLength(const string_t &value) { - return UnsafeNumericCast(value.GetSize()); - } - - static string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { - return SubstringASCII(result, input, begin + 1, end - begin); - } - - static string_t SliceValueWithSteps(Vector &result, SelectionVector &sel, string_t input, int64_t begin, - int64_t end, int64_t step, idx_t &sel_idx) { - throw InternalException("Slicing with steps is not supported for strings"); - } -}; - -struct StringSliceOperations { - static int64_t ValueLength(const string_t &value) { - return Length(value); - } - - static string_t SliceValue(Vector &result, string_t input, int64_t begin, int64_t end) { - return SubstringUnicode(result, input, begin + 1, end - begin); - } - - static string_t SliceValueWithSteps(Vector &result, SelectionVector &sel, string_t input, int64_t begin, - int64_t end, int64_t step, idx_t &sel_idx) { - throw InternalException("Slicing with steps is not supported for strings"); - } -}; - -struct ListSliceOperations { - static int64_t ValueLength(const list_entry_t &value) { - return UnsafeNumericCast(value.length); - } - - static list_entry_t SliceValue(Vector &result, list_entry_t input, int64_t begin, int64_t end) { - input.offset = UnsafeNumericCast(UnsafeNumericCast(input.offset) + begin); - input.length = UnsafeNumericCast(end - begin); - return input; - } - - static list_entry_t SliceValueWithSteps(Vector &result, SelectionVector &sel, list_entry_t input, int64_t begin, - int64_t end, int64_t step, idx_t &sel_idx) { - if (end - begin == 0) { - input.length = 0; - input.offset = sel_idx; - return input; - } - input.length = CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, true); - idx_t child_idx = input.offset + UnsafeNumericCast(begin); - if (step < 0) { - child_idx = input.offset + UnsafeNumericCast(end) - 1; - } - input.offset = sel_idx; - for (idx_t i = 0; i < input.length; i++) { - sel.set_index(sel_idx, child_idx); - child_idx += static_cast(step); // intentional overflow?? - sel_idx++; - } - return input; - } -}; - -template -static void ClampIndex(INDEX_TYPE &index, const INPUT_TYPE &value, const INDEX_TYPE length, bool is_min) { - if (index < 0) { - index = (!is_min) ? index + 1 : index; - index = length + index; - return; - } else if (index > length) { - index = length; - } - return; -} - -template -static bool ClampSlice(const INPUT_TYPE &value, INDEX_TYPE &begin, INDEX_TYPE &end) { - // Clamp offsets - begin = (begin != 0 && begin != (INDEX_TYPE)NumericLimits::Minimum()) ? begin - 1 : begin; - - bool is_min = false; - if (begin == (INDEX_TYPE)NumericLimits::Minimum()) { - begin++; - is_min = true; - } - - const auto length = OP::ValueLength(value); - if (begin < 0 && -begin > length && end < 0 && end < -length) { - begin = 0; - end = 0; - return true; - } - if (begin < 0 && -begin > length) { - begin = 0; - } - ClampIndex(begin, value, length, is_min); - ClampIndex(end, value, length, false); - end = MaxValue(begin, end); - - return true; -} - -template -static void ExecuteConstantSlice(Vector &result, Vector &str_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, SelectionVector &sel, - idx_t &sel_idx, optional_ptr result_child_vector, bool begin_is_empty, - bool end_is_empty) { - - // check all this nullness early - auto str_valid = !ConstantVector::IsNull(str_vector); - auto begin_valid = !ConstantVector::IsNull(begin_vector); - auto end_valid = !ConstantVector::IsNull(end_vector); - auto step_valid = step_vector && !ConstantVector::IsNull(*step_vector); - - if (!str_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { - ConstantVector::SetNull(result, true); - return; - } - - auto result_data = ConstantVector::GetData(result); - auto str_data = ConstantVector::GetData(str_vector); - auto begin_data = ConstantVector::GetData(begin_vector); - auto end_data = ConstantVector::GetData(end_vector); - auto step_data = step_vector ? ConstantVector::GetData(*step_vector) : nullptr; - - auto str = str_data[0]; - auto begin = begin_is_empty ? 0 : begin_data[0]; - auto end = end_is_empty ? OP::ValueLength(str) : end_data[0]; - auto step = step_data ? step_data[0] : 1; - - if (step < 0) { - swap(begin, end); - begin = end_is_empty ? 0 : begin; - end = begin_is_empty ? OP::ValueLength(str) : end; - } - - // Clamp offsets - bool clamp_result = false; - if (step_valid || step == 1) { - clamp_result = ClampSlice(str, begin, end); - } - - idx_t sel_length = 0; - bool sel_valid = false; - if (step_valid && step != 1 && end - begin > 0) { - sel_length = - CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); - sel.Initialize(sel_length); - sel_valid = true; - } - - // Try to slice - if (!clamp_result) { - ConstantVector::SetNull(result, true); - } else if (step == 1) { - result_data[0] = OP::SliceValue(result, str, begin, end); - } else { - result_data[0] = OP::SliceValueWithSteps(result, sel, str, begin, end, step, sel_idx); - } - - if (sel_valid) { - result_child_vector->Slice(sel, sel_length); - result_child_vector->Flatten(sel_length); - ListVector::SetListSize(result, sel_length); - } -} - -template -static void ExecuteFlatSlice(Vector &result, Vector &list_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, SelectionVector &sel, idx_t &sel_idx, - optional_ptr result_child_vector, bool begin_is_empty, bool end_is_empty) { - UnifiedVectorFormat list_data, begin_data, end_data, step_data; - idx_t sel_length = 0; - - list_vector.ToUnifiedFormat(count, list_data); - begin_vector.ToUnifiedFormat(count, begin_data); - end_vector.ToUnifiedFormat(count, end_data); - if (step_vector) { - step_vector->ToUnifiedFormat(count, step_data); - sel.Initialize(ListVector::GetListSize(list_vector)); - } - - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = 0; i < count; ++i) { - auto list_idx = list_data.sel->get_index(i); - auto begin_idx = begin_data.sel->get_index(i); - auto end_idx = end_data.sel->get_index(i); - auto step_idx = step_vector ? step_data.sel->get_index(i) : 0; - - auto list_valid = list_data.validity.RowIsValid(list_idx); - auto begin_valid = begin_data.validity.RowIsValid(begin_idx); - auto end_valid = end_data.validity.RowIsValid(end_idx); - auto step_valid = step_vector && step_data.validity.RowIsValid(step_idx); - - if (!list_valid || !begin_valid || !end_valid || (step_vector && !step_valid)) { - result_mask.SetInvalid(i); - continue; - } - - auto sliced = reinterpret_cast(list_data.data)[list_idx]; - auto begin = begin_is_empty ? 0 : reinterpret_cast(begin_data.data)[begin_idx]; - auto end = end_is_empty ? OP::ValueLength(sliced) : reinterpret_cast(end_data.data)[end_idx]; - auto step = step_vector ? reinterpret_cast(step_data.data)[step_idx] : 1; - - if (step < 0) { - swap(begin, end); - begin = end_is_empty ? 0 : begin; - end = begin_is_empty ? OP::ValueLength(sliced) : end; - } - - bool clamp_result = false; - if (step_valid || step == 1) { - clamp_result = ClampSlice(sliced, begin, end); - } - - idx_t length = 0; - if (end - begin > 0) { - length = - CalculateSliceLength(UnsafeNumericCast(begin), UnsafeNumericCast(end), step, step_valid); - } - sel_length += length; - - if (!clamp_result) { - result_mask.SetInvalid(i); - } else if (!step_vector) { - result_data[i] = OP::SliceValue(result, sliced, begin, end); - } else { - result_data[i] = OP::SliceValueWithSteps(result, sel, sliced, begin, end, step, sel_idx); - } - } - if (step_vector) { - SelectionVector new_sel(sel_length); - for (idx_t i = 0; i < sel_length; ++i) { - new_sel.set_index(i, sel.get_index(i)); - } - result_child_vector->Slice(new_sel, sel_length); - result_child_vector->Flatten(sel_length); - ListVector::SetListSize(result, sel_length); - } -} - -template -static void ExecuteSlice(Vector &result, Vector &list_or_str_vector, Vector &begin_vector, Vector &end_vector, - optional_ptr step_vector, const idx_t count, bool begin_is_empty, bool end_is_empty) { - optional_ptr result_child_vector; - if (step_vector) { - result_child_vector = &ListVector::GetEntry(result); - } - - SelectionVector sel; - idx_t sel_idx = 0; - - if (result.GetVectorType() == VectorType::CONSTANT_VECTOR) { - ExecuteConstantSlice(result, list_or_str_vector, begin_vector, end_vector, - step_vector, count, sel, sel_idx, result_child_vector, - begin_is_empty, end_is_empty); - } else { - ExecuteFlatSlice(result, list_or_str_vector, begin_vector, end_vector, step_vector, - count, sel, sel_idx, result_child_vector, begin_is_empty, - end_is_empty); - } - result.Verify(count); -} - -static void ArraySliceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); - D_ASSERT(args.data.size() == 3 || args.data.size() == 4); - auto count = args.size(); - - Vector &list_or_str_vector = result; - // this ensures that we do not change the input chunk - VectorOperations::Copy(args.data[0], list_or_str_vector, count, 0, 0); - - if (list_or_str_vector.GetType().id() == LogicalTypeId::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - Vector &begin_vector = args.data[1]; - Vector &end_vector = args.data[2]; - - optional_ptr step_vector; - if (args.ColumnCount() == 4) { - step_vector = &args.data[3]; - } - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto begin_is_empty = info.begin_is_empty; - auto end_is_empty = info.end_is_empty; - - result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); - switch (result.GetType().id()) { - case LogicalTypeId::LIST: { - // Share the value dictionary as we are just going to slice it - if (list_or_str_vector.GetVectorType() != VectorType::FLAT_VECTOR && - list_or_str_vector.GetVectorType() != VectorType::CONSTANT_VECTOR) { - list_or_str_vector.Flatten(count); - } - ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, - step_vector, count, begin_is_empty, end_is_empty); - break; - } - case LogicalTypeId::BLOB: - ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, - step_vector, count, begin_is_empty, end_is_empty); - break; - case LogicalTypeId::VARCHAR: - ExecuteSlice(result, list_or_str_vector, begin_vector, end_vector, - step_vector, count, begin_is_empty, end_is_empty); - break; - default: - throw NotImplementedException("Specifier type not implemented"); - } -} - -static bool CheckIfParamIsEmpty(duckdb::unique_ptr ¶m) { - bool is_empty = false; - if (param->return_type.id() == LogicalTypeId::LIST) { - auto empty_list = make_uniq(Value::LIST(LogicalType::INTEGER, vector())); - is_empty = param->Equals(*empty_list); - if (!is_empty) { - // if the param is not empty, the user has entered a list instead of a BIGINT - throw BinderException("The upper and lower bounds of the slice must be a BIGINT"); - } - } - return is_empty; -} - -static unique_ptr ArraySliceBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() == 3 || arguments.size() == 4); - D_ASSERT(bound_function.arguments.size() == 3 || bound_function.arguments.size() == 4); - - switch (arguments[0]->return_type.id()) { - case LogicalTypeId::ARRAY: { - // Cast to list - auto child_type = ArrayType::GetChildType(arguments[0]->return_type); - auto target_type = LogicalType::LIST(child_type); - arguments[0] = BoundCastExpression::AddCastToType(context, std::move(arguments[0]), target_type); - bound_function.return_type = arguments[0]->return_type; - } break; - case LogicalTypeId::LIST: - // The result is the same type - bound_function.return_type = arguments[0]->return_type; - break; - case LogicalTypeId::BLOB: - case LogicalTypeId::VARCHAR: - // string slice returns a string - if (bound_function.arguments.size() == 4) { - throw NotImplementedException( - "Slice with steps has not been implemented for string types, you can consider rewriting your query as " - "follows:\n SELECT array_to_string((str_split(string, '')[begin:end:step], '');"); - } - bound_function.return_type = arguments[0]->return_type; - for (idx_t i = 1; i < 3; i++) { - if (arguments[i]->return_type.id() != LogicalTypeId::LIST) { - bound_function.arguments[i] = LogicalType::BIGINT; - } - } - break; - case LogicalTypeId::SQLNULL: - case LogicalTypeId::UNKNOWN: - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - break; - default: - throw BinderException("ARRAY_SLICE can only operate on LISTs and VARCHARs"); - } - - bool begin_is_empty = CheckIfParamIsEmpty(arguments[1]); - if (!begin_is_empty) { - bound_function.arguments[1] = LogicalType::BIGINT; - } - bool end_is_empty = CheckIfParamIsEmpty(arguments[2]); - if (!end_is_empty) { - bound_function.arguments[2] = LogicalType::BIGINT; - } - - return make_uniq(bound_function.return_type, begin_is_empty, end_is_empty); -} - -ScalarFunctionSet ListSliceFun::GetFunctions() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, ArraySliceFunction, - ArraySliceBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); - ScalarFunctionSet set; - set.AddFunction(fun); - fun.arguments.push_back(LogicalType::BIGINT); - set.AddFunction(fun); - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp b/src/duckdb/extension/core_functions/scalar/list/flatten.cpp deleted file mode 100644 index 849c20d16..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/flatten.cpp +++ /dev/null @@ -1,171 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/storage/statistics/list_stats.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -void ListFlattenFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - Vector &input = args.data[0]; - if (input.GetType().id() == LogicalTypeId::SQLNULL) { - result.Reference(input); - return; - } - - idx_t count = args.size(); - - // Prepare the result vector - result.SetVectorType(VectorType::FLAT_VECTOR); - // This holds the new offsets and lengths - auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - // The outermost list in each row - UnifiedVectorFormat row_data; - input.ToUnifiedFormat(count, row_data); - auto row_entries = UnifiedVectorFormat::GetData(row_data); - - // The list elements in each row: [HERE, ...] - auto &row_lists = ListVector::GetEntry(input); - UnifiedVectorFormat row_lists_data; - idx_t total_row_lists = ListVector::GetListSize(input); - row_lists.ToUnifiedFormat(total_row_lists, row_lists_data); - auto row_lists_entries = UnifiedVectorFormat::GetData(row_lists_data); - - if (row_lists.GetType().id() == LogicalTypeId::SQLNULL) { - for (idx_t row_cnt = 0; row_cnt < count; row_cnt++) { - auto row_idx = row_data.sel->get_index(row_cnt); - if (!row_data.validity.RowIsValid(row_idx)) { - result_validity.SetInvalid(row_cnt); - continue; - } - result_entries[row_cnt].offset = 0; - result_entries[row_cnt].length = 0; - } - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return; - } - - // The actual elements inside each row list: [[HERE, ...], []] - // This one becomes the child vector of the result. - auto &elem_vector = ListVector::GetEntry(row_lists); - - // We'll use this selection vector to slice the elem_vector. - idx_t child_elem_cnt = ListVector::GetListSize(row_lists); - SelectionVector sel(child_elem_cnt); - idx_t sel_idx = 0; - - // HERE, [[]], ... - for (idx_t row_cnt = 0; row_cnt < count; row_cnt++) { - auto row_idx = row_data.sel->get_index(row_cnt); - - if (!row_data.validity.RowIsValid(row_idx)) { - result_validity.SetInvalid(row_cnt); - continue; - } - - idx_t list_offset = sel_idx; - idx_t list_length = 0; - - // [HERE, [...], ...] - auto row_entry = row_entries[row_idx]; - for (idx_t row_lists_cnt = 0; row_lists_cnt < row_entry.length; row_lists_cnt++) { - auto row_lists_idx = row_lists_data.sel->get_index(row_entry.offset + row_lists_cnt); - - // Skip invalid lists - if (!row_lists_data.validity.RowIsValid(row_lists_idx)) { - continue; - } - - // [[HERE, ...], [.., ...]] - auto list_entry = row_lists_entries[row_lists_idx]; - list_length += list_entry.length; - - for (idx_t elem_cnt = 0; elem_cnt < list_entry.length; elem_cnt++) { - // offset of the element in the elem_vector. - idx_t offset = list_entry.offset + elem_cnt; - sel.set_index(sel_idx, offset); - sel_idx++; - } - } - - result_entries[row_cnt].offset = list_offset; - result_entries[row_cnt].length = list_length; - } - - ListVector::SetListSize(result, sel_idx); - - auto &result_child_vector = ListVector::GetEntry(result); - result_child_vector.Slice(elem_vector, sel, sel_idx); - result_child_vector.Flatten(sel_idx); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListFlattenBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 1); - - if (arguments[0]->return_type.id() == LogicalTypeId::ARRAY) { - auto child_type = ArrayType::GetChildType(arguments[0]->return_type); - if (child_type.id() == LogicalTypeId::ARRAY) { - child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); - } - arguments[0] = - BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); - } else if (arguments[0]->return_type.id() == LogicalTypeId::LIST) { - auto child_type = ListType::GetChildType(arguments[0]->return_type); - if (child_type.id() == LogicalTypeId::ARRAY) { - child_type = LogicalType::LIST(ArrayType::GetChildType(child_type)); - arguments[0] = - BoundCastExpression::AddCastToType(context, std::move(arguments[0]), LogicalType::LIST(child_type)); - } - } - - auto &input_type = arguments[0]->return_type; - bound_function.arguments[0] = input_type; - if (input_type.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - D_ASSERT(input_type.id() == LogicalTypeId::LIST); - - auto child_type = ListType::GetChildType(input_type); - if (child_type.id() == LogicalType::SQLNULL) { - bound_function.return_type = input_type; - return make_uniq(bound_function.return_type); - } - if (child_type.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalType(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - D_ASSERT(child_type.id() == LogicalTypeId::LIST); - - bound_function.return_type = child_type; - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListFlattenStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); - auto child_copy = list_child_stats.Copy(); - child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); - return child_copy.ToUnique(); -} - -ScalarFunction ListFlattenFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::LIST(LogicalType::ANY))}, LogicalType::LIST(LogicalType::ANY), - ListFlattenFunction, ListFlattenBind, nullptr, ListFlattenStats); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp b/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp deleted file mode 100644 index 1b2aab71d..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_aggregates.cpp +++ /dev/null @@ -1,534 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "core_functions/aggregate/nested_functions.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/common/owning_string_map.hpp" - -namespace duckdb { - -// FIXME: use a local state for each thread to increase performance? -// FIXME: benchmark the use of simple_update against using update (if applicable) - -static unique_ptr ListAggregatesBindFailure(ScalarFunction &bound_function) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(LogicalType::SQLNULL); -} - -struct ListAggregatesBindData : public FunctionData { - ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p); - ~ListAggregatesBindData() override; - - LogicalType stype; - unique_ptr aggr_expr; - - unique_ptr Copy() const override { - return make_uniq(stype, aggr_expr->Copy()); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return stype == other.stype && aggr_expr->Equals(*other.aggr_expr); - } - void Serialize(Serializer &serializer) const { - serializer.WriteProperty(1, "stype", stype); - serializer.WriteProperty(2, "aggr_expr", aggr_expr); - } - static unique_ptr Deserialize(Deserializer &deserializer) { - auto stype = deserializer.ReadProperty(1, "stype"); - auto aggr_expr = deserializer.ReadProperty>(2, "aggr_expr"); - auto result = make_uniq(std::move(stype), std::move(aggr_expr)); - return result; - } - - static void SerializeFunction(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - auto bind_data = dynamic_cast(bind_data_p.get()); - serializer.WritePropertyWithDefault(100, "bind_data", bind_data, (const ListAggregatesBindData *)nullptr); - } - - static unique_ptr DeserializeFunction(Deserializer &deserializer, ScalarFunction &bound_function) { - auto result = deserializer.ReadPropertyWithExplicitDefault>( - 100, "bind_data", unique_ptr(nullptr)); - if (!result) { - return ListAggregatesBindFailure(bound_function); - } - return std::move(result); - } -}; - -ListAggregatesBindData::ListAggregatesBindData(const LogicalType &stype_p, unique_ptr aggr_expr_p) - : stype(stype_p), aggr_expr(std::move(aggr_expr_p)) { -} - -ListAggregatesBindData::~ListAggregatesBindData() { -} - -struct StateVector { - StateVector(idx_t count_p, unique_ptr aggr_expr_p) - : count(count_p), aggr_expr(std::move(aggr_expr_p)), state_vector(Vector(LogicalType::POINTER, count_p)) { - } - - ~StateVector() { // NOLINT - // destroy objects within the aggregate states - auto &aggr = aggr_expr->Cast(); - if (aggr.function.destructor) { - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.destructor(state_vector, aggr_input_data, count); - } - } - - idx_t count; - unique_ptr aggr_expr; - Vector state_vector; -}; - -struct FinalizeValueFunctor { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - FlatVector::GetData(result)[offset] = value; - } -}; - -struct FinalizeStringValueFunctor { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - FlatVector::GetData(result)[offset] = StringVector::AddStringOrBlob(result, value); - } -}; - -struct FinalizeGenericValueFunctor { - template - static void HistogramFinalize(T value, Vector &result, idx_t offset) { - CreateSortKeyHelpers::DecodeSortKey(value, result, offset, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - } -}; - -struct AggregateFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - } -}; - -struct DistinctFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData *>(sdata); - - auto old_len = ListVector::GetListSize(result); - idx_t new_entries = 0; - // figure out how much space we need - for (idx_t i = 0; i < count; i++) { - auto &state = *states[sdata.sel->get_index(i)]; - if (!state.hist) { - continue; - } - new_entries += state.hist->size(); - } - // reserve space in the list vector - ListVector::Reserve(result, old_len + new_entries); - auto &child_elements = ListVector::GetEntry(result); - auto list_entries = FlatVector::GetData(result); - - idx_t current_offset = old_len; - for (idx_t i = 0; i < count; i++) { - const auto rid = i; - auto &state = *states[sdata.sel->get_index(i)]; - auto &list_entry = list_entries[rid]; - list_entry.offset = current_offset; - if (!state.hist) { - list_entry.length = 0; - continue; - } - - for (auto &entry : *state.hist) { - OP::template HistogramFinalize(entry.first, child_elements, current_offset); - current_offset++; - } - list_entry.length = current_offset - list_entry.offset; - } - D_ASSERT(current_offset == old_len + new_entries); - ListVector::SetListSize(result, current_offset); - result.Verify(count); - } -}; - -struct UniqueFunctor { - template > - static void ListExecuteFunction(Vector &result, Vector &state_vector, idx_t count) { - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - auto states = UnifiedVectorFormat::GetData *>(sdata); - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - - auto state = states[sdata.sel->get_index(i)]; - - if (!state->hist) { - result_data[i] = 0; - continue; - } - result_data[i] = state->hist->size(); - } - result.Verify(count); - } -}; - -template -static void ListAggregatesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - Vector &lists = args.data[0]; - - // set the result vector - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_validity = FlatVector::Validity(result); - - if (lists.GetType().id() == LogicalTypeId::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - // get the aggregate function - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &aggr = info.aggr_expr->Cast(); - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - - D_ASSERT(aggr.function.update); - - auto lists_size = ListVector::GetListSize(lists); - auto &child_vector = ListVector::GetEntry(lists); - child_vector.Flatten(lists_size); - - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(lists_size, child_data); - - UnifiedVectorFormat lists_data; - lists.ToUnifiedFormat(count, lists_data); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - // state_buffer holds the state for each list of this chunk - idx_t size = aggr.function.state_size(aggr.function); - auto state_buffer = make_unsafe_uniq_array_uninitialized(size * count); - - // state vector for initialize and finalize - StateVector state_vector(count, info.aggr_expr->Copy()); - auto states = FlatVector::GetData(state_vector.state_vector); - - // state vector of STANDARD_VECTOR_SIZE holds the pointers to the states - Vector state_vector_update = Vector(LogicalType::POINTER); - auto states_update = FlatVector::GetData(state_vector_update); - - // selection vector pointing to the data - SelectionVector sel_vector(STANDARD_VECTOR_SIZE); - idx_t states_idx = 0; - - for (idx_t i = 0; i < count; i++) { - - // initialize the state for this list - auto state_ptr = state_buffer.get() + size * i; - states[i] = state_ptr; - aggr.function.initialize(aggr.function, states[i]); - - auto lists_index = lists_data.sel->get_index(i); - const auto &list_entry = list_entries[lists_index]; - - // nothing to do for this list - if (!lists_data.validity.RowIsValid(lists_index)) { - result_validity.SetInvalid(i); - continue; - } - - // skip empty list - if (list_entry.length == 0) { - continue; - } - - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // states vector is full, update - if (states_idx == STANDARD_VECTOR_SIZE) { - // update the aggregate state(s) - Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); - - // reset values - states_idx = 0; - } - - auto source_idx = child_data.sel->get_index(list_entry.offset + child_idx); - sel_vector.set_index(states_idx, source_idx); - states_update[states_idx] = state_ptr; - states_idx++; - } - } - - // update the remaining elements of the last list(s) - if (states_idx != 0) { - Vector slice(child_vector, sel_vector, states_idx); - aggr.function.update(&slice, aggr_input_data, 1, state_vector_update, states_idx); - } - - if (IS_AGGR) { - // finalize all the aggregate states - aggr.function.finalize(state_vector.state_vector, aggr_input_data, result, count, 0); - - } else { - // finalize manually to use the map - D_ASSERT(aggr.function.arguments.size() == 1); - auto key_type = aggr.function.arguments[0]; - - switch (key_type.InternalType()) { -#ifndef DUCKDB_SMALLER_BINARY - case PhysicalType::BOOL: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT8: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT16: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT32: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::UINT64: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT8: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT16: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT32: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::INT64: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::FLOAT: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::DOUBLE: - FUNCTION_FUNCTOR::template ListExecuteFunction( - result, state_vector.state_vector, count); - break; - case PhysicalType::VARCHAR: - FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, - count); - break; -#endif - default: - FUNCTION_FUNCTOR::template ListExecuteFunction>(result, state_vector.state_vector, - count); - break; - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void ListAggregateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() >= 2); - ListAggregatesFunction(args, state, result); -} - -static void ListDistinctFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - ListAggregatesFunction(args, state, result); -} - -static void ListUniqueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - ListAggregatesFunction(args, state, result); -} - -template -static unique_ptr -ListAggregatesBindFunction(ClientContext &context, ScalarFunction &bound_function, const LogicalType &list_child_type, - AggregateFunction &aggr_function, vector> &arguments) { - - // create the child expression and its type - vector> children; - auto expr = make_uniq(Value(list_child_type)); - children.push_back(std::move(expr)); - // push any extra arguments into the list aggregate bind - if (arguments.size() > 2) { - for (idx_t i = 2; i < arguments.size(); i++) { - children.push_back(std::move(arguments[i])); - } - arguments.resize(2); - } - - FunctionBinder function_binder(context); - auto bound_aggr_function = function_binder.BindAggregateFunction(aggr_function, std::move(children)); - bound_function.arguments[0] = LogicalType::LIST(bound_aggr_function->function.arguments[0]); - - if (IS_AGGR) { - bound_function.return_type = bound_aggr_function->function.return_type; - } - // check if the aggregate function consumed all the extra input arguments - if (bound_aggr_function->children.size() > 1) { - throw InvalidInputException( - "Aggregate function %s is not supported for list_aggr: extra arguments were not removed during bind", - bound_aggr_function->ToString()); - } - - return make_uniq(bound_function.return_type, std::move(bound_aggr_function)); -} - -template -static unique_ptr ListAggregatesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - if (arguments[0]->return_type.id() == LogicalTypeId::SQLNULL) { - return ListAggregatesBindFailure(bound_function); - } - - bool is_parameter = arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN; - LogicalType child_type; - if (is_parameter) { - child_type = LogicalType::ANY; - } else if (arguments[0]->return_type.id() == LogicalTypeId::LIST || - arguments[0]->return_type.id() == LogicalTypeId::MAP) { - child_type = ListType::GetChildType(arguments[0]->return_type); - } else { - // Unreachable - throw InvalidInputException("First argument of list aggregate must be a list, map or array"); - } - - string function_name = "histogram"; - if (IS_AGGR) { // get the name of the aggregate function - if (!arguments[1]->IsFoldable()) { - throw InvalidInputException("Aggregate function name must be a constant"); - } - // get the function name - Value function_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - function_name = function_value.ToString(); - } - - // look up the aggregate function in the catalog - auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, DEFAULT_SCHEMA, - function_name); - D_ASSERT(func.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - - if (is_parameter) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - return nullptr; - } - - // find a matching aggregate function - ErrorData error; - vector types; - types.push_back(child_type); - // push any extra arguments into the type list - for (idx_t i = 2; i < arguments.size(); i++) { - types.push_back(arguments[i]->return_type); - } - - FunctionBinder function_binder(context); - auto best_function_idx = function_binder.BindFunction(func.name, func.functions, types, error); - if (!best_function_idx.IsValid()) { - throw BinderException("No matching aggregate function\n%s", error.Message()); - } - - // found a matching function, bind it as an aggregate - auto best_function = func.functions.GetFunctionByOffset(best_function_idx.GetIndex()); - if (IS_AGGR) { - bound_function.errors = best_function.errors; - return ListAggregatesBindFunction(context, bound_function, child_type, best_function, arguments); - } - - // create the unordered map histogram function - D_ASSERT(best_function.arguments.size() == 1); - auto aggr_function = HistogramFun::GetHistogramUnorderedMap(child_type); - return ListAggregatesBindFunction(context, bound_function, child_type, aggr_function, arguments); -} - -static unique_ptr ListAggregateBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // the list column and the name of the aggregate function - D_ASSERT(bound_function.arguments.size() >= 2); - D_ASSERT(arguments.size() >= 2); - - return ListAggregatesBind(context, bound_function, arguments); -} - -static unique_ptr ListDistinctBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - bound_function.return_type = arguments[0]->return_type; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - -static unique_ptr ListUniqueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(bound_function.arguments.size() == 1); - D_ASSERT(arguments.size() == 1); - bound_function.return_type = LogicalType::UBIGINT; - - return ListAggregatesBind<>(context, bound_function, arguments); -} - -ScalarFunction ListAggregateFun::GetFunction() { - auto result = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, LogicalType::ANY, - ListAggregateFunction, ListAggregateBind); - BaseScalarFunction::SetReturnsError(result); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.varargs = LogicalType::ANY; - result.serialize = ListAggregatesBindData::SerializeFunction; - result.deserialize = ListAggregatesBindData::DeserializeFunction; - return result; -} - -ScalarFunction ListDistinctFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), - ListDistinctFunction, ListDistinctBind); -} - -ScalarFunction ListUniqueFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::UBIGINT, ListUniqueFunction, - ListUniqueBind); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp b/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp deleted file mode 100644 index 5c3513b2a..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_distance.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "core_functions/array_kernels.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -//------------------------------------------------------------------------------ -// Generic "fold" function -//------------------------------------------------------------------------------ -// Given two lists of the same size, combine and reduce their elements into a -// single scalar value. - -template -static void ListGenericFold(DataChunk &args, ExpressionState &state, Vector &result) { - const auto &lstate = state.Cast(); - const auto &expr = lstate.expr.Cast(); - const auto &func_name = expr.function.name; - - auto count = args.size(); - - auto &lhs_vec = args.data[0]; - auto &rhs_vec = args.data[1]; - - const auto lhs_count = ListVector::GetListSize(lhs_vec); - const auto rhs_count = ListVector::GetListSize(rhs_vec); - - auto &lhs_child = ListVector::GetEntry(lhs_vec); - auto &rhs_child = ListVector::GetEntry(rhs_vec); - - lhs_child.Flatten(lhs_count); - rhs_child.Flatten(rhs_count); - - D_ASSERT(lhs_child.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rhs_child.GetVectorType() == VectorType::FLAT_VECTOR); - - if (!FlatVector::Validity(lhs_child).CheckAllValid(lhs_count)) { - throw InvalidInputException("%s: left argument can not contain NULL values", func_name); - } - - if (!FlatVector::Validity(rhs_child).CheckAllValid(rhs_count)) { - throw InvalidInputException("%s: right argument can not contain NULL values", func_name); - } - - auto lhs_data = FlatVector::GetData(lhs_child); - auto rhs_data = FlatVector::GetData(rhs_child); - - BinaryExecutor::ExecuteWithNulls( - lhs_vec, rhs_vec, result, count, - [&](const list_entry_t &left, const list_entry_t &right, ValidityMask &mask, idx_t row_idx) { - if (left.length != right.length) { - throw InvalidInputException( - "%s: list dimensions must be equal, got left length '%d' and right length '%d'", func_name, - left.length, right.length); - } - - if (!OP::ALLOW_EMPTY && left.length == 0) { - mask.SetInvalid(row_idx); - return TYPE(); - } - - return OP::Operation(lhs_data + left.offset, rhs_data + right.offset, left.length); - }); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -//------------------------------------------------------------------------- -// Function Registration -//------------------------------------------------------------------------- - -template -static void AddListFoldFunction(ScalarFunctionSet &set, const LogicalType &type) { - const auto list = LogicalType::LIST(type); - if (type.id() == LogicalTypeId::FLOAT) { - set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); - } else if (type.id() == LogicalTypeId::DOUBLE) { - set.AddFunction(ScalarFunction({list, list}, type, ListGenericFold)); - } else { - throw NotImplementedException("List function not implemented for type %s", type.ToString()); - } -} - -ScalarFunctionSet ListDistanceFun::GetFunctions() { - ScalarFunctionSet set("list_distance"); - for (auto &type : LogicalType::Real()) { - AddListFoldFunction(set, type); - } - for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return set; -} - -ScalarFunctionSet ListInnerProductFun::GetFunctions() { - ScalarFunctionSet set("list_inner_product"); - for (auto &type : LogicalType::Real()) { - AddListFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ListNegativeInnerProductFun::GetFunctions() { - ScalarFunctionSet set("list_negative_inner_product"); - for (auto &type : LogicalType::Real()) { - AddListFoldFunction(set, type); - } - return set; -} - -ScalarFunctionSet ListCosineSimilarityFun::GetFunctions() { - ScalarFunctionSet set("list_cosine_similarity"); - for (auto &type : LogicalType::Real()) { - AddListFoldFunction(set, type); - } - for (auto &func : set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return set; -} - -ScalarFunctionSet ListCosineDistanceFun::GetFunctions() { - ScalarFunctionSet set("list_cosine_distance"); - for (auto &type : LogicalType::Real()) { - AddListFoldFunction(set, type); - } - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp b/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp deleted file mode 100644 index 30ac79db1..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_filter.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" - -#include "duckdb/function/lambda_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -static unique_ptr ListFilterBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // the list column and the bound lambda expression - D_ASSERT(arguments.size() == 2); - if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { - throw BinderException("Invalid lambda expression!"); - } - - auto &bound_lambda_expr = arguments[1]->Cast(); - - // try to cast to boolean, if the return type of the lambda filter expression is not already boolean - if (bound_lambda_expr.lambda_expr->return_type != LogicalType::BOOLEAN) { - auto cast_lambda_expr = - BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), LogicalType::BOOLEAN); - bound_lambda_expr.lambda_expr = std::move(cast_lambda_expr); - } - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - bound_function.return_type = arguments[0]->return_type; - auto has_index = bound_lambda_expr.parameter_count == 2; - return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); -} - -static LogicalType ListFilterBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { - return LambdaFunctions::BindBinaryLambda(parameter_idx, list_child_type); -} - -ScalarFunction ListFilterFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), - LambdaFunctions::ListFilterFunction, ListFilterBind, nullptr, nullptr); - - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListFilterBindLambda; - - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp b/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp deleted file mode 100644 index dd15edc93..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_has_any_or_all.cpp +++ /dev/null @@ -1,227 +0,0 @@ -#include "duckdb/function/lambda_functions.hpp" -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/common/string_map_set.hpp" - -namespace duckdb { - -static unique_ptr ListHasAnyOrAllBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - arguments[1] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[1])); - - const auto lhs_is_param = arguments[0]->HasParameter(); - const auto rhs_is_param = arguments[1]->HasParameter(); - - if (lhs_is_param && rhs_is_param) { - throw ParameterNotResolvedException(); - } - - const auto &lhs_list = arguments[0]->return_type; - const auto &rhs_list = arguments[1]->return_type; - - if (lhs_is_param) { - bound_function.arguments[0] = rhs_list; - bound_function.arguments[1] = rhs_list; - return nullptr; - } - if (rhs_is_param) { - bound_function.arguments[0] = lhs_list; - bound_function.arguments[1] = lhs_list; - return nullptr; - } - - bound_function.arguments[0] = lhs_list; - bound_function.arguments[1] = rhs_list; - - const auto &lhs_child = ListType::GetChildType(bound_function.arguments[0]); - const auto &rhs_child = ListType::GetChildType(bound_function.arguments[1]); - - if (lhs_child != LogicalType::SQLNULL && rhs_child != LogicalType::SQLNULL && lhs_child != rhs_child) { - LogicalType common_child; - if (!LogicalType::TryGetMaxLogicalType(context, lhs_child, rhs_child, common_child)) { - throw BinderException("'%s' cannot compare lists of different types: '%s' and '%s'", bound_function.name, - lhs_child.ToString(), rhs_child.ToString()); - } - bound_function.arguments[0] = LogicalType::LIST(common_child); - bound_function.arguments[1] = LogicalType::LIST(common_child); - } - - return nullptr; -} - -static void ListHasAnyFunction(DataChunk &args, ExpressionState &, Vector &result) { - - auto &l_vec = args.data[0]; - auto &r_vec = args.data[1]; - - if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL || - ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(result)[0] = false; - return; - } - - const auto l_size = ListVector::GetListSize(l_vec); - const auto r_size = ListVector::GetListSize(r_vec); - - auto &l_child = ListVector::GetEntry(l_vec); - auto &r_child = ListVector::GetEntry(r_vec); - - // Setup unified formats for the list elements - UnifiedVectorFormat l_child_format; - UnifiedVectorFormat r_child_format; - - l_child.ToUnifiedFormat(l_size, l_child_format); - r_child.ToUnifiedFormat(r_size, r_child_format); - - // Create the sort keys for the list elements - Vector l_sortkey_vec(LogicalType::BLOB, l_size); - Vector r_sortkey_vec(LogicalType::BLOB, r_size); - - const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - - CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); - CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); - - const auto l_sortkey_ptr = FlatVector::GetData(l_sortkey_vec); - const auto r_sortkey_ptr = FlatVector::GetData(r_sortkey_vec); - - string_set_t set; - - BinaryExecutor::Execute( - l_vec, r_vec, result, args.size(), [&](const list_entry_t &l_list, const list_entry_t &r_list) { - // Short circuit if either list is empty - if (l_list.length == 0 || r_list.length == 0) { - return false; - } - - auto build_list = l_list; - auto probe_list = r_list; - - auto build_data = l_sortkey_ptr; - auto probe_data = r_sortkey_ptr; - - auto build_format = &l_child_format; - auto probe_format = &r_child_format; - - // Use the smaller list to build the set - if (r_list.length < l_list.length) { - - build_list = r_list; - probe_list = l_list; - - build_data = r_sortkey_ptr; - probe_data = l_sortkey_ptr; - - build_format = &r_child_format; - probe_format = &l_child_format; - } - - // Reset the set - set.clear(); - - // Build the set - for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { - const auto entry_idx = build_format->sel->get_index(idx); - if (build_format->validity.RowIsValid(entry_idx)) { - set.insert(build_data[entry_idx]); - } - } - // Probe the set - for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { - const auto entry_idx = probe_format->sel->get_index(idx); - if (probe_format->validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) != set.end()) { - return true; - } - } - return false; - }); -} - -static void ListHasAllFunction(DataChunk &args, ExpressionState &state, Vector &result) { - - const auto &func_expr = state.expr.Cast(); - const auto swap = func_expr.function.name == "<@"; - - auto &l_vec = args.data[swap ? 1 : 0]; - auto &r_vec = args.data[swap ? 0 : 1]; - - if (ListType::GetChildType(l_vec.GetType()) == LogicalType::SQLNULL && - ListType::GetChildType(r_vec.GetType()) == LogicalType::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(result)[0] = true; - return; - } - - const auto l_size = ListVector::GetListSize(l_vec); - const auto r_size = ListVector::GetListSize(r_vec); - - auto &l_child = ListVector::GetEntry(l_vec); - auto &r_child = ListVector::GetEntry(r_vec); - - // Setup unified formats for the list elements - UnifiedVectorFormat build_format; - UnifiedVectorFormat probe_format; - - l_child.ToUnifiedFormat(l_size, build_format); - r_child.ToUnifiedFormat(r_size, probe_format); - - // Create the sort keys for the list elements - Vector l_sortkey_vec(LogicalType::BLOB, l_size); - Vector r_sortkey_vec(LogicalType::BLOB, r_size); - - const OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - - CreateSortKeyHelpers::CreateSortKey(l_child, l_size, order_modifiers, l_sortkey_vec); - CreateSortKeyHelpers::CreateSortKey(r_child, r_size, order_modifiers, r_sortkey_vec); - - const auto build_data = FlatVector::GetData(l_sortkey_vec); - const auto probe_data = FlatVector::GetData(r_sortkey_vec); - - string_set_t set; - - BinaryExecutor::Execute( - l_vec, r_vec, result, args.size(), [&](const list_entry_t &build_list, const list_entry_t &probe_list) { - // Short circuit if the probe list is empty - if (probe_list.length == 0) { - return true; - } - - // Reset the set - set.clear(); - - // Build the set - for (auto idx = build_list.offset; idx < build_list.offset + build_list.length; idx++) { - const auto entry_idx = build_format.sel->get_index(idx); - if (build_format.validity.RowIsValid(entry_idx)) { - set.insert(build_data[entry_idx]); - } - } - - // Probe the set - for (auto idx = probe_list.offset; idx < probe_list.offset + probe_list.length; idx++) { - const auto entry_idx = probe_format.sel->get_index(idx); - if (probe_format.validity.RowIsValid(entry_idx) && set.find(probe_data[entry_idx]) == set.end()) { - return false; - } - } - return true; - }); -} - -ScalarFunction ListHasAnyFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, - ListHasAnyFunction, ListHasAnyOrAllBind); - return fun; -} - -ScalarFunction ListHasAllFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, LogicalType::BOOLEAN, - ListHasAllFunction, ListHasAnyOrAllBind); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp b/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp deleted file mode 100644 index 173b52699..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_reduce.cpp +++ /dev/null @@ -1,232 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/function/lambda_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -struct ReduceExecuteInfo { - ReduceExecuteInfo(LambdaFunctions::LambdaInfo &info, ClientContext &context) - : left_slice(make_uniq(*info.child_vector)) { - SelectionVector left_vector(info.row_count); - active_rows.Resize(info.row_count); - active_rows.SetAllValid(info.row_count); - - left_sel.Initialize(info.row_count); - active_rows_sel.Initialize(info.row_count); - - idx_t reduced_row_idx = 0; - - for (idx_t original_row_idx = 0; original_row_idx < info.row_count; original_row_idx++) { - auto list_column_format_index = info.list_column_format.sel->get_index(original_row_idx); - if (info.list_column_format.validity.RowIsValid(list_column_format_index)) { - if (info.list_entries[list_column_format_index].length == 0) { - throw ParameterNotAllowedException("Cannot perform list_reduce on an empty input list"); - } - left_vector.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset); - reduced_row_idx++; - } else { - // Set the row as invalid and remove it from the active rows. - FlatVector::SetNull(info.result, original_row_idx, true); - active_rows.SetInvalid(original_row_idx); - } - } - left_slice->Slice(left_vector, reduced_row_idx); - - if (info.has_index) { - input_types.push_back(LogicalType::BIGINT); - } - input_types.push_back(left_slice->GetType()); - input_types.push_back(left_slice->GetType()); - for (auto &entry : info.column_infos) { - input_types.push_back(entry.vector.get().GetType()); - } - - expr_executor = make_uniq(context, *info.lambda_expr); - }; - ValidityMask active_rows; - unique_ptr left_slice; - unique_ptr expr_executor; - vector input_types; - - SelectionVector left_sel; - SelectionVector active_rows_sel; -}; - -static bool ExecuteReduce(idx_t loops, ReduceExecuteInfo &execute_info, LambdaFunctions::LambdaInfo &info, - DataChunk &result_chunk) { - idx_t original_row_idx = 0; - idx_t reduced_row_idx = 0; - idx_t valid_row_idx = 0; - - // create selection vectors for the left and right slice - auto data = execute_info.active_rows.GetData(); - - // reset right_sel each iteration to prevent referencing issues - SelectionVector right_sel; - right_sel.Initialize(info.row_count); - - idx_t bits_per_entry = sizeof(idx_t) * 8; - for (idx_t entry_idx = 0; original_row_idx < info.row_count; entry_idx++) { - if (data[entry_idx] == 0) { - original_row_idx += bits_per_entry; - continue; - } - - for (idx_t j = 0; entry_idx * bits_per_entry + j < info.row_count; j++) { - if (!execute_info.active_rows.RowIsValid(original_row_idx)) { - original_row_idx++; - continue; - } - auto list_column_format_index = info.list_column_format.sel->get_index(original_row_idx); - if (info.list_entries[list_column_format_index].length > loops + 1) { - right_sel.set_index(reduced_row_idx, info.list_entries[list_column_format_index].offset + loops + 1); - execute_info.left_sel.set_index(reduced_row_idx, valid_row_idx); - execute_info.active_rows_sel.set_index(reduced_row_idx, original_row_idx); - reduced_row_idx++; - - } else { - execute_info.active_rows.SetInvalid(original_row_idx); - auto val = execute_info.left_slice->GetValue(valid_row_idx); - info.result.SetValue(original_row_idx, val); - } - - original_row_idx++; - valid_row_idx++; - } - } - - if (reduced_row_idx == 0) { - return true; - } - - // create the index vector - Vector index_vector(Value::BIGINT(UnsafeNumericCast(loops + 2))); - - // slice the left and right slice - execute_info.left_slice->Slice(*execute_info.left_slice, execute_info.left_sel, reduced_row_idx); - Vector right_slice(*info.child_vector, right_sel, reduced_row_idx); - - // create the input chunk - DataChunk input_chunk; - input_chunk.InitializeEmpty(execute_info.input_types); - input_chunk.SetCardinality(reduced_row_idx); - - idx_t slice_offset = info.has_index ? 1 : 0; - if (info.has_index) { - input_chunk.data[0].Reference(index_vector); - } - input_chunk.data[slice_offset + 1].Reference(*execute_info.left_slice); - input_chunk.data[slice_offset].Reference(right_slice); - - // add the other columns - vector slices; - for (idx_t i = 0; i < info.column_infos.size(); i++) { - if (info.column_infos[i].vector.get().GetVectorType() == VectorType::CONSTANT_VECTOR) { - // only reference constant vectors - input_chunk.data[slice_offset + 2 + i].Reference(info.column_infos[i].vector); - } else { - // slice the other vectors - slices.emplace_back(info.column_infos[i].vector, execute_info.active_rows_sel, reduced_row_idx); - input_chunk.data[slice_offset + 2 + i].Reference(slices.back()); - } - } - - result_chunk.Reset(); - result_chunk.SetCardinality(reduced_row_idx); - execute_info.expr_executor->Execute(input_chunk, result_chunk); - - // We need to copy the result into left_slice to avoid data loss due to vector.Reference(...). - // Otherwise, we only keep the data of the previous iteration alive, not that of previous iterations. - execute_info.left_slice = make_uniq(result_chunk.data[0].GetType(), reduced_row_idx); - VectorOperations::Copy(result_chunk.data[0], *execute_info.left_slice, reduced_row_idx, 0, 0); - return false; -} - -void LambdaFunctions::ListReduceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // Initializes the left slice from the list entries, active rows, the expression executor and the input types - bool completed = false; - LambdaFunctions::LambdaInfo info(args, state, result, completed); - if (completed) { - return; - } - - ReduceExecuteInfo execute_info(info, state.GetContext()); - - // Since the left slice references the result chunk, we need to create two result chunks. - // This means there is always an empty result chunk for the next iteration, - // without the referenced chunk having to be reset until the current iteration is complete. - DataChunk odd_result_chunk; - odd_result_chunk.Initialize(Allocator::DefaultAllocator(), {info.lambda_expr->return_type}); - - DataChunk even_result_chunk; - even_result_chunk.Initialize(Allocator::DefaultAllocator(), {info.lambda_expr->return_type}); - - // Execute reduce until all rows are finished. - idx_t loops = 0; - bool end = false; - while (!end) { - auto &result_chunk = loops % 2 ? odd_result_chunk : even_result_chunk; - auto &spare_result_chunk = loops % 2 ? even_result_chunk : odd_result_chunk; - - end = ExecuteReduce(loops, execute_info, info, result_chunk); - spare_result_chunk.Reset(); - loops++; - } - - if (info.is_all_constant && !info.is_volatile) { - info.result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListReduceBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // the list column and the bound lambda expression - D_ASSERT(arguments.size() == 2); - if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { - throw BinderException("Invalid lambda expression!"); - } - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - auto &bound_lambda_expr = arguments[1]->Cast(); - if (bound_lambda_expr.parameter_count < 2 || bound_lambda_expr.parameter_count > 3) { - throw BinderException("list_reduce expects a function with 2 or 3 arguments"); - } - auto has_index = bound_lambda_expr.parameter_count == 3; - - unique_ptr bind_data = LambdaFunctions::ListLambdaPrepareBind(arguments, context, bound_function); - if (bind_data) { - return bind_data; - } - - auto list_child_type = arguments[0]->return_type; - list_child_type = ListType::GetChildType(list_child_type); - - auto cast_lambda_expr = - BoundCastExpression::AddCastToType(context, std::move(bound_lambda_expr.lambda_expr), list_child_type, false); - if (!cast_lambda_expr) { - throw BinderException("Could not cast lambda expression to list child type"); - } - bound_function.return_type = cast_lambda_expr->return_type; - return make_uniq(bound_function.return_type, std::move(cast_lambda_expr), has_index); -} - -static LogicalType ListReduceBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { - return LambdaFunctions::BindTernaryLambda(parameter_idx, list_child_type); -} - -ScalarFunction ListReduceFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::ANY, - LambdaFunctions::ListReduceFunction, ListReduceBind, nullptr, nullptr); - - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListReduceBindLambda; - - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp b/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp deleted file mode 100644 index 5ab523d20..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_sort.cpp +++ /dev/null @@ -1,416 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -struct ListSortBindData : public FunctionData { - ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, bool is_grade_up, - const LogicalType &return_type_p, const LogicalType &child_type_p, ClientContext &context_p); - ~ListSortBindData() override; - - OrderType order_type; - OrderByNullType null_order; - LogicalType return_type; - LogicalType child_type; - bool is_grade_up; - - vector types; - vector payload_types; - - ClientContext &context; - RowLayout payload_layout; - vector orders; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; -}; - -ListSortBindData::ListSortBindData(OrderType order_type_p, OrderByNullType null_order_p, bool is_grade_up_p, - const LogicalType &return_type_p, const LogicalType &child_type_p, - ClientContext &context_p) - : order_type(order_type_p), null_order(null_order_p), return_type(return_type_p), child_type(child_type_p), - is_grade_up(is_grade_up_p), context(context_p) { - - // get the vector types - types.emplace_back(LogicalType::USMALLINT); - types.emplace_back(child_type); - D_ASSERT(types.size() == 2); - - // get the payload types - payload_types.emplace_back(LogicalType::UINTEGER); - D_ASSERT(payload_types.size() == 1); - - // initialize the payload layout - payload_layout.Initialize(payload_types); - - // get the BoundOrderByNode - auto idx_col_expr = make_uniq_base(LogicalType::USMALLINT, 0U); - auto lists_col_expr = make_uniq_base(child_type, 1U); - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::ORDER_DEFAULT, std::move(idx_col_expr)); - orders.emplace_back(order_type, null_order, std::move(lists_col_expr)); -} - -unique_ptr ListSortBindData::Copy() const { - return make_uniq(order_type, null_order, is_grade_up, return_type, child_type, context); -} - -bool ListSortBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return order_type == other.order_type && null_order == other.null_order && is_grade_up == other.is_grade_up; -} - -ListSortBindData::~ListSortBindData() { -} - -// create the key_chunk and the payload_chunk and sink them into the local_sort_state -void SinkDataChunk(Vector *child_vector, SelectionVector &sel, idx_t offset_lists_indices, vector &types, - vector &payload_types, Vector &payload_vector, LocalSortState &local_sort_state, - bool &data_to_sort, Vector &lists_indices) { - - // slice the child vector - Vector slice(*child_vector, sel, offset_lists_indices); - - // initialize and fill key_chunk - DataChunk key_chunk; - key_chunk.InitializeEmpty(types); - key_chunk.data[0].Reference(lists_indices); - key_chunk.data[1].Reference(slice); - key_chunk.SetCardinality(offset_lists_indices); - - // initialize and fill key_chunk and payload_chunk - DataChunk payload_chunk; - payload_chunk.InitializeEmpty(payload_types); - payload_chunk.data[0].Reference(payload_vector); - payload_chunk.SetCardinality(offset_lists_indices); - - key_chunk.Verify(); - payload_chunk.Verify(); - - // sink - key_chunk.Flatten(); - local_sort_state.SinkChunk(key_chunk, payload_chunk); - data_to_sort = true; -} - -static void ListSortFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() >= 1 && args.ColumnCount() <= 3); - auto count = args.size(); - Vector &input_lists = args.data[0]; - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_validity = FlatVector::Validity(result); - - if (input_lists.GetType().id() == LogicalTypeId::SQLNULL) { - result_validity.SetInvalid(0); - return; - } - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // initialize the global and local sorting state - auto &buffer_manager = BufferManager::GetBufferManager(info.context); - GlobalSortState global_sort_state(buffer_manager, info.orders, info.payload_layout); - LocalSortState local_sort_state; - local_sort_state.Initialize(global_sort_state, buffer_manager); - - Vector sort_result_vec = info.is_grade_up ? Vector(input_lists.GetType()) : result; - - // this ensures that we do not change the order of the entries in the input chunk - VectorOperations::Copy(input_lists, sort_result_vec, count, 0, 0); - - // get the child vector - auto lists_size = ListVector::GetListSize(sort_result_vec); - auto &child_vector = ListVector::GetEntry(sort_result_vec); - - // get the lists data - UnifiedVectorFormat lists_data; - sort_result_vec.ToUnifiedFormat(count, lists_data); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - // create the lists_indices vector, this contains an element for each list's entry, - // the element corresponds to the list's index, e.g. for [1, 2, 4], [5, 4] - // lists_indices contains [0, 0, 0, 1, 1] - Vector lists_indices(LogicalType::USMALLINT); - auto lists_indices_data = FlatVector::GetData(lists_indices); - - // create the payload_vector, this is just a vector containing incrementing integers - // this will later be used as the 'new' selection vector of the child_vector, after - // rearranging the payload according to the sorting order - Vector payload_vector(LogicalType::UINTEGER); - auto payload_vector_data = FlatVector::GetData(payload_vector); - - // selection vector pointing to the data of the child vector, - // used for slicing the child_vector correctly - SelectionVector sel(STANDARD_VECTOR_SIZE); - - idx_t offset_lists_indices = 0; - uint32_t incr_payload_count = 0; - bool data_to_sort = false; - - for (idx_t i = 0; i < count; i++) { - auto lists_index = lists_data.sel->get_index(i); - const auto &list_entry = list_entries[lists_index]; - - // nothing to do for this list - if (!lists_data.validity.RowIsValid(lists_index)) { - result_validity.SetInvalid(i); - continue; - } - - // empty list, no sorting required - if (list_entry.length == 0) { - continue; - } - - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - // lists_indices vector is full, sink - if (offset_lists_indices == STANDARD_VECTOR_SIZE) { - SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, - local_sort_state, data_to_sort, lists_indices); - offset_lists_indices = 0; - } - - auto source_idx = list_entry.offset + child_idx; - sel.set_index(offset_lists_indices, source_idx); - lists_indices_data[offset_lists_indices] = UnsafeNumericCast(i); - payload_vector_data[offset_lists_indices] = NumericCast(source_idx); - offset_lists_indices++; - incr_payload_count++; - } - } - - if (offset_lists_indices != 0) { - SinkDataChunk(&child_vector, sel, offset_lists_indices, info.types, info.payload_types, payload_vector, - local_sort_state, data_to_sort, lists_indices); - } - - if (info.is_grade_up) { - ListVector::Reserve(result, lists_size); - ListVector::SetListSize(result, lists_size); - auto result_data = ListVector::GetData(result); - memcpy(result_data, list_entries, count * sizeof(list_entry_t)); - } - - if (data_to_sort) { - // add local state to global state, which sorts the data - global_sort_state.AddLocalState(local_sort_state); - global_sort_state.PrepareMergePhase(); - - // selection vector that is to be filled with the 'sorted' payload - SelectionVector sel_sorted(incr_payload_count); - idx_t sel_sorted_idx = 0; - - // scan the sorted row data - PayloadScanner scanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state); - for (;;) { - DataChunk result_chunk; - result_chunk.Initialize(Allocator::DefaultAllocator(), info.payload_types); - result_chunk.SetCardinality(0); - scanner.Scan(result_chunk); - if (result_chunk.size() == 0) { - break; - } - - // construct the selection vector with the new order from the result vectors - Vector result_vector(result_chunk.data[0]); - auto result_data = FlatVector::GetData(result_vector); - auto row_count = result_chunk.size(); - - for (idx_t i = 0; i < row_count; i++) { - sel_sorted.set_index(sel_sorted_idx, result_data[i]); - D_ASSERT(result_data[i] < lists_size); - sel_sorted_idx++; - } - } - - D_ASSERT(sel_sorted_idx == incr_payload_count); - if (info.is_grade_up) { - auto &result_entry = ListVector::GetEntry(result); - auto result_data = ListVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - if (!result_validity.RowIsValid(i)) { - continue; - } - for (idx_t j = result_data[i].offset; j < result_data[i].offset + result_data[i].length; j++) { - auto b = sel_sorted.get_index(j) - result_data[i].offset; - result_entry.SetValue(j, Value::BIGINT(UnsafeNumericCast(b + 1))); - } - } - } else { - child_vector.Slice(sel_sorted, sel_sorted_idx); - child_vector.Flatten(sel_sorted_idx); - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, OrderType &order, - OrderByNullType &null_order) { - - LogicalType child_type; - if (arguments[0]->return_type == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - child_type = bound_function.return_type; - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); - } - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - child_type = ListType::GetChildType(arguments[0]->return_type); - - bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; - - return make_uniq(order, null_order, false, bound_function.return_type, child_type, context); -} - -template -static T GetOrder(ClientContext &context, Expression &expr) { - if (!expr.IsFoldable()) { - throw InvalidInputException("Sorting order must be a constant"); - } - Value order_value = ExpressionExecutor::EvaluateScalar(context, expr); - auto order_name = StringUtil::Upper(order_value.ToString()); - return EnumUtil::FromString(order_name.c_str()); -} - -static unique_ptr ListGradeUpBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - D_ASSERT(!arguments.empty() && arguments.size() <= 3); - auto order = OrderType::ORDER_DEFAULT; - auto null_order = OrderByNullType::ORDER_DEFAULT; - - // get the sorting order - if (arguments.size() >= 2) { - order = GetOrder(context, *arguments[1]); - } - // get the null sorting order - if (arguments.size() == 3) { - null_order = GetOrder(context, *arguments[2]); - } - auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - null_order = config.ResolveNullOrder(order, null_order); - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = LogicalType::LIST(LogicalTypeId::BIGINT); - auto child_type = ListType::GetChildType(arguments[0]->return_type); - return make_uniq(order, null_order, true, bound_function.return_type, child_type, context); -} - -static unique_ptr ListNormalSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(!arguments.empty() && arguments.size() <= 3); - auto order = OrderType::ORDER_DEFAULT; - auto null_order = OrderByNullType::ORDER_DEFAULT; - - // get the sorting order - if (arguments.size() >= 2) { - order = GetOrder(context, *arguments[1]); - } - // get the null sorting order - if (arguments.size() == 3) { - null_order = GetOrder(context, *arguments[2]); - } - auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - null_order = config.ResolveNullOrder(order, null_order); - return ListSortBind(context, bound_function, arguments, order, null_order); -} - -static unique_ptr ListReverseSortBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto order = OrderType::ORDER_DEFAULT; - auto null_order = OrderByNullType::ORDER_DEFAULT; - - if (arguments.size() == 2) { - null_order = GetOrder(context, *arguments[1]); - } - auto &config = DBConfig::GetConfig(context); - order = config.ResolveOrder(order); - switch (order) { - case OrderType::ASCENDING: - order = OrderType::DESCENDING; - break; - case OrderType::DESCENDING: - order = OrderType::ASCENDING; - break; - default: - throw InternalException("Unexpected order type in list reverse sort"); - } - null_order = config.ResolveNullOrder(order, null_order); - return ListSortBind(context, bound_function, arguments, order, null_order); -} - -ScalarFunctionSet ListSortFun::GetFunctions() { - // one parameter: list - ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, - ListNormalSortBind); - - // two parameters: list, order - ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); - - // three parameters: list, order, null order - ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListNormalSortBind); - - ScalarFunctionSet list_sort; - list_sort.AddFunction(sort); - list_sort.AddFunction(sort_order); - list_sort.AddFunction(sort_orders); - return list_sort; -} - -ScalarFunctionSet ListGradeUpFun::GetFunctions() { - // one parameter: list - ScalarFunction sort({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), ListSortFunction, - ListGradeUpBind); - - // two parameters: list, order - ScalarFunction sort_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListGradeUpBind); - - // three parameters: list, order, null order - ScalarFunction sort_orders({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListGradeUpBind); - - ScalarFunctionSet list_grade_up; - list_grade_up.AddFunction(sort); - list_grade_up.AddFunction(sort_order); - list_grade_up.AddFunction(sort_orders); - return list_grade_up; -} - -ScalarFunctionSet ListReverseSortFun::GetFunctions() { - // one parameter: list - ScalarFunction sort_reverse({LogicalType::LIST(LogicalType::ANY)}, LogicalType::LIST(LogicalType::ANY), - ListSortFunction, ListReverseSortBind); - - // two parameters: list, null order - ScalarFunction sort_reverse_null_order({LogicalType::LIST(LogicalType::ANY), LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::ANY), ListSortFunction, ListReverseSortBind); - - ScalarFunctionSet list_reverse_sort; - list_reverse_sort.AddFunction(sort_reverse); - list_reverse_sort.AddFunction(sort_reverse_null_order); - return list_reverse_sort; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp b/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp deleted file mode 100644 index 26c6ad4b3..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_transform.cpp +++ /dev/null @@ -1,41 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" - -#include "duckdb/function/lambda_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -static unique_ptr ListTransformBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // the list column and the bound lambda expression - D_ASSERT(arguments.size() == 2); - if (arguments[1]->GetExpressionClass() != ExpressionClass::BOUND_LAMBDA) { - throw BinderException("Invalid lambda expression!"); - } - - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - auto &bound_lambda_expr = arguments[1]->Cast(); - bound_function.return_type = LogicalType::LIST(bound_lambda_expr.lambda_expr->return_type); - auto has_index = bound_lambda_expr.parameter_count == 2; - return LambdaFunctions::ListLambdaBind(context, bound_function, arguments, has_index); -} - -static LogicalType ListTransformBindLambda(const idx_t parameter_idx, const LogicalType &list_child_type) { - return LambdaFunctions::BindBinaryLambda(parameter_idx, list_child_type); -} - -ScalarFunction ListTransformFun::GetFunction() { - ScalarFunction fun({LogicalType::LIST(LogicalType::ANY), LogicalType::LAMBDA}, LogicalType::LIST(LogicalType::ANY), - LambdaFunctions::ListTransformFunction, ListTransformBind, nullptr, nullptr); - - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = ListLambdaBindData::Serialize; - fun.deserialize = ListLambdaBindData::Deserialize; - fun.bind_lambda = ListTransformBindLambda; - - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp b/src/duckdb/extension/core_functions/scalar/list/list_value.cpp deleted file mode 100644 index 01b342ec4..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/list_value.cpp +++ /dev/null @@ -1,203 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/storage/statistics/list_stats.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/parser/query_error_context.hpp" - -namespace duckdb { - -struct ListValueAssign { - template - static T Assign(const T &input, Vector &result) { - return input; - } -}; - -struct ListValueStringAssign { - template - static T Assign(const T &input, Vector &result) { - return StringVector::AddStringOrBlob(result, input); - } -}; - -template -static void TemplatedListValueFunction(DataChunk &args, Vector &result) { - idx_t list_size = args.ColumnCount(); - ListVector::Reserve(result, args.size() * list_size); - auto result_data = FlatVector::GetData(result); - auto &list_child = ListVector::GetEntry(result); - auto child_data = FlatVector::GetData(list_child); - auto &child_validity = FlatVector::Validity(list_child); - - auto unified_format = args.ToUnifiedFormat(); - for (idx_t r = 0; r < args.size(); r++) { - for (idx_t c = 0; c < list_size; c++) { - auto input_idx = unified_format[c].sel->get_index(r); - auto result_idx = r * list_size + c; - auto input_data = UnifiedVectorFormat::GetData(unified_format[c]); - if (unified_format[c].validity.RowIsValid(input_idx)) { - child_data[result_idx] = OP::template Assign(input_data[input_idx], list_child); - } else { - child_validity.SetInvalid(result_idx); - } - } - result_data[r].offset = r * list_size; - result_data[r].length = list_size; - } - ListVector::SetListSize(result, args.size() * list_size); -} - -static void TemplatedListValueFunctionFallback(DataChunk &args, Vector &result) { - auto &child_type = ListType::GetChildType(result.GetType()); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - result_data[i].offset = ListVector::GetListSize(result); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto val = args.GetValue(col_idx, i).DefaultCastAs(child_type); - ListVector::PushBack(result, val); - } - result_data[i].length = args.ColumnCount(); - } -} - -static void ListValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - if (args.ColumnCount() == 0) { - // no columns - early out - result is a constant empty list - auto result_data = FlatVector::GetData(result); - result_data[0].length = 0; - result_data[0].offset = 0; - return; - } - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - } - } - auto &result_type = ListVector::GetEntry(result).GetType(); - switch (result_type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::INT16: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::INT32: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::INT64: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::UINT8: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::UINT16: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::UINT32: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::UINT64: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::INT128: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::UINT128: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::FLOAT: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::DOUBLE: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::INTERVAL: - TemplatedListValueFunction(args, result); - break; - case PhysicalType::VARCHAR: - TemplatedListValueFunction(args, result); - break; - default: { - TemplatedListValueFunctionFallback(args, result); - break; - } - } -} - -template -static unique_ptr ListValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // collect names and deconflict, construct return type - LogicalType child_type = - arguments.empty() ? LogicalType::SQLNULL : ExpressionBinder::GetExpressionReturnType(*arguments[0]); - for (idx_t i = 1; i < arguments.size(); i++) { - auto arg_type = ExpressionBinder::GetExpressionReturnType(*arguments[i]); - if (!LogicalType::TryGetMaxLogicalType(context, child_type, arg_type, child_type)) { - if (IS_UNPIVOT) { - string list_arguments = "Full list: "; - idx_t error_index = list_arguments.size(); - for (idx_t k = 0; k < arguments.size(); k++) { - if (k > 0) { - list_arguments += ", "; - } - if (k == i) { - error_index = list_arguments.size(); - } - list_arguments += arguments[k]->ToString() + " " + arguments[k]->return_type.ToString(); - } - auto error = - StringUtil::Format("Cannot unpivot columns of types %s and %s - an explicit cast is required", - child_type.ToString(), arg_type.ToString()); - throw BinderException(arguments[i]->GetQueryLocation(), - QueryErrorContext::Format(list_arguments, error, error_index, false)); - } else { - throw BinderException(arguments[i]->GetQueryLocation(), - "Cannot create a list of types %s and %s - an explicit cast is required", - child_type.ToString(), arg_type.ToString()); - } - } - } - child_type = LogicalType::NormalizeType(child_type); - - // this is more for completeness reasons - bound_function.varargs = child_type; - bound_function.return_type = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); -} - -unique_ptr ListValueStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto list_stats = ListStats::CreateEmpty(expr.return_type); - auto &list_child_stats = ListStats::GetChildStats(list_stats); - for (idx_t i = 0; i < child_stats.size(); i++) { - list_child_stats.Merge(child_stats[i]); - } - return list_stats.ToUnique(); -} - -ScalarFunction ListValueFun::GetFunction() { - // the arguments and return types are actually set in the binder function - ScalarFunction fun("list_value", {}, LogicalTypeId::LIST, ListValueFunction, ListValueBind, nullptr, - ListValueStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -ScalarFunction UnpivotListFun::GetFunction() { - auto fun = ListValueFun::GetFunction(); - fun.name = "unpivot_list"; - fun.bind = ListValueBind; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/list/range.cpp b/src/duckdb/extension/core_functions/scalar/list/range.cpp deleted file mode 100644 index 8c641d13d..000000000 --- a/src/duckdb/extension/core_functions/scalar/list/range.cpp +++ /dev/null @@ -1,281 +0,0 @@ -#include "core_functions/scalar/list_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types/timestamp.hpp" - -namespace duckdb { - -struct NumericRangeInfo { - using TYPE = int64_t; - using INCREMENT_TYPE = int64_t; - - static int64_t DefaultStart() { - return 0; - } - static int64_t DefaultIncrement() { - return 1; - } - - static uint64_t ListLength(int64_t start_value, int64_t end_value, int64_t increment_value, bool inclusive_bound) { - if (increment_value == 0) { - return 0; - } - if (start_value > end_value && increment_value > 0) { - return 0; - } - if (start_value < end_value && increment_value < 0) { - return 0; - } - hugeint_t total_diff = AbsValue(hugeint_t(end_value) - hugeint_t(start_value)); - hugeint_t increment = AbsValue(hugeint_t(increment_value)); - hugeint_t total_values = total_diff / increment; - if (total_diff % increment == 0) { - if (inclusive_bound) { - total_values += 1; - } - } else { - total_values += 1; - } - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - return Hugeint::Cast(total_values); - } - - static void Increment(int64_t &input, int64_t increment) { - input += increment; - } -}; -struct TimestampRangeInfo { - using TYPE = timestamp_t; - using INCREMENT_TYPE = interval_t; - - static timestamp_t DefaultStart() { - throw InternalException("Default start not implemented for timestamp range"); - } - static interval_t DefaultIncrement() { - throw InternalException("Default increment not implemented for timestamp range"); - } - static uint64_t ListLength(timestamp_t start_value, timestamp_t end_value, interval_t increment_value, - bool inclusive_bound) { - bool is_positive = increment_value.months > 0 || increment_value.days > 0 || increment_value.micros > 0; - bool is_negative = increment_value.months < 0 || increment_value.days < 0 || increment_value.micros < 0; - if (!is_negative && !is_positive) { - // interval is 0: no result - return 0; - } - // We don't allow infinite bounds because they generate errors or infinite loops - if (!Timestamp::IsFinite(start_value) || !Timestamp::IsFinite(end_value)) { - throw InvalidInputException("Interval infinite bounds not supported"); - } - - if (is_negative && is_positive) { - // we don't allow a mix of - throw InvalidInputException("Interval with mix of negative/positive entries not supported"); - } - if (start_value > end_value && is_positive) { - return 0; - } - if (start_value < end_value && is_negative) { - return 0; - } - uint64_t total_values = 0; - if (is_negative) { - // negative interval, start_value is going down - while (inclusive_bound ? start_value >= end_value : start_value > end_value) { - start_value = Interval::Add(start_value, increment_value); - total_values++; - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - } - } else { - // positive interval, start_value is going up - while (inclusive_bound ? start_value <= end_value : start_value < end_value) { - start_value = Interval::Add(start_value, increment_value); - total_values++; - if (total_values > NumericLimits::Maximum()) { - throw InvalidInputException("Lists larger than 2^32 elements are not supported"); - } - } - } - return total_values; - } - - static void Increment(timestamp_t &input, interval_t increment) { - input = Interval::Add(input, increment); - } -}; - -template -class RangeInfoStruct { -public: - explicit RangeInfoStruct(DataChunk &args_p) : args(args_p) { - switch (args.ColumnCount()) { - case 1: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - break; - case 2: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - args.data[1].ToUnifiedFormat(args.size(), vdata[1]); - break; - case 3: - args.data[0].ToUnifiedFormat(args.size(), vdata[0]); - args.data[1].ToUnifiedFormat(args.size(), vdata[1]); - args.data[2].ToUnifiedFormat(args.size(), vdata[2]); - break; - default: - throw InternalException("Unsupported number of parameters for range"); - } - } - - bool RowIsValid(idx_t row_idx) { - for (idx_t i = 0; i < args.ColumnCount(); i++) { - auto idx = vdata[i].sel->get_index(row_idx); - if (!vdata[i].validity.RowIsValid(idx)) { - return false; - } - } - return true; - } - - typename OP::TYPE StartListValue(idx_t row_idx) { - if (args.ColumnCount() == 1) { - return OP::DefaultStart(); - } else { - auto data = (typename OP::TYPE *)vdata[0].data; - auto idx = vdata[0].sel->get_index(row_idx); - return data[idx]; - } - } - - typename OP::TYPE EndListValue(idx_t row_idx) { - idx_t vdata_idx = args.ColumnCount() == 1 ? 0 : 1; - auto data = (typename OP::TYPE *)vdata[vdata_idx].data; - auto idx = vdata[vdata_idx].sel->get_index(row_idx); - return data[idx]; - } - - typename OP::INCREMENT_TYPE ListIncrementValue(idx_t row_idx) { - if (args.ColumnCount() < 3) { - return OP::DefaultIncrement(); - } else { - auto data = (typename OP::INCREMENT_TYPE *)vdata[2].data; - auto idx = vdata[2].sel->get_index(row_idx); - return data[idx]; - } - } - - void GetListValues(idx_t row_idx, typename OP::TYPE &start_value, typename OP::TYPE &end_value, - typename OP::INCREMENT_TYPE &increment_value) { - start_value = StartListValue(row_idx); - end_value = EndListValue(row_idx); - increment_value = ListIncrementValue(row_idx); - } - - uint64_t ListLength(idx_t row_idx) { - typename OP::TYPE start_value; - typename OP::TYPE end_value; - typename OP::INCREMENT_TYPE increment_value; - GetListValues(row_idx, start_value, end_value, increment_value); - return OP::ListLength(start_value, end_value, increment_value, INCLUSIVE_BOUND); - } - -private: - DataChunk &args; - UnifiedVectorFormat vdata[3]; -}; - -template -static void ListRangeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - RangeInfoStruct info(args); - idx_t args_size = 1; - auto result_type = VectorType::CONSTANT_VECTOR; - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - args_size = args.size(); - result_type = VectorType::FLAT_VECTOR; - break; - } - } - auto list_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - uint64_t total_size = 0; - for (idx_t i = 0; i < args_size; i++) { - if (!info.RowIsValid(i)) { - result_validity.SetInvalid(i); - list_data[i].offset = total_size; - list_data[i].length = 0; - } else { - list_data[i].offset = total_size; - list_data[i].length = info.ListLength(i); - total_size += list_data[i].length; - } - } - - // now construct the child vector of the list - ListVector::Reserve(result, total_size); - auto range_data = FlatVector::GetData(ListVector::GetEntry(result)); - idx_t total_idx = 0; - for (idx_t i = 0; i < args_size; i++) { - typename OP::TYPE start_value = info.StartListValue(i); - typename OP::INCREMENT_TYPE increment = info.ListIncrementValue(i); - - typename OP::TYPE range_value = start_value; - for (idx_t range_idx = 0; range_idx < list_data[i].length; range_idx++) { - if (range_idx > 0) { - OP::Increment(range_value, increment); - } - range_data[total_idx++] = range_value; - } - } - - ListVector::SetListSize(result, total_size); - result.SetVectorType(result_type); - - result.Verify(args.size()); -} - -ScalarFunctionSet ListRangeFun::GetFunctions() { - // the arguments and return types are actually set in the binder function - ScalarFunctionSet range_set; - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - range_set.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::LIST(LogicalType::TIMESTAMP), - ListRangeFunction)); - for (auto &func : range_set.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return range_set; -} - -ScalarFunctionSet GenerateSeriesFun::GetFunctions() { - ScalarFunctionSet generate_series; - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::BIGINT), - ListRangeFunction)); - generate_series.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - LogicalType::LIST(LogicalType::TIMESTAMP), - ListRangeFunction)); - for (auto &func : generate_series.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return generate_series; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp b/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp deleted file mode 100644 index 9c81223e7..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/cardinality.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -static void CardinalityFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &map = args.data[0]; - UnifiedVectorFormat map_data; - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - map.ToUnifiedFormat(args.size(), map_data); - for (idx_t row = 0; row < args.size(); row++) { - auto list_entry = UnifiedVectorFormat::GetData(map_data)[map_data.sel->get_index(row)]; - result_data[row] = list_entry.length; - result_validity.Set(row, map_data.validity.RowIsValid(map_data.sel->get_index(row))); - } - - if (args.size() == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr CardinalityBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw BinderException("Cardinality must have exactly one arguments"); - } - - if (arguments[0]->return_type.id() != LogicalTypeId::MAP) { - throw BinderException("Cardinality can only operate on MAPs"); - } - - bound_function.return_type = LogicalType::UBIGINT; - return make_uniq(bound_function.return_type); -} - -ScalarFunction CardinalityFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY}, LogicalType::UBIGINT, CardinalityFunction, CardinalityBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map.cpp b/src/duckdb/extension/core_functions/scalar/map/map.cpp deleted file mode 100644 index b83a4a081..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map.cpp +++ /dev/null @@ -1,223 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -static void MapFunctionEmptyInput(Vector &result, const idx_t row_count) { - // If no chunk is set in ExpressionExecutor::ExecuteExpression (args.data.empty(), e.g., in SELECT MAP()), - // then we always pass a row_count of 1. - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ListVector::SetListSize(result, 0); - - auto result_data = ListVector::GetData(result); - result_data[0] = list_entry_t(); - result.Verify(row_count); -} - -static bool MapIsNull(DataChunk &chunk) { - if (chunk.data.empty()) { - return false; - } - D_ASSERT(chunk.data.size() == 2); - auto &keys = chunk.data[0]; - auto &values = chunk.data[1]; - - if (keys.GetType().id() == LogicalTypeId::SQLNULL) { - return true; - } - if (values.GetType().id() == LogicalTypeId::SQLNULL) { - return true; - } - return false; -} - -static void MapFunction(DataChunk &args, ExpressionState &, Vector &result) { - - // internal MAP representation - // - LIST-vector that contains STRUCTs as child entries - // - STRUCTs have exactly two fields, a key-field, and a value-field - // - key names are unique - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - - if (MapIsNull(args)) { - auto &validity = FlatVector::Validity(result); - validity.SetInvalid(0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - return; - } - - auto row_count = args.size(); - - // early-out, if no data - if (args.data.empty()) { - return MapFunctionEmptyInput(result, row_count); - } - - auto &keys = args.data[0]; - auto &values = args.data[1]; - - // a LIST vector, where each row contains a LIST of KEYS - UnifiedVectorFormat keys_data; - keys.ToUnifiedFormat(row_count, keys_data); - auto keys_entries = UnifiedVectorFormat::GetData(keys_data); - - // the KEYs child vector - auto keys_child_vector = ListVector::GetEntry(keys); - UnifiedVectorFormat keys_child_data; - keys_child_vector.ToUnifiedFormat(ListVector::GetListSize(keys), keys_child_data); - - // a LIST vector, where each row contains a LIST of VALUES - UnifiedVectorFormat values_data; - values.ToUnifiedFormat(row_count, values_data); - auto values_entries = UnifiedVectorFormat::GetData(values_data); - - // the VALUEs child vector - auto values_child_vector = ListVector::GetEntry(values); - UnifiedVectorFormat values_child_data; - values_child_vector.ToUnifiedFormat(ListVector::GetListSize(values), values_child_data); - - // a LIST vector, where each row contains a MAP (LIST of STRUCTs) - UnifiedVectorFormat result_data; - result.ToUnifiedFormat(row_count, result_data); - auto result_entries = UnifiedVectorFormat::GetDataNoConst(result_data); - - auto &result_validity = FlatVector::Validity(result); - - // get the resulting size of the key/value child lists - idx_t result_child_size = 0; - for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto keys_idx = keys_data.sel->get_index(row_idx); - auto values_idx = values_data.sel->get_index(row_idx); - if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { - continue; - } - auto keys_entry = keys_entries[keys_idx]; - result_child_size += keys_entry.length; - } - - // we need to slice potential non-flat vectors - SelectionVector sel_keys(result_child_size); - SelectionVector sel_values(result_child_size); - idx_t offset = 0; - - for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - - auto keys_idx = keys_data.sel->get_index(row_idx); - auto values_idx = values_data.sel->get_index(row_idx); - auto result_idx = result_data.sel->get_index(row_idx); - - // NULL MAP - if (!keys_data.validity.RowIsValid(keys_idx) || !values_data.validity.RowIsValid(values_idx)) { - result_validity.SetInvalid(row_idx); - continue; - } - - auto keys_entry = keys_entries[keys_idx]; - auto values_entry = values_entries[values_idx]; - - if (keys_entry.length != values_entry.length) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NOT_ALIGNED); - } - - // set the selection vectors and perform a duplicate key check - value_set_t unique_keys; - for (idx_t child_idx = 0; child_idx < keys_entry.length; child_idx++) { - - auto key_idx = keys_child_data.sel->get_index(keys_entry.offset + child_idx); - auto value_idx = values_child_data.sel->get_index(values_entry.offset + child_idx); - - // NULL check - if (!keys_child_data.validity.RowIsValid(key_idx)) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY); - } - - // unique check - auto value = keys_child_vector.GetValue(key_idx); - auto unique = unique_keys.insert(value).second; - if (!unique) { - MapVector::EvalMapInvalidReason(MapInvalidReason::DUPLICATE_KEY); - } - - // set selection vectors - sel_keys.set_index(offset + child_idx, key_idx); - sel_values.set_index(offset + child_idx, value_idx); - } - - // keys_entry and values_entry have the same length - result_entries[result_idx].length = keys_entry.length; - result_entries[result_idx].offset = offset; - offset += keys_entry.length; - } - D_ASSERT(offset == result_child_size); - - auto &result_key_vector = MapVector::GetKeys(result); - auto &result_value_vector = MapVector::GetValues(result); - - ListVector::SetListSize(result, offset); - result_key_vector.Slice(keys_child_vector, sel_keys, offset); - result_key_vector.Flatten(offset); - result_value_vector.Slice(values_child_vector, sel_values, offset); - result_value_vector.Flatten(offset); - FlatVector::Validity(ListVector::GetEntry(result)).Resize(result_child_size); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(row_count); -} - -static unique_ptr MapBind(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - - if (arguments.size() != 2 && !arguments.empty()) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } - - bool is_null = false; - if (arguments.empty()) { - is_null = true; - } - if (!is_null) { - auto key_id = arguments[0]->return_type.id(); - auto value_id = arguments[1]->return_type.id(); - if (key_id == LogicalTypeId::SQLNULL || value_id == LogicalTypeId::SQLNULL) { - is_null = true; - } - } - - if (is_null) { - bound_function.return_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); - } - - // bind a MAP with key-value pairs - D_ASSERT(arguments.size() == 2); - if (arguments[0]->return_type.id() != LogicalTypeId::LIST) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } - if (arguments[1]->return_type.id() != LogicalTypeId::LIST) { - MapVector::EvalMapInvalidReason(MapInvalidReason::INVALID_PARAMS); - } - - auto key_type = ListType::GetChildType(arguments[0]->return_type); - auto value_type = ListType::GetChildType(arguments[1]->return_type); - - bound_function.return_type = LogicalType::MAP(key_type, value_type); - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapFun::GetFunction() { - ScalarFunction fun({}, LogicalTypeId::MAP, MapFunction, MapBind); - fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(fun); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp b/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp deleted file mode 100644 index c958f41b9..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map_concat.cpp +++ /dev/null @@ -1,200 +0,0 @@ -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "core_functions/scalar/map_functions.hpp" - -namespace duckdb { - -namespace { - -struct MapKeyIndexPair { - MapKeyIndexPair(idx_t map, idx_t key) : map_index(map), key_index(key) { - } - // The index of the map that this key comes from - idx_t map_index; - // The index within the maps key_list - idx_t key_index; -}; - -} // namespace - -vector GetListEntries(vector keys, vector values) { - D_ASSERT(keys.size() == values.size()); - vector entries; - for (idx_t i = 0; i < keys.size(); i++) { - child_list_t children; - children.emplace_back(make_pair("key", std::move(keys[i]))); - children.emplace_back(make_pair("value", std::move(values[i]))); - entries.push_back(Value::STRUCT(std::move(children))); - } - return entries; -} - -static void MapConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - if (result.GetType().id() == LogicalTypeId::SQLNULL) { - // All inputs are NULL, just return NULL - auto &validity = FlatVector::Validity(result); - validity.SetInvalid(0); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - return; - } - D_ASSERT(result.GetType().id() == LogicalTypeId::MAP); - auto count = args.size(); - - auto map_count = args.ColumnCount(); - vector map_formats(map_count); - for (idx_t i = 0; i < map_count; i++) { - auto &map = args.data[i]; - map.ToUnifiedFormat(count, map_formats[i]); - } - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < count; i++) { - // Loop through all the maps per list - // we cant do better because all the entries of the child vector have to be contiguous - // so we cant start the next row before we have finished the one before it - auto &result_entry = result_data[i]; - vector index_to_map; - vector keys_list; - bool all_null = true; - for (idx_t map_idx = 0; map_idx < map_count; map_idx++) { - if (args.data[map_idx].GetType().id() == LogicalTypeId::SQLNULL) { - continue; - } - - auto &map_format = map_formats[map_idx]; - auto index = map_format.sel->get_index(i); - if (!map_format.validity.RowIsValid(index)) { - continue; - } - - all_null = false; - auto &keys = MapVector::GetKeys(args.data[map_idx]); - auto entry = UnifiedVectorFormat::GetData(map_format)[index]; - - // Update the list for this row - for (idx_t list_idx = 0; list_idx < entry.length; list_idx++) { - auto key_index = entry.offset + list_idx; - auto key = keys.GetValue(key_index); - auto entry = std::find(keys_list.begin(), keys_list.end(), key); - if (entry == keys_list.end()) { - // Result list does not contain this value yet - keys_list.push_back(key); - index_to_map.emplace_back(map_idx, key_index); - } else { - // Result list already contains this, update where to find the value at - auto distance = std::distance(keys_list.begin(), entry); - auto &mapping = *(index_to_map.begin() + distance); - mapping.key_index = key_index; - mapping.map_index = map_idx; - } - } - } - - result_entry.offset = ListVector::GetListSize(result); - result_entry.length = keys_list.size(); - if (all_null) { - D_ASSERT(keys_list.empty() && index_to_map.empty()); - FlatVector::SetNull(result, i, true); - continue; - } - - vector values_list; - D_ASSERT(keys_list.size() == index_to_map.size()); - // Get the values from the mapping - for (auto &mapping : index_to_map) { - auto &map = args.data[mapping.map_index]; - auto &values = MapVector::GetValues(map); - values_list.push_back(values.GetValue(mapping.key_index)); - } - D_ASSERT(values_list.size() == keys_list.size()); - auto list_entries = GetListEntries(std::move(keys_list), std::move(values_list)); - for (auto &list_entry : list_entries) { - ListVector::PushBack(result, list_entry); - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static bool IsEmptyMap(const LogicalType &map) { - D_ASSERT(map.id() == LogicalTypeId::MAP); - auto &key_type = MapType::KeyType(map); - auto &value_type = MapType::ValueType(map); - return key_type.id() == LogicalType::SQLNULL && value_type.id() == LogicalType::SQLNULL; -} - -static unique_ptr MapConcatBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto arg_count = arguments.size(); - if (arg_count < 2) { - throw InvalidInputException("The provided amount of arguments is incorrect, please provide 2 or more maps"); - } - - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - LogicalType expected = LogicalType::SQLNULL; - - bool is_null = true; - // Check and verify that all the maps are of the same type - for (idx_t i = 0; i < arg_count; i++) { - auto &arg = arguments[i]; - auto &map = arg->return_type; - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - if (map.id() == LogicalTypeId::SQLNULL) { - // The maps are allowed to be NULL - continue; - } - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("MAP_CONCAT only takes map arguments"); - } - is_null = false; - if (IsEmptyMap(map)) { - // Map is allowed to be empty - continue; - } - - if (expected.id() == LogicalTypeId::SQLNULL) { - expected = map; - } else if (map != expected) { - throw InvalidInputException( - "'value' type of map differs between arguments, expected '%s', found '%s' instead", expected.ToString(), - map.ToString()); - } - } - - if (expected.id() == LogicalTypeId::SQLNULL && is_null == false) { - expected = LogicalType::MAP(LogicalType::SQLNULL, LogicalType::SQLNULL); - } - bound_function.return_type = expected; - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapConcatFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun("map_concat", {}, LogicalTypeId::LIST, MapConcatFunction, MapConcatBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp deleted file mode 100644 index 487fd75fa..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map_entries.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -// Reverse of map_from_entries -static void MapEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - - auto &map = args.data[0]; - if (map.GetType().id() == LogicalTypeId::SQLNULL) { - // Input is a constant NULL - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - MapUtil::ReinterpretMap(result, map, count); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static LogicalType CreateReturnType(const LogicalType &map) { - auto &key_type = MapType::KeyType(map); - auto &value_type = MapType::ValueType(map); - - child_list_t child_types; - child_types.push_back(make_pair("key", key_type)); - child_types.push_back(make_pair("value", value_type)); - - auto row_type = LogicalType::STRUCT(child_types); - return LogicalType::LIST(row_type); -} - -static unique_ptr MapEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() == LogicalTypeId::SQLNULL) { - // Input is NULL, output is STRUCT(NULL, NULL)[] - auto map_type = LogicalType::MAP(LogicalTypeId::SQLNULL, LogicalTypeId::SQLNULL); - bound_function.return_type = CreateReturnType(map_type); - return make_uniq(bound_function.return_type); - } - - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - bound_function.return_type = CreateReturnType(map); - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::LIST, MapEntriesFunction, MapEntriesBind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp b/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp deleted file mode 100644 index 170f2b7da..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map_extract.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/list/contains_or_position.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static unique_ptr MapExtractBind(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 2) { - throw BinderException("MAP_EXTRACT must have exactly two arguments"); - } - - auto &map_type = arguments[0]->return_type; - auto &input_type = arguments[1]->return_type; - - if (map_type.id() == LogicalTypeId::SQLNULL) { - bound_function.return_type = LogicalTypeId::SQLNULL; - return make_uniq(bound_function.return_type); - } - - if (map_type.id() != LogicalTypeId::MAP) { - throw BinderException("MAP_EXTRACT can only operate on MAPs"); - } - auto &value_type = MapType::ValueType(map_type); - - //! Here we have to construct the List Type that will be returned - bound_function.return_type = value_type; - auto key_type = MapType::KeyType(map_type); - if (key_type.id() != LogicalTypeId::SQLNULL && input_type.id() != LogicalTypeId::SQLNULL) { - bound_function.arguments[1] = MapType::KeyType(map_type); - } - return make_uniq(bound_function.return_type); -} - -static void MapExtractFunc(DataChunk &args, ExpressionState &state, Vector &result) { - const auto count = args.size(); - - auto &map_vec = args.data[0]; - auto &arg_vec = args.data[1]; - - const auto map_is_null = map_vec.GetType().id() == LogicalTypeId::SQLNULL; - const auto arg_is_null = arg_vec.GetType().id() == LogicalTypeId::SQLNULL; - - if (map_is_null || arg_is_null) { - // Short-circuit if either the map or the arg is NULL - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - result.Verify(count); - return; - } - - auto &key_vec = MapVector::GetKeys(map_vec); - auto &val_vec = MapVector::GetValues(map_vec); - - // Collect the matching positions - Vector pos_vec(LogicalType::INTEGER, count); - ListSearchOp(map_vec, key_vec, arg_vec, pos_vec, args.size()); - - UnifiedVectorFormat pos_format; - UnifiedVectorFormat lst_format; - - pos_vec.ToUnifiedFormat(count, pos_format); - map_vec.ToUnifiedFormat(count, lst_format); - - const auto pos_data = UnifiedVectorFormat::GetData(pos_format); - const auto inc_list_data = ListVector::GetData(map_vec); - - auto &result_validity = FlatVector::Validity(result); - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto lst_idx = lst_format.sel->get_index(row_idx); - if (!lst_format.validity.RowIsValid(lst_idx)) { - FlatVector::SetNull(result, row_idx, true); - continue; - } - - const auto pos_idx = pos_format.sel->get_index(row_idx); - if (!pos_format.validity.RowIsValid(pos_idx)) { - // We didnt find the key in the map, so return NULL - result_validity.SetInvalid(row_idx); - continue; - } - - // Compute the actual position of the value in the map value vector - const auto pos = inc_list_data[lst_idx].offset + UnsafeNumericCast(pos_data[pos_idx] - 1); - VectorOperations::Copy(val_vec, result, pos + 1, pos, row_idx); - } - - if (args.size() == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(count); -} - -ScalarFunction MapExtractFun::GetFunction() { - ScalarFunction fun({LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, MapExtractFunc, MapExtractBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp b/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp deleted file mode 100644 index edbe1d4fb..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map_from_entries.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -static void MapFromEntriesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto count = args.size(); - - MapUtil::ReinterpretMap(result, args.data[0], count); - MapVector::MapConversionVerify(result, count); - result.Verify(count); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr MapFromEntriesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() != 1) { - throw InvalidInputException("The input argument must be a list of structs."); - } - auto &list = arguments[0]->return_type; - - if (list.id() == LogicalTypeId::UNKNOWN) { - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (list.id() != LogicalTypeId::LIST) { - throw InvalidInputException("The provided argument is not a list of structs"); - } - auto &elem_type = ListType::GetChildType(list); - if (elem_type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("The elements of the list must be structs"); - } - auto &children = StructType::GetChildTypes(elem_type); - if (children.size() != 2) { - throw InvalidInputException("The provided struct type should only contain 2 fields, a key and a value"); - } - - bound_function.return_type = LogicalType::MAP(elem_type); - return make_uniq(bound_function.return_type); -} - -ScalarFunction MapFromEntriesFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction fun({}, LogicalTypeId::MAP, MapFromEntriesFunction, MapFromEntriesBind); - fun.null_handling = FunctionNullHandling::DEFAULT_NULL_HANDLING; - fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(fun); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp b/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp deleted file mode 100644 index 6d99a353e..000000000 --- a/src/duckdb/extension/core_functions/scalar/map/map_keys_values.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "core_functions/scalar/map_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -static void MapKeyValueFunction(DataChunk &args, ExpressionState &state, Vector &result, - Vector &(*get_child_vector)(Vector &)) { - auto &map = args.data[0]; - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - if (map.GetType().id() == LogicalTypeId::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - auto count = args.size(); - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); - auto child = get_child_vector(map); - - auto &entries = ListVector::GetEntry(result); - entries.Reference(child); - - UnifiedVectorFormat map_data; - map.ToUnifiedFormat(count, map_data); - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - FlatVector::SetData(result, map_data.data); - FlatVector::SetValidity(result, map_data.validity); - auto list_size = ListVector::GetListSize(map); - ListVector::SetListSize(result, list_size); - if (map.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - result.Slice(*map_data.sel, count); - } - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - result.Verify(count); -} - -static void MapKeysFunction(DataChunk &args, ExpressionState &state, Vector &result) { - MapKeyValueFunction(args, state, result, MapVector::GetKeys); -} - -static void MapValuesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - MapKeyValueFunction(args, state, result, MapVector::GetValues); -} - -static unique_ptr MapKeyValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - const LogicalType &(*type_func)(const LogicalType &)) { - if (arguments.size() != 1) { - throw InvalidInputException("Too many arguments provided, only expecting a single map"); - } - auto &map = arguments[0]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - // Prepared statement - bound_function.arguments.emplace_back(LogicalTypeId::UNKNOWN); - bound_function.return_type = LogicalType(LogicalTypeId::SQLNULL); - return nullptr; - } - - if (map.id() == LogicalTypeId::SQLNULL) { - // Input is NULL, output is NULL[] - bound_function.return_type = LogicalType::LIST(LogicalTypeId::SQLNULL); - return make_uniq(bound_function.return_type); - } - - if (map.id() != LogicalTypeId::MAP) { - throw InvalidInputException("The provided argument is not a map"); - } - - auto &type = type_func(map); - - bound_function.return_type = LogicalType::LIST(type); - return make_uniq(bound_function.return_type); -} - -static unique_ptr MapKeysBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::KeyType); -} - -static unique_ptr MapValuesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return MapKeyValueBind(context, bound_function, arguments, MapType::ValueType); -} - -ScalarFunction MapKeysFun::GetFunction() { - //! the arguments and return types are actually set in the binder function - ScalarFunction function({}, LogicalTypeId::LIST, MapKeysFunction, MapKeysBind); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(function); - function.varargs = LogicalType::ANY; - return function; -} - -ScalarFunction MapValuesFun::GetFunction() { - ScalarFunction function({}, LogicalTypeId::LIST, MapValuesFunction, MapValuesBind); - function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(function); - function.varargs = LogicalType::ANY; - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp b/src/duckdb/extension/core_functions/scalar/math/numeric.cpp deleted file mode 100644 index 47eed7a79..000000000 --- a/src/duckdb/extension/core_functions/scalar/math/numeric.cpp +++ /dev/null @@ -1,1469 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/likely.hpp" -#include "duckdb/common/operator/abs.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/numeric_binary_operators.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/validity_mask.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "core_functions/scalar/math_functions.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerUnaryFunctionFixedReturn(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::UnaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunctionFixedReturn"); - } - return function; -} - -//===--------------------------------------------------------------------===// -// nextafter -//===--------------------------------------------------------------------===// -struct NextAfterOperator { - template - static inline TR Operation(TA base, TB exponent) { - throw NotImplementedException("Unimplemented type for NextAfter Function"); - } - - template - static inline double Operation(double input, double approximate_to) { - return nextafter(input, approximate_to); - } - template - static inline float Operation(float input, float approximate_to) { - return nextafterf(input, approximate_to); - } -}; - -ScalarFunctionSet NextAfterFun::GetFunctions() { - ScalarFunctionSet next_after_fun; - next_after_fun.AddFunction( - ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction)); - next_after_fun.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, - ScalarFunction::BinaryFunction)); - return next_after_fun; -} - -//===--------------------------------------------------------------------===// -// abs -//===--------------------------------------------------------------------===// -static unique_ptr PropagateAbsStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - auto &lstats = child_stats[0]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(lstats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT16: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT32: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - case PhysicalType::INT64: - potential_overflow = NumericStats::Min(lstats).GetValue() == NumericLimits::Minimum(); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } else { - // no potential overflow - - // compute stats - auto current_min = NumericStats::Min(lstats).GetValue(); - auto current_max = NumericStats::Max(lstats).GetValue(); - - int64_t min_val, max_val; - - if (current_min < 0 && current_max < 0) { - // if both min and max are below zero, then min=abs(cur_max) and max=abs(cur_min) - min_val = AbsValue(current_max); - max_val = AbsValue(current_min); - } else if (current_min < 0) { - D_ASSERT(current_max >= 0); - // if min is below zero and max is above 0, then min=0 and max=max(cur_max, abs(cur_min)) - min_val = 0; - max_val = MaxValue(AbsValue(current_min), current_max); - } else { - // if both current_min and current_max are > 0, then the abs is a no-op and can be removed entirely - *input.expr_ptr = std::move(input.expr.children[0]); - return child_stats[0].ToUnique(); - } - new_min = Value::Numeric(expr.return_type, min_val); - new_max = Value::Numeric(expr.return_type, max_val); - expr.function.function = ScalarFunction::GetScalarUnaryFunction(expr.return_type); - } - auto stats = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(stats, new_min); - NumericStats::SetMax(stats, new_max); - stats.CopyValidity(lstats); - return stats.ToUnique(); -} - -template -unique_ptr DecimalUnaryOpBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); - break; - case PhysicalType::INT32: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); - break; - case PhysicalType::INT64: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); - break; - default: - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); - break; - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; - return nullptr; -} - -ScalarFunctionSet AbsOperatorFun::GetFunctions() { - ScalarFunctionSet abs; - for (auto &type : LogicalType::Numeric()) { - switch (type.id()) { - case LogicalTypeId::DECIMAL: - abs.AddFunction(ScalarFunction({type}, type, nullptr, DecimalUnaryOpBind)); - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: { - ScalarFunction function({type}, type, ScalarFunction::GetScalarUnaryFunction(type)); - function.statistics = PropagateAbsStats; - abs.AddFunction(function); - break; - } - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::NopFunction)); - break; - default: - abs.AddFunction(ScalarFunction({type}, type, ScalarFunction::GetScalarUnaryFunction(type))); - break; - } - } - for (auto &func : abs.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return abs; -} - -//===--------------------------------------------------------------------===// -// bit_count -//===--------------------------------------------------------------------===// -struct BitCntOperator { - template - static inline TR Operation(TA input) { - using TU = typename std::make_unsigned::type; - TR count = 0; - for (auto value = TU(input); value; ++count) { - value &= (value - 1); - } - return count; - } -}; - -struct HugeIntBitCntOperator { - template - static inline TR Operation(TA input) { - using TU = typename std::make_unsigned::type; - TR count = 0; - - for (auto value = TU(input.upper); value; ++count) { - value &= (value - 1); - } - for (auto value = TU(input.lower); value; ++count) { - value &= (value - 1); - } - return count; - } -}; - -struct BitStringBitCntOperator { - template - static inline TR Operation(TA input) { - TR count = Bit::BitCount(input); - return count; - } -}; - -ScalarFunctionSet BitCountFun::GetFunctions() { - ScalarFunctionSet functions; - functions.AddFunction(ScalarFunction({LogicalType::TINYINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::SMALLINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::INTEGER}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::BIGINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::TINYINT, - ScalarFunction::UnaryFunction)); - functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - return functions; -} - -//===--------------------------------------------------------------------===// -// sign -//===--------------------------------------------------------------------===// -struct SignOperator { - template - static TR Operation(TA input) { - if (input == TA(0)) { - return 0; - } else if (input > TA(0)) { - return 1; - } else { - return -1; - } - } -}; - -template <> -int8_t SignOperator::Operation(float input) { - if (input == 0 || Value::IsNan(input)) { - return 0; - } else if (input > 0) { - return 1; - } else { - return -1; - } -} - -template <> -int8_t SignOperator::Operation(double input) { - if (input == 0 || Value::IsNan(input)) { - return 0; - } else if (input > 0) { - return 1; - } else { - return -1; - } -} - -ScalarFunctionSet SignFun::GetFunctions() { - ScalarFunctionSet sign; - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - continue; - } else { - sign.AddFunction( - ScalarFunction({type}, LogicalType::TINYINT, - ScalarFunction::GetScalarUnaryFunctionFixedReturn(type))); - } - } - return sign; -} - -//===--------------------------------------------------------------------===// -// ceil -//===--------------------------------------------------------------------===// -struct CeilOperator { - template - static inline TR Operation(TA left) { - return std::ceil(left); - } -}; - -template -static void GenericRoundFunctionDecimal(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - OP::template Operation(input, DecimalType::GetScale(func_expr.children[0]->return_type), result); -} - -template -unique_ptr BindGenericRoundFunctionDecimal(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // ceil essentially removes the scale - auto &decimal_type = arguments[0]->return_type; - auto scale = DecimalType::GetScale(decimal_type); - auto width = DecimalType::GetWidth(decimal_type); - if (scale == 0) { - bound_function.function = ScalarFunction::NopFunction; - } else { - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = GenericRoundFunctionDecimal; - break; - case PhysicalType::INT32: - bound_function.function = GenericRoundFunctionDecimal; - break; - case PhysicalType::INT64: - bound_function.function = GenericRoundFunctionDecimal; - break; - default: - bound_function.function = GenericRoundFunctionDecimal; - break; - } - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, 0); - return nullptr; -} - -struct CeilDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input <= 0) { - // below 0 we floor the number (e.g. -10.5 -> -10) - return UnsafeNumericCast(input / power_of_ten); - } else { - // above 0 we ceil the number - return UnsafeNumericCast(((input - 1) / power_of_ten) + 1); - } - }); - } -}; - -ScalarFunctionSet CeilFun::GetFunctions() { - ScalarFunctionSet ceil; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - if (type.IsIntegral()) { - // no ceil for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - default: - throw InternalException("Unimplemented numeric type for function \"ceil\""); - } - ceil.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return ceil; -} - -//===--------------------------------------------------------------------===// -// floor -//===--------------------------------------------------------------------===// -struct FloorOperator { - template - static inline TR Operation(TA left) { - return std::floor(left); - } -}; - -struct FloorDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - // below 0 we ceil the number (e.g. -10.5 -> -11) - return UnsafeNumericCast(((input + 1) / power_of_ten) - 1); - } else { - // above 0 we floor the number - return UnsafeNumericCast(input / power_of_ten); - } - }); - } -}; - -ScalarFunctionSet FloorFun::GetFunctions() { - ScalarFunctionSet floor; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - if (type.IsIntegral()) { - // no floor for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - default: - throw InternalException("Unimplemented numeric type for function \"floor\""); - } - floor.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return floor; -} - -//===--------------------------------------------------------------------===// -// trunc -//===--------------------------------------------------------------------===// -struct TruncOperator { - // Integer truncation is a NOP - template - static inline TR Operation(TA left) { - return std::trunc(left); - } -}; - -struct TruncDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - // Always floor - return UnsafeNumericCast((input / power_of_ten)); - }); - } -}; - -ScalarFunctionSet TruncFun::GetFunctions() { - ScalarFunctionSet trunc; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t func = nullptr; - bind_scalar_function_t bind_func = nullptr; - // Truncation of integers gets generated by some tools (e.g., Tableau/JDBC:Postgres) - switch (type.id()) { - case LogicalTypeId::FLOAT: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DOUBLE: - func = ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - func = ScalarFunction::NopFunction; - break; - default: - throw InternalException("Unimplemented numeric type for function \"trunc\""); - } - trunc.AddFunction(ScalarFunction({type}, type, func, bind_func)); - } - return trunc; -} - -//===--------------------------------------------------------------------===// -// round -//===--------------------------------------------------------------------===// -struct RoundOperatorPrecision { - template - static inline TR Operation(TA input, TB precision) { - double rounded_value; - if (precision < 0) { - double modifier = std::pow(10, -TA(precision)); - rounded_value = (std::round(input / modifier)) * modifier; - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return 0; - } - } else { - double modifier = std::pow(10, TA(precision)); - rounded_value = (std::round(input * modifier)) / modifier; - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return input; - } - } - return LossyNumericCast(rounded_value); - } -}; - -struct RoundOperator { - template - static inline TR Operation(TA input) { - double rounded_value = round(input); - if (std::isinf(rounded_value) || std::isnan(rounded_value)) { - return input; - } - return LossyNumericCast(rounded_value); - } -}; - -struct RoundDecimalOperator { - template - static void Operation(DataChunk &input, uint8_t scale, Vector &result) { - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[scale]); - T addition = power_of_ten / 2; - // regular round rounds towards the nearest number - // in case of a tie we round away from zero - // i.e. -10.5 -> -11, 10.5 -> 11 - // we implement this by adding (positive) or subtracting (negative) 0.5 - // and then flooring the number - // e.g. 10.5 + 0.5 = 11, floor(11) = 11 - // 10.4 + 0.5 = 10.9, floor(10.9) = 10 - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return UnsafeNumericCast(input / power_of_ten); - }); - } -}; - -struct RoundPrecisionFunctionData : public FunctionData { - explicit RoundPrecisionFunctionData(int32_t target_scale) : target_scale(target_scale) { - } - - int32_t target_scale; - - unique_ptr Copy() const override { - return make_uniq(target_scale); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return target_scale == other.target_scale; - } -}; - -template -static void DecimalRoundNegativePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - auto width = DecimalType::GetWidth(func_expr.children[0]->return_type); - if (info.target_scale <= -int32_t(width - source_scale)) { - // scale too big for width - result.SetVectorType(VectorType::CONSTANT_VECTOR); - result.SetValue(0, Value::INTEGER(0)); - return; - } - T divide_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale + source_scale]); - T multiply_power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[-info.target_scale]); - T addition = divide_power_of_ten / 2; - - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return UnsafeNumericCast(input / divide_power_of_ten * multiply_power_of_ten); - }); -} - -template -static void DecimalRoundPositivePrecisionFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto source_scale = DecimalType::GetScale(func_expr.children[0]->return_type); - T power_of_ten = UnsafeNumericCast(POWERS_OF_TEN_CLASS::POWERS_OF_TEN[source_scale - info.target_scale]); - T addition = power_of_ten / 2; - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](T input) { - if (input < 0) { - input -= addition; - } else { - input += addition; - } - return UnsafeNumericCast(input / power_of_ten); - }); -} - -unique_ptr BindDecimalRoundPrecision(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto &decimal_type = arguments[0]->return_type; - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - } - Value val = ExpressionExecutor::EvaluateScalar(context, *arguments[1]).DefaultCastAs(LogicalType::INTEGER); - if (val.IsNull()) { - throw NotImplementedException("ROUND(DECIMAL, INTEGER) with non-constant precision is not supported"); - } - // our new precision becomes the round value - // e.g. ROUND(DECIMAL(18,3), 1) -> DECIMAL(18,1) - // but ONLY if the round value is positive - // if it is negative the scale becomes zero - // i.e. ROUND(DECIMAL(18,3), -1) -> DECIMAL(18,0) - int32_t round_value = IntegerValue::Get(val); - uint8_t target_scale; - auto width = DecimalType::GetWidth(decimal_type); - auto scale = DecimalType::GetScale(decimal_type); - if (round_value < 0) { - target_scale = 0; - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundNegativePrecisionFunction; - break; - } - } else { - if (round_value >= (int32_t)scale) { - // if round_value is bigger than or equal to scale we do nothing - bound_function.function = ScalarFunction::NopFunction; - target_scale = scale; - } else { - target_scale = NumericCast(round_value); - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT32: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - case PhysicalType::INT64: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - default: - bound_function.function = DecimalRoundPositivePrecisionFunction; - break; - } - } - } - bound_function.arguments[0] = decimal_type; - bound_function.return_type = LogicalType::DECIMAL(width, target_scale); - return make_uniq(round_value); -} - -ScalarFunctionSet RoundFun::GetFunctions() { - ScalarFunctionSet round; - for (auto &type : LogicalType::Numeric()) { - scalar_function_t round_prec_func = nullptr; - scalar_function_t round_func = nullptr; - bind_scalar_function_t bind_func = nullptr; - bind_scalar_function_t bind_prec_func = nullptr; - if (type.IsIntegral()) { - // no round for integral numbers - continue; - } - switch (type.id()) { - case LogicalTypeId::FLOAT: - round_func = ScalarFunction::UnaryFunction; - round_prec_func = ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::DOUBLE: - round_func = ScalarFunction::UnaryFunction; - round_prec_func = ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::DECIMAL: - bind_func = BindGenericRoundFunctionDecimal; - bind_prec_func = BindDecimalRoundPrecision; - break; - default: - throw InternalException("Unimplemented numeric type for function \"floor\""); - } - round.AddFunction(ScalarFunction({type}, type, round_func, bind_func)); - round.AddFunction(ScalarFunction({type, LogicalType::INTEGER}, type, round_prec_func, bind_prec_func)); - } - return round; -} - -//===--------------------------------------------------------------------===// -// exp -//===--------------------------------------------------------------------===// -struct ExpOperator { - template - static inline TR Operation(TA left) { - return std::exp(left); - } -}; - -ScalarFunction ExpFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// pow -//===--------------------------------------------------------------------===// -struct PowOperator { - template - static inline TR Operation(TA base, TB exponent) { - return std::pow(base, exponent); - } -}; - -ScalarFunction PowOperatorFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction); -} - -//===--------------------------------------------------------------------===// -// sqrt -//===--------------------------------------------------------------------===// -struct SqrtOperator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take square root of a negative number"); - } - return std::sqrt(input); - } -}; - -ScalarFunction SqrtFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// cbrt -//===--------------------------------------------------------------------===// -struct CbRtOperator { - template - static inline TR Operation(TA left) { - return std::cbrt(left); - } -}; - -ScalarFunction CbrtFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// ln -//===--------------------------------------------------------------------===// - -struct LnOperator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log(input); - } -}; - -ScalarFunction LnFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// log -//===--------------------------------------------------------------------===// -struct Log10Operator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log10(input); - } -}; - -ScalarFunction Log10Fun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// log with base -//===--------------------------------------------------------------------===// -struct LogBaseOperator { - template - static inline TR Operation(TA b, TB x) { - auto divisor = Log10Operator::Operation(b); - if (divisor == 0) { - throw OutOfRangeException("divison by zero in based logarithm"); - } - return Log10Operator::Operation(x) / divisor; - } -}; - -ScalarFunctionSet LogFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction)); - for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return funcs; -} - -//===--------------------------------------------------------------------===// -// log2 -//===--------------------------------------------------------------------===// -struct Log2Operator { - template - static inline TR Operation(TA input) { - if (input < 0) { - throw OutOfRangeException("cannot take logarithm of a negative number"); - } - if (input == 0) { - throw OutOfRangeException("cannot take logarithm of zero"); - } - return std::log2(input); - } -}; - -ScalarFunction Log2Fun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// pi -//===--------------------------------------------------------------------===// -static void PiFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - Value pi_value = Value::DOUBLE(PI); - result.Reference(pi_value); -} - -ScalarFunction PiFun::GetFunction() { - return ScalarFunction({}, LogicalType::DOUBLE, PiFunction); -} - -//===--------------------------------------------------------------------===// -// degrees -//===--------------------------------------------------------------------===// -struct DegreesOperator { - template - static inline TR Operation(TA left) { - return left * (180 / PI); - } -}; - -ScalarFunction DegreesFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// radians -//===--------------------------------------------------------------------===// -struct RadiansOperator { - template - static inline TR Operation(TA left) { - return left * (PI / 180); - } -}; - -ScalarFunction RadiansFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// isnan -//===--------------------------------------------------------------------===// -struct IsNanOperator { - template - static inline TR Operation(TA input) { - return Value::IsNan(input); - } -}; - -ScalarFunctionSet IsNanFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// signbit -//===--------------------------------------------------------------------===// -struct SignBitOperator { - template - static inline TR Operation(TA input) { - return std::signbit(input); - } -}; - -ScalarFunctionSet SignBitFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// isinf -//===--------------------------------------------------------------------===// -struct IsInfiniteOperator { - template - static inline TR Operation(TA input) { - return !Value::IsNan(input) && !Value::IsFinite(input); - } -}; - -template <> -bool IsInfiniteOperator::Operation(date_t input) { - return !Value::IsFinite(input); -} - -template <> -bool IsInfiniteOperator::Operation(timestamp_t input) { - return !Value::IsFinite(input); -} - -ScalarFunctionSet IsInfiniteFun::GetFunctions() { - ScalarFunctionSet funcs("isinf"); - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// isfinite -//===--------------------------------------------------------------------===// -struct IsFiniteOperator { - template - static inline TR Operation(TA input) { - return Value::IsFinite(input); - } -}; - -ScalarFunctionSet IsFiniteFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction(ScalarFunction({LogicalType::FLOAT}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DOUBLE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::DATE}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - funcs.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_TZ}, LogicalType::BOOLEAN, - ScalarFunction::UnaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// sin -//===--------------------------------------------------------------------===// -template -struct NoInfiniteDoubleWrapper { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { - if (Value::IsNan(input)) { - return input; - } - throw OutOfRangeException("input value %lf is out of range for numeric function", input); - } - return OP::template Operation(input); - } -}; - -struct SinOperator { - template - static inline TR Operation(TA input) { - return std::sin(input); - } -}; - -ScalarFunction SinFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// cos -//===--------------------------------------------------------------------===// -struct CosOperator { - template - static inline TR Operation(TA input) { - return (double)std::cos(input); - } -}; - -ScalarFunction CosFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// tan -//===--------------------------------------------------------------------===// -struct TanOperator { - template - static inline TR Operation(TA input) { - return (double)std::tan(input); - } -}; - -ScalarFunction TanFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// asin -//===--------------------------------------------------------------------===// -struct ASinOperator { - template - static inline TR Operation(TA input) { - if (input < -1 || input > 1) { - throw InvalidInputException("ASIN is undefined outside [-1,1]"); - } - return (double)std::asin(input); - } -}; - -ScalarFunction AsinFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// atan -//===--------------------------------------------------------------------===// -struct ATanOperator { - template - static inline TR Operation(TA input) { - return (double)std::atan(input); - } -}; - -ScalarFunction AtanFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// atan2 -//===--------------------------------------------------------------------===// -struct ATan2 { - template - static inline TR Operation(TA left, TB right) { - return (double)std::atan2(left, right); - } -}; - -ScalarFunction Atan2Fun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::BinaryFunction); -} - -//===--------------------------------------------------------------------===// -// acos -//===--------------------------------------------------------------------===// -struct ACos { - template - static inline TR Operation(TA input) { - if (input < -1 || input > 1) { - throw InvalidInputException("ACOS is undefined outside [-1,1]"); - } - return (double)std::acos(input); - } -}; - -ScalarFunction AcosFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// cosh -//===--------------------------------------------------------------------===// -struct CoshOperator { - template - static inline TR Operation(TA input) { - return (double)std::cosh(input); - } -}; - -ScalarFunction CoshFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// acosh -//===--------------------------------------------------------------------===// -struct AcoshOperator { - template - static inline TR Operation(TA input) { - return (double)std::acosh(input); - } -}; - -ScalarFunction AcoshFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// sinh -//===--------------------------------------------------------------------===// -struct SinhOperator { - template - static inline TR Operation(TA input) { - return (double)std::sinh(input); - } -}; - -ScalarFunction SinhFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// asinh -//===--------------------------------------------------------------------===// -struct AsinhOperator { - template - static inline TR Operation(TA input) { - return (double)std::asinh(input); - } -}; - -ScalarFunction AsinhFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// tanh -//===--------------------------------------------------------------------===// -struct TanhOperator { - template - static inline TR Operation(TA input) { - return (double)std::tanh(input); - } -}; - -ScalarFunction TanhFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// atanh -//===--------------------------------------------------------------------===// -struct AtanhOperator { - template - static inline TR Operation(TA input) { - if (input < -1 || input > 1) { - throw InvalidInputException("ATANH is undefined outside [-1,1]"); - } - if (input == -1 || input == 1) { - return INFINITY; - } - return (double)std::atanh(input); - } -}; - -ScalarFunction AtanhFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// cot -//===--------------------------------------------------------------------===// -template -struct NoInfiniteNoZeroDoubleWrapper { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - if (DUCKDB_UNLIKELY(!Value::IsFinite(input))) { - if (Value::IsNan(input)) { - return input; - } - throw OutOfRangeException("input value %lf is out of range for numeric function", input); - } - if (DUCKDB_UNLIKELY((double)input == 0.0 || (double)input == -0.0)) { - throw OutOfRangeException("input value %lf is out of range for numeric function cotangent", input); - } - return OP::template Operation(input); - } -}; - -struct CotOperator { - template - static inline TR Operation(TA input) { - return 1.0 / (double)std::tan(input); - } -}; - -ScalarFunction CotFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction>); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// gamma -//===--------------------------------------------------------------------===// -struct GammaOperator { - template - static inline TR Operation(TA input) { - if (input == 0) { - throw OutOfRangeException("cannot take gamma of zero"); - } - return std::tgamma(input); - } -}; - -ScalarFunction GammaFun::GetFunction() { - auto func = ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(func); - return func; -} - -//===--------------------------------------------------------------------===// -// gamma -//===--------------------------------------------------------------------===// -struct LogGammaOperator { - template - static inline TR Operation(TA input) { - if (input == 0) { - throw OutOfRangeException("cannot take log gamma of zero"); - } - return std::lgamma(input); - } -}; - -ScalarFunction LogGammaFun::GetFunction() { - ScalarFunction function({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// factorial(), ! -//===--------------------------------------------------------------------===// -struct FactorialOperator { - template - static inline TR Operation(TA left) { - TR ret = 1; - for (TA i = 2; i <= left; i++) { - if (!TryMultiplyOperator::Operation(ret, TR(i), ret)) { - throw OutOfRangeException("Value out of range"); - } - } - return ret; - } -}; - -ScalarFunction FactorialOperatorFun::GetFunction() { - ScalarFunction function({LogicalType::INTEGER}, LogicalType::HUGEINT, - ScalarFunction::UnaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -//===--------------------------------------------------------------------===// -// even -//===--------------------------------------------------------------------===// -struct EvenOperator { - template - static inline TR Operation(TA left) { - double value; - if (left >= 0) { - value = std::ceil(left); - } else { - value = std::ceil(-left); - value = -value; - } - if (std::floor(value / 2) * 2 != value) { - if (left >= 0) { - return value += 1; - } - return value -= 1; - } - return value; - } -}; - -ScalarFunction EvenFun::GetFunction() { - return ScalarFunction({LogicalType::DOUBLE}, LogicalType::DOUBLE, - ScalarFunction::UnaryFunction); -} - -//===--------------------------------------------------------------------===// -// gcd -//===--------------------------------------------------------------------===// - -// should be replaced with std::gcd in a newer C++ standard -template -TA GreatestCommonDivisor(TA left, TA right) { - TA a = left; - TA b = right; - - // This protects the following modulo operations from a corner case, - // where we would get a runtime error due to an integer overflow. - if ((left == NumericLimits::Minimum() && right == -1) || - (left == -1 && right == NumericLimits::Minimum())) { - return 1; - } - - while (true) { - if (a == 0) { - return TryAbsOperator::Operation(b); - } - b %= a; - - if (b == 0) { - return TryAbsOperator::Operation(a); - } - a %= b; - } -} - -struct GreatestCommonDivisorOperator { - template - static inline TR Operation(TA left, TB right) { - return GreatestCommonDivisor(left, right); - } -}; - -ScalarFunctionSet GreatestCommonDivisorFun::GetFunctions() { - ScalarFunctionSet funcs; - funcs.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction)); - funcs.AddFunction( - ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, - ScalarFunction::BinaryFunction)); - return funcs; -} - -//===--------------------------------------------------------------------===// -// lcm -//===--------------------------------------------------------------------===// - -// should be replaced with std::lcm in a newer C++ standard -struct LeastCommonMultipleOperator { - template - static inline TR Operation(TA left, TB right) { - if (left == 0 || right == 0) { - return 0; - } - TR result; - if (!TryMultiplyOperator::Operation(left, right / GreatestCommonDivisor(left, right), result)) { - throw OutOfRangeException("lcm value is out of range"); - } - return TryAbsOperator::Operation(result); - } -}; - -ScalarFunctionSet LeastCommonMultipleFun::GetFunctions() { - ScalarFunctionSet funcs; - - funcs.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::BIGINT}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction)); - funcs.AddFunction( - ScalarFunction({LogicalType::HUGEINT, LogicalType::HUGEINT}, LogicalType::HUGEINT, - ScalarFunction::BinaryFunction)); - for (auto &function : funcs.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return funcs; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp b/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp deleted file mode 100644 index 103e0c2ab..000000000 --- a/src/duckdb/extension/core_functions/scalar/operators/bitwise.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "core_functions/scalar/operators_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/bit.hpp" - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerUnaryFunction(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UTINYINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::USMALLINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UINTEGER: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UBIGINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::UnaryFunction; - break; - case LogicalTypeId::UHUGEINT: - function = &ScalarFunction::UnaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerUnaryFunction"); - } - return function; -} - -template -static scalar_function_t GetScalarIntegerBinaryFunction(const LogicalType &type) { - scalar_function_t function; - switch (type.id()) { - case LogicalTypeId::TINYINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::SMALLINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::INTEGER: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::BIGINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UTINYINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::USMALLINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UINTEGER: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UBIGINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::HUGEINT: - function = &ScalarFunction::BinaryFunction; - break; - case LogicalTypeId::UHUGEINT: - function = &ScalarFunction::BinaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarIntegerBinaryFunction"); - } - return function; -} - -//===--------------------------------------------------------------------===// -// & [bitwise_and] -//===--------------------------------------------------------------------===// -struct BitwiseANDOperator { - template - static inline TR Operation(TA left, TB right) { - return left & right; - } -}; - -static void BitwiseANDOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseAnd(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseAndFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseANDOperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -//===--------------------------------------------------------------------===// -// | [bitwise_or] -//===--------------------------------------------------------------------===// -struct BitwiseOROperator { - template - static inline TR Operation(TA left, TB right) { - return left | right; - } -}; - -static void BitwiseOROperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseOr(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseOrFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseOROperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -//===--------------------------------------------------------------------===// -// # [bitwise_xor] -//===--------------------------------------------------------------------===// -struct BitwiseXOROperator { - template - static inline TR Operation(TA left, TB right) { - return left ^ right; - } -}; - -static void BitwiseXOROperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t rhs, string_t lhs) { - string_t target = StringVector::EmptyString(result, rhs.GetSize()); - - Bit::BitwiseXor(rhs, lhs, target); - return target; - }); -} - -ScalarFunctionSet BitwiseXorFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT, LogicalType::BIT}, LogicalType::BIT, BitwiseXOROperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -//===--------------------------------------------------------------------===// -// ~ [bitwise_not] -//===--------------------------------------------------------------------===// -struct BitwiseNotOperator { - template - static inline TR Operation(TA input) { - return ~input; - } -}; - -static void BitwiseNOTOperation(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { - string_t target = StringVector::EmptyString(result, input.GetSize()); - - Bit::BitwiseNot(input, target); - return target; - }); -} - -ScalarFunctionSet BitwiseNotFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction(ScalarFunction({type}, type, GetScalarIntegerUnaryFunction(type))); - } - functions.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIT, BitwiseNOTOperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -//===--------------------------------------------------------------------===// -// << [bitwise_left_shift] -//===--------------------------------------------------------------------===// -struct BitwiseShiftLeftOperator { - template - static inline TR Operation(TA input, TB shift) { - TA max_shift = TA(sizeof(TA) * 8) + (NumericLimits::IsSigned() ? 0 : 1); - if (input < 0) { - throw OutOfRangeException("Cannot left-shift negative number %s", NumericHelper::ToString(input)); - } - if (shift < 0) { - throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); - } - if (shift >= max_shift) { - if (input == 0) { - return 0; - } - throw OutOfRangeException("Left-shift value %s is out of range", NumericHelper::ToString(shift)); - } - if (shift == 0) { - return input; - } - TA max_value = UnsafeNumericCast((TA(1) << (max_shift - shift - 1))); - if (input >= max_value) { - throw OutOfRangeException("Overflow in left shift (%s << %s)", NumericHelper::ToString(input), - NumericHelper::ToString(shift)); - } - return UnsafeNumericCast(input << shift); - } -}; - -static void BitwiseShiftLeftOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { - auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); - if (shift == 0) { - return input; - } - if (shift < 0) { - throw OutOfRangeException("Cannot left-shift by negative number %s", NumericHelper::ToString(shift)); - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - - if (shift >= max_shift) { - Bit::SetEmptyBitString(target, input); - return target; - } - Bit::LeftShift(input, UnsafeNumericCast(shift), target); - return target; - }); -} - -ScalarFunctionSet LeftShiftFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction( - ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftLeftOperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -//===--------------------------------------------------------------------===// -// >> [bitwise_right_shift] -//===--------------------------------------------------------------------===// -template -bool RightShiftInRange(T shift) { - return shift >= 0 && shift < T(sizeof(T) * 8); -} - -struct BitwiseShiftRightOperator { - template - static inline TR Operation(TA input, TB shift) { - return RightShiftInRange(shift) ? input >> shift : 0; - } -}; - -static void BitwiseShiftRightOperation(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t shift) { - auto max_shift = UnsafeNumericCast(Bit::BitLength(input)); - if (shift == 0) { - return input; - } - string_t target = StringVector::EmptyString(result, input.GetSize()); - if (shift < 0 || shift >= max_shift) { - Bit::SetEmptyBitString(target, input); - return target; - } - Bit::RightShift(input, UnsafeNumericCast(shift), target); - return target; - }); -} - -ScalarFunctionSet RightShiftFun::GetFunctions() { - ScalarFunctionSet functions; - for (auto &type : LogicalType::Integral()) { - functions.AddFunction( - ScalarFunction({type, type}, type, GetScalarIntegerBinaryFunction(type))); - } - functions.AddFunction( - ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitwiseShiftRightOperation)); - for (auto &function : functions.functions) { - BaseScalarFunction::SetReturnsError(function); - } - return functions; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/random/random.cpp b/src/duckdb/extension/core_functions/scalar/random/random.cpp deleted file mode 100644 index 3054170ff..000000000 --- a/src/duckdb/extension/core_functions/scalar/random/random.cpp +++ /dev/null @@ -1,64 +0,0 @@ -#include "core_functions/scalar/random_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/random_engine.hpp" -#include "duckdb/common/types/uuid.hpp" - -namespace duckdb { - -struct RandomLocalState : public FunctionLocalState { - explicit RandomLocalState(uint64_t seed) : random_engine(0) { - random_engine.SetSeed(seed); - } - - RandomEngine random_engine; -}; - -static void RandomFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - result_data[i] = lstate.random_engine.NextRandom(); - } -} - -static unique_ptr RandomInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - auto &random_engine = RandomEngine::Get(state.GetContext()); - lock_guard guard(random_engine.lock); - return make_uniq(random_engine.NextRandomInteger64()); -} - -ScalarFunction RandomFun::GetFunction() { - ScalarFunction random("random", {}, LogicalType::DOUBLE, RandomFunction, nullptr, nullptr, nullptr, - RandomInitLocalState); - random.stability = FunctionStability::VOLATILE; - return random; -} - -static void GenerateUUIDFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 0); - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < args.size(); i++) { - result_data[i] = UUID::GenerateRandomUUID(lstate.random_engine); - } -} - -ScalarFunction UUIDFun::GetFunction() { - ScalarFunction uuid_function({}, LogicalType::UUID, GenerateUUIDFunction, nullptr, nullptr, nullptr, - RandomInitLocalState); - // generate a random uuid - uuid_function.stability = FunctionStability::VOLATILE; - return uuid_function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp b/src/duckdb/extension/core_functions/scalar/random/setseed.cpp deleted file mode 100644 index ca2865284..000000000 --- a/src/duckdb/extension/core_functions/scalar/random/setseed.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "core_functions/scalar/random_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/random_engine.hpp" - -namespace duckdb { - -struct SetseedBindData : public FunctionData { - //! The client context for the function call - ClientContext &context; - - explicit SetseedBindData(ClientContext &context) : context(context) { - } - - unique_ptr Copy() const override { - return make_uniq(context); - } - - bool Equals(const FunctionData &other_p) const override { - return true; - } -}; - -static void SetSeedFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - auto &input = args.data[0]; - input.Flatten(args.size()); - - auto input_seeds = FlatVector::GetData(input); - uint32_t half_max = NumericLimits::Maximum() / 2; - - auto &random_engine = RandomEngine::Get(info.context); - for (idx_t i = 0; i < args.size(); i++) { - if (input_seeds[i] < -1.0 || input_seeds[i] > 1.0 || Value::IsNan(input_seeds[i])) { - throw InvalidInputException("SETSEED accepts seed values between -1.0 and 1.0, inclusive"); - } - auto norm_seed = LossyNumericCast((input_seeds[i] + 1.0) * half_max); - random_engine.SetSeed(norm_seed); - } - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); -} - -unique_ptr SetSeedBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return make_uniq(context); -} - -ScalarFunction SetseedFun::GetFunction() { - ScalarFunction setseed("setseed", {LogicalType::DOUBLE}, LogicalType::SQLNULL, SetSeedFunction, SetSeedBind); - setseed.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(setseed); - return setseed; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/ascii.cpp b/src/duckdb/extension/core_functions/scalar/string/ascii.cpp deleted file mode 100644 index 4083c85de..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/ascii.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "utf8proc.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -struct AsciiOperator { - template - static inline TR Operation(const TA &input) { - auto str = input.GetData(); - if (Utf8Proc::Analyze(str, input.GetSize()) == UnicodeType::ASCII) { - return str[0]; - } - int utf8_bytes = 4; - return Utf8Proc::UTF8ToCodepoint(str, utf8_bytes); - } -}; - -ScalarFunction ASCIIFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, - ScalarFunction::UnaryFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/bar.cpp b/src/duckdb/extension/core_functions/scalar/string/bar.cpp deleted file mode 100644 index 957b8c624..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/bar.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/unicode_bar.hpp" -#include "duckdb/common/vector_operations/generic_executor.hpp" - -namespace duckdb { - -static string_t BarScalarFunction(double x, double min, double max, double max_width, string &result) { - static const char *FULL_BLOCK = UnicodeBar::FullBlock(); - static const char *const *PARTIAL_BLOCKS = UnicodeBar::PartialBlocks(); - static const idx_t PARTIAL_BLOCKS_COUNT = UnicodeBar::PartialBlocksCount(); - - if (!Value::IsFinite(max_width)) { - throw OutOfRangeException("Max bar width must not be NaN or infinity"); - } - if (max_width < 1) { - throw OutOfRangeException("Max bar width must be >= 1"); - } - if (max_width > 1000) { - throw OutOfRangeException("Max bar width must be <= 1000"); - } - - double width; - - if (Value::IsNan(x) || Value::IsNan(min) || Value::IsNan(max) || x <= min) { - width = 0; - } else if (x >= max) { - width = max_width; - } else { - width = max_width * (x - min) / (max - min); - } - - if (!Value::IsFinite(width)) { - throw OutOfRangeException("Bar width must not be NaN or infinity"); - } - - result.clear(); - idx_t used_blocks = 0; - - auto width_as_int = LossyNumericCast(width * PARTIAL_BLOCKS_COUNT); - idx_t full_blocks_count = (width_as_int / PARTIAL_BLOCKS_COUNT); - for (idx_t i = 0; i < full_blocks_count; i++) { - used_blocks++; - result += FULL_BLOCK; - } - - idx_t remaining = width_as_int % PARTIAL_BLOCKS_COUNT; - - if (remaining) { - used_blocks++; - result += PARTIAL_BLOCKS[remaining]; - } - - const idx_t integer_max_width = (idx_t)max_width; - if (used_blocks < integer_max_width) { - result += std::string(integer_max_width - used_blocks, ' '); - } - return string_t(result); -} - -static void BarFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 3 || args.ColumnCount() == 4); - auto &x_arg = args.data[0]; - auto &min_arg = args.data[1]; - auto &max_arg = args.data[2]; - string buffer; - - if (args.ColumnCount() == 3) { - GenericExecutor::ExecuteTernary, PrimitiveType, PrimitiveType, - PrimitiveType>( - x_arg, min_arg, max_arg, result, args.size(), - [&](PrimitiveType x, PrimitiveType min, PrimitiveType max) { - return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, 80, buffer)); - }); - } else { - auto &width_arg = args.data[3]; - GenericExecutor::ExecuteQuaternary, PrimitiveType, PrimitiveType, - PrimitiveType, PrimitiveType>( - x_arg, min_arg, max_arg, width_arg, result, args.size(), - [&](PrimitiveType x, PrimitiveType min, PrimitiveType max, - PrimitiveType width) { - return StringVector::AddString(result, BarScalarFunction(x.val, min.val, max.val, width.val, buffer)); - }); - } -} - -ScalarFunctionSet BarFun::GetFunctions() { - ScalarFunctionSet bar; - bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, - LogicalType::VARCHAR, BarFunction)); - bar.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE, LogicalType::DOUBLE}, - LogicalType::VARCHAR, BarFunction)); - return bar; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/chr.cpp b/src/duckdb/extension/core_functions/scalar/string/chr.cpp deleted file mode 100644 index bca2de6d8..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/chr.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "utf8proc.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -struct ChrOperator { - static void GetCodepoint(int32_t input, char c[], int &utf8_bytes) { - if (input < 0 || !Utf8Proc::CodepointToUtf8(input, utf8_bytes, &c[0])) { - throw InvalidInputException("Invalid UTF8 Codepoint %d", input); - } - } - - template - static inline TR Operation(const TA &input) { - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - int utf8_bytes; - GetCodepoint(input, c, utf8_bytes); - return string_t(&c[0], UnsafeNumericCast(utf8_bytes)); - } -}; - -#ifdef DUCKDB_DEBUG_NO_INLINE -// the chr function depends on the data always being inlined (which is always possible, since it outputs max 4 bytes) -// to enable chr when string inlining is disabled we create a special function here -static void ChrFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &code_vec = args.data[0]; - - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - int utf8_bytes; - UnaryExecutor::Execute(code_vec, result, args.size(), [&](int32_t input) { - ChrOperator::GetCodepoint(input, c, utf8_bytes); - return StringVector::AddString(result, &c[0], UnsafeNumericCast(utf8_bytes)); - }); -} -#endif - -ScalarFunction ChrFun::GetFunction() { - return ScalarFunction("chr", {LogicalType::INTEGER}, LogicalType::VARCHAR, -#ifdef DUCKDB_DEBUG_NO_INLINE - ChrFunction -#else - ScalarFunction::UnaryFunction -#endif - ); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp deleted file mode 100644 index 91b0fbd33..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/damerau_levenshtein.cpp +++ /dev/null @@ -1,104 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/map.hpp" -#include "duckdb/common/vector.hpp" - -namespace duckdb { - -// Using Lowrance-Wagner (LW) algorithm: https://doi.org/10.1145%2F321879.321880 -// Can't calculate as trivial modification to levenshtein algorithm -// as we need to potentially know about earlier in the string -static idx_t DamerauLevenshteinDistance(const string_t &source, const string_t &target) { - // costs associated with each type of edit, to aid readability - constexpr uint8_t COST_SUBSTITUTION = 1; - constexpr uint8_t COST_INSERTION = 1; - constexpr uint8_t COST_DELETION = 1; - constexpr uint8_t COST_TRANSPOSITION = 1; - const auto source_len = source.GetSize(); - const auto target_len = target.GetSize(); - - // If one string is empty, the distance equals the length of the other string - // either through target_len insertions - // or source_len deletions - if (source_len == 0) { - return target_len * COST_INSERTION; - } else if (target_len == 0) { - return source_len * COST_DELETION; - } - - const auto source_str = source.GetData(); - const auto target_str = target.GetData(); - - // larger than the largest possible value: - const auto inf = source_len * COST_DELETION + target_len * COST_INSERTION + 1; - // minimum edit distance from prefix of source string to prefix of target string - // same object as H in LW paper (with indices offset by 1) - vector> distance(source_len + 2, vector(target_len + 2, inf)); - // keeps track of the largest string indices of source string matching each character - // same as DA in LW paper - map largest_source_chr_matching; - - // initialise row/column corresponding to zero-length strings - // partial string -> empty requires a deletion for each character - for (idx_t source_idx = 0; source_idx <= source_len; source_idx++) { - distance[source_idx + 1][1] = source_idx * COST_DELETION; - } - // and empty -> partial string means simply inserting characters - for (idx_t target_idx = 1; target_idx <= target_len; target_idx++) { - distance[1][target_idx + 1] = target_idx * COST_INSERTION; - } - // loop through string indices - these are offset by 2 from distance indices - for (idx_t source_idx = 0; source_idx < source_len; source_idx++) { - // keeps track of the largest string indices of target string matching current source character - // same as DB in LW paper - idx_t largest_target_chr_matching; - largest_target_chr_matching = 0; - for (idx_t target_idx = 0; target_idx < target_len; target_idx++) { - // correspond to i1 and j1 in LW paper respectively - idx_t largest_source_chr_matching_target; - idx_t largest_target_chr_matching_source; - // cost associated to diagnanl shift in distance matrix - // corresponds to d in LW paper - uint8_t cost_diagonal_shift; - largest_source_chr_matching_target = largest_source_chr_matching[target_str[target_idx]]; - largest_target_chr_matching_source = largest_target_chr_matching; - // if characters match, diagonal move costs nothing and we update our largest target index - // otherwise move is substitution and costs as such - if (source_str[source_idx] == target_str[target_idx]) { - cost_diagonal_shift = 0; - largest_target_chr_matching = target_idx + 1; - } else { - cost_diagonal_shift = COST_SUBSTITUTION; - } - distance[source_idx + 2][target_idx + 2] = MinValue( - distance[source_idx + 1][target_idx + 1] + cost_diagonal_shift, - MinValue(distance[source_idx + 2][target_idx + 1] + COST_INSERTION, - MinValue(distance[source_idx + 1][target_idx + 2] + COST_DELETION, - distance[largest_source_chr_matching_target][largest_target_chr_matching_source] + - (source_idx - largest_source_chr_matching_target) * COST_DELETION + - COST_TRANSPOSITION + - (target_idx - largest_target_chr_matching_source) * COST_INSERTION))); - } - largest_source_chr_matching[source_str[source_idx]] = source_idx + 1; - } - return distance[source_len + 1][target_len + 1]; -} - -static int64_t DamerauLevenshteinScalarFunction(Vector &result, const string_t source, const string_t target) { - return (int64_t)DamerauLevenshteinDistance(source, target); -} - -static void DamerauLevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &source_vec = args.data[0]; - auto &target_vec = args.data[1]; - - BinaryExecutor::Execute( - source_vec, target_vec, result, args.size(), - [&](string_t source, string_t target) { return DamerauLevenshteinScalarFunction(result, source, target); }); -} - -ScalarFunction DamerauLevenshteinFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, - DamerauLevenshteinFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp b/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp deleted file mode 100644 index 46db22f25..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/format_bytes.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -template -static void FormatBytesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](int64_t bytes) { - bool is_negative = bytes < 0; - idx_t unsigned_bytes; - if (bytes < 0) { - if (bytes == NumericLimits::Minimum()) { - unsigned_bytes = idx_t(NumericLimits::Maximum()) + 1; - } else { - unsigned_bytes = idx_t(-bytes); - } - } else { - unsigned_bytes = idx_t(bytes); - } - return StringVector::AddString(result, (is_negative ? "-" : "") + - StringUtil::BytesToHumanReadableString(unsigned_bytes, MULTIPLIER)); - }); -} - -ScalarFunction FormatBytesFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction<1024>); -} - -ScalarFunction FormatreadabledecimalsizeFun::GetFunction() { - return ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, FormatBytesFunction<1000>); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/hamming.cpp b/src/duckdb/extension/core_functions/scalar/string/hamming.cpp deleted file mode 100644 index b32a80199..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/hamming.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -#include -#include - -namespace duckdb { - -static int64_t MismatchesScalarFunction(Vector &result, const string_t str, string_t tgt) { - idx_t str_len = str.GetSize(); - idx_t tgt_len = tgt.GetSize(); - - if (str_len != tgt_len) { - throw InvalidInputException("Mismatch Function: Strings must be of equal length!"); - } - if (str_len < 1) { - throw InvalidInputException("Mismatch Function: Strings must be of length > 0!"); - } - - idx_t mismatches = 0; - auto str_str = str.GetData(); - auto tgt_str = tgt.GetData(); - - for (idx_t idx = 0; idx < str_len; ++idx) { - if (str_str[idx] != tgt_str[idx]) { - mismatches++; - } - } - return (int64_t)mismatches; -} - -static void MismatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return MismatchesScalarFunction(result, str, tgt); }); -} - -ScalarFunction HammingFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, MismatchesFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/hex.cpp b/src/duckdb/extension/core_functions/scalar/string/hex.cpp deleted file mode 100644 index cbf541e1b..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/hex.cpp +++ /dev/null @@ -1,440 +0,0 @@ -#include "duckdb/common/bit_utils.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "core_functions/scalar/string_functions.hpp" - -namespace duckdb { - -static void WriteHexBytes(uint64_t x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size * 4; - - for (; offset >= 4; offset -= 4) { - uint8_t byte = (x >> (offset - 4)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } -} - -template -static void WriteHugeIntHexBytes(T x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size * 4; - auto upper = x.upper; - auto lower = x.lower; - - for (; offset >= 68; offset -= 4) { - uint8_t byte = (upper >> (offset - 68)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } - - for (; offset >= 4; offset -= 4) { - uint8_t byte = (lower >> (offset - 4)) & 0x0F; - *output = Blob::HEX_TABLE[byte]; - output++; - } -} - -static void WriteBinBytes(uint64_t x, char *&output, idx_t buffer_size) { - idx_t offset = buffer_size; - for (; offset >= 1; offset -= 1) { - *output = NumericCast(((x >> (offset - 1)) & 0x01) + '0'); - output++; - } -} - -template -static void WriteHugeIntBinBytes(T x, char *&output, idx_t buffer_size) { - auto upper = x.upper; - auto lower = x.lower; - idx_t offset = buffer_size; - - for (; offset >= 65; offset -= 1) { - *output = ((upper >> (offset - 65)) & 0x01) + '0'; - output++; - } - - for (; offset >= 1; offset -= 1) { - *output = ((lower >> (offset - 1)) & 0x01) + '0'; - output++; - } -} - -struct HexStrOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - // Allocate empty space - auto target = StringVector::EmptyString(result, size * 2); - auto output = target.GetDataWriteable(); - - for (idx_t i = 0; i < size; ++i) { - *output = Blob::HEX_TABLE[(data[i] >> 4) & 0x0F]; - output++; - *output = Blob::HEX_TABLE[data[i] & 0x0F]; - output++; - } - - target.Finalize(); - return target; - } -}; - -struct HexIntegralOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - auto num_leading_zero = CountZeros::Leading(static_cast(input)); - idx_t num_bits_to_check = 64 - num_leading_zero; - D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); - - idx_t buffer_size = (num_bits_to_check + 3) / 4; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHexBytes(static_cast(input), output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct HexHugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); - idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntHexBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct HexUhugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - idx_t num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); - idx_t buffer_size = sizeof(INPUT_TYPE) * 2 - (num_leading_zero / 4); - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntHexBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -template -static void ToHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); -} - -struct BinaryStrOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - // Allocate empty space - auto target = StringVector::EmptyString(result, size * 8); - auto output = target.GetDataWriteable(); - - for (idx_t i = 0; i < size; ++i) { - auto byte = static_cast(data[i]); - for (idx_t i = 8; i >= 1; --i) { - *output = ((byte >> (i - 1)) & 0x01) + '0'; - output++; - } - } - - target.Finalize(); - return target; - } -}; - -struct BinaryIntegralOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - - auto num_leading_zero = CountZeros::Leading(static_cast(input)); - idx_t num_bits_to_check = 64 - num_leading_zero; - D_ASSERT(num_bits_to_check <= sizeof(INPUT_TYPE) * 8); - - idx_t buffer_size = num_bits_to_check; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - D_ASSERT(buffer_size > 0); - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteBinBytes(static_cast(input), output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct BinaryHugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); - idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntBinBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct BinaryUhugeIntOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto num_leading_zero = CountZeros::Leading(UnsafeNumericCast(input)); - idx_t buffer_size = sizeof(INPUT_TYPE) * 8 - num_leading_zero; - - // Special case: All bits are zero - if (buffer_size == 0) { - auto target = StringVector::EmptyString(result, 1); - auto output = target.GetDataWriteable(); - *output = '0'; - target.Finalize(); - return target; - } - - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - WriteHugeIntBinBytes(input, output, buffer_size); - - target.Finalize(); - return target; - } -}; - -struct FromHexOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - if (size > NumericLimits::Maximum()) { - throw InvalidInputException("Hexadecimal input length larger than 2^32 are not supported"); - } - - D_ASSERT(size <= NumericLimits::Maximum()); - auto buffer_size = (size + 1) / 2; - - // Allocate empty space - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - // Treated as a single byte - idx_t i = 0; - if (size % 2 != 0) { - *output = static_cast(StringUtil::GetHexValue(data[i])); - i++; - output++; - } - - for (; i < size; i += 2) { - uint8_t major = StringUtil::GetHexValue(data[i]); - uint8_t minor = StringUtil::GetHexValue(data[i + 1]); - *output = static_cast((major << 4) | minor); - output++; - } - - target.Finalize(); - return target; - } -}; - -struct FromBinaryOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - if (size > NumericLimits::Maximum()) { - throw InvalidInputException("Binary input length larger than 2^32 are not supported"); - } - - D_ASSERT(size <= NumericLimits::Maximum()); - auto buffer_size = (size + 7) / 8; - - // Allocate empty space - auto target = StringVector::EmptyString(result, buffer_size); - auto output = target.GetDataWriteable(); - - // Treated as a single byte - idx_t i = 0; - if (size % 8 != 0) { - uint8_t byte = 0; - for (idx_t j = size % 8; j > 0; --j) { - byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); - i++; - } - *output = static_cast(byte); // binary eh - output++; - } - - while (i < size) { - uint8_t byte = 0; - for (idx_t j = 8; j > 0; --j) { - byte |= StringUtil::GetBinaryValue(data[i]) << (j - 1); - i++; - } - *output = static_cast(byte); - output++; - } - - target.Finalize(); - return target; - } -}; - -template -static void ToBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - auto &input = args.data[0]; - idx_t count = args.size(); - UnaryExecutor::ExecuteString(input, result, count); -} - -static void FromBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); -} - -static void FromHexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - D_ASSERT(args.data[0].GetType().InternalType() == PhysicalType::VARCHAR); - auto &input = args.data[0]; - idx_t count = args.size(); - - UnaryExecutor::ExecuteString(input, result, count); -} - -ScalarFunctionSet HexFun::GetFunctions() { - ScalarFunctionSet to_hex; - to_hex.AddFunction( - ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, ToHexFunction)); - to_hex.AddFunction( - ScalarFunction({LogicalType::UHUGEINT}, LogicalType::VARCHAR, ToHexFunction)); - return to_hex; -} - -ScalarFunction UnhexFun::GetFunction() { - ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromHexFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -ScalarFunctionSet BinFun::GetFunctions() { - ScalarFunctionSet to_binary; - - to_binary.AddFunction( - ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, ToBinaryFunction)); - to_binary.AddFunction( - ScalarFunction({LogicalType::VARINT}, LogicalType::VARCHAR, ToBinaryFunction)); - to_binary.AddFunction(ScalarFunction({LogicalType::UBIGINT}, LogicalType::VARCHAR, - ToBinaryFunction)); - to_binary.AddFunction( - ScalarFunction({LogicalType::BIGINT}, LogicalType::VARCHAR, ToBinaryFunction)); - to_binary.AddFunction(ScalarFunction({LogicalType::HUGEINT}, LogicalType::VARCHAR, - ToBinaryFunction)); - to_binary.AddFunction(ScalarFunction({LogicalType::UHUGEINT}, LogicalType::VARCHAR, - ToBinaryFunction)); - return to_binary; -} - -ScalarFunction UnbinFun::GetFunction() { - ScalarFunction function({LogicalType::VARCHAR}, LogicalType::BLOB, FromBinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/instr.cpp b/src/duckdb/extension/core_functions/scalar/string/instr.cpp deleted file mode 100644 index 77539e7c0..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/instr.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "utf8proc.hpp" - -namespace duckdb { - -struct InstrOperator { - template - static inline TR Operation(TA haystack, TB needle) { - int64_t string_position = 0; - - auto location = FindStrInStr(haystack, needle); - if (location != DConstants::INVALID_INDEX) { - auto len = (utf8proc_ssize_t)location; - auto str = reinterpret_cast(haystack.GetData()); - D_ASSERT(len <= (utf8proc_ssize_t)haystack.GetSize()); - for (++string_position; len > 0; ++string_position) { - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str, len, &codepoint); - str += bytes; - len -= bytes; - } - } - return string_position; - } -}; - -struct InstrAsciiOperator { - template - static inline TR Operation(TA haystack, TB needle) { - auto location = FindStrInStr(haystack, needle); - return UnsafeNumericCast(location == DConstants::INVALID_INDEX ? 0U : location + 1U); - } -}; - -static unique_ptr InStrPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 2); - // can only propagate stats if the children have stats - // for strpos, we only care if the FIRST string has unicode or not - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; - } - return nullptr; -} - -ScalarFunction InstrFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction, nullptr, nullptr, - InStrPropagateStats); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp b/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp deleted file mode 100644 index eae31dc9c..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/jaccard.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "duckdb/common/map.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "core_functions/scalar/string_functions.hpp" - -#include -#include - -namespace duckdb { - -namespace { -constexpr size_t MAX_SIZE = std::numeric_limits::max() + 1; -} - -static inline std::bitset GetSet(const string_t &str) { - std::bitset array_set; - - idx_t str_len = str.GetSize(); - auto s = str.GetData(); - - for (idx_t pos = 0; pos < str_len; pos++) { - array_set.set(static_cast(s[pos])); - } - return array_set; -} - -static double JaccardSimilarity(const string_t &str, const string_t &txt) { - if (str.GetSize() < 1 || txt.GetSize() < 1) { - throw InvalidInputException("Jaccard Function: An argument too short!"); - } - std::bitset m_str, m_txt; - - m_str = GetSet(str); - m_txt = GetSet(txt); - - idx_t size_intersect = (m_str & m_txt).count(); - idx_t size_union = (m_str | m_txt).count(); - - return static_cast(size_intersect) / static_cast(size_union); -} - -static double JaccardScalarFunction(Vector &result, const string_t str, string_t tgt) { - return (double)JaccardSimilarity(str, tgt); -} - -static void JaccardFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return JaccardScalarFunction(result, str, tgt); }); -} - -ScalarFunction JaccardFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaccardFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp b/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp deleted file mode 100644 index 13db07c79..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/jaro_winkler.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "jaro_winkler.hpp" - -#include "core_functions/scalar/string_functions.hpp" - -namespace duckdb { - -static inline double JaroScalarFunction(const string_t &s1, const string_t &s2, const double_t &score_cutoff = 0.0) { - auto s1_begin = s1.GetData(); - auto s2_begin = s2.GetData(); - return duckdb_jaro_winkler::jaro_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, s2_begin + s2.GetSize(), - score_cutoff); -} - -static inline double JaroWinklerScalarFunction(const string_t &s1, const string_t &s2, - const double_t &score_cutoff = 0.0) { - auto s1_begin = s1.GetData(); - auto s2_begin = s2.GetData(); - return duckdb_jaro_winkler::jaro_winkler_similarity(s1_begin, s1_begin + s1.GetSize(), s2_begin, - s2_begin + s2.GetSize(), 0.1, score_cutoff); -} - -template -static void CachedFunction(Vector &constant, Vector &other, Vector &result, DataChunk &args) { - auto val = constant.GetValue(0); - idx_t count = args.size(); - if (val.IsNull()) { - auto &result_validity = FlatVector::Validity(result); - result_validity.SetAllInvalid(count); - return; - } - - auto str_val = StringValue::Get(val); - auto cached = CACHED_SIMILARITY(str_val); - - D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); - if (args.ColumnCount() == 2) { - UnaryExecutor::Execute(other, result, count, [&](const string_t &other_str) { - auto other_str_begin = other_str.GetData(); - return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize()); - }); - } else { - auto score_cutoff = args.data[2]; - BinaryExecutor::Execute( - other, score_cutoff, result, count, [&](const string_t &other_str, const double_t score_cutoff) { - auto other_str_begin = other_str.GetData(); - return cached.similarity(other_str_begin, other_str_begin + other_str.GetSize(), score_cutoff); - }); - } -} - -template -static void TemplatedJaroWinklerFunction(DataChunk &args, Vector &result, SIMILARITY_FUNCTION fun) { - bool arg0_constant = args.data[0].GetVectorType() == VectorType::CONSTANT_VECTOR; - bool arg1_constant = args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR; - if (!(arg0_constant ^ arg1_constant)) { - // We can't optimize by caching one of the two strings - D_ASSERT(args.ColumnCount() == 2 || args.ColumnCount() == 3); - if (args.ColumnCount() == 2) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), - [&](const string_t &s1, const string_t &s2) { return fun(s1, s2, 0.0); }); - return; - } else { - TernaryExecutor::Execute(args.data[0], args.data[1], args.data[2], - result, args.size(), fun); - return; - } - } - - if (arg0_constant) { - CachedFunction(args.data[0], args.data[1], result, args); - } else { - CachedFunction(args.data[1], args.data[0], result, args); - } -} - -static void JaroFunction(DataChunk &args, ExpressionState &state, Vector &result) { - TemplatedJaroWinklerFunction>(args, result, JaroScalarFunction); -} - -static void JaroWinklerFunction(DataChunk &args, ExpressionState &state, Vector &result) { - TemplatedJaroWinklerFunction>(args, result, - JaroWinklerScalarFunction); -} - -ScalarFunctionSet JaroSimilarityFun::GetFunctions() { - ScalarFunctionSet jaro; - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroFunction); - jaro.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, - JaroFunction); - jaro.AddFunction(fun); - return jaro; -} - -ScalarFunctionSet JaroWinklerSimilarityFun::GetFunctions() { - ScalarFunctionSet jaroWinkler; - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::DOUBLE, JaroWinklerFunction); - jaroWinkler.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::DOUBLE}, LogicalType::DOUBLE, - JaroWinklerFunction); - jaroWinkler.AddFunction(fun); - return jaroWinkler; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/left_right.cpp b/src/duckdb/extension/core_functions/scalar/string/left_right.cpp deleted file mode 100644 index b13ff9560..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/left_right.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/function/scalar/string_common.hpp" - -#include -#include - -namespace duckdb { - -struct LeftRightUnicode { - template - static inline TR Operation(TA input) { - return Length(input); - } - - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringUnicode(result, input, offset, length); - } -}; - -struct LeftRightGrapheme { - template - static inline TR Operation(TA input) { - return GraphemeCount(input); - } - - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringGrapheme(result, input, offset, length); - } -}; - -template -static string_t LeftScalarFunction(Vector &result, const string_t str, int64_t pos) { - if (pos >= 0) { - return OP::Substring(result, str, 1, pos); - } - - int64_t num_characters = OP::template Operation(str); - pos = MaxValue(0, num_characters + pos); - return OP::Substring(result, str, 1, pos); -} - -template -static void LeftFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return LeftScalarFunction(result, str, pos); }); -} - -ScalarFunction LeftFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - LeftFunction); -} - -ScalarFunction LeftGraphemeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - LeftFunction); -} - -template -static string_t RightScalarFunction(Vector &result, const string_t str, int64_t pos) { - int64_t num_characters = OP::template Operation(str); - if (pos >= 0) { - int64_t len = MinValue(num_characters, pos); - int64_t start = num_characters - len + 1; - return OP::Substring(result, str, start, len); - } - - int64_t len = 0; - if (pos != std::numeric_limits::min()) { - len = num_characters - MinValue(num_characters, -pos); - } - int64_t start = num_characters - len + 1; - return OP::Substring(result, str, start, len); -} - -template -static void RightFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &pos_vec = args.data[1]; - BinaryExecutor::Execute( - str_vec, pos_vec, result, args.size(), - [&](string_t str, int64_t pos) { return RightScalarFunction(result, str, pos); }); -} - -ScalarFunction RightFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - RightFunction); -} - -ScalarFunction RightGraphemeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - RightFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp b/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp deleted file mode 100644 index 24e28b89c..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/levenshtein.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/string_util.hpp" - -#include -#include - -namespace duckdb { - -// See: https://www.kdnuggets.com/2020/10/optimizing-levenshtein-distance-measuring-text-similarity.html -// And: Iterative 2-row algorithm: https://en.wikipedia.org/wiki/Levenshtein_distance -// Note: A first implementation using the array algorithm version resulted in an error raised by duckdb -// (too muach memory usage) - -static idx_t LevenshteinDistance(const string_t &txt, const string_t &tgt) { - auto txt_len = txt.GetSize(); - auto tgt_len = tgt.GetSize(); - - // If one string is empty, the distance equals the length of the other string - if (txt_len == 0) { - return tgt_len; - } else if (tgt_len == 0) { - return txt_len; - } - - auto txt_str = txt.GetData(); - auto tgt_str = tgt.GetData(); - - // Create two working vectors - vector distances0(tgt_len + 1, 0); - vector distances1(tgt_len + 1, 0); - - idx_t cost_substitution = 0; - idx_t cost_insertion = 0; - idx_t cost_deletion = 0; - - // initialize distances0 vector - // edit distance for an empty txt string is just the number of characters to delete from tgt - for (idx_t pos_tgt = 0; pos_tgt <= tgt_len; pos_tgt++) { - distances0[pos_tgt] = pos_tgt; - } - - for (idx_t pos_txt = 0; pos_txt < txt_len; pos_txt++) { - // calculate distances1 (current raw distances) from the previous row - - distances1[0] = pos_txt + 1; - - for (idx_t pos_tgt = 0; pos_tgt < tgt_len; pos_tgt++) { - cost_deletion = distances0[pos_tgt + 1] + 1; - cost_insertion = distances1[pos_tgt] + 1; - cost_substitution = distances0[pos_tgt]; - - if (txt_str[pos_txt] != tgt_str[pos_tgt]) { - cost_substitution += 1; - } - - distances1[pos_tgt + 1] = MinValue(cost_deletion, MinValue(cost_substitution, cost_insertion)); - } - // copy distances1 (current row) to distances0 (previous row) for next iteration - // since data in distances1 is always invalidated, a swap without copy is more efficient - distances0 = distances1; - } - - return distances0[tgt_len]; -} - -static int64_t LevenshteinScalarFunction(Vector &result, const string_t str, string_t tgt) { - return (int64_t)LevenshteinDistance(str, tgt); -} - -static void LevenshteinFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vec = args.data[0]; - auto &tgt_vec = args.data[1]; - - BinaryExecutor::Execute( - str_vec, tgt_vec, result, args.size(), - [&](string_t str, string_t tgt) { return LevenshteinScalarFunction(result, str, tgt); }); -} - -ScalarFunction LevenshteinFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BIGINT, LevenshteinFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/pad.cpp b/src/duckdb/extension/core_functions/scalar/string/pad.cpp deleted file mode 100644 index 586e1605a..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/pad.cpp +++ /dev/null @@ -1,147 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/pair.hpp" - -#include "utf8proc.hpp" - -namespace duckdb { - -static pair PadCountChars(const idx_t len, const char *data, const idx_t size) { - // Count how much of str will fit in the output - auto str = reinterpret_cast(data); - idx_t nbytes = 0; - idx_t nchars = 0; - for (; nchars < len && nbytes < size; ++nchars) { - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); - D_ASSERT(bytes > 0); - nbytes += UnsafeNumericCast(bytes); - } - - return pair(nbytes, nchars); -} - -static bool InsertPadding(const idx_t len, const string_t &pad, vector &result) { - // Copy the padding until the output is long enough - auto data = pad.GetData(); - auto size = pad.GetSize(); - - // Check whether we need data that we don't have - if (len > 0 && size == 0) { - return false; - } - - // Insert characters until we have all we need. - auto str = reinterpret_cast(data); - idx_t nbytes = 0; - for (idx_t nchars = 0; nchars < len; ++nchars) { - // If we are at the end of the pad, flush all of it and loop back - if (nbytes >= size) { - result.insert(result.end(), data, data + size); - nbytes = 0; - } - - // Write the next character - utf8proc_int32_t codepoint; - auto bytes = utf8proc_iterate(str + nbytes, UnsafeNumericCast(size - nbytes), &codepoint); - D_ASSERT(bytes > 0); - nbytes += UnsafeNumericCast(bytes); - } - - // Flush the remaining pad - result.insert(result.end(), data, data + nbytes); - - return true; -} - -static string_t LeftPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { - // Reuse the buffer - result.clear(); - - // Get information about the base string - auto data_str = str.GetData(); - auto size_str = str.GetSize(); - - // Count how much of str will fit in the output - auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); - - // Left pad by the number of characters still needed - if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { - throw InvalidInputException("Insufficient padding in LPAD."); - } - - // Append as much of the original string as fits - result.insert(result.end(), data_str, data_str + written.first); - - return string_t(result.data(), UnsafeNumericCast(result.size())); -} - -struct LeftPadOperator { - static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, - vector &result) { - return LeftPadFunction(str, len, pad, result); - } -}; - -static string_t RightPadFunction(const string_t &str, const int32_t len, const string_t &pad, vector &result) { - // Reuse the buffer - result.clear(); - - // Get information about the base string - auto data_str = str.GetData(); - auto size_str = str.GetSize(); - - // Count how much of str will fit in the output - auto written = PadCountChars(UnsafeNumericCast(len), data_str, size_str); - - // Append as much of the original string as fits - result.insert(result.end(), data_str, data_str + written.first); - - // Right pad by the number of characters still needed - if (!InsertPadding(UnsafeNumericCast(len) - written.second, pad, result)) { - throw InvalidInputException("Insufficient padding in RPAD."); - }; - - return string_t(result.data(), UnsafeNumericCast(result.size())); -} - -struct RightPadOperator { - static inline string_t Operation(const string_t &str, const int32_t len, const string_t &pad, - vector &result) { - return RightPadFunction(str, len, pad, result); - } -}; - -template -static void PadFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str_vector = args.data[0]; - auto &len_vector = args.data[1]; - auto &pad_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - str_vector, len_vector, pad_vector, result, args.size(), [&](string_t str, int32_t len, string_t pad) { - len = MaxValue(len, 0); - return StringVector::AddString(result, OP::Operation(str, len, pad, buffer)); - }); -} - -ScalarFunction LpadFun::GetFunction() { - ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - PadFunction); - BaseScalarFunction::SetReturnsError(func); - return func; -} - -ScalarFunction RpadFun::GetFunction() { - ScalarFunction func({LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - PadFunction); - BaseScalarFunction::SetReturnsError(func); - return func; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp b/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp deleted file mode 100644 index 9ed926b4a..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/parse_path.cpp +++ /dev/null @@ -1,348 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/common/local_file_system.hpp" -#include - -namespace duckdb { - -static string GetSeparator(const string_t &input) { - string option = input.GetString(); - - // system's path separator - auto fs = FileSystem::CreateLocal(); - auto system_sep = fs->PathSeparator(option); - - string separator; - if (option == "system") { - separator = system_sep; - } else if (option == "forward_slash") { - separator = "/"; - } else if (option == "backslash") { - separator = "\\"; - } else { // both_slash (default) - separator = "/\\"; - } - return separator; -} - -struct SplitInput { - SplitInput(Vector &result_list, Vector &result_child, idx_t offset) - : result_list(result_list), result_child(result_child), offset(offset) { - } - - Vector &result_list; - Vector &result_child; - idx_t offset; - - void AddSplit(const char *split_data, idx_t split_size, idx_t list_idx) { - auto list_entry = offset + list_idx; - if (list_entry >= ListVector::GetListCapacity(result_list)) { - ListVector::SetListSize(result_list, offset + list_idx); - ListVector::Reserve(result_list, ListVector::GetListCapacity(result_list) * 2); - } - FlatVector::GetData(result_child)[list_entry] = - StringVector::AddString(result_child, split_data, split_size); - } -}; - -static bool IsIdxValid(const idx_t &i, const idx_t &sentence_size) { - if (i > sentence_size || i == DConstants::INVALID_INDEX) { - return false; - } - return true; -} - -static idx_t Find(const char *input_data, idx_t input_size, const string &sep_data) { - if (sep_data.empty()) { - return 0; - } - auto pos = FindStrInStr(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(&sep_data[0]), 1); - // both_slash option - if (sep_data.size() > 1) { - auto sec_pos = - FindStrInStr(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(&sep_data[1]), 1); - // choose the leftmost valid position - if (sec_pos != DConstants::INVALID_INDEX && (sec_pos < pos || pos == DConstants::INVALID_INDEX)) { - return sec_pos; - } - } - return pos; -} - -static idx_t FindLast(const char *data_ptr, idx_t input_size, const string &sep_data) { - idx_t start = 0; - while (input_size > 0) { - auto pos = Find(data_ptr, input_size, sep_data); - if (!IsIdxValid(pos, input_size)) { - break; - } - start += (pos + 1); - data_ptr += (pos + 1); - input_size -= (pos + 1); - } - if (start < 1) { - return DConstants::INVALID_INDEX; - } - return start - 1; -} - -static idx_t SplitPath(string_t input, const string &sep, SplitInput &state) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - if (!input_size) { - return 0; - } - idx_t list_idx = 0; - while (input_size > 0) { - auto pos = Find(input_data, input_size, sep); - if (!IsIdxValid(pos, input_size)) { - break; - } - - D_ASSERT(input_size >= pos); - if (pos == 0) { - if (list_idx == 0) { // first character in path is separator - state.AddSplit(input_data, 1, list_idx); - list_idx++; - if (input_size == 1) { // special case: the only character in path is a separator - return list_idx; - } - } // else: separator is in the path - } else { - state.AddSplit(input_data, pos, list_idx); - list_idx++; - } - input_data += (pos + 1); - input_size -= (pos + 1); - } - if (input_size > 0) { - state.AddSplit(input_data, input_size, list_idx); - list_idx++; - } - return list_idx; -} - -static void ReadOptionalArgs(DataChunk &args, Vector &sep, Vector &trim, const bool &front_trim) { - switch (args.ColumnCount()) { - case 1: { - // use default values - break; - } - case 2: { - UnifiedVectorFormat sec_arg; - args.data[1].ToUnifiedFormat(args.size(), sec_arg); - if (sec_arg.validity.RowIsValid(0)) { // if not NULL - switch (args.data[1].GetType().id()) { - case LogicalTypeId::VARCHAR: { - sep.Reinterpret(args.data[1]); - break; - } - case LogicalTypeId::BOOLEAN: { // parse_path and parse_driname won't get in here - trim.Reinterpret(args.data[1]); - break; - } - default: - throw InvalidInputException("Invalid argument type"); - } - } - break; - } - case 3: { - if (!front_trim) { - // set trim_extension - UnifiedVectorFormat sec_arg; - args.data[1].ToUnifiedFormat(args.size(), sec_arg); - if (sec_arg.validity.RowIsValid(0)) { - trim.Reinterpret(args.data[1]); - } - UnifiedVectorFormat third_arg; - args.data[2].ToUnifiedFormat(args.size(), third_arg); - if (third_arg.validity.RowIsValid(0)) { - sep.Reinterpret(args.data[2]); - } - } else { - throw InvalidInputException("Invalid number of arguments"); - } - break; - } - default: - throw InvalidInputException("Invalid number of arguments"); - } -} - -template -static void TrimPathFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // set default values - Vector &path = args.data[0]; - Vector separator(string_t("default")); - Vector trim_extension(Value::BOOLEAN(false)); - ReadOptionalArgs(args, separator, trim_extension, FRONT_TRIM); - - TernaryExecutor::Execute( - path, separator, trim_extension, result, args.size(), - [&](string_t &inputs, string_t input_sep, bool trim_extension) { - auto data = inputs.GetData(); - auto input_size = inputs.GetSize(); - auto sep = GetSeparator(input_sep.GetString()); - - // find the beginning idx and the size of the result string - idx_t begin = 0; - idx_t new_size = input_size; - if (FRONT_TRIM) { // left trim - auto pos = Find(data, input_size, sep); - if (pos == 0) { // path starts with separator - pos = 1; - } - new_size = (IsIdxValid(pos, input_size)) ? pos : 0; - } else { // right trim - auto idx_last_sep = FindLast(data, input_size, sep); - if (IsIdxValid(idx_last_sep, input_size)) { - begin = idx_last_sep + 1; - } - if (trim_extension) { - auto idx_extension_sep = FindLast(data, input_size, "."); - if (begin <= idx_extension_sep && IsIdxValid(idx_extension_sep, input_size)) { - new_size = idx_extension_sep; - } - } - } - // copy the trimmed string - D_ASSERT(begin <= new_size); - auto target = StringVector::EmptyString(result, new_size - begin); - auto output = target.GetDataWriteable(); - memcpy(output, data + begin, new_size - begin); - - target.Finalize(); - return target; - }); -} - -static void ParseDirpathFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // set default values - Vector &path = args.data[0]; - Vector separator(string_t("default")); - Vector trim_extension(false); - ReadOptionalArgs(args, separator, trim_extension, true); - - BinaryExecutor::Execute( - path, separator, result, args.size(), [&](string_t input_path, string_t input_sep) { - auto path = input_path.GetData(); - auto path_size = input_path.GetSize(); - auto sep = GetSeparator(input_sep.GetString()); - - auto last_sep = FindLast(path, path_size, sep); - if (last_sep == 0 && path_size == 1) { - last_sep = 1; - } - idx_t new_size = (IsIdxValid(last_sep, path_size)) ? last_sep : 0; - - auto target = StringVector::EmptyString(result, new_size); - auto output = target.GetDataWriteable(); - memcpy(output, path, new_size); - target.Finalize(); - return StringVector::AddString(result, target); - }); -} - -static void ParsePathFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1 || args.ColumnCount() == 2); - UnifiedVectorFormat input_data; - args.data[0].ToUnifiedFormat(args.size(), input_data); - auto inputs = UnifiedVectorFormat::GetData(input_data); - - // set the separator - string input_sep = "default"; - if (args.ColumnCount() == 2) { - UnifiedVectorFormat sep_data; - args.data[1].ToUnifiedFormat(args.size(), sep_data); - if (sep_data.validity.RowIsValid(0)) { - input_sep = UnifiedVectorFormat::GetData(sep_data)->GetString(); - } - } - const string sep = GetSeparator(input_sep); - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - result.SetVectorType(VectorType::FLAT_VECTOR); - ListVector::SetListSize(result, 0); - - // set up the list entries - auto list_data = FlatVector::GetData(result); - auto &child_entry = ListVector::GetEntry(result); - auto &result_mask = FlatVector::Validity(result); - idx_t total_splits = 0; - for (idx_t i = 0; i < args.size(); i++) { - auto input_idx = input_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(input_idx)) { - result_mask.SetInvalid(i); - continue; - } - SplitInput split_input(result, child_entry, total_splits); - auto list_length = SplitPath(inputs[input_idx], sep, split_input); - list_data[i].length = list_length; - list_data[i].offset = total_splits; - total_splits += list_length; - } - ListVector::SetListSize(result, total_splits); - D_ASSERT(ListVector::GetListSize(result) == total_splits); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -ScalarFunctionSet ParseDirnameFun::GetFunctions() { - ScalarFunctionSet parse_dirname; - ScalarFunction func({LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING); - parse_dirname.AddFunction(func); - // separator options - func.arguments.emplace_back(LogicalType::VARCHAR); - parse_dirname.AddFunction(func); - return parse_dirname; -} - -ScalarFunctionSet ParseDirpathFun::GetFunctions() { - ScalarFunctionSet parse_dirpath; - ScalarFunction func({LogicalType::VARCHAR}, LogicalType::VARCHAR, ParseDirpathFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING); - parse_dirpath.AddFunction(func); - // separator options - func.arguments.emplace_back(LogicalType::VARCHAR); - parse_dirpath.AddFunction(func); - return parse_dirpath; -} - -ScalarFunctionSet ParseFilenameFun::GetFunctions() { - ScalarFunctionSet parse_filename; - parse_filename.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, - nullptr, nullptr, nullptr, nullptr, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - parse_filename.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, - nullptr, nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - parse_filename.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::BOOLEAN}, LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, - nullptr, nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - parse_filename.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BOOLEAN, LogicalType::VARCHAR}, - LogicalType::VARCHAR, TrimPathFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING)); - return parse_filename; -} - -ScalarFunctionSet ParsePathFun::GetFunctions() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - ScalarFunctionSet parse_path; - ScalarFunction func({LogicalType::VARCHAR}, varchar_list_type, ParsePathFunction, nullptr, nullptr, nullptr, - nullptr, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING); - parse_path.AddFunction(func); - // separator options - func.arguments.emplace_back(LogicalType::VARCHAR); - parse_path.AddFunction(func); - return parse_path; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/printf.cpp b/src/duckdb/extension/core_functions/scalar/string/printf.cpp deleted file mode 100644 index 1db25b0df..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/printf.cpp +++ /dev/null @@ -1,189 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/limits.hpp" -#include "fmt/format.h" -#include "fmt/printf.h" - -namespace duckdb { - -struct FMTPrintf { - template - static string OP(const char *format_str, vector> &format_args) { - return duckdb_fmt::vsprintf( - format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); - } -}; - -struct FMTFormat { - template - static string OP(const char *format_str, vector> &format_args) { - return duckdb_fmt::vformat( - format_str, duckdb_fmt::basic_format_args(format_args.data(), static_cast(format_args.size()))); - } -}; - -unique_ptr BindPrintfFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - for (idx_t i = 1; i < arguments.size(); i++) { - switch (arguments[i]->return_type.id()) { - case LogicalTypeId::BOOLEAN: - bound_function.arguments.emplace_back(LogicalType::BOOLEAN); - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - bound_function.arguments.emplace_back(LogicalType::BIGINT); - break; - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - bound_function.arguments.emplace_back(LogicalType::UBIGINT); - break; - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - bound_function.arguments.emplace_back(LogicalType::DOUBLE); - break; - case LogicalTypeId::VARCHAR: - bound_function.arguments.push_back(LogicalType::VARCHAR); - break; - case LogicalTypeId::DECIMAL: - // decimal type: add cast to double - bound_function.arguments.emplace_back(LogicalType::DOUBLE); - break; - case LogicalTypeId::UNKNOWN: - // parameter: accept any input and rebind later - bound_function.arguments.emplace_back(LogicalType::ANY); - break; - default: - // all other types: add cast to string - bound_function.arguments.emplace_back(LogicalType::VARCHAR); - break; - } - } - return nullptr; -} - -template -static void PrintfFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &format_string = args.data[0]; - auto &result_validity = FlatVector::Validity(result); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - result_validity.Initialize(args.size()); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - switch (args.data[i].GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - if (ConstantVector::IsNull(args.data[i])) { - // constant null! result is always NULL regardless of other input - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - break; - default: - // FLAT VECTOR, we can directly OR the nullmask - args.data[i].Flatten(args.size()); - result.SetVectorType(VectorType::FLAT_VECTOR); - result_validity.Combine(FlatVector::Validity(args.data[i]), args.size()); - break; - } - } - idx_t count = result.GetVectorType() == VectorType::CONSTANT_VECTOR ? 1 : args.size(); - - auto format_data = FlatVector::GetData(format_string); - auto result_data = FlatVector::GetData(result); - for (idx_t idx = 0; idx < count; idx++) { - if (result.GetVectorType() == VectorType::FLAT_VECTOR && FlatVector::IsNull(result, idx)) { - // this entry is NULL: skip it - continue; - } - - // first fetch the format string - auto fmt_idx = format_string.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; - auto format_string = format_data[fmt_idx].GetString(); - - // now gather all the format arguments - vector> format_args; - vector> string_args; - - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &col = args.data[col_idx]; - idx_t arg_idx = col.GetVectorType() == VectorType::CONSTANT_VECTOR ? 0 : idx; - switch (col.GetType().id()) { - case LogicalTypeId::BOOLEAN: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::TINYINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::SMALLINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::INTEGER: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::BIGINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::UBIGINT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::FLOAT: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::DOUBLE: { - auto arg_data = FlatVector::GetData(col); - format_args.emplace_back(duckdb_fmt::internal::make_arg(arg_data[arg_idx])); - break; - } - case LogicalTypeId::VARCHAR: { - auto arg_data = FlatVector::GetData(col); - auto string_view = - duckdb_fmt::basic_string_view(arg_data[arg_idx].GetData(), arg_data[arg_idx].GetSize()); - format_args.emplace_back(duckdb_fmt::internal::make_arg(string_view)); - break; - } - default: - throw InternalException("Unexpected type for printf format"); - } - } - // finally actually perform the format - string dynamic_result = FORMAT_FUN::template OP(format_string.c_str(), format_args); - result_data[idx] = StringVector::AddString(result, dynamic_result); - } -} - -ScalarFunction PrintfFun::GetFunction() { - // duckdb_fmt::printf_context, duckdb_fmt::vsprintf - ScalarFunction printf_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, - PrintfFunction, BindPrintfFunction); - printf_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(printf_fun); - return printf_fun; -} - -ScalarFunction FormatFun::GetFunction() { - // duckdb_fmt::format_context, duckdb_fmt::vformat - ScalarFunction format_fun({LogicalType::VARCHAR}, LogicalType::VARCHAR, - PrintfFunction, BindPrintfFunction); - format_fun.varargs = LogicalType::ANY; - BaseScalarFunction::SetReturnsError(format_fun); - return format_fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp b/src/duckdb/extension/core_functions/scalar/string/repeat.cpp deleted file mode 100644 index 154634f94..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/repeat.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/operator/multiply.hpp" - -namespace duckdb { - -static void RepeatFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto &str_vector = args.data[0]; - auto &cnt_vector = args.data[1]; - - BinaryExecutor::Execute( - str_vector, cnt_vector, result, args.size(), [&](string_t str, int64_t cnt) { - auto input_str = str.GetData(); - auto size_str = str.GetSize(); - idx_t copy_count = cnt <= 0 || size_str == 0 ? 0 : UnsafeNumericCast(cnt); - - idx_t copy_size; - if (TryMultiplyOperator::Operation(size_str, copy_count, copy_size)) { - auto result_str = StringVector::EmptyString(result, copy_size); - auto result_data = result_str.GetDataWriteable(); - for (idx_t i = 0; i < copy_count; i++) { - memcpy(result_data + i * size_str, input_str, size_str); - } - result_str.Finalize(); - return result_str; - } else { - throw OutOfRangeException( - "Cannot create a string of size: '%d' * '%d', the maximum supported string size is: '%d'", size_str, - copy_count, string_t::MAX_STRING_SIZE); - } - }); -} - -unique_ptr RepeatBindFunction(ClientContext &, ScalarFunction &bound_function, - vector> &arguments) { - switch (arguments[0]->return_type.id()) { - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - case LogicalTypeId::LIST: - break; - default: - throw NotImplementedException("repeat(list, count) requires a list as parameter"); - } - bound_function.arguments[0] = arguments[0]->return_type; - bound_function.return_type = arguments[0]->return_type; - return nullptr; -} - -static void RepeatListFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto &list_vector = args.data[0]; - auto &cnt_vector = args.data[1]; - - auto &source_child = ListVector::GetEntry(list_vector); - auto &result_child = ListVector::GetEntry(result); - - idx_t current_size = ListVector::GetListSize(result); - BinaryExecutor::Execute( - list_vector, cnt_vector, result, args.size(), [&](list_entry_t list_input, int64_t cnt) { - idx_t copy_count = cnt <= 0 || list_input.length == 0 ? 0 : UnsafeNumericCast(cnt); - idx_t result_length = list_input.length * copy_count; - idx_t new_size = current_size + result_length; - ListVector::Reserve(result, new_size); - list_entry_t result_list; - result_list.offset = current_size; - result_list.length = result_length; - for (idx_t i = 0; i < copy_count; i++) { - // repeat the list contents "cnt" times - VectorOperations::Copy(source_child, result_child, list_input.offset + list_input.length, - list_input.offset, current_size); - current_size += list_input.length; - } - return result_list; - }); - ListVector::SetListSize(result, current_size); -} - -ScalarFunctionSet RepeatFun::GetFunctions() { - ScalarFunctionSet repeat; - for (const auto &type : {LogicalType::VARCHAR, LogicalType::BLOB}) { - repeat.AddFunction(ScalarFunction({type, LogicalType::BIGINT}, type, RepeatFunction)); - } - repeat.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, - LogicalType::LIST(LogicalType::ANY), RepeatListFunction, RepeatBindFunction)); - for (auto &func : repeat.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return repeat; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/replace.cpp b/src/duckdb/extension/core_functions/scalar/string/replace.cpp deleted file mode 100644 index 4702292c5..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/replace.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" - -#include -#include -#include - -namespace duckdb { - -static idx_t NextNeedle(const char *input_haystack, idx_t size_haystack, const char *input_needle, - const idx_t size_needle) { - // Needle needs something to proceed - if (size_needle > 0) { - // Haystack should be bigger or equal size to the needle - for (idx_t string_position = 0; (size_haystack - string_position) >= size_needle; ++string_position) { - // Compare Needle to the Haystack - if ((memcmp(input_haystack + string_position, input_needle, size_needle) == 0)) { - return string_position; - } - } - } - // Did not find the needle - return size_haystack; -} - -static string_t ReplaceScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, - vector &result) { - // Get information about the needle, the haystack and the "thread" - auto input_haystack = haystack.GetData(); - auto size_haystack = haystack.GetSize(); - - auto input_needle = needle.GetData(); - auto size_needle = needle.GetSize(); - - auto input_thread = thread.GetData(); - auto size_thread = thread.GetSize(); - - // Reuse the buffer - result.clear(); - - for (;;) { - // Append the non-matching characters - auto string_position = NextNeedle(input_haystack, size_haystack, input_needle, size_needle); - result.insert(result.end(), input_haystack, input_haystack + string_position); - input_haystack += string_position; - size_haystack -= string_position; - - // Stop when we have read the entire haystack - if (size_haystack == 0) { - break; - } - - // Replace the matching characters - result.insert(result.end(), input_thread, input_thread + size_thread); - input_haystack += size_needle; - size_haystack -= size_needle; - } - - return string_t(result.data(), UnsafeNumericCast(result.size())); -} - -static void ReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), - [&](string_t input_string, string_t needle_string, string_t thread_string) { - return StringVector::AddString(result, - ReplaceScalarFunction(input_string, needle_string, thread_string, buffer)); - }); -} - -ScalarFunction ReplaceFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - ReplaceFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/reverse.cpp b/src/duckdb/extension/core_functions/scalar/string/reverse.cpp deleted file mode 100644 index 4ff654909..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/reverse.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -//! Fast ASCII string reverse, returns false if the input data is not ascii -static bool StrReverseASCII(const char *input, idx_t n, char *output) { - for (idx_t i = 0; i < n; i++) { - if (input[i] & 0x80) { - // non-ascii character - return false; - } - output[n - i - 1] = input[i]; - } - return true; -} - -//! Unicode string reverse using grapheme breakers -static void StrReverseUnicode(const char *input, idx_t n, char *output) { - for (auto cluster : Utf8Proc::GraphemeClusters(input, n)) { - memcpy(output + n - cluster.end, input + cluster.start, cluster.end - cluster.start); - } -} - -struct ReverseOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - - auto target = StringVector::EmptyString(result, input_length); - auto target_data = target.GetDataWriteable(); - if (!StrReverseASCII(input_data, input_length, target_data)) { - StrReverseUnicode(input_data, input_length, target_data); - } - target.Finalize(); - return target; - } -}; - -static void ReverseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction ReverseFun::GetFunction() { - return ScalarFunction("reverse", {LogicalType::VARCHAR}, LogicalType::VARCHAR, ReverseFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp b/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp deleted file mode 100644 index 7ef277292..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/starts_with.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -static bool StartsWith(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t needle_size) { - D_ASSERT(needle_size > 0); - if (needle_size > haystack_size) { - // needle is bigger than haystack: haystack cannot start with needle - return false; - } - return memcmp(haystack, needle, needle_size) == 0; -} - -static bool StartsWith(const string_t &haystack_s, const string_t &needle_s) { - - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); - auto haystack_size = haystack_s.GetSize(); - auto needle = const_uchar_ptr_cast(needle_s.GetData()); - auto needle_size = needle_s.GetSize(); - if (needle_size == 0) { - // empty needle: always true - return true; - } - return StartsWith(haystack, haystack_size, needle, needle_size); -} - -struct StartsWithOperator { - template - static inline TR Operation(TA left, TB right) { - return StartsWith(left, right); - } -}; - -ScalarFunction StartsWithOperatorFun::GetFunction() { - ScalarFunction starts_with({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction); - starts_with.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return starts_with; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/to_base.cpp b/src/duckdb/extension/core_functions/scalar/string/to_base.cpp deleted file mode 100644 index f85f54be9..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/to_base.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" - -namespace duckdb { - -static const char alphabet[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"; - -static unique_ptr ToBaseBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // If no min_length is specified, default to 0 - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - if (arguments.size() == 2) { - arguments.push_back(make_uniq_base(Value::INTEGER(0))); - } - return nullptr; -} - -static void ToBaseFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - auto &radix = args.data[1]; - auto &min_length = args.data[2]; - auto count = args.size(); - - TernaryExecutor::Execute( - input, radix, min_length, result, count, [&](int64_t input, int32_t radix, int32_t min_length) { - if (input < 0) { - throw InvalidInputException("'to_base' number must be greater than or equal to 0"); - } - if (radix < 2 || radix > 36) { - throw InvalidInputException("'to_base' radix must be between 2 and 36"); - } - if (min_length > 64 || min_length < 0) { - throw InvalidInputException("'to_base' min_length must be between 0 and 64"); - } - - char buf[64]; - char *end = buf + sizeof(buf); - char *ptr = end; - do { - *--ptr = alphabet[input % radix]; - input /= radix; - } while (input > 0); - - auto length = end - ptr; - while (length < min_length) { - *--ptr = '0'; - length++; - } - - return StringVector::AddString(result, ptr, UnsafeNumericCast(end - ptr)); - }); -} - -ScalarFunctionSet ToBaseFun::GetFunctions() { - ScalarFunctionSet set("to_base"); - - set.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER}, LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); - set.AddFunction(ScalarFunction({LogicalType::BIGINT, LogicalType::INTEGER, LogicalType::INTEGER}, - LogicalType::VARCHAR, ToBaseFunction, ToBaseBind)); - - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/translate.cpp b/src/duckdb/extension/core_functions/scalar/string/translate.cpp deleted file mode 100644 index ca661cb3d..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/translate.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "utf8proc.hpp" -#include "utf8proc_wrapper.hpp" - -#include -#include -#include -#include - -namespace duckdb { - -static string_t TranslateScalarFunction(const string_t &haystack, const string_t &needle, const string_t &thread, - vector &result) { - // Get information about the haystack, the needle and the "thread" - auto input_haystack = haystack.GetData(); - auto size_haystack = haystack.GetSize(); - - auto input_needle = needle.GetData(); - auto size_needle = needle.GetSize(); - - auto input_thread = thread.GetData(); - auto size_thread = thread.GetSize(); - - // Reuse the buffer - result.clear(); - result.reserve(size_haystack); - - idx_t i = 0, j = 0; - int sz = 0, c_sz = 0; - - // Character to be replaced - unordered_map to_replace; - while (i < size_needle && j < size_thread) { - auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); - input_needle += sz; - i += UnsafeNumericCast(sz); - auto codepoint_thread = Utf8Proc::UTF8ToCodepoint(input_thread, sz); - input_thread += sz; - j += UnsafeNumericCast(sz); - // Ignore unicode character that is existed in to_replace - if (to_replace.count(codepoint_needle) == 0) { - to_replace[codepoint_needle] = codepoint_thread; - } - } - - // Character to be deleted - unordered_set to_delete; - while (i < size_needle) { - auto codepoint_needle = Utf8Proc::UTF8ToCodepoint(input_needle, sz); - input_needle += sz; - i += UnsafeNumericCast(sz); - // Add unicode character that will be deleted - if (to_replace.count(codepoint_needle) == 0) { - to_delete.insert(codepoint_needle); - } - } - - char c[5] = {'\0', '\0', '\0', '\0', '\0'}; - for (i = 0; i < size_haystack; i += UnsafeNumericCast(sz)) { - auto codepoint_haystack = Utf8Proc::UTF8ToCodepoint(input_haystack, sz); - if (to_replace.count(codepoint_haystack) != 0) { - Utf8Proc::CodepointToUtf8(to_replace[codepoint_haystack], c_sz, c); - result.insert(result.end(), c, c + c_sz); - } else if (to_delete.count(codepoint_haystack) == 0) { - result.insert(result.end(), input_haystack, input_haystack + sz); - } - input_haystack += sz; - } - - return string_t(result.data(), UnsafeNumericCast(result.size())); -} - -static void TranslateFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &haystack_vector = args.data[0]; - auto &needle_vector = args.data[1]; - auto &thread_vector = args.data[2]; - - vector buffer; - TernaryExecutor::Execute( - haystack_vector, needle_vector, thread_vector, result, args.size(), - [&](string_t input_string, string_t needle_string, string_t thread_string) { - return StringVector::AddString(result, - TranslateScalarFunction(input_string, needle_string, thread_string, buffer)); - }); -} - -ScalarFunction TranslateFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - TranslateFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/trim.cpp b/src/duckdb/extension/core_functions/scalar/string/trim.cpp deleted file mode 100644 index 5553d75e3..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/trim.cpp +++ /dev/null @@ -1,158 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "utf8proc.hpp" - -#include - -namespace duckdb { - -template -struct TrimOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto data = input.GetData(); - auto size = input.GetSize(); - - utf8proc_int32_t codepoint; - auto str = reinterpret_cast(data); - - // Find the first character that is not left trimmed - idx_t begin = 0; - if (LTRIM) { - while (begin < size) { - auto bytes = - utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); - D_ASSERT(bytes > 0); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - break; - } - begin += UnsafeNumericCast(bytes); - } - } - - // Find the last character that is not right trimmed - idx_t end; - if (RTRIM) { - end = begin; - for (auto next = begin; next < size;) { - auto bytes = utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); - D_ASSERT(bytes > 0); - next += UnsafeNumericCast(bytes); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - end = next; - } - } - } else { - end = size; - } - - // Copy the trimmed string - auto target = StringVector::EmptyString(result, end - begin); - auto output = target.GetDataWriteable(); - memcpy(output, data + begin, end - begin); - - target.Finalize(); - return target; - } -}; - -template -static void UnaryTrimFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); -} - -static void GetIgnoredCodepoints(string_t ignored, unordered_set &ignored_codepoints) { - auto dataptr = reinterpret_cast(ignored.GetData()); - auto size = ignored.GetSize(); - idx_t pos = 0; - while (pos < size) { - utf8proc_int32_t codepoint; - pos += UnsafeNumericCast( - utf8proc_iterate(dataptr + pos, UnsafeNumericCast(size - pos), &codepoint)); - ignored_codepoints.insert(codepoint); - } -} - -template -static void BinaryTrimFunction(DataChunk &input, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - input.data[0], input.data[1], result, input.size(), [&](string_t input, string_t ignored) { - auto data = input.GetData(); - auto size = input.GetSize(); - - unordered_set ignored_codepoints; - GetIgnoredCodepoints(ignored, ignored_codepoints); - - utf8proc_int32_t codepoint; - auto str = reinterpret_cast(data); - - // Find the first character that is not left trimmed - idx_t begin = 0; - if (LTRIM) { - while (begin < size) { - auto bytes = - utf8proc_iterate(str + begin, UnsafeNumericCast(size - begin), &codepoint); - if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { - break; - } - begin += UnsafeNumericCast(bytes); - } - } - - // Find the last character that is not right trimmed - idx_t end; - if (RTRIM) { - end = begin; - for (auto next = begin; next < size;) { - auto bytes = - utf8proc_iterate(str + next, UnsafeNumericCast(size - next), &codepoint); - D_ASSERT(bytes > 0); - next += UnsafeNumericCast(bytes); - if (ignored_codepoints.find(codepoint) == ignored_codepoints.end()) { - end = next; - } - } - } else { - end = size; - } - - // Copy the trimmed string - auto target = StringVector::EmptyString(result, end - begin); - auto output = target.GetDataWriteable(); - memcpy(output, data + begin, end - begin); - - target.Finalize(); - return target; - }); -} - -ScalarFunctionSet TrimFun::GetFunctions() { - ScalarFunctionSet trim; - trim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - - trim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return trim; -} - -ScalarFunctionSet LtrimFun::GetFunctions() { - ScalarFunctionSet ltrim; - ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - ltrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return ltrim; -} - -ScalarFunctionSet RtrimFun::GetFunctions() { - ScalarFunctionSet rtrim; - rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, UnaryTrimFunction)); - - rtrim.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - BinaryTrimFunction)); - return rtrim; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/unicode.cpp b/src/duckdb/extension/core_functions/scalar/string/unicode.cpp deleted file mode 100644 index 902c7c5e7..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/unicode.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "core_functions/scalar/string_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "utf8proc.hpp" - -#include - -namespace duckdb { - -struct UnicodeOperator { - template - static inline TR Operation(const TA &input) { - auto str = reinterpret_cast(input.GetData()); - auto len = input.GetSize(); - utf8proc_int32_t codepoint; - (void)utf8proc_iterate(str, UnsafeNumericCast(len), &codepoint); - return codepoint; - } -}; - -ScalarFunction UnicodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::INTEGER, - ScalarFunction::UnaryFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp b/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp deleted file mode 100644 index 17b9ad3cc..000000000 --- a/src/duckdb/extension/core_functions/scalar/string/url_encode.cpp +++ /dev/null @@ -1,49 +0,0 @@ -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "core_functions/scalar/string_functions.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -struct URLEncodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_str = input.GetData(); - auto input_size = input.GetSize(); - idx_t result_length = StringUtil::URLEncodeSize(input_str, input_size); - auto result_str = StringVector::EmptyString(result, result_length); - StringUtil::URLEncodeBuffer(input_str, input_size, result_str.GetDataWriteable()); - result_str.Finalize(); - return result_str; - } -}; - -static void URLEncodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction UrlEncodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLEncodeFunction); -} - -struct URLDecodeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_str = input.GetData(); - auto input_size = input.GetSize(); - idx_t result_length = StringUtil::URLDecodeSize(input_str, input_size); - auto result_str = StringVector::EmptyString(result, result_length); - StringUtil::URLDecodeBuffer(input_str, input_size, result_str.GetDataWriteable()); - result_str.Finalize(); - return result_str; - } -}; - -static void URLDecodeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction UrlDecodeFun::GetFunction() { - return ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, URLDecodeFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp b/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp deleted file mode 100644 index c83a83e3c..000000000 --- a/src/duckdb/extension/core_functions/scalar/struct/struct_insert.cpp +++ /dev/null @@ -1,103 +0,0 @@ -#include "core_functions/scalar/struct_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/common/case_insensitive_map.hpp" -#include "duckdb/storage/statistics/struct_stats.hpp" -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -static void StructInsertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &starting_vec = args.data[0]; - starting_vec.Verify(args.size()); - - auto &starting_child_entries = StructVector::GetEntries(starting_vec); - auto &result_child_entries = StructVector::GetEntries(result); - - // Assign the original child entries to the STRUCT. - for (idx_t i = 0; i < starting_child_entries.size(); i++) { - auto &starting_child = starting_child_entries[i]; - result_child_entries[i]->Reference(*starting_child); - } - - // Assign the new children to the result vector. - for (idx_t i = 1; i < args.ColumnCount(); i++) { - result_child_entries[starting_child_entries.size() + i - 1]->Reference(args.data[i]); - } - - result.Verify(args.size()); - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr StructInsertBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.empty()) { - throw InvalidInputException("Missing required arguments for struct_insert function."); - } - if (LogicalTypeId::STRUCT != arguments[0]->return_type.id()) { - throw InvalidInputException("The first argument to struct_insert must be a STRUCT"); - } - if (arguments.size() < 2) { - throw InvalidInputException("Can't insert nothing into a STRUCT"); - } - - case_insensitive_set_t name_collision_set; - child_list_t new_children; - auto &existing_children = StructType::GetChildTypes(arguments[0]->return_type); - - for (idx_t i = 0; i < existing_children.size(); i++) { - auto &child = existing_children[i]; - name_collision_set.insert(child.first); - new_children.push_back(make_pair(child.first, child.second)); - } - - // Loop through the additional arguments (name/value pairs) - for (idx_t i = 1; i < arguments.size(); i++) { - auto &child = arguments[i]; - if (child->GetAlias().empty()) { - throw BinderException("Need named argument for struct insert, e.g., a := b"); - } - if (name_collision_set.find(child->GetAlias()) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\"", child->GetAlias()); - } - name_collision_set.insert(child->GetAlias()); - new_children.push_back(make_pair(child->GetAlias(), arguments[i]->return_type)); - } - - bound_function.return_type = LogicalType::STRUCT(new_children); - return make_uniq(bound_function.return_type); -} - -unique_ptr StructInsertStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto new_stats = StructStats::CreateUnknown(expr.return_type); - - auto existing_count = StructType::GetChildCount(child_stats[0].GetType()); - auto existing_stats = StructStats::GetChildStats(child_stats[0]); - for (idx_t i = 0; i < existing_count; i++) { - StructStats::SetChildStats(new_stats, i, existing_stats[i]); - } - - auto new_count = StructType::GetChildCount(expr.return_type); - auto offset = new_count - child_stats.size(); - for (idx_t i = 1; i < child_stats.size(); i++) { - StructStats::SetChildStats(new_stats, offset + i, child_stats[i]); - } - return new_stats.ToUnique(); -} - -ScalarFunction StructInsertFun::GetFunction() { - ScalarFunction fun({}, LogicalTypeId::STRUCT, StructInsertFunction, StructInsertBind, nullptr, StructInsertStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.varargs = LogicalType::ANY; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp b/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp deleted file mode 100644 index 2a5371076..000000000 --- a/src/duckdb/extension/core_functions/scalar/union/union_extract.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "core_functions/scalar/union_functions.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" - -namespace duckdb { - -struct UnionExtractBindData : public FunctionData { - UnionExtractBindData(string key, idx_t index, LogicalType type) - : key(std::move(key)), index(index), type(std::move(type)) { - } - - string key; - idx_t index; - LogicalType type; - -public: - unique_ptr Copy() const override { - return make_uniq(key, index, type); - } - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return key == other.key && index == other.index && type == other.type; - } -}; - -static void UnionExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // this should be guaranteed by the binder - auto &vec = args.data[0]; - vec.Verify(args.size()); - - D_ASSERT(info.index < UnionType::GetMemberCount(vec.GetType())); - auto &member = UnionVector::GetMember(vec, info.index); - result.Reference(member); - result.Verify(args.size()); -} - -static unique_ptr UnionExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - if (arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (arguments[0]->return_type.id() != LogicalTypeId::UNION) { - throw BinderException("union_extract can only take a union parameter"); - } - idx_t union_member_count = UnionType::GetMemberCount(arguments[0]->return_type); - if (union_member_count == 0) { - throw InternalException("Can't extract something from an empty union"); - } - bound_function.arguments[0] = arguments[0]->return_type; - - auto &key_child = arguments[1]; - if (key_child->HasParameter()) { - throw ParameterNotResolvedException(); - } - - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw BinderException("Key name for union_extract needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - auto &key_str = StringValue::Get(key_val); - if (key_val.IsNull() || key_str.empty()) { - throw BinderException("Key name for union_extract needs to be neither NULL nor empty"); - } - string key = StringUtil::Lower(key_str); - - LogicalType return_type; - idx_t key_index = 0; - bool found_key = false; - - for (size_t i = 0; i < union_member_count; i++) { - auto &member_name = UnionType::GetMemberName(arguments[0]->return_type, i); - if (StringUtil::Lower(member_name) == key) { - found_key = true; - key_index = i; - return_type = UnionType::GetMemberType(arguments[0]->return_type, i); - break; - } - } - - if (!found_key) { - vector candidates; - candidates.reserve(union_member_count); - for (idx_t i = 0; i < union_member_count; i++) { - candidates.push_back(UnionType::GetMemberName(arguments[0]->return_type, i)); - } - auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); - auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in union\n%s", key, message); - } - - bound_function.return_type = return_type; - return make_uniq(key, key_index, return_type); -} - -ScalarFunction UnionExtractFun::GetFunction() { - // the arguments and return types are actually set in the binder function - return ScalarFunction({LogicalTypeId::UNION, LogicalType::VARCHAR}, LogicalType::ANY, UnionExtractFunction, - UnionExtractBind, nullptr, nullptr); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp b/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp deleted file mode 100644 index 173e36d6c..000000000 --- a/src/duckdb/extension/core_functions/scalar/union/union_tag.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "core_functions/scalar/union_functions.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" - -namespace duckdb { - -static unique_ptr UnionTagBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - if (arguments.empty()) { - throw BinderException("Missing required arguments for union_tag function."); - } - - if (LogicalTypeId::UNKNOWN == arguments[0]->return_type.id()) { - throw ParameterNotResolvedException(); - } - - if (LogicalTypeId::UNION != arguments[0]->return_type.id()) { - throw BinderException("First argument to union_tag function must be a union type."); - } - - if (arguments.size() > 1) { - throw BinderException("Too many arguments, union_tag takes at most one argument."); - } - - auto member_count = UnionType::GetMemberCount(arguments[0]->return_type); - if (member_count == 0) { - // this should never happen, empty unions are not allowed - throw InternalException("Can't get tags from an empty union"); - } - - bound_function.arguments[0] = arguments[0]->return_type; - - auto varchar_vector = Vector(LogicalType::VARCHAR, member_count); - for (idx_t i = 0; i < member_count; i++) { - auto str = string_t(UnionType::GetMemberName(arguments[0]->return_type, i)); - FlatVector::GetData(varchar_vector)[i] = - str.IsInlined() ? str : StringVector::AddString(varchar_vector, str); - } - auto enum_type = LogicalType::ENUM(varchar_vector, member_count); - bound_function.return_type = enum_type; - - return nullptr; -} - -static void UnionTagFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalTypeId::ENUM); - result.Reinterpret(UnionVector::GetTags(args.data[0])); -} - -ScalarFunction UnionTagFun::GetFunction() { - return ScalarFunction({LogicalTypeId::UNION}, LogicalTypeId::ANY, UnionTagFunction, UnionTagBind, nullptr, - nullptr); // TODO: Statistics? -} - -} // namespace duckdb diff --git a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp b/src/duckdb/extension/core_functions/scalar/union/union_value.cpp deleted file mode 100644 index 655003da9..000000000 --- a/src/duckdb/extension/core_functions/scalar/union/union_value.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "core_functions/scalar/union_functions.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" - -namespace duckdb { - -struct UnionValueBindData : public FunctionData { - UnionValueBindData() { - } - -public: - unique_ptr Copy() const override { - return make_uniq(); - } - bool Equals(const FunctionData &other_p) const override { - return true; - } -}; - -static void UnionValueFunction(DataChunk &args, ExpressionState &state, Vector &result) { - // Assign the new entries to the result vector - UnionVector::GetMember(result, 0).Reference(args.data[0]); - - // Set the result tag vector to a constant value - auto &tag_vector = UnionVector::GetTags(result); - tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(tag_vector)[0] = 0; - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(args.size()); -} - -static unique_ptr UnionValueBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - if (arguments.size() != 1) { - throw BinderException("union_value takes exactly one argument"); - } - auto &child = arguments[0]; - - if (child->GetAlias().empty()) { - throw BinderException("Need named argument for union tag, e.g. UNION_VALUE(a := b)"); - } - - child_list_t union_members; - - union_members.push_back(make_pair(child->GetAlias(), child->return_type)); - - bound_function.return_type = LogicalType::UNION(std::move(union_members)); - return make_uniq(bound_function.return_type); -} - -ScalarFunction UnionValueFun::GetFunction() { - ScalarFunction fun("union_value", {}, LogicalTypeId::UNION, UnionValueFunction, UnionValueBind, nullptr, nullptr); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/column_reader.cpp b/src/duckdb/extension/parquet/column_reader.cpp deleted file mode 100644 index 8a2112c35..000000000 --- a/src/duckdb/extension/parquet/column_reader.cpp +++ /dev/null @@ -1,1732 +0,0 @@ -#include "column_reader.hpp" - -#include "boolean_column_reader.hpp" -#include "brotli/decode.h" -#include "callback_column_reader.hpp" -#include "cast_column_reader.hpp" -#include "duckdb.hpp" -#include "expression_column_reader.hpp" -#include "list_column_reader.hpp" -#include "lz4.hpp" -#include "miniz_wrapper.hpp" -#include "null_column_reader.hpp" -#include "parquet_decimal_utils.hpp" -#include "parquet_reader.hpp" -#include "parquet_timestamp.hpp" -#include "row_number_column_reader.hpp" -#include "snappy.h" -#include "string_column_reader.hpp" -#include "struct_column_reader.hpp" -#include "templated_column_reader.hpp" -#include "utf8proc_wrapper.hpp" -#include "zstd.h" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/blob.hpp" -#endif - -namespace duckdb { - -using duckdb_parquet::CompressionCodec; -using duckdb_parquet::ConvertedType; -using duckdb_parquet::Encoding; -using duckdb_parquet::PageType; -using duckdb_parquet::Type; - -const uint64_t ParquetDecodeUtils::BITPACK_MASKS[] = {0, - 1, - 3, - 7, - 15, - 31, - 63, - 127, - 255, - 511, - 1023, - 2047, - 4095, - 8191, - 16383, - 32767, - 65535, - 131071, - 262143, - 524287, - 1048575, - 2097151, - 4194303, - 8388607, - 16777215, - 33554431, - 67108863, - 134217727, - 268435455, - 536870911, - 1073741823, - 2147483647, - 4294967295, - 8589934591, - 17179869183, - 34359738367, - 68719476735, - 137438953471, - 274877906943, - 549755813887, - 1099511627775, - 2199023255551, - 4398046511103, - 8796093022207, - 17592186044415, - 35184372088831, - 70368744177663, - 140737488355327, - 281474976710655, - 562949953421311, - 1125899906842623, - 2251799813685247, - 4503599627370495, - 9007199254740991, - 18014398509481983, - 36028797018963967, - 72057594037927935, - 144115188075855871, - 288230376151711743, - 576460752303423487, - 1152921504606846975, - 2305843009213693951, - 4611686018427387903, - 9223372036854775807, - 18446744073709551615ULL}; - -const uint64_t ParquetDecodeUtils::BITPACK_MASKS_SIZE = sizeof(ParquetDecodeUtils::BITPACK_MASKS) / sizeof(uint64_t); - -const uint8_t ParquetDecodeUtils::BITPACK_DLEN = 8; - -ColumnReader::ColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : schema(schema_p), file_idx(file_idx_p), max_define(max_define_p), max_repeat(max_repeat_p), reader(reader), - type(std::move(type_p)), page_rows_available(0), dictionary_selection_vector(STANDARD_VECTOR_SIZE), - dictionary_size(0) { - - // dummies for Skip() - dummy_define.resize(reader.allocator, STANDARD_VECTOR_SIZE); - dummy_repeat.resize(reader.allocator, STANDARD_VECTOR_SIZE); -} - -ColumnReader::~ColumnReader() { -} - -Allocator &ColumnReader::GetAllocator() { - return reader.allocator; -} - -ParquetReader &ColumnReader::Reader() { - return reader; -} - -const LogicalType &ColumnReader::Type() const { - return type; -} - -const SchemaElement &ColumnReader::Schema() const { - return schema; -} - -optional_ptr ColumnReader::GetParentSchema() const { - return parent_schema; -} - -void ColumnReader::SetParentSchema(const SchemaElement &parent_schema_p) { - parent_schema = &parent_schema_p; -} - -idx_t ColumnReader::FileIdx() const { - return file_idx; -} - -idx_t ColumnReader::MaxDefine() const { - return max_define; -} - -idx_t ColumnReader::MaxRepeat() const { - return max_repeat; -} - -void ColumnReader::RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) { - if (chunk) { - uint64_t size = chunk->meta_data.total_compressed_size; - transport.RegisterPrefetch(FileOffset(), size, allow_merge); - } -} - -uint64_t ColumnReader::TotalCompressedSize() { - if (!chunk) { - return 0; - } - - return chunk->meta_data.total_compressed_size; -} - -// Note: It's not trivial to determine where all Column data is stored. Chunk->file_offset -// apparently is not the first page of the data. Therefore we determine the address of the first page by taking the -// minimum of all page offsets. -idx_t ColumnReader::FileOffset() const { - if (!chunk) { - throw std::runtime_error("FileOffset called on ColumnReader with no chunk"); - } - auto min_offset = NumericLimits::Maximum(); - if (chunk->meta_data.__isset.dictionary_page_offset) { - min_offset = MinValue(min_offset, chunk->meta_data.dictionary_page_offset); - } - if (chunk->meta_data.__isset.index_page_offset) { - min_offset = MinValue(min_offset, chunk->meta_data.index_page_offset); - } - min_offset = MinValue(min_offset, chunk->meta_data.data_page_offset); - - return min_offset; -} - -idx_t ColumnReader::GroupRowsAvailable() { - return group_rows_available; -} - -unique_ptr ColumnReader::Stats(idx_t row_group_idx_p, const vector &columns) { - return ParquetStatisticsUtils::TransformColumnStatistics(*this, columns); -} - -void ColumnReader::Plain(shared_ptr plain_data, uint8_t *defines, idx_t num_values, // NOLINT - parquet_filter_t *filter, idx_t result_offset, Vector &result) { - throw NotImplementedException("Plain"); -} - -void ColumnReader::PrepareDeltaLengthByteArray(ResizeableBuffer &buffer) { - throw std::runtime_error("DELTA_LENGTH_BYTE_ARRAY encoding is only supported for text or binary data"); -} - -void ColumnReader::PrepareDeltaByteArray(ResizeableBuffer &buffer) { - throw std::runtime_error("DELTA_BYTE_ARRAY encoding is only supported for text or binary data"); -} - -void ColumnReader::DeltaByteArray(uint8_t *defines, idx_t num_values, // NOLINT - parquet_filter_t &filter, idx_t result_offset, Vector &result) { - throw NotImplementedException("DeltaByteArray"); -} - -void ColumnReader::PlainReference(shared_ptr, Vector &result) { // NOLINT -} - -void ColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) { - D_ASSERT(file_idx < columns.size()); - chunk = &columns[file_idx]; - protocol = &protocol_p; - D_ASSERT(chunk); - D_ASSERT(chunk->__isset.meta_data); - - if (chunk->__isset.file_path) { - throw std::runtime_error("Only inlined data files are supported (no references)"); - } - - // ugh. sometimes there is an extra offset for the dict. sometimes it's wrong. - chunk_read_offset = chunk->meta_data.data_page_offset; - if (chunk->meta_data.__isset.dictionary_page_offset && chunk->meta_data.dictionary_page_offset >= 4) { - // this assumes the data pages follow the dict pages directly. - chunk_read_offset = chunk->meta_data.dictionary_page_offset; - } - group_rows_available = chunk->meta_data.num_values; -} - -void ColumnReader::PrepareRead(parquet_filter_t &filter) { - dict_decoder.reset(); - defined_decoder.reset(); - bss_decoder.reset(); - block.reset(); - PageHeader page_hdr; - reader.Read(page_hdr, *protocol); - // some basic sanity check - if (page_hdr.compressed_page_size < 0 || page_hdr.uncompressed_page_size < 0) { - throw std::runtime_error("Page sizes can't be < 0"); - } - - switch (page_hdr.type) { - case PageType::DATA_PAGE_V2: - PreparePageV2(page_hdr); - PrepareDataPage(page_hdr); - break; - case PageType::DATA_PAGE: - PreparePage(page_hdr); - PrepareDataPage(page_hdr); - break; - case PageType::DICTIONARY_PAGE: { - PreparePage(page_hdr); - if (page_hdr.dictionary_page_header.num_values < 0) { - throw std::runtime_error("Invalid dictionary page header (num_values < 0)"); - } - auto old_dict_size = dictionary_size; - // we use the first value in the dictionary to keep a NULL - dictionary_size = page_hdr.dictionary_page_header.num_values; - if (!dictionary) { - dictionary = make_uniq(type, dictionary_size + 1); - } else if (dictionary_size > old_dict_size) { - dictionary->Resize(old_dict_size, dictionary_size + 1); - } - dictionary_id = reader.file_name + "_" + schema.name + "_" + std::to_string(chunk_read_offset); - // we use the first entry as a NULL, dictionary vectors don't have a separate validity mask - FlatVector::Validity(*dictionary).SetInvalid(0); - PlainReference(block, *dictionary); - Plain(block, nullptr, dictionary_size, nullptr, 1, *dictionary); - break; - } - default: - break; // ignore INDEX page type and any other custom extensions - } - ResetPage(); -} - -void ColumnReader::ResetPage() { -} - -void ColumnReader::PreparePageV2(PageHeader &page_hdr) { - D_ASSERT(page_hdr.type == PageType::DATA_PAGE_V2); - auto &trans = reinterpret_cast(*protocol->getTransport()); - - AllocateBlock(page_hdr.uncompressed_page_size + 1); - bool uncompressed = false; - if (page_hdr.data_page_header_v2.__isset.is_compressed && !page_hdr.data_page_header_v2.is_compressed) { - uncompressed = true; - } - if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { - if (page_hdr.compressed_page_size != page_hdr.uncompressed_page_size) { - throw std::runtime_error("Page size mismatch"); - } - uncompressed = true; - } - if (uncompressed) { - reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size); - return; - } - - // copy repeats & defines as-is because FOR SOME REASON they are uncompressed - auto uncompressed_bytes = page_hdr.data_page_header_v2.repetition_levels_byte_length + - page_hdr.data_page_header_v2.definition_levels_byte_length; - if (uncompressed_bytes > page_hdr.uncompressed_page_size) { - throw std::runtime_error("Page header inconsistency, uncompressed_page_size needs to be larger than " - "repetition_levels_byte_length + definition_levels_byte_length"); - } - trans.read(block->ptr, uncompressed_bytes); - - auto compressed_bytes = page_hdr.compressed_page_size - uncompressed_bytes; - - AllocateCompressed(compressed_bytes); - reader.ReadData(*protocol, compressed_buffer.ptr, compressed_bytes); - - DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, compressed_bytes, block->ptr + uncompressed_bytes, - page_hdr.uncompressed_page_size - uncompressed_bytes); -} - -void ColumnReader::AllocateBlock(idx_t size) { - if (!block) { - block = make_shared_ptr(GetAllocator(), size); - } else { - block->resize(GetAllocator(), size); - } -} - -void ColumnReader::AllocateCompressed(idx_t size) { - compressed_buffer.resize(GetAllocator(), size); -} - -void ColumnReader::PreparePage(PageHeader &page_hdr) { - AllocateBlock(page_hdr.uncompressed_page_size + 1); - if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) { - if (page_hdr.compressed_page_size != page_hdr.uncompressed_page_size) { - throw std::runtime_error("Page size mismatch"); - } - reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size); - return; - } - - AllocateCompressed(page_hdr.compressed_page_size + 1); - reader.ReadData(*protocol, compressed_buffer.ptr, page_hdr.compressed_page_size); - - DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, page_hdr.compressed_page_size, block->ptr, - page_hdr.uncompressed_page_size); -} - -void ColumnReader::DecompressInternal(CompressionCodec::type codec, const_data_ptr_t src, idx_t src_size, - data_ptr_t dst, idx_t dst_size) { - switch (codec) { - case CompressionCodec::UNCOMPRESSED: - throw InternalException("Parquet data unexpectedly uncompressed"); - case CompressionCodec::GZIP: { - MiniZStream s; - s.Decompress(const_char_ptr_cast(src), src_size, char_ptr_cast(dst), dst_size); - break; - } - case CompressionCodec::LZ4_RAW: { - auto res = - duckdb_lz4::LZ4_decompress_safe(const_char_ptr_cast(src), char_ptr_cast(dst), - UnsafeNumericCast(src_size), UnsafeNumericCast(dst_size)); - if (res != NumericCast(dst_size)) { - throw std::runtime_error("LZ4 decompression failure"); - } - break; - } - case CompressionCodec::SNAPPY: { - { - size_t uncompressed_size = 0; - auto res = duckdb_snappy::GetUncompressedLength(const_char_ptr_cast(src), src_size, &uncompressed_size); - if (!res) { - throw std::runtime_error("Snappy decompression failure"); - } - if (uncompressed_size != dst_size) { - throw std::runtime_error("Snappy decompression failure: Uncompressed data size mismatch"); - } - } - auto res = duckdb_snappy::RawUncompress(const_char_ptr_cast(src), src_size, char_ptr_cast(dst)); - if (!res) { - throw std::runtime_error("Snappy decompression failure"); - } - break; - } - case CompressionCodec::ZSTD: { - auto res = duckdb_zstd::ZSTD_decompress(dst, dst_size, src, src_size); - if (duckdb_zstd::ZSTD_isError(res) || res != dst_size) { - throw std::runtime_error("ZSTD Decompression failure"); - } - break; - } - case CompressionCodec::BROTLI: { - auto state = duckdb_brotli::BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); - size_t total_out = 0; - auto src_size_size_t = NumericCast(src_size); - auto dst_size_size_t = NumericCast(dst_size); - - auto res = duckdb_brotli::BrotliDecoderDecompressStream(state, &src_size_size_t, &src, &dst_size_size_t, &dst, - &total_out); - if (res != duckdb_brotli::BROTLI_DECODER_RESULT_SUCCESS) { - throw std::runtime_error("Brotli Decompression failure"); - } - duckdb_brotli::BrotliDecoderDestroyInstance(state); - break; - } - - default: { - std::stringstream codec_name; - codec_name << codec; - throw std::runtime_error("Unsupported compression codec \"" + codec_name.str() + - "\". Supported options are uncompressed, brotli, gzip, lz4_raw, snappy or zstd"); - } - } -} - -void ColumnReader::PrepareDataPage(PageHeader &page_hdr) { - if (page_hdr.type == PageType::DATA_PAGE && !page_hdr.__isset.data_page_header) { - throw std::runtime_error("Missing data page header from data page"); - } - if (page_hdr.type == PageType::DATA_PAGE_V2 && !page_hdr.__isset.data_page_header_v2) { - throw std::runtime_error("Missing data page header from data page v2"); - } - - bool is_v1 = page_hdr.type == PageType::DATA_PAGE; - bool is_v2 = page_hdr.type == PageType::DATA_PAGE_V2; - auto &v1_header = page_hdr.data_page_header; - auto &v2_header = page_hdr.data_page_header_v2; - - page_rows_available = is_v1 ? v1_header.num_values : v2_header.num_values; - auto page_encoding = is_v1 ? v1_header.encoding : v2_header.encoding; - - if (HasRepeats()) { - uint32_t rep_length = is_v1 ? block->read() : v2_header.repetition_levels_byte_length; - block->available(rep_length); - repeated_decoder = make_uniq(block->ptr, rep_length, RleBpDecoder::ComputeBitWidth(max_repeat)); - block->inc(rep_length); - } else if (is_v2 && v2_header.repetition_levels_byte_length > 0) { - block->inc(v2_header.repetition_levels_byte_length); - } - - if (HasDefines()) { - uint32_t def_length = is_v1 ? block->read() : v2_header.definition_levels_byte_length; - block->available(def_length); - defined_decoder = make_uniq(block->ptr, def_length, RleBpDecoder::ComputeBitWidth(max_define)); - block->inc(def_length); - } else if (is_v2 && v2_header.definition_levels_byte_length > 0) { - block->inc(v2_header.definition_levels_byte_length); - } - - switch (page_encoding) { - case Encoding::RLE_DICTIONARY: - case Encoding::PLAIN_DICTIONARY: { - // where is it otherwise?? - auto dict_width = block->read(); - // TODO somehow dict_width can be 0 ? - dict_decoder = make_uniq(block->ptr, block->len, dict_width); - block->inc(block->len); - break; - } - case Encoding::RLE: { - if (type.id() != LogicalTypeId::BOOLEAN) { - throw std::runtime_error("RLE encoding is only supported for boolean data"); - } - block->inc(sizeof(uint32_t)); - rle_decoder = make_uniq(block->ptr, block->len, 1); - break; - } - case Encoding::DELTA_BINARY_PACKED: { - dbp_decoder = make_uniq(block->ptr, block->len); - block->inc(block->len); - break; - } - case Encoding::DELTA_LENGTH_BYTE_ARRAY: { - PrepareDeltaLengthByteArray(*block); - break; - } - case Encoding::DELTA_BYTE_ARRAY: { - PrepareDeltaByteArray(*block); - break; - } - case Encoding::BYTE_STREAM_SPLIT: { - // Subtract 1 from length as the block is allocated with 1 extra byte, - // but the byte stream split encoder needs to know the correct data size. - bss_decoder = make_uniq(block->ptr, block->len - 1); - block->inc(block->len); - break; - } - case Encoding::PLAIN: - // nothing to do here, will be read directly below - break; - - default: - throw std::runtime_error("Unsupported page encoding"); - } -} - -void ColumnReader::ConvertDictToSelVec(uint32_t *offsets, uint8_t *defines, parquet_filter_t &filter, idx_t read_now, - idx_t result_offset) { - D_ASSERT(read_now <= STANDARD_VECTOR_SIZE); - idx_t offset_idx = 0; - for (idx_t row_idx = 0; row_idx < read_now; row_idx++) { - if (HasDefines() && defines[row_idx + result_offset] != max_define) { - dictionary_selection_vector.set_index(row_idx, 0); // dictionary entry 0 is NULL - continue; // we don't have a dict entry for NULLs - } - if (filter.test(row_idx + result_offset)) { - auto offset = offsets[offset_idx++]; - if (offset >= dictionary_size) { - throw std::runtime_error("Parquet file is likely corrupted, dictionary offset out of range"); - } - dictionary_selection_vector.set_index(row_idx, offset + 1); - } else { - dictionary_selection_vector.set_index(row_idx, 0); // just set NULL if the filter excludes this row - offset_idx++; - } - } -#ifdef DEBUG - dictionary_selection_vector.Verify(read_now, dictionary_size + 1); -#endif -} - -idx_t ColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result) { - // we need to reset the location because multiple column readers share the same protocol - auto &trans = reinterpret_cast(*protocol->getTransport()); - trans.SetLocation(chunk_read_offset); - - // Perform any skips that were not applied yet. - if (pending_skips > 0) { - ApplyPendingSkips(pending_skips); - } - - idx_t result_offset = 0; - auto to_read = num_values; - D_ASSERT(to_read <= STANDARD_VECTOR_SIZE); - - while (to_read > 0) { - while (page_rows_available == 0) { - PrepareRead(filter); - } - - D_ASSERT(block); - auto read_now = MinValue(to_read, page_rows_available); - - D_ASSERT(read_now + result_offset <= STANDARD_VECTOR_SIZE); - - if (HasRepeats()) { - D_ASSERT(repeated_decoder); - repeated_decoder->GetBatch(repeat_out + result_offset, read_now); - } - - if (HasDefines()) { - D_ASSERT(defined_decoder); - defined_decoder->GetBatch(define_out + result_offset, read_now); - } - - idx_t null_count = 0; - - if ((dict_decoder || dbp_decoder || rle_decoder || bss_decoder) && HasDefines()) { - // we need the null count because the dictionary offsets have no entries for nulls - for (idx_t i = result_offset; i < result_offset + read_now; i++) { - null_count += (define_out[i] != max_define); - } - } - - if (result_offset != 0 && result.GetVectorType() != VectorType::FLAT_VECTOR) { - result.Flatten(result_offset); - result.Resize(result_offset, STANDARD_VECTOR_SIZE); - } - - if (dict_decoder) { - if ((!dictionary || dictionary_size == 0) && null_count < read_now) { - throw std::runtime_error("Parquet file is likely corrupted, missing dictionary"); - } - offset_buffer.resize(reader.allocator, sizeof(uint32_t) * (read_now - null_count)); - dict_decoder->GetBatch(offset_buffer.ptr, read_now - null_count); - ConvertDictToSelVec(reinterpret_cast(offset_buffer.ptr), - reinterpret_cast(define_out), filter, read_now, result_offset); - if (result_offset == 0) { - result.Dictionary(*dictionary, dictionary_size + 1, dictionary_selection_vector, read_now); - DictionaryVector::SetDictionaryId(result, dictionary_id); - D_ASSERT(result.GetVectorType() == VectorType::DICTIONARY_VECTOR); - } else { - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(*dictionary, result, dictionary_selection_vector, read_now, 0, result_offset); - } - } else if (dbp_decoder) { - // TODO keep this in the state - auto read_buf = make_shared_ptr(); - - switch (schema.type) { - case duckdb_parquet::Type::INT32: - read_buf->resize(reader.allocator, sizeof(int32_t) * (read_now - null_count)); - dbp_decoder->GetBatch(read_buf->ptr, read_now - null_count); - - break; - case duckdb_parquet::Type::INT64: - read_buf->resize(reader.allocator, sizeof(int64_t) * (read_now - null_count)); - dbp_decoder->GetBatch(read_buf->ptr, read_now - null_count); - break; - - default: - throw std::runtime_error("DELTA_BINARY_PACKED should only be INT32 or INT64"); - } - // Plain() will put NULLs in the right place - Plain(read_buf, define_out, read_now, &filter, result_offset, result); - } else if (rle_decoder) { - // RLE encoding for boolean - D_ASSERT(type.id() == LogicalTypeId::BOOLEAN); - auto read_buf = make_shared_ptr(); - read_buf->resize(reader.allocator, sizeof(bool) * (read_now - null_count)); - rle_decoder->GetBatch(read_buf->ptr, read_now - null_count); - PlainTemplated>(read_buf, define_out, read_now, &filter, - result_offset, result); - } else if (byte_array_data) { - // DELTA_BYTE_ARRAY or DELTA_LENGTH_BYTE_ARRAY - DeltaByteArray(define_out, read_now, filter, result_offset, result); - } else if (bss_decoder) { - auto read_buf = make_shared_ptr(); - - switch (schema.type) { - case duckdb_parquet::Type::FLOAT: - read_buf->resize(reader.allocator, sizeof(float) * (read_now - null_count)); - bss_decoder->GetBatch(read_buf->ptr, read_now - null_count); - break; - case duckdb_parquet::Type::DOUBLE: - read_buf->resize(reader.allocator, sizeof(double) * (read_now - null_count)); - bss_decoder->GetBatch(read_buf->ptr, read_now - null_count); - break; - default: - throw std::runtime_error("BYTE_STREAM_SPLIT encoding is only supported for FLOAT or DOUBLE data"); - } - - Plain(read_buf, define_out, read_now, &filter, result_offset, result); - } else { - PlainReference(block, result); - Plain(block, define_out, read_now, &filter, result_offset, result); - } - - result_offset += read_now; - page_rows_available -= read_now; - to_read -= read_now; - } - group_rows_available -= num_values; - chunk_read_offset = trans.GetLocation(); - - return num_values; -} - -void ColumnReader::Skip(idx_t num_values) { - pending_skips += num_values; -} - -void ColumnReader::ApplyPendingSkips(idx_t num_values) { - pending_skips -= num_values; - - dummy_define.zero(); - dummy_repeat.zero(); - - // TODO this can be optimized, for example we dont actually have to bitunpack offsets - Vector base_result(type, nullptr); - - idx_t remaining = num_values; - idx_t read = 0; - - while (remaining) { - Vector dummy_result(base_result); - idx_t to_read = MinValue(remaining, STANDARD_VECTOR_SIZE); - read += Read(to_read, none_filter, dummy_define.ptr, dummy_repeat.ptr, dummy_result); - remaining -= to_read; - } - - if (read != num_values) { - throw std::runtime_error("Row count mismatch when skipping rows"); - } -} - -//===--------------------------------------------------------------------===// -// String Column Reader -//===--------------------------------------------------------------------===// -StringColumnReader::StringColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, - idx_t schema_idx_p, idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, - max_define_p, max_repeat_p) { - fixed_width_string_length = 0; - if (schema_p.type == Type::FIXED_LEN_BYTE_ARRAY) { - D_ASSERT(schema_p.__isset.type_length); - fixed_width_string_length = schema_p.type_length; - } -} - -uint32_t StringColumnReader::VerifyString(const char *str_data, uint32_t str_len, const bool is_varchar) { - if (!is_varchar) { - return str_len; - } - // verify if a string is actually UTF8, and if there are no null bytes in the middle of the string - // technically Parquet should guarantee this, but reality is often disappointing - UnicodeInvalidReason reason; - size_t pos; - auto utf_type = Utf8Proc::Analyze(str_data, str_len, &reason, &pos); - if (utf_type == UnicodeType::INVALID) { - throw InvalidInputException("Invalid string encoding found in Parquet file: value \"" + - Blob::ToString(string_t(str_data, str_len)) + "\" is not valid UTF8!"); - } - return str_len; -} - -uint32_t StringColumnReader::VerifyString(const char *str_data, uint32_t str_len) { - return VerifyString(str_data, str_len, Type() == LogicalTypeId::VARCHAR); -} - -static shared_ptr ReadDbpData(Allocator &allocator, ResizeableBuffer &buffer, idx_t &value_count) { - auto decoder = make_uniq(buffer.ptr, buffer.len); - value_count = decoder->TotalValues(); - auto result = make_shared_ptr(); - result->resize(allocator, sizeof(uint32_t) * value_count); - decoder->GetBatch(result->ptr, value_count); - decoder->Finalize(); - buffer.inc(buffer.len - decoder->BufferPtr().len); - return result; -} - -void StringColumnReader::PrepareDeltaLengthByteArray(ResizeableBuffer &buffer) { - idx_t value_count; - auto length_buffer = ReadDbpData(reader.allocator, buffer, value_count); - if (value_count == 0) { - // no values - byte_array_data = make_uniq(LogicalType::VARCHAR, nullptr); - return; - } - auto length_data = reinterpret_cast(length_buffer->ptr); - byte_array_data = make_uniq(LogicalType::VARCHAR, value_count); - byte_array_count = value_count; - delta_offset = 0; - auto string_data = FlatVector::GetData(*byte_array_data); - for (idx_t i = 0; i < value_count; i++) { - auto str_len = length_data[i]; - buffer.available(str_len); - string_data[i] = StringVector::EmptyString(*byte_array_data, str_len); - auto result_data = string_data[i].GetDataWriteable(); - memcpy(result_data, buffer.ptr, length_data[i]); - buffer.inc(length_data[i]); - string_data[i].Finalize(); - } -} - -void StringColumnReader::PrepareDeltaByteArray(ResizeableBuffer &buffer) { - idx_t prefix_count, suffix_count; - auto prefix_buffer = ReadDbpData(reader.allocator, buffer, prefix_count); - auto suffix_buffer = ReadDbpData(reader.allocator, buffer, suffix_count); - if (prefix_count != suffix_count) { - throw std::runtime_error("DELTA_BYTE_ARRAY - prefix and suffix counts are different - corrupt file?"); - } - if (prefix_count == 0) { - // no values - byte_array_data = make_uniq(LogicalType::VARCHAR, nullptr); - return; - } - auto prefix_data = reinterpret_cast(prefix_buffer->ptr); - auto suffix_data = reinterpret_cast(suffix_buffer->ptr); - byte_array_data = make_uniq(LogicalType::VARCHAR, prefix_count); - byte_array_count = prefix_count; - delta_offset = 0; - auto string_data = FlatVector::GetData(*byte_array_data); - for (idx_t i = 0; i < prefix_count; i++) { - auto str_len = prefix_data[i] + suffix_data[i]; - buffer.available(suffix_data[i]); - string_data[i] = StringVector::EmptyString(*byte_array_data, str_len); - auto result_data = string_data[i].GetDataWriteable(); - if (prefix_data[i] > 0) { - if (i == 0 || prefix_data[i] > string_data[i - 1].GetSize()) { - throw std::runtime_error("DELTA_BYTE_ARRAY - prefix is out of range - corrupt file?"); - } - memcpy(result_data, string_data[i - 1].GetData(), prefix_data[i]); - } - memcpy(result_data + prefix_data[i], buffer.ptr, suffix_data[i]); - buffer.inc(suffix_data[i]); - string_data[i].Finalize(); - } -} - -void StringColumnReader::DeltaByteArray(uint8_t *defines, idx_t num_values, parquet_filter_t &filter, - idx_t result_offset, Vector &result) { - if (!byte_array_data) { - throw std::runtime_error("Internal error - DeltaByteArray called but there was no byte_array_data set"); - } - auto result_ptr = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - auto string_data = FlatVector::GetData(*byte_array_data); - for (idx_t row_idx = 0; row_idx < num_values; row_idx++) { - if (HasDefines() && defines[row_idx + result_offset] != max_define) { - result_mask.SetInvalid(row_idx + result_offset); - continue; - } - if (filter.test(row_idx + result_offset)) { - if (delta_offset >= byte_array_count) { - throw IOException("DELTA_BYTE_ARRAY - length mismatch between values and byte array lengths (attempted " - "read of %d from %d entries) - corrupt file?", - delta_offset + 1, byte_array_count); - } - result_ptr[row_idx + result_offset] = string_data[delta_offset++]; - } else { - delta_offset++; - } - } - StringVector::AddHeapReference(result, *byte_array_data); -} - -class ParquetStringVectorBuffer : public VectorBuffer { -public: - explicit ParquetStringVectorBuffer(shared_ptr buffer_p) - : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), buffer(std::move(buffer_p)) { - } - -private: - shared_ptr buffer; -}; - -void StringColumnReader::PlainReference(shared_ptr plain_data, Vector &result) { - StringVector::AddBuffer(result, make_buffer(std::move(plain_data))); -} - -string_t StringParquetValueConversion::PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - auto &scr = reader.Cast(); - uint32_t str_len = scr.fixed_width_string_length == 0 ? plain_data.read() : scr.fixed_width_string_length; - plain_data.available(str_len); - auto plain_str = char_ptr_cast(plain_data.ptr); - auto actual_str_len = reader.Cast().VerifyString(plain_str, str_len); - auto ret_str = string_t(plain_str, actual_str_len); - plain_data.inc(str_len); - return ret_str; -} - -void StringParquetValueConversion::PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - auto &scr = reader.Cast(); - uint32_t str_len = scr.fixed_width_string_length == 0 ? plain_data.read() : scr.fixed_width_string_length; - plain_data.inc(str_len); -} - -bool StringParquetValueConversion::PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return true; -} - -string_t StringParquetValueConversion::UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return PlainRead(plain_data, reader); -} - -void StringParquetValueConversion::UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - PlainSkip(plain_data, reader); -} - -//===--------------------------------------------------------------------===// -// List Column Reader -//===--------------------------------------------------------------------===// -idx_t ListColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, - data_ptr_t repeat_out, Vector &result_out) { - idx_t result_offset = 0; - auto result_ptr = FlatVector::GetData(result_out); - auto &result_mask = FlatVector::Validity(result_out); - - if (pending_skips > 0) { - ApplyPendingSkips(pending_skips); - } - - D_ASSERT(ListVector::GetListSize(result_out) == 0); - // if an individual list is longer than STANDARD_VECTOR_SIZE we actually have to loop the child read to fill it - bool finished = false; - while (!finished) { - idx_t child_actual_num_values = 0; - - // check if we have any overflow from a previous read - if (overflow_child_count == 0) { - // we don't: read elements from the child reader - child_defines.zero(); - child_repeats.zero(); - // we don't know in advance how many values to read because of the beautiful repetition/definition setup - // we just read (up to) a vector from the child column, and see if we have read enough - // if we have not read enough, we read another vector - // if we have read enough, we leave any unhandled elements in the overflow vector for a subsequent read - auto child_req_num_values = - MinValue(STANDARD_VECTOR_SIZE, child_column_reader->GroupRowsAvailable()); - read_vector.ResetFromCache(read_cache); - child_actual_num_values = child_column_reader->Read(child_req_num_values, child_filter, child_defines_ptr, - child_repeats_ptr, read_vector); - } else { - // we do: use the overflow values - child_actual_num_values = overflow_child_count; - overflow_child_count = 0; - } - - if (child_actual_num_values == 0) { - // no more elements available: we are done - break; - } - read_vector.Verify(child_actual_num_values); - idx_t current_chunk_offset = ListVector::GetListSize(result_out); - - // hard-won piece of code this, modify at your own risk - // the intuition is that we have to only collapse values into lists that are repeated *on this level* - // the rest is pretty much handed up as-is as a single-valued list or NULL - idx_t child_idx; - for (child_idx = 0; child_idx < child_actual_num_values; child_idx++) { - if (child_repeats_ptr[child_idx] == max_repeat) { - // value repeats on this level, append - D_ASSERT(result_offset > 0); - result_ptr[result_offset - 1].length++; - continue; - } - - if (result_offset >= num_values) { - // we ran out of output space - finished = true; - break; - } - if (child_defines_ptr[child_idx] >= max_define) { - // value has been defined down the stack, hence its NOT NULL - result_ptr[result_offset].offset = child_idx + current_chunk_offset; - result_ptr[result_offset].length = 1; - } else if (child_defines_ptr[child_idx] == max_define - 1) { - // empty list - result_ptr[result_offset].offset = child_idx + current_chunk_offset; - result_ptr[result_offset].length = 0; - } else { - // value is NULL somewhere up the stack - result_mask.SetInvalid(result_offset); - result_ptr[result_offset].offset = 0; - result_ptr[result_offset].length = 0; - } - - repeat_out[result_offset] = child_repeats_ptr[child_idx]; - define_out[result_offset] = child_defines_ptr[child_idx]; - - result_offset++; - } - // actually append the required elements to the child list - ListVector::Append(result_out, read_vector, child_idx); - - // we have read more values from the child reader than we can fit into the result for this read - // we have to pass everything from child_idx to child_actual_num_values into the next call - if (child_idx < child_actual_num_values && result_offset == num_values) { - read_vector.Slice(read_vector, child_idx, child_actual_num_values); - overflow_child_count = child_actual_num_values - child_idx; - read_vector.Verify(overflow_child_count); - - // move values in the child repeats and defines *backward* by child_idx - for (idx_t repdef_idx = 0; repdef_idx < overflow_child_count; repdef_idx++) { - child_defines_ptr[repdef_idx] = child_defines_ptr[child_idx + repdef_idx]; - child_repeats_ptr[repdef_idx] = child_repeats_ptr[child_idx + repdef_idx]; - } - } - } - result_out.Verify(result_offset); - return result_offset; -} - -ListColumnReader::ListColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, - idx_t schema_idx_p, idx_t max_define_p, idx_t max_repeat_p, - unique_ptr child_column_reader_p) - : ColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, max_define_p, max_repeat_p), - child_column_reader(std::move(child_column_reader_p)), - read_cache(reader.allocator, ListType::GetChildType(Type())), read_vector(read_cache), overflow_child_count(0) { - - child_defines.resize(reader.allocator, STANDARD_VECTOR_SIZE); - child_repeats.resize(reader.allocator, STANDARD_VECTOR_SIZE); - child_defines_ptr = (uint8_t *)child_defines.ptr; - child_repeats_ptr = (uint8_t *)child_repeats.ptr; - - child_filter.set(); -} - -void ListColumnReader::ApplyPendingSkips(idx_t num_values) { - pending_skips -= num_values; - - auto define_out = unique_ptr(new uint8_t[num_values]); - auto repeat_out = unique_ptr(new uint8_t[num_values]); - - idx_t remaining = num_values; - idx_t read = 0; - - while (remaining) { - Vector result_out(Type()); - parquet_filter_t filter; - idx_t to_read = MinValue(remaining, STANDARD_VECTOR_SIZE); - read += Read(to_read, filter, define_out.get(), repeat_out.get(), result_out); - remaining -= to_read; - } - - if (read != num_values) { - throw InternalException("Not all skips done!"); - } -} - -//===--------------------------------------------------------------------===// -// Row NumberColumn Reader -//===--------------------------------------------------------------------===// -RowNumberColumnReader::RowNumberColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, - idx_t schema_idx_p, idx_t max_define_p, idx_t max_repeat_p) - : ColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, max_define_p, max_repeat_p) { -} - -unique_ptr RowNumberColumnReader::Stats(idx_t row_group_idx_p, const vector &columns) { - auto stats = NumericStats::CreateUnknown(type); - auto &row_groups = reader.GetFileMetadata()->row_groups; - D_ASSERT(row_group_idx_p < row_groups.size()); - idx_t row_group_offset_min = 0; - for (idx_t i = 0; i < row_group_idx_p; i++) { - row_group_offset_min += row_groups[i].num_rows; - } - - NumericStats::SetMin(stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min))); - NumericStats::SetMax( - stats, Value::BIGINT(UnsafeNumericCast(row_group_offset_min + row_groups[row_group_idx_p].num_rows))); - stats.Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - return stats.ToUnique(); -} - -void RowNumberColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, - TProtocol &protocol_p) { - row_group_offset = 0; - auto &row_groups = reader.GetFileMetadata()->row_groups; - for (idx_t i = 0; i < row_group_idx_p; i++) { - row_group_offset += row_groups[i].num_rows; - } -} - -idx_t RowNumberColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, - data_ptr_t repeat_out, Vector &result) { - - auto data_ptr = FlatVector::GetData(result); - for (idx_t i = 0; i < num_values; i++) { - data_ptr[i] = UnsafeNumericCast(row_group_offset++); - } - return num_values; -} - -//===--------------------------------------------------------------------===// -// Cast Column Reader -//===--------------------------------------------------------------------===// -CastColumnReader::CastColumnReader(unique_ptr child_reader_p, LogicalType target_type_p) - : ColumnReader(child_reader_p->Reader(), std::move(target_type_p), child_reader_p->Schema(), - child_reader_p->FileIdx(), child_reader_p->MaxDefine(), child_reader_p->MaxRepeat()), - child_reader(std::move(child_reader_p)) { - vector intermediate_types {child_reader->Type()}; - intermediate_chunk.Initialize(reader.allocator, intermediate_types); -} - -unique_ptr CastColumnReader::Stats(idx_t row_group_idx_p, const vector &columns) { - // casting stats is not supported (yet) - return nullptr; -} - -void CastColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, - TProtocol &protocol_p) { - child_reader->InitializeRead(row_group_idx_p, columns, protocol_p); -} - -idx_t CastColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, - data_ptr_t repeat_out, Vector &result) { - intermediate_chunk.Reset(); - auto &intermediate_vector = intermediate_chunk.data[0]; - - auto amount = child_reader->Read(num_values, filter, define_out, repeat_out, intermediate_vector); - if (!filter.all()) { - // work-around for filters: set all values that are filtered to NULL to prevent the cast from failing on - // uninitialized data - intermediate_vector.Flatten(amount); - auto &validity = FlatVector::Validity(intermediate_vector); - for (idx_t i = 0; i < amount; i++) { - if (!filter.test(i)) { - validity.SetInvalid(i); - } - } - } - string error_message; - bool all_succeeded = VectorOperations::DefaultTryCast(intermediate_vector, result, amount, &error_message); - if (!all_succeeded) { - string extended_error; - if (!reader.table_columns.empty()) { - // COPY .. FROM - extended_error = StringUtil::Format( - "In file \"%s\" the column \"%s\" has type %s, but we are trying to load it into column ", - reader.file_name, schema.name, intermediate_vector.GetType()); - if (FileIdx() < reader.table_columns.size()) { - extended_error += "\"" + reader.table_columns[FileIdx()] + "\" "; - } - extended_error += StringUtil::Format("with type %s.", result.GetType()); - extended_error += "\nThis means the Parquet schema does not match the schema of the table."; - extended_error += "\nPossible solutions:"; - extended_error += "\n* Insert by name instead of by position using \"INSERT INTO tbl BY NAME SELECT * FROM " - "read_parquet(...)\""; - extended_error += "\n* Manually specify which columns to insert using \"INSERT INTO tbl SELECT ... FROM " - "read_parquet(...)\""; - } else { - // read_parquet() with multiple files - extended_error = StringUtil::Format( - "In file \"%s\" the column \"%s\" has type %s, but we are trying to read it as type %s.", - reader.file_name, schema.name, intermediate_vector.GetType(), result.GetType()); - extended_error += - "\nThis can happen when reading multiple Parquet files. The schema information is taken from " - "the first Parquet file by default. Possible solutions:\n"; - extended_error += "* Enable the union_by_name=True option to combine the schema of all Parquet files " - "(duckdb.org/docs/data/multiple_files/combining_schemas)\n"; - extended_error += "* Use a COPY statement to automatically derive types from an existing table."; - } - throw ConversionException( - "In Parquet reader of file \"%s\": failed to cast column \"%s\" from type %s to %s: %s\n\n%s", - reader.file_name, schema.name, intermediate_vector.GetType(), result.GetType(), error_message, - extended_error); - } - return amount; -} - -void CastColumnReader::Skip(idx_t num_values) { - child_reader->Skip(num_values); -} - -idx_t CastColumnReader::GroupRowsAvailable() { - return child_reader->GroupRowsAvailable(); -} - -//===--------------------------------------------------------------------===// -// Expression Column Reader -//===--------------------------------------------------------------------===// -ExpressionColumnReader::ExpressionColumnReader(ClientContext &context, unique_ptr child_reader_p, - unique_ptr expr_p) - : ColumnReader(child_reader_p->Reader(), expr_p->return_type, child_reader_p->Schema(), child_reader_p->FileIdx(), - child_reader_p->MaxDefine(), child_reader_p->MaxRepeat()), - child_reader(std::move(child_reader_p)), expr(std::move(expr_p)), executor(context, expr.get()) { - vector intermediate_types {child_reader->Type()}; - intermediate_chunk.Initialize(reader.allocator, intermediate_types); -} - -unique_ptr ExpressionColumnReader::Stats(idx_t row_group_idx_p, const vector &columns) { - // expression stats is not supported (yet) - return nullptr; -} - -void ExpressionColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, - TProtocol &protocol_p) { - child_reader->InitializeRead(row_group_idx_p, columns, protocol_p); -} - -idx_t ExpressionColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, - data_ptr_t repeat_out, Vector &result) { - intermediate_chunk.Reset(); - auto &intermediate_vector = intermediate_chunk.data[0]; - - auto amount = child_reader->Read(num_values, filter, define_out, repeat_out, intermediate_vector); - if (!filter.all()) { - // work-around for filters: set all values that are filtered to NULL to prevent the cast from failing on - // uninitialized data - intermediate_vector.Flatten(amount); - auto &validity = FlatVector::Validity(intermediate_vector); - for (idx_t i = 0; i < amount; i++) { - if (!filter[i]) { - validity.SetInvalid(i); - } - } - } - // Execute the expression - intermediate_chunk.SetCardinality(amount); - executor.ExecuteExpression(intermediate_chunk, result); - return amount; -} - -void ExpressionColumnReader::Skip(idx_t num_values) { - child_reader->Skip(num_values); -} - -idx_t ExpressionColumnReader::GroupRowsAvailable() { - return child_reader->GroupRowsAvailable(); -} - -//===--------------------------------------------------------------------===// -// Struct Column Reader -//===--------------------------------------------------------------------===// -StructColumnReader::StructColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, - idx_t schema_idx_p, idx_t max_define_p, idx_t max_repeat_p, - vector> child_readers_p) - : ColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, max_define_p, max_repeat_p), - child_readers(std::move(child_readers_p)) { - D_ASSERT(type.InternalType() == PhysicalType::STRUCT); -} - -ColumnReader &StructColumnReader::GetChildReader(idx_t child_idx) { - if (!child_readers[child_idx]) { - throw InternalException("StructColumnReader::GetChildReader(%d) - but this child reader is not set", child_idx); - } - return *child_readers[child_idx].get(); -} - -void StructColumnReader::InitializeRead(idx_t row_group_idx_p, const vector &columns, - TProtocol &protocol_p) { - for (auto &child : child_readers) { - if (!child) { - continue; - } - child->InitializeRead(row_group_idx_p, columns, protocol_p); - } -} - -idx_t StructColumnReader::Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, - data_ptr_t repeat_out, Vector &result) { - auto &struct_entries = StructVector::GetEntries(result); - D_ASSERT(StructType::GetChildTypes(Type()).size() == struct_entries.size()); - - if (pending_skips > 0) { - ApplyPendingSkips(pending_skips); - } - - optional_idx read_count; - for (idx_t i = 0; i < child_readers.size(); i++) { - auto &child = child_readers[i]; - auto &target_vector = *struct_entries[i]; - if (!child) { - // if we are not scanning this vector - set it to NULL - target_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(target_vector, true); - continue; - } - auto child_num_values = child->Read(num_values, filter, define_out, repeat_out, target_vector); - if (!read_count.IsValid()) { - read_count = child_num_values; - } else if (read_count.GetIndex() != child_num_values) { - throw std::runtime_error("Struct child row count mismatch"); - } - } - if (!read_count.IsValid()) { - read_count = num_values; - } - // set the validity mask for this level - auto &validity = FlatVector::Validity(result); - for (idx_t i = 0; i < read_count.GetIndex(); i++) { - if (define_out[i] < max_define) { - validity.SetInvalid(i); - } - } - - return read_count.GetIndex(); -} - -void StructColumnReader::Skip(idx_t num_values) { - for (auto &child : child_readers) { - if (!child) { - continue; - } - child->Skip(num_values); - } -} - -void StructColumnReader::RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) { - for (auto &child : child_readers) { - if (!child) { - continue; - } - child->RegisterPrefetch(transport, allow_merge); - } -} - -uint64_t StructColumnReader::TotalCompressedSize() { - uint64_t size = 0; - for (auto &child : child_readers) { - if (!child) { - continue; - } - size += child->TotalCompressedSize(); - } - return size; -} - -static bool TypeHasExactRowCount(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return false; - case LogicalTypeId::STRUCT: - for (auto &kv : StructType::GetChildTypes(type)) { - if (TypeHasExactRowCount(kv.second)) { - return true; - } - } - return false; - default: - return true; - } -} - -idx_t StructColumnReader::GroupRowsAvailable() { - for (idx_t i = 0; i < child_readers.size(); i++) { - if (TypeHasExactRowCount(child_readers[i]->Type())) { - return child_readers[i]->GroupRowsAvailable(); - } - } - return child_readers[0]->GroupRowsAvailable(); -} - -//===--------------------------------------------------------------------===// -// Decimal Column Reader -//===--------------------------------------------------------------------===// -template -struct DecimalParquetValueConversion { - - static DUCKDB_PHYSICAL_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - idx_t byte_len; - if (FIXED_LENGTH) { - byte_len = (idx_t)reader.Schema().type_length; /* sure, type length needs to be a signed int */ - } else { - byte_len = plain_data.read(); - } - plain_data.available(byte_len); - auto res = ParquetDecimalUtils::ReadDecimalValue(const_data_ptr_cast(plain_data.ptr), - byte_len, reader.Schema()); - - plain_data.inc(byte_len); - return res; - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - uint32_t decimal_len = FIXED_LENGTH ? reader.Schema().type_length : plain_data.read(); - plain_data.inc(decimal_len); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return true; - } - - static DUCKDB_PHYSICAL_TYPE UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return PlainRead(plain_data, reader); - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - PlainSkip(plain_data, reader); - } -}; - -template -class DecimalColumnReader - : public TemplatedColumnReader> { - using BaseType = - TemplatedColumnReader>; - -public: - DecimalColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, // NOLINT - idx_t file_idx_p, idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader>( - reader, std::move(type_p), schema_p, file_idx_p, max_define_p, max_repeat_p) {}; - -protected: -}; - -template -static unique_ptr CreateDecimalReaderInternal(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define, idx_t max_repeat) { - switch (type_p.InternalType()) { - case PhysicalType::INT16: - return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, - max_repeat); - case PhysicalType::INT32: - return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, - max_repeat); - case PhysicalType::INT64: - return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, - max_repeat); - case PhysicalType::INT128: - return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, - max_repeat); - case PhysicalType::DOUBLE: - return make_uniq>(reader, type_p, schema_p, file_idx_p, max_define, - max_repeat); - default: - throw InternalException("Unrecognized type for Decimal"); - } -} - -template <> -double ParquetDecimalUtils::ReadDecimalValue(const_data_ptr_t pointer, idx_t size, - const duckdb_parquet::SchemaElement &schema_ele) { - double res = 0; - bool positive = (*pointer & 0x80) == 0; - for (idx_t i = 0; i < size; i += 8) { - auto byte_size = MinValue(sizeof(uint64_t), size - i); - uint64_t input = 0; - auto res_ptr = reinterpret_cast(&input); - for (idx_t k = 0; k < byte_size; k++) { - auto byte = pointer[i + k]; - res_ptr[sizeof(uint64_t) - k - 1] = positive ? byte : byte ^ 0xFF; - } - res *= double(NumericLimits::Maximum()) + 1; - res += static_cast(input); - } - if (!positive) { - res += 1; - res /= pow(10, schema_ele.scale); - return -res; - } - res /= pow(10, schema_ele.scale); - return res; -} - -unique_ptr ParquetDecimalUtils::CreateReader(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define, idx_t max_repeat) { - if (schema_p.__isset.type_length) { - return CreateDecimalReaderInternal(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else { - return CreateDecimalReaderInternal(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } -} - -//===--------------------------------------------------------------------===// -// UUID Column Reader -//===--------------------------------------------------------------------===// -struct UUIDValueConversion { - static hugeint_t ReadParquetUUID(const_data_ptr_t input) { - hugeint_t result; - result.lower = 0; - uint64_t unsigned_upper = 0; - for (idx_t i = 0; i < sizeof(uint64_t); i++) { - unsigned_upper <<= 8; - unsigned_upper += input[i]; - } - for (idx_t i = sizeof(uint64_t); i < sizeof(hugeint_t); i++) { - result.lower <<= 8; - result.lower += input[i]; - } - result.upper = static_cast(unsigned_upper ^ (uint64_t(1) << 63)); - return result; - } - - static hugeint_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.available(sizeof(hugeint_t)); - return UnsafePlainRead(plain_data, reader); - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.inc(sizeof(hugeint_t)); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return plain_data.check_available(count * sizeof(hugeint_t)); - } - - static hugeint_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - auto res = ReadParquetUUID(const_data_ptr_cast(plain_data.ptr)); - plain_data.unsafe_inc(sizeof(hugeint_t)); - return res; - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.unsafe_inc(sizeof(hugeint_t)); - } -}; - -class UUIDColumnReader : public TemplatedColumnReader { - -public: - UUIDColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader(reader, std::move(type_p), schema_p, file_idx_p, - max_define_p, max_repeat_p) {}; -}; - -//===--------------------------------------------------------------------===// -// Interval Column Reader -//===--------------------------------------------------------------------===// -struct IntervalValueConversion { - static constexpr const idx_t PARQUET_INTERVAL_SIZE = 12; - - static interval_t ReadParquetInterval(const_data_ptr_t input) { - interval_t result; - result.months = Load(input); - result.days = Load(input + sizeof(uint32_t)); - result.micros = int64_t(Load(input + sizeof(uint32_t) * 2)) * 1000; - return result; - } - - static interval_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.available(PARQUET_INTERVAL_SIZE); - return UnsafePlainRead(plain_data, reader); - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.inc(PARQUET_INTERVAL_SIZE); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return plain_data.check_available(count * PARQUET_INTERVAL_SIZE); - } - - static interval_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - auto res = ReadParquetInterval(const_data_ptr_cast(plain_data.ptr)); - plain_data.unsafe_inc(PARQUET_INTERVAL_SIZE); - return res; - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.unsafe_inc(PARQUET_INTERVAL_SIZE); - } -}; - -class IntervalColumnReader : public TemplatedColumnReader { - -public: - IntervalColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader(reader, std::move(type_p), schema_p, file_idx_p, - max_define_p, max_repeat_p) {}; -}; - -//===--------------------------------------------------------------------===// -// Create Column Reader -//===--------------------------------------------------------------------===// -template -unique_ptr CreateDecimalReader(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t file_idx_p, idx_t max_define, - idx_t max_repeat) { - switch (type_p.InternalType()) { - case PhysicalType::INT16: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case PhysicalType::INT32: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case PhysicalType::INT64: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - throw NotImplementedException("Unimplemented internal type for CreateDecimalReader"); - } -} - -unique_ptr ColumnReader::CreateReader(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t file_idx_p, idx_t max_define, - idx_t max_repeat) { - switch (type_p.id()) { - case LogicalTypeId::BOOLEAN: - return make_uniq(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::UTINYINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::USMALLINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::UINTEGER: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::UBIGINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::TINYINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::SMALLINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::INTEGER: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::BIGINT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::FLOAT: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::DOUBLE: - switch (schema_p.type) { - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - return ParquetDecimalUtils::CreateReader(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - return make_uniq>>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - switch (schema_p.type) { - case Type::INT96: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case Type::INT64: - if (schema_p.__isset.logicalType && schema_p.logicalType.__isset.TIMESTAMP) { - if (schema_p.logicalType.TIMESTAMP.unit.__isset.MILLIS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.MICROS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.NANOS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } - } else if (schema_p.__isset.converted_type) { - switch (schema_p.converted_type) { - case ConvertedType::TIMESTAMP_MICROS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case ConvertedType::TIMESTAMP_MILLIS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - break; - } - } - default: - break; - } - break; - case LogicalTypeId::TIMESTAMP_NS: - switch (schema_p.type) { - case Type::INT96: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case Type::INT64: - if (schema_p.__isset.logicalType && schema_p.logicalType.__isset.TIMESTAMP) { - if (schema_p.logicalType.TIMESTAMP.unit.__isset.MILLIS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.MICROS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIMESTAMP.unit.__isset.NANOS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } - } else if (schema_p.__isset.converted_type) { - switch (schema_p.converted_type) { - case ConvertedType::TIMESTAMP_MICROS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case ConvertedType::TIMESTAMP_MILLIS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - break; - } - } - default: - break; - } - break; - case LogicalTypeId::DATE: - return make_uniq>(reader, type_p, schema_p, file_idx_p, - max_define, max_repeat); - case LogicalTypeId::TIME: - if (schema_p.__isset.logicalType && schema_p.logicalType.__isset.TIME) { - if (schema_p.logicalType.TIME.unit.__isset.MILLIS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIME.unit.__isset.MICROS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIME.unit.__isset.NANOS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } - } else if (schema_p.__isset.converted_type) { - switch (schema_p.converted_type) { - case ConvertedType::TIME_MICROS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case ConvertedType::TIME_MILLIS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - break; - } - } - throw NotImplementedException("Unsupported time encoding in Parquet file"); - case LogicalTypeId::TIME_TZ: - if (schema_p.__isset.logicalType && schema_p.logicalType.__isset.TIME) { - if (schema_p.logicalType.TIME.unit.__isset.MILLIS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIME.unit.__isset.MICROS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } else if (schema_p.logicalType.TIME.unit.__isset.NANOS) { - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - } - } else if (schema_p.__isset.converted_type) { - switch (schema_p.converted_type) { - case ConvertedType::TIME_MICROS: - return make_uniq>( - reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - break; - } - } - throw NotImplementedException("Unsupported time encoding in Parquet file"); - case LogicalTypeId::BLOB: - case LogicalTypeId::VARCHAR: - return make_uniq(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::DECIMAL: - // we have to figure out what kind of int we need - switch (schema_p.type) { - case Type::INT32: - return CreateDecimalReader(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case Type::INT64: - return CreateDecimalReader(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - return ParquetDecimalUtils::CreateReader(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - throw NotImplementedException("Unrecognized Parquet type for Decimal"); - } - break; - case LogicalTypeId::UUID: - return make_uniq(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::INTERVAL: - return make_uniq(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - case LogicalTypeId::SQLNULL: - return make_uniq(reader, type_p, schema_p, file_idx_p, max_define, max_repeat); - default: - break; - } - throw NotImplementedException(type_p.ToString()); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/column_writer.cpp b/src/duckdb/extension/parquet/column_writer.cpp deleted file mode 100644 index 677bacf42..000000000 --- a/src/duckdb/extension/parquet/column_writer.cpp +++ /dev/null @@ -1,2606 +0,0 @@ -#include "column_writer.hpp" - -#include "duckdb.hpp" -#include "geo_parquet.hpp" -#include "parquet_dbp_encoder.hpp" -#include "parquet_dlba_encoder.hpp" -#include "parquet_rle_bp_decoder.hpp" -#include "parquet_rle_bp_encoder.hpp" -#include "parquet_bss_encoder.hpp" -#include "parquet_statistics.hpp" -#include "parquet_writer.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" -#include "duckdb/common/serializer/memory_stream.hpp" -#include "duckdb/common/serializer/write_stream.hpp" -#include "duckdb/common/string_map_set.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/execution/expression_executor.hpp" -#endif - -#include "brotli/encode.h" -#include "lz4.hpp" -#include "miniz_wrapper.hpp" -#include "snappy.h" -#include "zstd.h" -#include "zstd/common/xxhash.hpp" - -#include - -namespace duckdb { - -using namespace duckdb_parquet; // NOLINT -using namespace duckdb_miniz; // NOLINT - -using duckdb_parquet::CompressionCodec; -using duckdb_parquet::ConvertedType; -using duckdb_parquet::Encoding; -using duckdb_parquet::FieldRepetitionType; -using duckdb_parquet::FileMetaData; -using duckdb_parquet::PageHeader; -using duckdb_parquet::PageType; -using ParquetRowGroup = duckdb_parquet::RowGroup; -using duckdb_parquet::Type; - -#define PARQUET_DEFINE_VALID 65535 - -//===--------------------------------------------------------------------===// -// ColumnWriterStatistics -//===--------------------------------------------------------------------===// -ColumnWriterStatistics::~ColumnWriterStatistics() { -} - -bool ColumnWriterStatistics::HasStats() { - return false; -} - -string ColumnWriterStatistics::GetMin() { - return string(); -} - -string ColumnWriterStatistics::GetMax() { - return string(); -} - -string ColumnWriterStatistics::GetMinValue() { - return string(); -} - -string ColumnWriterStatistics::GetMaxValue() { - return string(); -} - -//===--------------------------------------------------------------------===// -// RleBpEncoder -//===--------------------------------------------------------------------===// -RleBpEncoder::RleBpEncoder(uint32_t bit_width) - : byte_width((bit_width + 7) / 8), byte_count(idx_t(-1)), run_count(idx_t(-1)) { -} - -// we always RLE everything (for now) -void RleBpEncoder::BeginPrepare(uint32_t first_value) { - byte_count = 0; - run_count = 1; - current_run_count = 1; - last_value = first_value; -} - -void RleBpEncoder::FinishRun() { - // last value, or value has changed - // write out the current run - byte_count += ParquetDecodeUtils::GetVarintSize(current_run_count << 1) + byte_width; - current_run_count = 1; - run_count++; -} - -void RleBpEncoder::PrepareValue(uint32_t value) { - if (value != last_value) { - FinishRun(); - last_value = value; - } else { - current_run_count++; - } -} - -void RleBpEncoder::FinishPrepare() { - FinishRun(); -} - -idx_t RleBpEncoder::GetByteCount() { - D_ASSERT(byte_count != idx_t(-1)); - return byte_count; -} - -void RleBpEncoder::BeginWrite(WriteStream &writer, uint32_t first_value) { - // start the RLE runs - last_value = first_value; - current_run_count = 1; -} - -void RleBpEncoder::WriteRun(WriteStream &writer) { - // write the header of the run - ParquetDecodeUtils::VarintEncode(current_run_count << 1, writer); - // now write the value - D_ASSERT(last_value >> (byte_width * 8) == 0); - switch (byte_width) { - case 1: - writer.Write(last_value); - break; - case 2: - writer.Write(last_value); - break; - case 3: - writer.Write(last_value & 0xFF); - writer.Write((last_value >> 8) & 0xFF); - writer.Write((last_value >> 16) & 0xFF); - break; - case 4: - writer.Write(last_value); - break; - default: - throw InternalException("unsupported byte width for RLE encoding"); - } - current_run_count = 1; -} - -void RleBpEncoder::WriteValue(WriteStream &writer, uint32_t value) { - if (value != last_value) { - WriteRun(writer); - last_value = value; - } else { - current_run_count++; - } -} - -void RleBpEncoder::FinishWrite(WriteStream &writer) { - WriteRun(writer); -} - -//===--------------------------------------------------------------------===// -// ColumnWriter -//===--------------------------------------------------------------------===// -ColumnWriter::ColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) - : writer(writer), schema_idx(schema_idx), schema_path(std::move(schema_path_p)), max_repeat(max_repeat), - max_define(max_define), can_have_nulls(can_have_nulls) { -} -ColumnWriter::~ColumnWriter() { -} - -ColumnWriterState::~ColumnWriterState() { -} - -void ColumnWriter::CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, - unique_ptr &compressed_buf) { - switch (writer.GetCodec()) { - case CompressionCodec::UNCOMPRESSED: - compressed_size = temp_writer.GetPosition(); - compressed_data = temp_writer.GetData(); - break; - - case CompressionCodec::SNAPPY: { - compressed_size = duckdb_snappy::MaxCompressedLength(temp_writer.GetPosition()); - compressed_buf = unique_ptr(new data_t[compressed_size]); - duckdb_snappy::RawCompress(const_char_ptr_cast(temp_writer.GetData()), temp_writer.GetPosition(), - char_ptr_cast(compressed_buf.get()), &compressed_size); - compressed_data = compressed_buf.get(); - D_ASSERT(compressed_size <= duckdb_snappy::MaxCompressedLength(temp_writer.GetPosition())); - break; - } - case CompressionCodec::LZ4_RAW: { - compressed_size = duckdb_lz4::LZ4_compressBound(UnsafeNumericCast(temp_writer.GetPosition())); - compressed_buf = unique_ptr(new data_t[compressed_size]); - compressed_size = duckdb_lz4::LZ4_compress_default( - const_char_ptr_cast(temp_writer.GetData()), char_ptr_cast(compressed_buf.get()), - UnsafeNumericCast(temp_writer.GetPosition()), UnsafeNumericCast(compressed_size)); - compressed_data = compressed_buf.get(); - break; - } - case CompressionCodec::GZIP: { - MiniZStream s; - compressed_size = s.MaxCompressedLength(temp_writer.GetPosition()); - compressed_buf = unique_ptr(new data_t[compressed_size]); - s.Compress(const_char_ptr_cast(temp_writer.GetData()), temp_writer.GetPosition(), - char_ptr_cast(compressed_buf.get()), &compressed_size); - compressed_data = compressed_buf.get(); - break; - } - case CompressionCodec::ZSTD: { - compressed_size = duckdb_zstd::ZSTD_compressBound(temp_writer.GetPosition()); - compressed_buf = unique_ptr(new data_t[compressed_size]); - compressed_size = duckdb_zstd::ZSTD_compress((void *)compressed_buf.get(), compressed_size, - (const void *)temp_writer.GetData(), temp_writer.GetPosition(), - UnsafeNumericCast(writer.CompressionLevel())); - compressed_data = compressed_buf.get(); - break; - } - case CompressionCodec::BROTLI: { - - compressed_size = duckdb_brotli::BrotliEncoderMaxCompressedSize(temp_writer.GetPosition()); - compressed_buf = unique_ptr(new data_t[compressed_size]); - - duckdb_brotli::BrotliEncoderCompress(BROTLI_DEFAULT_QUALITY, BROTLI_DEFAULT_WINDOW, BROTLI_DEFAULT_MODE, - temp_writer.GetPosition(), temp_writer.GetData(), &compressed_size, - compressed_buf.get()); - compressed_data = compressed_buf.get(); - - break; - } - default: - throw InternalException("Unsupported codec for Parquet Writer"); - } - - if (compressed_size > idx_t(NumericLimits::Maximum())) { - throw InternalException("Parquet writer: %d compressed page size out of range for type integer", - temp_writer.GetPosition()); - } -} - -void ColumnWriter::HandleRepeatLevels(ColumnWriterState &state, ColumnWriterState *parent, idx_t count, - idx_t max_repeat) const { - if (!parent) { - // no repeat levels without a parent node - return; - } - while (state.repetition_levels.size() < parent->repetition_levels.size()) { - state.repetition_levels.push_back(parent->repetition_levels[state.repetition_levels.size()]); - } -} - -void ColumnWriter::HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, - const idx_t count, const uint16_t define_value, const uint16_t null_value) const { - if (parent) { - // parent node: inherit definition level from the parent - idx_t vector_index = 0; - while (state.definition_levels.size() < parent->definition_levels.size()) { - idx_t current_index = state.definition_levels.size(); - if (parent->definition_levels[current_index] != PARQUET_DEFINE_VALID) { - state.definition_levels.push_back(parent->definition_levels[current_index]); - } else if (validity.RowIsValid(vector_index)) { - state.definition_levels.push_back(define_value); - } else { - if (!can_have_nulls) { - throw IOException("Parquet writer: map key column is not allowed to contain NULL values"); - } - state.null_count++; - state.definition_levels.push_back(null_value); - } - if (parent->is_empty.empty() || !parent->is_empty[current_index]) { - vector_index++; - } - } - } else { - // no parent: set definition levels only from this validity mask - for (idx_t i = 0; i < count; i++) { - const auto is_null = !validity.RowIsValid(i); - state.definition_levels.emplace_back(is_null ? null_value : define_value); - state.null_count += is_null; - } - if (!can_have_nulls && state.null_count != 0) { - throw IOException("Parquet writer: map key column is not allowed to contain NULL values"); - } - } -} - -class ColumnWriterPageState { -public: - virtual ~ColumnWriterPageState() { - } - -public: - template - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -struct PageInformation { - idx_t offset = 0; - idx_t row_count = 0; - idx_t empty_count = 0; - idx_t estimated_page_size = 0; -}; - -struct PageWriteInformation { - PageHeader page_header; - unique_ptr temp_writer; - unique_ptr page_state; - idx_t write_page_idx = 0; - idx_t write_count = 0; - idx_t max_write_count = 0; - size_t compressed_size; - data_ptr_t compressed_data; - unique_ptr compressed_buf; -}; - -class BasicColumnWriterState : public ColumnWriterState { -public: - BasicColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) - : row_group(row_group), col_idx(col_idx) { - page_info.emplace_back(); - } - ~BasicColumnWriterState() override = default; - - duckdb_parquet::RowGroup &row_group; - idx_t col_idx; - vector page_info; - vector write_info; - unique_ptr stats_state; - idx_t current_page = 0; - - unique_ptr bloom_filter; -}; - -//===--------------------------------------------------------------------===// -// BasicColumnWriter -// A base class for writing all non-compound types (ex. numerics, strings) -//===--------------------------------------------------------------------===// -class BasicColumnWriter : public ColumnWriter { -public: - BasicColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) - : ColumnWriter(writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls) { - } - - ~BasicColumnWriter() override = default; - - //! We limit the uncompressed page size to 100MB - //! The max size in Parquet is 2GB, but we choose a more conservative limit - static constexpr const idx_t MAX_UNCOMPRESSED_PAGE_SIZE = 100000000; - //! Dictionary pages must be below 2GB. Unlike data pages, there's only one dictionary page. - //! For this reason we go with a much higher, but still a conservative upper bound of 1GB; - static constexpr const idx_t MAX_UNCOMPRESSED_DICT_PAGE_SIZE = 1e9; - //! If the dictionary has this many entries, we stop creating the dictionary - static constexpr const idx_t DICTIONARY_ANALYZE_THRESHOLD = 1e4; - //! The maximum size a key entry in an RLE page takes - static constexpr const idx_t MAX_DICTIONARY_KEY_SIZE = sizeof(uint32_t); - //! The size of encoding the string length - static constexpr const idx_t STRING_LENGTH_SIZE = sizeof(uint32_t); - -public: - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; - void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void BeginWrite(ColumnWriterState &state) override; - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; - void FinalizeWrite(ColumnWriterState &state) override; - -protected: - static void WriteLevels(WriteStream &temp_writer, const unsafe_vector &levels, idx_t max_value, - idx_t start_offset, idx_t count); - - virtual duckdb_parquet::Encoding::type GetEncoding(BasicColumnWriterState &state); - - void NextPage(BasicColumnWriterState &state); - void FlushPage(BasicColumnWriterState &state); - - //! Initializes the state used to track statistics during writing. Only used for scalar types. - virtual unique_ptr InitializeStatsState(); - - //! Initialize the writer for a specific page. Only used for scalar types. - virtual unique_ptr InitializePageState(BasicColumnWriterState &state); - - //! Flushes the writer for a specific page. Only used for scalar types. - virtual void FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state); - - //! Retrieves the row size of a vector at the specified location. Only used for scalar types. - virtual idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const; - //! Writes a (subset of a) vector to the specified serializer. Only used for scalar types. - virtual void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats, ColumnWriterPageState *page_state, - Vector &vector, idx_t chunk_start, idx_t chunk_end) = 0; - - virtual bool HasDictionary(BasicColumnWriterState &state_p) { - return false; - } - //! The number of elements in the dictionary - virtual idx_t DictionarySize(BasicColumnWriterState &state_p); - void WriteDictionary(BasicColumnWriterState &state, unique_ptr temp_writer, idx_t row_count); - virtual void FlushDictionary(BasicColumnWriterState &state, ColumnWriterStatistics *stats); - - void SetParquetStatistics(BasicColumnWriterState &state, duckdb_parquet::ColumnChunk &column); - void RegisterToRowGroup(duckdb_parquet::RowGroup &row_group); -}; - -unique_ptr BasicColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { - auto result = make_uniq(row_group, row_group.columns.size()); - RegisterToRowGroup(row_group); - return std::move(result); -} - -void BasicColumnWriter::RegisterToRowGroup(duckdb_parquet::RowGroup &row_group) { - duckdb_parquet::ColumnChunk column_chunk; - column_chunk.__isset.meta_data = true; - column_chunk.meta_data.codec = writer.GetCodec(); - column_chunk.meta_data.path_in_schema = schema_path; - column_chunk.meta_data.num_values = 0; - column_chunk.meta_data.type = writer.GetType(schema_idx); - row_group.columns.push_back(std::move(column_chunk)); -} - -unique_ptr BasicColumnWriter::InitializePageState(BasicColumnWriterState &state) { - return nullptr; -} - -void BasicColumnWriter::FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state) { -} - -void BasicColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &col_chunk = state.row_group.columns[state.col_idx]; - - idx_t start = 0; - idx_t vcount = parent ? parent->definition_levels.size() - state.definition_levels.size() : count; - idx_t parent_index = state.definition_levels.size(); - auto &validity = FlatVector::Validity(vector); - HandleRepeatLevels(state, parent, count, max_repeat); - HandleDefineLevels(state, parent, validity, count, max_define, max_define - 1); - - idx_t vector_index = 0; - reference page_info_ref = state.page_info.back(); - for (idx_t i = start; i < vcount; i++) { - auto &page_info = page_info_ref.get(); - page_info.row_count++; - col_chunk.meta_data.num_values++; - if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index + i]) { - page_info.empty_count++; - continue; - } - if (validity.RowIsValid(vector_index)) { - page_info.estimated_page_size += GetRowSize(vector, vector_index, state); - if (page_info.estimated_page_size >= MAX_UNCOMPRESSED_PAGE_SIZE) { - PageInformation new_info; - new_info.offset = page_info.offset + page_info.row_count; - state.page_info.push_back(new_info); - page_info_ref = state.page_info.back(); - } - } - vector_index++; - } -} - -duckdb_parquet::Encoding::type BasicColumnWriter::GetEncoding(BasicColumnWriterState &state) { - return Encoding::PLAIN; -} - -void BasicColumnWriter::BeginWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - - // set up the page write info - state.stats_state = InitializeStatsState(); - for (idx_t page_idx = 0; page_idx < state.page_info.size(); page_idx++) { - auto &page_info = state.page_info[page_idx]; - if (page_info.row_count == 0) { - D_ASSERT(page_idx + 1 == state.page_info.size()); - state.page_info.erase_at(page_idx); - break; - } - PageWriteInformation write_info; - // set up the header - auto &hdr = write_info.page_header; - hdr.compressed_page_size = 0; - hdr.uncompressed_page_size = 0; - hdr.type = PageType::DATA_PAGE; - hdr.__isset.data_page_header = true; - - hdr.data_page_header.num_values = UnsafeNumericCast(page_info.row_count); - hdr.data_page_header.encoding = GetEncoding(state); - hdr.data_page_header.definition_level_encoding = Encoding::RLE; - hdr.data_page_header.repetition_level_encoding = Encoding::RLE; - - write_info.temp_writer = make_uniq( - MaxValue(NextPowerOfTwo(page_info.estimated_page_size), MemoryStream::DEFAULT_INITIAL_CAPACITY)); - write_info.write_count = page_info.empty_count; - write_info.max_write_count = page_info.row_count; - write_info.page_state = InitializePageState(state); - - write_info.compressed_size = 0; - write_info.compressed_data = nullptr; - - state.write_info.push_back(std::move(write_info)); - } - - // start writing the first page - NextPage(state); -} - -void BasicColumnWriter::WriteLevels(WriteStream &temp_writer, const unsafe_vector &levels, idx_t max_value, - idx_t offset, idx_t count) { - if (levels.empty() || count == 0) { - return; - } - - // write the levels using the RLE-BP encoding - auto bit_width = RleBpDecoder::ComputeBitWidth((max_value)); - RleBpEncoder rle_encoder(bit_width); - - rle_encoder.BeginPrepare(levels[offset]); - for (idx_t i = offset + 1; i < offset + count; i++) { - rle_encoder.PrepareValue(levels[i]); - } - rle_encoder.FinishPrepare(); - - // start off by writing the byte count as a uint32_t - temp_writer.Write(rle_encoder.GetByteCount()); - rle_encoder.BeginWrite(temp_writer, levels[offset]); - for (idx_t i = offset + 1; i < offset + count; i++) { - rle_encoder.WriteValue(temp_writer, levels[i]); - } - rle_encoder.FinishWrite(temp_writer); -} - -void BasicColumnWriter::NextPage(BasicColumnWriterState &state) { - if (state.current_page > 0) { - // need to flush the current page - FlushPage(state); - } - if (state.current_page >= state.write_info.size()) { - state.current_page = state.write_info.size() + 1; - return; - } - auto &page_info = state.page_info[state.current_page]; - auto &write_info = state.write_info[state.current_page]; - state.current_page++; - - auto &temp_writer = *write_info.temp_writer; - - // write the repetition levels - WriteLevels(temp_writer, state.repetition_levels, max_repeat, page_info.offset, page_info.row_count); - - // write the definition levels - WriteLevels(temp_writer, state.definition_levels, max_define, page_info.offset, page_info.row_count); -} - -void BasicColumnWriter::FlushPage(BasicColumnWriterState &state) { - D_ASSERT(state.current_page > 0); - if (state.current_page > state.write_info.size()) { - return; - } - - // compress the page info - auto &write_info = state.write_info[state.current_page - 1]; - auto &temp_writer = *write_info.temp_writer; - auto &hdr = write_info.page_header; - - FlushPageState(temp_writer, write_info.page_state.get()); - - // now that we have finished writing the data we know the uncompressed size - if (temp_writer.GetPosition() > idx_t(NumericLimits::Maximum())) { - throw InternalException("Parquet writer: %d uncompressed page size out of range for type integer", - temp_writer.GetPosition()); - } - hdr.uncompressed_page_size = UnsafeNumericCast(temp_writer.GetPosition()); - - // compress the data - CompressPage(temp_writer, write_info.compressed_size, write_info.compressed_data, write_info.compressed_buf); - hdr.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); - D_ASSERT(hdr.uncompressed_page_size > 0); - D_ASSERT(hdr.compressed_page_size > 0); - - if (write_info.compressed_buf) { - // if the data has been compressed, we no longer need the uncompressed data - D_ASSERT(write_info.compressed_buf.get() == write_info.compressed_data); - write_info.temp_writer.reset(); - } -} - -unique_ptr BasicColumnWriter::InitializeStatsState() { - return make_uniq(); -} - -idx_t BasicColumnWriter::GetRowSize(const Vector &vector, const idx_t index, - const BasicColumnWriterState &state) const { - throw InternalException("GetRowSize unsupported for struct/list column writers"); -} - -void BasicColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - - idx_t remaining = count; - idx_t offset = 0; - while (remaining > 0) { - auto &write_info = state.write_info[state.current_page - 1]; - if (!write_info.temp_writer) { - throw InternalException("Writes are not correctly aligned!?"); - } - auto &temp_writer = *write_info.temp_writer; - idx_t write_count = MinValue(remaining, write_info.max_write_count - write_info.write_count); - D_ASSERT(write_count > 0); - - WriteVector(temp_writer, state.stats_state.get(), write_info.page_state.get(), vector, offset, - offset + write_count); - - write_info.write_count += write_count; - if (write_info.write_count == write_info.max_write_count) { - NextPage(state); - } - offset += write_count; - remaining -= write_count; - } -} - -void BasicColumnWriter::SetParquetStatistics(BasicColumnWriterState &state, duckdb_parquet::ColumnChunk &column_chunk) { - if (!state.stats_state) { - return; - } - if (max_repeat == 0) { - column_chunk.meta_data.statistics.null_count = NumericCast(state.null_count); - column_chunk.meta_data.statistics.__isset.null_count = true; - column_chunk.meta_data.__isset.statistics = true; - } - // set min/max/min_value/max_value - // this code is not going to win any beauty contests, but well - auto min = state.stats_state->GetMin(); - if (!min.empty()) { - column_chunk.meta_data.statistics.min = std::move(min); - column_chunk.meta_data.statistics.__isset.min = true; - column_chunk.meta_data.__isset.statistics = true; - } - auto max = state.stats_state->GetMax(); - if (!max.empty()) { - column_chunk.meta_data.statistics.max = std::move(max); - column_chunk.meta_data.statistics.__isset.max = true; - column_chunk.meta_data.__isset.statistics = true; - } - if (state.stats_state->HasStats()) { - column_chunk.meta_data.statistics.min_value = state.stats_state->GetMinValue(); - column_chunk.meta_data.statistics.__isset.min_value = true; - column_chunk.meta_data.__isset.statistics = true; - - column_chunk.meta_data.statistics.max_value = state.stats_state->GetMaxValue(); - column_chunk.meta_data.statistics.__isset.max_value = true; - column_chunk.meta_data.__isset.statistics = true; - } - if (HasDictionary(state)) { - column_chunk.meta_data.statistics.distinct_count = UnsafeNumericCast(DictionarySize(state)); - column_chunk.meta_data.statistics.__isset.distinct_count = true; - column_chunk.meta_data.__isset.statistics = true; - } - for (const auto &write_info : state.write_info) { - // only care about data page encodings, data_page_header.encoding is meaningless for dict - if (write_info.page_header.type != PageType::DATA_PAGE && - write_info.page_header.type != PageType::DATA_PAGE_V2) { - continue; - } - column_chunk.meta_data.encodings.push_back(write_info.page_header.data_page_header.encoding); - } -} - -void BasicColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - auto &column_chunk = state.row_group.columns[state.col_idx]; - - // flush the last page (if any remains) - FlushPage(state); - - auto &column_writer = writer.GetWriter(); - auto start_offset = column_writer.GetTotalWritten(); - // flush the dictionary - if (HasDictionary(state)) { - column_chunk.meta_data.statistics.distinct_count = UnsafeNumericCast(DictionarySize(state)); - column_chunk.meta_data.statistics.__isset.distinct_count = true; - column_chunk.meta_data.dictionary_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); - column_chunk.meta_data.__isset.dictionary_page_offset = true; - FlushDictionary(state, state.stats_state.get()); - } - - // record the start position of the pages for this column - column_chunk.meta_data.data_page_offset = 0; - SetParquetStatistics(state, column_chunk); - - // write the individual pages to disk - idx_t total_uncompressed_size = 0; - for (auto &write_info : state.write_info) { - // set the data page offset whenever we see the *first* data page - if (column_chunk.meta_data.data_page_offset == 0 && (write_info.page_header.type == PageType::DATA_PAGE || - write_info.page_header.type == PageType::DATA_PAGE_V2)) { - column_chunk.meta_data.data_page_offset = UnsafeNumericCast(column_writer.GetTotalWritten()); - ; - } - D_ASSERT(write_info.page_header.uncompressed_page_size > 0); - auto header_start_offset = column_writer.GetTotalWritten(); - writer.Write(write_info.page_header); - // total uncompressed size in the column chunk includes the header size (!) - total_uncompressed_size += column_writer.GetTotalWritten() - header_start_offset; - total_uncompressed_size += write_info.page_header.uncompressed_page_size; - writer.WriteData(write_info.compressed_data, write_info.compressed_size); - } - column_chunk.meta_data.total_compressed_size = - UnsafeNumericCast(column_writer.GetTotalWritten() - start_offset); - column_chunk.meta_data.total_uncompressed_size = UnsafeNumericCast(total_uncompressed_size); - - if (state.bloom_filter) { - writer.BufferBloomFilter(state.col_idx, std::move(state.bloom_filter)); - } - // which row group is this? -} - -void BasicColumnWriter::FlushDictionary(BasicColumnWriterState &state, ColumnWriterStatistics *stats) { - throw InternalException("This page does not have a dictionary"); -} - -idx_t BasicColumnWriter::DictionarySize(BasicColumnWriterState &state) { - throw InternalException("This page does not have a dictionary"); -} - -void BasicColumnWriter::WriteDictionary(BasicColumnWriterState &state, unique_ptr temp_writer, - idx_t row_count) { - D_ASSERT(temp_writer); - D_ASSERT(temp_writer->GetPosition() > 0); - - // write the dictionary page header - PageWriteInformation write_info; - // set up the header - auto &hdr = write_info.page_header; - hdr.uncompressed_page_size = UnsafeNumericCast(temp_writer->GetPosition()); - hdr.type = PageType::DICTIONARY_PAGE; - hdr.__isset.dictionary_page_header = true; - - hdr.dictionary_page_header.encoding = Encoding::PLAIN; - hdr.dictionary_page_header.is_sorted = false; - hdr.dictionary_page_header.num_values = UnsafeNumericCast(row_count); - - write_info.temp_writer = std::move(temp_writer); - write_info.write_count = 0; - write_info.max_write_count = 0; - - // compress the contents of the dictionary page - CompressPage(*write_info.temp_writer, write_info.compressed_size, write_info.compressed_data, - write_info.compressed_buf); - hdr.compressed_page_size = UnsafeNumericCast(write_info.compressed_size); - - // insert the dictionary page as the first page to write for this column - state.write_info.insert(state.write_info.begin(), std::move(write_info)); -} - -//===--------------------------------------------------------------------===// -// Standard Column Writer -//===--------------------------------------------------------------------===// -template -class NumericStatisticsState : public ColumnWriterStatistics { -public: - NumericStatisticsState() : min(NumericLimits::Maximum()), max(NumericLimits::Minimum()) { - } - - T min; - T max; - -public: - bool HasStats() override { - return min <= max; - } - - string GetMin() override { - return NumericLimits::IsSigned() ? GetMinValue() : string(); - } - string GetMax() override { - return NumericLimits::IsSigned() ? GetMaxValue() : string(); - } - string GetMinValue() override { - return HasStats() ? string(char_ptr_cast(&min), sizeof(T)) : string(); - } - string GetMaxValue() override { - return HasStats() ? string(char_ptr_cast(&max), sizeof(T)) : string(); - } -}; - -struct BaseParquetOperator { - - template - static void WriteToStream(const TGT &input, WriteStream &ser) { - ser.WriteData(const_data_ptr_cast(&input), sizeof(TGT)); - } - - template - static uint64_t XXHash64(const TGT &target_value) { - return duckdb_zstd::XXH64(&target_value, sizeof(target_value), 0); - } - - template - static unique_ptr InitializeStats() { - return nullptr; - } - - template - static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { - } -}; - -struct ParquetCastOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - return TGT(input); - } - template - static unique_ptr InitializeStats() { - return make_uniq>(); - } - - template - static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { - auto &numeric_stats = (NumericStatisticsState &)*stats; - if (LessThan::Operation(target_value, numeric_stats.min)) { - numeric_stats.min = target_value; - } - if (GreaterThan::Operation(target_value, numeric_stats.max)) { - numeric_stats.max = target_value; - } - } -}; - -struct ParquetTimestampNSOperator : public ParquetCastOperator { - template - static TGT Operation(SRC input) { - return TGT(input); - } -}; - -struct ParquetTimestampSOperator : public ParquetCastOperator { - template - static TGT Operation(SRC input) { - return Timestamp::FromEpochSecondsPossiblyInfinite(input).value; - } -}; - -class StringStatisticsState : public ColumnWriterStatistics { - static constexpr const idx_t MAX_STRING_STATISTICS_SIZE = 10000; - -public: - StringStatisticsState() : has_stats(false), values_too_big(false), min(), max() { - } - - bool has_stats; - bool values_too_big; - string min; - string max; - -public: - bool HasStats() override { - return has_stats; - } - - void Update(const string_t &val) { - if (values_too_big) { - return; - } - auto str_len = val.GetSize(); - if (str_len > MAX_STRING_STATISTICS_SIZE) { - // we avoid gathering stats when individual string values are too large - // this is because the statistics are copied into the Parquet file meta data in uncompressed format - // ideally we avoid placing several mega or giga-byte long strings there - // we put a threshold of 10KB, if we see strings that exceed this threshold we avoid gathering stats - values_too_big = true; - has_stats = false; - min = string(); - max = string(); - return; - } - if (!has_stats || LessThan::Operation(val, string_t(min))) { - min = val.GetString(); - } - if (!has_stats || GreaterThan::Operation(val, string_t(max))) { - max = val.GetString(); - } - has_stats = true; - } - - string GetMin() override { - return GetMinValue(); - } - string GetMax() override { - return GetMaxValue(); - } - string GetMinValue() override { - return HasStats() ? min : string(); - } - string GetMaxValue() override { - return HasStats() ? max : string(); - } -}; - -struct ParquetStringOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - return input; - } - - template - static unique_ptr InitializeStats() { - return make_uniq(); - } - - template - static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { - auto &string_stats = stats->Cast(); - string_stats.Update(target_value); - } - - template - static void WriteToStream(const TGT &target_value, WriteStream &ser) { - ser.Write(target_value.GetSize()); - ser.WriteData(const_data_ptr_cast(target_value.GetData()), target_value.GetSize()); - } - - template - static uint64_t XXHash64(const TGT &target_value) { - return duckdb_zstd::XXH64(target_value.GetData(), target_value.GetSize(), 0); - } -}; - -struct ParquetIntervalTargetType { - static constexpr const idx_t PARQUET_INTERVAL_SIZE = 12; - data_t bytes[PARQUET_INTERVAL_SIZE]; -}; - -struct ParquetIntervalOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - - if (input.days < 0 || input.months < 0 || input.micros < 0) { - throw IOException("Parquet files do not support negative intervals"); - } - TGT result; - Store(input.months, result.bytes); - Store(input.days, result.bytes + sizeof(uint32_t)); - Store(input.micros / 1000, result.bytes + sizeof(uint32_t) * 2); - return result; - } - - template - static void WriteToStream(const TGT &target_value, WriteStream &ser) { - ser.WriteData(target_value.bytes, ParquetIntervalTargetType::PARQUET_INTERVAL_SIZE); - } - - template - static uint64_t XXHash64(const TGT &target_value) { - return duckdb_zstd::XXH64(target_value.bytes, ParquetIntervalTargetType::PARQUET_INTERVAL_SIZE, 0); - } -}; - -struct ParquetUUIDTargetType { - static constexpr const idx_t PARQUET_UUID_SIZE = 16; - data_t bytes[PARQUET_UUID_SIZE]; -}; - -struct ParquetUUIDOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - TGT result; - uint64_t high_bytes = input.upper ^ (int64_t(1) << 63); - uint64_t low_bytes = input.lower; - for (idx_t i = 0; i < sizeof(uint64_t); i++) { - auto shift_count = (sizeof(uint64_t) - i - 1) * 8; - result.bytes[i] = (high_bytes >> shift_count) & 0xFF; - } - for (idx_t i = 0; i < sizeof(uint64_t); i++) { - auto shift_count = (sizeof(uint64_t) - i - 1) * 8; - result.bytes[sizeof(uint64_t) + i] = (low_bytes >> shift_count) & 0xFF; - } - return result; - } - - template - static void WriteToStream(const TGT &target_value, WriteStream &ser) { - ser.WriteData(target_value.bytes, ParquetUUIDTargetType::PARQUET_UUID_SIZE); - } - - template - static uint64_t XXHash64(const TGT &target_value) { - return duckdb_zstd::XXH64(target_value.bytes, ParquetUUIDTargetType::PARQUET_UUID_SIZE, 0); - } -}; - -struct ParquetTimeTZOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - return input.time().micros; - } -}; - -struct ParquetHugeintOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - return Hugeint::Cast(input); - } - - template - static unique_ptr InitializeStats() { - return make_uniq(); - } - - template - static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { - } -}; - -struct ParquetUhugeintOperator : public BaseParquetOperator { - template - static TGT Operation(SRC input) { - return Uhugeint::Cast(input); - } - - template - static unique_ptr InitializeStats() { - return make_uniq(); - } - - template - static void HandleStats(ColumnWriterStatistics *stats, TGT target_value) { - } -}; - -template -static void TemplatedWritePlain(Vector &col, ColumnWriterStatistics *stats, const idx_t chunk_start, - const idx_t chunk_end, const ValidityMask &mask, WriteStream &ser) { - - const auto *ptr = FlatVector::GetData(col); - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - TGT target_value = OP::template Operation(ptr[r]); - OP::template HandleStats(stats, target_value); - OP::template WriteToStream(target_value, ser); - } -} - -template -class StandardColumnWriterState : public BasicColumnWriterState { -public: - StandardColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) - : BasicColumnWriterState(row_group, col_idx) { - } - ~StandardColumnWriterState() override = default; - - // analysis state for integer values for DELTA_BINARY_PACKED/DELTA_LENGTH_BYTE_ARRAY - idx_t total_value_count = 0; - idx_t total_string_size = 0; - - unordered_map dictionary; - duckdb_parquet::Encoding::type encoding; -}; - -template -class StandardWriterPageState : public ColumnWriterPageState { -public: - explicit StandardWriterPageState(const idx_t total_value_count, const idx_t total_string_size, - Encoding::type encoding_p, const unordered_map &dictionary_p) - : encoding(encoding_p), dbp_initialized(false), dbp_encoder(total_value_count), dlba_initialized(false), - dlba_encoder(total_value_count, total_string_size), bss_encoder(total_value_count, sizeof(TGT)), - dictionary(dictionary_p), dict_written_value(false), - dict_bit_width(RleBpDecoder::ComputeBitWidth(dictionary.size())), dict_encoder(dict_bit_width) { - } - duckdb_parquet::Encoding::type encoding; - - bool dbp_initialized; - DbpEncoder dbp_encoder; - - bool dlba_initialized; - DlbaEncoder dlba_encoder; - - BssEncoder bss_encoder; - - const unordered_map &dictionary; - bool dict_written_value; - uint32_t dict_bit_width; - RleBpEncoder dict_encoder; -}; - -namespace dbp_encoder { - -template -void BeginWrite(DbpEncoder &encoder, WriteStream &writer, const T &first_value) { - throw InternalException("Can't write type to DELTA_BINARY_PACKED column"); -} - -template <> -void BeginWrite(DbpEncoder &encoder, WriteStream &writer, const int64_t &first_value) { - encoder.BeginWrite(writer, first_value); -} - -template <> -void BeginWrite(DbpEncoder &encoder, WriteStream &writer, const int32_t &first_value) { - BeginWrite(encoder, writer, UnsafeNumericCast(first_value)); -} - -template <> -void BeginWrite(DbpEncoder &encoder, WriteStream &writer, const uint64_t &first_value) { - encoder.BeginWrite(writer, UnsafeNumericCast(first_value)); -} - -template <> -void BeginWrite(DbpEncoder &encoder, WriteStream &writer, const uint32_t &first_value) { - BeginWrite(encoder, writer, UnsafeNumericCast(first_value)); -} - -template -void WriteValue(DbpEncoder &encoder, WriteStream &writer, const T &value) { - throw InternalException("Can't write type to DELTA_BINARY_PACKED column"); -} - -template <> -void WriteValue(DbpEncoder &encoder, WriteStream &writer, const int64_t &value) { - encoder.WriteValue(writer, value); -} - -template <> -void WriteValue(DbpEncoder &encoder, WriteStream &writer, const int32_t &value) { - WriteValue(encoder, writer, UnsafeNumericCast(value)); -} - -template <> -void WriteValue(DbpEncoder &encoder, WriteStream &writer, const uint64_t &value) { - encoder.WriteValue(writer, UnsafeNumericCast(value)); -} - -template <> -void WriteValue(DbpEncoder &encoder, WriteStream &writer, const uint32_t &value) { - WriteValue(encoder, writer, UnsafeNumericCast(value)); -} - -} // namespace dbp_encoder - -namespace dlba_encoder { - -template -void BeginWrite(DlbaEncoder &encoder, WriteStream &writer, const T &first_value) { - throw InternalException("Can't write type to DELTA_LENGTH_BYTE_ARRAY column"); -} - -template <> -void BeginWrite(DlbaEncoder &encoder, WriteStream &writer, const string_t &first_value) { - encoder.BeginWrite(writer, first_value); -} - -template -void WriteValue(DlbaEncoder &encoder, WriteStream &writer, const T &value) { - throw InternalException("Can't write type to DELTA_LENGTH_BYTE_ARRAY column"); -} - -template <> -void WriteValue(DlbaEncoder &encoder, WriteStream &writer, const string_t &value) { - encoder.WriteValue(writer, value); -} - -// helpers to get size from strings -template -static constexpr idx_t GetDlbaStringSize(const SRC &src_value) { - return 0; -} - -template <> -idx_t GetDlbaStringSize(const string_t &src_value) { - return src_value.GetSize(); -} - -} // namespace dlba_encoder - -namespace bss_encoder { - -template -void WriteValue(BssEncoder &encoder, const T &value) { - throw InternalException("Can't write type to BYTE_STREAM_SPLIT column"); -} - -template <> -void WriteValue(BssEncoder &encoder, const float &value) { - encoder.WriteValue(value); -} - -template <> -void WriteValue(BssEncoder &encoder, const double &value) { - encoder.WriteValue(value); -} - -} // namespace bss_encoder - -template -class StandardColumnWriter : public BasicColumnWriter { -public: - StandardColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, // NOLINT - idx_t max_repeat, idx_t max_define, bool can_have_nulls) - : BasicColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls) { - } - ~StandardColumnWriter() override = default; - -public: - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override { - auto result = make_uniq>(row_group, row_group.columns.size()); - result->encoding = Encoding::RLE_DICTIONARY; - RegisterToRowGroup(row_group); - return std::move(result); - } - - unique_ptr InitializePageState(BasicColumnWriterState &state_p) override { - auto &state = state_p.Cast>(); - - auto result = make_uniq>(state.total_value_count, state.total_string_size, - state.encoding, state.dictionary); - return std::move(result); - } - - void FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state_p) override { - auto &page_state = state_p->Cast>(); - switch (page_state.encoding) { - case Encoding::DELTA_BINARY_PACKED: - if (!page_state.dbp_initialized) { - dbp_encoder::BeginWrite(page_state.dbp_encoder, temp_writer, 0); - } - page_state.dbp_encoder.FinishWrite(temp_writer); - break; - case Encoding::RLE_DICTIONARY: - D_ASSERT(page_state.dict_bit_width != 0); - if (!page_state.dict_written_value) { - // all values are null - // just write the bit width - temp_writer.Write(page_state.dict_bit_width); - return; - } - page_state.dict_encoder.FinishWrite(temp_writer); - break; - case Encoding::DELTA_LENGTH_BYTE_ARRAY: - if (!page_state.dlba_initialized) { - dlba_encoder::BeginWrite(page_state.dlba_encoder, temp_writer, string_t("")); - } - page_state.dlba_encoder.FinishWrite(temp_writer); - break; - case Encoding::BYTE_STREAM_SPLIT: - page_state.bss_encoder.FinishWrite(temp_writer); - break; - case Encoding::PLAIN: - break; - default: - throw InternalException("Unknown encoding"); - } - } - - Encoding::type GetEncoding(BasicColumnWriterState &state_p) override { - auto &state = state_p.Cast>(); - return state.encoding; - } - - bool HasAnalyze() override { - return true; - } - - void Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) override { - auto &state = state_p.Cast>(); - - auto data_ptr = FlatVector::GetData(vector); - idx_t vector_index = 0; - uint32_t new_value_index = state.dictionary.size(); - - const bool check_parent_empty = parent && !parent->is_empty.empty(); - const idx_t parent_index = state.definition_levels.size(); - - const idx_t vcount = - check_parent_empty ? parent->definition_levels.size() - state.definition_levels.size() : count; - - const auto &validity = FlatVector::Validity(vector); - - for (idx_t i = 0; i < vcount; i++) { - if (check_parent_empty && parent->is_empty[parent_index + i]) { - continue; - } - if (validity.RowIsValid(vector_index)) { - const auto &src_value = data_ptr[vector_index]; - if (state.dictionary.size() <= writer.DictionarySizeLimit()) { - if (state.dictionary.find(src_value) == state.dictionary.end()) { - state.dictionary[src_value] = new_value_index; - new_value_index++; - } - } - state.total_value_count++; - state.total_string_size += dlba_encoder::GetDlbaStringSize(src_value); - } - vector_index++; - } - } - - void FinalizeAnalyze(ColumnWriterState &state_p) override { - const auto type = writer.GetType(schema_idx); - - auto &state = state_p.Cast>(); - if (state.dictionary.size() == 0 || state.dictionary.size() > writer.DictionarySizeLimit()) { - // If we aren't doing dictionary encoding, the following encodings are virtually always better than PLAIN - switch (type) { - case Type::type::INT32: - case Type::type::INT64: - state.encoding = Encoding::DELTA_BINARY_PACKED; - break; - case Type::type::BYTE_ARRAY: - state.encoding = Encoding::DELTA_LENGTH_BYTE_ARRAY; - break; - case Type::type::FLOAT: - case Type::type::DOUBLE: - state.encoding = Encoding::BYTE_STREAM_SPLIT; - break; - default: - state.encoding = Encoding::PLAIN; - } - state.dictionary.clear(); - } - } - - unique_ptr InitializeStatsState() override { - return OP::template InitializeStats(); - } - - bool HasDictionary(BasicColumnWriterState &state_p) override { - auto &state = state_p.Cast>(); - return state.encoding == Encoding::RLE_DICTIONARY; - } - - idx_t DictionarySize(BasicColumnWriterState &state_p) override { - auto &state = state_p.Cast>(); - return state.dictionary.size(); - } - - void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats, ColumnWriterPageState *page_state_p, - Vector &input_column, idx_t chunk_start, idx_t chunk_end) override { - auto &page_state = page_state_p->Cast>(); - - const auto &mask = FlatVector::Validity(input_column); - const auto *data_ptr = FlatVector::GetData(input_column); - - switch (page_state.encoding) { - case Encoding::RLE_DICTIONARY: { - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - auto &src_val = data_ptr[r]; - auto value_index = page_state.dictionary.at(src_val); - if (!page_state.dict_written_value) { - // first value - // write the bit-width as a one-byte entry - temp_writer.Write(page_state.dict_bit_width); - // now begin writing the actual value - page_state.dict_encoder.BeginWrite(temp_writer, value_index); - page_state.dict_written_value = true; - } else { - page_state.dict_encoder.WriteValue(temp_writer, value_index); - } - } - break; - } - case Encoding::DELTA_BINARY_PACKED: { - idx_t r = chunk_start; - if (!page_state.dbp_initialized) { - // find first non-null value - for (; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - const TGT target_value = OP::template Operation(data_ptr[r]); - OP::template HandleStats(stats, target_value); - dbp_encoder::BeginWrite(page_state.dbp_encoder, temp_writer, target_value); - page_state.dbp_initialized = true; - r++; // skip over - break; - } - } - - for (; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - const TGT target_value = OP::template Operation(data_ptr[r]); - OP::template HandleStats(stats, target_value); - dbp_encoder::WriteValue(page_state.dbp_encoder, temp_writer, target_value); - } - break; - } - case Encoding::DELTA_LENGTH_BYTE_ARRAY: { - idx_t r = chunk_start; - if (!page_state.dlba_initialized) { - // find first non-null value - for (; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - const TGT target_value = OP::template Operation(data_ptr[r]); - OP::template HandleStats(stats, target_value); - dlba_encoder::BeginWrite(page_state.dlba_encoder, temp_writer, target_value); - page_state.dlba_initialized = true; - r++; // skip over - break; - } - } - - for (; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - const TGT target_value = OP::template Operation(data_ptr[r]); - OP::template HandleStats(stats, target_value); - dlba_encoder::WriteValue(page_state.dlba_encoder, temp_writer, target_value); - } - break; - } - case Encoding::BYTE_STREAM_SPLIT: { - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (!mask.RowIsValid(r)) { - continue; - } - const TGT target_value = OP::template Operation(data_ptr[r]); - OP::template HandleStats(stats, target_value); - bss_encoder::WriteValue(page_state.bss_encoder, target_value); - } - break; - } - case Encoding::PLAIN: { - D_ASSERT(page_state.encoding == Encoding::PLAIN); - TemplatedWritePlain(input_column, stats, chunk_start, chunk_end, mask, temp_writer); - break; - } - default: - throw InternalException("Unknown encoding"); - } - } - - void FlushDictionary(BasicColumnWriterState &state_p, ColumnWriterStatistics *stats) override { - auto &state = state_p.Cast>(); - - D_ASSERT(state.encoding == Encoding::RLE_DICTIONARY); - - // first we need to sort the values in index order - auto values = vector(state.dictionary.size()); - for (const auto &entry : state.dictionary) { - values[entry.second] = entry.first; - } - - state.bloom_filter = - make_uniq(state.dictionary.size(), writer.BloomFilterFalsePositiveRatio()); - - // first write the contents of the dictionary page to a temporary buffer - auto temp_writer = make_uniq(MaxValue( - NextPowerOfTwo(state.dictionary.size() * sizeof(TGT)), MemoryStream::DEFAULT_INITIAL_CAPACITY)); - for (idx_t r = 0; r < values.size(); r++) { - const TGT target_value = OP::template Operation(values[r]); - // update the statistics - OP::template HandleStats(stats, target_value); - // update the bloom filter - auto hash = OP::template XXHash64(target_value); - state.bloom_filter->FilterInsert(hash); - // actually write the dictionary value - OP::template WriteToStream(target_value, *temp_writer); - } - // flush the dictionary page and add it to the to-be-written pages - WriteDictionary(state, std::move(temp_writer), values.size()); - // bloom filter will be queued for writing in ParquetWriter::BufferBloomFilter one level up - } - - // TODO this now vastly over-estimates the page size - idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state_p) const override { - return sizeof(TGT); - } -}; - -//===--------------------------------------------------------------------===// -// Boolean Column Writer -//===--------------------------------------------------------------------===// -class BooleanStatisticsState : public ColumnWriterStatistics { -public: - BooleanStatisticsState() : min(true), max(false) { - } - - bool min; - bool max; - -public: - bool HasStats() override { - return !(min && !max); - } - - string GetMin() override { - return GetMinValue(); - } - string GetMax() override { - return GetMaxValue(); - } - string GetMinValue() override { - return HasStats() ? string(const_char_ptr_cast(&min), sizeof(bool)) : string(); - } - string GetMaxValue() override { - return HasStats() ? string(const_char_ptr_cast(&max), sizeof(bool)) : string(); - } -}; - -class BooleanWriterPageState : public ColumnWriterPageState { -public: - uint8_t byte = 0; - uint8_t byte_pos = 0; -}; - -class BooleanColumnWriter : public BasicColumnWriter { -public: - BooleanColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) - : BasicColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls) { - } - ~BooleanColumnWriter() override = default; - -public: - unique_ptr InitializeStatsState() override { - return make_uniq(); - } - - void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats_p, ColumnWriterPageState *state_p, - Vector &input_column, idx_t chunk_start, idx_t chunk_end) override { - auto &stats = stats_p->Cast(); - auto &state = state_p->Cast(); - auto &mask = FlatVector::Validity(input_column); - - auto *ptr = FlatVector::GetData(input_column); - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (mask.RowIsValid(r)) { - // only encode if non-null - if (ptr[r]) { - stats.max = true; - state.byte |= 1 << state.byte_pos; - } else { - stats.min = false; - } - state.byte_pos++; - - if (state.byte_pos == 8) { - temp_writer.Write(state.byte); - state.byte = 0; - state.byte_pos = 0; - } - } - } - } - - unique_ptr InitializePageState(BasicColumnWriterState &state) override { - return make_uniq(); - } - - void FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state_p) override { - auto &state = state_p->Cast(); - if (state.byte_pos > 0) { - temp_writer.Write(state.byte); - state.byte = 0; - state.byte_pos = 0; - } - } - - idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { - return sizeof(bool); - } -}; - -//===--------------------------------------------------------------------===// -// Decimal Column Writer -//===--------------------------------------------------------------------===// -static void WriteParquetDecimal(hugeint_t input, data_ptr_t result) { - bool positive = input >= 0; - // numbers are stored as two's complement so some muckery is required - if (!positive) { - input = NumericLimits::Maximum() + input + 1; - } - uint64_t high_bytes = uint64_t(input.upper); - uint64_t low_bytes = input.lower; - - for (idx_t i = 0; i < sizeof(uint64_t); i++) { - auto shift_count = (sizeof(uint64_t) - i - 1) * 8; - result[i] = (high_bytes >> shift_count) & 0xFF; - } - for (idx_t i = 0; i < sizeof(uint64_t); i++) { - auto shift_count = (sizeof(uint64_t) - i - 1) * 8; - result[sizeof(uint64_t) + i] = (low_bytes >> shift_count) & 0xFF; - } - if (!positive) { - result[0] |= 0x80; - } -} - -class FixedDecimalStatistics : public ColumnWriterStatistics { -public: - FixedDecimalStatistics() : min(NumericLimits::Maximum()), max(NumericLimits::Minimum()) { - } - - hugeint_t min; - hugeint_t max; - -public: - string GetStats(hugeint_t &input) { - data_t buffer[16]; - WriteParquetDecimal(input, buffer); - return string(const_char_ptr_cast(buffer), 16); - } - - bool HasStats() override { - return min <= max; - } - - void Update(hugeint_t &val) { - if (LessThan::Operation(val, min)) { - min = val; - } - if (GreaterThan::Operation(val, max)) { - max = val; - } - } - - string GetMin() override { - return GetMinValue(); - } - string GetMax() override { - return GetMaxValue(); - } - string GetMinValue() override { - return HasStats() ? GetStats(min) : string(); - } - string GetMaxValue() override { - return HasStats() ? GetStats(max) : string(); - } -}; - -class FixedDecimalColumnWriter : public BasicColumnWriter { -public: - FixedDecimalColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, bool can_have_nulls) - : BasicColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls) { - } - ~FixedDecimalColumnWriter() override = default; - -public: - unique_ptr InitializeStatsState() override { - return make_uniq(); - } - - void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats_p, ColumnWriterPageState *page_state, - Vector &input_column, idx_t chunk_start, idx_t chunk_end) override { - auto &mask = FlatVector::Validity(input_column); - auto *ptr = FlatVector::GetData(input_column); - auto &stats = stats_p->Cast(); - - data_t temp_buffer[16]; - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (mask.RowIsValid(r)) { - stats.Update(ptr[r]); - WriteParquetDecimal(ptr[r], temp_buffer); - temp_writer.WriteData(temp_buffer, 16); - } - } - } - - idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { - return sizeof(hugeint_t); - } -}; - -//===--------------------------------------------------------------------===// -// WKB Column Writer -//===--------------------------------------------------------------------===// -// Used to store the metadata for a WKB-encoded geometry column when writing -// GeoParquet files. -class WKBColumnWriterState final : public StandardColumnWriterState { -public: - WKBColumnWriterState(ClientContext &context, duckdb_parquet::RowGroup &row_group, idx_t col_idx) - : StandardColumnWriterState(row_group, col_idx), geo_data(), geo_data_writer(context) { - } - - GeoParquetColumnMetadata geo_data; - GeoParquetColumnMetadataWriter geo_data_writer; -}; - -class WKBColumnWriter final : public StandardColumnWriter { -public: - WKBColumnWriter(ClientContext &context_p, ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, - idx_t max_repeat, idx_t max_define, bool can_have_nulls, string name) - : StandardColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls), - column_name(std::move(name)), context(context_p) { - - this->writer.GetGeoParquetData().RegisterGeometryColumn(column_name); - } - - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override { - auto result = make_uniq(context, row_group, row_group.columns.size()); - result->encoding = Encoding::RLE_DICTIONARY; - RegisterToRowGroup(row_group); - return std::move(result); - } - - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override { - StandardColumnWriter::Write(state, vector, count); - - auto &geo_state = state.Cast(); - geo_state.geo_data_writer.Update(geo_state.geo_data, vector, count); - } - - void FinalizeWrite(ColumnWriterState &state) override { - StandardColumnWriter::FinalizeWrite(state); - - // Add the geodata object to the writer - const auto &geo_state = state.Cast(); - - // Merge this state's geo column data with the writer's geo column data - writer.GetGeoParquetData().FlushColumnMeta(column_name, geo_state.geo_data); - } - -private: - string column_name; - ClientContext &context; -}; - -//===--------------------------------------------------------------------===// -// Enum Column Writer -//===--------------------------------------------------------------------===// -class EnumWriterPageState : public ColumnWriterPageState { -public: - explicit EnumWriterPageState(uint32_t bit_width) : encoder(bit_width), written_value(false) { - } - - RleBpEncoder encoder; - bool written_value; -}; - -class EnumColumnWriter : public BasicColumnWriter { -public: - EnumColumnWriter(ParquetWriter &writer, LogicalType enum_type_p, idx_t schema_idx, vector schema_path_p, - idx_t max_repeat, idx_t max_define, bool can_have_nulls) - : BasicColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls), - enum_type(std::move(enum_type_p)) { - bit_width = RleBpDecoder::ComputeBitWidth(EnumType::GetSize(enum_type)); - } - ~EnumColumnWriter() override = default; - - LogicalType enum_type; - uint32_t bit_width; - -public: - unique_ptr InitializeStatsState() override { - return make_uniq(); - } - - template - void WriteEnumInternal(WriteStream &temp_writer, Vector &input_column, idx_t chunk_start, idx_t chunk_end, - EnumWriterPageState &page_state) { - auto &mask = FlatVector::Validity(input_column); - auto *ptr = FlatVector::GetData(input_column); - for (idx_t r = chunk_start; r < chunk_end; r++) { - if (mask.RowIsValid(r)) { - if (!page_state.written_value) { - // first value - // write the bit-width as a one-byte entry - temp_writer.Write(bit_width); - // now begin writing the actual value - page_state.encoder.BeginWrite(temp_writer, ptr[r]); - page_state.written_value = true; - } else { - page_state.encoder.WriteValue(temp_writer, ptr[r]); - } - } - } - } - - void WriteVector(WriteStream &temp_writer, ColumnWriterStatistics *stats_p, ColumnWriterPageState *page_state_p, - Vector &input_column, idx_t chunk_start, idx_t chunk_end) override { - auto &page_state = page_state_p->Cast(); - switch (enum_type.InternalType()) { - case PhysicalType::UINT8: - WriteEnumInternal(temp_writer, input_column, chunk_start, chunk_end, page_state); - break; - case PhysicalType::UINT16: - WriteEnumInternal(temp_writer, input_column, chunk_start, chunk_end, page_state); - break; - case PhysicalType::UINT32: - WriteEnumInternal(temp_writer, input_column, chunk_start, chunk_end, page_state); - break; - default: - throw InternalException("Unsupported internal enum type"); - } - } - - unique_ptr InitializePageState(BasicColumnWriterState &state) override { - return make_uniq(bit_width); - } - - void FlushPageState(WriteStream &temp_writer, ColumnWriterPageState *state_p) override { - auto &page_state = state_p->Cast(); - if (!page_state.written_value) { - // all values are null - // just write the bit width - temp_writer.Write(bit_width); - return; - } - page_state.encoder.FinishWrite(temp_writer); - } - - duckdb_parquet::Encoding::type GetEncoding(BasicColumnWriterState &state) override { - return Encoding::RLE_DICTIONARY; - } - - bool HasDictionary(BasicColumnWriterState &state) override { - return true; - } - - idx_t DictionarySize(BasicColumnWriterState &state_p) override { - return EnumType::GetSize(enum_type); - } - - void FlushDictionary(BasicColumnWriterState &state, ColumnWriterStatistics *stats_p) override { - auto &stats = stats_p->Cast(); - // write the enum values to a dictionary page - auto &enum_values = EnumType::GetValuesInsertOrder(enum_type); - auto enum_count = EnumType::GetSize(enum_type); - auto string_values = FlatVector::GetData(enum_values); - // first write the contents of the dictionary page to a temporary buffer - auto temp_writer = make_uniq(); - for (idx_t r = 0; r < enum_count; r++) { - D_ASSERT(!FlatVector::IsNull(enum_values, r)); - // update the statistics - stats.Update(string_values[r]); - // write this string value to the dictionary - temp_writer->Write(string_values[r].GetSize()); - temp_writer->WriteData(const_data_ptr_cast(string_values[r].GetData()), string_values[r].GetSize()); - } - // flush the dictionary page and add it to the to-be-written pages - WriteDictionary(state, std::move(temp_writer), enum_count); - } - - idx_t GetRowSize(const Vector &vector, const idx_t index, const BasicColumnWriterState &state) const override { - return (bit_width + 7) / 8; - } -}; - -//===--------------------------------------------------------------------===// -// Struct Column Writer -//===--------------------------------------------------------------------===// -class StructColumnWriter : public ColumnWriter { -public: - StructColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, vector> child_writers_p, bool can_have_nulls) - : ColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls), - child_writers(std::move(child_writers_p)) { - } - ~StructColumnWriter() override = default; - - vector> child_writers; - -public: - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; - bool HasAnalyze() override; - void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void FinalizeAnalyze(ColumnWriterState &state) override; - void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - - void BeginWrite(ColumnWriterState &state) override; - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; - void FinalizeWrite(ColumnWriterState &state) override; -}; - -class StructColumnWriterState : public ColumnWriterState { -public: - StructColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) - : row_group(row_group), col_idx(col_idx) { - } - ~StructColumnWriterState() override = default; - - duckdb_parquet::RowGroup &row_group; - idx_t col_idx; - vector> child_states; -}; - -unique_ptr StructColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { - auto result = make_uniq(row_group, row_group.columns.size()); - - result->child_states.reserve(child_writers.size()); - for (auto &child_writer : child_writers) { - result->child_states.push_back(child_writer->InitializeWriteState(row_group)); - } - return std::move(result); -} - -bool StructColumnWriter::HasAnalyze() { - for (auto &child_writer : child_writers) { - if (child_writer->HasAnalyze()) { - return true; - } - } - return false; -} - -void StructColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &child_vectors = StructVector::GetEntries(vector); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - // Need to check again. It might be that just one child needs it but the rest not - if (child_writers[child_idx]->HasAnalyze()) { - child_writers[child_idx]->Analyze(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], - count); - } - } -} - -void StructColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - // Need to check again. It might be that just one child needs it but the rest not - if (child_writers[child_idx]->HasAnalyze()) { - child_writers[child_idx]->FinalizeAnalyze(*state.child_states[child_idx]); - } - } -} - -void StructColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - - auto &validity = FlatVector::Validity(vector); - if (parent) { - // propagate empty entries from the parent - while (state.is_empty.size() < parent->is_empty.size()) { - state.is_empty.push_back(parent->is_empty[state.is_empty.size()]); - } - } - HandleRepeatLevels(state_p, parent, count, max_repeat); - HandleDefineLevels(state_p, parent, validity, count, PARQUET_DEFINE_VALID, max_define - 1); - auto &child_vectors = StructVector::GetEntries(vector); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - child_writers[child_idx]->Prepare(*state.child_states[child_idx], &state_p, *child_vectors[child_idx], count); - } -} - -void StructColumnWriter::BeginWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - child_writers[child_idx]->BeginWrite(*state.child_states[child_idx]); - } -} - -void StructColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &child_vectors = StructVector::GetEntries(vector); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - child_writers[child_idx]->Write(*state.child_states[child_idx], *child_vectors[child_idx], count); - } -} - -void StructColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - for (idx_t child_idx = 0; child_idx < child_writers.size(); child_idx++) { - // we add the null count of the struct to the null count of the children - state.child_states[child_idx]->null_count += state_p.null_count; - child_writers[child_idx]->FinalizeWrite(*state.child_states[child_idx]); - } -} - -//===--------------------------------------------------------------------===// -// List Column Writer -//===--------------------------------------------------------------------===// -class ListColumnWriter : public ColumnWriter { -public: - ListColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, unique_ptr child_writer_p, bool can_have_nulls) - : ColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, can_have_nulls), - child_writer(std::move(child_writer_p)) { - } - ~ListColumnWriter() override = default; - - unique_ptr child_writer; - -public: - unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) override; - bool HasAnalyze() override; - void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void FinalizeAnalyze(ColumnWriterState &state) override; - void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - - void BeginWrite(ColumnWriterState &state) override; - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; - void FinalizeWrite(ColumnWriterState &state) override; -}; - -class ListColumnWriterState : public ColumnWriterState { -public: - ListColumnWriterState(duckdb_parquet::RowGroup &row_group, idx_t col_idx) : row_group(row_group), col_idx(col_idx) { - } - ~ListColumnWriterState() override = default; - - duckdb_parquet::RowGroup &row_group; - idx_t col_idx; - unique_ptr child_state; - idx_t parent_index = 0; -}; - -unique_ptr ListColumnWriter::InitializeWriteState(duckdb_parquet::RowGroup &row_group) { - auto result = make_uniq(row_group, row_group.columns.size()); - result->child_state = child_writer->InitializeWriteState(row_group); - return std::move(result); -} - -bool ListColumnWriter::HasAnalyze() { - return child_writer->HasAnalyze(); -} -void ListColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &list_child = ListVector::GetEntry(vector); - auto list_count = ListVector::GetListSize(vector); - child_writer->Analyze(*state.child_state, &state_p, list_child, list_count); -} - -void ListColumnWriter::FinalizeAnalyze(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - child_writer->FinalizeAnalyze(*state.child_state); -} - -idx_t GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { - // returns a consecutive child list that fully flattens and repeats all required elements - auto &validity = FlatVector::Validity(list); - auto list_entries = FlatVector::GetData(list); - bool is_consecutive = true; - idx_t total_length = 0; - for (idx_t c = offset; c < offset + count; c++) { - if (!validity.RowIsValid(c)) { - continue; - } - if (list_entries[c].offset != total_length) { - is_consecutive = false; - } - total_length += list_entries[c].length; - } - if (is_consecutive) { - // already consecutive - leave it as-is - return total_length; - } - SelectionVector sel(total_length); - idx_t index = 0; - for (idx_t c = offset; c < offset + count; c++) { - if (!validity.RowIsValid(c)) { - continue; - } - for (idx_t k = 0; k < list_entries[c].length; k++) { - sel.set_index(index++, list_entries[c].offset + k); - } - } - result.Slice(sel, total_length); - result.Flatten(total_length); - return total_length; -} - -void ListColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - - auto list_data = FlatVector::GetData(vector); - auto &validity = FlatVector::Validity(vector); - - // write definition levels and repeats - idx_t start = 0; - idx_t vcount = parent ? parent->definition_levels.size() - state.parent_index : count; - idx_t vector_index = 0; - for (idx_t i = start; i < vcount; i++) { - idx_t parent_index = state.parent_index + i; - if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index]) { - state.definition_levels.push_back(parent->definition_levels[parent_index]); - state.repetition_levels.push_back(parent->repetition_levels[parent_index]); - state.is_empty.push_back(true); - continue; - } - auto first_repeat_level = - parent && !parent->repetition_levels.empty() ? parent->repetition_levels[parent_index] : max_repeat; - if (parent && parent->definition_levels[parent_index] != PARQUET_DEFINE_VALID) { - state.definition_levels.push_back(parent->definition_levels[parent_index]); - state.repetition_levels.push_back(first_repeat_level); - state.is_empty.push_back(true); - } else if (validity.RowIsValid(vector_index)) { - // push the repetition levels - if (list_data[vector_index].length == 0) { - state.definition_levels.push_back(max_define); - state.is_empty.push_back(true); - } else { - state.definition_levels.push_back(PARQUET_DEFINE_VALID); - state.is_empty.push_back(false); - } - state.repetition_levels.push_back(first_repeat_level); - for (idx_t k = 1; k < list_data[vector_index].length; k++) { - state.repetition_levels.push_back(max_repeat + 1); - state.definition_levels.push_back(PARQUET_DEFINE_VALID); - state.is_empty.push_back(false); - } - } else { - if (!can_have_nulls) { - throw IOException("Parquet writer: map key column is not allowed to contain NULL values"); - } - state.definition_levels.push_back(max_define - 1); - state.repetition_levels.push_back(first_repeat_level); - state.is_empty.push_back(true); - } - vector_index++; - } - state.parent_index += vcount; - - auto &list_child = ListVector::GetEntry(vector); - Vector child_list(list_child); - auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); - child_writer->Prepare(*state.child_state, &state_p, child_list, child_length); -} - -void ListColumnWriter::BeginWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - child_writer->BeginWrite(*state.child_state); -} - -void ListColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - - auto &list_child = ListVector::GetEntry(vector); - Vector child_list(list_child); - auto child_length = GetConsecutiveChildList(vector, child_list, 0, count); - child_writer->Write(*state.child_state, child_list, child_length); -} - -void ListColumnWriter::FinalizeWrite(ColumnWriterState &state_p) { - auto &state = state_p.Cast(); - child_writer->FinalizeWrite(*state.child_state); -} - -//===--------------------------------------------------------------------===// -// Array Column Writer -//===--------------------------------------------------------------------===// -class ArrayColumnWriter : public ListColumnWriter { -public: - ArrayColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path_p, idx_t max_repeat, - idx_t max_define, unique_ptr child_writer_p, bool can_have_nulls) - : ListColumnWriter(writer, schema_idx, std::move(schema_path_p), max_repeat, max_define, - std::move(child_writer_p), can_have_nulls) { - } - ~ArrayColumnWriter() override = default; - -public: - void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) override; - void Write(ColumnWriterState &state, Vector &vector, idx_t count) override; -}; - -void ArrayColumnWriter::Analyze(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto &array_child = ArrayVector::GetEntry(vector); - auto array_size = ArrayType::GetSize(vector.GetType()); - child_writer->Analyze(*state.child_state, &state_p, array_child, array_size * count); -} - -void ArrayColumnWriter::Prepare(ColumnWriterState &state_p, ColumnWriterState *parent, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - - auto array_size = ArrayType::GetSize(vector.GetType()); - auto &validity = FlatVector::Validity(vector); - - // write definition levels and repeats - // the main difference between this and ListColumnWriter::Prepare is that we need to make sure to write out - // repetition levels and definitions for the child elements of the array even if the array itself is NULL. - idx_t start = 0; - idx_t vcount = parent ? parent->definition_levels.size() - state.parent_index : count; - idx_t vector_index = 0; - for (idx_t i = start; i < vcount; i++) { - idx_t parent_index = state.parent_index + i; - if (parent && !parent->is_empty.empty() && parent->is_empty[parent_index]) { - state.definition_levels.push_back(parent->definition_levels[parent_index]); - state.repetition_levels.push_back(parent->repetition_levels[parent_index]); - state.is_empty.push_back(true); - continue; - } - auto first_repeat_level = - parent && !parent->repetition_levels.empty() ? parent->repetition_levels[parent_index] : max_repeat; - if (parent && parent->definition_levels[parent_index] != PARQUET_DEFINE_VALID) { - state.definition_levels.push_back(parent->definition_levels[parent_index]); - state.repetition_levels.push_back(first_repeat_level); - state.is_empty.push_back(false); - for (idx_t k = 1; k < array_size; k++) { - state.repetition_levels.push_back(max_repeat + 1); - state.definition_levels.push_back(parent->definition_levels[parent_index]); - state.is_empty.push_back(false); - } - } else if (validity.RowIsValid(vector_index)) { - // push the repetition levels - state.definition_levels.push_back(PARQUET_DEFINE_VALID); - state.is_empty.push_back(false); - - state.repetition_levels.push_back(first_repeat_level); - for (idx_t k = 1; k < array_size; k++) { - state.repetition_levels.push_back(max_repeat + 1); - state.definition_levels.push_back(PARQUET_DEFINE_VALID); - state.is_empty.push_back(false); - } - } else { - state.definition_levels.push_back(max_define - 1); - state.repetition_levels.push_back(first_repeat_level); - state.is_empty.push_back(false); - for (idx_t k = 1; k < array_size; k++) { - state.repetition_levels.push_back(max_repeat + 1); - state.definition_levels.push_back(max_define - 1); - state.is_empty.push_back(false); - } - } - vector_index++; - } - state.parent_index += vcount; - - auto &array_child = ArrayVector::GetEntry(vector); - child_writer->Prepare(*state.child_state, &state_p, array_child, count * array_size); -} - -void ArrayColumnWriter::Write(ColumnWriterState &state_p, Vector &vector, idx_t count) { - auto &state = state_p.Cast(); - auto array_size = ArrayType::GetSize(vector.GetType()); - auto &array_child = ArrayVector::GetEntry(vector); - child_writer->Write(*state.child_state, array_child, count * array_size); -} - -// special double/float class to deal with dictionary encoding and NaN equality -struct double_na_equal { - double_na_equal() : val(0) { - } - explicit double_na_equal(const double val_p) : val(val_p) { - } - // NOLINTNEXTLINE: allow implicit conversion to double - operator double() const { - return val; - } - - bool operator==(const double &right) const { - if (std::isnan(val) && std::isnan(right)) { - return true; - } - return val == right; - } - double val; -}; - -struct float_na_equal { - float_na_equal() : val(0) { - } - explicit float_na_equal(const float val_p) : val(val_p) { - } - // NOLINTNEXTLINE: allow implicit conversion to float - operator float() const { - return val; - } - bool operator==(const float &right) const { - if (std::isnan(val) && std::isnan(right)) { - return true; - } - return val == right; - } - float val; -}; - -//===--------------------------------------------------------------------===// -// Create Column Writer -//===--------------------------------------------------------------------===// - -unique_ptr ColumnWriter::CreateWriterRecursive(ClientContext &context, - vector &schemas, - ParquetWriter &writer, const LogicalType &type, - const string &name, vector schema_path, - optional_ptr field_ids, - idx_t max_repeat, idx_t max_define, bool can_have_nulls) { - auto null_type = can_have_nulls ? FieldRepetitionType::OPTIONAL : FieldRepetitionType::REQUIRED; - if (!can_have_nulls) { - max_define--; - } - idx_t schema_idx = schemas.size(); - - optional_ptr field_id; - optional_ptr child_field_ids; - if (field_ids) { - auto field_id_it = field_ids->ids->find(name); - if (field_id_it != field_ids->ids->end()) { - field_id = &field_id_it->second; - child_field_ids = &field_id->child_field_ids; - } - } - - if (type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION) { - auto &child_types = StructType::GetChildTypes(type); - // set up the schema element for this struct - duckdb_parquet::SchemaElement schema_element; - schema_element.repetition_type = null_type; - schema_element.num_children = UnsafeNumericCast(child_types.size()); - schema_element.__isset.num_children = true; - schema_element.__isset.type = false; - schema_element.__isset.repetition_type = true; - schema_element.name = name; - if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(schema_element)); - schema_path.push_back(name); - - // construct the child types recursively - vector> child_writers; - child_writers.reserve(child_types.size()); - for (auto &child_type : child_types) { - child_writers.push_back(CreateWriterRecursive(context, schemas, writer, child_type.second, child_type.first, - schema_path, child_field_ids, max_repeat, max_define + 1)); - } - return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, - std::move(child_writers), can_have_nulls); - } - if (type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::ARRAY) { - auto is_list = type.id() == LogicalTypeId::LIST; - auto &child_type = is_list ? ListType::GetChildType(type) : ArrayType::GetChildType(type); - // set up the two schema elements for the list - // for some reason we only set the converted type in the OPTIONAL element - // first an OPTIONAL element - duckdb_parquet::SchemaElement optional_element; - optional_element.repetition_type = null_type; - optional_element.num_children = 1; - optional_element.converted_type = ConvertedType::LIST; - optional_element.__isset.num_children = true; - optional_element.__isset.type = false; - optional_element.__isset.repetition_type = true; - optional_element.__isset.converted_type = true; - optional_element.name = name; - if (field_id && field_id->set) { - optional_element.__isset.field_id = true; - optional_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(optional_element)); - schema_path.push_back(name); - - // then a REPEATED element - duckdb_parquet::SchemaElement repeated_element; - repeated_element.repetition_type = FieldRepetitionType::REPEATED; - repeated_element.num_children = 1; - repeated_element.__isset.num_children = true; - repeated_element.__isset.type = false; - repeated_element.__isset.repetition_type = true; - repeated_element.name = is_list ? "list" : "array"; - schemas.push_back(std::move(repeated_element)); - schema_path.emplace_back(is_list ? "list" : "array"); - - auto child_writer = CreateWriterRecursive(context, schemas, writer, child_type, "element", schema_path, - child_field_ids, max_repeat + 1, max_define + 2); - if (is_list) { - return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, - std::move(child_writer), can_have_nulls); - } else { - return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, - std::move(child_writer), can_have_nulls); - } - } - if (type.id() == LogicalTypeId::MAP) { - // map type - // maps are stored as follows: - // group (MAP) { - // repeated group key_value { - // required key; - // value; - // } - // } - // top map element - duckdb_parquet::SchemaElement top_element; - top_element.repetition_type = null_type; - top_element.num_children = 1; - top_element.converted_type = ConvertedType::MAP; - top_element.__isset.repetition_type = true; - top_element.__isset.num_children = true; - top_element.__isset.converted_type = true; - top_element.__isset.type = false; - top_element.name = name; - if (field_id && field_id->set) { - top_element.__isset.field_id = true; - top_element.field_id = field_id->field_id; - } - schemas.push_back(std::move(top_element)); - schema_path.push_back(name); - - // key_value element - duckdb_parquet::SchemaElement kv_element; - kv_element.repetition_type = FieldRepetitionType::REPEATED; - kv_element.num_children = 2; - kv_element.__isset.repetition_type = true; - kv_element.__isset.num_children = true; - kv_element.__isset.type = false; - kv_element.name = "key_value"; - schemas.push_back(std::move(kv_element)); - schema_path.emplace_back("key_value"); - - // construct the child types recursively - vector kv_types {MapType::KeyType(type), MapType::ValueType(type)}; - vector kv_names {"key", "value"}; - vector> child_writers; - child_writers.reserve(2); - for (idx_t i = 0; i < 2; i++) { - // key needs to be marked as REQUIRED - bool is_key = i == 0; - auto child_writer = CreateWriterRecursive(context, schemas, writer, kv_types[i], kv_names[i], schema_path, - child_field_ids, max_repeat + 1, max_define + 2, !is_key); - - child_writers.push_back(std::move(child_writer)); - } - auto struct_writer = make_uniq(writer, schema_idx, schema_path, max_repeat, max_define, - std::move(child_writers), can_have_nulls); - return make_uniq(writer, schema_idx, schema_path, max_repeat, max_define, - std::move(struct_writer), can_have_nulls); - } - duckdb_parquet::SchemaElement schema_element; - schema_element.type = ParquetWriter::DuckDBTypeToParquetType(type); - schema_element.repetition_type = null_type; - schema_element.__isset.num_children = false; - schema_element.__isset.type = true; - schema_element.__isset.repetition_type = true; - schema_element.name = name; - if (field_id && field_id->set) { - schema_element.__isset.field_id = true; - schema_element.field_id = field_id->field_id; - } - ParquetWriter::SetSchemaProperties(type, schema_element); - schemas.push_back(std::move(schema_element)); - schema_path.push_back(name); - if (type.id() == LogicalTypeId::BLOB && type.GetAlias() == "WKB_BLOB" && - GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - return make_uniq(context, writer, schema_idx, std::move(schema_path), max_repeat, max_define, - can_have_nulls, name); - } - - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, max_define, - can_have_nulls); - case LogicalTypeId::TINYINT: - return make_uniq>(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - case LogicalTypeId::SMALLINT: - return make_uniq>(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - case LogicalTypeId::INTEGER: - case LogicalTypeId::DATE: - return make_uniq>(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_MS: - return make_uniq>(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - case LogicalTypeId::TIME_TZ: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::HUGEINT: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::UHUGEINT: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::TIMESTAMP_NS: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::TIMESTAMP_SEC: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::UTINYINT: - return make_uniq>(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - case LogicalTypeId::USMALLINT: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case LogicalTypeId::UINTEGER: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case LogicalTypeId::UBIGINT: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case LogicalTypeId::FLOAT: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case LogicalTypeId::DOUBLE: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case PhysicalType::INT32: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - case PhysicalType::INT64: - return make_uniq>(writer, schema_idx, std::move(schema_path), - max_repeat, max_define, can_have_nulls); - default: - return make_uniq(writer, schema_idx, std::move(schema_path), max_repeat, - max_define, can_have_nulls); - } - case LogicalTypeId::BLOB: - case LogicalTypeId::VARCHAR: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::UUID: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::INTERVAL: - return make_uniq>( - writer, schema_idx, std::move(schema_path), max_repeat, max_define, can_have_nulls); - case LogicalTypeId::ENUM: - return make_uniq(writer, type, schema_idx, std::move(schema_path), max_repeat, max_define, - can_have_nulls); - default: - throw InternalException("Unsupported type \"%s\" in Parquet writer", type.ToString()); - } -} - -template <> -struct NumericLimits { - static constexpr float Minimum() { - return std::numeric_limits::lowest(); - }; - static constexpr float Maximum() { - return std::numeric_limits::max(); - }; - static constexpr bool IsSigned() { - return std::is_signed::value; - } - static constexpr bool IsIntegral() { - return std::is_integral::value; - } -}; - -template <> -struct NumericLimits { - static constexpr double Minimum() { - return std::numeric_limits::lowest(); - }; - static constexpr double Maximum() { - return std::numeric_limits::max(); - }; - static constexpr bool IsSigned() { - return std::is_signed::value; - } - static constexpr bool IsIntegral() { - return std::is_integral::value; - } -}; - -} // namespace duckdb - -namespace std { -template <> -struct hash { - size_t operator()(const duckdb::ParquetIntervalTargetType &val) const { - return duckdb::Hash(duckdb::const_char_ptr_cast(val.bytes), - duckdb::ParquetIntervalTargetType::PARQUET_INTERVAL_SIZE); - } -}; - -template <> -struct hash { - size_t operator()(const duckdb::ParquetUUIDTargetType &val) const { - return duckdb::Hash(duckdb::const_char_ptr_cast(val.bytes), duckdb::ParquetUUIDTargetType::PARQUET_UUID_SIZE); - } -}; - -template <> -struct hash { - size_t operator()(const duckdb::float_na_equal &val) const { - if (std::isnan(val.val)) { - return duckdb::Hash(std::numeric_limits::quiet_NaN()); - } - return duckdb::Hash(val.val); - } -}; - -template <> -struct hash { - inline size_t operator()(const duckdb::double_na_equal &val) const { - if (std::isnan(val.val)) { - return duckdb::Hash(std::numeric_limits::quiet_NaN()); - } - return duckdb::Hash(val.val); - } -}; -} // namespace std diff --git a/src/duckdb/extension/parquet/geo_parquet.cpp b/src/duckdb/extension/parquet/geo_parquet.cpp deleted file mode 100644 index ec252a50d..000000000 --- a/src/duckdb/extension/parquet/geo_parquet.cpp +++ /dev/null @@ -1,424 +0,0 @@ - -#include "geo_parquet.hpp" - -#include "column_reader.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "expression_column_reader.hpp" -#include "parquet_reader.hpp" -#include "yyjson.hpp" - -namespace duckdb { - -using namespace duckdb_yyjson; // NOLINT - -const char *WKBGeometryTypes::ToString(WKBGeometryType type) { - switch (type) { - case WKBGeometryType::POINT: - return "Point"; - case WKBGeometryType::LINESTRING: - return "LineString"; - case WKBGeometryType::POLYGON: - return "Polygon"; - case WKBGeometryType::MULTIPOINT: - return "MultiPoint"; - case WKBGeometryType::MULTILINESTRING: - return "MultiLineString"; - case WKBGeometryType::MULTIPOLYGON: - return "MultiPolygon"; - case WKBGeometryType::GEOMETRYCOLLECTION: - return "GeometryCollection"; - case WKBGeometryType::POINT_Z: - return "Point Z"; - case WKBGeometryType::LINESTRING_Z: - return "LineString Z"; - case WKBGeometryType::POLYGON_Z: - return "Polygon Z"; - case WKBGeometryType::MULTIPOINT_Z: - return "MultiPoint Z"; - case WKBGeometryType::MULTILINESTRING_Z: - return "MultiLineString Z"; - case WKBGeometryType::MULTIPOLYGON_Z: - return "MultiPolygon Z"; - case WKBGeometryType::GEOMETRYCOLLECTION_Z: - return "GeometryCollection Z"; - default: - throw NotImplementedException("Unsupported geometry type"); - } -} - -//------------------------------------------------------------------------------ -// GeoParquetColumnMetadataWriter -//------------------------------------------------------------------------------ -GeoParquetColumnMetadataWriter::GeoParquetColumnMetadataWriter(ClientContext &context) { - executor = make_uniq(context); - - auto &catalog = Catalog::GetSystemCatalog(context); - - // These functions are required to extract the geometry type, ZM flag and bounding box from a WKB blob - auto &type_func_set = - catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_geometrytype") - .Cast(); - auto &flag_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_zmflag") - .Cast(); - auto &bbox_func_set = catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_extent") - .Cast(); - - auto wkb_type = LogicalType(LogicalTypeId::BLOB); - wkb_type.SetAlias("WKB_BLOB"); - - auto type_func = type_func_set.functions.GetFunctionByArguments(context, {wkb_type}); - auto flag_func = flag_func_set.functions.GetFunctionByArguments(context, {wkb_type}); - auto bbox_func = bbox_func_set.functions.GetFunctionByArguments(context, {wkb_type}); - - auto type_type = LogicalType::UTINYINT; - auto flag_type = flag_func.return_type; - auto bbox_type = bbox_func.return_type; - - vector> type_args; - type_args.push_back(make_uniq(wkb_type, 0)); - - vector> flag_args; - flag_args.push_back(make_uniq(wkb_type, 0)); - - vector> bbox_args; - bbox_args.push_back(make_uniq(wkb_type, 0)); - - type_expr = make_uniq(type_type, type_func, std::move(type_args), nullptr); - flag_expr = make_uniq(flag_type, flag_func, std::move(flag_args), nullptr); - bbox_expr = make_uniq(bbox_type, bbox_func, std::move(bbox_args), nullptr); - - // Add the expressions to the executor - executor->AddExpression(*type_expr); - executor->AddExpression(*flag_expr); - executor->AddExpression(*bbox_expr); - - // Initialize the input and result chunks - // The input chunk should be empty, as we always reference the input vector - input_chunk.InitializeEmpty({wkb_type}); - result_chunk.Initialize(context, {type_type, flag_type, bbox_type}); -} - -void GeoParquetColumnMetadataWriter::Update(GeoParquetColumnMetadata &meta, Vector &vector, idx_t count) { - input_chunk.Reset(); - result_chunk.Reset(); - - // Reference the vector - input_chunk.data[0].Reference(vector); - input_chunk.SetCardinality(count); - - // Execute the expression - executor->Execute(input_chunk, result_chunk); - - // The first column is the geometry type - // The second column is the zm flag - // The third column is the bounding box - - UnifiedVectorFormat type_format; - UnifiedVectorFormat flag_format; - UnifiedVectorFormat bbox_format; - - result_chunk.data[0].ToUnifiedFormat(count, type_format); - result_chunk.data[1].ToUnifiedFormat(count, flag_format); - result_chunk.data[2].ToUnifiedFormat(count, bbox_format); - - const auto &bbox_components = StructVector::GetEntries(result_chunk.data[2]); - D_ASSERT(bbox_components.size() == 4); - - UnifiedVectorFormat xmin_format; - UnifiedVectorFormat ymin_format; - UnifiedVectorFormat xmax_format; - UnifiedVectorFormat ymax_format; - - bbox_components[0]->ToUnifiedFormat(count, xmin_format); - bbox_components[1]->ToUnifiedFormat(count, ymin_format); - bbox_components[2]->ToUnifiedFormat(count, xmax_format); - bbox_components[3]->ToUnifiedFormat(count, ymax_format); - - for (idx_t in_idx = 0; in_idx < count; in_idx++) { - const auto type_idx = type_format.sel->get_index(in_idx); - const auto flag_idx = flag_format.sel->get_index(in_idx); - const auto bbox_idx = bbox_format.sel->get_index(in_idx); - - const auto type_valid = type_format.validity.RowIsValid(type_idx); - const auto flag_valid = flag_format.validity.RowIsValid(flag_idx); - const auto bbox_valid = bbox_format.validity.RowIsValid(bbox_idx); - - if (!type_valid || !flag_valid || !bbox_valid) { - continue; - } - - // Update the geometry type - const auto flag = UnifiedVectorFormat::GetData(flag_format)[flag_idx]; - const auto type = UnifiedVectorFormat::GetData(type_format)[type_idx]; - if (flag == 1 || flag == 3) { - // M or ZM - throw InvalidInputException("Geoparquet does not support geometries with M coordinates"); - } - const auto has_z = flag == 2; - auto wkb_type = static_cast((type + 1) + (has_z ? 1000 : 0)); - meta.geometry_types.insert(wkb_type); - - // Update the bounding box - const auto min_x = UnifiedVectorFormat::GetData(xmin_format)[bbox_idx]; - const auto min_y = UnifiedVectorFormat::GetData(ymin_format)[bbox_idx]; - const auto max_x = UnifiedVectorFormat::GetData(xmax_format)[bbox_idx]; - const auto max_y = UnifiedVectorFormat::GetData(ymax_format)[bbox_idx]; - meta.bbox.Combine(min_x, max_x, min_y, max_y); - } -} - -//------------------------------------------------------------------------------ -// GeoParquetFileMetadata -//------------------------------------------------------------------------------ - -unique_ptr GeoParquetFileMetadata::TryRead(const duckdb_parquet::FileMetaData &file_meta_data, - const ClientContext &context) { - - // Conversion not enabled, or spatial is not loaded! - if (!IsGeoParquetConversionEnabled(context)) { - return nullptr; - } - - for (auto &kv : file_meta_data.key_value_metadata) { - if (kv.key == "geo") { - const auto geo_metadata = yyjson_read(kv.value.c_str(), kv.value.size(), 0); - if (!geo_metadata) { - // Could not parse the JSON - return nullptr; - } - - try { - // Check the root object - const auto root = yyjson_doc_get_root(geo_metadata); - if (!yyjson_is_obj(root)) { - throw InvalidInputException("Geoparquet metadata is not an object"); - } - - auto result = make_uniq(); - - // Check and parse the version - const auto version_val = yyjson_obj_get(root, "version"); - if (!yyjson_is_str(version_val)) { - throw InvalidInputException("Geoparquet metadata does not have a version"); - } - result->version = yyjson_get_str(version_val); - if (StringUtil::StartsWith(result->version, "2")) { - // Guard against a breaking future 2.0 version - throw InvalidInputException("Geoparquet version %s is not supported", result->version); - } - - // Check and parse the primary geometry column - const auto primary_geometry_column_val = yyjson_obj_get(root, "primary_column"); - if (!yyjson_is_str(primary_geometry_column_val)) { - throw InvalidInputException("Geoparquet metadata does not have a primary column"); - } - result->primary_geometry_column = yyjson_get_str(primary_geometry_column_val); - - // Check and parse the geometry columns - const auto columns_val = yyjson_obj_get(root, "columns"); - if (!yyjson_is_obj(columns_val)) { - throw InvalidInputException("Geoparquet metadata does not have a columns object"); - } - - // Iterate over all geometry columns - yyjson_obj_iter iter = yyjson_obj_iter_with(columns_val); - yyjson_val *column_key; - - while ((column_key = yyjson_obj_iter_next(&iter))) { - const auto column_val = yyjson_obj_iter_get_val(column_key); - const auto column_name = yyjson_get_str(column_key); - - auto &column = result->geometry_columns[column_name]; - - if (!yyjson_is_obj(column_val)) { - throw InvalidInputException("Geoparquet column '%s' is not an object", column_name); - } - - // Parse the encoding - const auto encoding_val = yyjson_obj_get(column_val, "encoding"); - if (!yyjson_is_str(encoding_val)) { - throw InvalidInputException("Geoparquet column '%s' does not have an encoding", column_name); - } - const auto encoding_str = yyjson_get_str(encoding_val); - if (strcmp(encoding_str, "WKB") == 0) { - column.geometry_encoding = GeoParquetColumnEncoding::WKB; - } else { - throw InvalidInputException("Geoparquet column '%s' has an unsupported encoding", column_name); - } - - // Parse the geometry types - const auto geometry_types_val = yyjson_obj_get(column_val, "geometry_types"); - if (!yyjson_is_arr(geometry_types_val)) { - throw InvalidInputException("Geoparquet column '%s' does not have geometry types", column_name); - } - // We dont care about the geometry types for now. - - // TODO: Parse the bounding box, other metadata that might be useful. - // (Only encoding and geometry types are required to be present) - } - - // Return the result - // Make sure to free the JSON document - yyjson_doc_free(geo_metadata); - return result; - - } catch (...) { - // Make sure to free the JSON document in case of an exception - yyjson_doc_free(geo_metadata); - throw; - } - } - } - return nullptr; -} - -void GeoParquetFileMetadata::FlushColumnMeta(const string &column_name, const GeoParquetColumnMetadata &meta) { - // Lock the metadata - lock_guard glock(write_lock); - - auto &column = geometry_columns[column_name]; - - // Combine the metadata - column.geometry_types.insert(meta.geometry_types.begin(), meta.geometry_types.end()); - column.bbox.Combine(meta.bbox); -} - -void GeoParquetFileMetadata::Write(duckdb_parquet::FileMetaData &file_meta_data) const { - - yyjson_mut_doc *doc = yyjson_mut_doc_new(nullptr); - yyjson_mut_val *root = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, root); - - // Add the version - yyjson_mut_obj_add_strncpy(doc, root, "version", version.c_str(), version.size()); - - // Add the primary column - yyjson_mut_obj_add_strncpy(doc, root, "primary_column", primary_geometry_column.c_str(), - primary_geometry_column.size()); - - // Add the columns - const auto json_columns = yyjson_mut_obj_add_obj(doc, root, "columns"); - - for (auto &column : geometry_columns) { - const auto column_json = yyjson_mut_obj_add_obj(doc, json_columns, column.first.c_str()); - yyjson_mut_obj_add_str(doc, column_json, "encoding", "WKB"); - const auto geometry_types = yyjson_mut_obj_add_arr(doc, column_json, "geometry_types"); - for (auto &geometry_type : column.second.geometry_types) { - const auto type_name = WKBGeometryTypes::ToString(geometry_type); - yyjson_mut_arr_add_str(doc, geometry_types, type_name); - } - const auto bbox = yyjson_mut_obj_add_arr(doc, column_json, "bbox"); - yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.min_x); - yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.min_y); - yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.max_x); - yyjson_mut_arr_add_real(doc, bbox, column.second.bbox.max_y); - - // If the CRS is present, add it - if (!column.second.projjson.empty()) { - const auto crs_doc = yyjson_read(column.second.projjson.c_str(), column.second.projjson.size(), 0); - if (!crs_doc) { - yyjson_mut_doc_free(doc); - throw InvalidInputException("Failed to parse CRS JSON"); - } - const auto crs_root = yyjson_doc_get_root(crs_doc); - const auto crs_val = yyjson_val_mut_copy(doc, crs_root); - const auto crs_key = yyjson_mut_strcpy(doc, "projjson"); - yyjson_mut_obj_add(column_json, crs_key, crs_val); - yyjson_doc_free(crs_doc); - } - } - - yyjson_write_err err; - size_t len; - char *json = yyjson_mut_write_opts(doc, 0, nullptr, &len, &err); - if (!json) { - yyjson_mut_doc_free(doc); - throw SerializationException("Failed to write JSON string: %s", err.msg); - } - - // Create a string from the JSON - duckdb_parquet::KeyValue kv; - kv.__set_key("geo"); - kv.__set_value(string(json, len)); - - // Free the JSON and the document - free(json); - yyjson_mut_doc_free(doc); - - file_meta_data.key_value_metadata.push_back(kv); - file_meta_data.__isset.key_value_metadata = true; -} - -bool GeoParquetFileMetadata::IsGeometryColumn(const string &column_name) const { - return geometry_columns.find(column_name) != geometry_columns.end(); -} - -void GeoParquetFileMetadata::RegisterGeometryColumn(const string &column_name) { - lock_guard glock(write_lock); - if (primary_geometry_column.empty()) { - primary_geometry_column = column_name; - } - geometry_columns[column_name] = GeoParquetColumnMetadata(); -} - -bool GeoParquetFileMetadata::IsGeoParquetConversionEnabled(const ClientContext &context) { - Value geoparquet_enabled; - if (!context.TryGetCurrentSetting("enable_geoparquet_conversion", geoparquet_enabled)) { - return false; - } - if (!geoparquet_enabled.GetValue()) { - // Disabled by setting - return false; - } - if (!context.db->ExtensionIsLoaded("spatial")) { - // Spatial extension is not loaded, we cant convert anyway - return false; - } - return true; -} - -unique_ptr GeoParquetFileMetadata::CreateColumnReader(ParquetReader &reader, - const LogicalType &logical_type, - const SchemaElement &s_ele, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p, - ClientContext &context) { - - D_ASSERT(IsGeometryColumn(s_ele.name)); - - const auto &column = geometry_columns[s_ele.name]; - - // Get the catalog - auto &catalog = Catalog::GetSystemCatalog(context); - - // WKB encoding - if (logical_type.id() == LogicalTypeId::BLOB && column.geometry_encoding == GeoParquetColumnEncoding::WKB) { - // Look for a conversion function in the catalog - auto &conversion_func_set = - catalog.GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, DEFAULT_SCHEMA, "st_geomfromwkb") - .Cast(); - auto conversion_func = conversion_func_set.functions.GetFunctionByArguments(context, {LogicalType::BLOB}); - - // Create a bound function call expression - auto args = vector>(); - args.push_back(std::move(make_uniq(LogicalType::BLOB, 0))); - auto expr = - make_uniq(conversion_func.return_type, conversion_func, std::move(args), nullptr); - - // Create a child reader - auto child_reader = - ColumnReader::CreateReader(reader, logical_type, s_ele, schema_idx_p, max_define_p, max_repeat_p); - - // Create an expression reader that applies the conversion function to the child reader - return make_uniq(context, std::move(child_reader), std::move(expr)); - } - - // Otherwise, unrecognized encoding - throw NotImplementedException("Unsupported geometry encoding"); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/boolean_column_reader.hpp b/src/duckdb/extension/parquet/include/boolean_column_reader.hpp deleted file mode 100644 index c37c62098..000000000 --- a/src/duckdb/extension/parquet/include/boolean_column_reader.hpp +++ /dev/null @@ -1,71 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// boolean_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -struct BooleanParquetValueConversion; - -class BooleanColumnReader : public TemplatedColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::BOOL; - -public: - BooleanColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, - max_define_p, max_repeat_p), - byte_pos(0) {}; - - uint8_t byte_pos; - - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override { - byte_pos = 0; - TemplatedColumnReader::InitializeRead(row_group_idx_p, columns, - protocol_p); - } - - void ResetPage() override { - byte_pos = 0; - } -}; - -struct BooleanParquetValueConversion { - static bool PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.available(1); - return UnsafePlainRead(plain_data, reader); - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - PlainRead(plain_data, reader); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return plain_data.check_available((count + 7) / 8); - } - - static bool UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - auto &byte_pos = reader.Cast().byte_pos; - bool ret = (*plain_data.ptr >> byte_pos) & 1; - if (++byte_pos == 8) { - byte_pos = 0; - plain_data.unsafe_inc(1); - } - return ret; - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - UnsafePlainRead(plain_data, reader); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/callback_column_reader.hpp b/src/duckdb/extension/parquet/include/callback_column_reader.hpp deleted file mode 100644 index 45c3e726e..000000000 --- a/src/duckdb/extension/parquet/include/callback_column_reader.hpp +++ /dev/null @@ -1,47 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// callback_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" -#include "parquet_reader.hpp" - -namespace duckdb { - -template -class CallbackColumnReader - : public TemplatedColumnReader> { - using BaseType = - TemplatedColumnReader>; - -public: - static constexpr const PhysicalType TYPE = PhysicalType::INVALID; - -public: - CallbackColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : TemplatedColumnReader>( - reader, std::move(type_p), schema_p, file_idx_p, max_define_p, max_repeat_p) { - } - -protected: - void Dictionary(shared_ptr dictionary_data, idx_t num_entries) { - BaseType::AllocateDict(num_entries * sizeof(DUCKDB_PHYSICAL_TYPE)); - auto dict_ptr = (DUCKDB_PHYSICAL_TYPE *)this->dict->ptr; - for (idx_t i = 0; i < num_entries; i++) { - dict_ptr[i] = FUNC(dictionary_data->read()); - } - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/cast_column_reader.hpp b/src/duckdb/extension/parquet/include/cast_column_reader.hpp deleted file mode 100644 index 640a77bda..000000000 --- a/src/duckdb/extension/parquet/include/cast_column_reader.hpp +++ /dev/null @@ -1,50 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// cast_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -//! A column reader that represents a cast over a child reader -class CastColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::INVALID; - -public: - CastColumnReader(unique_ptr child_reader, LogicalType target_type); - - unique_ptr child_reader; - DataChunk intermediate_chunk; - -public: - unique_ptr Stats(idx_t row_group_idx_p, const vector &columns) override; - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; - - idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result) override; - - void Skip(idx_t num_values) override; - idx_t GroupRowsAvailable() override; - - uint64_t TotalCompressedSize() override { - return child_reader->TotalCompressedSize(); - } - - idx_t FileOffset() const override { - return child_reader->FileOffset(); - } - - void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { - child_reader->RegisterPrefetch(transport, allow_merge); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/column_reader.hpp b/src/duckdb/extension/parquet/include/column_reader.hpp deleted file mode 100644 index 23d4fc3d4..000000000 --- a/src/duckdb/extension/parquet/include/column_reader.hpp +++ /dev/null @@ -1,219 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#include "parquet_bss_decoder.hpp" -#include "parquet_dbp_decoder.hpp" -#include "parquet_rle_bp_decoder.hpp" -#include "parquet_statistics.hpp" -#include "parquet_types.h" -#include "resizable_buffer.hpp" -#include "thrift_tools.hpp" -#ifndef DUCKDB_AMALGAMATION - -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types/vector_cache.hpp" -#endif - -namespace duckdb { -class ParquetReader; - -using duckdb_apache::thrift::protocol::TProtocol; - -using duckdb_parquet::ColumnChunk; -using duckdb_parquet::CompressionCodec; -using duckdb_parquet::FieldRepetitionType; -using duckdb_parquet::PageHeader; -using duckdb_parquet::SchemaElement; -using duckdb_parquet::Type; - -typedef std::bitset parquet_filter_t; - -class ColumnReader { -public: - ColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t file_idx_p, - idx_t max_define_p, idx_t max_repeat_p); - virtual ~ColumnReader(); - -public: - static unique_ptr CreateReader(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t schema_idx_p, idx_t max_define, - idx_t max_repeat); - virtual void InitializeRead(idx_t row_group_index, const vector &columns, TProtocol &protocol_p); - virtual idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result_out); - - virtual void Skip(idx_t num_values); - - ParquetReader &Reader(); - const LogicalType &Type() const; - const SchemaElement &Schema() const; - optional_ptr GetParentSchema() const; - void SetParentSchema(const SchemaElement &parent_schema); - - idx_t FileIdx() const; - idx_t MaxDefine() const; - idx_t MaxRepeat() const; - - virtual idx_t FileOffset() const; - virtual uint64_t TotalCompressedSize(); - virtual idx_t GroupRowsAvailable(); - - // register the range this reader will touch for prefetching - virtual void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge); - - virtual unique_ptr Stats(idx_t row_group_idx_p, const vector &columns); - - template - void PlainTemplated(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, - parquet_filter_t *filter, idx_t result_offset, Vector &result) { - if (HasDefines()) { - if (CONVERSION::PlainAvailable(*plain_data, num_values)) { - PlainTemplatedInternal(*plain_data, defines, num_values, filter, - result_offset, result); - } else { - PlainTemplatedInternal(*plain_data, defines, num_values, filter, - result_offset, result); - } - } else { - if (CONVERSION::PlainAvailable(*plain_data, num_values)) { - PlainTemplatedInternal(*plain_data, defines, num_values, filter, - result_offset, result); - } else { - PlainTemplatedInternal(*plain_data, defines, num_values, filter, - result_offset, result); - } - } - } - -private: - template - void PlainTemplatedInternal(ByteBuffer &plain_data, const uint8_t *__restrict defines, const uint64_t num_values, - const parquet_filter_t *filter, const idx_t result_offset, Vector &result) { - const auto result_ptr = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - for (idx_t row_idx = result_offset; row_idx < result_offset + num_values; row_idx++) { - if (HAS_DEFINES && defines && defines[row_idx] != max_define) { - result_mask.SetInvalid(row_idx); - } else if (!filter || filter->test(row_idx)) { - result_ptr[row_idx] = - UNSAFE ? CONVERSION::UnsafePlainRead(plain_data, *this) : CONVERSION::PlainRead(plain_data, *this); - } else { // there is still some data there that we have to skip over - if (UNSAFE) { - CONVERSION::UnsafePlainSkip(plain_data, *this); - } else { - CONVERSION::PlainSkip(plain_data, *this); - } - } - } - } - -protected: - Allocator &GetAllocator(); - // readers that use the default Read() need to implement those - virtual void Plain(shared_ptr plain_data, uint8_t *defines, idx_t num_values, parquet_filter_t *filter, - idx_t result_offset, Vector &result); - // these are nops for most types, but not for strings - virtual void PlainReference(shared_ptr, Vector &result); - - virtual void PrepareDeltaLengthByteArray(ResizeableBuffer &buffer); - virtual void PrepareDeltaByteArray(ResizeableBuffer &buffer); - virtual void DeltaByteArray(uint8_t *defines, idx_t num_values, parquet_filter_t &filter, idx_t result_offset, - Vector &result); - - // applies any skips that were registered using Skip() - virtual void ApplyPendingSkips(idx_t num_values); - - bool HasDefines() const { - return max_define > 0; - } - - bool HasRepeats() const { - return max_repeat > 0; - } - -protected: - const SchemaElement &schema; - optional_ptr parent_schema; - - idx_t file_idx; - idx_t max_define; - idx_t max_repeat; - - ParquetReader &reader; - LogicalType type; - unique_ptr byte_array_data; - idx_t byte_array_count = 0; - - idx_t pending_skips = 0; - - virtual void ResetPage(); - -private: - void AllocateBlock(idx_t size); - void AllocateCompressed(idx_t size); - void PrepareRead(parquet_filter_t &filter); - void PreparePage(PageHeader &page_hdr); - void PrepareDataPage(PageHeader &page_hdr); - void PreparePageV2(PageHeader &page_hdr); - void DecompressInternal(CompressionCodec::type codec, const_data_ptr_t src, idx_t src_size, data_ptr_t dst, - idx_t dst_size); - void ConvertDictToSelVec(uint32_t *offsets, uint8_t *defines, parquet_filter_t &filter, idx_t read_now, - idx_t result_offset); - const ColumnChunk *chunk = nullptr; - - TProtocol *protocol; - idx_t page_rows_available; - idx_t group_rows_available; - idx_t chunk_read_offset; - - shared_ptr block; - - ResizeableBuffer compressed_buffer; - ResizeableBuffer offset_buffer; - - unique_ptr dict_decoder; - unique_ptr defined_decoder; - unique_ptr repeated_decoder; - unique_ptr dbp_decoder; - unique_ptr rle_decoder; - unique_ptr bss_decoder; - - // dummies for Skip() - parquet_filter_t none_filter; - ResizeableBuffer dummy_define; - ResizeableBuffer dummy_repeat; - - SelectionVector dictionary_selection_vector; - idx_t dictionary_size; - unique_ptr dictionary; - string dictionary_id; - -public: - template - TARGET &Cast() { - if (TARGET::TYPE != PhysicalType::INVALID && type.InternalType() != TARGET::TYPE) { - throw InternalException("Failed to cast column reader to type - type mismatch"); - } - return reinterpret_cast(*this); - } - - template - const TARGET &Cast() const { - if (TARGET::TYPE != PhysicalType::INVALID && type.InternalType() != TARGET::TYPE) { - throw InternalException("Failed to cast column reader to type - type mismatch"); - } - return reinterpret_cast(*this); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/column_writer.hpp b/src/duckdb/extension/parquet/include/column_writer.hpp deleted file mode 100644 index b27254f74..000000000 --- a/src/duckdb/extension/parquet/include/column_writer.hpp +++ /dev/null @@ -1,121 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// column_writer.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#include "parquet_types.h" - -namespace duckdb { -class MemoryStream; -class ParquetWriter; -class ColumnWriterPageState; -class BasicColumnWriterState; -struct ChildFieldIDs; -class ResizeableBuffer; -class ParquetBloomFilter; - -class ColumnWriterState { -public: - virtual ~ColumnWriterState(); - - unsafe_vector definition_levels; - unsafe_vector repetition_levels; - vector is_empty; - idx_t null_count = 0; - -public: - template - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -class ColumnWriterStatistics { -public: - virtual ~ColumnWriterStatistics(); - - virtual bool HasStats(); - virtual string GetMin(); - virtual string GetMax(); - virtual string GetMinValue(); - virtual string GetMaxValue(); - -public: - template - TARGET &Cast() { - DynamicCastCheck(this); - return reinterpret_cast(*this); - } - template - const TARGET &Cast() const { - D_ASSERT(dynamic_cast(this)); - return reinterpret_cast(*this); - } -}; - -class ColumnWriter { - -public: - ColumnWriter(ParquetWriter &writer, idx_t schema_idx, vector schema_path, idx_t max_repeat, - idx_t max_define, bool can_have_nulls); - virtual ~ColumnWriter(); - - ParquetWriter &writer; - idx_t schema_idx; - vector schema_path; - idx_t max_repeat; - idx_t max_define; - bool can_have_nulls; - -public: - //! Create the column writer for a specific type recursively - static unique_ptr - CreateWriterRecursive(ClientContext &context, vector &schemas, ParquetWriter &writer, - const LogicalType &type, const string &name, vector schema_path, - optional_ptr field_ids, idx_t max_repeat = 0, idx_t max_define = 1, - bool can_have_nulls = true); - - virtual unique_ptr InitializeWriteState(duckdb_parquet::RowGroup &row_group) = 0; - - //! indicates whether the write need to analyse the data before preparing it - virtual bool HasAnalyze() { - return false; - } - - virtual void Analyze(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) { - throw NotImplementedException("Writer does not need analysis"); - } - - //! Called after all data has been passed to Analyze - virtual void FinalizeAnalyze(ColumnWriterState &state) { - throw NotImplementedException("Writer does not need analysis"); - } - - virtual void Prepare(ColumnWriterState &state, ColumnWriterState *parent, Vector &vector, idx_t count) = 0; - - virtual void BeginWrite(ColumnWriterState &state) = 0; - virtual void Write(ColumnWriterState &state, Vector &vector, idx_t count) = 0; - virtual void FinalizeWrite(ColumnWriterState &state) = 0; - -protected: - void HandleDefineLevels(ColumnWriterState &state, ColumnWriterState *parent, const ValidityMask &validity, - const idx_t count, const uint16_t define_value, const uint16_t null_value) const; - void HandleRepeatLevels(ColumnWriterState &state_p, ColumnWriterState *parent, idx_t count, idx_t max_repeat) const; - - void CompressPage(MemoryStream &temp_writer, size_t &compressed_size, data_ptr_t &compressed_data, - unique_ptr &compressed_buf); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/decode_utils.hpp b/src/duckdb/extension/parquet/include/decode_utils.hpp deleted file mode 100644 index d6c4a854b..000000000 --- a/src/duckdb/extension/parquet/include/decode_utils.hpp +++ /dev/null @@ -1,178 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// decode_utils.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb/common/bitpacking.hpp" -#include "resizable_buffer.hpp" - -namespace duckdb { - -class ParquetDecodeUtils { - //===--------------------------------------------------------------------===// - // Bitpacking - //===--------------------------------------------------------------------===// -private: - static const uint64_t BITPACK_MASKS[]; - static const uint64_t BITPACK_MASKS_SIZE; - static const uint8_t BITPACK_DLEN; - - static void CheckWidth(const uint8_t width) { - if (width >= BITPACK_MASKS_SIZE) { - throw InvalidInputException("The width (%d) of the bitpacked data exceeds the supported max width (%d), " - "the file might be corrupted.", - width, BITPACK_MASKS_SIZE); - } - } - -public: - template - static void BitUnpack(ByteBuffer &src, bitpacking_width_t &bitpack_pos, T *dst, idx_t count, - const bitpacking_width_t width) { - CheckWidth(width); - const auto mask = BITPACK_MASKS[width]; - src.available(count * width / BITPACK_DLEN); // check if buffer has enough space available once - if (bitpack_pos == 0 && count >= BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE) { - idx_t remainder = count % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - idx_t aligned_count = count - remainder; - BitUnpackAlignedInternal(src, dst, aligned_count, width); - dst += aligned_count; - count = remainder; - } - for (idx_t i = 0; i < count; i++) { - auto val = (src.unsafe_get() >> bitpack_pos) & mask; - bitpack_pos += width; - while (bitpack_pos > BITPACK_DLEN) { - src.unsafe_inc(1); - val |= (static_cast(src.unsafe_get()) - << static_cast(BITPACK_DLEN - (bitpack_pos - width))) & - mask; - bitpack_pos -= BITPACK_DLEN; - } - dst[i] = val; - } - } - - template - static void BitPackAligned(T *src, data_ptr_t dst, const idx_t count, const bitpacking_width_t width) { - D_ASSERT(width < BITPACK_MASKS_SIZE); - D_ASSERT(count % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0); - BitpackingPrimitives::PackBuffer(dst, src, count, width); - } - - template - static void BitUnpackAlignedInternal(ByteBuffer &src, T *dst, const idx_t count, const bitpacking_width_t width) { - for (idx_t i = 0; i < count; i += BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE) { - const auto next_read = BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE * width / 8; - - // Buffer for alignment - T aligned_data[BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE]; - - // Copy over to aligned buffer - memcpy(aligned_data, src.ptr, next_read); - - // Unpack - BitpackingPrimitives::UnPackBlock(data_ptr_cast(dst), data_ptr_cast(aligned_data), width, true); - - src.unsafe_inc(next_read); - dst += BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - } - } - - template - static void BitUnpackAligned(ByteBuffer &src, T *dst, const idx_t count, const bitpacking_width_t width) { - CheckWidth(width); - if (count % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE != 0) { - throw InvalidInputException("Aligned bitpacking count must be a multiple of %llu", - BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE); - } - const auto read_size = count * width / BITPACK_DLEN; - src.available(read_size); // check if buffer has enough space available once - BitUnpackAlignedInternal(src, dst, count, width); - } - - //===--------------------------------------------------------------------===// - // Zigzag - //===--------------------------------------------------------------------===// -private: - //! https://lemire.me/blog/2022/11/25/making-all-your-integers-positive-with-zigzag-encoding/ - template - static typename std::enable_if::value, typename std::make_signed::type>::type - ZigzagToIntInternal(UNSIGNED x) { - return (x >> 1) ^ (-(x & 1)); - } - - template - static typename std::enable_if::value, typename std::make_unsigned::type>::type - IntToZigzagInternal(SIGNED x) { - using UNSIGNED = typename std::make_unsigned::type; - return (static_cast(x) << 1) ^ static_cast(x >> (sizeof(SIGNED) * 8 - 1)); - } - -public: - template - static typename std::enable_if::value, typename std::make_signed::type>::type - ZigzagToInt(UNSIGNED x) { - auto integer = ZigzagToIntInternal(x); - D_ASSERT(x == IntToZigzagInternal(integer)); // test roundtrip - return integer; - } - - template - static typename std::enable_if::value, typename std::make_unsigned::type>::type - IntToZigzag(SIGNED x) { - auto zigzag = IntToZigzagInternal(x); - D_ASSERT(x == ZigzagToIntInternal(zigzag)); // test roundtrip - return zigzag; - } - - //===--------------------------------------------------------------------===// - // Varint - //===--------------------------------------------------------------------===// -public: - template - static uint8_t GetVarintSize(T val) { - uint8_t res = 0; - do { - val >>= 7; - res++; - } while (val != 0); - return res; - } - - template - static void VarintEncode(T val, WriteStream &ser) { - do { - uint8_t byte = val & 127; - val >>= 7; - if (val != 0) { - byte |= 128; - } - ser.Write(byte); - } while (val != 0); - } - - template - static T VarintDecode(ByteBuffer &buf) { - T result = 0; - uint8_t shift = 0; - while (true) { - auto byte = buf.read(); - result |= T(byte & 127) << shift; - if ((byte & 128) == 0) { - break; - } - shift += 7; - if (shift > sizeof(T) * 8) { - throw std::runtime_error("Varint-decoding found too large number"); - } - } - return result; - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/expression_column_reader.hpp b/src/duckdb/extension/parquet/include/expression_column_reader.hpp deleted file mode 100644 index c94a816d3..000000000 --- a/src/duckdb/extension/parquet/include/expression_column_reader.hpp +++ /dev/null @@ -1,52 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// expression_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -//! A column reader that executes an expression over a child reader -class ExpressionColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::INVALID; - -public: - ExpressionColumnReader(ClientContext &context, unique_ptr child_reader, unique_ptr expr); - - unique_ptr child_reader; - DataChunk intermediate_chunk; - unique_ptr expr; - ExpressionExecutor executor; - -public: - unique_ptr Stats(idx_t row_group_idx_p, const vector &columns) override; - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; - - idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result) override; - - void Skip(idx_t num_values) override; - idx_t GroupRowsAvailable() override; - - uint64_t TotalCompressedSize() override { - return child_reader->TotalCompressedSize(); - } - - idx_t FileOffset() const override { - return child_reader->FileOffset(); - } - - void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { - child_reader->RegisterPrefetch(transport, allow_merge); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/geo_parquet.hpp b/src/duckdb/extension/parquet/include/geo_parquet.hpp deleted file mode 100644 index 0a9b0966f..000000000 --- a/src/duckdb/extension/parquet/include/geo_parquet.hpp +++ /dev/null @@ -1,146 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// geo_parquet.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_writer.hpp" -#include "duckdb/common/string.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "parquet_types.h" - -namespace duckdb { - -enum class WKBGeometryType : uint16_t { - POINT = 1, - LINESTRING = 2, - POLYGON = 3, - MULTIPOINT = 4, - MULTILINESTRING = 5, - MULTIPOLYGON = 6, - GEOMETRYCOLLECTION = 7, - - POINT_Z = 1001, - LINESTRING_Z = 1002, - POLYGON_Z = 1003, - MULTIPOINT_Z = 1004, - MULTILINESTRING_Z = 1005, - MULTIPOLYGON_Z = 1006, - GEOMETRYCOLLECTION_Z = 1007, -}; - -struct WKBGeometryTypes { - static const char *ToString(WKBGeometryType type); -}; - -struct GeometryBounds { - double min_x = NumericLimits::Maximum(); - double max_x = NumericLimits::Minimum(); - double min_y = NumericLimits::Maximum(); - double max_y = NumericLimits::Minimum(); - - GeometryBounds() = default; - - void Combine(const GeometryBounds &other) { - min_x = std::min(min_x, other.min_x); - max_x = std::max(max_x, other.max_x); - min_y = std::min(min_y, other.min_y); - max_y = std::max(max_y, other.max_y); - } - - void Combine(const double &x, const double &y) { - min_x = std::min(min_x, x); - max_x = std::max(max_x, x); - min_y = std::min(min_y, y); - max_y = std::max(max_y, y); - } - - void Combine(const double &min_x, const double &max_x, const double &min_y, const double &max_y) { - this->min_x = std::min(this->min_x, min_x); - this->max_x = std::max(this->max_x, max_x); - this->min_y = std::min(this->min_y, min_y); - this->max_y = std::max(this->max_y, max_y); - } -}; - -//------------------------------------------------------------------------------ -// GeoParquetMetadata -//------------------------------------------------------------------------------ -class ParquetReader; -class ColumnReader; -class ClientContext; -class ExpressionExecutor; - -enum class GeoParquetColumnEncoding : uint8_t { - WKB = 1, - POINT, - LINESTRING, - POLYGON, - MULTIPOINT, - MULTILINESTRING, - MULTIPOLYGON, -}; - -struct GeoParquetColumnMetadata { - // The encoding of the geometry column - GeoParquetColumnEncoding geometry_encoding; - - // The geometry types that are present in the column - set geometry_types; - - // The bounds of the geometry column - GeometryBounds bbox; - - // The crs of the geometry column (if any) in PROJJSON format - string projjson; -}; - -class GeoParquetColumnMetadataWriter { - unique_ptr executor; - DataChunk input_chunk; - DataChunk result_chunk; - - unique_ptr type_expr; - unique_ptr flag_expr; - unique_ptr bbox_expr; - -public: - explicit GeoParquetColumnMetadataWriter(ClientContext &context); - void Update(GeoParquetColumnMetadata &meta, Vector &vector, idx_t count); -}; - -class GeoParquetFileMetadata { -public: - // Try to read GeoParquet metadata. Returns nullptr if not found, invalid or the required spatial extension is not - // available. - - static unique_ptr TryRead(const duckdb_parquet::FileMetaData &file_meta_data, - const ClientContext &context); - void Write(duckdb_parquet::FileMetaData &file_meta_data) const; - - void FlushColumnMeta(const string &column_name, const GeoParquetColumnMetadata &meta); - const unordered_map &GetColumnMeta() const; - - unique_ptr CreateColumnReader(ParquetReader &reader, const LogicalType &logical_type, - const duckdb_parquet::SchemaElement &s_ele, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p, ClientContext &context); - - bool IsGeometryColumn(const string &column_name) const; - void RegisterGeometryColumn(const string &column_name); - - static bool IsGeoParquetConversionEnabled(const ClientContext &context); - -private: - mutex write_lock; - string version = "1.1.0"; - string primary_geometry_column; - unordered_map geometry_columns; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/list_column_reader.hpp b/src/duckdb/extension/parquet/include/list_column_reader.hpp deleted file mode 100644 index 67565dfbf..000000000 --- a/src/duckdb/extension/parquet/include/list_column_reader.hpp +++ /dev/null @@ -1,60 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// list_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -class ListColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::LIST; - -public: - ListColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p, unique_ptr child_column_reader_p); - - idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result_out) override; - - void ApplyPendingSkips(idx_t num_values) override; - - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override { - child_column_reader->InitializeRead(row_group_idx_p, columns, protocol_p); - } - - idx_t GroupRowsAvailable() override { - return child_column_reader->GroupRowsAvailable() + overflow_child_count; - } - - uint64_t TotalCompressedSize() override { - return child_column_reader->TotalCompressedSize(); - } - - void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { - child_column_reader->RegisterPrefetch(transport, allow_merge); - } - -private: - unique_ptr child_column_reader; - ResizeableBuffer child_defines; - ResizeableBuffer child_repeats; - uint8_t *child_defines_ptr; - uint8_t *child_repeats_ptr; - - VectorCache read_cache; - Vector read_vector; - - parquet_filter_t child_filter; - - idx_t overflow_child_count; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/null_column_reader.hpp b/src/duckdb/extension/parquet/include/null_column_reader.hpp deleted file mode 100644 index 6d89c906b..000000000 --- a/src/duckdb/extension/parquet/include/null_column_reader.hpp +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// null_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "duckdb/common/helper.hpp" - -namespace duckdb { - -class NullColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::INVALID; - -public: - NullColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : ColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, max_define_p, max_repeat_p) {}; - - shared_ptr dict; - -public: - void Plain(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, parquet_filter_t *filter, - idx_t result_offset, Vector &result) override { - (void)defines; - (void)plain_data; - (void)filter; - - auto &result_mask = FlatVector::Validity(result); - for (idx_t row_idx = 0; row_idx < num_values; row_idx++) { - result_mask.SetInvalid(row_idx + result_offset); - } - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_bss_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_bss_decoder.hpp deleted file mode 100644 index b8cd8d11c..000000000 --- a/src/duckdb/extension/parquet/include/parquet_bss_decoder.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_bss_decoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once -#include "parquet_types.h" -#include "resizable_buffer.hpp" - -namespace duckdb { - -/// Decoder for the Byte Stream Split encoding -class BssDecoder { -public: - /// Create a decoder object. buffer/buffer_len is the encoded data. - BssDecoder(data_ptr_t buffer, uint32_t buffer_len) : buffer_(buffer, buffer_len), value_offset_(0) { - } - -public: - template - void GetBatch(data_ptr_t values_target_ptr, uint32_t batch_size) { - if (buffer_.len % sizeof(T) != 0) { - std::stringstream error; - error << "Data buffer size for the BYTE_STREAM_SPLIT encoding (" << buffer_.len - << ") should be a multiple of the type size (" << sizeof(T) << ")"; - throw std::runtime_error(error.str()); - } - uint32_t num_buffer_values = buffer_.len / sizeof(T); - - buffer_.available((value_offset_ + batch_size) * sizeof(T)); - - for (uint32_t byte_offset = 0; byte_offset < sizeof(T); ++byte_offset) { - data_ptr_t input_bytes = buffer_.ptr + byte_offset * num_buffer_values + value_offset_; - for (uint32_t i = 0; i < batch_size; ++i) { - values_target_ptr[byte_offset + i * sizeof(T)] = *(input_bytes + i); - } - } - value_offset_ += batch_size; - } - -private: - ByteBuffer buffer_; - uint32_t value_offset_; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_bss_encoder.hpp b/src/duckdb/extension/parquet/include/parquet_bss_encoder.hpp deleted file mode 100644 index 80da1726d..000000000 --- a/src/duckdb/extension/parquet/include/parquet_bss_encoder.hpp +++ /dev/null @@ -1,45 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_bss_encoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "decode_utils.hpp" - -namespace duckdb { - -class BssEncoder { -public: - explicit BssEncoder(const idx_t total_value_count_p, const idx_t bit_width_p) - : total_value_count(total_value_count_p), bit_width(bit_width_p), count(0), - buffer(Allocator::DefaultAllocator().Allocate(total_value_count * bit_width + 1)) { - } - -public: - template - void WriteValue(const T &value) { - D_ASSERT(sizeof(T) == bit_width); - for (idx_t i = 0; i < sizeof(T); i++) { - buffer.get()[i * total_value_count + count] = reinterpret_cast(&value)[i]; - } - count++; - } - - void FinishWrite(WriteStream &writer) { - D_ASSERT(count == total_value_count); - writer.WriteData(buffer.get(), total_value_count * bit_width); - } - -private: - const idx_t total_value_count; - const idx_t bit_width; - - idx_t count; - AllocatedData buffer; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_crypto.hpp b/src/duckdb/extension/parquet/include/parquet_crypto.hpp deleted file mode 100644 index 470648446..000000000 --- a/src/duckdb/extension/parquet/include/parquet_crypto.hpp +++ /dev/null @@ -1,92 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_crypto.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "parquet_types.h" -#include "duckdb/common/encryption_state.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/storage/object_cache.hpp" -#endif - -namespace duckdb { - -using duckdb_apache::thrift::TBase; -using duckdb_apache::thrift::protocol::TProtocol; - -class BufferedFileWriter; - -class ParquetKeys : public ObjectCacheEntry { -public: - static ParquetKeys &Get(ClientContext &context); - -public: - void AddKey(const string &key_name, const string &key); - bool HasKey(const string &key_name) const; - const string &GetKey(const string &key_name) const; - -public: - static string ObjectType(); - string GetObjectType() override; - -private: - unordered_map keys; -}; - -class ParquetEncryptionConfig { -public: - explicit ParquetEncryptionConfig(ClientContext &context); - ParquetEncryptionConfig(ClientContext &context, const Value &arg); - -public: - static shared_ptr Create(ClientContext &context, const Value &arg); - const string &GetFooterKey() const; - -public: - void Serialize(Serializer &serializer) const; - static shared_ptr Deserialize(Deserializer &deserializer); - -private: - ClientContext &context; - //! Name of the key used for the footer - string footer_key; - //! Mapping from column name to key name - unordered_map column_keys; -}; - -class ParquetCrypto { -public: - //! Encrypted modules - static constexpr idx_t LENGTH_BYTES = 4; - static constexpr idx_t NONCE_BYTES = 12; - static constexpr idx_t TAG_BYTES = 16; - - //! Block size we encrypt/decrypt - static constexpr idx_t CRYPTO_BLOCK_SIZE = 4096; - static constexpr idx_t BLOCK_SIZE = 16; - -public: - //! Decrypt and read a Thrift object from the transport protocol - static uint32_t Read(TBase &object, TProtocol &iprot, const string &key, const EncryptionUtil &encryption_util_p); - //! Encrypt and write a Thrift object to the transport protocol - static uint32_t Write(const TBase &object, TProtocol &oprot, const string &key, - const EncryptionUtil &encryption_util_p); - //! Decrypt and read a buffer - static uint32_t ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key, - const EncryptionUtil &encryption_util_p); - //! Encrypt and write a buffer to a file - static uint32_t WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, - const string &key, const EncryptionUtil &encryption_util_p); - -public: - static void AddKey(ClientContext &context, const FunctionParameters ¶meters); - static bool ValidKey(const std::string &key); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp deleted file mode 100644 index 4925a0ff9..000000000 --- a/src/duckdb/extension/parquet/include/parquet_dbp_decoder.hpp +++ /dev/null @@ -1,137 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_dbp_deccoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "decode_utils.hpp" - -namespace duckdb { - -class DbpDecoder { -public: - DbpDecoder(const data_ptr_t buffer, const uint32_t buffer_len) - : buffer_(buffer, buffer_len), - // - block_size_in_values(ParquetDecodeUtils::VarintDecode(buffer_)), - number_of_miniblocks_per_block(ParquetDecodeUtils::VarintDecode(buffer_)), - number_of_values_in_a_miniblock(block_size_in_values / number_of_miniblocks_per_block), - total_value_count(ParquetDecodeUtils::VarintDecode(buffer_)), - previous_value(ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_))), - // init state to something sane - is_first_value(true), read_values(0), min_delta(NumericLimits::Maximum()), - miniblock_index(number_of_miniblocks_per_block - 1), list_of_bitwidths_of_miniblocks(nullptr), - miniblock_offset(number_of_values_in_a_miniblock), - unpacked_data_offset(BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE) { - if (!(block_size_in_values % number_of_miniblocks_per_block == 0 && - number_of_values_in_a_miniblock % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0)) { - throw InvalidInputException("Parquet file has invalid block sizes for DELTA_BINARY_PACKED"); - } - }; - - ByteBuffer BufferPtr() const { - return buffer_; - } - - uint64_t TotalValues() const { - return total_value_count; - } - - template - void GetBatch(const data_ptr_t target_values_ptr, const idx_t batch_size) { - if (read_values + batch_size > total_value_count) { - throw std::runtime_error("DBP decode did not find enough values"); - } - read_values += batch_size; - GetBatchInternal(target_values_ptr, batch_size); - } - - void Finalize() { - if (miniblock_offset == number_of_values_in_a_miniblock) { - return; - } - auto data = make_unsafe_uniq_array(number_of_values_in_a_miniblock); - GetBatchInternal(data_ptr_cast(data.get()), number_of_values_in_a_miniblock - miniblock_offset); - } - -private: - template - void GetBatchInternal(const data_ptr_t target_values_ptr, const idx_t batch_size) { - if (batch_size == 0) { - return; - } - - auto target_values = reinterpret_cast(target_values_ptr); - idx_t target_values_offset = 0; - if (is_first_value) { - target_values[0] = static_cast(previous_value); - target_values_offset++; - is_first_value = false; - } - - while (target_values_offset < batch_size) { - // Copy over any remaining data - const idx_t next = MinValue(batch_size - target_values_offset, - BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE - unpacked_data_offset); - if (next != 0) { - for (idx_t i = 0; i < next; i++) { - auto &target = target_values[target_values_offset + i]; - const auto &unpacked_value = unpacked_data[unpacked_data_offset + i]; - target = static_cast(static_cast(previous_value) + static_cast(min_delta) + - unpacked_value); - previous_value = static_cast(target); - } - target_values_offset += next; - unpacked_data_offset += next; - continue; - } - - // Move to next miniblock / block - D_ASSERT(unpacked_data_offset == BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE); - D_ASSERT(miniblock_index < number_of_miniblocks_per_block); - D_ASSERT(miniblock_offset <= number_of_values_in_a_miniblock); - if (miniblock_offset == number_of_values_in_a_miniblock) { - miniblock_offset = 0; - if (++miniblock_index == number_of_miniblocks_per_block) { - // - min_delta = ParquetDecodeUtils::ZigzagToInt(ParquetDecodeUtils::VarintDecode(buffer_)); - buffer_.available(number_of_miniblocks_per_block); - list_of_bitwidths_of_miniblocks = buffer_.ptr; - buffer_.unsafe_inc(number_of_miniblocks_per_block); - miniblock_index = 0; - } - } - - // Unpack from current miniblock - ParquetDecodeUtils::BitUnpackAligned(buffer_, unpacked_data, - BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE, - list_of_bitwidths_of_miniblocks[miniblock_index]); - unpacked_data_offset = 0; - miniblock_offset += BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE; - } - } - -private: - ByteBuffer buffer_; - const idx_t block_size_in_values; - const idx_t number_of_miniblocks_per_block; - const idx_t number_of_values_in_a_miniblock; - const idx_t total_value_count; - int64_t previous_value; - - bool is_first_value; - idx_t read_values; - - //! Block stuff - int64_t min_delta; - idx_t miniblock_index; - bitpacking_width_t *list_of_bitwidths_of_miniblocks; - idx_t miniblock_offset; - uint64_t unpacked_data[BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE]; - idx_t unpacked_data_offset; -}; -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_dbp_encoder.hpp b/src/duckdb/extension/parquet/include/parquet_dbp_encoder.hpp deleted file mode 100644 index 791d10e08..000000000 --- a/src/duckdb/extension/parquet/include/parquet_dbp_encoder.hpp +++ /dev/null @@ -1,179 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_dbp_encoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "decode_utils.hpp" - -namespace duckdb { - -class DbpEncoder { -private: - static constexpr uint64_t BLOCK_SIZE_IN_VALUES = 2048; - static constexpr uint64_t NUMBER_OF_MINIBLOCKS_IN_A_BLOCK = 8; - static constexpr uint64_t NUMBER_OF_VALUES_IN_A_MINIBLOCK = BLOCK_SIZE_IN_VALUES / NUMBER_OF_MINIBLOCKS_IN_A_BLOCK; - -public: - explicit DbpEncoder(const idx_t total_value_count_p) : total_value_count(total_value_count_p), count(0) { - } - -public: - void BeginWrite(WriteStream &writer, const int64_t &first_value) { - // - - // the block size is a multiple of 128; it is stored as a ULEB128 int - ParquetDecodeUtils::VarintEncode(BLOCK_SIZE_IN_VALUES, writer); - // the miniblock count per block is a divisor of the block size such that their quotient, - // the number of values in a miniblock, is a multiple of 32 - static_assert(BLOCK_SIZE_IN_VALUES % NUMBER_OF_MINIBLOCKS_IN_A_BLOCK == 0 && - NUMBER_OF_VALUES_IN_A_MINIBLOCK % BitpackingPrimitives::BITPACKING_ALGORITHM_GROUP_SIZE == 0, - "invalid block sizes for DELTA_BINARY_PACKED"); - // it is stored as a ULEB128 int - ParquetDecodeUtils::VarintEncode(NUMBER_OF_MINIBLOCKS_IN_A_BLOCK, writer); - // the total value count is stored as a ULEB128 int - ParquetDecodeUtils::VarintEncode(total_value_count, writer); - // the first value is stored as a zigzag ULEB128 int - ParquetDecodeUtils::VarintEncode(ParquetDecodeUtils::IntToZigzag(first_value), writer); - - // initialize - if (total_value_count != 0) { - count++; - } - previous_value = first_value; - - min_delta = NumericLimits::Maximum(); - block_count = 0; - } - - void WriteValue(WriteStream &writer, const int64_t &value) { - // 1. Compute the differences between consecutive elements. For the first element in the block, - // use the last element in the previous block or, in the case of the first block, - // use the first value of the whole sequence, stored in the header. - - // Subtractions in steps 1) and 2) may incur signed arithmetic overflow, - // and so will the corresponding additions when decoding. - // Overflow should be allowed and handled as wrapping around in 2’s complement notation - // so that the original values are correctly restituted. - // This may require explicit care in some programming languages - // (for example by doing all arithmetic in the unsigned domain). - const auto delta = static_cast(static_cast(value) - static_cast(previous_value)); - previous_value = value; - // Compute the frame of reference (the minimum of the deltas in the block). - min_delta = MinValue(min_delta, delta); - // append. if block is full, write it out - data[block_count++] = delta; - if (block_count == BLOCK_SIZE_IN_VALUES) { - WriteBlock(writer); - } - } - - void FinishWrite(WriteStream &writer) { - if (count + block_count != total_value_count) { - throw InternalException("value count mismatch when writing DELTA_BINARY_PACKED"); - } - if (block_count != 0) { - WriteBlock(writer); - } - } - -private: - void WriteBlock(WriteStream &writer) { - D_ASSERT(count + block_count == total_value_count || block_count == BLOCK_SIZE_IN_VALUES); - const auto number_of_miniblocks = - (block_count + NUMBER_OF_VALUES_IN_A_MINIBLOCK - 1) / NUMBER_OF_VALUES_IN_A_MINIBLOCK; - for (idx_t miniblock_idx = 0; miniblock_idx < number_of_miniblocks; miniblock_idx++) { - for (idx_t i = 0; i < NUMBER_OF_VALUES_IN_A_MINIBLOCK; i++) { - const idx_t index = miniblock_idx * NUMBER_OF_VALUES_IN_A_MINIBLOCK + i; - auto &value = data[index]; - if (index < block_count) { - // 2. Compute the frame of reference (the minimum of the deltas in the block). - // Subtract this min delta from all deltas in the block. - // This guarantees that all values are non-negative. - D_ASSERT(min_delta <= value); - value = static_cast(static_cast(value) - static_cast(min_delta)); - } else { - // If there are not enough values to fill the last miniblock, we pad the miniblock - // so that its length is always the number of values in a full miniblock multiplied by the bit - // width. The values of the padding bits should be zero, but readers must accept paddings consisting - // of arbitrary bits as well. - value = 0; - } - } - } - - for (idx_t miniblock_idx = 0; miniblock_idx < NUMBER_OF_MINIBLOCKS_IN_A_BLOCK; miniblock_idx++) { - auto &width = list_of_bitwidths_of_miniblocks[miniblock_idx]; - if (miniblock_idx < number_of_miniblocks) { - const auto src = &data[miniblock_idx * NUMBER_OF_VALUES_IN_A_MINIBLOCK]; - width = BitpackingPrimitives::MinimumBitWidth(reinterpret_cast(src), - NUMBER_OF_VALUES_IN_A_MINIBLOCK); - D_ASSERT(width <= sizeof(int64_t) * 8); - } else { - // If, in the last block, less than miniblocks are needed to store the - // values, the bytes storing the bit widths of the unneeded miniblocks are still present, their value - // should be zero, but readers must accept arbitrary values as well. There are no additional padding - // bytes for the miniblock bodies though, as if their bit widths were 0 (regardless of the actual byte - // values). The reader knows when to stop reading by keeping track of the number of values read. - width = 0; - } - } - - // 3. Encode the frame of reference (min delta) as a zigzag ULEB128 int - // followed by the bit widths of the miniblocks - // and the delta values (minus the min delta) bit-packed per miniblock. - // - - // the min delta is a zigzag ULEB128 int (we compute a minimum as we need positive integers for bit packing) - ParquetDecodeUtils::VarintEncode(ParquetDecodeUtils::IntToZigzag(min_delta), writer); - // the bitwidth of each block is stored as a byte - writer.WriteData(list_of_bitwidths_of_miniblocks, NUMBER_OF_MINIBLOCKS_IN_A_BLOCK); - // each miniblock is a list of bit packed ints according to the bit width stored at the beginning of the block - for (idx_t miniblock_idx = 0; miniblock_idx < number_of_miniblocks; miniblock_idx++) { - const auto src = &data[miniblock_idx * NUMBER_OF_VALUES_IN_A_MINIBLOCK]; - const auto &width = list_of_bitwidths_of_miniblocks[miniblock_idx]; - memset(data_packed, 0, sizeof(data_packed)); - ParquetDecodeUtils::BitPackAligned(reinterpret_cast(src), data_packed, - NUMBER_OF_VALUES_IN_A_MINIBLOCK, width); - const auto write_size = NUMBER_OF_VALUES_IN_A_MINIBLOCK * width / 8; -#ifdef DEBUG - // immediately verify that unpacking yields the input data - int64_t verification_data[NUMBER_OF_VALUES_IN_A_MINIBLOCK]; - ByteBuffer byte_buffer(data_ptr_cast(data_packed), write_size); - bitpacking_width_t bitpack_pos = 0; - ParquetDecodeUtils::BitUnpack(byte_buffer, bitpack_pos, verification_data, NUMBER_OF_VALUES_IN_A_MINIBLOCK, - width); - for (idx_t i = 0; i < NUMBER_OF_VALUES_IN_A_MINIBLOCK; i++) { - D_ASSERT(src[i] == verification_data[i]); - } -#endif - writer.WriteData(data_packed, write_size); - } - - count += block_count; - - min_delta = NumericLimits::Maximum(); - block_count = 0; - } - -private: - //! Overall fields - const idx_t total_value_count; - idx_t count; - int64_t previous_value; - - //! Block-specific fields - int64_t min_delta; - int64_t data[BLOCK_SIZE_IN_VALUES]; - idx_t block_count; - - //! Bitpacking fields - bitpacking_width_t list_of_bitwidths_of_miniblocks[NUMBER_OF_MINIBLOCKS_IN_A_BLOCK]; - data_t data_packed[NUMBER_OF_VALUES_IN_A_MINIBLOCK * sizeof(int64_t)]; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp b/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp deleted file mode 100644 index 119ed5672..000000000 --- a/src/duckdb/extension/parquet/include/parquet_decimal_utils.hpp +++ /dev/null @@ -1,58 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_decimal_utils.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -class ParquetDecimalUtils { -public: - template - static PHYSICAL_TYPE ReadDecimalValue(const_data_ptr_t pointer, idx_t size, const duckdb_parquet::SchemaElement &) { - PHYSICAL_TYPE res = 0; - - auto res_ptr = (uint8_t *)&res; - bool positive = (*pointer & 0x80) == 0; - - // numbers are stored as two's complement so some muckery is required - for (idx_t i = 0; i < MinValue(size, sizeof(PHYSICAL_TYPE)); i++) { - auto byte = *(pointer + (size - i - 1)); - res_ptr[i] = positive ? byte : byte ^ 0xFF; - } - // Verify that there are only 0s here - if (size > sizeof(PHYSICAL_TYPE)) { - for (idx_t i = sizeof(PHYSICAL_TYPE); i < size; i++) { - auto byte = *(pointer + (size - i - 1)); - if (!positive) { - byte ^= 0xFF; - } - if (byte != 0) { - throw InvalidInputException("Invalid decimal encoding in Parquet file"); - } - } - } - if (!positive) { - res += 1; - return -res; - } - return res; - } - - static unique_ptr CreateReader(ParquetReader &reader, const LogicalType &type_p, - const SchemaElement &schema_p, idx_t file_idx_p, idx_t max_define, - idx_t max_repeat); -}; - -template <> -double ParquetDecimalUtils::ReadDecimalValue(const_data_ptr_t pointer, idx_t size, - const duckdb_parquet::SchemaElement &schema_ele); - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_dlba_encoder.hpp b/src/duckdb/extension/parquet/include/parquet_dlba_encoder.hpp deleted file mode 100644 index b3cd1aa96..000000000 --- a/src/duckdb/extension/parquet/include/parquet_dlba_encoder.hpp +++ /dev/null @@ -1,48 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_dlba_encoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "parquet_dbp_encoder.hpp" -#include "duckdb/common/serializer/memory_stream.hpp" - -namespace duckdb { - -class DlbaEncoder { -public: - DlbaEncoder(const idx_t total_value_count_p, const idx_t total_string_size_p) - : dbp_encoder(total_value_count_p), total_string_size(total_string_size_p), - buffer(Allocator::DefaultAllocator().Allocate(total_string_size + 1)), - stream(make_unsafe_uniq(buffer.get(), buffer.GetSize())) { - } - -public: - void BeginWrite(WriteStream &writer, const string_t &first_value) { - dbp_encoder.BeginWrite(writer, UnsafeNumericCast(first_value.GetSize())); - stream->WriteData(const_data_ptr_cast(first_value.GetData()), first_value.GetSize()); - } - - void WriteValue(WriteStream &writer, const string_t &value) { - dbp_encoder.WriteValue(writer, UnsafeNumericCast(value.GetSize())); - stream->WriteData(const_data_ptr_cast(value.GetData()), value.GetSize()); - } - - void FinishWrite(WriteStream &writer) { - D_ASSERT(stream->GetPosition() == total_string_size); - dbp_encoder.FinishWrite(writer); - writer.WriteData(buffer.get(), total_string_size); - } - -private: - DbpEncoder dbp_encoder; - const idx_t total_string_size; - AllocatedData buffer; - unsafe_unique_ptr stream; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_extension.hpp b/src/duckdb/extension/parquet/include/parquet_extension.hpp deleted file mode 100644 index 413a104b6..000000000 --- a/src/duckdb/extension/parquet/include/parquet_extension.hpp +++ /dev/null @@ -1,22 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_extension.hpp -// -// -//===----------------------------------------------------------------------===/ - -#pragma once - -#include "duckdb.hpp" - -namespace duckdb { - -class ParquetExtension : public Extension { -public: - void Load(DuckDB &db) override; - std::string Name() override; - std::string Version() const override; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp b/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp deleted file mode 100644 index b7373056f..000000000 --- a/src/duckdb/extension/parquet/include/parquet_file_metadata_cache.hpp +++ /dev/null @@ -1,48 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_file_metadata_cache.hpp -// -// -//===----------------------------------------------------------------------===// -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/storage/object_cache.hpp" -#include "geo_parquet.hpp" -#endif -#include "parquet_types.h" -namespace duckdb { - -//! ParquetFileMetadataCache -class ParquetFileMetadataCache : public ObjectCacheEntry { -public: - ParquetFileMetadataCache() : metadata(nullptr) { - } - ParquetFileMetadataCache(unique_ptr file_metadata, time_t r_time, - unique_ptr geo_metadata) - : metadata(std::move(file_metadata)), read_time(r_time), geo_metadata(std::move(geo_metadata)) { - } - - ~ParquetFileMetadataCache() override = default; - - //! Parquet file metadata - unique_ptr metadata; - - //! read time - time_t read_time; - - //! GeoParquet metadata - unique_ptr geo_metadata; - -public: - static string ObjectType() { - return "parquet_metadata"; - } - - string GetObjectType() override { - return ObjectType(); - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_metadata.hpp b/src/duckdb/extension/parquet/include/parquet_metadata.hpp deleted file mode 100644 index 09ecd5afa..000000000 --- a/src/duckdb/extension/parquet/include/parquet_metadata.hpp +++ /dev/null @@ -1,41 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_metadata.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "parquet_reader.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -class ParquetMetaDataFunction : public TableFunction { -public: - ParquetMetaDataFunction(); -}; - -class ParquetSchemaFunction : public TableFunction { -public: - ParquetSchemaFunction(); -}; - -class ParquetKeyValueMetadataFunction : public TableFunction { -public: - ParquetKeyValueMetadataFunction(); -}; - -class ParquetFileMetadataFunction : public TableFunction { -public: - ParquetFileMetadataFunction(); -}; - -class ParquetBloomProbeFunction : public TableFunction { -public: - ParquetBloomProbeFunction(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_reader.hpp b/src/duckdb/extension/parquet/include/parquet_reader.hpp deleted file mode 100644 index 79de2b4cc..000000000 --- a/src/duckdb/extension/parquet/include/parquet_reader.hpp +++ /dev/null @@ -1,226 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/common.hpp" -#include "duckdb/common/encryption_state.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/multi_file_reader_options.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/null_filter.hpp" -#include "duckdb/planner/table_filter.hpp" -#endif -#include "column_reader.hpp" -#include "parquet_file_metadata_cache.hpp" -#include "parquet_rle_bp_decoder.hpp" -#include "parquet_types.h" -#include "resizable_buffer.hpp" - -#include - -namespace duckdb_parquet { -namespace format { -class FileMetaData; -} -} // namespace duckdb_parquet - -namespace duckdb { -class Allocator; -class ClientContext; -class BaseStatistics; -class TableFilterSet; -class ParquetEncryptionConfig; - -struct ParquetReaderPrefetchConfig { - // Percentage of data in a row group span that should be scanned for enabling whole group prefetch - static constexpr double WHOLE_GROUP_PREFETCH_MINIMUM_SCAN = 0.95; -}; - -struct ParquetReaderScanState { - vector group_idx_list; - int64_t current_group; - idx_t group_offset; - unique_ptr file_handle; - unique_ptr root_reader; - std::unique_ptr thrift_file_proto; - - bool finished; - SelectionVector sel; - - ResizeableBuffer define_buf; - ResizeableBuffer repeat_buf; - - bool prefetch_mode = false; - bool current_group_prefetched = false; -}; - -struct ParquetColumnDefinition { -public: - static ParquetColumnDefinition FromSchemaValue(ClientContext &context, const Value &column_value); - -public: - int32_t field_id; - string name; - LogicalType type; - Value default_value; - -public: - void Serialize(Serializer &serializer) const; - static ParquetColumnDefinition Deserialize(Deserializer &deserializer); -}; - -struct ParquetOptions { - explicit ParquetOptions() { - } - explicit ParquetOptions(ClientContext &context); - - bool binary_as_string = false; - bool file_row_number = false; - shared_ptr encryption_config; - bool debug_use_openssl = true; - - MultiFileReaderOptions file_options; - vector schema; - idx_t explicit_cardinality = 0; - -public: - void Serialize(Serializer &serializer) const; - static ParquetOptions Deserialize(Deserializer &deserializer); -}; - -struct ParquetUnionData { - ~ParquetUnionData(); - - string file_name; - vector names; - vector types; - ParquetOptions options; - shared_ptr metadata; - unique_ptr reader; - - const string &GetFileName() { - return file_name; - } -}; - -class ParquetReader { -public: - using UNION_READER_DATA = unique_ptr; - -public: - ParquetReader(ClientContext &context, string file_name, ParquetOptions parquet_options, - shared_ptr metadata = nullptr); - ~ParquetReader(); - - FileSystem &fs; - Allocator &allocator; - string file_name; - vector return_types; - vector names; - shared_ptr metadata; - ParquetOptions parquet_options; - MultiFileReaderData reader_data; - unique_ptr root_reader; - shared_ptr encryption_util; - - //! Index of the file_row_number column - idx_t file_row_number_idx = DConstants::INVALID_INDEX; - //! Parquet schema for the generated columns - vector generated_column_schema; - //! Table column names - set when using COPY tbl FROM file.parquet - vector table_columns; - -public: - void InitializeScan(ClientContext &context, ParquetReaderScanState &state, vector groups_to_read); - void Scan(ParquetReaderScanState &state, DataChunk &output); - - static unique_ptr StoreUnionReader(unique_ptr reader_p, idx_t file_idx) { - auto result = make_uniq(); - result->file_name = reader_p->file_name; - if (file_idx == 0) { - result->names = reader_p->names; - result->types = reader_p->return_types; - result->options = reader_p->parquet_options; - result->metadata = reader_p->metadata; - result->reader = std::move(reader_p); - } else { - result->names = std::move(reader_p->names); - result->types = std::move(reader_p->return_types); - result->options = std::move(reader_p->parquet_options); - result->metadata = std::move(reader_p->metadata); - } - return result; - } - - idx_t NumRows(); - idx_t NumRowGroups(); - - const duckdb_parquet::FileMetaData *GetFileMetadata(); - - uint32_t Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot); - uint32_t ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, - const uint32_t buffer_size); - - unique_ptr ReadStatistics(const string &name); - static LogicalType DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string); - - FileHandle &GetHandle() { - return *file_handle; - } - - const string &GetFileName() { - return file_name; - } - const vector &GetNames() { - return names; - } - const vector &GetTypes() { - return return_types; - } - - static unique_ptr ReadStatistics(ClientContext &context, ParquetOptions parquet_options, - shared_ptr metadata, const string &name); - -private: - //! Construct a parquet reader but **do not** open a file, used in ReadStatistics only - ParquetReader(ClientContext &context, ParquetOptions parquet_options, - shared_ptr metadata); - - void InitializeSchema(ClientContext &context); - bool ScanInternal(ParquetReaderScanState &state, DataChunk &output); - unique_ptr CreateReader(ClientContext &context); - - unique_ptr CreateReaderRecursive(ClientContext &context, const vector &indexes, - idx_t depth, idx_t max_define, idx_t max_repeat, - idx_t &next_schema_idx, idx_t &next_file_idx); - const duckdb_parquet::RowGroup &GetGroup(ParquetReaderScanState &state); - uint64_t GetGroupCompressedSize(ParquetReaderScanState &state); - idx_t GetGroupOffset(ParquetReaderScanState &state); - // Group span is the distance between the min page offset and the max page offset plus the max page compressed size - uint64_t GetGroupSpan(ParquetReaderScanState &state); - void PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t out_col_idx); - LogicalType DeriveLogicalType(const SchemaElement &s_ele); - - template - std::runtime_error FormatException(const string fmt_str, Args... params) { - return std::runtime_error("Failed to read Parquet file \"" + file_name + - "\": " + StringUtil::Format(fmt_str, params...)); - } - -private: - unique_ptr file_handle; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp b/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp deleted file mode 100644 index b8dc35b35..000000000 --- a/src/duckdb/extension/parquet/include/parquet_rle_bp_decoder.hpp +++ /dev/null @@ -1,116 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_rle_bp_decoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once -#include "decode_utils.hpp" -#include "parquet_types.h" -#include "resizable_buffer.hpp" -#include "thrift_tools.hpp" - -namespace duckdb { - -class RleBpDecoder { -public: - /// Create a decoder object. buffer/buffer_len is the decoded data. - /// bit_width is the width of each value (before encoding). - RleBpDecoder(data_ptr_t buffer, uint32_t buffer_len, uint32_t bit_width) - : buffer_(buffer, buffer_len), bit_width_(bit_width), current_value_(0), repeat_count_(0), literal_count_(0) { - if (bit_width >= 64) { - throw std::runtime_error("Decode bit width too large"); - } - byte_encoded_len = ((bit_width_ + 7) / 8); - max_val = (uint64_t(1) << bit_width_) - 1; - } - - template - void GetBatch(data_ptr_t values_target_ptr, uint32_t batch_size) { - auto values = reinterpret_cast(values_target_ptr); - uint32_t values_read = 0; - - while (values_read < batch_size) { - if (repeat_count_ > 0) { - int repeat_batch = MinValue(batch_size - values_read, static_cast(repeat_count_)); - std::fill_n(values + values_read, repeat_batch, static_cast(current_value_)); - repeat_count_ -= repeat_batch; - values_read += repeat_batch; - } else if (literal_count_ > 0) { - uint32_t literal_batch = MinValue(batch_size - values_read, static_cast(literal_count_)); - ParquetDecodeUtils::BitUnpack(buffer_, bitpack_pos, values + values_read, literal_batch, bit_width_); - literal_count_ -= literal_batch; - values_read += literal_batch; - } else { - if (!NextCounts()) { - if (values_read != batch_size) { - throw std::runtime_error("RLE decode did not find enough values"); - } - return; - } - } - } - if (values_read != batch_size) { - throw std::runtime_error("RLE decode did not find enough values"); - } - } - - static uint8_t ComputeBitWidth(idx_t val) { - if (val == 0) { - return 0; - } - uint8_t ret = 1; - while ((((idx_t)1u << (idx_t)ret) - 1) < val) { - ret++; - } - return ret; - } - -private: - ByteBuffer buffer_; - - /// Number of bits needed to encode the value. Must be between 0 and 64. - uint32_t bit_width_; - uint64_t current_value_; - uint32_t repeat_count_; - uint32_t literal_count_; - uint8_t byte_encoded_len; - uint64_t max_val; - - uint8_t bitpack_pos = 0; - - /// Fills literal_count_ and repeat_count_ with next values. Returns false if there - /// are no more. - template - bool NextCounts() { - // Read the next run's indicator int, it could be a literal or repeated run. - // The int is encoded as a vlq-encoded value. - if (bitpack_pos != 0) { - buffer_.inc(1); - bitpack_pos = 0; - } - auto indicator_value = ParquetDecodeUtils::VarintDecode(buffer_); - - // lsb indicates if it is a literal run or repeated run - bool is_literal = indicator_value & 1; - if (is_literal) { - literal_count_ = (indicator_value >> 1) * 8; - } else { - repeat_count_ = indicator_value >> 1; - // (ARROW-4018) this is not big-endian compatible, lol - current_value_ = 0; - for (auto i = 0; i < byte_encoded_len; i++) { - current_value_ |= (buffer_.read() << (i * 8)); - } - // sanity check - if (repeat_count_ > 0 && current_value_ > max_val) { - throw std::runtime_error("Payload value bigger than allowed. Corrupted file?"); - } - } - // TODO complain if we run out of buffer - return true; - } -}; -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp b/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp deleted file mode 100644 index 029dd06eb..000000000 --- a/src/duckdb/extension/parquet/include/parquet_rle_bp_encoder.hpp +++ /dev/null @@ -1,49 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_rle_bp_encoder.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "parquet_types.h" -#include "thrift_tools.hpp" -#include "resizable_buffer.hpp" - -namespace duckdb { - -class RleBpEncoder { -public: - explicit RleBpEncoder(uint32_t bit_width); - -public: - //! NOTE: Prepare is only required if a byte count is required BEFORE writing - //! This is the case with e.g. writing repetition/definition levels - //! If GetByteCount() is not required, prepare can be safely skipped - void BeginPrepare(uint32_t first_value); - void PrepareValue(uint32_t value); - void FinishPrepare(); - - void BeginWrite(WriteStream &writer, uint32_t first_value); - void WriteValue(WriteStream &writer, uint32_t value); - void FinishWrite(WriteStream &writer); - - idx_t GetByteCount(); - -private: - //! meta information - uint32_t byte_width; - //! RLE run information - idx_t byte_count; - idx_t run_count; - idx_t current_run_count; - uint32_t last_value; - -private: - void FinishRun(); - void WriteRun(WriteStream &writer); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_statistics.hpp b/src/duckdb/extension/parquet/include/parquet_statistics.hpp deleted file mode 100644 index ad1f939c8..000000000 --- a/src/duckdb/extension/parquet/include/parquet_statistics.hpp +++ /dev/null @@ -1,110 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_statistics.hpp -// -// -//===----------------------------------------------------------------------===/ - -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/storage/statistics/base_statistics.hpp" -#endif -#include "parquet_types.h" - -namespace duckdb { - -using duckdb_parquet::ColumnChunk; -using duckdb_parquet::SchemaElement; - -struct LogicalType; -class ColumnReader; -class ResizeableBuffer; - -struct ParquetStatisticsUtils { - - static unique_ptr TransformColumnStatistics(const ColumnReader &reader, - const vector &columns); - - static Value ConvertValue(const LogicalType &type, const duckdb_parquet::SchemaElement &schema_ele, - const std::string &stats); - - static bool BloomFilterSupported(const LogicalTypeId &type_id); - - static bool BloomFilterExcludes(const TableFilter &filter, const duckdb_parquet::ColumnMetaData &column_meta_data, - duckdb_apache::thrift::protocol::TProtocol &file_proto, Allocator &allocator); - -private: - static Value ConvertValueInternal(const LogicalType &type, const duckdb_parquet::SchemaElement &schema_ele, - const std::string &stats); -}; - -class ParquetBloomFilter { - static constexpr const idx_t DEFAULT_BLOCK_COUNT = 32; // 4k filter - -public: - ParquetBloomFilter(idx_t num_entries, double bloom_filter_false_positive_ratio); - ParquetBloomFilter(unique_ptr data_p); - void FilterInsert(uint64_t x); - bool FilterCheck(uint64_t x); - void Shrink(idx_t new_block_count); - double OneRatio(); - ResizeableBuffer *Get(); - -private: - unique_ptr data; - idx_t block_count; -}; - -// see https://github.com/apache/parquet-format/blob/master/BloomFilter.md - -struct ParquetBloomBlock { - struct ParquetBloomMaskResult { - uint8_t bit_set[8] = {0}; - }; - - uint32_t block[8] = {0}; - - static bool check_bit(uint32_t &x, const uint8_t i) { - D_ASSERT(i < 32); - return (x >> i) & (uint32_t)1; - } - - static void set_bit(uint32_t &x, const uint8_t i) { - D_ASSERT(i < 32); - x |= (uint32_t)1 << i; - D_ASSERT(check_bit(x, i)); - } - - static ParquetBloomMaskResult Mask(uint32_t x) { - static const uint32_t parquet_bloom_salt[8] = {0x47b6137bU, 0x44974d91U, 0x8824ad5bU, 0xa2b7289dU, - 0x705495c7U, 0x2df1424bU, 0x9efc4947U, 0x5c6bfb31U}; - ParquetBloomMaskResult result; - for (idx_t i = 0; i < 8; i++) { - result.bit_set[i] = (x * parquet_bloom_salt[i]) >> 27; - } - return result; - } - - static void BlockInsert(ParquetBloomBlock &b, uint32_t x) { - auto masked = Mask(x); - for (idx_t i = 0; i < 8; i++) { - set_bit(b.block[i], masked.bit_set[i]); - D_ASSERT(check_bit(b.block[i], masked.bit_set[i])); - } - } - - static bool BlockCheck(ParquetBloomBlock &b, uint32_t x) { - auto masked = Mask(x); - for (idx_t i = 0; i < 8; i++) { - if (!check_bit(b.block[i], masked.bit_set[i])) { - return false; - } - } - return true; - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_support.hpp b/src/duckdb/extension/parquet/include/parquet_support.hpp deleted file mode 100644 index 91c43fcb4..000000000 --- a/src/duckdb/extension/parquet/include/parquet_support.hpp +++ /dev/null @@ -1,621 +0,0 @@ -#pragma once - -namespace duckdb { - -class StripeStreams { -public: - virtual ~StripeStreams() = default; - - /** - * get column selector for current stripe reading session - * @return column selector will hold column projection info - */ - virtual const dwio::common::ColumnSelector &getColumnSelector() const = 0; - - // Get row reader options - virtual const dwio::common::RowReaderOptclass StripeStreams { - public: - virtual ~StripeStreams() = default; - - /** - * get column selector for current stripe reading session - * @return column selector will hold column projection info - */ - virtual const dwio::common::ColumnSelector &getColumnSelector() const = 0; - - // Get row reader options - virtual const dwio::common::RowReaderOptions &getRowReaderOptions() const = 0; - - /** - * Get the encoding for the given column for this stripe. - */ - virtual const proto::ColumnEncoding &getEncoding(const EncodingKey &) const = 0; - - /** - * Get the stream for the given column/kind in this stripe. - * @param streamId stream identifier object - * @param throwIfNotFound fail if a stream is required and not found - * @return the new stream - */ - virtual unique_ptr getStream(const StreamIdentifier &si, bool throwIfNotFound) const = 0; - - /** - * visit all streams of given node and execute visitor logic - * return number of streams visited - */ - virtual uint32_t visitStreamsOfNode(uint32_t node, std::function visitor) - const = 0; - - /** - * Get the value of useVInts for the given column in this stripe. - * Defaults to true. - * @param streamId stream identifier - */ - virtual bool getUseVInts(const StreamIdentifier &streamId) const = 0; - - /** - * Get the memory pool for this reader. - */ - virtual memory::MemoryPool &getMemoryPool() const = 0; - - /** - * Get the RowGroupIndex. - * @return a vector of RowIndex belonging to the stripe - */ - virtual unique_ptr getRowGroupIndex(const StreamIdentifier &si) const = 0; - - /** - * Get stride index provider which is used by string dictionary reader to - * get the row index stride index where next() happens - */ - virtual const StrideIndexProvider &getStrideIndexProvider() const = 0; - } - ions &getRowReaderOptions() const = 0; - - /** - * Get the encoding for the given column for this stripe. - */ - virtual const proto::ColumnEncoding &getEncoding(const EncodingKey &) const = 0; - - /** - * Get the stream for the given column/kind in this stripe. - * @param streamId stream identifier object - * @param throwIfNotFound fail if a stream is required and not found - * @return the new stream - */ - virtual unique_ptr getStream(const StreamIdentifier &si, bool throwIfNotFound) const = 0; - - /** - * visit all streams of given node and execute visitor logic - * return number of streams visited - */ - virtual uint32_t visitStreamsOfNode(uint32_t node, - std::function visitor) const = 0; - - /** - * Get the value of useVInts for the given column in this stripe. - * Defaults to true. - * @param streamId stream identifier - */ - virtual bool getUseVInts(const StreamIdentifier &streamId) const = 0; - - /** - * Get the memory pool for this reader. - */ - virtual memory::MemoryPool &getMemoryPool() const = 0; - - /** - * Get the RowGroupIndex. - * @return a vector of RowIndex belonging to the stripe - */ - virtual unique_ptr getRowGroupIndex(const StreamIdentifier &si) const = 0; - - /** - * Get stride index provider which is used by string dictionary reader to - * get the row index stride index where next() happens - */ - virtual const StrideIndexProvider &getStrideIndexProvider() const = 0; -}; - -class ColumnReader { - -public: - ColumnReader(const EncodingKey &ek, StripeStreams &stripe); - - virtual ~ColumnReader() = default; - - /** - * Skip number of specified rows. - * @param numValues the number of values to skip - * @return the number of non-null values skipped - */ - virtual uint64_t skip(uint64_t numValues); - - /** - * Read the next group of values into a RowVector. - * @param numValues the number of values to read - * @param vector to read into - */ - virtual void next(uint64_t numValues, VectorPtr &result, const uint64_t *nulls = nullptr) = 0; -}; - -class SelectiveColumnReader : public ColumnReader { -public: - static constexpr uint64_t kStringBufferSize = 16 * 1024; - - SelectiveColumnReader(const EncodingKey &ek, StripeStreams &stripe, common::ScanSpec *scanSpec); - - /** - * Read the next group of values into a RowVector. - * @param numValues the number of values to read - * @param vector to read into - */ - void next(uint64_t /*numValues*/, VectorPtr & /*result*/, const uint64_t * /*incomingNulls*/) override { - DATALIB_CHECK(false) << "next() is only defined in SelectiveStructColumnReader"; - } - - // Creates a reader for the given stripe. - static unique_ptr build(const std::shared_ptr &requestedType, - const std::shared_ptr &dataType, - StripeStreams &stripe, common::ScanSpec *scanSpec, - uint32_t sequence = 0); - - // Seeks to offset and reads the rows in 'rows' and applies - // filters and value processing as given by 'scanSpec supplied at - // construction. 'offset' is relative to start of stripe. 'rows' are - // relative to 'offset', so that row 0 is the 'offset'th row from - // start of stripe. 'rows' is expected to stay constant - // between this and the next call to read. - virtual void read(vector_size_t offset, RowSet rows, const uint64_t *incomingNulls) = 0; - - // Extracts the values at 'rows' into '*result'. May rewrite or - // reallocate '*result'. 'rows' must be the same set or a subset of - // 'rows' passed to the last 'read(). - virtual void getValues(RowSet rows, VectorPtr *result) = 0; - - // Returns the rows that were selected/visited by the last - // read(). If 'this' has no filter, returns 'rows' passed to last - // read(). - const RowSet outputRows() const { - if (scanSpec_->hasFilter()) { - return outputRows_; - } - return inputRows_; - } - - // Advances to 'offset', so that the next item to be read is the - // offset-th from the start of stripe. - void seekTo(vector_size_t offset, bool readsNullsOnly); - - // The below functions are called from ColumnVisitor to fill the result set. - inline void addOutputRow(vector_size_t row) { - outputRows_.push_back(row); - } - - template - inline void addNull() { - DATALIB_DCHECK(rawResultNulls_ && rawValues_ && (numValues_ + 1) * sizeof(T) < rawSize_); - - anyNulls_ = true; - bits::setBit(rawResultNulls_, numValues_); - reinterpret_cast(rawValues_)[numValues_] = T(); - numValues_++; - } - - template - inline void addValue(const T value) { - // @lint-ignore-every HOWTOEVEN ConstantArgumentPassByValue - static_assert(std::is_pod::value, "General case of addValue is only for primitive types"); - DATALIB_DCHECK(rawValues_ && (numValues _ + 1) * sizeof(T) < rawSize_); - reinterpret_cast(rawValues_)[numValues_] = value; - numValues_++; - } - - void dropResults(vector_size_t count) { - outputRows_.resize(outputRows_.size() - count); - numValues_ -= count; - } - - common::ScanSpec *scanSpec() const { - return scanSpec_; - } - - auto readOffset() const { - return readOffset_; - } - - void setReadOffset(vector_size_t readOffset) { - readOffset_ = readOffset; - } - -protected: - static constexpr int8_t kNoValueSize = -1; - - template - void ensureValuesCapacity(vector_size_t numRows); - - void prepareNulls(vector_size_t numRows, bool needNulls); - - template - void filterNulls(RowSet rows, bool isNull, bool extractValues); - - template - void prepareRead(vector_size_t offset, RowSet rows, const uint64_t *incomingNulls); - - void setOutputRows(RowSet rows) { - outputRows_.resize(rows.size()); - if (!rows.size()) { - return; - } - memcpy(outputRows_.data(), &rows[0], rows.size() * sizeof(vector_size_t)); - } - template - void getFlatValues(RowSet rows, VectorPtr *result); - - template - void compactScalarValues(RowSet rows); - - void addStringValue(folly::StringPiece value); - - // Specification of filters, value extraction, pruning etc. The - // spec is assigned at construction and the contents may change at - // run time based on adaptation. Owned by caller. - common::ScanSpec *const scanSpec_; - // Row number after last read row, relative to stripe start. - vector_size_t readOffset_ = 0; - // The rows to process in read(). References memory supplied by - // caller. The values must remain live until the next call to read(). - RowSet inputRows_; - // Rows passing the filter in readWithVisitor. Must stay - // constant between consecutive calls to read(). - vector outputRows_; - // The row number corresponding to each element in 'values_' - vector valueRows_; - // The set of all nulls in the range of read(). Created when first - // needed and then reused. Not returned to callers. - BufferPtr nullsInReadRange_; - // Nulls buffer for readWithVisitor. Not set if no nulls. 'numValues' - // is the index of the first non-set bit. - BufferPtr resultNulls_; - uint64_t *rawResultNulls_ = nullptr; - // Buffer for gathering scalar values in readWithVisitor. - BufferPtr values_; - // Writable content in 'values' - void *rawValues_ = nullptr; - vector_size_t numValues_ = 0; - // Size of fixed width value in 'rawValues'. For integers, values - // are read at 64 bit width and can be compacted or extracted at a - // different width. - int8_t valueSize_ = kNoValueSize; - // Buffers backing the StringViews in 'values' when reading strings. - vector stringBuffers_; - // Writable contents of 'stringBuffers_.back()'. - char *rawStringBuffer_ = nullptr; - // Total writable bytes in 'rawStringBuffer_'. - int32_t rawStringSize_ = 0; - // Number of written bytes in 'rawStringBuffer_'. - uint32_t rawStringUsed_ = 0; - - // True if last read() added any nulls. - bool anyNulls_ = false; - // True if all values in scope for last read() are null. - bool allNull_ = false; -}; - -struct ExtractValues { - static constexpr bool kSkipNulls = false; - - bool acceptsNulls() const { - return true; - } - - template - void addValue(vector_size_t /*rowIndex*/, V /*value*/) { - } - void addNull(vector_size_t /*rowIndex*/) { - } -}; - -class Filter { -protected: - Filter(bool deterministic, bool nullAllowed, FilterKind kind) - : nullAllowed_(nullAllowed), deterministic_(deterministic), kind_(kind) { - } - -public: - virtual ~Filter() = default; - - // Templates parametrized on filter need to know determinism at compile - // time. If this is false, deterministic() will be consulted at - // runtime. - static constexpr bool deterministic = true; - - FilterKind kind() const { - return kind_; - } - - virtual unique_ptr clone() const = 0; - - /** - * A filter becomes non-deterministic when applies to nested column, - * e.g. a[1] > 10 is non-deterministic because > 10 filter applies only to - * some positions, e.g. first entry in a set of entries that correspond to a - * single top-level position. - */ - virtual bool isDeterministic() const { - return deterministic_; - } - - /** - * When a filter applied to a nested column fails, the whole top-level - * position should fail. To enable this functionality, the filter keeps track - * of the boundaries of top-level positions and allows the caller to find out - * where the current top-level position started and how far it continues. - * @return number of positions from the start of the current top-level - * position up to the current position (excluding current position) - */ - virtual int getPrecedingPositionsToFail() const { - return 0; - } - - /** - * @return number of positions remaining until the end of the current - * top-level position - */ - virtual int getSucceedingPositionsToFail() const { - return 0; - } - - virtual bool testNull() const { - return nullAllowed_; - } - - /** - * Used to apply is [not] null filters to complex types, e.g. - * a[1] is null AND a[3] is not null, where a is an array(array(T)). - * - * In these case, the exact values are not known, but it is known whether they - * are null or not. Furthermore, for some positions only nulls are allowed - * (a[1] is null), for others only non-nulls (a[3] is not null), and for the - * rest both are allowed (a[2] and a[N], where N > 3). - */ - virtual bool testNonNull() const { - DWIO_RAISE("not supported"); - } - - virtual bool testInt64(int64_t /* unused */) const { - DWIO_RAISE("not supported"); - } - - virtual bool testDouble(double /* unused */) const { - DWIO_RAISE("not supported"); - } - - virtual bool testFloat(float /* unused */) const { - DWIO_RAISE("not supported"); - } - - virtual bool testBool(bool /* unused */) const { - DWIO_RAISE("not supported"); - } - - virtual bool testBytes(const char * /* unused */, int32_t /* unused */) const { - DWIO_RAISE("not supported"); - } - - /** - * Filters like string equality and IN, as well as conditions on cardinality - * of lists and maps can be at least partly decided by looking at lengths - * alone. If this is false, then no further checks are needed. If true, - * eventual filters on the data itself need to be evaluated. - */ - virtual bool testLength(int32_t /* unused */) const { - DWIO_RAISE("not supported"); - } - -protected: - const bool nullAllowed_; - -private: - const bool deterministic_; - const FilterKind kind_; -}; - -// Template parameter for controlling filtering and action on a set of rows. -template -class ColumnVisitor { -public: - using FilterType = TFilter; - static constexpr bool dense = isDense; - ColumnVisitor(TFilter &filter, SelectiveColumnReader *reader, const RowSet &rows, ExtractValues values) - : filter_(filter), reader_(reader), allowNulls_(!TFilter::deterministic || filter.testNull()), rows_(&rows[0]), - numRows_(rows.size()), rowIndex_(0), values_(values) { - } - - bool allowNulls() { - if (ExtractValues::kSkipNulls && TFilter::deterministic) { - return false; - } - return allowNulls_ && values_.acceptsNulls(); - } - - vector_size_t start() { - return isDense ? 0 : rowAt(0); - } - - // Tests for a null value and processes it. If the value is not - // null, returns 0 and has no effect. If the value is null, advances - // to the next non-null value in 'rows_'. Returns the number of - // values (not including nulls) to skip to get to the next non-null. - // If there is no next non-null in 'rows_', sets 'atEnd'. If 'atEnd' - // is set and a non-zero skip is returned, the caller must perform - // the skip before returning. - FOLLY_ALWAYS_INLINE vector_size_t checkAndSkipNulls(const uint64_t *nulls, vector_size_t ¤t, bool &atEnd) { - auto testRow = currentRow(); - // Check that the caller and the visitor are in sync about current row. - DATALIB_DCHECK(current == testRow); - uint32_t nullIndex = testRow >> 6; - uint64_t nullWord = nulls[nullIndex]; - if (!nullWord) { - return 0; - } - uint8_t nullBit = testRow & 63; - if ((nullWord & (1UL << nullBit)) == 0) { - return 0; - } - // We have a null. We find the next non-null. - if (++rowIndex_ >= numRows_) { - atEnd = true; - return 0; - } - auto rowOfNullWord = testRow - nullBit; - if (isDense) { - if (nullBit == 63) { - nullBit = 0; - rowOfNullWord += 64; - nullWord = nulls[++nullIndex]; - } else { - ++nullBit; - // set all the bits below the row to null. - nullWord |= f4d::bits::lowMask(nullBit); - } - for (;;) { - auto nextNonNull = count_trailing_zeros(~nullWord); - if (rowOfNullWord + nextNonNull >= numRows_) { - // Nulls all the way to the end. - atEnd = true; - return 0; - } - if (nextNonNull < 64) { - DATALIB_CHECK(rowIndex_ <= rowOfNullWord + nextNonNull); - rowIndex_ = rowOfNullWord + nextNonNull; - current = currentRow(); - return 0; - } - rowOfNullWord += 64; - nullWord = nulls[++nullIndex]; - } - } else { - // Sparse row numbers. We find the first non-null and count - // how many non-nulls on rows not in 'rows_' we skipped. - int32_t toSkip = 0; - nullWord |= f4d::bits::lowMask(nullBit); - for (;;) { - testRow = currentRow(); - while (testRow >= rowOfNullWord + 64) { - toSkip += __builtin_popcountll(~nullWord); - nullWord = nulls[++nullIndex]; - rowOfNullWord += 64; - } - // testRow is inside nullWord. See if non-null. - nullBit = testRow & 63; - if ((nullWord & (1UL << nullBit)) == 0) { - toSkip += __builtin_popcountll(~nullWord & f4d::bits::lowMask(nullBit)); - current = testRow; - return toSkip; - } - if (++rowIndex_ >= numRows_) { - // We end with a null. Add the non-nulls below the final null. - toSkip += __builtin_popcountll(~nullWord & f4d::bits::lowMask(testRow - rowOfNullWord)); - atEnd = true; - return toSkip; - } - } - } - } - - vector_size_t processNull(bool &atEnd) { - vector_size_t previous = currentRow(); - if (filter_.testNull()) { - filterPassedForNull(); - } else { - filterFailed(); - } - if (++rowIndex_ >= numRows_) { - atEnd = true; - return rows_[numRows_ - 1] - previous; - } - if (TFilter::deterministic && isDense) { - return 0; - } - return currentRow() - previous - 1; - } - - FOLLY_ALWAYS_INLINE vector_size_t process(T value, bool &atEnd) { - if (!TFilter::deterministic) { - auto previous = currentRow(); - if (common::applyFilter(filter_, value)) { - filterPassed(value); - } else { - filterFailed(); - } - if (++rowIndex_ >= numRows_) { - atEnd = true; - return rows_[numRows_ - 1] - previous; - } - return currentRow() - previous - 1; - } - // The filter passes or fails and we go to the next row if any. - if (common::applyFilter(filter_, value)) { - filterPassed(value); - } else { - filterFailed(); - } - if (++rowIndex_ >= numRows_) { - atEnd = true; - return 0; - } - if (isDense) { - return 0; - } - return currentRow() - rows_[rowIndex_ - 1] - 1; - } - - inline vector_size_t rowAt(vector_size_t index) { - if (isDense) { - return index; - } - return rows_[index]; - } - - vector_size_t currentRow() { - if (isDense) { - return rowIndex_; - } - return rows_[rowIndex_]; - } - - vector_size_t numRows() { - return numRows_; - } - - void filterPassed(T value) { - addResult(value); - if (!std::is_same::value) { - addOutputRow(currentRow()); - } - } - - inline void filterPassedForNull() { - addNull(); - if (!std::is_same::value) { - addOutputRow(currentRow()); - } - } - - FOLLY_ALWAYS_INLINE void filterFailed(); - inline void addResult(T value); - inline void addNull(); - inline void addOutputRow(vector_size_t row); - -protected: - TFilter &filter_; - SelectiveColumnReader *reader_; - const bool allowNulls_; - const vector_size_t *rows_; - vector_size_t numRows_; - vector_size_t rowIndex_; - ExtractValues values_; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_timestamp.hpp b/src/duckdb/extension/parquet/include/parquet_timestamp.hpp deleted file mode 100644 index 8631af997..000000000 --- a/src/duckdb/extension/parquet/include/parquet_timestamp.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_timestamp.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" - -namespace duckdb { - -struct Int96 { - uint32_t value[3]; -}; - -timestamp_t ImpalaTimestampToTimestamp(const Int96 &raw_ts); -timestamp_ns_t ImpalaTimestampToTimestampNS(const Int96 &raw_ts); -Int96 TimestampToImpalaTimestamp(timestamp_t &ts); - -timestamp_t ParquetTimestampMicrosToTimestamp(const int64_t &raw_ts); -timestamp_t ParquetTimestampMsToTimestamp(const int64_t &raw_ts); -timestamp_t ParquetTimestampNsToTimestamp(const int64_t &raw_ts); - -timestamp_ns_t ParquetTimestampMsToTimestampNs(const int64_t &raw_ms); -timestamp_ns_t ParquetTimestampUsToTimestampNs(const int64_t &raw_us); -timestamp_ns_t ParquetTimestampNsToTimestampNs(const int64_t &raw_ns); - -date_t ParquetIntToDate(const int32_t &raw_date); -dtime_t ParquetIntToTimeMs(const int32_t &raw_time); -dtime_t ParquetIntToTime(const int64_t &raw_time); -dtime_t ParquetIntToTimeNs(const int64_t &raw_time); - -dtime_tz_t ParquetIntToTimeMsTZ(const int32_t &raw_time); -dtime_tz_t ParquetIntToTimeTZ(const int64_t &raw_time); -dtime_tz_t ParquetIntToTimeNsTZ(const int64_t &raw_time); - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/parquet_writer.hpp b/src/duckdb/extension/parquet/include/parquet_writer.hpp deleted file mode 100644 index e601926b2..000000000 --- a/src/duckdb/extension/parquet/include/parquet_writer.hpp +++ /dev/null @@ -1,155 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// parquet_writer.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/common.hpp" -#include "duckdb/common/encryption_state.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/mutex.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/function/copy_function.hpp" -#endif - -#include "parquet_statistics.hpp" -#include "column_writer.hpp" -#include "parquet_types.h" -#include "geo_parquet.hpp" -#include "thrift/protocol/TCompactProtocol.h" - -namespace duckdb { -class FileSystem; -class FileOpener; -class ParquetEncryptionConfig; - -class Serializer; -class Deserializer; - -struct PreparedRowGroup { - duckdb_parquet::RowGroup row_group; - vector> states; - vector> heaps; -}; - -struct FieldID; -struct ChildFieldIDs { - ChildFieldIDs(); - ChildFieldIDs Copy() const; - unique_ptr> ids; - - void Serialize(Serializer &serializer) const; - static ChildFieldIDs Deserialize(Deserializer &source); -}; - -struct FieldID { - static constexpr const auto DUCKDB_FIELD_ID = "__duckdb_field_id"; - FieldID(); - explicit FieldID(int32_t field_id); - FieldID Copy() const; - bool set; - int32_t field_id; - ChildFieldIDs child_field_ids; - - void Serialize(Serializer &serializer) const; - static FieldID Deserialize(Deserializer &source); -}; - -struct ParquetBloomFilterEntry { - unique_ptr bloom_filter; - idx_t row_group_idx; - idx_t column_idx; -}; - -class ParquetWriter { -public: - ParquetWriter(ClientContext &context, FileSystem &fs, string file_name, vector types, - vector names, duckdb_parquet::CompressionCodec::type codec, ChildFieldIDs field_ids, - const vector> &kv_metadata, - shared_ptr encryption_config, idx_t dictionary_size_limit, - double bloom_filter_false_positive_ratio, int64_t compression_level, bool debug_use_openssl); - -public: - void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result); - void FlushRowGroup(PreparedRowGroup &row_group); - void Flush(ColumnDataCollection &buffer); - void Finalize(); - - static duckdb_parquet::Type::type DuckDBTypeToParquetType(const LogicalType &duckdb_type); - static void SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele); - - duckdb_apache::thrift::protocol::TProtocol *GetProtocol() { - return protocol.get(); - } - duckdb_parquet::CompressionCodec::type GetCodec() { - return codec; - } - duckdb_parquet::Type::type GetType(idx_t schema_idx) { - return file_meta_data.schema[schema_idx].type; - } - LogicalType GetSQLType(idx_t schema_idx) const { - return sql_types[schema_idx]; - } - BufferedFileWriter &GetWriter() { - return *writer; - } - idx_t FileSize() { - lock_guard glock(lock); - return writer->total_written; - } - idx_t DictionarySizeLimit() const { - return dictionary_size_limit; - } - double BloomFilterFalsePositiveRatio() const { - return bloom_filter_false_positive_ratio; - } - int64_t CompressionLevel() const { - return compression_level; - } - idx_t NumberOfRowGroups() { - lock_guard glock(lock); - return file_meta_data.row_groups.size(); - } - - uint32_t Write(const duckdb_apache::thrift::TBase &object); - uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size); - - GeoParquetFileMetadata &GetGeoParquetData(); - - static bool TryGetParquetType(const LogicalType &duckdb_type, - optional_ptr type = nullptr); - - void BufferBloomFilter(idx_t col_idx, unique_ptr bloom_filter); - -private: - string file_name; - vector sql_types; - vector column_names; - duckdb_parquet::CompressionCodec::type codec; - ChildFieldIDs field_ids; - shared_ptr encryption_config; - idx_t dictionary_size_limit; - double bloom_filter_false_positive_ratio; - int64_t compression_level; - bool debug_use_openssl; - shared_ptr encryption_util; - - unique_ptr writer; - std::shared_ptr protocol; - duckdb_parquet::FileMetaData file_meta_data; - std::mutex lock; - - vector> column_writers; - - unique_ptr geoparquet_data; - vector bloom_filters; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/resizable_buffer.hpp b/src/duckdb/extension/parquet/include/resizable_buffer.hpp deleted file mode 100644 index 14658ecee..000000000 --- a/src/duckdb/extension/parquet/include/resizable_buffer.hpp +++ /dev/null @@ -1,111 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// resizable_buffer.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/allocator.hpp" -#endif - -#include - -namespace duckdb { - -class ByteBuffer { // on to the 10 thousandth impl -public: - ByteBuffer() {}; - ByteBuffer(data_ptr_t ptr, uint64_t len) : ptr(ptr), len(len) {}; - - data_ptr_t ptr = nullptr; - uint64_t len = 0; - -public: - void inc(const uint64_t increment) { - available(increment); - unsafe_inc(increment); - } - - void unsafe_inc(const uint64_t increment) { - len -= increment; - ptr += increment; - } - - template - T read() { - available(sizeof(T)); - return unsafe_read(); - } - - template - T unsafe_read() { - T val = unsafe_get(); - unsafe_inc(sizeof(T)); - return val; - } - - template - T get() { - available(sizeof(T)); - return unsafe_get(); - } - - template - T unsafe_get() { - return Load(ptr); - } - - void copy_to(char *dest, const uint64_t len) const { - available(len); - unsafe_copy_to(dest, len); - } - - void unsafe_copy_to(char *dest, const uint64_t len) const { - std::memcpy(dest, ptr, len); - } - - void zero() const { - std::memset(ptr, 0, len); - } - - void available(const uint64_t req_len) const { - if (!check_available(req_len)) { - throw std::runtime_error("Out of buffer"); - } - } - - bool check_available(const uint64_t req_len) const { - return req_len <= len; - } -}; - -class ResizeableBuffer : public ByteBuffer { -public: - ResizeableBuffer() { - } - ResizeableBuffer(Allocator &allocator, const uint64_t new_size) { - resize(allocator, new_size); - } - void resize(Allocator &allocator, const uint64_t new_size) { - len = new_size; - if (new_size == 0) { - return; - } - if (new_size > alloc_len) { - alloc_len = NextPowerOfTwo(new_size); - allocated_data = allocator.Allocate(alloc_len); - ptr = allocated_data.get(); - } - } - -private: - AllocatedData allocated_data; - idx_t alloc_len = 0; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/row_number_column_reader.hpp b/src/duckdb/extension/parquet/include/row_number_column_reader.hpp deleted file mode 100644 index cdd5df1f3..000000000 --- a/src/duckdb/extension/parquet/include/row_number_column_reader.hpp +++ /dev/null @@ -1,55 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// row_number_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/limits.hpp" -#endif -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -//! Reads a file-absolute row number as a virtual column that's not actually stored in the file -class RowNumberColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::INT64; - -public: - RowNumberColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p); - -public: - idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result) override; - - unique_ptr Stats(idx_t row_group_idx_p, const vector &columns) override; - - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; - - void Skip(idx_t num_values) override { - row_group_offset += num_values; - } - idx_t GroupRowsAvailable() override { - return NumericLimits::Maximum(); - }; - uint64_t TotalCompressedSize() override { - return 0; - } - idx_t FileOffset() const override { - return 0; - } - void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override { - } - -private: - idx_t row_group_offset; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/string_column_reader.hpp b/src/duckdb/extension/parquet/include/string_column_reader.hpp deleted file mode 100644 index 2ab96a296..000000000 --- a/src/duckdb/extension/parquet/include/string_column_reader.hpp +++ /dev/null @@ -1,45 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// string_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" - -namespace duckdb { - -struct StringParquetValueConversion { - static string_t PlainRead(ByteBuffer &plain_data, ColumnReader &reader); - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader); - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count); - static string_t UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader); - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader); -}; - -class StringColumnReader : public TemplatedColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::VARCHAR; - -public: - StringColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p); - idx_t fixed_width_string_length; - idx_t delta_offset = 0; - -public: - void PrepareDeltaLengthByteArray(ResizeableBuffer &buffer) override; - void PrepareDeltaByteArray(ResizeableBuffer &buffer) override; - void DeltaByteArray(uint8_t *defines, idx_t num_values, parquet_filter_t &filter, idx_t result_offset, - Vector &result) override; - static uint32_t VerifyString(const char *str_data, uint32_t str_len, const bool isVarchar); - uint32_t VerifyString(const char *str_data, uint32_t str_len); - -protected: - void PlainReference(shared_ptr plain_data, Vector &result) override; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/struct_column_reader.hpp b/src/duckdb/extension/parquet/include/struct_column_reader.hpp deleted file mode 100644 index 4a0254695..000000000 --- a/src/duckdb/extension/parquet/include/struct_column_reader.hpp +++ /dev/null @@ -1,40 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// struct_column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "templated_column_reader.hpp" - -namespace duckdb { - -class StructColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::STRUCT; - -public: - StructColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p, vector> child_readers_p); - - vector> child_readers; - -public: - ColumnReader &GetChildReader(idx_t child_idx); - - void InitializeRead(idx_t row_group_idx_p, const vector &columns, TProtocol &protocol_p) override; - - idx_t Read(uint64_t num_values, parquet_filter_t &filter, data_ptr_t define_out, data_ptr_t repeat_out, - Vector &result) override; - - void Skip(idx_t num_values) override; - idx_t GroupRowsAvailable() override; - uint64_t TotalCompressedSize() override; - void RegisterPrefetch(ThriftFileTransport &transport, bool allow_merge) override; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/templated_column_reader.hpp b/src/duckdb/extension/parquet/include/templated_column_reader.hpp deleted file mode 100644 index d85865309..000000000 --- a/src/duckdb/extension/parquet/include/templated_column_reader.hpp +++ /dev/null @@ -1,93 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// templated__column_reader.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "column_reader.hpp" -#include "duckdb/common/helper.hpp" - -namespace duckdb { - -template -struct TemplatedParquetValueConversion { - - static VALUE_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return plain_data.read(); - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.inc(sizeof(VALUE_TYPE)); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return plain_data.check_available(count * sizeof(VALUE_TYPE)); - } - - static VALUE_TYPE UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return plain_data.unsafe_read(); - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.unsafe_inc(sizeof(VALUE_TYPE)); - } -}; - -template -class TemplatedColumnReader : public ColumnReader { -public: - static constexpr const PhysicalType TYPE = PhysicalType::INVALID; - -public: - TemplatedColumnReader(ParquetReader &reader, LogicalType type_p, const SchemaElement &schema_p, idx_t schema_idx_p, - idx_t max_define_p, idx_t max_repeat_p) - : ColumnReader(reader, std::move(type_p), schema_p, schema_idx_p, max_define_p, max_repeat_p) {}; - - shared_ptr dict; - -public: - void AllocateDict(idx_t size) { - if (!dict) { - dict = make_shared_ptr(GetAllocator(), size); - } else { - dict->resize(GetAllocator(), size); - } - } - - void Plain(shared_ptr plain_data, uint8_t *defines, uint64_t num_values, parquet_filter_t *filter, - idx_t result_offset, Vector &result) override { - PlainTemplated(std::move(plain_data), defines, num_values, filter, result_offset, - result); - } -}; - -template -struct CallbackParquetValueConversion { - - static DUCKDB_PHYSICAL_TYPE PlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return FUNC(plain_data.read()); - } - - static void PlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.inc(sizeof(PARQUET_PHYSICAL_TYPE)); - } - - static bool PlainAvailable(const ByteBuffer &plain_data, const idx_t count) { - return plain_data.check_available(count * sizeof(PARQUET_PHYSICAL_TYPE)); - } - - static DUCKDB_PHYSICAL_TYPE UnsafePlainRead(ByteBuffer &plain_data, ColumnReader &reader) { - return FUNC(plain_data.unsafe_read()); - } - - static void UnsafePlainSkip(ByteBuffer &plain_data, ColumnReader &reader) { - plain_data.unsafe_inc(sizeof(PARQUET_PHYSICAL_TYPE)); - } -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/thrift_tools.hpp b/src/duckdb/extension/parquet/include/thrift_tools.hpp deleted file mode 100644 index de1eaca34..000000000 --- a/src/duckdb/extension/parquet/include/thrift_tools.hpp +++ /dev/null @@ -1,219 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// thrift_tools.hpp -// -// -//===----------------------------------------------------------------------===/ - -#pragma once - -#include -#include "thrift/protocol/TCompactProtocol.h" -#include "thrift/transport/TBufferTransports.h" - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/allocator.hpp" -#endif - -namespace duckdb { - -// A ReadHead for prefetching data in a specific range -struct ReadHead { - ReadHead(idx_t location, uint64_t size) : location(location), size(size) {}; - // Hint info - idx_t location; - uint64_t size; - - // Current info - AllocatedData data; - bool data_isset = false; - - idx_t GetEnd() const { - return size + location; - } - - void Allocate(Allocator &allocator) { - data = allocator.Allocate(size); - } -}; - -// Comparator for ReadHeads that are either overlapping, adjacent, or within ALLOW_GAP bytes from each other -struct ReadHeadComparator { - static constexpr uint64_t ALLOW_GAP = 1 << 14; // 16 KiB - bool operator()(const ReadHead *a, const ReadHead *b) const { - auto a_start = a->location; - auto a_end = a->location + a->size; - auto b_start = b->location; - - if (a_end <= NumericLimits::Maximum() - ALLOW_GAP) { - a_end += ALLOW_GAP; - } - - return a_start < b_start && a_end < b_start; - } -}; - -// Two-step read ahead buffer -// 1: register all ranges that will be read, merging ranges that are consecutive -// 2: prefetch all registered ranges -struct ReadAheadBuffer { - ReadAheadBuffer(Allocator &allocator, FileHandle &handle) : allocator(allocator), handle(handle) { - } - - // The list of read heads - std::list read_heads; - // Set for merging consecutive ranges - std::set merge_set; - - Allocator &allocator; - FileHandle &handle; - - idx_t total_size = 0; - - // Add a read head to the prefetching list - void AddReadHead(idx_t pos, uint64_t len, bool merge_buffers = true) { - // Attempt to merge with existing - if (merge_buffers) { - ReadHead new_read_head {pos, len}; - auto lookup_set = merge_set.find(&new_read_head); - if (lookup_set != merge_set.end()) { - auto existing_head = *lookup_set; - auto new_start = MinValue(existing_head->location, new_read_head.location); - auto new_length = MaxValue(existing_head->GetEnd(), new_read_head.GetEnd()) - new_start; - existing_head->location = new_start; - existing_head->size = new_length; - return; - } - } - - read_heads.emplace_front(ReadHead(pos, len)); - total_size += len; - auto &read_head = read_heads.front(); - - if (merge_buffers) { - merge_set.insert(&read_head); - } - - if (read_head.GetEnd() > handle.GetFileSize()) { - throw std::runtime_error("Prefetch registered for bytes outside file: " + handle.GetPath() + - ", attempted range: [" + std::to_string(pos) + ", " + - std::to_string(read_head.GetEnd()) + - "), file size: " + std::to_string(handle.GetFileSize())); - } - } - - // Returns the relevant read head - ReadHead *GetReadHead(idx_t pos) { - for (auto &read_head : read_heads) { - if (pos >= read_head.location && pos < read_head.GetEnd()) { - return &read_head; - } - } - return nullptr; - } - - // Prefetch all read heads - void Prefetch() { - for (auto &read_head : read_heads) { - read_head.Allocate(allocator); - - if (read_head.GetEnd() > handle.GetFileSize()) { - throw std::runtime_error("Prefetch registered requested for bytes outside file"); - } - - handle.Read(read_head.data.get(), read_head.size, read_head.location); - read_head.data_isset = true; - } - } -}; - -class ThriftFileTransport : public duckdb_apache::thrift::transport::TVirtualTransport { -public: - static constexpr uint64_t PREFETCH_FALLBACK_BUFFERSIZE = 1000000; - - ThriftFileTransport(Allocator &allocator, FileHandle &handle_p, bool prefetch_mode_p) - : handle(handle_p), location(0), allocator(allocator), ra_buffer(ReadAheadBuffer(allocator, handle_p)), - prefetch_mode(prefetch_mode_p) { - } - - uint32_t read(uint8_t *buf, uint32_t len) { - auto prefetch_buffer = ra_buffer.GetReadHead(location); - if (prefetch_buffer != nullptr && location - prefetch_buffer->location + len <= prefetch_buffer->size) { - D_ASSERT(location - prefetch_buffer->location + len <= prefetch_buffer->size); - - if (!prefetch_buffer->data_isset) { - prefetch_buffer->Allocate(allocator); - handle.Read(prefetch_buffer->data.get(), prefetch_buffer->size, prefetch_buffer->location); - prefetch_buffer->data_isset = true; - } - memcpy(buf, prefetch_buffer->data.get() + location - prefetch_buffer->location, len); - } else { - if (prefetch_mode && len < PREFETCH_FALLBACK_BUFFERSIZE && len > 0) { - Prefetch(location, MinValue(PREFETCH_FALLBACK_BUFFERSIZE, handle.GetFileSize() - location)); - auto prefetch_buffer_fallback = ra_buffer.GetReadHead(location); - D_ASSERT(location - prefetch_buffer_fallback->location + len <= prefetch_buffer_fallback->size); - memcpy(buf, prefetch_buffer_fallback->data.get() + location - prefetch_buffer_fallback->location, len); - } else { - handle.Read(buf, len, location); - } - } - location += len; - return len; - } - - // Prefetch a single buffer - void Prefetch(idx_t pos, uint64_t len) { - RegisterPrefetch(pos, len, false); - FinalizeRegistration(); - PrefetchRegistered(); - } - - // Register a buffer for prefixing - void RegisterPrefetch(idx_t pos, uint64_t len, bool can_merge = true) { - ra_buffer.AddReadHead(pos, len, can_merge); - } - - // Prevents any further merges, should be called before PrefetchRegistered - void FinalizeRegistration() { - ra_buffer.merge_set.clear(); - } - - // Prefetch all previously registered ranges - void PrefetchRegistered() { - ra_buffer.Prefetch(); - } - - void ClearPrefetch() { - ra_buffer.read_heads.clear(); - ra_buffer.merge_set.clear(); - } - - void SetLocation(idx_t location_p) { - location = location_p; - } - - idx_t GetLocation() { - return location; - } - idx_t GetSize() { - return handle.file_system.GetFileSize(handle); - } - -private: - FileHandle &handle; - idx_t location; - - Allocator &allocator; - - // Multi-buffer prefetch - ReadAheadBuffer ra_buffer; - - // Whether the prefetch mode is enabled. In this mode the DirectIO flag of the handle will be set and the parquet - // reader will manage the read buffering. - bool prefetch_mode; -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/include/zstd_file_system.hpp b/src/duckdb/extension/parquet/include/zstd_file_system.hpp deleted file mode 100644 index 5b132bc8a..000000000 --- a/src/duckdb/extension/parquet/include/zstd_file_system.hpp +++ /dev/null @@ -1,35 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// zstd_file_system.hpp -// -// -//===----------------------------------------------------------------------===// - -#pragma once - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/compressed_file_system.hpp" -#endif - -namespace duckdb { - -class ZStdFileSystem : public CompressedFileSystem { -public: - unique_ptr OpenCompressedFile(unique_ptr handle, bool write) override; - - std::string GetName() const override { - return "ZStdFileSystem"; - } - - unique_ptr CreateStream() override; - idx_t InBufferSize() override; - idx_t OutBufferSize() override; - - static int64_t DefaultCompressionLevel(); - static int64_t MinimumCompressionLevel(); - static int64_t MaximumCompressionLevel(); -}; - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_crypto.cpp b/src/duckdb/extension/parquet/parquet_crypto.cpp deleted file mode 100644 index 070c381d7..000000000 --- a/src/duckdb/extension/parquet/parquet_crypto.cpp +++ /dev/null @@ -1,420 +0,0 @@ -#include "parquet_crypto.hpp" - -#include "mbedtls_wrapper.hpp" -#include "thrift_tools.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/storage/arena_allocator.hpp" -#endif - -namespace duckdb { - -ParquetKeys &ParquetKeys::Get(ClientContext &context) { - auto &cache = ObjectCache::GetObjectCache(context); - if (!cache.Get(ParquetKeys::ObjectType())) { - cache.Put(ParquetKeys::ObjectType(), make_shared_ptr()); - } - return *cache.Get(ParquetKeys::ObjectType()); -} - -void ParquetKeys::AddKey(const string &key_name, const string &key) { - keys[key_name] = key; -} - -bool ParquetKeys::HasKey(const string &key_name) const { - return keys.find(key_name) != keys.end(); -} - -const string &ParquetKeys::GetKey(const string &key_name) const { - D_ASSERT(HasKey(key_name)); - return keys.at(key_name); -} - -string ParquetKeys::ObjectType() { - return "parquet_keys"; -} - -string ParquetKeys::GetObjectType() { - return ObjectType(); -} - -ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context_p) : context(context_p) { -} - -ParquetEncryptionConfig::ParquetEncryptionConfig(ClientContext &context_p, const Value &arg) - : ParquetEncryptionConfig(context_p) { - - if (arg.type().id() != LogicalTypeId::STRUCT) { - throw BinderException("Parquet encryption_config must be of type STRUCT"); - } - const auto &child_types = StructType::GetChildTypes(arg.type()); - auto &children = StructValue::GetChildren(arg); - const auto &keys = ParquetKeys::Get(context); - for (idx_t i = 0; i < StructType::GetChildCount(arg.type()); i++) { - auto &struct_key = child_types[i].first; - if (StringUtil::Lower(struct_key) == "footer_key") { - const auto footer_key_name = StringValue::Get(children[i].DefaultCastAs(LogicalType::VARCHAR)); - if (!keys.HasKey(footer_key_name)) { - throw BinderException( - "No key with name \"%s\" exists. Add it with PRAGMA add_parquet_key('','');", - footer_key_name); - } - footer_key = footer_key_name; - } else if (StringUtil::Lower(struct_key) == "column_keys") { - throw NotImplementedException("Parquet encryption_config column_keys not yet implemented"); - } else { - throw BinderException("Unknown key in encryption_config \"%s\"", struct_key); - } - } -} - -shared_ptr ParquetEncryptionConfig::Create(ClientContext &context, const Value &arg) { - return shared_ptr(new ParquetEncryptionConfig(context, arg)); -} - -const string &ParquetEncryptionConfig::GetFooterKey() const { - const auto &keys = ParquetKeys::Get(context); - D_ASSERT(!footer_key.empty()); - D_ASSERT(keys.HasKey(footer_key)); - return keys.GetKey(footer_key); -} - -using duckdb_apache::thrift::protocol::TCompactProtocolFactoryT; -using duckdb_apache::thrift::transport::TTransport; - -//! Encryption wrapper for a transport protocol -class EncryptionTransport : public TTransport { -public: - EncryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) - : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState()), - allocator(Allocator::DefaultAllocator(), ParquetCrypto::CRYPTO_BLOCK_SIZE) { - Initialize(key); - } - - bool isOpen() const override { - return trans.isOpen(); - } - - void open() override { - trans.open(); - } - - void close() override { - trans.close(); - } - - void write_virt(const uint8_t *buf, uint32_t len) override { - memcpy(allocator.Allocate(len), buf, len); - } - - uint32_t Finalize() { - // Write length - const auto ciphertext_length = allocator.SizeInBytes(); - const uint32_t total_length = ParquetCrypto::NONCE_BYTES + ciphertext_length + ParquetCrypto::TAG_BYTES; - - trans.write(const_data_ptr_cast(&total_length), ParquetCrypto::LENGTH_BYTES); - // Write nonce at beginning of encrypted chunk - trans.write(nonce, ParquetCrypto::NONCE_BYTES); - - data_t aes_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE]; - auto current = allocator.GetTail(); - - // Loop through the whole chunk - while (current != nullptr) { - for (idx_t pos = 0; pos < current->current_position; pos += ParquetCrypto::CRYPTO_BLOCK_SIZE) { - auto next = MinValue(current->current_position - pos, ParquetCrypto::CRYPTO_BLOCK_SIZE); - auto write_size = - aes->Process(current->data.get() + pos, next, aes_buffer, ParquetCrypto::CRYPTO_BLOCK_SIZE); - trans.write(aes_buffer, write_size); - } - current = current->prev; - } - - // Finalize the last encrypted data - data_t tag[ParquetCrypto::TAG_BYTES]; - auto write_size = aes->Finalize(aes_buffer, 0, tag, ParquetCrypto::TAG_BYTES); - trans.write(aes_buffer, write_size); - // Write tag for verification - trans.write(tag, ParquetCrypto::TAG_BYTES); - - return ParquetCrypto::LENGTH_BYTES + total_length; - } - -private: - void Initialize(const string &key) { - // Generate Nonce - aes->GenerateRandomData(nonce, ParquetCrypto::NONCE_BYTES); - // Initialize Encryption - aes->InitializeEncryption(nonce, ParquetCrypto::NONCE_BYTES, &key); - } - -private: - //! Protocol and corresponding transport that we're wrapping - TProtocol &prot; - TTransport &trans; - - //! AES context and buffers - shared_ptr aes; - - //! Nonce created by Initialize() - data_t nonce[ParquetCrypto::NONCE_BYTES]; - - //! Arena Allocator to fully materialize in memory before encrypting - ArenaAllocator allocator; -}; - -//! Decryption wrapper for a transport protocol -class DecryptionTransport : public TTransport { -public: - DecryptionTransport(TProtocol &prot_p, const string &key, const EncryptionUtil &encryption_util_p) - : prot(prot_p), trans(*prot.getTransport()), aes(encryption_util_p.CreateEncryptionState()), - read_buffer_size(0), read_buffer_offset(0) { - Initialize(key); - } - uint32_t read_virt(uint8_t *buf, uint32_t len) override { - const uint32_t result = len; - - if (len > transport_remaining - ParquetCrypto::TAG_BYTES + read_buffer_size - read_buffer_offset) { - throw InvalidInputException("Too many bytes requested from crypto buffer"); - } - - while (len != 0) { - if (read_buffer_offset == read_buffer_size) { - ReadBlock(buf); - } - const auto next = MinValue(read_buffer_size - read_buffer_offset, len); - read_buffer_offset += next; - buf += next; - len -= next; - } - - return result; - } - - uint32_t Finalize() { - - if (read_buffer_offset != read_buffer_size) { - throw InternalException("DecryptionTransport::Finalize was called with bytes remaining in read buffer: \n" - "read buffer offset: %d, read buffer size: %d", - read_buffer_offset, read_buffer_size); - } - - data_t computed_tag[ParquetCrypto::TAG_BYTES]; - - if (aes->IsOpenSSL()) { - // For OpenSSL, the obtained tag is an input argument for aes->Finalize() - transport_remaining -= trans.read(computed_tag, ParquetCrypto::TAG_BYTES); - if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { - throw InternalException( - "DecryptionTransport::Finalize was called with bytes remaining in AES context out"); - } - } else { - // For mbedtls, computed_tag is an output argument for aes->Finalize() - if (aes->Finalize(read_buffer, 0, computed_tag, ParquetCrypto::TAG_BYTES) != 0) { - throw InternalException( - "DecryptionTransport::Finalize was called with bytes remaining in AES context out"); - } - VerifyTag(computed_tag); - } - - if (transport_remaining != 0) { - throw InvalidInputException("Encoded ciphertext length differs from actual ciphertext length"); - } - - return ParquetCrypto::LENGTH_BYTES + total_bytes; - } - - AllocatedData ReadAll() { - D_ASSERT(transport_remaining == total_bytes - ParquetCrypto::NONCE_BYTES); - auto result = Allocator::DefaultAllocator().Allocate(transport_remaining - ParquetCrypto::TAG_BYTES); - read_virt(result.get(), transport_remaining - ParquetCrypto::TAG_BYTES); - Finalize(); - return result; - } - -private: - void Initialize(const string &key) { - // Read encoded length (don't add to read_bytes) - data_t length_buf[ParquetCrypto::LENGTH_BYTES]; - trans.read(length_buf, ParquetCrypto::LENGTH_BYTES); - total_bytes = Load(length_buf); - transport_remaining = total_bytes; - // Read nonce and initialize AES - transport_remaining -= trans.read(nonce, ParquetCrypto::NONCE_BYTES); - // check whether context is initialized - aes->InitializeDecryption(nonce, ParquetCrypto::NONCE_BYTES, &key); - } - - void ReadBlock(uint8_t *buf) { - // Read from transport into read_buffer at one AES block size offset (up to the tag) - read_buffer_size = MinValue(ParquetCrypto::CRYPTO_BLOCK_SIZE, transport_remaining - ParquetCrypto::TAG_BYTES); - transport_remaining -= trans.read(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size); - - // Decrypt from read_buffer + block size into read_buffer start (decryption can trail behind in same buffer) -#ifdef DEBUG - auto size = aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, - ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); - D_ASSERT(size == read_buffer_size); -#else - aes->Process(read_buffer + ParquetCrypto::BLOCK_SIZE, read_buffer_size, buf, - ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE); -#endif - read_buffer_offset = 0; - } - - void VerifyTag(data_t *computed_tag) { - data_t read_tag[ParquetCrypto::TAG_BYTES]; - transport_remaining -= trans.read(read_tag, ParquetCrypto::TAG_BYTES); - if (memcmp(computed_tag, read_tag, ParquetCrypto::TAG_BYTES) != 0) { - throw InvalidInputException("Computed AES tag differs from read AES tag, are you using the right key?"); - } - } - -private: - //! Protocol and corresponding transport that we're wrapping - TProtocol &prot; - TTransport &trans; - - //! AES context and buffers - shared_ptr aes; - - //! We read/decrypt big blocks at a time - data_t read_buffer[ParquetCrypto::CRYPTO_BLOCK_SIZE + ParquetCrypto::BLOCK_SIZE]; - uint32_t read_buffer_size; - uint32_t read_buffer_offset; - - //! Remaining bytes to read, set by Initialize(), decremented by ReadBlock() - uint32_t total_bytes; - uint32_t transport_remaining; - //! Nonce read by Initialize() - data_t nonce[ParquetCrypto::NONCE_BYTES]; -}; - -class SimpleReadTransport : public TTransport { -public: - explicit SimpleReadTransport(data_ptr_t read_buffer_p, uint32_t read_buffer_size_p) - : read_buffer(read_buffer_p), read_buffer_size(read_buffer_size_p), read_buffer_offset(0) { - } - - uint32_t read_virt(uint8_t *buf, uint32_t len) override { - const auto remaining = read_buffer_size - read_buffer_offset; - if (len > remaining) { - return remaining; - } - memcpy(buf, read_buffer + read_buffer_offset, len); - read_buffer_offset += len; - return len; - } - -private: - const data_ptr_t read_buffer; - const uint32_t read_buffer_size; - uint32_t read_buffer_offset; -}; - -uint32_t ParquetCrypto::Read(TBase &object, TProtocol &iprot, const string &key, - const EncryptionUtil &encryption_util_p) { - TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key, encryption_util_p)); - auto &dtrans = reinterpret_cast(*dprot->getTransport()); - - // We have to read the whole thing otherwise thrift throws an error before we realize we're decryption is wrong - auto all = dtrans.ReadAll(); - TCompactProtocolFactoryT tsimple_proto_factory; - auto simple_prot = - tsimple_proto_factory.getProtocol(std::make_shared(all.get(), all.GetSize())); - - // Read the object - object.read(simple_prot.get()); - - return ParquetCrypto::LENGTH_BYTES + ParquetCrypto::NONCE_BYTES + all.GetSize() + ParquetCrypto::TAG_BYTES; -} - -uint32_t ParquetCrypto::Write(const TBase &object, TProtocol &oprot, const string &key, - const EncryptionUtil &encryption_util_p) { - // Create encryption protocol - TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key, encryption_util_p)); - auto &etrans = reinterpret_cast(*eprot->getTransport()); - - // Write the object in memory - object.write(eprot.get()); - - // Encrypt and write to oprot - return etrans.Finalize(); -} - -uint32_t ParquetCrypto::ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, - const string &key, const EncryptionUtil &encryption_util_p) { - // Create decryption protocol - TCompactProtocolFactoryT tproto_factory; - auto dprot = tproto_factory.getProtocol(std::make_shared(iprot, key, encryption_util_p)); - auto &dtrans = reinterpret_cast(*dprot->getTransport()); - - // Read buffer - dtrans.read(buffer, buffer_size); - - // Verify AES tag and read length - return dtrans.Finalize(); -} - -uint32_t ParquetCrypto::WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size, - const string &key, const EncryptionUtil &encryption_util_p) { - // FIXME: we know the size upfront so we could do a streaming write instead of this - // Create encryption protocol - TCompactProtocolFactoryT tproto_factory; - auto eprot = tproto_factory.getProtocol(std::make_shared(oprot, key, encryption_util_p)); - auto &etrans = reinterpret_cast(*eprot->getTransport()); - - // Write the data in memory - etrans.write(buffer, buffer_size); - - // Encrypt and write to oprot - return etrans.Finalize(); -} - -bool ParquetCrypto::ValidKey(const std::string &key) { - switch (key.size()) { - case 16: - case 24: - case 32: - return true; - default: - return false; - } -} - -string Base64Decode(const string &key) { - auto result_size = Blob::FromBase64Size(key); - auto output = duckdb::unique_ptr(new unsigned char[result_size]); - Blob::FromBase64(key, output.get(), result_size); - string decoded_key(reinterpret_cast(output.get()), result_size); - return decoded_key; -} - -void ParquetCrypto::AddKey(ClientContext &context, const FunctionParameters ¶meters) { - const auto &key_name = StringValue::Get(parameters.values[0]); - const auto &key = StringValue::Get(parameters.values[1]); - - auto &keys = ParquetKeys::Get(context); - if (ValidKey(key)) { - keys.AddKey(key_name, key); - } else { - string decoded_key; - try { - decoded_key = Base64Decode(key); - } catch (const ConversionException &e) { - throw InvalidInputException("Invalid AES key. Not a plain AES key NOR a base64 encoded string"); - } - if (!ValidKey(decoded_key)) { - throw InvalidInputException( - "Invalid AES key. Must have a length of 128, 192, or 256 bits (16, 24, or 32 bytes)"); - } - keys.AddKey(key_name, decoded_key); - } -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_extension.cpp b/src/duckdb/extension/parquet/parquet_extension.cpp deleted file mode 100644 index 500b89919..000000000 --- a/src/duckdb/extension/parquet/parquet_extension.cpp +++ /dev/null @@ -1,1758 +0,0 @@ -#define DUCKDB_EXTENSION_MAIN - -#include "parquet_extension.hpp" - -#include "cast_column_reader.hpp" -#include "duckdb.hpp" -#include "duckdb/parser/expression/positional_reference_expression.hpp" -#include "duckdb/parser/query_node/select_node.hpp" -#include "duckdb/parser/tableref/subqueryref.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/query_node/bound_select_node.hpp" -#include "geo_parquet.hpp" -#include "parquet_crypto.hpp" -#include "parquet_metadata.hpp" -#include "parquet_reader.hpp" -#include "parquet_writer.hpp" -#include "struct_column_reader.hpp" -#include "zstd_file_system.hpp" - -#include -#include -#include -#include -#include -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" -#include "duckdb/common/constants.hpp" -#include "duckdb/common/enums/file_compression_type.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/function/copy_function.hpp" -#include "duckdb/function/pragma_function.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/extension_util.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/storage/statistics/base_statistics.hpp" -#include "duckdb/storage/table/row_group.hpp" -#endif - -namespace duckdb { - -struct ParquetReadBindData : public TableFunctionData { - shared_ptr file_list; - unique_ptr multi_file_reader; - - shared_ptr initial_reader; - atomic chunk_count; - vector names; - vector types; - //! Table column names - set when using COPY tbl FROM file.parquet - vector table_columns; - - // The union readers are created (when parquet union_by_name option is on) during binding - // Those readers can be re-used during ParquetParallelStateNext - vector> union_readers; - - // These come from the initial_reader, but need to be stored in case the initial_reader is removed by a filter - idx_t initial_file_cardinality; - idx_t initial_file_row_groups; - idx_t explicit_cardinality = 0; // can be set to inject exterior cardinality knowledge (e.g. from a data lake) - ParquetOptions parquet_options; - MultiFileReaderBindData reader_bind; - - void Initialize(shared_ptr reader) { - initial_reader = std::move(reader); - initial_file_cardinality = initial_reader->NumRows(); - initial_file_row_groups = initial_reader->NumRowGroups(); - parquet_options = initial_reader->parquet_options; - } - void Initialize(ClientContext &, unique_ptr &union_data) { - Initialize(std::move(union_data->reader)); - } -}; - -struct ParquetReadLocalState : public LocalTableFunctionState { - shared_ptr reader; - ParquetReaderScanState scan_state; - bool is_parallel; - idx_t batch_index; - idx_t file_index; - //! The DataChunk containing all read columns (even columns that are immediately removed) - DataChunk all_columns; -}; - -enum class ParquetFileState : uint8_t { UNOPENED, OPENING, OPEN, CLOSED }; - -struct ParquetFileReaderData { - // Create data for an unopened file - explicit ParquetFileReaderData(const string &file_to_be_opened) - : reader(nullptr), file_state(ParquetFileState::UNOPENED), file_mutex(make_uniq()), - file_to_be_opened(file_to_be_opened) { - } - // Create data for an existing reader - explicit ParquetFileReaderData(shared_ptr reader_p) - : reader(std::move(reader_p)), file_state(ParquetFileState::OPEN), file_mutex(make_uniq()) { - } - // Create data for an existing reader - explicit ParquetFileReaderData(unique_ptr union_data_p) : file_mutex(make_uniq()) { - if (union_data_p->reader) { - reader = std::move(union_data_p->reader); - file_state = ParquetFileState::OPEN; - } else { - union_data = std::move(union_data_p); - file_state = ParquetFileState::UNOPENED; - } - } - - //! Currently opened reader for the file - shared_ptr reader; - //! Flag to indicate the file is being opened - ParquetFileState file_state; - //! Mutexes to wait for the file when it is being opened - unique_ptr file_mutex; - //! Parquet options for opening the file - unique_ptr union_data; - - //! (only set when file_state is UNOPENED) the file to be opened - string file_to_be_opened; -}; - -struct ParquetReadGlobalState : public GlobalTableFunctionState { - explicit ParquetReadGlobalState(MultiFileList &file_list_p) : file_list(file_list_p) { - } - explicit ParquetReadGlobalState(unique_ptr owned_file_list_p) - : file_list(*owned_file_list_p), owned_file_list(std::move(owned_file_list_p)) { - } - - //! The file list to scan - MultiFileList &file_list; - //! The scan over the file_list - MultiFileListScanData file_list_scan; - //! Owned multi file list - if filters have been dynamically pushed into the reader - unique_ptr owned_file_list; - - unique_ptr multi_file_reader_state; - - mutex lock; - - //! The current set of parquet readers - vector> readers; - - //! Signal to other threads that a file failed to open, letting every thread abort. - bool error_opening_file = false; - - //! Index of file currently up for scanning - atomic file_index; - //! Index of row group within file currently up for scanning - idx_t row_group_index; - //! Batch index of the next row group to be scanned - idx_t batch_index; - - idx_t max_threads; - vector projection_ids; - vector scanned_types; - vector column_indexes; - optional_ptr filters; - - idx_t MaxThreads() const override { - return max_threads; - } - - bool CanRemoveColumns() const { - return !projection_ids.empty(); - } -}; - -struct ParquetWriteBindData : public TableFunctionData { - vector sql_types; - vector column_names; - duckdb_parquet::CompressionCodec::type codec = duckdb_parquet::CompressionCodec::SNAPPY; - vector> kv_metadata; - idx_t row_group_size = DEFAULT_ROW_GROUP_SIZE; - idx_t row_group_size_bytes = NumericLimits::Maximum(); - - //! How/Whether to encrypt the data - shared_ptr encryption_config; - bool debug_use_openssl = true; - - //! After how many distinct values should we abandon dictionary compression and bloom filters? - idx_t dictionary_size_limit = row_group_size / 100; - - //! What false positive rate are we willing to accept for bloom filters - double bloom_filter_false_positive_ratio = 0.01; - - //! After how many row groups to rotate to a new file - optional_idx row_groups_per_file; - - ChildFieldIDs field_ids; - //! The compression level, higher value is more - int64_t compression_level = ZStdFileSystem::DefaultCompressionLevel(); -}; - -struct ParquetWriteGlobalState : public GlobalFunctionData { - unique_ptr writer; -}; - -struct ParquetWriteLocalState : public LocalFunctionData { - explicit ParquetWriteLocalState(ClientContext &context, const vector &types) - : buffer(BufferAllocator::Get(context), types) { - buffer.InitializeAppend(append_state); - } - - ColumnDataCollection buffer; - ColumnDataAppendState append_state; -}; - -BindInfo ParquetGetBindInfo(const optional_ptr bind_data) { - auto bind_info = BindInfo(ScanType::PARQUET); - auto &parquet_bind = bind_data->Cast(); - - vector file_path; - for (const auto &file : parquet_bind.file_list->Files()) { - file_path.emplace_back(file); - } - - // LCOV_EXCL_START - bind_info.InsertOption("file_path", Value::LIST(LogicalType::VARCHAR, file_path)); - bind_info.InsertOption("binary_as_string", Value::BOOLEAN(parquet_bind.parquet_options.binary_as_string)); - bind_info.InsertOption("file_row_number", Value::BOOLEAN(parquet_bind.parquet_options.file_row_number)); - bind_info.InsertOption("debug_use_openssl", Value::BOOLEAN(parquet_bind.parquet_options.debug_use_openssl)); - parquet_bind.parquet_options.file_options.AddBatchInfo(bind_info); - // LCOV_EXCL_STOP - return bind_info; -} - -static void ParseFileRowNumberOption(MultiFileReaderBindData &bind_data, ParquetOptions &options, - vector &return_types, vector &names) { - if (options.file_row_number) { - if (StringUtil::CIFind(names, "file_row_number") != DConstants::INVALID_INDEX) { - throw BinderException( - "Using file_row_number option on file with column named file_row_number is not supported"); - } - - bind_data.file_row_number_idx = names.size(); - return_types.emplace_back(LogicalType::BIGINT); - names.emplace_back("file_row_number"); - } -} - -static MultiFileReaderBindData BindSchema(ClientContext &context, vector &return_types, - vector &names, ParquetReadBindData &result, ParquetOptions &options) { - D_ASSERT(!options.schema.empty()); - - options.file_options.AutoDetectHivePartitioning(*result.file_list, context); - - auto &file_options = options.file_options; - if (file_options.union_by_name || file_options.hive_partitioning) { - throw BinderException("Parquet schema cannot be combined with union_by_name=true or hive_partitioning=true"); - } - - vector schema_col_names; - vector schema_col_types; - schema_col_names.reserve(options.schema.size()); - schema_col_types.reserve(options.schema.size()); - for (const auto &column : options.schema) { - schema_col_names.push_back(column.name); - schema_col_types.push_back(column.type); - } - - // perform the binding on the obtained set of names + types - MultiFileReaderBindData bind_data; - result.multi_file_reader->BindOptions(options.file_options, *result.file_list, schema_col_types, schema_col_names, - bind_data); - - names = schema_col_names; - return_types = schema_col_types; - D_ASSERT(names.size() == return_types.size()); - - ParseFileRowNumberOption(bind_data, options, return_types, names); - - return bind_data; -} - -static void InitializeParquetReader(ParquetReader &reader, const ParquetReadBindData &bind_data, - const vector &global_column_ids, - optional_ptr table_filters, ClientContext &context, - optional_idx file_idx, optional_ptr reader_state) { - auto &parquet_options = bind_data.parquet_options; - auto &reader_data = reader.reader_data; - - reader.table_columns = bind_data.table_columns; - // Mark the file in the file list we are scanning here - reader_data.file_list_idx = file_idx; - - if (bind_data.parquet_options.schema.empty()) { - bind_data.multi_file_reader->InitializeReader( - reader, parquet_options.file_options, bind_data.reader_bind, bind_data.types, bind_data.names, - global_column_ids, table_filters, bind_data.file_list->GetFirstFile(), context, reader_state); - return; - } - - // a fixed schema was supplied, initialize the MultiFileReader settings here so we can read using the schema - - // this deals with hive partitioning and filename=true - bind_data.multi_file_reader->FinalizeBind(parquet_options.file_options, bind_data.reader_bind, reader.GetFileName(), - reader.GetNames(), bind_data.types, bind_data.names, global_column_ids, - reader_data, context, reader_state); - - // create a mapping from field id to column index in file - unordered_map field_id_to_column_index; - auto &column_readers = reader.root_reader->Cast().child_readers; - for (idx_t column_index = 0; column_index < column_readers.size(); column_index++) { - auto &column_reader = *column_readers[column_index]; - auto &column_schema = column_reader.Schema(); - if (column_schema.__isset.field_id) { - field_id_to_column_index[column_schema.field_id] = column_index; - } else if (column_reader.GetParentSchema()) { - auto &parent_column_schema = *column_reader.GetParentSchema(); - if (parent_column_schema.__isset.field_id) { - field_id_to_column_index[parent_column_schema.field_id] = column_index; - } - } - } - - // loop through the schema definition - for (idx_t i = 0; i < global_column_ids.size(); i++) { - auto global_column_index = global_column_ids[i].GetPrimaryIndex(); - - // check if this is a constant column - bool constant = false; - for (auto &entry : reader_data.constant_map) { - if (entry.column_id == i) { - constant = true; - break; - } - } - if (constant) { - // this column is constant for this file - continue; - } - - // Handle any generate columns that are not in the schema (currently only file_row_number) - if (global_column_index >= parquet_options.schema.size()) { - if (bind_data.reader_bind.file_row_number_idx == global_column_index) { - reader_data.column_mapping.push_back(i); - reader_data.column_ids.push_back(reader.file_row_number_idx); - } - continue; - } - - const auto &column_definition = parquet_options.schema[global_column_index]; - auto it = field_id_to_column_index.find(column_definition.field_id); - if (it == field_id_to_column_index.end()) { - // field id not present in file, use default value - reader_data.constant_map.emplace_back(i, column_definition.default_value); - continue; - } - - const auto &local_column_index = it->second; - auto &column_reader = column_readers[local_column_index]; - if (column_reader->Type() != column_definition.type) { - // differing types, wrap in a cast column reader - reader_data.cast_map[local_column_index] = column_definition.type; - } - - reader_data.column_mapping.push_back(i); - reader_data.column_ids.push_back(local_column_index); - } - reader_data.empty_columns = reader_data.column_ids.empty(); - - // Finally, initialize the filters - bind_data.multi_file_reader->CreateFilterMap(bind_data.types, table_filters, reader_data, reader_state); - reader_data.filters = table_filters; -} - -static bool GetBooleanArgument(const pair> &option) { - if (option.second.empty()) { - return true; - } - Value boolean_value; - string error_message; - if (!option.second[0].DefaultTryCastAs(LogicalType::BOOLEAN, boolean_value, &error_message)) { - throw InvalidInputException("Unable to cast \"%s\" to BOOLEAN for Parquet option \"%s\"", - option.second[0].ToString(), option.first); - } - return BooleanValue::Get(boolean_value); -} - -TablePartitionInfo ParquetGetPartitionInfo(ClientContext &context, TableFunctionPartitionInput &input) { - auto &parquet_bind = input.bind_data->Cast(); - return parquet_bind.multi_file_reader->GetPartitionInfo(context, parquet_bind.reader_bind, input); -} - -class ParquetScanFunction { -public: - static TableFunctionSet GetFunctionSet() { - TableFunction table_function("parquet_scan", {LogicalType::VARCHAR}, ParquetScanImplementation, ParquetScanBind, - ParquetScanInitGlobal, ParquetScanInitLocal); - table_function.statistics = ParquetScanStats; - table_function.cardinality = ParquetCardinality; - table_function.table_scan_progress = ParquetProgress; - table_function.named_parameters["binary_as_string"] = LogicalType::BOOLEAN; - table_function.named_parameters["file_row_number"] = LogicalType::BOOLEAN; - table_function.named_parameters["debug_use_openssl"] = LogicalType::BOOLEAN; - table_function.named_parameters["compression"] = LogicalType::VARCHAR; - table_function.named_parameters["explicit_cardinality"] = LogicalType::UBIGINT; - table_function.named_parameters["schema"] = - LogicalType::MAP(LogicalType::INTEGER, LogicalType::STRUCT({{{"name", LogicalType::VARCHAR}, - {"type", LogicalType::VARCHAR}, - {"default_value", LogicalType::VARCHAR}}})); - table_function.named_parameters["encryption_config"] = LogicalTypeId::ANY; - table_function.get_partition_data = ParquetScanGetPartitionData; - table_function.serialize = ParquetScanSerialize; - table_function.deserialize = ParquetScanDeserialize; - table_function.get_bind_info = ParquetGetBindInfo; - table_function.projection_pushdown = true; - table_function.filter_pushdown = true; - table_function.filter_prune = true; - table_function.pushdown_complex_filter = ParquetComplexFilterPushdown; - table_function.get_partition_info = ParquetGetPartitionInfo; - - MultiFileReader::AddParameters(table_function); - - return MultiFileReader::CreateFunctionSet(table_function); - } - - static unique_ptr ParquetReadBind(ClientContext &context, CopyInfo &info, - vector &expected_names, - vector &expected_types) { - D_ASSERT(expected_names.size() == expected_types.size()); - ParquetOptions parquet_options(context); - - for (auto &option : info.options) { - auto loption = StringUtil::Lower(option.first); - if (loption == "compression" || loption == "codec" || loption == "row_group_size") { - // CODEC/COMPRESSION and ROW_GROUP_SIZE options have no effect on parquet read. - // These options are determined from the file. - continue; - } else if (loption == "binary_as_string") { - parquet_options.binary_as_string = GetBooleanArgument(option); - } else if (loption == "file_row_number") { - parquet_options.file_row_number = GetBooleanArgument(option); - } else if (loption == "debug_use_openssl") { - parquet_options.debug_use_openssl = GetBooleanArgument(option); - } else if (loption == "encryption_config") { - if (option.second.size() != 1) { - throw BinderException("Parquet encryption_config cannot be empty!"); - } - parquet_options.encryption_config = ParquetEncryptionConfig::Create(context, option.second[0]); - } else { - throw NotImplementedException("Unsupported option for COPY FROM parquet: %s", option.first); - } - } - - // TODO: Allow overriding the MultiFileReader for COPY FROM? - auto multi_file_reader = MultiFileReader::CreateDefault("ParquetCopy"); - vector paths = {info.file_path}; - auto file_list = multi_file_reader->CreateFileList(context, paths); - - return ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), expected_types, - expected_names, parquet_options); - } - - static unique_ptr ParquetScanStats(ClientContext &context, const FunctionData *bind_data_p, - column_t column_index) { - auto &bind_data = bind_data_p->Cast(); - - if (IsRowIdColumnId(column_index)) { - return nullptr; - } - - // NOTE: we do not want to parse the Parquet metadata for the sole purpose of getting column statistics - if (bind_data.file_list->GetExpandResult() == FileExpandResult::MULTIPLE_FILES) { - // multiple files, no luck! - return nullptr; - } - if (!bind_data.initial_reader) { - // no reader - return nullptr; - } - // scanning single parquet file and we have the metadata read already - return bind_data.initial_reader->ReadStatistics(bind_data.names[column_index]); - - return nullptr; - } - - static unique_ptr ParquetScanBindInternal(ClientContext &context, - unique_ptr multi_file_reader, - shared_ptr file_list, - vector &return_types, vector &names, - ParquetOptions parquet_options) { - auto result = make_uniq(); - result->multi_file_reader = std::move(multi_file_reader); - result->file_list = std::move(file_list); - - bool bound_on_first_file = true; - if (result->multi_file_reader->Bind(parquet_options.file_options, *result->file_list, result->types, - result->names, result->reader_bind)) { - result->multi_file_reader->BindOptions(parquet_options.file_options, *result->file_list, result->types, - result->names, result->reader_bind); - // Enable the parquet file_row_number on the parquet options if the file_row_number_idx was set - if (result->reader_bind.file_row_number_idx != DConstants::INVALID_INDEX) { - parquet_options.file_row_number = true; - } - bound_on_first_file = false; - } else if (!parquet_options.schema.empty()) { - // A schema was supplied: use the schema for binding - result->reader_bind = BindSchema(context, result->types, result->names, *result, parquet_options); - } else { - parquet_options.file_options.AutoDetectHivePartitioning(*result->file_list, context); - // Default bind - result->reader_bind = result->multi_file_reader->BindReader( - context, result->types, result->names, *result->file_list, *result, parquet_options); - } - if (parquet_options.explicit_cardinality) { - auto file_count = result->file_list->GetTotalFileCount(); - result->explicit_cardinality = parquet_options.explicit_cardinality; - result->initial_file_cardinality = result->explicit_cardinality / (file_count ? file_count : 1); - } - if (return_types.empty()) { - // no expected types - just copy the types - return_types = result->types; - names = result->names; - } else { - if (return_types.size() != result->types.size()) { - auto file_string = bound_on_first_file ? result->file_list->GetFirstFile() - : StringUtil::Join(result->file_list->GetPaths(), ","); - string extended_error; - extended_error = "Table schema: "; - for (idx_t col_idx = 0; col_idx < return_types.size(); col_idx++) { - if (col_idx > 0) { - extended_error += ", "; - } - extended_error += names[col_idx] + " " + return_types[col_idx].ToString(); - } - extended_error += "\nParquet schema: "; - for (idx_t col_idx = 0; col_idx < result->types.size(); col_idx++) { - if (col_idx > 0) { - extended_error += ", "; - } - extended_error += result->names[col_idx] + " " + result->types[col_idx].ToString(); - } - extended_error += "\n\nPossible solutions:"; - extended_error += "\n* Manually specify which columns to insert using \"INSERT INTO tbl SELECT ... " - "FROM read_parquet(...)\""; - throw ConversionException( - "Failed to read file(s) \"%s\" - column count mismatch: expected %d columns but found %d\n%s", - file_string, return_types.size(), result->types.size(), extended_error); - } - // expected types - overwrite the types we want to read instead - result->types = return_types; - result->table_columns = names; - } - result->parquet_options = parquet_options; - return std::move(result); - } - - static unique_ptr ParquetScanBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto multi_file_reader = MultiFileReader::Create(input.table_function); - - ParquetOptions parquet_options(context); - for (auto &kv : input.named_parameters) { - if (kv.second.IsNull()) { - throw BinderException("Cannot use NULL as function argument"); - } - auto loption = StringUtil::Lower(kv.first); - if (multi_file_reader->ParseOption(kv.first, kv.second, parquet_options.file_options, context)) { - continue; - } - if (loption == "binary_as_string") { - parquet_options.binary_as_string = BooleanValue::Get(kv.second); - } else if (loption == "file_row_number") { - parquet_options.file_row_number = BooleanValue::Get(kv.second); - } else if (loption == "debug_use_openssl") { - parquet_options.debug_use_openssl = BooleanValue::Get(kv.second); - } else if (loption == "schema") { - // Argument is a map that defines the schema - const auto &schema_value = kv.second; - const auto column_values = ListValue::GetChildren(schema_value); - if (column_values.empty()) { - throw BinderException("Parquet schema cannot be empty"); - } - parquet_options.schema.reserve(column_values.size()); - for (idx_t i = 0; i < column_values.size(); i++) { - parquet_options.schema.emplace_back( - ParquetColumnDefinition::FromSchemaValue(context, column_values[i])); - } - - // cannot be combined with hive_partitioning=true, so we disable auto-detection - parquet_options.file_options.auto_detect_hive_partitioning = false; - } else if (loption == "explicit_cardinality") { - parquet_options.explicit_cardinality = UBigIntValue::Get(kv.second); - } else if (loption == "encryption_config") { - parquet_options.encryption_config = ParquetEncryptionConfig::Create(context, kv.second); - } - } - - auto file_list = multi_file_reader->CreateFileList(context, input.inputs[0]); - return ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), return_types, names, - parquet_options); - } - - static double ParquetProgress(ClientContext &context, const FunctionData *bind_data_p, - const GlobalTableFunctionState *global_state) { - auto &bind_data = bind_data_p->Cast(); - auto &gstate = global_state->Cast(); - - auto total_count = gstate.file_list.GetTotalFileCount(); - if (total_count == 0) { - return 100.0; - } - if (bind_data.initial_file_cardinality == 0) { - return (100.0 * (static_cast(gstate.file_index) + 1.0)) / static_cast(total_count); - } - auto percentage = MinValue(100.0, (static_cast(bind_data.chunk_count) * STANDARD_VECTOR_SIZE * - 100.0 / static_cast(bind_data.initial_file_cardinality))); - return (percentage + 100.0 * static_cast(gstate.file_index)) / static_cast(total_count); - } - - static unique_ptr - ParquetScanInitLocal(ExecutionContext &context, TableFunctionInitInput &input, GlobalTableFunctionState *gstate_p) { - auto &bind_data = input.bind_data->Cast(); - auto &gstate = gstate_p->Cast(); - - auto result = make_uniq(); - result->is_parallel = true; - result->batch_index = 0; - - if (gstate.CanRemoveColumns()) { - result->all_columns.Initialize(context.client, gstate.scanned_types); - } - if (!ParquetParallelStateNext(context.client, bind_data, *result, gstate)) { - return nullptr; - } - return std::move(result); - } - - static unique_ptr ParquetDynamicFilterPushdown(ClientContext &context, - const ParquetReadBindData &data, - const vector &column_ids, - optional_ptr filters) { - if (!filters) { - return nullptr; - } - auto new_list = data.multi_file_reader->DynamicFilterPushdown( - context, *data.file_list, data.parquet_options.file_options, data.names, data.types, column_ids, *filters); - return new_list; - } - - static unique_ptr ParquetScanInitGlobal(ClientContext &context, - TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->CastNoConst(); - unique_ptr result; - - // before instantiating a scan trigger a dynamic filter pushdown if possible - auto new_list = ParquetDynamicFilterPushdown(context, bind_data, input.column_ids, input.filters); - if (new_list) { - result = make_uniq(std::move(new_list)); - } else { - result = make_uniq(*bind_data.file_list); - } - auto &file_list = result->file_list; - file_list.InitializeScan(result->file_list_scan); - - result->multi_file_reader_state = bind_data.multi_file_reader->InitializeGlobalState( - context, bind_data.parquet_options.file_options, bind_data.reader_bind, file_list, bind_data.types, - bind_data.names, input.column_indexes); - if (file_list.IsEmpty()) { - result->readers = {}; - } else if (!bind_data.union_readers.empty()) { - // TODO: confirm we are not changing behaviour by modifying the order here? - for (auto &reader : bind_data.union_readers) { - if (!reader) { - break; - } - result->readers.push_back(make_uniq(std::move(reader))); - } - if (result->readers.size() != file_list.GetTotalFileCount()) { - // This case happens with recursive CTEs: the first execution the readers have already - // been moved out of the bind data. - // FIXME: clean up this process and make it more explicit - result->readers = {}; - } - } else if (bind_data.initial_reader) { - // we can only use the initial reader if it was constructed from the first file - if (bind_data.initial_reader->file_name == file_list.GetFirstFile()) { - result->readers.push_back(make_uniq(std::move(bind_data.initial_reader))); - } - } - - // Ensure all readers are initialized and FileListScan is sync with readers list - for (auto &reader_data : result->readers) { - string file_name; - idx_t file_idx = result->file_list_scan.current_file_idx; - file_list.Scan(result->file_list_scan, file_name); - if (reader_data->union_data) { - if (file_name != reader_data->union_data->GetFileName()) { - throw InternalException("Mismatch in filename order and union reader order in parquet scan"); - } - } else { - D_ASSERT(reader_data->reader); - if (file_name != reader_data->reader->file_name) { - throw InternalException("Mismatch in filename order and reader order in parquet scan"); - } - InitializeParquetReader(*reader_data->reader, bind_data, input.column_indexes, input.filters, context, - file_idx, result->multi_file_reader_state); - } - } - - result->column_indexes = input.column_indexes; - result->filters = input.filters.get(); - result->row_group_index = 0; - result->file_index = 0; - result->batch_index = 0; - result->max_threads = ParquetScanMaxThreads(context, input.bind_data.get()); - - bool require_extra_columns = - result->multi_file_reader_state && result->multi_file_reader_state->RequiresExtraColumns(); - if (input.CanRemoveFilterColumns() || require_extra_columns) { - if (!input.projection_ids.empty()) { - result->projection_ids = input.projection_ids; - } else { - result->projection_ids.resize(input.column_indexes.size()); - iota(begin(result->projection_ids), end(result->projection_ids), 0); - } - - const auto table_types = bind_data.types; - for (const auto &col_idx : input.column_indexes) { - if (col_idx.IsRowIdColumn()) { - result->scanned_types.emplace_back(LogicalType::ROW_TYPE); - } else { - result->scanned_types.push_back(table_types[col_idx.GetPrimaryIndex()]); - } - } - } - - if (require_extra_columns) { - for (const auto &column_type : result->multi_file_reader_state->extra_columns) { - result->scanned_types.push_back(column_type); - } - } - - return std::move(result); - } - - static OperatorPartitionData ParquetScanGetPartitionData(ClientContext &context, - TableFunctionGetPartitionInput &input) { - auto &bind_data = input.bind_data->CastNoConst(); - auto &data = input.local_state->Cast(); - auto &gstate = input.global_state->Cast(); - OperatorPartitionData partition_data(data.batch_index); - bind_data.multi_file_reader->GetPartitionData(context, bind_data.reader_bind, data.reader->reader_data, - gstate.multi_file_reader_state, input.partition_info, - partition_data); - return partition_data; - } - - static void ParquetScanSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const TableFunction &function) { - auto &bind_data = bind_data_p->Cast(); - - serializer.WriteProperty(100, "files", bind_data.file_list->GetAllFiles()); - serializer.WriteProperty(101, "types", bind_data.types); - serializer.WriteProperty(102, "names", bind_data.names); - serializer.WriteProperty(103, "parquet_options", bind_data.parquet_options); - if (serializer.ShouldSerialize(3)) { - serializer.WriteProperty(104, "table_columns", bind_data.table_columns); - } - } - - static unique_ptr ParquetScanDeserialize(Deserializer &deserializer, TableFunction &function) { - auto &context = deserializer.Get(); - auto files = deserializer.ReadProperty>(100, "files"); - auto types = deserializer.ReadProperty>(101, "types"); - auto names = deserializer.ReadProperty>(102, "names"); - auto parquet_options = deserializer.ReadProperty(103, "parquet_options"); - auto table_columns = - deserializer.ReadPropertyWithExplicitDefault>(104, "table_columns", vector {}); - - vector file_path; - for (auto &path : files) { - file_path.emplace_back(path); - } - - auto multi_file_reader = MultiFileReader::Create(function); - auto file_list = multi_file_reader->CreateFileList(context, Value::LIST(LogicalType::VARCHAR, file_path), - FileGlobOptions::DISALLOW_EMPTY); - auto bind_data = ParquetScanBindInternal(context, std::move(multi_file_reader), std::move(file_list), types, - names, parquet_options); - bind_data->Cast().table_columns = std::move(table_columns); - return bind_data; - } - - static void ParquetScanImplementation(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - if (!data_p.local_state) { - return; - } - auto &data = data_p.local_state->Cast(); - auto &gstate = data_p.global_state->Cast(); - auto &bind_data = data_p.bind_data->CastNoConst(); - - bool rowgroup_finished; - do { - if (gstate.CanRemoveColumns()) { - data.all_columns.Reset(); - data.reader->Scan(data.scan_state, data.all_columns); - rowgroup_finished = data.all_columns.size() == 0; - bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data, - data.all_columns, gstate.multi_file_reader_state); - output.ReferenceColumns(data.all_columns, gstate.projection_ids); - } else { - data.reader->Scan(data.scan_state, output); - rowgroup_finished = output.size() == 0; - bind_data.multi_file_reader->FinalizeChunk(context, bind_data.reader_bind, data.reader->reader_data, - output, gstate.multi_file_reader_state); - } - - bind_data.chunk_count++; - if (output.size() > 0) { - return; - } - if (rowgroup_finished && !ParquetParallelStateNext(context, bind_data, data, gstate)) { - return; - } - } while (true); - } - - static unique_ptr ParquetCardinality(ClientContext &context, const FunctionData *bind_data) { - auto &data = bind_data->Cast(); - if (data.explicit_cardinality) { - return make_uniq(data.explicit_cardinality); - } - auto file_list_cardinality_estimate = data.file_list->GetCardinality(context); - if (file_list_cardinality_estimate) { - return file_list_cardinality_estimate; - } - return make_uniq(MaxValue(data.initial_file_cardinality, (idx_t)1) * - data.file_list->GetTotalFileCount()); - } - - static idx_t ParquetScanMaxThreads(ClientContext &context, const FunctionData *bind_data) { - auto &data = bind_data->Cast(); - - if (data.file_list->GetExpandResult() == FileExpandResult::MULTIPLE_FILES) { - return TaskScheduler::GetScheduler(context).NumberOfThreads(); - } - - return MaxValue(data.initial_file_row_groups, (idx_t)1); - } - - // Queries the metadataprovider for another file to scan, updating the files/reader lists in the process. - // Returns true if resized - static bool ResizeFiles(ParquetReadGlobalState ¶llel_state) { - string scanned_file; - if (!parallel_state.file_list.Scan(parallel_state.file_list_scan, scanned_file)) { - return false; - } - - // Push the file in the reader data, to be opened later - parallel_state.readers.push_back(make_uniq(scanned_file)); - - return true; - } - - // This function looks for the next available row group. If not available, it will open files from bind_data.files - // until there is a row group available for scanning or the files runs out - static bool ParquetParallelStateNext(ClientContext &context, const ParquetReadBindData &bind_data, - ParquetReadLocalState &scan_data, ParquetReadGlobalState ¶llel_state) { - unique_lock parallel_lock(parallel_state.lock); - - while (true) { - if (parallel_state.error_opening_file) { - return false; - } - - if (parallel_state.file_index >= parallel_state.readers.size() && !ResizeFiles(parallel_state)) { - return false; - } - - auto ¤t_reader_data = *parallel_state.readers[parallel_state.file_index]; - if (current_reader_data.file_state == ParquetFileState::OPEN) { - if (parallel_state.row_group_index < current_reader_data.reader->NumRowGroups()) { - // The current reader has rowgroups left to be scanned - scan_data.reader = current_reader_data.reader; - vector group_indexes {parallel_state.row_group_index}; - scan_data.reader->InitializeScan(context, scan_data.scan_state, group_indexes); - scan_data.batch_index = parallel_state.batch_index++; - scan_data.file_index = parallel_state.file_index; - parallel_state.row_group_index++; - return true; - } else { - // Close current file - current_reader_data.file_state = ParquetFileState::CLOSED; - current_reader_data.reader = nullptr; - - // Set state to the next file - parallel_state.file_index++; - parallel_state.row_group_index = 0; - - continue; - } - } - - if (TryOpenNextFile(context, bind_data, scan_data, parallel_state, parallel_lock)) { - continue; - } - - // Check if the current file is being opened, in that case we need to wait for it. - if (current_reader_data.file_state == ParquetFileState::OPENING) { - WaitForFile(parallel_state.file_index, parallel_state, parallel_lock); - } - } - } - - static void ParquetComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, - vector> &filters) { - auto &data = bind_data_p->Cast(); - - MultiFilePushdownInfo info(get); - auto new_list = data.multi_file_reader->ComplexFilterPushdown(context, *data.file_list, - data.parquet_options.file_options, info, filters); - - if (new_list) { - data.file_list = std::move(new_list); - MultiFileReader::PruneReaders(data, *data.file_list); - } - } - - //! Wait for a file to become available. Parallel lock should be locked when calling. - static void WaitForFile(idx_t file_index, ParquetReadGlobalState ¶llel_state, - unique_lock ¶llel_lock) { - while (true) { - // Get pointer to file mutex before unlocking - auto &file_mutex = *parallel_state.readers[file_index]->file_mutex; - - // To get the file lock, we first need to release the parallel_lock to prevent deadlocking. Note that this - // requires getting the ref to the file mutex pointer with the lock stil held: readers get be resized - parallel_lock.unlock(); - unique_lock current_file_lock(file_mutex); - parallel_lock.lock(); - - // Here we have both locks which means we can stop waiting if: - // - the thread opening the file is done and the file is available - // - the thread opening the file has failed - // - the file was somehow scanned till the end while we were waiting - if (parallel_state.file_index >= parallel_state.readers.size() || - parallel_state.readers[parallel_state.file_index]->file_state != ParquetFileState::OPENING || - parallel_state.error_opening_file) { - return; - } - } - } - - //! Helper function that try to start opening a next file. Parallel lock should be locked when calling. - static bool TryOpenNextFile(ClientContext &context, const ParquetReadBindData &bind_data, - ParquetReadLocalState &scan_data, ParquetReadGlobalState ¶llel_state, - unique_lock ¶llel_lock) { - const auto file_index_limit = - parallel_state.file_index + TaskScheduler::GetScheduler(context).NumberOfThreads(); - - for (idx_t i = parallel_state.file_index; i < file_index_limit; i++) { - // We check if we can resize files in this loop too otherwise we will only ever open 1 file ahead - if (i >= parallel_state.readers.size() && !ResizeFiles(parallel_state)) { - return false; - } - - auto ¤t_reader_data = *parallel_state.readers[i]; - if (current_reader_data.file_state == ParquetFileState::UNOPENED) { - current_reader_data.file_state = ParquetFileState::OPENING; - auto pq_options = bind_data.parquet_options; - - // Get pointer to file mutex before unlocking - auto ¤t_file_lock = *current_reader_data.file_mutex; - - // Now we switch which lock we are holding, instead of locking the global state, we grab the lock on - // the file we are opening. This file lock allows threads to wait for a file to be opened. - parallel_lock.unlock(); - unique_lock file_lock(current_file_lock); - - shared_ptr reader; - try { - if (current_reader_data.union_data) { - auto &union_data = *current_reader_data.union_data; - reader = make_shared_ptr(context, union_data.file_name, union_data.options, - union_data.metadata); - } else { - reader = - make_shared_ptr(context, current_reader_data.file_to_be_opened, pq_options); - } - InitializeParquetReader(*reader, bind_data, parallel_state.column_indexes, parallel_state.filters, - context, i, parallel_state.multi_file_reader_state); - } catch (...) { - parallel_lock.lock(); - parallel_state.error_opening_file = true; - throw; - } - - // Now re-lock the state and add the reader - parallel_lock.lock(); - current_reader_data.reader = std::move(reader); - current_reader_data.file_state = ParquetFileState::OPEN; - - return true; - } - } - - return false; - } -}; - -static case_insensitive_map_t GetChildNameToTypeMap(const LogicalType &type) { - case_insensitive_map_t name_to_type_map; - switch (type.id()) { - case LogicalTypeId::LIST: - name_to_type_map.emplace("element", ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - name_to_type_map.emplace("key", MapType::KeyType(type)); - name_to_type_map.emplace("value", MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - if (child_type.first == FieldID::DUCKDB_FIELD_ID) { - throw BinderException("Cannot have column named \"%s\" with FIELD_IDS", FieldID::DUCKDB_FIELD_ID); - } - name_to_type_map.emplace(child_type); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNameToTypeMap"); - } // LCOV_EXCL_STOP - return name_to_type_map; -} - -static void GetChildNamesAndTypes(const LogicalType &type, vector &child_names, - vector &child_types) { - switch (type.id()) { - case LogicalTypeId::LIST: - child_names.emplace_back("element"); - child_types.emplace_back(ListType::GetChildType(type)); - break; - case LogicalTypeId::MAP: - child_names.emplace_back("key"); - child_names.emplace_back("value"); - child_types.emplace_back(MapType::KeyType(type)); - child_types.emplace_back(MapType::ValueType(type)); - break; - case LogicalTypeId::STRUCT: - for (auto &child_type : StructType::GetChildTypes(type)) { - child_names.emplace_back(child_type.first); - child_types.emplace_back(child_type.second); - } - break; - default: // LCOV_EXCL_START - throw InternalException("Unexpected type in GetChildNamesAndTypes"); - } // LCOV_EXCL_STOP -} - -static void GenerateFieldIDs(ChildFieldIDs &field_ids, idx_t &field_id, const vector &names, - const vector &sql_types) { - D_ASSERT(names.size() == sql_types.size()); - for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - const auto &col_name = names[col_idx]; - auto inserted = field_ids.ids->insert(make_pair(col_name, FieldID(UnsafeNumericCast(field_id++)))); - D_ASSERT(inserted.second); - - const auto &col_type = sql_types[col_idx]; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - continue; - } - - // Cannot use GetChildNameToTypeMap here because we lose order, and we want to generate depth-first - vector child_names; - vector child_types; - GetChildNamesAndTypes(col_type, child_names, child_types); - - GenerateFieldIDs(inserted.first->second.child_field_ids, field_id, child_names, child_types); - } -} - -static void GetFieldIDs(const Value &field_ids_value, ChildFieldIDs &field_ids, - unordered_set &unique_field_ids, - const case_insensitive_map_t &name_to_type_map) { - const auto &struct_type = field_ids_value.type(); - if (struct_type.id() != LogicalTypeId::STRUCT) { - throw BinderException( - "Expected FIELD_IDS to be a STRUCT, e.g., {col1: 42, col2: {%s: 43, nested_col: 44}, col3: 44}", - FieldID::DUCKDB_FIELD_ID); - } - const auto &struct_children = StructValue::GetChildren(field_ids_value); - D_ASSERT(StructType::GetChildTypes(struct_type).size() == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - const auto &col_name = StringUtil::Lower(StructType::GetChildName(struct_type, i)); - if (col_name == FieldID::DUCKDB_FIELD_ID) { - continue; - } - - auto it = name_to_type_map.find(col_name); - if (it == name_to_type_map.end()) { - string names; - for (const auto &name : name_to_type_map) { - if (!names.empty()) { - names += ", "; - } - names += name.first; - } - throw BinderException( - "Column name \"%s\" specified in FIELD_IDS not found. Consider using WRITE_PARTITION_COLUMNS if this " - "column is a partition column. Available column names: [%s]", - col_name, names); - } - D_ASSERT(field_ids.ids->find(col_name) == field_ids.ids->end()); // Caught by STRUCT - deduplicates keys - - const auto &child_value = struct_children[i]; - const auto &child_type = child_value.type(); - optional_ptr field_id_value; - optional_ptr child_field_ids_value; - - if (child_type.id() == LogicalTypeId::STRUCT) { - const auto &nested_children = StructValue::GetChildren(child_value); - D_ASSERT(StructType::GetChildTypes(child_type).size() == nested_children.size()); - for (idx_t nested_i = 0; nested_i < nested_children.size(); nested_i++) { - const auto &field_id_or_nested_col = StructType::GetChildName(child_type, nested_i); - if (field_id_or_nested_col == FieldID::DUCKDB_FIELD_ID) { - field_id_value = &nested_children[nested_i]; - } else { - child_field_ids_value = &child_value; - } - } - } else { - field_id_value = &child_value; - } - - FieldID field_id; - if (field_id_value) { - Value field_id_integer_value = field_id_value->DefaultCastAs(LogicalType::INTEGER); - const uint32_t field_id_int = IntegerValue::Get(field_id_integer_value); - if (!unique_field_ids.insert(field_id_int).second) { - throw BinderException("Duplicate field_id %s found in FIELD_IDS", field_id_integer_value.ToString()); - } - field_id = FieldID(UnsafeNumericCast(field_id_int)); - } - auto inserted = field_ids.ids->insert(make_pair(col_name, std::move(field_id))); - D_ASSERT(inserted.second); - - if (child_field_ids_value) { - const auto &col_type = it->second; - if (col_type.id() != LogicalTypeId::LIST && col_type.id() != LogicalTypeId::MAP && - col_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("Column \"%s\" with type \"%s\" cannot have a nested FIELD_IDS specification", - col_name, LogicalTypeIdToString(col_type.id())); - } - - GetFieldIDs(*child_field_ids_value, inserted.first->second.child_field_ids, unique_field_ids, - GetChildNameToTypeMap(col_type)); - } - } -} - -unique_ptr ParquetWriteBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { - D_ASSERT(names.size() == sql_types.size()); - bool row_group_size_bytes_set = false; - bool compression_level_set = false; - auto bind_data = make_uniq(); - for (auto &option : input.info.options) { - const auto loption = StringUtil::Lower(option.first); - if (option.second.size() != 1) { - // All parquet write options require exactly one argument - throw BinderException("%s requires exactly one argument", StringUtil::Upper(loption)); - } - if (loption == "row_group_size" || loption == "chunk_size") { - bind_data->row_group_size = option.second[0].GetValue(); - } else if (loption == "row_group_size_bytes") { - auto roption = option.second[0]; - if (roption.GetTypeMutable().id() == LogicalTypeId::VARCHAR) { - bind_data->row_group_size_bytes = DBConfig::ParseMemoryLimit(roption.ToString()); - } else { - bind_data->row_group_size_bytes = option.second[0].GetValue(); - } - row_group_size_bytes_set = true; - } else if (loption == "row_groups_per_file") { - bind_data->row_groups_per_file = option.second[0].GetValue(); - } else if (loption == "compression" || loption == "codec") { - const auto roption = StringUtil::Lower(option.second[0].ToString()); - if (roption == "uncompressed") { - bind_data->codec = duckdb_parquet::CompressionCodec::UNCOMPRESSED; - } else if (roption == "snappy") { - bind_data->codec = duckdb_parquet::CompressionCodec::SNAPPY; - } else if (roption == "gzip") { - bind_data->codec = duckdb_parquet::CompressionCodec::GZIP; - } else if (roption == "zstd") { - bind_data->codec = duckdb_parquet::CompressionCodec::ZSTD; - } else if (roption == "brotli") { - bind_data->codec = duckdb_parquet::CompressionCodec::BROTLI; - } else if (roption == "lz4" || roption == "lz4_raw") { - /* LZ4 is technically another compression scheme, but deprecated and arrow also uses them - * interchangeably */ - bind_data->codec = duckdb_parquet::CompressionCodec::LZ4_RAW; - } else { - throw BinderException("Expected %s argument to be either [uncompressed, brotli, gzip, snappy, or zstd]", - loption); - } - } else if (loption == "field_ids") { - if (option.second[0].type().id() == LogicalTypeId::VARCHAR && - StringUtil::Lower(StringValue::Get(option.second[0])) == "auto") { - idx_t field_id = 0; - GenerateFieldIDs(bind_data->field_ids, field_id, names, sql_types); - } else { - unordered_set unique_field_ids; - case_insensitive_map_t name_to_type_map; - for (idx_t col_idx = 0; col_idx < names.size(); col_idx++) { - if (names[col_idx] == FieldID::DUCKDB_FIELD_ID) { - throw BinderException("Cannot have a column named \"%s\" when writing FIELD_IDS", - FieldID::DUCKDB_FIELD_ID); - } - name_to_type_map.emplace(names[col_idx], sql_types[col_idx]); - } - GetFieldIDs(option.second[0], bind_data->field_ids, unique_field_ids, name_to_type_map); - } - } else if (loption == "kv_metadata") { - auto &kv_struct = option.second[0]; - auto &kv_struct_type = kv_struct.type(); - if (kv_struct_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("Expected kv_metadata argument to be a STRUCT"); - } - auto values = StructValue::GetChildren(kv_struct); - for (idx_t i = 0; i < values.size(); i++) { - auto value = values[i]; - auto key = StructType::GetChildName(kv_struct_type, i); - // If the value is a blob, write the raw blob bytes - // otherwise, cast to string - if (value.type().id() == LogicalTypeId::BLOB) { - bind_data->kv_metadata.emplace_back(key, StringValue::Get(value)); - } else { - bind_data->kv_metadata.emplace_back(key, value.ToString()); - } - } - } else if (loption == "encryption_config") { - bind_data->encryption_config = ParquetEncryptionConfig::Create(context, option.second[0]); - } else if (loption == "dictionary_compression_ratio_threshold") { - // deprecated, ignore setting - } else if (loption == "dictionary_size_limit") { - auto val = option.second[0].GetValue(); - if (val < 0) { - throw BinderException("dictionary_size_limit must be greater than 0 or 0 to disable"); - } - bind_data->dictionary_size_limit = val; - } else if (loption == "bloom_filter_false_positive_ratio") { - auto val = option.second[0].GetValue(); - if (val <= 0) { - throw BinderException("bloom_filter_false_positive_ratio must be greater than 0"); - } - bind_data->bloom_filter_false_positive_ratio = val; - } else if (loption == "debug_use_openssl") { - auto val = StringUtil::Lower(option.second[0].GetValue()); - if (val == "false") { - bind_data->debug_use_openssl = false; - } else if (val == "true") { - bind_data->debug_use_openssl = true; - } else { - throw BinderException("Expected debug_use_openssl to be a BOOLEAN"); - } - } else if (loption == "compression_level") { - const auto val = option.second[0].GetValue(); - if (val < ZStdFileSystem::MinimumCompressionLevel() || val > ZStdFileSystem::MaximumCompressionLevel()) { - throw BinderException("Compression level must be between %lld and %lld", - ZStdFileSystem::MinimumCompressionLevel(), - ZStdFileSystem::MaximumCompressionLevel()); - } - bind_data->compression_level = val; - compression_level_set = true; - } else { - throw NotImplementedException("Unrecognized option for PARQUET: %s", option.first.c_str()); - } - } - if (row_group_size_bytes_set) { - if (DBConfig::GetConfig(context).options.preserve_insertion_order) { - throw BinderException("ROW_GROUP_SIZE_BYTES does not work while preserving insertion order. Use \"SET " - "preserve_insertion_order=false;\" to disable preserving insertion order."); - } - } - - if (compression_level_set && bind_data->codec != CompressionCodec::ZSTD) { - throw BinderException("Compression level is only supported for the ZSTD compression codec"); - } - - bind_data->sql_types = sql_types; - bind_data->column_names = names; - return std::move(bind_data); -} - -unique_ptr ParquetWriteInitializeGlobal(ClientContext &context, FunctionData &bind_data, - const string &file_path) { - auto global_state = make_uniq(); - auto &parquet_bind = bind_data.Cast(); - - auto &fs = FileSystem::GetFileSystem(context); - global_state->writer = make_uniq( - context, fs, file_path, parquet_bind.sql_types, parquet_bind.column_names, parquet_bind.codec, - parquet_bind.field_ids.Copy(), parquet_bind.kv_metadata, parquet_bind.encryption_config, - parquet_bind.dictionary_size_limit, parquet_bind.bloom_filter_false_positive_ratio, - parquet_bind.compression_level, parquet_bind.debug_use_openssl); - return std::move(global_state); -} - -void ParquetWriteSink(ExecutionContext &context, FunctionData &bind_data_p, GlobalFunctionData &gstate, - LocalFunctionData &lstate, DataChunk &input) { - auto &bind_data = bind_data_p.Cast(); - auto &global_state = gstate.Cast(); - auto &local_state = lstate.Cast(); - - // append data to the local (buffered) chunk collection - local_state.buffer.Append(local_state.append_state, input); - - if (local_state.buffer.Count() >= bind_data.row_group_size || - local_state.buffer.SizeInBytes() >= bind_data.row_group_size_bytes) { - // if the chunk collection exceeds a certain size (rows/bytes) we flush it to the parquet file - local_state.append_state.current_chunk_state.handles.clear(); - global_state.writer->Flush(local_state.buffer); - local_state.buffer.InitializeAppend(local_state.append_state); - } -} - -void ParquetWriteCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - LocalFunctionData &lstate) { - auto &global_state = gstate.Cast(); - auto &local_state = lstate.Cast(); - // flush any data left in the local state to the file - global_state.writer->Flush(local_state.buffer); -} - -void ParquetWriteFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { - auto &global_state = gstate.Cast(); - // finalize: write any additional metadata to the file here - global_state.writer->Finalize(); -} - -unique_ptr ParquetWriteInitializeLocal(ExecutionContext &context, FunctionData &bind_data_p) { - auto &bind_data = bind_data_p.Cast(); - return make_uniq(context.client, bind_data.sql_types); -} - -// LCOV_EXCL_START - -// FIXME: Have these be generated instead -template <> -const char *EnumUtil::ToChars(duckdb_parquet::CompressionCodec::type value) { - switch (value) { - case CompressionCodec::UNCOMPRESSED: - return "UNCOMPRESSED"; - break; - case CompressionCodec::SNAPPY: - return "SNAPPY"; - break; - case CompressionCodec::GZIP: - return "GZIP"; - break; - case CompressionCodec::LZO: - return "LZO"; - break; - case CompressionCodec::BROTLI: - return "BROTLI"; - break; - case CompressionCodec::LZ4: - return "LZ4"; - break; - case CompressionCodec::LZ4_RAW: - return "LZ4_RAW"; - break; - case CompressionCodec::ZSTD: - return "ZSTD"; - break; - default: - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); - } -} - -template <> -duckdb_parquet::CompressionCodec::type EnumUtil::FromString(const char *value) { - if (StringUtil::Equals(value, "UNCOMPRESSED")) { - return CompressionCodec::UNCOMPRESSED; - } - if (StringUtil::Equals(value, "SNAPPY")) { - return CompressionCodec::SNAPPY; - } - if (StringUtil::Equals(value, "GZIP")) { - return CompressionCodec::GZIP; - } - if (StringUtil::Equals(value, "LZO")) { - return CompressionCodec::LZO; - } - if (StringUtil::Equals(value, "BROTLI")) { - return CompressionCodec::BROTLI; - } - if (StringUtil::Equals(value, "LZ4")) { - return CompressionCodec::LZ4; - } - if (StringUtil::Equals(value, "LZ4_RAW")) { - return CompressionCodec::LZ4_RAW; - } - if (StringUtil::Equals(value, "ZSTD")) { - return CompressionCodec::ZSTD; - } - throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value)); -} - -static optional_idx SerializeCompressionLevel(const int64_t compression_level) { - return compression_level < 0 ? NumericLimits::Maximum() - NumericCast(AbsValue(compression_level)) - : NumericCast(compression_level); -} - -static int64_t DeserializeCompressionLevel(const optional_idx compression_level) { - // Was originally an optional_idx, now int64_t, so we still serialize as such - if (!compression_level.IsValid()) { - return ZStdFileSystem::DefaultCompressionLevel(); - } - if (compression_level.GetIndex() > NumericCast(ZStdFileSystem::MaximumCompressionLevel())) { - // restore the negative compression level - return -NumericCast(NumericLimits::Maximum() - compression_level.GetIndex()); - } - return NumericCast(compression_level.GetIndex()); -} - -static void ParquetCopySerialize(Serializer &serializer, const FunctionData &bind_data_p, - const CopyFunction &function) { - auto &bind_data = bind_data_p.Cast(); - serializer.WriteProperty(100, "sql_types", bind_data.sql_types); - serializer.WriteProperty(101, "column_names", bind_data.column_names); - serializer.WriteProperty(102, "codec", bind_data.codec); - serializer.WriteProperty(103, "row_group_size", bind_data.row_group_size); - serializer.WriteProperty(104, "row_group_size_bytes", bind_data.row_group_size_bytes); - serializer.WriteProperty(105, "kv_metadata", bind_data.kv_metadata); - serializer.WriteProperty(106, "field_ids", bind_data.field_ids); - serializer.WritePropertyWithDefault>(107, "encryption_config", - bind_data.encryption_config, nullptr); - - // 108 was dictionary_compression_ratio_threshold, but was deleted - const auto compression_level = SerializeCompressionLevel(bind_data.compression_level); - D_ASSERT(DeserializeCompressionLevel(compression_level) == bind_data.compression_level); - serializer.WritePropertyWithDefault(109, "compression_level", compression_level); - serializer.WriteProperty(110, "row_groups_per_file", bind_data.row_groups_per_file); - serializer.WriteProperty(111, "debug_use_openssl", bind_data.debug_use_openssl); - serializer.WriteProperty(112, "dictionary_size_limit", bind_data.dictionary_size_limit); - serializer.WriteProperty(113, "bloom_filter_false_positive_ratio", bind_data.bloom_filter_false_positive_ratio); -} - -static unique_ptr ParquetCopyDeserialize(Deserializer &deserializer, CopyFunction &function) { - auto data = make_uniq(); - data->sql_types = deserializer.ReadProperty>(100, "sql_types"); - data->column_names = deserializer.ReadProperty>(101, "column_names"); - data->codec = deserializer.ReadProperty(102, "codec"); - data->row_group_size = deserializer.ReadProperty(103, "row_group_size"); - data->row_group_size_bytes = deserializer.ReadProperty(104, "row_group_size_bytes"); - data->kv_metadata = deserializer.ReadProperty>>(105, "kv_metadata"); - data->field_ids = deserializer.ReadProperty(106, "field_ids"); - deserializer.ReadPropertyWithExplicitDefault>(107, "encryption_config", - data->encryption_config, nullptr); - deserializer.ReadDeletedProperty(108, "dictionary_compression_ratio_threshold"); - - optional_idx compression_level; - deserializer.ReadPropertyWithDefault(109, "compression_level", compression_level); - data->compression_level = DeserializeCompressionLevel(compression_level); - D_ASSERT(SerializeCompressionLevel(data->compression_level) == compression_level); - data->row_groups_per_file = - deserializer.ReadPropertyWithExplicitDefault(110, "row_groups_per_file", optional_idx::Invalid()); - data->debug_use_openssl = deserializer.ReadPropertyWithExplicitDefault(111, "debug_use_openssl", true); - data->dictionary_size_limit = - deserializer.ReadPropertyWithExplicitDefault(112, "dictionary_size_limit", data->row_group_size / 10); - data->bloom_filter_false_positive_ratio = - deserializer.ReadPropertyWithExplicitDefault(113, "bloom_filter_false_positive_ratio", 0.01); - - return std::move(data); -} -// LCOV_EXCL_STOP - -//===--------------------------------------------------------------------===// -// Execution Mode -//===--------------------------------------------------------------------===// -CopyFunctionExecutionMode ParquetWriteExecutionMode(bool preserve_insertion_order, bool supports_batch_index) { - if (!preserve_insertion_order) { - return CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; - } - if (supports_batch_index) { - return CopyFunctionExecutionMode::BATCH_COPY_TO_FILE; - } - return CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; -} -//===--------------------------------------------------------------------===// -// Prepare Batch -//===--------------------------------------------------------------------===// -struct ParquetWriteBatchData : public PreparedBatchData { - PreparedRowGroup prepared_row_group; -}; - -unique_ptr ParquetWritePrepareBatch(ClientContext &context, FunctionData &bind_data, - GlobalFunctionData &gstate, - unique_ptr collection) { - auto &global_state = gstate.Cast(); - auto result = make_uniq(); - global_state.writer->PrepareRowGroup(*collection, result->prepared_row_group); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Flush Batch -//===--------------------------------------------------------------------===// -void ParquetWriteFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - PreparedBatchData &batch_p) { - auto &global_state = gstate.Cast(); - auto &batch = batch_p.Cast(); - global_state.writer->FlushRowGroup(batch.prepared_row_group); -} - -//===--------------------------------------------------------------------===// -// Desired Batch Size -//===--------------------------------------------------------------------===// -idx_t ParquetWriteDesiredBatchSize(ClientContext &context, FunctionData &bind_data_p) { - auto &bind_data = bind_data_p.Cast(); - return bind_data.row_group_size; -} - -//===--------------------------------------------------------------------===// -// File rotation -//===--------------------------------------------------------------------===// -bool ParquetWriteRotateFiles(FunctionData &bind_data_p, const optional_idx &file_size_bytes) { - auto &bind_data = bind_data_p.Cast(); - return file_size_bytes.IsValid() || bind_data.row_groups_per_file.IsValid(); -} - -bool ParquetWriteRotateNextFile(GlobalFunctionData &gstate, FunctionData &bind_data_p, - const optional_idx &file_size_bytes) { - auto &global_state = gstate.Cast(); - auto &bind_data = bind_data_p.Cast(); - if (file_size_bytes.IsValid() && global_state.writer->FileSize() > file_size_bytes.GetIndex()) { - return true; - } - if (bind_data.row_groups_per_file.IsValid() && - global_state.writer->NumberOfRowGroups() >= bind_data.row_groups_per_file.GetIndex()) { - return true; - } - return false; -} - -//===--------------------------------------------------------------------===// -// Scan Replacement -//===--------------------------------------------------------------------===// -unique_ptr ParquetScanReplacement(ClientContext &context, ReplacementScanInput &input, - optional_ptr data) { - auto table_name = ReplacementScan::GetFullPath(input); - if (!ReplacementScan::CanReplace(table_name, {"parquet"})) { - return nullptr; - } - auto table_function = make_uniq(); - vector> children; - children.push_back(make_uniq(Value(table_name))); - table_function->function = make_uniq("parquet_scan", std::move(children)); - - if (!FileSystem::HasGlob(table_name)) { - auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); - } - - return std::move(table_function); -} - -//===--------------------------------------------------------------------===// -// Select -//===--------------------------------------------------------------------===// -// Helper predicates for ParquetWriteSelect -static bool IsTypeNotSupported(const LogicalType &type) { - if (type.IsNested()) { - return false; - } - return !ParquetWriter::TryGetParquetType(type); -} - -static bool IsTypeLossy(const LogicalType &type) { - return type.id() == LogicalTypeId::HUGEINT || type.id() == LogicalTypeId::UHUGEINT; -} - -static vector> ParquetWriteSelect(CopyToSelectInput &input) { - - auto &context = input.context; - - vector> result; - - bool any_change = false; - - for (auto &expr : input.select_list) { - - const auto &type = expr->return_type; - const auto &name = expr->GetAlias(); - - // Spatial types need to be encoded into WKB when writing GeoParquet. - // But dont perform this conversion if this is a EXPORT DATABASE statement - if (input.copy_to_type == CopyToType::COPY_TO_FILE && type.id() == LogicalTypeId::BLOB && type.HasAlias() && - type.GetAlias() == "GEOMETRY" && GeoParquetFileMetadata::IsGeoParquetConversionEnabled(context)) { - - LogicalType wkb_blob_type(LogicalTypeId::BLOB); - wkb_blob_type.SetAlias("WKB_BLOB"); - - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), wkb_blob_type, false); - cast_expr->SetAlias(name); - result.push_back(std::move(cast_expr)); - any_change = true; - } - // If this is an EXPORT DATABASE statement, we dont want to write "lossy" types, instead cast them to VARCHAR - else if (input.copy_to_type == CopyToType::EXPORT_DATABASE && TypeVisitor::Contains(type, IsTypeLossy)) { - // Replace all lossy types with VARCHAR - auto new_type = TypeVisitor::VisitReplace( - type, [](const LogicalType &ty) -> LogicalType { return IsTypeLossy(ty) ? LogicalType::VARCHAR : ty; }); - - // Cast the column to the new type - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), new_type, false); - cast_expr->SetAlias(name); - result.push_back(std::move(cast_expr)); - any_change = true; - } - // Else look if there is any unsupported type - else if (TypeVisitor::Contains(type, IsTypeNotSupported)) { - // If there is at least one unsupported type, replace all unsupported types with varchar - // and perform a CAST - auto new_type = TypeVisitor::VisitReplace(type, [](const LogicalType &ty) -> LogicalType { - return IsTypeNotSupported(ty) ? LogicalType::VARCHAR : ty; - }); - - auto cast_expr = BoundCastExpression::AddCastToType(context, std::move(expr), new_type, false); - cast_expr->SetAlias(name); - result.push_back(std::move(cast_expr)); - any_change = true; - } - // Otherwise, just reference the input column - else { - result.push_back(std::move(expr)); - } - } - - // If any change was made, return the new expressions - // otherwise, return an empty vector to indicate no change and avoid pushing another projection on to the plan - if (any_change) { - return result; - } - return {}; -} - -void ParquetExtension::Load(DuckDB &db) { - auto &db_instance = *db.instance; - auto &fs = db.GetFileSystem(); - fs.RegisterSubSystem(FileCompressionType::ZSTD, make_uniq()); - - auto scan_fun = ParquetScanFunction::GetFunctionSet(); - scan_fun.name = "read_parquet"; - ExtensionUtil::RegisterFunction(db_instance, scan_fun); - scan_fun.name = "parquet_scan"; - ExtensionUtil::RegisterFunction(db_instance, scan_fun); - - // parquet_metadata - ParquetMetaDataFunction meta_fun; - ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(meta_fun)); - - // parquet_schema - ParquetSchemaFunction schema_fun; - ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(schema_fun)); - - // parquet_key_value_metadata - ParquetKeyValueMetadataFunction kv_meta_fun; - ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(kv_meta_fun)); - - // parquet_file_metadata - ParquetFileMetadataFunction file_meta_fun; - ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(file_meta_fun)); - - // parquet_bloom_probe - ParquetBloomProbeFunction bloom_probe_fun; - ExtensionUtil::RegisterFunction(db_instance, MultiFileReader::CreateFunctionSet(bloom_probe_fun)); - - CopyFunction function("parquet"); - function.copy_to_select = ParquetWriteSelect; - function.copy_to_bind = ParquetWriteBind; - function.copy_to_initialize_global = ParquetWriteInitializeGlobal; - function.copy_to_initialize_local = ParquetWriteInitializeLocal; - function.copy_to_sink = ParquetWriteSink; - function.copy_to_combine = ParquetWriteCombine; - function.copy_to_finalize = ParquetWriteFinalize; - function.execution_mode = ParquetWriteExecutionMode; - function.copy_from_bind = ParquetScanFunction::ParquetReadBind; - function.copy_from_function = scan_fun.functions[0]; - function.prepare_batch = ParquetWritePrepareBatch; - function.flush_batch = ParquetWriteFlushBatch; - function.desired_batch_size = ParquetWriteDesiredBatchSize; - function.rotate_files = ParquetWriteRotateFiles; - function.rotate_next_file = ParquetWriteRotateNextFile; - function.serialize = ParquetCopySerialize; - function.deserialize = ParquetCopyDeserialize; - - function.extension = "parquet"; - ExtensionUtil::RegisterFunction(db_instance, function); - - // parquet_key - auto parquet_key_fun = PragmaFunction::PragmaCall("add_parquet_key", ParquetCrypto::AddKey, - {LogicalType::VARCHAR, LogicalType::VARCHAR}); - ExtensionUtil::RegisterFunction(db_instance, parquet_key_fun); - - auto &config = DBConfig::GetConfig(*db.instance); - config.replacement_scans.emplace_back(ParquetScanReplacement); - config.AddExtensionOption("binary_as_string", "In Parquet files, interpret binary data as a string.", - LogicalType::BOOLEAN); - config.AddExtensionOption("disable_parquet_prefetching", "Disable the prefetching mechanism in Parquet", - LogicalType::BOOLEAN, Value(false)); - config.AddExtensionOption("prefetch_all_parquet_files", - "Use the prefetching mechanism for all types of parquet files", LogicalType::BOOLEAN, - Value(false)); - config.AddExtensionOption("parquet_metadata_cache", - "Cache Parquet metadata - useful when reading the same files multiple times", - LogicalType::BOOLEAN, Value(false)); - config.AddExtensionOption( - "enable_geoparquet_conversion", - "Attempt to decode/encode geometry data in/as GeoParquet files if the spatial extension is present.", - LogicalType::BOOLEAN, Value::BOOLEAN(true)); -} - -std::string ParquetExtension::Name() { - return "parquet"; -} - -std::string ParquetExtension::Version() const { -#ifdef EXT_VERSION_PARQUET - return EXT_VERSION_PARQUET; -#else - return ""; -#endif -} - -} // namespace duckdb - -#ifdef DUCKDB_BUILD_LOADABLE_EXTENSION -extern "C" { - -DUCKDB_EXTENSION_API void parquet_init(duckdb::DatabaseInstance &db) { // NOLINT - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); -} - -DUCKDB_EXTENSION_API const char *parquet_version() { // NOLINT - return duckdb::DuckDB::LibraryVersion(); -} -} -#endif - -#ifndef DUCKDB_EXTENSION_MAIN -#error DUCKDB_EXTENSION_MAIN not defined -#endif diff --git a/src/duckdb/extension/parquet/parquet_metadata.cpp b/src/duckdb/extension/parquet/parquet_metadata.cpp deleted file mode 100644 index 456c9b3e4..000000000 --- a/src/duckdb/extension/parquet/parquet_metadata.cpp +++ /dev/null @@ -1,821 +0,0 @@ -#include "parquet_metadata.hpp" - -#include "parquet_statistics.hpp" - -#include - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/main/config.hpp" -#endif - -namespace duckdb { - -struct ParquetMetaDataBindData : public TableFunctionData { - vector return_types; - shared_ptr file_list; - unique_ptr multi_file_reader; -}; - -struct ParquetBloomProbeBindData : public ParquetMetaDataBindData { - string probe_column_name; - Value probe_constant; -}; - -enum class ParquetMetadataOperatorType : uint8_t { - META_DATA, - SCHEMA, - KEY_VALUE_META_DATA, - FILE_META_DATA, - BLOOM_PROBE -}; - -struct ParquetMetaDataOperatorData : public GlobalTableFunctionState { - explicit ParquetMetaDataOperatorData(ClientContext &context, const vector &types) - : collection(context, types) { - } - - ColumnDataCollection collection; - ColumnDataScanState scan_state; - - MultiFileListScanData file_list_scan; - string current_file; - -public: - static void BindMetaData(vector &return_types, vector &names); - static void BindSchema(vector &return_types, vector &names); - static void BindKeyValueMetaData(vector &return_types, vector &names); - static void BindFileMetaData(vector &return_types, vector &names); - static void BindBloomProbe(vector &return_types, vector &names); - - void LoadRowGroupMetadata(ClientContext &context, const vector &return_types, const string &file_path); - void LoadSchemaData(ClientContext &context, const vector &return_types, const string &file_path); - void LoadKeyValueMetaData(ClientContext &context, const vector &return_types, const string &file_path); - void LoadFileMetaData(ClientContext &context, const vector &return_types, const string &file_path); - void ExecuteBloomProbe(ClientContext &context, const vector &return_types, const string &file_path, - const string &column_name, const Value &probe); -}; - -template -string ConvertParquetElementToString(T &&entry) { - std::stringstream ss; - ss << entry; - return ss.str(); -} - -template -string PrintParquetElementToString(T &&entry) { - std::stringstream ss; - entry.printTo(ss); - return ss.str(); -} - -template -Value ParquetElementString(T &&value, bool is_set) { - if (!is_set) { - return Value(); - } - return Value(ConvertParquetElementToString(value)); -} - -Value ParquetElementStringVal(const string &value, bool is_set) { - if (!is_set) { - return Value(); - } - return Value(value); -} - -template -Value ParquetElementInteger(T &&value, bool is_iset) { - if (!is_iset) { - return Value(); - } - return Value::INTEGER(value); -} - -template -Value ParquetElementBigint(T &&value, bool is_iset) { - if (!is_iset) { - return Value(); - } - return Value::BIGINT(value); -} - -//===--------------------------------------------------------------------===// -// Row Group Meta Data -//===--------------------------------------------------------------------===// -void ParquetMetaDataOperatorData::BindMetaData(vector &return_types, vector &names) { - names.emplace_back("file_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("row_group_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("row_group_num_rows"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("row_group_num_columns"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("row_group_bytes"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("file_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("num_values"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("path_in_schema"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats_min"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats_max"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats_null_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("stats_distinct_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("stats_min_value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats_max_value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("compression"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("encodings"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("index_page_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("dictionary_page_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("data_page_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_compressed_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_uncompressed_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("key_value_metadata"); - return_types.emplace_back(LogicalType::MAP(LogicalType::BLOB, LogicalType::BLOB)); - - names.emplace_back("bloom_filter_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("bloom_filter_length"); - return_types.emplace_back(LogicalType::BIGINT); -} - -Value ConvertParquetStats(const LogicalType &type, const duckdb_parquet::SchemaElement &schema_ele, bool stats_is_set, - const std::string &stats) { - if (!stats_is_set) { - return Value(LogicalType::VARCHAR); - } - return ParquetStatisticsUtils::ConvertValue(type, schema_ele, stats).DefaultCastAs(LogicalType::VARCHAR); -} - -void ParquetMetaDataOperatorData::LoadRowGroupMetadata(ClientContext &context, const vector &return_types, - const string &file_path) { - collection.Reset(); - ParquetOptions parquet_options(context); - auto reader = make_uniq(context, file_path, parquet_options); - idx_t count = 0; - DataChunk current_chunk; - current_chunk.Initialize(context, return_types); - auto meta_data = reader->GetFileMetadata(); - vector column_types; - vector schema_indexes; - for (idx_t schema_idx = 0; schema_idx < meta_data->schema.size(); schema_idx++) { - auto &schema_element = meta_data->schema[schema_idx]; - if (schema_element.num_children > 0) { - continue; - } - column_types.push_back(ParquetReader::DeriveLogicalType(schema_element, false)); - schema_indexes.push_back(schema_idx); - } - - for (idx_t row_group_idx = 0; row_group_idx < meta_data->row_groups.size(); row_group_idx++) { - auto &row_group = meta_data->row_groups[row_group_idx]; - - if (row_group.columns.size() > column_types.size()) { - throw InternalException("Too many column in row group: corrupt file?"); - } - for (idx_t col_idx = 0; col_idx < row_group.columns.size(); col_idx++) { - auto &column = row_group.columns[col_idx]; - auto &col_meta = column.meta_data; - auto &stats = col_meta.statistics; - auto &schema_element = meta_data->schema[schema_indexes[col_idx]]; - auto &column_type = column_types[col_idx]; - - // file_name, LogicalType::VARCHAR - current_chunk.SetValue(0, count, file_path); - - // row_group_id, LogicalType::BIGINT - current_chunk.SetValue(1, count, Value::BIGINT(UnsafeNumericCast(row_group_idx))); - - // row_group_num_rows, LogicalType::BIGINT - current_chunk.SetValue(2, count, Value::BIGINT(row_group.num_rows)); - - // row_group_num_columns, LogicalType::BIGINT - current_chunk.SetValue(3, count, Value::BIGINT(UnsafeNumericCast(row_group.columns.size()))); - - // row_group_bytes, LogicalType::BIGINT - current_chunk.SetValue(4, count, Value::BIGINT(row_group.total_byte_size)); - - // column_id, LogicalType::BIGINT - current_chunk.SetValue(5, count, Value::BIGINT(UnsafeNumericCast(col_idx))); - - // file_offset, LogicalType::BIGINT - current_chunk.SetValue(6, count, ParquetElementBigint(column.file_offset, row_group.__isset.file_offset)); - - // num_values, LogicalType::BIGINT - current_chunk.SetValue(7, count, Value::BIGINT(col_meta.num_values)); - - // path_in_schema, LogicalType::VARCHAR - current_chunk.SetValue(8, count, StringUtil::Join(col_meta.path_in_schema, ", ")); - - // type, LogicalType::VARCHAR - current_chunk.SetValue(9, count, ConvertParquetElementToString(col_meta.type)); - - // stats_min, LogicalType::VARCHAR - current_chunk.SetValue(10, count, - ConvertParquetStats(column_type, schema_element, stats.__isset.min, stats.min)); - - // stats_max, LogicalType::VARCHAR - current_chunk.SetValue(11, count, - ConvertParquetStats(column_type, schema_element, stats.__isset.max, stats.max)); - - // stats_null_count, LogicalType::BIGINT - current_chunk.SetValue(12, count, ParquetElementBigint(stats.null_count, stats.__isset.null_count)); - - // stats_distinct_count, LogicalType::BIGINT - current_chunk.SetValue(13, count, ParquetElementBigint(stats.distinct_count, stats.__isset.distinct_count)); - - // stats_min_value, LogicalType::VARCHAR - current_chunk.SetValue( - 14, count, ConvertParquetStats(column_type, schema_element, stats.__isset.min_value, stats.min_value)); - - // stats_max_value, LogicalType::VARCHAR - current_chunk.SetValue( - 15, count, ConvertParquetStats(column_type, schema_element, stats.__isset.max_value, stats.max_value)); - - // compression, LogicalType::VARCHAR - current_chunk.SetValue(16, count, ConvertParquetElementToString(col_meta.codec)); - - // encodings, LogicalType::VARCHAR - vector encoding_string; - encoding_string.reserve(col_meta.encodings.size()); - for (auto &encoding : col_meta.encodings) { - encoding_string.push_back(ConvertParquetElementToString(encoding)); - } - current_chunk.SetValue(17, count, Value(StringUtil::Join(encoding_string, ", "))); - - // index_page_offset, LogicalType::BIGINT - current_chunk.SetValue( - 18, count, ParquetElementBigint(col_meta.index_page_offset, col_meta.__isset.index_page_offset)); - - // dictionary_page_offset, LogicalType::BIGINT - current_chunk.SetValue( - 19, count, - ParquetElementBigint(col_meta.dictionary_page_offset, col_meta.__isset.dictionary_page_offset)); - - // data_page_offset, LogicalType::BIGINT - current_chunk.SetValue(20, count, Value::BIGINT(col_meta.data_page_offset)); - - // total_compressed_size, LogicalType::BIGINT - current_chunk.SetValue(21, count, Value::BIGINT(col_meta.total_compressed_size)); - - // total_uncompressed_size, LogicalType::BIGINT - current_chunk.SetValue(22, count, Value::BIGINT(col_meta.total_uncompressed_size)); - - // key_value_metadata, LogicalType::MAP(LogicalType::BLOB, LogicalType::BLOB) - vector map_keys, map_values; - for (auto &entry : col_meta.key_value_metadata) { - map_keys.push_back(Value::BLOB_RAW(entry.key)); - map_values.push_back(Value::BLOB_RAW(entry.value)); - } - current_chunk.SetValue( - 23, count, - Value::MAP(LogicalType::BLOB, LogicalType::BLOB, std::move(map_keys), std::move(map_values))); - - // bloom_filter_offset, LogicalType::BIGINT - current_chunk.SetValue( - 24, count, ParquetElementBigint(col_meta.bloom_filter_offset, col_meta.__isset.bloom_filter_offset)); - - // bloom_filter_length, LogicalType::BIGINT - current_chunk.SetValue( - 25, count, ParquetElementBigint(col_meta.bloom_filter_length, col_meta.__isset.bloom_filter_length)); - - count++; - if (count >= STANDARD_VECTOR_SIZE) { - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - count = 0; - current_chunk.Reset(); - } - } - } - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - collection.InitializeScan(scan_state); -} - -//===--------------------------------------------------------------------===// -// Schema Data -//===--------------------------------------------------------------------===// -void ParquetMetaDataOperatorData::BindSchema(vector &return_types, vector &names) { - names.emplace_back("file_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type_length"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("repetition_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("num_children"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("converted_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("scale"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("precision"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("field_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("logical_type"); - return_types.emplace_back(LogicalType::VARCHAR); -} - -Value ParquetLogicalTypeToString(const duckdb_parquet::LogicalType &type, bool is_set) { - if (!is_set) { - return Value(); - } - if (type.__isset.STRING) { - return Value(PrintParquetElementToString(type.STRING)); - } - if (type.__isset.MAP) { - return Value(PrintParquetElementToString(type.MAP)); - } - if (type.__isset.LIST) { - return Value(PrintParquetElementToString(type.LIST)); - } - if (type.__isset.ENUM) { - return Value(PrintParquetElementToString(type.ENUM)); - } - if (type.__isset.DECIMAL) { - return Value(PrintParquetElementToString(type.DECIMAL)); - } - if (type.__isset.DATE) { - return Value(PrintParquetElementToString(type.DATE)); - } - if (type.__isset.TIME) { - return Value(PrintParquetElementToString(type.TIME)); - } - if (type.__isset.TIMESTAMP) { - return Value(PrintParquetElementToString(type.TIMESTAMP)); - } - if (type.__isset.INTEGER) { - return Value(PrintParquetElementToString(type.INTEGER)); - } - if (type.__isset.UNKNOWN) { - return Value(PrintParquetElementToString(type.UNKNOWN)); - } - if (type.__isset.JSON) { - return Value(PrintParquetElementToString(type.JSON)); - } - if (type.__isset.BSON) { - return Value(PrintParquetElementToString(type.BSON)); - } - if (type.__isset.UUID) { - return Value(PrintParquetElementToString(type.UUID)); - } - return Value(); -} - -void ParquetMetaDataOperatorData::LoadSchemaData(ClientContext &context, const vector &return_types, - const string &file_path) { - collection.Reset(); - ParquetOptions parquet_options(context); - auto reader = make_uniq(context, file_path, parquet_options); - idx_t count = 0; - DataChunk current_chunk; - current_chunk.Initialize(context, return_types); - auto meta_data = reader->GetFileMetadata(); - for (idx_t col_idx = 0; col_idx < meta_data->schema.size(); col_idx++) { - auto &column = meta_data->schema[col_idx]; - - // file_name, LogicalType::VARCHAR - current_chunk.SetValue(0, count, file_path); - - // name, LogicalType::VARCHAR - current_chunk.SetValue(1, count, column.name); - - // type, LogicalType::VARCHAR - current_chunk.SetValue(2, count, ParquetElementString(column.type, column.__isset.type)); - - // type_length, LogicalType::INTEGER - current_chunk.SetValue(3, count, ParquetElementInteger(column.type_length, column.__isset.type_length)); - - // repetition_type, LogicalType::VARCHAR - current_chunk.SetValue(4, count, ParquetElementString(column.repetition_type, column.__isset.repetition_type)); - - // num_children, LogicalType::BIGINT - current_chunk.SetValue(5, count, ParquetElementBigint(column.num_children, column.__isset.num_children)); - - // converted_type, LogicalType::VARCHAR - current_chunk.SetValue(6, count, ParquetElementString(column.converted_type, column.__isset.converted_type)); - - // scale, LogicalType::BIGINT - current_chunk.SetValue(7, count, ParquetElementBigint(column.scale, column.__isset.scale)); - - // precision, LogicalType::BIGINT - current_chunk.SetValue(8, count, ParquetElementBigint(column.precision, column.__isset.precision)); - - // field_id, LogicalType::BIGINT - current_chunk.SetValue(9, count, ParquetElementBigint(column.field_id, column.__isset.field_id)); - - // logical_type, LogicalType::VARCHAR - current_chunk.SetValue(10, count, ParquetLogicalTypeToString(column.logicalType, column.__isset.logicalType)); - - count++; - if (count >= STANDARD_VECTOR_SIZE) { - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - count = 0; - current_chunk.Reset(); - } - } - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - collection.InitializeScan(scan_state); -} - -//===--------------------------------------------------------------------===// -// KV Meta Data -//===--------------------------------------------------------------------===// -void ParquetMetaDataOperatorData::BindKeyValueMetaData(vector &return_types, vector &names) { - names.emplace_back("file_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("key"); - return_types.emplace_back(LogicalType::BLOB); - - names.emplace_back("value"); - return_types.emplace_back(LogicalType::BLOB); -} - -void ParquetMetaDataOperatorData::LoadKeyValueMetaData(ClientContext &context, const vector &return_types, - const string &file_path) { - collection.Reset(); - ParquetOptions parquet_options(context); - auto reader = make_uniq(context, file_path, parquet_options); - idx_t count = 0; - DataChunk current_chunk; - current_chunk.Initialize(context, return_types); - auto meta_data = reader->GetFileMetadata(); - - for (idx_t col_idx = 0; col_idx < meta_data->key_value_metadata.size(); col_idx++) { - auto &entry = meta_data->key_value_metadata[col_idx]; - - current_chunk.SetValue(0, count, Value(file_path)); - current_chunk.SetValue(1, count, Value::BLOB_RAW(entry.key)); - current_chunk.SetValue(2, count, Value::BLOB_RAW(entry.value)); - - count++; - if (count >= STANDARD_VECTOR_SIZE) { - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - count = 0; - current_chunk.Reset(); - } - } - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - collection.InitializeScan(scan_state); -} - -//===--------------------------------------------------------------------===// -// File Meta Data -//===--------------------------------------------------------------------===// -void ParquetMetaDataOperatorData::BindFileMetaData(vector &return_types, vector &names) { - names.emplace_back("file_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("created_by"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("num_rows"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("num_row_groups"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("format_version"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("encryption_algorithm"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("footer_signing_key_metadata"); - return_types.emplace_back(LogicalType::VARCHAR); -} - -void ParquetMetaDataOperatorData::LoadFileMetaData(ClientContext &context, const vector &return_types, - const string &file_path) { - collection.Reset(); - ParquetOptions parquet_options(context); - auto reader = make_uniq(context, file_path, parquet_options); - DataChunk current_chunk; - current_chunk.Initialize(context, return_types); - auto meta_data = reader->GetFileMetadata(); - - // file_name - current_chunk.SetValue(0, 0, Value(file_path)); - // created_by - current_chunk.SetValue(1, 0, ParquetElementStringVal(meta_data->created_by, meta_data->__isset.created_by)); - // num_rows - current_chunk.SetValue(2, 0, Value::BIGINT(meta_data->num_rows)); - // num_row_groups - current_chunk.SetValue(3, 0, Value::BIGINT(UnsafeNumericCast(meta_data->row_groups.size()))); - // format_version - current_chunk.SetValue(4, 0, Value::BIGINT(meta_data->version)); - // encryption_algorithm - current_chunk.SetValue( - 5, 0, ParquetElementString(meta_data->encryption_algorithm, meta_data->__isset.encryption_algorithm)); - // footer_signing_key_metadata - current_chunk.SetValue(6, 0, - ParquetElementStringVal(meta_data->footer_signing_key_metadata, - meta_data->__isset.footer_signing_key_metadata)); - current_chunk.SetCardinality(1); - collection.Append(current_chunk); - collection.InitializeScan(scan_state); -} - -//===--------------------------------------------------------------------===// -// Bloom Probe -//===--------------------------------------------------------------------===// -void ParquetMetaDataOperatorData::BindBloomProbe(vector &return_types, vector &names) { - names.emplace_back("file_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("row_group_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("bloom_filter_excludes"); - return_types.emplace_back(LogicalType::BOOLEAN); -} - -void ParquetMetaDataOperatorData::ExecuteBloomProbe(ClientContext &context, const vector &return_types, - const string &file_path, const string &column_name, - const Value &probe) { - collection.Reset(); - ParquetOptions parquet_options(context); - auto reader = make_uniq(context, file_path, parquet_options); - idx_t count = 0; - DataChunk current_chunk; - current_chunk.Initialize(context, return_types); - auto meta_data = reader->GetFileMetadata(); - - optional_idx probe_column_idx; - for (idx_t column_idx = 0; column_idx < reader->names.size(); column_idx++) { - if (reader->names[column_idx] == column_name) { - probe_column_idx = column_idx; - } - } - - if (!probe_column_idx.IsValid()) { - throw InvalidInputException("Column %s not found in %s", column_name, file_path); - } - - auto &allocator = Allocator::DefaultAllocator(); - auto transport = std::make_shared(allocator, reader->GetHandle(), false); - auto protocol = - make_uniq>(std::move(transport)); - - D_ASSERT(!probe.IsNull()); - ConstantFilter filter(ExpressionType::COMPARE_EQUAL, - probe.CastAs(context, reader->GetTypes()[probe_column_idx.GetIndex()])); - - for (idx_t row_group_idx = 0; row_group_idx < meta_data->row_groups.size(); row_group_idx++) { - auto &row_group = meta_data->row_groups[row_group_idx]; - auto &column = row_group.columns[probe_column_idx.GetIndex()]; - - auto bloom_excludes = - ParquetStatisticsUtils::BloomFilterExcludes(filter, column.meta_data, *protocol, allocator); - current_chunk.SetValue(0, count, Value(file_path)); - current_chunk.SetValue(1, count, Value::BIGINT(NumericCast(row_group_idx))); - current_chunk.SetValue(2, count, Value::BOOLEAN(bloom_excludes)); - - count++; - if (count >= STANDARD_VECTOR_SIZE) { - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - - count = 0; - current_chunk.Reset(); - } - } - - current_chunk.SetCardinality(count); - collection.Append(current_chunk); - collection.InitializeScan(scan_state); -} - -//===--------------------------------------------------------------------===// -// Bind -//===--------------------------------------------------------------------===// -template -unique_ptr ParquetMetaDataBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - - switch (TYPE) { - case ParquetMetadataOperatorType::SCHEMA: - ParquetMetaDataOperatorData::BindSchema(return_types, names); - break; - case ParquetMetadataOperatorType::META_DATA: - ParquetMetaDataOperatorData::BindMetaData(return_types, names); - break; - case ParquetMetadataOperatorType::KEY_VALUE_META_DATA: - ParquetMetaDataOperatorData::BindKeyValueMetaData(return_types, names); - break; - case ParquetMetadataOperatorType::FILE_META_DATA: - ParquetMetaDataOperatorData::BindFileMetaData(return_types, names); - break; - case ParquetMetadataOperatorType::BLOOM_PROBE: { - auto probe_bind_data = make_uniq(); - D_ASSERT(input.inputs.size() == 3); - if (input.inputs[1].IsNull() || input.inputs[2].IsNull()) { - throw InvalidInputException("Can't have NULL parameters for parquet_bloom_probe"); - } - probe_bind_data->probe_column_name = input.inputs[1].CastAs(context, LogicalType::VARCHAR).GetValue(); - probe_bind_data->probe_constant = input.inputs[2]; - result = std::move(probe_bind_data); - ParquetMetaDataOperatorData::BindBloomProbe(return_types, names); - break; - } - default: - throw InternalException("Unsupported ParquetMetadataOperatorType"); - } - - result->return_types = return_types; - result->multi_file_reader = MultiFileReader::Create(input.table_function); - result->file_list = result->multi_file_reader->CreateFileList(context, input.inputs[0]); - - return std::move(result); -} - -template -unique_ptr ParquetMetaDataInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - - auto result = make_uniq(context, bind_data.return_types); - - bind_data.file_list->InitializeScan(result->file_list_scan); - bind_data.file_list->Scan(result->file_list_scan, result->current_file); - - D_ASSERT(!bind_data.file_list->IsEmpty()); - - switch (TYPE) { - case ParquetMetadataOperatorType::SCHEMA: - result->LoadSchemaData(context, bind_data.return_types, bind_data.file_list->GetFirstFile()); - break; - case ParquetMetadataOperatorType::META_DATA: - result->LoadRowGroupMetadata(context, bind_data.return_types, bind_data.file_list->GetFirstFile()); - break; - case ParquetMetadataOperatorType::KEY_VALUE_META_DATA: - result->LoadKeyValueMetaData(context, bind_data.return_types, bind_data.file_list->GetFirstFile()); - break; - case ParquetMetadataOperatorType::FILE_META_DATA: - result->LoadFileMetaData(context, bind_data.return_types, bind_data.file_list->GetFirstFile()); - break; - case ParquetMetadataOperatorType::BLOOM_PROBE: { - auto &bloom_probe_bind_data = input.bind_data->Cast(); - result->ExecuteBloomProbe(context, bind_data.return_types, bind_data.file_list->GetFirstFile(), - bloom_probe_bind_data.probe_column_name, bloom_probe_bind_data.probe_constant); - break; - } - default: - throw InternalException("Unsupported ParquetMetadataOperatorType"); - } - - return std::move(result); -} - -template -void ParquetMetaDataImplementation(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - auto &bind_data = data_p.bind_data->Cast(); - - while (true) { - if (!data.collection.Scan(data.scan_state, output)) { - - // Try get next file - if (!bind_data.file_list->Scan(data.file_list_scan, data.current_file)) { - return; - } - - switch (TYPE) { - case ParquetMetadataOperatorType::SCHEMA: - data.LoadSchemaData(context, bind_data.return_types, data.current_file); - break; - case ParquetMetadataOperatorType::META_DATA: - data.LoadRowGroupMetadata(context, bind_data.return_types, data.current_file); - break; - case ParquetMetadataOperatorType::KEY_VALUE_META_DATA: - data.LoadKeyValueMetaData(context, bind_data.return_types, data.current_file); - break; - case ParquetMetadataOperatorType::FILE_META_DATA: - data.LoadFileMetaData(context, bind_data.return_types, data.current_file); - break; - case ParquetMetadataOperatorType::BLOOM_PROBE: { - auto &bloom_probe_bind_data = data_p.bind_data->Cast(); - data.ExecuteBloomProbe(context, bind_data.return_types, bind_data.file_list->GetFirstFile(), - bloom_probe_bind_data.probe_column_name, bloom_probe_bind_data.probe_constant); - break; - } - default: - throw InternalException("Unsupported ParquetMetadataOperatorType"); - } - continue; - } - if (output.size() != 0) { - return; - } - } -} - -ParquetMetaDataFunction::ParquetMetaDataFunction() - : TableFunction("parquet_metadata", {LogicalType::VARCHAR}, - ParquetMetaDataImplementation, - ParquetMetaDataBind, - ParquetMetaDataInit) { -} - -ParquetSchemaFunction::ParquetSchemaFunction() - : TableFunction("parquet_schema", {LogicalType::VARCHAR}, - ParquetMetaDataImplementation, - ParquetMetaDataBind, - ParquetMetaDataInit) { -} - -ParquetKeyValueMetadataFunction::ParquetKeyValueMetadataFunction() - : TableFunction("parquet_kv_metadata", {LogicalType::VARCHAR}, - ParquetMetaDataImplementation, - ParquetMetaDataBind, - ParquetMetaDataInit) { -} - -ParquetFileMetadataFunction::ParquetFileMetadataFunction() - : TableFunction("parquet_file_metadata", {LogicalType::VARCHAR}, - ParquetMetaDataImplementation, - ParquetMetaDataBind, - ParquetMetaDataInit) { -} - -ParquetBloomProbeFunction::ParquetBloomProbeFunction() - : TableFunction("parquet_bloom_probe", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::ANY}, - ParquetMetaDataImplementation, - ParquetMetaDataBind, - ParquetMetaDataInit) { -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_reader.cpp b/src/duckdb/extension/parquet/parquet_reader.cpp deleted file mode 100644 index 1d2565a2d..000000000 --- a/src/duckdb/extension/parquet/parquet_reader.cpp +++ /dev/null @@ -1,1243 +0,0 @@ -#include "parquet_reader.hpp" - -#include "boolean_column_reader.hpp" -#include "callback_column_reader.hpp" -#include "cast_column_reader.hpp" -#include "column_reader.hpp" -#include "duckdb.hpp" -#include "expression_column_reader.hpp" -#include "geo_parquet.hpp" -#include "list_column_reader.hpp" -#include "parquet_crypto.hpp" -#include "parquet_file_metadata_cache.hpp" -#include "parquet_statistics.hpp" -#include "parquet_timestamp.hpp" -#include "mbedtls_wrapper.hpp" -#include "row_number_column_reader.hpp" -#include "string_column_reader.hpp" -#include "struct_column_reader.hpp" -#include "templated_column_reader.hpp" -#include "thrift_tools.hpp" -#include "duckdb/main/config.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/encryption_state.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/struct_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/storage/object_cache.hpp" -#endif - -#include -#include -#include -#include - -namespace duckdb { - -using duckdb_parquet::ColumnChunk; -using duckdb_parquet::ConvertedType; -using duckdb_parquet::FieldRepetitionType; -using duckdb_parquet::FileCryptoMetaData; -using duckdb_parquet::FileMetaData; -using ParquetRowGroup = duckdb_parquet::RowGroup; -using duckdb_parquet::SchemaElement; -using duckdb_parquet::Statistics; -using duckdb_parquet::Type; - -static unique_ptr -CreateThriftFileProtocol(Allocator &allocator, FileHandle &file_handle, bool prefetch_mode) { - auto transport = std::make_shared(allocator, file_handle, prefetch_mode); - return make_uniq>(std::move(transport)); -} - -static shared_ptr -LoadMetadata(ClientContext &context, Allocator &allocator, FileHandle &file_handle, - const shared_ptr &encryption_config, - const EncryptionUtil &encryption_util) { - auto current_time = std::chrono::system_clock::to_time_t(std::chrono::system_clock::now()); - - auto file_proto = CreateThriftFileProtocol(allocator, file_handle, false); - auto &transport = reinterpret_cast(*file_proto->getTransport()); - auto file_size = transport.GetSize(); - if (file_size < 12) { - throw InvalidInputException("File '%s' too small to be a Parquet file", file_handle.path); - } - - ResizeableBuffer buf; - buf.resize(allocator, 8); - buf.zero(); - - transport.SetLocation(file_size - 8); - transport.read(buf.ptr, 8); - - bool footer_encrypted; - if (memcmp(buf.ptr + 4, "PAR1", 4) == 0) { - footer_encrypted = false; - if (encryption_config) { - throw InvalidInputException("File '%s' is not encrypted, but 'encryption_config' was set", - file_handle.path); - } - } else if (memcmp(buf.ptr + 4, "PARE", 4) == 0) { - footer_encrypted = true; - if (!encryption_config) { - throw InvalidInputException("File '%s' is encrypted, but 'encryption_config' was not set", - file_handle.path); - } - } else { - throw InvalidInputException("No magic bytes found at end of file '%s'", file_handle.path); - } - - // read four-byte footer length from just before the end magic bytes - auto footer_len = *reinterpret_cast(buf.ptr); - if (footer_len == 0 || file_size < 12 + footer_len) { - throw InvalidInputException("Footer length error in file '%s'", file_handle.path); - } - - auto metadata_pos = file_size - (footer_len + 8); - transport.SetLocation(metadata_pos); - transport.Prefetch(metadata_pos, footer_len); - - auto metadata = make_uniq(); - if (footer_encrypted) { - auto crypto_metadata = make_uniq(); - crypto_metadata->read(file_proto.get()); - if (crypto_metadata->encryption_algorithm.__isset.AES_GCM_CTR_V1) { - throw InvalidInputException("File '%s' is encrypted with AES_GCM_CTR_V1, but only AES_GCM_V1 is supported", - file_handle.path); - } - ParquetCrypto::Read(*metadata, *file_proto, encryption_config->GetFooterKey(), encryption_util); - } else { - metadata->read(file_proto.get()); - } - - // Try to read the GeoParquet metadata (if present) - auto geo_metadata = GeoParquetFileMetadata::TryRead(*metadata, context); - - return make_shared_ptr(std::move(metadata), current_time, std::move(geo_metadata)); -} - -LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string) { - // inner node - if (s_ele.type == Type::FIXED_LEN_BYTE_ARRAY && !s_ele.__isset.type_length) { - throw IOException("FIXED_LEN_BYTE_ARRAY requires length to be set"); - } - if (s_ele.__isset.logicalType) { - if (s_ele.logicalType.__isset.UUID) { - if (s_ele.type == Type::FIXED_LEN_BYTE_ARRAY) { - return LogicalType::UUID; - } - } else if (s_ele.logicalType.__isset.TIMESTAMP) { - if (s_ele.logicalType.TIMESTAMP.isAdjustedToUTC) { - return LogicalType::TIMESTAMP_TZ; - } else if (s_ele.logicalType.TIMESTAMP.unit.__isset.NANOS) { - return LogicalType::TIMESTAMP_NS; - } - return LogicalType::TIMESTAMP; - } else if (s_ele.logicalType.__isset.TIME) { - if (s_ele.logicalType.TIME.isAdjustedToUTC) { - return LogicalType::TIME_TZ; - } - return LogicalType::TIME; - } - } - if (s_ele.__isset.converted_type) { - // Legacy NULL type, does no longer exist, but files are still around of course - if (static_cast(s_ele.converted_type) == 24) { - return LogicalTypeId::SQLNULL; - } - switch (s_ele.converted_type) { - case ConvertedType::INT_8: - if (s_ele.type == Type::INT32) { - return LogicalType::TINYINT; - } else { - throw IOException("INT8 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::INT_16: - if (s_ele.type == Type::INT32) { - return LogicalType::SMALLINT; - } else { - throw IOException("INT16 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::INT_32: - if (s_ele.type == Type::INT32) { - return LogicalType::INTEGER; - } else { - throw IOException("INT32 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::INT_64: - if (s_ele.type == Type::INT64) { - return LogicalType::BIGINT; - } else { - throw IOException("INT64 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::UINT_8: - if (s_ele.type == Type::INT32) { - return LogicalType::UTINYINT; - } else { - throw IOException("UINT8 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::UINT_16: - if (s_ele.type == Type::INT32) { - return LogicalType::USMALLINT; - } else { - throw IOException("UINT16 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::UINT_32: - if (s_ele.type == Type::INT32) { - return LogicalType::UINTEGER; - } else { - throw IOException("UINT32 converted type can only be set for value of Type::INT32"); - } - case ConvertedType::UINT_64: - if (s_ele.type == Type::INT64) { - return LogicalType::UBIGINT; - } else { - throw IOException("UINT64 converted type can only be set for value of Type::INT64"); - } - case ConvertedType::DATE: - if (s_ele.type == Type::INT32) { - return LogicalType::DATE; - } else { - throw IOException("DATE converted type can only be set for value of Type::INT32"); - } - case ConvertedType::TIMESTAMP_MICROS: - case ConvertedType::TIMESTAMP_MILLIS: - if (s_ele.type == Type::INT64) { - return LogicalType::TIMESTAMP; - } else { - throw IOException("TIMESTAMP converted type can only be set for value of Type::INT64"); - } - case ConvertedType::DECIMAL: - if (!s_ele.__isset.precision || !s_ele.__isset.scale) { - throw IOException("DECIMAL requires a length and scale specifier!"); - } - if (s_ele.precision > DecimalType::MaxWidth()) { - return LogicalType::DOUBLE; - } - switch (s_ele.type) { - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - case Type::INT32: - case Type::INT64: - return LogicalType::DECIMAL(s_ele.precision, s_ele.scale); - default: - throw IOException( - "DECIMAL converted type can only be set for value of Type::(FIXED_LEN_)BYTE_ARRAY/INT32/INT64"); - } - case ConvertedType::UTF8: - case ConvertedType::ENUM: - switch (s_ele.type) { - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - return LogicalType::VARCHAR; - default: - throw IOException("UTF8 converted type can only be set for Type::(FIXED_LEN_)BYTE_ARRAY"); - } - case ConvertedType::TIME_MILLIS: - if (s_ele.type == Type::INT32) { - return LogicalType::TIME; - } else { - throw IOException("TIME_MILLIS converted type can only be set for value of Type::INT32"); - } - case ConvertedType::TIME_MICROS: - if (s_ele.type == Type::INT64) { - return LogicalType::TIME; - } else { - throw IOException("TIME_MICROS converted type can only be set for value of Type::INT64"); - } - case ConvertedType::INTERVAL: - return LogicalType::INTERVAL; - case ConvertedType::JSON: - return LogicalType::JSON(); - case ConvertedType::MAP: - case ConvertedType::MAP_KEY_VALUE: - case ConvertedType::LIST: - case ConvertedType::BSON: - default: - throw IOException("Unsupported converted type (%d)", (int32_t)s_ele.converted_type); - } - } else { - // no converted type set - // use default type for each physical type - switch (s_ele.type) { - case Type::BOOLEAN: - return LogicalType::BOOLEAN; - case Type::INT32: - return LogicalType::INTEGER; - case Type::INT64: - return LogicalType::BIGINT; - case Type::INT96: // always a timestamp it would seem - return LogicalType::TIMESTAMP; - case Type::FLOAT: - return LogicalType::FLOAT; - case Type::DOUBLE: - return LogicalType::DOUBLE; - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - if (binary_as_string) { - return LogicalType::VARCHAR; - } - return LogicalType::BLOB; - default: - return LogicalType::INVALID; - } - } -} - -LogicalType ParquetReader::DeriveLogicalType(const SchemaElement &s_ele) { - return DeriveLogicalType(s_ele, parquet_options.binary_as_string); -} - -unique_ptr ParquetReader::CreateReaderRecursive(ClientContext &context, - const vector &indexes, idx_t depth, - idx_t max_define, idx_t max_repeat, - idx_t &next_schema_idx, idx_t &next_file_idx) { - auto file_meta_data = GetFileMetadata(); - D_ASSERT(file_meta_data); - D_ASSERT(next_schema_idx < file_meta_data->schema.size()); - auto &s_ele = file_meta_data->schema[next_schema_idx]; - auto this_idx = next_schema_idx; - - auto repetition_type = FieldRepetitionType::REQUIRED; - if (s_ele.__isset.repetition_type && this_idx > 0) { - repetition_type = s_ele.repetition_type; - } - if (repetition_type != FieldRepetitionType::REQUIRED) { - max_define++; - } - if (repetition_type == FieldRepetitionType::REPEATED) { - max_repeat++; - } - - // Check for geoparquet spatial types - if (depth == 1) { - // geoparquet types have to be at the root of the schema, and have to be present in the kv metadata - if (metadata->geo_metadata && metadata->geo_metadata->IsGeometryColumn(s_ele.name)) { - return metadata->geo_metadata->CreateColumnReader(*this, DeriveLogicalType(s_ele), s_ele, next_file_idx++, - max_define, max_repeat, context); - } - } - - if (s_ele.__isset.num_children && s_ele.num_children > 0) { // inner node - child_list_t child_types; - vector> child_readers; - // this type is a nested type - it has child columns specified - // create a mapping for which column readers we should create - unordered_map> required_readers; - for (auto &index : indexes) { - required_readers.insert(make_pair(index.GetPrimaryIndex(), index.GetChildIndexes())); - } - - idx_t c_idx = 0; - while (c_idx < (idx_t)s_ele.num_children) { - next_schema_idx++; - - auto &child_ele = file_meta_data->schema[next_schema_idx]; - - // figure out which child columns we should read of this child column - vector child_indexes; - auto entry = required_readers.find(c_idx); - if (entry != required_readers.end()) { - child_indexes = entry->second; - } - auto child_reader = CreateReaderRecursive(context, child_indexes, depth + 1, max_define, max_repeat, - next_schema_idx, next_file_idx); - child_types.push_back(make_pair(child_ele.name, child_reader->Type())); - if (indexes.empty() || entry != required_readers.end()) { - // either (1) indexes is empty, meaning we need to read all child columns, or (2) we need to read this - // column - child_readers.push_back(std::move(child_reader)); - } else { - // this column was explicitly not required - push an empty child reader here - child_readers.push_back(nullptr); - } - - c_idx++; - } - // rename child type entries if there are case-insensitive duplicates by appending _1, _2 etc. - // behavior consistent with CSV reader fwiw - case_insensitive_map_t name_collision_count; - // get header names from CSV - for (auto &child_type : child_types) { - auto col_name = child_type.first; - // avoid duplicate header names - while (name_collision_count.find(col_name) != name_collision_count.end()) { - name_collision_count[col_name] += 1; - col_name = col_name + "_" + to_string(name_collision_count[col_name]); - } - child_type.first = col_name; - name_collision_count[col_name] = 0; - } - - D_ASSERT(!child_types.empty()); - unique_ptr result; - LogicalType result_type; - - bool is_repeated = repetition_type == FieldRepetitionType::REPEATED; - bool is_list = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::LIST; - bool is_map = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP; - bool is_map_kv = s_ele.__isset.converted_type && s_ele.converted_type == ConvertedType::MAP_KEY_VALUE; - if (!is_map_kv && this_idx > 0) { - // check if the parent node of this is a map - auto &p_ele = file_meta_data->schema[this_idx - 1]; - bool parent_is_map = p_ele.__isset.converted_type && p_ele.converted_type == ConvertedType::MAP; - bool parent_has_children = p_ele.__isset.num_children && p_ele.num_children == 1; - is_map_kv = parent_is_map && parent_has_children; - } - - if (is_map_kv) { - if (child_types.size() != 2) { - throw IOException("MAP_KEY_VALUE requires two children"); - } - if (!is_repeated) { - throw IOException("MAP_KEY_VALUE needs to be repeated"); - } - result_type = LogicalType::MAP(std::move(child_types[0].second), std::move(child_types[1].second)); - - auto struct_reader = - make_uniq(*this, ListType::GetChildType(result_type), s_ele, this_idx, - max_define - 1, max_repeat - 1, std::move(child_readers)); - return make_uniq(*this, result_type, s_ele, this_idx, max_define, max_repeat, - std::move(struct_reader)); - } - if (child_types.size() > 1 || (!is_list && !is_map && !is_repeated)) { - result_type = LogicalType::STRUCT(child_types); - result = make_uniq(*this, result_type, s_ele, this_idx, max_define, max_repeat, - std::move(child_readers)); - } else { - // if we have a struct with only a single type, pull up - result_type = child_types[0].second; - result = std::move(child_readers[0]); - } - if (is_repeated) { - result_type = LogicalType::LIST(result_type); - result = make_uniq(*this, result_type, s_ele, this_idx, max_define, max_repeat, - std::move(result)); - } - result->SetParentSchema(s_ele); - return result; - } else { // leaf node - if (!s_ele.__isset.type) { - throw InvalidInputException( - "Node has neither num_children nor type set - this violates the Parquet spec (corrupted file)"); - } - if (s_ele.repetition_type == FieldRepetitionType::REPEATED) { - const auto derived_type = DeriveLogicalType(s_ele); - auto list_type = LogicalType::LIST(derived_type); - - auto element_reader = - ColumnReader::CreateReader(*this, derived_type, s_ele, next_file_idx++, max_define, max_repeat); - - return make_uniq(*this, list_type, s_ele, this_idx, max_define, max_repeat, - std::move(element_reader)); - } - // TODO check return value of derive type or should we only do this on read() - return ColumnReader::CreateReader(*this, DeriveLogicalType(s_ele), s_ele, next_file_idx++, max_define, - max_repeat); - } -} - -// TODO we don't need readers for columns we are not going to read ay -unique_ptr ParquetReader::CreateReader(ClientContext &context) { - auto file_meta_data = GetFileMetadata(); - idx_t next_schema_idx = 0; - idx_t next_file_idx = 0; - - if (file_meta_data->schema.empty()) { - throw IOException("Parquet reader: no schema elements found"); - } - if (file_meta_data->schema[0].num_children == 0) { - throw IOException("Parquet reader: root schema element has no children"); - } - auto ret = CreateReaderRecursive(context, reader_data.column_indexes, 0, 0, 0, next_schema_idx, next_file_idx); - if (ret->Type().id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("Root element of Parquet file must be a struct"); - } - D_ASSERT(next_schema_idx == file_meta_data->schema.size() - 1); - D_ASSERT(file_meta_data->row_groups.empty() || next_file_idx == file_meta_data->row_groups[0].columns.size()); - - auto &root_struct_reader = ret->Cast(); - // add casts if required - for (auto &entry : reader_data.cast_map) { - auto column_idx = entry.first; - auto &expected_type = entry.second; - auto child_reader = std::move(root_struct_reader.child_readers[column_idx]); - auto cast_reader = make_uniq(std::move(child_reader), expected_type); - root_struct_reader.child_readers[column_idx] = std::move(cast_reader); - } - if (parquet_options.file_row_number) { - file_row_number_idx = root_struct_reader.child_readers.size(); - - generated_column_schema.push_back(SchemaElement()); - root_struct_reader.child_readers.push_back(make_uniq( - *this, LogicalType::BIGINT, generated_column_schema.back(), next_file_idx, 0, 0)); - } - - return ret; -} - -void ParquetReader::InitializeSchema(ClientContext &context) { - auto file_meta_data = GetFileMetadata(); - - if (file_meta_data->__isset.encryption_algorithm) { - if (file_meta_data->encryption_algorithm.__isset.AES_GCM_CTR_V1) { - throw InvalidInputException("File '%s' is encrypted with AES_GCM_CTR_V1, but only AES_GCM_V1 is supported", - file_name); - } - } - // check if we like this schema - if (file_meta_data->schema.size() < 2) { - throw FormatException("Need at least one non-root column in the file"); - } - root_reader = CreateReader(context); - auto &root_type = root_reader->Type(); - auto &child_types = StructType::GetChildTypes(root_type); - D_ASSERT(root_type.id() == LogicalTypeId::STRUCT); - for (auto &type_pair : child_types) { - names.push_back(type_pair.first); - return_types.push_back(type_pair.second); - } - - // Add generated constant column for row number - if (parquet_options.file_row_number) { - if (StringUtil::CIFind(names, "file_row_number") != DConstants::INVALID_INDEX) { - throw BinderException( - "Using file_row_number option on file with column named file_row_number is not supported"); - } - return_types.emplace_back(LogicalType::BIGINT); - names.emplace_back("file_row_number"); - } -} - -ParquetOptions::ParquetOptions(ClientContext &context) { - Value binary_as_string_val; - if (context.TryGetCurrentSetting("binary_as_string", binary_as_string_val)) { - binary_as_string = binary_as_string_val.GetValue(); - } -} - -ParquetColumnDefinition ParquetColumnDefinition::FromSchemaValue(ClientContext &context, const Value &column_value) { - ParquetColumnDefinition result; - result.field_id = IntegerValue::Get(StructValue::GetChildren(column_value)[0]); - - const auto &column_def = StructValue::GetChildren(column_value)[1]; - D_ASSERT(column_def.type().id() == LogicalTypeId::STRUCT); - - const auto children = StructValue::GetChildren(column_def); - result.name = StringValue::Get(children[0]); - result.type = TransformStringToLogicalType(StringValue::Get(children[1])); - string error_message; - if (!children[2].TryCastAs(context, result.type, result.default_value, &error_message)) { - throw BinderException("Unable to cast Parquet schema default_value \"%s\" to %s", children[2].ToString(), - result.type.ToString()); - } - - return result; -} - -ParquetReader::ParquetReader(ClientContext &context_p, string file_name_p, ParquetOptions parquet_options_p, - shared_ptr metadata_p) - : fs(FileSystem::GetFileSystem(context_p)), allocator(BufferAllocator::Get(context_p)), - parquet_options(std::move(parquet_options_p)) { - file_name = std::move(file_name_p); - file_handle = fs.OpenFile(file_name, FileFlags::FILE_FLAGS_READ); - if (!file_handle->CanSeek()) { - throw NotImplementedException( - "Reading parquet files from a FIFO stream is not supported and cannot be efficiently supported since " - "metadata is located at the end of the file. Write the stream to disk first and read from there instead."); - } - - // set pointer to factory method for AES state - auto &config = DBConfig::GetConfig(context_p); - if (config.encryption_util && parquet_options.debug_use_openssl) { - encryption_util = config.encryption_util; - } else { - encryption_util = make_shared_ptr(); - } - - // If metadata cached is disabled - // or if this file has cached metadata - // or if the cached version already expired - if (!metadata_p) { - Value metadata_cache = false; - context_p.TryGetCurrentSetting("parquet_metadata_cache", metadata_cache); - if (!metadata_cache.GetValue()) { - metadata = - LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, *encryption_util); - } else { - auto last_modify_time = fs.GetLastModifiedTime(*file_handle); - metadata = ObjectCache::GetObjectCache(context_p).Get(file_name); - if (!metadata || (last_modify_time + 10 >= metadata->read_time)) { - metadata = LoadMetadata(context_p, allocator, *file_handle, parquet_options.encryption_config, - *encryption_util); - ObjectCache::GetObjectCache(context_p).Put(file_name, metadata); - } - } - } else { - metadata = std::move(metadata_p); - } - InitializeSchema(context_p); -} - -ParquetUnionData::~ParquetUnionData() { -} - -ParquetReader::ParquetReader(ClientContext &context_p, ParquetOptions parquet_options_p, - shared_ptr metadata_p) - : fs(FileSystem::GetFileSystem(context_p)), allocator(BufferAllocator::Get(context_p)), - metadata(std::move(metadata_p)), parquet_options(std::move(parquet_options_p)) { - InitializeSchema(context_p); -} - -ParquetReader::~ParquetReader() { -} - -const FileMetaData *ParquetReader::GetFileMetadata() { - D_ASSERT(metadata); - D_ASSERT(metadata->metadata); - return metadata->metadata.get(); -} - -unique_ptr ParquetReader::ReadStatistics(const string &name) { - idx_t file_col_idx; - for (file_col_idx = 0; file_col_idx < names.size(); file_col_idx++) { - if (names[file_col_idx] == name) { - break; - } - } - if (file_col_idx == names.size()) { - return nullptr; - } - - unique_ptr column_stats; - auto file_meta_data = GetFileMetadata(); - auto &column_reader = root_reader->Cast().GetChildReader(file_col_idx); - - for (idx_t row_group_idx = 0; row_group_idx < file_meta_data->row_groups.size(); row_group_idx++) { - auto &row_group = file_meta_data->row_groups[row_group_idx]; - auto chunk_stats = column_reader.Stats(row_group_idx, row_group.columns); - if (!chunk_stats) { - return nullptr; - } - if (!column_stats) { - column_stats = std::move(chunk_stats); - } else { - column_stats->Merge(*chunk_stats); - } - } - return column_stats; -} - -unique_ptr ParquetReader::ReadStatistics(ClientContext &context, ParquetOptions parquet_options, - shared_ptr metadata, - const string &name) { - ParquetReader reader(context, std::move(parquet_options), std::move(metadata)); - return reader.ReadStatistics(name); -} - -uint32_t ParquetReader::Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot) { - if (parquet_options.encryption_config) { - return ParquetCrypto::Read(object, iprot, parquet_options.encryption_config->GetFooterKey(), *encryption_util); - } else { - return object.read(&iprot); - } -} - -uint32_t ParquetReader::ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer, - const uint32_t buffer_size) { - if (parquet_options.encryption_config) { - return ParquetCrypto::ReadData(iprot, buffer, buffer_size, parquet_options.encryption_config->GetFooterKey(), - *encryption_util); - } else { - return iprot.getTransport()->read(buffer, buffer_size); - } -} - -const ParquetRowGroup &ParquetReader::GetGroup(ParquetReaderScanState &state) { - auto file_meta_data = GetFileMetadata(); - D_ASSERT(state.current_group >= 0 && (idx_t)state.current_group < state.group_idx_list.size()); - D_ASSERT(state.group_idx_list[state.current_group] < file_meta_data->row_groups.size()); - return file_meta_data->row_groups[state.group_idx_list[state.current_group]]; -} - -uint64_t ParquetReader::GetGroupCompressedSize(ParquetReaderScanState &state) { - auto &group = GetGroup(state); - auto total_compressed_size = group.total_compressed_size; - - idx_t calc_compressed_size = 0; - - // If the global total_compressed_size is not set, we can still calculate it - if (group.total_compressed_size == 0) { - for (auto &column_chunk : group.columns) { - calc_compressed_size += column_chunk.meta_data.total_compressed_size; - } - } - - if (total_compressed_size != 0 && calc_compressed_size != 0 && - (idx_t)total_compressed_size != calc_compressed_size) { - throw InvalidInputException("mismatch between calculated compressed size and reported compressed size"); - } - - return total_compressed_size ? total_compressed_size : calc_compressed_size; -} - -uint64_t ParquetReader::GetGroupSpan(ParquetReaderScanState &state) { - auto &group = GetGroup(state); - idx_t min_offset = NumericLimits::Maximum(); - idx_t max_offset = NumericLimits::Minimum(); - - for (auto &column_chunk : group.columns) { - - // Set the min offset - idx_t current_min_offset = NumericLimits::Maximum(); - if (column_chunk.meta_data.__isset.dictionary_page_offset) { - current_min_offset = MinValue(current_min_offset, column_chunk.meta_data.dictionary_page_offset); - } - if (column_chunk.meta_data.__isset.index_page_offset) { - current_min_offset = MinValue(current_min_offset, column_chunk.meta_data.index_page_offset); - } - current_min_offset = MinValue(current_min_offset, column_chunk.meta_data.data_page_offset); - min_offset = MinValue(current_min_offset, min_offset); - max_offset = MaxValue(max_offset, column_chunk.meta_data.total_compressed_size + current_min_offset); - } - - return max_offset - min_offset; -} - -idx_t ParquetReader::GetGroupOffset(ParquetReaderScanState &state) { - auto &group = GetGroup(state); - idx_t min_offset = NumericLimits::Maximum(); - - for (auto &column_chunk : group.columns) { - if (column_chunk.meta_data.__isset.dictionary_page_offset) { - min_offset = MinValue(min_offset, column_chunk.meta_data.dictionary_page_offset); - } - if (column_chunk.meta_data.__isset.index_page_offset) { - min_offset = MinValue(min_offset, column_chunk.meta_data.index_page_offset); - } - min_offset = MinValue(min_offset, column_chunk.meta_data.data_page_offset); - } - - return min_offset; -} - -static FilterPropagateResult CheckParquetStringFilter(BaseStatistics &stats, const Statistics &pq_col_stats, - TableFilter &filter) { - if (filter.filter_type == TableFilterType::CONSTANT_COMPARISON) { - auto &constant_filter = filter.Cast(); - auto &min_value = pq_col_stats.min_value; - auto &max_value = pq_col_stats.max_value; - return StringStats::CheckZonemap(const_data_ptr_cast(min_value.c_str()), min_value.size(), - const_data_ptr_cast(max_value.c_str()), max_value.size(), - constant_filter.comparison_type, StringValue::Get(constant_filter.constant)); - } else { - return filter.CheckStatistics(stats); - } -} - -void ParquetReader::PrepareRowGroupBuffer(ParquetReaderScanState &state, idx_t col_idx) { - auto &group = GetGroup(state); - auto column_id = reader_data.column_ids[col_idx]; - auto &column_reader = state.root_reader->Cast().GetChildReader(column_id); - - // TODO move this to columnreader too - if (reader_data.filters) { - auto stats = column_reader.Stats(state.group_idx_list[state.current_group], group.columns); - // filters contain output chunk index, not file col idx! - auto global_id = reader_data.column_mapping[col_idx]; - auto filter_entry = reader_data.filters->filters.find(global_id); - - if (stats && filter_entry != reader_data.filters->filters.end()) { - auto &filter = *filter_entry->second; - - FilterPropagateResult prune_result; - // TODO we might not have stats but STILL a bloom filter so move this up - // check the bloom filter if present - if (!column_reader.Type().IsNested() && - ParquetStatisticsUtils::BloomFilterSupported(column_reader.Type().id()) && - ParquetStatisticsUtils::BloomFilterExcludes(filter, group.columns[column_reader.FileIdx()].meta_data, - *state.thrift_file_proto, allocator)) { - prune_result = FilterPropagateResult::FILTER_ALWAYS_FALSE; - } else if (column_reader.Type().id() == LogicalTypeId::VARCHAR && - group.columns[column_reader.FileIdx()].meta_data.statistics.__isset.min_value && - group.columns[column_reader.FileIdx()].meta_data.statistics.__isset.max_value) { - - // our StringStats only store the first 8 bytes of strings (even if Parquet has longer string stats) - // however, when reading remote Parquet files, skipping row groups is really important - // here, we implement a special case to check the full length for string filters - if (filter.filter_type == TableFilterType::CONJUNCTION_AND) { - const auto &and_filter = filter.Cast(); - auto and_result = FilterPropagateResult::FILTER_ALWAYS_TRUE; - for (auto &child_filter : and_filter.child_filters) { - auto child_prune_result = CheckParquetStringFilter( - *stats, group.columns[column_reader.FileIdx()].meta_data.statistics, *child_filter); - if (child_prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - and_result = FilterPropagateResult::FILTER_ALWAYS_FALSE; - break; - } else if (child_prune_result != and_result) { - and_result = FilterPropagateResult::NO_PRUNING_POSSIBLE; - } - } - prune_result = and_result; - } else { - prune_result = CheckParquetStringFilter( - *stats, group.columns[column_reader.FileIdx()].meta_data.statistics, filter); - } - } else { - prune_result = filter.CheckStatistics(*stats); - } - - if (prune_result == FilterPropagateResult::FILTER_ALWAYS_FALSE) { - // this effectively will skip this chunk - state.group_offset = group.num_rows; - return; - } - } - } - - state.root_reader->InitializeRead(state.group_idx_list[state.current_group], group.columns, - *state.thrift_file_proto); -} - -idx_t ParquetReader::NumRows() { - return GetFileMetadata()->num_rows; -} - -idx_t ParquetReader::NumRowGroups() { - return GetFileMetadata()->row_groups.size(); -} - -void ParquetReader::InitializeScan(ClientContext &context, ParquetReaderScanState &state, - vector groups_to_read) { - state.current_group = -1; - state.finished = false; - state.group_offset = 0; - state.group_idx_list = std::move(groups_to_read); - state.sel.Initialize(STANDARD_VECTOR_SIZE); - if (!state.file_handle || state.file_handle->path != file_handle->path) { - auto flags = FileFlags::FILE_FLAGS_READ; - - Value disable_prefetch = false; - Value prefetch_all_files = false; - context.TryGetCurrentSetting("disable_parquet_prefetching", disable_prefetch); - context.TryGetCurrentSetting("prefetch_all_parquet_files", prefetch_all_files); - bool should_prefetch = !file_handle->OnDiskFile() || prefetch_all_files.GetValue(); - bool can_prefetch = file_handle->CanSeek() && !disable_prefetch.GetValue(); - - if (should_prefetch && can_prefetch) { - state.prefetch_mode = true; - flags |= FileFlags::FILE_FLAGS_DIRECT_IO; - } else { - state.prefetch_mode = false; - } - - state.file_handle = fs.OpenFile(file_handle->path, flags); - } - - state.thrift_file_proto = CreateThriftFileProtocol(allocator, *state.file_handle, state.prefetch_mode); - state.root_reader = CreateReader(context); - state.define_buf.resize(allocator, STANDARD_VECTOR_SIZE); - state.repeat_buf.resize(allocator, STANDARD_VECTOR_SIZE); -} - -void FilterIsNull(Vector &v, parquet_filter_t &filter_mask, idx_t count) { - if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto &mask = ConstantVector::Validity(v); - if (mask.RowIsValid(0)) { - filter_mask.reset(); - } - return; - } - - UnifiedVectorFormat unified; - v.ToUnifiedFormat(count, unified); - - if (unified.validity.AllValid()) { - filter_mask.reset(); - } else { - for (idx_t i = 0; i < count; i++) { - if (filter_mask.test(i)) { - filter_mask.set(i, !unified.validity.RowIsValid(unified.sel->get_index(i))); - } - } - } -} - -void FilterIsNotNull(Vector &v, parquet_filter_t &filter_mask, idx_t count) { - if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto &mask = ConstantVector::Validity(v); - if (!mask.RowIsValid(0)) { - filter_mask.reset(); - } - return; - } - - UnifiedVectorFormat unified; - v.ToUnifiedFormat(count, unified); - - if (!unified.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - if (filter_mask.test(i)) { - filter_mask.set(i, unified.validity.RowIsValid(unified.sel->get_index(i))); - } - } - } -} - -template -void TemplatedFilterOperation(Vector &v, T constant, parquet_filter_t &filter_mask, idx_t count) { - if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto v_ptr = ConstantVector::GetData(v); - auto &mask = ConstantVector::Validity(v); - - if (mask.RowIsValid(0)) { - if (!OP::Operation(v_ptr[0], constant)) { - filter_mask.reset(); - } - } else { - filter_mask.reset(); - } - return; - } - - UnifiedVectorFormat unified; - v.ToUnifiedFormat(count, unified); - auto data_ptr = UnifiedVectorFormat::GetData(unified); - - if (!unified.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - if (filter_mask.test(i)) { - auto idx = unified.sel->get_index(i); - bool is_valid = unified.validity.RowIsValid(idx); - if (is_valid) { - filter_mask.set(i, OP::Operation(data_ptr[idx], constant)); - } else { - filter_mask.set(i, false); - } - } - } - } else { - for (idx_t i = 0; i < count; i++) { - if (filter_mask.test(i)) { - filter_mask.set(i, OP::Operation(data_ptr[unified.sel->get_index(i)], constant)); - } - } - } -} - -template -void TemplatedFilterOperation(Vector &v, const Value &constant, parquet_filter_t &filter_mask, idx_t count) { - TemplatedFilterOperation(v, constant.template GetValueUnsafe(), filter_mask, count); -} - -template -static void FilterOperationSwitch(Vector &v, Value &constant, parquet_filter_t &filter_mask, idx_t count) { - if (filter_mask.none() || count == 0) { - return; - } - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::UINT8: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::UINT16: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::UINT32: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::UINT64: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::INT8: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::INT16: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::INT32: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::INT64: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::INT128: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::FLOAT: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::DOUBLE: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - case PhysicalType::VARCHAR: - TemplatedFilterOperation(v, constant, filter_mask, count); - break; - default: - throw NotImplementedException("Unsupported type for filter %s", v.ToString()); - } -} - -static void ApplyFilter(Vector &v, TableFilter &filter, parquet_filter_t &filter_mask, idx_t count) { - switch (filter.filter_type) { - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction = filter.Cast(); - for (auto &child_filter : conjunction.child_filters) { - ApplyFilter(v, *child_filter, filter_mask, count); - } - break; - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction = filter.Cast(); - parquet_filter_t or_mask; - for (auto &child_filter : conjunction.child_filters) { - parquet_filter_t child_mask = filter_mask; - ApplyFilter(v, *child_filter, child_mask, count); - or_mask |= child_mask; - } - filter_mask &= or_mask; - break; - } - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = filter.Cast(); - switch (constant_filter.comparison_type) { - case ExpressionType::COMPARE_EQUAL: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - case ExpressionType::COMPARE_LESSTHAN: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - case ExpressionType::COMPARE_GREATERTHAN: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - case ExpressionType::COMPARE_NOTEQUAL: - FilterOperationSwitch(v, constant_filter.constant, filter_mask, count); - break; - default: - throw InternalException("Unsupported comparison for Parquet filter pushdown"); - } - break; - } - case TableFilterType::IS_NOT_NULL: - FilterIsNotNull(v, filter_mask, count); - break; - case TableFilterType::IS_NULL: - FilterIsNull(v, filter_mask, count); - break; - case TableFilterType::STRUCT_EXTRACT: { - auto &struct_filter = filter.Cast(); - auto &child = StructVector::GetEntries(v)[struct_filter.child_idx]; - ApplyFilter(*child, *struct_filter.child_filter, filter_mask, count); - } - case TableFilterType::OPTIONAL_FILTER: { - // we don't execute zone map filters here - we only consider them for zone map pruning - // do nothing to the mask. - break; - } - default: - D_ASSERT(0); - break; - } -} - -void ParquetReader::Scan(ParquetReaderScanState &state, DataChunk &result) { - while (ScanInternal(state, result)) { - if (result.size() > 0) { - break; - } - result.Reset(); - } -} - -bool ParquetReader::ScanInternal(ParquetReaderScanState &state, DataChunk &result) { - if (state.finished) { - return false; - } - - // see if we have to switch to the next row group in the parquet file - if (state.current_group < 0 || (int64_t)state.group_offset >= GetGroup(state).num_rows) { - state.current_group++; - state.group_offset = 0; - - auto &trans = reinterpret_cast(*state.thrift_file_proto->getTransport()); - trans.ClearPrefetch(); - state.current_group_prefetched = false; - - if ((idx_t)state.current_group == state.group_idx_list.size()) { - state.finished = true; - return false; - } - - uint64_t to_scan_compressed_bytes = 0; - for (idx_t col_idx = 0; col_idx < reader_data.column_ids.size(); col_idx++) { - PrepareRowGroupBuffer(state, col_idx); - - auto file_col_idx = reader_data.column_ids[col_idx]; - - auto &root_reader = state.root_reader->Cast(); - to_scan_compressed_bytes += root_reader.GetChildReader(file_col_idx).TotalCompressedSize(); - } - - auto &group = GetGroup(state); - if (state.prefetch_mode && state.group_offset != (idx_t)group.num_rows) { - uint64_t total_row_group_span = GetGroupSpan(state); - - double scan_percentage = (double)(to_scan_compressed_bytes) / static_cast(total_row_group_span); - - if (to_scan_compressed_bytes > total_row_group_span) { - throw IOException( - "The parquet file '%s' seems to have incorrectly set page offsets. This interferes with DuckDB's " - "prefetching optimization. DuckDB may still be able to scan this file by manually disabling the " - "prefetching mechanism using: 'SET disable_parquet_prefetching=true'.", - file_name); - } - - if (!reader_data.filters && - scan_percentage > ParquetReaderPrefetchConfig::WHOLE_GROUP_PREFETCH_MINIMUM_SCAN) { - // Prefetch the whole row group - if (!state.current_group_prefetched) { - auto total_compressed_size = GetGroupCompressedSize(state); - if (total_compressed_size > 0) { - trans.Prefetch(GetGroupOffset(state), total_row_group_span); - } - state.current_group_prefetched = true; - } - } else { - // lazy fetching is when all tuples in a column can be skipped. With lazy fetching the buffer is only - // fetched on the first read to that buffer. - bool lazy_fetch = reader_data.filters; - - // Prefetch column-wise - for (idx_t col_idx = 0; col_idx < reader_data.column_ids.size(); col_idx++) { - auto file_col_idx = reader_data.column_ids[col_idx]; - auto &root_reader = state.root_reader->Cast(); - - bool has_filter = false; - if (reader_data.filters) { - auto entry = reader_data.filters->filters.find(reader_data.column_mapping[col_idx]); - has_filter = entry != reader_data.filters->filters.end(); - } - root_reader.GetChildReader(file_col_idx).RegisterPrefetch(trans, !(lazy_fetch && !has_filter)); - } - - trans.FinalizeRegistration(); - - if (!lazy_fetch) { - trans.PrefetchRegistered(); - } - } - } - return true; - } - - auto this_output_chunk_rows = MinValue(STANDARD_VECTOR_SIZE, GetGroup(state).num_rows - state.group_offset); - result.SetCardinality(this_output_chunk_rows); - - if (this_output_chunk_rows == 0) { - state.finished = true; - return false; // end of last group, we are done - } - - // we evaluate simple table filters directly in this scan so we can skip decoding column data that's never going to - // be relevant - parquet_filter_t filter_mask; - filter_mask.set(); - - // mask out unused part of bitset - for (idx_t i = this_output_chunk_rows; i < STANDARD_VECTOR_SIZE; i++) { - filter_mask.set(i, false); - } - - state.define_buf.zero(); - state.repeat_buf.zero(); - - auto define_ptr = (uint8_t *)state.define_buf.ptr; - auto repeat_ptr = (uint8_t *)state.repeat_buf.ptr; - - auto &root_reader = state.root_reader->Cast(); - - if (reader_data.filters) { - vector need_to_read(reader_data.column_ids.size(), true); - - // first load the columns that are used in filters - for (auto &filter_col : reader_data.filters->filters) { - if (filter_mask.none()) { - // if no rows are left we can stop checking filters - break; - } - auto filter_entry = reader_data.filter_map[filter_col.first]; - if (filter_entry.is_constant) { - // this is a constant vector, look for the constant - auto &constant = reader_data.constant_map[filter_entry.index].value; - Vector constant_vector(constant); - ApplyFilter(constant_vector, *filter_col.second, filter_mask, this_output_chunk_rows); - } else { - auto id = filter_entry.index; - auto file_col_idx = reader_data.column_ids[id]; - auto result_idx = reader_data.column_mapping[id]; - - auto &result_vector = result.data[result_idx]; - auto &child_reader = root_reader.GetChildReader(file_col_idx); - child_reader.Read(result.size(), filter_mask, define_ptr, repeat_ptr, result_vector); - need_to_read[id] = false; - - ApplyFilter(result_vector, *filter_col.second, filter_mask, this_output_chunk_rows); - } - } - - // we still may have to read some cols - for (idx_t col_idx = 0; col_idx < reader_data.column_ids.size(); col_idx++) { - if (!need_to_read[col_idx]) { - continue; - } - auto file_col_idx = reader_data.column_ids[col_idx]; - if (filter_mask.none()) { - root_reader.GetChildReader(file_col_idx).Skip(result.size()); - continue; - } - auto &result_vector = result.data[reader_data.column_mapping[col_idx]]; - auto &child_reader = root_reader.GetChildReader(file_col_idx); - child_reader.Read(result.size(), filter_mask, define_ptr, repeat_ptr, result_vector); - } - - idx_t sel_size = 0; - for (idx_t i = 0; i < this_output_chunk_rows; i++) { - if (filter_mask.test(i)) { - state.sel.set_index(sel_size++, i); - } - } - - result.Slice(state.sel, sel_size); - } else { - for (idx_t col_idx = 0; col_idx < reader_data.column_ids.size(); col_idx++) { - auto file_col_idx = reader_data.column_ids[col_idx]; - auto &result_vector = result.data[reader_data.column_mapping[col_idx]]; - auto &child_reader = root_reader.GetChildReader(file_col_idx); - auto rows_read = child_reader.Read(result.size(), filter_mask, define_ptr, repeat_ptr, result_vector); - if (rows_read != result.size()) { - throw InvalidInputException("Mismatch in parquet read for column %llu, expected %llu rows, got %llu", - file_col_idx, result.size(), rows_read); - } - } - } - - state.group_offset += this_output_chunk_rows; - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_statistics.cpp b/src/duckdb/extension/parquet/parquet_statistics.cpp deleted file mode 100644 index 7e4ceac40..000000000 --- a/src/duckdb/extension/parquet/parquet_statistics.cpp +++ /dev/null @@ -1,610 +0,0 @@ -#include "parquet_statistics.hpp" - -#include "duckdb.hpp" -#include "parquet_decimal_utils.hpp" -#include "parquet_timestamp.hpp" -#include "string_column_reader.hpp" -#include "struct_column_reader.hpp" -#include "zstd/common/xxhash.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/storage/statistics/struct_stats.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#endif - -namespace duckdb { - -using duckdb_parquet::ConvertedType; -using duckdb_parquet::Type; - -static unique_ptr CreateNumericStats(const LogicalType &type, - const duckdb_parquet::SchemaElement &schema_ele, - const duckdb_parquet::Statistics &parquet_stats) { - auto stats = NumericStats::CreateUnknown(type); - - // for reasons unknown to science, Parquet defines *both* `min` and `min_value` as well as `max` and - // `max_value`. All are optional. such elegance. - Value min; - Value max; - if (parquet_stats.__isset.min_value) { - min = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.min_value); - } else if (parquet_stats.__isset.min) { - min = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.min); - } else { - min = Value(type); - } - if (parquet_stats.__isset.max_value) { - max = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.max_value); - } else if (parquet_stats.__isset.max) { - max = ParquetStatisticsUtils::ConvertValue(type, schema_ele, parquet_stats.max); - } else { - max = Value(type); - } - NumericStats::SetMin(stats, min); - NumericStats::SetMax(stats, max); - return stats.ToUnique(); -} - -Value ParquetStatisticsUtils::ConvertValue(const LogicalType &type, const duckdb_parquet::SchemaElement &schema_ele, - const std::string &stats) { - Value result; - string error; - auto stats_val = ConvertValueInternal(type, schema_ele, stats); - if (!stats_val.DefaultTryCastAs(type, result, &error)) { - return Value(type); - } - return result; -} -Value ParquetStatisticsUtils::ConvertValueInternal(const LogicalType &type, - const duckdb_parquet::SchemaElement &schema_ele, - const std::string &stats) { - auto stats_data = const_data_ptr_cast(stats.c_str()); - switch (type.id()) { - case LogicalTypeId::BOOLEAN: { - if (stats.size() != sizeof(bool)) { - throw InvalidInputException("Incorrect stats size for type BOOLEAN"); - } - return Value::BOOLEAN(Load(stats_data)); - } - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - if (stats.size() != sizeof(uint32_t)) { - throw InvalidInputException("Incorrect stats size for type UINTEGER"); - } - return Value::UINTEGER(Load(stats_data)); - case LogicalTypeId::UBIGINT: - if (stats.size() != sizeof(uint64_t)) { - throw InvalidInputException("Incorrect stats size for type UBIGINT"); - } - return Value::UBIGINT(Load(stats_data)); - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - if (stats.size() != sizeof(int32_t)) { - throw InvalidInputException("Incorrect stats size for type INTEGER"); - } - return Value::INTEGER(Load(stats_data)); - case LogicalTypeId::BIGINT: - if (stats.size() != sizeof(int64_t)) { - throw InvalidInputException("Incorrect stats size for type BIGINT"); - } - return Value::BIGINT(Load(stats_data)); - case LogicalTypeId::FLOAT: { - if (stats.size() != sizeof(float)) { - throw InvalidInputException("Incorrect stats size for type FLOAT"); - } - auto val = Load(stats_data); - if (!Value::FloatIsFinite(val)) { - return Value(); - } - return Value::FLOAT(val); - } - case LogicalTypeId::DOUBLE: { - switch (schema_ele.type) { - case Type::FIXED_LEN_BYTE_ARRAY: - case Type::BYTE_ARRAY: - // decimals cast to double - return Value::DOUBLE(ParquetDecimalUtils::ReadDecimalValue(stats_data, stats.size(), schema_ele)); - default: - break; - } - if (stats.size() != sizeof(double)) { - throw InvalidInputException("Incorrect stats size for type DOUBLE"); - } - auto val = Load(stats_data); - if (!Value::DoubleIsFinite(val)) { - return Value(); - } - return Value::DOUBLE(val); - } - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (schema_ele.type) { - case Type::INT32: { - if (stats.size() != sizeof(int32_t)) { - throw InvalidInputException("Incorrect stats size for type %s", type.ToString()); - } - return Value::DECIMAL(Load(stats_data), width, scale); - } - case Type::INT64: { - if (stats.size() != sizeof(int64_t)) { - throw InvalidInputException("Incorrect stats size for type %s", type.ToString()); - } - return Value::DECIMAL(Load(stats_data), width, scale); - } - case Type::BYTE_ARRAY: - case Type::FIXED_LEN_BYTE_ARRAY: - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL( - ParquetDecimalUtils::ReadDecimalValue(stats_data, stats.size(), schema_ele), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL( - ParquetDecimalUtils::ReadDecimalValue(stats_data, stats.size(), schema_ele), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL( - ParquetDecimalUtils::ReadDecimalValue(stats_data, stats.size(), schema_ele), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL( - ParquetDecimalUtils::ReadDecimalValue(stats_data, stats.size(), schema_ele), width, - scale); - default: - throw InvalidInputException("Unsupported internal type for decimal"); - } - default: - throw InternalException("Unsupported internal type for decimal?.."); - } - } - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - if (type.id() == LogicalTypeId::BLOB || !Value::StringIsValid(stats)) { - return Value(Blob::ToString(string_t(stats))); - } - return Value(stats); - case LogicalTypeId::DATE: - if (stats.size() != sizeof(int32_t)) { - throw InvalidInputException("Incorrect stats size for type DATE"); - } - return Value::DATE(date_t(Load(stats_data))); - case LogicalTypeId::TIME: { - int64_t val; - if (stats.size() == sizeof(int32_t)) { - val = Load(stats_data); - } else if (stats.size() == sizeof(int64_t)) { - val = Load(stats_data); - } else { - throw InvalidInputException("Incorrect stats size for type TIME"); - } - if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIME) { - // logical type - if (schema_ele.logicalType.TIME.unit.__isset.MILLIS) { - return Value::TIME(Time::FromTimeMs(val)); - } else if (schema_ele.logicalType.TIME.unit.__isset.NANOS) { - return Value::TIME(Time::FromTimeNs(val)); - } else if (schema_ele.logicalType.TIME.unit.__isset.MICROS) { - return Value::TIME(dtime_t(val)); - } else { - throw InternalException("Time logicalType is set but unit is not defined"); - } - } - if (schema_ele.converted_type == duckdb_parquet::ConvertedType::TIME_MILLIS) { - return Value::TIME(Time::FromTimeMs(val)); - } else { - return Value::TIME(dtime_t(val)); - } - } - case LogicalTypeId::TIME_TZ: { - int64_t val; - if (stats.size() == sizeof(int32_t)) { - val = Load(stats_data); - } else if (stats.size() == sizeof(int64_t)) { - val = Load(stats_data); - } else { - throw InvalidInputException("Incorrect stats size for type TIMETZ"); - } - if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIME) { - // logical type - if (schema_ele.logicalType.TIME.unit.__isset.MILLIS) { - return Value::TIMETZ(ParquetIntToTimeMsTZ(NumericCast(val))); - } else if (schema_ele.logicalType.TIME.unit.__isset.MICROS) { - return Value::TIMETZ(ParquetIntToTimeTZ(val)); - } else if (schema_ele.logicalType.TIME.unit.__isset.NANOS) { - return Value::TIMETZ(ParquetIntToTimeNsTZ(val)); - } else { - throw InternalException("Time With Time Zone logicalType is set but unit is not defined"); - } - } - return Value::TIMETZ(ParquetIntToTimeTZ(val)); - } - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: { - timestamp_t timestamp_value; - if (schema_ele.type == Type::INT96) { - if (stats.size() != sizeof(Int96)) { - throw InvalidInputException("Incorrect stats size for type TIMESTAMP"); - } - timestamp_value = ImpalaTimestampToTimestamp(Load(stats_data)); - } else { - D_ASSERT(schema_ele.type == Type::INT64); - if (stats.size() != sizeof(int64_t)) { - throw InvalidInputException("Incorrect stats size for type TIMESTAMP"); - } - auto val = Load(stats_data); - if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIMESTAMP) { - // logical type - if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MILLIS) { - timestamp_value = Timestamp::FromEpochMs(val); - } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.NANOS) { - timestamp_value = Timestamp::FromEpochNanoSeconds(val); - } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MICROS) { - timestamp_value = timestamp_t(val); - } else { - throw InternalException("Timestamp logicalType is set but unit is not defined"); - } - } else if (schema_ele.converted_type == duckdb_parquet::ConvertedType::TIMESTAMP_MILLIS) { - timestamp_value = Timestamp::FromEpochMs(val); - } else { - timestamp_value = timestamp_t(val); - } - } - if (type.id() == LogicalTypeId::TIMESTAMP_TZ) { - return Value::TIMESTAMPTZ(timestamp_tz_t(timestamp_value)); - } - return Value::TIMESTAMP(timestamp_value); - } - case LogicalTypeId::TIMESTAMP_NS: { - timestamp_ns_t timestamp_value; - if (schema_ele.type == Type::INT96) { - if (stats.size() != sizeof(Int96)) { - throw InvalidInputException("Incorrect stats size for type TIMESTAMP_NS"); - } - timestamp_value = ImpalaTimestampToTimestampNS(Load(stats_data)); - } else { - D_ASSERT(schema_ele.type == Type::INT64); - if (stats.size() != sizeof(int64_t)) { - throw InvalidInputException("Incorrect stats size for type TIMESTAMP_NS"); - } - auto val = Load(stats_data); - if (schema_ele.__isset.logicalType && schema_ele.logicalType.__isset.TIMESTAMP) { - // logical type - if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MILLIS) { - timestamp_value = ParquetTimestampMsToTimestampNs(val); - } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.NANOS) { - timestamp_value = ParquetTimestampNsToTimestampNs(val); - } else if (schema_ele.logicalType.TIMESTAMP.unit.__isset.MICROS) { - timestamp_value = ParquetTimestampUsToTimestampNs(val); - } else { - throw InternalException("Timestamp (NS) logicalType is set but unit is unknown"); - } - } else if (schema_ele.converted_type == duckdb_parquet::ConvertedType::TIMESTAMP_MILLIS) { - timestamp_value = ParquetTimestampMsToTimestampNs(val); - } else { - timestamp_value = ParquetTimestampUsToTimestampNs(val); - } - } - return Value::TIMESTAMPNS(timestamp_value); - } - default: - throw InternalException("Unsupported type for stats %s", type.ToString()); - } -} - -unique_ptr ParquetStatisticsUtils::TransformColumnStatistics(const ColumnReader &reader, - const vector &columns) { - - // Not supported types - if (reader.Type().id() == LogicalTypeId::ARRAY || reader.Type().id() == LogicalTypeId::MAP || - reader.Type().id() == LogicalTypeId::LIST) { - return nullptr; - } - - unique_ptr row_group_stats; - - // Structs are handled differently (they dont have stats) - if (reader.Type().id() == LogicalTypeId::STRUCT) { - auto struct_stats = StructStats::CreateUnknown(reader.Type()); - auto &struct_reader = reader.Cast(); - // Recurse into child readers - for (idx_t i = 0; i < struct_reader.child_readers.size(); i++) { - if (!struct_reader.child_readers[i]) { - continue; - } - auto &child_reader = *struct_reader.child_readers[i]; - auto child_stats = ParquetStatisticsUtils::TransformColumnStatistics(child_reader, columns); - StructStats::SetChildStats(struct_stats, i, std::move(child_stats)); - } - row_group_stats = struct_stats.ToUnique(); - - // null count is generic - if (row_group_stats) { - row_group_stats->Set(StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES); - } - return row_group_stats; - } - - // Otherwise, its a standard column with stats - - auto &column_chunk = columns[reader.FileIdx()]; - if (!column_chunk.__isset.meta_data || !column_chunk.meta_data.__isset.statistics) { - // no stats present for row group - return nullptr; - } - auto &parquet_stats = column_chunk.meta_data.statistics; - - auto &type = reader.Type(); - auto &s_ele = reader.Schema(); - - switch (type.id()) { - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::DECIMAL: - row_group_stats = CreateNumericStats(type, s_ele, parquet_stats); - break; - case LogicalTypeId::VARCHAR: { - auto string_stats = StringStats::CreateEmpty(type); - if (parquet_stats.__isset.min_value) { - StringColumnReader::VerifyString(parquet_stats.min_value.c_str(), parquet_stats.min_value.size(), true); - StringStats::Update(string_stats, parquet_stats.min_value); - } else if (parquet_stats.__isset.min) { - StringColumnReader::VerifyString(parquet_stats.min.c_str(), parquet_stats.min.size(), true); - StringStats::Update(string_stats, parquet_stats.min); - } else { - return nullptr; - } - if (parquet_stats.__isset.max_value) { - StringColumnReader::VerifyString(parquet_stats.max_value.c_str(), parquet_stats.max_value.size(), true); - StringStats::Update(string_stats, parquet_stats.max_value); - } else if (parquet_stats.__isset.max) { - StringColumnReader::VerifyString(parquet_stats.max.c_str(), parquet_stats.max.size(), true); - StringStats::Update(string_stats, parquet_stats.max); - } else { - return nullptr; - } - StringStats::SetContainsUnicode(string_stats); - StringStats::ResetMaxStringLength(string_stats); - row_group_stats = string_stats.ToUnique(); - break; - } - default: - // no stats for you - break; - } // end of type switch - - // null count is generic - if (row_group_stats) { - row_group_stats->Set(StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES); - if (parquet_stats.__isset.null_count && parquet_stats.null_count == 0) { - row_group_stats->Set(StatsInfo::CANNOT_HAVE_NULL_VALUES); - } - } - return row_group_stats; -} - -static bool HasFilterConstants(const TableFilter &duckdb_filter) { - switch (duckdb_filter.filter_type) { - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = duckdb_filter.Cast(); - return (constant_filter.comparison_type == ExpressionType::COMPARE_EQUAL && !constant_filter.constant.IsNull()); - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and_filter = duckdb_filter.Cast(); - bool child_has_constant = false; - for (auto &child_filter : conjunction_and_filter.child_filters) { - child_has_constant |= HasFilterConstants(*child_filter); - } - return child_has_constant; - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_or_filter = duckdb_filter.Cast(); - bool child_has_constant = false; - for (auto &child_filter : conjunction_or_filter.child_filters) { - child_has_constant |= HasFilterConstants(*child_filter); - } - return child_has_constant; - } - default: - return false; - } -} - -template -uint64_t ValueXH64FixedWidth(const Value &constant) { - T val = constant.GetValue(); - return duckdb_zstd::XXH64(&val, sizeof(val), 0); -} - -// TODO we can only this if the parquet representation of the type exactly matches the duckdb rep! -// TODO TEST THIS! -// TODO perhaps we can re-use some writer infra here -static uint64_t ValueXXH64(const Value &constant) { - switch (constant.type().InternalType()) { - case PhysicalType::UINT8: - return ValueXH64FixedWidth(constant); - case PhysicalType::INT8: - return ValueXH64FixedWidth(constant); - case PhysicalType::UINT16: - return ValueXH64FixedWidth(constant); - case PhysicalType::INT16: - return ValueXH64FixedWidth(constant); - case PhysicalType::UINT32: - return ValueXH64FixedWidth(constant); - case PhysicalType::INT32: - return ValueXH64FixedWidth(constant); - case PhysicalType::UINT64: - return ValueXH64FixedWidth(constant); - case PhysicalType::INT64: - return ValueXH64FixedWidth(constant); - case PhysicalType::FLOAT: - return ValueXH64FixedWidth(constant); - case PhysicalType::DOUBLE: - return ValueXH64FixedWidth(constant); - case PhysicalType::VARCHAR: { - auto val = constant.GetValue(); - return duckdb_zstd::XXH64(val.c_str(), val.length(), 0); - } - default: - return 0; - } -} - -static bool ApplyBloomFilter(const TableFilter &duckdb_filter, ParquetBloomFilter &bloom_filter) { - switch (duckdb_filter.filter_type) { - case TableFilterType::CONSTANT_COMPARISON: { - auto &constant_filter = duckdb_filter.Cast(); - auto is_compare_equal = constant_filter.comparison_type == ExpressionType::COMPARE_EQUAL; - D_ASSERT(!constant_filter.constant.IsNull()); - auto hash = ValueXXH64(constant_filter.constant); - return hash > 0 && !bloom_filter.FilterCheck(hash) && is_compare_equal; - } - case TableFilterType::CONJUNCTION_AND: { - auto &conjunction_and_filter = duckdb_filter.Cast(); - bool any_children_true = false; - for (auto &child_filter : conjunction_and_filter.child_filters) { - any_children_true |= ApplyBloomFilter(*child_filter, bloom_filter); - } - return any_children_true; - } - case TableFilterType::CONJUNCTION_OR: { - auto &conjunction_or_filter = duckdb_filter.Cast(); - bool all_children_true = true; - for (auto &child_filter : conjunction_or_filter.child_filters) { - all_children_true &= ApplyBloomFilter(*child_filter, bloom_filter); - } - return all_children_true; - } - default: - return false; - } -} - -bool ParquetStatisticsUtils::BloomFilterSupported(const LogicalTypeId &type_id) { - switch (type_id) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - return true; - default: - return false; - } -} - -bool ParquetStatisticsUtils::BloomFilterExcludes(const TableFilter &duckdb_filter, - const duckdb_parquet::ColumnMetaData &column_meta_data, - TProtocol &file_proto, Allocator &allocator) { - if (!HasFilterConstants(duckdb_filter) || !column_meta_data.__isset.bloom_filter_offset || - column_meta_data.bloom_filter_offset <= 0) { - return false; - } - // TODO check length against file length! - - auto &transport = reinterpret_cast(*file_proto.getTransport()); - transport.SetLocation(column_meta_data.bloom_filter_offset); - if (column_meta_data.__isset.bloom_filter_length && column_meta_data.bloom_filter_length > 0) { - transport.Prefetch(column_meta_data.bloom_filter_offset, column_meta_data.bloom_filter_length); - } - - duckdb_parquet::BloomFilterHeader filter_header; - // TODO the bloom filter could be encrypted, too, so need to double check that this is NOT the case - filter_header.read(&file_proto); - if (!filter_header.algorithm.__isset.BLOCK || !filter_header.compression.__isset.UNCOMPRESSED || - !filter_header.hash.__isset.XXHASH) { - return false; - } - - auto new_buffer = make_uniq(allocator, filter_header.numBytes); - transport.read(new_buffer->ptr, filter_header.numBytes); - ParquetBloomFilter bloom_filter(std::move(new_buffer)); - return ApplyBloomFilter(duckdb_filter, bloom_filter); -} - -ParquetBloomFilter::ParquetBloomFilter(idx_t num_entries, double bloom_filter_false_positive_ratio) { - - // aim for hit ratio of 0.01% - // see http://tfk.mit.edu/pdf/bloom.pdf - double f = bloom_filter_false_positive_ratio; - double k = 8.0; - double n = LossyNumericCast(num_entries); - double m = -k * n / std::log(1 - std::pow(f, 1 / k)); - auto b = MaxValue(NextPowerOfTwo(LossyNumericCast(m / k)) / 32, 1); - - D_ASSERT(b > 0 && IsPowerOfTwo(b)); - - data = make_uniq(Allocator::DefaultAllocator(), sizeof(ParquetBloomBlock) * b); - data->zero(); - block_count = data->len / sizeof(ParquetBloomBlock); - D_ASSERT(data->len % sizeof(ParquetBloomBlock) == 0); -} - -ParquetBloomFilter::ParquetBloomFilter(unique_ptr data_p) { - D_ASSERT(data_p->len % sizeof(ParquetBloomBlock) == 0); - data = std::move(data_p); - block_count = data->len / sizeof(ParquetBloomBlock); - D_ASSERT(data->len % sizeof(ParquetBloomBlock) == 0); -} - -void ParquetBloomFilter::FilterInsert(uint64_t x) { - auto blocks = reinterpret_cast(data->ptr); - uint64_t i = ((x >> 32) * block_count) >> 32; - auto &b = blocks[i]; - ParquetBloomBlock::BlockInsert(b, x); -} - -bool ParquetBloomFilter::FilterCheck(uint64_t x) { - auto blocks = reinterpret_cast(data->ptr); - auto i = ((x >> 32) * block_count) >> 32; - return ParquetBloomBlock::BlockCheck(blocks[i], x); -} - -// compiler optimizes this into a single instruction (popcnt) -static uint8_t PopCnt64(uint64_t n) { - uint8_t c = 0; - for (; n; ++c) { - n &= n - 1; - } - return c; -} - -double ParquetBloomFilter::OneRatio() { - auto bloom_ptr = reinterpret_cast(data->ptr); - idx_t one_count = 0; - for (idx_t b_idx = 0; b_idx < data->len / sizeof(uint64_t); ++b_idx) { - one_count += PopCnt64(bloom_ptr[b_idx]); - } - return LossyNumericCast(one_count) / (LossyNumericCast(data->len) * 8.0); -} - -ResizeableBuffer *ParquetBloomFilter::Get() { - return data.get(); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_timestamp.cpp b/src/duckdb/extension/parquet/parquet_timestamp.cpp deleted file mode 100644 index 84c158f19..000000000 --- a/src/duckdb/extension/parquet/parquet_timestamp.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "parquet_timestamp.hpp" - -#include "duckdb.hpp" -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#endif - -namespace duckdb { - -// surely they are joking -static constexpr int64_t JULIAN_TO_UNIX_EPOCH_DAYS = 2440588LL; -static constexpr int64_t MILLISECONDS_PER_DAY = 86400000LL; -static constexpr int64_t MICROSECONDS_PER_DAY = MILLISECONDS_PER_DAY * 1000LL; -static constexpr int64_t NANOSECONDS_PER_MICRO = 1000LL; -static constexpr int64_t NANOSECONDS_PER_DAY = MICROSECONDS_PER_DAY * 1000LL; - -static inline int64_t ImpalaTimestampToDays(const Int96 &impala_timestamp) { - return impala_timestamp.value[2] - JULIAN_TO_UNIX_EPOCH_DAYS; -} - -static int64_t ImpalaTimestampToMicroseconds(const Int96 &impala_timestamp) { - int64_t days_since_epoch = ImpalaTimestampToDays(impala_timestamp); - auto nanoseconds = Load(const_data_ptr_cast(impala_timestamp.value)); - auto microseconds = nanoseconds / NANOSECONDS_PER_MICRO; - return days_since_epoch * MICROSECONDS_PER_DAY + microseconds; -} - -static int64_t ImpalaTimestampToNanoseconds(const Int96 &impala_timestamp) { - int64_t days_since_epoch = ImpalaTimestampToDays(impala_timestamp); - auto nanoseconds = Load(const_data_ptr_cast(impala_timestamp.value)); - return days_since_epoch * NANOSECONDS_PER_DAY + nanoseconds; -} - -timestamp_ns_t ImpalaTimestampToTimestampNS(const Int96 &raw_ts) { - timestamp_ns_t result; - result.value = ImpalaTimestampToNanoseconds(raw_ts); - return result; -} - -timestamp_t ImpalaTimestampToTimestamp(const Int96 &raw_ts) { - auto impala_us = ImpalaTimestampToMicroseconds(raw_ts); - return Timestamp::FromEpochMicroSeconds(impala_us); -} - -Int96 TimestampToImpalaTimestamp(timestamp_t &ts) { - int32_t hour, min, sec, msec; - Time::Convert(Timestamp::GetTime(ts), hour, min, sec, msec); - uint64_t ms_since_midnight = hour * 60 * 60 * 1000 + min * 60 * 1000 + sec * 1000 + msec; - auto days_since_epoch = Date::Epoch(Timestamp::GetDate(ts)) / int64_t(24 * 60 * 60); - // first two uint32 in Int96 are nanoseconds since midnights - // last uint32 is number of days since year 4713 BC ("Julian date") - Int96 impala_ts; - Store(ms_since_midnight * 1000000, data_ptr_cast(impala_ts.value)); - impala_ts.value[2] = days_since_epoch + JULIAN_TO_UNIX_EPOCH_DAYS; - return impala_ts; -} - -timestamp_t ParquetTimestampMicrosToTimestamp(const int64_t &raw_ts) { - return Timestamp::FromEpochMicroSeconds(raw_ts); -} - -timestamp_t ParquetTimestampMsToTimestamp(const int64_t &raw_ts) { - timestamp_t input(raw_ts); - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::FromEpochMs(raw_ts); -} - -timestamp_ns_t ParquetTimestampMsToTimestampNs(const int64_t &raw_ms) { - timestamp_ns_t input; - input.value = raw_ms; - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::TimestampNsFromEpochMillis(raw_ms); -} - -timestamp_ns_t ParquetTimestampUsToTimestampNs(const int64_t &raw_us) { - timestamp_ns_t input; - input.value = raw_us; - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::TimestampNsFromEpochMicros(raw_us); -} - -timestamp_ns_t ParquetTimestampNsToTimestampNs(const int64_t &raw_ns) { - timestamp_ns_t result; - result.value = raw_ns; - return result; -} - -timestamp_t ParquetTimestampNsToTimestamp(const int64_t &raw_ts) { - timestamp_t input(raw_ts); - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::FromEpochNanoSeconds(raw_ts); -} - -date_t ParquetIntToDate(const int32_t &raw_date) { - return date_t(raw_date); -} - -template -static T ParquetWrapTime(const T &raw, const T day) { - // Special case 24:00:00 - if (raw == day) { - return raw; - } - const auto modulus = raw % day; - return modulus + (modulus < 0) * day; -} - -dtime_t ParquetIntToTimeMs(const int32_t &raw_millis) { - return Time::FromTimeMs(raw_millis); -} - -dtime_t ParquetIntToTime(const int64_t &raw_micros) { - return dtime_t(raw_micros); -} - -dtime_t ParquetIntToTimeNs(const int64_t &raw_nanos) { - return Time::FromTimeNs(raw_nanos); -} - -dtime_tz_t ParquetIntToTimeMsTZ(const int32_t &raw_millis) { - const int32_t MSECS_PER_DAY = Interval::MSECS_PER_SEC * Interval::SECS_PER_DAY; - const auto millis = ParquetWrapTime(raw_millis, MSECS_PER_DAY); - return dtime_tz_t(Time::FromTimeMs(millis), 0); -} - -dtime_tz_t ParquetIntToTimeTZ(const int64_t &raw_micros) { - const auto micros = ParquetWrapTime(raw_micros, Interval::MICROS_PER_DAY); - return dtime_tz_t(dtime_t(micros), 0); -} - -dtime_tz_t ParquetIntToTimeNsTZ(const int64_t &raw_nanos) { - const auto nanos = ParquetWrapTime(raw_nanos, Interval::NANOS_PER_DAY); - return dtime_tz_t(Time::FromTimeNs(nanos), 0); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/parquet_writer.cpp b/src/duckdb/extension/parquet/parquet_writer.cpp deleted file mode 100644 index 883f03069..000000000 --- a/src/duckdb/extension/parquet/parquet_writer.cpp +++ /dev/null @@ -1,609 +0,0 @@ -#include "parquet_writer.hpp" - -#include "duckdb.hpp" -#include "mbedtls_wrapper.hpp" -#include "parquet_crypto.hpp" -#include "parquet_timestamp.hpp" -#include "resizable_buffer.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/write_stream.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/connection.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#endif - -namespace duckdb { - -using namespace duckdb_apache::thrift; // NOLINT -using namespace duckdb_apache::thrift::protocol; // NOLINT -using namespace duckdb_apache::thrift::transport; // NOLINT - -using duckdb_parquet::CompressionCodec; -using duckdb_parquet::ConvertedType; -using duckdb_parquet::Encoding; -using duckdb_parquet::FieldRepetitionType; -using duckdb_parquet::FileCryptoMetaData; -using duckdb_parquet::FileMetaData; -using duckdb_parquet::PageHeader; -using duckdb_parquet::PageType; -using ParquetRowGroup = duckdb_parquet::RowGroup; -using duckdb_parquet::Type; - -ChildFieldIDs::ChildFieldIDs() : ids(make_uniq>()) { -} - -ChildFieldIDs ChildFieldIDs::Copy() const { - ChildFieldIDs result; - for (const auto &id : *ids) { - result.ids->emplace(id.first, id.second.Copy()); - } - return result; -} - -FieldID::FieldID() : set(false) { -} - -FieldID::FieldID(int32_t field_id_p) : set(true), field_id(field_id_p) { -} - -FieldID FieldID::Copy() const { - auto result = set ? FieldID(field_id) : FieldID(); - result.child_field_ids = child_field_ids.Copy(); - return result; -} - -class MyTransport : public TTransport { -public: - explicit MyTransport(WriteStream &serializer) : serializer(serializer) { - } - - bool isOpen() const override { - return true; - } - - void open() override { - } - - void close() override { - } - - void write_virt(const uint8_t *buf, uint32_t len) override { - serializer.WriteData(const_data_ptr_cast(buf), len); - } - -private: - WriteStream &serializer; -}; - -bool ParquetWriter::TryGetParquetType(const LogicalType &duckdb_type, optional_ptr parquet_type_ptr) { - Type::type parquet_type; - switch (duckdb_type.id()) { - case LogicalTypeId::BOOLEAN: - parquet_type = Type::BOOLEAN; - break; - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::DATE: - parquet_type = Type::INT32; - break; - case LogicalTypeId::BIGINT: - parquet_type = Type::INT64; - break; - case LogicalTypeId::FLOAT: - parquet_type = Type::FLOAT; - break; - case LogicalTypeId::DOUBLE: - parquet_type = Type::DOUBLE; - break; - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::HUGEINT: - parquet_type = Type::DOUBLE; - break; - case LogicalTypeId::ENUM: - case LogicalTypeId::BLOB: - case LogicalTypeId::VARCHAR: - parquet_type = Type::BYTE_ARRAY; - break; - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: - parquet_type = Type::INT64; - break; - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - parquet_type = Type::INT32; - break; - case LogicalTypeId::UBIGINT: - parquet_type = Type::INT64; - break; - case LogicalTypeId::INTERVAL: - case LogicalTypeId::UUID: - parquet_type = Type::FIXED_LEN_BYTE_ARRAY; - break; - case LogicalTypeId::DECIMAL: - switch (duckdb_type.InternalType()) { - case PhysicalType::INT16: - case PhysicalType::INT32: - parquet_type = Type::INT32; - break; - case PhysicalType::INT64: - parquet_type = Type::INT64; - break; - case PhysicalType::INT128: - parquet_type = Type::FIXED_LEN_BYTE_ARRAY; - break; - default: - throw InternalException("Unsupported internal decimal type"); - } - break; - default: - // Anything that is not supported - return false; - } - if (parquet_type_ptr) { - *parquet_type_ptr = parquet_type; - } - return true; -} - -Type::type ParquetWriter::DuckDBTypeToParquetType(const LogicalType &duckdb_type) { - Type::type result; - if (TryGetParquetType(duckdb_type, &result)) { - return result; - } - throw NotImplementedException("Unimplemented type for Parquet \"%s\"", duckdb_type.ToString()); -} - -void ParquetWriter::SetSchemaProperties(const LogicalType &duckdb_type, duckdb_parquet::SchemaElement &schema_ele) { - if (duckdb_type.IsJSONType()) { - schema_ele.converted_type = ConvertedType::JSON; - schema_ele.__isset.converted_type = true; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__set_JSON(duckdb_parquet::JsonType()); - return; - } - switch (duckdb_type.id()) { - case LogicalTypeId::TINYINT: - schema_ele.converted_type = ConvertedType::INT_8; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::SMALLINT: - schema_ele.converted_type = ConvertedType::INT_16; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::INTEGER: - schema_ele.converted_type = ConvertedType::INT_32; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::BIGINT: - schema_ele.converted_type = ConvertedType::INT_64; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::UTINYINT: - schema_ele.converted_type = ConvertedType::UINT_8; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::USMALLINT: - schema_ele.converted_type = ConvertedType::UINT_16; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::UINTEGER: - schema_ele.converted_type = ConvertedType::UINT_32; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::UBIGINT: - schema_ele.converted_type = ConvertedType::UINT_64; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::DATE: - schema_ele.converted_type = ConvertedType::DATE; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIME: - schema_ele.converted_type = ConvertedType::TIME_MICROS; - schema_ele.__isset.converted_type = true; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.TIME = true; - schema_ele.logicalType.TIME.isAdjustedToUTC = (duckdb_type.id() == LogicalTypeId::TIME_TZ); - schema_ele.logicalType.TIME.unit.__isset.MICROS = true; - break; - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - schema_ele.converted_type = ConvertedType::TIMESTAMP_MICROS; - schema_ele.__isset.converted_type = true; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.TIMESTAMP = true; - schema_ele.logicalType.TIMESTAMP.isAdjustedToUTC = (duckdb_type.id() == LogicalTypeId::TIMESTAMP_TZ); - schema_ele.logicalType.TIMESTAMP.unit.__isset.MICROS = true; - break; - case LogicalTypeId::TIMESTAMP_NS: - schema_ele.__isset.converted_type = false; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.TIMESTAMP = true; - schema_ele.logicalType.TIMESTAMP.isAdjustedToUTC = false; - schema_ele.logicalType.TIMESTAMP.unit.__isset.NANOS = true; - break; - case LogicalTypeId::TIMESTAMP_MS: - schema_ele.converted_type = ConvertedType::TIMESTAMP_MILLIS; - schema_ele.__isset.converted_type = true; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.TIMESTAMP = true; - schema_ele.logicalType.TIMESTAMP.isAdjustedToUTC = false; - schema_ele.logicalType.TIMESTAMP.unit.__isset.MILLIS = true; - break; - case LogicalTypeId::ENUM: - case LogicalTypeId::VARCHAR: - schema_ele.converted_type = ConvertedType::UTF8; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::INTERVAL: - schema_ele.type_length = 12; - schema_ele.converted_type = ConvertedType::INTERVAL; - schema_ele.__isset.type_length = true; - schema_ele.__isset.converted_type = true; - break; - case LogicalTypeId::UUID: - schema_ele.type_length = 16; - schema_ele.__isset.type_length = true; - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.UUID = true; - break; - case LogicalTypeId::DECIMAL: - schema_ele.converted_type = ConvertedType::DECIMAL; - schema_ele.precision = DecimalType::GetWidth(duckdb_type); - schema_ele.scale = DecimalType::GetScale(duckdb_type); - schema_ele.__isset.converted_type = true; - schema_ele.__isset.precision = true; - schema_ele.__isset.scale = true; - if (duckdb_type.InternalType() == PhysicalType::INT128) { - schema_ele.type_length = 16; - schema_ele.__isset.type_length = true; - } - schema_ele.__isset.logicalType = true; - schema_ele.logicalType.__isset.DECIMAL = true; - schema_ele.logicalType.DECIMAL.precision = schema_ele.precision; - schema_ele.logicalType.DECIMAL.scale = schema_ele.scale; - break; - default: - break; - } -} - -uint32_t ParquetWriter::Write(const duckdb_apache::thrift::TBase &object) { - if (encryption_config) { - return ParquetCrypto::Write(object, *protocol, encryption_config->GetFooterKey(), *encryption_util); - } else { - return object.write(protocol.get()); - } -} - -uint32_t ParquetWriter::WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size) { - if (encryption_config) { - return ParquetCrypto::WriteData(*protocol, buffer, buffer_size, encryption_config->GetFooterKey(), - *encryption_util); - } else { - protocol->getTransport()->write(buffer, buffer_size); - return buffer_size; - } -} - -void VerifyUniqueNames(const vector &names) { -#ifdef DEBUG - unordered_set name_set; - name_set.reserve(names.size()); - for (auto &column : names) { - auto res = name_set.insert(column); - D_ASSERT(res.second == true); - } - // If there would be duplicates, these sizes would differ - D_ASSERT(name_set.size() == names.size()); -#endif -} - -ParquetWriter::ParquetWriter(ClientContext &context, FileSystem &fs, string file_name_p, vector types_p, - vector names_p, CompressionCodec::type codec, ChildFieldIDs field_ids_p, - const vector> &kv_metadata, - shared_ptr encryption_config_p, idx_t dictionary_size_limit_p, - double bloom_filter_false_positive_ratio_p, int64_t compression_level_p, - bool debug_use_openssl_p) - : file_name(std::move(file_name_p)), sql_types(std::move(types_p)), column_names(std::move(names_p)), codec(codec), - field_ids(std::move(field_ids_p)), encryption_config(std::move(encryption_config_p)), - dictionary_size_limit(dictionary_size_limit_p), - bloom_filter_false_positive_ratio(bloom_filter_false_positive_ratio_p), compression_level(compression_level_p), - debug_use_openssl(debug_use_openssl_p) { - - // initialize the file writer - writer = make_uniq(fs, file_name.c_str(), - FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW); - if (encryption_config) { - auto &config = DBConfig::GetConfig(context); - if (config.encryption_util && debug_use_openssl) { - // Use OpenSSL - encryption_util = config.encryption_util; - } else { - encryption_util = make_shared_ptr(); - } - // encrypted parquet files start with the string "PARE" - writer->WriteData(const_data_ptr_cast("PARE"), 4); - // we only support this one for now, not "AES_GCM_CTR_V1" - file_meta_data.encryption_algorithm.__isset.AES_GCM_V1 = true; - } else { - // parquet files start with the string "PAR1" - writer->WriteData(const_data_ptr_cast("PAR1"), 4); - } - TCompactProtocolFactoryT tproto_factory; - protocol = tproto_factory.getProtocol(std::make_shared(*writer)); - - file_meta_data.num_rows = 0; - file_meta_data.version = 1; - - file_meta_data.__isset.created_by = true; - file_meta_data.created_by = - StringUtil::Format("DuckDB version %s (build %s)", DuckDB::LibraryVersion(), DuckDB::SourceID()); - - file_meta_data.schema.resize(1); - - for (auto &kv_pair : kv_metadata) { - duckdb_parquet::KeyValue kv; - kv.__set_key(kv_pair.first); - kv.__set_value(kv_pair.second); - file_meta_data.key_value_metadata.push_back(kv); - file_meta_data.__isset.key_value_metadata = true; - } - - // populate root schema object - file_meta_data.schema[0].name = "duckdb_schema"; - file_meta_data.schema[0].num_children = NumericCast(sql_types.size()); - file_meta_data.schema[0].__isset.num_children = true; - file_meta_data.schema[0].repetition_type = duckdb_parquet::FieldRepetitionType::REQUIRED; - file_meta_data.schema[0].__isset.repetition_type = true; - - auto &unique_names = column_names; - VerifyUniqueNames(unique_names); - - vector schema_path; - for (idx_t i = 0; i < sql_types.size(); i++) { - column_writers.push_back(ColumnWriter::CreateWriterRecursive( - context, file_meta_data.schema, *this, sql_types[i], unique_names[i], schema_path, &field_ids)); - } -} - -void ParquetWriter::PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result) { - // We write 8 columns at a time so that iterating over ColumnDataCollection is more efficient - static constexpr idx_t COLUMNS_PER_PASS = 8; - - // We want these to be in-memory/hybrid so we don't have to copy over strings to the dictionary - D_ASSERT(buffer.GetAllocatorType() == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || - buffer.GetAllocatorType() == ColumnDataAllocatorType::HYBRID); - - // set up a new row group for this chunk collection - auto &row_group = result.row_group; - row_group.num_rows = NumericCast(buffer.Count()); - row_group.total_byte_size = NumericCast(buffer.SizeInBytes()); - row_group.__isset.file_offset = true; - - auto &states = result.states; - // iterate over each of the columns of the chunk collection and write them - D_ASSERT(buffer.ColumnCount() == column_writers.size()); - for (idx_t col_idx = 0; col_idx < buffer.ColumnCount(); col_idx += COLUMNS_PER_PASS) { - const auto next = MinValue(buffer.ColumnCount() - col_idx, COLUMNS_PER_PASS); - vector column_ids; - vector> col_writers; - vector> write_states; - for (idx_t i = 0; i < next; i++) { - column_ids.emplace_back(col_idx + i); - col_writers.emplace_back(*column_writers[column_ids.back()]); - write_states.emplace_back(col_writers.back().get().InitializeWriteState(row_group)); - } - - for (auto &chunk : buffer.Chunks({column_ids})) { - for (idx_t i = 0; i < next; i++) { - if (col_writers[i].get().HasAnalyze()) { - col_writers[i].get().Analyze(*write_states[i], nullptr, chunk.data[i], chunk.size()); - } - } - } - - for (idx_t i = 0; i < next; i++) { - if (col_writers[i].get().HasAnalyze()) { - col_writers[i].get().FinalizeAnalyze(*write_states[i]); - } - } - - // Reserving these once at the start really pays off - for (auto &write_state : write_states) { - write_state->definition_levels.reserve(buffer.Count()); - } - - for (auto &chunk : buffer.Chunks({column_ids})) { - for (idx_t i = 0; i < next; i++) { - col_writers[i].get().Prepare(*write_states[i], nullptr, chunk.data[i], chunk.size()); - } - } - - for (idx_t i = 0; i < next; i++) { - col_writers[i].get().BeginWrite(*write_states[i]); - } - - for (auto &chunk : buffer.Chunks({column_ids})) { - for (idx_t i = 0; i < next; i++) { - col_writers[i].get().Write(*write_states[i], chunk.data[i], chunk.size()); - } - } - - for (auto &write_state : write_states) { - states.push_back(std::move(write_state)); - } - } - result.heaps = buffer.GetHeapReferences(); -} - -// Validation code adapted from Impala -static void ValidateOffsetInFile(const string &filename, idx_t col_idx, idx_t file_length, idx_t offset, - const string &offset_name) { - if (offset >= file_length) { - throw IOException("File '%s': metadata is corrupt. Column %d has invalid " - "%s (offset=%llu file_size=%llu).", - filename, col_idx, offset_name, offset, file_length); - } -} - -static void ValidateColumnOffsets(const string &filename, idx_t file_length, const ParquetRowGroup &row_group) { - for (idx_t i = 0; i < row_group.columns.size(); ++i) { - const auto &col_chunk = row_group.columns[i]; - ValidateOffsetInFile(filename, i, file_length, col_chunk.meta_data.data_page_offset, "data page offset"); - auto col_start = NumericCast(col_chunk.meta_data.data_page_offset); - // The file format requires that if a dictionary page exists, it be before data pages. - if (col_chunk.meta_data.__isset.dictionary_page_offset) { - ValidateOffsetInFile(filename, i, file_length, col_chunk.meta_data.dictionary_page_offset, - "dictionary page offset"); - if (NumericCast(col_chunk.meta_data.dictionary_page_offset) >= col_start) { - throw IOException("Parquet file '%s': metadata is corrupt. Dictionary " - "page (offset=%llu) must come before any data pages (offset=%llu).", - filename, col_chunk.meta_data.dictionary_page_offset, col_start); - } - col_start = col_chunk.meta_data.dictionary_page_offset; - } - auto col_len = NumericCast(col_chunk.meta_data.total_compressed_size); - auto col_end = col_start + col_len; - if (col_end <= 0 || col_end > file_length) { - throw IOException("Parquet file '%s': metadata is corrupt. Column %llu has " - "invalid column offsets (offset=%llu, size=%llu, file_size=%llu).", - filename, i, col_start, col_len, file_length); - } - } -} - -void ParquetWriter::FlushRowGroup(PreparedRowGroup &prepared) { - lock_guard glock(lock); - auto &row_group = prepared.row_group; - auto &states = prepared.states; - if (states.empty()) { - throw InternalException("Attempting to flush a row group with no rows"); - } - row_group.file_offset = NumericCast(writer->GetTotalWritten()); - for (idx_t col_idx = 0; col_idx < states.size(); col_idx++) { - const auto &col_writer = column_writers[col_idx]; - auto write_state = std::move(states[col_idx]); - col_writer->FinalizeWrite(*write_state); - } - // let's make sure all offsets are ay-okay - ValidateColumnOffsets(file_name, writer->GetTotalWritten(), row_group); - - // append the row group to the file meta data - file_meta_data.row_groups.push_back(row_group); - file_meta_data.num_rows += row_group.num_rows; - - prepared.heaps.clear(); -} - -void ParquetWriter::Flush(ColumnDataCollection &buffer) { - if (buffer.Count() == 0) { - return; - } - - PreparedRowGroup prepared_row_group; - PrepareRowGroup(buffer, prepared_row_group); - buffer.Reset(); - - FlushRowGroup(prepared_row_group); -} - -void ParquetWriter::Finalize() { - - // dump the bloom filters right before footer, not if stuff is encrypted - - for (auto &bloom_filter_entry : bloom_filters) { - D_ASSERT(!encryption_config); - // write nonsense bloom filter header - duckdb_parquet::BloomFilterHeader filter_header; - auto bloom_filter_bytes = bloom_filter_entry.bloom_filter->Get(); - filter_header.numBytes = NumericCast(bloom_filter_bytes->len); - filter_header.algorithm.__set_BLOCK(duckdb_parquet::SplitBlockAlgorithm()); - filter_header.compression.__set_UNCOMPRESSED(duckdb_parquet::Uncompressed()); - filter_header.hash.__set_XXHASH(duckdb_parquet::XxHash()); - - // set metadata flags - auto &column_chunk = - file_meta_data.row_groups[bloom_filter_entry.row_group_idx].columns[bloom_filter_entry.column_idx]; - - column_chunk.meta_data.__isset.bloom_filter_offset = true; - column_chunk.meta_data.bloom_filter_offset = NumericCast(writer->GetTotalWritten()); - - auto bloom_filter_header_size = Write(filter_header); - // write actual data - WriteData(bloom_filter_bytes->ptr, bloom_filter_bytes->len); - - column_chunk.meta_data.__isset.bloom_filter_length = true; - column_chunk.meta_data.bloom_filter_length = - NumericCast(bloom_filter_header_size + bloom_filter_bytes->len); - } - - const auto metadata_start_offset = writer->GetTotalWritten(); - if (encryption_config) { - // Crypto metadata is written unencrypted - FileCryptoMetaData crypto_metadata; - duckdb_parquet::AesGcmV1 aes_gcm_v1; - duckdb_parquet::EncryptionAlgorithm alg; - alg.__set_AES_GCM_V1(aes_gcm_v1); - crypto_metadata.__set_encryption_algorithm(alg); - crypto_metadata.write(protocol.get()); - } - - // Add geoparquet metadata to the file metadata - if (geoparquet_data) { - geoparquet_data->Write(file_meta_data); - } - - Write(file_meta_data); - - writer->Write(writer->GetTotalWritten() - metadata_start_offset); - - if (encryption_config) { - // encrypted parquet files also end with the string "PARE" - writer->WriteData(const_data_ptr_cast("PARE"), 4); - } else { - // parquet files also end with the string "PAR1" - writer->WriteData(const_data_ptr_cast("PAR1"), 4); - } - - // flush to disk - writer->Close(); - writer.reset(); -} - -GeoParquetFileMetadata &ParquetWriter::GetGeoParquetData() { - if (!geoparquet_data) { - geoparquet_data = make_uniq(); - } - return *geoparquet_data; -} - -void ParquetWriter::BufferBloomFilter(idx_t col_idx, unique_ptr bloom_filter) { - if (encryption_config) { - return; - } - ParquetBloomFilterEntry new_entry; - new_entry.bloom_filter = std::move(bloom_filter); - new_entry.column_idx = col_idx; - new_entry.row_group_idx = file_meta_data.row_groups.size(); - bloom_filters.push_back(std::move(new_entry)); -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/serialize_parquet.cpp b/src/duckdb/extension/parquet/serialize_parquet.cpp deleted file mode 100644 index bcb334a45..000000000 --- a/src/duckdb/extension/parquet/serialize_parquet.cpp +++ /dev/null @@ -1,88 +0,0 @@ -//===----------------------------------------------------------------------===// -// This file is automatically generated by scripts/generate_serialization.py -// Do not edit this file manually, your changes will be overwritten -//===----------------------------------------------------------------------===// - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "parquet_reader.hpp" -#include "parquet_crypto.hpp" -#include "parquet_writer.hpp" - -namespace duckdb { - -void ChildFieldIDs::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault>(100, "ids", ids.operator*()); -} - -ChildFieldIDs ChildFieldIDs::Deserialize(Deserializer &deserializer) { - ChildFieldIDs result; - deserializer.ReadPropertyWithDefault>(100, "ids", result.ids.operator*()); - return result; -} - -void FieldID::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "set", set); - serializer.WritePropertyWithDefault(101, "field_id", field_id); - serializer.WriteProperty(102, "child_field_ids", child_field_ids); -} - -FieldID FieldID::Deserialize(Deserializer &deserializer) { - FieldID result; - deserializer.ReadPropertyWithDefault(100, "set", result.set); - deserializer.ReadPropertyWithDefault(101, "field_id", result.field_id); - deserializer.ReadProperty(102, "child_field_ids", result.child_field_ids); - return result; -} - -void ParquetColumnDefinition::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "field_id", field_id); - serializer.WritePropertyWithDefault(101, "name", name); - serializer.WriteProperty(103, "type", type); - serializer.WriteProperty(104, "default_value", default_value); -} - -ParquetColumnDefinition ParquetColumnDefinition::Deserialize(Deserializer &deserializer) { - ParquetColumnDefinition result; - deserializer.ReadPropertyWithDefault(100, "field_id", result.field_id); - deserializer.ReadPropertyWithDefault(101, "name", result.name); - deserializer.ReadProperty(103, "type", result.type); - deserializer.ReadProperty(104, "default_value", result.default_value); - return result; -} - -void ParquetEncryptionConfig::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "footer_key", footer_key); - serializer.WritePropertyWithDefault>(101, "column_keys", column_keys); -} - -shared_ptr ParquetEncryptionConfig::Deserialize(Deserializer &deserializer) { - auto result = duckdb::shared_ptr(new ParquetEncryptionConfig(deserializer.Get())); - deserializer.ReadPropertyWithDefault(100, "footer_key", result->footer_key); - deserializer.ReadPropertyWithDefault>(101, "column_keys", result->column_keys); - return result; -} - -void ParquetOptions::Serialize(Serializer &serializer) const { - serializer.WritePropertyWithDefault(100, "binary_as_string", binary_as_string); - serializer.WritePropertyWithDefault(101, "file_row_number", file_row_number); - serializer.WriteProperty(102, "file_options", file_options); - serializer.WritePropertyWithDefault>(103, "schema", schema); - serializer.WritePropertyWithDefault>(104, "encryption_config", encryption_config, nullptr); - serializer.WritePropertyWithDefault(105, "debug_use_openssl", debug_use_openssl, true); - serializer.WritePropertyWithDefault(106, "explicit_cardinality", explicit_cardinality, 0); -} - -ParquetOptions ParquetOptions::Deserialize(Deserializer &deserializer) { - ParquetOptions result; - deserializer.ReadPropertyWithDefault(100, "binary_as_string", result.binary_as_string); - deserializer.ReadPropertyWithDefault(101, "file_row_number", result.file_row_number); - deserializer.ReadProperty(102, "file_options", result.file_options); - deserializer.ReadPropertyWithDefault>(103, "schema", result.schema); - deserializer.ReadPropertyWithExplicitDefault>(104, "encryption_config", result.encryption_config, nullptr); - deserializer.ReadPropertyWithExplicitDefault(105, "debug_use_openssl", result.debug_use_openssl, true); - deserializer.ReadPropertyWithExplicitDefault(106, "explicit_cardinality", result.explicit_cardinality, 0); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/extension/parquet/zstd_file_system.cpp b/src/duckdb/extension/parquet/zstd_file_system.cpp deleted file mode 100644 index 7204f3607..000000000 --- a/src/duckdb/extension/parquet/zstd_file_system.cpp +++ /dev/null @@ -1,200 +0,0 @@ -#include "zstd_file_system.hpp" - -#include "zstd.h" - -namespace duckdb { - -struct ZstdStreamWrapper : public StreamWrapper { - ~ZstdStreamWrapper() override; - - CompressedFile *file = nullptr; - duckdb_zstd::ZSTD_DStream *zstd_stream_ptr = nullptr; - duckdb_zstd::ZSTD_CStream *zstd_compress_ptr = nullptr; - bool writing = false; - -public: - void Initialize(CompressedFile &file, bool write) override; - bool Read(StreamData &stream_data) override; - void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override; - - void Close() override; - - void FlushStream(); -}; - -ZstdStreamWrapper::~ZstdStreamWrapper() { - if (Exception::UncaughtException()) { - return; - } - try { - Close(); - } catch (...) { // NOLINT: swallow exceptions in destructor - } -} - -void ZstdStreamWrapper::Initialize(CompressedFile &file, bool write) { - Close(); - this->file = &file; - this->writing = write; - if (write) { - zstd_compress_ptr = duckdb_zstd::ZSTD_createCStream(); - } else { - zstd_stream_ptr = duckdb_zstd::ZSTD_createDStream(); - } -} - -bool ZstdStreamWrapper::Read(StreamData &sd) { - D_ASSERT(!writing); - - duckdb_zstd::ZSTD_inBuffer in_buffer; - duckdb_zstd::ZSTD_outBuffer out_buffer; - - in_buffer.src = sd.in_buff_start; - in_buffer.size = sd.in_buff_end - sd.in_buff_start; - in_buffer.pos = 0; - - out_buffer.dst = sd.out_buff_start; - out_buffer.size = sd.out_buf_size; - out_buffer.pos = 0; - - auto res = duckdb_zstd::ZSTD_decompressStream(zstd_stream_ptr, &out_buffer, &in_buffer); - if (duckdb_zstd::ZSTD_isError(res)) { - throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); - } - - sd.in_buff_start = (data_ptr_t)in_buffer.src + in_buffer.pos; // NOLINT - sd.in_buff_end = (data_ptr_t)in_buffer.src + in_buffer.size; // NOLINT - sd.out_buff_end = (data_ptr_t)out_buffer.dst + out_buffer.pos; // NOLINT - return false; -} - -void ZstdStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, - int64_t uncompressed_size) { - D_ASSERT(writing); - - auto remaining = uncompressed_size; - while (remaining > 0) { - D_ASSERT(sd.out_buff.get() + sd.out_buf_size > sd.out_buff_start); - idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; - - duckdb_zstd::ZSTD_inBuffer in_buffer; - duckdb_zstd::ZSTD_outBuffer out_buffer; - - in_buffer.src = uncompressed_data; - in_buffer.size = remaining; - in_buffer.pos = 0; - - out_buffer.dst = sd.out_buff_start; - out_buffer.size = output_remaining; - out_buffer.pos = 0; - auto res = - duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_continue); - if (duckdb_zstd::ZSTD_isError(res)) { - throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); - } - idx_t input_consumed = in_buffer.pos; - idx_t written_to_output = out_buffer.pos; - sd.out_buff_start += written_to_output; - if (sd.out_buff_start == sd.out_buff.get() + sd.out_buf_size) { - // no more output buffer available: flush - file.child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); - sd.out_buff_start = sd.out_buff.get(); - } - uncompressed_data += input_consumed; - remaining -= UnsafeNumericCast(input_consumed); - } -} - -void ZstdStreamWrapper::FlushStream() { - auto &sd = file->stream_data; - duckdb_zstd::ZSTD_inBuffer in_buffer; - duckdb_zstd::ZSTD_outBuffer out_buffer; - - in_buffer.src = nullptr; - in_buffer.size = 0; - in_buffer.pos = 0; - while (true) { - idx_t output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; - - out_buffer.dst = sd.out_buff_start; - out_buffer.size = output_remaining; - out_buffer.pos = 0; - - auto res = - duckdb_zstd::ZSTD_compressStream2(zstd_compress_ptr, &out_buffer, &in_buffer, duckdb_zstd::ZSTD_e_end); - if (duckdb_zstd::ZSTD_isError(res)) { - throw IOException(duckdb_zstd::ZSTD_getErrorName(res)); - } - idx_t written_to_output = out_buffer.pos; - sd.out_buff_start += written_to_output; - if (sd.out_buff_start > sd.out_buff.get()) { - file->child_handle->Write(sd.out_buff.get(), sd.out_buff_start - sd.out_buff.get()); - sd.out_buff_start = sd.out_buff.get(); - } - if (res == 0) { - break; - } - } -} - -void ZstdStreamWrapper::Close() { - if (!zstd_stream_ptr && !zstd_compress_ptr) { - return; - } - if (writing) { - FlushStream(); - } - if (zstd_stream_ptr) { - duckdb_zstd::ZSTD_freeDStream(zstd_stream_ptr); - } - if (zstd_compress_ptr) { - duckdb_zstd::ZSTD_freeCStream(zstd_compress_ptr); - } - zstd_stream_ptr = nullptr; - zstd_compress_ptr = nullptr; -} - -class ZStdFile : public CompressedFile { -public: - ZStdFile(unique_ptr child_handle_p, const string &path, bool write) - : CompressedFile(zstd_fs, std::move(child_handle_p), path) { - Initialize(write); - } - - FileCompressionType GetFileCompressionType() override { - return FileCompressionType::ZSTD; - } - - ZStdFileSystem zstd_fs; -}; - -unique_ptr ZStdFileSystem::OpenCompressedFile(unique_ptr handle, bool write) { - auto path = handle->path; - return make_uniq(std::move(handle), path, write); -} - -unique_ptr ZStdFileSystem::CreateStream() { - return make_uniq(); -} - -idx_t ZStdFileSystem::InBufferSize() { - return duckdb_zstd::ZSTD_DStreamInSize(); -} - -idx_t ZStdFileSystem::OutBufferSize() { - return duckdb_zstd::ZSTD_DStreamOutSize(); -} - -int64_t ZStdFileSystem::DefaultCompressionLevel() { - return duckdb_zstd::ZSTD_defaultCLevel(); -} - -int64_t ZStdFileSystem::MinimumCompressionLevel() { - return duckdb_zstd::ZSTD_minCLevel(); -} - -int64_t ZStdFileSystem::MaximumCompressionLevel() { - return duckdb_zstd::ZSTD_maxCLevel(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog.cpp b/src/duckdb/src/catalog/catalog.cpp deleted file mode 100644 index 45047f185..000000000 --- a/src/duckdb/src/catalog/catalog.cpp +++ /dev/null @@ -1,1091 +0,0 @@ -#include "duckdb/catalog/catalog.hpp" - -#include "duckdb/catalog/catalog_search_path.hpp" -#include "duckdb/catalog/catalog_entry/list.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_set.hpp" -#include "duckdb/catalog/default/default_schemas.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" -#include "duckdb/parser/parsed_data/create_collation_info.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" -#include "duckdb/parser/parsed_data/create_index_info.hpp" -#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" -#include "duckdb/parser/parsed_data/create_secret_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/parsed_data/drop_info.hpp" -#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/catalog/default/default_types.hpp" -#include "duckdb/main/extension_entries.hpp" -#include "duckdb/main/extension/generated_extension_loader.hpp" -#include "duckdb/main/connection.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/catalog/similar_catalog_entry.hpp" -#include "duckdb/storage/database_size.hpp" -#include - -namespace duckdb { - -Catalog::Catalog(AttachedDatabase &db) : db(db) { -} - -Catalog::~Catalog() { -} - -DatabaseInstance &Catalog::GetDatabase() { - return db.GetDatabase(); -} - -AttachedDatabase &Catalog::GetAttached() { - return db; -} - -const AttachedDatabase &Catalog::GetAttached() const { - return db; -} - -const string &Catalog::GetName() const { - return GetAttached().GetName(); -} - -idx_t Catalog::GetOid() { - return GetAttached().oid; -} - -Catalog &Catalog::GetSystemCatalog(ClientContext &context) { - return Catalog::GetSystemCatalog(*context.db); -} - -const string &GetDefaultCatalog(CatalogEntryRetriever &retriever) { - return DatabaseManager::GetDefaultDatabase(retriever.GetContext()); -} - -optional_ptr Catalog::GetCatalogEntry(CatalogEntryRetriever &retriever, const string &catalog_name) { - auto &context = retriever.GetContext(); - auto &db_manager = DatabaseManager::Get(context); - if (catalog_name == TEMP_CATALOG) { - return &ClientData::Get(context).temporary_objects->GetCatalog(); - } - if (catalog_name == SYSTEM_CATALOG) { - return &GetSystemCatalog(context); - } - auto entry = - db_manager.GetDatabase(context, IsInvalidCatalog(catalog_name) ? GetDefaultCatalog(retriever) : catalog_name); - if (!entry) { - return nullptr; - } - return &entry->GetCatalog(); -} - -optional_ptr Catalog::GetCatalogEntry(ClientContext &context, const string &catalog_name) { - CatalogEntryRetriever entry_retriever(context); - return GetCatalogEntry(entry_retriever, catalog_name); -} - -Catalog &Catalog::GetCatalog(CatalogEntryRetriever &retriever, const string &catalog_name) { - auto catalog = Catalog::GetCatalogEntry(retriever, catalog_name); - if (!catalog) { - throw BinderException("Catalog \"%s\" does not exist!", catalog_name); - } - return *catalog; -} - -Catalog &Catalog::GetCatalog(ClientContext &context, const string &catalog_name) { - CatalogEntryRetriever entry_retriever(context); - return GetCatalog(entry_retriever, catalog_name); -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateSchema(ClientContext &context, CreateSchemaInfo &info) { - return CreateSchema(GetCatalogTransaction(context), info); -} - -CatalogTransaction Catalog::GetCatalogTransaction(ClientContext &context) { - return CatalogTransaction(*this, context); -} - -//===--------------------------------------------------------------------===// -// Table -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateTable(ClientContext &context, BoundCreateTableInfo &info) { - return CreateTable(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateTable(ClientContext &context, unique_ptr info) { - auto binder = Binder::CreateBinder(context); - auto bound_info = binder->BindCreateTableInfo(std::move(info)); - return CreateTable(context, *bound_info); -} - -optional_ptr Catalog::CreateTable(CatalogTransaction transaction, SchemaCatalogEntry &schema, - BoundCreateTableInfo &info) { - return schema.CreateTable(transaction, info); -} - -optional_ptr Catalog::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { - auto &schema = GetSchema(transaction, info.base->schema); - return CreateTable(transaction, schema, info); -} - -//===--------------------------------------------------------------------===// -// View -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateView(transaction, schema, info); -} - -optional_ptr Catalog::CreateView(ClientContext &context, CreateViewInfo &info) { - return CreateView(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateView(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateViewInfo &info) { - return schema.CreateView(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Sequence -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateSequence(transaction, schema, info); -} - -optional_ptr Catalog::CreateSequence(ClientContext &context, CreateSequenceInfo &info) { - return CreateSequence(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateSequence(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateSequenceInfo &info) { - return schema.CreateSequence(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Type -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateType(transaction, schema, info); -} - -optional_ptr Catalog::CreateType(ClientContext &context, CreateTypeInfo &info) { - return CreateType(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateType(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateTypeInfo &info) { - return schema.CreateType(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Table Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, CreateTableFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateTableFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateTableFunction(ClientContext &context, CreateTableFunctionInfo &info) { - return CreateTableFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateTableFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateTableFunctionInfo &info) { - return schema.CreateTableFunction(transaction, info); -} - -optional_ptr Catalog::CreateTableFunction(ClientContext &context, - optional_ptr info) { - return CreateTableFunction(context, *info); -} - -//===--------------------------------------------------------------------===// -// Copy Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, CreateCopyFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateCopyFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateCopyFunction(ClientContext &context, CreateCopyFunctionInfo &info) { - return CreateCopyFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateCopyFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateCopyFunctionInfo &info) { - return schema.CreateCopyFunction(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Pragma Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreatePragmaFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreatePragmaFunction(ClientContext &context, CreatePragmaFunctionInfo &info) { - return CreatePragmaFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreatePragmaFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreatePragmaFunctionInfo &info) { - return schema.CreatePragmaFunction(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Function -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateFunction(transaction, schema, info); -} - -optional_ptr Catalog::CreateFunction(ClientContext &context, CreateFunctionInfo &info) { - return CreateFunction(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateFunction(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateFunctionInfo &info) { - return schema.CreateFunction(transaction, info); -} - -optional_ptr Catalog::AddFunction(ClientContext &context, CreateFunctionInfo &info) { - info.on_conflict = OnCreateConflict::ALTER_ON_CONFLICT; - return CreateFunction(context, info); -} - -//===--------------------------------------------------------------------===// -// Collation -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - return CreateCollation(transaction, schema, info); -} - -optional_ptr Catalog::CreateCollation(ClientContext &context, CreateCollationInfo &info) { - return CreateCollation(GetCatalogTransaction(context), info); -} - -optional_ptr Catalog::CreateCollation(CatalogTransaction transaction, SchemaCatalogEntry &schema, - CreateCollationInfo &info) { - return schema.CreateCollation(transaction, info); -} - -//===--------------------------------------------------------------------===// -// Index -//===--------------------------------------------------------------------===// -optional_ptr Catalog::CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info) { - auto &schema = GetSchema(transaction, info.schema); - auto &table = schema.GetEntry(transaction, CatalogType::TABLE_ENTRY, info.table)->Cast(); - return schema.CreateIndex(transaction, info, table); -} - -optional_ptr Catalog::CreateIndex(ClientContext &context, CreateIndexInfo &info) { - return CreateIndex(GetCatalogTransaction(context), info); -} - -unique_ptr Catalog::BindAlterAddIndex(Binder &binder, TableCatalogEntry &table_entry, - unique_ptr plan, - unique_ptr create_info, - unique_ptr alter_info) { - throw NotImplementedException("BindAlterAddIndex not supported by this catalog"); -} - -//===--------------------------------------------------------------------===// -// Lookup Structures -//===--------------------------------------------------------------------===// -struct CatalogLookup { - CatalogLookup(Catalog &catalog, string schema_p, string name_p) - : catalog(catalog), schema(std::move(schema_p)), name(std::move(name_p)) { - } - - Catalog &catalog; - string schema; - string name; -}; - -//===--------------------------------------------------------------------===// -// Generic -//===--------------------------------------------------------------------===// -void Catalog::DropEntry(ClientContext &context, DropInfo &info) { - if (info.type == CatalogType::SCHEMA_ENTRY) { - // DROP SCHEMA - DropSchema(context, info); - return; - } - - CatalogEntryRetriever retriever(context); - auto lookup = LookupEntry(retriever, info.type, info.schema, info.name, info.if_not_found); - if (!lookup.Found()) { - return; - } - - lookup.schema->DropEntry(context, info); -} - -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &name, QueryErrorContext error_context) { - return *Catalog::GetSchema(context, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetSchema(ClientContext &context, const string &schema_name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - return GetSchema(GetCatalogTransaction(context), schema_name, if_not_found, error_context); -} - -SchemaCatalogEntry &Catalog::GetSchema(ClientContext &context, const string &catalog_name, const string &schema_name, - QueryErrorContext error_context) { - return *Catalog::GetSchema(context, catalog_name, schema_name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -SchemaCatalogEntry &Catalog::GetSchema(CatalogTransaction transaction, const string &name, - QueryErrorContext error_context) { - return *GetSchema(transaction, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -//===--------------------------------------------------------------------===// -// Lookup -//===--------------------------------------------------------------------===// -vector Catalog::SimilarEntriesInSchemas(ClientContext &context, const string &entry_name, - CatalogType type, - const reference_set_t &schemas) { - vector results; - for (auto schema_ref : schemas) { - auto &schema = schema_ref.get(); - auto transaction = schema.catalog.GetCatalogTransaction(context); - auto entry = schema.GetSimilarEntry(transaction, type, entry_name); - if (!entry.Found()) { - // no similar entry found - continue; - } - if (results.empty() || results[0].score <= entry.score) { - if (!results.empty() && results[0].score < entry.score) { - results.clear(); - } - - results.push_back(entry); - results.back().schema = &schema; - } - } - return results; -} - -vector GetCatalogEntries(CatalogEntryRetriever &retriever, const string &catalog, - const string &schema) { - auto &context = retriever.GetContext(); - vector entries; - auto &search_path = retriever.GetSearchPath(); - if (IsInvalidCatalog(catalog) && IsInvalidSchema(schema)) { - // no catalog or schema provided - scan the entire search path - entries = search_path.Get(); - } else if (IsInvalidCatalog(catalog)) { - auto catalogs = search_path.GetCatalogsForSchema(schema); - for (auto &catalog_name : catalogs) { - entries.emplace_back(catalog_name, schema); - } - if (entries.empty()) { - auto &default_entry = search_path.GetDefault(); - if (!IsInvalidCatalog(default_entry.catalog)) { - entries.emplace_back(default_entry.catalog, schema); - } else { - entries.emplace_back(DatabaseManager::GetDefaultDatabase(context), schema); - } - } - } else if (IsInvalidSchema(schema)) { - auto schemas = search_path.GetSchemasForCatalog(catalog); - for (auto &schema_name : schemas) { - entries.emplace_back(catalog, schema_name); - } - if (entries.empty()) { - entries.emplace_back(catalog, DEFAULT_SCHEMA); - } - } else { - // specific catalog and schema provided - entries.emplace_back(catalog, schema); - } - return entries; -} - -void FindMinimalQualification(CatalogEntryRetriever &retriever, const string &catalog_name, const string &schema_name, - bool &qualify_database, bool &qualify_schema) { - // check if we can we qualify ONLY the schema - bool found = false; - auto entries = GetCatalogEntries(retriever, INVALID_CATALOG, schema_name); - for (auto &entry : entries) { - if (entry.catalog == catalog_name && entry.schema == schema_name) { - found = true; - break; - } - } - if (found) { - qualify_database = false; - qualify_schema = true; - return; - } - // check if we can qualify ONLY the catalog - found = false; - entries = GetCatalogEntries(retriever, catalog_name, INVALID_SCHEMA); - for (auto &entry : entries) { - if (entry.catalog == catalog_name && entry.schema == schema_name) { - found = true; - break; - } - } - if (found) { - qualify_database = true; - qualify_schema = false; - return; - } - // need to qualify both catalog and schema - qualify_database = true; - qualify_schema = true; -} - -bool Catalog::TryAutoLoad(ClientContext &context, const string &original_name) noexcept { - string extension_name = ExtensionHelper::ApplyExtensionAlias(original_name); - if (context.db->ExtensionIsLoaded(extension_name)) { - return true; - } -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(context); - if (!dbconfig.options.autoload_known_extensions) { - return false; - } - try { - if (ExtensionHelper::CanAutoloadExtension(extension_name)) { - return ExtensionHelper::TryAutoLoadExtension(context, extension_name); - } - } catch (...) { - return false; - } -#endif - return false; -} - -void Catalog::AutoloadExtensionByConfigName(ClientContext &context, const string &configuration_name) { -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(context); - if (dbconfig.options.autoload_known_extensions) { - auto extension_name = ExtensionHelper::FindExtensionInEntries(configuration_name, EXTENSION_SETTINGS); - if (ExtensionHelper::CanAutoloadExtension(extension_name)) { - ExtensionHelper::AutoLoadExtension(context, extension_name); - return; - } - } -#endif - - throw Catalog::UnrecognizedConfigurationError(context, configuration_name); -} - -static bool IsAutoloadableFunction(CatalogType type) { - return (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || - type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY); -} - -bool IsTableFunction(CatalogType type) { - switch (type) { - case CatalogType::TABLE_FUNCTION_ENTRY: - case CatalogType::TABLE_MACRO_ENTRY: - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return true; - default: - return false; - } -} - -bool IsScalarFunction(CatalogType type) { - switch (type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - case CatalogType::MACRO_ENTRY: - return true; - default: - return false; - } -} - -static bool CompareCatalogTypes(CatalogType type_a, CatalogType type_b) { - if (type_a == type_b) { - // Types are same - return true; - } - if (IsScalarFunction(type_a) && IsScalarFunction(type_b)) { - return true; - } - if (IsTableFunction(type_a) && IsTableFunction(type_b)) { - return true; - } - return false; -} - -bool Catalog::AutoLoadExtensionByCatalogEntry(DatabaseInstance &db, CatalogType type, const string &entry_name) { -#ifndef DUCKDB_DISABLE_EXTENSION_LOAD - auto &dbconfig = DBConfig::GetConfig(db); - if (dbconfig.options.autoload_known_extensions) { - string extension_name; - if (IsAutoloadableFunction(type)) { - auto lookup_result = ExtensionHelper::FindExtensionInFunctionEntries(entry_name, EXTENSION_FUNCTIONS); - if (lookup_result.empty()) { - return false; - } - for (auto &function : lookup_result) { - auto function_type = function.second; - // FIXME: what if there are two functions with the same name, from different extensions? - if (CompareCatalogTypes(type, function_type)) { - extension_name = function.first; - break; - } - } - } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); - } else if (type == CatalogType::TYPE_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); - } else if (type == CatalogType::COLLATION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); - } - - if (!extension_name.empty() && ExtensionHelper::CanAutoloadExtension(extension_name)) { - ExtensionHelper::AutoLoadExtension(db, extension_name); - return true; - } - } -#endif - - return false; -} - -CatalogException Catalog::UnrecognizedConfigurationError(ClientContext &context, const string &name) { - // check if the setting exists in any extensions - auto extension_name = ExtensionHelper::FindExtensionInEntries(name, EXTENSION_SETTINGS); - if (!extension_name.empty()) { - auto error_message = "Setting with name \"" + name + "\" is not in the catalog, but it exists in the " + - extension_name + " extension."; - error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); - return CatalogException(error_message); - } - // the setting is not in an extension - // get a list of all options - vector potential_names = DBConfig::GetOptionNames(); - for (auto &entry : DBConfig::GetConfig(context).extension_parameters) { - potential_names.push_back(entry.first); - } - throw CatalogException::MissingEntry("configuration parameter", name, potential_names); -} - -CatalogException Catalog::CreateMissingEntryException(CatalogEntryRetriever &retriever, const string &entry_name, - CatalogType type, - const reference_set_t &schemas, - QueryErrorContext error_context) { - auto &context = retriever.GetContext(); - auto entries = SimilarEntriesInSchemas(context, entry_name, type, schemas); - - reference_set_t unseen_schemas; - auto &db_manager = DatabaseManager::Get(context); - auto databases = db_manager.GetDatabases(context); - auto &config = DBConfig::GetConfig(context); - - auto max_schema_count = config.GetSetting(context); - for (auto database : databases) { - if (unseen_schemas.size() >= max_schema_count) { - break; - } - auto &catalog = database.get().GetCatalog(); - auto current_schemas = catalog.GetAllSchemas(context); - for (auto ¤t_schema : current_schemas) { - if (unseen_schemas.size() >= max_schema_count) { - break; - } - unseen_schemas.insert(current_schema.get()); - } - } - // check if the entry exists in any extension - string extension_name; - if (type == CatalogType::TABLE_FUNCTION_ENTRY || type == CatalogType::SCALAR_FUNCTION_ENTRY || - type == CatalogType::AGGREGATE_FUNCTION_ENTRY || type == CatalogType::PRAGMA_FUNCTION_ENTRY) { - auto lookup_result = ExtensionHelper::FindExtensionInFunctionEntries(entry_name, EXTENSION_FUNCTIONS); - do { - if (lookup_result.empty()) { - break; - } - vector other_types; - string extension_for_error; - for (auto &function : lookup_result) { - auto function_type = function.second; - if (CompareCatalogTypes(type, function_type)) { - extension_name = function.first; - break; - } - extension_for_error = function.first; - other_types.push_back(CatalogTypeToString(function_type)); - } - if (!extension_name.empty()) { - break; - } - if (other_types.size() == 1) { - auto &function_type = other_types[0]; - auto error = - CatalogException("%s with name \"%s\" is not in the catalog, a function by this name exists " - "in the %s extension, but it's of a different type, namely %s", - CatalogTypeToString(type), entry_name, extension_for_error, function_type); - return error; - } else { - D_ASSERT(!other_types.empty()); - auto list_of_types = StringUtil::Join(other_types, ", "); - auto error = - CatalogException("%s with name \"%s\" is not in the catalog, functions with this name exist " - "in the %s extension, but they are of different types, namely %s", - CatalogTypeToString(type), entry_name, extension_for_error, list_of_types); - return error; - } - } while (false); - } else if (type == CatalogType::TYPE_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_TYPES); - } else if (type == CatalogType::COPY_FUNCTION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COPY_FUNCTIONS); - } else if (type == CatalogType::COLLATION_ENTRY) { - extension_name = ExtensionHelper::FindExtensionInEntries(entry_name, EXTENSION_COLLATIONS); - } - - // if we found an extension that can handle this catalog entry, create an error hinting the user - if (!extension_name.empty()) { - auto error_message = CatalogTypeToString(type) + " with name \"" + entry_name + - "\" is not in the catalog, but it exists in the " + extension_name + " extension."; - error_message = ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, extension_name); - return CatalogException(error_message); - } - - // entries in other schemas get a penalty - // however, if there is an exact match in another schema, we will always show it - static constexpr const double UNSEEN_PENALTY = 0.2; - auto unseen_entries = SimilarEntriesInSchemas(context, entry_name, type, unseen_schemas); - vector suggestions; - if (!unseen_entries.empty() && (unseen_entries[0].score == 1.0 || unseen_entries[0].score - UNSEEN_PENALTY > - (entries.empty() ? 0.0 : entries[0].score))) { - // the closest matching entry requires qualification as it is not in the default search path - // check how to minimally qualify this entry - for (auto &unseen_entry : unseen_entries) { - auto catalog_name = unseen_entry.schema->catalog.GetName(); - auto schema_name = unseen_entry.schema->name; - bool qualify_database; - bool qualify_schema; - FindMinimalQualification(retriever, catalog_name, schema_name, qualify_database, qualify_schema); - suggestions.push_back(unseen_entry.GetQualifiedName(qualify_database, qualify_schema)); - } - } else if (!entries.empty()) { - for (auto &entry : entries) { - suggestions.push_back(entry.name); - } - } - - string did_you_mean; - std::sort(suggestions.begin(), suggestions.end()); - if (suggestions.size() > 2) { - auto last = suggestions.back(); - suggestions.pop_back(); - did_you_mean = StringUtil::Join(suggestions, ", ") + ", or " + last; - } else { - did_you_mean = StringUtil::Join(suggestions, " or "); - } - - return CatalogException::MissingEntry(type, entry_name, did_you_mean, error_context); -} - -CatalogEntryLookup Catalog::TryLookupEntryInternal(CatalogTransaction transaction, CatalogType type, - const string &schema, const string &name) { - auto schema_entry = GetSchema(transaction, schema, OnEntryNotFound::RETURN_NULL); - if (!schema_entry) { - return {nullptr, nullptr, ErrorData()}; - } - auto entry = schema_entry->GetEntry(transaction, type, name); - if (!entry) { - return {schema_entry, nullptr, ErrorData()}; - } - return {schema_entry, entry, ErrorData()}; -} - -CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, CatalogType type, const string &schema, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto &context = retriever.GetContext(); - reference_set_t schemas; - if (IsInvalidSchema(schema)) { - // try all schemas for this catalog - auto entries = GetCatalogEntries(retriever, GetName(), INVALID_SCHEMA); - for (auto &entry : entries) { - auto &candidate_schema = entry.schema; - auto transaction = GetCatalogTransaction(context); - auto result = TryLookupEntryInternal(transaction, type, candidate_schema, name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - } else { - auto transaction = GetCatalogTransaction(context); - auto result = TryLookupEntryInternal(transaction, type, schema, name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - return {nullptr, nullptr, ErrorData()}; - } else { - auto except = CreateMissingEntryException(retriever, name, type, schemas, error_context); - return {nullptr, nullptr, ErrorData(except)}; - } -} - -CatalogEntryLookup Catalog::LookupEntry(CatalogEntryRetriever &retriever, CatalogType type, const string &schema, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto res = TryLookupEntry(retriever, type, schema, name, if_not_found, error_context); - - if (res.error.HasError()) { - res.error.Throw(); - } - - return res; -} - -CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, vector &lookups, - CatalogType type, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto &context = retriever.GetContext(); - reference_set_t schemas; - for (auto &lookup : lookups) { - auto transaction = lookup.catalog.GetCatalogTransaction(context); - auto result = lookup.catalog.TryLookupEntryInternal(transaction, type, lookup.schema, lookup.name); - if (result.Found()) { - return result; - } - if (result.schema) { - schemas.insert(*result.schema); - } - } - - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - return {nullptr, nullptr, ErrorData()}; - } else { - auto except = CreateMissingEntryException(retriever, name, type, schemas, error_context); - return {nullptr, nullptr, ErrorData(except)}; - } -} - -CatalogEntryLookup Catalog::TryLookupDefaultTable(CatalogEntryRetriever &retriever, CatalogType type, - const string &catalog, const string &schema, const string &name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - // Default tables of catalogs can only be accessed by the catalog name directly - if (!schema.empty() || !catalog.empty()) { - return {nullptr, nullptr, ErrorData()}; - } - - vector catalog_by_name_lookups; - auto catalog_by_name = GetCatalogEntry(retriever, name); - if (catalog_by_name && catalog_by_name->HasDefaultTable()) { - catalog_by_name_lookups.emplace_back(*catalog_by_name, catalog_by_name->GetDefaultTableSchema(), - catalog_by_name->GetDefaultTable()); - } - - return TryLookupEntry(retriever, catalog_by_name_lookups, type, name, if_not_found, error_context); -} - -static void ThrowDefaultTableAmbiguityException(CatalogEntryLookup &base_lookup, CatalogEntryLookup &default_table, - const string &name) { - auto entry_type = CatalogTypeToString(base_lookup.entry->type); - string fully_qualified_name_hint; - if (base_lookup.schema) { - fully_qualified_name_hint = StringUtil::Format(": '%s.%s.%s'", base_lookup.schema->catalog.GetName(), - base_lookup.schema->name, base_lookup.entry->name); - } - string fully_qualified_catalog_name_hint = StringUtil::Format( - ": '%s.%s.%s'", default_table.schema->catalog.GetName(), default_table.schema->name, default_table.entry->name); - throw CatalogException( - "Ambiguity detected for '%s': this could either refer to the '%s' '%s', or the " - "attached catalog '%s' which has a default table. To avoid this error, either detach the catalog and " - "reattach under a different name, or use a fully qualified name for the '%s'%s or for the Catalog " - "Default Table%s.", - name, entry_type, name, name, entry_type, fully_qualified_name_hint, fully_qualified_catalog_name_hint); -} - -CatalogEntryLookup Catalog::TryLookupEntry(CatalogEntryRetriever &retriever, CatalogType type, const string &catalog, - const string &schema, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto entries = GetCatalogEntries(retriever, catalog, schema); - vector lookups; - vector final_lookups; - lookups.reserve(entries.size()); - for (auto &entry : entries) { - optional_ptr catalog_entry; - if (if_not_found == OnEntryNotFound::RETURN_NULL) { - catalog_entry = Catalog::GetCatalogEntry(retriever, entry.catalog); - } else { - catalog_entry = &Catalog::GetCatalog(retriever, entry.catalog); - } - if (!catalog_entry) { - return {nullptr, nullptr, ErrorData()}; - } - D_ASSERT(catalog_entry); - auto lookup_behavior = catalog_entry->CatalogTypeLookupRule(type); - if (lookup_behavior == CatalogLookupBehavior::STANDARD) { - lookups.emplace_back(*catalog_entry, entry.schema, name); - } else if (lookup_behavior == CatalogLookupBehavior::LOWER_PRIORITY) { - final_lookups.emplace_back(*catalog_entry, entry.schema, name); - } - } - - for (auto &lookup : final_lookups) { - lookups.emplace_back(std::move(lookup)); - } - - // Do the main lookup - auto lookup_result = TryLookupEntry(retriever, lookups, type, name, if_not_found, error_context); - - // Special case for tables: we do a second lookup searching for catalogs with default tables that also match this - // lookup - if (type == CatalogType::TABLE_ENTRY) { - auto lookup_result_default_table = - TryLookupDefaultTable(retriever, type, catalog, schema, name, OnEntryNotFound::RETURN_NULL, error_context); - - if (lookup_result_default_table.Found() && lookup_result.Found()) { - ThrowDefaultTableAmbiguityException(lookup_result, lookup_result_default_table, name); - } - - if (lookup_result_default_table.Found()) { - return lookup_result_default_table; - } - } - - return lookup_result; -} - -optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, CatalogType type, - const string &schema_name, const string &name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - auto lookup_entry = TryLookupEntry(retriever, type, schema_name, name, if_not_found, error_context); - - // Try autoloading extension to resolve lookup - if (!lookup_entry.Found()) { - if (AutoLoadExtensionByCatalogEntry(*retriever.GetContext().db, type, name)) { - lookup_entry = TryLookupEntry(retriever, type, schema_name, name, if_not_found, error_context); - } - } - - if (lookup_entry.error.HasError()) { - lookup_entry.error.Throw(); - } - - return lookup_entry.entry.get(); -} - -optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema_name, - const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - CatalogEntryRetriever retriever(context); - return GetEntry(retriever, type, schema_name, name, if_not_found, error_context); -} - -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &schema, const string &name, - QueryErrorContext error_context) { - return *Catalog::GetEntry(context, type, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetEntry(CatalogEntryRetriever &retriever, CatalogType type, const string &catalog, - const string &schema, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto result = TryLookupEntry(retriever, type, catalog, schema, name, if_not_found, error_context); - - // Try autoloading extension to resolve lookup - if (!result.Found()) { - if (AutoLoadExtensionByCatalogEntry(*retriever.GetContext().db, type, name)) { - result = TryLookupEntry(retriever, type, catalog, schema, name, if_not_found, error_context); - } - } - - if (result.error.HasError()) { - result.error.Throw(); - } - - if (!result.Found()) { - D_ASSERT(if_not_found == OnEntryNotFound::RETURN_NULL); - return nullptr; - } - return result.entry.get(); -} -optional_ptr Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, - const string &schema, const string &name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - CatalogEntryRetriever retriever(context); - return GetEntry(retriever, type, catalog, schema, name, if_not_found, error_context); -} - -CatalogEntry &Catalog::GetEntry(ClientContext &context, CatalogType type, const string &catalog, const string &schema, - const string &name, QueryErrorContext error_context) { - return *Catalog::GetEntry(context, type, catalog, schema, name, OnEntryNotFound::THROW_EXCEPTION, error_context); -} - -optional_ptr Catalog::GetSchema(CatalogEntryRetriever &retriever, const string &catalog_name, - const string &schema_name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - auto entries = GetCatalogEntries(retriever, catalog_name, schema_name); - for (idx_t i = 0; i < entries.size(); i++) { - auto on_not_found = i + 1 == entries.size() ? if_not_found : OnEntryNotFound::RETURN_NULL; - auto &catalog = Catalog::GetCatalog(retriever, entries[i].catalog); - auto result = catalog.GetSchema(retriever.GetContext(), schema_name, on_not_found, error_context); - if (result) { - return result; - } - } - return nullptr; -} - -optional_ptr Catalog::GetSchema(ClientContext &context, const string &catalog_name, - const string &schema_name, OnEntryNotFound if_not_found, - QueryErrorContext error_context) { - CatalogEntryRetriever retriever(context); - return GetSchema(retriever, catalog_name, schema_name, if_not_found, error_context); -} - -vector> Catalog::GetSchemas(ClientContext &context) { - vector> schemas; - ScanSchemas(context, [&](SchemaCatalogEntry &entry) { schemas.push_back(entry); }); - return schemas; -} - -vector> Catalog::GetSchemas(CatalogEntryRetriever &retriever, - const string &catalog_name) { - vector> catalogs; - if (IsInvalidCatalog(catalog_name)) { - reference_set_t inserted_catalogs; - - auto &search_path = retriever.GetSearchPath(); - for (auto &entry : search_path.Get()) { - auto &catalog = Catalog::GetCatalog(retriever, entry.catalog); - if (inserted_catalogs.find(catalog) != inserted_catalogs.end()) { - continue; - } - inserted_catalogs.insert(catalog); - catalogs.push_back(catalog); - } - } else { - catalogs.push_back(Catalog::GetCatalog(retriever, catalog_name)); - } - vector> result; - for (auto catalog : catalogs) { - auto schemas = catalog.get().GetSchemas(retriever.GetContext()); - result.insert(result.end(), schemas.begin(), schemas.end()); - } - return result; -} - -vector> Catalog::GetSchemas(ClientContext &context, const string &catalog_name) { - CatalogEntryRetriever retriever(context); - return GetSchemas(retriever, catalog_name); -} - -vector> Catalog::GetAllSchemas(ClientContext &context) { - vector> result; - - auto &db_manager = DatabaseManager::Get(context); - auto databases = db_manager.GetDatabases(context); - for (auto database : databases) { - auto &catalog = database.get().GetCatalog(); - auto new_schemas = catalog.GetSchemas(context); - result.insert(result.end(), new_schemas.begin(), new_schemas.end()); - } - sort(result.begin(), result.end(), - [&](reference left_p, reference right_p) { - auto &left = left_p.get(); - auto &right = right_p.get(); - if (left.catalog.GetName() < right.catalog.GetName()) { - return true; - } - if (left.catalog.GetName() == right.catalog.GetName()) { - return left.name < right.name; - } - return false; - }); - - return result; -} - -void Catalog::Alter(CatalogTransaction transaction, AlterInfo &info) { - if (transaction.HasContext()) { - CatalogEntryRetriever retriever(transaction.GetContext()); - auto lookup = LookupEntry(retriever, info.GetCatalogType(), info.schema, info.name, info.if_not_found); - if (!lookup.Found()) { - return; - } - return lookup.schema->Alter(transaction, info); - } - D_ASSERT(info.if_not_found == OnEntryNotFound::THROW_EXCEPTION); - auto &schema = GetSchema(transaction, info.schema); - return schema.Alter(transaction, info); -} - -void Catalog::Alter(ClientContext &context, AlterInfo &info) { - Alter(GetCatalogTransaction(context), info); -} - -vector Catalog::GetMetadataInfo(ClientContext &context) { - return vector(); -} - -optional_ptr Catalog::GetDependencyManager() { - return nullptr; -} - -//! Whether this catalog has a default table. Catalogs with a default table can be queries by their catalog name -bool Catalog::HasDefaultTable() const { - return !default_table.empty(); -} - -void Catalog::SetDefaultTable(const string &schema, const string &name) { - default_table = name; - default_table_schema = schema; -} - -string Catalog::GetDefaultTable() const { - return default_table; -} - -string Catalog::GetDefaultTableSchema() const { - return !default_table_schema.empty() ? default_table_schema : DEFAULT_SCHEMA; -} - -void Catalog::Verify() { -} - -bool Catalog::IsSystemCatalog() const { - return db.IsSystem(); -} - -bool Catalog::IsTemporaryCatalog() const { - return db.IsTemporary(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry.cpp deleted file mode 100644 index d6a96d6ec..000000000 --- a/src/duckdb/src/catalog/catalog_entry.cpp +++ /dev/null @@ -1,125 +0,0 @@ -#include "duckdb/catalog/catalog_entry.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/serializer/binary_deserializer.hpp" -#include "duckdb/common/serializer/binary_serializer.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/parser/parsed_data/create_info.hpp" - -namespace duckdb { - -CatalogEntry::CatalogEntry(CatalogType type, string name_p, idx_t oid) - : oid(oid), type(type), set(nullptr), name(std::move(name_p)), deleted(false), temporary(false), internal(false), - parent(nullptr) { -} - -CatalogEntry::CatalogEntry(CatalogType type, Catalog &catalog, string name_p) - : CatalogEntry(type, std::move(name_p), catalog.GetDatabase().GetDatabaseManager().NextOid()) { -} - -CatalogEntry::~CatalogEntry() { -} - -void CatalogEntry::SetAsRoot() { -} - -// LCOV_EXCL_START -unique_ptr CatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - throw InternalException("Unsupported alter type for catalog entry!"); -} - -unique_ptr CatalogEntry::AlterEntry(CatalogTransaction transaction, AlterInfo &info) { - if (!transaction.context) { - throw InternalException("Cannot AlterEntry without client context"); - } - return AlterEntry(*transaction.context, info); -} - -void CatalogEntry::UndoAlter(ClientContext &context, AlterInfo &info) { -} - -unique_ptr CatalogEntry::Copy(ClientContext &context) const { - throw InternalException("Unsupported copy type for catalog entry!"); -} - -unique_ptr CatalogEntry::GetInfo() const { - throw InternalException("Unsupported type for CatalogEntry::GetInfo!"); -} - -string CatalogEntry::ToSQL() const { - throw InternalException("Unsupported catalog type for ToSQL()"); -} - -void CatalogEntry::SetChild(unique_ptr child_p) { - child = std::move(child_p); - if (child) { - child->parent = this; - } -} - -unique_ptr CatalogEntry::TakeChild() { - if (child) { - child->parent = nullptr; - } - return std::move(child); -} - -bool CatalogEntry::HasChild() const { - return child != nullptr; -} -bool CatalogEntry::HasParent() const { - return parent != nullptr; -} - -CatalogEntry &CatalogEntry::Child() { - return *child; -} - -CatalogEntry &CatalogEntry::Parent() { - return *parent; -} - -Catalog &CatalogEntry::ParentCatalog() { - throw InternalException("CatalogEntry::ParentCatalog called on catalog entry without catalog"); -} - -const Catalog &CatalogEntry::ParentCatalog() const { - throw InternalException("CatalogEntry::ParentCatalog called on catalog entry without catalog"); -} - -SchemaCatalogEntry &CatalogEntry::ParentSchema() { - throw InternalException("CatalogEntry::ParentSchema called on catalog entry without schema"); -} - -const SchemaCatalogEntry &CatalogEntry::ParentSchema() const { - throw InternalException("CatalogEntry::ParentSchema called on catalog entry without schema"); -} -// LCOV_EXCL_STOP - -void CatalogEntry::Serialize(Serializer &serializer) const { - const auto info = GetInfo(); - info->Serialize(serializer); -} - -unique_ptr CatalogEntry::Deserialize(Deserializer &deserializer) { - return CreateInfo::Deserialize(deserializer); -} - -void CatalogEntry::Verify(Catalog &catalog_p) { -} - -void CatalogEntry::Rollback(CatalogEntry &prev_entry) { -} - -InCatalogEntry::InCatalogEntry(CatalogType type, Catalog &catalog, string name) - : CatalogEntry(type, catalog, std::move(name)), catalog(catalog) { -} - -InCatalogEntry::~InCatalogEntry() { -} - -void InCatalogEntry::Verify(Catalog &catalog_p) { - D_ASSERT(&catalog_p == &catalog); -} -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp b/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp deleted file mode 100644 index b8b2f0e87..000000000 --- a/src/duckdb/src/catalog/catalog_entry/column_dependency_manager.cpp +++ /dev/null @@ -1,271 +0,0 @@ -#include "duckdb/catalog/catalog_entry/column_dependency_manager.hpp" -#include "duckdb/parser/column_definition.hpp" -#include "duckdb/common/set.hpp" -#include "duckdb/common/queue.hpp" -#include "duckdb/common/exception/binder_exception.hpp" - -namespace duckdb { - -ColumnDependencyManager::ColumnDependencyManager() { -} - -ColumnDependencyManager::~ColumnDependencyManager() { -} - -void ColumnDependencyManager::AddGeneratedColumn(const ColumnDefinition &column, const ColumnList &list) { - D_ASSERT(column.Generated()); - vector referenced_columns; - column.GetListOfDependencies(referenced_columns); - vector indices; - for (auto &col : referenced_columns) { - if (!list.ColumnExists(col)) { - throw BinderException("Column \"%s\" referenced by generated column does not exist", col); - } - auto &entry = list.GetColumn(col); - indices.push_back(entry.Logical()); - } - return AddGeneratedColumn(column.Logical(), indices); -} - -void ColumnDependencyManager::AddGeneratedColumn(LogicalIndex index, const vector &indices, bool root) { - if (indices.empty()) { - return; - } - auto &list = dependents_map[index]; - // Create a link between the dependencies - for (auto &dep : indices) { - // Add this column as a dependency of the new column - list.insert(dep); - // Add the new column as a dependent of the column - dependencies_map[dep].insert(index); - // Inherit the dependencies - if (HasDependencies(dep)) { - auto &inherited_deps = dependents_map[dep]; - D_ASSERT(!inherited_deps.empty()); - for (auto &inherited_dep : inherited_deps) { - list.insert(inherited_dep); - dependencies_map[inherited_dep].insert(index); - } - } - if (!root) { - continue; - } - direct_dependencies[index].insert(dep); - } - if (!HasDependents(index)) { - return; - } - auto &dependents = dependencies_map[index]; - if (dependents.count(index)) { - throw InvalidInputException("Circular dependency encountered when resolving generated column expressions"); - } - // Also let the dependents of this generated column inherit the dependencies - for (auto &dependent : dependents) { - AddGeneratedColumn(dependent, indices, false); - } -} - -vector ColumnDependencyManager::RemoveColumn(LogicalIndex index, idx_t column_amount) { - // Always add the initial column - deleted_columns.insert(index); - - RemoveGeneratedColumn(index); - RemoveStandardColumn(index); - - // Clean up the internal list - vector new_indices = CleanupInternals(column_amount); - D_ASSERT(deleted_columns.empty()); - return new_indices; -} - -bool ColumnDependencyManager::IsDependencyOf(LogicalIndex gcol, LogicalIndex col) const { - auto entry = dependents_map.find(gcol); - if (entry == dependents_map.end()) { - return false; - } - auto &list = entry->second; - return list.count(col); -} - -bool ColumnDependencyManager::HasDependencies(LogicalIndex index) const { - auto entry = dependents_map.find(index); - if (entry == dependents_map.end()) { - return false; - } - return true; -} - -const logical_index_set_t &ColumnDependencyManager::GetDependencies(LogicalIndex index) const { - auto entry = dependents_map.find(index); - D_ASSERT(entry != dependents_map.end()); - return entry->second; -} - -bool ColumnDependencyManager::HasDependents(LogicalIndex index) const { - auto entry = dependencies_map.find(index); - if (entry == dependencies_map.end()) { - return false; - } - return true; -} - -const logical_index_set_t &ColumnDependencyManager::GetDependents(LogicalIndex index) const { - auto entry = dependencies_map.find(index); - D_ASSERT(entry != dependencies_map.end()); - return entry->second; -} - -void ColumnDependencyManager::RemoveStandardColumn(LogicalIndex index) { - if (!HasDependents(index)) { - return; - } - auto dependents = dependencies_map[index]; - for (auto &gcol : dependents) { - // If index is a direct dependency of gcol, remove it from the list - if (direct_dependencies.find(gcol) != direct_dependencies.end()) { - direct_dependencies[gcol].erase(index); - } - RemoveGeneratedColumn(gcol); - } - // Remove this column from the dependencies map - dependencies_map.erase(index); -} - -void ColumnDependencyManager::RemoveGeneratedColumn(LogicalIndex index) { - deleted_columns.insert(index); - if (!HasDependencies(index)) { - return; - } - auto &dependencies = dependents_map[index]; - for (auto &col : dependencies) { - // Remove this generated column from the list of this column - auto &col_dependents = dependencies_map[col]; - D_ASSERT(col_dependents.count(index)); - col_dependents.erase(index); - // If the resulting list is empty, remove the column from the dependencies map altogether - if (col_dependents.empty()) { - dependencies_map.erase(col); - } - } - // Remove this column from the dependents_map map - dependents_map.erase(index); -} - -void ColumnDependencyManager::AdjustSingle(LogicalIndex idx, idx_t offset) { - D_ASSERT(idx.index >= offset); - LogicalIndex new_idx = LogicalIndex(idx.index - offset); - // Adjust this index in the dependents of this column - bool has_dependents = HasDependents(idx); - bool has_dependencies = HasDependencies(idx); - - if (has_dependents) { - auto &dependents = GetDependents(idx); - for (auto &dep : dependents) { - auto &dep_dependencies = dependents_map[dep]; - dep_dependencies.erase(idx); - D_ASSERT(!dep_dependencies.count(new_idx)); - dep_dependencies.insert(new_idx); - } - } - if (has_dependencies) { - auto &dependencies = GetDependencies(idx); - for (auto &dep : dependencies) { - auto &dep_dependents = dependencies_map[dep]; - dep_dependents.erase(idx); - D_ASSERT(!dep_dependents.count(new_idx)); - dep_dependents.insert(new_idx); - } - } - if (has_dependents) { - D_ASSERT(!dependencies_map.count(new_idx)); - dependencies_map[new_idx] = std::move(dependencies_map[idx]); - dependencies_map.erase(idx); - } - if (has_dependencies) { - D_ASSERT(!dependents_map.count(new_idx)); - dependents_map[new_idx] = std::move(dependents_map[idx]); - dependents_map.erase(idx); - } -} - -vector ColumnDependencyManager::CleanupInternals(idx_t column_amount) { - vector to_adjust; - D_ASSERT(!deleted_columns.empty()); - // Get the lowest index that was deleted - vector new_indices(column_amount, LogicalIndex(DConstants::INVALID_INDEX)); - idx_t threshold = deleted_columns.begin()->index; - - idx_t offset = 0; - for (idx_t i = 0; i < column_amount; i++) { - auto current_index = LogicalIndex(i); - auto new_index = LogicalIndex(i - offset); - new_indices[i] = new_index; - if (deleted_columns.count(current_index)) { - offset++; - continue; - } - if (i > threshold && (HasDependencies(current_index) || HasDependents(current_index))) { - to_adjust.push_back(current_index); - } - } - - // Adjust all indices inside the dependency managers internal mappings - for (auto &col : to_adjust) { - auto offset = col.index - new_indices[col.index].index; - AdjustSingle(col, offset); - } - deleted_columns.clear(); - return new_indices; -} - -stack ColumnDependencyManager::GetBindOrder(const ColumnList &columns) { - stack bind_order; - queue to_visit; - logical_index_set_t visited; - - for (auto &entry : direct_dependencies) { - auto dependent = entry.first; - //! Skip the dependents that are also dependencies - if (dependencies_map.find(dependent) != dependencies_map.end()) { - continue; - } - bind_order.push(dependent); - visited.insert(dependent); - for (auto &dependency : direct_dependencies[dependent]) { - to_visit.push(dependency); - } - } - - while (!to_visit.empty()) { - auto column = to_visit.front(); - to_visit.pop(); - - //! If this column does not have dependencies, the queue stops getting filled - if (direct_dependencies.find(column) == direct_dependencies.end()) { - continue; - } - bind_order.push(column); - visited.insert(column); - - for (auto &dependency : direct_dependencies[column]) { - to_visit.push(dependency); - } - } - - // Add generated columns that have no dependencies, but still might need to have their type resolved - for (auto &col : columns.Logical()) { - // Not a generated column - if (!col.Generated()) { - continue; - } - // Already added to the bind_order stack - if (visited.count(col.Logical())) { - continue; - } - bind_order.push(col.Logical()); - } - - return bind_order; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp deleted file mode 100644 index 25544a343..000000000 --- a/src/duckdb/src/catalog/catalog_entry/copy_function_catalog_entry.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" - -namespace duckdb { - -CopyFunctionCatalogEntry::CopyFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateCopyFunctionInfo &info) - : StandardEntry(CatalogType::COPY_FUNCTION_ENTRY, schema, catalog, info.name), function(info.function) { -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_dependent_entry.cpp b/src/duckdb/src/catalog/catalog_entry/dependency/dependency_dependent_entry.cpp deleted file mode 100644 index 5baaed41a..000000000 --- a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_dependent_entry.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "duckdb/catalog/catalog_entry/dependency/dependency_dependent_entry.hpp" - -namespace duckdb { - -DependencyDependentEntry::DependencyDependentEntry(Catalog &catalog, const DependencyInfo &info) - : DependencyEntry(catalog, DependencyEntryType::DEPENDENT, - MangledDependencyName(DependencyManager::MangleName(info.subject.entry), - DependencyManager::MangleName(info.dependent.entry)), - info) { -} - -const MangledEntryName &DependencyDependentEntry::EntryMangledName() const { - return dependent_name; -} - -const CatalogEntryInfo &DependencyDependentEntry::EntryInfo() const { - return dependent.entry; -} - -const MangledEntryName &DependencyDependentEntry::SourceMangledName() const { - return subject_name; -} - -const CatalogEntryInfo &DependencyDependentEntry::SourceInfo() const { - return subject.entry; -} - -DependencyDependentEntry::~DependencyDependentEntry() { -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_entry.cpp b/src/duckdb/src/catalog/catalog_entry/dependency/dependency_entry.cpp deleted file mode 100644 index ab425bd98..000000000 --- a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_entry.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -DependencyEntry::DependencyEntry(Catalog &catalog, DependencyEntryType side, const MangledDependencyName &name, - const DependencyInfo &info) - : InCatalogEntry(CatalogType::DEPENDENCY_ENTRY, catalog, name.name), - dependent_name(DependencyManager::MangleName(info.dependent.entry)), - subject_name(DependencyManager::MangleName(info.subject.entry)), dependent(info.dependent), subject(info.subject), - side(side) { - D_ASSERT(info.dependent.entry.type != CatalogType::DEPENDENCY_ENTRY); - D_ASSERT(info.subject.entry.type != CatalogType::DEPENDENCY_ENTRY); - if (catalog.IsTemporaryCatalog()) { - temporary = true; - } -} - -const MangledEntryName &DependencyEntry::SubjectMangledName() const { - return subject_name; -} - -const DependencySubject &DependencyEntry::Subject() const { - return subject; -} - -const MangledEntryName &DependencyEntry::DependentMangledName() const { - return dependent_name; -} - -const DependencyDependent &DependencyEntry::Dependent() const { - return dependent; -} - -DependencyEntry::~DependencyEntry() { -} - -DependencyEntryType DependencyEntry::Side() const { - return side; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_subject_entry.cpp b/src/duckdb/src/catalog/catalog_entry/dependency/dependency_subject_entry.cpp deleted file mode 100644 index eb9c7f63a..000000000 --- a/src/duckdb/src/catalog/catalog_entry/dependency/dependency_subject_entry.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "duckdb/catalog/catalog_entry/dependency/dependency_subject_entry.hpp" - -namespace duckdb { - -DependencySubjectEntry::DependencySubjectEntry(Catalog &catalog, const DependencyInfo &info) - : DependencyEntry(catalog, DependencyEntryType::SUBJECT, - MangledDependencyName(DependencyManager::MangleName(info.dependent.entry), - DependencyManager::MangleName(info.subject.entry)), - info) { -} - -const MangledEntryName &DependencySubjectEntry::EntryMangledName() const { - return subject_name; -} - -const CatalogEntryInfo &DependencySubjectEntry::EntryInfo() const { - return subject.entry; -} - -const MangledEntryName &DependencySubjectEntry::SourceMangledName() const { - return dependent_name; -} - -const CatalogEntryInfo &DependencySubjectEntry::SourceInfo() const { - return dependent.entry; -} - -DependencySubjectEntry::~DependencySubjectEntry() { -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp deleted file mode 100644 index c70984e53..000000000 --- a/src/duckdb/src/catalog/catalog_entry/duck_index_entry.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" - -#include "duckdb/storage/data_table.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" - -namespace duckdb { - -IndexDataTableInfo::IndexDataTableInfo(shared_ptr info_p, const string &index_name_p) - : info(std::move(info_p)), index_name(index_name_p) { -} - -void DuckIndexEntry::Rollback(CatalogEntry &) { - if (!info) { - return; - } - if (!info->info) { - return; - } - info->info->GetIndexes().RemoveIndex(name); -} - -DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, - TableCatalogEntry &table_p) - : IndexCatalogEntry(catalog, schema, create_info), initial_index_size(0) { - - auto &table = table_p.Cast(); - auto &storage = table.GetStorage(); - info = make_shared_ptr(storage.GetDataTableInfo(), name); -} - -DuckIndexEntry::DuckIndexEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &create_info, - shared_ptr storage_info) - : IndexCatalogEntry(catalog, schema, create_info), info(std::move(storage_info)), initial_index_size(0) { -} - -unique_ptr DuckIndexEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - - auto result = make_uniq(catalog, schema, cast_info, info); - result->initial_index_size = initial_index_size; - - return std::move(result); -} - -string DuckIndexEntry::GetSchemaName() const { - return GetDataTableInfo().GetSchemaName(); -} - -string DuckIndexEntry::GetTableName() const { - return GetDataTableInfo().GetTableName(); -} - -DataTableInfo &DuckIndexEntry::GetDataTableInfo() const { - return *info->info; -} - -void DuckIndexEntry::CommitDrop() { - D_ASSERT(info); - auto &indexes = GetDataTableInfo().GetIndexes(); - indexes.CommitDrop(name); - indexes.RemoveIndex(name); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp deleted file mode 100644 index 9f07d4acf..000000000 --- a/src/duckdb/src/catalog/catalog_entry/duck_schema_entry.cpp +++ /dev/null @@ -1,425 +0,0 @@ -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" - -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/copy_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/catalog/default/default_functions.hpp" -#include "duckdb/catalog/default/default_table_functions.hpp" -#include "duckdb/catalog/default/default_types.hpp" -#include "duckdb/catalog/default/default_views.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/constraints/foreign_key_constraint.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/parser/parsed_data/create_collation_info.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" -#include "duckdb/parser/parsed_data/create_index_info.hpp" -#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/parsed_data/drop_info.hpp" -#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" -#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/transaction/meta_transaction.hpp" - -namespace duckdb { - -static void FindForeignKeyInformation(TableCatalogEntry &table, AlterForeignKeyType alter_fk_type, - vector> &fk_arrays) { - auto &constraints = table.GetConstraints(); - auto &catalog = table.ParentCatalog(); - auto &name = table.name; - for (idx_t i = 0; i < constraints.size(); i++) { - auto &cond = constraints[i]; - if (cond->type != ConstraintType::FOREIGN_KEY) { - continue; - } - auto &fk = cond->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - AlterEntryData alter_data(catalog.GetName(), fk.info.schema, fk.info.table, - OnEntryNotFound::THROW_EXCEPTION); - fk_arrays.push_back(make_uniq(std::move(alter_data), name, fk.pk_columns, - fk.fk_columns, fk.info.pk_keys, fk.info.fk_keys, - alter_fk_type)); - } else if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && - alter_fk_type == AlterForeignKeyType::AFT_DELETE) { - throw CatalogException("Could not drop the table because this table is main key table of the table \"%s\"", - fk.info.table); - } - } -} - -static void LazyLoadIndexes(ClientContext &context, CatalogEntry &entry) { - if (entry.type == CatalogType::TABLE_ENTRY) { - auto &table_entry = entry.Cast(); - table_entry.GetStorage().InitializeIndexes(context); - } else if (entry.type == CatalogType::INDEX_ENTRY) { - auto &index_entry = entry.Cast(); - auto &table_entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, index_entry.catalog.GetName(), - index_entry.GetSchemaName(), index_entry.GetTableName()) - .Cast(); - table_entry.GetStorage().InitializeIndexes(context); - } -} - -DuckSchemaEntry::DuckSchemaEntry(Catalog &catalog, CreateSchemaInfo &info) - : SchemaCatalogEntry(catalog, info), tables(catalog, make_uniq(catalog, *this)), - indexes(catalog), table_functions(catalog, make_uniq(catalog, *this)), - copy_functions(catalog), pragma_functions(catalog), - functions(catalog, make_uniq(catalog, *this)), sequences(catalog), collations(catalog), - types(catalog, make_uniq(catalog, *this)) { -} - -unique_ptr DuckSchemaEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - - auto result = make_uniq(catalog, cast_info); - - return std::move(result); -} - -optional_ptr DuckSchemaEntry::AddEntryInternal(CatalogTransaction transaction, - unique_ptr entry, - OnCreateConflict on_conflict, - LogicalDependencyList dependencies) { - auto entry_name = entry->name; - auto entry_type = entry->type; - auto result = entry.get(); - - if (transaction.context) { - auto &meta = MetaTransaction::Get(transaction.GetContext()); - auto modified_database = meta.ModifiedDatabase(); - auto &db = ParentCatalog().GetAttached(); - if (!db.IsTemporary() && !db.IsSystem()) { - if (!modified_database || !RefersToSameObject(*modified_database, ParentCatalog().GetAttached())) { - throw InternalException( - "DuckSchemaEntry::AddEntryInternal called but this database is not marked as modified"); - } - } - } - // first find the set for this entry - auto &set = GetCatalogSet(entry_type); - dependencies.AddDependency(*this); - if (on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT) { - auto old_entry = set.GetEntry(transaction, entry_name); - if (old_entry) { - return nullptr; - } - } - - if (on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT) { - // CREATE OR REPLACE: first try to drop the entry - auto old_entry = set.GetEntry(transaction, entry_name); - if (old_entry) { - if (dependencies.Contains(*old_entry)) { - throw CatalogException("CREATE OR REPLACE is not allowed to depend on itself"); - } - if (old_entry->type != entry_type) { - throw CatalogException("Existing object %s is of type %s, trying to replace with type %s", entry_name, - CatalogTypeToString(old_entry->type), CatalogTypeToString(entry_type)); - } - OnDropEntry(transaction, *old_entry); - (void)set.DropEntry(transaction, entry_name, false, entry->internal); - } - } - // now try to add the entry - if (!set.CreateEntry(transaction, entry_name, std::move(entry), dependencies)) { - // entry already exists! - if (on_conflict == OnCreateConflict::ERROR_ON_CONFLICT) { - throw CatalogException::EntryAlreadyExists(entry_type, entry_name); - } else { - return nullptr; - } - } - return result; -} - -optional_ptr DuckSchemaEntry::CreateTable(CatalogTransaction transaction, BoundCreateTableInfo &info) { - auto table = make_uniq(catalog, *this, info); - - // add a foreign key constraint in main key table if there is a foreign key constraint - vector> fk_arrays; - FindForeignKeyInformation(*table, AlterForeignKeyType::AFT_ADD, fk_arrays); - for (idx_t i = 0; i < fk_arrays.size(); i++) { - // alter primary key table - auto &fk_info = *fk_arrays[i]; - Alter(transaction, fk_info); - - // make a dependency between this table and referenced table - auto &set = GetCatalogSet(CatalogType::TABLE_ENTRY); - info.dependencies.AddDependency(*set.GetEntry(transaction, fk_info.name)); - } - for (auto &dep : info.dependencies.Set()) { - table->dependencies.AddDependency(dep); - } - - auto entry = AddEntryInternal(transaction, std::move(table), info.Base().on_conflict, info.dependencies); - if (!entry) { - return nullptr; - } - - return entry; -} - -optional_ptr DuckSchemaEntry::CreateFunction(CatalogTransaction transaction, CreateFunctionInfo &info) { - if (info.on_conflict == OnCreateConflict::ALTER_ON_CONFLICT) { - // check if the original entry exists - auto &catalog_set = GetCatalogSet(info.type); - auto current_entry = catalog_set.GetEntry(transaction, info.name); - if (current_entry) { - // the current entry exists - alter it instead - auto alter_info = info.GetAlterInfo(); - Alter(transaction, *alter_info); - return nullptr; - } - } - unique_ptr function; - switch (info.type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: - function = make_uniq_base(catalog, *this, - info.Cast()); - break; - case CatalogType::TABLE_FUNCTION_ENTRY: - function = make_uniq_base(catalog, *this, - info.Cast()); - break; - case CatalogType::MACRO_ENTRY: - // create a macro function - function = make_uniq_base(catalog, *this, info.Cast()); - break; - - case CatalogType::TABLE_MACRO_ENTRY: - // create a macro table function - function = make_uniq_base(catalog, *this, info.Cast()); - break; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - D_ASSERT(info.type == CatalogType::AGGREGATE_FUNCTION_ENTRY); - // create an aggregate function - function = make_uniq_base( - catalog, *this, info.Cast()); - break; - default: - throw InternalException("Unknown function type \"%s\"", CatalogTypeToString(info.type)); - } - function->internal = info.internal; - return AddEntry(transaction, std::move(function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::AddEntry(CatalogTransaction transaction, unique_ptr entry, - OnCreateConflict on_conflict) { - LogicalDependencyList dependencies = entry->dependencies; - return AddEntryInternal(transaction, std::move(entry), on_conflict, dependencies); -} - -optional_ptr DuckSchemaEntry::CreateSequence(CatalogTransaction transaction, CreateSequenceInfo &info) { - auto sequence = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(sequence), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateType(CatalogTransaction transaction, CreateTypeInfo &info) { - auto type_entry = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(type_entry), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateView(CatalogTransaction transaction, CreateViewInfo &info) { - auto view = make_uniq(catalog, *this, info); - return AddEntry(transaction, std::move(view), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateIndex(CatalogTransaction transaction, CreateIndexInfo &info, - TableCatalogEntry &table) { - info.dependencies.AddDependency(table); - - // currently, we can not alter PK/FK/UNIQUE constraints - // concurrency-safe name checks against other INDEX catalog entries happens in the catalog - if (info.on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT && - !table.GetStorage().IndexNameIsUnique(info.index_name)) { - throw CatalogException("An index with the name " + info.index_name + " already exists!"); - } - - auto index = make_uniq(catalog, *this, info, table); - auto dependencies = index->dependencies; - return AddEntryInternal(transaction, std::move(index), info.on_conflict, dependencies); -} - -optional_ptr DuckSchemaEntry::CreateCollation(CatalogTransaction transaction, CreateCollationInfo &info) { - auto collation = make_uniq(catalog, *this, info); - collation->internal = info.internal; - return AddEntry(transaction, std::move(collation), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateTableFunction(CatalogTransaction transaction, - CreateTableFunctionInfo &info) { - auto table_function = make_uniq(catalog, *this, info); - table_function->internal = info.internal; - return AddEntry(transaction, std::move(table_function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreateCopyFunction(CatalogTransaction transaction, - CreateCopyFunctionInfo &info) { - auto copy_function = make_uniq(catalog, *this, info); - copy_function->internal = info.internal; - return AddEntry(transaction, std::move(copy_function), info.on_conflict); -} - -optional_ptr DuckSchemaEntry::CreatePragmaFunction(CatalogTransaction transaction, - CreatePragmaFunctionInfo &info) { - auto pragma_function = make_uniq(catalog, *this, info); - pragma_function->internal = info.internal; - return AddEntry(transaction, std::move(pragma_function), info.on_conflict); -} - -void DuckSchemaEntry::Alter(CatalogTransaction transaction, AlterInfo &info) { - CatalogType type = info.GetCatalogType(); - - auto &set = GetCatalogSet(type); - if (info.type == AlterType::CHANGE_OWNERSHIP) { - if (!set.AlterOwnership(transaction, info.Cast())) { - throw CatalogException("Couldn't change ownership!"); - } - } else { - string name = info.name; - if (!set.AlterEntry(transaction, name, info)) { - throw CatalogException::MissingEntry(type, name, string()); - } - } -} - -void DuckSchemaEntry::Scan(ClientContext &context, CatalogType type, - const std::function &callback) { - auto &set = GetCatalogSet(type); - set.Scan(GetCatalogTransaction(context), callback); -} - -void DuckSchemaEntry::Scan(CatalogType type, const std::function &callback) { - auto &set = GetCatalogSet(type); - set.Scan(callback); -} - -void DuckSchemaEntry::DropEntry(ClientContext &context, DropInfo &info) { - auto &set = GetCatalogSet(info.type); - - // first find the entry - auto transaction = GetCatalogTransaction(context); - auto existing_entry = set.GetEntry(transaction, info.name); - if (!existing_entry) { - throw InternalException("Failed to drop entry \"%s\" - entry could not be found", info.name); - } - if (existing_entry->type != info.type) { - throw CatalogException("Existing object %s is of type %s, trying to drop type %s", info.name, - CatalogTypeToString(existing_entry->type), CatalogTypeToString(info.type)); - } - - // if this is a index or table with indexes, initialize any unknown index instances - LazyLoadIndexes(context, *existing_entry); - - vector> fk_arrays; - if (existing_entry->type == CatalogType::TABLE_ENTRY) { - // if there is a foreign key constraint, get that information - auto &table_entry = existing_entry->Cast(); - FindForeignKeyInformation(table_entry, AlterForeignKeyType::AFT_DELETE, fk_arrays); - } - - OnDropEntry(transaction, *existing_entry); - if (!set.DropEntry(transaction, info.name, info.cascade, info.allow_drop_internal)) { - throw InternalException("Could not drop element because of an internal error"); - } - - // remove the foreign key constraint in main key table if main key table's name is valid - for (idx_t i = 0; i < fk_arrays.size(); i++) { - // alter primary key table - Alter(transaction, *fk_arrays[i]); - } -} - -void DuckSchemaEntry::OnDropEntry(CatalogTransaction transaction, CatalogEntry &entry) { - if (!transaction.transaction) { - return; - } - if (entry.type != CatalogType::TABLE_ENTRY) { - return; - } - // if we have transaction local insertions for this table - clear them - auto &table_entry = entry.Cast(); - auto &local_storage = LocalStorage::Get(transaction.transaction->Cast()); - local_storage.DropTable(table_entry.GetStorage()); -} - -optional_ptr DuckSchemaEntry::GetEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - return GetCatalogSet(type).GetEntry(transaction, name); -} - -CatalogSet::EntryLookup DuckSchemaEntry::GetEntryDetailed(CatalogTransaction transaction, CatalogType type, - const string &name) { - return GetCatalogSet(type).GetEntryDetailed(transaction, name); -} - -SimilarCatalogEntry DuckSchemaEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - return GetCatalogSet(type).SimilarEntry(transaction, name); -} - -CatalogSet &DuckSchemaEntry::GetCatalogSet(CatalogType type) { - switch (type) { - case CatalogType::VIEW_ENTRY: - case CatalogType::TABLE_ENTRY: - return tables; - case CatalogType::INDEX_ENTRY: - return indexes; - case CatalogType::TABLE_FUNCTION_ENTRY: - case CatalogType::TABLE_MACRO_ENTRY: - return table_functions; - case CatalogType::COPY_FUNCTION_ENTRY: - return copy_functions; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return pragma_functions; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - case CatalogType::SCALAR_FUNCTION_ENTRY: - case CatalogType::MACRO_ENTRY: - return functions; - case CatalogType::SEQUENCE_ENTRY: - return sequences; - case CatalogType::COLLATION_ENTRY: - return collations; - case CatalogType::TYPE_ENTRY: - return types; - default: - throw InternalException("Unsupported catalog type in schema"); - } -} - -void DuckSchemaEntry::Verify(Catalog &catalog) { - InCatalogEntry::Verify(catalog); - - tables.Verify(catalog); - indexes.Verify(catalog); - table_functions.Verify(catalog); - copy_functions.Verify(catalog); - pragma_functions.Verify(catalog); - functions.Verify(catalog); - sequences.Verify(catalog); - collations.Verify(catalog); - types.Verify(catalog); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp b/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp deleted file mode 100644 index 4983710d9..000000000 --- a/src/duckdb/src/catalog/catalog_entry/duck_table_entry.cpp +++ /dev/null @@ -1,912 +0,0 @@ -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" - -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception/transaction_exception.hpp" -#include "duckdb/common/index_map.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/function/table/table_scan.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/constraints/list.hpp" -#include "duckdb/parser/parsed_data/comment_on_column_info.hpp" -#include "duckdb/parser/parsed_expression_iterator.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/constraints/bound_check_constraint.hpp" -#include "duckdb/planner/constraints/bound_foreign_key_constraint.hpp" -#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" -#include "duckdb/planner/constraints/bound_unique_constraint.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression_binder/alter_binder.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/storage/table_storage_info.hpp" - -namespace duckdb { - -IndexStorageInfo GetIndexInfo(const IndexConstraintType type, const bool v1_0_0_storage, unique_ptr &info, - const idx_t id) { - - auto &table_info = info->Cast(); - auto constraint_name = EnumUtil::ToString(type) + "_"; - auto name = constraint_name + table_info.table + "_" + to_string(id); - IndexStorageInfo index_info(name); - if (!v1_0_0_storage) { - index_info.options.emplace("v1_0_0_storage", v1_0_0_storage); - } - return index_info; -} - -DuckTableEntry::DuckTableEntry(Catalog &catalog, SchemaCatalogEntry &schema, BoundCreateTableInfo &info, - shared_ptr inherited_storage) - : TableCatalogEntry(catalog, schema, info.Base()), storage(std::move(inherited_storage)), - column_dependency_manager(std::move(info.column_dependency_manager)) { - - if (storage) { - if (!info.indexes.empty()) { - storage->SetIndexStorageInfo(std::move(info.indexes)); - } - return; - } - - // create the physical storage - vector column_defs; - for (auto &col_def : columns.Physical()) { - column_defs.push_back(col_def.Copy()); - } - storage = make_shared_ptr(catalog.GetAttached(), StorageManager::Get(catalog).GetTableIOManager(&info), - schema.name, name, std::move(column_defs), std::move(info.data)); - - // Create the unique indexes for the UNIQUE, PRIMARY KEY, and FOREIGN KEY constraints. - idx_t indexes_idx = 0; - for (idx_t i = 0; i < constraints.size(); i++) { - auto &constraint = constraints[i]; - if (constraint->type == ConstraintType::UNIQUE) { - - // UNIQUE constraint: Create a unique index. - auto &unique = constraint->Cast(); - IndexConstraintType constraint_type = IndexConstraintType::UNIQUE; - if (unique.is_primary_key) { - constraint_type = IndexConstraintType::PRIMARY; - } - - auto column_indexes = unique.GetLogicalIndexes(columns); - if (info.indexes.empty()) { - auto index_info = GetIndexInfo(constraint_type, false, info.base, i); - storage->AddIndex(columns, column_indexes, constraint_type, index_info); - continue; - } - - // We read the index from an old storage version applying a dummy name. - if (info.indexes[indexes_idx].name.empty()) { - auto name_info = GetIndexInfo(constraint_type, true, info.base, i); - info.indexes[indexes_idx].name = name_info.name; - } - - // Now we can add the index. - storage->AddIndex(columns, column_indexes, constraint_type, info.indexes[indexes_idx++]); - continue; - } - - if (constraint->type == ConstraintType::FOREIGN_KEY) { - // Create a FOREIGN KEY index. - auto &bfk = constraint->Cast(); - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - - vector column_indexes; - for (const auto &physical_index : bfk.info.fk_keys) { - auto &col = columns.GetColumn(physical_index); - column_indexes.push_back(col.Logical()); - } - - if (info.indexes.empty()) { - auto constraint_type = IndexConstraintType::FOREIGN; - auto index_info = GetIndexInfo(constraint_type, false, info.base, i); - storage->AddIndex(columns, column_indexes, constraint_type, index_info); - continue; - } - - // We read the index from an old storage version applying a dummy name. - if (info.indexes[indexes_idx].name.empty()) { - auto name_info = GetIndexInfo(IndexConstraintType::FOREIGN, true, info.base, i); - info.indexes[indexes_idx].name = name_info.name; - } - - // Now we can add the index. - storage->AddIndex(columns, column_indexes, IndexConstraintType::FOREIGN, info.indexes[indexes_idx++]); - } - } - } - - if (!info.indexes.empty()) { - storage->SetIndexStorageInfo(std::move(info.indexes)); - } -} - -unique_ptr DuckTableEntry::GetStatistics(ClientContext &context, column_t column_id) { - if (column_id == COLUMN_IDENTIFIER_ROW_ID) { - return nullptr; - } - auto &column = columns.GetColumn(LogicalIndex(column_id)); - if (column.Generated()) { - return nullptr; - } - return storage->GetStatistics(context, column.StorageOid()); -} - -unique_ptr DuckTableEntry::GetSample() { - return storage->GetSample(); -} - -unique_ptr DuckTableEntry::AlterEntry(CatalogTransaction transaction, AlterInfo &info) { - if (transaction.HasContext()) { - return AlterEntry(transaction.GetContext(), info); - } - if (info.type != AlterType::ALTER_TABLE) { - return CatalogEntry::AlterEntry(transaction, info); - } - - auto &table_info = info.Cast(); - if (table_info.alter_table_type != AlterTableType::FOREIGN_KEY_CONSTRAINT) { - return CatalogEntry::AlterEntry(transaction, info); - } - - auto &foreign_key_constraint_info = table_info.Cast(); - if (foreign_key_constraint_info.type != AlterForeignKeyType::AFT_ADD) { - return CatalogEntry::AlterEntry(transaction, info); - } - - // We add foreign key constraints without a client context during checkpoint loading. - return AddForeignKeyConstraint(nullptr, foreign_key_constraint_info); -} - -unique_ptr DuckTableEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - - // Column comments have a special alter type - if (info.type == AlterType::SET_COLUMN_COMMENT) { - auto &comment_on_column_info = info.Cast(); - return SetColumnComment(context, comment_on_column_info); - } - - if (info.type != AlterType::ALTER_TABLE) { - throw CatalogException("Can only modify table with ALTER TABLE statement"); - } - auto &table_info = info.Cast(); - switch (table_info.alter_table_type) { - case AlterTableType::RENAME_COLUMN: { - auto &rename_info = table_info.Cast(); - return RenameColumn(context, rename_info); - } - case AlterTableType::RENAME_TABLE: { - auto &rename_info = table_info.Cast(); - auto copied_table = Copy(context); - copied_table->name = rename_info.new_table_name; - storage->SetTableName(rename_info.new_table_name); - return copied_table; - } - case AlterTableType::ADD_COLUMN: { - auto &add_info = table_info.Cast(); - return AddColumn(context, add_info); - } - case AlterTableType::REMOVE_COLUMN: { - auto &remove_info = table_info.Cast(); - return RemoveColumn(context, remove_info); - } - case AlterTableType::SET_DEFAULT: { - auto &set_default_info = table_info.Cast(); - return SetDefault(context, set_default_info); - } - case AlterTableType::ALTER_COLUMN_TYPE: { - auto &change_type_info = table_info.Cast(); - return ChangeColumnType(context, change_type_info); - } - case AlterTableType::FOREIGN_KEY_CONSTRAINT: { - auto &foreign_key_constraint_info = table_info.Cast(); - if (foreign_key_constraint_info.type == AlterForeignKeyType::AFT_ADD) { - return AddForeignKeyConstraint(context, foreign_key_constraint_info); - } else { - return DropForeignKeyConstraint(context, foreign_key_constraint_info); - } - } - case AlterTableType::SET_NOT_NULL: { - auto &set_not_null_info = table_info.Cast(); - return SetNotNull(context, set_not_null_info); - } - case AlterTableType::DROP_NOT_NULL: { - auto &drop_not_null_info = table_info.Cast(); - return DropNotNull(context, drop_not_null_info); - } - case AlterTableType::ADD_CONSTRAINT: { - auto &add_constraint_info = table_info.Cast(); - return AddConstraint(context, add_constraint_info); - } - default: - throw InternalException("Unrecognized alter table type!"); - } -} - -void DuckTableEntry::UndoAlter(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - D_ASSERT(info.type == AlterType::ALTER_TABLE); - auto &table_info = info.Cast(); - switch (table_info.alter_table_type) { - case AlterTableType::RENAME_TABLE: { - storage->SetTableName(this->name); - break; - default: - break; - } - } -} - -static void RenameExpression(ParsedExpression &expr, RenameColumnInfo &info) { - if (expr.GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &colref = expr.Cast(); - if (colref.column_names.back() == info.old_name) { - colref.column_names.back() = info.new_name; - } - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { RenameExpression((ParsedExpression &)child, info); }); -} - -unique_ptr DuckTableEntry::RenameColumn(ClientContext &context, RenameColumnInfo &info) { - auto rename_idx = GetColumnIndex(info.old_name); - if (rename_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot rename rowid column"); - } - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (rename_idx == col.Logical()) { - copy.SetName(info.new_name); - } - if (col.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), rename_idx)) { - RenameExpression(copy.GeneratedExpressionMutable(), info); - } - create_info->columns.AddColumn(std::move(copy)); - } - for (idx_t c_idx = 0; c_idx < constraints.size(); c_idx++) { - auto copy = constraints[c_idx]->Copy(); - switch (copy->type) { - case ConstraintType::NOT_NULL: - // NOT NULL constraint: no adjustments necessary - break; - case ConstraintType::CHECK: { - // CHECK constraint: need to rename column references that refer to the renamed column - auto &check = copy->Cast(); - RenameExpression(*check.expression, info); - break; - } - case ConstraintType::UNIQUE: { - // UNIQUE constraint: possibly need to rename columns - auto &unique = copy->Cast(); - for (auto &column_name : unique.GetColumnNamesMutable()) { - if (column_name == info.old_name) { - column_name = info.new_name; - } - } - break; - } - case ConstraintType::FOREIGN_KEY: { - // FOREIGN KEY constraint: possibly need to rename columns - auto &fk = copy->Cast(); - vector columns = fk.pk_columns; - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - columns = fk.fk_columns; - } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - for (idx_t i = 0; i < fk.fk_columns.size(); i++) { - columns.push_back(fk.fk_columns[i]); - } - } - for (idx_t i = 0; i < columns.size(); i++) { - if (columns[i] == info.old_name) { - throw CatalogException( - "Cannot rename column \"%s\" because this is involved in the foreign key constraint", - info.old_name); - } - } - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - create_info->constraints.push_back(std::move(copy)); - } - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::AddColumn(ClientContext &context, AddColumnInfo &info) { - auto col_name = info.new_column.GetName(); - - // We're checking for the opposite condition (ADD COLUMN IF _NOT_ EXISTS ...). - if (info.if_column_not_exists && ColumnExists(col_name)) { - return nullptr; - } - - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - - for (auto &col : columns.Logical()) { - create_info->columns.AddColumn(col.Copy()); - } - for (auto &constraint : constraints) { - create_info->constraints.push_back(constraint->Copy()); - } - auto binder = Binder::CreateBinder(context); - binder->BindLogicalType(info.new_column.TypeMutable(), &catalog, schema.name); - info.new_column.SetOid(columns.LogicalColumnCount()); - info.new_column.SetStorageOid(columns.PhysicalColumnCount()); - auto col = info.new_column.Copy(); - - create_info->columns.AddColumn(std::move(col)); - - vector> bound_defaults; - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema, bound_defaults); - auto new_storage = make_shared_ptr(context, *storage, info.new_column, *bound_defaults.back()); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -void DuckTableEntry::UpdateConstraintsOnColumnDrop(const LogicalIndex &removed_index, - const vector &adjusted_indices, - const RemoveColumnInfo &info, CreateTableInfo &create_info, - const vector> &bound_constraints, - bool is_generated) { - // handle constraints for the new table - D_ASSERT(constraints.size() == bound_constraints.size()); - for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { - auto &constraint = constraints[constr_idx]; - auto &bound_constraint = bound_constraints[constr_idx]; - switch (constraint->type) { - case ConstraintType::NOT_NULL: { - auto ¬_null_constraint = bound_constraint->Cast(); - auto not_null_index = columns.PhysicalToLogical(not_null_constraint.index); - if (not_null_index != removed_index) { - // the constraint is not about this column: we need to copy it - // we might need to shift the index back by one though, to account for the removed column - auto new_index = adjusted_indices[not_null_index.index]; - create_info.constraints.push_back(make_uniq(new_index)); - } - break; - } - case ConstraintType::CHECK: { - // Generated columns can not be part of an index - // CHECK constraint - auto &bound_check = bound_constraint->Cast(); - // check if the removed column is part of the check constraint - if (is_generated) { - // generated columns can not be referenced by constraints, we can just add the constraint back - create_info.constraints.push_back(constraint->Copy()); - break; - } - auto physical_index = columns.LogicalToPhysical(removed_index); - if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { - if (bound_check.bound_columns.size() > 1) { - // CHECK constraint that concerns mult - throw CatalogException( - "Cannot drop column \"%s\" because there is a CHECK constraint that depends on it", - info.removed_column); - } else { - // CHECK constraint that ONLY concerns this column, strip the constraint - } - } else { - // check constraint does not concern the removed column: simply re-add it - create_info.constraints.push_back(constraint->Copy()); - } - break; - } - case ConstraintType::UNIQUE: { - auto copy = constraint->Copy(); - auto &unique = copy->Cast(); - if (unique.HasIndex()) { - if (unique.GetIndex() == removed_index) { - throw CatalogException( - "Cannot drop column \"%s\" because there is a UNIQUE constraint that depends on it", - info.removed_column); - } - unique.SetIndex(adjusted_indices[unique.GetIndex().index]); - } - create_info.constraints.push_back(std::move(copy)); - break; - } - case ConstraintType::FOREIGN_KEY: { - auto copy = constraint->Copy(); - auto &fk = copy->Cast(); - vector columns = fk.pk_columns; - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - columns = fk.fk_columns; - } else if (fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - for (idx_t i = 0; i < fk.fk_columns.size(); i++) { - columns.push_back(fk.fk_columns[i]); - } - } - for (idx_t i = 0; i < columns.size(); i++) { - if (columns[i] == info.removed_column) { - throw CatalogException( - "Cannot drop column \"%s\" because there is a FOREIGN KEY constraint that depends on it", - info.removed_column); - } - } - create_info.constraints.push_back(std::move(copy)); - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - } -} - -unique_ptr DuckTableEntry::RemoveColumn(ClientContext &context, RemoveColumnInfo &info) { - auto removed_index = GetColumnIndex(info.removed_column, info.if_column_exists); - if (!removed_index.IsValid()) { - if (!info.if_column_exists) { - throw CatalogException("Cannot drop column: rowid column cannot be dropped"); - } - return nullptr; - } - - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - - logical_index_set_t removed_columns; - if (column_dependency_manager.HasDependents(removed_index)) { - removed_columns = column_dependency_manager.GetDependents(removed_index); - } - if (!removed_columns.empty() && !info.cascade) { - throw CatalogException("Cannot drop column: column is a dependency of 1 or more generated column(s)"); - } - bool dropped_column_is_generated = false; - for (auto &col : columns.Logical()) { - if (col.Logical() == removed_index || removed_columns.count(col.Logical())) { - if (col.Generated()) { - dropped_column_is_generated = true; - } - continue; - } - create_info->columns.AddColumn(col.Copy()); - } - if (create_info->columns.empty()) { - throw CatalogException("Cannot drop column: table only has one column remaining!"); - } - auto adjusted_indices = column_dependency_manager.RemoveColumn(removed_index, columns.LogicalColumnCount()); - - auto binder = Binder::CreateBinder(context); - auto bound_constraints = binder->BindConstraints(constraints, name, columns); - - UpdateConstraintsOnColumnDrop(removed_index, adjusted_indices, info, *create_info, bound_constraints, - dropped_column_is_generated); - - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - if (columns.GetColumn(LogicalIndex(removed_index)).Generated()) { - return make_uniq(catalog, schema, *bound_create_info, storage); - } - auto new_storage = - make_shared_ptr(context, *storage, columns.LogicalToPhysical(LogicalIndex(removed_index)).index); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -unique_ptr DuckTableEntry::SetDefault(ClientContext &context, SetDefaultInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - auto default_idx = GetColumnIndex(info.column_name); - if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot SET DEFAULT for rowid column"); - } - - // Copy all the columns, changing the value of the one that was specified by 'column_name' - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (default_idx == col.Logical()) { - // set the default value of this column - if (copy.Generated()) { - throw BinderException("Cannot SET DEFAULT for generated column \"%s\"", col.Name()); - } - copy.SetDefaultValue(info.expression ? info.expression->Copy() : nullptr); - } - create_info->columns.AddColumn(std::move(copy)); - } - // Copy all the constraints - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::SetNotNull(ClientContext &context, SetNotNullInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - - auto not_null_idx = GetColumnIndex(info.column_name); - if (columns.GetColumn(LogicalIndex(not_null_idx)).Generated()) { - throw BinderException("Unsupported constraint for generated column!"); - } - bool has_not_null = false; - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - if (not_null.index == not_null_idx) { - has_not_null = true; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - if (!has_not_null) { - create_info->constraints.push_back(make_uniq(not_null_idx)); - } - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - - // Early return - if (has_not_null) { - return make_uniq(catalog, schema, *bound_create_info, storage); - } - - // Return with new storage info. Note that we need the bound column index here. - auto physical_columns = columns.LogicalToPhysical(LogicalIndex(not_null_idx)); - auto bound_constraint = make_uniq(physical_columns); - auto new_storage = make_shared_ptr(context, *storage, *bound_constraint); - return make_uniq(catalog, schema, *bound_create_info, new_storage); -} - -unique_ptr DuckTableEntry::DropNotNull(ClientContext &context, DropNotNullInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - - auto not_null_idx = GetColumnIndex(info.column_name); - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - // Skip/drop not_null - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - if (not_null.index == not_null_idx) { - continue; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::ChangeColumnType(ClientContext &context, ChangeColumnTypeInfo &info) { - auto binder = Binder::CreateBinder(context); - binder->BindLogicalType(info.target_type, &catalog, schema.name); - - auto change_idx = GetColumnIndex(info.column_name); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - - // Bind the USING expression. - vector bound_columns; - AlterBinder expr_binder(*binder, context, *this, bound_columns, info.target_type); - auto expression = info.expression->Copy(); - auto bound_expression = expr_binder.Bind(expression); - - // Infer the target_type from the USING expression, if not set explicitly. - if (info.target_type == LogicalType::UNKNOWN) { - info.target_type = bound_expression->return_type; - } - - auto bound_constraints = binder->BindConstraints(constraints, name, columns); - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (change_idx == col.Logical()) { - // set the type of this column - if (copy.Generated()) { - throw NotImplementedException("Changing types of generated columns is not supported yet"); - } - copy.SetType(info.target_type); - } - // TODO: check if the generated_expression breaks, only delete it if it does - if (copy.Generated() && column_dependency_manager.IsDependencyOf(col.Logical(), change_idx)) { - throw BinderException( - "This column is referenced by the generated column \"%s\", so its type can not be changed", - copy.Name()); - } - create_info->columns.AddColumn(std::move(copy)); - } - - for (idx_t constr_idx = 0; constr_idx < constraints.size(); constr_idx++) { - auto constraint = constraints[constr_idx]->Copy(); - switch (constraint->type) { - case ConstraintType::CHECK: { - auto &bound_check = bound_constraints[constr_idx]->Cast(); - auto physical_index = columns.LogicalToPhysical(change_idx); - if (bound_check.bound_columns.find(physical_index) != bound_check.bound_columns.end()) { - throw BinderException("Cannot change the type of a column that has a CHECK constraint specified"); - } - break; - } - case ConstraintType::NOT_NULL: - break; - case ConstraintType::UNIQUE: { - auto &bound_unique = bound_constraints[constr_idx]->Cast(); - auto physical_index = columns.LogicalToPhysical(change_idx); - if (bound_unique.key_set.find(physical_index) != bound_unique.key_set.end()) { - throw BinderException( - "Cannot change the type of a column that has a UNIQUE or PRIMARY KEY constraint specified"); - } - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &bfk = bound_constraints[constr_idx]->Cast(); - auto key_set = bfk.pk_key_set; - if (bfk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE) { - key_set = bfk.fk_key_set; - } else if (bfk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - key_set.insert(bfk.info.fk_keys.begin(), bfk.info.fk_keys.end()); - } - if (key_set.find(columns.LogicalToPhysical(change_idx)) != key_set.end()) { - throw BinderException("Cannot change the type of a column that has a FOREIGN KEY constraint specified"); - } - break; - } - default: - throw InternalException("Unsupported constraint for entry!"); - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - - vector storage_oids; - for (idx_t i = 0; i < bound_columns.size(); i++) { - storage_oids.emplace_back(columns.LogicalToPhysical(bound_columns[i]).index); - } - if (storage_oids.empty()) { - storage_oids.emplace_back(COLUMN_IDENTIFIER_ROW_ID); - } - - auto new_storage = - make_shared_ptr(context, *storage, columns.LogicalToPhysical(LogicalIndex(change_idx)).index, - info.target_type, std::move(storage_oids), *bound_expression); - auto result = make_uniq(catalog, schema, *bound_create_info, new_storage); - return std::move(result); -} - -unique_ptr DuckTableEntry::SetColumnComment(ClientContext &context, SetColumnCommentInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - auto default_idx = GetColumnIndex(info.column_name); - if (default_idx.index == COLUMN_IDENTIFIER_ROW_ID) { - throw CatalogException("Cannot SET DEFAULT for rowid column"); - } - - // Copy all the columns, changing the value of the one that was specified by 'column_name' - for (auto &col : columns.Logical()) { - auto copy = col.Copy(); - if (default_idx == col.Logical()) { - copy.SetComment(info.comment_value); - } - create_info->columns.AddColumn(std::move(copy)); - } - // Copy all the constraints - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::AddForeignKeyConstraint(optional_ptr context, - AlterForeignKeyInfo &info) { - D_ASSERT(info.type == AlterForeignKeyType::AFT_ADD); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - - create_info->columns = columns.Copy(); - for (idx_t i = 0; i < constraints.size(); i++) { - create_info->constraints.push_back(constraints[i]->Copy()); - } - ForeignKeyInfo fk_info; - fk_info.type = ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE; - fk_info.schema = info.schema; - fk_info.table = info.fk_table; - fk_info.pk_keys = info.pk_keys; - fk_info.fk_keys = info.fk_keys; - create_info->constraints.push_back( - make_uniq(info.pk_columns, info.fk_columns, std::move(fk_info))); - - unique_ptr bound_create_info; - if (context) { - auto binder = Binder::CreateBinder(*context); - bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - } else { - bound_create_info = Binder::BindCreateTableCheckpoint(std::move(create_info), schema); - } - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -unique_ptr DuckTableEntry::DropForeignKeyConstraint(ClientContext &context, AlterForeignKeyInfo &info) { - D_ASSERT(info.type == AlterForeignKeyType::AFT_DELETE); - auto create_info = make_uniq(schema, name); - create_info->temporary = temporary; - create_info->comment = comment; - create_info->tags = tags; - - create_info->columns = columns.Copy(); - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - if (constraint->type == ConstraintType::FOREIGN_KEY) { - ForeignKeyConstraint &fk = constraint->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE && fk.info.table == info.fk_table) { - continue; - } - } - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -void DuckTableEntry::Rollback(CatalogEntry &prev_entry) { - if (prev_entry.type != CatalogType::TABLE_ENTRY) { - return; - } - - // Rolls back any physical index creation. - // FIXME: Currently only works for PKs. - // FIXME: Should be changed to work for any index-based constraint. - - auto &table = Cast(); - auto &prev_table = prev_entry.Cast(); - auto &prev_info = prev_table.GetStorage().GetDataTableInfo(); - auto &prev_indexes = prev_info->GetIndexes(); - - // Find all index-based constraints that exist in rollback_table, but not in table. - // Then, remove them. - - unordered_set names; - for (const auto &constraint : prev_table.GetConstraints()) { - if (constraint->type != ConstraintType::UNIQUE) { - continue; - } - const auto &unique = constraint->Cast(); - if (unique.is_primary_key) { - auto index_name = unique.GetName(prev_table.name); - names.insert(index_name); - } - } - - for (const auto &constraint : GetConstraints()) { - if (constraint->type != ConstraintType::UNIQUE) { - continue; - } - const auto &unique = constraint->Cast(); - if (!unique.IsPrimaryKey()) { - continue; - } - auto index_name = unique.GetName(table.name); - if (names.find(index_name) == names.end()) { - prev_indexes.RemoveIndex(index_name); - } - } -} - -unique_ptr DuckTableEntry::AddConstraint(ClientContext &context, AddConstraintInfo &info) { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - - // Copy all columns and constraints to the modified table. - create_info->columns = columns.Copy(); - for (const auto &constraint : constraints) { - create_info->constraints.push_back(constraint->Copy()); - } - - if (info.constraint->type == ConstraintType::UNIQUE) { - const auto &unique = info.constraint->Cast(); - const auto existing_pk = GetPrimaryKey(); - - if (unique.is_primary_key && existing_pk) { - auto existing_name = existing_pk->ToString(); - throw CatalogException("table \"%s\" can have only one primary key: %s", name, existing_name); - } - create_info->constraints.push_back(info.constraint->Copy()); - - } else { - throw InternalException("unsupported constraint type in ALTER TABLE statement"); - } - - // We create a physical table with a new constraint and a new unique index. - const auto binder = Binder::CreateBinder(context); - const auto bound_constraint = binder->BindConstraint(*info.constraint, create_info->table, create_info->columns); - const auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - - auto new_storage = make_shared_ptr(context, *storage, *bound_constraint); - auto new_entry = make_uniq(catalog, schema, *bound_create_info, new_storage); - return std::move(new_entry); -} - -unique_ptr DuckTableEntry::Copy(ClientContext &context) const { - auto create_info = make_uniq(schema, name); - create_info->comment = comment; - create_info->tags = tags; - create_info->columns = columns.Copy(); - - for (idx_t i = 0; i < constraints.size(); i++) { - auto constraint = constraints[i]->Copy(); - create_info->constraints.push_back(std::move(constraint)); - } - - auto binder = Binder::CreateBinder(context); - auto bound_create_info = binder->BindCreateTableInfo(std::move(create_info), schema); - return make_uniq(catalog, schema, *bound_create_info, storage); -} - -void DuckTableEntry::SetAsRoot() { - storage->SetAsRoot(); - storage->SetTableName(name); -} - -void DuckTableEntry::CommitAlter(string &column_name) { - D_ASSERT(!column_name.empty()); - optional_idx removed_index; - for (auto &col : columns.Logical()) { - if (col.Name() == column_name) { - // No need to alter storage, removed column is generated column - if (col.Generated()) { - return; - } - removed_index = col.Oid(); - break; - } - } - storage->CommitDropColumn(columns.LogicalToPhysical(LogicalIndex(removed_index.GetIndex())).index); -} - -void DuckTableEntry::CommitDrop() { - storage->CommitDropTable(); -} - -DataTable &DuckTableEntry::GetStorage() { - return *storage; -} - -TableFunction DuckTableEntry::GetScanFunction(ClientContext &context, unique_ptr &bind_data) { - bind_data = make_uniq(*this); - return TableScanFunction::GetFunction(); -} - -vector DuckTableEntry::GetColumnSegmentInfo() { - return storage->GetColumnSegmentInfo(); -} - -TableStorageInfo DuckTableEntry::GetStorageInfo(ClientContext &context) { - return storage->GetStorageInfo(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp deleted file mode 100644 index 2c5cb9ae7..000000000 --- a/src/duckdb/src/catalog/catalog_entry/index_catalog_entry.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" - -namespace duckdb { - -IndexCatalogEntry::IndexCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateIndexInfo &info) - : StandardEntry(CatalogType::INDEX_ENTRY, schema, catalog, info.index_name), sql(info.sql), options(info.options), - index_type(info.index_type), index_constraint_type(info.constraint_type), column_ids(info.column_ids) { - - this->temporary = info.temporary; - this->dependencies = info.dependencies; - this->comment = info.comment; - for (auto &expr : expressions) { - D_ASSERT(expr); - expressions.push_back(expr->Copy()); - } - for (auto &parsed_expr : info.parsed_expressions) { - D_ASSERT(parsed_expr); - parsed_expressions.push_back(parsed_expr->Copy()); - } -} - -unique_ptr IndexCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = GetSchemaName(); - result->table = GetTableName(); - - result->temporary = temporary; - result->sql = sql; - result->index_name = name; - result->index_type = index_type; - result->constraint_type = index_constraint_type; - result->column_ids = column_ids; - result->dependencies = dependencies; - - for (auto &expr : expressions) { - result->expressions.push_back(expr->Copy()); - } - for (auto &expr : parsed_expressions) { - result->parsed_expressions.push_back(expr->Copy()); - } - - result->comment = comment; - result->tags = tags; - - return std::move(result); -} - -string IndexCatalogEntry::ToSQL() const { - auto info = GetInfo(); - return info->ToString(); -} - -bool IndexCatalogEntry::IsUnique() const { - return (index_constraint_type == IndexConstraintType::UNIQUE || - index_constraint_type == IndexConstraintType::PRIMARY); -} - -bool IndexCatalogEntry::IsPrimary() const { - return (index_constraint_type == IndexConstraintType::PRIMARY); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp deleted file mode 100644 index 6aa9a52c3..000000000 --- a/src/duckdb/src/catalog/catalog_entry/macro_catalog_entry.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" -#include "duckdb/function/scalar_macro_function.hpp" - -namespace duckdb { - -MacroCatalogEntry::MacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : FunctionEntry( - (info.macros[0]->type == MacroType::SCALAR_MACRO ? CatalogType::MACRO_ENTRY : CatalogType::TABLE_MACRO_ENTRY), - catalog, schema, info), - macros(std::move(info.macros)) { - this->temporary = info.temporary; - this->internal = info.internal; - this->dependencies = info.dependencies; - this->comment = info.comment; - this->tags = info.tags; -} - -ScalarMacroCatalogEntry::ScalarMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : MacroCatalogEntry(catalog, schema, info) { -} - -unique_ptr ScalarMacroCatalogEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - auto result = make_uniq(catalog, schema, cast_info); - return std::move(result); -} - -TableMacroCatalogEntry::TableMacroCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateMacroInfo &info) - : MacroCatalogEntry(catalog, schema, info) { -} - -unique_ptr TableMacroCatalogEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - auto result = make_uniq(catalog, schema, cast_info); - return std::move(result); -} - -unique_ptr MacroCatalogEntry::GetInfo() const { - auto info = make_uniq(type); - info->catalog = catalog.GetName(); - info->schema = schema.name; - info->name = name; - for (auto &function : macros) { - info->macros.push_back(function->Copy()); - } - info->dependencies = dependencies; - info->comment = comment; - info->tags = tags; - return std::move(info); -} - -string MacroCatalogEntry::ToSQL() const { - auto create_info = GetInfo(); - return create_info->ToString(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp deleted file mode 100644 index ff247dcb0..000000000 --- a/src/duckdb/src/catalog/catalog_entry/pragma_function_catalog_entry.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" -#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" - -namespace duckdb { - -PragmaFunctionCatalogEntry::PragmaFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreatePragmaFunctionInfo &info) - : FunctionEntry(CatalogType::PRAGMA_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp deleted file mode 100644 index f983fb762..000000000 --- a/src/duckdb/src/catalog/catalog_entry/scalar_function_catalog_entry.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" - -namespace duckdb { - -ScalarFunctionCatalogEntry::ScalarFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateScalarFunctionInfo &info) - : FunctionEntry(CatalogType::SCALAR_FUNCTION_ENTRY, catalog, schema, info), functions(info.functions) { -} - -unique_ptr ScalarFunctionCatalogEntry::AlterEntry(CatalogTransaction transaction, AlterInfo &info) { - if (info.type != AlterType::ALTER_SCALAR_FUNCTION) { - throw InternalException("Attempting to alter ScalarFunctionCatalogEntry with unsupported alter type"); - } - auto &function_info = info.Cast(); - if (function_info.alter_scalar_function_type != AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS) { - throw InternalException( - "Attempting to alter ScalarFunctionCatalogEntry with unsupported alter scalar function type"); - } - auto &add_overloads = function_info.Cast(); - - ScalarFunctionSet new_set = functions; - if (!new_set.MergeFunctionSet(add_overloads.new_overloads->functions, true)) { - throw BinderException( - "Failed to add new function overloads to function \"%s\": function overload already exists", name); - } - CreateScalarFunctionInfo new_info(std::move(new_set)); - new_info.internal = internal; - new_info.descriptions = descriptions; - new_info.descriptions.insert(new_info.descriptions.end(), add_overloads.new_overloads->descriptions.begin(), - add_overloads.new_overloads->descriptions.end()); - return make_uniq(catalog, schema, new_info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp deleted file mode 100644 index ef14221ff..000000000 --- a/src/duckdb/src/catalog/catalog_entry/schema_catalog_entry.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/default/default_schemas.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" - -#include - -namespace duckdb { - -SchemaCatalogEntry::SchemaCatalogEntry(Catalog &catalog, CreateSchemaInfo &info) - : InCatalogEntry(CatalogType::SCHEMA_ENTRY, catalog, info.schema) { - this->internal = info.internal; - this->comment = info.comment; - this->tags = info.tags; -} - -CatalogTransaction SchemaCatalogEntry::GetCatalogTransaction(ClientContext &context) { - return CatalogTransaction(catalog, context); -} - -optional_ptr SchemaCatalogEntry::CreateIndex(ClientContext &context, CreateIndexInfo &info, - TableCatalogEntry &table) { - return CreateIndex(GetCatalogTransaction(context), info, table); -} - -SimilarCatalogEntry SchemaCatalogEntry::GetSimilarEntry(CatalogTransaction transaction, CatalogType type, - const string &name) { - SimilarCatalogEntry result; - Scan(transaction.GetContext(), type, [&](CatalogEntry &entry) { - auto entry_score = StringUtil::SimilarityRating(entry.name, name); - if (entry_score > result.score) { - result.score = entry_score; - result.name = entry.name; - } - }); - return result; -} - -//! This should not be used, it's only implemented to not put the burden of implementing it on every derived class of -//! SchemaCatalogEntry -CatalogSet::EntryLookup SchemaCatalogEntry::GetEntryDetailed(CatalogTransaction transaction, CatalogType type, - const string &name) { - CatalogSet::EntryLookup result; - result.result = GetEntry(transaction, type, name); - if (!result.result) { - result.reason = CatalogSet::EntryLookup::FailureReason::DELETED; - } else { - result.reason = CatalogSet::EntryLookup::FailureReason::SUCCESS; - } - return result; -} - -unique_ptr SchemaCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = name; - result->comment = comment; - result->tags = tags; - return std::move(result); -} - -string SchemaCatalogEntry::ToSQL() const { - auto create_schema_info = GetInfo(); - return create_schema_info->ToString(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp deleted file mode 100644 index 085048b0a..000000000 --- a/src/duckdb/src/catalog/catalog_entry/sequence_catalog_entry.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/transaction/duck_transaction.hpp" - -#include -#include - -namespace duckdb { - -SequenceData::SequenceData(CreateSequenceInfo &info) - : usage_count(info.usage_count), counter(info.start_value), last_value(info.start_value), increment(info.increment), - start_value(info.start_value), min_value(info.min_value), max_value(info.max_value), cycle(info.cycle) { -} - -SequenceCatalogEntry::SequenceCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateSequenceInfo &info) - : StandardEntry(CatalogType::SEQUENCE_ENTRY, schema, catalog, info.name), data(info) { - this->temporary = info.temporary; - this->comment = info.comment; - this->tags = info.tags; -} - -unique_ptr SequenceCatalogEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - - auto result = make_uniq(catalog, schema, cast_info); - result->data = GetData(); - - return std::move(result); -} - -SequenceData SequenceCatalogEntry::GetData() const { - lock_guard seqlock(lock); - return data; -} - -int64_t SequenceCatalogEntry::CurrentValue() { - lock_guard seqlock(lock); - int64_t result; - if (data.usage_count == 0u) { - throw SequenceException("currval: sequence is not yet defined in this session"); - } - result = data.last_value; - return result; -} - -int64_t SequenceCatalogEntry::NextValue(DuckTransaction &transaction) { - lock_guard seqlock(lock); - int64_t result; - result = data.counter; - bool overflow = !TryAddOperator::Operation(data.counter, data.increment, data.counter); - if (data.cycle) { - if (overflow) { - data.counter = data.increment < 0 ? data.max_value : data.min_value; - } else if (data.counter < data.min_value) { - data.counter = data.max_value; - } else if (data.counter > data.max_value) { - data.counter = data.min_value; - } - } else { - if (result < data.min_value || (overflow && data.increment < 0)) { - throw SequenceException("nextval: reached minimum value of sequence \"%s\" (%lld)", name, data.min_value); - } - if (result > data.max_value || overflow) { - throw SequenceException("nextval: reached maximum value of sequence \"%s\" (%lld)", name, data.max_value); - } - } - data.last_value = result; - data.usage_count++; - if (!temporary) { - transaction.PushSequenceUsage(*this, data); - } - return result; -} - -void SequenceCatalogEntry::ReplayValue(uint64_t v_usage_count, int64_t v_counter) { - if (v_usage_count > data.usage_count) { - data.usage_count = v_usage_count; - data.counter = v_counter; - } -} - -unique_ptr SequenceCatalogEntry::GetInfo() const { - auto seq_data = GetData(); - - auto result = make_uniq(); - result->catalog = catalog.GetName(); - result->schema = schema.name; - result->name = name; - result->usage_count = seq_data.usage_count; - result->increment = seq_data.increment; - result->min_value = seq_data.min_value; - result->max_value = seq_data.max_value; - result->start_value = seq_data.counter; - result->cycle = seq_data.cycle; - result->dependencies = dependencies; - result->comment = comment; - result->tags = tags; - return std::move(result); -} - -string SequenceCatalogEntry::ToSQL() const { - auto seq_data = GetData(); - - std::stringstream ss; - ss << "CREATE SEQUENCE "; - ss << name; - ss << " INCREMENT BY " << seq_data.increment; - ss << " MINVALUE " << seq_data.min_value; - ss << " MAXVALUE " << seq_data.max_value; - ss << " START " << seq_data.counter; - ss << " " << (seq_data.cycle ? "CYCLE" : "NO CYCLE") << ";"; - return ss.str(); -} -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp deleted file mode 100644 index 3070b2e30..000000000 --- a/src/duckdb/src/catalog/catalog_entry/table_catalog_entry.cpp +++ /dev/null @@ -1,335 +0,0 @@ -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/constraints/list.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/storage/table_storage_info.hpp" -#include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/constraints/bound_check_constraint.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/common/extra_type_info.hpp" -#include "duckdb/parser/expression/cast_expression.hpp" - -#include - -namespace duckdb { - -TableCatalogEntry::TableCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTableInfo &info) - : StandardEntry(CatalogType::TABLE_ENTRY, schema, catalog, info.table), columns(std::move(info.columns)), - constraints(std::move(info.constraints)) { - this->temporary = info.temporary; - this->dependencies = info.dependencies; - this->comment = info.comment; - this->tags = info.tags; -} - -bool TableCatalogEntry::HasGeneratedColumns() const { - return columns.LogicalColumnCount() != columns.PhysicalColumnCount(); -} - -LogicalIndex TableCatalogEntry::GetColumnIndex(string &column_name, bool if_exists) const { - auto entry = columns.GetColumnIndex(column_name); - if (!entry.IsValid()) { - if (if_exists) { - return entry; - } - throw BinderException("Table \"%s\" does not have a column with name \"%s\"", name, column_name); - } - return entry; -} - -unique_ptr TableCatalogEntry::GetSample() { - return nullptr; -} - -bool TableCatalogEntry::ColumnExists(const string &name) const { - return columns.ColumnExists(name); -} - -const ColumnDefinition &TableCatalogEntry::GetColumn(const string &name) const { - return columns.GetColumn(name); -} - -vector TableCatalogEntry::GetTypes() const { - vector types; - for (auto &col : columns.Physical()) { - types.push_back(col.Type()); - } - return types; -} - -unique_ptr TableCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->catalog = catalog.GetName(); - result->schema = schema.name; - result->table = name; - result->columns = columns.Copy(); - result->constraints.reserve(constraints.size()); - result->dependencies = dependencies; - std::for_each(constraints.begin(), constraints.end(), - [&result](const unique_ptr &c) { result->constraints.emplace_back(c->Copy()); }); - result->comment = comment; - result->tags = tags; - return std::move(result); -} - -string TableCatalogEntry::ColumnsToSQL(const ColumnList &columns, const vector> &constraints) { - std::stringstream ss; - - ss << "("; - - // find all columns that have NOT NULL specified, but are NOT primary key columns - logical_index_set_t not_null_columns; - logical_index_set_t unique_columns; - logical_index_set_t pk_columns; - unordered_set multi_key_pks; - vector extra_constraints; - for (auto &constraint : constraints) { - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - not_null_columns.insert(not_null.index); - } else if (constraint->type == ConstraintType::UNIQUE) { - auto &pk = constraint->Cast(); - if (pk.HasIndex()) { - // no columns specified: single column constraint - if (pk.IsPrimaryKey()) { - pk_columns.insert(pk.GetIndex()); - } else { - unique_columns.insert(pk.GetIndex()); - } - } else { - // multi-column constraint, this constraint needs to go at the end after all columns - if (pk.IsPrimaryKey()) { - // multi key pk column: insert set of columns into multi_key_pks - for (auto &col : pk.GetColumnNames()) { - multi_key_pks.insert(col); - } - } - extra_constraints.push_back(constraint->ToString()); - } - } else if (constraint->type == ConstraintType::FOREIGN_KEY) { - auto &fk = constraint->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE || - fk.info.type == ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE) { - extra_constraints.push_back(constraint->ToString()); - } - } else { - extra_constraints.push_back(constraint->ToString()); - } - } - - for (auto &column : columns.Logical()) { - if (column.Oid() > 0) { - ss << ", "; - } - ss << KeywordHelper::WriteOptionallyQuoted(column.Name()) << " "; - auto &column_type = column.Type(); - if (column_type.id() != LogicalTypeId::ANY) { - ss << column.Type().ToString(); - } - auto extra_type_info = column_type.AuxInfo(); - if (extra_type_info && extra_type_info->type == ExtraTypeInfoType::STRING_TYPE_INFO) { - auto &string_info = extra_type_info->Cast(); - if (!string_info.collation.empty()) { - ss << " COLLATE " + string_info.collation; - } - } - bool not_null = not_null_columns.find(column.Logical()) != not_null_columns.end(); - bool is_single_key_pk = pk_columns.find(column.Logical()) != pk_columns.end(); - bool is_multi_key_pk = multi_key_pks.find(column.Name()) != multi_key_pks.end(); - bool is_unique = unique_columns.find(column.Logical()) != unique_columns.end(); - if (column.Generated()) { - reference generated_expression = column.GeneratedExpression(); - if (column_type.id() != LogicalTypeId::ANY) { - // We artificially add a cast if the type is specified, need to strip it - auto &expr = generated_expression.get(); - D_ASSERT(expr.GetExpressionType() == ExpressionType::OPERATOR_CAST); - auto &cast_expr = expr.Cast(); - D_ASSERT(cast_expr.cast_type.id() == column_type.id()); - generated_expression = *cast_expr.child; - } - ss << " GENERATED ALWAYS AS(" << generated_expression.get().ToString() << ")"; - } else if (column.HasDefaultValue()) { - ss << " DEFAULT(" << column.DefaultValue().ToString() << ")"; - } - if (not_null && !is_single_key_pk && !is_multi_key_pk) { - // NOT NULL but not a primary key column - ss << " NOT NULL"; - } - if (is_single_key_pk) { - // single column pk: insert constraint here - ss << " PRIMARY KEY"; - } - if (is_unique) { - // single column unique: insert constraint here - ss << " UNIQUE"; - } - } - // print any extra constraints that still need to be printed - for (auto &extra_constraint : extra_constraints) { - ss << ", "; - ss << extra_constraint; - } - - ss << ")"; - return ss.str(); -} - -string TableCatalogEntry::ColumnNamesToSQL(const ColumnList &columns) { - if (columns.empty()) { - return ""; - } - - std::stringstream ss; - ss << "("; - - for (auto &column : columns.Logical()) { - if (column.Oid() > 0) { - ss << ", "; - } - ss << KeywordHelper::WriteOptionallyQuoted(column.Name()) << " "; - } - ss << ")"; - return ss.str(); -} - -string TableCatalogEntry::ToSQL() const { - auto create_info = GetInfo(); - return create_info->ToString(); -} - -const ColumnList &TableCatalogEntry::GetColumns() const { - return columns; -} - -const ColumnDefinition &TableCatalogEntry::GetColumn(LogicalIndex idx) const { - return columns.GetColumn(idx); -} - -const vector> &TableCatalogEntry::GetConstraints() const { - return constraints; -} - -// LCOV_EXCL_START -DataTable &TableCatalogEntry::GetStorage() { - throw InternalException("Calling GetStorage on a TableCatalogEntry that is not a DuckTableEntry"); -} -// LCOV_EXCL_STOP - -static void BindExtraColumns(TableCatalogEntry &table, LogicalGet &get, LogicalProjection &proj, LogicalUpdate &update, - physical_index_set_t &bound_columns) { - if (bound_columns.size() <= 1) { - return; - } - idx_t found_column_count = 0; - physical_index_set_t found_columns; - for (idx_t i = 0; i < update.columns.size(); i++) { - if (bound_columns.find(update.columns[i]) != bound_columns.end()) { - // this column is referenced in the CHECK constraint - found_column_count++; - found_columns.insert(update.columns[i]); - } - } - if (found_column_count > 0 && found_column_count != bound_columns.size()) { - // columns in this CHECK constraint were referenced, but not all were part of the UPDATE - // add them to the scan and update set - for (auto &check_column_id : bound_columns) { - if (found_columns.find(check_column_id) != found_columns.end()) { - // column is already projected - continue; - } - // column is not projected yet: project it by adding the clause "i=i" to the set of updated columns - auto &column = table.GetColumns().GetColumn(check_column_id); - update.expressions.push_back(make_uniq( - column.Type(), ColumnBinding(proj.table_index, proj.expressions.size()))); - proj.expressions.push_back(make_uniq( - column.Type(), ColumnBinding(get.table_index, get.GetColumnIds().size()))); - get.AddColumnId(check_column_id.index); - update.columns.push_back(check_column_id); - } - } -} - -vector TableCatalogEntry::GetColumnSegmentInfo() { - return {}; -} - -void TableCatalogEntry::BindUpdateConstraints(Binder &binder, LogicalGet &get, LogicalProjection &proj, - LogicalUpdate &update, ClientContext &context) { - // check the constraints and indexes of the table to see if we need to project any additional columns - // we do this for indexes with multiple columns and CHECK constraints in the UPDATE clause - // suppose we have a constraint CHECK(i + j < 10); now we need both i and j to check the constraint - // if we are only updating one of the two columns we add the other one to the UPDATE set - // with a "useless" update (i.e. i=i) so we can verify that the CHECK constraint is not violated - auto bound_constraints = binder.BindConstraints(constraints, name, columns); - for (auto &constraint : bound_constraints) { - if (constraint->type == ConstraintType::CHECK) { - auto &check = constraint->Cast(); - // check constraint! check if we need to add any extra columns to the UPDATE clause - BindExtraColumns(*this, get, proj, update, check.bound_columns); - } - } - if (update.return_chunk) { - physical_index_set_t all_columns; - for (auto &column : GetColumns().Physical()) { - all_columns.insert(column.Physical()); - } - BindExtraColumns(*this, get, proj, update, all_columns); - } - // for index updates we always turn any update into an insert and a delete - // we thus need all the columns to be available, hence we check if the update touches any index columns - // If the returning keyword is used, we need access to the whole row in case the user requests it. - // Therefore switch the update to a delete and insert. - update.update_is_del_and_insert = false; - TableStorageInfo table_storage_info = GetStorageInfo(context); - for (auto index : table_storage_info.index_info) { - for (auto &column : update.columns) { - if (index.column_set.find(column.index) != index.column_set.end()) { - update.update_is_del_and_insert = true; - break; - } - } - }; - - // we also convert any updates on LIST columns into delete + insert - for (auto &col_index : update.columns) { - auto &column = GetColumns().GetColumn(col_index); - if (!column.Type().SupportsRegularUpdate()) { - update.update_is_del_and_insert = true; - break; - } - } - - if (update.update_is_del_and_insert) { - // the update updates a column required by an index or requires returning the updated rows, - // push projections for all columns - physical_index_set_t all_columns; - for (auto &column : GetColumns().Physical()) { - all_columns.insert(column.Physical()); - } - BindExtraColumns(*this, get, proj, update, all_columns); - } -} - -optional_ptr TableCatalogEntry::GetPrimaryKey() const { - for (const auto &constraint : GetConstraints()) { - if (constraint->type == ConstraintType::UNIQUE) { - auto &unique = constraint->Cast(); - if (unique.IsPrimaryKey()) { - return &unique; - } - } - } - return nullptr; -} - -bool TableCatalogEntry::HasPrimaryKey() const { - return GetPrimaryKey() != nullptr; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp deleted file mode 100644 index 9b80a1d5d..000000000 --- a/src/duckdb/src/catalog/catalog_entry/table_function_catalog_entry.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" -#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" - -namespace duckdb { - -TableFunctionCatalogEntry::TableFunctionCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, - CreateTableFunctionInfo &info) - : FunctionEntry(CatalogType::TABLE_FUNCTION_ENTRY, catalog, schema, info), functions(std::move(info.functions)) { - D_ASSERT(this->functions.Size() > 0); -} - -unique_ptr TableFunctionCatalogEntry::AlterEntry(CatalogTransaction transaction, AlterInfo &info) { - if (info.type != AlterType::ALTER_TABLE_FUNCTION) { - throw InternalException("Attempting to alter TableFunctionCatalogEntry with unsupported alter type"); - } - auto &function_info = info.Cast(); - if (function_info.alter_table_function_type != AlterTableFunctionType::ADD_FUNCTION_OVERLOADS) { - throw InternalException( - "Attempting to alter TableFunctionCatalogEntry with unsupported alter table function type"); - } - auto &add_overloads = function_info.Cast(); - - TableFunctionSet new_set = functions; - if (!new_set.MergeFunctionSet(add_overloads.new_overloads)) { - throw BinderException("Failed to add new function overloads to function \"%s\": function already exists", name); - } - CreateTableFunctionInfo new_info(std::move(new_set)); - return make_uniq(catalog, schema, new_info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp deleted file mode 100644 index be9c24834..000000000 --- a/src/duckdb/src/catalog/catalog_entry/type_catalog_entry.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/parser/keyword_helper.hpp" -#include -#include - -namespace duckdb { - -TypeCatalogEntry::TypeCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateTypeInfo &info) - : StandardEntry(CatalogType::TYPE_ENTRY, schema, catalog, info.name), user_type(info.type), - bind_function(info.bind_function) { - this->temporary = info.temporary; - this->internal = info.internal; - this->dependencies = info.dependencies; - this->comment = info.comment; - this->tags = info.tags; -} - -unique_ptr TypeCatalogEntry::Copy(ClientContext &context) const { - auto info_copy = GetInfo(); - auto &cast_info = info_copy->Cast(); - auto result = make_uniq(catalog, schema, cast_info); - return std::move(result); -} - -unique_ptr TypeCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->catalog = catalog.GetName(); - result->schema = schema.name; - result->name = name; - result->type = user_type; - result->dependencies = dependencies; - result->comment = comment; - result->tags = tags; - result->bind_function = bind_function; - return std::move(result); -} - -string TypeCatalogEntry::ToSQL() const { - std::stringstream ss; - ss << "CREATE TYPE "; - ss << KeywordHelper::WriteOptionallyQuoted(name); - ss << " AS "; - - auto user_type_copy = user_type; - - // Strip off the potential alias so ToString doesn't just output the alias - user_type_copy.SetAlias(""); - D_ASSERT(user_type_copy.GetAlias().empty()); - - ss << user_type_copy.ToString(); - ss << ";"; - return ss.str(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp b/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp deleted file mode 100644 index 9f029f214..000000000 --- a/src/duckdb/src/catalog/catalog_entry/view_catalog_entry.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" - -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/exception/binder_exception.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/parsed_data/comment_on_column_info.hpp" -#include "duckdb/common/limits.hpp" - -#include - -namespace duckdb { - -void ViewCatalogEntry::Initialize(CreateViewInfo &info) { - query = std::move(info.query); - this->aliases = info.aliases; - this->types = info.types; - this->names = info.names; - this->temporary = info.temporary; - this->sql = info.sql; - this->internal = info.internal; - this->dependencies = info.dependencies; - this->comment = info.comment; - this->tags = info.tags; - this->column_comments = info.column_comments; -} - -ViewCatalogEntry::ViewCatalogEntry(Catalog &catalog, SchemaCatalogEntry &schema, CreateViewInfo &info) - : StandardEntry(CatalogType::VIEW_ENTRY, schema, catalog, info.view_name) { - Initialize(info); -} - -unique_ptr ViewCatalogEntry::GetInfo() const { - auto result = make_uniq(); - result->schema = schema.name; - result->view_name = name; - result->sql = sql; - result->query = unique_ptr_cast(query->Copy()); - result->aliases = aliases; - result->names = names; - result->types = types; - result->temporary = temporary; - result->dependencies = dependencies; - result->comment = comment; - result->tags = tags; - result->column_comments = column_comments; - return std::move(result); -} - -unique_ptr ViewCatalogEntry::AlterEntry(ClientContext &context, AlterInfo &info) { - D_ASSERT(!internal); - - // Column comments have a special alter type - if (info.type == AlterType::SET_COLUMN_COMMENT) { - auto &comment_on_column_info = info.Cast(); - auto copied_view = Copy(context); - - for (idx_t i = 0; i < names.size(); i++) { - const auto &col_name = names[i]; - if (col_name == comment_on_column_info.column_name) { - auto &copied_view_entry = copied_view->Cast(); - - // If vector is empty, we need to initialize it on setting here - if (copied_view_entry.column_comments.empty()) { - copied_view_entry.column_comments = vector(copied_view_entry.types.size()); - } - - copied_view_entry.column_comments[i] = comment_on_column_info.comment_value; - return copied_view; - } - } - throw BinderException("View \"%s\" does not have a column with name \"%s\"", name, - comment_on_column_info.column_name); - } - - if (info.type != AlterType::ALTER_VIEW) { - throw CatalogException("Can only modify view with ALTER VIEW statement"); - } - auto &view_info = info.Cast(); - switch (view_info.alter_view_type) { - case AlterViewType::RENAME_VIEW: { - auto &rename_info = view_info.Cast(); - auto copied_view = Copy(context); - copied_view->name = rename_info.new_view_name; - return copied_view; - } - default: - throw InternalException("Unrecognized alter view type!"); - } -} - -string ViewCatalogEntry::ToSQL() const { - if (sql.empty()) { - //! Return empty sql with view name so pragma view_tables don't complain - return sql; - } - auto info = GetInfo(); - auto result = info->ToString(); - return result; -} - -unique_ptr ViewCatalogEntry::Copy(ClientContext &context) const { - D_ASSERT(!internal); - auto create_info = GetInfo(); - - return make_uniq(catalog, schema, create_info->Cast()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_entry_retriever.cpp b/src/duckdb/src/catalog/catalog_entry_retriever.cpp deleted file mode 100644 index c37562d72..000000000 --- a/src/duckdb/src/catalog/catalog_entry_retriever.cpp +++ /dev/null @@ -1,119 +0,0 @@ -#include "duckdb/catalog/catalog_entry_retriever.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/parser/query_error_context.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/enums/on_entry_not_found.hpp" -#include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/common/optional_ptr.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -LogicalType CatalogEntryRetriever::GetType(Catalog &catalog, const string &schema, const string &name, - OnEntryNotFound on_entry_not_found) { - QueryErrorContext error_context; - auto result = GetEntry(CatalogType::TYPE_ENTRY, catalog, schema, name, on_entry_not_found, error_context); - if (!result) { - return LogicalType::INVALID; - } - auto &type_entry = result->Cast(); - return type_entry.user_type; -} - -LogicalType CatalogEntryRetriever::GetType(const string &catalog, const string &schema, const string &name, - OnEntryNotFound on_entry_not_found) { - QueryErrorContext error_context; - auto result = GetEntry(CatalogType::TYPE_ENTRY, catalog, schema, name, on_entry_not_found, error_context); - if (!result) { - return LogicalType::INVALID; - } - auto &type_entry = result->Cast(); - return type_entry.user_type; -} - -optional_ptr CatalogEntryRetriever::GetEntry(CatalogType type, const string &catalog, - const string &schema, const string &name, - OnEntryNotFound on_entry_not_found, - QueryErrorContext error_context) { - return ReturnAndCallback(Catalog::GetEntry(*this, type, catalog, schema, name, on_entry_not_found, error_context)); -} - -optional_ptr CatalogEntryRetriever::GetSchema(const string &catalog, const string &name, - OnEntryNotFound on_entry_not_found, - QueryErrorContext error_context) { - auto result = Catalog::GetSchema(*this, catalog, name, on_entry_not_found, error_context); - if (!result) { - return result; - } - if (callback) { - // Call the callback if it's set - callback(*result); - } - return result; -} - -optional_ptr CatalogEntryRetriever::GetEntry(CatalogType type, Catalog &catalog, const string &schema, - const string &name, OnEntryNotFound on_entry_not_found, - QueryErrorContext error_context) { - return ReturnAndCallback(catalog.GetEntry(*this, type, schema, name, on_entry_not_found, error_context)); -} - -optional_ptr CatalogEntryRetriever::ReturnAndCallback(optional_ptr result) { - if (!result) { - return result; - } - if (callback) { - // Call the callback if it's set - callback(*result); - } - return result; -} - -void CatalogEntryRetriever::Inherit(const CatalogEntryRetriever &parent) { - this->callback = parent.callback; - this->search_path = parent.search_path; -} - -CatalogSearchPath &CatalogEntryRetriever::GetSearchPath() { - if (search_path) { - return *search_path; - } - return *ClientData::Get(context).catalog_search_path; -} - -void CatalogEntryRetriever::SetSearchPath(vector entries) { - vector new_path; - for (auto &entry : entries) { - if (IsInvalidCatalog(entry.catalog) || entry.catalog == SYSTEM_CATALOG || entry.catalog == TEMP_CATALOG) { - continue; - } - new_path.push_back(std::move(entry)); - } - if (new_path.empty()) { - return; - } - - // push the set paths from the ClientContext behind the provided paths - auto &client_search_path = *ClientData::Get(context).catalog_search_path; - auto &set_paths = client_search_path.GetSetPaths(); - for (auto path : set_paths) { - if (IsInvalidCatalog(path.catalog)) { - path.catalog = DatabaseManager::GetDefaultDatabase(context); - } - new_path.push_back(std::move(path)); - } - - this->search_path = make_shared_ptr(context, std::move(new_path)); -} - -void CatalogEntryRetriever::SetCallback(catalog_entry_callback_t callback) { - this->callback = std::move(callback); -} - -catalog_entry_callback_t CatalogEntryRetriever::GetCallback() { - return callback; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_search_path.cpp b/src/duckdb/src/catalog/catalog_search_path.cpp deleted file mode 100644 index dcf2569d8..000000000 --- a/src/duckdb/src/catalog/catalog_search_path.cpp +++ /dev/null @@ -1,284 +0,0 @@ -#include "duckdb/catalog/catalog_search_path.hpp" -#include "duckdb/catalog/default/default_schemas.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/constants.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database_manager.hpp" - -namespace duckdb { - -CatalogSearchEntry::CatalogSearchEntry(string catalog_p, string schema_p) - : catalog(std::move(catalog_p)), schema(std::move(schema_p)) { -} - -string CatalogSearchEntry::ToString() const { - if (catalog.empty()) { - return WriteOptionallyQuoted(schema); - } else { - return WriteOptionallyQuoted(catalog) + "." + WriteOptionallyQuoted(schema); - } -} - -string CatalogSearchEntry::WriteOptionallyQuoted(const string &input) { - for (idx_t i = 0; i < input.size(); i++) { - if (input[i] == '.' || input[i] == ',') { - return "\"" + input + "\""; - } - } - return input; -} - -string CatalogSearchEntry::ListToString(const vector &input) { - string result; - for (auto &entry : input) { - if (!result.empty()) { - result += ","; - } - result += entry.ToString(); - } - return result; -} - -CatalogSearchEntry CatalogSearchEntry::ParseInternal(const string &input, idx_t &idx) { - string catalog; - string schema; - string entry; - bool finished = false; -normal: - for (; idx < input.size(); idx++) { - if (input[idx] == '"') { - idx++; - goto quoted; - } else if (input[idx] == '.') { - goto separator; - } else if (input[idx] == ',') { - finished = true; - goto separator; - } - entry += input[idx]; - } - finished = true; - goto separator; -quoted: - //! look for another quote - for (; idx < input.size(); idx++) { - if (input[idx] == '"') { - //! unquote - idx++; - if (idx < input.size() && input[idx] == '"') { - // escaped quote - entry += input[idx]; - continue; - } - goto normal; - } - entry += input[idx]; - } - throw ParserException("Unterminated quote in qualified name!"); -separator: - if (entry.empty()) { - throw ParserException("Unexpected dot - empty CatalogSearchEntry"); - } - if (schema.empty()) { - // if we parse one entry it is the schema - schema = std::move(entry); - } else if (catalog.empty()) { - // if we parse two entries it is [catalog.schema] - catalog = std::move(schema); - schema = std::move(entry); - } else { - throw ParserException("Too many dots - expected [schema] or [catalog.schema] for CatalogSearchEntry"); - } - entry = ""; - idx++; - if (finished) { - goto final; - } - goto normal; -final: - if (schema.empty()) { - throw ParserException("Unexpected end of entry - empty CatalogSearchEntry"); - } - return CatalogSearchEntry(std::move(catalog), std::move(schema)); -} - -CatalogSearchEntry CatalogSearchEntry::Parse(const string &input) { - idx_t pos = 0; - auto result = ParseInternal(input, pos); - if (pos < input.size()) { - throw ParserException("Failed to convert entry \"%s\" to CatalogSearchEntry - expected a single entry", input); - } - return result; -} - -vector CatalogSearchEntry::ParseList(const string &input) { - idx_t pos = 0; - vector result; - while (pos < input.size()) { - auto entry = ParseInternal(input, pos); - result.push_back(entry); - } - return result; -} - -CatalogSearchPath::CatalogSearchPath(ClientContext &context_p, vector entries) - : context(context_p) { - SetPathsInternal(std::move(entries)); -} - -CatalogSearchPath::CatalogSearchPath(ClientContext &context_p) : CatalogSearchPath(context_p, {}) { -} - -void CatalogSearchPath::Reset() { - vector empty; - SetPathsInternal(empty); -} - -string CatalogSearchPath::GetSetName(CatalogSetPathType set_type) { - switch (set_type) { - case CatalogSetPathType::SET_SCHEMA: - return "SET schema"; - case CatalogSetPathType::SET_SCHEMAS: - return "SET search_path"; - default: - throw InternalException("Unrecognized CatalogSetPathType"); - } -} - -void CatalogSearchPath::Set(vector new_paths, CatalogSetPathType set_type) { - if (set_type != CatalogSetPathType::SET_SCHEMAS && new_paths.size() != 1) { - throw CatalogException("%s can set only 1 schema. This has %d", GetSetName(set_type), new_paths.size()); - } - for (auto &path : new_paths) { - auto schema_entry = Catalog::GetSchema(context, path.catalog, path.schema, OnEntryNotFound::RETURN_NULL); - if (schema_entry) { - // we are setting a schema - update the catalog and schema - if (path.catalog.empty()) { - path.catalog = GetDefault().catalog; - } - continue; - } - // only schema supplied - check if this is a catalog instead - if (path.catalog.empty()) { - auto catalog = Catalog::GetCatalogEntry(context, path.schema); - if (catalog) { - auto schema = catalog->GetSchema(context, DEFAULT_SCHEMA, OnEntryNotFound::RETURN_NULL); - if (schema) { - path.catalog = std::move(path.schema); - path.schema = schema->name; - continue; - } - } - } - throw CatalogException("%s: No catalog + schema named \"%s\" found.", GetSetName(set_type), path.ToString()); - } - if (set_type == CatalogSetPathType::SET_SCHEMA) { - if (new_paths[0].catalog == TEMP_CATALOG || new_paths[0].catalog == SYSTEM_CATALOG) { - throw CatalogException("%s cannot be set to internal schema \"%s\"", GetSetName(set_type), - new_paths[0].catalog); - } - } - SetPathsInternal(std::move(new_paths)); -} - -void CatalogSearchPath::Set(CatalogSearchEntry new_value, CatalogSetPathType set_type) { - vector new_paths {std::move(new_value)}; - Set(std::move(new_paths), set_type); -} - -const vector &CatalogSearchPath::Get() { - return paths; -} - -string CatalogSearchPath::GetDefaultSchema(const string &catalog) { - for (auto &path : paths) { - if (path.catalog == TEMP_CATALOG) { - continue; - } - if (StringUtil::CIEquals(path.catalog, catalog)) { - return path.schema; - } - } - return DEFAULT_SCHEMA; -} - -string CatalogSearchPath::GetDefaultCatalog(const string &schema) { - if (DefaultSchemaGenerator::IsDefaultSchema(schema)) { - return SYSTEM_CATALOG; - } - for (auto &path : paths) { - if (path.catalog == TEMP_CATALOG) { - continue; - } - if (StringUtil::CIEquals(path.schema, schema)) { - return path.catalog; - } - } - return INVALID_CATALOG; -} - -vector CatalogSearchPath::GetCatalogsForSchema(const string &schema) { - vector schemas; - if (DefaultSchemaGenerator::IsDefaultSchema(schema)) { - schemas.push_back(SYSTEM_CATALOG); - } else { - for (auto &path : paths) { - if (StringUtil::CIEquals(path.schema, schema)) { - schemas.push_back(path.catalog); - } - } - } - return schemas; -} - -vector CatalogSearchPath::GetSchemasForCatalog(const string &catalog) { - vector schemas; - for (auto &path : paths) { - if (StringUtil::CIEquals(path.catalog, catalog)) { - schemas.push_back(path.schema); - } - } - return schemas; -} - -const CatalogSearchEntry &CatalogSearchPath::GetDefault() { - const auto &paths = Get(); - D_ASSERT(paths.size() >= 2); - return paths[1]; -} - -void CatalogSearchPath::SetPathsInternal(vector new_paths) { - this->set_paths = std::move(new_paths); - - paths.clear(); - paths.reserve(set_paths.size() + 3); - paths.emplace_back(TEMP_CATALOG, DEFAULT_SCHEMA); - for (auto &path : set_paths) { - paths.push_back(path); - } - paths.emplace_back(INVALID_CATALOG, DEFAULT_SCHEMA); - paths.emplace_back(SYSTEM_CATALOG, DEFAULT_SCHEMA); - paths.emplace_back(SYSTEM_CATALOG, "pg_catalog"); -} - -bool CatalogSearchPath::SchemaInSearchPath(ClientContext &context, const string &catalog_name, - const string &schema_name) { - for (auto &path : paths) { - if (!StringUtil::CIEquals(path.schema, schema_name)) { - continue; - } - if (StringUtil::CIEquals(path.catalog, catalog_name)) { - return true; - } - if (IsInvalidCatalog(path.catalog) && - StringUtil::CIEquals(catalog_name, DatabaseManager::GetDefaultDatabase(context))) { - return true; - } - } - return false; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_set.cpp b/src/duckdb/src/catalog/catalog_set.cpp deleted file mode 100644 index b95f85747..000000000 --- a/src/duckdb/src/catalog/catalog_set.cpp +++ /dev/null @@ -1,731 +0,0 @@ -#include "duckdb/catalog/catalog_set.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/memory_stream.hpp" -#include "duckdb/common/serializer/binary_serializer.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/transaction/duck_transaction_manager.hpp" -#include "duckdb/transaction/transaction_manager.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/common/exception/transaction_exception.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" - -namespace duckdb { - -void CatalogEntryMap::AddEntry(unique_ptr entry) { - auto name = entry->name; - - if (entries.find(name) != entries.end()) { - throw InternalException("Entry with name \"%s\" already exists", name); - } - entries.insert(make_pair(name, std::move(entry))); -} - -void CatalogEntryMap::UpdateEntry(unique_ptr catalog_entry) { - auto name = catalog_entry->name; - - auto entry = entries.find(name); - if (entry == entries.end()) { - throw InternalException("Entry with name \"%s\" does not exist", name); - } - - auto existing = std::move(entry->second); - entry->second = std::move(catalog_entry); - entry->second->SetChild(std::move(existing)); -} - -case_insensitive_tree_t> &CatalogEntryMap::Entries() { - return entries; -} - -void CatalogEntryMap::DropEntry(CatalogEntry &entry) { - auto &name = entry.name; - auto chain = GetEntry(name); - if (!chain) { - throw InternalException("Attempting to drop entry with name \"%s\" but no chain with that name exists", name); - } - auto child = entry.TakeChild(); - if (!entry.HasParent()) { - // This is the top of the chain - D_ASSERT(chain.get() == &entry); - auto it = entries.find(name); - D_ASSERT(it != entries.end()); - - // Remove the entry - it->second.reset(); - if (child) { - // Replace it with its child - it->second = std::move(child); - } else { - entries.erase(it); - } - } else { - // Just replace the entry with its child - auto &parent = entry.Parent(); - parent.SetChild(std::move(child)); - } -} - -optional_ptr CatalogEntryMap::GetEntry(const string &name) { - auto entry = entries.find(name); - if (entry == entries.end()) { - return nullptr; - } - return entry->second.get(); -} - -CatalogSet::CatalogSet(Catalog &catalog_p, unique_ptr defaults) - : catalog(catalog_p.Cast()), defaults(std::move(defaults)) { - D_ASSERT(catalog_p.IsDuckCatalog()); -} -CatalogSet::~CatalogSet() { -} - -bool CatalogSet::StartChain(CatalogTransaction transaction, const string &name, unique_lock &read_lock) { - D_ASSERT(!map.GetEntry(name)); - - // check if there is a default entry - auto entry = CreateDefaultEntry(transaction, name, read_lock); - if (entry) { - return false; - } - - // first create a dummy deleted entry - // so other transactions will see that instead of the entry that is to be added. - auto dummy_node = make_uniq(CatalogType::INVALID, catalog, name); - dummy_node->timestamp = 0; - dummy_node->deleted = true; - dummy_node->set = this; - - map.AddEntry(std::move(dummy_node)); - return true; -} - -bool CatalogSet::VerifyVacancy(CatalogTransaction transaction, CatalogEntry &entry) { - if (HasConflict(transaction, entry.timestamp)) { - // A transaction that is not visible to our snapshot has already made a change to this entry. - // Because of Catalog limitations we can't push our change on this, even if the change was made by another - // active transaction that might end up being aborted. So we have to cancel this transaction. - throw TransactionException("Catalog write-write conflict on create with \"%s\"", entry.name); - } - // The entry is visible to our snapshot - if (!entry.deleted) { - return false; - } - return true; -} - -static bool IsDependencyEntry(CatalogEntry &entry) { - return entry.type == CatalogType::DEPENDENCY_ENTRY; -} - -void CatalogSet::CheckCatalogEntryInvariants(CatalogEntry &value, const string &name) { - if (value.internal && !catalog.IsSystemCatalog() && name != DEFAULT_SCHEMA) { - throw InternalException("Attempting to create internal entry \"%s\" in non-system catalog - internal entries " - "can only be created in the system catalog", - name); - } - if (!value.internal) { - if (!value.temporary && catalog.IsSystemCatalog() && !IsDependencyEntry(value)) { - throw InternalException( - "Attempting to create non-internal entry \"%s\" in system catalog - the system catalog " - "can only contain internal entries", - name); - } - if (value.temporary && !catalog.IsTemporaryCatalog()) { - throw InternalException("Attempting to create temporary entry \"%s\" in non-temporary catalog", name); - } - if (!value.temporary && catalog.IsTemporaryCatalog() && name != DEFAULT_SCHEMA) { - throw InvalidInputException("Cannot create non-temporary entry \"%s\" in temporary catalog", name); - } - } -} - -optional_ptr CatalogSet::CreateCommittedEntry(unique_ptr entry) { - auto existing_entry = map.GetEntry(entry->name); - if (existing_entry) { - // Return null if an entry by that name already exists - return nullptr; - } - - auto catalog_entry = entry.get(); - - entry->set = this; - // Give the entry commit id 0, so it is visible to all transactions - entry->timestamp = 0; - map.AddEntry(std::move(entry)); - - return catalog_entry; -} - -bool CatalogSet::CreateEntryInternal(CatalogTransaction transaction, const string &name, unique_ptr value, - unique_lock &read_lock, bool should_be_empty) { - auto entry_value = map.GetEntry(name); - if (!entry_value) { - // Add a dummy node to start the chain - if (!StartChain(transaction, name, read_lock)) { - return false; - } - } else if (should_be_empty) { - // Verify that the entry is deleted, not altered by another transaction - if (!VerifyVacancy(transaction, *entry_value)) { - return false; - } - } - - // Finally add the new entry to the chain - auto value_ptr = value.get(); - map.UpdateEntry(std::move(value)); - // Push the old entry in the undo buffer for this transaction, so it can be restored in the event of failure - if (transaction.transaction) { - DuckTransactionManager::Get(GetCatalog().GetAttached()) - .PushCatalogEntry(*transaction.transaction, value_ptr->Child()); - } - return true; -} - -bool CatalogSet::CreateEntry(CatalogTransaction transaction, const string &name, unique_ptr value, - const LogicalDependencyList &dependencies) { - CheckCatalogEntryInvariants(*value, name); - - // Mark this entry as being created by the current active transaction - value->timestamp = transaction.transaction_id; - value->set = this; - catalog.GetDependencyManager()->AddObject(transaction, *value, dependencies); - - // lock the catalog for writing - lock_guard write_lock(catalog.GetWriteLock()); - // lock this catalog set to disallow reading - unique_lock read_lock(catalog_lock); - - return CreateEntryInternal(transaction, name, std::move(value), read_lock); -} - -bool CatalogSet::CreateEntry(ClientContext &context, const string &name, unique_ptr value, - const LogicalDependencyList &dependencies) { - return CreateEntry(catalog.GetCatalogTransaction(context), name, std::move(value), dependencies); -} - -//! This method is used to retrieve an entry for the purpose of making a new version, through an alter/drop/create -optional_ptr CatalogSet::GetEntryInternal(CatalogTransaction transaction, const string &name) { - auto entry_value = map.GetEntry(name); - if (!entry_value) { - return nullptr; - } - auto &catalog_entry = *entry_value; - - // Check if this entry is visible to our snapshot - if (HasConflict(transaction, catalog_entry.timestamp)) { - // We intend to create a new version of the entry. - // Another transaction has already made an edit to this catalog entry, because of limitations in the Catalog we - // can't create an edit alongside this even if the other transaction might end up getting aborted. So we have to - // abort the transaction. - throw TransactionException("Catalog write-write conflict on alter with \"%s\"", catalog_entry.name); - } - // The entry is visible to our snapshot, check if it's deleted - if (catalog_entry.deleted) { - return nullptr; - } - return &catalog_entry; -} - -bool CatalogSet::AlterOwnership(CatalogTransaction transaction, ChangeOwnershipInfo &info) { - // lock the catalog for writing - unique_lock write_lock(catalog.GetWriteLock()); - - auto entry = GetEntryInternal(transaction, info.name); - if (!entry) { - return false; - } - optional_ptr owner_entry; - auto schema = catalog.GetSchema(transaction, info.owner_schema, OnEntryNotFound::RETURN_NULL); - if (schema) { - vector entry_types {CatalogType::TABLE_ENTRY, CatalogType::SEQUENCE_ENTRY}; - for (auto entry_type : entry_types) { - owner_entry = schema->GetEntry(transaction, entry_type, info.owner_name); - if (owner_entry) { - break; - } - } - } - if (!owner_entry) { - throw CatalogException("CatalogElement \"%s.%s\" does not exist!", info.owner_schema, info.owner_name); - } - write_lock.unlock(); - catalog.GetDependencyManager()->AddOwnership(transaction, *owner_entry, *entry); - return true; -} - -bool CatalogSet::RenameEntryInternal(CatalogTransaction transaction, CatalogEntry &old, const string &new_name, - AlterInfo &alter_info, unique_lock &read_lock) { - auto &original_name = old.name; - - auto &context = *transaction.context; - auto entry_value = map.GetEntry(new_name); - if (entry_value) { - auto &existing_entry = GetEntryForTransaction(transaction, *entry_value); - if (!existing_entry.deleted) { - // There exists an entry by this name that is not deleted - old.UndoAlter(context, alter_info); - throw CatalogException("Could not rename \"%s\" to \"%s\": another entry with this name already exists!", - original_name, new_name); - } - } - - // Add a RENAMED_ENTRY before adding a DELETED_ENTRY, this makes it so that when this is committed - // we know that this was not a DROP statement. - auto renamed_tombstone = make_uniq(CatalogType::RENAMED_ENTRY, old.ParentCatalog(), original_name); - renamed_tombstone->timestamp = transaction.transaction_id; - renamed_tombstone->deleted = false; - renamed_tombstone->set = this; - if (!CreateEntryInternal(transaction, original_name, std::move(renamed_tombstone), read_lock, - /*should_be_empty = */ false)) { - return false; - } - if (!DropEntryInternal(transaction, original_name, false)) { - return false; - } - - // Add the renamed entry - // Start this off with a RENAMED_ENTRY node, for commit/cleanup/rollback purposes - auto renamed_node = make_uniq(CatalogType::RENAMED_ENTRY, catalog, new_name); - renamed_node->timestamp = transaction.transaction_id; - renamed_node->deleted = false; - renamed_node->set = this; - return CreateEntryInternal(transaction, new_name, std::move(renamed_node), read_lock); -} - -bool CatalogSet::AlterEntry(CatalogTransaction transaction, const string &name, AlterInfo &alter_info) { - // If the entry does not exist, we error - auto entry = GetEntry(transaction, name); - if (!entry) { - return false; - } - if (!alter_info.allow_internal && entry->internal) { - throw CatalogException("Cannot alter entry \"%s\" because it is an internal system entry", entry->name); - } - - unique_ptr value; - if (alter_info.type == AlterType::SET_COMMENT) { - // Copy the existing entry; we are only changing metadata here - if (!transaction.context) { - throw InternalException("Cannot AlterEntry::SET_COMMENT without client context"); - } - value = entry->Copy(*transaction.context); - value->comment = alter_info.Cast().comment_value; - } else { - // Use the existing entry to create the altered entry - value = entry->AlterEntry(transaction, alter_info); - if (!value) { - // alter failed, but did not result in an error - return true; - } - } - - // lock the catalog for writing - unique_lock write_lock(catalog.GetWriteLock()); - // lock this catalog set to disallow reading - unique_lock read_lock(catalog_lock); - - // fetch the entry again before doing the modification - // this will catch any write-write conflicts between transactions - entry = GetEntryInternal(transaction, name); - - // Mark this entry as being created by this transaction - value->timestamp = transaction.transaction_id; - value->set = this; - - if (!StringUtil::CIEquals(value->name, entry->name)) { - if (!RenameEntryInternal(transaction, *entry, value->name, alter_info, read_lock)) { - return false; - } - } - auto new_entry = value.get(); - map.UpdateEntry(std::move(value)); - - // push the old entry in the undo buffer for this transaction - if (transaction.transaction) { - // serialize the AlterInfo into a temporary buffer - MemoryStream stream; - BinarySerializer serializer(stream); - serializer.Begin(); - serializer.WriteProperty(100, "column_name", alter_info.GetColumnName()); - serializer.WriteProperty(101, "alter_info", &alter_info); - serializer.End(); - - DuckTransactionManager::Get(GetCatalog().GetAttached()) - .PushCatalogEntry(*transaction.transaction, new_entry->Child(), stream.GetData(), stream.GetPosition()); - } - - read_lock.unlock(); - write_lock.unlock(); - - // Check the dependency manager to verify that there are no conflicting dependencies with this alter - catalog.GetDependencyManager()->AlterObject(transaction, *entry, *new_entry, alter_info); - - return true; -} - -bool CatalogSet::DropDependencies(CatalogTransaction transaction, const string &name, bool cascade, - bool allow_drop_internal) { - auto entry = GetEntry(transaction, name); - if (!entry) { - return false; - } - if (entry->internal && !allow_drop_internal) { - throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); - } - // check any dependencies of this object - D_ASSERT(entry->ParentCatalog().IsDuckCatalog()); - auto &duck_catalog = entry->ParentCatalog().Cast(); - duck_catalog.GetDependencyManager()->DropObject(transaction, *entry, cascade); - return true; -} - -bool CatalogSet::DropEntryInternal(CatalogTransaction transaction, const string &name, bool allow_drop_internal) { - // lock the catalog for writing - // we can only delete an entry that exists - auto entry = GetEntryInternal(transaction, name); - if (!entry) { - return false; - } - if (entry->internal && !allow_drop_internal) { - throw CatalogException("Cannot drop entry \"%s\" because it is an internal system entry", entry->name); - } - - // create a new tombstone entry and replace the currently stored one - // set the timestamp to the timestamp of the current transaction - // and point it at the tombstone node - auto value = make_uniq(CatalogType::DELETED_ENTRY, entry->ParentCatalog(), entry->name); - value->timestamp = transaction.transaction_id; - value->set = this; - value->deleted = true; - auto value_ptr = value.get(); - map.UpdateEntry(std::move(value)); - - // push the old entry in the undo buffer for this transaction - if (transaction.transaction) { - DuckTransactionManager::Get(GetCatalog().GetAttached()) - .PushCatalogEntry(*transaction.transaction, value_ptr->Child()); - } - return true; -} - -bool CatalogSet::DropEntry(CatalogTransaction transaction, const string &name, bool cascade, bool allow_drop_internal) { - if (!DropDependencies(transaction, name, cascade, allow_drop_internal)) { - return false; - } - lock_guard write_lock(catalog.GetWriteLock()); - lock_guard read_lock(catalog_lock); - return DropEntryInternal(transaction, name, allow_drop_internal); -} - -bool CatalogSet::DropEntry(ClientContext &context, const string &name, bool cascade, bool allow_drop_internal) { - return DropEntry(catalog.GetCatalogTransaction(context), name, cascade, allow_drop_internal); -} - -//! Verify that the object referenced by the dependency still exists when we commit the dependency -void CatalogSet::VerifyExistenceOfDependency(transaction_t commit_id, CatalogEntry &entry) { - auto &duck_catalog = GetCatalog(); - - // Make sure that we don't see any uncommitted changes - auto transaction_id = MAX_TRANSACTION_ID; - // This will allow us to see all committed changes made before this COMMIT happened - auto tx_start_time = commit_id; - CatalogTransaction commit_transaction(duck_catalog.GetDatabase(), transaction_id, tx_start_time); - - D_ASSERT(entry.type == CatalogType::DEPENDENCY_ENTRY); - auto &dep = entry.Cast(); - duck_catalog.GetDependencyManager()->VerifyExistence(commit_transaction, dep); -} - -//! Verify that no dependencies creations were committed since our transaction started, that reference the entry we're -//! dropping -void CatalogSet::CommitDrop(transaction_t commit_id, transaction_t start_time, CatalogEntry &entry) { - auto &duck_catalog = GetCatalog(); - - // Make sure that we don't see any uncommitted changes - auto transaction_id = MAX_TRANSACTION_ID; - // This will allow us to see all committed changes made before this COMMIT happened - auto tx_start_time = commit_id; - CatalogTransaction commit_transaction(duck_catalog.GetDatabase(), transaction_id, tx_start_time); - - duck_catalog.GetDependencyManager()->VerifyCommitDrop(commit_transaction, start_time, entry); -} - -DuckCatalog &CatalogSet::GetCatalog() { - return catalog; -} - -void CatalogSet::CleanupEntry(CatalogEntry &catalog_entry) { - // destroy the backed up entry: it is no longer required - lock_guard write_lock(catalog.GetWriteLock()); - lock_guard lock(catalog_lock); - auto &parent = catalog_entry.Parent(); - map.DropEntry(catalog_entry); - if (parent.deleted && !parent.HasChild() && !parent.HasParent()) { - // The entry's parent is a tombstone and the entry had no child - // clean up the mapping and the tombstone entry as well - D_ASSERT(map.GetEntry(parent.name).get() == &parent); - map.DropEntry(parent); - } -} - -bool CatalogSet::CreatedByOtherActiveTransaction(CatalogTransaction transaction, transaction_t timestamp) { - // True if this transaction is not committed yet and the entry was made by another active (not committed) - // transaction - return (timestamp >= TRANSACTION_ID_START && timestamp != transaction.transaction_id); -} - -bool CatalogSet::CommittedAfterStarting(CatalogTransaction transaction, transaction_t timestamp) { - // The entry has been committed after this transaction started, this is not our source of truth. - return (timestamp < TRANSACTION_ID_START && timestamp > transaction.start_time); -} - -bool CatalogSet::HasConflict(CatalogTransaction transaction, transaction_t timestamp) { - return CreatedByOtherActiveTransaction(transaction, timestamp) || CommittedAfterStarting(transaction, timestamp); -} - -bool CatalogSet::IsCommitted(transaction_t timestamp) { - //! FIXME: `transaction_t` itself should be a class that has these methods - return timestamp < TRANSACTION_ID_START; -} - -bool CatalogSet::UseTimestamp(CatalogTransaction transaction, transaction_t timestamp) { - if (timestamp == transaction.transaction_id) { - // we created this version - return true; - } - if (timestamp < transaction.start_time) { - // this version was commited before we started the transaction - return true; - } - return false; -} - -CatalogEntry &CatalogSet::GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t) { - bool visible; - return GetEntryForTransaction(transaction, current, visible); -} - -CatalogEntry &CatalogSet::GetEntryForTransaction(CatalogTransaction transaction, CatalogEntry ¤t, bool &visible) { - reference entry(current); - while (entry.get().HasChild()) { - if (UseTimestamp(transaction, entry.get().timestamp)) { - visible = true; - return entry.get(); - } - entry = entry.get().Child(); - } - visible = false; - return entry.get(); -} - -CatalogEntry &CatalogSet::GetCommittedEntry(CatalogEntry ¤t) { - reference entry(current); - while (entry.get().HasChild()) { - if (entry.get().timestamp < TRANSACTION_ID_START) { - // this entry is committed: use it - break; - } - entry = entry.get().Child(); - } - return entry.get(); -} - -SimilarCatalogEntry CatalogSet::SimilarEntry(CatalogTransaction transaction, const string &name) { - unique_lock lock(catalog_lock); - CreateDefaultEntries(transaction, lock); - - SimilarCatalogEntry result; - for (auto &kv : map.Entries()) { - auto entry_score = StringUtil::SimilarityRating(kv.first, name); - if (entry_score > result.score) { - result.score = entry_score; - result.name = kv.first; - } - } - return result; -} - -optional_ptr CatalogSet::CreateDefaultEntry(CatalogTransaction transaction, const string &name, - unique_lock &read_lock) { - // no entry found with this name, check for defaults - if (!defaults || defaults->created_all_entries) { - // no defaults either: return null - return nullptr; - } - read_lock.unlock(); - // this catalog set has a default map defined - // check if there is a default entry that we can create with this name - auto entry = defaults->CreateDefaultEntry(transaction, name); - - read_lock.lock(); - if (!entry) { - // no default entry - return nullptr; - } - // there is a default entry! create it - auto result = CreateCommittedEntry(std::move(entry)); - if (result) { - return result; - } - // we found a default entry, but failed - // this means somebody else created the entry first - // just retry? - read_lock.unlock(); - return GetEntry(transaction, name); -} - -CatalogSet::EntryLookup CatalogSet::GetEntryDetailed(CatalogTransaction transaction, const string &name) { - unique_lock read_lock(catalog_lock); - auto entry_value = map.GetEntry(name); - if (entry_value) { - // we found an entry for this name - // check the version numbers - - auto &catalog_entry = *entry_value; - bool visible; - auto ¤t = GetEntryForTransaction(transaction, catalog_entry, visible); - if (current.deleted) { - if (!visible) { - return EntryLookup {nullptr, EntryLookup::FailureReason::INVISIBLE}; - } else { - return EntryLookup {nullptr, EntryLookup::FailureReason::DELETED}; - } - } - D_ASSERT(StringUtil::CIEquals(name, current.name)); - return EntryLookup {¤t, EntryLookup::FailureReason::SUCCESS}; - } - auto default_entry = CreateDefaultEntry(transaction, name, read_lock); - if (!default_entry) { - return EntryLookup {default_entry, EntryLookup::FailureReason::NOT_PRESENT}; - } - return EntryLookup {default_entry, EntryLookup::FailureReason::SUCCESS}; -} - -optional_ptr CatalogSet::GetEntry(CatalogTransaction transaction, const string &name) { - auto lookup = GetEntryDetailed(transaction, name); - return lookup.result; -} - -optional_ptr CatalogSet::GetEntry(ClientContext &context, const string &name) { - return GetEntry(catalog.GetCatalogTransaction(context), name); -} - -void CatalogSet::UpdateTimestamp(CatalogEntry &entry, transaction_t timestamp) { - entry.timestamp = timestamp; -} - -void CatalogSet::Undo(CatalogEntry &entry) { - lock_guard write_lock(catalog.GetWriteLock()); - lock_guard lock(catalog_lock); - - // entry has to be restored - // and entry->parent has to be removed ("rolled back") - - // i.e. we have to place (entry) as (entry->parent) again - auto &to_be_removed_node = entry.Parent(); - to_be_removed_node.Rollback(entry); - - D_ASSERT(StringUtil::CIEquals(entry.name, to_be_removed_node.name)); - if (!to_be_removed_node.HasParent()) { - to_be_removed_node.Child().SetAsRoot(); - } - map.DropEntry(to_be_removed_node); - - if (entry.type == CatalogType::INVALID) { - // This was the root of the entry chain - map.DropEntry(entry); - } -} - -void CatalogSet::CreateDefaultEntries(CatalogTransaction transaction, unique_lock &read_lock) { - if (!defaults || defaults->created_all_entries) { - return; - } - // this catalog set has a default set defined: - auto default_entries = defaults->GetDefaultEntries(); - for (auto &default_entry : default_entries) { - auto entry_value = map.GetEntry(default_entry); - if (!entry_value) { - // we unlock during the CreateEntry, since it might reference other catalog sets... - // specifically for views this can happen since the view will be bound - read_lock.unlock(); - auto entry = defaults->CreateDefaultEntry(transaction, default_entry); - if (!entry) { - throw InternalException("Failed to create default entry for %s", default_entry); - } - - read_lock.lock(); - CreateCommittedEntry(std::move(entry)); - } - } - defaults->created_all_entries = true; -} - -void CatalogSet::Scan(CatalogTransaction transaction, const std::function &callback) { - // lock the catalog set - unique_lock lock(catalog_lock); - CreateDefaultEntries(transaction, lock); - - for (auto &kv : map.Entries()) { - auto &entry = *kv.second; - auto &entry_for_transaction = GetEntryForTransaction(transaction, entry); - if (!entry_for_transaction.deleted) { - callback(entry_for_transaction); - } - } -} - -void CatalogSet::Scan(ClientContext &context, const std::function &callback) { - Scan(catalog.GetCatalogTransaction(context), callback); -} - -void CatalogSet::ScanWithPrefix(CatalogTransaction transaction, const std::function &callback, - const string &prefix) { - // lock the catalog set - unique_lock lock(catalog_lock); - CreateDefaultEntries(transaction, lock); - - auto &entries = map.Entries(); - auto it = entries.lower_bound(prefix); - auto end = entries.upper_bound(prefix + char(255)); - for (; it != end; it++) { - auto &entry = *it->second; - auto &entry_for_transaction = GetEntryForTransaction(transaction, entry); - if (!entry_for_transaction.deleted) { - callback(entry_for_transaction); - } - } -} - -void CatalogSet::Scan(const std::function &callback) { - // lock the catalog set - lock_guard lock(catalog_lock); - for (auto &kv : map.Entries()) { - auto &entry = *kv.second; - auto &commited_entry = GetCommittedEntry(entry); - if (!commited_entry.deleted) { - callback(commited_entry); - } - } -} - -void CatalogSet::Verify(Catalog &catalog_p) { - D_ASSERT(&catalog_p == &catalog); - vector> entries; - Scan([&](CatalogEntry &entry) { entries.push_back(entry); }); - for (auto &entry : entries) { - entry.get().Verify(catalog_p); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/catalog_transaction.cpp b/src/duckdb/src/catalog/catalog_transaction.cpp deleted file mode 100644 index fbe100d3a..000000000 --- a/src/duckdb/src/catalog/catalog_transaction.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "duckdb/catalog/catalog_transaction.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/main/database.hpp" - -namespace duckdb { - -CatalogTransaction::CatalogTransaction(Catalog &catalog, ClientContext &context) { - auto &transaction = Transaction::Get(context, catalog); - this->db = &DatabaseInstance::GetDatabase(context); - if (!transaction.IsDuckTransaction()) { - this->transaction_id = transaction_t(-1); - this->start_time = transaction_t(-1); - } else { - auto &dtransaction = transaction.Cast(); - this->transaction_id = dtransaction.transaction_id; - this->start_time = dtransaction.start_time; - } - this->transaction = &transaction; - this->context = &context; -} - -CatalogTransaction::CatalogTransaction(DatabaseInstance &db, transaction_t transaction_id_p, transaction_t start_time_p) - : db(&db), context(nullptr), transaction(nullptr), transaction_id(transaction_id_p), start_time(start_time_p) { -} - -ClientContext &CatalogTransaction::GetContext() { - if (!context) { - throw InternalException("Attempting to get a context in a CatalogTransaction without a context"); - } - return *context; -} - -CatalogTransaction CatalogTransaction::GetSystemCatalogTransaction(ClientContext &context) { - return CatalogTransaction(Catalog::GetSystemCatalog(context), context); -} - -CatalogTransaction CatalogTransaction::GetSystemTransaction(DatabaseInstance &db) { - return CatalogTransaction(db, 1, 1); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_functions.cpp b/src/duckdb/src/catalog/default/default_functions.cpp deleted file mode 100644 index 7702bc4e0..000000000 --- a/src/duckdb/src/catalog/default/default_functions.cpp +++ /dev/null @@ -1,267 +0,0 @@ -#include "duckdb/catalog/default/default_functions.hpp" -#include "duckdb/parser/parser.hpp" -#include "duckdb/parser/parsed_data/create_macro_info.hpp" -#include "duckdb/parser/expression/columnref_expression.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/function/table_macro_function.hpp" - -#include "duckdb/function/scalar_macro_function.hpp" - -namespace duckdb { - -static const DefaultMacro internal_macros[] = { - {DEFAULT_SCHEMA, "current_role", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_user", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // user name of current execution context - {DEFAULT_SCHEMA, "current_catalog", {nullptr}, {{nullptr, nullptr}}, "main.current_database()"}, // name of current database (called "catalog" in the SQL standard) - {DEFAULT_SCHEMA, "user", {nullptr}, {{nullptr, nullptr}}, "current_user"}, // equivalent to current_user - {DEFAULT_SCHEMA, "session_user", {nullptr}, {{nullptr, nullptr}}, "'duckdb'"}, // session user name - {"pg_catalog", "inet_client_addr", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // address of the remote connection - {"pg_catalog", "inet_client_port", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // port of the remote connection - {"pg_catalog", "inet_server_addr", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // address of the local connection - {"pg_catalog", "inet_server_port", {nullptr}, {{nullptr, nullptr}}, "NULL"}, // port of the local connection - {"pg_catalog", "pg_my_temp_schema", {nullptr}, {{nullptr, nullptr}}, "0"}, // OID of session's temporary schema, or 0 if none - {"pg_catalog", "pg_is_other_temp_schema", {"schema_id", nullptr}, {{nullptr, nullptr}}, "false"}, // is schema another session's temporary schema? - - {"pg_catalog", "pg_conf_load_time", {nullptr}, {{nullptr, nullptr}}, "current_timestamp"}, // configuration load time - {"pg_catalog", "pg_postmaster_start_time", {nullptr}, {{nullptr, nullptr}}, "current_timestamp"}, // server start time - - {"pg_catalog", "pg_typeof", {"expression", nullptr}, {{nullptr, nullptr}}, "lower(typeof(expression))"}, // get the data type of any value - - {"pg_catalog", "current_database", {nullptr}, {{nullptr, nullptr}}, "system.main.current_database()"}, // name of current database (called "catalog" in the SQL standard) - {"pg_catalog", "current_query", {nullptr}, {{nullptr, nullptr}}, "system.main.current_query()"}, // the currently executing query (NULL if not inside a plpgsql function) - {"pg_catalog", "current_schema", {nullptr}, {{nullptr, nullptr}}, "system.main.current_schema()"}, // name of current schema - {"pg_catalog", "current_schemas", {"include_implicit"}, {{nullptr, nullptr}}, "system.main.current_schemas(include_implicit)"}, // names of schemas in search path - - // privilege functions - {"pg_catalog", "has_any_column_privilege", {"table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for any column of table - {"pg_catalog", "has_any_column_privilege", {"user", "table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for any column of table - {"pg_catalog", "has_column_privilege", {"table", "column", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for column - {"pg_catalog", "has_column_privilege", {"user", "table", "column", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for column - {"pg_catalog", "has_database_privilege", {"database", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for database - {"pg_catalog", "has_database_privilege", {"user", "database", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for database - {"pg_catalog", "has_foreign_data_wrapper_privilege", {"fdw", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for foreign-data wrapper - {"pg_catalog", "has_foreign_data_wrapper_privilege", {"user", "fdw", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for foreign-data wrapper - {"pg_catalog", "has_function_privilege", {"function", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for function - {"pg_catalog", "has_function_privilege", {"user", "function", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for function - {"pg_catalog", "has_language_privilege", {"language", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for language - {"pg_catalog", "has_language_privilege", {"user", "language", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for language - {"pg_catalog", "has_schema_privilege", {"schema", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for schema - {"pg_catalog", "has_schema_privilege", {"user", "schema", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for schema - {"pg_catalog", "has_sequence_privilege", {"sequence", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for sequence - {"pg_catalog", "has_sequence_privilege", {"user", "sequence", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for sequence - {"pg_catalog", "has_server_privilege", {"server", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for foreign server - {"pg_catalog", "has_server_privilege", {"user", "server", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for foreign server - {"pg_catalog", "has_table_privilege", {"table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for table - {"pg_catalog", "has_table_privilege", {"user", "table", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for table - {"pg_catalog", "has_tablespace_privilege", {"tablespace", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for tablespace - {"pg_catalog", "has_tablespace_privilege", {"user", "tablespace", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for tablespace - - // various postgres system functions - {"pg_catalog", "pg_get_viewdef", {"oid", nullptr}, {{nullptr, nullptr}}, "(select sql from duckdb_views() v where v.view_oid=oid)"}, - {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", nullptr}, {{nullptr, nullptr}}, "(select constraint_text from duckdb_constraints() d_constraint where d_constraint.table_oid=constraint_oid//1000000 and d_constraint.constraint_index=constraint_oid%1000000)"}, - {"pg_catalog", "pg_get_constraintdef", {"constraint_oid", "pretty_bool", nullptr}, {{nullptr, nullptr}}, "pg_get_constraintdef(constraint_oid)"}, - {"pg_catalog", "pg_get_expr", {"pg_node_tree", "relation_oid", nullptr}, {{nullptr, nullptr}}, "pg_node_tree"}, - {"pg_catalog", "format_pg_type", {"logical_type", "type_name", nullptr}, {{nullptr, nullptr}}, "case upper(logical_type) when 'FLOAT' then 'float4' when 'DOUBLE' then 'float8' when 'DECIMAL' then 'numeric' when 'ENUM' then lower(type_name) when 'VARCHAR' then 'varchar' when 'BLOB' then 'bytea' when 'TIMESTAMP' then 'timestamp' when 'TIME' then 'time' when 'TIMESTAMP WITH TIME ZONE' then 'timestamptz' when 'TIME WITH TIME ZONE' then 'timetz' when 'SMALLINT' then 'int2' when 'INTEGER' then 'int4' when 'BIGINT' then 'int8' when 'BOOLEAN' then 'bool' else lower(logical_type) end"}, - {"pg_catalog", "format_type", {"type_oid", "typemod", nullptr}, {{nullptr, nullptr}}, "(select format_pg_type(logical_type, type_name) from duckdb_types() t where t.type_oid=type_oid) || case when typemod>0 then concat('(', typemod//1000, ',', typemod%1000, ')') else '' end"}, - {"pg_catalog", "map_to_pg_oid", {"type_name", nullptr}, {{nullptr, nullptr}}, "case type_name when 'bool' then 16 when 'int16' then 21 when 'int' then 23 when 'bigint' then 20 when 'date' then 1082 when 'time' then 1083 when 'datetime' then 1114 when 'dec' then 1700 when 'float' then 700 when 'double' then 701 when 'bpchar' then 1043 when 'binary' then 17 when 'interval' then 1186 when 'timestamptz' then 1184 when 'timetz' then 1266 when 'bit' then 1560 when 'guid' then 2950 else null end"}, // map duckdb_oid to pg_oid. If no corresponding type, return null - - {"pg_catalog", "pg_has_role", {"user", "role", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does user have privilege for role - {"pg_catalog", "pg_has_role", {"role", "privilege", nullptr}, {{nullptr, nullptr}}, "true"}, //boolean //does current user have privilege for role - - {"pg_catalog", "col_description", {"table_oid", "column_number", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a table column - {"pg_catalog", "obj_description", {"object_oid", "catalog_name", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a database object - {"pg_catalog", "shobj_description", {"object_oid", "catalog_name", nullptr}, {{nullptr, nullptr}}, "NULL"}, // get comment for a shared database object - - // visibility functions - {"pg_catalog", "pg_collation_is_visible", {"collation_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_conversion_is_visible", {"conversion_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_function_is_visible", {"function_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_opclass_is_visible", {"opclass_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_operator_is_visible", {"operator_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_opfamily_is_visible", {"opclass_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_table_is_visible", {"table_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_ts_config_is_visible", {"config_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_ts_dict_is_visible", {"dict_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_ts_parser_is_visible", {"parser_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_ts_template_is_visible", {"template_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - {"pg_catalog", "pg_type_is_visible", {"type_oid", nullptr}, {{nullptr, nullptr}}, "true"}, - - {"pg_catalog", "pg_size_pretty", {"bytes", nullptr}, {{nullptr, nullptr}}, "format_bytes(bytes)"}, - - {DEFAULT_SCHEMA, "round_even", {"x", "n", nullptr}, {{nullptr, nullptr}}, "CASE ((abs(x) * power(10, n+1)) % 10) WHEN 5 THEN round(x/2, n) * 2 ELSE round(x, n) END"}, - {DEFAULT_SCHEMA, "roundbankers", {"x", "n", nullptr}, {{nullptr, nullptr}}, "round_even(x, n)"}, - {DEFAULT_SCHEMA, "nullif", {"a", "b", nullptr}, {{nullptr, nullptr}}, "CASE WHEN a=b THEN NULL ELSE a END"}, - {DEFAULT_SCHEMA, "list_append", {"l", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(l, list_value(e))"}, - {DEFAULT_SCHEMA, "array_append", {"arr", "el", nullptr}, {{nullptr, nullptr}}, "list_append(arr, el)"}, - {DEFAULT_SCHEMA, "list_prepend", {"e", "l", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), l)"}, - {DEFAULT_SCHEMA, "array_prepend", {"el", "arr", nullptr}, {{nullptr, nullptr}}, "list_prepend(el, arr)"}, - {DEFAULT_SCHEMA, "array_pop_back", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[:LEN(arr)-1]"}, - {DEFAULT_SCHEMA, "array_pop_front", {"arr", nullptr}, {{nullptr, nullptr}}, "arr[2:]"}, - {DEFAULT_SCHEMA, "array_push_back", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(arr, list_value(e))"}, - {DEFAULT_SCHEMA, "array_push_front", {"arr", "e", nullptr}, {{nullptr, nullptr}}, "list_concat(list_value(e), arr)"}, - {DEFAULT_SCHEMA, "array_to_string", {"arr", "sep", nullptr}, {{nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, - // Test default parameters - {DEFAULT_SCHEMA, "array_to_string_comma_default", {"arr", nullptr}, {{"sep", "','"}, {nullptr, nullptr}}, "list_aggr(arr::varchar[], 'string_agg', sep)"}, - - {DEFAULT_SCHEMA, "generate_subscripts", {"arr", "dim", nullptr}, {{nullptr, nullptr}}, "unnest(generate_series(1, array_length(arr, dim)))"}, - {DEFAULT_SCHEMA, "fdiv", {"x", "y", nullptr}, {{nullptr, nullptr}}, "floor(x/y)"}, - {DEFAULT_SCHEMA, "fmod", {"x", "y", nullptr}, {{nullptr, nullptr}}, "(x-y*floor(x/y))"}, - {DEFAULT_SCHEMA, "split_part", {"string", "delimiter", "position", nullptr}, {{nullptr, nullptr}}, "if(string IS NOT NULL AND delimiter IS NOT NULL AND position IS NOT NULL, coalesce(string_split(string, delimiter)[position],''), NULL)"}, - {DEFAULT_SCHEMA, "geomean", {"x", nullptr}, {{nullptr, nullptr}}, "exp(avg(ln(x)))"}, - {DEFAULT_SCHEMA, "geometric_mean", {"x", nullptr}, {{nullptr, nullptr}}, "geomean(x)"}, - - {DEFAULT_SCHEMA, "weighted_avg", {"value", "weight", nullptr}, {{nullptr, nullptr}}, "SUM(value * weight) / SUM(CASE WHEN value IS NOT NULL THEN weight ELSE 0 END)"}, - {DEFAULT_SCHEMA, "wavg", {"value", "weight", nullptr}, {{nullptr, nullptr}}, "weighted_avg(value, weight)"}, - - {DEFAULT_SCHEMA, "list_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "l[:-:-1]"}, - {DEFAULT_SCHEMA, "array_reverse", {"l", nullptr}, {{nullptr, nullptr}}, "list_reverse(l)"}, - - // FIXME implement as actual function if we encounter a lot of performance issues. Complexity now: n * m, with hashing possibly n + m - {DEFAULT_SCHEMA, "list_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_filter(list_distinct(l1), (variable_intersect) -> list_contains(l2, variable_intersect))"}, - {DEFAULT_SCHEMA, "array_intersect", {"l1", "l2", nullptr}, {{nullptr, nullptr}}, "list_intersect(l1, l2)"}, - - // algebraic list aggregates - {DEFAULT_SCHEMA, "list_avg", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'avg')"}, - {DEFAULT_SCHEMA, "list_var_samp", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'var_samp')"}, - {DEFAULT_SCHEMA, "list_var_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'var_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'stddev_pop')"}, - {DEFAULT_SCHEMA, "list_stddev_samp", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'stddev_samp')"}, - {DEFAULT_SCHEMA, "list_sem", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'sem')"}, - - // distributive list aggregates - {DEFAULT_SCHEMA, "list_approx_count_distinct", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'approx_count_distinct')"}, - {DEFAULT_SCHEMA, "list_bit_xor", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_xor')"}, - {DEFAULT_SCHEMA, "list_bit_or", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_or')"}, - {DEFAULT_SCHEMA, "list_bit_and", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bit_and')"}, - {DEFAULT_SCHEMA, "list_bool_and", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bool_and')"}, - {DEFAULT_SCHEMA, "list_bool_or", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'bool_or')"}, - {DEFAULT_SCHEMA, "list_count", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'count')"}, - {DEFAULT_SCHEMA, "list_entropy", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'entropy')"}, - {DEFAULT_SCHEMA, "list_last", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'last')"}, - {DEFAULT_SCHEMA, "list_first", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'first')"}, - {DEFAULT_SCHEMA, "list_any_value", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'any_value')"}, - {DEFAULT_SCHEMA, "list_kurtosis", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'kurtosis')"}, - {DEFAULT_SCHEMA, "list_kurtosis_pop", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'kurtosis_pop')"}, - {DEFAULT_SCHEMA, "list_min", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'min')"}, - {DEFAULT_SCHEMA, "list_max", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'max')"}, - {DEFAULT_SCHEMA, "list_product", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'product')"}, - {DEFAULT_SCHEMA, "list_skewness", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'skewness')"}, - {DEFAULT_SCHEMA, "list_sum", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'sum')"}, - {DEFAULT_SCHEMA, "list_string_agg", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'string_agg')"}, - - // holistic list aggregates - {DEFAULT_SCHEMA, "list_mode", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'mode')"}, - {DEFAULT_SCHEMA, "list_median", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'median')"}, - {DEFAULT_SCHEMA, "list_mad", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'mad')"}, - - // nested list aggregates - {DEFAULT_SCHEMA, "list_histogram", {"l", nullptr}, {{nullptr, nullptr}}, "list_aggr(l, 'histogram')"}, - - // map functions - {DEFAULT_SCHEMA, "map_contains_entry", {"map", "key", "value"}, {{nullptr, nullptr}}, "contains(map_entries(map), {'key': key, 'value': value})"}, - {DEFAULT_SCHEMA, "map_contains_value", {"map", "value", nullptr}, {{nullptr, nullptr}}, "contains(map_values(map), value)"}, - - // date functions - {DEFAULT_SCHEMA, "date_add", {"date", "interval", nullptr}, {{nullptr, nullptr}}, "date + interval"}, - {DEFAULT_SCHEMA, "current_date", {nullptr}, {{nullptr, nullptr}}, "current_timestamp::DATE"}, - {DEFAULT_SCHEMA, "today", {nullptr}, {{nullptr, nullptr}}, "current_timestamp::DATE"}, - {DEFAULT_SCHEMA, "get_current_time", {nullptr}, {{nullptr, nullptr}}, "current_timestamp::TIMETZ"}, - - // regexp functions - {DEFAULT_SCHEMA, "regexp_split_to_table", {"text", "pattern", nullptr}, {{nullptr, nullptr}}, "unnest(string_split_regex(text, pattern))"}, - - // storage helper functions - {DEFAULT_SCHEMA, "get_block_size", {"db_name"}, {{nullptr, nullptr}}, "(SELECT block_size FROM pragma_database_size() WHERE database_name = db_name)"}, - - // string functions - {DEFAULT_SCHEMA, "md5_number_upper", {"param"}, {{nullptr, nullptr}}, "((md5_number(param)::bit::varchar)[65:])::bit::uint64"}, - {DEFAULT_SCHEMA, "md5_number_lower", {"param"}, {{nullptr, nullptr}}, "((md5_number(param)::bit::varchar)[:64])::bit::uint64"}, - - {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} - }; - -unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(const DefaultMacro &default_macro) { - return CreateInternalMacroInfo(array_ptr(default_macro)); -} - - -unique_ptr DefaultFunctionGenerator::CreateInternalMacroInfo(array_ptr macros) { - auto type = CatalogType::MACRO_ENTRY; - auto bind_info = make_uniq(type); - for(auto &default_macro : macros) { - // parse the expression - auto expressions = Parser::ParseExpressionList(default_macro.macro); - D_ASSERT(expressions.size() == 1); - - auto function = make_uniq(std::move(expressions[0])); - for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { - function->parameters.push_back( - make_uniq(default_macro.parameters[param_idx])); - } - for (idx_t named_idx = 0; default_macro.named_parameters[named_idx].name != nullptr; named_idx++) { - auto expr_list = Parser::ParseExpressionList(default_macro.named_parameters[named_idx].default_value); - if (expr_list.size() != 1) { - throw InternalException("Expected a single expression"); - } - function->default_parameters.insert( - make_pair(default_macro.named_parameters[named_idx].name, std::move(expr_list[0]))); - } - D_ASSERT(function->type == MacroType::SCALAR_MACRO); - bind_info->macros.push_back(std::move(function)); - } - bind_info->schema = macros[0].schema; - bind_info->name = macros[0].name; - bind_info->temporary = true; - bind_info->internal = true; - return bind_info; -} - -static bool DefaultFunctionMatches(const DefaultMacro ¯o, const string &schema, const string &name) { - return macro.schema == schema && macro.name == name; -} - -static unique_ptr GetDefaultFunction(const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); - for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { - if (DefaultFunctionMatches(internal_macros[index], schema, name)) { - // found the function! keep on iterating to find all overloads - idx_t overload_count; - for(overload_count = 1; internal_macros[index + overload_count].name; overload_count++) { - if (!DefaultFunctionMatches(internal_macros[index + overload_count], schema, name)) { - break; - } - } - return DefaultFunctionGenerator::CreateInternalMacroInfo(array_ptr(internal_macros + index, overload_count)); - } - } - return nullptr; -} - -DefaultFunctionGenerator::DefaultFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultFunctionGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { - auto info = GetDefaultFunction(schema.name, entry_name); - if (info) { - return make_uniq_base(catalog, schema, info->Cast()); - } - return nullptr; -} - -vector DefaultFunctionGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_macros[index].name != nullptr; index++) { - if (StringUtil::Lower(internal_macros[index].name) != internal_macros[index].name) { - throw InternalException("Default macro name %s should be lowercase", internal_macros[index].name); - } - if (internal_macros[index].schema == schema.name) { - result.emplace_back(internal_macros[index].name); - } - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_generator.cpp b/src/duckdb/src/catalog/default/default_generator.cpp deleted file mode 100644 index 2fbb2b646..000000000 --- a/src/duckdb/src/catalog/default/default_generator.cpp +++ /dev/null @@ -1,24 +0,0 @@ -#include "duckdb/catalog/default/default_generator.hpp" -#include "duckdb/catalog/catalog_transaction.hpp" - -namespace duckdb { - -DefaultGenerator::DefaultGenerator(Catalog &catalog) : catalog(catalog), created_all_entries(false) { -} -DefaultGenerator::~DefaultGenerator() { -} - -unique_ptr DefaultGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - throw InternalException("CreateDefaultEntry with ClientContext called but not supported in this generator"); -} - -unique_ptr DefaultGenerator::CreateDefaultEntry(CatalogTransaction transaction, - const string &entry_name) { - if (!transaction.context) { - // no context - cannot create default entry - return nullptr; - } - return CreateDefaultEntry(*transaction.context, entry_name); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_schemas.cpp b/src/duckdb/src/catalog/default/default_schemas.cpp deleted file mode 100644 index 64aaf56d2..000000000 --- a/src/duckdb/src/catalog/default/default_schemas.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "duckdb/catalog/default/default_schemas.hpp" -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -struct DefaultSchema { - const char *name; -}; - -static const DefaultSchema internal_schemas[] = {{"information_schema"}, {"pg_catalog"}, {nullptr}}; - -bool DefaultSchemaGenerator::IsDefaultSchema(const string &input_schema) { - auto schema = StringUtil::Lower(input_schema); - for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { - if (internal_schemas[index].name == schema) { - return true; - } - } - return false; -} - -DefaultSchemaGenerator::DefaultSchemaGenerator(Catalog &catalog) : DefaultGenerator(catalog) { -} - -unique_ptr DefaultSchemaGenerator::CreateDefaultEntry(CatalogTransaction transaction, - const string &entry_name) { - if (IsDefaultSchema(entry_name)) { - CreateSchemaInfo info; - info.schema = StringUtil::Lower(entry_name); - info.internal = true; - return make_uniq_base(catalog, info); - } - return nullptr; -} - -vector DefaultSchemaGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_schemas[index].name != nullptr; index++) { - result.emplace_back(internal_schemas[index].name); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_table_functions.cpp b/src/duckdb/src/catalog/default/default_table_functions.cpp deleted file mode 100644 index b0755c834..000000000 --- a/src/duckdb/src/catalog/default/default_table_functions.cpp +++ /dev/null @@ -1,148 +0,0 @@ -#include "duckdb/catalog/default/default_table_functions.hpp" -#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" -#include "duckdb/parser/parser.hpp" -#include "duckdb/parser/parsed_data/create_macro_info.hpp" -#include "duckdb/parser/statement/select_statement.hpp" -#include "duckdb/function/table_macro_function.hpp" - -namespace duckdb { - -// clang-format off -static const DefaultTableMacro internal_table_macros[] = { - {DEFAULT_SCHEMA, "histogram_values", {"source", "col_name", nullptr}, {{"bin_count", "10"}, {"technique", "'auto'"}, {nullptr, nullptr}}, R"( -WITH bins AS ( - SELECT - CASE - WHEN (NOT (can_cast_implicitly(MIN(col_name), NULL::BIGINT) OR - can_cast_implicitly(MIN(col_name), NULL::DOUBLE) OR - can_cast_implicitly(MIN(col_name), NULL::TIMESTAMP)) AND technique='auto') - OR technique='sample' - THEN - approx_top_k(col_name, bin_count) - WHEN technique='equi-height' - THEN - quantile(col_name, [x / bin_count::DOUBLE for x in generate_series(1, bin_count)]) - WHEN technique='equi-width' - THEN - equi_width_bins(MIN(col_name), MAX(col_name), bin_count, false) - WHEN technique='equi-width-nice' OR technique='auto' - THEN - equi_width_bins(MIN(col_name), MAX(col_name), bin_count, true) - ELSE - error(concat('Unrecognized technique ', technique)) - END AS bins - FROM query_table(source::VARCHAR) - ) -SELECT UNNEST(map_keys(histogram)) AS bin, UNNEST(map_values(histogram)) AS count -FROM ( - SELECT CASE - WHEN (NOT (can_cast_implicitly(MIN(col_name), NULL::BIGINT) OR - can_cast_implicitly(MIN(col_name), NULL::DOUBLE) OR - can_cast_implicitly(MIN(col_name), NULL::TIMESTAMP)) AND technique='auto') - OR technique='sample' - THEN - histogram_exact(col_name, bins) - ELSE - histogram(col_name, bins) - END AS histogram - FROM query_table(source::VARCHAR), bins -); -)"}, - {DEFAULT_SCHEMA, "histogram", {"source", "col_name", nullptr}, {{"bin_count", "10"}, {"technique", "'auto'"}, {nullptr, nullptr}}, R"( -SELECT - CASE - WHEN is_histogram_other_bin(bin) - THEN '(other values)' - WHEN (NOT (can_cast_implicitly(bin, NULL::BIGINT) OR - can_cast_implicitly(bin, NULL::DOUBLE) OR - can_cast_implicitly(bin, NULL::TIMESTAMP)) AND technique='auto') - OR technique='sample' - THEN bin::VARCHAR - WHEN row_number() over () = 1 - THEN concat('x <= ', bin::VARCHAR) - ELSE concat(lag(bin::VARCHAR) over (), ' < x <= ', bin::VARCHAR) - END AS bin, - count, - bar(count, 0, max(count) over ()) AS bar -FROM histogram_values(source, col_name, bin_count := bin_count, technique := technique); -)"}, - {nullptr, nullptr, {nullptr}, {{nullptr, nullptr}}, nullptr} - }; -// clang-format on - -DefaultTableFunctionGenerator::DefaultTableFunctionGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr -DefaultTableFunctionGenerator::CreateInternalTableMacroInfo(const DefaultTableMacro &default_macro, - unique_ptr function) { - for (idx_t param_idx = 0; default_macro.parameters[param_idx] != nullptr; param_idx++) { - function->parameters.push_back(make_uniq(default_macro.parameters[param_idx])); - } - for (idx_t named_idx = 0; default_macro.named_parameters[named_idx].name != nullptr; named_idx++) { - auto expr_list = Parser::ParseExpressionList(default_macro.named_parameters[named_idx].default_value); - if (expr_list.size() != 1) { - throw InternalException("Expected a single expression"); - } - function->default_parameters.insert( - make_pair(default_macro.named_parameters[named_idx].name, std::move(expr_list[0]))); - } - - auto type = CatalogType::TABLE_MACRO_ENTRY; - auto bind_info = make_uniq(type); - bind_info->schema = default_macro.schema; - bind_info->name = default_macro.name; - bind_info->temporary = true; - bind_info->internal = true; - bind_info->macros.push_back(std::move(function)); - return bind_info; -} - -unique_ptr -DefaultTableFunctionGenerator::CreateTableMacroInfo(const DefaultTableMacro &default_macro) { - Parser parser; - parser.ParseQuery(default_macro.macro); - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw InternalException("Expected a single select statement in CreateTableMacroInfo internal"); - } - auto node = std::move(parser.statements[0]->Cast().node); - - auto result = make_uniq(std::move(node)); - return CreateInternalTableMacroInfo(default_macro, std::move(result)); -} - -static unique_ptr GetDefaultTableFunction(const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); - for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { - if (internal_table_macros[index].schema == schema && internal_table_macros[index].name == name) { - return DefaultTableFunctionGenerator::CreateTableMacroInfo(internal_table_macros[index]); - } - } - return nullptr; -} - -unique_ptr DefaultTableFunctionGenerator::CreateDefaultEntry(ClientContext &context, - const string &entry_name) { - auto info = GetDefaultTableFunction(schema.name, entry_name); - if (info) { - return make_uniq_base(catalog, schema, info->Cast()); - } - return nullptr; -} - -vector DefaultTableFunctionGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_table_macros[index].name != nullptr; index++) { - if (StringUtil::Lower(internal_table_macros[index].name) != internal_table_macros[index].name) { - throw InternalException("Default macro name %s should be lowercase", internal_table_macros[index].name); - } - if (internal_table_macros[index].schema == schema.name) { - result.emplace_back(internal_table_macros[index].name); - } - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_types.cpp b/src/duckdb/src/catalog/default/default_types.cpp deleted file mode 100644 index 23edac049..000000000 --- a/src/duckdb/src/catalog/default/default_types.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "duckdb/catalog/default/default_types.hpp" - -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/catalog/default/builtin_types/types.hpp" - -namespace duckdb { - -LogicalTypeId DefaultTypeGenerator::GetDefaultType(const string &name) { - auto &internal_types = BUILTIN_TYPES; - for (auto &type : internal_types) { - if (StringUtil::CIEquals(name, type.name)) { - return type.type; - } - } - return LogicalType::INVALID; -} - -DefaultTypeGenerator::DefaultTypeGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultTypeGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - if (schema.name != DEFAULT_SCHEMA) { - return nullptr; - } - auto type_id = GetDefaultType(entry_name); - if (type_id == LogicalTypeId::INVALID) { - return nullptr; - } - CreateTypeInfo info; - info.name = entry_name; - info.type = LogicalType(type_id); - info.internal = true; - info.temporary = true; - return make_uniq_base(catalog, schema, info); -} - -vector DefaultTypeGenerator::GetDefaultEntries() { - vector result; - if (schema.name != DEFAULT_SCHEMA) { - return result; - } - auto &internal_types = BUILTIN_TYPES; - for (auto &type : internal_types) { - result.emplace_back(StringUtil::Lower(type.name)); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/default/default_views.cpp b/src/duckdb/src/catalog/default/default_views.cpp deleted file mode 100644 index 718696471..000000000 --- a/src/duckdb/src/catalog/default/default_views.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "duckdb/catalog/default/default_views.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -struct DefaultView { - const char *schema; - const char *name; - const char *sql; -}; - -static const DefaultView internal_views[] = { - {DEFAULT_SCHEMA, "pragma_database_list", "SELECT database_oid AS seq, database_name AS name, path AS file FROM duckdb_databases() WHERE NOT internal ORDER BY 1"}, - {DEFAULT_SCHEMA, "sqlite_master", "select 'table' \"type\", table_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_tables union all select 'view' \"type\", view_name \"name\", view_name \"tbl_name\", 0 rootpage, sql from duckdb_views union all select 'index' \"type\", index_name \"name\", table_name \"tbl_name\", 0 rootpage, sql from duckdb_indexes;"}, - {DEFAULT_SCHEMA, "sqlite_schema", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "sqlite_temp_master", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "sqlite_temp_schema", "SELECT * FROM sqlite_master"}, - {DEFAULT_SCHEMA, "duckdb_constraints", "SELECT * FROM duckdb_constraints()"}, - {DEFAULT_SCHEMA, "duckdb_columns", "SELECT * FROM duckdb_columns() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_databases", "SELECT * FROM duckdb_databases() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_indexes", "SELECT * FROM duckdb_indexes()"}, - {DEFAULT_SCHEMA, "duckdb_schemas", "SELECT * FROM duckdb_schemas() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_tables", "SELECT * FROM duckdb_tables() WHERE NOT internal"}, - {DEFAULT_SCHEMA, "duckdb_types", "SELECT * FROM duckdb_types()"}, - {DEFAULT_SCHEMA, "duckdb_views", "SELECT * FROM duckdb_views() WHERE NOT internal"}, - {"pg_catalog", "pg_am", "SELECT 0 oid, 'art' amname, NULL amhandler, 'i' amtype"}, - {"pg_catalog", "pg_attribute", "SELECT table_oid attrelid, column_name attname, data_type_id atttypid, 0 attstattarget, NULL attlen, column_index attnum, 0 attndims, -1 attcacheoff, case when data_type ilike '%decimal%' then numeric_precision*1000+numeric_scale else -1 end atttypmod, false attbyval, NULL attstorage, NULL attalign, NOT is_nullable attnotnull, column_default IS NOT NULL atthasdef, false atthasmissing, '' attidentity, '' attgenerated, false attisdropped, true attislocal, 0 attinhcount, 0 attcollation, NULL attcompression, NULL attacl, NULL attoptions, NULL attfdwoptions, NULL attmissingval FROM duckdb_columns()"}, - {"pg_catalog", "pg_attrdef", "SELECT column_index oid, table_oid adrelid, column_index adnum, column_default adbin from duckdb_columns() where column_default is not null;"}, - {"pg_catalog", "pg_class", "SELECT table_oid oid, table_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, estimated_size::real reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, index_count > 0 relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'r' relkind, column_count relnatts, check_constraint_count relchecks, false relhasoids, has_primary_key relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_tables() UNION ALL SELECT view_oid oid, view_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'v' relkind, column_count relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_views() UNION ALL SELECT sequence_oid oid, sequence_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, case when temporary then 't' else 'p' end relpersistence, 'S' relkind, 0 relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_sequences() UNION ALL SELECT index_oid oid, index_name relname, schema_oid relnamespace, 0 reltype, 0 reloftype, 0 relowner, 0 relam, 0 relfilenode, 0 reltablespace, 0 relpages, 0 reltuples, 0 relallvisible, 0 reltoastrelid, 0 reltoastidxid, false relhasindex, false relisshared, 't' relpersistence, 'i' relkind, NULL relnatts, 0 relchecks, false relhasoids, false relhaspkey, false relhasrules, false relhastriggers, false relhassubclass, false relrowsecurity, true relispopulated, NULL relreplident, false relispartition, 0 relrewrite, 0 relfrozenxid, NULL relminmxid, NULL relacl, NULL reloptions, NULL relpartbound FROM duckdb_indexes()"}, - {"pg_catalog", "pg_constraint", "SELECT table_oid*1000000+constraint_index oid, constraint_text conname, schema_oid connamespace, CASE constraint_type WHEN 'CHECK' then 'c' WHEN 'UNIQUE' then 'u' WHEN 'PRIMARY KEY' THEN 'p' WHEN 'FOREIGN KEY' THEN 'f' ELSE 'x' END contype, false condeferrable, false condeferred, true convalidated, table_oid conrelid, 0 contypid, 0 conindid, 0 conparentid, 0 confrelid, NULL confupdtype, NULL confdeltype, NULL confmatchtype, true conislocal, 0 coninhcount, false connoinherit, constraint_column_indexes conkey, NULL confkey, NULL conpfeqop, NULL conppeqop, NULL conffeqop, NULL conexclop, expression conbin FROM duckdb_constraints()"}, - {"pg_catalog", "pg_database", "SELECT database_oid oid, database_name datname FROM duckdb_databases()"}, - {"pg_catalog", "pg_depend", "SELECT * FROM duckdb_dependencies()"}, - {"pg_catalog", "pg_description", "SELECT table_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_tables() WHERE NOT internal UNION ALL SELECT table_oid AS objoid, database_oid AS classoid, column_index AS objsubid, comment AS description FROM duckdb_columns() WHERE NOT internal UNION ALL SELECT view_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_views() WHERE NOT internal UNION ALL SELECT index_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_indexes UNION ALL SELECT sequence_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_sequences() UNION ALL SELECT type_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_types() WHERE NOT internal UNION ALL SELECT function_oid AS objoid, database_oid AS classoid, 0 AS objsubid, comment AS description FROM duckdb_functions() WHERE NOT internal;"}, - {"pg_catalog", "pg_enum", "SELECT NULL oid, a.type_oid enumtypid, list_position(b.labels, a.elabel) enumsortorder, a.elabel enumlabel FROM (SELECT UNNEST(labels) elabel, type_oid FROM duckdb_types() WHERE logical_type='ENUM') a JOIN duckdb_types() b ON a.type_oid=b.type_oid;"}, - {"pg_catalog", "pg_index", "SELECT index_oid indexrelid, table_oid indrelid, 0 indnatts, 0 indnkeyatts, is_unique indisunique, is_primary indisprimary, false indisexclusion, true indimmediate, false indisclustered, true indisvalid, false indcheckxmin, true indisready, true indislive, false indisreplident, NULL::INT[] indkey, NULL::OID[] indcollation, NULL::OID[] indclass, NULL::INT[] indoption, expressions indexprs, NULL indpred FROM duckdb_indexes()"}, - {"pg_catalog", "pg_indexes", "SELECT schema_name schemaname, table_name tablename, index_name indexname, NULL \"tablespace\", sql indexdef FROM duckdb_indexes()"}, - {"pg_catalog", "pg_namespace", "SELECT oid, schema_name nspname, 0 nspowner, NULL nspacl FROM duckdb_schemas()"}, - {"pg_catalog", "pg_proc", "SELECT f.function_oid oid, function_name proname, s.oid pronamespace, NULL proowner, NULL prolang, 0 procost, 0 prorows, varargs provariadic, 0 prosupport, CASE function_type WHEN 'aggregate' THEN 'a' ELSE 'f' END prokind, false prosecdef, false proleakproof, false proisstrict, function_type = 'table' proretset, case (stability) when 'CONSISTENT' then 'i' when 'CONSISTENT_WITHIN_QUERY' then 's' when 'VOLATILE' then 'v' end provolatile, 'u' proparallel, length(parameters) pronargs, 0 pronargdefaults, return_type prorettype, parameter_types proargtypes, NULL proallargtypes, NULL proargmodes, parameters proargnames, NULL proargdefaults, NULL protrftypes, NULL prosrc, NULL probin, macro_definition prosqlbody, NULL proconfig, NULL proacl, function_type = 'aggregate' proisagg, FROM duckdb_functions() f LEFT JOIN duckdb_schemas() s USING (database_name, schema_name)"}, - {"pg_catalog", "pg_sequence", "SELECT sequence_oid seqrelid, 0 seqtypid, start_value seqstart, increment_by seqincrement, max_value seqmax, min_value seqmin, 0 seqcache, cycle seqcycle FROM duckdb_sequences()"}, - {"pg_catalog", "pg_sequences", "SELECT schema_name schemaname, sequence_name sequencename, 'duckdb' sequenceowner, 0 data_type, start_value, min_value, max_value, increment_by, cycle, 0 cache_size, last_value FROM duckdb_sequences()"}, - {"pg_catalog", "pg_settings", "SELECT name, value setting, description short_desc, CASE WHEN input_type = 'VARCHAR' THEN 'string' WHEN input_type = 'BOOLEAN' THEN 'bool' WHEN input_type IN ('BIGINT', 'UBIGINT') THEN 'integer' ELSE input_type END vartype FROM duckdb_settings()"}, - {"pg_catalog", "pg_tables", "SELECT schema_name schemaname, table_name tablename, 'duckdb' tableowner, NULL \"tablespace\", index_count > 0 hasindexes, false hasrules, false hastriggers FROM duckdb_tables()"}, - {"pg_catalog", "pg_tablespace", "SELECT 0 oid, 'pg_default' spcname, 0 spcowner, NULL spcacl, NULL spcoptions"}, - {"pg_catalog", "pg_type", "SELECT CASE WHEN type_oid IS NULL THEN NULL WHEN logical_type = 'ENUM' AND type_name <> 'enum' THEN type_oid ELSE map_to_pg_oid(type_name) END oid, format_pg_type(logical_type, type_name) typname, schema_oid typnamespace, 0 typowner, type_size typlen, false typbyval, CASE WHEN logical_type='ENUM' THEN 'e' else 'b' end typtype, CASE WHEN type_category='NUMERIC' THEN 'N' WHEN type_category='STRING' THEN 'S' WHEN type_category='DATETIME' THEN 'D' WHEN type_category='BOOLEAN' THEN 'B' WHEN type_category='COMPOSITE' THEN 'C' WHEN type_category='USER' THEN 'U' ELSE 'X' END typcategory, false typispreferred, true typisdefined, NULL typdelim, NULL typrelid, NULL typsubscript, NULL typelem, NULL typarray, NULL typinput, NULL typoutput, NULL typreceive, NULL typsend, NULL typmodin, NULL typmodout, NULL typanalyze, 'd' typalign, 'p' typstorage, NULL typnotnull, NULL typbasetype, NULL typtypmod, NULL typndims, NULL typcollation, NULL typdefaultbin, NULL typdefault, NULL typacl FROM duckdb_types() WHERE type_oid IS NOT NULL;"}, - {"pg_catalog", "pg_views", "SELECT schema_name schemaname, view_name viewname, 'duckdb' viewowner, sql definition FROM duckdb_views()"}, - {"information_schema", "columns", "SELECT database_name table_catalog, schema_name table_schema, table_name, column_name, column_index ordinal_position, column_default, CASE WHEN is_nullable THEN 'YES' ELSE 'NO' END is_nullable, data_type, character_maximum_length, NULL::INT character_octet_length, numeric_precision, numeric_precision_radix, numeric_scale, NULL::INT datetime_precision, NULL::VARCHAR interval_type, NULL::INT interval_precision, NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, NULL::VARCHAR character_set_name, NULL::VARCHAR collation_catalog, NULL::VARCHAR collation_schema, NULL::VARCHAR collation_name, NULL::VARCHAR domain_catalog, NULL::VARCHAR domain_schema, NULL::VARCHAR domain_name, NULL::VARCHAR udt_catalog, NULL::VARCHAR udt_schema, NULL::VARCHAR udt_name, NULL::VARCHAR scope_catalog, NULL::VARCHAR scope_schema, NULL::VARCHAR scope_name, NULL::BIGINT maximum_cardinality, NULL::VARCHAR dtd_identifier, NULL::BOOL is_self_referencing, NULL::BOOL is_identity, NULL::VARCHAR identity_generation, NULL::VARCHAR identity_start, NULL::VARCHAR identity_increment, NULL::VARCHAR identity_maximum, NULL::VARCHAR identity_minimum, NULL::BOOL identity_cycle, NULL::VARCHAR is_generated, NULL::VARCHAR generation_expression, NULL::BOOL is_updatable, comment AS COLUMN_COMMENT FROM duckdb_columns;"}, - {"information_schema", "schemata", "SELECT database_name catalog_name, schema_name, 'duckdb' schema_owner, NULL::VARCHAR default_character_set_catalog, NULL::VARCHAR default_character_set_schema, NULL::VARCHAR default_character_set_name, sql sql_path FROM duckdb_schemas()"}, - {"information_schema", "tables", "SELECT database_name table_catalog, schema_name table_schema, table_name, CASE WHEN temporary THEN 'LOCAL TEMPORARY' ELSE 'BASE TABLE' END table_type, NULL::VARCHAR self_referencing_column_name, NULL::VARCHAR reference_generation, NULL::VARCHAR user_defined_type_catalog, NULL::VARCHAR user_defined_type_schema, NULL::VARCHAR user_defined_type_name, 'YES' is_insertable_into, 'NO' is_typed, CASE WHEN temporary THEN 'PRESERVE' ELSE NULL END commit_action, comment AS TABLE_COMMENT FROM duckdb_tables() UNION ALL SELECT database_name table_catalog, schema_name table_schema, view_name table_name, 'VIEW' table_type, NULL self_referencing_column_name, NULL reference_generation, NULL user_defined_type_catalog, NULL user_defined_type_schema, NULL user_defined_type_name, 'NO' is_insertable_into, 'NO' is_typed, NULL commit_action, comment AS TABLE_COMMENT FROM duckdb_views;"}, - {"information_schema", "character_sets", "SELECT NULL::VARCHAR character_set_catalog, NULL::VARCHAR character_set_schema, 'UTF8' character_set_name, 'UCS' character_repertoire, 'UTF8' form_of_use, current_database() default_collate_catalog, 'pg_catalog' default_collate_schema, 'ucs_basic' default_collate_name;"}, - {"information_schema", "referential_constraints", "SELECT f.database_name constraint_catalog, f.schema_name constraint_schema, f.constraint_name constraint_name, c.database_name unique_constraint_catalog, c.schema_name unique_constraint_schema, c.constraint_name unique_constraint_name, 'NONE' match_option, 'NO ACTION' update_rule, 'NO ACTION' delete_rule FROM duckdb_constraints() c, duckdb_constraints() f WHERE f.constraint_type = 'FOREIGN KEY' AND (c.constraint_type = 'UNIQUE' OR c.constraint_type = 'PRIMARY KEY') AND f.database_oid = c.database_oid AND f.schema_oid = c.schema_oid AND lower(f.referenced_table) = lower(c.table_name) AND [lower(x) for x in f.referenced_column_names] = [lower(x) for x in c.constraint_column_names]"}, - {"information_schema", "key_column_usage", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, UNNEST(constraint_column_names) column_name, UNNEST(generate_series(1, len(constraint_column_names))) ordinal_position, CASE constraint_type WHEN 'FOREIGN KEY' THEN 1 ELSE NULL END position_in_unique_constraint FROM duckdb_constraints() WHERE constraint_type = 'FOREIGN KEY' OR constraint_type = 'PRIMARY KEY' OR constraint_type = 'UNIQUE';"}, - {"information_schema", "table_constraints", "SELECT database_name constraint_catalog, schema_name constraint_schema, constraint_name, database_name table_catalog, schema_name table_schema, table_name, CASE constraint_type WHEN 'NOT NULL' THEN 'CHECK' ELSE constraint_type END constraint_type, 'NO' is_deferrable, 'NO' initially_deferred, 'YES' enforced, 'YES' nulls_distinct FROM duckdb_constraints() WHERE constraint_type = 'PRIMARY KEY' OR constraint_type = 'FOREIGN KEY' OR constraint_type = 'UNIQUE' OR constraint_type = 'CHECK' OR constraint_type = 'NOT NULL';"}, - {"information_schema", "constraint_column_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, column_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type, constraint_text FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type NOT IN ('NOT NULL') );"}, - {"information_schema", "constraint_table_usage", "SELECT database_name AS table_catalog, schema_name AS table_schema, table_name, database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, constraint_type FROM duckdb_constraints() WHERE constraint_type NOT IN ('NOT NULL');"}, - {"information_schema", "check_constraints", "SELECT database_name AS constraint_catalog, schema_name AS constraint_schema, constraint_name, CASE constraint_type WHEN 'NOT NULL' THEN column_name || ' IS NOT NULL' ELSE constraint_text END AS check_clause FROM (SELECT dc.*, UNNEST(dc.constraint_column_names) AS column_name FROM duckdb_constraints() AS dc WHERE constraint_type IN ('CHECK', 'NOT NULL'));"}, - {"information_schema", "views", "SELECT database_name AS table_catalog, schema_name AS table_schema, view_name AS table_name, sql AS view_definition, 'NONE' AS check_option, 'NO' AS is_updatable, 'NO' AS is_insertable_into, 'NO' AS is_trigger_updatable, 'NO' AS is_trigger_deletable, 'NO' AS is_trigger_insertable_into FROM duckdb_views();"}, - {nullptr, nullptr, nullptr}}; - -static unique_ptr GetDefaultView(ClientContext &context, const string &input_schema, const string &input_name) { - auto schema = StringUtil::Lower(input_schema); - auto name = StringUtil::Lower(input_name); - for (idx_t index = 0; internal_views[index].name != nullptr; index++) { - if (internal_views[index].schema == schema && internal_views[index].name == name) { - auto result = make_uniq(); - result->schema = schema; - result->view_name = name; - result->sql = internal_views[index].sql; - result->temporary = true; - result->internal = true; - - return CreateViewInfo::FromSelect(context, std::move(result)); - } - } - return nullptr; -} - -DefaultViewGenerator::DefaultViewGenerator(Catalog &catalog, SchemaCatalogEntry &schema) - : DefaultGenerator(catalog), schema(schema) { -} - -unique_ptr DefaultViewGenerator::CreateDefaultEntry(ClientContext &context, const string &entry_name) { - auto info = GetDefaultView(context, schema.name, entry_name); - if (info) { - return make_uniq_base(catalog, schema, *info); - } - return nullptr; -} - -vector DefaultViewGenerator::GetDefaultEntries() { - vector result; - for (idx_t index = 0; internal_views[index].name != nullptr; index++) { - if (internal_views[index].schema == schema.name) { - result.emplace_back(internal_views[index].name); - } - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/dependency_catalog_set.cpp b/src/duckdb/src/catalog/dependency_catalog_set.cpp deleted file mode 100644 index bfb3862a4..000000000 --- a/src/duckdb/src/catalog/dependency_catalog_set.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "duckdb/catalog/dependency_catalog_set.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/dependency_list.hpp" - -namespace duckdb { - -MangledDependencyName DependencyCatalogSet::ApplyPrefix(const MangledEntryName &name) const { - return MangledDependencyName(mangled_name, name); -} - -bool DependencyCatalogSet::CreateEntry(CatalogTransaction transaction, const MangledEntryName &name, - unique_ptr value) { - auto new_name = ApplyPrefix(name); - const LogicalDependencyList EMPTY_DEPENDENCIES; - return set.CreateEntry(transaction, new_name.name, std::move(value), EMPTY_DEPENDENCIES); -} - -CatalogSet::EntryLookup DependencyCatalogSet::GetEntryDetailed(CatalogTransaction transaction, - const MangledEntryName &name) { - auto new_name = ApplyPrefix(name); - return set.GetEntryDetailed(transaction, new_name.name); -} - -optional_ptr DependencyCatalogSet::GetEntry(CatalogTransaction transaction, - const MangledEntryName &name) { - auto new_name = ApplyPrefix(name); - return set.GetEntry(transaction, new_name.name); -} - -void DependencyCatalogSet::Scan(CatalogTransaction transaction, const std::function &callback) { - set.ScanWithPrefix( - transaction, - [&](CatalogEntry &entry) { - auto &dep = entry.Cast(); - auto &from = dep.SourceMangledName(); - if (!StringUtil::CIEquals(from.name, mangled_name.name)) { - return; - } - callback(entry); - }, - mangled_name.name); -} - -bool DependencyCatalogSet::DropEntry(CatalogTransaction transaction, const MangledEntryName &name, bool cascade, - bool allow_drop_internal) { - auto new_name = ApplyPrefix(name); - return set.DropEntry(transaction, new_name.name, cascade, allow_drop_internal); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/dependency_list.cpp b/src/duckdb/src/catalog/dependency_list.cpp deleted file mode 100644 index 1ef3be15f..000000000 --- a/src/duckdb/src/catalog/dependency_list.cpp +++ /dev/null @@ -1,110 +0,0 @@ -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -uint64_t LogicalDependencyHashFunction::operator()(const LogicalDependency &a) const { - auto &name = a.entry.name; - auto &schema = a.entry.schema; - auto &type = a.entry.type; - auto &catalog = a.catalog; - - hash_t hash = duckdb::Hash(name.c_str()); - hash = CombineHash(hash, duckdb::Hash(schema.c_str())); - hash = CombineHash(hash, duckdb::Hash(catalog.c_str())); - hash = CombineHash(hash, duckdb::Hash(static_cast(type))); - return hash; -} - -bool LogicalDependencyEquality::operator()(const LogicalDependency &a, const LogicalDependency &b) const { - if (a.entry.type != b.entry.type) { - return false; - } - if (a.entry.name != b.entry.name) { - return false; - } - if (a.entry.schema != b.entry.schema) { - return false; - } - if (a.catalog != b.catalog) { - return false; - } - return true; -} - -LogicalDependency::LogicalDependency() : entry(), catalog() { -} - -static string GetSchema(CatalogEntry &entry) { - if (entry.type == CatalogType::SCHEMA_ENTRY) { - return entry.name; - } - return entry.ParentSchema().name; -} - -LogicalDependency::LogicalDependency(CatalogEntry &entry) { - catalog = INVALID_CATALOG; - if (entry.type == CatalogType::DEPENDENCY_ENTRY) { - auto &dependency_entry = entry.Cast(); - - this->entry = dependency_entry.EntryInfo(); - } else { - this->entry.schema = GetSchema(entry); - this->entry.name = entry.name; - this->entry.type = entry.type; - catalog = entry.ParentCatalog().GetName(); - } -} - -bool LogicalDependency::operator==(const LogicalDependency &other) const { - return other.entry.name == entry.name && other.entry.schema == entry.schema && other.entry.type == entry.type; -} - -void LogicalDependencyList::AddDependency(CatalogEntry &entry) { - LogicalDependency dependency(entry); - set.insert(dependency); -} - -void LogicalDependencyList::AddDependency(const LogicalDependency &entry) { - set.insert(entry); -} - -bool LogicalDependencyList::Contains(CatalogEntry &entry_p) { - LogicalDependency logical_entry(entry_p); - return set.count(logical_entry); -} - -void LogicalDependencyList::VerifyDependencies(Catalog &catalog, const string &name) { - for (auto &dep : set) { - if (dep.catalog != catalog.GetName()) { - throw DependencyException( - "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " - "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", - name, dep.entry.name, dep.catalog, catalog.GetName()); - } - } -} - -const LogicalDependencyList::create_info_set_t &LogicalDependencyList::Set() const { - return set; -} - -bool LogicalDependencyList::operator==(const LogicalDependencyList &other) const { - if (set.size() != other.set.size()) { - return false; - } - - for (auto &entry : set) { - if (!other.set.count(entry)) { - return false; - } - } - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/dependency_manager.cpp b/src/duckdb/src/catalog/dependency_manager.cpp deleted file mode 100644 index 4fcb4d4d1..000000000 --- a/src/duckdb/src/catalog/dependency_manager.cpp +++ /dev/null @@ -1,822 +0,0 @@ -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/catalog_entry.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_subject_entry.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_dependent_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/common/queue.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/parser/constraints/foreign_key_constraint.hpp" -#include "duckdb/catalog/dependency_catalog_set.hpp" - -namespace duckdb { - -static void AssertMangledName(const string &mangled_name, idx_t expected_null_bytes) { -#ifdef DEBUG - idx_t nullbyte_count = 0; - for (auto &ch : mangled_name) { - nullbyte_count += ch == '\0'; - } - D_ASSERT(nullbyte_count == expected_null_bytes); -#endif -} - -MangledEntryName::MangledEntryName(const CatalogEntryInfo &info) { - auto &type = info.type; - auto &schema = info.schema; - auto &name = info.name; - - this->name = CatalogTypeToString(type) + '\0' + schema + '\0' + name; - AssertMangledName(this->name, 2); -} - -MangledDependencyName::MangledDependencyName(const MangledEntryName &from, const MangledEntryName &to) { - this->name = from.name + '\0' + to.name; - AssertMangledName(this->name, 5); -} - -DependencyManager::DependencyManager(DuckCatalog &catalog) : catalog(catalog), subjects(catalog), dependents(catalog) { -} - -string DependencyManager::GetSchema(const CatalogEntry &entry) { - if (entry.type == CatalogType::SCHEMA_ENTRY) { - return entry.name; - } - return entry.ParentSchema().name; -} - -MangledEntryName DependencyManager::MangleName(const CatalogEntryInfo &info) { - return MangledEntryName(info); -} - -MangledEntryName DependencyManager::MangleName(const CatalogEntry &entry) { - if (entry.type == CatalogType::DEPENDENCY_ENTRY) { - auto &dependency_entry = entry.Cast(); - return dependency_entry.EntryMangledName(); - } - auto type = entry.type; - auto schema = GetSchema(entry); - auto name = entry.name; - CatalogEntryInfo info {type, schema, name}; - - return MangleName(info); -} - -DependencyInfo DependencyInfo::FromSubject(DependencyEntry &dep) { - return DependencyInfo {/*dependent = */ dep.Dependent(), - /*subject = */ dep.Subject()}; -} - -DependencyInfo DependencyInfo::FromDependent(DependencyEntry &dep) { - return DependencyInfo {/*dependent = */ dep.Dependent(), - /*subject = */ dep.Subject()}; -} - -// ----------- DEPENDENCY_MANAGER ----------- - -bool DependencyManager::IsSystemEntry(CatalogEntry &entry) const { - if (entry.internal) { - return true; - } - - switch (entry.type) { - case CatalogType::DEPENDENCY_ENTRY: - case CatalogType::DATABASE_ENTRY: - case CatalogType::RENAMED_ENTRY: - return true; - default: - return false; - } -} - -CatalogSet &DependencyManager::Dependents() { - return dependents; -} - -CatalogSet &DependencyManager::Subjects() { - return subjects; -} - -void DependencyManager::ScanSetInternal(CatalogTransaction transaction, const CatalogEntryInfo &info, - bool scan_subjects, dependency_callback_t &callback) { - catalog_entry_set_t other_entries; - - auto cb = [&](CatalogEntry &other) { - D_ASSERT(other.type == CatalogType::DEPENDENCY_ENTRY); - auto &other_entry = other.Cast(); -#ifdef DEBUG - auto side = other_entry.Side(); - if (scan_subjects) { - D_ASSERT(side == DependencyEntryType::SUBJECT); - } else { - D_ASSERT(side == DependencyEntryType::DEPENDENT); - } - -#endif - - other_entries.insert(other_entry); - callback(other_entry); - }; - - if (scan_subjects) { - DependencyCatalogSet subjects(Subjects(), info); - subjects.Scan(transaction, cb); - } else { - DependencyCatalogSet dependents(Dependents(), info); - dependents.Scan(transaction, cb); - } - -#ifdef DEBUG - // Verify some invariants - // Every dependency should have a matching dependent in the other set - // And vice versa - auto mangled_name = MangleName(info); - - if (scan_subjects) { - for (auto &entry : other_entries) { - auto other_info = GetLookupProperties(entry); - DependencyCatalogSet other_dependents(Dependents(), other_info); - - // Verify that the other half of the dependency also exists - auto dependent = other_dependents.GetEntryDetailed(transaction, mangled_name); - D_ASSERT(dependent.reason != CatalogSet::EntryLookup::FailureReason::NOT_PRESENT); - } - } else { - for (auto &entry : other_entries) { - auto other_info = GetLookupProperties(entry); - DependencyCatalogSet other_subjects(Subjects(), other_info); - - // Verify that the other half of the dependent also exists - auto subject = other_subjects.GetEntryDetailed(transaction, mangled_name); - D_ASSERT(subject.reason != CatalogSet::EntryLookup::FailureReason::NOT_PRESENT); - } - } -#endif -} - -void DependencyManager::ScanDependents(CatalogTransaction transaction, const CatalogEntryInfo &info, - dependency_callback_t &callback) { - ScanSetInternal(transaction, info, false, callback); -} - -void DependencyManager::ScanSubjects(CatalogTransaction transaction, const CatalogEntryInfo &info, - dependency_callback_t &callback) { - ScanSetInternal(transaction, info, true, callback); -} - -void DependencyManager::RemoveDependency(CatalogTransaction transaction, const DependencyInfo &info) { - auto &dependent = info.dependent; - auto &subject = info.subject; - - // The dependents of the dependency (target) - DependencyCatalogSet dependents(Dependents(), subject.entry); - // The subjects of the dependencies of the dependent - DependencyCatalogSet subjects(Subjects(), dependent.entry); - - auto dependent_mangled = MangledEntryName(dependent.entry); - auto subject_mangled = MangledEntryName(subject.entry); - - auto dependent_p = dependents.GetEntry(transaction, dependent_mangled); - if (dependent_p) { - // 'dependent' is no longer inhibiting the deletion of 'dependency' - dependents.DropEntry(transaction, dependent_mangled, false); - } - auto subject_p = subjects.GetEntry(transaction, subject_mangled); - if (subject_p) { - // 'dependency' is no longer required by 'dependent' - subjects.DropEntry(transaction, subject_mangled, false); - } -} - -void DependencyManager::CreateSubject(CatalogTransaction transaction, const DependencyInfo &info) { - auto &from = info.dependent.entry; - - DependencyCatalogSet set(Subjects(), from); - auto dep = make_uniq_base(catalog, info); - auto entry_name = dep->EntryMangledName(); - - //! Add to the list of objects that 'dependent' has a dependency on - set.CreateEntry(transaction, entry_name, std::move(dep)); -} - -void DependencyManager::CreateDependent(CatalogTransaction transaction, const DependencyInfo &info) { - auto &from = info.subject.entry; - - DependencyCatalogSet set(Dependents(), from); - auto dep = make_uniq_base(catalog, info); - auto entry_name = dep->EntryMangledName(); - - //! Add to the list of object that depend on 'subject' - set.CreateEntry(transaction, entry_name, std::move(dep)); -} - -void DependencyManager::CreateDependency(CatalogTransaction transaction, DependencyInfo &info) { - DependencyCatalogSet subjects(Subjects(), info.dependent.entry); - DependencyCatalogSet dependents(Dependents(), info.subject.entry); - - auto subject_mangled = MangleName(info.subject.entry); - auto dependent_mangled = MangleName(info.dependent.entry); - - auto &dependent_flags = info.dependent.flags; - auto &subject_flags = info.subject.flags; - - auto existing_subject = subjects.GetEntry(transaction, subject_mangled); - auto existing_dependent = dependents.GetEntry(transaction, dependent_mangled); - - // Inherit the existing flags and drop the existing entry if present - if (existing_subject) { - auto &existing = existing_subject->Cast(); - auto existing_flags = existing.Subject().flags; - if (existing_flags != subject_flags) { - subject_flags.Apply(existing_flags); - } - subjects.DropEntry(transaction, subject_mangled, false, false); - } - if (existing_dependent) { - auto &existing = existing_dependent->Cast(); - auto existing_flags = existing.Dependent().flags; - if (existing_flags != dependent_flags) { - dependent_flags.Apply(existing_flags); - } - dependents.DropEntry(transaction, dependent_mangled, false, false); - } - - // Create an entry in the dependents map of the object that is the target of the dependency - CreateDependent(transaction, info); - // Create an entry in the subjects map of the object that is targeting another entry - CreateSubject(transaction, info); -} - -void DependencyManager::CreateDependencies(CatalogTransaction transaction, const CatalogEntry &object, - const LogicalDependencyList &dependencies) { - DependencyDependentFlags dependency_flags; - if (object.type != CatalogType::INDEX_ENTRY) { - // indexes do not require CASCADE to be dropped, they are simply always dropped along with the table - dependency_flags.SetBlocking(); - } - - const auto object_info = GetLookupProperties(object); - // check for each object in the sources if they were not deleted yet - for (auto &dependency : dependencies.Set()) { - if (dependency.catalog != object.ParentCatalog().GetName()) { - throw DependencyException( - "Error adding dependency for object \"%s\" - dependency \"%s\" is in catalog " - "\"%s\", which does not match the catalog \"%s\".\nCross catalog dependencies are not supported.", - object.name, dependency.entry.name, dependency.catalog, object.ParentCatalog().GetName()); - } - } - - // add the object to the dependents_map of each object that it depends on - for (auto &dependency : dependencies.Set()) { - DependencyInfo info {/*dependent = */ DependencyDependent {GetLookupProperties(object), dependency_flags}, - /*subject = */ DependencySubject {dependency.entry, DependencySubjectFlags()}}; - CreateDependency(transaction, info); - } -} - -void DependencyManager::AddObject(CatalogTransaction transaction, CatalogEntry &object, - const LogicalDependencyList &dependencies) { - if (IsSystemEntry(object)) { - // Don't do anything for this - return; - } - CreateDependencies(transaction, object, dependencies); -} - -static bool CascadeDrop(bool cascade, const DependencyDependentFlags &flags) { - if (cascade) { - return true; - } - if (flags.IsOwnedBy()) { - // We are owned by this object, while it exists we can not be dropped without cascade. - return false; - } - return !flags.IsBlocking(); -} - -CatalogEntryInfo DependencyManager::GetLookupProperties(const CatalogEntry &entry) { - if (entry.type == CatalogType::DEPENDENCY_ENTRY) { - auto &dependency_entry = entry.Cast(); - return dependency_entry.EntryInfo(); - } else { - auto schema = DependencyManager::GetSchema(entry); - auto &name = entry.name; - auto &type = entry.type; - return CatalogEntryInfo {type, schema, name}; - } -} - -optional_ptr DependencyManager::LookupEntry(CatalogTransaction transaction, CatalogEntry &dependency) { - if (dependency.type != CatalogType::DEPENDENCY_ENTRY) { - return &dependency; - } - auto info = GetLookupProperties(dependency); - - auto &type = info.type; - auto &schema = info.schema; - auto &name = info.name; - - // Lookup the schema - auto schema_entry = catalog.GetSchema(transaction, schema, OnEntryNotFound::RETURN_NULL); - if (type == CatalogType::SCHEMA_ENTRY || !schema_entry) { - // This is a schema entry, perform the callback only providing the schema - return reinterpret_cast(schema_entry.get()); - } - auto entry = schema_entry->GetEntry(transaction, type, name); - return entry; -} - -void DependencyManager::CleanupDependencies(CatalogTransaction transaction, CatalogEntry &object) { - // Collect the dependencies - vector to_remove; - - auto info = GetLookupProperties(object); - ScanSubjects(transaction, info, - [&](DependencyEntry &dep) { to_remove.push_back(DependencyInfo::FromSubject(dep)); }); - ScanDependents(transaction, info, - [&](DependencyEntry &dep) { to_remove.push_back(DependencyInfo::FromDependent(dep)); }); - - // Remove the dependency entries - for (auto &dep : to_remove) { - RemoveDependency(transaction, dep); - } -} - -static string EntryToString(CatalogEntryInfo &info) { - auto type = info.type; - switch (type) { - case CatalogType::TABLE_ENTRY: { - return StringUtil::Format("table \"%s\"", info.name); - } - case CatalogType::SCHEMA_ENTRY: { - return StringUtil::Format("schema \"%s\"", info.name); - } - case CatalogType::VIEW_ENTRY: { - return StringUtil::Format("view \"%s\"", info.name); - } - case CatalogType::INDEX_ENTRY: { - return StringUtil::Format("index \"%s\"", info.name); - } - case CatalogType::SEQUENCE_ENTRY: { - return StringUtil::Format("index \"%s\"", info.name); - } - case CatalogType::COLLATION_ENTRY: { - return StringUtil::Format("collation \"%s\"", info.name); - } - case CatalogType::TYPE_ENTRY: { - return StringUtil::Format("type \"%s\"", info.name); - } - case CatalogType::TABLE_FUNCTION_ENTRY: { - return StringUtil::Format("table function \"%s\"", info.name); - } - case CatalogType::SCALAR_FUNCTION_ENTRY: { - return StringUtil::Format("scalar function \"%s\"", info.name); - } - case CatalogType::AGGREGATE_FUNCTION_ENTRY: { - return StringUtil::Format("aggregate function \"%s\"", info.name); - } - case CatalogType::PRAGMA_FUNCTION_ENTRY: { - return StringUtil::Format("pragma function \"%s\"", info.name); - } - case CatalogType::COPY_FUNCTION_ENTRY: { - return StringUtil::Format("copy function \"%s\"", info.name); - } - case CatalogType::MACRO_ENTRY: { - return StringUtil::Format("macro function \"%s\"", info.name); - } - case CatalogType::TABLE_MACRO_ENTRY: { - return StringUtil::Format("table macro function \"%s\"", info.name); - } - case CatalogType::SECRET_ENTRY: { - return StringUtil::Format("secret \"%s\"", info.name); - } - case CatalogType::SECRET_TYPE_ENTRY: { - return StringUtil::Format("secret type \"%s\"", info.name); - } - case CatalogType::SECRET_FUNCTION_ENTRY: { - return StringUtil::Format("secret function \"%s\"", info.name); - } - default: - throw InternalException("CatalogType not handled in EntryToString (DependencyManager) for %s", - CatalogTypeToString(type)); - }; -} - -string DependencyManager::CollectDependents(CatalogTransaction transaction, catalog_entry_set_t &entries, - CatalogEntryInfo &info) { - string result; - for (auto &entry : entries) { - D_ASSERT(!IsSystemEntry(entry.get())); - auto other_info = GetLookupProperties(entry); - result += StringUtil::Format("%s depends on %s.\n", EntryToString(other_info), EntryToString(info)); - catalog_entry_set_t entry_dependents; - ScanDependents(transaction, other_info, [&](DependencyEntry &dep) { - auto child = LookupEntry(transaction, dep); - if (!child) { - return; - } - if (!CascadeDrop(false, dep.Dependent().flags)) { - entry_dependents.insert(*child); - } - }); - if (!entry_dependents.empty()) { - result += CollectDependents(transaction, entry_dependents, other_info); - } - } - return result; -} - -void DependencyManager::VerifyExistence(CatalogTransaction transaction, DependencyEntry &object) { - auto &subject = object.Subject(); - - CatalogEntryInfo info; - if (subject.flags.IsOwnership()) { - info = object.SourceInfo(); - } else { - info = object.EntryInfo(); - } - - auto &type = info.type; - auto &schema = info.schema; - auto &name = info.name; - - auto &duck_catalog = catalog.Cast(); - auto &schema_catalog_set = duck_catalog.GetSchemaCatalogSet(); - - CatalogSet::EntryLookup lookup_result; - lookup_result = schema_catalog_set.GetEntryDetailed(transaction, schema); - - if (type != CatalogType::SCHEMA_ENTRY && lookup_result.result) { - auto &schema_entry = lookup_result.result->Cast(); - lookup_result = schema_entry.GetEntryDetailed(transaction, type, name); - } - - if (lookup_result.reason == CatalogSet::EntryLookup::FailureReason::DELETED) { - throw DependencyException("Could not commit creation of dependency, subject \"%s\" has been deleted", - object.SourceInfo().name); - } -} - -void DependencyManager::VerifyCommitDrop(CatalogTransaction transaction, transaction_t start_time, - CatalogEntry &object) { - if (IsSystemEntry(object)) { - return; - } - auto info = GetLookupProperties(object); - ScanDependents(transaction, info, [&](DependencyEntry &dep) { - auto dep_committed_at = dep.timestamp.load(); - if (dep_committed_at > start_time) { - // In the event of a CASCADE, the dependency drop has not committed yet - // so we would be halted by the existence of a dependency we are already dropping unless we check the - // timestamp - // - // Which differentiates between objects that we were already aware of (and will subsequently be dropped) and - // objects that were introduced inbetween, which should cause this error: - throw DependencyException( - "Could not commit DROP of \"%s\" because a dependency was created after the transaction started", - object.name); - } - }); - ScanSubjects(transaction, info, [&](DependencyEntry &dep) { - auto dep_committed_at = dep.timestamp.load(); - if (!dep.Dependent().flags.IsOwnedBy()) { - return; - } - D_ASSERT(dep.Subject().flags.IsOwnership()); - if (dep_committed_at > start_time) { - // Same as above, objects that are owned by the object that is being dropped will be dropped as part of this - // transaction. Only objects that were introduced by other transactions, that this transaction could not - // see, should cause this error: - throw DependencyException( - "Could not commit DROP of \"%s\" because a dependency was created after the transaction started", - object.name); - } - }); -} - -catalog_entry_set_t DependencyManager::CheckDropDependencies(CatalogTransaction transaction, CatalogEntry &object, - bool cascade) { - if (IsSystemEntry(object)) { - // Don't do anything for this - return catalog_entry_set_t(); - } - - catalog_entry_set_t to_drop; - catalog_entry_set_t blocking_dependents; - - auto info = GetLookupProperties(object); - // Look through all the objects that depend on the 'object' - ScanDependents(transaction, info, [&](DependencyEntry &dep) { - // It makes no sense to have a schema depend on anything - D_ASSERT(dep.EntryInfo().type != CatalogType::SCHEMA_ENTRY); - auto entry = LookupEntry(transaction, dep); - if (!entry) { - return; - } - - if (!CascadeDrop(cascade, dep.Dependent().flags)) { - // no cascade and there are objects that depend on this object: throw error - blocking_dependents.insert(*entry); - } else { - to_drop.insert(*entry); - } - }); - if (!blocking_dependents.empty()) { - string error_string = - StringUtil::Format("Cannot drop entry \"%s\" because there are entries that depend on it.\n", object.name); - error_string += CollectDependents(transaction, blocking_dependents, info); - error_string += "Use DROP...CASCADE to drop all dependents."; - throw DependencyException(error_string); - } - - // Look through all the entries that 'object' depends on - ScanSubjects(transaction, info, [&](DependencyEntry &dep) { - auto flags = dep.Subject().flags; - if (flags.IsOwnership()) { - // We own this object, it should be dropped along with the table - auto entry = LookupEntry(transaction, dep); - to_drop.insert(*entry); - } - }); - return to_drop; -} - -void DependencyManager::DropObject(CatalogTransaction transaction, CatalogEntry &object, bool cascade) { - if (IsSystemEntry(object)) { - // Don't do anything for this - return; - } - - // Check if there are any entries that block the DROP because they still depend on the object - auto to_drop = CheckDropDependencies(transaction, object, cascade); - CleanupDependencies(transaction, object); - - for (auto &entry : to_drop) { - auto set = entry.get().set; - D_ASSERT(set); - set->DropEntry(transaction, entry.get().name, cascade); - } -} - -void DependencyManager::ReorderEntries(catalog_entry_vector_t &entries, ClientContext &context) { - auto transaction = catalog.GetCatalogTransaction(context); - // Read all the entries visible to this snapshot - ReorderEntries(entries, transaction); -} - -void DependencyManager::ReorderEntries(catalog_entry_vector_t &entries) { - // Read all committed entries - CatalogTransaction transaction(catalog.GetDatabase(), TRANSACTION_ID_START - 1, TRANSACTION_ID_START - 1); - ReorderEntries(entries, transaction); -} - -void DependencyManager::ReorderEntry(CatalogTransaction transaction, CatalogEntry &entry, catalog_entry_set_t &visited, - catalog_entry_vector_t &order) { - auto &catalog_entry = *LookupEntry(transaction, entry); - // We use this in CheckpointManager, it has the highest commit ID, allowing us to read any committed data - bool allow_internal = transaction.start_time == TRANSACTION_ID_START - 1; - if (visited.count(catalog_entry) || (!allow_internal && catalog_entry.internal)) { - // Already seen and ordered appropriately - return; - } - - // Check if there are any entries that this entry depends on, those are written first - catalog_entry_vector_t dependents; - auto info = GetLookupProperties(entry); - ScanSubjects(transaction, info, [&](DependencyEntry &dep) { dependents.push_back(dep); }); - for (auto &dep : dependents) { - ReorderEntry(transaction, dep, visited, order); - } - - // Then write the entry - visited.insert(catalog_entry); - order.push_back(catalog_entry); -} - -void DependencyManager::ReorderEntries(catalog_entry_vector_t &entries, CatalogTransaction transaction) { - catalog_entry_vector_t reordered; - catalog_entry_set_t visited; - for (auto &entry : entries) { - ReorderEntry(transaction, entry, visited, reordered); - } - // If this would fail, that means there are more entries that we somehow reached through the dependency manager - // but those entries should not actually be visible to this transaction - D_ASSERT(entries.size() == reordered.size()); - entries.clear(); - entries = reordered; -} - -void DependencyManager::AlterObject(CatalogTransaction transaction, CatalogEntry &old_obj, CatalogEntry &new_obj, - AlterInfo &alter_info) { - if (IsSystemEntry(new_obj)) { - D_ASSERT(IsSystemEntry(old_obj)); - // Don't do anything for this - return; - } - - const auto old_info = GetLookupProperties(old_obj); - const auto new_info = GetLookupProperties(new_obj); - - vector dependencies; - // Other entries that depend on us - ScanDependents(transaction, old_info, [&](DependencyEntry &dep) { - // It makes no sense to have a schema depend on anything - D_ASSERT(dep.EntryInfo().type != CatalogType::SCHEMA_ENTRY); - - bool disallow_alter = true; - switch (alter_info.type) { - case AlterType::ALTER_TABLE: { - auto &alter_table = alter_info.Cast(); - switch (alter_table.alter_table_type) { - case AlterTableType::FOREIGN_KEY_CONSTRAINT: { - // These alters are made as part of a CREATE or DROP table statement when a foreign key column is - // present either adding or removing a reference to the referenced primary key table - disallow_alter = false; - break; - } - case AlterTableType::ADD_COLUMN: { - disallow_alter = false; - break; - } - default: - break; - } - break; - } - case AlterType::SET_COLUMN_COMMENT: - case AlterType::SET_COMMENT: { - disallow_alter = false; - break; - } - default: - break; - } - if (disallow_alter) { - throw DependencyException("Cannot alter entry \"%s\" because there are entries that " - "depend on it.", - old_obj.name); - } - - auto dep_info = DependencyInfo::FromDependent(dep); - dep_info.subject.entry = new_info; - dependencies.emplace_back(dep_info); - }); - - // Keep old dependencies - dependency_set_t dependents; - ScanSubjects(transaction, old_info, [&](DependencyEntry &dep) { - auto entry = LookupEntry(transaction, dep); - if (!entry) { - return; - } - - auto dep_info = DependencyInfo::FromSubject(dep); - dep_info.dependent.entry = new_info; - dependencies.emplace_back(dep_info); - }); - - // FIXME: we should update dependencies in the future - // some alters could cause dependencies to change (imagine types of table columns) - // or DEFAULT depending on a sequence - if (!StringUtil::CIEquals(old_obj.name, new_obj.name)) { - // The name has been changed, we need to recreate the dependency links - CleanupDependencies(transaction, old_obj); - } - - // Reinstate the old dependencies - for (auto &dep : dependencies) { - CreateDependency(transaction, dep); - } -} - -void DependencyManager::Scan( - ClientContext &context, - const std::function &callback) { - auto transaction = catalog.GetCatalogTransaction(context); - lock_guard write_lock(catalog.GetWriteLock()); - - // All the objects registered in the dependency manager - catalog_entry_set_t entries; - dependents.Scan(transaction, [&](CatalogEntry &set) { - auto entry = LookupEntry(transaction, set); - entries.insert(*entry); - }); - - // For every registered entry, get the dependents - for (auto &entry : entries) { - auto entry_info = GetLookupProperties(entry); - // Scan all the dependents of the entry - ScanDependents(transaction, entry_info, [&](DependencyEntry &dependent) { - auto dep = LookupEntry(transaction, dependent); - if (!dep) { - return; - } - auto &dependent_entry = *dep; - callback(entry, dependent_entry, dependent.Dependent().flags); - }); - } -} - -void DependencyManager::AddOwnership(CatalogTransaction transaction, CatalogEntry &owner, CatalogEntry &entry) { - if (IsSystemEntry(entry) || IsSystemEntry(owner)) { - return; - } - - // If the owner is already owned by something else, throw an error - const auto owner_info = GetLookupProperties(owner); - ScanDependents(transaction, owner_info, [&](DependencyEntry &dep) { - if (dep.Dependent().flags.IsOwnedBy()) { - throw DependencyException("%s can not become the owner, it is already owned by %s", owner.name, - dep.EntryInfo().name); - } - }); - - // If the entry is the owner of another entry, throw an error - auto entry_info = GetLookupProperties(entry); - ScanSubjects(transaction, entry_info, [&](DependencyEntry &other) { - auto dependent_entry = LookupEntry(transaction, other); - if (!dependent_entry) { - return; - } - auto &dep = *dependent_entry; - - auto flags = other.Dependent().flags; - if (!flags.IsOwnedBy()) { - return; - } - throw DependencyException("%s already owns %s. Cannot have circular dependencies", entry.name, dep.name); - }); - - // If the entry is already owned, throw an error - ScanDependents(transaction, entry_info, [&](DependencyEntry &other) { - auto dependent_entry = LookupEntry(transaction, other); - if (!dependent_entry) { - return; - } - - auto &dep = *dependent_entry; - auto flags = other.Subject().flags; - if (!flags.IsOwnership()) { - return; - } - if (&dep != &owner) { - throw DependencyException("%s is already owned by %s", entry.name, dep.name); - } - }); - - DependencyInfo info { - /*dependent = */ DependencyDependent {GetLookupProperties(owner), DependencyDependentFlags().SetOwnedBy()}, - /*subject = */ DependencySubject {GetLookupProperties(entry), DependencySubjectFlags().SetOwnership()}}; - CreateDependency(transaction, info); -} - -static string FormatString(const MangledEntryName &mangled) { - auto input = mangled.name; - for (size_t i = 0; i < input.size(); i++) { - if (input[i] == '\0') { - input[i] = '_'; - } - } - return input; -} - -void DependencyManager::PrintSubjects(CatalogTransaction transaction, const CatalogEntryInfo &info) { - auto name = MangleName(info); - Printer::Print(StringUtil::Format("Subjects of %s", FormatString(name))); - auto subjects = DependencyCatalogSet(Subjects(), info); - subjects.Scan(transaction, [&](CatalogEntry &dependency) { - auto &dep = dependency.Cast(); - auto &entry_info = dep.EntryInfo(); - auto type = entry_info.type; - auto schema = entry_info.schema; - auto name = entry_info.name; - Printer::Print(StringUtil::Format("Schema: %s | Name: %s | Type: %s | Dependent type: %s | Subject type: %s", - schema, name, CatalogTypeToString(type), dep.Dependent().flags.ToString(), - dep.Subject().flags.ToString())); - }); -} - -void DependencyManager::PrintDependents(CatalogTransaction transaction, const CatalogEntryInfo &info) { - auto name = MangleName(info); - Printer::Print(StringUtil::Format("Dependents of %s", FormatString(name))); - auto dependents = DependencyCatalogSet(Dependents(), info); - dependents.Scan(transaction, [&](CatalogEntry &dependent) { - auto &dep = dependent.Cast(); - auto &entry_info = dep.EntryInfo(); - auto type = entry_info.type; - auto schema = entry_info.schema; - auto name = entry_info.name; - Printer::Print(StringUtil::Format("Schema: %s | Name: %s | Type: %s | Dependent type: %s | Subject type: %s", - schema, name, CatalogTypeToString(type), dep.Dependent().flags.ToString(), - dep.Subject().flags.ToString())); - }); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/duck_catalog.cpp b/src/duckdb/src/catalog/duck_catalog.cpp deleted file mode 100644 index a55349a74..000000000 --- a/src/duckdb/src/catalog/duck_catalog.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/catalog/catalog_entry/duck_schema_entry.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/parser/parsed_data/drop_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/catalog/default/default_schemas.hpp" -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/transaction/duck_transaction_manager.hpp" -#include "duckdb/function/function_list.hpp" - -namespace duckdb { - -DuckCatalog::DuckCatalog(AttachedDatabase &db) - : Catalog(db), dependency_manager(make_uniq(*this)), - schemas(make_uniq(*this, IsSystemCatalog() ? make_uniq(*this) : nullptr)) { -} - -DuckCatalog::~DuckCatalog() { -} - -void DuckCatalog::Initialize(bool load_builtin) { - // first initialize the base system catalogs - // these are never written to the WAL - // we start these at 1 because deleted entries default to 0 - auto data = CatalogTransaction::GetSystemTransaction(GetDatabase()); - - // create the default schema - CreateSchemaInfo info; - info.schema = DEFAULT_SCHEMA; - info.internal = true; - CreateSchema(data, info); - - if (load_builtin) { - BuiltinFunctions builtin(data, *this); - builtin.Initialize(); - - // initialize default functions - FunctionList::RegisterFunctions(*this, data); - } - - Verify(); -} - -bool DuckCatalog::IsDuckCatalog() { - return true; -} - -optional_ptr DuckCatalog::GetDependencyManager() { - return dependency_manager.get(); -} - -//===--------------------------------------------------------------------===// -// Schema -//===--------------------------------------------------------------------===// -optional_ptr DuckCatalog::CreateSchemaInternal(CatalogTransaction transaction, CreateSchemaInfo &info) { - LogicalDependencyList dependencies; - - if (!info.internal && DefaultSchemaGenerator::IsDefaultSchema(info.schema)) { - return nullptr; - } - auto entry = make_uniq(*this, info); - auto result = entry.get(); - if (!schemas->CreateEntry(transaction, info.schema, std::move(entry), dependencies)) { - return nullptr; - } - return result; -} - -optional_ptr DuckCatalog::CreateSchema(CatalogTransaction transaction, CreateSchemaInfo &info) { - D_ASSERT(!info.schema.empty()); - auto result = CreateSchemaInternal(transaction, info); - if (!result) { - switch (info.on_conflict) { - case OnCreateConflict::ERROR_ON_CONFLICT: - throw CatalogException::EntryAlreadyExists(CatalogType::SCHEMA_ENTRY, info.schema); - case OnCreateConflict::REPLACE_ON_CONFLICT: { - DropInfo drop_info; - drop_info.type = CatalogType::SCHEMA_ENTRY; - drop_info.catalog = info.catalog; - drop_info.name = info.schema; - DropSchema(transaction, drop_info); - result = CreateSchemaInternal(transaction, info); - if (!result) { - throw InternalException("Failed to create schema entry in CREATE_OR_REPLACE"); - } - break; - } - case OnCreateConflict::IGNORE_ON_CONFLICT: - break; - default: - throw InternalException("Unsupported OnCreateConflict for CreateSchema"); - } - return nullptr; - } - return result; -} - -void DuckCatalog::DropSchema(CatalogTransaction transaction, DropInfo &info) { - D_ASSERT(!info.name.empty()); - if (!schemas->DropEntry(transaction, info.name, info.cascade)) { - if (info.if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException::MissingEntry(CatalogType::SCHEMA_ENTRY, info.name, string()); - } - } -} - -void DuckCatalog::DropSchema(ClientContext &context, DropInfo &info) { - DropSchema(GetCatalogTransaction(context), info); -} - -void DuckCatalog::ScanSchemas(ClientContext &context, std::function callback) { - schemas->Scan(GetCatalogTransaction(context), - [&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -void DuckCatalog::ScanSchemas(std::function callback) { - schemas->Scan([&](CatalogEntry &entry) { callback(entry.Cast()); }); -} - -CatalogSet &DuckCatalog::GetSchemaCatalogSet() { - return *schemas; -} - -optional_ptr DuckCatalog::GetSchema(CatalogTransaction transaction, const string &schema_name, - OnEntryNotFound if_not_found, QueryErrorContext error_context) { - D_ASSERT(!schema_name.empty()); - auto entry = schemas->GetEntry(transaction, schema_name); - if (!entry) { - if (if_not_found == OnEntryNotFound::THROW_EXCEPTION) { - throw CatalogException(error_context, "Schema with name %s does not exist!", schema_name); - } - return nullptr; - } - return &entry->Cast(); -} - -DatabaseSize DuckCatalog::GetDatabaseSize(ClientContext &context) { - auto &transaction = DuckTransactionManager::Get(db); - auto lock = transaction.SharedCheckpointLock(); - return db.GetStorageManager().GetDatabaseSize(); -} - -vector DuckCatalog::GetMetadataInfo(ClientContext &context) { - auto &transaction = DuckTransactionManager::Get(db); - auto lock = transaction.SharedCheckpointLock(); - return db.GetStorageManager().GetMetadataInfo(); -} - -bool DuckCatalog::InMemory() { - return db.GetStorageManager().InMemory(); -} - -string DuckCatalog::GetDBPath() { - return db.GetStorageManager().GetDBPath(); -} - -void DuckCatalog::Verify() { -#ifdef DEBUG - Catalog::Verify(); - schemas->Verify(*this); -#endif -} - -optional_idx DuckCatalog::GetCatalogVersion(ClientContext &context) { - auto &transaction_manager = DuckTransactionManager::Get(db); - auto transaction = GetCatalogTransaction(context); - D_ASSERT(transaction.transaction); - return transaction_manager.GetCatalogVersion(*transaction.transaction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/catalog/similar_catalog_entry.cpp b/src/duckdb/src/catalog/similar_catalog_entry.cpp deleted file mode 100644 index d3e3487b9..000000000 --- a/src/duckdb/src/catalog/similar_catalog_entry.cpp +++ /dev/null @@ -1,26 +0,0 @@ -#include "duckdb/catalog/similar_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -string SimilarCatalogEntry::GetQualifiedName(bool qualify_catalog, bool qualify_schema) const { - D_ASSERT(Found()); - string result; - if (qualify_catalog) { - result += schema->catalog.GetName(); - } - if (qualify_schema) { - if (!result.empty()) { - result += "."; - } - result += schema->name; - } - if (!result.empty()) { - result += "."; - } - result += name; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/adbc/adbc.cpp b/src/duckdb/src/common/adbc/adbc.cpp deleted file mode 100644 index 7323f3b1f..000000000 --- a/src/duckdb/src/common/adbc/adbc.cpp +++ /dev/null @@ -1,1351 +0,0 @@ -#include "duckdb/common/adbc/adbc.hpp" -#include "duckdb/common/adbc/adbc-init.hpp" - -#include "duckdb/common/string.hpp" -#include "duckdb/common/string_util.hpp" - -#include "duckdb.h" -#include "duckdb/common/arrow/arrow_wrapper.hpp" -#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" - -#include "duckdb/main/capi/capi_internal.hpp" - -#ifndef DUCKDB_AMALGAMATION -#include "duckdb/main/connection.hpp" -#endif - -#include "duckdb/common/adbc/options.h" -#include "duckdb/common/adbc/single_batch_array_stream.hpp" -#include "duckdb/function/table/arrow.hpp" - -#include -#include - -// We must leak the symbols of the init function -AdbcStatusCode duckdb_adbc_init(int version, void *driver, struct AdbcError *error) { - if (!driver) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto adbc_driver = static_cast(driver); - - adbc_driver->DatabaseNew = duckdb_adbc::DatabaseNew; - adbc_driver->DatabaseSetOption = duckdb_adbc::DatabaseSetOption; - adbc_driver->DatabaseInit = duckdb_adbc::DatabaseInit; - adbc_driver->DatabaseRelease = duckdb_adbc::DatabaseRelease; - adbc_driver->ConnectionNew = duckdb_adbc::ConnectionNew; - adbc_driver->ConnectionSetOption = duckdb_adbc::ConnectionSetOption; - adbc_driver->ConnectionInit = duckdb_adbc::ConnectionInit; - adbc_driver->ConnectionRelease = duckdb_adbc::ConnectionRelease; - adbc_driver->ConnectionGetTableTypes = duckdb_adbc::ConnectionGetTableTypes; - adbc_driver->StatementNew = duckdb_adbc::StatementNew; - adbc_driver->StatementRelease = duckdb_adbc::StatementRelease; - adbc_driver->StatementBind = duckdb_adbc::StatementBind; - adbc_driver->StatementBindStream = duckdb_adbc::StatementBindStream; - adbc_driver->StatementExecuteQuery = duckdb_adbc::StatementExecuteQuery; - adbc_driver->StatementPrepare = duckdb_adbc::StatementPrepare; - adbc_driver->StatementSetOption = duckdb_adbc::StatementSetOption; - adbc_driver->StatementSetSqlQuery = duckdb_adbc::StatementSetSqlQuery; - adbc_driver->ConnectionGetObjects = duckdb_adbc::ConnectionGetObjects; - adbc_driver->ConnectionCommit = duckdb_adbc::ConnectionCommit; - adbc_driver->ConnectionRollback = duckdb_adbc::ConnectionRollback; - adbc_driver->ConnectionReadPartition = duckdb_adbc::ConnectionReadPartition; - adbc_driver->StatementExecutePartitions = duckdb_adbc::StatementExecutePartitions; - adbc_driver->ConnectionGetInfo = duckdb_adbc::ConnectionGetInfo; - adbc_driver->StatementGetParameterSchema = duckdb_adbc::StatementGetParameterSchema; - adbc_driver->ConnectionGetTableSchema = duckdb_adbc::ConnectionGetTableSchema; - adbc_driver->StatementSetSubstraitPlan = duckdb_adbc::StatementSetSubstraitPlan; - return ADBC_STATUS_OK; -} - -namespace duckdb_adbc { - -enum class IngestionMode { CREATE = 0, APPEND = 1 }; - -struct DuckDBAdbcStatementWrapper { - duckdb_connection connection; - duckdb_arrow result; - duckdb_prepared_statement statement; - char *ingestion_table_name; - char *db_schema; - ArrowArrayStream ingestion_stream; - IngestionMode ingestion_mode = IngestionMode::CREATE; - bool temporary_table = false; - uint8_t *substrait_plan; - uint64_t plan_length; -}; - -static AdbcStatusCode QueryInternal(struct AdbcConnection *connection, struct ArrowArrayStream *out, const char *query, - struct AdbcError *error) { - AdbcStatement statement; - - auto status = StatementNew(connection, &statement, error); - if (status != ADBC_STATUS_OK) { - StatementRelease(&statement, error); - SetError(error, "unable to initialize statement"); - return status; - } - status = StatementSetSqlQuery(&statement, query, error); - if (status != ADBC_STATUS_OK) { - StatementRelease(&statement, error); - SetError(error, "unable to initialize statement"); - return status; - } - status = StatementExecuteQuery(&statement, out, nullptr, error); - if (status != ADBC_STATUS_OK) { - StatementRelease(&statement, error); - SetError(error, "unable to initialize statement"); - return status; - } - StatementRelease(&statement, error); - return ADBC_STATUS_OK; -} - -struct DuckDBAdbcDatabaseWrapper { - //! The DuckDB Database Configuration - duckdb_config config = nullptr; - //! The DuckDB Database - duckdb_database database = nullptr; - //! Path of Disk-Based Database or :memory: database - std::string path; -}; - -static void EmptyErrorRelease(AdbcError *error) { - // The object is valid but doesn't contain any data that needs to be cleaned up - // Just set the release to nullptr to indicate that it's no longer valid. - error->release = nullptr; -} - -void InitializeADBCError(AdbcError *error) { - if (!error) { - return; - } - error->message = nullptr; - // Don't set to nullptr, as that indicates that it's invalid - error->release = EmptyErrorRelease; - std::memset(error->sqlstate, '\0', sizeof(error->sqlstate)); - error->vendor_code = -1; -} - -AdbcStatusCode CheckResult(const duckdb_state &res, AdbcError *error, const char *error_msg) { - if (!error) { - // Error should be a non-null pointer - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (res != DuckDBSuccess) { - SetError(error, error_msg); - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode DatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - database->private_data = nullptr; - // you can't malloc a struct with a non-trivial C++ constructor - // and std::string has a non-trivial constructor. so we need - // to use new and delete rather than malloc and free. - auto wrapper = new (std::nothrow) DuckDBAdbcDatabaseWrapper; - if (!wrapper) { - SetError(error, "Allocation error"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - database->private_data = wrapper; - auto res = duckdb_create_config(&wrapper->config); - return CheckResult(res, error, "Failed to allocate"); -} - -AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Statement is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!plan) { - SetError(error, "Substrait Plan is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (length == 0) { - SetError(error, "Can't execute plan with size = 0"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = static_cast(statement->private_data); - if (wrapper->ingestion_stream.release) { - // Release any resources currently held by the ingestion stream before we overwrite it - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - wrapper->ingestion_stream.release = nullptr; - } - if (wrapper->statement) { - duckdb_destroy_prepare(&wrapper->statement); - wrapper->statement = nullptr; - } - wrapper->substrait_plan = static_cast(malloc(sizeof(uint8_t) * length)); - wrapper->plan_length = length; - memcpy(wrapper->substrait_plan, plan, length); - return ADBC_STATUS_OK; -} - -AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!key) { - SetError(error, "Missing key"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = static_cast(database->private_data); - if (strcmp(key, "path") == 0) { - wrapper->path = value; - return ADBC_STATUS_OK; - } - auto res = duckdb_set_config(wrapper->config, key, value); - - return CheckResult(res, error, "Failed to set configuration option"); -} - -AdbcStatusCode DatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { - if (!error) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!database) { - SetError(error, "ADBC Database has an invalid pointer"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - char *errormsg = nullptr; - // TODO can we set the database path via option, too? Does not look like it... - auto wrapper = static_cast(database->private_data); - auto res = duckdb_open_ext(wrapper->path.c_str(), &wrapper->database, wrapper->config, &errormsg); - auto adbc_result = CheckResult(res, error, errormsg); - if (errormsg) { - free(errormsg); - } - return adbc_result; -} - -AdbcStatusCode DatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - - if (database && database->private_data) { - auto wrapper = static_cast(database->private_data); - - duckdb_close(&wrapper->database); - duckdb_destroy_config(&wrapper->config); - delete wrapper; - database->private_data = nullptr; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, const char *db_schema, - const char *table_name, struct ArrowSchema *schema, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (db_schema == nullptr || strlen(db_schema) == 0) { - // if schema is not set, we use the default schema - db_schema = "main"; - } - if (table_name == nullptr) { - SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } else if (strlen(table_name) == 0) { - SetError(error, "AdbcConnectionGetTableSchema: must provide table_name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - ArrowArrayStream arrow_stream; - - std::string query = "SELECT * FROM "; - if (catalog != nullptr && strlen(catalog) > 0) { - query += std::string(catalog) + "."; - } - query += std::string(db_schema) + "."; - query += std::string(table_name) + " LIMIT 0;"; - - auto success = QueryInternal(connection, &arrow_stream, query.c_str(), error); - if (success != ADBC_STATUS_OK) { - return success; - } - arrow_stream.get_schema(&arrow_stream, schema); - arrow_stream.release(&arrow_stream); - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - connection->private_data = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode ExecuteQuery(duckdb::Connection *conn, const char *query, struct AdbcError *error) { - auto res = conn->Query(query); - if (res->HasError()) { - auto error_message = "Failed to execute query \"" + std::string(query) + "\": " + res->GetError(); - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode ConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, - struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto conn = static_cast(connection->private_data); - if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { - if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { - if (conn->HasActiveTransaction()) { - AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); - if (status != ADBC_STATUS_OK) { - return status; - } - } else { - // no-op - } - } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { - if (conn->HasActiveTransaction()) { - // no-op - } else { - // begin - AdbcStatusCode status = ExecuteQuery(conn, "START TRANSACTION", error); - if (status != ADBC_STATUS_OK) { - return status; - } - } - } else { - auto error_message = "Invalid connection option value " + std::string(key) + "=" + std::string(value); - SetError(error, error_message); - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; - } - auto error_message = - "Unknown connection option " + std::string(key) + "=" + (value ? std::string(value) : "(NULL)"); - SetError(error, error_message); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - SetError(error, "Read Partitions are not supported in DuckDB"); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - SetError(error, "Execute Partitions are not supported in DuckDB"); - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto conn = static_cast(connection->private_data); - if (!conn->HasActiveTransaction()) { - SetError(error, "No active transaction, cannot commit"); - return ADBC_STATUS_INVALID_STATE; - } - - AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); - if (status != ADBC_STATUS_OK) { - return status; - } - return ExecuteQuery(conn, "START TRANSACTION", error); -} - -AdbcStatusCode ConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection) { - SetError(error, "Connection is not set"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto conn = static_cast(connection->private_data); - if (!conn->HasActiveTransaction()) { - SetError(error, "No active transaction, cannot rollback"); - return ADBC_STATUS_INVALID_STATE; - } - - AdbcStatusCode status = ExecuteQuery(conn, "ROLLBACK", error); - if (status != ADBC_STATUS_OK) { - return status; - } - return ExecuteQuery(conn, "START TRANSACTION", error); -} - -enum class AdbcInfoCode : uint32_t { - VENDOR_NAME, - VENDOR_VERSION, - DRIVER_NAME, - DRIVER_VERSION, - DRIVER_ARROW_VERSION, - UNRECOGNIZED // always the last entry of the enum -}; - -static AdbcInfoCode ConvertToInfoCode(uint32_t info_code) { - switch (info_code) { - case 0: - return AdbcInfoCode::VENDOR_NAME; - case 1: - return AdbcInfoCode::VENDOR_VERSION; - case 2: - return AdbcInfoCode::DRIVER_NAME; - case 3: - return AdbcInfoCode::DRIVER_VERSION; - case 4: - return AdbcInfoCode::DRIVER_ARROW_VERSION; - default: - return AdbcInfoCode::UNRECOGNIZED; - } -} - -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, const uint32_t *info_codes, - size_t info_codes_length, struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "Connection is invalid"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!out) { - SetError(error, "Output parameter was not provided"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - // If 'info_codes' is NULL, we should output all the info codes we recognize - size_t length = info_codes ? info_codes_length : static_cast(AdbcInfoCode::UNRECOGNIZED); - - duckdb::string q = R"EOF( - select - name::UINTEGER as info_name, - info::UNION( - string_value VARCHAR, - bool_value BOOL, - int64_value BIGINT, - int32_bitmask INTEGER, - string_list VARCHAR[], - int32_to_int32_list_map MAP(INTEGER, INTEGER[]) - ) as info_value from values - )EOF"; - - duckdb::string results = ""; - - for (size_t i = 0; i < length; i++) { - auto code = duckdb::NumericCast(info_codes ? info_codes[i] : i); - auto info_code = ConvertToInfoCode(code); - switch (info_code) { - case AdbcInfoCode::VENDOR_NAME: { - results += "(0, 'duckdb'),"; - break; - } - case AdbcInfoCode::VENDOR_VERSION: { - results += duckdb::StringUtil::Format("(1, '%s'),", duckdb_library_version()); - break; - } - case AdbcInfoCode::DRIVER_NAME: { - results += "(2, 'ADBC DuckDB Driver'),"; - break; - } - case AdbcInfoCode::DRIVER_VERSION: { - // TODO: fill in driver version - results += "(3, '(unknown)'),"; - break; - } - case AdbcInfoCode::DRIVER_ARROW_VERSION: { - // TODO: fill in arrow version - results += "(4, '(unknown)'),"; - break; - } - case AdbcInfoCode::UNRECOGNIZED: { - // Unrecognized codes are not an error, just ignored - continue; - } - default: { - // Codes that we have implemented but not handled here are a developer error - SetError(error, "Info code recognized but not handled"); - return ADBC_STATUS_INTERNAL; - } - } - } - if (results.empty()) { - // Add a group of values so the query parses - q += "(NULL, NULL)"; - } else { - q += results; - } - q += " tbl(name, info)"; - if (results.empty()) { - // Add an impossible where clause to return an empty result set - q += " where true = false"; - } - return QueryInternal(connection, out, q.c_str(), error); -} - -AdbcStatusCode ConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, - struct AdbcError *error) { - if (!database) { - SetError(error, "Missing database object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!database->private_data) { - SetError(error, "Invalid database"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto database_wrapper = static_cast(database->private_data); - - connection->private_data = nullptr; - auto res = - duckdb_connect(database_wrapper->database, reinterpret_cast(&connection->private_data)); - return CheckResult(res, error, "Failed to connect to Database"); -} - -AdbcStatusCode ConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { - if (connection && connection->private_data) { - duckdb_disconnect(reinterpret_cast(&connection->private_data)); - connection->private_data = nullptr; - } - return ADBC_STATUS_OK; -} - -// some stream callbacks - -static int get_schema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { - if (!stream || !stream->private_data || !out) { - return DuckDBError; - } - return duckdb_query_arrow_schema(static_cast(stream->private_data), - reinterpret_cast(&out)); -} - -static int get_next(struct ArrowArrayStream *stream, struct ArrowArray *out) { - if (!stream || !stream->private_data || !out) { - return DuckDBError; - } - out->release = nullptr; - - return duckdb_query_arrow_array(static_cast(stream->private_data), - reinterpret_cast(&out)); -} - -void release(struct ArrowArrayStream *stream) { - if (!stream || !stream->release) { - return; - } - if (stream->private_data) { - duckdb_destroy_arrow(reinterpret_cast(&stream->private_data)); - stream->private_data = nullptr; - } - stream->release = nullptr; -} - -const char *get_last_error(struct ArrowArrayStream *stream) { - if (!stream) { - return nullptr; - } - return nullptr; - // return duckdb_query_arrow_error(stream); -} - -// this is an evil hack, normally we would need a stream factory here, but its probably much easier if the adbc clients -// just hand over a stream - -duckdb::unique_ptr stream_produce(uintptr_t factory_ptr, - duckdb::ArrowStreamParameters ¶meters) { - - // TODO this will ignore any projections or filters but since we don't expose the scan it should be sort of fine - auto res = duckdb::make_uniq(); - res->arrow_array_stream = *reinterpret_cast(factory_ptr); - return res; -} - -void stream_schema(ArrowArrayStream *stream, ArrowSchema &schema) { - stream->get_schema(stream, &schema); -} - -AdbcStatusCode Ingest(duckdb_connection connection, const char *table_name, const char *schema, - struct ArrowArrayStream *input, struct AdbcError *error, IngestionMode ingestion_mode, - bool temporary) { - - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!input) { - SetError(error, "Missing input arrow stream pointer"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!table_name) { - SetError(error, "Missing database object name"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (schema && temporary) { - // Temporary option is not supported with ADBC_INGEST_OPTION_TARGET_DB_SCHEMA or - // ADBC_INGEST_OPTION_TARGET_CATALOG - SetError(error, "Temporary option is not supported with schema"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto cconn = reinterpret_cast(connection); - - auto arrow_scan = - cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), - duckdb::Value::POINTER(reinterpret_cast(stream_produce)), - duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); - try { - switch (ingestion_mode) { - case IngestionMode::CREATE: - if (schema) { - arrow_scan->Create(schema, table_name, temporary); - } else { - arrow_scan->Create(table_name, temporary); - } - break; - case IngestionMode::APPEND: { - arrow_scan->CreateView("temp_adbc_view", true, true); - std::string query; - if (schema) { - query = duckdb::StringUtil::Format("insert into \"%s.%s\" select * from temp_adbc_view", schema, - table_name); - } else { - query = duckdb::StringUtil::Format("insert into \"%s\" select * from temp_adbc_view", table_name); - } - auto result = cconn->Query(query); - break; - } - } - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - duckdb::ErrorData parsed_error(ex); - error->message = strdup(parsed_error.RawMessage().c_str()); - } - return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, - struct AdbcError *error) { - if (!connection) { - SetError(error, "Missing connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!connection->private_data) { - SetError(error, "Invalid connection object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - statement->private_data = nullptr; - - auto statement_wrapper = static_cast(malloc(sizeof(DuckDBAdbcStatementWrapper))); - if (!statement_wrapper) { - SetError(error, "Allocation error"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - statement->private_data = statement_wrapper; - statement_wrapper->connection = static_cast(connection->private_data); - statement_wrapper->statement = nullptr; - statement_wrapper->result = nullptr; - statement_wrapper->ingestion_stream.release = nullptr; - statement_wrapper->ingestion_table_name = nullptr; - statement_wrapper->db_schema = nullptr; - statement_wrapper->substrait_plan = nullptr; - statement_wrapper->temporary_table = false; - - statement_wrapper->ingestion_mode = IngestionMode::CREATE; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement || !statement->private_data) { - return ADBC_STATUS_OK; - } - auto wrapper = static_cast(statement->private_data); - if (wrapper->statement) { - duckdb_destroy_prepare(&wrapper->statement); - wrapper->statement = nullptr; - } - if (wrapper->result) { - duckdb_destroy_arrow(&wrapper->result); - wrapper->result = nullptr; - } - if (wrapper->ingestion_stream.release) { - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - wrapper->ingestion_stream.release = nullptr; - } - if (wrapper->ingestion_table_name) { - free(wrapper->ingestion_table_name); - wrapper->ingestion_table_name = nullptr; - } - if (wrapper->db_schema) { - free(wrapper->db_schema); - wrapper->db_schema = nullptr; - } - if (wrapper->substrait_plan) { - free(wrapper->substrait_plan); - wrapper->substrait_plan = nullptr; - } - free(statement->private_data); - statement->private_data = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!schema) { - SetError(error, "Missing schema object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = static_cast(statement->private_data); - // TODO: we might want to cache this, but then we need to return a deep copy anyways.., so I'm not sure if that - // would be worth the extra management - auto res = duckdb_prepared_arrow_schema(wrapper->statement, reinterpret_cast(&schema)); - if (res != DuckDBSuccess) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode GetPreparedParameters(duckdb_connection connection, duckdb::unique_ptr &result, - ArrowArrayStream *input, AdbcError *error) { - - auto cconn = reinterpret_cast(connection); - - try { - auto arrow_scan = - cconn->TableFunction("arrow_scan", {duckdb::Value::POINTER(reinterpret_cast(input)), - duckdb::Value::POINTER(reinterpret_cast(stream_produce)), - duckdb::Value::POINTER(reinterpret_cast(stream_schema))}); - result = arrow_scan->Execute(); - // After creating a table, the arrow array stream is released. Hence we must set it as released to avoid - // double-releasing it - input->release = nullptr; - } catch (std::exception &ex) { - if (error) { - ::duckdb::ErrorData parsed_error(ex); - error->message = strdup(parsed_error.RawMessage().c_str()); - } - return ADBC_STATUS_INTERNAL; - } catch (...) { - return ADBC_STATUS_INTERNAL; - } - return ADBC_STATUS_OK; -} - -static AdbcStatusCode IngestToTableFromBoundStream(DuckDBAdbcStatementWrapper *statement, AdbcError *error) { - // See ADBC_INGEST_OPTION_TARGET_TABLE - D_ASSERT(statement->ingestion_stream.release); - D_ASSERT(statement->ingestion_table_name); - - // Take the input stream from the statement - auto stream = statement->ingestion_stream; - statement->ingestion_stream.release = nullptr; - - // Ingest into a table from the bound stream - return Ingest(statement->connection, statement->ingestion_table_name, statement->db_schema, &stream, error, - statement->ingestion_mode, statement->temporary_table); -} - -AdbcStatusCode StatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, - int64_t *rows_affected, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto wrapper = static_cast(statement->private_data); - - // TODO: Set affected rows, careful with early return - if (rows_affected) { - *rows_affected = 0; - } - - const auto has_stream = wrapper->ingestion_stream.release != nullptr; - const auto to_table = wrapper->ingestion_table_name != nullptr; - - if (has_stream && to_table) { - return IngestToTableFromBoundStream(wrapper, error); - } - if (wrapper->substrait_plan != nullptr) { - auto plan_str = std::string(reinterpret_cast(wrapper->substrait_plan), wrapper->plan_length); - duckdb::vector params; - params.emplace_back(duckdb::Value::BLOB_RAW(plan_str)); - duckdb::unique_ptr query_result; - try { - query_result = reinterpret_cast(wrapper->connection) - ->TableFunction("from_substrait", params) - ->Execute(); - } catch (duckdb::Exception &e) { - std::string error_msg = "It was not possible to execute substrait query. " + std::string(e.what()); - SetError(error, error_msg); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto arrow_wrapper = new duckdb::ArrowResultWrapper(); - arrow_wrapper->result = - duckdb::unique_ptr_cast(std::move(query_result)); - wrapper->result = reinterpret_cast(arrow_wrapper); - } else if (has_stream) { - // A stream was bound to the statement, use that to bind parameters - duckdb::unique_ptr result; - ArrowArrayStream stream = wrapper->ingestion_stream; - wrapper->ingestion_stream.release = nullptr; - auto adbc_res = GetPreparedParameters(wrapper->connection, result, &stream, error); - if (adbc_res != ADBC_STATUS_OK) { - return adbc_res; - } - if (!result) { - return ADBC_STATUS_INVALID_ARGUMENT; - } - duckdb::unique_ptr chunk; - auto prepared_statement_params = - reinterpret_cast(wrapper->statement)->statement->named_param_map.size(); - - while ((chunk = result->Fetch()) != nullptr) { - if (chunk->size() == 0) { - SetError(error, "Please provide a non-empty chunk to be bound"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (chunk->size() != 1) { - // TODO: add support for binding multiple rows - SetError(error, "Binding multiple rows at once is not supported yet"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - if (chunk->ColumnCount() > prepared_statement_params) { - SetError(error, "Input data has more column than prepared statement has parameters"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - duckdb_clear_bindings(wrapper->statement); - for (idx_t col_idx = 0; col_idx < chunk->ColumnCount(); col_idx++) { - auto val = chunk->GetValue(col_idx, 0); - auto duck_val = reinterpret_cast(&val); - auto res = duckdb_bind_value(wrapper->statement, 1 + col_idx, duck_val); - if (res != DuckDBSuccess) { - SetError(error, duckdb_prepare_error(wrapper->statement)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); - if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - } else { - auto res = duckdb_execute_prepared_arrow(wrapper->statement, &wrapper->result); - if (res != DuckDBSuccess) { - SetError(error, duckdb_query_arrow_error(wrapper->result)); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - if (out) { - out->private_data = wrapper->result; - out->get_schema = get_schema; - out->get_next = get_next; - out->release = release; - out->get_last_error = get_last_error; - - // because we handed out the stream pointer its no longer our responsibility to destroy it in - // AdbcStatementRelease, this is now done in release() - wrapper->result = nullptr; - } - - return ADBC_STATUS_OK; -} - -// this is a nop for us -AdbcStatusCode StatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!query) { - SetError(error, "Missing query"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = static_cast(statement->private_data); - if (wrapper->ingestion_stream.release) { - // Release any resources currently held by the ingestion stream before we overwrite it - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - wrapper->ingestion_stream.release = nullptr; - } - if (wrapper->statement) { - duckdb_destroy_prepare(&wrapper->statement); - wrapper->statement = nullptr; - } - auto res = duckdb_prepare(wrapper->connection, query, &wrapper->statement); - auto error_msg = duckdb_prepare_error(wrapper->statement); - return CheckResult(res, error, error_msg); -} - -AdbcStatusCode StatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schemas, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!values) { - SetError(error, "Missing values object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!schemas) { - SetError(error, "Invalid schemas object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = static_cast(statement->private_data); - if (wrapper->ingestion_stream.release) { - // Free the stream that was previously bound - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } - auto status = BatchToArrayStream(values, schemas, &wrapper->ingestion_stream, error); - return status; -} - -AdbcStatusCode StatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *values, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!values) { - SetError(error, "Missing values object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = static_cast(statement->private_data); - if (wrapper->ingestion_stream.release) { - // Release any resources currently held by the ingestion stream before we overwrite it - wrapper->ingestion_stream.release(&wrapper->ingestion_stream); - } - wrapper->ingestion_stream = *values; - values->release = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode StatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, - struct AdbcError *error) { - if (!statement) { - SetError(error, "Missing statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!statement->private_data) { - SetError(error, "Invalid statement object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - if (!key) { - SetError(error, "Missing key object"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - auto wrapper = static_cast(statement->private_data); - - if (strcmp(key, ADBC_INGEST_OPTION_TARGET_TABLE) == 0) { - wrapper->ingestion_table_name = strdup(value); - wrapper->temporary_table = false; - return ADBC_STATUS_OK; - } - if (strcmp(key, ADBC_INGEST_OPTION_TEMPORARY) == 0) { - if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { - if (wrapper->db_schema) { - SetError(error, "Temporary option is not supported with schema"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - wrapper->temporary_table = true; - return ADBC_STATUS_OK; - } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { - wrapper->temporary_table = false; - return ADBC_STATUS_OK; - } else { - SetError( - error, - "ADBC_INGEST_OPTION_TEMPORARY, can only be ADBC_OPTION_VALUE_ENABLED or ADBC_OPTION_VALUE_DISABLED"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - - if (strcmp(key, ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) == 0) { - if (wrapper->temporary_table) { - SetError(error, "Temporary option is not supported with schema"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - wrapper->db_schema = strdup(value); - return ADBC_STATUS_OK; - } - - if (strcmp(key, ADBC_INGEST_OPTION_MODE) == 0) { - if (strcmp(value, ADBC_INGEST_OPTION_MODE_CREATE) == 0) { - wrapper->ingestion_mode = IngestionMode::CREATE; - return ADBC_STATUS_OK; - } else if (strcmp(value, ADBC_INGEST_OPTION_MODE_APPEND) == 0) { - wrapper->ingestion_mode = IngestionMode::APPEND; - return ADBC_STATUS_OK; - } else { - SetError(error, "Invalid ingestion mode"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - } - std::stringstream ss; - ss << "Statement Set Option " << key << " is not yet accepted by DuckDB"; - SetError(error, ss.str()); - return ADBC_STATUS_INVALID_ARGUMENT; -} - -AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, - const char *db_schema, const char *table_name, const char **table_type, - const char *column_name, struct ArrowArrayStream *out, struct AdbcError *error) { - if (table_type != nullptr) { - SetError(error, "Table types parameter not yet supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - std::string catalog_filter = catalog ? catalog : "%"; - std::string db_schema_filter = db_schema ? db_schema : "%"; - std::string table_name_filter = table_name ? table_name : "%"; - std::string column_name_filter = column_name ? column_name : "%"; - - std::string query; - switch (depth) { - case ADBC_OBJECT_DEPTH_CATALOGS: - // Return metadata on catalogs. - query = duckdb::StringUtil::Format(R"( - SELECT - catalog_name, - []::STRUCT( - db_schema_name VARCHAR, - db_schema_tables STRUCT( - table_name VARCHAR, - table_type VARCHAR, - table_columns STRUCT( - column_name VARCHAR, - ordinal_position INTEGER, - remarks VARCHAR, - xdbc_data_type SMALLINT, - xdbc_type_name VARCHAR, - xdbc_column_size INTEGER, - xdbc_decimal_digits SMALLINT, - xdbc_num_prec_radix SMALLINT, - xdbc_nullable SMALLINT, - xdbc_column_def VARCHAR, - xdbc_sql_data_type SMALLINT, - xdbc_datetime_sub SMALLINT, - xdbc_char_octet_length INTEGER, - xdbc_is_nullable VARCHAR, - xdbc_scope_catalog VARCHAR, - xdbc_scope_schema VARCHAR, - xdbc_scope_table VARCHAR, - xdbc_is_autoincrement BOOLEAN, - xdbc_is_generatedcolumn BOOLEAN - )[], - table_constraints STRUCT( - constraint_name VARCHAR, - constraint_type VARCHAR, - constraint_column_names VARCHAR[], - constraint_column_usage STRUCT(fk_catalog VARCHAR, fk_db_schema VARCHAR, fk_table VARCHAR, fk_column_name VARCHAR)[] - )[] - )[] - )[] catalog_db_schemas - FROM - information_schema.schemata - WHERE catalog_name LIKE '%s' - GROUP BY catalog_name - )", - catalog_filter); - break; - case ADBC_OBJECT_DEPTH_DB_SCHEMAS: - // Return metadata on catalogs and schemas. - query = duckdb::StringUtil::Format(R"( - WITH db_schemas AS ( - SELECT - catalog_name, - schema_name, - FROM information_schema.schemata - WHERE schema_name LIKE '%s' - ) - - SELECT - catalog_name, - LIST({ - db_schema_name: schema_name, - db_schema_tables: []::STRUCT( - table_name VARCHAR, - table_type VARCHAR, - table_columns STRUCT( - column_name VARCHAR, - ordinal_position INTEGER, - remarks VARCHAR, - xdbc_data_type SMALLINT, - xdbc_type_name VARCHAR, - xdbc_column_size INTEGER, - xdbc_decimal_digits SMALLINT, - xdbc_num_prec_radix SMALLINT, - xdbc_nullable SMALLINT, - xdbc_column_def VARCHAR, - xdbc_sql_data_type SMALLINT, - xdbc_datetime_sub SMALLINT, - xdbc_char_octet_length INTEGER, - xdbc_is_nullable VARCHAR, - xdbc_scope_catalog VARCHAR, - xdbc_scope_schema VARCHAR, - xdbc_scope_table VARCHAR, - xdbc_is_autoincrement BOOLEAN, - xdbc_is_generatedcolumn BOOLEAN - )[], - table_constraints STRUCT( - constraint_name VARCHAR, - constraint_type VARCHAR, - constraint_column_names VARCHAR[], - constraint_column_usage STRUCT(fk_catalog VARCHAR, fk_db_schema VARCHAR, fk_table VARCHAR, fk_column_name VARCHAR)[] - )[] - )[], - }) FILTER (dbs.schema_name is not null) catalog_db_schemas - FROM - information_schema.schemata - LEFT JOIN db_schemas dbs - USING (catalog_name, schema_name) - WHERE catalog_name LIKE '%s' - GROUP BY catalog_name - )", - db_schema_filter, catalog_filter); - break; - case ADBC_OBJECT_DEPTH_TABLES: - // Return metadata on catalogs, schemas, and tables. - query = duckdb::StringUtil::Format(R"( - WITH tables AS ( - SELECT - table_catalog catalog_name, - table_schema schema_name, - LIST({ - table_name: table_name, - table_type: table_type, - table_columns: []::STRUCT( - column_name VARCHAR, - ordinal_position INTEGER, - remarks VARCHAR, - xdbc_data_type SMALLINT, - xdbc_type_name VARCHAR, - xdbc_column_size INTEGER, - xdbc_decimal_digits SMALLINT, - xdbc_num_prec_radix SMALLINT, - xdbc_nullable SMALLINT, - xdbc_column_def VARCHAR, - xdbc_sql_data_type SMALLINT, - xdbc_datetime_sub SMALLINT, - xdbc_char_octet_length INTEGER, - xdbc_is_nullable VARCHAR, - xdbc_scope_catalog VARCHAR, - xdbc_scope_schema VARCHAR, - xdbc_scope_table VARCHAR, - xdbc_is_autoincrement BOOLEAN, - xdbc_is_generatedcolumn BOOLEAN - )[], - table_constraints: []::STRUCT( - constraint_name VARCHAR, - constraint_type VARCHAR, - constraint_column_names VARCHAR[], - constraint_column_usage STRUCT(fk_catalog VARCHAR, fk_db_schema VARCHAR, fk_table VARCHAR, fk_column_name VARCHAR)[] - )[], - }) db_schema_tables - FROM information_schema.tables - WHERE table_name LIKE '%s' - GROUP BY table_catalog, table_schema - ), - db_schemas AS ( - SELECT - catalog_name, - schema_name, - db_schema_tables, - FROM information_schema.schemata - LEFT JOIN tables - USING (catalog_name, schema_name) - WHERE schema_name LIKE '%s' - ) - - SELECT - catalog_name, - LIST({ - db_schema_name: schema_name, - db_schema_tables: db_schema_tables, - }) FILTER (dbs.schema_name is not null) catalog_db_schemas - FROM - information_schema.schemata - LEFT JOIN db_schemas dbs - USING (catalog_name, schema_name) - WHERE catalog_name LIKE '%s' - GROUP BY catalog_name - )", - table_name_filter, db_schema_filter, catalog_filter); - break; - case ADBC_OBJECT_DEPTH_COLUMNS: - // Return metadata on catalogs, schemas, tables, and columns. - query = duckdb::StringUtil::Format(R"( - WITH columns AS ( - SELECT - table_catalog, - table_schema, - table_name, - LIST({ - column_name: column_name, - ordinal_position: ordinal_position, - remarks : '', - xdbc_data_type: NULL::SMALLINT, - xdbc_type_name: NULL::VARCHAR, - xdbc_column_size: NULL::INTEGER, - xdbc_decimal_digits: NULL::SMALLINT, - xdbc_num_prec_radix: NULL::SMALLINT, - xdbc_nullable: NULL::SMALLINT, - xdbc_column_def: NULL::VARCHAR, - xdbc_sql_data_type: NULL::SMALLINT, - xdbc_datetime_sub: NULL::SMALLINT, - xdbc_char_octet_length: NULL::INTEGER, - xdbc_is_nullable: NULL::VARCHAR, - xdbc_scope_catalog: NULL::VARCHAR, - xdbc_scope_schema: NULL::VARCHAR, - xdbc_scope_table: NULL::VARCHAR, - xdbc_is_autoincrement: NULL::BOOLEAN, - xdbc_is_generatedcolumn: NULL::BOOLEAN, - }) table_columns - FROM information_schema.columns - WHERE column_name LIKE '%s' - GROUP BY table_catalog, table_schema, table_name - ), - constraints AS ( - SELECT - table_catalog, - table_schema, - table_name, - LIST( - { - constraint_name: constraint_name, - constraint_type: constraint_type, - constraint_column_names: []::VARCHAR[], - constraint_column_usage: []::STRUCT(fk_catalog VARCHAR, fk_db_schema VARCHAR, fk_table VARCHAR, fk_column_name VARCHAR)[], - } - ) table_constraints - FROM information_schema.table_constraints - GROUP BY table_catalog, table_schema, table_name - ), - tables AS ( - SELECT - table_catalog catalog_name, - table_schema schema_name, - LIST({ - table_name: table_name, - table_type: table_type, - table_columns: table_columns, - table_constraints: table_constraints, - }) db_schema_tables - FROM information_schema.tables - LEFT JOIN columns - USING (table_catalog, table_schema, table_name) - LEFT JOIN constraints - USING (table_catalog, table_schema, table_name) - WHERE table_name LIKE '%s' - GROUP BY table_catalog, table_schema - ), - db_schemas AS ( - SELECT - catalog_name, - schema_name, - db_schema_tables, - FROM information_schema.schemata - LEFT JOIN tables - USING (catalog_name, schema_name) - WHERE schema_name LIKE '%s' - ) - - SELECT - catalog_name, - LIST({ - db_schema_name: schema_name, - db_schema_tables: db_schema_tables, - }) FILTER (dbs.schema_name is not null) catalog_db_schemas - FROM - information_schema.schemata - LEFT JOIN db_schemas dbs - USING (catalog_name, schema_name) - WHERE catalog_name LIKE '%s' - GROUP BY catalog_name - )", - column_name_filter, table_name_filter, db_schema_filter, catalog_filter); - break; - default: - SetError(error, "Invalid value of Depth"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - return QueryInternal(connection, out, query.c_str(), error); -} - -AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *out, - struct AdbcError *error) { - const auto q = "SELECT DISTINCT table_type FROM information_schema.tables ORDER BY table_type"; - return QueryInternal(connection, out, q, error); -} - -} // namespace duckdb_adbc diff --git a/src/duckdb/src/common/adbc/driver_manager.cpp b/src/duckdb/src/common/adbc/driver_manager.cpp deleted file mode 100644 index 1d2bc1f23..000000000 --- a/src/duckdb/src/common/adbc/driver_manager.cpp +++ /dev/null @@ -1,1624 +0,0 @@ -//////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////// -// THIS FILE IS GENERATED BY apache/arrow, DO NOT EDIT MANUALLY // -//////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////// - -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/adbc/driver_manager.h" -#include "duckdb/common/adbc/adbc.h" -#include "duckdb/common/adbc/adbc.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#if defined(_WIN32) -#include // Must come first - -#include -#include -#else -#include -#endif // defined(_WIN32) - -// Platform-specific helpers - -#if defined(_WIN32) -/// Append a description of the Windows error to the buffer. -void GetWinError(std::string *buffer) { - DWORD rc = GetLastError(); - LPVOID message; - - FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - /*lpSource=*/nullptr, rc, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), - reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); - - (*buffer) += '('; - (*buffer) += std::to_string(rc); - (*buffer) += ") "; - (*buffer) += reinterpret_cast(message); - LocalFree(message); -} - -#endif // defined(_WIN32) - -// Error handling - -void ReleaseError(struct AdbcError *error) { - if (error) { - if (error->message) - delete[] error->message; - error->message = nullptr; - error->release = nullptr; - } -} - -void SetError(struct AdbcError *error, const std::string &message) { - if (!error) - return; - if (error->message) { - // Append - std::string buffer = error->message; - buffer.reserve(buffer.size() + message.size() + 1); - buffer += '\n'; - buffer += message; - error->release(error); - - error->message = new char[buffer.size() + 1]; - buffer.copy(error->message, buffer.size()); - error->message[buffer.size()] = '\0'; - } else { - error->message = new char[message.size() + 1]; - message.copy(error->message, message.size()); - error->message[message.size()] = '\0'; - } - error->release = ReleaseError; -} - -// Driver state - -/// A driver DLL. -struct ManagedLibrary { - ManagedLibrary() : handle(nullptr) { - } - ManagedLibrary(ManagedLibrary &&other) : handle(other.handle) { - other.handle = nullptr; - } - ManagedLibrary(const ManagedLibrary &) = delete; - ManagedLibrary &operator=(const ManagedLibrary &) = delete; - ManagedLibrary &operator=(ManagedLibrary &&other) noexcept { - this->handle = other.handle; - other.handle = nullptr; - return *this; - } - - ~ManagedLibrary() { - Release(); - } - - void Release() { - // TODO(apache/arrow-adbc#204): causes tests to segfault - // Need to refcount the driver DLL; also, errors may retain a reference to - // release() from the DLL - how to handle this? - } - - AdbcStatusCode Load(const char *library, struct AdbcError *error) { - std::string error_message; -#if defined(_WIN32) - HMODULE handle = LoadLibraryExA(library, NULL, 0); - if (!handle) { - error_message += library; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - - std::string full_driver_name = library; - full_driver_name += ".dll"; - handle = LoadLibraryExA(full_driver_name.c_str(), NULL, 0); - if (!handle) { - error_message += '\n'; - error_message += full_driver_name; - error_message += ": LoadLibraryExA() failed: "; - GetWinError(&error_message); - } - } - if (!handle) { - SetError(error, error_message); - return ADBC_STATUS_INTERNAL; - } else { - this->handle = handle; - } -#else - const std::string kPlatformLibraryPrefix = "lib"; -#if defined(__APPLE__) - const std::string kPlatformLibrarySuffix = ".dylib"; -#else - static const std::string kPlatformLibrarySuffix = ".so"; -#endif // defined(__APPLE__) - - void *handle = dlopen(library, RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message = "dlopen() failed: "; - error_message += dlerror(); - - // If applicable, append the shared library prefix/extension and - // try again (this way you don't have to hardcode driver names by - // platform in the application) - const std::string driver_str = library; - - std::string full_driver_name; - if (driver_str.size() < kPlatformLibraryPrefix.size() || - driver_str.compare(0, kPlatformLibraryPrefix.size(), kPlatformLibraryPrefix) != 0) { - full_driver_name += kPlatformLibraryPrefix; - } - full_driver_name += library; - if (driver_str.size() < kPlatformLibrarySuffix.size() || - driver_str.compare(full_driver_name.size() - kPlatformLibrarySuffix.size(), - kPlatformLibrarySuffix.size(), kPlatformLibrarySuffix) != 0) { - full_driver_name += kPlatformLibrarySuffix; - } - handle = dlopen(full_driver_name.c_str(), RTLD_NOW | RTLD_LOCAL); - if (!handle) { - error_message += "\ndlopen() failed: "; - error_message += dlerror(); - } - } - if (handle) { - this->handle = handle; - } else { - return ADBC_STATUS_INTERNAL; - } -#endif // defined(_WIN32) - return ADBC_STATUS_OK; - } - - AdbcStatusCode Lookup(const char *name, void **func, struct AdbcError *error) { -#if defined(_WIN32) - void *load_handle = reinterpret_cast(GetProcAddress(handle, name)); - if (!load_handle) { - std::string message = "GetProcAddress("; - message += name; - message += ") failed: "; - GetWinError(&message); - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } -#else - void *load_handle = dlsym(handle, name); - if (!load_handle) { - std::string message = "dlsym("; - message += name; - message += ") failed: "; - message += dlerror(); - SetError(error, message); - return ADBC_STATUS_INTERNAL; - } -#endif // defined(_WIN32) - *func = load_handle; - return ADBC_STATUS_OK; - } - -#if defined(_WIN32) - // The loaded DLL - HMODULE handle; -#else - void *handle; -#endif // defined(_WIN32) -}; - -/// Hold the driver DLL and the driver release callback in the driver struct. -struct ManagerDriverState { - // The original release callback - AdbcStatusCode (*driver_release)(struct AdbcDriver *driver, struct AdbcError *error); - - ManagedLibrary handle; -}; - -/// Unload the driver DLL. -static AdbcStatusCode ReleaseDriver(struct AdbcDriver *driver, struct AdbcError *error) { - AdbcStatusCode status = ADBC_STATUS_OK; - - if (!driver->private_manager) - return status; - ManagerDriverState *state = reinterpret_cast(driver->private_manager); - - if (state->driver_release) { - status = state->driver_release(driver, error); - } - state->handle.Release(); - - driver->private_manager = nullptr; - delete state; - return status; -} - -// ArrowArrayStream wrapper to support AdbcErrorFromArrayStream - -struct ErrorArrayStream { - struct ArrowArrayStream stream; - struct AdbcDriver *private_driver; -}; - -void ErrorArrayStreamRelease(struct ArrowArrayStream *stream) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return; - - auto *private_data = reinterpret_cast(stream->private_data); - private_data->stream.release(&private_data->stream); - delete private_data; - std::memset(stream, 0, sizeof(*stream)); -} - -const char *ErrorArrayStreamGetLastError(struct ArrowArrayStream *stream) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return nullptr; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_last_error(&private_data->stream); -} - -int ErrorArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *array) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return EINVAL; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_next(&private_data->stream, array); -} - -int ErrorArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *schema) { - if (stream->release != ErrorArrayStreamRelease || !stream->private_data) - return EINVAL; - auto *private_data = reinterpret_cast(stream->private_data); - return private_data->stream.get_schema(&private_data->stream, schema); -} - -// Default stubs - -int ErrorGetDetailCount(const struct AdbcError *error) { - return 0; -} - -struct AdbcErrorDetail ErrorGetDetail(const struct AdbcError *error, int index) { - return {nullptr, nullptr, 0}; -} - -const struct AdbcError *ErrorFromArrayStream(struct ArrowArrayStream *stream, AdbcStatusCode *status) { - return nullptr; -} - -void ErrorArrayStreamInit(struct ArrowArrayStream *out, struct AdbcDriver *private_driver) { - if (!out || !out->release || - // Don't bother wrapping if driver didn't claim support - private_driver->ErrorFromArrayStream == ErrorFromArrayStream) { - return; - } - struct ErrorArrayStream *private_data = new ErrorArrayStream; - private_data->stream = *out; - private_data->private_driver = private_driver; - out->get_last_error = ErrorArrayStreamGetLastError; - out->get_next = ErrorArrayStreamGetNext; - out->get_schema = ErrorArrayStreamGetSchema; - out->release = ErrorArrayStreamRelease; - out->private_data = private_data; -} - -AdbcStatusCode DatabaseGetOption(struct AdbcDatabase *database, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionBytes(struct AdbcDatabase *database, const char *key, uint8_t *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionInt(struct AdbcDatabase *database, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseGetOptionDouble(struct AdbcDatabase *database, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode DatabaseSetOptionBytes(struct AdbcDatabase *database, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode DatabaseSetOptionInt(struct AdbcDatabase *database, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode DatabaseSetOptionDouble(struct AdbcDatabase *database, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCancel(struct AdbcConnection *connection, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetOption(struct AdbcConnection *connection, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionBytes(struct AdbcConnection *connection, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionInt(struct AdbcConnection *connection, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetOptionDouble(struct AdbcConnection *connection, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode ConnectionGetStatistics(struct AdbcConnection *, const char *, const char *, const char *, char, - struct ArrowArrayStream *, struct AdbcError *) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetStatisticNames(struct AdbcConnection *, struct ArrowArrayStream *, struct AdbcError *) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionBytes(struct AdbcConnection *, const char *, const uint8_t *, size_t, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionInt(struct AdbcConnection *connection, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOptionDouble(struct AdbcConnection *connection, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementCancel(struct AdbcStatement *statement, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementGetOption(struct AdbcStatement *statement, const char *key, char *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionBytes(struct AdbcStatement *statement, const char *key, uint8_t *value, size_t *length, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionInt(struct AdbcStatement *statement, const char *key, int64_t *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementGetOptionDouble(struct AdbcStatement *statement, const char *key, double *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_FOUND; -} - -AdbcStatusCode StatementSetOptionBytes(struct AdbcStatement *, const char *, const uint8_t *, size_t, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOptionInt(struct AdbcStatement *statement, const char *key, int64_t value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOptionDouble(struct AdbcStatement *statement, const char *key, double value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -/// Temporary state while the database is being configured. -struct TempDatabase { - std::unordered_map options; - std::unordered_map bytes_options; - std::unordered_map int_options; - std::unordered_map double_options; - std::string driver; - std::string entrypoint; - AdbcDriverInitFunc init_func = nullptr; -}; - -/// Temporary state while the database is being configured. -struct TempConnection { - std::unordered_map options; - std::unordered_map bytes_options; - std::unordered_map int_options; - std::unordered_map double_options; -}; - -static const char kDefaultEntrypoint[] = "AdbcDriverInit"; - -// Other helpers (intentionally not in an anonymous namespace so they can be tested) - -ADBC_EXPORT -std::string AdbcDriverManagerDefaultEntrypoint(const std::string &driver) { - /// - libadbc_driver_sqlite.so.2.0.0 -> AdbcDriverSqliteInit - /// - adbc_driver_sqlite.dll -> AdbcDriverSqliteInit - /// - proprietary_driver.dll -> AdbcProprietaryDriverInit - - // Potential path -> filename - // Treat both \ and / as directory separators on all platforms for simplicity - std::string filename; - { - size_t pos = driver.find_last_of("/\\"); - if (pos != std::string::npos) { - filename = driver.substr(pos + 1); - } else { - filename = driver; - } - } - - // Remove all extensions - { - size_t pos = filename.find('.'); - if (pos != std::string::npos) { - filename = filename.substr(0, pos); - } - } - - // Remove lib prefix - // https://stackoverflow.com/q/1878001/262727 - if (filename.rfind("lib", 0) == 0) { - filename = filename.substr(3); - } - - // Split on underscores, hyphens - // Capitalize and join - std::string entrypoint; - entrypoint.reserve(filename.size()); - size_t pos = 0; - while (pos < filename.size()) { - size_t prev = pos; - pos = filename.find_first_of("-_", pos); - // if pos == npos this is the entire filename - std::string token = filename.substr(prev, pos - prev); - // capitalize first letter - token[0] = duckdb::NumericCast(std::toupper(static_cast(token[0]))); - - entrypoint += token; - - if (pos != std::string::npos) { - pos++; - } - } - - if (entrypoint.rfind("Adbc", 0) != 0) { - entrypoint = "Adbc" + entrypoint; - } - entrypoint += "Init"; - - return entrypoint; -} - -// Direct implementations of API methods - -int AdbcErrorGetDetailCount(const struct AdbcError *error) { - if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && error->private_driver) { - return error->private_driver->ErrorGetDetailCount(error); - } - return 0; -} - -struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError *error, int index) { - if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && error->private_driver) { - return error->private_driver->ErrorGetDetail(error, index); - } - return {nullptr, nullptr, 0}; -} - -const struct AdbcError *AdbcErrorFromArrayStream(struct ArrowArrayStream *stream, AdbcStatusCode *status) { - if (!stream->private_data || stream->release != ErrorArrayStreamRelease) { - return nullptr; - } - auto *private_data = reinterpret_cast(stream->private_data); - auto *error = private_data->private_driver->ErrorFromArrayStream(&private_data->stream, status); - if (error) { - const_cast(error)->private_driver = private_data->private_driver; - } - return error; -} - -#define INIT_ERROR(ERROR, SOURCE) \ - if ((ERROR) != nullptr && (ERROR)->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { \ - (ERROR)->private_driver = (SOURCE)->private_driver; \ - } - -#define WRAP_STREAM(EXPR, OUT, SOURCE) \ - if (!(OUT)) { \ - /* Happens for ExecuteQuery where out is optional */ \ - return EXPR; \ - } \ - AdbcStatusCode status_code = EXPR; \ - ErrorArrayStreamInit(OUT, (SOURCE)->private_driver); \ - return status_code; - -AdbcStatusCode DatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionCommit(struct AdbcConnection *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetInfo(struct AdbcConnection *connection, const uint32_t *info_codes, - size_t info_codes_length, struct ArrowArrayStream *out, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetObjects(struct AdbcConnection *, int, const char *, const char *, const char *, - const char **, const char *, struct ArrowArrayStream *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetTableSchema(struct AdbcConnection *, const char *, const char *, const char *, - struct ArrowSchema *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionGetTableTypes(struct AdbcConnection *, struct ArrowArrayStream *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionRollback(struct AdbcConnection *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode ConnectionSetOption(struct AdbcConnection *, const char *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementBind(struct AdbcStatement *, struct ArrowArray *, struct ArrowSchema *, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementExecutePartitions(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementPrepare(struct AdbcStatement *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetOption(struct AdbcStatement *, const char *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetSqlQuery(struct AdbcStatement *, const char *, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode StatementSetSubstraitPlan(struct AdbcStatement *, const uint8_t *, size_t, struct AdbcError *error) { - return ADBC_STATUS_NOT_IMPLEMENTED; -} - -AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase *database, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init - database->private_data = new TempDatabase(); - database->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase *database, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOption(database, key, value, length, error); - } - const auto *args = reinterpret_cast(database->private_data); - const std::string *result = nullptr; - if (std::strcmp(key, "driver") == 0) { - result = &args->driver; - } else if (std::strcmp(key, "entrypoint") == 0) { - result = &args->entrypoint; - } else { - const auto it = args->options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - result = &it->second; - } - - if (*length <= result->size() + 1) { - // Enough space - std::memcpy(value, result->c_str(), result->size() + 1); - } - *length = result->size() + 1; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase *database, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionBytes(database, key, value, length, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->bytes_options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - const std::string &result = it->second; - - if (*length <= result.size()) { - // Enough space - std::memcpy(value, result.c_str(), result.size()); - } - *length = result.size(); - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase *database, const char *key, int64_t *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionInt(database, key, value, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->int_options.find(key); - if (it == args->int_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase *database, const char *key, double *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseGetOptionDouble(database, key, value, error); - } - const auto *args = reinterpret_cast(database->private_data); - const auto it = args->double_options.find(key); - if (it == args->double_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase *database, const char *key, const char *value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOption(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - if (std::strcmp(key, "driver") == 0) { - args->driver = value; - } else if (std::strcmp(key, "entrypoint") == 0) { - args->entrypoint = value; - } else { - args->options[key] = value; - } - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase *database, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionBytes(database, key, value, length, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->bytes_options[key] = std::string(reinterpret_cast(value), length); - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase *database, const char *key, int64_t value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionInt(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->int_options[key] = value; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase *database, const char *key, double value, - struct AdbcError *error) { - if (database->private_driver) { - INIT_ERROR(error, database); - return database->private_driver->DatabaseSetOptionDouble(database, key, value, error); - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->double_options[key] = value; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDriverManagerDatabaseSetInitFunc(struct AdbcDatabase *database, AdbcDriverInitFunc init_func, - struct AdbcError *error) { - if (database->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - - TempDatabase *args = reinterpret_cast(database->private_data); - args->init_func = init_func; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_data) { - SetError(error, "Must call AdbcDatabaseNew first"); - return ADBC_STATUS_INVALID_STATE; - } - TempDatabase *args = reinterpret_cast(database->private_data); - if (args->init_func) { - // Do nothing - } else if (args->driver.empty()) { - SetError(error, "Must provide 'driver' parameter"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - database->private_driver = new AdbcDriver; - std::memset(database->private_driver, 0, sizeof(AdbcDriver)); - AdbcStatusCode status; - // So we don't confuse a driver into thinking it's initialized already - database->private_data = nullptr; - if (args->init_func) { - status = AdbcLoadDriverFromInitFunc(args->init_func, ADBC_VERSION_1_1_0, database->private_driver, error); - } else if (!args->entrypoint.empty()) { - status = AdbcLoadDriver(args->driver.c_str(), args->entrypoint.c_str(), ADBC_VERSION_1_1_0, - database->private_driver, error); - } else { - status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, database->private_driver, error); - } - if (status != ADBC_STATUS_OK) { - // Restore private_data so it will be released by AdbcDatabaseRelease - database->private_data = args; - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - status = database->private_driver->DatabaseNew(database, error); - if (status != ADBC_STATUS_OK) { - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - return status; - } - auto options = std::move(args->options); - auto bytes_options = std::move(args->bytes_options); - auto int_options = std::move(args->int_options); - auto double_options = std::move(args->double_options); - delete args; - - INIT_ERROR(error, database); - for (const auto &option : options) { - status = - database->private_driver->DatabaseSetOption(database, option.first.c_str(), option.second.c_str(), error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : bytes_options) { - status = database->private_driver->DatabaseSetOptionBytes( - database, option.first.c_str(), reinterpret_cast(option.second.data()), - option.second.size(), error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : int_options) { - status = database->private_driver->DatabaseSetOptionInt(database, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - break; - } - for (const auto &option : double_options) { - status = - database->private_driver->DatabaseSetOptionDouble(database, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - break; - } - - if (status != ADBC_STATUS_OK) { - // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_driver = nullptr; - // Should be redundant, but ensure that AdbcDatabaseRelease - // below doesn't think that it contains a TempDatabase - database->private_data = nullptr; - return status; - } - return database->private_driver->DatabaseInit(database, error); -} - -AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase *database, struct AdbcError *error) { - if (!database->private_driver) { - if (database->private_data) { - TempDatabase *args = reinterpret_cast(database->private_data); - delete args; - database->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, database); - auto status = database->private_driver->DatabaseRelease(database, error); - if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); - } - delete database->private_driver; - database->private_data = nullptr; - database->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionCancel(connection, error); -} - -AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionCommit(connection, error); -} - -AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection *connection, const uint32_t *info_codes, - size_t info_codes_length, struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetInfo(connection, info_codes, info_codes_length, out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection *connection, int depth, const char *catalog, - const char *db_schema, const char *table_name, const char **table_types, - const char *column_name, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetObjects(connection, depth, catalog, db_schema, table_name, - table_types, column_name, stream, error), - stream, connection); -} - -AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection *connection, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - if (*length >= it->second.size() + 1) { - std::memcpy(value, it->second.c_str(), it->second.size() + 1); - } - *length = it->second.size() + 1; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOption(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection *connection, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->bytes_options.find(key); - if (it == args->options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - if (*length >= it->second.size() + 1) { - std::memcpy(value, it->second.data(), it->second.size() + 1); - } - *length = it->second.size() + 1; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionBytes(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection *connection, const char *key, int64_t *value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->int_options.find(key); - if (it == args->int_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionInt(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection *connection, const char *key, double *value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionGetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, get the saved option - const auto *args = reinterpret_cast(connection->private_data); - const auto it = args->double_options.find(key); - if (it == args->double_options.end()) { - return ADBC_STATUS_NOT_FOUND; - } - *value = it->second; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetOptionDouble(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection *connection, const char *catalog, - const char *db_schema, const char *table_name, char approximate, - struct ArrowArrayStream *out, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetStatistics(connection, catalog, db_schema, table_name, - approximate == 1, out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection *connection, struct ArrowArrayStream *out, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetStatisticNames(connection, out, error), out, connection); -} - -AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection *connection, const char *catalog, - const char *db_schema, const char *table_name, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionGetTableSchema(connection, catalog, db_schema, table_name, schema, - error); -} - -AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection *connection, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionGetTableTypes(connection, stream, error), stream, connection); -} - -AdbcStatusCode AdbcConnectionInit(struct AdbcConnection *connection, struct AdbcDatabase *database, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "Must call AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } else if (!database->private_driver) { - SetError(error, "Database is not initialized"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - TempConnection *args = reinterpret_cast(connection->private_data); - connection->private_data = nullptr; - std::unordered_map options = std::move(args->options); - std::unordered_map bytes_options = std::move(args->bytes_options); - std::unordered_map int_options = std::move(args->int_options); - std::unordered_map double_options = std::move(args->double_options); - delete args; - - auto status = database->private_driver->ConnectionNew(connection, error); - if (status != ADBC_STATUS_OK) - return status; - connection->private_driver = database->private_driver; - - for (const auto &option : options) { - status = database->private_driver->ConnectionSetOption(connection, option.first.c_str(), option.second.c_str(), - error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : bytes_options) { - status = database->private_driver->ConnectionSetOptionBytes( - connection, option.first.c_str(), reinterpret_cast(option.second.data()), - option.second.size(), error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : int_options) { - status = - database->private_driver->ConnectionSetOptionInt(connection, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - return status; - } - for (const auto &option : double_options) { - status = - database->private_driver->ConnectionSetOptionDouble(connection, option.first.c_str(), option.second, error); - if (status != ADBC_STATUS_OK) - return status; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionInit(connection, database, error); -} - -AdbcStatusCode AdbcConnectionNew(struct AdbcConnection *connection, struct AdbcError *error) { - // Allocate a temporary structure to store options pre-Init, because - // we don't get access to the database (and hence the driver - // function table) until then - connection->private_data = new TempConnection; - connection->private_driver = nullptr; - return ADBC_STATUS_OK; -} - -AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection *connection, const uint8_t *serialized_partition, - size_t serialized_length, struct ArrowArrayStream *out, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - WRAP_STREAM(connection->private_driver->ConnectionReadPartition(connection, serialized_partition, serialized_length, - out, error), - out, connection); -} - -AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - if (connection->private_data) { - TempConnection *args = reinterpret_cast(connection->private_data); - delete args; - connection->private_data = nullptr; - return ADBC_STATUS_OK; - } - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - auto status = connection->private_driver->ConnectionRelease(connection, error); - connection->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection *connection, struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionRollback(connection, error); -} - -AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection *connection, const char *key, const char *value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOption: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOption(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection *connection, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->bytes_options[key] = std::string(reinterpret_cast(value), length); - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionBytes(connection, key, value, length, error); -} - -AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection *connection, const char *key, int64_t value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionInt: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->int_options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionInt(connection, key, value, error); -} - -AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection *connection, const char *key, double value, - struct AdbcError *error) { - if (!connection->private_data) { - SetError(error, "AdbcConnectionSetOptionDouble: must AdbcConnectionNew first"); - return ADBC_STATUS_INVALID_STATE; - } - if (!connection->private_driver) { - // Init not yet called, save the option - TempConnection *args = reinterpret_cast(connection->private_data); - args->double_options[key] = value; - return ADBC_STATUS_OK; - } - INIT_ERROR(error, connection); - return connection->private_driver->ConnectionSetOptionDouble(connection, key, value, error); -} - -AdbcStatusCode AdbcStatementBind(struct AdbcStatement *statement, struct ArrowArray *values, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementBind(statement, values, schema, error); -} - -AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement *statement, struct ArrowArrayStream *stream, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementBindStream(statement, stream, error); -} - -AdbcStatusCode AdbcStatementCancel(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementCancel(statement, error); -} - -// XXX: cpplint gets confused here if declared as 'struct ArrowSchema* schema' -AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement *statement, ArrowSchema *schema, - struct AdbcPartitions *partitions, int64_t *rows_affected, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementExecutePartitions(statement, schema, partitions, rows_affected, error); -} - -AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement *statement, struct ArrowArrayStream *out, - int64_t *rows_affected, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - WRAP_STREAM(statement->private_driver->StatementExecuteQuery(statement, out, rows_affected, error), out, statement); -} - -AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementExecuteSchema(statement, schema, error); -} - -AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement *statement, const char *key, char *value, size_t *length, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOption(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement *statement, const char *key, uint8_t *value, - size_t *length, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionBytes(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement *statement, const char *key, int64_t *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionInt(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement *statement, const char *key, double *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetOptionDouble(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement *statement, struct ArrowSchema *schema, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementGetParameterSchema(statement, schema, error); -} - -AdbcStatusCode AdbcStatementNew(struct AdbcConnection *connection, struct AdbcStatement *statement, - struct AdbcError *error) { - if (!connection->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, connection); - auto status = connection->private_driver->StatementNew(connection, statement, error); - statement->private_driver = connection->private_driver; - return status; -} - -AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementPrepare(statement, error); -} - -AdbcStatusCode AdbcStatementRelease(struct AdbcStatement *statement, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - auto status = statement->private_driver->StatementRelease(statement, error); - statement->private_driver = nullptr; - return status; -} - -AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement *statement, const char *key, const char *value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOption(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement *statement, const char *key, const uint8_t *value, - size_t length, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionBytes(statement, key, value, length, error); -} - -AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement *statement, const char *key, int64_t value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionInt(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement *statement, const char *key, double value, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetOptionDouble(statement, key, value, error); -} - -AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement *statement, const char *query, struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetSqlQuery(statement, query, error); -} - -AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement *statement, const uint8_t *plan, size_t length, - struct AdbcError *error) { - if (!statement->private_driver) { - return ADBC_STATUS_INVALID_STATE; - } - INIT_ERROR(error, statement); - return statement->private_driver->StatementSetSubstraitPlan(statement, plan, length, error); -} - -const char *AdbcStatusCodeMessage(AdbcStatusCode code) { -#define CASE(CONSTANT) \ - case ADBC_STATUS_##CONSTANT: \ - return #CONSTANT; - - switch (code) { - CASE(OK); - CASE(UNKNOWN); - CASE(NOT_IMPLEMENTED); - CASE(NOT_FOUND); - CASE(ALREADY_EXISTS); - CASE(INVALID_ARGUMENT); - CASE(INVALID_STATE); - CASE(INVALID_DATA); - CASE(INTEGRITY); - CASE(INTERNAL); - CASE(IO); - CASE(CANCELLED); - CASE(TIMEOUT); - CASE(UNAUTHENTICATED); - CASE(UNAUTHORIZED); - default: - return "(invalid code)"; - } -#undef CASE -} - -AdbcStatusCode AdbcLoadDriver(const char *driver_name, const char *entrypoint, int version, void *raw_driver, - struct AdbcError *error) { - AdbcDriverInitFunc init_func; - std::string error_message; - - switch (version) { - case ADBC_VERSION_1_0_0: - case ADBC_VERSION_1_1_0: - break; - default: - SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - if (!raw_driver) { - SetError(error, "Must provide non-NULL raw_driver"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - auto *driver = reinterpret_cast(raw_driver); - - ManagedLibrary library; - AdbcStatusCode status = library.Load(driver_name, error); - if (status != ADBC_STATUS_OK) { - // AdbcDatabaseInit tries to call this if set - driver->release = nullptr; - return status; - } - - void *load_handle = nullptr; - if (entrypoint) { - status = library.Lookup(entrypoint, &load_handle, error); - } else { - auto name = AdbcDriverManagerDefaultEntrypoint(driver_name); - status = library.Lookup(name.c_str(), &load_handle, error); - if (status != ADBC_STATUS_OK) { - status = library.Lookup(kDefaultEntrypoint, &load_handle, error); - } - } - - if (status != ADBC_STATUS_OK) { - library.Release(); - return status; - } - init_func = reinterpret_cast(load_handle); - - status = AdbcLoadDriverFromInitFunc(init_func, version, driver, error); - if (status == ADBC_STATUS_OK) { - ManagerDriverState *state = new ManagerDriverState; - state->driver_release = driver->release; - state->handle = std::move(library); - driver->release = &ReleaseDriver; - driver->private_manager = state; - } else { - library.Release(); - } - return status; -} - -AdbcStatusCode AdbcLoadDriverFromInitFunc(AdbcDriverInitFunc init_func, int version, void *raw_driver, - struct AdbcError *error) { - constexpr std::array kSupportedVersions = { - ADBC_VERSION_1_1_0, - ADBC_VERSION_1_0_0, - }; - - if (!raw_driver) { - SetError(error, "Must provide non-NULL raw_driver"); - return ADBC_STATUS_INVALID_ARGUMENT; - } - - switch (version) { - case ADBC_VERSION_1_0_0: - case ADBC_VERSION_1_1_0: - break; - default: - SetError(error, "Only ADBC 1.0.0 and 1.1.0 are supported"); - return ADBC_STATUS_NOT_IMPLEMENTED; - } - -#define FILL_DEFAULT(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - DRIVER->STUB = &STUB; \ - } -#define CHECK_REQUIRED(DRIVER, STUB) \ - if (!DRIVER->STUB) { \ - SetError(error, "Driver does not implement required function Adbc" #STUB); \ - return ADBC_STATUS_INTERNAL; \ - } - - // Starting from the passed version, try each (older) version in - // succession with the underlying driver until we find one that's - // accepted. - AdbcStatusCode result = ADBC_STATUS_NOT_IMPLEMENTED; - for (const int try_version : kSupportedVersions) { - if (try_version > version) - continue; - result = init_func(try_version, raw_driver, error); - if (result != ADBC_STATUS_NOT_IMPLEMENTED) - break; - } - if (result != ADBC_STATUS_OK) { - return result; - } - - if (version >= ADBC_VERSION_1_0_0) { - auto *driver = reinterpret_cast(raw_driver); - CHECK_REQUIRED(driver, DatabaseNew); - CHECK_REQUIRED(driver, DatabaseInit); - CHECK_REQUIRED(driver, DatabaseRelease); - FILL_DEFAULT(driver, DatabaseSetOption); - - CHECK_REQUIRED(driver, ConnectionNew); - CHECK_REQUIRED(driver, ConnectionInit); - CHECK_REQUIRED(driver, ConnectionRelease); - FILL_DEFAULT(driver, ConnectionCommit); - FILL_DEFAULT(driver, ConnectionGetInfo); - FILL_DEFAULT(driver, ConnectionGetObjects); - FILL_DEFAULT(driver, ConnectionGetTableSchema); - FILL_DEFAULT(driver, ConnectionGetTableTypes); - FILL_DEFAULT(driver, ConnectionReadPartition); - FILL_DEFAULT(driver, ConnectionRollback); - FILL_DEFAULT(driver, ConnectionSetOption); - - FILL_DEFAULT(driver, StatementExecutePartitions); - CHECK_REQUIRED(driver, StatementExecuteQuery); - CHECK_REQUIRED(driver, StatementNew); - CHECK_REQUIRED(driver, StatementRelease); - FILL_DEFAULT(driver, StatementBind); - FILL_DEFAULT(driver, StatementGetParameterSchema); - FILL_DEFAULT(driver, StatementPrepare); - FILL_DEFAULT(driver, StatementSetOption); - FILL_DEFAULT(driver, StatementSetSqlQuery); - FILL_DEFAULT(driver, StatementSetSubstraitPlan); - } - if (version >= ADBC_VERSION_1_1_0) { - auto *driver = reinterpret_cast(raw_driver); - FILL_DEFAULT(driver, ErrorGetDetailCount); - FILL_DEFAULT(driver, ErrorGetDetail); - FILL_DEFAULT(driver, ErrorFromArrayStream); - - FILL_DEFAULT(driver, DatabaseGetOption); - FILL_DEFAULT(driver, DatabaseGetOptionBytes); - FILL_DEFAULT(driver, DatabaseGetOptionDouble); - FILL_DEFAULT(driver, DatabaseGetOptionInt); - FILL_DEFAULT(driver, DatabaseSetOptionBytes); - FILL_DEFAULT(driver, DatabaseSetOptionDouble); - FILL_DEFAULT(driver, DatabaseSetOptionInt); - - FILL_DEFAULT(driver, ConnectionCancel); - FILL_DEFAULT(driver, ConnectionGetOption); - FILL_DEFAULT(driver, ConnectionGetOptionBytes); - FILL_DEFAULT(driver, ConnectionGetOptionDouble); - FILL_DEFAULT(driver, ConnectionGetOptionInt); - FILL_DEFAULT(driver, ConnectionGetStatistics); - FILL_DEFAULT(driver, ConnectionGetStatisticNames); - FILL_DEFAULT(driver, ConnectionSetOptionBytes); - FILL_DEFAULT(driver, ConnectionSetOptionDouble); - FILL_DEFAULT(driver, ConnectionSetOptionInt); - - FILL_DEFAULT(driver, StatementCancel); - FILL_DEFAULT(driver, StatementExecuteSchema); - FILL_DEFAULT(driver, StatementGetOption); - FILL_DEFAULT(driver, StatementGetOptionBytes); - FILL_DEFAULT(driver, StatementGetOptionDouble); - FILL_DEFAULT(driver, StatementGetOptionInt); - FILL_DEFAULT(driver, StatementSetOptionBytes); - FILL_DEFAULT(driver, StatementSetOptionDouble); - FILL_DEFAULT(driver, StatementSetOptionInt); - } - - return ADBC_STATUS_OK; - -#undef FILL_DEFAULT -#undef CHECK_REQUIRED -} diff --git a/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp b/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp deleted file mode 100644 index cea53b188..000000000 --- a/src/duckdb/src/common/adbc/nanoarrow/allocator.cpp +++ /dev/null @@ -1,57 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include - -#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" - -namespace duckdb_nanoarrow { - -void *ArrowMalloc(int64_t size) { - return malloc(size_t(size)); -} - -void *ArrowRealloc(void *ptr, int64_t size) { - return realloc(ptr, size_t(size)); -} - -void ArrowFree(void *ptr) { - free(ptr); -} - -static uint8_t *ArrowBufferAllocatorMallocAllocate(struct ArrowBufferAllocator *allocator, int64_t size) { - return (uint8_t *)ArrowMalloc(size); -} - -static uint8_t *ArrowBufferAllocatorMallocReallocate(struct ArrowBufferAllocator *allocator, uint8_t *ptr, - int64_t old_size, int64_t new_size) { - return (uint8_t *)ArrowRealloc(ptr, new_size); -} - -static void ArrowBufferAllocatorMallocFree(struct ArrowBufferAllocator *allocator, uint8_t *ptr, int64_t size) { - ArrowFree(ptr); -} - -static struct ArrowBufferAllocator ArrowBufferAllocatorMalloc = { - &ArrowBufferAllocatorMallocAllocate, &ArrowBufferAllocatorMallocReallocate, &ArrowBufferAllocatorMallocFree, NULL}; - -struct ArrowBufferAllocator *ArrowBufferAllocatorDefault() { - return &ArrowBufferAllocatorMalloc; -} - -} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp b/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp deleted file mode 100644 index cb3009c01..000000000 --- a/src/duckdb/src/common/adbc/nanoarrow/metadata.cpp +++ /dev/null @@ -1,121 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include - -#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" - -namespace duckdb_nanoarrow { - -ArrowErrorCode ArrowMetadataReaderInit(struct ArrowMetadataReader *reader, const char *metadata) { - reader->metadata = metadata; - - if (reader->metadata == NULL) { - reader->offset = 0; - reader->remaining_keys = 0; - } else { - memcpy(&reader->remaining_keys, reader->metadata, sizeof(int32_t)); - reader->offset = sizeof(int32_t); - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowMetadataReaderRead(struct ArrowMetadataReader *reader, struct ArrowStringView *key_out, - struct ArrowStringView *value_out) { - if (reader->remaining_keys <= 0) { - return EINVAL; - } - - int64_t pos = 0; - - int32_t key_size; - memcpy(&key_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); - pos += sizeof(int32_t); - - key_out->data = reader->metadata + reader->offset + pos; - key_out->n_bytes = key_size; - pos += key_size; - - int32_t value_size; - memcpy(&value_size, reader->metadata + reader->offset + pos, sizeof(int32_t)); - pos += sizeof(int32_t); - - value_out->data = reader->metadata + reader->offset + pos; - value_out->n_bytes = value_size; - pos += value_size; - - reader->offset += pos; - reader->remaining_keys--; - return NANOARROW_OK; -} - -int64_t ArrowMetadataSizeOf(const char *metadata) { - if (metadata == NULL) { - return 0; - } - - struct ArrowMetadataReader reader; - struct ArrowStringView key; - struct ArrowStringView value; - ArrowMetadataReaderInit(&reader, metadata); - - int64_t size = sizeof(int32_t); - while (ArrowMetadataReaderRead(&reader, &key, &value) == NANOARROW_OK) { - size += sizeof(int32_t) + uint64_t(key.n_bytes) + sizeof(int32_t) + uint64_t(value.n_bytes); - } - - return size; -} - -ArrowErrorCode ArrowMetadataGetValue(const char *metadata, const char *key, const char *default_value, - struct ArrowStringView *value_out) { - struct ArrowStringView target_key_view = {key, static_cast(strlen(key))}; - value_out->data = default_value; - if (default_value != NULL) { - value_out->n_bytes = int64_t(strlen(default_value)); - } else { - value_out->n_bytes = 0; - } - - struct ArrowMetadataReader reader; - struct ArrowStringView key_view; - struct ArrowStringView value; - ArrowMetadataReaderInit(&reader, metadata); - - while (ArrowMetadataReaderRead(&reader, &key_view, &value) == NANOARROW_OK) { - int key_equal = target_key_view.n_bytes == key_view.n_bytes && - strncmp(target_key_view.data, key_view.data, size_t(key_view.n_bytes)) == 0; - if (key_equal) { - value_out->data = value.data; - value_out->n_bytes = value.n_bytes; - break; - } - } - - return NANOARROW_OK; -} - -char ArrowMetadataHasKey(const char *metadata, const char *key) { - struct ArrowStringView value; - ArrowMetadataGetValue(metadata, key, NULL, &value); - return value.data != NULL; -} - -} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/schema.cpp b/src/duckdb/src/common/adbc/nanoarrow/schema.cpp deleted file mode 100644 index 38d1b314f..000000000 --- a/src/duckdb/src/common/adbc/nanoarrow/schema.cpp +++ /dev/null @@ -1,475 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include -#include -#include -#include - -#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" - -namespace duckdb_nanoarrow { - -void ArrowSchemaRelease(struct ArrowSchema *schema) { - if (schema->format != NULL) - ArrowFree((void *)schema->format); - if (schema->name != NULL) - ArrowFree((void *)schema->name); - if (schema->metadata != NULL) - ArrowFree((void *)schema->metadata); - - // This object owns the memory for all the children, but those - // children may have been generated elsewhere and might have - // their own release() callback. - if (schema->children != NULL) { - for (int64_t i = 0; i < schema->n_children; i++) { - if (schema->children[i] != NULL) { - if (schema->children[i]->release != NULL) { - schema->children[i]->release(schema->children[i]); - } - - ArrowFree(schema->children[i]); - } - } - - ArrowFree(schema->children); - } - - // This object owns the memory for the dictionary but it - // may have been generated somewhere else and have its own - // release() callback. - if (schema->dictionary != NULL) { - if (schema->dictionary->release != NULL) { - schema->dictionary->release(schema->dictionary); - } - - ArrowFree(schema->dictionary); - } - - // private data not currently used - if (schema->private_data != NULL) { - ArrowFree(schema->private_data); - } - - schema->release = NULL; -} - -const char *ArrowSchemaFormatTemplate(enum ArrowType data_type) { - switch (data_type) { - case NANOARROW_TYPE_UNINITIALIZED: - return NULL; - case NANOARROW_TYPE_NA: - return "n"; - case NANOARROW_TYPE_BOOL: - return "b"; - - case NANOARROW_TYPE_UINT8: - return "C"; - case NANOARROW_TYPE_INT8: - return "c"; - case NANOARROW_TYPE_UINT16: - return "S"; - case NANOARROW_TYPE_INT16: - return "s"; - case NANOARROW_TYPE_UINT32: - return "I"; - case NANOARROW_TYPE_INT32: - return "i"; - case NANOARROW_TYPE_UINT64: - return "L"; - case NANOARROW_TYPE_INT64: - return "l"; - - case NANOARROW_TYPE_HALF_FLOAT: - return "e"; - case NANOARROW_TYPE_FLOAT: - return "f"; - case NANOARROW_TYPE_DOUBLE: - return "g"; - - case NANOARROW_TYPE_STRING: - return "u"; - case NANOARROW_TYPE_LARGE_STRING: - return "U"; - case NANOARROW_TYPE_BINARY: - return "z"; - case NANOARROW_TYPE_LARGE_BINARY: - return "Z"; - - case NANOARROW_TYPE_DATE32: - return "tdD"; - case NANOARROW_TYPE_DATE64: - return "tdm"; - case NANOARROW_TYPE_INTERVAL_MONTHS: - return "tiM"; - case NANOARROW_TYPE_INTERVAL_DAY_TIME: - return "tiD"; - case NANOARROW_TYPE_INTERVAL_MONTH_DAY_NANO: - return "tin"; - - case NANOARROW_TYPE_LIST: - return "+l"; - case NANOARROW_TYPE_LARGE_LIST: - return "+L"; - case NANOARROW_TYPE_STRUCT: - return "+s"; - case NANOARROW_TYPE_MAP: - return "+m"; - - default: - return NULL; - } -} - -ArrowErrorCode ArrowSchemaInit(struct ArrowSchema *schema, enum ArrowType data_type) { - schema->format = NULL; - schema->name = NULL; - schema->metadata = NULL; - schema->flags = ARROW_FLAG_NULLABLE; - schema->n_children = 0; - schema->children = NULL; - schema->dictionary = NULL; - schema->private_data = NULL; - schema->release = &ArrowSchemaRelease; - - // We don't allocate the dictionary because it has to be nullptr - // for non-dictionary-encoded arrays. - - // Set the format to a valid format string for data_type - const char *template_format = ArrowSchemaFormatTemplate(data_type); - - // If data_type isn't recognized and not explicitly unset - if (template_format == NULL && data_type != NANOARROW_TYPE_UNINITIALIZED) { - schema->release(schema); - return EINVAL; - } - - int result = ArrowSchemaSetFormat(schema, template_format); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaInitFixedSize(struct ArrowSchema *schema, enum ArrowType data_type, int32_t fixed_size) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - if (fixed_size <= 0) { - schema->release(schema); - return EINVAL; - } - - char buffer[64]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_FIXED_SIZE_BINARY: - n_chars = snprintf(buffer, sizeof(buffer), "w:%d", (int)fixed_size); - break; - case NANOARROW_TYPE_FIXED_SIZE_LIST: - n_chars = snprintf(buffer, sizeof(buffer), "+w:%d", (int)fixed_size); - break; - default: - schema->release(schema); - return EINVAL; - } - - buffer[n_chars] = '\0'; - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - } - - return result; -} - -ArrowErrorCode ArrowSchemaInitDecimal(struct ArrowSchema *schema, enum ArrowType data_type, int32_t decimal_precision, - int32_t decimal_scale) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - if (decimal_precision <= 0) { - schema->release(schema); - return EINVAL; - } - - char buffer[64]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_DECIMAL128: - n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d", decimal_precision, decimal_scale); - break; - case NANOARROW_TYPE_DECIMAL256: - n_chars = snprintf(buffer, sizeof(buffer), "d:%d,%d,256", decimal_precision, decimal_scale); - break; - default: - schema->release(schema); - return EINVAL; - } - - buffer[n_chars] = '\0'; - - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -static const char *ArrowTimeUnitString(enum ArrowTimeUnit time_unit) { - switch (time_unit) { - case NANOARROW_TIME_UNIT_SECOND: - return "s"; - case NANOARROW_TIME_UNIT_MILLI: - return "m"; - case NANOARROW_TIME_UNIT_MICRO: - return "u"; - case NANOARROW_TIME_UNIT_NANO: - return "n"; - default: - return NULL; - } -} - -ArrowErrorCode ArrowSchemaInitDateTime(struct ArrowSchema *schema, enum ArrowType data_type, - enum ArrowTimeUnit time_unit, const char *timezone) { - int result = ArrowSchemaInit(schema, NANOARROW_TYPE_UNINITIALIZED); - if (result != NANOARROW_OK) { - return result; - } - - const char *time_unit_str = ArrowTimeUnitString(time_unit); - if (time_unit_str == NULL) { - schema->release(schema); - return EINVAL; - } - - char buffer[128]; - int n_chars; - switch (data_type) { - case NANOARROW_TYPE_TIME32: - case NANOARROW_TYPE_TIME64: - if (timezone != NULL) { - schema->release(schema); - return EINVAL; - } - n_chars = snprintf(buffer, sizeof(buffer), "tt%s", time_unit_str); - break; - case NANOARROW_TYPE_TIMESTAMP: - if (timezone == NULL) { - timezone = ""; - } - n_chars = snprintf(buffer, sizeof(buffer), "ts%s:%s", time_unit_str, timezone); - break; - case NANOARROW_TYPE_DURATION: - if (timezone != NULL) { - schema->release(schema); - return EINVAL; - } - n_chars = snprintf(buffer, sizeof(buffer), "tD%s", time_unit_str); - break; - default: - schema->release(schema); - return EINVAL; - } - - if (static_cast(n_chars) >= sizeof(buffer)) { - schema->release(schema); - return ERANGE; - } - - buffer[n_chars] = '\0'; - - result = ArrowSchemaSetFormat(schema, buffer); - if (result != NANOARROW_OK) { - schema->release(schema); - return result; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetFormat(struct ArrowSchema *schema, const char *format) { - if (schema->format != NULL) { - ArrowFree((void *)schema->format); - } - - if (format != NULL) { - size_t format_size = strlen(format) + 1; - schema->format = (const char *)ArrowMalloc(int64_t(format_size)); - if (schema->format == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->format, format, format_size); - } else { - schema->format = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetName(struct ArrowSchema *schema, const char *name) { - if (schema->name != NULL) { - ArrowFree((void *)schema->name); - } - - if (name != NULL) { - size_t name_size = strlen(name) + 1; - schema->name = (const char *)ArrowMalloc(int64_t(name_size)); - if (schema->name == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->name, name, name_size); - } else { - schema->name = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaSetMetadata(struct ArrowSchema *schema, const char *metadata) { - if (schema->metadata != NULL) { - ArrowFree((void *)schema->metadata); - } - - if (metadata != NULL) { - auto metadata_size = ArrowMetadataSizeOf(metadata); - schema->metadata = (const char *)ArrowMalloc(metadata_size); - if (schema->metadata == NULL) { - return ENOMEM; - } - - memcpy((void *)schema->metadata, metadata, size_t(metadata_size)); - } else { - schema->metadata = NULL; - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaAllocateChildren(struct ArrowSchema *schema, int64_t n_children) { - if (schema->children != NULL) { - return EEXIST; - } - - if (n_children > 0) { - schema->children = - (struct ArrowSchema **)ArrowMalloc(int64_t(uint64_t(n_children) * sizeof(struct ArrowSchema *))); - - if (schema->children == NULL) { - return ENOMEM; - } - - schema->n_children = n_children; - - memset(schema->children, 0, uint64_t(n_children) * sizeof(struct ArrowSchema *)); - - for (int64_t i = 0; i < n_children; i++) { - schema->children[i] = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); - - if (schema->children[i] == NULL) { - return ENOMEM; - } - - schema->children[i]->release = NULL; - } - } - - return NANOARROW_OK; -} - -ArrowErrorCode ArrowSchemaAllocateDictionary(struct ArrowSchema *schema) { - if (schema->dictionary != NULL) { - return EEXIST; - } - - schema->dictionary = (struct ArrowSchema *)ArrowMalloc(sizeof(struct ArrowSchema)); - if (schema->dictionary == NULL) { - return ENOMEM; - } - - schema->dictionary->release = NULL; - return NANOARROW_OK; -} - -int ArrowSchemaDeepCopy(struct ArrowSchema *schema, struct ArrowSchema *schema_out) { - int result; - result = ArrowSchemaInit(schema_out, NANOARROW_TYPE_NA); - if (result != NANOARROW_OK) { - return result; - } - - result = ArrowSchemaSetFormat(schema_out, schema->format); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaSetName(schema_out, schema->name); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaSetMetadata(schema_out, schema->metadata); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaAllocateChildren(schema_out, schema->n_children); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - for (int64_t i = 0; i < schema->n_children; i++) { - result = ArrowSchemaDeepCopy(schema->children[i], schema_out->children[i]); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - } - - if (schema->dictionary != NULL) { - result = ArrowSchemaAllocateDictionary(schema_out); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - - result = ArrowSchemaDeepCopy(schema->dictionary, schema_out->dictionary); - if (result != NANOARROW_OK) { - schema_out->release(schema_out); - return result; - } - } - - return NANOARROW_OK; -} - -} // namespace duckdb_nanoarrow diff --git a/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp b/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp deleted file mode 100644 index bddcd4e09..000000000 --- a/src/duckdb/src/common/adbc/nanoarrow/single_batch_array_stream.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "duckdb/common/adbc/single_batch_array_stream.hpp" -#include "duckdb/common/arrow/nanoarrow/nanoarrow.h" -#include "duckdb/common/adbc/adbc.hpp" - -#include "duckdb.h" -#include "duckdb/common/arrow/arrow_wrapper.hpp" -#include "duckdb/common/arrow/nanoarrow/nanoarrow.hpp" - -#include -#include -#include -#include -#include - -namespace duckdb_adbc { - -using duckdb_nanoarrow::ArrowSchemaDeepCopy; - -static const char *SingleBatchArrayStreamGetLastError(struct ArrowArrayStream *stream) { - return NULL; -} - -static int SingleBatchArrayStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *batch) { - if (!stream || !stream->private_data) { - return EINVAL; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - - memcpy(batch, &impl->batch, sizeof(*batch)); - memset(&impl->batch, 0, sizeof(*batch)); - return 0; -} - -static int SingleBatchArrayStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *schema) { - if (!stream || !stream->private_data) { - return EINVAL; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - - return ArrowSchemaDeepCopy(&impl->schema, schema); -} - -static void SingleBatchArrayStreamRelease(struct ArrowArrayStream *stream) { - if (!stream || !stream->private_data) { - return; - } - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)stream->private_data; - impl->schema.release(&impl->schema); - if (impl->batch.release) { - impl->batch.release(&impl->batch); - } - free(impl); - - memset(stream, 0, sizeof(*stream)); -} - -AdbcStatusCode BatchToArrayStream(struct ArrowArray *values, struct ArrowSchema *schema, - struct ArrowArrayStream *stream, struct AdbcError *error) { - if (!values->release) { - SetError(error, "ArrowArray is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (!schema->release) { - SetError(error, "ArrowSchema is not initialized"); - return ADBC_STATUS_INTERNAL; - } else if (stream->release) { - SetError(error, "ArrowArrayStream is already initialized"); - return ADBC_STATUS_INTERNAL; - } - - struct SingleBatchArrayStream *impl = (struct SingleBatchArrayStream *)malloc(sizeof(*impl)); - memcpy(&impl->schema, schema, sizeof(*schema)); - memcpy(&impl->batch, values, sizeof(*values)); - memset(schema, 0, sizeof(*schema)); - memset(values, 0, sizeof(*values)); - stream->private_data = impl; - stream->get_last_error = SingleBatchArrayStreamGetLastError; - stream->get_next = SingleBatchArrayStreamGetNext; - stream->get_schema = SingleBatchArrayStreamGetSchema; - stream->release = SingleBatchArrayStreamRelease; - - return ADBC_STATUS_OK; -} - -} // namespace duckdb_adbc diff --git a/src/duckdb/src/common/allocator.cpp b/src/duckdb/src/common/allocator.cpp deleted file mode 100644 index 977087939..000000000 --- a/src/duckdb/src/common/allocator.cpp +++ /dev/null @@ -1,343 +0,0 @@ -#include "duckdb/common/allocator.hpp" - -#include "duckdb/common/assert.hpp" -#include "duckdb/common/atomic.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/timestamp.hpp" - -#include - -#ifdef DUCKDB_DEBUG_ALLOCATION -#include "duckdb/common/mutex.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/unordered_map.hpp" - -#include -#endif - -#ifndef USE_JEMALLOC -#if defined(DUCKDB_EXTENSION_JEMALLOC_LINKED) && DUCKDB_EXTENSION_JEMALLOC_LINKED && !defined(WIN32) && \ - INTPTR_MAX == INT64_MAX -#define USE_JEMALLOC -#endif -#endif - -#ifdef USE_JEMALLOC -#include "jemalloc_extension.hpp" -#endif - -#ifdef __GLIBC__ -#include -#endif - -namespace duckdb { - -AllocatedData::AllocatedData() : allocator(nullptr), pointer(nullptr), allocated_size(0) { -} - -AllocatedData::AllocatedData(Allocator &allocator, data_ptr_t pointer, idx_t allocated_size) - : allocator(&allocator), pointer(pointer), allocated_size(allocated_size) { - if (!pointer) { - throw InternalException("AllocatedData object constructed with nullptr"); - } -} -AllocatedData::~AllocatedData() { - Reset(); -} - -AllocatedData::AllocatedData(AllocatedData &&other) noexcept - : allocator(other.allocator), pointer(nullptr), allocated_size(0) { - std::swap(pointer, other.pointer); - std::swap(allocated_size, other.allocated_size); -} - -AllocatedData &AllocatedData::operator=(AllocatedData &&other) noexcept { - std::swap(allocator, other.allocator); - std::swap(pointer, other.pointer); - std::swap(allocated_size, other.allocated_size); - return *this; -} - -void AllocatedData::Reset() { - if (!pointer) { - return; - } - D_ASSERT(allocator); - allocator->FreeData(pointer, allocated_size); - allocated_size = 0; - pointer = nullptr; -} - -//===--------------------------------------------------------------------===// -// Debug Info -//===--------------------------------------------------------------------===// -struct AllocatorDebugInfo { -#ifdef DEBUG - AllocatorDebugInfo(); - ~AllocatorDebugInfo(); - - void AllocateData(data_ptr_t pointer, idx_t size); - void FreeData(data_ptr_t pointer, idx_t size); - void ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size); - -private: - //! The number of bytes that are outstanding (i.e. that have been allocated - but not freed) - //! Used for debug purposes - atomic allocation_count; -#ifdef DUCKDB_DEBUG_ALLOCATION - mutex pointer_lock; - //! Set of active outstanding pointers together with stack traces - unordered_map> pointers; -#endif -#endif -}; - -PrivateAllocatorData::PrivateAllocatorData() { -} - -PrivateAllocatorData::~PrivateAllocatorData() { -} - -//===--------------------------------------------------------------------===// -// Allocator -//===--------------------------------------------------------------------===// -Allocator::Allocator() - : Allocator(Allocator::DefaultAllocate, Allocator::DefaultFree, Allocator::DefaultReallocate, nullptr) { -} - -Allocator::Allocator(allocate_function_ptr_t allocate_function_p, free_function_ptr_t free_function_p, - reallocate_function_ptr_t reallocate_function_p, unique_ptr private_data_p) - : allocate_function(allocate_function_p), free_function(free_function_p), - reallocate_function(reallocate_function_p), private_data(std::move(private_data_p)) { - D_ASSERT(allocate_function); - D_ASSERT(free_function); - D_ASSERT(reallocate_function); -#ifdef DEBUG - if (!private_data) { - private_data = make_uniq(); - } - private_data->debug_info = make_uniq(); -#endif -} - -Allocator::~Allocator() { -} - -data_ptr_t Allocator::AllocateData(idx_t size) { - D_ASSERT(size > 0); - if (size >= MAXIMUM_ALLOC_SIZE) { - D_ASSERT(false); - throw InternalException("Requested allocation size of %llu is out of range - maximum allocation size is %llu", - size, MAXIMUM_ALLOC_SIZE); - } - auto result = allocate_function(private_data.get(), size); -#ifdef DEBUG - D_ASSERT(private_data); - if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { - private_data->debug_info->AllocateData(result, size); - } -#endif - if (!result) { - throw OutOfMemoryException("Failed to allocate block of %llu bytes (bad allocation)", size); - } - return result; -} - -void Allocator::FreeData(data_ptr_t pointer, idx_t size) { - if (!pointer) { - return; - } - D_ASSERT(size > 0); -#ifdef DEBUG - D_ASSERT(private_data); - if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { - private_data->debug_info->FreeData(pointer, size); - } -#endif - free_function(private_data.get(), pointer, size); -} - -data_ptr_t Allocator::ReallocateData(data_ptr_t pointer, idx_t old_size, idx_t size) { - if (!pointer) { - return nullptr; - } - if (size >= MAXIMUM_ALLOC_SIZE) { - D_ASSERT(false); - throw InternalException( - "Requested re-allocation size of %llu is out of range - maximum allocation size is %llu", size, - MAXIMUM_ALLOC_SIZE); - } - auto new_pointer = reallocate_function(private_data.get(), pointer, old_size, size); -#ifdef DEBUG - D_ASSERT(private_data); - if (private_data->free_type != AllocatorFreeType::DOES_NOT_REQUIRE_FREE) { - private_data->debug_info->ReallocateData(pointer, new_pointer, old_size, size); - } -#endif - if (!new_pointer) { - throw OutOfMemoryException("Failed to re-allocate block of %llu bytes (bad allocation)", size); - } - return new_pointer; -} - -data_ptr_t Allocator::DefaultAllocate(PrivateAllocatorData *private_data, idx_t size) { -#ifdef USE_JEMALLOC - return JemallocExtension::Allocate(private_data, size); -#else - auto default_allocate_result = malloc(size); - if (!default_allocate_result) { - throw std::bad_alloc(); - } - return data_ptr_cast(default_allocate_result); -#endif -} - -void Allocator::DefaultFree(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t size) { -#ifdef USE_JEMALLOC - JemallocExtension::Free(private_data, pointer, size); -#else - free(pointer); -#endif -} - -data_ptr_t Allocator::DefaultReallocate(PrivateAllocatorData *private_data, data_ptr_t pointer, idx_t old_size, - idx_t size) { -#ifdef USE_JEMALLOC - return JemallocExtension::Reallocate(private_data, pointer, old_size, size); -#else - return data_ptr_cast(realloc(pointer, size)); -#endif -} - -shared_ptr &Allocator::DefaultAllocatorReference() { - static shared_ptr DEFAULT_ALLOCATOR = make_shared_ptr(); - return DEFAULT_ALLOCATOR; -} - -Allocator &Allocator::DefaultAllocator() { - return *DefaultAllocatorReference(); -} - -optional_idx Allocator::DecayDelay() { -#ifdef USE_JEMALLOC - return NumericCast(JemallocExtension::DecayDelay()); -#else - return optional_idx(); -#endif -} - -bool Allocator::SupportsFlush() { -#if defined(USE_JEMALLOC) || defined(__GLIBC__) - return true; -#else - return false; -#endif -} - -static void MallocTrim(idx_t pad) { -#ifdef __GLIBC__ - static constexpr int64_t TRIM_INTERVAL_MS = 100; - static atomic LAST_TRIM_TIMESTAMP_MS {0}; - - int64_t last_trim_timestamp_ms = LAST_TRIM_TIMESTAMP_MS.load(); - auto current_ts = Timestamp::GetCurrentTimestamp(); - auto current_timestamp_ms = Cast::Operation(current_ts).value; - - if (current_timestamp_ms - last_trim_timestamp_ms < TRIM_INTERVAL_MS) { - return; // We trimmed less than TRIM_INTERVAL_MS ago - } - if (!LAST_TRIM_TIMESTAMP_MS.compare_exchange_strong(last_trim_timestamp_ms, current_timestamp_ms, - std::memory_order_acquire, std::memory_order_relaxed)) { - return; // Another thread has updated LAST_TRIM_TIMESTAMP_MS since we loaded it - } - - // We succesfully updated LAST_TRIM_TIMESTAMP_MS, we can trim - malloc_trim(pad); -#endif -} - -void Allocator::ThreadFlush(bool allocator_background_threads, idx_t threshold, idx_t thread_count) { -#ifdef USE_JEMALLOC - if (!allocator_background_threads) { - JemallocExtension::ThreadFlush(threshold); - } -#endif - MallocTrim(thread_count * threshold); -} - -void Allocator::ThreadIdle() { -#ifdef USE_JEMALLOC - JemallocExtension::ThreadIdle(); -#endif -} - -void Allocator::FlushAll() { -#ifdef USE_JEMALLOC - JemallocExtension::FlushAll(); -#endif - MallocTrim(0); -} - -void Allocator::SetBackgroundThreads(bool enable) { -#ifdef USE_JEMALLOC - JemallocExtension::SetBackgroundThreads(enable); -#endif -} - -//===--------------------------------------------------------------------===// -// Debug Info (extended) -//===--------------------------------------------------------------------===// -#ifdef DEBUG -AllocatorDebugInfo::AllocatorDebugInfo() { - allocation_count = 0; -} -AllocatorDebugInfo::~AllocatorDebugInfo() { -#ifdef DUCKDB_DEBUG_ALLOCATION - if (allocation_count != 0) { - printf("Outstanding allocations found for Allocator\n"); - for (auto &entry : pointers) { - printf("Allocation of size %llu at address %p\n", entry.second.first, (void *)entry.first); - printf("Stack trace:\n%s\n", entry.second.second.c_str()); - printf("\n"); - } - } -#endif - //! Verify that there is no outstanding memory still associated with the batched allocator - //! Only works for access to the batched allocator through the batched allocator interface - //! If this assertion triggers, enable DUCKDB_DEBUG_ALLOCATION for more information about the allocations - D_ASSERT(allocation_count == 0); -} - -void AllocatorDebugInfo::AllocateData(data_ptr_t pointer, idx_t size) { - allocation_count += size; -#ifdef DUCKDB_DEBUG_ALLOCATION - lock_guard l(pointer_lock); - pointers[pointer] = make_pair(size, Exception::GetStackTrace()); -#endif -} - -void AllocatorDebugInfo::FreeData(data_ptr_t pointer, idx_t size) { - D_ASSERT(allocation_count >= size); - allocation_count -= size; -#ifdef DUCKDB_DEBUG_ALLOCATION - lock_guard l(pointer_lock); - // verify that the pointer exists - D_ASSERT(pointers.find(pointer) != pointers.end()); - // verify that the stored size matches the passed in size - D_ASSERT(pointers[pointer].first == size); - // erase the pointer - pointers.erase(pointer); -#endif -} - -void AllocatorDebugInfo::ReallocateData(data_ptr_t pointer, data_ptr_t new_pointer, idx_t old_size, idx_t new_size) { - FreeData(pointer, old_size); - AllocateData(new_pointer, new_size); -} - -#endif - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/bool_data.cpp b/src/duckdb/src/common/arrow/appender/bool_data.cpp deleted file mode 100644 index 78befb603..000000000 --- a/src/duckdb/src/common/arrow/appender/bool_data.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/appender/bool_data.hpp" - -namespace duckdb { - -void ArrowBoolData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto byte_count = (capacity + 7) / 8; - result.GetMainBuffer().reserve(byte_count); - (void)AppendValidity; // silence a compiler warning about unused static function -} - -void ArrowBoolData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - idx_t size = to - from; - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - auto &main_buffer = append_data.GetMainBuffer(); - auto &validity_buffer = append_data.GetValidityBuffer(); - // we initialize both the validity and the bit set to 1's - ResizeValidity(validity_buffer, append_data.row_count + size); - ResizeValidity(main_buffer, append_data.row_count + size); - auto data = UnifiedVectorFormat::GetData(format); - - auto result_data = main_buffer.GetData(); - auto validity_data = validity_buffer.GetData(); - uint8_t current_bit; - idx_t current_byte; - GetBitPosition(append_data.row_count, current_byte, current_bit); - for (idx_t i = from; i < to; i++) { - auto source_idx = format.sel->get_index(i); - // append the validity mask - if (!format.validity.RowIsValid(source_idx)) { - SetNull(append_data, validity_data, current_byte, current_bit); - } else if (!data[source_idx]) { - UnsetBit(result_data, current_byte, current_bit); - } - NextBit(current_byte, current_bit); - } - append_data.row_count += size; -} - -void ArrowBoolData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 2; - result->buffers[1] = append_data.GetMainBuffer().data(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp b/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp deleted file mode 100644 index 172144fd3..000000000 --- a/src/duckdb/src/common/arrow/appender/fixed_size_list_data.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/appender/fixed_size_list_data.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Arrays -//===--------------------------------------------------------------------===// -void ArrowFixedSizeListData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto &child_type = ArrayType::GetChildType(type); - auto array_size = ArrayType::GetSize(type); - auto child_buffer = ArrowAppender::InitializeChild(child_type, capacity * array_size, result.options); - result.child_data.push_back(std::move(child_buffer)); -} - -void ArrowFixedSizeListData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, - idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - AppendValidity(append_data, format, from, to); - input.Flatten(input_size); - auto array_size = ArrayType::GetSize(input.GetType()); - auto &child_vector = ArrayVector::GetEntry(input); - auto &child_data = *append_data.child_data[0]; - child_data.append_vector(child_data, child_vector, from * array_size, to * array_size, size * array_size); - append_data.row_count += size; -} - -void ArrowFixedSizeListData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 1; - auto &child_type = ArrayType::GetChildType(type); - ArrowAppender::AddChildren(append_data, 1); - result->children = append_data.child_pointers.data(); - result->n_children = 1; - append_data.child_arrays[0] = *ArrowAppender::FinalizeChild(child_type, std::move(append_data.child_data[0])); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/struct_data.cpp b/src/duckdb/src/common/arrow/appender/struct_data.cpp deleted file mode 100644 index b2afa62d1..000000000 --- a/src/duckdb/src/common/arrow/appender/struct_data.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/appender/struct_data.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Structs -//===--------------------------------------------------------------------===// -void ArrowStructData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - auto &children = StructType::GetChildTypes(type); - for (auto &child : children) { - auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); - result.child_data.push_back(std::move(child_buffer)); - } -} - -void ArrowStructData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - AppendValidity(append_data, format, from, to); - // append the children of the struct - auto &children = StructVector::GetEntries(input); - for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { - auto &child = children[child_idx]; - auto &child_data = *append_data.child_data[child_idx]; - child_data.append_vector(child_data, *child, from, to, size); - } - append_data.row_count += size; -} - -void ArrowStructData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 1; - - auto &child_types = StructType::GetChildTypes(type); - ArrowAppender::AddChildren(append_data, child_types.size()); - result->children = append_data.child_pointers.data(); - result->n_children = NumericCast(child_types.size()); - for (idx_t i = 0; i < child_types.size(); i++) { - auto &child_type = child_types[i].second; - append_data.child_arrays[i] = *ArrowAppender::FinalizeChild(child_type, std::move(append_data.child_data[i])); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/appender/union_data.cpp b/src/duckdb/src/common/arrow/appender/union_data.cpp deleted file mode 100644 index 1e9f4f432..000000000 --- a/src/duckdb/src/common/arrow/appender/union_data.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/appender/union_data.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Unions -//===--------------------------------------------------------------------===// -void ArrowUnionData::Initialize(ArrowAppendData &result, const LogicalType &type, idx_t capacity) { - result.GetMainBuffer().reserve(capacity * sizeof(int8_t)); - - for (auto &child : UnionType::CopyMemberTypes(type)) { - auto child_buffer = ArrowAppender::InitializeChild(child.second, capacity, result.options); - result.child_data.push_back(std::move(child_buffer)); - } - (void)AppendValidity; // silence a compiler warning about unused static functiondep -} - -void ArrowUnionData::Append(ArrowAppendData &append_data, Vector &input, idx_t from, idx_t to, idx_t input_size) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(input_size, format); - idx_t size = to - from; - - auto &types_buffer = append_data.GetMainBuffer(); - - duckdb::vector child_vectors; - for (const auto &child : UnionType::CopyMemberTypes(input.GetType())) { - child_vectors.emplace_back(child.second, size); - } - - for (idx_t input_idx = from; input_idx < to; input_idx++) { - const auto &val = input.GetValue(input_idx); - - idx_t tag = 0; - Value resolved_value(nullptr); - if (!val.IsNull()) { - tag = UnionValue::GetTag(val); - - resolved_value = UnionValue::GetValue(val); - } - - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - child_vectors[child_idx].SetValue(input_idx, child_idx == tag ? resolved_value : Value(nullptr)); - } - types_buffer.push_back(NumericCast(tag)); - } - - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - auto &child_buffer = append_data.child_data[child_idx]; - auto &child = child_vectors[child_idx]; - child_buffer->append_vector(*child_buffer, child, from, to, size); - } - append_data.row_count += size; -} - -void ArrowUnionData::Finalize(ArrowAppendData &append_data, const LogicalType &type, ArrowArray *result) { - result->n_buffers = 1; - result->buffers[0] = append_data.GetMainBuffer().data(); - - auto &child_types = UnionType::CopyMemberTypes(type); - ArrowAppender::AddChildren(append_data, child_types.size()); - result->children = append_data.child_pointers.data(); - result->n_children = NumericCast(child_types.size()); - for (idx_t i = 0; i < child_types.size(); i++) { - auto &child_type = child_types[i].second; - append_data.child_arrays[i] = *ArrowAppender::FinalizeChild(child_type, std::move(append_data.child_data[i])); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_appender.cpp b/src/duckdb/src/common/arrow/arrow_appender.cpp deleted file mode 100644 index 632bffc66..000000000 --- a/src/duckdb/src/common/arrow/arrow_appender.cpp +++ /dev/null @@ -1,306 +0,0 @@ -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/arrow_buffer.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/common/array.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/uuid.hpp" -#include "duckdb/function/table/arrow.hpp" -#include "duckdb/common/arrow/appender/append_data.hpp" -#include "duckdb/common/arrow/appender/list.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// ArrowAppender -//===--------------------------------------------------------------------===// - -ArrowAppender::ArrowAppender(vector types_p, const idx_t initial_capacity, ClientProperties options) - : types(std::move(types_p)) { - for (auto &type : types) { - auto entry = InitializeChild(type, initial_capacity, options); - root_data.push_back(std::move(entry)); - } -} - -ArrowAppender::~ArrowAppender() { -} - -//! Append a data chunk to the underlying arrow array -void ArrowAppender::Append(DataChunk &input, idx_t from, idx_t to, idx_t input_size) { - D_ASSERT(types == input.GetTypes()); - D_ASSERT(to >= from); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - root_data[i]->append_vector(*root_data[i], input.data[i], from, to, input_size); - } - row_count += to - from; -} - -idx_t ArrowAppender::RowCount() const { - return row_count; -} - -void ArrowAppender::ReleaseArray(ArrowArray *array) { - if (!array || !array->release) { - return; - } - auto holder = static_cast(array->private_data); - for (int64_t i = 0; i < array->n_children; i++) { - auto child = array->children[i]; - if (!child->release) { - // Child was moved out of the array - continue; - } - child->release(child); - D_ASSERT(!child->release); - } - if (array->dictionary && array->dictionary->release) { - array->dictionary->release(array->dictionary); - } - array->release = nullptr; - delete holder; -} - -//===--------------------------------------------------------------------===// -// Finalize Arrow Child -//===--------------------------------------------------------------------===// -ArrowArray *ArrowAppender::FinalizeChild(const LogicalType &type, unique_ptr append_data_p) { - auto result = make_uniq(); - - auto &append_data = *append_data_p; - result->private_data = append_data_p.release(); - result->release = ReleaseArray; - result->n_children = 0; - result->null_count = 0; - result->offset = 0; - result->dictionary = nullptr; - result->buffers = append_data.buffers.data(); - result->null_count = NumericCast(append_data.null_count); - result->length = NumericCast(append_data.row_count); - result->buffers[0] = append_data.GetValidityBuffer().data(); - - if (append_data.finalize) { - append_data.finalize(append_data, type, result.get()); - } - - append_data.array = std::move(result); - return append_data.array.get(); -} - -//! Returns the underlying arrow array -ArrowArray ArrowAppender::Finalize() { - D_ASSERT(root_data.size() == types.size()); - auto root_holder = make_uniq(options); - - ArrowArray result; - AddChildren(*root_holder, types.size()); - result.children = root_holder->child_pointers.data(); - result.n_children = NumericCast(types.size()); - - // Configure root array - result.length = NumericCast(row_count); - result.n_buffers = 1; - result.buffers = root_holder->buffers.data(); // there is no actual buffer there since we don't have NULLs - result.offset = 0; - result.null_count = 0; // needs to be 0 - result.dictionary = nullptr; - root_holder->child_data = std::move(root_data); - - for (idx_t i = 0; i < root_holder->child_data.size(); i++) { - root_holder->child_arrays[i] = *ArrowAppender::FinalizeChild(types[i], std::move(root_holder->child_data[i])); - } - - // Release ownership to caller - result.private_data = root_holder.release(); - result.release = ArrowAppender::ReleaseArray; - return result; -} - -//===--------------------------------------------------------------------===// -// Initialize Arrow Child -//===--------------------------------------------------------------------===// - -template -static void InitializeAppenderForType(ArrowAppendData &append_data) { - append_data.initialize = OP::Initialize; - append_data.append_vector = OP::Append; - append_data.finalize = OP::Finalize; -} - -static void InitializeFunctionPointers(ArrowAppendData &append_data, const LogicalType &type) { - // handle special logical types - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::TINYINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::SMALLINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DATE: - case LogicalTypeId::INTEGER: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::TIME_TZ: { - if (append_data.options.arrow_lossless_conversion) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - break; - } - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::BIGINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UUID: - if (append_data.options.arrow_lossless_conversion) { - InitializeAppenderForType>(append_data); - } else { - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - } - break; - case LogicalTypeId::HUGEINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UHUGEINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UTINYINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::USMALLINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UINTEGER: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UBIGINT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::FLOAT: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DOUBLE: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::DECIMAL: - switch (type.InternalType()) { - case PhysicalType::INT16: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT32: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT64: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::INT128: - InitializeAppenderForType>(append_data); - break; - default: - throw InternalException("Unsupported internal decimal type"); - } - break; - case LogicalTypeId::VARCHAR: - if (append_data.options.produce_arrow_string_view) { - InitializeAppenderForType(append_data); - } else { - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - } - break; - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - break; - case LogicalTypeId::ENUM: - switch (type.InternalType()) { - case PhysicalType::UINT8: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::UINT16: - InitializeAppenderForType>(append_data); - break; - case PhysicalType::UINT32: - InitializeAppenderForType>(append_data); - break; - default: - throw InternalException("Unsupported internal enum type"); - } - break; - case LogicalTypeId::INTERVAL: - InitializeAppenderForType>(append_data); - break; - case LogicalTypeId::UNION: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::STRUCT: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::ARRAY: - InitializeAppenderForType(append_data); - break; - case LogicalTypeId::LIST: { - if (append_data.options.arrow_use_list_view) { - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - } else { - if (append_data.options.arrow_offset_size == ArrowOffsetSize::LARGE) { - InitializeAppenderForType>(append_data); - } else { - InitializeAppenderForType>(append_data); - } - } - break; - } - case LogicalTypeId::MAP: - // Arrow MapArray only supports 32-bit offsets. There is no LargeMapArray type in Arrow. - InitializeAppenderForType>(append_data); - break; - default: - throw NotImplementedException("Unsupported type in DuckDB -> Arrow Conversion: %s\n", type.ToString()); - } -} - -unique_ptr ArrowAppender::InitializeChild(const LogicalType &type, const idx_t capacity, - ClientProperties &options) { - auto result = make_uniq(options); - InitializeFunctionPointers(*result, type); - - const auto byte_count = (capacity + 7) / 8; - result->GetValidityBuffer().reserve(byte_count); - result->initialize(*result, type, capacity); - return result; -} - -void ArrowAppender::AddChildren(ArrowAppendData &data, const idx_t count) { - data.child_pointers.resize(count); - data.child_arrays.resize(count); - for (idx_t i = 0; i < count; i++) { - data.child_pointers[i] = &data.child_arrays[i]; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_converter.cpp b/src/duckdb/src/common/arrow/arrow_converter.cpp deleted file mode 100644 index 582daf34b..000000000 --- a/src/duckdb/src/common/arrow/arrow_converter.cpp +++ /dev/null @@ -1,432 +0,0 @@ -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/arrow/arrow.hpp" -#include "duckdb/common/arrow/arrow_converter.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/sel_cache.hpp" -#include "duckdb/common/types/vector_cache.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/vector.hpp" -#include -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/arrow/schema_metadata.hpp" - -namespace duckdb { - -void ArrowConverter::ToArrowArray(DataChunk &input, ArrowArray *out_array, ClientProperties options) { - ArrowAppender appender(input.GetTypes(), input.size(), std::move(options)); - appender.Append(input, 0, input.size(), input.size()); - *out_array = appender.Finalize(); -} - -unsafe_unique_array AddName(const string &name) { - auto name_ptr = make_unsafe_uniq_array(name.size() + 1); - for (size_t i = 0; i < name.size(); i++) { - name_ptr[i] = name[i]; - } - name_ptr[name.size()] = '\0'; - return name_ptr; -} - -//===--------------------------------------------------------------------===// -// Arrow Schema -//===--------------------------------------------------------------------===// -struct DuckDBArrowSchemaHolder { - // unused in children - vector children; - // unused in children - vector children_ptrs; - //! used for nested structures - std::list> nested_children; - std::list> nested_children_ptr; - //! This holds strings created to represent decimal types - vector> owned_type_names; - vector> owned_column_names; - //! This holds any values created for metadata info - vector> metadata_info; -}; - -static void ReleaseDuckDBArrowSchema(ArrowSchema *schema) { - if (!schema || !schema->release) { - return; - } - schema->release = nullptr; - auto holder = static_cast(schema->private_data); - delete holder; -} - -void InitializeChild(ArrowSchema &child, DuckDBArrowSchemaHolder &root_holder, const string &name = "") { - //! Child is cleaned up by parent - child.private_data = nullptr; - child.release = ReleaseDuckDBArrowSchema; - - // Store the child schema - child.flags = ARROW_FLAG_NULLABLE; - root_holder.owned_type_names.push_back(AddName(name)); - - child.name = root_holder.owned_type_names.back().get(); - child.n_children = 0; - child.children = nullptr; - child.metadata = nullptr; - child.dictionary = nullptr; -} - -void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options); - -void SetArrowMapFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options) { - child.format = "+m"; - //! Map has one child which is a struct - child.n_children = 1; - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.children = &root_holder.nested_children_ptr.back()[0]; - child.children[0]->name = "entries"; - SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); -} - -void SetArrowFormat(DuckDBArrowSchemaHolder &root_holder, ArrowSchema &child, const LogicalType &type, - const ClientProperties &options) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - child.format = "b"; - break; - case LogicalTypeId::TINYINT: - child.format = "c"; - break; - case LogicalTypeId::SMALLINT: - child.format = "s"; - break; - case LogicalTypeId::INTEGER: - child.format = "i"; - break; - case LogicalTypeId::BIGINT: - child.format = "l"; - break; - case LogicalTypeId::UTINYINT: - child.format = "C"; - break; - case LogicalTypeId::USMALLINT: - child.format = "S"; - break; - case LogicalTypeId::UINTEGER: - child.format = "I"; - break; - case LogicalTypeId::UBIGINT: - child.format = "L"; - break; - case LogicalTypeId::FLOAT: - child.format = "f"; - break; - case LogicalTypeId::HUGEINT: { - if (options.arrow_lossless_conversion) { - child.format = "w:16"; - auto schema_metadata = ArrowSchemaMetadata::DuckDBInternalType("hugeint"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - } else { - child.format = "d:38,0"; - } - break; - } - case LogicalTypeId::UHUGEINT: { - child.format = "w:16"; - auto schema_metadata = ArrowSchemaMetadata::DuckDBInternalType("uhugeint"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - break; - } - case LogicalTypeId::VARINT: { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "Z"; - } else { - child.format = "z"; - } - auto schema_metadata = ArrowSchemaMetadata::DuckDBInternalType("varint"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - break; - } - case LogicalTypeId::DOUBLE: - child.format = "g"; - break; - case LogicalTypeId::UUID: { - if (options.arrow_lossless_conversion) { - // This is a canonical extension, hence needs the "arrow." prefix - child.format = "w:16"; - auto schema_metadata = ArrowSchemaMetadata::ArrowCanonicalType("arrow.uuid"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - } else { - if (options.produce_arrow_string_view) { - child.format = "vu"; - } else { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "U"; - } else { - child.format = "u"; - } - } - } - break; - } - case LogicalTypeId::VARCHAR: - if (type.IsJSONType() && options.arrow_lossless_conversion) { - auto schema_metadata = ArrowSchemaMetadata::ArrowCanonicalType("arrow.json"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - } - if (options.produce_arrow_string_view) { - child.format = "vu"; - } else { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "U"; - } else { - child.format = "u"; - } - } - break; - case LogicalTypeId::DATE: - child.format = "tdD"; - break; - case LogicalTypeId::TIME_TZ: { - if (options.arrow_lossless_conversion) { - child.format = "w:8"; - auto schema_metadata = ArrowSchemaMetadata::DuckDBInternalType("time_tz"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - } else { - child.format = "ttu"; - } - break; - } - case LogicalTypeId::TIME: - child.format = "ttu"; - break; - case LogicalTypeId::TIMESTAMP: - child.format = "tsu:"; - break; - case LogicalTypeId::TIMESTAMP_TZ: { - string format = "tsu:" + options.time_zone; - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - break; - } - case LogicalTypeId::TIMESTAMP_SEC: - child.format = "tss:"; - break; - case LogicalTypeId::TIMESTAMP_NS: - child.format = "tsn:"; - break; - case LogicalTypeId::TIMESTAMP_MS: - child.format = "tsm:"; - break; - case LogicalTypeId::INTERVAL: - child.format = "tin"; - break; - case LogicalTypeId::DECIMAL: { - uint8_t width, scale; - type.GetDecimalProperties(width, scale); - string format = "d:" + to_string(width) + "," + to_string(scale); - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - break; - } - case LogicalTypeId::SQLNULL: { - child.format = "n"; - break; - } - case LogicalTypeId::BLOB: - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "Z"; - } else { - child.format = "z"; - } - break; - case LogicalTypeId::BIT: { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "Z"; - } else { - child.format = "z"; - } - if (options.arrow_lossless_conversion) { - auto schema_metadata = ArrowSchemaMetadata::DuckDBInternalType("bit"); - root_holder.metadata_info.emplace_back(schema_metadata.SerializeMetadata()); - child.metadata = root_holder.metadata_info.back().get(); - } - break; - } - case LogicalTypeId::LIST: { - if (options.arrow_use_list_view) { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "+vL"; - } else { - child.format = "+vl"; - } - } else { - if (options.arrow_offset_size == ArrowOffsetSize::LARGE) { - child.format = "+L"; - } else { - child.format = "+l"; - } - } - child.n_children = 1; - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.children = &root_holder.nested_children_ptr.back()[0]; - child.children[0]->name = "l"; - SetArrowFormat(root_holder, **child.children, ListType::GetChildType(type), options); - break; - } - case LogicalTypeId::STRUCT: { - child.format = "+s"; - auto &child_types = StructType::GetChildTypes(type); - child.n_children = NumericCast(child_types.size()); - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(child_types.size()); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().resize(child_types.size()); - for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; - } - child.children = &root_holder.nested_children_ptr.back()[0]; - for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - - InitializeChild(*child.children[type_idx], root_holder); - - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); - - child.children[type_idx]->name = root_holder.owned_type_names.back().get(); - SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); - } - break; - } - case LogicalTypeId::ARRAY: { - auto array_size = ArrayType::GetSize(type); - auto &child_type = ArrayType::GetChildType(type); - auto format = "+w:" + to_string(array_size); - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - - child.n_children = 1; - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.children = &root_holder.nested_children_ptr.back()[0]; - SetArrowFormat(root_holder, **child.children, child_type, options); - break; - } - case LogicalTypeId::MAP: { - SetArrowMapFormat(root_holder, child, type, options); - break; - } - case LogicalTypeId::UNION: { - std::string format = "+us:"; - - auto &child_types = UnionType::CopyMemberTypes(type); - child.n_children = NumericCast(child_types.size()); - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(child_types.size()); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().resize(child_types.size()); - for (idx_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - root_holder.nested_children_ptr.back()[type_idx] = &root_holder.nested_children.back()[type_idx]; - } - child.children = &root_holder.nested_children_ptr.back()[0]; - for (size_t type_idx = 0; type_idx < child_types.size(); type_idx++) { - - InitializeChild(*child.children[type_idx], root_holder); - - root_holder.owned_type_names.push_back(AddName(child_types[type_idx].first)); - - child.children[type_idx]->name = root_holder.owned_type_names.back().get(); - SetArrowFormat(root_holder, *child.children[type_idx], child_types[type_idx].second, options); - - format += to_string(type_idx) + ","; - } - - format.pop_back(); - - root_holder.owned_type_names.push_back(AddName(format)); - child.format = root_holder.owned_type_names.back().get(); - - break; - } - case LogicalTypeId::ENUM: { - // TODO what do we do with pointer enums here? - switch (EnumType::GetPhysicalType(type)) { - case PhysicalType::UINT8: - child.format = "C"; - break; - case PhysicalType::UINT16: - child.format = "S"; - break; - case PhysicalType::UINT32: - child.format = "I"; - break; - default: - throw InternalException("Unsupported Enum Internal Type"); - } - root_holder.nested_children.emplace_back(); - root_holder.nested_children.back().resize(1); - root_holder.nested_children_ptr.emplace_back(); - root_holder.nested_children_ptr.back().push_back(&root_holder.nested_children.back()[0]); - InitializeChild(root_holder.nested_children.back()[0], root_holder); - child.dictionary = root_holder.nested_children_ptr.back()[0]; - child.dictionary->format = "u"; - break; - } - default: - throw NotImplementedException("Unsupported Arrow type " + type.ToString()); - } -} - -void ArrowConverter::ToArrowSchema(ArrowSchema *out_schema, const vector &types, - const vector &names, const ClientProperties &options) { - D_ASSERT(out_schema); - D_ASSERT(types.size() == names.size()); - idx_t column_count = types.size(); - // Allocate as unique_ptr first to cleanup properly on error - auto root_holder = make_uniq(); - - // Allocate the children - root_holder->children.resize(column_count); - root_holder->children_ptrs.resize(column_count, nullptr); - for (size_t i = 0; i < column_count; ++i) { - root_holder->children_ptrs[i] = &root_holder->children[i]; - } - out_schema->children = root_holder->children_ptrs.data(); - out_schema->n_children = NumericCast(column_count); - - // Store the schema - out_schema->format = "+s"; // struct apparently - out_schema->flags = 0; - out_schema->metadata = nullptr; - out_schema->name = "duckdb_query_result"; - out_schema->dictionary = nullptr; - - // Configure all child schemas - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - root_holder->owned_column_names.push_back(AddName(names[col_idx])); - auto &child = root_holder->children[col_idx]; - InitializeChild(child, *root_holder, names[col_idx]); - SetArrowFormat(*root_holder, child, types[col_idx], options); - } - - // Release ownership to caller - out_schema->private_data = root_holder.release(); - out_schema->release = ReleaseDuckDBArrowSchema; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_merge_event.cpp b/src/duckdb/src/common/arrow/arrow_merge_event.cpp deleted file mode 100644 index 1315ad1ad..000000000 --- a/src/duckdb/src/common/arrow/arrow_merge_event.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include "duckdb/common/arrow/arrow_merge_event.hpp" -#include "duckdb/common/arrow/arrow_util.hpp" -#include "duckdb/storage/storage_info.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Arrow Batch Task -//===--------------------------------------------------------------------===// - -ArrowBatchTask::ArrowBatchTask(ArrowQueryResult &result, vector record_batch_indices, Executor &executor, - shared_ptr event_p, BatchCollectionChunkScanState scan_state, - vector names, idx_t batch_size) - : ExecutorTask(executor, event_p), result(result), record_batch_indices(std::move(record_batch_indices)), - event(std::move(event_p)), batch_size(batch_size), names(std::move(names)), scan_state(std::move(scan_state)) { -} - -void ArrowBatchTask::ProduceRecordBatches() { - auto &arrays = result.Arrays(); - auto arrow_options = executor.context.GetClientProperties(); - for (auto &index : record_batch_indices) { - auto &array = arrays[index]; - D_ASSERT(array); - idx_t count; - count = ArrowUtil::FetchChunk(scan_state, arrow_options, batch_size, &array->arrow_array); - (void)count; - D_ASSERT(count != 0); - } -} - -TaskExecutionResult ArrowBatchTask::ExecuteTask(TaskExecutionMode mode) { - ProduceRecordBatches(); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -//===--------------------------------------------------------------------===// -// Arrow Merge Event -//===--------------------------------------------------------------------===// - -ArrowMergeEvent::ArrowMergeEvent(ArrowQueryResult &result, BatchedDataCollection &batches, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), result(result), batches(batches) { - record_batch_size = result.BatchSize(); -} - -namespace { - -struct BatchesForTask { - idx_t tuple_count; - BatchedChunkIteratorRange batches; -}; - -struct BatchesToTaskTransformer { -public: - explicit BatchesToTaskTransformer(BatchedDataCollection &batches) : batches(batches), batch_index(0) { - batch_count = batches.BatchCount(); - } - idx_t GetIndex() const { - return batch_index; - } - bool TryGetNextBatchSize(idx_t &tuple_count) { - if (batch_index >= batch_count) { - return false; - } - auto internal_index = batches.IndexToBatchIndex(batch_index++); - auto tuples_in_batch = batches.BatchSize(internal_index); - tuple_count = tuples_in_batch; - return true; - } - -public: - BatchedDataCollection &batches; - idx_t batch_index; - idx_t batch_count; -}; - -} // namespace - -void ArrowMergeEvent::Schedule() { - vector> tasks; - - BatchesToTaskTransformer transformer(batches); - vector task_data; - bool finished = false; - // First we convert our list of batches into units of Storage::ROW_GROUP_SIZE tuples each - while (!finished) { - idx_t tuples_for_task = 0; - idx_t start_index = transformer.GetIndex(); - idx_t end_index = start_index; - while (tuples_for_task < DEFAULT_ROW_GROUP_SIZE) { - idx_t batch_size; - if (!transformer.TryGetNextBatchSize(batch_size)) { - finished = true; - break; - } - end_index++; - tuples_for_task += batch_size; - } - if (start_index == end_index) { - break; - } - BatchesForTask batches_for_task; - batches_for_task.tuple_count = tuples_for_task; - batches_for_task.batches = batches.BatchRange(start_index, end_index); - task_data.push_back(batches_for_task); - } - - // Now we produce tasks from these units - // Every task is given a scan_state created from the range of batches - // and a vector of indices indicating the arrays (record batches) they should populate - idx_t record_batch_index = 0; - for (auto &data : task_data) { - const auto tuples = data.tuple_count; - - auto full_batches = tuples / record_batch_size; - auto remainder = tuples % record_batch_size; - auto total_batches = full_batches + !!remainder; - - vector record_batch_indices(total_batches); - for (idx_t i = 0; i < total_batches; i++) { - record_batch_indices[i] = record_batch_index++; - } - - BatchCollectionChunkScanState scan_state(batches, data.batches, pipeline->executor.context); - tasks.push_back(make_uniq(result, std::move(record_batch_indices), pipeline->executor, - shared_from_this(), std::move(scan_state), result.names, - record_batch_size)); - } - - // Allocate the list of record batches inside the query result - { - vector> arrays; - arrays.resize(record_batch_index); - for (idx_t i = 0; i < record_batch_index; i++) { - arrays[i] = make_uniq(); - } - result.SetArrowData(std::move(arrays)); - } - D_ASSERT(!tasks.empty()); - SetTasks(std::move(tasks)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_query_result.cpp b/src/duckdb/src/common/arrow/arrow_query_result.cpp deleted file mode 100644 index 396a99944..000000000 --- a/src/duckdb/src/common/arrow/arrow_query_result.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "duckdb/common/arrow/arrow_query_result.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/box_renderer.hpp" -#include "duckdb/common/arrow/arrow_converter.hpp" - -namespace duckdb { - -ArrowQueryResult::ArrowQueryResult(StatementType statement_type, StatementProperties properties, vector names_p, - vector types_p, ClientProperties client_properties, idx_t batch_size) - : QueryResult(QueryResultType::ARROW_RESULT, statement_type, std::move(properties), std::move(types_p), - std::move(names_p), std::move(client_properties)), - batch_size(batch_size) { -} - -ArrowQueryResult::ArrowQueryResult(ErrorData error) : QueryResult(QueryResultType::ARROW_RESULT, std::move(error)) { -} - -unique_ptr ArrowQueryResult::Fetch() { - throw NotImplementedException("Can't 'Fetch' from ArrowQueryResult"); -} -unique_ptr ArrowQueryResult::FetchRaw() { - throw NotImplementedException("Can't 'FetchRaw' from ArrowQueryResult"); -} - -string ArrowQueryResult::ToString() { - // FIXME: can't throw an exception here as it's used for verification - return ""; -} - -vector> ArrowQueryResult::ConsumeArrays() { - if (HasError()) { - throw InvalidInputException("Attempting to fetch ArrowArrays from an unsuccessful query result\n: Error %s", - GetError()); - } - return std::move(arrays); -} - -vector> &ArrowQueryResult::Arrays() { - if (HasError()) { - throw InvalidInputException("Attempting to fetch ArrowArrays from an unsuccessful query result\n: Error %s", - GetError()); - } - return arrays; -} - -void ArrowQueryResult::SetArrowData(vector> arrays) { - D_ASSERT(this->arrays.empty()); - this->arrays = std::move(arrays); -} - -idx_t ArrowQueryResult::BatchSize() const { - return batch_size; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_util.cpp b/src/duckdb/src/common/arrow/arrow_util.cpp deleted file mode 100644 index 423a6dd2c..000000000 --- a/src/duckdb/src/common/arrow/arrow_util.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/common/arrow/arrow_util.hpp" -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/common/types/data_chunk.hpp" - -namespace duckdb { - -bool ArrowUtil::TryFetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t batch_size, ArrowArray *out, - idx_t &count, ErrorData &error) { - count = 0; - ArrowAppender appender(scan_state.Types(), batch_size, std::move(options)); - auto remaining_tuples_in_chunk = scan_state.RemainingInChunk(); - if (remaining_tuples_in_chunk) { - // We start by scanning the non-finished current chunk - idx_t cur_consumption = MinValue(remaining_tuples_in_chunk, batch_size); - count += cur_consumption; - auto ¤t_chunk = scan_state.CurrentChunk(); - appender.Append(current_chunk, scan_state.CurrentOffset(), scan_state.CurrentOffset() + cur_consumption, - current_chunk.size()); - scan_state.IncreaseOffset(cur_consumption); - } - while (count < batch_size) { - if (!scan_state.LoadNextChunk(error)) { - if (scan_state.HasError()) { - error = scan_state.GetError(); - } - return false; - } - if (scan_state.ChunkIsEmpty()) { - // The scan was successful, but an empty chunk was returned - break; - } - auto ¤t_chunk = scan_state.CurrentChunk(); - if (scan_state.Finished() || current_chunk.size() == 0) { - break; - } - // The amount we still need to append into this chunk - auto remaining = batch_size - count; - - // The amount remaining, capped by the amount left in the current chunk - auto to_append_to_batch = MinValue(remaining, scan_state.RemainingInChunk()); - appender.Append(current_chunk, 0, to_append_to_batch, current_chunk.size()); - count += to_append_to_batch; - scan_state.IncreaseOffset(to_append_to_batch); - } - if (count > 0) { - *out = appender.Finalize(); - } - return true; -} - -idx_t ArrowUtil::FetchChunk(ChunkScanState &scan_state, ClientProperties options, idx_t chunk_size, ArrowArray *out) { - ErrorData error; - idx_t result_count; - if (!TryFetchChunk(scan_state, std::move(options), chunk_size, out, result_count, error)) { - error.Throw(); - } - return result_count; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/arrow_wrapper.cpp b/src/duckdb/src/common/arrow/arrow_wrapper.cpp deleted file mode 100644 index 2d17076f4..000000000 --- a/src/duckdb/src/common/arrow/arrow_wrapper.cpp +++ /dev/null @@ -1,180 +0,0 @@ -#include "duckdb/common/arrow/arrow_wrapper.hpp" -#include "duckdb/common/arrow/arrow_util.hpp" -#include "duckdb/common/arrow/arrow_converter.hpp" - -#include "duckdb/common/assert.hpp" -#include "duckdb/common/exception.hpp" - -#include "duckdb/main/stream_query_result.hpp" - -#include "duckdb/common/arrow/result_arrow_wrapper.hpp" -#include "duckdb/common/arrow/arrow_appender.hpp" -#include "duckdb/main/query_result.hpp" -#include "duckdb/main/chunk_scan_state/query_result.hpp" - -namespace duckdb { - -ArrowSchemaWrapper::~ArrowSchemaWrapper() { - if (arrow_schema.release) { - arrow_schema.release(&arrow_schema); - D_ASSERT(!arrow_schema.release); - } -} - -ArrowArrayWrapper::~ArrowArrayWrapper() { - if (arrow_array.release) { - arrow_array.release(&arrow_array); - D_ASSERT(!arrow_array.release); - } -} - -ArrowArrayStreamWrapper::~ArrowArrayStreamWrapper() { - if (arrow_array_stream.release) { - arrow_array_stream.release(&arrow_array_stream); - D_ASSERT(!arrow_array_stream.release); - } -} - -void ArrowArrayStreamWrapper::GetSchema(ArrowSchemaWrapper &schema) { - D_ASSERT(arrow_array_stream.get_schema); - // LCOV_EXCL_START - if (arrow_array_stream.get_schema(&arrow_array_stream, &schema.arrow_schema)) { - throw InvalidInputException("arrow_scan: get_schema failed(): %s", string(GetError())); - } - if (!schema.arrow_schema.release) { - throw InvalidInputException("arrow_scan: released schema passed"); - } - if (schema.arrow_schema.n_children < 1) { - throw InvalidInputException("arrow_scan: empty schema passed"); - } - // LCOV_EXCL_STOP -} - -shared_ptr ArrowArrayStreamWrapper::GetNextChunk() { - auto current_chunk = make_shared_ptr(); - if (arrow_array_stream.get_next(&arrow_array_stream, ¤t_chunk->arrow_array)) { // LCOV_EXCL_START - throw InvalidInputException("arrow_scan: get_next failed(): %s", string(GetError())); - } // LCOV_EXCL_STOP - - return current_chunk; -} - -const char *ArrowArrayStreamWrapper::GetError() { // LCOV_EXCL_START - return arrow_array_stream.get_last_error(&arrow_array_stream); -} // LCOV_EXCL_STOP - -int ResultArrowArrayStreamWrapper::MyStreamGetSchema(struct ArrowArrayStream *stream, struct ArrowSchema *out) { - if (!stream->release) { - return -1; - } - out->release = nullptr; - auto my_stream = reinterpret_cast(stream->private_data); - if (!my_stream->column_types.empty()) { - try { - ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, - my_stream->result->client_properties); - } catch (std::runtime_error &e) { - my_stream->last_error = ErrorData(e); - return -1; - } - return 0; - } - - auto &result = *my_stream->result; - if (result.HasError()) { - my_stream->last_error = result.GetErrorObject(); - return -1; - } - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - my_stream->last_error = ErrorData("Query Stream is closed"); - return -1; - } - } - if (my_stream->column_types.empty()) { - my_stream->column_types = result.types; - my_stream->column_names = result.names; - } - try { - ArrowConverter::ToArrowSchema(out, my_stream->column_types, my_stream->column_names, - my_stream->result->client_properties); - } catch (std::runtime_error &e) { - my_stream->last_error = ErrorData(e); - return -1; - } - return 0; -} - -int ResultArrowArrayStreamWrapper::MyStreamGetNext(struct ArrowArrayStream *stream, struct ArrowArray *out) { - if (!stream->release) { - return -1; - } - auto my_stream = reinterpret_cast(stream->private_data); - auto &result = *my_stream->result; - auto &scan_state = *my_stream->scan_state; - if (result.HasError()) { - my_stream->last_error = result.GetErrorObject(); - return -1; - } - if (result.type == QueryResultType::STREAM_RESULT) { - auto &stream_result = result.Cast(); - if (!stream_result.IsOpen()) { - // Nothing to output - out->release = nullptr; - return 0; - } - } - if (my_stream->column_types.empty()) { - my_stream->column_types = result.types; - my_stream->column_names = result.names; - } - idx_t result_count; - ErrorData error; - if (!ArrowUtil::TryFetchChunk(scan_state, result.client_properties, my_stream->batch_size, out, result_count, - error)) { - D_ASSERT(error.HasError()); - my_stream->last_error = error; - return -1; - } - if (result_count == 0) { - // Nothing to output - out->release = nullptr; - } - return 0; -} - -void ResultArrowArrayStreamWrapper::MyStreamRelease(struct ArrowArrayStream *stream) { - if (!stream || !stream->release) { - return; - } - stream->release = nullptr; - delete reinterpret_cast(stream->private_data); -} - -const char *ResultArrowArrayStreamWrapper::MyStreamGetLastError(struct ArrowArrayStream *stream) { - if (!stream->release) { - return "stream was released"; - } - D_ASSERT(stream->private_data); - auto my_stream = reinterpret_cast(stream->private_data); - return my_stream->last_error.Message().c_str(); -} - -ResultArrowArrayStreamWrapper::ResultArrowArrayStreamWrapper(unique_ptr result_p, idx_t batch_size_p) - : result(std::move(result_p)), scan_state(make_uniq(*result)) { - //! We first initialize the private data of the stream - stream.private_data = this; - //! Ceil Approx_Batch_Size/STANDARD_VECTOR_SIZE - if (batch_size_p == 0) { - throw std::runtime_error("Approximate Batch Size of Record Batch MUST be higher than 0"); - } - batch_size = batch_size_p; - //! We initialize the stream functions - stream.get_schema = ResultArrowArrayStreamWrapper::MyStreamGetSchema; - stream.get_next = ResultArrowArrayStreamWrapper::MyStreamGetNext; - stream.release = ResultArrowArrayStreamWrapper::MyStreamRelease; - stream.get_last_error = ResultArrowArrayStreamWrapper::MyStreamGetLastError; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp deleted file mode 100644 index 11406c540..000000000 --- a/src/duckdb/src/common/arrow/physical_arrow_batch_collector.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "duckdb/common/arrow/physical_arrow_batch_collector.hpp" -#include "duckdb/common/types/batched_data_collection.hpp" -#include "duckdb/common/arrow/arrow_query_result.hpp" -#include "duckdb/common/arrow/arrow_merge_event.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/arrow/physical_arrow_collector.hpp" - -namespace duckdb { - -unique_ptr PhysicalArrowBatchCollector::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkFinalizeType PhysicalArrowBatchCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - auto total_tuple_count = gstate.data.Count(); - if (total_tuple_count == 0) { - // Create the result containing a single empty result conversion - gstate.result = make_uniq(statement_type, properties, names, types, - context.GetClientProperties(), record_batch_size); - return SinkFinalizeType::READY; - } - - // Already create the final query result - gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), - record_batch_size); - // Spawn an event that will populate the conversion result - auto &arrow_result = gstate.result->Cast(); - auto new_event = make_shared_ptr(arrow_result, gstate.data, pipeline); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp b/src/duckdb/src/common/arrow/physical_arrow_collector.cpp deleted file mode 100644 index d82246b4c..000000000 --- a/src/duckdb/src/common/arrow/physical_arrow_collector.cpp +++ /dev/null @@ -1,128 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/arrow/physical_arrow_collector.hpp" -#include "duckdb/common/arrow/physical_arrow_batch_collector.hpp" -#include "duckdb/common/arrow/arrow_query_result.hpp" -#include "duckdb/main/prepared_statement_data.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -unique_ptr PhysicalArrowCollector::Create(ClientContext &context, PreparedStatementData &data, - idx_t batch_size) { - if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, *data.plan)) { - // the plan is not order preserving, so we just use the parallel materialized collector - return make_uniq_base(data, true, batch_size); - } else if (!PhysicalPlanGenerator::UseBatchIndex(context, *data.plan)) { - // the plan is order preserving, but we cannot use the batch index: use a single-threaded result collector - return make_uniq_base(data, false, batch_size); - } else { - return make_uniq_base(data, batch_size); - } -} - -SinkResultType PhysicalArrowCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - // Append to the appender, up to chunk size - - auto count = chunk.size(); - auto &appender = lstate.appender; - D_ASSERT(count != 0); - - idx_t processed = 0; - do { - if (!appender) { - // Create the appender if we haven't started this chunk yet - auto properties = context.client.GetClientProperties(); - D_ASSERT(processed < count); - auto initial_capacity = MinValue(record_batch_size, count - processed); - appender = make_uniq(types, initial_capacity, properties); - } - - // Figure out how much we can still append to this chunk - auto row_count = appender->RowCount(); - D_ASSERT(record_batch_size > row_count); - auto to_append = MinValue(record_batch_size - row_count, count - processed); - - // Append and check if the chunk is finished - appender->Append(chunk, processed, processed + to_append, count); - processed += to_append; - row_count = appender->RowCount(); - if (row_count >= record_batch_size) { - lstate.FinishArray(); - } - } while (processed < count); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalArrowCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &last_appender = lstate.appender; - auto &arrays = lstate.finished_arrays; - if (arrays.empty() && !last_appender) { - // Nothing to do - return SinkCombineResultType::FINISHED; - } - if (last_appender) { - // FIXME: we could set these aside and merge them in a finalize event in an effort to create more balanced - // chunks out of these remnants - lstate.FinishArray(); - } - // Collect all the finished arrays - lock_guard l(gstate.glock); - // Move the arrays from our local state into the global state - gstate.chunks.insert(gstate.chunks.end(), std::make_move_iterator(arrays.begin()), - std::make_move_iterator(arrays.end())); - arrays.clear(); - gstate.tuple_count += lstate.tuple_count; - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalArrowCollector::GetResult(GlobalSinkState &state_p) { - auto &gstate = state_p.Cast(); - return std::move(gstate.result); -} - -unique_ptr PhysicalArrowCollector::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalArrowCollector::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(); -} - -SinkFinalizeType PhysicalArrowCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - if (gstate.chunks.empty()) { - if (gstate.tuple_count != 0) { - throw InternalException( - "PhysicalArrowCollector Finalize contains no chunks, but tuple_count is non-zero (%d)", - gstate.tuple_count); - } - gstate.result = make_uniq(statement_type, properties, names, types, - context.GetClientProperties(), record_batch_size); - return SinkFinalizeType::READY; - } - - gstate.result = make_uniq(statement_type, properties, names, types, context.GetClientProperties(), - record_batch_size); - auto &arrow_result = gstate.result->Cast(); - arrow_result.SetArrowData(std::move(gstate.chunks)); - - return SinkFinalizeType::READY; -} - -bool PhysicalArrowCollector::ParallelSink() const { - return parallel; -} - -bool PhysicalArrowCollector::SinkOrderDependent() const { - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/arrow/schema_metadata.cpp b/src/duckdb/src/common/arrow/schema_metadata.cpp deleted file mode 100644 index 836f89f2c..000000000 --- a/src/duckdb/src/common/arrow/schema_metadata.cpp +++ /dev/null @@ -1,131 +0,0 @@ -#include "duckdb/common/arrow/schema_metadata.hpp" - -namespace duckdb { - -ArrowSchemaMetadata::ArrowSchemaMetadata(const char *metadata) { - if (metadata) { - // Read the number of key-value pairs (int32) - int32_t num_pairs; - memcpy(&num_pairs, metadata, sizeof(int32_t)); - metadata += sizeof(int32_t); - - // Loop through each key-value pair - for (int32_t i = 0; i < num_pairs; ++i) { - // Read the length of the key (int32) - int32_t key_length; - memcpy(&key_length, metadata, sizeof(int32_t)); - metadata += sizeof(int32_t); - - // Read the key - std::string key(metadata, static_cast(key_length)); - metadata += key_length; - - // Read the length of the value (int32) - int32_t value_length; - memcpy(&value_length, metadata, sizeof(int32_t)); - metadata += sizeof(int32_t); - - // Read the value - const std::string value(metadata, static_cast(value_length)); - metadata += value_length; - schema_metadata_map[key] = value; - } - } - extension_metadata_map = StringUtil::ParseJSONMap(schema_metadata_map[ARROW_METADATA_KEY]); -} - -void ArrowSchemaMetadata::AddOption(const string &key, const string &value) { - schema_metadata_map[key] = value; -} -string ArrowSchemaMetadata::GetOption(const string &key) const { - auto it = schema_metadata_map.find(key); - if (it != schema_metadata_map.end()) { - return it->second; - } else { - return ""; - } -} - -string ArrowSchemaMetadata::GetExtensionName() const { - return GetOption(ARROW_EXTENSION_NAME); -} - -ArrowSchemaMetadata ArrowSchemaMetadata::ArrowCanonicalType(const string &extension_name) { - ArrowSchemaMetadata metadata; - metadata.AddOption(ARROW_EXTENSION_NAME, extension_name); - metadata.AddOption(ARROW_METADATA_KEY, ""); - return metadata; -} - -ArrowSchemaMetadata ArrowSchemaMetadata::DuckDBInternalType(const string &type_name) { - ArrowSchemaMetadata metadata; - metadata.AddOption(ARROW_EXTENSION_NAME, ARROW_EXTENSION_NON_CANONICAL); - // We have to set the metadata key with type_name and vendor_name. - metadata.extension_metadata_map["vendor_name"] = "DuckDB"; - metadata.extension_metadata_map["type_name"] = type_name; - metadata.AddOption(ARROW_METADATA_KEY, StringUtil::ToJSONMap(metadata.extension_metadata_map)); - return metadata; -} - -bool ArrowSchemaMetadata::IsNonCanonicalType(const string &type, const string &vendor) const { - if (schema_metadata_map.find(ARROW_EXTENSION_NAME) == schema_metadata_map.end()) { - return false; - } - if (schema_metadata_map.find(ARROW_EXTENSION_NAME)->second != ARROW_EXTENSION_NON_CANONICAL) { - return false; - } - if (extension_metadata_map.find("type_name") == extension_metadata_map.end() || - extension_metadata_map.find("vendor_name") == extension_metadata_map.end()) { - return false; - } - auto vendor_name = extension_metadata_map.find("vendor_name")->second; - auto type_name = extension_metadata_map.find("type_name")->second; - return vendor_name == vendor && type_name == type; -} - -bool ArrowSchemaMetadata::HasExtension() const { - auto arrow_extension = GetOption(ArrowSchemaMetadata::ARROW_EXTENSION_NAME); - // FIXME: We are currently ignoring the ogc extensions - return !arrow_extension.empty() && !StringUtil::StartsWith(arrow_extension, "ogc"); -} - -unsafe_unique_array ArrowSchemaMetadata::SerializeMetadata() const { - // First we have to figure out the total size: - // 1. number of key-value pairs (int32) - idx_t total_size = sizeof(int32_t); - for (const auto &option : schema_metadata_map) { - // 2. Length of the key and value (2 * int32) - total_size += 2 * sizeof(int32_t); - // 3. Length of key - total_size += option.first.size(); - // 4. Length of value - total_size += option.second.size(); - } - auto metadata_array_ptr = make_unsafe_uniq_array(total_size); - auto metadata_ptr = metadata_array_ptr.get(); - // 1. number of key-value pairs (int32) - const idx_t map_size = schema_metadata_map.size(); - memcpy(metadata_ptr, &map_size, sizeof(int32_t)); - metadata_ptr += sizeof(int32_t); - // Iterate through each key-value pair in the map - for (const auto &pair : schema_metadata_map) { - const std::string &key = pair.first; - idx_t key_size = key.size(); - // Length of the key (int32) - memcpy(metadata_ptr, &key_size, sizeof(int32_t)); - metadata_ptr += sizeof(int32_t); - // Key - memcpy(metadata_ptr, key.c_str(), key_size); - metadata_ptr += key_size; - const std::string &value = pair.second; - const idx_t value_size = value.size(); - // Length of the value (int32) - memcpy(metadata_ptr, &value_size, sizeof(int32_t)); - metadata_ptr += sizeof(int32_t); - // Value - memcpy(metadata_ptr, value.c_str(), value_size); - metadata_ptr += value_size; - } - return metadata_array_ptr; -} -} // namespace duckdb diff --git a/src/duckdb/src/common/assert.cpp b/src/duckdb/src/common/assert.cpp deleted file mode 100644 index 02f62212e..000000000 --- a/src/duckdb/src/common/assert.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "duckdb/common/assert.hpp" -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -void DuckDBAssertInternal(bool condition, const char *condition_name, const char *file, int linenr) { -#ifdef DISABLE_ASSERTIONS - return; -#endif - if (condition) { - return; - } - throw InternalException("Assertion triggered in file \"%s\" on line %d: %s", file, linenr, condition_name); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/bind_helpers.cpp b/src/duckdb/src/common/bind_helpers.cpp deleted file mode 100644 index b618b797b..000000000 --- a/src/duckdb/src/common/bind_helpers.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "duckdb/common/bind_helpers.hpp" -#include "duckdb/common/common.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/case_insensitive_map.hpp" -#include "duckdb/common/exception/binder_exception.hpp" -#include - -namespace duckdb { - -Value ConvertVectorToValue(vector set) { - if (set.empty()) { - return Value::LIST(LogicalType::BOOLEAN, std::move(set)); - } - return Value::LIST(std::move(set)); -} - -vector ParseColumnList(const vector &set, vector &names, const string &loption) { - vector result; - - if (set.empty()) { - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - // list of options: parse the list - case_insensitive_map_t option_map; - for (idx_t i = 0; i < set.size(); i++) { - option_map[set[i].ToString()] = false; - } - result.resize(names.size(), false); - for (idx_t i = 0; i < names.size(); i++) { - auto entry = option_map.find(names[i]); - if (entry != option_map.end()) { - result[i] = true; - entry->second = true; - } - } - for (auto &entry : option_map) { - if (!entry.second) { - throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, - entry.first.c_str()); - } - } - return result; -} - -vector ParseColumnList(const Value &value, vector &names, const string &loption) { - vector result; - - // Only accept a list of arguments - if (value.type().id() != LogicalTypeId::LIST) { - // Support a single argument if it's '*' - if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { - result.resize(names.size(), true); - return result; - } - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - auto &children = ListValue::GetChildren(value); - // accept '*' as single argument - if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && - children[0].GetValue() == "*") { - result.resize(names.size(), true); - return result; - } - return ParseColumnList(children, names, loption); -} - -vector ParseColumnsOrdered(const vector &set, vector &names, const string &loption) { - vector result; - - if (set.empty()) { - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - - // Maps option to bool indicating if its found and the index in the original set - case_insensitive_map_t> option_map; - for (idx_t i = 0; i < set.size(); i++) { - option_map[set[i].ToString()] = {false, i}; - } - result.resize(option_map.size()); - - for (idx_t i = 0; i < names.size(); i++) { - auto entry = option_map.find(names[i]); - if (entry != option_map.end()) { - result[entry->second.second] = i; - entry->second.first = true; - } - } - for (auto &entry : option_map) { - if (!entry.second.first) { - throw BinderException("\"%s\" expected to find %s, but it was not found in the table", loption, - entry.first.c_str()); - } - } - return result; -} - -vector ParseColumnsOrdered(const Value &value, vector &names, const string &loption) { - vector result; - - // Only accept a list of arguments - if (value.type().id() != LogicalTypeId::LIST) { - // Support a single argument if it's '*' - if (value.type().id() == LogicalTypeId::VARCHAR && value.GetValue() == "*") { - result.resize(names.size(), 0); - std::iota(std::begin(result), std::end(result), 0); - return result; - } - throw BinderException("\"%s\" expects a column list or * as parameter", loption); - } - auto &children = ListValue::GetChildren(value); - // accept '*' as single argument - if (children.size() == 1 && children[0].type().id() == LogicalTypeId::VARCHAR && - children[0].GetValue() == "*") { - result.resize(names.size(), 0); - std::iota(std::begin(result), std::end(result), 0); - return result; - } - return ParseColumnsOrdered(children, names, loption); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/box_renderer.cpp b/src/duckdb/src/common/box_renderer.cpp deleted file mode 100644 index f85cf622d..000000000 --- a/src/duckdb/src/common/box_renderer.cpp +++ /dev/null @@ -1,1092 +0,0 @@ -#include "duckdb/common/box_renderer.hpp" - -#include "duckdb/common/printer.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -const idx_t BoxRenderer::SPLIT_COLUMN = idx_t(-1); - -//===--------------------------------------------------------------------===// -// Result Renderer -//===--------------------------------------------------------------------===// -BaseResultRenderer::BaseResultRenderer() : value_type(LogicalTypeId::INVALID) { -} - -BaseResultRenderer::~BaseResultRenderer() { -} - -BaseResultRenderer &BaseResultRenderer::operator<<(char c) { - RenderLayout(string(1, c)); - return *this; -} - -BaseResultRenderer &BaseResultRenderer::operator<<(const string &val) { - RenderLayout(val); - return *this; -} - -void BaseResultRenderer::Render(ResultRenderType render_mode, const string &val) { - switch (render_mode) { - case ResultRenderType::LAYOUT: - RenderLayout(val); - break; - case ResultRenderType::COLUMN_NAME: - RenderColumnName(val); - break; - case ResultRenderType::COLUMN_TYPE: - RenderType(val); - break; - case ResultRenderType::VALUE: - RenderValue(val, value_type); - break; - case ResultRenderType::NULL_VALUE: - RenderNull(val, value_type); - break; - case ResultRenderType::FOOTER: - RenderFooter(val); - break; - default: - throw InternalException("Unsupported type for result renderer"); - } -} - -void BaseResultRenderer::SetValueType(const LogicalType &type) { - value_type = type; -} - -void StringResultRenderer::RenderLayout(const string &text) { - result += text; -} - -void StringResultRenderer::RenderColumnName(const string &text) { - result += text; -} - -void StringResultRenderer::RenderType(const string &text) { - result += text; -} - -void StringResultRenderer::RenderValue(const string &text, const LogicalType &type) { - result += text; -} - -void StringResultRenderer::RenderNull(const string &text, const LogicalType &type) { - result += text; -} - -void StringResultRenderer::RenderFooter(const string &text) { - result += text; -} - -const string &StringResultRenderer::str() { - return result; -} - -//===--------------------------------------------------------------------===// -// Box Renderer -//===--------------------------------------------------------------------===// -BoxRenderer::BoxRenderer(BoxRendererConfig config_p) : config(std::move(config_p)) { -} - -string BoxRenderer::ToString(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - StringResultRenderer ss; - Render(context, names, result, ss); - return ss.str(); -} - -void BoxRenderer::Print(ClientContext &context, const vector &names, const ColumnDataCollection &result) { - Printer::Print(ToString(context, names, result)); -} - -void BoxRenderer::RenderValue(BaseResultRenderer &ss, const string &value, idx_t column_width, - ResultRenderType render_mode, ValueRenderAlignment alignment) { - auto render_width = Utf8Proc::RenderWidth(value); - - const string *render_value = &value; - string small_value; - if (render_width > column_width) { - // the string is too large to fit in this column! - // the size of this column must have been reduced - // figure out how much of this value we can render - idx_t pos = 0; - idx_t current_render_width = config.DOTDOTDOT_LENGTH; - while (pos < value.size()) { - // check if this character fits... - auto char_size = Utf8Proc::RenderWidth(value.c_str(), value.size(), pos); - if (current_render_width + char_size >= column_width) { - // it doesn't! stop - break; - } - // it does! move to the next character - current_render_width += char_size; - pos = Utf8Proc::NextGraphemeCluster(value.c_str(), value.size(), pos); - } - small_value = value.substr(0, pos) + config.DOTDOTDOT; - render_value = &small_value; - render_width = current_render_width; - } - auto padding_count = (column_width - render_width) + 2; - idx_t lpadding; - idx_t rpadding; - switch (alignment) { - case ValueRenderAlignment::LEFT: - lpadding = 1; - rpadding = padding_count - 1; - break; - case ValueRenderAlignment::MIDDLE: - lpadding = padding_count / 2; - rpadding = padding_count - lpadding; - break; - case ValueRenderAlignment::RIGHT: - lpadding = padding_count - 1; - rpadding = 1; - break; - default: - throw InternalException("Unrecognized value renderer alignment"); - } - ss << config.VERTICAL; - ss << string(lpadding, ' '); - ss.Render(render_mode, *render_value); - ss << string(rpadding, ' '); -} - -string BoxRenderer::RenderType(const LogicalType &type) { - if (type.HasAlias()) { - return StringUtil::Lower(type.ToString()); - } - switch (type.id()) { - case LogicalTypeId::TINYINT: - return "int8"; - case LogicalTypeId::SMALLINT: - return "int16"; - case LogicalTypeId::INTEGER: - return "int32"; - case LogicalTypeId::BIGINT: - return "int64"; - case LogicalTypeId::HUGEINT: - return "int128"; - case LogicalTypeId::UTINYINT: - return "uint8"; - case LogicalTypeId::USMALLINT: - return "uint16"; - case LogicalTypeId::UINTEGER: - return "uint32"; - case LogicalTypeId::UBIGINT: - return "uint64"; - case LogicalTypeId::UHUGEINT: - return "uint128"; - case LogicalTypeId::LIST: { - auto child = RenderType(ListType::GetChildType(type)); - return child + "[]"; - } - default: - return StringUtil::Lower(type.ToString()); - } -} - -ValueRenderAlignment BoxRenderer::TypeAlignment(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return ValueRenderAlignment::RIGHT; - default: - return ValueRenderAlignment::LEFT; - } -} - -string BoxRenderer::TryFormatLargeNumber(const string &numeric) { - // we only return a readable rendering if the number is > 1 million - if (numeric.size() <= 5) { - // number too small for sure - return string(); - } - // get the number to summarize - idx_t number = 0; - bool negative = false; - idx_t i = 0; - if (numeric[0] == '-') { - negative = true; - i++; - } - for (; i < numeric.size(); i++) { - char c = numeric[i]; - if (c == '.') { - break; - } - if (c < '0' || c > '9') { - // not a number or something funky (e.g. 1.23e7) - // we could theoretically summarize numbers with exponents - return string(); - } - if (number >= 1000000000000000000ULL) { - // number too big - return string(); - } - number = number * 10 + static_cast(c - '0'); - } - struct UnitBase { - idx_t base; - const char *name; - }; - static constexpr idx_t BASE_COUNT = 5; - UnitBase bases[] = {{1000000ULL, "million"}, - {1000000000ULL, "billion"}, - {1000000000000ULL, "trillion"}, - {1000000000000000ULL, "quadrillion"}, - {1000000000000000000ULL, "quintillion"}}; - idx_t base = 0; - string unit; - for (idx_t i = 0; i < BASE_COUNT; i++) { - // round the number according to this base - idx_t rounded_number = number + ((bases[i].base / 100ULL) / 2); - if (rounded_number >= bases[i].base) { - base = bases[i].base; - unit = bases[i].name; - } - } - if (unit.empty()) { - return string(); - } - number += (base / 100ULL) / 2; - idx_t decimal_unit = number / (base / 100ULL); - string decimal_str = to_string(decimal_unit); - string result; - if (negative) { - result += "-"; - } - result += decimal_str.substr(0, decimal_str.size() - 2); - result += config.decimal_separator == '\0' ? '.' : config.decimal_separator; - result += decimal_str.substr(decimal_str.size() - 2, 2); - result += " "; - result += unit; - return result; -} - -list BoxRenderer::FetchRenderCollections(ClientContext &context, - const ColumnDataCollection &result, idx_t top_rows, - idx_t bottom_rows) { - auto column_count = result.ColumnCount(); - vector varchar_types; - for (idx_t c = 0; c < column_count; c++) { - varchar_types.emplace_back(LogicalType::VARCHAR); - } - std::list collections; - collections.emplace_back(context, varchar_types); - collections.emplace_back(context, varchar_types); - - auto &top_collection = collections.front(); - auto &bottom_collection = collections.back(); - - DataChunk fetch_result; - fetch_result.Initialize(context, result.Types()); - - DataChunk insert_result; - insert_result.Initialize(context, varchar_types); - - if (config.large_number_rendering == LargeNumberRendering::FOOTER) { - if (config.render_mode != RenderMode::ROWS || result.Count() != 1) { - // large number footer can only be constructed (1) if we have a single row, and (2) in ROWS mode - config.large_number_rendering = LargeNumberRendering::NONE; - } - } - - // fetch the top rows from the ColumnDataCollection - idx_t chunk_idx = 0; - idx_t row_idx = 0; - while (row_idx < top_rows) { - fetch_result.Reset(); - insert_result.Reset(); - // fetch the next chunk - result.FetchChunk(chunk_idx, fetch_result); - idx_t insert_count = MinValue(fetch_result.size(), top_rows - row_idx); - - // cast all columns to varchar - for (idx_t c = 0; c < column_count; c++) { - VectorOperations::Cast(context, fetch_result.data[c], insert_result.data[c], insert_count); - } - insert_result.SetCardinality(insert_count); - - // construct the render collection - top_collection.Append(insert_result); - - // if we have are constructing a footer - if (config.large_number_rendering == LargeNumberRendering::FOOTER) { - D_ASSERT(insert_count == 1); - vector readable_numbers; - readable_numbers.resize(column_count); - bool all_readable = true; - for (idx_t c = 0; c < column_count; c++) { - if (!result.Types()[c].IsNumeric()) { - // not a numeric type - cannot summarize - all_readable = false; - break; - } - // add a readable rendering of the value (i.e. "1234567" becomes "1.23 million") - // we only add the rendering if the string is big - auto numeric_val = insert_result.data[c].GetValue(0).ToString(); - readable_numbers[c] = TryFormatLargeNumber(numeric_val); - if (readable_numbers[c].empty()) { - all_readable = false; - break; - } - readable_numbers[c] = "(" + readable_numbers[c] + ")"; - } - insert_result.Reset(); - if (all_readable) { - for (idx_t c = 0; c < column_count; c++) { - insert_result.data[c].SetValue(0, Value(readable_numbers[c])); - } - insert_result.SetCardinality(1); - top_collection.Append(insert_result); - } - } - - chunk_idx++; - row_idx += fetch_result.size(); - } - - // fetch the bottom rows from the ColumnDataCollection - row_idx = 0; - chunk_idx = result.ChunkCount() - 1; - while (row_idx < bottom_rows) { - fetch_result.Reset(); - insert_result.Reset(); - // fetch the next chunk - result.FetchChunk(chunk_idx, fetch_result); - idx_t insert_count = MinValue(fetch_result.size(), bottom_rows - row_idx); - - // invert the rows - SelectionVector inverted_sel(insert_count); - for (idx_t r = 0; r < insert_count; r++) { - inverted_sel.set_index(r, fetch_result.size() - r - 1); - } - - for (idx_t c = 0; c < column_count; c++) { - Vector slice(fetch_result.data[c], inverted_sel, insert_count); - VectorOperations::Cast(context, slice, insert_result.data[c], insert_count); - } - insert_result.SetCardinality(insert_count); - // construct the render collection - bottom_collection.Append(insert_result); - - chunk_idx--; - row_idx += fetch_result.size(); - } - return collections; -} - -list BoxRenderer::PivotCollections(ClientContext &context, list input, - vector &column_names, - vector &result_types, idx_t row_count) { - auto &top = input.front(); - auto &bottom = input.back(); - - vector varchar_types; - vector new_names; - new_names.emplace_back("Column"); - new_names.emplace_back("Type"); - varchar_types.emplace_back(LogicalType::VARCHAR); - varchar_types.emplace_back(LogicalType::VARCHAR); - for (idx_t r = 0; r < top.Count(); r++) { - new_names.emplace_back("Row " + to_string(r + 1)); - varchar_types.emplace_back(LogicalType::VARCHAR); - } - for (idx_t r = 0; r < bottom.Count(); r++) { - auto row_index = row_count - bottom.Count() + r + 1; - new_names.emplace_back("Row " + to_string(row_index)); - varchar_types.emplace_back(LogicalType::VARCHAR); - } - // - DataChunk row_chunk; - row_chunk.Initialize(Allocator::DefaultAllocator(), varchar_types); - std::list result; - result.emplace_back(context, varchar_types); - result.emplace_back(context, varchar_types); - auto &res_coll = result.front(); - ColumnDataAppendState append_state; - res_coll.InitializeAppend(append_state); - for (idx_t c = 0; c < top.ColumnCount(); c++) { - vector column_ids {c}; - auto row_index = row_chunk.size(); - idx_t current_index = 0; - row_chunk.SetValue(current_index++, row_index, column_names[c]); - row_chunk.SetValue(current_index++, row_index, RenderType(result_types[c])); - for (auto &collection : input) { - for (auto &chunk : collection.Chunks(column_ids)) { - for (idx_t r = 0; r < chunk.size(); r++) { - row_chunk.SetValue(current_index++, row_index, chunk.GetValue(0, r)); - } - } - } - row_chunk.SetCardinality(row_chunk.size() + 1); - if (row_chunk.size() == STANDARD_VECTOR_SIZE || c + 1 == top.ColumnCount()) { - res_coll.Append(append_state, row_chunk); - row_chunk.Reset(); - } - } - column_names = std::move(new_names); - result_types = std::move(varchar_types); - return result; -} - -string BoxRenderer::ConvertRenderValue(const string &input) { - string result; - result.reserve(input.size()); - for (idx_t c = 0; c < input.size(); c++) { - data_t byte_value = const_data_ptr_cast(input.c_str())[c]; - if (byte_value < 32) { - // ASCII control character - result += "\\"; - switch (input[c]) { - case 7: - // bell - result += 'a'; - break; - case 8: - // backspace - result += 'b'; - break; - case 9: - // tab - result += 't'; - break; - case 10: - // newline - result += 'n'; - break; - case 11: - // vertical tab - result += 'v'; - break; - case 12: - // form feed - result += 'f'; - break; - case 13: - // cariage return - result += 'r'; - break; - case 27: - // escape - result += 'e'; - break; - default: - result += to_string(byte_value); - break; - } - } else { - result += input[c]; - } - } - return result; -} - -string BoxRenderer::FormatNumber(const string &input) { - if (config.large_number_rendering == LargeNumberRendering::ALL) { - // when large number rendering is set to ALL, we try to format all numbers as large numbers - auto number = TryFormatLargeNumber(input); - if (!number.empty()) { - return number; - } - } - if (config.decimal_separator == '\0' && config.thousand_separator == '\0') { - // no thousand separator - return input; - } - // first check how many digits there are (preceding any decimal point) - idx_t character_count = 0; - for (auto c : input) { - if (!StringUtil::CharacterIsDigit(c)) { - break; - } - character_count++; - } - // find the position of the first thousand separator - idx_t separator_position = character_count % 3 == 0 ? 3 : character_count % 3; - // now add the thousand separators - string result; - for (idx_t c = 0; c < character_count; c++) { - if (c == separator_position && config.thousand_separator != '\0') { - result += config.thousand_separator; - separator_position += 3; - } - result += input[c]; - } - // add any remaining characters - for (idx_t c = character_count; c < input.size(); c++) { - if (input[c] == '.' && config.decimal_separator != '\0') { - result += config.decimal_separator; - } else { - result += input[c]; - } - } - return result; -} - -string BoxRenderer::ConvertRenderValue(const string &input, const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return FormatNumber(input); - default: - return ConvertRenderValue(input); - } -} - -string BoxRenderer::GetRenderValue(BaseResultRenderer &ss, ColumnDataRowCollection &rows, idx_t c, idx_t r, - const LogicalType &type, ResultRenderType &render_mode) { - try { - render_mode = ResultRenderType::VALUE; - ss.SetValueType(type); - auto row = rows.GetValue(c, r); - if (row.IsNull()) { - render_mode = ResultRenderType::NULL_VALUE; - return config.null_value; - } - return ConvertRenderValue(StringValue::Get(row), type); - } catch (std::exception &ex) { - return "????INVALID VALUE - " + string(ex.what()) + "?????"; - } -} - -vector BoxRenderer::ComputeRenderWidths(const vector &names, const vector &result_types, - list &collections, idx_t min_width, - idx_t max_width, vector &column_map, idx_t &total_length) { - auto column_count = result_types.size(); - - vector widths; - widths.reserve(column_count); - for (idx_t c = 0; c < column_count; c++) { - auto name_width = Utf8Proc::RenderWidth(ConvertRenderValue(names[c])); - auto type_width = Utf8Proc::RenderWidth(RenderType(result_types[c])); - widths.push_back(MaxValue(name_width, type_width)); - } - - // now iterate over the data in the render collection and find out the true max width - for (auto &collection : collections) { - for (auto &chunk : collection.Chunks()) { - for (idx_t c = 0; c < column_count; c++) { - auto string_data = FlatVector::GetData(chunk.data[c]); - for (idx_t r = 0; r < chunk.size(); r++) { - string render_value; - if (FlatVector::IsNull(chunk.data[c], r)) { - render_value = config.null_value; - } else { - render_value = ConvertRenderValue(string_data[r].GetString(), result_types[c]); - } - auto render_width = Utf8Proc::RenderWidth(render_value); - widths[c] = MaxValue(render_width, widths[c]); - } - } - } - } - - // figure out the total length - // we start off with a pipe (|) - total_length = 1; - for (idx_t c = 0; c < widths.size(); c++) { - // each column has a space at the beginning, and a space plus a pipe (|) at the end - // hence + 3 - total_length += widths[c] + 3; - } - if (total_length < min_width) { - // if there are hidden rows we should always display that - // stretch up the first column until we have space to show the row count - widths[0] += min_width - total_length; - total_length = min_width; - } - // now we need to constrain the length - unordered_set pruned_columns; - if (total_length > max_width) { - // before we remove columns, check if we can just reduce the size of columns - for (auto &w : widths) { - if (w > config.max_col_width) { - auto max_diff = w - config.max_col_width; - if (total_length - max_diff <= max_width) { - // if we reduce the size of this column we fit within the limits! - // reduce the width exactly enough so that the box fits - w -= total_length - max_width; - total_length = max_width; - break; - } else { - // reducing the width of this column does not make the result fit - // reduce the column width by the maximum amount anyway - w = config.max_col_width; - total_length -= max_diff; - } - } - } - - if (total_length > max_width) { - // the total length is still too large - // we need to remove columns! - // first, we add 6 characters to the total length - // this is what we need to add the "..." in the middle - total_length += 3 + config.DOTDOTDOT_LENGTH; - // now select columns to prune - // we select columns in zig-zag order starting from the middle - // e.g. if we have 10 columns, we remove #5, then #4, then #6, then #3, then #7, etc - int64_t offset = 0; - while (total_length > max_width) { - auto c = NumericCast(NumericCast(column_count) / 2 + offset); - total_length -= widths[c] + 3; - pruned_columns.insert(c); - if (offset >= 0) { - offset = -offset - 1; - } else { - offset = -offset; - } - } - } - } - - bool added_split_column = false; - vector new_widths; - for (idx_t c = 0; c < column_count; c++) { - if (pruned_columns.find(c) == pruned_columns.end()) { - column_map.push_back(c); - new_widths.push_back(widths[c]); - } else { - if (!added_split_column) { - // "..." - column_map.push_back(SPLIT_COLUMN); - new_widths.push_back(config.DOTDOTDOT_LENGTH); - added_split_column = true; - } - } - } - return new_widths; -} - -void BoxRenderer::RenderHeader(const vector &names, const vector &result_types, - const vector &column_map, const vector &widths, - const vector &boundaries, idx_t total_length, bool has_results, - BaseResultRenderer &ss) { - auto column_count = column_map.size(); - // render the top line - ss << config.LTCORNER; - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < column_count && k == boundaries[column_index]) { - ss << config.TMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << config.RTCORNER; - ss << '\n'; - - // render the header names - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string name; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - render_mode = ResultRenderType::LAYOUT; - name = config.DOTDOTDOT; - } else { - render_mode = ResultRenderType::COLUMN_NAME; - name = ConvertRenderValue(names[column_idx]); - } - RenderValue(ss, name, widths[c], render_mode); - } - ss << config.VERTICAL; - ss << '\n'; - - // render the types - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string type; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - render_mode = ResultRenderType::LAYOUT; - } else { - render_mode = ResultRenderType::COLUMN_TYPE; - type = RenderType(result_types[column_idx]); - } - RenderValue(ss, type, widths[c], render_mode); - } - ss << config.VERTICAL; - ss << '\n'; - } - - // render the line under the header - ss << config.LMIDDLE; - column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < column_count && k == boundaries[column_index]) { - ss << (has_results ? config.MIDDLE : config.DMIDDLE); - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << config.RMIDDLE; - ss << '\n'; -} - -void BoxRenderer::RenderValues(const list &collections, const vector &column_map, - const vector &widths, const vector &result_types, - BaseResultRenderer &ss) { - auto &top_collection = collections.front(); - auto &bottom_collection = collections.back(); - // render the top rows - auto top_rows = top_collection.Count(); - auto bottom_rows = bottom_collection.Count(); - auto column_count = column_map.size(); - - bool large_number_footer = config.large_number_rendering == LargeNumberRendering::FOOTER; - vector alignments; - if (config.render_mode == RenderMode::ROWS) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - if (column_idx == SPLIT_COLUMN) { - alignments.push_back(ValueRenderAlignment::MIDDLE); - } else if (large_number_footer && result_types[column_idx].IsNumeric()) { - alignments.push_back(ValueRenderAlignment::MIDDLE); - } else { - alignments.push_back(TypeAlignment(result_types[column_idx])); - } - } - } - - auto rows = top_collection.GetRows(); - for (idx_t r = 0; r < top_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - render_mode = ResultRenderType::LAYOUT; - } else { - str = GetRenderValue(ss, rows, column_idx, r, result_types[column_idx], render_mode); - } - ValueRenderAlignment alignment; - if (config.render_mode == RenderMode::ROWS) { - alignment = alignments[c]; - if (large_number_footer && r == 1) { - // render readable numbers with highlighting of a NULL value - render_mode = ResultRenderType::NULL_VALUE; - } - } else { - switch (c) { - case 0: - render_mode = ResultRenderType::COLUMN_NAME; - break; - case 1: - render_mode = ResultRenderType::COLUMN_TYPE; - break; - default: - render_mode = ResultRenderType::VALUE; - break; - } - if (c < 2) { - alignment = ValueRenderAlignment::LEFT; - } else if (c == SPLIT_COLUMN) { - alignment = ValueRenderAlignment::MIDDLE; - } else { - alignment = ValueRenderAlignment::RIGHT; - } - } - RenderValue(ss, str, widths[c], render_mode, alignment); - } - ss << config.VERTICAL; - ss << '\n'; - } - - if (bottom_rows > 0) { - if (config.render_mode == RenderMode::COLUMNS) { - throw InternalException("Columns render mode does not support bottom rows"); - } - // render the bottom rows - // first render the divider - auto brows = bottom_collection.GetRows(); - for (idx_t k = 0; k < 3; k++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - auto alignment = alignments[c]; - if (alignment == ValueRenderAlignment::MIDDLE || column_idx == SPLIT_COLUMN) { - str = config.DOT; - } else { - // align the dots in the center of the column - ResultRenderType render_mode; - auto top_value = - GetRenderValue(ss, rows, column_idx, top_rows - 1, result_types[column_idx], render_mode); - auto bottom_value = - GetRenderValue(ss, brows, column_idx, bottom_rows - 1, result_types[column_idx], render_mode); - auto top_length = MinValue(widths[c], Utf8Proc::RenderWidth(top_value)); - auto bottom_length = MinValue(widths[c], Utf8Proc::RenderWidth(bottom_value)); - auto dot_length = MinValue(top_length, bottom_length); - if (top_length == 0) { - dot_length = bottom_length; - } else if (bottom_length == 0) { - dot_length = top_length; - } - if (dot_length > 1) { - auto padding = dot_length - 1; - idx_t left_padding, right_padding; - switch (alignment) { - case ValueRenderAlignment::LEFT: - left_padding = padding / 2; - right_padding = padding - left_padding; - break; - case ValueRenderAlignment::RIGHT: - right_padding = padding / 2; - left_padding = padding - right_padding; - break; - default: - throw InternalException("Unrecognized value renderer alignment"); - } - str = string(left_padding, ' ') + config.DOT + string(right_padding, ' '); - } else { - if (dot_length == 0) { - // everything is empty - alignment = ValueRenderAlignment::MIDDLE; - } - str = config.DOT; - } - } - RenderValue(ss, str, widths[c], ResultRenderType::LAYOUT, alignment); - } - ss << config.VERTICAL; - ss << '\n'; - } - // note that the bottom rows are in reverse order - for (idx_t r = 0; r < bottom_rows; r++) { - for (idx_t c = 0; c < column_count; c++) { - auto column_idx = column_map[c]; - string str; - ResultRenderType render_mode; - if (column_idx == SPLIT_COLUMN) { - str = config.DOTDOTDOT; - render_mode = ResultRenderType::LAYOUT; - } else { - str = GetRenderValue(ss, brows, column_idx, bottom_rows - r - 1, result_types[column_idx], - render_mode); - } - RenderValue(ss, str, widths[c], render_mode, alignments[c]); - } - ss << config.VERTICAL; - ss << '\n'; - } - } -} - -void BoxRenderer::RenderRowCount(string row_count_str, string shown_str, const string &column_count_str, - const vector &boundaries, bool has_hidden_rows, bool has_hidden_columns, - idx_t total_length, idx_t row_count, idx_t column_count, idx_t minimum_row_length, - BaseResultRenderer &ss) { - // check if we can merge the row_count_str and the shown_str - bool display_shown_separately = has_hidden_rows; - if (has_hidden_rows && total_length >= row_count_str.size() + shown_str.size() + 5) { - // we can! - row_count_str += " " + shown_str; - shown_str = string(); - display_shown_separately = false; - minimum_row_length = row_count_str.size() + 4; - } - auto minimum_length = row_count_str.size() + column_count_str.size() + 6; - bool render_rows_and_columns = total_length >= minimum_length && - ((has_hidden_columns && row_count > 0) || (row_count >= 10 && column_count > 1)); - bool render_rows = total_length >= minimum_row_length && (row_count == 0 || row_count >= 10); - bool render_anything = true; - if (!render_rows && !render_rows_and_columns) { - render_anything = false; - } - // render the bottom of the result values, if there are any - if (row_count > 0) { - ss << (render_anything ? config.LMIDDLE : config.LDCORNER); - idx_t column_index = 0; - for (idx_t k = 0; k < total_length - 2; k++) { - if (column_index + 1 < boundaries.size() && k == boundaries[column_index]) { - ss << config.DMIDDLE; - column_index++; - } else { - ss << config.HORIZONTAL; - } - } - ss << (render_anything ? config.RMIDDLE : config.RDCORNER); - ss << '\n'; - } - if (!render_anything) { - return; - } - - if (render_rows_and_columns) { - ss << config.VERTICAL; - ss << " "; - ss.Render(ResultRenderType::FOOTER, row_count_str); - ss << string(total_length - row_count_str.size() - column_count_str.size() - 4, ' '); - ss.Render(ResultRenderType::FOOTER, column_count_str); - ss << " "; - ss << config.VERTICAL; - ss << '\n'; - } else if (render_rows) { - RenderValue(ss, row_count_str, total_length - 4, ResultRenderType::FOOTER); - ss << config.VERTICAL; - ss << '\n'; - - if (display_shown_separately) { - RenderValue(ss, shown_str, total_length - 4, ResultRenderType::FOOTER); - ss << config.VERTICAL; - ss << '\n'; - } - } - // render the bottom line - ss << config.LDCORNER; - for (idx_t k = 0; k < total_length - 2; k++) { - ss << config.HORIZONTAL; - } - ss << config.RDCORNER; - ss << '\n'; -} - -void BoxRenderer::Render(ClientContext &context, const vector &names, const ColumnDataCollection &result, - BaseResultRenderer &ss) { - if (result.ColumnCount() != names.size()) { - throw InternalException("Error in BoxRenderer::Render - unaligned columns and names"); - } - auto max_width = config.max_width; - if (max_width == 0) { - if (Printer::IsTerminal(OutputStream::STREAM_STDOUT)) { - max_width = Printer::TerminalWidth(); - } else { - max_width = 120; - } - } - // we do not support max widths under 80 - max_width = MaxValue(80, max_width); - - // figure out how many/which rows to render - idx_t row_count = result.Count(); - idx_t rows_to_render = MinValue(row_count, config.max_rows); - if (row_count <= config.max_rows + 3) { - // hiding rows adds 3 extra rows - // so hiding rows makes no sense if we are only slightly over the limit - // if we are 1 row over the limit hiding rows will actually increase the number of lines we display! - // in this case render all the rows - rows_to_render = row_count; - } - idx_t top_rows; - idx_t bottom_rows; - if (rows_to_render == row_count) { - top_rows = row_count; - bottom_rows = 0; - } else { - top_rows = rows_to_render / 2 + (rows_to_render % 2 != 0 ? 1 : 0); - bottom_rows = rows_to_render - top_rows; - } - auto row_count_str = to_string(row_count) + " rows"; - bool has_limited_rows = config.limit > 0 && row_count == config.limit; - if (has_limited_rows) { - row_count_str = "? rows"; - } - string shown_str; - bool has_hidden_rows = top_rows < row_count; - if (has_hidden_rows) { - shown_str = "("; - if (has_limited_rows) { - shown_str += ">" + to_string(config.limit - 1) + " rows, "; - } - shown_str += to_string(top_rows + bottom_rows) + " shown)"; - } - auto minimum_row_length = MaxValue(row_count_str.size(), shown_str.size()) + 4; - - // fetch the top and bottom render collections from the result - auto collections = FetchRenderCollections(context, result, top_rows, bottom_rows); - auto column_names = names; - auto result_types = result.Types(); - if (config.render_mode == RenderMode::COLUMNS) { - collections = PivotCollections(context, std::move(collections), column_names, result_types, row_count); - } - - // for each column, figure out the width - // start off by figuring out the name of the header by looking at the column name and column type - idx_t min_width = has_hidden_rows || row_count == 0 ? minimum_row_length : 0; - vector column_map; - idx_t total_length; - auto widths = - ComputeRenderWidths(column_names, result_types, collections, min_width, max_width, column_map, total_length); - - // render boundaries for the individual columns - vector boundaries; - for (idx_t c = 0; c < widths.size(); c++) { - idx_t render_boundary; - if (c == 0) { - render_boundary = widths[c] + 2; - } else { - render_boundary = boundaries[c - 1] + widths[c] + 3; - } - boundaries.push_back(render_boundary); - } - - // now begin rendering - // first render the header - RenderHeader(column_names, result_types, column_map, widths, boundaries, total_length, row_count > 0, ss); - - // render the values, if there are any - RenderValues(collections, column_map, widths, result_types, ss); - - // render the row count and column count - auto column_count_str = to_string(result.ColumnCount()) + " column"; - if (result.ColumnCount() > 1) { - column_count_str += "s"; - } - bool has_hidden_columns = false; - for (auto entry : column_map) { - if (entry == SPLIT_COLUMN) { - has_hidden_columns = true; - break; - } - } - idx_t column_count = column_map.size(); - if (config.render_mode == RenderMode::COLUMNS) { - if (has_hidden_columns) { - has_hidden_rows = true; - shown_str = " (" + to_string(column_count - 3) + " shown)"; - } else { - shown_str = string(); - } - } else { - if (has_hidden_columns) { - column_count--; - column_count_str += " (" + to_string(column_count) + " shown)"; - } - } - - RenderRowCount(std::move(row_count_str), std::move(shown_str), column_count_str, boundaries, has_hidden_rows, - has_hidden_columns, total_length, row_count, column_count, minimum_row_length, ss); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/cgroups.cpp b/src/duckdb/src/common/cgroups.cpp deleted file mode 100644 index d7ddd980c..000000000 --- a/src/duckdb/src/common/cgroups.cpp +++ /dev/null @@ -1,186 +0,0 @@ -#include "duckdb/common/cgroups.hpp" - -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -#include - -namespace duckdb { - -optional_idx CGroups::GetMemoryLimit(FileSystem &fs) { - // First, try cgroup v2 - auto cgroup_v2_limit = GetCGroupV2MemoryLimit(fs); - if (cgroup_v2_limit.IsValid()) { - return cgroup_v2_limit; - } - - // If cgroup v2 fails, try cgroup v1 - return GetCGroupV1MemoryLimit(fs); -} - -optional_idx CGroups::GetCGroupV2MemoryLimit(FileSystem &fs) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - const char *cgroup_self = "/proc/self/cgroup"; - const char *memory_max = "/sys/fs/cgroup/%s/memory.max"; - - if (!fs.FileExists(cgroup_self)) { - return optional_idx(); - } - - string cgroup_path = ReadCGroupPath(fs, cgroup_self); - if (cgroup_path.empty()) { - return optional_idx(); - } - - char memory_max_path[256]; - snprintf(memory_max_path, sizeof(memory_max_path), memory_max, cgroup_path.c_str()); - - if (!fs.FileExists(memory_max_path)) { - return optional_idx(); - } - - return ReadCGroupValue(fs, memory_max_path); -#else - return optional_idx(); -#endif -} - -optional_idx CGroups::GetCGroupV1MemoryLimit(FileSystem &fs) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - const char *cgroup_self = "/proc/self/cgroup"; - const char *memory_limit = "/sys/fs/cgroup/memory/%s/memory.limit_in_bytes"; - - if (!fs.FileExists(cgroup_self)) { - return optional_idx(); - } - - string memory_cgroup_path = ReadMemoryCGroupPath(fs, cgroup_self); - if (memory_cgroup_path.empty()) { - return optional_idx(); - } - - char memory_limit_path[256]; - snprintf(memory_limit_path, sizeof(memory_limit_path), memory_limit, memory_cgroup_path.c_str()); - - if (!fs.FileExists(memory_limit_path)) { - return optional_idx(); - } - - return ReadCGroupValue(fs, memory_limit_path); -#else - return optional_idx(); -#endif -} - -string CGroups::ReadCGroupPath(FileSystem &fs, const char *cgroup_file) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - auto handle = fs.OpenFile(cgroup_file, FileFlags::FILE_FLAGS_READ); - char buffer[1024]; - auto bytes_read = fs.Read(*handle, buffer, sizeof(buffer) - 1); - buffer[bytes_read] = '\0'; - - // For cgroup v2, we're looking for a single line with "0::/path" - string content(buffer); - auto pos = content.find("::"); - if (pos != string::npos) { - // remove trailing \n - auto pos2 = content.find('\n', pos + 2); - if (pos2 != string::npos) { - return content.substr(pos + 2, pos2 - (pos + 2)); - } else { - return content.substr(pos + 2); - } - } -#endif - return ""; -} - -string CGroups::ReadMemoryCGroupPath(FileSystem &fs, const char *cgroup_file) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - auto handle = fs.OpenFile(cgroup_file, FileFlags::FILE_FLAGS_READ); - char buffer[1024]; - auto bytes_read = fs.Read(*handle, buffer, sizeof(buffer) - 1); - buffer[bytes_read] = '\0'; - - // For cgroup v1, we're looking for a line with "memory:/path" - string content(buffer); - size_t pos = 0; - string line; - while ((pos = content.find('\n')) != string::npos) { - line = content.substr(0, pos); - if (line.find("memory:") == 0) { - return line.substr(line.find(':') + 1); - } - content.erase(0, pos + 1); - } -#endif - return ""; -} - -optional_idx CGroups::ReadCGroupValue(FileSystem &fs, const char *file_path) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ); - char buffer[100]; - auto bytes_read = fs.Read(*handle, buffer, 99); - buffer[bytes_read] = '\0'; - - idx_t value; - if (TryCast::Operation(string_t(buffer), value)) { - return optional_idx(value); - } -#endif - return optional_idx(); -} - -idx_t CGroups::GetCPULimit(FileSystem &fs, idx_t physical_cores) { -#if defined(__linux__) && !defined(DUCKDB_WASM) - static constexpr const char *cpu_max = "/sys/fs/cgroup/cpu.max"; - static constexpr const char *cfs_quota = "/sys/fs/cgroup/cpu/cpu.cfs_quota_us"; - static constexpr const char *cfs_period = "/sys/fs/cgroup/cpu/cpu.cfs_period_us"; - - int64_t quota, period; - char byte_buffer[1000]; - unique_ptr handle; - int64_t read_bytes; - - if (fs.FileExists(cpu_max)) { - // cgroup v2 - handle = fs.OpenFile(cpu_max, FileFlags::FILE_FLAGS_READ); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 " %" SCNd64 "", "a, &period) != 2) { - return physical_cores; - } - } else if (fs.FileExists(cfs_quota) && fs.FileExists(cfs_period)) { - // cgroup v1 - handle = fs.OpenFile(cfs_quota, FileFlags::FILE_FLAGS_READ); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 "", "a) != 1) { - return physical_cores; - } - - handle = fs.OpenFile(cfs_period, FileFlags::FILE_FLAGS_READ); - read_bytes = fs.Read(*handle, (void *)byte_buffer, 999); - byte_buffer[read_bytes] = '\0'; - if (std::sscanf(byte_buffer, "%" SCNd64 "", &period) != 1) { - return physical_cores; - } - } else { - // No cgroup quota - return physical_cores; - } - if (quota > 0 && period > 0) { - return idx_t(std::ceil((double)quota / (double)period)); - } else { - return physical_cores; - } -#else - return physical_cores; -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/checksum.cpp b/src/duckdb/src/common/checksum.cpp deleted file mode 100644 index 2fbca299e..000000000 --- a/src/duckdb/src/common/checksum.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "duckdb/common/checksum.hpp" -#include "duckdb/common/types/hash.hpp" - -namespace duckdb { - -hash_t Checksum(uint64_t x) { - return x * UINT64_C(0xbf58476d1ce4e5b9); -} - -uint64_t Checksum(uint8_t *buffer, size_t size) { - uint64_t result = 5381; - uint64_t *ptr = reinterpret_cast(buffer); - size_t i; - // for efficiency, we first checksum uint64_t values - for (i = 0; i < size / 8; i++) { - result ^= Checksum(ptr[i]); - } - if (size - i * 8 > 0) { - // the remaining 0-7 bytes we hash using a string hash - result ^= Hash(buffer + i * 8, size - i * 8); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/compressed_file_system.cpp b/src/duckdb/src/common/compressed_file_system.cpp deleted file mode 100644 index ddc325cf1..000000000 --- a/src/duckdb/src/common/compressed_file_system.cpp +++ /dev/null @@ -1,156 +0,0 @@ -#include "duckdb/common/compressed_file_system.hpp" -#include "duckdb/common/numeric_utils.hpp" - -namespace duckdb { - -StreamWrapper::~StreamWrapper() { -} - -CompressedFile::CompressedFile(CompressedFileSystem &fs, unique_ptr child_handle_p, const string &path) - : FileHandle(fs, path, child_handle_p->GetFlags()), compressed_fs(fs), child_handle(std::move(child_handle_p)) { -} - -CompressedFile::~CompressedFile() { - CompressedFile::Close(); -} - -void CompressedFile::Initialize(bool write) { - Close(); - - this->write = write; - stream_data.in_buf_size = compressed_fs.InBufferSize(); - stream_data.out_buf_size = compressed_fs.OutBufferSize(); - stream_data.in_buff = make_unsafe_uniq_array(stream_data.in_buf_size); - stream_data.in_buff_start = stream_data.in_buff.get(); - stream_data.in_buff_end = stream_data.in_buff.get(); - stream_data.out_buff = make_unsafe_uniq_array(stream_data.out_buf_size); - stream_data.out_buff_start = stream_data.out_buff.get(); - stream_data.out_buff_end = stream_data.out_buff.get(); - - stream_wrapper = compressed_fs.CreateStream(); - stream_wrapper->Initialize(*this, write); -} - -idx_t CompressedFile::GetProgress() { - return current_position; -} - -int64_t CompressedFile::ReadData(void *buffer, int64_t remaining) { - idx_t total_read = 0; - while (true) { - // first check if there are input bytes available in the output buffers - if (stream_data.out_buff_start != stream_data.out_buff_end) { - // there is! copy it into the output buffer - auto available = - MinValue(UnsafeNumericCast(remaining), - UnsafeNumericCast(stream_data.out_buff_end - stream_data.out_buff_start)); - memcpy(data_ptr_t(buffer) + total_read, stream_data.out_buff_start, available); - - // increment the total read variables as required - stream_data.out_buff_start += available; - total_read += available; - remaining = UnsafeNumericCast(UnsafeNumericCast(remaining) - available); - if (remaining == 0) { - // done! read enough - return UnsafeNumericCast(total_read); - } - } - if (!stream_wrapper) { - return UnsafeNumericCast(total_read); - } - current_position += static_cast(stream_data.in_buff_end - stream_data.in_buff_start); - // ran out of buffer: read more data from the child stream - stream_data.out_buff_start = stream_data.out_buff.get(); - stream_data.out_buff_end = stream_data.out_buff.get(); - D_ASSERT(stream_data.in_buff_start <= stream_data.in_buff_end); - D_ASSERT(stream_data.in_buff_end <= stream_data.in_buff_start + stream_data.in_buf_size); - - // read more input when requested and still data in the input stream - if (stream_data.refresh && (stream_data.in_buff_end == stream_data.in_buff.get() + stream_data.in_buf_size)) { - auto bufrem = stream_data.in_buff_end - stream_data.in_buff_start; - // buffer not empty, move remaining bytes to the beginning - memmove(stream_data.in_buff.get(), stream_data.in_buff_start, UnsafeNumericCast(bufrem)); - stream_data.in_buff_start = stream_data.in_buff.get(); - // refill the rest of input buffer - auto sz = child_handle->Read(stream_data.in_buff_start + bufrem, - stream_data.in_buf_size - UnsafeNumericCast(bufrem)); - stream_data.in_buff_end = stream_data.in_buff_start + bufrem + sz; - if (sz <= 0) { - stream_wrapper.reset(); - break; - } - } - - // read more input if none available - if (stream_data.in_buff_start == stream_data.in_buff_end) { - // empty input buffer: refill from the start - stream_data.in_buff_start = stream_data.in_buff.get(); - stream_data.in_buff_end = stream_data.in_buff_start; - auto sz = child_handle->Read(stream_data.in_buff.get(), stream_data.in_buf_size); - if (sz <= 0) { - stream_wrapper.reset(); - break; - } - stream_data.in_buff_end = stream_data.in_buff_start + sz; - } - - auto finished = stream_wrapper->Read(stream_data); - if (finished) { - stream_wrapper.reset(); - } - } - return UnsafeNumericCast(total_read); -} - -int64_t CompressedFile::WriteData(data_ptr_t buffer, int64_t nr_bytes) { - stream_wrapper->Write(*this, stream_data, buffer, nr_bytes); - return nr_bytes; -} - -void CompressedFile::Close() { - if (stream_wrapper) { - stream_wrapper->Close(); - stream_wrapper.reset(); - } - stream_data.in_buff.reset(); - stream_data.out_buff.reset(); - stream_data.out_buff_start = nullptr; - stream_data.out_buff_end = nullptr; - stream_data.in_buff_start = nullptr; - stream_data.in_buff_end = nullptr; - stream_data.in_buf_size = 0; - stream_data.out_buf_size = 0; - stream_data.refresh = false; -} - -int64_t CompressedFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &compressed_file = handle.Cast(); - return compressed_file.ReadData(buffer, nr_bytes); -} - -int64_t CompressedFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &compressed_file = handle.Cast(); - return compressed_file.WriteData(data_ptr_cast(buffer), nr_bytes); -} - -void CompressedFileSystem::Reset(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - compressed_file.child_handle->Reset(); - compressed_file.Initialize(compressed_file.write); -} - -int64_t CompressedFileSystem::GetFileSize(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - return NumericCast(compressed_file.child_handle->GetFileSize()); -} - -bool CompressedFileSystem::OnDiskFile(FileHandle &handle) { - auto &compressed_file = handle.Cast(); - return compressed_file.child_handle->OnDiskFile(); -} - -bool CompressedFileSystem::CanSeek() { - return false; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/constants.cpp b/src/duckdb/src/common/constants.cpp deleted file mode 100644 index edafe6b67..000000000 --- a/src/duckdb/src/common/constants.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "duckdb/common/constants.hpp" -#include "duckdb/common/exception.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/common/vector_size.hpp" - -namespace duckdb { - -constexpr const idx_t DConstants::INVALID_INDEX; -const row_t MAX_ROW_ID = 36028797018960000ULL; // 2^55 -const row_t MAX_ROW_ID_LOCAL = 72057594037920000ULL; // 2^56 -const column_t COLUMN_IDENTIFIER_ROW_ID = (column_t)-1; -const double PI = 3.141592653589793; - -const transaction_t TRANSACTION_ID_START = 4611686018427388000ULL; // 2^62 -const transaction_t MAX_TRANSACTION_ID = NumericLimits::Maximum(); // 2^63 -const transaction_t NOT_DELETED_ID = NumericLimits::Maximum() - 1; // 2^64 - 1 -const transaction_t MAXIMUM_QUERY_ID = NumericLimits::Maximum(); // 2^64 - -bool IsPowerOfTwo(uint64_t v) { - return (v & (v - 1)) == 0; -} - -uint64_t NextPowerOfTwo(uint64_t v) { - auto v_in = v; - if (v < 1) { // this is not strictly right but we seem to rely on it in places - return 2; - } - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v |= v >> 32; - v++; - if (v == 0) { - throw OutOfRangeException("Can't find next power of 2 for %llu", v_in); - } - return v; -} - -uint64_t PreviousPowerOfTwo(uint64_t v) { - return NextPowerOfTwo((v / 2) + 1); -} - -bool IsInvalidSchema(const string &str) { - return str.empty(); -} - -bool IsInvalidCatalog(const string &str) { - return str.empty(); -} - -bool IsRowIdColumnId(column_t column_id) { - return column_id == COLUMN_IDENTIFIER_ROW_ID; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/crypto/md5.cpp b/src/duckdb/src/common/crypto/md5.cpp deleted file mode 100644 index d0d679c32..000000000 --- a/src/duckdb/src/common/crypto/md5.cpp +++ /dev/null @@ -1,257 +0,0 @@ -/* -** This code taken from the SQLite test library. Originally found on -** the internet. The original header comment follows this comment. -** The code is largerly unchanged, but there have been some modifications. -*/ -/* - * This code implements the MD5 message-digest algorithm. - * The algorithm is due to Ron Rivest. This code was - * written by Colin Plumb in 1993, no copyright is claimed. - * This code is in the public domain; do with it what you wish. - * - * Equivalent code is available from RSA Data Security, Inc. - * This code has been tested against that, and is equivalent, - * except that you don't need to include two pages of legalese - * with every copy. - * - * To compute the message digest of a chunk of bytes, declare an - * MD5Context structure, pass it to MD5Init, call MD5Update as - * needed on buffers full of bytes, and then call MD5Final, which - * will fill a supplied 16-byte array with the digest. - */ -#include "duckdb/common/crypto/md5.hpp" -#include "mbedtls_wrapper.hpp" - -namespace duckdb { - -/* - * Note: this code is harmless on little-endian machines. - */ -static void ByteReverse(unsigned char *buf, unsigned longs) { - uint32_t t; - do { - t = (uint32_t)((unsigned)buf[3] << 8 | buf[2]) << 16 | ((unsigned)buf[1] << 8 | buf[0]); - *reinterpret_cast(buf) = t; - buf += 4; - } while (--longs); -} -/* The four core functions - F1 is optimized somewhat */ - -/* #define F1(x, y, z) (x & y | ~x & z) */ -#define F1(x, y, z) ((z) ^ ((x) & ((y) ^ (z)))) -#define F2(x, y, z) F1(z, x, y) -#define F3(x, y, z) ((x) ^ (y) ^ (z)) -#define F4(x, y, z) ((y) ^ ((x) | ~(z))) - -/* This is the central step in the MD5 algorithm. */ -#define MD5STEP(f, w, x, y, z, data, s) ((w) += f(x, y, z) + (data), (w) = (w) << (s) | (w) >> (32 - (s)), (w) += (x)) - -/* - * The core of the MD5 algorithm, this alters an existing MD5 hash to - * reflect the addition of 16 longwords of new data. MD5Update blocks - * the data and converts bytes into longwords for this routine. - */ -static void MD5Transform(uint32_t buf[4], const uint32_t in[16]) { - uint32_t a, b, c, d; - - a = buf[0]; - b = buf[1]; - c = buf[2]; - d = buf[3]; - - MD5STEP(F1, a, b, c, d, in[0] + 0xd76aa478, 7); - MD5STEP(F1, d, a, b, c, in[1] + 0xe8c7b756, 12); - MD5STEP(F1, c, d, a, b, in[2] + 0x242070db, 17); - MD5STEP(F1, b, c, d, a, in[3] + 0xc1bdceee, 22); - MD5STEP(F1, a, b, c, d, in[4] + 0xf57c0faf, 7); - MD5STEP(F1, d, a, b, c, in[5] + 0x4787c62a, 12); - MD5STEP(F1, c, d, a, b, in[6] + 0xa8304613, 17); - MD5STEP(F1, b, c, d, a, in[7] + 0xfd469501, 22); - MD5STEP(F1, a, b, c, d, in[8] + 0x698098d8, 7); - MD5STEP(F1, d, a, b, c, in[9] + 0x8b44f7af, 12); - MD5STEP(F1, c, d, a, b, in[10] + 0xffff5bb1, 17); - MD5STEP(F1, b, c, d, a, in[11] + 0x895cd7be, 22); - MD5STEP(F1, a, b, c, d, in[12] + 0x6b901122, 7); - MD5STEP(F1, d, a, b, c, in[13] + 0xfd987193, 12); - MD5STEP(F1, c, d, a, b, in[14] + 0xa679438e, 17); - MD5STEP(F1, b, c, d, a, in[15] + 0x49b40821, 22); - - MD5STEP(F2, a, b, c, d, in[1] + 0xf61e2562, 5); - MD5STEP(F2, d, a, b, c, in[6] + 0xc040b340, 9); - MD5STEP(F2, c, d, a, b, in[11] + 0x265e5a51, 14); - MD5STEP(F2, b, c, d, a, in[0] + 0xe9b6c7aa, 20); - MD5STEP(F2, a, b, c, d, in[5] + 0xd62f105d, 5); - MD5STEP(F2, d, a, b, c, in[10] + 0x02441453, 9); - MD5STEP(F2, c, d, a, b, in[15] + 0xd8a1e681, 14); - MD5STEP(F2, b, c, d, a, in[4] + 0xe7d3fbc8, 20); - MD5STEP(F2, a, b, c, d, in[9] + 0x21e1cde6, 5); - MD5STEP(F2, d, a, b, c, in[14] + 0xc33707d6, 9); - MD5STEP(F2, c, d, a, b, in[3] + 0xf4d50d87, 14); - MD5STEP(F2, b, c, d, a, in[8] + 0x455a14ed, 20); - MD5STEP(F2, a, b, c, d, in[13] + 0xa9e3e905, 5); - MD5STEP(F2, d, a, b, c, in[2] + 0xfcefa3f8, 9); - MD5STEP(F2, c, d, a, b, in[7] + 0x676f02d9, 14); - MD5STEP(F2, b, c, d, a, in[12] + 0x8d2a4c8a, 20); - - MD5STEP(F3, a, b, c, d, in[5] + 0xfffa3942, 4); - MD5STEP(F3, d, a, b, c, in[8] + 0x8771f681, 11); - MD5STEP(F3, c, d, a, b, in[11] + 0x6d9d6122, 16); - MD5STEP(F3, b, c, d, a, in[14] + 0xfde5380c, 23); - MD5STEP(F3, a, b, c, d, in[1] + 0xa4beea44, 4); - MD5STEP(F3, d, a, b, c, in[4] + 0x4bdecfa9, 11); - MD5STEP(F3, c, d, a, b, in[7] + 0xf6bb4b60, 16); - MD5STEP(F3, b, c, d, a, in[10] + 0xbebfbc70, 23); - MD5STEP(F3, a, b, c, d, in[13] + 0x289b7ec6, 4); - MD5STEP(F3, d, a, b, c, in[0] + 0xeaa127fa, 11); - MD5STEP(F3, c, d, a, b, in[3] + 0xd4ef3085, 16); - MD5STEP(F3, b, c, d, a, in[6] + 0x04881d05, 23); - MD5STEP(F3, a, b, c, d, in[9] + 0xd9d4d039, 4); - MD5STEP(F3, d, a, b, c, in[12] + 0xe6db99e5, 11); - MD5STEP(F3, c, d, a, b, in[15] + 0x1fa27cf8, 16); - MD5STEP(F3, b, c, d, a, in[2] + 0xc4ac5665, 23); - - MD5STEP(F4, a, b, c, d, in[0] + 0xf4292244, 6); - MD5STEP(F4, d, a, b, c, in[7] + 0x432aff97, 10); - MD5STEP(F4, c, d, a, b, in[14] + 0xab9423a7, 15); - MD5STEP(F4, b, c, d, a, in[5] + 0xfc93a039, 21); - MD5STEP(F4, a, b, c, d, in[12] + 0x655b59c3, 6); - MD5STEP(F4, d, a, b, c, in[3] + 0x8f0ccc92, 10); - MD5STEP(F4, c, d, a, b, in[10] + 0xffeff47d, 15); - MD5STEP(F4, b, c, d, a, in[1] + 0x85845dd1, 21); - MD5STEP(F4, a, b, c, d, in[8] + 0x6fa87e4f, 6); - MD5STEP(F4, d, a, b, c, in[15] + 0xfe2ce6e0, 10); - MD5STEP(F4, c, d, a, b, in[6] + 0xa3014314, 15); - MD5STEP(F4, b, c, d, a, in[13] + 0x4e0811a1, 21); - MD5STEP(F4, a, b, c, d, in[4] + 0xf7537e82, 6); - MD5STEP(F4, d, a, b, c, in[11] + 0xbd3af235, 10); - MD5STEP(F4, c, d, a, b, in[2] + 0x2ad7d2bb, 15); - MD5STEP(F4, b, c, d, a, in[9] + 0xeb86d391, 21); - - buf[0] += a; - buf[1] += b; - buf[2] += c; - buf[3] += d; -} - -/* - * Start MD5 accumulation. Set bit count to 0 and buffer to mysterious - * initialization constants. - */ -MD5Context::MD5Context() { - buf[0] = 0x67452301; - buf[1] = 0xefcdab89; - buf[2] = 0x98badcfe; - buf[3] = 0x10325476; - bits[0] = 0; - bits[1] = 0; -} - -/* - * Update context to reflect the concatenation of another buffer full - * of bytes. - */ -void MD5Context::MD5Update(const_data_ptr_t input, idx_t len) { - uint32_t t; - - /* Update bitcount */ - - t = bits[0]; - bits[0] = t + ((uint32_t)len << 3); - if (bits[0] < t) { - bits[1]++; /* Carry from low to high */ - } - bits[1] += len >> 29; - - t = (t >> 3) & 0x3f; /* Bytes already in shsInfo->data */ - - /* Handle any leading odd-sized chunks */ - - if (t) { - unsigned char *p = (unsigned char *)in + t; - - t = 64 - t; - if (len < t) { - memcpy(p, input, len); - return; - } - memcpy(p, input, t); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - input += t; - len -= t; - } - - /* Process data in 64-byte chunks */ - - while (len >= 64) { - memcpy(in, input, 64); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - input += 64; - len -= 64; - } - - /* Handle any remaining bytes of data. */ - memcpy(in, input, len); -} - -/* - * Final wrapup - pad to 64-byte boundary with the bit pattern - * 1 0* (64-bit count of bits processed, MSB-first) - */ -void MD5Context::Finish(data_ptr_t out_digest) { - unsigned count; - unsigned char *p; - - /* Compute number of bytes mod 64 */ - count = (bits[0] >> 3) & 0x3F; - - /* Set the first char of padding to 0x80. This is safe since there is - always at least one byte free */ - p = in + count; - *p++ = 0x80; - - /* Bytes of padding needed to make 64 bytes */ - count = 64 - 1 - count; - - /* Pad out to 56 mod 64 */ - if (count < 8) { - /* Two lots of padding: Pad the first block to 64 bytes */ - memset(p, 0, count); - ByteReverse(in, 16); - MD5Transform(buf, reinterpret_cast(in)); - - /* Now fill the next block with 56 bytes */ - memset(in, 0, 56); - } else { - /* Pad block to 56 bytes */ - memset(p, 0, count - 8); - } - ByteReverse(in, 14); - - /* Append length in bits and transform */ - (reinterpret_cast(in))[14] = bits[0]; - (reinterpret_cast(in))[15] = bits[1]; - - MD5Transform(buf, reinterpret_cast(in)); - ByteReverse(reinterpret_cast(buf), 4); - memcpy(out_digest, buf, 16); -} - -void MD5Context::FinishHex(char *out_digest) { - data_t digest[MD5_HASH_LENGTH_BINARY]; - Finish(digest); - duckdb_mbedtls::MbedTlsWrapper::ToBase16(reinterpret_cast(digest), out_digest, MD5_HASH_LENGTH_BINARY); -} - -string MD5Context::FinishHex() { - char digest[MD5_HASH_LENGTH_TEXT]; - FinishHex(digest); - return string(digest, MD5_HASH_LENGTH_TEXT); -} - -void MD5Context::Add(const char *data) { - MD5Update(const_data_ptr_cast(data), strlen(data)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/encryption_state.cpp b/src/duckdb/src/common/encryption_state.cpp deleted file mode 100644 index b6343b630..000000000 --- a/src/duckdb/src/common/encryption_state.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "duckdb/common/encryption_state.hpp" - -namespace duckdb { - -EncryptionState::EncryptionState() { - // abstract class, no implementation needed -} - -EncryptionState::~EncryptionState() { -} - -bool EncryptionState::IsOpenSSL() { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -void EncryptionState::InitializeEncryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, const std::string *key) { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -void EncryptionState::InitializeDecryption(duckdb::const_data_ptr_t iv, duckdb::idx_t iv_len, const std::string *key) { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -size_t EncryptionState::Process(duckdb::const_data_ptr_t in, duckdb::idx_t in_len, duckdb::data_ptr_t out, - duckdb::idx_t out_len) { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -size_t EncryptionState::Finalize(duckdb::data_ptr_t out, duckdb::idx_t out_len, duckdb::data_ptr_t tag, - duckdb::idx_t tag_len) { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -void EncryptionState::GenerateRandomData(duckdb::data_ptr_t data, duckdb::idx_t len) { - throw NotImplementedException("EncryptionState Abstract Class is called"); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enum_util.cpp b/src/duckdb/src/common/enum_util.cpp deleted file mode 100644 index 1fd91919b..000000000 --- a/src/duckdb/src/common/enum_util.cpp +++ /dev/null @@ -1,4240 +0,0 @@ -//------------------------------------------------------------------------- -// This file is automatically generated by scripts/generate_enum_util.py -// Do not edit this file manually, your changes will be overwritten -// If you want to exclude an enum from serialization, add it to the blacklist in the script -// -// Note: The generated code will only work properly if the enum is a top level item in the duckdb namespace -// If the enum is nested in a class, or in another namespace, the generated code will not compile. -// You should move the enum to the duckdb namespace, manually write a specialization or add it to the blacklist -//------------------------------------------------------------------------- - - -#include "duckdb/common/enum_util.hpp" -#include "duckdb/catalog/catalog_entry/dependency/dependency_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_column_type.hpp" -#include "duckdb/common/box_renderer.hpp" -#include "duckdb/common/enums/access_mode.hpp" -#include "duckdb/common/enums/aggregate_handling.hpp" -#include "duckdb/common/enums/catalog_lookup_behavior.hpp" -#include "duckdb/common/enums/catalog_type.hpp" -#include "duckdb/common/enums/compression_type.hpp" -#include "duckdb/common/enums/copy_overwrite_mode.hpp" -#include "duckdb/common/enums/cte_materialize.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/enums/debug_initialize.hpp" -#include "duckdb/common/enums/destroy_buffer_upon.hpp" -#include "duckdb/common/enums/explain_format.hpp" -#include "duckdb/common/enums/expression_type.hpp" -#include "duckdb/common/enums/file_compression_type.hpp" -#include "duckdb/common/enums/file_glob_options.hpp" -#include "duckdb/common/enums/filter_propagate_result.hpp" -#include "duckdb/common/enums/function_errors.hpp" -#include "duckdb/common/enums/index_constraint_type.hpp" -#include "duckdb/common/enums/join_type.hpp" -#include "duckdb/common/enums/joinref_type.hpp" -#include "duckdb/common/enums/logical_operator_type.hpp" -#include "duckdb/common/enums/memory_tag.hpp" -#include "duckdb/common/enums/metric_type.hpp" -#include "duckdb/common/enums/on_create_conflict.hpp" -#include "duckdb/common/enums/on_entry_not_found.hpp" -#include "duckdb/common/enums/operator_result_type.hpp" -#include "duckdb/common/enums/optimizer_type.hpp" -#include "duckdb/common/enums/order_preservation_type.hpp" -#include "duckdb/common/enums/order_type.hpp" -#include "duckdb/common/enums/output_type.hpp" -#include "duckdb/common/enums/pending_execution_result.hpp" -#include "duckdb/common/enums/physical_operator_type.hpp" -#include "duckdb/common/enums/prepared_statement_mode.hpp" -#include "duckdb/common/enums/profiler_format.hpp" -#include "duckdb/common/enums/quantile_enum.hpp" -#include "duckdb/common/enums/relation_type.hpp" -#include "duckdb/common/enums/scan_options.hpp" -#include "duckdb/common/enums/set_operation_type.hpp" -#include "duckdb/common/enums/set_scope.hpp" -#include "duckdb/common/enums/set_type.hpp" -#include "duckdb/common/enums/statement_type.hpp" -#include "duckdb/common/enums/stream_execution_result.hpp" -#include "duckdb/common/enums/subquery_type.hpp" -#include "duckdb/common/enums/tableref_type.hpp" -#include "duckdb/common/enums/undo_flags.hpp" -#include "duckdb/common/enums/vector_type.hpp" -#include "duckdb/common/enums/wal_type.hpp" -#include "duckdb/common/enums/window_aggregation_mode.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/exception_format_value.hpp" -#include "duckdb/common/extra_type_info.hpp" -#include "duckdb/common/file_buffer.hpp" -#include "duckdb/common/file_open_flags.hpp" -#include "duckdb/common/multi_file_list.hpp" -#include "duckdb/common/operator/decimal_cast_operators.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/column/column_data_scan_states.hpp" -#include "duckdb/common/types/column/partitioned_column_data.hpp" -#include "duckdb/common/types/conflict_manager.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/hyperloglog.hpp" -#include "duckdb/common/types/row/partitioned_tuple_data.hpp" -#include "duckdb/common/types/row/tuple_data_states.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types/vector_buffer.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/index/art/node.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_option.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_state.hpp" -#include "duckdb/execution/operator/csv_scanner/quote_rules.hpp" -#include "duckdb/execution/reservoir_sample.hpp" -#include "duckdb/function/aggregate_state.hpp" -#include "duckdb/function/compression_function.hpp" -#include "duckdb/function/copy_function.hpp" -#include "duckdb/function/function.hpp" -#include "duckdb/function/macro_function.hpp" -#include "duckdb/function/partition_stats.hpp" -#include "duckdb/function/scalar/compressed_materialization_utils.hpp" -#include "duckdb/function/scalar/strftime_format.hpp" -#include "duckdb/function/table/arrow/enum/arrow_datetime_type.hpp" -#include "duckdb/function/table/arrow/enum/arrow_type_info_type.hpp" -#include "duckdb/function/table/arrow/enum/arrow_variable_size_type.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/appender.hpp" -#include "duckdb/main/capi/capi_internal.hpp" -#include "duckdb/main/client_properties.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/error_manager.hpp" -#include "duckdb/main/extension.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/main/extension_install_info.hpp" -#include "duckdb/main/query_result.hpp" -#include "duckdb/main/secret/secret.hpp" -#include "duckdb/main/settings.hpp" -#include "duckdb/parallel/interrupt.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/task.hpp" -#include "duckdb/parser/constraint.hpp" -#include "duckdb/parser/expression/parameter_expression.hpp" -#include "duckdb/parser/expression/window_expression.hpp" -#include "duckdb/parser/parsed_data/alter_info.hpp" -#include "duckdb/parser/parsed_data/alter_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/alter_table_function_info.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/parser/parsed_data/create_sequence_info.hpp" -#include "duckdb/parser/parsed_data/extra_drop_info.hpp" -#include "duckdb/parser/parsed_data/load_info.hpp" -#include "duckdb/parser/parsed_data/parse_info.hpp" -#include "duckdb/parser/parsed_data/pragma_info.hpp" -#include "duckdb/parser/parsed_data/sample_options.hpp" -#include "duckdb/parser/parsed_data/transaction_info.hpp" -#include "duckdb/parser/parser_extension.hpp" -#include "duckdb/parser/query_node.hpp" -#include "duckdb/parser/result_modifier.hpp" -#include "duckdb/parser/simplified_token.hpp" -#include "duckdb/parser/statement/copy_statement.hpp" -#include "duckdb/parser/statement/explain_statement.hpp" -#include "duckdb/parser/statement/insert_statement.hpp" -#include "duckdb/parser/tableref/showref.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/storage/buffer/block_handle.hpp" -#include "duckdb/storage/compression/bitpacking.hpp" -#include "duckdb/storage/magic_bytes.hpp" -#include "duckdb/storage/statistics/base_statistics.hpp" -#include "duckdb/storage/table/chunk_info.hpp" -#include "duckdb/storage/table/column_segment.hpp" -#include "duckdb/storage/temporary_file_manager.hpp" -#include "duckdb/verification/statement_verifier.hpp" - -namespace duckdb { - -const StringUtil::EnumStringLiteral *GetARTAppendModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ARTAppendMode::DEFAULT), "DEFAULT" }, - { static_cast(ARTAppendMode::IGNORE_DUPLICATES), "IGNORE_DUPLICATES" }, - { static_cast(ARTAppendMode::INSERT_DUPLICATES), "INSERT_DUPLICATES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ARTAppendMode value) { - return StringUtil::EnumToString(GetARTAppendModeValues(), 3, "ARTAppendMode", static_cast(value)); -} - -template<> -ARTAppendMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetARTAppendModeValues(), 3, "ARTAppendMode", value)); -} - -const StringUtil::EnumStringLiteral *GetARTConflictTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ARTConflictType::NO_CONFLICT), "NO_CONFLICT" }, - { static_cast(ARTConflictType::CONSTRAINT), "CONSTRAINT" }, - { static_cast(ARTConflictType::TRANSACTION), "TRANSACTION" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ARTConflictType value) { - return StringUtil::EnumToString(GetARTConflictTypeValues(), 3, "ARTConflictType", static_cast(value)); -} - -template<> -ARTConflictType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetARTConflictTypeValues(), 3, "ARTConflictType", value)); -} - -const StringUtil::EnumStringLiteral *GetAccessModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AccessMode::UNDEFINED), "UNDEFINED" }, - { static_cast(AccessMode::AUTOMATIC), "AUTOMATIC" }, - { static_cast(AccessMode::READ_ONLY), "READ_ONLY" }, - { static_cast(AccessMode::READ_WRITE), "READ_WRITE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AccessMode value) { - return StringUtil::EnumToString(GetAccessModeValues(), 4, "AccessMode", static_cast(value)); -} - -template<> -AccessMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAccessModeValues(), 4, "AccessMode", value)); -} - -const StringUtil::EnumStringLiteral *GetAggregateCombineTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AggregateCombineType::PRESERVE_INPUT), "PRESERVE_INPUT" }, - { static_cast(AggregateCombineType::ALLOW_DESTRUCTIVE), "ALLOW_DESTRUCTIVE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AggregateCombineType value) { - return StringUtil::EnumToString(GetAggregateCombineTypeValues(), 2, "AggregateCombineType", static_cast(value)); -} - -template<> -AggregateCombineType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAggregateCombineTypeValues(), 2, "AggregateCombineType", value)); -} - -const StringUtil::EnumStringLiteral *GetAggregateDistinctDependentValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AggregateDistinctDependent::DISTINCT_DEPENDENT), "DISTINCT_DEPENDENT" }, - { static_cast(AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT), "NOT_DISTINCT_DEPENDENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AggregateDistinctDependent value) { - return StringUtil::EnumToString(GetAggregateDistinctDependentValues(), 2, "AggregateDistinctDependent", static_cast(value)); -} - -template<> -AggregateDistinctDependent EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAggregateDistinctDependentValues(), 2, "AggregateDistinctDependent", value)); -} - -const StringUtil::EnumStringLiteral *GetAggregateHandlingValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AggregateHandling::STANDARD_HANDLING), "STANDARD_HANDLING" }, - { static_cast(AggregateHandling::NO_AGGREGATES_ALLOWED), "NO_AGGREGATES_ALLOWED" }, - { static_cast(AggregateHandling::FORCE_AGGREGATES), "FORCE_AGGREGATES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AggregateHandling value) { - return StringUtil::EnumToString(GetAggregateHandlingValues(), 3, "AggregateHandling", static_cast(value)); -} - -template<> -AggregateHandling EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAggregateHandlingValues(), 3, "AggregateHandling", value)); -} - -const StringUtil::EnumStringLiteral *GetAggregateOrderDependentValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AggregateOrderDependent::ORDER_DEPENDENT), "ORDER_DEPENDENT" }, - { static_cast(AggregateOrderDependent::NOT_ORDER_DEPENDENT), "NOT_ORDER_DEPENDENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AggregateOrderDependent value) { - return StringUtil::EnumToString(GetAggregateOrderDependentValues(), 2, "AggregateOrderDependent", static_cast(value)); -} - -template<> -AggregateOrderDependent EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAggregateOrderDependentValues(), 2, "AggregateOrderDependent", value)); -} - -const StringUtil::EnumStringLiteral *GetAggregateTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AggregateType::NON_DISTINCT), "NON_DISTINCT" }, - { static_cast(AggregateType::DISTINCT), "DISTINCT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AggregateType value) { - return StringUtil::EnumToString(GetAggregateTypeValues(), 2, "AggregateType", static_cast(value)); -} - -template<> -AggregateType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAggregateTypeValues(), 2, "AggregateType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterForeignKeyTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterForeignKeyType::AFT_ADD), "AFT_ADD" }, - { static_cast(AlterForeignKeyType::AFT_DELETE), "AFT_DELETE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterForeignKeyType value) { - return StringUtil::EnumToString(GetAlterForeignKeyTypeValues(), 2, "AlterForeignKeyType", static_cast(value)); -} - -template<> -AlterForeignKeyType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterForeignKeyTypeValues(), 2, "AlterForeignKeyType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterScalarFunctionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterScalarFunctionType::INVALID), "INVALID" }, - { static_cast(AlterScalarFunctionType::ADD_FUNCTION_OVERLOADS), "ADD_FUNCTION_OVERLOADS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterScalarFunctionType value) { - return StringUtil::EnumToString(GetAlterScalarFunctionTypeValues(), 2, "AlterScalarFunctionType", static_cast(value)); -} - -template<> -AlterScalarFunctionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterScalarFunctionTypeValues(), 2, "AlterScalarFunctionType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterTableFunctionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterTableFunctionType::INVALID), "INVALID" }, - { static_cast(AlterTableFunctionType::ADD_FUNCTION_OVERLOADS), "ADD_FUNCTION_OVERLOADS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterTableFunctionType value) { - return StringUtil::EnumToString(GetAlterTableFunctionTypeValues(), 2, "AlterTableFunctionType", static_cast(value)); -} - -template<> -AlterTableFunctionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterTableFunctionTypeValues(), 2, "AlterTableFunctionType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterTableTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterTableType::INVALID), "INVALID" }, - { static_cast(AlterTableType::RENAME_COLUMN), "RENAME_COLUMN" }, - { static_cast(AlterTableType::RENAME_TABLE), "RENAME_TABLE" }, - { static_cast(AlterTableType::ADD_COLUMN), "ADD_COLUMN" }, - { static_cast(AlterTableType::REMOVE_COLUMN), "REMOVE_COLUMN" }, - { static_cast(AlterTableType::ALTER_COLUMN_TYPE), "ALTER_COLUMN_TYPE" }, - { static_cast(AlterTableType::SET_DEFAULT), "SET_DEFAULT" }, - { static_cast(AlterTableType::FOREIGN_KEY_CONSTRAINT), "FOREIGN_KEY_CONSTRAINT" }, - { static_cast(AlterTableType::SET_NOT_NULL), "SET_NOT_NULL" }, - { static_cast(AlterTableType::DROP_NOT_NULL), "DROP_NOT_NULL" }, - { static_cast(AlterTableType::SET_COLUMN_COMMENT), "SET_COLUMN_COMMENT" }, - { static_cast(AlterTableType::ADD_CONSTRAINT), "ADD_CONSTRAINT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterTableType value) { - return StringUtil::EnumToString(GetAlterTableTypeValues(), 12, "AlterTableType", static_cast(value)); -} - -template<> -AlterTableType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterTableTypeValues(), 12, "AlterTableType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterType::INVALID), "INVALID" }, - { static_cast(AlterType::ALTER_TABLE), "ALTER_TABLE" }, - { static_cast(AlterType::ALTER_VIEW), "ALTER_VIEW" }, - { static_cast(AlterType::ALTER_SEQUENCE), "ALTER_SEQUENCE" }, - { static_cast(AlterType::CHANGE_OWNERSHIP), "CHANGE_OWNERSHIP" }, - { static_cast(AlterType::ALTER_SCALAR_FUNCTION), "ALTER_SCALAR_FUNCTION" }, - { static_cast(AlterType::ALTER_TABLE_FUNCTION), "ALTER_TABLE_FUNCTION" }, - { static_cast(AlterType::SET_COMMENT), "SET_COMMENT" }, - { static_cast(AlterType::SET_COLUMN_COMMENT), "SET_COLUMN_COMMENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterType value) { - return StringUtil::EnumToString(GetAlterTypeValues(), 9, "AlterType", static_cast(value)); -} - -template<> -AlterType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterTypeValues(), 9, "AlterType", value)); -} - -const StringUtil::EnumStringLiteral *GetAlterViewTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AlterViewType::INVALID), "INVALID" }, - { static_cast(AlterViewType::RENAME_VIEW), "RENAME_VIEW" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AlterViewType value) { - return StringUtil::EnumToString(GetAlterViewTypeValues(), 2, "AlterViewType", static_cast(value)); -} - -template<> -AlterViewType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAlterViewTypeValues(), 2, "AlterViewType", value)); -} - -const StringUtil::EnumStringLiteral *GetAppenderTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(AppenderType::LOGICAL), "LOGICAL" }, - { static_cast(AppenderType::PHYSICAL), "PHYSICAL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(AppenderType value) { - return StringUtil::EnumToString(GetAppenderTypeValues(), 2, "AppenderType", static_cast(value)); -} - -template<> -AppenderType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetAppenderTypeValues(), 2, "AppenderType", value)); -} - -const StringUtil::EnumStringLiteral *GetArrowDateTimeTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ArrowDateTimeType::MILLISECONDS), "MILLISECONDS" }, - { static_cast(ArrowDateTimeType::MICROSECONDS), "MICROSECONDS" }, - { static_cast(ArrowDateTimeType::NANOSECONDS), "NANOSECONDS" }, - { static_cast(ArrowDateTimeType::SECONDS), "SECONDS" }, - { static_cast(ArrowDateTimeType::DAYS), "DAYS" }, - { static_cast(ArrowDateTimeType::MONTHS), "MONTHS" }, - { static_cast(ArrowDateTimeType::MONTH_DAY_NANO), "MONTH_DAY_NANO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ArrowDateTimeType value) { - return StringUtil::EnumToString(GetArrowDateTimeTypeValues(), 7, "ArrowDateTimeType", static_cast(value)); -} - -template<> -ArrowDateTimeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetArrowDateTimeTypeValues(), 7, "ArrowDateTimeType", value)); -} - -const StringUtil::EnumStringLiteral *GetArrowOffsetSizeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ArrowOffsetSize::REGULAR), "REGULAR" }, - { static_cast(ArrowOffsetSize::LARGE), "LARGE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ArrowOffsetSize value) { - return StringUtil::EnumToString(GetArrowOffsetSizeValues(), 2, "ArrowOffsetSize", static_cast(value)); -} - -template<> -ArrowOffsetSize EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetArrowOffsetSizeValues(), 2, "ArrowOffsetSize", value)); -} - -const StringUtil::EnumStringLiteral *GetArrowTypeInfoTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ArrowTypeInfoType::LIST), "LIST" }, - { static_cast(ArrowTypeInfoType::STRUCT), "STRUCT" }, - { static_cast(ArrowTypeInfoType::DATE_TIME), "DATE_TIME" }, - { static_cast(ArrowTypeInfoType::STRING), "STRING" }, - { static_cast(ArrowTypeInfoType::ARRAY), "ARRAY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ArrowTypeInfoType value) { - return StringUtil::EnumToString(GetArrowTypeInfoTypeValues(), 5, "ArrowTypeInfoType", static_cast(value)); -} - -template<> -ArrowTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetArrowTypeInfoTypeValues(), 5, "ArrowTypeInfoType", value)); -} - -const StringUtil::EnumStringLiteral *GetArrowVariableSizeTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ArrowVariableSizeType::NORMAL), "NORMAL" }, - { static_cast(ArrowVariableSizeType::FIXED_SIZE), "FIXED_SIZE" }, - { static_cast(ArrowVariableSizeType::SUPER_SIZE), "SUPER_SIZE" }, - { static_cast(ArrowVariableSizeType::VIEW), "VIEW" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ArrowVariableSizeType value) { - return StringUtil::EnumToString(GetArrowVariableSizeTypeValues(), 4, "ArrowVariableSizeType", static_cast(value)); -} - -template<> -ArrowVariableSizeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetArrowVariableSizeTypeValues(), 4, "ArrowVariableSizeType", value)); -} - -const StringUtil::EnumStringLiteral *GetBinderTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(BinderType::REGULAR_BINDER), "REGULAR_BINDER" }, - { static_cast(BinderType::VIEW_BINDER), "VIEW_BINDER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(BinderType value) { - return StringUtil::EnumToString(GetBinderTypeValues(), 2, "BinderType", static_cast(value)); -} - -template<> -BinderType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetBinderTypeValues(), 2, "BinderType", value)); -} - -const StringUtil::EnumStringLiteral *GetBindingModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(BindingMode::STANDARD_BINDING), "STANDARD_BINDING" }, - { static_cast(BindingMode::EXTRACT_NAMES), "EXTRACT_NAMES" }, - { static_cast(BindingMode::EXTRACT_REPLACEMENT_SCANS), "EXTRACT_REPLACEMENT_SCANS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(BindingMode value) { - return StringUtil::EnumToString(GetBindingModeValues(), 3, "BindingMode", static_cast(value)); -} - -template<> -BindingMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetBindingModeValues(), 3, "BindingMode", value)); -} - -const StringUtil::EnumStringLiteral *GetBitpackingModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(BitpackingMode::INVALID), "INVALID" }, - { static_cast(BitpackingMode::AUTO), "AUTO" }, - { static_cast(BitpackingMode::CONSTANT), "CONSTANT" }, - { static_cast(BitpackingMode::CONSTANT_DELTA), "CONSTANT_DELTA" }, - { static_cast(BitpackingMode::DELTA_FOR), "DELTA_FOR" }, - { static_cast(BitpackingMode::FOR), "FOR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(BitpackingMode value) { - return StringUtil::EnumToString(GetBitpackingModeValues(), 6, "BitpackingMode", static_cast(value)); -} - -template<> -BitpackingMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetBitpackingModeValues(), 6, "BitpackingMode", value)); -} - -const StringUtil::EnumStringLiteral *GetBlockStateValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(BlockState::BLOCK_UNLOADED), "BLOCK_UNLOADED" }, - { static_cast(BlockState::BLOCK_LOADED), "BLOCK_LOADED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(BlockState value) { - return StringUtil::EnumToString(GetBlockStateValues(), 2, "BlockState", static_cast(value)); -} - -template<> -BlockState EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetBlockStateValues(), 2, "BlockState", value)); -} - -const StringUtil::EnumStringLiteral *GetCAPIResultSetTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_NONE), "CAPI_RESULT_TYPE_NONE" }, - { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_MATERIALIZED), "CAPI_RESULT_TYPE_MATERIALIZED" }, - { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_STREAMING), "CAPI_RESULT_TYPE_STREAMING" }, - { static_cast(CAPIResultSetType::CAPI_RESULT_TYPE_DEPRECATED), "CAPI_RESULT_TYPE_DEPRECATED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CAPIResultSetType value) { - return StringUtil::EnumToString(GetCAPIResultSetTypeValues(), 4, "CAPIResultSetType", static_cast(value)); -} - -template<> -CAPIResultSetType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCAPIResultSetTypeValues(), 4, "CAPIResultSetType", value)); -} - -const StringUtil::EnumStringLiteral *GetCSVStateValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CSVState::STANDARD), "STANDARD" }, - { static_cast(CSVState::DELIMITER), "DELIMITER" }, - { static_cast(CSVState::DELIMITER_FIRST_BYTE), "DELIMITER_FIRST_BYTE" }, - { static_cast(CSVState::DELIMITER_SECOND_BYTE), "DELIMITER_SECOND_BYTE" }, - { static_cast(CSVState::DELIMITER_THIRD_BYTE), "DELIMITER_THIRD_BYTE" }, - { static_cast(CSVState::RECORD_SEPARATOR), "RECORD_SEPARATOR" }, - { static_cast(CSVState::CARRIAGE_RETURN), "CARRIAGE_RETURN" }, - { static_cast(CSVState::QUOTED), "QUOTED" }, - { static_cast(CSVState::UNQUOTED), "UNQUOTED" }, - { static_cast(CSVState::ESCAPE), "ESCAPE" }, - { static_cast(CSVState::INVALID), "INVALID" }, - { static_cast(CSVState::NOT_SET), "NOT_SET" }, - { static_cast(CSVState::QUOTED_NEW_LINE), "QUOTED_NEW_LINE" }, - { static_cast(CSVState::EMPTY_SPACE), "EMPTY_SPACE" }, - { static_cast(CSVState::COMMENT), "COMMENT" }, - { static_cast(CSVState::STANDARD_NEWLINE), "STANDARD_NEWLINE" }, - { static_cast(CSVState::UNQUOTED_ESCAPE), "UNQUOTED_ESCAPE" }, - { static_cast(CSVState::ESCAPED_RETURN), "ESCAPED_RETURN" }, - { static_cast(CSVState::MAYBE_QUOTED), "MAYBE_QUOTED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CSVState value) { - return StringUtil::EnumToString(GetCSVStateValues(), 19, "CSVState", static_cast(value)); -} - -template<> -CSVState EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCSVStateValues(), 19, "CSVState", value)); -} - -const StringUtil::EnumStringLiteral *GetCTEMaterializeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CTEMaterialize::CTE_MATERIALIZE_DEFAULT), "CTE_MATERIALIZE_DEFAULT" }, - { static_cast(CTEMaterialize::CTE_MATERIALIZE_ALWAYS), "CTE_MATERIALIZE_ALWAYS" }, - { static_cast(CTEMaterialize::CTE_MATERIALIZE_NEVER), "CTE_MATERIALIZE_NEVER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CTEMaterialize value) { - return StringUtil::EnumToString(GetCTEMaterializeValues(), 3, "CTEMaterialize", static_cast(value)); -} - -template<> -CTEMaterialize EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCTEMaterializeValues(), 3, "CTEMaterialize", value)); -} - -const StringUtil::EnumStringLiteral *GetCatalogLookupBehaviorValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CatalogLookupBehavior::STANDARD), "STANDARD" }, - { static_cast(CatalogLookupBehavior::LOWER_PRIORITY), "LOWER_PRIORITY" }, - { static_cast(CatalogLookupBehavior::NEVER_LOOKUP), "NEVER_LOOKUP" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CatalogLookupBehavior value) { - return StringUtil::EnumToString(GetCatalogLookupBehaviorValues(), 3, "CatalogLookupBehavior", static_cast(value)); -} - -template<> -CatalogLookupBehavior EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCatalogLookupBehaviorValues(), 3, "CatalogLookupBehavior", value)); -} - -const StringUtil::EnumStringLiteral *GetCatalogTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CatalogType::INVALID), "INVALID" }, - { static_cast(CatalogType::TABLE_ENTRY), "TABLE_ENTRY" }, - { static_cast(CatalogType::SCHEMA_ENTRY), "SCHEMA_ENTRY" }, - { static_cast(CatalogType::VIEW_ENTRY), "VIEW_ENTRY" }, - { static_cast(CatalogType::INDEX_ENTRY), "INDEX_ENTRY" }, - { static_cast(CatalogType::PREPARED_STATEMENT), "PREPARED_STATEMENT" }, - { static_cast(CatalogType::SEQUENCE_ENTRY), "SEQUENCE_ENTRY" }, - { static_cast(CatalogType::COLLATION_ENTRY), "COLLATION_ENTRY" }, - { static_cast(CatalogType::TYPE_ENTRY), "TYPE_ENTRY" }, - { static_cast(CatalogType::DATABASE_ENTRY), "DATABASE_ENTRY" }, - { static_cast(CatalogType::TABLE_FUNCTION_ENTRY), "TABLE_FUNCTION_ENTRY" }, - { static_cast(CatalogType::SCALAR_FUNCTION_ENTRY), "SCALAR_FUNCTION_ENTRY" }, - { static_cast(CatalogType::AGGREGATE_FUNCTION_ENTRY), "AGGREGATE_FUNCTION_ENTRY" }, - { static_cast(CatalogType::PRAGMA_FUNCTION_ENTRY), "PRAGMA_FUNCTION_ENTRY" }, - { static_cast(CatalogType::COPY_FUNCTION_ENTRY), "COPY_FUNCTION_ENTRY" }, - { static_cast(CatalogType::MACRO_ENTRY), "MACRO_ENTRY" }, - { static_cast(CatalogType::TABLE_MACRO_ENTRY), "TABLE_MACRO_ENTRY" }, - { static_cast(CatalogType::DELETED_ENTRY), "DELETED_ENTRY" }, - { static_cast(CatalogType::RENAMED_ENTRY), "RENAMED_ENTRY" }, - { static_cast(CatalogType::SECRET_ENTRY), "SECRET_ENTRY" }, - { static_cast(CatalogType::SECRET_TYPE_ENTRY), "SECRET_TYPE_ENTRY" }, - { static_cast(CatalogType::SECRET_FUNCTION_ENTRY), "SECRET_FUNCTION_ENTRY" }, - { static_cast(CatalogType::DEPENDENCY_ENTRY), "DEPENDENCY_ENTRY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CatalogType value) { - return StringUtil::EnumToString(GetCatalogTypeValues(), 23, "CatalogType", static_cast(value)); -} - -template<> -CatalogType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCatalogTypeValues(), 23, "CatalogType", value)); -} - -const StringUtil::EnumStringLiteral *GetCheckpointAbortValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CheckpointAbort::NO_ABORT), "NONE" }, - { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_TRUNCATE), "BEFORE_TRUNCATE" }, - { static_cast(CheckpointAbort::DEBUG_ABORT_BEFORE_HEADER), "BEFORE_HEADER" }, - { static_cast(CheckpointAbort::DEBUG_ABORT_AFTER_FREE_LIST_WRITE), "AFTER_FREE_LIST_WRITE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CheckpointAbort value) { - return StringUtil::EnumToString(GetCheckpointAbortValues(), 4, "CheckpointAbort", static_cast(value)); -} - -template<> -CheckpointAbort EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCheckpointAbortValues(), 4, "CheckpointAbort", value)); -} - -const StringUtil::EnumStringLiteral *GetChunkInfoTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ChunkInfoType::CONSTANT_INFO), "CONSTANT_INFO" }, - { static_cast(ChunkInfoType::VECTOR_INFO), "VECTOR_INFO" }, - { static_cast(ChunkInfoType::EMPTY_INFO), "EMPTY_INFO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ChunkInfoType value) { - return StringUtil::EnumToString(GetChunkInfoTypeValues(), 3, "ChunkInfoType", static_cast(value)); -} - -template<> -ChunkInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetChunkInfoTypeValues(), 3, "ChunkInfoType", value)); -} - -const StringUtil::EnumStringLiteral *GetColumnDataAllocatorTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR), "BUFFER_MANAGER_ALLOCATOR" }, - { static_cast(ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR), "IN_MEMORY_ALLOCATOR" }, - { static_cast(ColumnDataAllocatorType::HYBRID), "HYBRID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ColumnDataAllocatorType value) { - return StringUtil::EnumToString(GetColumnDataAllocatorTypeValues(), 3, "ColumnDataAllocatorType", static_cast(value)); -} - -template<> -ColumnDataAllocatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetColumnDataAllocatorTypeValues(), 3, "ColumnDataAllocatorType", value)); -} - -const StringUtil::EnumStringLiteral *GetColumnDataScanPropertiesValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ColumnDataScanProperties::INVALID), "INVALID" }, - { static_cast(ColumnDataScanProperties::ALLOW_ZERO_COPY), "ALLOW_ZERO_COPY" }, - { static_cast(ColumnDataScanProperties::DISALLOW_ZERO_COPY), "DISALLOW_ZERO_COPY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ColumnDataScanProperties value) { - return StringUtil::EnumToString(GetColumnDataScanPropertiesValues(), 3, "ColumnDataScanProperties", static_cast(value)); -} - -template<> -ColumnDataScanProperties EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetColumnDataScanPropertiesValues(), 3, "ColumnDataScanProperties", value)); -} - -const StringUtil::EnumStringLiteral *GetColumnSegmentTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ColumnSegmentType::TRANSIENT), "TRANSIENT" }, - { static_cast(ColumnSegmentType::PERSISTENT), "PERSISTENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ColumnSegmentType value) { - return StringUtil::EnumToString(GetColumnSegmentTypeValues(), 2, "ColumnSegmentType", static_cast(value)); -} - -template<> -ColumnSegmentType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetColumnSegmentTypeValues(), 2, "ColumnSegmentType", value)); -} - -const StringUtil::EnumStringLiteral *GetCompressedMaterializationDirectionValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CompressedMaterializationDirection::INVALID), "INVALID" }, - { static_cast(CompressedMaterializationDirection::COMPRESS), "COMPRESS" }, - { static_cast(CompressedMaterializationDirection::DECOMPRESS), "DECOMPRESS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CompressedMaterializationDirection value) { - return StringUtil::EnumToString(GetCompressedMaterializationDirectionValues(), 3, "CompressedMaterializationDirection", static_cast(value)); -} - -template<> -CompressedMaterializationDirection EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCompressedMaterializationDirectionValues(), 3, "CompressedMaterializationDirection", value)); -} - -const StringUtil::EnumStringLiteral *GetCompressionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CompressionType::COMPRESSION_AUTO), "COMPRESSION_AUTO" }, - { static_cast(CompressionType::COMPRESSION_UNCOMPRESSED), "COMPRESSION_UNCOMPRESSED" }, - { static_cast(CompressionType::COMPRESSION_CONSTANT), "COMPRESSION_CONSTANT" }, - { static_cast(CompressionType::COMPRESSION_RLE), "COMPRESSION_RLE" }, - { static_cast(CompressionType::COMPRESSION_DICTIONARY), "COMPRESSION_DICTIONARY" }, - { static_cast(CompressionType::COMPRESSION_PFOR_DELTA), "COMPRESSION_PFOR_DELTA" }, - { static_cast(CompressionType::COMPRESSION_BITPACKING), "COMPRESSION_BITPACKING" }, - { static_cast(CompressionType::COMPRESSION_FSST), "COMPRESSION_FSST" }, - { static_cast(CompressionType::COMPRESSION_CHIMP), "COMPRESSION_CHIMP" }, - { static_cast(CompressionType::COMPRESSION_PATAS), "COMPRESSION_PATAS" }, - { static_cast(CompressionType::COMPRESSION_ALP), "COMPRESSION_ALP" }, - { static_cast(CompressionType::COMPRESSION_ALPRD), "COMPRESSION_ALPRD" }, - { static_cast(CompressionType::COMPRESSION_ZSTD), "COMPRESSION_ZSTD" }, - { static_cast(CompressionType::COMPRESSION_ROARING), "COMPRESSION_ROARING" }, - { static_cast(CompressionType::COMPRESSION_EMPTY), "COMPRESSION_EMPTY" }, - { static_cast(CompressionType::COMPRESSION_COUNT), "COMPRESSION_COUNT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CompressionType value) { - return StringUtil::EnumToString(GetCompressionTypeValues(), 16, "CompressionType", static_cast(value)); -} - -template<> -CompressionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCompressionTypeValues(), 16, "CompressionType", value)); -} - -const StringUtil::EnumStringLiteral *GetCompressionValidityValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CompressionValidity::REQUIRES_VALIDITY), "REQUIRES_VALIDITY" }, - { static_cast(CompressionValidity::NO_VALIDITY_REQUIRED), "NO_VALIDITY_REQUIRED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CompressionValidity value) { - return StringUtil::EnumToString(GetCompressionValidityValues(), 2, "CompressionValidity", static_cast(value)); -} - -template<> -CompressionValidity EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCompressionValidityValues(), 2, "CompressionValidity", value)); -} - -const StringUtil::EnumStringLiteral *GetConflictManagerModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ConflictManagerMode::SCAN), "SCAN" }, - { static_cast(ConflictManagerMode::THROW), "THROW" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ConflictManagerMode value) { - return StringUtil::EnumToString(GetConflictManagerModeValues(), 2, "ConflictManagerMode", static_cast(value)); -} - -template<> -ConflictManagerMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetConflictManagerModeValues(), 2, "ConflictManagerMode", value)); -} - -const StringUtil::EnumStringLiteral *GetConstraintTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ConstraintType::INVALID), "INVALID" }, - { static_cast(ConstraintType::NOT_NULL), "NOT_NULL" }, - { static_cast(ConstraintType::CHECK), "CHECK" }, - { static_cast(ConstraintType::UNIQUE), "UNIQUE" }, - { static_cast(ConstraintType::FOREIGN_KEY), "FOREIGN_KEY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ConstraintType value) { - return StringUtil::EnumToString(GetConstraintTypeValues(), 5, "ConstraintType", static_cast(value)); -} - -template<> -ConstraintType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetConstraintTypeValues(), 5, "ConstraintType", value)); -} - -const StringUtil::EnumStringLiteral *GetCopyFunctionReturnTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CopyFunctionReturnType::CHANGED_ROWS), "CHANGED_ROWS" }, - { static_cast(CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST), "CHANGED_ROWS_AND_FILE_LIST" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CopyFunctionReturnType value) { - return StringUtil::EnumToString(GetCopyFunctionReturnTypeValues(), 2, "CopyFunctionReturnType", static_cast(value)); -} - -template<> -CopyFunctionReturnType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCopyFunctionReturnTypeValues(), 2, "CopyFunctionReturnType", value)); -} - -const StringUtil::EnumStringLiteral *GetCopyOverwriteModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CopyOverwriteMode::COPY_ERROR_ON_CONFLICT), "COPY_ERROR_ON_CONFLICT" }, - { static_cast(CopyOverwriteMode::COPY_OVERWRITE), "COPY_OVERWRITE" }, - { static_cast(CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE), "COPY_OVERWRITE_OR_IGNORE" }, - { static_cast(CopyOverwriteMode::COPY_APPEND), "COPY_APPEND" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CopyOverwriteMode value) { - return StringUtil::EnumToString(GetCopyOverwriteModeValues(), 4, "CopyOverwriteMode", static_cast(value)); -} - -template<> -CopyOverwriteMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCopyOverwriteModeValues(), 4, "CopyOverwriteMode", value)); -} - -const StringUtil::EnumStringLiteral *GetCopyToTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(CopyToType::COPY_TO_FILE), "COPY_TO_FILE" }, - { static_cast(CopyToType::EXPORT_DATABASE), "EXPORT_DATABASE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(CopyToType value) { - return StringUtil::EnumToString(GetCopyToTypeValues(), 2, "CopyToType", static_cast(value)); -} - -template<> -CopyToType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetCopyToTypeValues(), 2, "CopyToType", value)); -} - -const StringUtil::EnumStringLiteral *GetDataFileTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DataFileType::FILE_DOES_NOT_EXIST), "FILE_DOES_NOT_EXIST" }, - { static_cast(DataFileType::DUCKDB_FILE), "DUCKDB_FILE" }, - { static_cast(DataFileType::SQLITE_FILE), "SQLITE_FILE" }, - { static_cast(DataFileType::PARQUET_FILE), "PARQUET_FILE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DataFileType value) { - return StringUtil::EnumToString(GetDataFileTypeValues(), 4, "DataFileType", static_cast(value)); -} - -template<> -DataFileType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDataFileTypeValues(), 4, "DataFileType", value)); -} - -const StringUtil::EnumStringLiteral *GetDateCastResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DateCastResult::SUCCESS), "SUCCESS" }, - { static_cast(DateCastResult::ERROR_INCORRECT_FORMAT), "ERROR_INCORRECT_FORMAT" }, - { static_cast(DateCastResult::ERROR_RANGE), "ERROR_RANGE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DateCastResult value) { - return StringUtil::EnumToString(GetDateCastResultValues(), 3, "DateCastResult", static_cast(value)); -} - -template<> -DateCastResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDateCastResultValues(), 3, "DateCastResult", value)); -} - -const StringUtil::EnumStringLiteral *GetDatePartSpecifierValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DatePartSpecifier::YEAR), "YEAR" }, - { static_cast(DatePartSpecifier::MONTH), "MONTH" }, - { static_cast(DatePartSpecifier::DAY), "DAY" }, - { static_cast(DatePartSpecifier::DECADE), "DECADE" }, - { static_cast(DatePartSpecifier::CENTURY), "CENTURY" }, - { static_cast(DatePartSpecifier::MILLENNIUM), "MILLENNIUM" }, - { static_cast(DatePartSpecifier::MICROSECONDS), "MICROSECONDS" }, - { static_cast(DatePartSpecifier::MILLISECONDS), "MILLISECONDS" }, - { static_cast(DatePartSpecifier::SECOND), "SECOND" }, - { static_cast(DatePartSpecifier::MINUTE), "MINUTE" }, - { static_cast(DatePartSpecifier::HOUR), "HOUR" }, - { static_cast(DatePartSpecifier::DOW), "DOW" }, - { static_cast(DatePartSpecifier::ISODOW), "ISODOW" }, - { static_cast(DatePartSpecifier::WEEK), "WEEK" }, - { static_cast(DatePartSpecifier::ISOYEAR), "ISOYEAR" }, - { static_cast(DatePartSpecifier::QUARTER), "QUARTER" }, - { static_cast(DatePartSpecifier::DOY), "DOY" }, - { static_cast(DatePartSpecifier::YEARWEEK), "YEARWEEK" }, - { static_cast(DatePartSpecifier::ERA), "ERA" }, - { static_cast(DatePartSpecifier::TIMEZONE), "TIMEZONE" }, - { static_cast(DatePartSpecifier::TIMEZONE_HOUR), "TIMEZONE_HOUR" }, - { static_cast(DatePartSpecifier::TIMEZONE_MINUTE), "TIMEZONE_MINUTE" }, - { static_cast(DatePartSpecifier::EPOCH), "EPOCH" }, - { static_cast(DatePartSpecifier::JULIAN_DAY), "JULIAN_DAY" }, - { static_cast(DatePartSpecifier::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DatePartSpecifier value) { - return StringUtil::EnumToString(GetDatePartSpecifierValues(), 25, "DatePartSpecifier", static_cast(value)); -} - -template<> -DatePartSpecifier EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDatePartSpecifierValues(), 25, "DatePartSpecifier", value)); -} - -const StringUtil::EnumStringLiteral *GetDebugInitializeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DebugInitialize::NO_INITIALIZE), "NO_INITIALIZE" }, - { static_cast(DebugInitialize::DEBUG_ZERO_INITIALIZE), "DEBUG_ZERO_INITIALIZE" }, - { static_cast(DebugInitialize::DEBUG_ONE_INITIALIZE), "DEBUG_ONE_INITIALIZE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DebugInitialize value) { - return StringUtil::EnumToString(GetDebugInitializeValues(), 3, "DebugInitialize", static_cast(value)); -} - -template<> -DebugInitialize EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDebugInitializeValues(), 3, "DebugInitialize", value)); -} - -const StringUtil::EnumStringLiteral *GetDefaultOrderByNullTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DefaultOrderByNullType::INVALID), "INVALID" }, - { static_cast(DefaultOrderByNullType::NULLS_FIRST), "NULLS_FIRST" }, - { static_cast(DefaultOrderByNullType::NULLS_LAST), "NULLS_LAST" }, - { static_cast(DefaultOrderByNullType::NULLS_FIRST_ON_ASC_LAST_ON_DESC), "NULLS_FIRST_ON_ASC_LAST_ON_DESC" }, - { static_cast(DefaultOrderByNullType::NULLS_LAST_ON_ASC_FIRST_ON_DESC), "NULLS_LAST_ON_ASC_FIRST_ON_DESC" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DefaultOrderByNullType value) { - return StringUtil::EnumToString(GetDefaultOrderByNullTypeValues(), 5, "DefaultOrderByNullType", static_cast(value)); -} - -template<> -DefaultOrderByNullType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDefaultOrderByNullTypeValues(), 5, "DefaultOrderByNullType", value)); -} - -const StringUtil::EnumStringLiteral *GetDependencyEntryTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DependencyEntryType::SUBJECT), "SUBJECT" }, - { static_cast(DependencyEntryType::DEPENDENT), "DEPENDENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DependencyEntryType value) { - return StringUtil::EnumToString(GetDependencyEntryTypeValues(), 2, "DependencyEntryType", static_cast(value)); -} - -template<> -DependencyEntryType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDependencyEntryTypeValues(), 2, "DependencyEntryType", value)); -} - -const StringUtil::EnumStringLiteral *GetDeprecatedIndexTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DeprecatedIndexType::INVALID), "INVALID" }, - { static_cast(DeprecatedIndexType::ART), "ART" }, - { static_cast(DeprecatedIndexType::EXTENSION), "EXTENSION" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DeprecatedIndexType value) { - return StringUtil::EnumToString(GetDeprecatedIndexTypeValues(), 3, "DeprecatedIndexType", static_cast(value)); -} - -template<> -DeprecatedIndexType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDeprecatedIndexTypeValues(), 3, "DeprecatedIndexType", value)); -} - -const StringUtil::EnumStringLiteral *GetDestroyBufferUponValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DestroyBufferUpon::BLOCK), "BLOCK" }, - { static_cast(DestroyBufferUpon::EVICTION), "EVICTION" }, - { static_cast(DestroyBufferUpon::UNPIN), "UNPIN" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DestroyBufferUpon value) { - return StringUtil::EnumToString(GetDestroyBufferUponValues(), 3, "DestroyBufferUpon", static_cast(value)); -} - -template<> -DestroyBufferUpon EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDestroyBufferUponValues(), 3, "DestroyBufferUpon", value)); -} - -const StringUtil::EnumStringLiteral *GetDistinctTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(DistinctType::DISTINCT), "DISTINCT" }, - { static_cast(DistinctType::DISTINCT_ON), "DISTINCT_ON" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(DistinctType value) { - return StringUtil::EnumToString(GetDistinctTypeValues(), 2, "DistinctType", static_cast(value)); -} - -template<> -DistinctType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetDistinctTypeValues(), 2, "DistinctType", value)); -} - -const StringUtil::EnumStringLiteral *GetErrorTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ErrorType::UNSIGNED_EXTENSION), "UNSIGNED_EXTENSION" }, - { static_cast(ErrorType::INVALIDATED_TRANSACTION), "INVALIDATED_TRANSACTION" }, - { static_cast(ErrorType::INVALIDATED_DATABASE), "INVALIDATED_DATABASE" }, - { static_cast(ErrorType::ERROR_COUNT), "ERROR_COUNT" }, - { static_cast(ErrorType::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ErrorType value) { - return StringUtil::EnumToString(GetErrorTypeValues(), 5, "ErrorType", static_cast(value)); -} - -template<> -ErrorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetErrorTypeValues(), 5, "ErrorType", value)); -} - -const StringUtil::EnumStringLiteral *GetExceptionFormatValueTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE), "FORMAT_VALUE_TYPE_DOUBLE" }, - { static_cast(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), "FORMAT_VALUE_TYPE_INTEGER" }, - { static_cast(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), "FORMAT_VALUE_TYPE_STRING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExceptionFormatValueType value) { - return StringUtil::EnumToString(GetExceptionFormatValueTypeValues(), 3, "ExceptionFormatValueType", static_cast(value)); -} - -template<> -ExceptionFormatValueType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExceptionFormatValueTypeValues(), 3, "ExceptionFormatValueType", value)); -} - -const StringUtil::EnumStringLiteral *GetExceptionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExceptionType::INVALID), "INVALID" }, - { static_cast(ExceptionType::OUT_OF_RANGE), "OUT_OF_RANGE" }, - { static_cast(ExceptionType::CONVERSION), "CONVERSION" }, - { static_cast(ExceptionType::UNKNOWN_TYPE), "UNKNOWN_TYPE" }, - { static_cast(ExceptionType::DECIMAL), "DECIMAL" }, - { static_cast(ExceptionType::MISMATCH_TYPE), "MISMATCH_TYPE" }, - { static_cast(ExceptionType::DIVIDE_BY_ZERO), "DIVIDE_BY_ZERO" }, - { static_cast(ExceptionType::OBJECT_SIZE), "OBJECT_SIZE" }, - { static_cast(ExceptionType::INVALID_TYPE), "INVALID_TYPE" }, - { static_cast(ExceptionType::SERIALIZATION), "SERIALIZATION" }, - { static_cast(ExceptionType::TRANSACTION), "TRANSACTION" }, - { static_cast(ExceptionType::NOT_IMPLEMENTED), "NOT_IMPLEMENTED" }, - { static_cast(ExceptionType::EXPRESSION), "EXPRESSION" }, - { static_cast(ExceptionType::CATALOG), "CATALOG" }, - { static_cast(ExceptionType::PARSER), "PARSER" }, - { static_cast(ExceptionType::PLANNER), "PLANNER" }, - { static_cast(ExceptionType::SCHEDULER), "SCHEDULER" }, - { static_cast(ExceptionType::EXECUTOR), "EXECUTOR" }, - { static_cast(ExceptionType::CONSTRAINT), "CONSTRAINT" }, - { static_cast(ExceptionType::INDEX), "INDEX" }, - { static_cast(ExceptionType::STAT), "STAT" }, - { static_cast(ExceptionType::CONNECTION), "CONNECTION" }, - { static_cast(ExceptionType::SYNTAX), "SYNTAX" }, - { static_cast(ExceptionType::SETTINGS), "SETTINGS" }, - { static_cast(ExceptionType::BINDER), "BINDER" }, - { static_cast(ExceptionType::NETWORK), "NETWORK" }, - { static_cast(ExceptionType::OPTIMIZER), "OPTIMIZER" }, - { static_cast(ExceptionType::NULL_POINTER), "NULL_POINTER" }, - { static_cast(ExceptionType::IO), "IO" }, - { static_cast(ExceptionType::INTERRUPT), "INTERRUPT" }, - { static_cast(ExceptionType::FATAL), "FATAL" }, - { static_cast(ExceptionType::INTERNAL), "INTERNAL" }, - { static_cast(ExceptionType::INVALID_INPUT), "INVALID_INPUT" }, - { static_cast(ExceptionType::OUT_OF_MEMORY), "OUT_OF_MEMORY" }, - { static_cast(ExceptionType::PERMISSION), "PERMISSION" }, - { static_cast(ExceptionType::PARAMETER_NOT_RESOLVED), "PARAMETER_NOT_RESOLVED" }, - { static_cast(ExceptionType::PARAMETER_NOT_ALLOWED), "PARAMETER_NOT_ALLOWED" }, - { static_cast(ExceptionType::DEPENDENCY), "DEPENDENCY" }, - { static_cast(ExceptionType::HTTP), "HTTP" }, - { static_cast(ExceptionType::MISSING_EXTENSION), "MISSING_EXTENSION" }, - { static_cast(ExceptionType::AUTOLOAD), "AUTOLOAD" }, - { static_cast(ExceptionType::SEQUENCE), "SEQUENCE" }, - { static_cast(ExceptionType::INVALID_CONFIGURATION), "INVALID_CONFIGURATION" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExceptionType value) { - return StringUtil::EnumToString(GetExceptionTypeValues(), 43, "ExceptionType", static_cast(value)); -} - -template<> -ExceptionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExceptionTypeValues(), 43, "ExceptionType", value)); -} - -const StringUtil::EnumStringLiteral *GetExplainFormatValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExplainFormat::DEFAULT), "DEFAULT" }, - { static_cast(ExplainFormat::TEXT), "TEXT" }, - { static_cast(ExplainFormat::JSON), "JSON" }, - { static_cast(ExplainFormat::HTML), "HTML" }, - { static_cast(ExplainFormat::GRAPHVIZ), "GRAPHVIZ" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExplainFormat value) { - return StringUtil::EnumToString(GetExplainFormatValues(), 5, "ExplainFormat", static_cast(value)); -} - -template<> -ExplainFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExplainFormatValues(), 5, "ExplainFormat", value)); -} - -const StringUtil::EnumStringLiteral *GetExplainOutputTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExplainOutputType::ALL), "ALL" }, - { static_cast(ExplainOutputType::OPTIMIZED_ONLY), "OPTIMIZED_ONLY" }, - { static_cast(ExplainOutputType::PHYSICAL_ONLY), "PHYSICAL_ONLY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExplainOutputType value) { - return StringUtil::EnumToString(GetExplainOutputTypeValues(), 3, "ExplainOutputType", static_cast(value)); -} - -template<> -ExplainOutputType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExplainOutputTypeValues(), 3, "ExplainOutputType", value)); -} - -const StringUtil::EnumStringLiteral *GetExplainTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExplainType::EXPLAIN_STANDARD), "EXPLAIN_STANDARD" }, - { static_cast(ExplainType::EXPLAIN_ANALYZE), "EXPLAIN_ANALYZE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExplainType value) { - return StringUtil::EnumToString(GetExplainTypeValues(), 2, "ExplainType", static_cast(value)); -} - -template<> -ExplainType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExplainTypeValues(), 2, "ExplainType", value)); -} - -const StringUtil::EnumStringLiteral *GetExponentTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExponentType::NONE), "NONE" }, - { static_cast(ExponentType::POSITIVE), "POSITIVE" }, - { static_cast(ExponentType::NEGATIVE), "NEGATIVE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExponentType value) { - return StringUtil::EnumToString(GetExponentTypeValues(), 3, "ExponentType", static_cast(value)); -} - -template<> -ExponentType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExponentTypeValues(), 3, "ExponentType", value)); -} - -const StringUtil::EnumStringLiteral *GetExpressionClassValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExpressionClass::INVALID), "INVALID" }, - { static_cast(ExpressionClass::AGGREGATE), "AGGREGATE" }, - { static_cast(ExpressionClass::CASE), "CASE" }, - { static_cast(ExpressionClass::CAST), "CAST" }, - { static_cast(ExpressionClass::COLUMN_REF), "COLUMN_REF" }, - { static_cast(ExpressionClass::COMPARISON), "COMPARISON" }, - { static_cast(ExpressionClass::CONJUNCTION), "CONJUNCTION" }, - { static_cast(ExpressionClass::CONSTANT), "CONSTANT" }, - { static_cast(ExpressionClass::DEFAULT), "DEFAULT" }, - { static_cast(ExpressionClass::FUNCTION), "FUNCTION" }, - { static_cast(ExpressionClass::OPERATOR), "OPERATOR" }, - { static_cast(ExpressionClass::STAR), "STAR" }, - { static_cast(ExpressionClass::SUBQUERY), "SUBQUERY" }, - { static_cast(ExpressionClass::WINDOW), "WINDOW" }, - { static_cast(ExpressionClass::PARAMETER), "PARAMETER" }, - { static_cast(ExpressionClass::COLLATE), "COLLATE" }, - { static_cast(ExpressionClass::LAMBDA), "LAMBDA" }, - { static_cast(ExpressionClass::POSITIONAL_REFERENCE), "POSITIONAL_REFERENCE" }, - { static_cast(ExpressionClass::BETWEEN), "BETWEEN" }, - { static_cast(ExpressionClass::LAMBDA_REF), "LAMBDA_REF" }, - { static_cast(ExpressionClass::BOUND_AGGREGATE), "BOUND_AGGREGATE" }, - { static_cast(ExpressionClass::BOUND_CASE), "BOUND_CASE" }, - { static_cast(ExpressionClass::BOUND_CAST), "BOUND_CAST" }, - { static_cast(ExpressionClass::BOUND_COLUMN_REF), "BOUND_COLUMN_REF" }, - { static_cast(ExpressionClass::BOUND_COMPARISON), "BOUND_COMPARISON" }, - { static_cast(ExpressionClass::BOUND_CONJUNCTION), "BOUND_CONJUNCTION" }, - { static_cast(ExpressionClass::BOUND_CONSTANT), "BOUND_CONSTANT" }, - { static_cast(ExpressionClass::BOUND_DEFAULT), "BOUND_DEFAULT" }, - { static_cast(ExpressionClass::BOUND_FUNCTION), "BOUND_FUNCTION" }, - { static_cast(ExpressionClass::BOUND_OPERATOR), "BOUND_OPERATOR" }, - { static_cast(ExpressionClass::BOUND_PARAMETER), "BOUND_PARAMETER" }, - { static_cast(ExpressionClass::BOUND_REF), "BOUND_REF" }, - { static_cast(ExpressionClass::BOUND_SUBQUERY), "BOUND_SUBQUERY" }, - { static_cast(ExpressionClass::BOUND_WINDOW), "BOUND_WINDOW" }, - { static_cast(ExpressionClass::BOUND_BETWEEN), "BOUND_BETWEEN" }, - { static_cast(ExpressionClass::BOUND_UNNEST), "BOUND_UNNEST" }, - { static_cast(ExpressionClass::BOUND_LAMBDA), "BOUND_LAMBDA" }, - { static_cast(ExpressionClass::BOUND_LAMBDA_REF), "BOUND_LAMBDA_REF" }, - { static_cast(ExpressionClass::BOUND_EXPRESSION), "BOUND_EXPRESSION" }, - { static_cast(ExpressionClass::BOUND_EXPANDED), "BOUND_EXPANDED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExpressionClass value) { - return StringUtil::EnumToString(GetExpressionClassValues(), 40, "ExpressionClass", static_cast(value)); -} - -template<> -ExpressionClass EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExpressionClassValues(), 40, "ExpressionClass", value)); -} - -const StringUtil::EnumStringLiteral *GetExpressionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExpressionType::INVALID), "INVALID" }, - { static_cast(ExpressionType::OPERATOR_CAST), "OPERATOR_CAST" }, - { static_cast(ExpressionType::OPERATOR_NOT), "OPERATOR_NOT" }, - { static_cast(ExpressionType::OPERATOR_IS_NULL), "OPERATOR_IS_NULL" }, - { static_cast(ExpressionType::OPERATOR_IS_NOT_NULL), "OPERATOR_IS_NOT_NULL" }, - { static_cast(ExpressionType::COMPARE_EQUAL), "COMPARE_EQUAL" }, - { static_cast(ExpressionType::COMPARE_NOTEQUAL), "COMPARE_NOTEQUAL" }, - { static_cast(ExpressionType::COMPARE_LESSTHAN), "COMPARE_LESSTHAN" }, - { static_cast(ExpressionType::COMPARE_GREATERTHAN), "COMPARE_GREATERTHAN" }, - { static_cast(ExpressionType::COMPARE_LESSTHANOREQUALTO), "COMPARE_LESSTHANOREQUALTO" }, - { static_cast(ExpressionType::COMPARE_GREATERTHANOREQUALTO), "COMPARE_GREATERTHANOREQUALTO" }, - { static_cast(ExpressionType::COMPARE_IN), "COMPARE_IN" }, - { static_cast(ExpressionType::COMPARE_NOT_IN), "COMPARE_NOT_IN" }, - { static_cast(ExpressionType::COMPARE_DISTINCT_FROM), "COMPARE_DISTINCT_FROM" }, - { static_cast(ExpressionType::COMPARE_BETWEEN), "COMPARE_BETWEEN" }, - { static_cast(ExpressionType::COMPARE_NOT_BETWEEN), "COMPARE_NOT_BETWEEN" }, - { static_cast(ExpressionType::COMPARE_NOT_DISTINCT_FROM), "COMPARE_NOT_DISTINCT_FROM" }, - { static_cast(ExpressionType::CONJUNCTION_AND), "CONJUNCTION_AND" }, - { static_cast(ExpressionType::CONJUNCTION_OR), "CONJUNCTION_OR" }, - { static_cast(ExpressionType::VALUE_CONSTANT), "VALUE_CONSTANT" }, - { static_cast(ExpressionType::VALUE_PARAMETER), "VALUE_PARAMETER" }, - { static_cast(ExpressionType::VALUE_TUPLE), "VALUE_TUPLE" }, - { static_cast(ExpressionType::VALUE_TUPLE_ADDRESS), "VALUE_TUPLE_ADDRESS" }, - { static_cast(ExpressionType::VALUE_NULL), "VALUE_NULL" }, - { static_cast(ExpressionType::VALUE_VECTOR), "VALUE_VECTOR" }, - { static_cast(ExpressionType::VALUE_SCALAR), "VALUE_SCALAR" }, - { static_cast(ExpressionType::VALUE_DEFAULT), "VALUE_DEFAULT" }, - { static_cast(ExpressionType::AGGREGATE), "AGGREGATE" }, - { static_cast(ExpressionType::BOUND_AGGREGATE), "BOUND_AGGREGATE" }, - { static_cast(ExpressionType::GROUPING_FUNCTION), "GROUPING_FUNCTION" }, - { static_cast(ExpressionType::WINDOW_AGGREGATE), "WINDOW_AGGREGATE" }, - { static_cast(ExpressionType::WINDOW_RANK), "WINDOW_RANK" }, - { static_cast(ExpressionType::WINDOW_RANK_DENSE), "WINDOW_RANK_DENSE" }, - { static_cast(ExpressionType::WINDOW_NTILE), "WINDOW_NTILE" }, - { static_cast(ExpressionType::WINDOW_PERCENT_RANK), "WINDOW_PERCENT_RANK" }, - { static_cast(ExpressionType::WINDOW_CUME_DIST), "WINDOW_CUME_DIST" }, - { static_cast(ExpressionType::WINDOW_ROW_NUMBER), "WINDOW_ROW_NUMBER" }, - { static_cast(ExpressionType::WINDOW_FIRST_VALUE), "WINDOW_FIRST_VALUE" }, - { static_cast(ExpressionType::WINDOW_LAST_VALUE), "WINDOW_LAST_VALUE" }, - { static_cast(ExpressionType::WINDOW_LEAD), "WINDOW_LEAD" }, - { static_cast(ExpressionType::WINDOW_LAG), "WINDOW_LAG" }, - { static_cast(ExpressionType::WINDOW_NTH_VALUE), "WINDOW_NTH_VALUE" }, - { static_cast(ExpressionType::FUNCTION), "FUNCTION" }, - { static_cast(ExpressionType::BOUND_FUNCTION), "BOUND_FUNCTION" }, - { static_cast(ExpressionType::CASE_EXPR), "CASE_EXPR" }, - { static_cast(ExpressionType::OPERATOR_NULLIF), "OPERATOR_NULLIF" }, - { static_cast(ExpressionType::OPERATOR_COALESCE), "OPERATOR_COALESCE" }, - { static_cast(ExpressionType::ARRAY_EXTRACT), "ARRAY_EXTRACT" }, - { static_cast(ExpressionType::ARRAY_SLICE), "ARRAY_SLICE" }, - { static_cast(ExpressionType::STRUCT_EXTRACT), "STRUCT_EXTRACT" }, - { static_cast(ExpressionType::ARRAY_CONSTRUCTOR), "ARRAY_CONSTRUCTOR" }, - { static_cast(ExpressionType::ARROW), "ARROW" }, - { static_cast(ExpressionType::SUBQUERY), "SUBQUERY" }, - { static_cast(ExpressionType::STAR), "STAR" }, - { static_cast(ExpressionType::TABLE_STAR), "TABLE_STAR" }, - { static_cast(ExpressionType::PLACEHOLDER), "PLACEHOLDER" }, - { static_cast(ExpressionType::COLUMN_REF), "COLUMN_REF" }, - { static_cast(ExpressionType::FUNCTION_REF), "FUNCTION_REF" }, - { static_cast(ExpressionType::TABLE_REF), "TABLE_REF" }, - { static_cast(ExpressionType::LAMBDA_REF), "LAMBDA_REF" }, - { static_cast(ExpressionType::CAST), "CAST" }, - { static_cast(ExpressionType::BOUND_REF), "BOUND_REF" }, - { static_cast(ExpressionType::BOUND_COLUMN_REF), "BOUND_COLUMN_REF" }, - { static_cast(ExpressionType::BOUND_UNNEST), "BOUND_UNNEST" }, - { static_cast(ExpressionType::COLLATE), "COLLATE" }, - { static_cast(ExpressionType::LAMBDA), "LAMBDA" }, - { static_cast(ExpressionType::POSITIONAL_REFERENCE), "POSITIONAL_REFERENCE" }, - { static_cast(ExpressionType::BOUND_LAMBDA_REF), "BOUND_LAMBDA_REF" }, - { static_cast(ExpressionType::BOUND_EXPANDED), "BOUND_EXPANDED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExpressionType value) { - return StringUtil::EnumToString(GetExpressionTypeValues(), 69, "ExpressionType", static_cast(value)); -} - -template<> -ExpressionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExpressionTypeValues(), 69, "ExpressionType", value)); -} - -const StringUtil::EnumStringLiteral *GetExtensionABITypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtensionABIType::UNKNOWN), "UNKNOWN" }, - { static_cast(ExtensionABIType::CPP), "CPP" }, - { static_cast(ExtensionABIType::C_STRUCT), "C_STRUCT" }, - { static_cast(ExtensionABIType::C_STRUCT_UNSTABLE), "C_STRUCT_UNSTABLE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtensionABIType value) { - return StringUtil::EnumToString(GetExtensionABITypeValues(), 4, "ExtensionABIType", static_cast(value)); -} - -template<> -ExtensionABIType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtensionABITypeValues(), 4, "ExtensionABIType", value)); -} - -const StringUtil::EnumStringLiteral *GetExtensionInstallModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtensionInstallMode::UNKNOWN), "UNKNOWN" }, - { static_cast(ExtensionInstallMode::REPOSITORY), "REPOSITORY" }, - { static_cast(ExtensionInstallMode::CUSTOM_PATH), "CUSTOM_PATH" }, - { static_cast(ExtensionInstallMode::STATICALLY_LINKED), "STATICALLY_LINKED" }, - { static_cast(ExtensionInstallMode::NOT_INSTALLED), "NOT_INSTALLED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtensionInstallMode value) { - return StringUtil::EnumToString(GetExtensionInstallModeValues(), 5, "ExtensionInstallMode", static_cast(value)); -} - -template<> -ExtensionInstallMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtensionInstallModeValues(), 5, "ExtensionInstallMode", value)); -} - -const StringUtil::EnumStringLiteral *GetExtensionLoadResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtensionLoadResult::LOADED_EXTENSION), "LOADED_EXTENSION" }, - { static_cast(ExtensionLoadResult::EXTENSION_UNKNOWN), "EXTENSION_UNKNOWN" }, - { static_cast(ExtensionLoadResult::NOT_LOADED), "NOT_LOADED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtensionLoadResult value) { - return StringUtil::EnumToString(GetExtensionLoadResultValues(), 3, "ExtensionLoadResult", static_cast(value)); -} - -template<> -ExtensionLoadResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtensionLoadResultValues(), 3, "ExtensionLoadResult", value)); -} - -const StringUtil::EnumStringLiteral *GetExtensionUpdateResultTagValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtensionUpdateResultTag::UNKNOWN), "UNKNOWN" }, - { static_cast(ExtensionUpdateResultTag::NO_UPDATE_AVAILABLE), "NO_UPDATE_AVAILABLE" }, - { static_cast(ExtensionUpdateResultTag::NOT_A_REPOSITORY), "NOT_A_REPOSITORY" }, - { static_cast(ExtensionUpdateResultTag::NOT_INSTALLED), "NOT_INSTALLED" }, - { static_cast(ExtensionUpdateResultTag::STATICALLY_LOADED), "STATICALLY_LOADED" }, - { static_cast(ExtensionUpdateResultTag::MISSING_INSTALL_INFO), "MISSING_INSTALL_INFO" }, - { static_cast(ExtensionUpdateResultTag::REDOWNLOADED), "REDOWNLOADED" }, - { static_cast(ExtensionUpdateResultTag::UPDATED), "UPDATED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtensionUpdateResultTag value) { - return StringUtil::EnumToString(GetExtensionUpdateResultTagValues(), 8, "ExtensionUpdateResultTag", static_cast(value)); -} - -template<> -ExtensionUpdateResultTag EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtensionUpdateResultTagValues(), 8, "ExtensionUpdateResultTag", value)); -} - -const StringUtil::EnumStringLiteral *GetExtraDropInfoTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtraDropInfoType::INVALID), "INVALID" }, - { static_cast(ExtraDropInfoType::SECRET_INFO), "SECRET_INFO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtraDropInfoType value) { - return StringUtil::EnumToString(GetExtraDropInfoTypeValues(), 2, "ExtraDropInfoType", static_cast(value)); -} - -template<> -ExtraDropInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraDropInfoTypeValues(), 2, "ExtraDropInfoType", value)); -} - -const StringUtil::EnumStringLiteral *GetExtraTypeInfoTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ExtraTypeInfoType::INVALID_TYPE_INFO), "INVALID_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::GENERIC_TYPE_INFO), "GENERIC_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::DECIMAL_TYPE_INFO), "DECIMAL_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::STRING_TYPE_INFO), "STRING_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::LIST_TYPE_INFO), "LIST_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::STRUCT_TYPE_INFO), "STRUCT_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::ENUM_TYPE_INFO), "ENUM_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::USER_TYPE_INFO), "USER_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), "AGGREGATE_STATE_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::ARRAY_TYPE_INFO), "ARRAY_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::ANY_TYPE_INFO), "ANY_TYPE_INFO" }, - { static_cast(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), "INTEGER_LITERAL_TYPE_INFO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ExtraTypeInfoType value) { - return StringUtil::EnumToString(GetExtraTypeInfoTypeValues(), 12, "ExtraTypeInfoType", static_cast(value)); -} - -template<> -ExtraTypeInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetExtraTypeInfoTypeValues(), 12, "ExtraTypeInfoType", value)); -} - -const StringUtil::EnumStringLiteral *GetFileBufferTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FileBufferType::BLOCK), "BLOCK" }, - { static_cast(FileBufferType::MANAGED_BUFFER), "MANAGED_BUFFER" }, - { static_cast(FileBufferType::TINY_BUFFER), "TINY_BUFFER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FileBufferType value) { - return StringUtil::EnumToString(GetFileBufferTypeValues(), 3, "FileBufferType", static_cast(value)); -} - -template<> -FileBufferType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFileBufferTypeValues(), 3, "FileBufferType", value)); -} - -const StringUtil::EnumStringLiteral *GetFileCompressionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FileCompressionType::AUTO_DETECT), "AUTO_DETECT" }, - { static_cast(FileCompressionType::UNCOMPRESSED), "UNCOMPRESSED" }, - { static_cast(FileCompressionType::GZIP), "GZIP" }, - { static_cast(FileCompressionType::ZSTD), "ZSTD" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FileCompressionType value) { - return StringUtil::EnumToString(GetFileCompressionTypeValues(), 4, "FileCompressionType", static_cast(value)); -} - -template<> -FileCompressionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFileCompressionTypeValues(), 4, "FileCompressionType", value)); -} - -const StringUtil::EnumStringLiteral *GetFileExpandResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FileExpandResult::NO_FILES), "NO_FILES" }, - { static_cast(FileExpandResult::SINGLE_FILE), "SINGLE_FILE" }, - { static_cast(FileExpandResult::MULTIPLE_FILES), "MULTIPLE_FILES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FileExpandResult value) { - return StringUtil::EnumToString(GetFileExpandResultValues(), 3, "FileExpandResult", static_cast(value)); -} - -template<> -FileExpandResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFileExpandResultValues(), 3, "FileExpandResult", value)); -} - -const StringUtil::EnumStringLiteral *GetFileGlobOptionsValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FileGlobOptions::DISALLOW_EMPTY), "DISALLOW_EMPTY" }, - { static_cast(FileGlobOptions::ALLOW_EMPTY), "ALLOW_EMPTY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FileGlobOptions value) { - return StringUtil::EnumToString(GetFileGlobOptionsValues(), 2, "FileGlobOptions", static_cast(value)); -} - -template<> -FileGlobOptions EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFileGlobOptionsValues(), 2, "FileGlobOptions", value)); -} - -const StringUtil::EnumStringLiteral *GetFileLockTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FileLockType::NO_LOCK), "NO_LOCK" }, - { static_cast(FileLockType::READ_LOCK), "READ_LOCK" }, - { static_cast(FileLockType::WRITE_LOCK), "WRITE_LOCK" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FileLockType value) { - return StringUtil::EnumToString(GetFileLockTypeValues(), 3, "FileLockType", static_cast(value)); -} - -template<> -FileLockType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFileLockTypeValues(), 3, "FileLockType", value)); -} - -const StringUtil::EnumStringLiteral *GetFilterPropagateResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FilterPropagateResult::NO_PRUNING_POSSIBLE), "NO_PRUNING_POSSIBLE" }, - { static_cast(FilterPropagateResult::FILTER_ALWAYS_TRUE), "FILTER_ALWAYS_TRUE" }, - { static_cast(FilterPropagateResult::FILTER_ALWAYS_FALSE), "FILTER_ALWAYS_FALSE" }, - { static_cast(FilterPropagateResult::FILTER_TRUE_OR_NULL), "FILTER_TRUE_OR_NULL" }, - { static_cast(FilterPropagateResult::FILTER_FALSE_OR_NULL), "FILTER_FALSE_OR_NULL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FilterPropagateResult value) { - return StringUtil::EnumToString(GetFilterPropagateResultValues(), 5, "FilterPropagateResult", static_cast(value)); -} - -template<> -FilterPropagateResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFilterPropagateResultValues(), 5, "FilterPropagateResult", value)); -} - -const StringUtil::EnumStringLiteral *GetForeignKeyTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE), "FK_TYPE_PRIMARY_KEY_TABLE" }, - { static_cast(ForeignKeyType::FK_TYPE_FOREIGN_KEY_TABLE), "FK_TYPE_FOREIGN_KEY_TABLE" }, - { static_cast(ForeignKeyType::FK_TYPE_SELF_REFERENCE_TABLE), "FK_TYPE_SELF_REFERENCE_TABLE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ForeignKeyType value) { - return StringUtil::EnumToString(GetForeignKeyTypeValues(), 3, "ForeignKeyType", static_cast(value)); -} - -template<> -ForeignKeyType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetForeignKeyTypeValues(), 3, "ForeignKeyType", value)); -} - -const StringUtil::EnumStringLiteral *GetFunctionCollationHandlingValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FunctionCollationHandling::PROPAGATE_COLLATIONS), "PROPAGATE_COLLATIONS" }, - { static_cast(FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS), "PUSH_COMBINABLE_COLLATIONS" }, - { static_cast(FunctionCollationHandling::IGNORE_COLLATIONS), "IGNORE_COLLATIONS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FunctionCollationHandling value) { - return StringUtil::EnumToString(GetFunctionCollationHandlingValues(), 3, "FunctionCollationHandling", static_cast(value)); -} - -template<> -FunctionCollationHandling EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFunctionCollationHandlingValues(), 3, "FunctionCollationHandling", value)); -} - -const StringUtil::EnumStringLiteral *GetFunctionErrorsValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FunctionErrors::CANNOT_ERROR), "CANNOT_ERROR" }, - { static_cast(FunctionErrors::CAN_THROW_RUNTIME_ERROR), "CAN_THROW_RUNTIME_ERROR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FunctionErrors value) { - return StringUtil::EnumToString(GetFunctionErrorsValues(), 2, "FunctionErrors", static_cast(value)); -} - -template<> -FunctionErrors EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFunctionErrorsValues(), 2, "FunctionErrors", value)); -} - -const StringUtil::EnumStringLiteral *GetFunctionNullHandlingValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FunctionNullHandling::DEFAULT_NULL_HANDLING), "DEFAULT_NULL_HANDLING" }, - { static_cast(FunctionNullHandling::SPECIAL_HANDLING), "SPECIAL_HANDLING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FunctionNullHandling value) { - return StringUtil::EnumToString(GetFunctionNullHandlingValues(), 2, "FunctionNullHandling", static_cast(value)); -} - -template<> -FunctionNullHandling EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFunctionNullHandlingValues(), 2, "FunctionNullHandling", value)); -} - -const StringUtil::EnumStringLiteral *GetFunctionStabilityValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(FunctionStability::CONSISTENT), "CONSISTENT" }, - { static_cast(FunctionStability::VOLATILE), "VOLATILE" }, - { static_cast(FunctionStability::CONSISTENT_WITHIN_QUERY), "CONSISTENT_WITHIN_QUERY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(FunctionStability value) { - return StringUtil::EnumToString(GetFunctionStabilityValues(), 3, "FunctionStability", static_cast(value)); -} - -template<> -FunctionStability EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetFunctionStabilityValues(), 3, "FunctionStability", value)); -} - -const StringUtil::EnumStringLiteral *GetGateStatusValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(GateStatus::GATE_NOT_SET), "GATE_NOT_SET" }, - { static_cast(GateStatus::GATE_SET), "GATE_SET" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(GateStatus value) { - return StringUtil::EnumToString(GetGateStatusValues(), 2, "GateStatus", static_cast(value)); -} - -template<> -GateStatus EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetGateStatusValues(), 2, "GateStatus", value)); -} - -const StringUtil::EnumStringLiteral *GetHLLStorageTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(HLLStorageType::HLL_V1), "HLL_V1" }, - { static_cast(HLLStorageType::HLL_V2), "HLL_V2" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(HLLStorageType value) { - return StringUtil::EnumToString(GetHLLStorageTypeValues(), 2, "HLLStorageType", static_cast(value)); -} - -template<> -HLLStorageType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetHLLStorageTypeValues(), 2, "HLLStorageType", value)); -} - -const StringUtil::EnumStringLiteral *GetIndexConstraintTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(IndexConstraintType::NONE), "NONE" }, - { static_cast(IndexConstraintType::UNIQUE), "UNIQUE" }, - { static_cast(IndexConstraintType::PRIMARY), "PRIMARY" }, - { static_cast(IndexConstraintType::FOREIGN), "FOREIGN" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(IndexConstraintType value) { - return StringUtil::EnumToString(GetIndexConstraintTypeValues(), 4, "IndexConstraintType", static_cast(value)); -} - -template<> -IndexConstraintType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetIndexConstraintTypeValues(), 4, "IndexConstraintType", value)); -} - -const StringUtil::EnumStringLiteral *GetInsertColumnOrderValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(InsertColumnOrder::INSERT_BY_POSITION), "INSERT_BY_POSITION" }, - { static_cast(InsertColumnOrder::INSERT_BY_NAME), "INSERT_BY_NAME" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(InsertColumnOrder value) { - return StringUtil::EnumToString(GetInsertColumnOrderValues(), 2, "InsertColumnOrder", static_cast(value)); -} - -template<> -InsertColumnOrder EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetInsertColumnOrderValues(), 2, "InsertColumnOrder", value)); -} - -const StringUtil::EnumStringLiteral *GetInterruptModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(InterruptMode::NO_INTERRUPTS), "NO_INTERRUPTS" }, - { static_cast(InterruptMode::TASK), "TASK" }, - { static_cast(InterruptMode::BLOCKING), "BLOCKING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(InterruptMode value) { - return StringUtil::EnumToString(GetInterruptModeValues(), 3, "InterruptMode", static_cast(value)); -} - -template<> -InterruptMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetInterruptModeValues(), 3, "InterruptMode", value)); -} - -const StringUtil::EnumStringLiteral *GetJoinRefTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(JoinRefType::REGULAR), "REGULAR" }, - { static_cast(JoinRefType::NATURAL), "NATURAL" }, - { static_cast(JoinRefType::CROSS), "CROSS" }, - { static_cast(JoinRefType::POSITIONAL), "POSITIONAL" }, - { static_cast(JoinRefType::ASOF), "ASOF" }, - { static_cast(JoinRefType::DEPENDENT), "DEPENDENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(JoinRefType value) { - return StringUtil::EnumToString(GetJoinRefTypeValues(), 6, "JoinRefType", static_cast(value)); -} - -template<> -JoinRefType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetJoinRefTypeValues(), 6, "JoinRefType", value)); -} - -const StringUtil::EnumStringLiteral *GetJoinTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(JoinType::INVALID), "INVALID" }, - { static_cast(JoinType::LEFT), "LEFT" }, - { static_cast(JoinType::RIGHT), "RIGHT" }, - { static_cast(JoinType::INNER), "INNER" }, - { static_cast(JoinType::OUTER), "FULL" }, - { static_cast(JoinType::SEMI), "SEMI" }, - { static_cast(JoinType::ANTI), "ANTI" }, - { static_cast(JoinType::MARK), "MARK" }, - { static_cast(JoinType::SINGLE), "SINGLE" }, - { static_cast(JoinType::RIGHT_SEMI), "RIGHT_SEMI" }, - { static_cast(JoinType::RIGHT_ANTI), "RIGHT_ANTI" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(JoinType value) { - return StringUtil::EnumToString(GetJoinTypeValues(), 11, "JoinType", static_cast(value)); -} - -template<> -JoinType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetJoinTypeValues(), 11, "JoinType", value)); -} - -const StringUtil::EnumStringLiteral *GetKeywordCategoryValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(KeywordCategory::KEYWORD_RESERVED), "KEYWORD_RESERVED" }, - { static_cast(KeywordCategory::KEYWORD_UNRESERVED), "KEYWORD_UNRESERVED" }, - { static_cast(KeywordCategory::KEYWORD_TYPE_FUNC), "KEYWORD_TYPE_FUNC" }, - { static_cast(KeywordCategory::KEYWORD_COL_NAME), "KEYWORD_COL_NAME" }, - { static_cast(KeywordCategory::KEYWORD_NONE), "KEYWORD_NONE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(KeywordCategory value) { - return StringUtil::EnumToString(GetKeywordCategoryValues(), 5, "KeywordCategory", static_cast(value)); -} - -template<> -KeywordCategory EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetKeywordCategoryValues(), 5, "KeywordCategory", value)); -} - -const StringUtil::EnumStringLiteral *GetLimitNodeTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(LimitNodeType::UNSET), "UNSET" }, - { static_cast(LimitNodeType::CONSTANT_VALUE), "CONSTANT_VALUE" }, - { static_cast(LimitNodeType::CONSTANT_PERCENTAGE), "CONSTANT_PERCENTAGE" }, - { static_cast(LimitNodeType::EXPRESSION_VALUE), "EXPRESSION_VALUE" }, - { static_cast(LimitNodeType::EXPRESSION_PERCENTAGE), "EXPRESSION_PERCENTAGE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(LimitNodeType value) { - return StringUtil::EnumToString(GetLimitNodeTypeValues(), 5, "LimitNodeType", static_cast(value)); -} - -template<> -LimitNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLimitNodeTypeValues(), 5, "LimitNodeType", value)); -} - -const StringUtil::EnumStringLiteral *GetLoadTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(LoadType::LOAD), "LOAD" }, - { static_cast(LoadType::INSTALL), "INSTALL" }, - { static_cast(LoadType::FORCE_INSTALL), "FORCE_INSTALL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(LoadType value) { - return StringUtil::EnumToString(GetLoadTypeValues(), 3, "LoadType", static_cast(value)); -} - -template<> -LoadType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLoadTypeValues(), 3, "LoadType", value)); -} - -const StringUtil::EnumStringLiteral *GetLogicalOperatorTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(LogicalOperatorType::LOGICAL_INVALID), "LOGICAL_INVALID" }, - { static_cast(LogicalOperatorType::LOGICAL_PROJECTION), "LOGICAL_PROJECTION" }, - { static_cast(LogicalOperatorType::LOGICAL_FILTER), "LOGICAL_FILTER" }, - { static_cast(LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY), "LOGICAL_AGGREGATE_AND_GROUP_BY" }, - { static_cast(LogicalOperatorType::LOGICAL_WINDOW), "LOGICAL_WINDOW" }, - { static_cast(LogicalOperatorType::LOGICAL_UNNEST), "LOGICAL_UNNEST" }, - { static_cast(LogicalOperatorType::LOGICAL_LIMIT), "LOGICAL_LIMIT" }, - { static_cast(LogicalOperatorType::LOGICAL_ORDER_BY), "LOGICAL_ORDER_BY" }, - { static_cast(LogicalOperatorType::LOGICAL_TOP_N), "LOGICAL_TOP_N" }, - { static_cast(LogicalOperatorType::LOGICAL_COPY_TO_FILE), "LOGICAL_COPY_TO_FILE" }, - { static_cast(LogicalOperatorType::LOGICAL_DISTINCT), "LOGICAL_DISTINCT" }, - { static_cast(LogicalOperatorType::LOGICAL_SAMPLE), "LOGICAL_SAMPLE" }, - { static_cast(LogicalOperatorType::LOGICAL_PIVOT), "LOGICAL_PIVOT" }, - { static_cast(LogicalOperatorType::LOGICAL_COPY_DATABASE), "LOGICAL_COPY_DATABASE" }, - { static_cast(LogicalOperatorType::LOGICAL_GET), "LOGICAL_GET" }, - { static_cast(LogicalOperatorType::LOGICAL_CHUNK_GET), "LOGICAL_CHUNK_GET" }, - { static_cast(LogicalOperatorType::LOGICAL_DELIM_GET), "LOGICAL_DELIM_GET" }, - { static_cast(LogicalOperatorType::LOGICAL_EXPRESSION_GET), "LOGICAL_EXPRESSION_GET" }, - { static_cast(LogicalOperatorType::LOGICAL_DUMMY_SCAN), "LOGICAL_DUMMY_SCAN" }, - { static_cast(LogicalOperatorType::LOGICAL_EMPTY_RESULT), "LOGICAL_EMPTY_RESULT" }, - { static_cast(LogicalOperatorType::LOGICAL_CTE_REF), "LOGICAL_CTE_REF" }, - { static_cast(LogicalOperatorType::LOGICAL_JOIN), "LOGICAL_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_DELIM_JOIN), "LOGICAL_DELIM_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_COMPARISON_JOIN), "LOGICAL_COMPARISON_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_ANY_JOIN), "LOGICAL_ANY_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_CROSS_PRODUCT), "LOGICAL_CROSS_PRODUCT" }, - { static_cast(LogicalOperatorType::LOGICAL_POSITIONAL_JOIN), "LOGICAL_POSITIONAL_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_ASOF_JOIN), "LOGICAL_ASOF_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_DEPENDENT_JOIN), "LOGICAL_DEPENDENT_JOIN" }, - { static_cast(LogicalOperatorType::LOGICAL_UNION), "LOGICAL_UNION" }, - { static_cast(LogicalOperatorType::LOGICAL_EXCEPT), "LOGICAL_EXCEPT" }, - { static_cast(LogicalOperatorType::LOGICAL_INTERSECT), "LOGICAL_INTERSECT" }, - { static_cast(LogicalOperatorType::LOGICAL_RECURSIVE_CTE), "LOGICAL_RECURSIVE_CTE" }, - { static_cast(LogicalOperatorType::LOGICAL_MATERIALIZED_CTE), "LOGICAL_MATERIALIZED_CTE" }, - { static_cast(LogicalOperatorType::LOGICAL_INSERT), "LOGICAL_INSERT" }, - { static_cast(LogicalOperatorType::LOGICAL_DELETE), "LOGICAL_DELETE" }, - { static_cast(LogicalOperatorType::LOGICAL_UPDATE), "LOGICAL_UPDATE" }, - { static_cast(LogicalOperatorType::LOGICAL_ALTER), "LOGICAL_ALTER" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_TABLE), "LOGICAL_CREATE_TABLE" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_INDEX), "LOGICAL_CREATE_INDEX" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_SEQUENCE), "LOGICAL_CREATE_SEQUENCE" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_VIEW), "LOGICAL_CREATE_VIEW" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_SCHEMA), "LOGICAL_CREATE_SCHEMA" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_MACRO), "LOGICAL_CREATE_MACRO" }, - { static_cast(LogicalOperatorType::LOGICAL_DROP), "LOGICAL_DROP" }, - { static_cast(LogicalOperatorType::LOGICAL_PRAGMA), "LOGICAL_PRAGMA" }, - { static_cast(LogicalOperatorType::LOGICAL_TRANSACTION), "LOGICAL_TRANSACTION" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_TYPE), "LOGICAL_CREATE_TYPE" }, - { static_cast(LogicalOperatorType::LOGICAL_ATTACH), "LOGICAL_ATTACH" }, - { static_cast(LogicalOperatorType::LOGICAL_DETACH), "LOGICAL_DETACH" }, - { static_cast(LogicalOperatorType::LOGICAL_EXPLAIN), "LOGICAL_EXPLAIN" }, - { static_cast(LogicalOperatorType::LOGICAL_PREPARE), "LOGICAL_PREPARE" }, - { static_cast(LogicalOperatorType::LOGICAL_EXECUTE), "LOGICAL_EXECUTE" }, - { static_cast(LogicalOperatorType::LOGICAL_EXPORT), "LOGICAL_EXPORT" }, - { static_cast(LogicalOperatorType::LOGICAL_VACUUM), "LOGICAL_VACUUM" }, - { static_cast(LogicalOperatorType::LOGICAL_SET), "LOGICAL_SET" }, - { static_cast(LogicalOperatorType::LOGICAL_LOAD), "LOGICAL_LOAD" }, - { static_cast(LogicalOperatorType::LOGICAL_RESET), "LOGICAL_RESET" }, - { static_cast(LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS), "LOGICAL_UPDATE_EXTENSIONS" }, - { static_cast(LogicalOperatorType::LOGICAL_CREATE_SECRET), "LOGICAL_CREATE_SECRET" }, - { static_cast(LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR), "LOGICAL_EXTENSION_OPERATOR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(LogicalOperatorType value) { - return StringUtil::EnumToString(GetLogicalOperatorTypeValues(), 61, "LogicalOperatorType", static_cast(value)); -} - -template<> -LogicalOperatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalOperatorTypeValues(), 61, "LogicalOperatorType", value)); -} - -const StringUtil::EnumStringLiteral *GetLogicalTypeIdValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(LogicalTypeId::INVALID), "INVALID" }, - { static_cast(LogicalTypeId::SQLNULL), "NULL" }, - { static_cast(LogicalTypeId::UNKNOWN), "UNKNOWN" }, - { static_cast(LogicalTypeId::ANY), "ANY" }, - { static_cast(LogicalTypeId::USER), "USER" }, - { static_cast(LogicalTypeId::BOOLEAN), "BOOLEAN" }, - { static_cast(LogicalTypeId::TINYINT), "TINYINT" }, - { static_cast(LogicalTypeId::SMALLINT), "SMALLINT" }, - { static_cast(LogicalTypeId::INTEGER), "INTEGER" }, - { static_cast(LogicalTypeId::BIGINT), "BIGINT" }, - { static_cast(LogicalTypeId::DATE), "DATE" }, - { static_cast(LogicalTypeId::TIME), "TIME" }, - { static_cast(LogicalTypeId::TIMESTAMP_SEC), "TIMESTAMP_S" }, - { static_cast(LogicalTypeId::TIMESTAMP_MS), "TIMESTAMP_MS" }, - { static_cast(LogicalTypeId::TIMESTAMP), "TIMESTAMP" }, - { static_cast(LogicalTypeId::TIMESTAMP_NS), "TIMESTAMP_NS" }, - { static_cast(LogicalTypeId::DECIMAL), "DECIMAL" }, - { static_cast(LogicalTypeId::FLOAT), "FLOAT" }, - { static_cast(LogicalTypeId::DOUBLE), "DOUBLE" }, - { static_cast(LogicalTypeId::CHAR), "CHAR" }, - { static_cast(LogicalTypeId::VARCHAR), "VARCHAR" }, - { static_cast(LogicalTypeId::BLOB), "BLOB" }, - { static_cast(LogicalTypeId::INTERVAL), "INTERVAL" }, - { static_cast(LogicalTypeId::UTINYINT), "UTINYINT" }, - { static_cast(LogicalTypeId::USMALLINT), "USMALLINT" }, - { static_cast(LogicalTypeId::UINTEGER), "UINTEGER" }, - { static_cast(LogicalTypeId::UBIGINT), "UBIGINT" }, - { static_cast(LogicalTypeId::TIMESTAMP_TZ), "TIMESTAMP WITH TIME ZONE" }, - { static_cast(LogicalTypeId::TIME_TZ), "TIME WITH TIME ZONE" }, - { static_cast(LogicalTypeId::BIT), "BIT" }, - { static_cast(LogicalTypeId::STRING_LITERAL), "STRING_LITERAL" }, - { static_cast(LogicalTypeId::INTEGER_LITERAL), "INTEGER_LITERAL" }, - { static_cast(LogicalTypeId::VARINT), "VARINT" }, - { static_cast(LogicalTypeId::UHUGEINT), "UHUGEINT" }, - { static_cast(LogicalTypeId::HUGEINT), "HUGEINT" }, - { static_cast(LogicalTypeId::POINTER), "POINTER" }, - { static_cast(LogicalTypeId::VALIDITY), "VALIDITY" }, - { static_cast(LogicalTypeId::UUID), "UUID" }, - { static_cast(LogicalTypeId::STRUCT), "STRUCT" }, - { static_cast(LogicalTypeId::LIST), "LIST" }, - { static_cast(LogicalTypeId::MAP), "MAP" }, - { static_cast(LogicalTypeId::TABLE), "TABLE" }, - { static_cast(LogicalTypeId::ENUM), "ENUM" }, - { static_cast(LogicalTypeId::AGGREGATE_STATE), "AGGREGATE_STATE" }, - { static_cast(LogicalTypeId::LAMBDA), "LAMBDA" }, - { static_cast(LogicalTypeId::UNION), "UNION" }, - { static_cast(LogicalTypeId::ARRAY), "ARRAY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(LogicalTypeId value) { - return StringUtil::EnumToString(GetLogicalTypeIdValues(), 47, "LogicalTypeId", static_cast(value)); -} - -template<> -LogicalTypeId EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLogicalTypeIdValues(), 47, "LogicalTypeId", value)); -} - -const StringUtil::EnumStringLiteral *GetLookupResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(LookupResultType::LOOKUP_MISS), "LOOKUP_MISS" }, - { static_cast(LookupResultType::LOOKUP_HIT), "LOOKUP_HIT" }, - { static_cast(LookupResultType::LOOKUP_NULL), "LOOKUP_NULL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(LookupResultType value) { - return StringUtil::EnumToString(GetLookupResultTypeValues(), 3, "LookupResultType", static_cast(value)); -} - -template<> -LookupResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetLookupResultTypeValues(), 3, "LookupResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetMacroTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MacroType::VOID_MACRO), "VOID_MACRO" }, - { static_cast(MacroType::TABLE_MACRO), "TABLE_MACRO" }, - { static_cast(MacroType::SCALAR_MACRO), "SCALAR_MACRO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MacroType value) { - return StringUtil::EnumToString(GetMacroTypeValues(), 3, "MacroType", static_cast(value)); -} - -template<> -MacroType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMacroTypeValues(), 3, "MacroType", value)); -} - -const StringUtil::EnumStringLiteral *GetMapInvalidReasonValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MapInvalidReason::VALID), "VALID" }, - { static_cast(MapInvalidReason::NULL_KEY), "NULL_KEY" }, - { static_cast(MapInvalidReason::DUPLICATE_KEY), "DUPLICATE_KEY" }, - { static_cast(MapInvalidReason::NOT_ALIGNED), "NOT_ALIGNED" }, - { static_cast(MapInvalidReason::INVALID_PARAMS), "INVALID_PARAMS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MapInvalidReason value) { - return StringUtil::EnumToString(GetMapInvalidReasonValues(), 5, "MapInvalidReason", static_cast(value)); -} - -template<> -MapInvalidReason EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMapInvalidReasonValues(), 5, "MapInvalidReason", value)); -} - -const StringUtil::EnumStringLiteral *GetMemoryTagValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MemoryTag::BASE_TABLE), "BASE_TABLE" }, - { static_cast(MemoryTag::HASH_TABLE), "HASH_TABLE" }, - { static_cast(MemoryTag::PARQUET_READER), "PARQUET_READER" }, - { static_cast(MemoryTag::CSV_READER), "CSV_READER" }, - { static_cast(MemoryTag::ORDER_BY), "ORDER_BY" }, - { static_cast(MemoryTag::ART_INDEX), "ART_INDEX" }, - { static_cast(MemoryTag::COLUMN_DATA), "COLUMN_DATA" }, - { static_cast(MemoryTag::METADATA), "METADATA" }, - { static_cast(MemoryTag::OVERFLOW_STRINGS), "OVERFLOW_STRINGS" }, - { static_cast(MemoryTag::IN_MEMORY_TABLE), "IN_MEMORY_TABLE" }, - { static_cast(MemoryTag::ALLOCATOR), "ALLOCATOR" }, - { static_cast(MemoryTag::EXTENSION), "EXTENSION" }, - { static_cast(MemoryTag::TRANSACTION), "TRANSACTION" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MemoryTag value) { - return StringUtil::EnumToString(GetMemoryTagValues(), 13, "MemoryTag", static_cast(value)); -} - -template<> -MemoryTag EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMemoryTagValues(), 13, "MemoryTag", value)); -} - -const StringUtil::EnumStringLiteral *GetMetaPipelineTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetaPipelineType::REGULAR), "REGULAR" }, - { static_cast(MetaPipelineType::JOIN_BUILD), "JOIN_BUILD" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MetaPipelineType value) { - return StringUtil::EnumToString(GetMetaPipelineTypeValues(), 2, "MetaPipelineType", static_cast(value)); -} - -template<> -MetaPipelineType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetaPipelineTypeValues(), 2, "MetaPipelineType", value)); -} - -const StringUtil::EnumStringLiteral *GetMetricsTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(MetricsType::QUERY_NAME), "QUERY_NAME" }, - { static_cast(MetricsType::BLOCKED_THREAD_TIME), "BLOCKED_THREAD_TIME" }, - { static_cast(MetricsType::CPU_TIME), "CPU_TIME" }, - { static_cast(MetricsType::EXTRA_INFO), "EXTRA_INFO" }, - { static_cast(MetricsType::CUMULATIVE_CARDINALITY), "CUMULATIVE_CARDINALITY" }, - { static_cast(MetricsType::OPERATOR_TYPE), "OPERATOR_TYPE" }, - { static_cast(MetricsType::OPERATOR_CARDINALITY), "OPERATOR_CARDINALITY" }, - { static_cast(MetricsType::CUMULATIVE_ROWS_SCANNED), "CUMULATIVE_ROWS_SCANNED" }, - { static_cast(MetricsType::OPERATOR_ROWS_SCANNED), "OPERATOR_ROWS_SCANNED" }, - { static_cast(MetricsType::OPERATOR_TIMING), "OPERATOR_TIMING" }, - { static_cast(MetricsType::RESULT_SET_SIZE), "RESULT_SET_SIZE" }, - { static_cast(MetricsType::LATENCY), "LATENCY" }, - { static_cast(MetricsType::ROWS_RETURNED), "ROWS_RETURNED" }, - { static_cast(MetricsType::OPERATOR_NAME), "OPERATOR_NAME" }, - { static_cast(MetricsType::ALL_OPTIMIZERS), "ALL_OPTIMIZERS" }, - { static_cast(MetricsType::CUMULATIVE_OPTIMIZER_TIMING), "CUMULATIVE_OPTIMIZER_TIMING" }, - { static_cast(MetricsType::PLANNER), "PLANNER" }, - { static_cast(MetricsType::PLANNER_BINDING), "PLANNER_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER), "PHYSICAL_PLANNER" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING), "PHYSICAL_PLANNER_COLUMN_BINDING" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES), "PHYSICAL_PLANNER_RESOLVE_TYPES" }, - { static_cast(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN), "PHYSICAL_PLANNER_CREATE_PLAN" }, - { static_cast(MetricsType::OPTIMIZER_EXPRESSION_REWRITER), "OPTIMIZER_EXPRESSION_REWRITER" }, - { static_cast(MetricsType::OPTIMIZER_FILTER_PULLUP), "OPTIMIZER_FILTER_PULLUP" }, - { static_cast(MetricsType::OPTIMIZER_FILTER_PUSHDOWN), "OPTIMIZER_FILTER_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP), "OPTIMIZER_EMPTY_RESULT_PULLUP" }, - { static_cast(MetricsType::OPTIMIZER_CTE_FILTER_PUSHER), "OPTIMIZER_CTE_FILTER_PUSHER" }, - { static_cast(MetricsType::OPTIMIZER_REGEX_RANGE), "OPTIMIZER_REGEX_RANGE" }, - { static_cast(MetricsType::OPTIMIZER_IN_CLAUSE), "OPTIMIZER_IN_CLAUSE" }, - { static_cast(MetricsType::OPTIMIZER_JOIN_ORDER), "OPTIMIZER_JOIN_ORDER" }, - { static_cast(MetricsType::OPTIMIZER_DELIMINATOR), "OPTIMIZER_DELIMINATOR" }, - { static_cast(MetricsType::OPTIMIZER_UNNEST_REWRITER), "OPTIMIZER_UNNEST_REWRITER" }, - { static_cast(MetricsType::OPTIMIZER_UNUSED_COLUMNS), "OPTIMIZER_UNUSED_COLUMNS" }, - { static_cast(MetricsType::OPTIMIZER_STATISTICS_PROPAGATION), "OPTIMIZER_STATISTICS_PROPAGATION" }, - { static_cast(MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS), "OPTIMIZER_COMMON_SUBEXPRESSIONS" }, - { static_cast(MetricsType::OPTIMIZER_COMMON_AGGREGATE), "OPTIMIZER_COMMON_AGGREGATE" }, - { static_cast(MetricsType::OPTIMIZER_COLUMN_LIFETIME), "OPTIMIZER_COLUMN_LIFETIME" }, - { static_cast(MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE), "OPTIMIZER_BUILD_SIDE_PROBE_SIDE" }, - { static_cast(MetricsType::OPTIMIZER_LIMIT_PUSHDOWN), "OPTIMIZER_LIMIT_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_TOP_N), "OPTIMIZER_TOP_N" }, - { static_cast(MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION), "OPTIMIZER_COMPRESSED_MATERIALIZATION" }, - { static_cast(MetricsType::OPTIMIZER_DUPLICATE_GROUPS), "OPTIMIZER_DUPLICATE_GROUPS" }, - { static_cast(MetricsType::OPTIMIZER_REORDER_FILTER), "OPTIMIZER_REORDER_FILTER" }, - { static_cast(MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN), "OPTIMIZER_SAMPLING_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN), "OPTIMIZER_JOIN_FILTER_PUSHDOWN" }, - { static_cast(MetricsType::OPTIMIZER_EXTENSION), "OPTIMIZER_EXTENSION" }, - { static_cast(MetricsType::OPTIMIZER_MATERIALIZED_CTE), "OPTIMIZER_MATERIALIZED_CTE" }, - { static_cast(MetricsType::OPTIMIZER_SUM_REWRITER), "OPTIMIZER_SUM_REWRITER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(MetricsType value) { - return StringUtil::EnumToString(GetMetricsTypeValues(), 48, "MetricsType", static_cast(value)); -} - -template<> -MetricsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetMetricsTypeValues(), 48, "MetricsType", value)); -} - -const StringUtil::EnumStringLiteral *GetNTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(NType::PREFIX), "PREFIX" }, - { static_cast(NType::LEAF), "LEAF" }, - { static_cast(NType::NODE_4), "NODE_4" }, - { static_cast(NType::NODE_16), "NODE_16" }, - { static_cast(NType::NODE_48), "NODE_48" }, - { static_cast(NType::NODE_256), "NODE_256" }, - { static_cast(NType::LEAF_INLINED), "LEAF_INLINED" }, - { static_cast(NType::NODE_7_LEAF), "NODE_7_LEAF" }, - { static_cast(NType::NODE_15_LEAF), "NODE_15_LEAF" }, - { static_cast(NType::NODE_256_LEAF), "NODE_256_LEAF" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(NType value) { - return StringUtil::EnumToString(GetNTypeValues(), 10, "NType", static_cast(value)); -} - -template<> -NType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetNTypeValues(), 10, "NType", value)); -} - -const StringUtil::EnumStringLiteral *GetNewLineIdentifierValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(NewLineIdentifier::SINGLE_N), "SINGLE_N" }, - { static_cast(NewLineIdentifier::CARRY_ON), "CARRY_ON" }, - { static_cast(NewLineIdentifier::NOT_SET), "NOT_SET" }, - { static_cast(NewLineIdentifier::SINGLE_R), "SINGLE_R" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(NewLineIdentifier value) { - return StringUtil::EnumToString(GetNewLineIdentifierValues(), 4, "NewLineIdentifier", static_cast(value)); -} - -template<> -NewLineIdentifier EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetNewLineIdentifierValues(), 4, "NewLineIdentifier", value)); -} - -const StringUtil::EnumStringLiteral *GetOnConflictActionValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OnConflictAction::THROW), "THROW" }, - { static_cast(OnConflictAction::NOTHING), "NOTHING" }, - { static_cast(OnConflictAction::UPDATE), "UPDATE" }, - { static_cast(OnConflictAction::REPLACE), "REPLACE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OnConflictAction value) { - return StringUtil::EnumToString(GetOnConflictActionValues(), 4, "OnConflictAction", static_cast(value)); -} - -template<> -OnConflictAction EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOnConflictActionValues(), 4, "OnConflictAction", value)); -} - -const StringUtil::EnumStringLiteral *GetOnCreateConflictValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OnCreateConflict::ERROR_ON_CONFLICT), "ERROR_ON_CONFLICT" }, - { static_cast(OnCreateConflict::IGNORE_ON_CONFLICT), "IGNORE_ON_CONFLICT" }, - { static_cast(OnCreateConflict::REPLACE_ON_CONFLICT), "REPLACE_ON_CONFLICT" }, - { static_cast(OnCreateConflict::ALTER_ON_CONFLICT), "ALTER_ON_CONFLICT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OnCreateConflict value) { - return StringUtil::EnumToString(GetOnCreateConflictValues(), 4, "OnCreateConflict", static_cast(value)); -} - -template<> -OnCreateConflict EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOnCreateConflictValues(), 4, "OnCreateConflict", value)); -} - -const StringUtil::EnumStringLiteral *GetOnEntryNotFoundValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OnEntryNotFound::THROW_EXCEPTION), "THROW_EXCEPTION" }, - { static_cast(OnEntryNotFound::RETURN_NULL), "RETURN_NULL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OnEntryNotFound value) { - return StringUtil::EnumToString(GetOnEntryNotFoundValues(), 2, "OnEntryNotFound", static_cast(value)); -} - -template<> -OnEntryNotFound EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOnEntryNotFoundValues(), 2, "OnEntryNotFound", value)); -} - -const StringUtil::EnumStringLiteral *GetOperatorFinalizeResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OperatorFinalizeResultType::HAVE_MORE_OUTPUT), "HAVE_MORE_OUTPUT" }, - { static_cast(OperatorFinalizeResultType::FINISHED), "FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OperatorFinalizeResultType value) { - return StringUtil::EnumToString(GetOperatorFinalizeResultTypeValues(), 2, "OperatorFinalizeResultType", static_cast(value)); -} - -template<> -OperatorFinalizeResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOperatorFinalizeResultTypeValues(), 2, "OperatorFinalizeResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetOperatorResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OperatorResultType::NEED_MORE_INPUT), "NEED_MORE_INPUT" }, - { static_cast(OperatorResultType::HAVE_MORE_OUTPUT), "HAVE_MORE_OUTPUT" }, - { static_cast(OperatorResultType::FINISHED), "FINISHED" }, - { static_cast(OperatorResultType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OperatorResultType value) { - return StringUtil::EnumToString(GetOperatorResultTypeValues(), 4, "OperatorResultType", static_cast(value)); -} - -template<> -OperatorResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOperatorResultTypeValues(), 4, "OperatorResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetOptimizerTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OptimizerType::INVALID), "INVALID" }, - { static_cast(OptimizerType::EXPRESSION_REWRITER), "EXPRESSION_REWRITER" }, - { static_cast(OptimizerType::FILTER_PULLUP), "FILTER_PULLUP" }, - { static_cast(OptimizerType::FILTER_PUSHDOWN), "FILTER_PUSHDOWN" }, - { static_cast(OptimizerType::EMPTY_RESULT_PULLUP), "EMPTY_RESULT_PULLUP" }, - { static_cast(OptimizerType::CTE_FILTER_PUSHER), "CTE_FILTER_PUSHER" }, - { static_cast(OptimizerType::REGEX_RANGE), "REGEX_RANGE" }, - { static_cast(OptimizerType::IN_CLAUSE), "IN_CLAUSE" }, - { static_cast(OptimizerType::JOIN_ORDER), "JOIN_ORDER" }, - { static_cast(OptimizerType::DELIMINATOR), "DELIMINATOR" }, - { static_cast(OptimizerType::UNNEST_REWRITER), "UNNEST_REWRITER" }, - { static_cast(OptimizerType::UNUSED_COLUMNS), "UNUSED_COLUMNS" }, - { static_cast(OptimizerType::STATISTICS_PROPAGATION), "STATISTICS_PROPAGATION" }, - { static_cast(OptimizerType::COMMON_SUBEXPRESSIONS), "COMMON_SUBEXPRESSIONS" }, - { static_cast(OptimizerType::COMMON_AGGREGATE), "COMMON_AGGREGATE" }, - { static_cast(OptimizerType::COLUMN_LIFETIME), "COLUMN_LIFETIME" }, - { static_cast(OptimizerType::BUILD_SIDE_PROBE_SIDE), "BUILD_SIDE_PROBE_SIDE" }, - { static_cast(OptimizerType::LIMIT_PUSHDOWN), "LIMIT_PUSHDOWN" }, - { static_cast(OptimizerType::TOP_N), "TOP_N" }, - { static_cast(OptimizerType::COMPRESSED_MATERIALIZATION), "COMPRESSED_MATERIALIZATION" }, - { static_cast(OptimizerType::DUPLICATE_GROUPS), "DUPLICATE_GROUPS" }, - { static_cast(OptimizerType::REORDER_FILTER), "REORDER_FILTER" }, - { static_cast(OptimizerType::SAMPLING_PUSHDOWN), "SAMPLING_PUSHDOWN" }, - { static_cast(OptimizerType::JOIN_FILTER_PUSHDOWN), "JOIN_FILTER_PUSHDOWN" }, - { static_cast(OptimizerType::EXTENSION), "EXTENSION" }, - { static_cast(OptimizerType::MATERIALIZED_CTE), "MATERIALIZED_CTE" }, - { static_cast(OptimizerType::SUM_REWRITER), "SUM_REWRITER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OptimizerType value) { - return StringUtil::EnumToString(GetOptimizerTypeValues(), 27, "OptimizerType", static_cast(value)); -} - -template<> -OptimizerType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOptimizerTypeValues(), 27, "OptimizerType", value)); -} - -const StringUtil::EnumStringLiteral *GetOrderByNullTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OrderByNullType::INVALID), "INVALID" }, - { static_cast(OrderByNullType::ORDER_DEFAULT), "ORDER_DEFAULT" }, - { static_cast(OrderByNullType::ORDER_DEFAULT), "DEFAULT" }, - { static_cast(OrderByNullType::NULLS_FIRST), "NULLS_FIRST" }, - { static_cast(OrderByNullType::NULLS_FIRST), "NULLS FIRST" }, - { static_cast(OrderByNullType::NULLS_LAST), "NULLS_LAST" }, - { static_cast(OrderByNullType::NULLS_LAST), "NULLS LAST" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OrderByNullType value) { - return StringUtil::EnumToString(GetOrderByNullTypeValues(), 7, "OrderByNullType", static_cast(value)); -} - -template<> -OrderByNullType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOrderByNullTypeValues(), 7, "OrderByNullType", value)); -} - -const StringUtil::EnumStringLiteral *GetOrderPreservationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OrderPreservationType::NO_ORDER), "NO_ORDER" }, - { static_cast(OrderPreservationType::INSERTION_ORDER), "INSERTION_ORDER" }, - { static_cast(OrderPreservationType::FIXED_ORDER), "FIXED_ORDER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OrderPreservationType value) { - return StringUtil::EnumToString(GetOrderPreservationTypeValues(), 3, "OrderPreservationType", static_cast(value)); -} - -template<> -OrderPreservationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOrderPreservationTypeValues(), 3, "OrderPreservationType", value)); -} - -const StringUtil::EnumStringLiteral *GetOrderTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OrderType::INVALID), "INVALID" }, - { static_cast(OrderType::ORDER_DEFAULT), "ORDER_DEFAULT" }, - { static_cast(OrderType::ORDER_DEFAULT), "DEFAULT" }, - { static_cast(OrderType::ASCENDING), "ASCENDING" }, - { static_cast(OrderType::ASCENDING), "ASC" }, - { static_cast(OrderType::DESCENDING), "DESCENDING" }, - { static_cast(OrderType::DESCENDING), "DESC" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OrderType value) { - return StringUtil::EnumToString(GetOrderTypeValues(), 7, "OrderType", static_cast(value)); -} - -template<> -OrderType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOrderTypeValues(), 7, "OrderType", value)); -} - -const StringUtil::EnumStringLiteral *GetOutputStreamValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(OutputStream::STREAM_STDOUT), "STREAM_STDOUT" }, - { static_cast(OutputStream::STREAM_STDERR), "STREAM_STDERR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(OutputStream value) { - return StringUtil::EnumToString(GetOutputStreamValues(), 2, "OutputStream", static_cast(value)); -} - -template<> -OutputStream EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetOutputStreamValues(), 2, "OutputStream", value)); -} - -const StringUtil::EnumStringLiteral *GetParseInfoTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ParseInfoType::ALTER_INFO), "ALTER_INFO" }, - { static_cast(ParseInfoType::ATTACH_INFO), "ATTACH_INFO" }, - { static_cast(ParseInfoType::COPY_INFO), "COPY_INFO" }, - { static_cast(ParseInfoType::CREATE_INFO), "CREATE_INFO" }, - { static_cast(ParseInfoType::CREATE_SECRET_INFO), "CREATE_SECRET_INFO" }, - { static_cast(ParseInfoType::DETACH_INFO), "DETACH_INFO" }, - { static_cast(ParseInfoType::DROP_INFO), "DROP_INFO" }, - { static_cast(ParseInfoType::BOUND_EXPORT_DATA), "BOUND_EXPORT_DATA" }, - { static_cast(ParseInfoType::LOAD_INFO), "LOAD_INFO" }, - { static_cast(ParseInfoType::PRAGMA_INFO), "PRAGMA_INFO" }, - { static_cast(ParseInfoType::SHOW_SELECT_INFO), "SHOW_SELECT_INFO" }, - { static_cast(ParseInfoType::TRANSACTION_INFO), "TRANSACTION_INFO" }, - { static_cast(ParseInfoType::VACUUM_INFO), "VACUUM_INFO" }, - { static_cast(ParseInfoType::COMMENT_ON_INFO), "COMMENT_ON_INFO" }, - { static_cast(ParseInfoType::COMMENT_ON_COLUMN_INFO), "COMMENT_ON_COLUMN_INFO" }, - { static_cast(ParseInfoType::COPY_DATABASE_INFO), "COPY_DATABASE_INFO" }, - { static_cast(ParseInfoType::UPDATE_EXTENSIONS_INFO), "UPDATE_EXTENSIONS_INFO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ParseInfoType value) { - return StringUtil::EnumToString(GetParseInfoTypeValues(), 17, "ParseInfoType", static_cast(value)); -} - -template<> -ParseInfoType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetParseInfoTypeValues(), 17, "ParseInfoType", value)); -} - -const StringUtil::EnumStringLiteral *GetParserExtensionResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ParserExtensionResultType::PARSE_SUCCESSFUL), "PARSE_SUCCESSFUL" }, - { static_cast(ParserExtensionResultType::DISPLAY_ORIGINAL_ERROR), "DISPLAY_ORIGINAL_ERROR" }, - { static_cast(ParserExtensionResultType::DISPLAY_EXTENSION_ERROR), "DISPLAY_EXTENSION_ERROR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ParserExtensionResultType value) { - return StringUtil::EnumToString(GetParserExtensionResultTypeValues(), 3, "ParserExtensionResultType", static_cast(value)); -} - -template<> -ParserExtensionResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetParserExtensionResultTypeValues(), 3, "ParserExtensionResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetPartitionSortStageValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionSortStage::INIT), "INIT" }, - { static_cast(PartitionSortStage::SCAN), "SCAN" }, - { static_cast(PartitionSortStage::PREPARE), "PREPARE" }, - { static_cast(PartitionSortStage::MERGE), "MERGE" }, - { static_cast(PartitionSortStage::SORTED), "SORTED" }, - { static_cast(PartitionSortStage::FINISHED), "FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionSortStage value) { - return StringUtil::EnumToString(GetPartitionSortStageValues(), 6, "PartitionSortStage", static_cast(value)); -} - -template<> -PartitionSortStage EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionSortStageValues(), 6, "PartitionSortStage", value)); -} - -const StringUtil::EnumStringLiteral *GetPartitionedColumnDataTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionedColumnDataType::INVALID), "INVALID" }, - { static_cast(PartitionedColumnDataType::RADIX), "RADIX" }, - { static_cast(PartitionedColumnDataType::HIVE), "HIVE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionedColumnDataType value) { - return StringUtil::EnumToString(GetPartitionedColumnDataTypeValues(), 3, "PartitionedColumnDataType", static_cast(value)); -} - -template<> -PartitionedColumnDataType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionedColumnDataTypeValues(), 3, "PartitionedColumnDataType", value)); -} - -const StringUtil::EnumStringLiteral *GetPartitionedTupleDataTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PartitionedTupleDataType::INVALID), "INVALID" }, - { static_cast(PartitionedTupleDataType::RADIX), "RADIX" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PartitionedTupleDataType value) { - return StringUtil::EnumToString(GetPartitionedTupleDataTypeValues(), 2, "PartitionedTupleDataType", static_cast(value)); -} - -template<> -PartitionedTupleDataType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPartitionedTupleDataTypeValues(), 2, "PartitionedTupleDataType", value)); -} - -const StringUtil::EnumStringLiteral *GetPendingExecutionResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PendingExecutionResult::RESULT_READY), "RESULT_READY" }, - { static_cast(PendingExecutionResult::RESULT_NOT_READY), "RESULT_NOT_READY" }, - { static_cast(PendingExecutionResult::EXECUTION_ERROR), "EXECUTION_ERROR" }, - { static_cast(PendingExecutionResult::BLOCKED), "BLOCKED" }, - { static_cast(PendingExecutionResult::NO_TASKS_AVAILABLE), "NO_TASKS_AVAILABLE" }, - { static_cast(PendingExecutionResult::EXECUTION_FINISHED), "EXECUTION_FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PendingExecutionResult value) { - return StringUtil::EnumToString(GetPendingExecutionResultValues(), 6, "PendingExecutionResult", static_cast(value)); -} - -template<> -PendingExecutionResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPendingExecutionResultValues(), 6, "PendingExecutionResult", value)); -} - -const StringUtil::EnumStringLiteral *GetPhysicalOperatorTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PhysicalOperatorType::INVALID), "INVALID" }, - { static_cast(PhysicalOperatorType::ORDER_BY), "ORDER_BY" }, - { static_cast(PhysicalOperatorType::LIMIT), "LIMIT" }, - { static_cast(PhysicalOperatorType::STREAMING_LIMIT), "STREAMING_LIMIT" }, - { static_cast(PhysicalOperatorType::LIMIT_PERCENT), "LIMIT_PERCENT" }, - { static_cast(PhysicalOperatorType::TOP_N), "TOP_N" }, - { static_cast(PhysicalOperatorType::WINDOW), "WINDOW" }, - { static_cast(PhysicalOperatorType::UNNEST), "UNNEST" }, - { static_cast(PhysicalOperatorType::UNGROUPED_AGGREGATE), "UNGROUPED_AGGREGATE" }, - { static_cast(PhysicalOperatorType::HASH_GROUP_BY), "HASH_GROUP_BY" }, - { static_cast(PhysicalOperatorType::PERFECT_HASH_GROUP_BY), "PERFECT_HASH_GROUP_BY" }, - { static_cast(PhysicalOperatorType::PARTITIONED_AGGREGATE), "PARTITIONED_AGGREGATE" }, - { static_cast(PhysicalOperatorType::FILTER), "FILTER" }, - { static_cast(PhysicalOperatorType::PROJECTION), "PROJECTION" }, - { static_cast(PhysicalOperatorType::COPY_TO_FILE), "COPY_TO_FILE" }, - { static_cast(PhysicalOperatorType::BATCH_COPY_TO_FILE), "BATCH_COPY_TO_FILE" }, - { static_cast(PhysicalOperatorType::RESERVOIR_SAMPLE), "RESERVOIR_SAMPLE" }, - { static_cast(PhysicalOperatorType::STREAMING_SAMPLE), "STREAMING_SAMPLE" }, - { static_cast(PhysicalOperatorType::STREAMING_WINDOW), "STREAMING_WINDOW" }, - { static_cast(PhysicalOperatorType::PIVOT), "PIVOT" }, - { static_cast(PhysicalOperatorType::COPY_DATABASE), "COPY_DATABASE" }, - { static_cast(PhysicalOperatorType::TABLE_SCAN), "TABLE_SCAN" }, - { static_cast(PhysicalOperatorType::DUMMY_SCAN), "DUMMY_SCAN" }, - { static_cast(PhysicalOperatorType::COLUMN_DATA_SCAN), "COLUMN_DATA_SCAN" }, - { static_cast(PhysicalOperatorType::CHUNK_SCAN), "CHUNK_SCAN" }, - { static_cast(PhysicalOperatorType::RECURSIVE_CTE_SCAN), "RECURSIVE_CTE_SCAN" }, - { static_cast(PhysicalOperatorType::CTE_SCAN), "CTE_SCAN" }, - { static_cast(PhysicalOperatorType::DELIM_SCAN), "DELIM_SCAN" }, - { static_cast(PhysicalOperatorType::EXPRESSION_SCAN), "EXPRESSION_SCAN" }, - { static_cast(PhysicalOperatorType::POSITIONAL_SCAN), "POSITIONAL_SCAN" }, - { static_cast(PhysicalOperatorType::BLOCKWISE_NL_JOIN), "BLOCKWISE_NL_JOIN" }, - { static_cast(PhysicalOperatorType::NESTED_LOOP_JOIN), "NESTED_LOOP_JOIN" }, - { static_cast(PhysicalOperatorType::HASH_JOIN), "HASH_JOIN" }, - { static_cast(PhysicalOperatorType::CROSS_PRODUCT), "CROSS_PRODUCT" }, - { static_cast(PhysicalOperatorType::PIECEWISE_MERGE_JOIN), "PIECEWISE_MERGE_JOIN" }, - { static_cast(PhysicalOperatorType::IE_JOIN), "IE_JOIN" }, - { static_cast(PhysicalOperatorType::LEFT_DELIM_JOIN), "LEFT_DELIM_JOIN" }, - { static_cast(PhysicalOperatorType::RIGHT_DELIM_JOIN), "RIGHT_DELIM_JOIN" }, - { static_cast(PhysicalOperatorType::POSITIONAL_JOIN), "POSITIONAL_JOIN" }, - { static_cast(PhysicalOperatorType::ASOF_JOIN), "ASOF_JOIN" }, - { static_cast(PhysicalOperatorType::UNION), "UNION" }, - { static_cast(PhysicalOperatorType::RECURSIVE_CTE), "RECURSIVE_CTE" }, - { static_cast(PhysicalOperatorType::CTE), "CTE" }, - { static_cast(PhysicalOperatorType::INSERT), "INSERT" }, - { static_cast(PhysicalOperatorType::BATCH_INSERT), "BATCH_INSERT" }, - { static_cast(PhysicalOperatorType::DELETE_OPERATOR), "DELETE_OPERATOR" }, - { static_cast(PhysicalOperatorType::UPDATE), "UPDATE" }, - { static_cast(PhysicalOperatorType::CREATE_TABLE), "CREATE_TABLE" }, - { static_cast(PhysicalOperatorType::CREATE_TABLE_AS), "CREATE_TABLE_AS" }, - { static_cast(PhysicalOperatorType::BATCH_CREATE_TABLE_AS), "BATCH_CREATE_TABLE_AS" }, - { static_cast(PhysicalOperatorType::CREATE_INDEX), "CREATE_INDEX" }, - { static_cast(PhysicalOperatorType::ALTER), "ALTER" }, - { static_cast(PhysicalOperatorType::CREATE_SEQUENCE), "CREATE_SEQUENCE" }, - { static_cast(PhysicalOperatorType::CREATE_VIEW), "CREATE_VIEW" }, - { static_cast(PhysicalOperatorType::CREATE_SCHEMA), "CREATE_SCHEMA" }, - { static_cast(PhysicalOperatorType::CREATE_MACRO), "CREATE_MACRO" }, - { static_cast(PhysicalOperatorType::DROP), "DROP" }, - { static_cast(PhysicalOperatorType::PRAGMA), "PRAGMA" }, - { static_cast(PhysicalOperatorType::TRANSACTION), "TRANSACTION" }, - { static_cast(PhysicalOperatorType::CREATE_TYPE), "CREATE_TYPE" }, - { static_cast(PhysicalOperatorType::ATTACH), "ATTACH" }, - { static_cast(PhysicalOperatorType::DETACH), "DETACH" }, - { static_cast(PhysicalOperatorType::EXPLAIN), "EXPLAIN" }, - { static_cast(PhysicalOperatorType::EXPLAIN_ANALYZE), "EXPLAIN_ANALYZE" }, - { static_cast(PhysicalOperatorType::EMPTY_RESULT), "EMPTY_RESULT" }, - { static_cast(PhysicalOperatorType::EXECUTE), "EXECUTE" }, - { static_cast(PhysicalOperatorType::PREPARE), "PREPARE" }, - { static_cast(PhysicalOperatorType::VACUUM), "VACUUM" }, - { static_cast(PhysicalOperatorType::EXPORT), "EXPORT" }, - { static_cast(PhysicalOperatorType::SET), "SET" }, - { static_cast(PhysicalOperatorType::SET_VARIABLE), "SET_VARIABLE" }, - { static_cast(PhysicalOperatorType::LOAD), "LOAD" }, - { static_cast(PhysicalOperatorType::INOUT_FUNCTION), "INOUT_FUNCTION" }, - { static_cast(PhysicalOperatorType::RESULT_COLLECTOR), "RESULT_COLLECTOR" }, - { static_cast(PhysicalOperatorType::RESET), "RESET" }, - { static_cast(PhysicalOperatorType::EXTENSION), "EXTENSION" }, - { static_cast(PhysicalOperatorType::VERIFY_VECTOR), "VERIFY_VECTOR" }, - { static_cast(PhysicalOperatorType::UPDATE_EXTENSIONS), "UPDATE_EXTENSIONS" }, - { static_cast(PhysicalOperatorType::CREATE_SECRET), "CREATE_SECRET" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PhysicalOperatorType value) { - return StringUtil::EnumToString(GetPhysicalOperatorTypeValues(), 79, "PhysicalOperatorType", static_cast(value)); -} - -template<> -PhysicalOperatorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPhysicalOperatorTypeValues(), 79, "PhysicalOperatorType", value)); -} - -const StringUtil::EnumStringLiteral *GetPhysicalTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PhysicalType::BOOL), "BOOL" }, - { static_cast(PhysicalType::UINT8), "UINT8" }, - { static_cast(PhysicalType::INT8), "INT8" }, - { static_cast(PhysicalType::UINT16), "UINT16" }, - { static_cast(PhysicalType::INT16), "INT16" }, - { static_cast(PhysicalType::UINT32), "UINT32" }, - { static_cast(PhysicalType::INT32), "INT32" }, - { static_cast(PhysicalType::UINT64), "UINT64" }, - { static_cast(PhysicalType::INT64), "INT64" }, - { static_cast(PhysicalType::FLOAT), "FLOAT" }, - { static_cast(PhysicalType::DOUBLE), "DOUBLE" }, - { static_cast(PhysicalType::INTERVAL), "INTERVAL" }, - { static_cast(PhysicalType::LIST), "LIST" }, - { static_cast(PhysicalType::STRUCT), "STRUCT" }, - { static_cast(PhysicalType::ARRAY), "ARRAY" }, - { static_cast(PhysicalType::VARCHAR), "VARCHAR" }, - { static_cast(PhysicalType::UINT128), "UINT128" }, - { static_cast(PhysicalType::INT128), "INT128" }, - { static_cast(PhysicalType::UNKNOWN), "UNKNOWN" }, - { static_cast(PhysicalType::BIT), "BIT" }, - { static_cast(PhysicalType::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PhysicalType value) { - return StringUtil::EnumToString(GetPhysicalTypeValues(), 21, "PhysicalType", static_cast(value)); -} - -template<> -PhysicalType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPhysicalTypeValues(), 21, "PhysicalType", value)); -} - -const StringUtil::EnumStringLiteral *GetPragmaTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PragmaType::PRAGMA_STATEMENT), "PRAGMA_STATEMENT" }, - { static_cast(PragmaType::PRAGMA_CALL), "PRAGMA_CALL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PragmaType value) { - return StringUtil::EnumToString(GetPragmaTypeValues(), 2, "PragmaType", static_cast(value)); -} - -template<> -PragmaType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPragmaTypeValues(), 2, "PragmaType", value)); -} - -const StringUtil::EnumStringLiteral *GetPreparedParamTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PreparedParamType::AUTO_INCREMENT), "AUTO_INCREMENT" }, - { static_cast(PreparedParamType::POSITIONAL), "POSITIONAL" }, - { static_cast(PreparedParamType::NAMED), "NAMED" }, - { static_cast(PreparedParamType::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PreparedParamType value) { - return StringUtil::EnumToString(GetPreparedParamTypeValues(), 4, "PreparedParamType", static_cast(value)); -} - -template<> -PreparedParamType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPreparedParamTypeValues(), 4, "PreparedParamType", value)); -} - -const StringUtil::EnumStringLiteral *GetPreparedStatementModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(PreparedStatementMode::PREPARE_ONLY), "PREPARE_ONLY" }, - { static_cast(PreparedStatementMode::PREPARE_AND_EXECUTE), "PREPARE_AND_EXECUTE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(PreparedStatementMode value) { - return StringUtil::EnumToString(GetPreparedStatementModeValues(), 2, "PreparedStatementMode", static_cast(value)); -} - -template<> -PreparedStatementMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetPreparedStatementModeValues(), 2, "PreparedStatementMode", value)); -} - -const StringUtil::EnumStringLiteral *GetProfilerPrintFormatValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ProfilerPrintFormat::QUERY_TREE), "QUERY_TREE" }, - { static_cast(ProfilerPrintFormat::JSON), "JSON" }, - { static_cast(ProfilerPrintFormat::QUERY_TREE_OPTIMIZER), "QUERY_TREE_OPTIMIZER" }, - { static_cast(ProfilerPrintFormat::NO_OUTPUT), "NO_OUTPUT" }, - { static_cast(ProfilerPrintFormat::HTML), "HTML" }, - { static_cast(ProfilerPrintFormat::GRAPHVIZ), "GRAPHVIZ" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ProfilerPrintFormat value) { - return StringUtil::EnumToString(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", static_cast(value)); -} - -template<> -ProfilerPrintFormat EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetProfilerPrintFormatValues(), 6, "ProfilerPrintFormat", value)); -} - -const StringUtil::EnumStringLiteral *GetQuantileSerializationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(QuantileSerializationType::NON_DECIMAL), "NON_DECIMAL" }, - { static_cast(QuantileSerializationType::DECIMAL_DISCRETE), "DECIMAL_DISCRETE" }, - { static_cast(QuantileSerializationType::DECIMAL_DISCRETE_LIST), "DECIMAL_DISCRETE_LIST" }, - { static_cast(QuantileSerializationType::DECIMAL_CONTINUOUS), "DECIMAL_CONTINUOUS" }, - { static_cast(QuantileSerializationType::DECIMAL_CONTINUOUS_LIST), "DECIMAL_CONTINUOUS_LIST" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(QuantileSerializationType value) { - return StringUtil::EnumToString(GetQuantileSerializationTypeValues(), 5, "QuantileSerializationType", static_cast(value)); -} - -template<> -QuantileSerializationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQuantileSerializationTypeValues(), 5, "QuantileSerializationType", value)); -} - -const StringUtil::EnumStringLiteral *GetQueryNodeTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(QueryNodeType::SELECT_NODE), "SELECT_NODE" }, - { static_cast(QueryNodeType::SET_OPERATION_NODE), "SET_OPERATION_NODE" }, - { static_cast(QueryNodeType::BOUND_SUBQUERY_NODE), "BOUND_SUBQUERY_NODE" }, - { static_cast(QueryNodeType::RECURSIVE_CTE_NODE), "RECURSIVE_CTE_NODE" }, - { static_cast(QueryNodeType::CTE_NODE), "CTE_NODE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(QueryNodeType value) { - return StringUtil::EnumToString(GetQueryNodeTypeValues(), 5, "QueryNodeType", static_cast(value)); -} - -template<> -QueryNodeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryNodeTypeValues(), 5, "QueryNodeType", value)); -} - -const StringUtil::EnumStringLiteral *GetQueryResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(QueryResultType::MATERIALIZED_RESULT), "MATERIALIZED_RESULT" }, - { static_cast(QueryResultType::STREAM_RESULT), "STREAM_RESULT" }, - { static_cast(QueryResultType::PENDING_RESULT), "PENDING_RESULT" }, - { static_cast(QueryResultType::ARROW_RESULT), "ARROW_RESULT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(QueryResultType value) { - return StringUtil::EnumToString(GetQueryResultTypeValues(), 4, "QueryResultType", static_cast(value)); -} - -template<> -QueryResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQueryResultTypeValues(), 4, "QueryResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetQuoteRuleValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(QuoteRule::QUOTES_RFC), "QUOTES_RFC" }, - { static_cast(QuoteRule::QUOTES_OTHER), "QUOTES_OTHER" }, - { static_cast(QuoteRule::NO_QUOTES), "NO_QUOTES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(QuoteRule value) { - return StringUtil::EnumToString(GetQuoteRuleValues(), 3, "QuoteRule", static_cast(value)); -} - -template<> -QuoteRule EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetQuoteRuleValues(), 3, "QuoteRule", value)); -} - -const StringUtil::EnumStringLiteral *GetRelationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(RelationType::INVALID_RELATION), "INVALID_RELATION" }, - { static_cast(RelationType::TABLE_RELATION), "TABLE_RELATION" }, - { static_cast(RelationType::PROJECTION_RELATION), "PROJECTION_RELATION" }, - { static_cast(RelationType::FILTER_RELATION), "FILTER_RELATION" }, - { static_cast(RelationType::EXPLAIN_RELATION), "EXPLAIN_RELATION" }, - { static_cast(RelationType::CROSS_PRODUCT_RELATION), "CROSS_PRODUCT_RELATION" }, - { static_cast(RelationType::JOIN_RELATION), "JOIN_RELATION" }, - { static_cast(RelationType::AGGREGATE_RELATION), "AGGREGATE_RELATION" }, - { static_cast(RelationType::SET_OPERATION_RELATION), "SET_OPERATION_RELATION" }, - { static_cast(RelationType::DISTINCT_RELATION), "DISTINCT_RELATION" }, - { static_cast(RelationType::LIMIT_RELATION), "LIMIT_RELATION" }, - { static_cast(RelationType::ORDER_RELATION), "ORDER_RELATION" }, - { static_cast(RelationType::CREATE_VIEW_RELATION), "CREATE_VIEW_RELATION" }, - { static_cast(RelationType::CREATE_TABLE_RELATION), "CREATE_TABLE_RELATION" }, - { static_cast(RelationType::INSERT_RELATION), "INSERT_RELATION" }, - { static_cast(RelationType::VALUE_LIST_RELATION), "VALUE_LIST_RELATION" }, - { static_cast(RelationType::MATERIALIZED_RELATION), "MATERIALIZED_RELATION" }, - { static_cast(RelationType::DELETE_RELATION), "DELETE_RELATION" }, - { static_cast(RelationType::UPDATE_RELATION), "UPDATE_RELATION" }, - { static_cast(RelationType::WRITE_CSV_RELATION), "WRITE_CSV_RELATION" }, - { static_cast(RelationType::WRITE_PARQUET_RELATION), "WRITE_PARQUET_RELATION" }, - { static_cast(RelationType::READ_CSV_RELATION), "READ_CSV_RELATION" }, - { static_cast(RelationType::SUBQUERY_RELATION), "SUBQUERY_RELATION" }, - { static_cast(RelationType::TABLE_FUNCTION_RELATION), "TABLE_FUNCTION_RELATION" }, - { static_cast(RelationType::VIEW_RELATION), "VIEW_RELATION" }, - { static_cast(RelationType::QUERY_RELATION), "QUERY_RELATION" }, - { static_cast(RelationType::DELIM_JOIN_RELATION), "DELIM_JOIN_RELATION" }, - { static_cast(RelationType::DELIM_GET_RELATION), "DELIM_GET_RELATION" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(RelationType value) { - return StringUtil::EnumToString(GetRelationTypeValues(), 28, "RelationType", static_cast(value)); -} - -template<> -RelationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetRelationTypeValues(), 28, "RelationType", value)); -} - -const StringUtil::EnumStringLiteral *GetRenderModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(RenderMode::ROWS), "ROWS" }, - { static_cast(RenderMode::COLUMNS), "COLUMNS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(RenderMode value) { - return StringUtil::EnumToString(GetRenderModeValues(), 2, "RenderMode", static_cast(value)); -} - -template<> -RenderMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetRenderModeValues(), 2, "RenderMode", value)); -} - -const StringUtil::EnumStringLiteral *GetResultModifierTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ResultModifierType::LIMIT_MODIFIER), "LIMIT_MODIFIER" }, - { static_cast(ResultModifierType::ORDER_MODIFIER), "ORDER_MODIFIER" }, - { static_cast(ResultModifierType::DISTINCT_MODIFIER), "DISTINCT_MODIFIER" }, - { static_cast(ResultModifierType::LIMIT_PERCENT_MODIFIER), "LIMIT_PERCENT_MODIFIER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ResultModifierType value) { - return StringUtil::EnumToString(GetResultModifierTypeValues(), 4, "ResultModifierType", static_cast(value)); -} - -template<> -ResultModifierType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetResultModifierTypeValues(), 4, "ResultModifierType", value)); -} - -const StringUtil::EnumStringLiteral *GetSampleMethodValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SampleMethod::SYSTEM_SAMPLE), "System" }, - { static_cast(SampleMethod::BERNOULLI_SAMPLE), "Bernoulli" }, - { static_cast(SampleMethod::RESERVOIR_SAMPLE), "Reservoir" }, - { static_cast(SampleMethod::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SampleMethod value) { - return StringUtil::EnumToString(GetSampleMethodValues(), 4, "SampleMethod", static_cast(value)); -} - -template<> -SampleMethod EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSampleMethodValues(), 4, "SampleMethod", value)); -} - -const StringUtil::EnumStringLiteral *GetSampleTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SampleType::BLOCKING_SAMPLE), "BLOCKING_SAMPLE" }, - { static_cast(SampleType::RESERVOIR_SAMPLE), "RESERVOIR_SAMPLE" }, - { static_cast(SampleType::RESERVOIR_PERCENTAGE_SAMPLE), "RESERVOIR_PERCENTAGE_SAMPLE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SampleType value) { - return StringUtil::EnumToString(GetSampleTypeValues(), 3, "SampleType", static_cast(value)); -} - -template<> -SampleType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSampleTypeValues(), 3, "SampleType", value)); -} - -const StringUtil::EnumStringLiteral *GetSamplingStateValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SamplingState::RANDOM), "RANDOM" }, - { static_cast(SamplingState::RESERVOIR), "RESERVOIR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SamplingState value) { - return StringUtil::EnumToString(GetSamplingStateValues(), 2, "SamplingState", static_cast(value)); -} - -template<> -SamplingState EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSamplingStateValues(), 2, "SamplingState", value)); -} - -const StringUtil::EnumStringLiteral *GetScanTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ScanType::TABLE), "TABLE" }, - { static_cast(ScanType::PARQUET), "PARQUET" }, - { static_cast(ScanType::EXTERNAL), "EXTERNAL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ScanType value) { - return StringUtil::EnumToString(GetScanTypeValues(), 3, "ScanType", static_cast(value)); -} - -template<> -ScanType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetScanTypeValues(), 3, "ScanType", value)); -} - -const StringUtil::EnumStringLiteral *GetSecretDisplayTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SecretDisplayType::REDACTED), "REDACTED" }, - { static_cast(SecretDisplayType::UNREDACTED), "UNREDACTED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SecretDisplayType value) { - return StringUtil::EnumToString(GetSecretDisplayTypeValues(), 2, "SecretDisplayType", static_cast(value)); -} - -template<> -SecretDisplayType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSecretDisplayTypeValues(), 2, "SecretDisplayType", value)); -} - -const StringUtil::EnumStringLiteral *GetSecretPersistTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SecretPersistType::DEFAULT), "DEFAULT" }, - { static_cast(SecretPersistType::TEMPORARY), "TEMPORARY" }, - { static_cast(SecretPersistType::PERSISTENT), "PERSISTENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SecretPersistType value) { - return StringUtil::EnumToString(GetSecretPersistTypeValues(), 3, "SecretPersistType", static_cast(value)); -} - -template<> -SecretPersistType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSecretPersistTypeValues(), 3, "SecretPersistType", value)); -} - -const StringUtil::EnumStringLiteral *GetSecretSerializationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SecretSerializationType::CUSTOM), "CUSTOM" }, - { static_cast(SecretSerializationType::KEY_VALUE_SECRET), "KEY_VALUE_SECRET" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SecretSerializationType value) { - return StringUtil::EnumToString(GetSecretSerializationTypeValues(), 2, "SecretSerializationType", static_cast(value)); -} - -template<> -SecretSerializationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSecretSerializationTypeValues(), 2, "SecretSerializationType", value)); -} - -const StringUtil::EnumStringLiteral *GetSequenceInfoValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SequenceInfo::SEQ_START), "SEQ_START" }, - { static_cast(SequenceInfo::SEQ_INC), "SEQ_INC" }, - { static_cast(SequenceInfo::SEQ_MIN), "SEQ_MIN" }, - { static_cast(SequenceInfo::SEQ_MAX), "SEQ_MAX" }, - { static_cast(SequenceInfo::SEQ_CYCLE), "SEQ_CYCLE" }, - { static_cast(SequenceInfo::SEQ_OWN), "SEQ_OWN" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SequenceInfo value) { - return StringUtil::EnumToString(GetSequenceInfoValues(), 6, "SequenceInfo", static_cast(value)); -} - -template<> -SequenceInfo EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSequenceInfoValues(), 6, "SequenceInfo", value)); -} - -const StringUtil::EnumStringLiteral *GetSetOperationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SetOperationType::NONE), "NONE" }, - { static_cast(SetOperationType::UNION), "UNION" }, - { static_cast(SetOperationType::EXCEPT), "EXCEPT" }, - { static_cast(SetOperationType::INTERSECT), "INTERSECT" }, - { static_cast(SetOperationType::UNION_BY_NAME), "UNION_BY_NAME" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SetOperationType value) { - return StringUtil::EnumToString(GetSetOperationTypeValues(), 5, "SetOperationType", static_cast(value)); -} - -template<> -SetOperationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSetOperationTypeValues(), 5, "SetOperationType", value)); -} - -const StringUtil::EnumStringLiteral *GetSetScopeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SetScope::AUTOMATIC), "AUTOMATIC" }, - { static_cast(SetScope::LOCAL), "LOCAL" }, - { static_cast(SetScope::SESSION), "SESSION" }, - { static_cast(SetScope::GLOBAL), "GLOBAL" }, - { static_cast(SetScope::VARIABLE), "VARIABLE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SetScope value) { - return StringUtil::EnumToString(GetSetScopeValues(), 5, "SetScope", static_cast(value)); -} - -template<> -SetScope EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSetScopeValues(), 5, "SetScope", value)); -} - -const StringUtil::EnumStringLiteral *GetSetTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SetType::SET), "SET" }, - { static_cast(SetType::RESET), "RESET" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SetType value) { - return StringUtil::EnumToString(GetSetTypeValues(), 2, "SetType", static_cast(value)); -} - -template<> -SetType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSetTypeValues(), 2, "SetType", value)); -} - -const StringUtil::EnumStringLiteral *GetSettingScopeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SettingScope::GLOBAL), "GLOBAL" }, - { static_cast(SettingScope::LOCAL), "LOCAL" }, - { static_cast(SettingScope::SECRET), "SECRET" }, - { static_cast(SettingScope::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SettingScope value) { - return StringUtil::EnumToString(GetSettingScopeValues(), 4, "SettingScope", static_cast(value)); -} - -template<> -SettingScope EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSettingScopeValues(), 4, "SettingScope", value)); -} - -const StringUtil::EnumStringLiteral *GetShowTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(ShowType::SUMMARY), "SUMMARY" }, - { static_cast(ShowType::DESCRIBE), "DESCRIBE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(ShowType value) { - return StringUtil::EnumToString(GetShowTypeValues(), 2, "ShowType", static_cast(value)); -} - -template<> -ShowType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetShowTypeValues(), 2, "ShowType", value)); -} - -const StringUtil::EnumStringLiteral *GetSimplifiedTokenTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_IDENTIFIER), "SIMPLIFIED_TOKEN_IDENTIFIER" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_NUMERIC_CONSTANT), "SIMPLIFIED_TOKEN_NUMERIC_CONSTANT" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_STRING_CONSTANT), "SIMPLIFIED_TOKEN_STRING_CONSTANT" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_OPERATOR), "SIMPLIFIED_TOKEN_OPERATOR" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_KEYWORD), "SIMPLIFIED_TOKEN_KEYWORD" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_COMMENT), "SIMPLIFIED_TOKEN_COMMENT" }, - { static_cast(SimplifiedTokenType::SIMPLIFIED_TOKEN_ERROR), "SIMPLIFIED_TOKEN_ERROR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SimplifiedTokenType value) { - return StringUtil::EnumToString(GetSimplifiedTokenTypeValues(), 7, "SimplifiedTokenType", static_cast(value)); -} - -template<> -SimplifiedTokenType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSimplifiedTokenTypeValues(), 7, "SimplifiedTokenType", value)); -} - -const StringUtil::EnumStringLiteral *GetSinkCombineResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SinkCombineResultType::FINISHED), "FINISHED" }, - { static_cast(SinkCombineResultType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SinkCombineResultType value) { - return StringUtil::EnumToString(GetSinkCombineResultTypeValues(), 2, "SinkCombineResultType", static_cast(value)); -} - -template<> -SinkCombineResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSinkCombineResultTypeValues(), 2, "SinkCombineResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetSinkFinalizeTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SinkFinalizeType::READY), "READY" }, - { static_cast(SinkFinalizeType::NO_OUTPUT_POSSIBLE), "NO_OUTPUT_POSSIBLE" }, - { static_cast(SinkFinalizeType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SinkFinalizeType value) { - return StringUtil::EnumToString(GetSinkFinalizeTypeValues(), 3, "SinkFinalizeType", static_cast(value)); -} - -template<> -SinkFinalizeType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSinkFinalizeTypeValues(), 3, "SinkFinalizeType", value)); -} - -const StringUtil::EnumStringLiteral *GetSinkNextBatchTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SinkNextBatchType::READY), "READY" }, - { static_cast(SinkNextBatchType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SinkNextBatchType value) { - return StringUtil::EnumToString(GetSinkNextBatchTypeValues(), 2, "SinkNextBatchType", static_cast(value)); -} - -template<> -SinkNextBatchType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSinkNextBatchTypeValues(), 2, "SinkNextBatchType", value)); -} - -const StringUtil::EnumStringLiteral *GetSinkResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SinkResultType::NEED_MORE_INPUT), "NEED_MORE_INPUT" }, - { static_cast(SinkResultType::FINISHED), "FINISHED" }, - { static_cast(SinkResultType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SinkResultType value) { - return StringUtil::EnumToString(GetSinkResultTypeValues(), 3, "SinkResultType", static_cast(value)); -} - -template<> -SinkResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSinkResultTypeValues(), 3, "SinkResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetSourceResultTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SourceResultType::HAVE_MORE_OUTPUT), "HAVE_MORE_OUTPUT" }, - { static_cast(SourceResultType::FINISHED), "FINISHED" }, - { static_cast(SourceResultType::BLOCKED), "BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SourceResultType value) { - return StringUtil::EnumToString(GetSourceResultTypeValues(), 3, "SourceResultType", static_cast(value)); -} - -template<> -SourceResultType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSourceResultTypeValues(), 3, "SourceResultType", value)); -} - -const StringUtil::EnumStringLiteral *GetStatementReturnTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StatementReturnType::QUERY_RESULT), "QUERY_RESULT" }, - { static_cast(StatementReturnType::CHANGED_ROWS), "CHANGED_ROWS" }, - { static_cast(StatementReturnType::NOTHING), "NOTHING" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StatementReturnType value) { - return StringUtil::EnumToString(GetStatementReturnTypeValues(), 3, "StatementReturnType", static_cast(value)); -} - -template<> -StatementReturnType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatementReturnTypeValues(), 3, "StatementReturnType", value)); -} - -const StringUtil::EnumStringLiteral *GetStatementTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StatementType::INVALID_STATEMENT), "INVALID_STATEMENT" }, - { static_cast(StatementType::SELECT_STATEMENT), "SELECT_STATEMENT" }, - { static_cast(StatementType::INSERT_STATEMENT), "INSERT_STATEMENT" }, - { static_cast(StatementType::UPDATE_STATEMENT), "UPDATE_STATEMENT" }, - { static_cast(StatementType::CREATE_STATEMENT), "CREATE_STATEMENT" }, - { static_cast(StatementType::DELETE_STATEMENT), "DELETE_STATEMENT" }, - { static_cast(StatementType::PREPARE_STATEMENT), "PREPARE_STATEMENT" }, - { static_cast(StatementType::EXECUTE_STATEMENT), "EXECUTE_STATEMENT" }, - { static_cast(StatementType::ALTER_STATEMENT), "ALTER_STATEMENT" }, - { static_cast(StatementType::TRANSACTION_STATEMENT), "TRANSACTION_STATEMENT" }, - { static_cast(StatementType::COPY_STATEMENT), "COPY_STATEMENT" }, - { static_cast(StatementType::ANALYZE_STATEMENT), "ANALYZE_STATEMENT" }, - { static_cast(StatementType::VARIABLE_SET_STATEMENT), "VARIABLE_SET_STATEMENT" }, - { static_cast(StatementType::CREATE_FUNC_STATEMENT), "CREATE_FUNC_STATEMENT" }, - { static_cast(StatementType::EXPLAIN_STATEMENT), "EXPLAIN_STATEMENT" }, - { static_cast(StatementType::DROP_STATEMENT), "DROP_STATEMENT" }, - { static_cast(StatementType::EXPORT_STATEMENT), "EXPORT_STATEMENT" }, - { static_cast(StatementType::PRAGMA_STATEMENT), "PRAGMA_STATEMENT" }, - { static_cast(StatementType::VACUUM_STATEMENT), "VACUUM_STATEMENT" }, - { static_cast(StatementType::CALL_STATEMENT), "CALL_STATEMENT" }, - { static_cast(StatementType::SET_STATEMENT), "SET_STATEMENT" }, - { static_cast(StatementType::LOAD_STATEMENT), "LOAD_STATEMENT" }, - { static_cast(StatementType::RELATION_STATEMENT), "RELATION_STATEMENT" }, - { static_cast(StatementType::EXTENSION_STATEMENT), "EXTENSION_STATEMENT" }, - { static_cast(StatementType::LOGICAL_PLAN_STATEMENT), "LOGICAL_PLAN_STATEMENT" }, - { static_cast(StatementType::ATTACH_STATEMENT), "ATTACH_STATEMENT" }, - { static_cast(StatementType::DETACH_STATEMENT), "DETACH_STATEMENT" }, - { static_cast(StatementType::MULTI_STATEMENT), "MULTI_STATEMENT" }, - { static_cast(StatementType::COPY_DATABASE_STATEMENT), "COPY_DATABASE_STATEMENT" }, - { static_cast(StatementType::UPDATE_EXTENSIONS_STATEMENT), "UPDATE_EXTENSIONS_STATEMENT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StatementType value) { - return StringUtil::EnumToString(GetStatementTypeValues(), 30, "StatementType", static_cast(value)); -} - -template<> -StatementType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatementTypeValues(), 30, "StatementType", value)); -} - -const StringUtil::EnumStringLiteral *GetStatisticsTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StatisticsType::NUMERIC_STATS), "NUMERIC_STATS" }, - { static_cast(StatisticsType::STRING_STATS), "STRING_STATS" }, - { static_cast(StatisticsType::LIST_STATS), "LIST_STATS" }, - { static_cast(StatisticsType::STRUCT_STATS), "STRUCT_STATS" }, - { static_cast(StatisticsType::BASE_STATS), "BASE_STATS" }, - { static_cast(StatisticsType::ARRAY_STATS), "ARRAY_STATS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StatisticsType value) { - return StringUtil::EnumToString(GetStatisticsTypeValues(), 6, "StatisticsType", static_cast(value)); -} - -template<> -StatisticsType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatisticsTypeValues(), 6, "StatisticsType", value)); -} - -const StringUtil::EnumStringLiteral *GetStatsInfoValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StatsInfo::CAN_HAVE_NULL_VALUES), "CAN_HAVE_NULL_VALUES" }, - { static_cast(StatsInfo::CANNOT_HAVE_NULL_VALUES), "CANNOT_HAVE_NULL_VALUES" }, - { static_cast(StatsInfo::CAN_HAVE_VALID_VALUES), "CAN_HAVE_VALID_VALUES" }, - { static_cast(StatsInfo::CANNOT_HAVE_VALID_VALUES), "CANNOT_HAVE_VALID_VALUES" }, - { static_cast(StatsInfo::CAN_HAVE_NULL_AND_VALID_VALUES), "CAN_HAVE_NULL_AND_VALID_VALUES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StatsInfo value) { - return StringUtil::EnumToString(GetStatsInfoValues(), 5, "StatsInfo", static_cast(value)); -} - -template<> -StatsInfo EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStatsInfoValues(), 5, "StatsInfo", value)); -} - -const StringUtil::EnumStringLiteral *GetStrTimeSpecifierValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME), "ABBREVIATED_WEEKDAY_NAME" }, - { static_cast(StrTimeSpecifier::FULL_WEEKDAY_NAME), "FULL_WEEKDAY_NAME" }, - { static_cast(StrTimeSpecifier::WEEKDAY_DECIMAL), "WEEKDAY_DECIMAL" }, - { static_cast(StrTimeSpecifier::DAY_OF_MONTH_PADDED), "DAY_OF_MONTH_PADDED" }, - { static_cast(StrTimeSpecifier::DAY_OF_MONTH), "DAY_OF_MONTH" }, - { static_cast(StrTimeSpecifier::ABBREVIATED_MONTH_NAME), "ABBREVIATED_MONTH_NAME" }, - { static_cast(StrTimeSpecifier::FULL_MONTH_NAME), "FULL_MONTH_NAME" }, - { static_cast(StrTimeSpecifier::MONTH_DECIMAL_PADDED), "MONTH_DECIMAL_PADDED" }, - { static_cast(StrTimeSpecifier::MONTH_DECIMAL), "MONTH_DECIMAL" }, - { static_cast(StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED), "YEAR_WITHOUT_CENTURY_PADDED" }, - { static_cast(StrTimeSpecifier::YEAR_WITHOUT_CENTURY), "YEAR_WITHOUT_CENTURY" }, - { static_cast(StrTimeSpecifier::YEAR_DECIMAL), "YEAR_DECIMAL" }, - { static_cast(StrTimeSpecifier::HOUR_24_PADDED), "HOUR_24_PADDED" }, - { static_cast(StrTimeSpecifier::HOUR_24_DECIMAL), "HOUR_24_DECIMAL" }, - { static_cast(StrTimeSpecifier::HOUR_12_PADDED), "HOUR_12_PADDED" }, - { static_cast(StrTimeSpecifier::HOUR_12_DECIMAL), "HOUR_12_DECIMAL" }, - { static_cast(StrTimeSpecifier::AM_PM), "AM_PM" }, - { static_cast(StrTimeSpecifier::MINUTE_PADDED), "MINUTE_PADDED" }, - { static_cast(StrTimeSpecifier::MINUTE_DECIMAL), "MINUTE_DECIMAL" }, - { static_cast(StrTimeSpecifier::SECOND_PADDED), "SECOND_PADDED" }, - { static_cast(StrTimeSpecifier::SECOND_DECIMAL), "SECOND_DECIMAL" }, - { static_cast(StrTimeSpecifier::MICROSECOND_PADDED), "MICROSECOND_PADDED" }, - { static_cast(StrTimeSpecifier::MILLISECOND_PADDED), "MILLISECOND_PADDED" }, - { static_cast(StrTimeSpecifier::UTC_OFFSET), "UTC_OFFSET" }, - { static_cast(StrTimeSpecifier::TZ_NAME), "TZ_NAME" }, - { static_cast(StrTimeSpecifier::DAY_OF_YEAR_PADDED), "DAY_OF_YEAR_PADDED" }, - { static_cast(StrTimeSpecifier::DAY_OF_YEAR_DECIMAL), "DAY_OF_YEAR_DECIMAL" }, - { static_cast(StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST), "WEEK_NUMBER_PADDED_SUN_FIRST" }, - { static_cast(StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST), "WEEK_NUMBER_PADDED_MON_FIRST" }, - { static_cast(StrTimeSpecifier::LOCALE_APPROPRIATE_DATE_AND_TIME), "LOCALE_APPROPRIATE_DATE_AND_TIME" }, - { static_cast(StrTimeSpecifier::LOCALE_APPROPRIATE_DATE), "LOCALE_APPROPRIATE_DATE" }, - { static_cast(StrTimeSpecifier::LOCALE_APPROPRIATE_TIME), "LOCALE_APPROPRIATE_TIME" }, - { static_cast(StrTimeSpecifier::NANOSECOND_PADDED), "NANOSECOND_PADDED" }, - { static_cast(StrTimeSpecifier::YEAR_ISO), "YEAR_ISO" }, - { static_cast(StrTimeSpecifier::WEEKDAY_ISO), "WEEKDAY_ISO" }, - { static_cast(StrTimeSpecifier::WEEK_NUMBER_ISO), "WEEK_NUMBER_ISO" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StrTimeSpecifier value) { - return StringUtil::EnumToString(GetStrTimeSpecifierValues(), 36, "StrTimeSpecifier", static_cast(value)); -} - -template<> -StrTimeSpecifier EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStrTimeSpecifierValues(), 36, "StrTimeSpecifier", value)); -} - -const StringUtil::EnumStringLiteral *GetStreamExecutionResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(StreamExecutionResult::CHUNK_READY), "CHUNK_READY" }, - { static_cast(StreamExecutionResult::CHUNK_NOT_READY), "CHUNK_NOT_READY" }, - { static_cast(StreamExecutionResult::EXECUTION_ERROR), "EXECUTION_ERROR" }, - { static_cast(StreamExecutionResult::EXECUTION_CANCELLED), "EXECUTION_CANCELLED" }, - { static_cast(StreamExecutionResult::BLOCKED), "BLOCKED" }, - { static_cast(StreamExecutionResult::NO_TASKS_AVAILABLE), "NO_TASKS_AVAILABLE" }, - { static_cast(StreamExecutionResult::EXECUTION_FINISHED), "EXECUTION_FINISHED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(StreamExecutionResult value) { - return StringUtil::EnumToString(GetStreamExecutionResultValues(), 7, "StreamExecutionResult", static_cast(value)); -} - -template<> -StreamExecutionResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetStreamExecutionResultValues(), 7, "StreamExecutionResult", value)); -} - -const StringUtil::EnumStringLiteral *GetSubqueryTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(SubqueryType::INVALID), "INVALID" }, - { static_cast(SubqueryType::SCALAR), "SCALAR" }, - { static_cast(SubqueryType::EXISTS), "EXISTS" }, - { static_cast(SubqueryType::NOT_EXISTS), "NOT_EXISTS" }, - { static_cast(SubqueryType::ANY), "ANY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(SubqueryType value) { - return StringUtil::EnumToString(GetSubqueryTypeValues(), 5, "SubqueryType", static_cast(value)); -} - -template<> -SubqueryType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetSubqueryTypeValues(), 5, "SubqueryType", value)); -} - -const StringUtil::EnumStringLiteral *GetTableColumnTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TableColumnType::STANDARD), "STANDARD" }, - { static_cast(TableColumnType::GENERATED), "GENERATED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TableColumnType value) { - return StringUtil::EnumToString(GetTableColumnTypeValues(), 2, "TableColumnType", static_cast(value)); -} - -template<> -TableColumnType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableColumnTypeValues(), 2, "TableColumnType", value)); -} - -const StringUtil::EnumStringLiteral *GetTableFilterTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TableFilterType::CONSTANT_COMPARISON), "CONSTANT_COMPARISON" }, - { static_cast(TableFilterType::IS_NULL), "IS_NULL" }, - { static_cast(TableFilterType::IS_NOT_NULL), "IS_NOT_NULL" }, - { static_cast(TableFilterType::CONJUNCTION_OR), "CONJUNCTION_OR" }, - { static_cast(TableFilterType::CONJUNCTION_AND), "CONJUNCTION_AND" }, - { static_cast(TableFilterType::STRUCT_EXTRACT), "STRUCT_EXTRACT" }, - { static_cast(TableFilterType::OPTIONAL_FILTER), "OPTIONAL_FILTER" }, - { static_cast(TableFilterType::IN_FILTER), "IN_FILTER" }, - { static_cast(TableFilterType::DYNAMIC_FILTER), "DYNAMIC_FILTER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TableFilterType value) { - return StringUtil::EnumToString(GetTableFilterTypeValues(), 9, "TableFilterType", static_cast(value)); -} - -template<> -TableFilterType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableFilterTypeValues(), 9, "TableFilterType", value)); -} - -const StringUtil::EnumStringLiteral *GetTablePartitionInfoValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TablePartitionInfo::NOT_PARTITIONED), "NOT_PARTITIONED" }, - { static_cast(TablePartitionInfo::SINGLE_VALUE_PARTITIONS), "SINGLE_VALUE_PARTITIONS" }, - { static_cast(TablePartitionInfo::OVERLAPPING_PARTITIONS), "OVERLAPPING_PARTITIONS" }, - { static_cast(TablePartitionInfo::DISJOINT_PARTITIONS), "DISJOINT_PARTITIONS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TablePartitionInfo value) { - return StringUtil::EnumToString(GetTablePartitionInfoValues(), 4, "TablePartitionInfo", static_cast(value)); -} - -template<> -TablePartitionInfo EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTablePartitionInfoValues(), 4, "TablePartitionInfo", value)); -} - -const StringUtil::EnumStringLiteral *GetTableReferenceTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TableReferenceType::INVALID), "INVALID" }, - { static_cast(TableReferenceType::BASE_TABLE), "BASE_TABLE" }, - { static_cast(TableReferenceType::SUBQUERY), "SUBQUERY" }, - { static_cast(TableReferenceType::JOIN), "JOIN" }, - { static_cast(TableReferenceType::TABLE_FUNCTION), "TABLE_FUNCTION" }, - { static_cast(TableReferenceType::EXPRESSION_LIST), "EXPRESSION_LIST" }, - { static_cast(TableReferenceType::CTE), "CTE" }, - { static_cast(TableReferenceType::EMPTY_FROM), "EMPTY" }, - { static_cast(TableReferenceType::PIVOT), "PIVOT" }, - { static_cast(TableReferenceType::SHOW_REF), "SHOW_REF" }, - { static_cast(TableReferenceType::COLUMN_DATA), "COLUMN_DATA" }, - { static_cast(TableReferenceType::DELIM_GET), "DELIM_GET" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TableReferenceType value) { - return StringUtil::EnumToString(GetTableReferenceTypeValues(), 12, "TableReferenceType", static_cast(value)); -} - -template<> -TableReferenceType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableReferenceTypeValues(), 12, "TableReferenceType", value)); -} - -const StringUtil::EnumStringLiteral *GetTableScanTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TableScanType::TABLE_SCAN_REGULAR), "TABLE_SCAN_REGULAR" }, - { static_cast(TableScanType::TABLE_SCAN_COMMITTED_ROWS), "TABLE_SCAN_COMMITTED_ROWS" }, - { static_cast(TableScanType::TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES), "TABLE_SCAN_COMMITTED_ROWS_DISALLOW_UPDATES" }, - { static_cast(TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED), "TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED" }, - { static_cast(TableScanType::TABLE_SCAN_LATEST_COMMITTED_ROWS), "TABLE_SCAN_LATEST_COMMITTED_ROWS" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TableScanType value) { - return StringUtil::EnumToString(GetTableScanTypeValues(), 5, "TableScanType", static_cast(value)); -} - -template<> -TableScanType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTableScanTypeValues(), 5, "TableScanType", value)); -} - -const StringUtil::EnumStringLiteral *GetTaskExecutionModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TaskExecutionMode::PROCESS_ALL), "PROCESS_ALL" }, - { static_cast(TaskExecutionMode::PROCESS_PARTIAL), "PROCESS_PARTIAL" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TaskExecutionMode value) { - return StringUtil::EnumToString(GetTaskExecutionModeValues(), 2, "TaskExecutionMode", static_cast(value)); -} - -template<> -TaskExecutionMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTaskExecutionModeValues(), 2, "TaskExecutionMode", value)); -} - -const StringUtil::EnumStringLiteral *GetTaskExecutionResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TaskExecutionResult::TASK_FINISHED), "TASK_FINISHED" }, - { static_cast(TaskExecutionResult::TASK_NOT_FINISHED), "TASK_NOT_FINISHED" }, - { static_cast(TaskExecutionResult::TASK_ERROR), "TASK_ERROR" }, - { static_cast(TaskExecutionResult::TASK_BLOCKED), "TASK_BLOCKED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TaskExecutionResult value) { - return StringUtil::EnumToString(GetTaskExecutionResultValues(), 4, "TaskExecutionResult", static_cast(value)); -} - -template<> -TaskExecutionResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTaskExecutionResultValues(), 4, "TaskExecutionResult", value)); -} - -const StringUtil::EnumStringLiteral *GetTemporaryBufferSizeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TemporaryBufferSize::INVALID), "INVALID" }, - { static_cast(TemporaryBufferSize::S32K), "S32K" }, - { static_cast(TemporaryBufferSize::S64K), "S64K" }, - { static_cast(TemporaryBufferSize::S96K), "S96K" }, - { static_cast(TemporaryBufferSize::S128K), "S128K" }, - { static_cast(TemporaryBufferSize::S160K), "S160K" }, - { static_cast(TemporaryBufferSize::S192K), "S192K" }, - { static_cast(TemporaryBufferSize::S224K), "S224K" }, - { static_cast(TemporaryBufferSize::DEFAULT), "DEFAULT" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TemporaryBufferSize value) { - return StringUtil::EnumToString(GetTemporaryBufferSizeValues(), 9, "TemporaryBufferSize", static_cast(value)); -} - -template<> -TemporaryBufferSize EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTemporaryBufferSizeValues(), 9, "TemporaryBufferSize", value)); -} - -const StringUtil::EnumStringLiteral *GetTemporaryCompressionLevelValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TemporaryCompressionLevel::ZSTD_MINUS_FIVE), "ZSTD_MINUS_FIVE" }, - { static_cast(TemporaryCompressionLevel::ZSTD_MINUS_THREE), "ZSTD_MINUS_THREE" }, - { static_cast(TemporaryCompressionLevel::ZSTD_MINUS_ONE), "ZSTD_MINUS_ONE" }, - { static_cast(TemporaryCompressionLevel::UNCOMPRESSED), "UNCOMPRESSED" }, - { static_cast(TemporaryCompressionLevel::ZSTD_ONE), "ZSTD_ONE" }, - { static_cast(TemporaryCompressionLevel::ZSTD_THREE), "ZSTD_THREE" }, - { static_cast(TemporaryCompressionLevel::ZSTD_FIVE), "ZSTD_FIVE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TemporaryCompressionLevel value) { - return StringUtil::EnumToString(GetTemporaryCompressionLevelValues(), 7, "TemporaryCompressionLevel", static_cast(value)); -} - -template<> -TemporaryCompressionLevel EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTemporaryCompressionLevelValues(), 7, "TemporaryCompressionLevel", value)); -} - -const StringUtil::EnumStringLiteral *GetTimestampCastResultValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TimestampCastResult::SUCCESS), "SUCCESS" }, - { static_cast(TimestampCastResult::ERROR_INCORRECT_FORMAT), "ERROR_INCORRECT_FORMAT" }, - { static_cast(TimestampCastResult::ERROR_NON_UTC_TIMEZONE), "ERROR_NON_UTC_TIMEZONE" }, - { static_cast(TimestampCastResult::ERROR_RANGE), "ERROR_RANGE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TimestampCastResult value) { - return StringUtil::EnumToString(GetTimestampCastResultValues(), 4, "TimestampCastResult", static_cast(value)); -} - -template<> -TimestampCastResult EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTimestampCastResultValues(), 4, "TimestampCastResult", value)); -} - -const StringUtil::EnumStringLiteral *GetTransactionModifierTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TransactionModifierType::TRANSACTION_DEFAULT_MODIFIER), "TRANSACTION_DEFAULT_MODIFIER" }, - { static_cast(TransactionModifierType::TRANSACTION_READ_ONLY), "TRANSACTION_READ_ONLY" }, - { static_cast(TransactionModifierType::TRANSACTION_READ_WRITE), "TRANSACTION_READ_WRITE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TransactionModifierType value) { - return StringUtil::EnumToString(GetTransactionModifierTypeValues(), 3, "TransactionModifierType", static_cast(value)); -} - -template<> -TransactionModifierType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTransactionModifierTypeValues(), 3, "TransactionModifierType", value)); -} - -const StringUtil::EnumStringLiteral *GetTransactionTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TransactionType::INVALID), "INVALID" }, - { static_cast(TransactionType::BEGIN_TRANSACTION), "BEGIN_TRANSACTION" }, - { static_cast(TransactionType::COMMIT), "COMMIT" }, - { static_cast(TransactionType::ROLLBACK), "ROLLBACK" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TransactionType value) { - return StringUtil::EnumToString(GetTransactionTypeValues(), 4, "TransactionType", static_cast(value)); -} - -template<> -TransactionType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTransactionTypeValues(), 4, "TransactionType", value)); -} - -const StringUtil::EnumStringLiteral *GetTupleDataPinPropertiesValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(TupleDataPinProperties::INVALID), "INVALID" }, - { static_cast(TupleDataPinProperties::KEEP_EVERYTHING_PINNED), "KEEP_EVERYTHING_PINNED" }, - { static_cast(TupleDataPinProperties::UNPIN_AFTER_DONE), "UNPIN_AFTER_DONE" }, - { static_cast(TupleDataPinProperties::DESTROY_AFTER_DONE), "DESTROY_AFTER_DONE" }, - { static_cast(TupleDataPinProperties::ALREADY_PINNED), "ALREADY_PINNED" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(TupleDataPinProperties value) { - return StringUtil::EnumToString(GetTupleDataPinPropertiesValues(), 5, "TupleDataPinProperties", static_cast(value)); -} - -template<> -TupleDataPinProperties EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetTupleDataPinPropertiesValues(), 5, "TupleDataPinProperties", value)); -} - -const StringUtil::EnumStringLiteral *GetUndoFlagsValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(UndoFlags::EMPTY_ENTRY), "EMPTY_ENTRY" }, - { static_cast(UndoFlags::CATALOG_ENTRY), "CATALOG_ENTRY" }, - { static_cast(UndoFlags::INSERT_TUPLE), "INSERT_TUPLE" }, - { static_cast(UndoFlags::DELETE_TUPLE), "DELETE_TUPLE" }, - { static_cast(UndoFlags::UPDATE_TUPLE), "UPDATE_TUPLE" }, - { static_cast(UndoFlags::SEQUENCE_VALUE), "SEQUENCE_VALUE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(UndoFlags value) { - return StringUtil::EnumToString(GetUndoFlagsValues(), 6, "UndoFlags", static_cast(value)); -} - -template<> -UndoFlags EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetUndoFlagsValues(), 6, "UndoFlags", value)); -} - -const StringUtil::EnumStringLiteral *GetUnionInvalidReasonValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(UnionInvalidReason::VALID), "VALID" }, - { static_cast(UnionInvalidReason::TAG_OUT_OF_RANGE), "TAG_OUT_OF_RANGE" }, - { static_cast(UnionInvalidReason::NO_MEMBERS), "NO_MEMBERS" }, - { static_cast(UnionInvalidReason::VALIDITY_OVERLAP), "VALIDITY_OVERLAP" }, - { static_cast(UnionInvalidReason::TAG_MISMATCH), "TAG_MISMATCH" }, - { static_cast(UnionInvalidReason::NULL_TAG), "NULL_TAG" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(UnionInvalidReason value) { - return StringUtil::EnumToString(GetUnionInvalidReasonValues(), 6, "UnionInvalidReason", static_cast(value)); -} - -template<> -UnionInvalidReason EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetUnionInvalidReasonValues(), 6, "UnionInvalidReason", value)); -} - -const StringUtil::EnumStringLiteral *GetVectorAuxiliaryDataTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VectorAuxiliaryDataType::ARROW_AUXILIARY), "ARROW_AUXILIARY" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VectorAuxiliaryDataType value) { - return StringUtil::EnumToString(GetVectorAuxiliaryDataTypeValues(), 1, "VectorAuxiliaryDataType", static_cast(value)); -} - -template<> -VectorAuxiliaryDataType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVectorAuxiliaryDataTypeValues(), 1, "VectorAuxiliaryDataType", value)); -} - -const StringUtil::EnumStringLiteral *GetVectorBufferTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VectorBufferType::STANDARD_BUFFER), "STANDARD_BUFFER" }, - { static_cast(VectorBufferType::DICTIONARY_BUFFER), "DICTIONARY_BUFFER" }, - { static_cast(VectorBufferType::VECTOR_CHILD_BUFFER), "VECTOR_CHILD_BUFFER" }, - { static_cast(VectorBufferType::STRING_BUFFER), "STRING_BUFFER" }, - { static_cast(VectorBufferType::FSST_BUFFER), "FSST_BUFFER" }, - { static_cast(VectorBufferType::STRUCT_BUFFER), "STRUCT_BUFFER" }, - { static_cast(VectorBufferType::LIST_BUFFER), "LIST_BUFFER" }, - { static_cast(VectorBufferType::MANAGED_BUFFER), "MANAGED_BUFFER" }, - { static_cast(VectorBufferType::OPAQUE_BUFFER), "OPAQUE_BUFFER" }, - { static_cast(VectorBufferType::ARRAY_BUFFER), "ARRAY_BUFFER" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VectorBufferType value) { - return StringUtil::EnumToString(GetVectorBufferTypeValues(), 10, "VectorBufferType", static_cast(value)); -} - -template<> -VectorBufferType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVectorBufferTypeValues(), 10, "VectorBufferType", value)); -} - -const StringUtil::EnumStringLiteral *GetVectorTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VectorType::FLAT_VECTOR), "FLAT_VECTOR" }, - { static_cast(VectorType::FSST_VECTOR), "FSST_VECTOR" }, - { static_cast(VectorType::CONSTANT_VECTOR), "CONSTANT_VECTOR" }, - { static_cast(VectorType::DICTIONARY_VECTOR), "DICTIONARY_VECTOR" }, - { static_cast(VectorType::SEQUENCE_VECTOR), "SEQUENCE_VECTOR" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VectorType value) { - return StringUtil::EnumToString(GetVectorTypeValues(), 5, "VectorType", static_cast(value)); -} - -template<> -VectorType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVectorTypeValues(), 5, "VectorType", value)); -} - -const StringUtil::EnumStringLiteral *GetVerificationTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VerificationType::ORIGINAL), "ORIGINAL" }, - { static_cast(VerificationType::COPIED), "COPIED" }, - { static_cast(VerificationType::DESERIALIZED), "DESERIALIZED" }, - { static_cast(VerificationType::PARSED), "PARSED" }, - { static_cast(VerificationType::UNOPTIMIZED), "UNOPTIMIZED" }, - { static_cast(VerificationType::NO_OPERATOR_CACHING), "NO_OPERATOR_CACHING" }, - { static_cast(VerificationType::PREPARED), "PREPARED" }, - { static_cast(VerificationType::EXTERNAL), "EXTERNAL" }, - { static_cast(VerificationType::FETCH_ROW_AS_SCAN), "FETCH_ROW_AS_SCAN" }, - { static_cast(VerificationType::INVALID), "INVALID" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VerificationType value) { - return StringUtil::EnumToString(GetVerificationTypeValues(), 10, "VerificationType", static_cast(value)); -} - -template<> -VerificationType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVerificationTypeValues(), 10, "VerificationType", value)); -} - -const StringUtil::EnumStringLiteral *GetVerifyExistenceTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(VerifyExistenceType::APPEND), "APPEND" }, - { static_cast(VerifyExistenceType::APPEND_FK), "APPEND_FK" }, - { static_cast(VerifyExistenceType::DELETE_FK), "DELETE_FK" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(VerifyExistenceType value) { - return StringUtil::EnumToString(GetVerifyExistenceTypeValues(), 3, "VerifyExistenceType", static_cast(value)); -} - -template<> -VerifyExistenceType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetVerifyExistenceTypeValues(), 3, "VerifyExistenceType", value)); -} - -const StringUtil::EnumStringLiteral *GetWALTypeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(WALType::INVALID), "INVALID" }, - { static_cast(WALType::CREATE_TABLE), "CREATE_TABLE" }, - { static_cast(WALType::DROP_TABLE), "DROP_TABLE" }, - { static_cast(WALType::CREATE_SCHEMA), "CREATE_SCHEMA" }, - { static_cast(WALType::DROP_SCHEMA), "DROP_SCHEMA" }, - { static_cast(WALType::CREATE_VIEW), "CREATE_VIEW" }, - { static_cast(WALType::DROP_VIEW), "DROP_VIEW" }, - { static_cast(WALType::CREATE_SEQUENCE), "CREATE_SEQUENCE" }, - { static_cast(WALType::DROP_SEQUENCE), "DROP_SEQUENCE" }, - { static_cast(WALType::SEQUENCE_VALUE), "SEQUENCE_VALUE" }, - { static_cast(WALType::CREATE_MACRO), "CREATE_MACRO" }, - { static_cast(WALType::DROP_MACRO), "DROP_MACRO" }, - { static_cast(WALType::CREATE_TYPE), "CREATE_TYPE" }, - { static_cast(WALType::DROP_TYPE), "DROP_TYPE" }, - { static_cast(WALType::ALTER_INFO), "ALTER_INFO" }, - { static_cast(WALType::CREATE_TABLE_MACRO), "CREATE_TABLE_MACRO" }, - { static_cast(WALType::DROP_TABLE_MACRO), "DROP_TABLE_MACRO" }, - { static_cast(WALType::CREATE_INDEX), "CREATE_INDEX" }, - { static_cast(WALType::DROP_INDEX), "DROP_INDEX" }, - { static_cast(WALType::USE_TABLE), "USE_TABLE" }, - { static_cast(WALType::INSERT_TUPLE), "INSERT_TUPLE" }, - { static_cast(WALType::DELETE_TUPLE), "DELETE_TUPLE" }, - { static_cast(WALType::UPDATE_TUPLE), "UPDATE_TUPLE" }, - { static_cast(WALType::ROW_GROUP_DATA), "ROW_GROUP_DATA" }, - { static_cast(WALType::WAL_VERSION), "WAL_VERSION" }, - { static_cast(WALType::CHECKPOINT), "CHECKPOINT" }, - { static_cast(WALType::WAL_FLUSH), "WAL_FLUSH" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(WALType value) { - return StringUtil::EnumToString(GetWALTypeValues(), 27, "WALType", static_cast(value)); -} - -template<> -WALType EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetWALTypeValues(), 27, "WALType", value)); -} - -const StringUtil::EnumStringLiteral *GetWindowAggregationModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(WindowAggregationMode::WINDOW), "WINDOW" }, - { static_cast(WindowAggregationMode::COMBINE), "COMBINE" }, - { static_cast(WindowAggregationMode::SEPARATE), "SEPARATE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(WindowAggregationMode value) { - return StringUtil::EnumToString(GetWindowAggregationModeValues(), 3, "WindowAggregationMode", static_cast(value)); -} - -template<> -WindowAggregationMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetWindowAggregationModeValues(), 3, "WindowAggregationMode", value)); -} - -const StringUtil::EnumStringLiteral *GetWindowBoundaryValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(WindowBoundary::INVALID), "INVALID" }, - { static_cast(WindowBoundary::UNBOUNDED_PRECEDING), "UNBOUNDED_PRECEDING" }, - { static_cast(WindowBoundary::UNBOUNDED_FOLLOWING), "UNBOUNDED_FOLLOWING" }, - { static_cast(WindowBoundary::CURRENT_ROW_RANGE), "CURRENT_ROW_RANGE" }, - { static_cast(WindowBoundary::CURRENT_ROW_ROWS), "CURRENT_ROW_ROWS" }, - { static_cast(WindowBoundary::EXPR_PRECEDING_ROWS), "EXPR_PRECEDING_ROWS" }, - { static_cast(WindowBoundary::EXPR_FOLLOWING_ROWS), "EXPR_FOLLOWING_ROWS" }, - { static_cast(WindowBoundary::EXPR_PRECEDING_RANGE), "EXPR_PRECEDING_RANGE" }, - { static_cast(WindowBoundary::EXPR_FOLLOWING_RANGE), "EXPR_FOLLOWING_RANGE" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(WindowBoundary value) { - return StringUtil::EnumToString(GetWindowBoundaryValues(), 9, "WindowBoundary", static_cast(value)); -} - -template<> -WindowBoundary EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetWindowBoundaryValues(), 9, "WindowBoundary", value)); -} - -const StringUtil::EnumStringLiteral *GetWindowExcludeModeValues() { - static constexpr StringUtil::EnumStringLiteral values[] { - { static_cast(WindowExcludeMode::NO_OTHER), "NO_OTHER" }, - { static_cast(WindowExcludeMode::CURRENT_ROW), "CURRENT_ROW" }, - { static_cast(WindowExcludeMode::GROUP), "GROUP" }, - { static_cast(WindowExcludeMode::TIES), "TIES" } - }; - return values; -} - -template<> -const char* EnumUtil::ToChars(WindowExcludeMode value) { - return StringUtil::EnumToString(GetWindowExcludeModeValues(), 4, "WindowExcludeMode", static_cast(value)); -} - -template<> -WindowExcludeMode EnumUtil::FromString(const char *value) { - return static_cast(StringUtil::StringToEnum(GetWindowExcludeModeValues(), 4, "WindowExcludeMode", value)); -} - -} - diff --git a/src/duckdb/src/common/enums/catalog_type.cpp b/src/duckdb/src/common/enums/catalog_type.cpp deleted file mode 100644 index 55dbcb41a..000000000 --- a/src/duckdb/src/common/enums/catalog_type.cpp +++ /dev/null @@ -1,114 +0,0 @@ -#include "duckdb/common/enums/catalog_type.hpp" - -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -// LCOV_EXCL_START -string CatalogTypeToString(CatalogType type) { - switch (type) { - case CatalogType::COLLATION_ENTRY: - return "Collation"; - case CatalogType::TYPE_ENTRY: - return "Type"; - case CatalogType::TABLE_ENTRY: - return "Table"; - case CatalogType::SCHEMA_ENTRY: - return "Schema"; - case CatalogType::DATABASE_ENTRY: - return "Database"; - case CatalogType::TABLE_FUNCTION_ENTRY: - return "Table Function"; - case CatalogType::SCALAR_FUNCTION_ENTRY: - return "Scalar Function"; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - return "Aggregate Function"; - case CatalogType::COPY_FUNCTION_ENTRY: - return "Copy Function"; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - return "Pragma Function"; - case CatalogType::MACRO_ENTRY: - return "Macro Function"; - case CatalogType::TABLE_MACRO_ENTRY: - return "Table Macro Function"; - case CatalogType::VIEW_ENTRY: - return "View"; - case CatalogType::INDEX_ENTRY: - return "Index"; - case CatalogType::PREPARED_STATEMENT: - return "Prepared Statement"; - case CatalogType::SEQUENCE_ENTRY: - return "Sequence"; - case CatalogType::SECRET_ENTRY: - return "Secret"; - case CatalogType::SECRET_TYPE_ENTRY: - return "Secret Type"; - case CatalogType::SECRET_FUNCTION_ENTRY: - return "Secret Function"; - case CatalogType::INVALID: - case CatalogType::DELETED_ENTRY: - case CatalogType::RENAMED_ENTRY: - case CatalogType::DEPENDENCY_ENTRY: - break; - } - return "INVALID"; -} - -CatalogType CatalogTypeFromString(const string &type) { - if (type == "Collation") { - return CatalogType::COLLATION_ENTRY; - } - if (type == "Type") { - return CatalogType::TYPE_ENTRY; - } - if (type == "Table") { - return CatalogType::TABLE_ENTRY; - } - if (type == "Schema") { - return CatalogType::SCHEMA_ENTRY; - } - if (type == "Database") { - return CatalogType::DATABASE_ENTRY; - } - if (type == "Table Function") { - return CatalogType::TABLE_FUNCTION_ENTRY; - } - if (type == "Scalar Function") { - return CatalogType::SCALAR_FUNCTION_ENTRY; - } - if (type == "Aggregate Function") { - return CatalogType::AGGREGATE_FUNCTION_ENTRY; - } - if (type == "Copy Function") { - return CatalogType::COPY_FUNCTION_ENTRY; - } - if (type == "Pragma Function") { - return CatalogType::PRAGMA_FUNCTION_ENTRY; - } - if (type == "Macro Function") { - return CatalogType::MACRO_ENTRY; - } - if (type == "Table Macro Function") { - return CatalogType::TABLE_MACRO_ENTRY; - } - if (type == "View") { - return CatalogType::VIEW_ENTRY; - } - if (type == "Index") { - return CatalogType::INDEX_ENTRY; - } - if (type == "Prepared Statement") { - return CatalogType::PREPARED_STATEMENT; - } - if (type == "Sequence") { - return CatalogType::SEQUENCE_ENTRY; - } - if (type == "INVALID") { - return CatalogType::INVALID; - } - throw InternalException("Unrecognized CatalogType '%s'", type); -} - -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/compression_type.cpp b/src/duckdb/src/common/enums/compression_type.cpp deleted file mode 100644 index 3bc66fe08..000000000 --- a/src/duckdb/src/common/enums/compression_type.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "duckdb/common/enums/compression_type.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -// LCOV_EXCL_START - -vector ListCompressionTypes(void) { - vector compression_types; - uint8_t amount_of_compression_options = (uint8_t)CompressionType::COMPRESSION_COUNT; - compression_types.reserve(amount_of_compression_options); - for (uint8_t i = 0; i < amount_of_compression_options; i++) { - compression_types.push_back(CompressionTypeToString((CompressionType)i)); - } - return compression_types; -} - -bool CompressionTypeIsDeprecated(CompressionType compression_type) { - const bool is_patas = compression_type == CompressionType::COMPRESSION_PATAS; - const bool is_chimp = compression_type == CompressionType::COMPRESSION_CHIMP; - return (is_patas || is_chimp); -} - -CompressionType CompressionTypeFromString(const string &str) { - auto compression = StringUtil::Lower(str); - //! NOTE: this explicitly does not include 'constant' and 'empty validity', these are internal compression functions - //! not general purpose - if (compression == "uncompressed") { - return CompressionType::COMPRESSION_UNCOMPRESSED; - } else if (compression == "rle") { - return CompressionType::COMPRESSION_RLE; - } else if (compression == "dictionary") { - return CompressionType::COMPRESSION_DICTIONARY; - } else if (compression == "pfor") { - return CompressionType::COMPRESSION_PFOR_DELTA; - } else if (compression == "bitpacking") { - return CompressionType::COMPRESSION_BITPACKING; - } else if (compression == "fsst") { - return CompressionType::COMPRESSION_FSST; - } else if (compression == "chimp") { - return CompressionType::COMPRESSION_CHIMP; - } else if (compression == "patas") { - return CompressionType::COMPRESSION_PATAS; - } else if (compression == "zstd") { - return CompressionType::COMPRESSION_ZSTD; - } else if (compression == "alp") { - return CompressionType::COMPRESSION_ALP; - } else if (compression == "alprd") { - return CompressionType::COMPRESSION_ALPRD; - } else if (compression == "roaring") { - return CompressionType::COMPRESSION_ROARING; - } else { - return CompressionType::COMPRESSION_AUTO; - } -} - -string CompressionTypeToString(CompressionType type) { - switch (type) { - case CompressionType::COMPRESSION_AUTO: - return "Auto"; - case CompressionType::COMPRESSION_UNCOMPRESSED: - return "Uncompressed"; - case CompressionType::COMPRESSION_CONSTANT: - return "Constant"; - case CompressionType::COMPRESSION_RLE: - return "RLE"; - case CompressionType::COMPRESSION_DICTIONARY: - return "Dictionary"; - case CompressionType::COMPRESSION_PFOR_DELTA: - return "PFOR"; - case CompressionType::COMPRESSION_BITPACKING: - return "BitPacking"; - case CompressionType::COMPRESSION_FSST: - return "FSST"; - case CompressionType::COMPRESSION_CHIMP: - return "Chimp"; - case CompressionType::COMPRESSION_PATAS: - return "Patas"; - case CompressionType::COMPRESSION_ZSTD: - return "ZSTD"; - case CompressionType::COMPRESSION_ALP: - return "ALP"; - case CompressionType::COMPRESSION_ALPRD: - return "ALPRD"; - case CompressionType::COMPRESSION_ROARING: - return "Roaring"; - case CompressionType::COMPRESSION_EMPTY: - return "Empty Validity"; - default: - throw InternalException("Unrecognized compression type!"); - } -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/date_part_specifier.cpp b/src/duckdb/src/common/enums/date_part_specifier.cpp deleted file mode 100644 index 032a82165..000000000 --- a/src/duckdb/src/common/enums/date_part_specifier.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" - -namespace duckdb { - -bool TryGetDatePartSpecifier(const string &specifier_p, DatePartSpecifier &result) { - auto specifier = StringUtil::Lower(specifier_p); - if (specifier == "year" || specifier == "yr" || specifier == "y" || specifier == "years" || specifier == "yrs") { - result = DatePartSpecifier::YEAR; - } else if (specifier == "month" || specifier == "mon" || specifier == "months" || specifier == "mons") { - result = DatePartSpecifier::MONTH; - } else if (specifier == "day" || specifier == "days" || specifier == "d" || specifier == "dayofmonth") { - result = DatePartSpecifier::DAY; - } else if (specifier == "decade" || specifier == "dec" || specifier == "decades" || specifier == "decs") { - result = DatePartSpecifier::DECADE; - } else if (specifier == "century" || specifier == "cent" || specifier == "centuries" || specifier == "c") { - result = DatePartSpecifier::CENTURY; - } else if (specifier == "millennium" || specifier == "mil" || specifier == "millenniums" || - specifier == "millennia" || specifier == "mils" || specifier == "millenium") { - result = DatePartSpecifier::MILLENNIUM; - } else if (specifier == "microseconds" || specifier == "microsecond" || specifier == "us" || specifier == "usec" || - specifier == "usecs" || specifier == "usecond" || specifier == "useconds") { - result = DatePartSpecifier::MICROSECONDS; - } else if (specifier == "milliseconds" || specifier == "millisecond" || specifier == "ms" || specifier == "msec" || - specifier == "msecs" || specifier == "msecond" || specifier == "mseconds") { - result = DatePartSpecifier::MILLISECONDS; - } else if (specifier == "second" || specifier == "sec" || specifier == "seconds" || specifier == "secs" || - specifier == "s") { - result = DatePartSpecifier::SECOND; - } else if (specifier == "minute" || specifier == "min" || specifier == "minutes" || specifier == "mins" || - specifier == "m") { - result = DatePartSpecifier::MINUTE; - } else if (specifier == "hour" || specifier == "hr" || specifier == "hours" || specifier == "hrs" || - specifier == "h") { - result = DatePartSpecifier::HOUR; - } else if (specifier == "epoch") { - // seconds since 1970-01-01 - result = DatePartSpecifier::EPOCH; - } else if (specifier == "dow" || specifier == "dayofweek" || specifier == "weekday") { - // day of the week (Sunday = 0, Saturday = 6) - result = DatePartSpecifier::DOW; - } else if (specifier == "isodow") { - // isodow (Monday = 1, Sunday = 7) - result = DatePartSpecifier::ISODOW; - } else if (specifier == "week" || specifier == "weeks" || specifier == "w" || specifier == "weekofyear") { - // ISO week number - result = DatePartSpecifier::WEEK; - } else if (specifier == "doy" || specifier == "dayofyear") { - // day of the year (1-365/366) - result = DatePartSpecifier::DOY; - } else if (specifier == "quarter" || specifier == "quarters") { - // quarter of the year (1-4) - result = DatePartSpecifier::QUARTER; - } else if (specifier == "yearweek") { - // Combined isoyear and isoweek YYYYWW - result = DatePartSpecifier::YEARWEEK; - } else if (specifier == "isoyear") { - // ISO year (first week of the year may be in previous year) - result = DatePartSpecifier::ISOYEAR; - } else if (specifier == "era") { - result = DatePartSpecifier::ERA; - } else if (specifier == "timezone") { - result = DatePartSpecifier::TIMEZONE; - } else if (specifier == "timezone_hour") { - result = DatePartSpecifier::TIMEZONE_HOUR; - } else if (specifier == "timezone_minute") { - result = DatePartSpecifier::TIMEZONE_MINUTE; - } else if (specifier == "julian" || specifier == "jd") { - result = DatePartSpecifier::JULIAN_DAY; - } else { - return false; - } - return true; -} - -DatePartSpecifier GetDatePartSpecifier(const string &specifier) { - DatePartSpecifier result; - if (!TryGetDatePartSpecifier(specifier, result)) { - throw ConversionException("extract specifier \"%s\" not recognized", specifier); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/expression_type.cpp b/src/duckdb/src/common/enums/expression_type.cpp deleted file mode 100644 index 105f4dfa7..000000000 --- a/src/duckdb/src/common/enums/expression_type.cpp +++ /dev/null @@ -1,334 +0,0 @@ -#include "duckdb/common/enums/expression_type.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -string ExpressionTypeToString(ExpressionType type) { - switch (type) { - case ExpressionType::OPERATOR_CAST: - return "CAST"; - case ExpressionType::OPERATOR_NOT: - return "NOT"; - case ExpressionType::OPERATOR_IS_NULL: - return "IS_NULL"; - case ExpressionType::OPERATOR_IS_NOT_NULL: - return "IS_NOT_NULL"; - case ExpressionType::COMPARE_EQUAL: - return "EQUAL"; - case ExpressionType::COMPARE_NOTEQUAL: - return "NOTEQUAL"; - case ExpressionType::COMPARE_LESSTHAN: - return "LESSTHAN"; - case ExpressionType::COMPARE_GREATERTHAN: - return "GREATERTHAN"; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return "LESSTHANOREQUALTO"; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return "GREATERTHANOREQUALTO"; - case ExpressionType::COMPARE_IN: - return "IN"; - case ExpressionType::COMPARE_DISTINCT_FROM: - return "DISTINCT_FROM"; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return "NOT_DISTINCT_FROM"; - case ExpressionType::CONJUNCTION_AND: - return "AND"; - case ExpressionType::CONJUNCTION_OR: - return "OR"; - case ExpressionType::VALUE_CONSTANT: - return "CONSTANT"; - case ExpressionType::VALUE_PARAMETER: - return "PARAMETER"; - case ExpressionType::VALUE_TUPLE: - return "TUPLE"; - case ExpressionType::VALUE_TUPLE_ADDRESS: - return "TUPLE_ADDRESS"; - case ExpressionType::VALUE_NULL: - return "NULL"; - case ExpressionType::VALUE_VECTOR: - return "VECTOR"; - case ExpressionType::VALUE_SCALAR: - return "SCALAR"; - case ExpressionType::AGGREGATE: - return "AGGREGATE"; - case ExpressionType::WINDOW_AGGREGATE: - return "WINDOW_AGGREGATE"; - case ExpressionType::WINDOW_RANK: - return "RANK"; - case ExpressionType::WINDOW_RANK_DENSE: - return "RANK_DENSE"; - case ExpressionType::WINDOW_PERCENT_RANK: - return "PERCENT_RANK"; - case ExpressionType::WINDOW_ROW_NUMBER: - return "ROW_NUMBER"; - case ExpressionType::WINDOW_FIRST_VALUE: - return "FIRST_VALUE"; - case ExpressionType::WINDOW_LAST_VALUE: - return "LAST_VALUE"; - case ExpressionType::WINDOW_NTH_VALUE: - return "NTH_VALUE"; - case ExpressionType::WINDOW_CUME_DIST: - return "CUME_DIST"; - case ExpressionType::WINDOW_LEAD: - return "LEAD"; - case ExpressionType::WINDOW_LAG: - return "LAG"; - case ExpressionType::WINDOW_NTILE: - return "NTILE"; - case ExpressionType::FUNCTION: - return "FUNCTION"; - case ExpressionType::CASE_EXPR: - return "CASE"; - case ExpressionType::OPERATOR_NULLIF: - return "NULLIF"; - case ExpressionType::OPERATOR_COALESCE: - return "COALESCE"; - case ExpressionType::ARRAY_EXTRACT: - return "ARRAY_EXTRACT"; - case ExpressionType::ARRAY_SLICE: - return "ARRAY_SLICE"; - case ExpressionType::STRUCT_EXTRACT: - return "STRUCT_EXTRACT"; - case ExpressionType::SUBQUERY: - return "SUBQUERY"; - case ExpressionType::STAR: - return "STAR"; - case ExpressionType::PLACEHOLDER: - return "PLACEHOLDER"; - case ExpressionType::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionType::LAMBDA_REF: - return "LAMBDA_REF"; - case ExpressionType::FUNCTION_REF: - return "FUNCTION_REF"; - case ExpressionType::TABLE_REF: - return "TABLE_REF"; - case ExpressionType::CAST: - return "CAST"; - case ExpressionType::COMPARE_NOT_IN: - return "COMPARE_NOT_IN"; - case ExpressionType::COMPARE_BETWEEN: - return "COMPARE_BETWEEN"; - case ExpressionType::COMPARE_NOT_BETWEEN: - return "COMPARE_NOT_BETWEEN"; - case ExpressionType::VALUE_DEFAULT: - return "VALUE_DEFAULT"; - case ExpressionType::BOUND_REF: - return "BOUND_REF"; - case ExpressionType::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionType::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionType::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionType::GROUPING_FUNCTION: - return "GROUPING"; - case ExpressionType::ARRAY_CONSTRUCTOR: - return "ARRAY_CONSTRUCTOR"; - case ExpressionType::TABLE_STAR: - return "TABLE_STAR"; - case ExpressionType::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionType::COLLATE: - return "COLLATE"; - case ExpressionType::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionType::BOUND_LAMBDA_REF: - return "BOUND_LAMBDA_REF"; - case ExpressionType::LAMBDA: - return "LAMBDA"; - case ExpressionType::ARROW: - return "ARROW"; - case ExpressionType::BOUND_EXPANDED: - return "BOUND_EXPANDED"; - case ExpressionType::INVALID: - break; - } - return "INVALID"; -} -string ExpressionClassToString(ExpressionClass type) { - switch (type) { - case ExpressionClass::INVALID: - return "INVALID"; - case ExpressionClass::AGGREGATE: - return "AGGREGATE"; - case ExpressionClass::CASE: - return "CASE"; - case ExpressionClass::CAST: - return "CAST"; - case ExpressionClass::COLUMN_REF: - return "COLUMN_REF"; - case ExpressionClass::LAMBDA_REF: - return "LAMBDA_REF"; - case ExpressionClass::COMPARISON: - return "COMPARISON"; - case ExpressionClass::CONJUNCTION: - return "CONJUNCTION"; - case ExpressionClass::CONSTANT: - return "CONSTANT"; - case ExpressionClass::DEFAULT: - return "DEFAULT"; - case ExpressionClass::FUNCTION: - return "FUNCTION"; - case ExpressionClass::OPERATOR: - return "OPERATOR"; - case ExpressionClass::STAR: - return "STAR"; - case ExpressionClass::SUBQUERY: - return "SUBQUERY"; - case ExpressionClass::WINDOW: - return "WINDOW"; - case ExpressionClass::PARAMETER: - return "PARAMETER"; - case ExpressionClass::COLLATE: - return "COLLATE"; - case ExpressionClass::LAMBDA: - return "LAMBDA"; - case ExpressionClass::POSITIONAL_REFERENCE: - return "POSITIONAL_REFERENCE"; - case ExpressionClass::BETWEEN: - return "BETWEEN"; - case ExpressionClass::BOUND_AGGREGATE: - return "BOUND_AGGREGATE"; - case ExpressionClass::BOUND_CASE: - return "BOUND_CASE"; - case ExpressionClass::BOUND_CAST: - return "BOUND_CAST"; - case ExpressionClass::BOUND_COLUMN_REF: - return "BOUND_COLUMN_REF"; - case ExpressionClass::BOUND_COMPARISON: - return "BOUND_COMPARISON"; - case ExpressionClass::BOUND_CONJUNCTION: - return "BOUND_CONJUNCTION"; - case ExpressionClass::BOUND_CONSTANT: - return "BOUND_CONSTANT"; - case ExpressionClass::BOUND_DEFAULT: - return "BOUND_DEFAULT"; - case ExpressionClass::BOUND_FUNCTION: - return "BOUND_FUNCTION"; - case ExpressionClass::BOUND_OPERATOR: - return "BOUND_OPERATOR"; - case ExpressionClass::BOUND_PARAMETER: - return "BOUND_PARAMETER"; - case ExpressionClass::BOUND_REF: - return "BOUND_REF"; - case ExpressionClass::BOUND_SUBQUERY: - return "BOUND_SUBQUERY"; - case ExpressionClass::BOUND_WINDOW: - return "BOUND_WINDOW"; - case ExpressionClass::BOUND_BETWEEN: - return "BOUND_BETWEEN"; - case ExpressionClass::BOUND_UNNEST: - return "BOUND_UNNEST"; - case ExpressionClass::BOUND_LAMBDA: - return "BOUND_LAMBDA"; - case ExpressionClass::BOUND_EXPRESSION: - return "BOUND_EXPRESSION"; - case ExpressionClass::BOUND_EXPANDED: - return "BOUND_EXPANDED"; - default: - return "ExpressionClass::!!UNIMPLEMENTED_CASE!!"; - } -} - -string ExpressionTypeToOperator(ExpressionType type) { - switch (type) { - case ExpressionType::COMPARE_EQUAL: - return "="; - case ExpressionType::COMPARE_NOTEQUAL: - return "!="; - case ExpressionType::COMPARE_LESSTHAN: - return "<"; - case ExpressionType::COMPARE_GREATERTHAN: - return ">"; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return "<="; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return ">="; - case ExpressionType::COMPARE_DISTINCT_FROM: - return "IS DISTINCT FROM"; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return "IS NOT DISTINCT FROM"; - case ExpressionType::CONJUNCTION_AND: - return "AND"; - case ExpressionType::CONJUNCTION_OR: - return "OR"; - default: - return ""; - } -} - -ExpressionType NegateComparisonExpression(ExpressionType type) { - ExpressionType negated_type = ExpressionType::INVALID; - switch (type) { - case ExpressionType::COMPARE_EQUAL: - negated_type = ExpressionType::COMPARE_NOTEQUAL; - break; - case ExpressionType::COMPARE_NOTEQUAL: - negated_type = ExpressionType::COMPARE_EQUAL; - break; - case ExpressionType::COMPARE_LESSTHAN: - negated_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHAN: - negated_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - negated_type = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - negated_type = ExpressionType::COMPARE_LESSTHAN; - break; - default: - throw InternalException("Unsupported comparison type in negation"); - } - return negated_type; -} - -ExpressionType FlipComparisonExpression(ExpressionType type) { - ExpressionType flipped_type = ExpressionType::INVALID; - switch (type) { - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - case ExpressionType::COMPARE_DISTINCT_FROM: - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_EQUAL: - flipped_type = type; - break; - case ExpressionType::COMPARE_LESSTHAN: - flipped_type = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_GREATERTHAN: - flipped_type = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - flipped_type = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - flipped_type = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - default: - throw InternalException("Unsupported comparison type in flip"); - } - return flipped_type; -} - -ExpressionType OperatorToExpressionType(const string &op) { - if (op == "=" || op == "==") { - return ExpressionType::COMPARE_EQUAL; - } else if (op == "!=" || op == "<>") { - return ExpressionType::COMPARE_NOTEQUAL; - } else if (op == "<") { - return ExpressionType::COMPARE_LESSTHAN; - } else if (op == ">") { - return ExpressionType::COMPARE_GREATERTHAN; - } else if (op == "<=") { - return ExpressionType::COMPARE_LESSTHANOREQUALTO; - } else if (op == ">=") { - return ExpressionType::COMPARE_GREATERTHANOREQUALTO; - } - return ExpressionType::INVALID; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/file_compression_type.cpp b/src/duckdb/src/common/enums/file_compression_type.cpp deleted file mode 100644 index 44066f32c..000000000 --- a/src/duckdb/src/common/enums/file_compression_type.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "duckdb/common/enums/file_compression_type.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/exception/parser_exception.hpp" - -namespace duckdb { - -FileCompressionType FileCompressionTypeFromString(const string &input) { - auto parameter = StringUtil::Lower(input); - if (parameter == "infer" || parameter == "auto") { - return FileCompressionType::AUTO_DETECT; - } else if (parameter == "gzip") { - return FileCompressionType::GZIP; - } else if (parameter == "zstd") { - return FileCompressionType::ZSTD; - } else if (parameter == "uncompressed" || parameter == "none" || parameter.empty()) { - return FileCompressionType::UNCOMPRESSED; - } else { - throw ParserException("Unrecognized file compression type \"%s\"", input); - } -} - -string CompressionExtensionFromType(const FileCompressionType type) { - switch (type) { - case FileCompressionType::GZIP: - return ".gz"; - case FileCompressionType::ZSTD: - return ".zst"; - default: - throw NotImplementedException("Compression Extension of file compression type is not implemented"); - } -} - -bool IsFileCompressed(string path, FileCompressionType type) { - auto extension = CompressionExtensionFromType(type); - std::size_t question_mark_pos = std::string::npos; - if (!StringUtil::StartsWith(path, "\\\\?\\")) { - question_mark_pos = path.find('?'); - } - path = path.substr(0, question_mark_pos); - if (StringUtil::EndsWith(path, extension)) { - return true; - } - return false; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/join_type.cpp b/src/duckdb/src/common/enums/join_type.cpp deleted file mode 100644 index 6bfe54892..000000000 --- a/src/duckdb/src/common/enums/join_type.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "duckdb/common/enums/join_type.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -bool IsLeftOuterJoin(JoinType type) { - return type == JoinType::LEFT || type == JoinType::OUTER; -} - -bool IsRightOuterJoin(JoinType type) { - return type == JoinType::OUTER || type == JoinType::RIGHT; -} - -bool PropagatesBuildSide(JoinType type) { - return type == JoinType::OUTER || type == JoinType::RIGHT || type == JoinType::RIGHT_ANTI || - type == JoinType::RIGHT_SEMI; -} - -bool HasInverseJoinType(JoinType type) { - return type != JoinType::SINGLE && type != JoinType::MARK; -} - -JoinType InverseJoinType(JoinType type) { - D_ASSERT(HasInverseJoinType(type)); - switch (type) { - case JoinType::LEFT: - return JoinType::RIGHT; - case JoinType::RIGHT: - return JoinType::LEFT; - case JoinType::INNER: - return JoinType::INNER; - case JoinType::OUTER: - return JoinType::OUTER; - case JoinType::SEMI: - return JoinType::RIGHT_SEMI; - case JoinType::ANTI: - return JoinType::RIGHT_ANTI; - case JoinType::RIGHT_SEMI: - return JoinType::SEMI; - case JoinType::RIGHT_ANTI: - return JoinType::ANTI; - default: - throw NotImplementedException("InverseJoinType for JoinType::%s", EnumUtil::ToString(type)); - } -} - -// **DEPRECATED**: Use EnumUtil directly instead. -string JoinTypeToString(JoinType type) { - return EnumUtil::ToString(type); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/logical_operator_type.cpp b/src/duckdb/src/common/enums/logical_operator_type.cpp deleted file mode 100644 index 550bea5be..000000000 --- a/src/duckdb/src/common/enums/logical_operator_type.cpp +++ /dev/null @@ -1,138 +0,0 @@ -#include "duckdb/common/enums/logical_operator_type.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Value <--> String Utilities -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -string LogicalOperatorToString(LogicalOperatorType type) { - switch (type) { - case LogicalOperatorType::LOGICAL_GET: - return "GET"; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - return "CHUNK_GET"; - case LogicalOperatorType::LOGICAL_DELIM_GET: - return "DELIM_GET"; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - return "EMPTY_RESULT"; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - return "EXPRESSION_GET"; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - return "ANY_JOIN"; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return "ASOF_JOIN"; - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - return "DEPENDENT_JOIN"; - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - return "COMPARISON_JOIN"; - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return "DELIM_JOIN"; - case LogicalOperatorType::LOGICAL_PROJECTION: - return "PROJECTION"; - case LogicalOperatorType::LOGICAL_FILTER: - return "FILTER"; - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - return "AGGREGATE"; - case LogicalOperatorType::LOGICAL_WINDOW: - return "WINDOW"; - case LogicalOperatorType::LOGICAL_UNNEST: - return "UNNEST"; - case LogicalOperatorType::LOGICAL_LIMIT: - return "LIMIT"; - case LogicalOperatorType::LOGICAL_ORDER_BY: - return "ORDER_BY"; - case LogicalOperatorType::LOGICAL_TOP_N: - return "TOP_N"; - case LogicalOperatorType::LOGICAL_SAMPLE: - return "SAMPLE"; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - return "COPY_TO_FILE"; - case LogicalOperatorType::LOGICAL_COPY_DATABASE: - return "COPY_DATABASE"; - case LogicalOperatorType::LOGICAL_JOIN: - return "JOIN"; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - return "CROSS_PRODUCT"; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - return "POSITIONAL_JOIN"; - case LogicalOperatorType::LOGICAL_UNION: - return "UNION"; - case LogicalOperatorType::LOGICAL_EXCEPT: - return "EXCEPT"; - case LogicalOperatorType::LOGICAL_INTERSECT: - return "INTERSECT"; - case LogicalOperatorType::LOGICAL_INSERT: - return "INSERT"; - case LogicalOperatorType::LOGICAL_DISTINCT: - return "DISTINCT"; - case LogicalOperatorType::LOGICAL_DELETE: - return "DELETE"; - case LogicalOperatorType::LOGICAL_UPDATE: - return "UPDATE"; - case LogicalOperatorType::LOGICAL_PREPARE: - return "PREPARE"; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - return "DUMMY_SCAN"; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - return "CREATE_INDEX"; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - return "CREATE_TABLE"; - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - return "CREATE_MACRO"; - case LogicalOperatorType::LOGICAL_EXPLAIN: - return "EXPLAIN"; - case LogicalOperatorType::LOGICAL_EXECUTE: - return "EXECUTE"; - case LogicalOperatorType::LOGICAL_VACUUM: - return "VACUUM"; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - return "REC_CTE"; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - return "CTE"; - case LogicalOperatorType::LOGICAL_CTE_REF: - return "CTE_SCAN"; - case LogicalOperatorType::LOGICAL_ALTER: - return "ALTER"; - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - return "CREATE_TYPE"; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - return "CREATE_VIEW"; - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case LogicalOperatorType::LOGICAL_CREATE_SECRET: - return "CREATE_SECRET"; - case LogicalOperatorType::LOGICAL_ATTACH: - return "ATTACH"; - case LogicalOperatorType::LOGICAL_DETACH: - return "DETACH"; - case LogicalOperatorType::LOGICAL_DROP: - return "DROP"; - case LogicalOperatorType::LOGICAL_PRAGMA: - return "PRAGMA"; - case LogicalOperatorType::LOGICAL_TRANSACTION: - return "TRANSACTION"; - case LogicalOperatorType::LOGICAL_EXPORT: - return "EXPORT"; - case LogicalOperatorType::LOGICAL_SET: - return "SET"; - case LogicalOperatorType::LOGICAL_RESET: - return "RESET"; - case LogicalOperatorType::LOGICAL_LOAD: - return "LOAD"; - case LogicalOperatorType::LOGICAL_INVALID: - break; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - return "CUSTOM_OP"; - case LogicalOperatorType::LOGICAL_PIVOT: - return "PIVOT"; - case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: - return "UPDATE_EXTENSIONS"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/metric_type.cpp b/src/duckdb/src/common/enums/metric_type.cpp deleted file mode 100644 index 94a75cc1e..000000000 --- a/src/duckdb/src/common/enums/metric_type.cpp +++ /dev/null @@ -1,226 +0,0 @@ -//------------------------------------------------------------------------- -// DuckDB -// -// -// duckdb/common/enums/metrics_type.hpp -// -// This file is automatically generated by scripts/generate_metric_enums.py -// Do not edit this file manually, your changes will be overwritten -//------------------------------------------------------------------------- - -#include "duckdb/common/enums/metric_type.hpp" -namespace duckdb { - -profiler_settings_t MetricsUtils::GetOptimizerMetrics() { - return { - MetricsType::OPTIMIZER_EXPRESSION_REWRITER, - MetricsType::OPTIMIZER_FILTER_PULLUP, - MetricsType::OPTIMIZER_FILTER_PUSHDOWN, - MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP, - MetricsType::OPTIMIZER_CTE_FILTER_PUSHER, - MetricsType::OPTIMIZER_REGEX_RANGE, - MetricsType::OPTIMIZER_IN_CLAUSE, - MetricsType::OPTIMIZER_JOIN_ORDER, - MetricsType::OPTIMIZER_DELIMINATOR, - MetricsType::OPTIMIZER_UNNEST_REWRITER, - MetricsType::OPTIMIZER_UNUSED_COLUMNS, - MetricsType::OPTIMIZER_STATISTICS_PROPAGATION, - MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS, - MetricsType::OPTIMIZER_COMMON_AGGREGATE, - MetricsType::OPTIMIZER_COLUMN_LIFETIME, - MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE, - MetricsType::OPTIMIZER_LIMIT_PUSHDOWN, - MetricsType::OPTIMIZER_TOP_N, - MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION, - MetricsType::OPTIMIZER_DUPLICATE_GROUPS, - MetricsType::OPTIMIZER_REORDER_FILTER, - MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN, - MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN, - MetricsType::OPTIMIZER_EXTENSION, - MetricsType::OPTIMIZER_MATERIALIZED_CTE, - MetricsType::OPTIMIZER_SUM_REWRITER, - }; -} - -profiler_settings_t MetricsUtils::GetPhaseTimingMetrics() { - return { - MetricsType::ALL_OPTIMIZERS, - MetricsType::CUMULATIVE_OPTIMIZER_TIMING, - MetricsType::PLANNER, - MetricsType::PLANNER_BINDING, - MetricsType::PHYSICAL_PLANNER, - MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING, - MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES, - MetricsType::PHYSICAL_PLANNER_CREATE_PLAN, - }; -} - -MetricsType MetricsUtils::GetOptimizerMetricByType(OptimizerType type) { - switch(type) { - case OptimizerType::EXPRESSION_REWRITER: - return MetricsType::OPTIMIZER_EXPRESSION_REWRITER; - case OptimizerType::FILTER_PULLUP: - return MetricsType::OPTIMIZER_FILTER_PULLUP; - case OptimizerType::FILTER_PUSHDOWN: - return MetricsType::OPTIMIZER_FILTER_PUSHDOWN; - case OptimizerType::EMPTY_RESULT_PULLUP: - return MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP; - case OptimizerType::CTE_FILTER_PUSHER: - return MetricsType::OPTIMIZER_CTE_FILTER_PUSHER; - case OptimizerType::REGEX_RANGE: - return MetricsType::OPTIMIZER_REGEX_RANGE; - case OptimizerType::IN_CLAUSE: - return MetricsType::OPTIMIZER_IN_CLAUSE; - case OptimizerType::JOIN_ORDER: - return MetricsType::OPTIMIZER_JOIN_ORDER; - case OptimizerType::DELIMINATOR: - return MetricsType::OPTIMIZER_DELIMINATOR; - case OptimizerType::UNNEST_REWRITER: - return MetricsType::OPTIMIZER_UNNEST_REWRITER; - case OptimizerType::UNUSED_COLUMNS: - return MetricsType::OPTIMIZER_UNUSED_COLUMNS; - case OptimizerType::STATISTICS_PROPAGATION: - return MetricsType::OPTIMIZER_STATISTICS_PROPAGATION; - case OptimizerType::COMMON_SUBEXPRESSIONS: - return MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS; - case OptimizerType::COMMON_AGGREGATE: - return MetricsType::OPTIMIZER_COMMON_AGGREGATE; - case OptimizerType::COLUMN_LIFETIME: - return MetricsType::OPTIMIZER_COLUMN_LIFETIME; - case OptimizerType::BUILD_SIDE_PROBE_SIDE: - return MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE; - case OptimizerType::LIMIT_PUSHDOWN: - return MetricsType::OPTIMIZER_LIMIT_PUSHDOWN; - case OptimizerType::TOP_N: - return MetricsType::OPTIMIZER_TOP_N; - case OptimizerType::COMPRESSED_MATERIALIZATION: - return MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION; - case OptimizerType::DUPLICATE_GROUPS: - return MetricsType::OPTIMIZER_DUPLICATE_GROUPS; - case OptimizerType::REORDER_FILTER: - return MetricsType::OPTIMIZER_REORDER_FILTER; - case OptimizerType::SAMPLING_PUSHDOWN: - return MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN; - case OptimizerType::JOIN_FILTER_PUSHDOWN: - return MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN; - case OptimizerType::EXTENSION: - return MetricsType::OPTIMIZER_EXTENSION; - case OptimizerType::MATERIALIZED_CTE: - return MetricsType::OPTIMIZER_MATERIALIZED_CTE; - case OptimizerType::SUM_REWRITER: - return MetricsType::OPTIMIZER_SUM_REWRITER; - default: - throw InternalException("OptimizerType %s cannot be converted to a MetricsType", EnumUtil::ToString(type)); - }; -} - -OptimizerType MetricsUtils::GetOptimizerTypeByMetric(MetricsType type) { - switch(type) { - case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: - return OptimizerType::EXPRESSION_REWRITER; - case MetricsType::OPTIMIZER_FILTER_PULLUP: - return OptimizerType::FILTER_PULLUP; - case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: - return OptimizerType::FILTER_PUSHDOWN; - case MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP: - return OptimizerType::EMPTY_RESULT_PULLUP; - case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: - return OptimizerType::CTE_FILTER_PUSHER; - case MetricsType::OPTIMIZER_REGEX_RANGE: - return OptimizerType::REGEX_RANGE; - case MetricsType::OPTIMIZER_IN_CLAUSE: - return OptimizerType::IN_CLAUSE; - case MetricsType::OPTIMIZER_JOIN_ORDER: - return OptimizerType::JOIN_ORDER; - case MetricsType::OPTIMIZER_DELIMINATOR: - return OptimizerType::DELIMINATOR; - case MetricsType::OPTIMIZER_UNNEST_REWRITER: - return OptimizerType::UNNEST_REWRITER; - case MetricsType::OPTIMIZER_UNUSED_COLUMNS: - return OptimizerType::UNUSED_COLUMNS; - case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: - return OptimizerType::STATISTICS_PROPAGATION; - case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - return OptimizerType::COMMON_SUBEXPRESSIONS; - case MetricsType::OPTIMIZER_COMMON_AGGREGATE: - return OptimizerType::COMMON_AGGREGATE; - case MetricsType::OPTIMIZER_COLUMN_LIFETIME: - return OptimizerType::COLUMN_LIFETIME; - case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - return OptimizerType::BUILD_SIDE_PROBE_SIDE; - case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: - return OptimizerType::LIMIT_PUSHDOWN; - case MetricsType::OPTIMIZER_TOP_N: - return OptimizerType::TOP_N; - case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - return OptimizerType::COMPRESSED_MATERIALIZATION; - case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: - return OptimizerType::DUPLICATE_GROUPS; - case MetricsType::OPTIMIZER_REORDER_FILTER: - return OptimizerType::REORDER_FILTER; - case MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN: - return OptimizerType::SAMPLING_PUSHDOWN; - case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - return OptimizerType::JOIN_FILTER_PUSHDOWN; - case MetricsType::OPTIMIZER_EXTENSION: - return OptimizerType::EXTENSION; - case MetricsType::OPTIMIZER_MATERIALIZED_CTE: - return OptimizerType::MATERIALIZED_CTE; - case MetricsType::OPTIMIZER_SUM_REWRITER: - return OptimizerType::SUM_REWRITER; - default: - return OptimizerType::INVALID; - }; -} - -bool MetricsUtils::IsOptimizerMetric(MetricsType type) { - switch(type) { - case MetricsType::OPTIMIZER_EXPRESSION_REWRITER: - case MetricsType::OPTIMIZER_FILTER_PULLUP: - case MetricsType::OPTIMIZER_FILTER_PUSHDOWN: - case MetricsType::OPTIMIZER_EMPTY_RESULT_PULLUP: - case MetricsType::OPTIMIZER_CTE_FILTER_PUSHER: - case MetricsType::OPTIMIZER_REGEX_RANGE: - case MetricsType::OPTIMIZER_IN_CLAUSE: - case MetricsType::OPTIMIZER_JOIN_ORDER: - case MetricsType::OPTIMIZER_DELIMINATOR: - case MetricsType::OPTIMIZER_UNNEST_REWRITER: - case MetricsType::OPTIMIZER_UNUSED_COLUMNS: - case MetricsType::OPTIMIZER_STATISTICS_PROPAGATION: - case MetricsType::OPTIMIZER_COMMON_SUBEXPRESSIONS: - case MetricsType::OPTIMIZER_COMMON_AGGREGATE: - case MetricsType::OPTIMIZER_COLUMN_LIFETIME: - case MetricsType::OPTIMIZER_BUILD_SIDE_PROBE_SIDE: - case MetricsType::OPTIMIZER_LIMIT_PUSHDOWN: - case MetricsType::OPTIMIZER_TOP_N: - case MetricsType::OPTIMIZER_COMPRESSED_MATERIALIZATION: - case MetricsType::OPTIMIZER_DUPLICATE_GROUPS: - case MetricsType::OPTIMIZER_REORDER_FILTER: - case MetricsType::OPTIMIZER_SAMPLING_PUSHDOWN: - case MetricsType::OPTIMIZER_JOIN_FILTER_PUSHDOWN: - case MetricsType::OPTIMIZER_EXTENSION: - case MetricsType::OPTIMIZER_MATERIALIZED_CTE: - case MetricsType::OPTIMIZER_SUM_REWRITER: - return true; - default: - return false; - }; -} - -bool MetricsUtils::IsPhaseTimingMetric(MetricsType type) { - switch(type) { - case MetricsType::ALL_OPTIMIZERS: - case MetricsType::CUMULATIVE_OPTIMIZER_TIMING: - case MetricsType::PLANNER: - case MetricsType::PLANNER_BINDING: - case MetricsType::PHYSICAL_PLANNER: - case MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING: - case MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES: - case MetricsType::PHYSICAL_PLANNER_CREATE_PLAN: - return true; - default: - return false; - }; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/optimizer_type.cpp b/src/duckdb/src/common/enums/optimizer_type.cpp deleted file mode 100644 index f2555dc58..000000000 --- a/src/duckdb/src/common/enums/optimizer_type.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include "duckdb/common/enums/optimizer_type.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/exception/parser_exception.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -struct DefaultOptimizerType { - const char *name; - OptimizerType type; -}; - -static const DefaultOptimizerType internal_optimizer_types[] = { - {"expression_rewriter", OptimizerType::EXPRESSION_REWRITER}, - {"filter_pullup", OptimizerType::FILTER_PULLUP}, - {"filter_pushdown", OptimizerType::FILTER_PUSHDOWN}, - {"empty_result_pullup", OptimizerType::EMPTY_RESULT_PULLUP}, - {"cte_filter_pusher", OptimizerType::CTE_FILTER_PUSHER}, - {"regex_range", OptimizerType::REGEX_RANGE}, - {"in_clause", OptimizerType::IN_CLAUSE}, - {"join_order", OptimizerType::JOIN_ORDER}, - {"deliminator", OptimizerType::DELIMINATOR}, - {"unnest_rewriter", OptimizerType::UNNEST_REWRITER}, - {"unused_columns", OptimizerType::UNUSED_COLUMNS}, - {"statistics_propagation", OptimizerType::STATISTICS_PROPAGATION}, - {"common_subexpressions", OptimizerType::COMMON_SUBEXPRESSIONS}, - {"common_aggregate", OptimizerType::COMMON_AGGREGATE}, - {"column_lifetime", OptimizerType::COLUMN_LIFETIME}, - {"limit_pushdown", OptimizerType::LIMIT_PUSHDOWN}, - {"top_n", OptimizerType::TOP_N}, - {"build_side_probe_side", OptimizerType::BUILD_SIDE_PROBE_SIDE}, - {"compressed_materialization", OptimizerType::COMPRESSED_MATERIALIZATION}, - {"duplicate_groups", OptimizerType::DUPLICATE_GROUPS}, - {"reorder_filter", OptimizerType::REORDER_FILTER}, - {"sampling_pushdown", OptimizerType::SAMPLING_PUSHDOWN}, - {"join_filter_pushdown", OptimizerType::JOIN_FILTER_PUSHDOWN}, - {"extension", OptimizerType::EXTENSION}, - {"materialized_cte", OptimizerType::MATERIALIZED_CTE}, - {"sum_rewriter", OptimizerType::SUM_REWRITER}, - {nullptr, OptimizerType::INVALID}}; - -string OptimizerTypeToString(OptimizerType type) { - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - if (internal_optimizer_types[i].type == type) { - return internal_optimizer_types[i].name; - } - } - throw InternalException("Invalid optimizer type"); -} - -vector ListAllOptimizers() { - vector result; - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - result.push_back(internal_optimizer_types[i].name); - } - return result; -} - -OptimizerType OptimizerTypeFromString(const string &str) { - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - if (internal_optimizer_types[i].name == str) { - return internal_optimizer_types[i].type; - } - } - // optimizer not found, construct candidate list - vector optimizer_names; - for (idx_t i = 0; internal_optimizer_types[i].name; i++) { - optimizer_names.emplace_back(internal_optimizer_types[i].name); - } - throw ParserException("Optimizer type \"%s\" not recognized\n%s", str, - StringUtil::CandidatesErrorMessage(optimizer_names, str, "Candidate optimizers")); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/physical_operator_type.cpp b/src/duckdb/src/common/enums/physical_operator_type.cpp deleted file mode 100644 index f520cb448..000000000 --- a/src/duckdb/src/common/enums/physical_operator_type.cpp +++ /dev/null @@ -1,171 +0,0 @@ -#include "duckdb/common/enums/physical_operator_type.hpp" - -namespace duckdb { - -// LCOV_EXCL_START -string PhysicalOperatorToString(PhysicalOperatorType type) { - switch (type) { - case PhysicalOperatorType::TABLE_SCAN: - return "TABLE_SCAN"; - case PhysicalOperatorType::DUMMY_SCAN: - return "DUMMY_SCAN"; - case PhysicalOperatorType::CHUNK_SCAN: - return "CHUNK_SCAN"; - case PhysicalOperatorType::COLUMN_DATA_SCAN: - return "COLUMN_DATA_SCAN"; - case PhysicalOperatorType::DELIM_SCAN: - return "DELIM_SCAN"; - case PhysicalOperatorType::ORDER_BY: - return "ORDER_BY"; - case PhysicalOperatorType::LIMIT: - return "LIMIT"; - case PhysicalOperatorType::LIMIT_PERCENT: - return "LIMIT_PERCENT"; - case PhysicalOperatorType::STREAMING_LIMIT: - return "STREAMING_LIMIT"; - case PhysicalOperatorType::RESERVOIR_SAMPLE: - return "RESERVOIR_SAMPLE"; - case PhysicalOperatorType::STREAMING_SAMPLE: - return "STREAMING_SAMPLE"; - case PhysicalOperatorType::TOP_N: - return "TOP_N"; - case PhysicalOperatorType::WINDOW: - return "WINDOW"; - case PhysicalOperatorType::STREAMING_WINDOW: - return "STREAMING_WINDOW"; - case PhysicalOperatorType::UNNEST: - return "UNNEST"; - case PhysicalOperatorType::UNGROUPED_AGGREGATE: - return "UNGROUPED_AGGREGATE"; - case PhysicalOperatorType::HASH_GROUP_BY: - return "HASH_GROUP_BY"; - case PhysicalOperatorType::PERFECT_HASH_GROUP_BY: - return "PERFECT_HASH_GROUP_BY"; - case PhysicalOperatorType::PARTITIONED_AGGREGATE: - return "PARTITIONED_AGGREGATE"; - case PhysicalOperatorType::FILTER: - return "FILTER"; - case PhysicalOperatorType::PROJECTION: - return "PROJECTION"; - case PhysicalOperatorType::COPY_TO_FILE: - return "COPY_TO_FILE"; - case PhysicalOperatorType::BATCH_COPY_TO_FILE: - return "BATCH_COPY_TO_FILE"; - case PhysicalOperatorType::LEFT_DELIM_JOIN: - return "LEFT_DELIM_JOIN"; - case PhysicalOperatorType::RIGHT_DELIM_JOIN: - return "RIGHT_DELIM_JOIN"; - case PhysicalOperatorType::BLOCKWISE_NL_JOIN: - return "BLOCKWISE_NL_JOIN"; - case PhysicalOperatorType::NESTED_LOOP_JOIN: - return "NESTED_LOOP_JOIN"; - case PhysicalOperatorType::HASH_JOIN: - return "HASH_JOIN"; - case PhysicalOperatorType::PIECEWISE_MERGE_JOIN: - return "PIECEWISE_MERGE_JOIN"; - case PhysicalOperatorType::IE_JOIN: - return "IE_JOIN"; - case PhysicalOperatorType::ASOF_JOIN: - return "ASOF_JOIN"; - case PhysicalOperatorType::CROSS_PRODUCT: - return "CROSS_PRODUCT"; - case PhysicalOperatorType::POSITIONAL_JOIN: - return "POSITIONAL_JOIN"; - case PhysicalOperatorType::POSITIONAL_SCAN: - return "POSITIONAL_SCAN"; - case PhysicalOperatorType::UNION: - return "UNION"; - case PhysicalOperatorType::INSERT: - return "INSERT"; - case PhysicalOperatorType::BATCH_INSERT: - return "BATCH_INSERT"; - case PhysicalOperatorType::DELETE_OPERATOR: - return "DELETE"; - case PhysicalOperatorType::UPDATE: - return "UPDATE"; - case PhysicalOperatorType::EMPTY_RESULT: - return "EMPTY_RESULT"; - case PhysicalOperatorType::CREATE_TABLE: - return "CREATE_TABLE"; - case PhysicalOperatorType::CREATE_TABLE_AS: - return "CREATE_TABLE_AS"; - case PhysicalOperatorType::BATCH_CREATE_TABLE_AS: - return "BATCH_CREATE_TABLE_AS"; - case PhysicalOperatorType::CREATE_INDEX: - return "CREATE_INDEX"; - case PhysicalOperatorType::EXPLAIN: - return "EXPLAIN"; - case PhysicalOperatorType::EXPLAIN_ANALYZE: - return "EXPLAIN_ANALYZE"; - case PhysicalOperatorType::EXECUTE: - return "EXECUTE"; - case PhysicalOperatorType::VACUUM: - return "VACUUM"; - case PhysicalOperatorType::RECURSIVE_CTE: - return "REC_CTE"; - case PhysicalOperatorType::CTE: - return "CTE"; - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: - return "REC_CTE_SCAN"; - case PhysicalOperatorType::CTE_SCAN: - return "CTE_SCAN"; - case PhysicalOperatorType::EXPRESSION_SCAN: - return "EXPRESSION_SCAN"; - case PhysicalOperatorType::ALTER: - return "ALTER"; - case PhysicalOperatorType::CREATE_SEQUENCE: - return "CREATE_SEQUENCE"; - case PhysicalOperatorType::CREATE_VIEW: - return "CREATE_VIEW"; - case PhysicalOperatorType::CREATE_SCHEMA: - return "CREATE_SCHEMA"; - case PhysicalOperatorType::CREATE_MACRO: - return "CREATE_MACRO"; - case PhysicalOperatorType::CREATE_SECRET: - return "CREATE_SECRET"; - case PhysicalOperatorType::DROP: - return "DROP"; - case PhysicalOperatorType::PRAGMA: - return "PRAGMA"; - case PhysicalOperatorType::TRANSACTION: - return "TRANSACTION"; - case PhysicalOperatorType::PREPARE: - return "PREPARE"; - case PhysicalOperatorType::EXPORT: - return "EXPORT"; - case PhysicalOperatorType::SET: - return "SET"; - case PhysicalOperatorType::SET_VARIABLE: - return "SET_VARIABLE"; - case PhysicalOperatorType::RESET: - return "RESET"; - case PhysicalOperatorType::LOAD: - return "LOAD"; - case PhysicalOperatorType::INOUT_FUNCTION: - return "INOUT_FUNCTION"; - case PhysicalOperatorType::CREATE_TYPE: - return "CREATE_TYPE"; - case PhysicalOperatorType::ATTACH: - return "ATTACH"; - case PhysicalOperatorType::DETACH: - return "DETACH"; - case PhysicalOperatorType::RESULT_COLLECTOR: - return "RESULT_COLLECTOR"; - case PhysicalOperatorType::EXTENSION: - return "EXTENSION"; - case PhysicalOperatorType::PIVOT: - return "PIVOT"; - case PhysicalOperatorType::COPY_DATABASE: - return "COPY_DATABASE"; - case PhysicalOperatorType::VERIFY_VECTOR: - return "VERIFY_VECTOR"; - case PhysicalOperatorType::UPDATE_EXTENSIONS: - return "UPDATE_EXTENSIONS"; - case PhysicalOperatorType::INVALID: - break; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/relation_type.cpp b/src/duckdb/src/common/enums/relation_type.cpp deleted file mode 100644 index 4f58ed7c4..000000000 --- a/src/duckdb/src/common/enums/relation_type.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include "duckdb/common/enums/relation_type.hpp" - -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -// LCOV_EXCL_START -string RelationTypeToString(RelationType type) { - switch (type) { - case RelationType::TABLE_RELATION: - return "TABLE_RELATION"; - case RelationType::DELIM_GET_RELATION: - return "DELIM_GET_RELATION"; - case RelationType::DELIM_JOIN_RELATION: - return "DELIM_JOIN_RELATION"; - case RelationType::PROJECTION_RELATION: - return "PROJECTION_RELATION"; - case RelationType::FILTER_RELATION: - return "FILTER_RELATION"; - case RelationType::EXPLAIN_RELATION: - return "EXPLAIN_RELATION"; - case RelationType::CROSS_PRODUCT_RELATION: - return "CROSS_PRODUCT_RELATION"; - case RelationType::JOIN_RELATION: - return "JOIN_RELATION"; - case RelationType::AGGREGATE_RELATION: - return "AGGREGATE_RELATION"; - case RelationType::SET_OPERATION_RELATION: - return "SET_OPERATION_RELATION"; - case RelationType::DISTINCT_RELATION: - return "DISTINCT_RELATION"; - case RelationType::LIMIT_RELATION: - return "LIMIT_RELATION"; - case RelationType::ORDER_RELATION: - return "ORDER_RELATION"; - case RelationType::CREATE_VIEW_RELATION: - return "CREATE_VIEW_RELATION"; - case RelationType::CREATE_TABLE_RELATION: - return "CREATE_TABLE_RELATION"; - case RelationType::INSERT_RELATION: - return "INSERT_RELATION"; - case RelationType::VALUE_LIST_RELATION: - return "VALUE_LIST_RELATION"; - case RelationType::MATERIALIZED_RELATION: - return "MATERIALIZED_RELATION"; - case RelationType::DELETE_RELATION: - return "DELETE_RELATION"; - case RelationType::UPDATE_RELATION: - return "UPDATE_RELATION"; - case RelationType::WRITE_CSV_RELATION: - return "WRITE_CSV_RELATION"; - case RelationType::WRITE_PARQUET_RELATION: - return "WRITE_PARQUET_RELATION"; - case RelationType::READ_CSV_RELATION: - return "READ_CSV_RELATION"; - case RelationType::SUBQUERY_RELATION: - return "SUBQUERY_RELATION"; - case RelationType::TABLE_FUNCTION_RELATION: - return "TABLE_FUNCTION_RELATION"; - case RelationType::VIEW_RELATION: - return "VIEW_RELATION"; - case RelationType::QUERY_RELATION: - return "QUERY_RELATION"; - case RelationType::INVALID_RELATION: - break; - } - return "INVALID_RELATION"; -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/enums/statement_type.cpp b/src/duckdb/src/common/enums/statement_type.cpp deleted file mode 100644 index 98524b938..000000000 --- a/src/duckdb/src/common/enums/statement_type.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "duckdb/common/enums/statement_type.hpp" - -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -// LCOV_EXCL_START -string StatementTypeToString(StatementType type) { - switch (type) { - case StatementType::SELECT_STATEMENT: - return "SELECT"; - case StatementType::INSERT_STATEMENT: - return "INSERT"; - case StatementType::UPDATE_STATEMENT: - return "UPDATE"; - case StatementType::DELETE_STATEMENT: - return "DELETE"; - case StatementType::PREPARE_STATEMENT: - return "PREPARE"; - case StatementType::EXECUTE_STATEMENT: - return "EXECUTE"; - case StatementType::ALTER_STATEMENT: - return "ALTER"; - case StatementType::TRANSACTION_STATEMENT: - return "TRANSACTION"; - case StatementType::COPY_STATEMENT: - return "COPY"; - case StatementType::COPY_DATABASE_STATEMENT: - return "COPY_DATABASE"; - case StatementType::ANALYZE_STATEMENT: - return "ANALYZE"; - case StatementType::VARIABLE_SET_STATEMENT: - return "VARIABLE_SET"; - case StatementType::CREATE_FUNC_STATEMENT: - return "CREATE_FUNC"; - case StatementType::EXPLAIN_STATEMENT: - return "EXPLAIN"; - case StatementType::CREATE_STATEMENT: - return "CREATE"; - case StatementType::DROP_STATEMENT: - return "DROP"; - case StatementType::PRAGMA_STATEMENT: - return "PRAGMA"; - case StatementType::VACUUM_STATEMENT: - return "VACUUM"; - case StatementType::RELATION_STATEMENT: - return "RELATION"; - case StatementType::EXPORT_STATEMENT: - return "EXPORT"; - case StatementType::CALL_STATEMENT: - return "CALL"; - case StatementType::SET_STATEMENT: - return "SET"; - case StatementType::LOAD_STATEMENT: - return "LOAD"; - case StatementType::EXTENSION_STATEMENT: - return "EXTENSION"; - case StatementType::LOGICAL_PLAN_STATEMENT: - return "LOGICAL_PLAN"; - case StatementType::ATTACH_STATEMENT: - return "ATTACH"; - case StatementType::DETACH_STATEMENT: - return "DETACH"; - case StatementType::MULTI_STATEMENT: - return "MULTI"; - case StatementType::UPDATE_EXTENSIONS_STATEMENT: - return "UPDATE_EXTENSIONS"; - case StatementType::INVALID_STATEMENT: - break; - } - return "INVALID"; -} - -string StatementReturnTypeToString(StatementReturnType type) { - switch (type) { - case StatementReturnType::QUERY_RESULT: - return "QUERY_RESULT"; - case StatementReturnType::CHANGED_ROWS: - return "CHANGED_ROWS"; - case StatementReturnType::NOTHING: - return "NOTHING"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -void StatementProperties::RegisterDBRead(Catalog &catalog, ClientContext &context) { - auto catalog_identity = CatalogIdentity {catalog.GetOid(), catalog.GetCatalogVersion(context)}; - D_ASSERT(read_databases.count(catalog.GetName()) == 0 || read_databases[catalog.GetName()] == catalog_identity); - read_databases[catalog.GetName()] = catalog_identity; -} - -void StatementProperties::RegisterDBModify(Catalog &catalog, ClientContext &context) { - auto catalog_identity = CatalogIdentity {catalog.GetOid(), catalog.GetCatalogVersion(context)}; - D_ASSERT(modified_databases.count(catalog.GetName()) == 0 || - modified_databases[catalog.GetName()] == catalog_identity); - modified_databases[catalog.GetName()] = catalog_identity; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/error_data.cpp b/src/duckdb/src/common/error_data.cpp deleted file mode 100644 index 9cca3264c..000000000 --- a/src/duckdb/src/common/error_data.cpp +++ /dev/null @@ -1,147 +0,0 @@ -#include "duckdb/common/error_data.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/stacktrace.hpp" -#include "duckdb/parser/parsed_expression.hpp" -#include "duckdb/parser/query_error_context.hpp" -#include "duckdb/parser/tableref.hpp" - -namespace duckdb { - -ErrorData::ErrorData() : initialized(false), type(ExceptionType::INVALID) { -} - -ErrorData::ErrorData(const std::exception &ex) : ErrorData(ex.what()) { -} - -ErrorData::ErrorData(ExceptionType type, const string &message) - : initialized(true), type(type), raw_message(SanitizeErrorMessage(message)), - final_message(ConstructFinalMessage()) { -} - -ErrorData::ErrorData(const string &message) - : initialized(true), type(ExceptionType::INVALID), raw_message(string()), final_message(string()) { - - // parse the constructed JSON - if (message.empty() || message[0] != '{') { - // not JSON! Use the message as a raw Exception message and leave type as uninitialized - - if (message == std::bad_alloc().what()) { - type = ExceptionType::OUT_OF_MEMORY; - raw_message = "Allocation failure"; - } else { - raw_message = message; - } - } else { - auto info = StringUtil::ParseJSONMap(message); - for (auto &entry : info) { - if (entry.first == "exception_type") { - type = Exception::StringToExceptionType(entry.second); - } else if (entry.first == "exception_message") { - raw_message = SanitizeErrorMessage(entry.second); - } else { - extra_info[entry.first] = entry.second; - } - } - } - - final_message = ConstructFinalMessage(); -} - -string ErrorData::SanitizeErrorMessage(string error) { - return StringUtil::Replace(std::move(error), string("\0", 1), "\\0"); -} - -string ErrorData::ConstructFinalMessage() const { - std::string error; - if (type != ExceptionType::UNKNOWN_TYPE) { - error = Exception::ExceptionTypeToString(type) + " "; - } - error += "Error: " + raw_message; - if (type == ExceptionType::INTERNAL) { - error += "\nThis error signals an assertion failure within DuckDB. This usually occurs due to " - "unexpected conditions or errors in the program's logic.\nFor more information, see " - "https://duckdb.org/docs/dev/internal_errors"; - } - return error; -} - -void ErrorData::Throw(const string &prepended_message) const { - D_ASSERT(initialized); - if (!prepended_message.empty()) { - string new_message = prepended_message + raw_message; - throw Exception(type, new_message, extra_info); - } else { - throw Exception(type, raw_message, extra_info); - } -} - -const ExceptionType &ErrorData::Type() const { - D_ASSERT(initialized); - return this->type; -} - -bool ErrorData::operator==(const ErrorData &other) const { - if (initialized != other.initialized) { - return false; - } - if (type != other.type) { - return false; - } - return raw_message == other.raw_message; -} - -void ErrorData::ConvertErrorToJSON() { - if (raw_message.empty() || raw_message[0] == '{') { - // empty or already JSON - return; - } - raw_message = StringUtil::ExceptionToJSONMap(type, raw_message, extra_info); - final_message = raw_message; -} - -void ErrorData::FinalizeError() { - auto entry = extra_info.find("stack_trace_pointers"); - if (entry != extra_info.end()) { - auto stack_trace = StackTrace::ResolveStacktraceSymbols(entry->second); - extra_info["stack_trace"] = std::move(stack_trace); - extra_info.erase("stack_trace_pointers"); - } -} - -void ErrorData::AddErrorLocation(const string &query) { - if (!query.empty()) { - auto entry = extra_info.find("position"); - if (entry != extra_info.end()) { - raw_message = QueryErrorContext::Format(query, raw_message, std::stoull(entry->second)); - } - } - { - auto entry = extra_info.find("stack_trace"); - if (entry != extra_info.end()) { - raw_message += "\n\nStack Trace:\n" + entry->second; - } - } - final_message = ConstructFinalMessage(); -} - -void ErrorData::AddQueryLocation(optional_idx query_location) { - Exception::SetQueryLocation(query_location, extra_info); -} - -void ErrorData::AddQueryLocation(QueryErrorContext error_context) { - AddQueryLocation(error_context.query_location); -} - -void ErrorData::AddQueryLocation(const ParsedExpression &ref) { - AddQueryLocation(ref.GetQueryLocation()); -} - -void ErrorData::AddQueryLocation(const TableRef &ref) { - AddQueryLocation(ref.query_location); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception.cpp b/src/duckdb/src/common/exception.cpp deleted file mode 100644 index 991f929f2..000000000 --- a/src/duckdb/src/common/exception.cpp +++ /dev/null @@ -1,354 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/exception/list.hpp" -#include "duckdb/parser/tableref.hpp" -#include "duckdb/planner/expression.hpp" - -#ifdef DUCKDB_CRASH_ON_ASSERT -#include "duckdb/common/printer.hpp" -#include -#include -#endif -#include "duckdb/common/stacktrace.hpp" - -namespace duckdb { - -Exception::Exception(ExceptionType exception_type, const string &message) - : std::runtime_error(ToJSON(exception_type, message)) { -} - -Exception::Exception(ExceptionType exception_type, const string &message, - const unordered_map &extra_info) - : std::runtime_error(ToJSON(exception_type, message, extra_info)) { -} - -string Exception::ToJSON(ExceptionType type, const string &message) { - unordered_map extra_info; - return ToJSON(type, message, extra_info); -} - -string Exception::ToJSON(ExceptionType type, const string &message, const unordered_map &extra_info) { -#ifndef DUCKDB_DEBUG_STACKTRACE - // by default we only enable stack traces for internal exceptions - if (type == ExceptionType::INTERNAL) -#endif - { - auto extended_extra_info = extra_info; - extended_extra_info["stack_trace_pointers"] = StackTrace::GetStacktracePointers(); - return StringUtil::ExceptionToJSONMap(type, message, extended_extra_info); - } - return StringUtil::ExceptionToJSONMap(type, message, extra_info); -} - -bool Exception::UncaughtException() { -#if __cplusplus >= 201703L - return std::uncaught_exceptions() > 0; -#else - return std::uncaught_exception(); -#endif -} - -bool Exception::InvalidatesTransaction(ExceptionType exception_type) { - switch (exception_type) { - case ExceptionType::BINDER: - case ExceptionType::CATALOG: - case ExceptionType::CONNECTION: - case ExceptionType::PARAMETER_NOT_ALLOWED: - case ExceptionType::PARSER: - case ExceptionType::PERMISSION: - return false; - default: - return true; - } -} - -bool Exception::InvalidatesDatabase(ExceptionType exception_type) { - switch (exception_type) { - case ExceptionType::FATAL: - return true; - default: - return false; - } -} - -string Exception::GetStackTrace(idx_t max_depth) { - return StackTrace::GetStackTrace(max_depth); -} - -string Exception::ConstructMessageRecursive(const string &msg, std::vector &values) { -#ifdef DEBUG - // Verify that we have the required amount of values for the message - idx_t parameter_count = 0; - for (idx_t i = 0; i + 1 < msg.size(); i++) { - if (msg[i] != '%') { - continue; - } - if (msg[i + 1] == '%') { - i++; - continue; - } - parameter_count++; - } - if (parameter_count != values.size()) { - throw InternalException("Primary exception: %s\nSecondary exception in ConstructMessageRecursive: Expected %d " - "parameters, received %d", - msg.c_str(), parameter_count, values.size()); - } - -#endif - return ExceptionFormatValue::Format(msg, values); -} - -struct ExceptionEntry { - ExceptionType type; - char text[48]; -}; - -static constexpr ExceptionEntry EXCEPTION_MAP[] = {{ExceptionType::INVALID, "Invalid"}, - {ExceptionType::OUT_OF_RANGE, "Out of Range"}, - {ExceptionType::CONVERSION, "Conversion"}, - {ExceptionType::UNKNOWN_TYPE, "Unknown Type"}, - {ExceptionType::DECIMAL, "Decimal"}, - {ExceptionType::MISMATCH_TYPE, "Mismatch Type"}, - {ExceptionType::DIVIDE_BY_ZERO, "Divide by Zero"}, - {ExceptionType::OBJECT_SIZE, "Object Size"}, - {ExceptionType::INVALID_TYPE, "Invalid type"}, - {ExceptionType::SERIALIZATION, "Serialization"}, - {ExceptionType::TRANSACTION, "TransactionContext"}, - {ExceptionType::NOT_IMPLEMENTED, "Not implemented"}, - {ExceptionType::EXPRESSION, "Expression"}, - {ExceptionType::CATALOG, "Catalog"}, - {ExceptionType::PARSER, "Parser"}, - {ExceptionType::BINDER, "Binder"}, - {ExceptionType::PLANNER, "Planner"}, - {ExceptionType::SCHEDULER, "Scheduler"}, - {ExceptionType::EXECUTOR, "Executor"}, - {ExceptionType::CONSTRAINT, "Constraint"}, - {ExceptionType::INDEX, "Index"}, - {ExceptionType::STAT, "Stat"}, - {ExceptionType::CONNECTION, "Connection"}, - {ExceptionType::SYNTAX, "Syntax"}, - {ExceptionType::SETTINGS, "Settings"}, - {ExceptionType::OPTIMIZER, "Optimizer"}, - {ExceptionType::NULL_POINTER, "NullPointer"}, - {ExceptionType::IO, "IO"}, - {ExceptionType::INTERRUPT, "INTERRUPT"}, - {ExceptionType::FATAL, "FATAL"}, - {ExceptionType::INTERNAL, "INTERNAL"}, - {ExceptionType::INVALID_INPUT, "Invalid Input"}, - {ExceptionType::OUT_OF_MEMORY, "Out of Memory"}, - {ExceptionType::PERMISSION, "Permission"}, - {ExceptionType::PARAMETER_NOT_RESOLVED, "Parameter Not Resolved"}, - {ExceptionType::PARAMETER_NOT_ALLOWED, "Parameter Not Allowed"}, - {ExceptionType::DEPENDENCY, "Dependency"}, - {ExceptionType::MISSING_EXTENSION, "Missing Extension"}, - {ExceptionType::HTTP, "HTTP"}, - {ExceptionType::AUTOLOAD, "Extension Autoloading"}, - {ExceptionType::SEQUENCE, "Sequence"}, - {ExceptionType::INVALID_CONFIGURATION, "Invalid Configuration"}}; - -string Exception::ExceptionTypeToString(ExceptionType type) { - for (auto &e : EXCEPTION_MAP) { - if (e.type == type) { - return e.text; - } - } - return "Unknown"; -} - -ExceptionType Exception::StringToExceptionType(const string &type) { - for (auto &e : EXCEPTION_MAP) { - if (e.text == type) { - return e.type; - } - } - return ExceptionType::INVALID; -} - -unordered_map Exception::InitializeExtraInfo(const Expression &expr) { - return InitializeExtraInfo(expr.GetQueryLocation()); -} - -unordered_map Exception::InitializeExtraInfo(const ParsedExpression &expr) { - return InitializeExtraInfo(expr.GetQueryLocation()); -} - -unordered_map Exception::InitializeExtraInfo(const QueryErrorContext &error_context) { - return InitializeExtraInfo(error_context.query_location); -} - -unordered_map Exception::InitializeExtraInfo(const TableRef &ref) { - return InitializeExtraInfo(ref.query_location); -} - -unordered_map Exception::InitializeExtraInfo(optional_idx error_location) { - unordered_map result; - SetQueryLocation(error_location, result); - return result; -} - -unordered_map Exception::InitializeExtraInfo(const string &subtype, optional_idx error_location) { - unordered_map result; - result["error_subtype"] = subtype; - SetQueryLocation(error_location, result); - return result; -} - -void Exception::SetQueryLocation(optional_idx error_location, unordered_map &extra_info) { - if (error_location.IsValid()) { - extra_info["position"] = to_string(error_location.GetIndex()); - } -} - -InvalidTypeException::InvalidTypeException(PhysicalType type, const string &msg) - : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + TypeIdToString(type) + "]: " + msg) { -} - -InvalidTypeException::InvalidTypeException(const LogicalType &type, const string &msg) - : Exception(ExceptionType::INVALID_TYPE, "Invalid Type [" + type.ToString() + "]: " + msg) { -} - -InvalidTypeException::InvalidTypeException(const string &msg) : Exception(ExceptionType::INVALID_TYPE, msg) { -} - -TypeMismatchException::TypeMismatchException(const PhysicalType type_1, const PhysicalType type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + TypeIdToString(type_1) + " does not match with " + TypeIdToString(type_2) + ". " + msg) { -} - -TypeMismatchException::TypeMismatchException(const LogicalType &type_1, const LogicalType &type_2, const string &msg) - : TypeMismatchException(optional_idx(), type_1, type_2, msg) { -} - -TypeMismatchException::TypeMismatchException(optional_idx error_location, const LogicalType &type_1, - const LogicalType &type_2, const string &msg) - : Exception(ExceptionType::MISMATCH_TYPE, - "Type " + type_1.ToString() + " does not match with " + type_2.ToString() + ". " + msg, - Exception::InitializeExtraInfo(error_location)) { -} - -TypeMismatchException::TypeMismatchException(const string &msg) : Exception(ExceptionType::MISMATCH_TYPE, msg) { -} - -TransactionException::TransactionException(const string &msg) : Exception(ExceptionType::TRANSACTION, msg) { -} - -NotImplementedException::NotImplementedException(const string &msg) : Exception(ExceptionType::NOT_IMPLEMENTED, msg) { -} - -OutOfRangeException::OutOfRangeException(const string &msg) : Exception(ExceptionType::OUT_OF_RANGE, msg) { -} - -OutOfRangeException::OutOfRangeException(const int64_t value, const PhysicalType orig_type, const PhysicalType new_type) - : Exception(ExceptionType::OUT_OF_RANGE, "Type " + TypeIdToString(orig_type) + " with value " + - to_string((intmax_t)value) + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -OutOfRangeException::OutOfRangeException(const double value, const PhysicalType orig_type, const PhysicalType new_type) - : Exception(ExceptionType::OUT_OF_RANGE, "Type " + TypeIdToString(orig_type) + " with value " + to_string(value) + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -OutOfRangeException::OutOfRangeException(const hugeint_t value, const PhysicalType orig_type, - const PhysicalType new_type) - : Exception(ExceptionType::OUT_OF_RANGE, "Type " + TypeIdToString(orig_type) + " with value " + value.ToString() + - " can't be cast because the value is out of range " - "for the destination type " + - TypeIdToString(new_type)) { -} - -OutOfRangeException::OutOfRangeException(const PhysicalType var_type, const idx_t length) - : Exception(ExceptionType::OUT_OF_RANGE, - "The value is too long to fit into type " + TypeIdToString(var_type) + "(" + to_string(length) + ")") { -} - -ConnectionException::ConnectionException(const string &msg) : Exception(ExceptionType::CONNECTION, msg) { -} - -PermissionException::PermissionException(const string &msg) : Exception(ExceptionType::PERMISSION, msg) { -} - -SyntaxException::SyntaxException(const string &msg) : Exception(ExceptionType::SYNTAX, msg) { -} - -ExecutorException::ExecutorException(const string &msg) : Exception(ExceptionType::EXECUTOR, msg) { -} - -ConstraintException::ConstraintException(const string &msg) : Exception(ExceptionType::CONSTRAINT, msg) { -} - -DependencyException::DependencyException(const string &msg) : Exception(ExceptionType::DEPENDENCY, msg) { -} - -IOException::IOException(const string &msg) : Exception(ExceptionType::IO, msg) { -} - -IOException::IOException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::IO, msg, extra_info) { -} - -MissingExtensionException::MissingExtensionException(const string &msg) - : Exception(ExceptionType::MISSING_EXTENSION, msg) { -} - -AutoloadException::AutoloadException(const string &extension_name, const string &message) - : Exception(ExceptionType::AUTOLOAD, - "An error occurred while trying to automatically install the required extension '" + extension_name + - "':\n" + message) { -} - -SerializationException::SerializationException(const string &msg) : Exception(ExceptionType::SERIALIZATION, msg) { -} - -SequenceException::SequenceException(const string &msg) : Exception(ExceptionType::SEQUENCE, msg) { -} - -InterruptException::InterruptException() : Exception(ExceptionType::INTERRUPT, "Interrupted!") { -} - -FatalException::FatalException(ExceptionType type, const string &msg) : Exception(type, msg) { -} - -InternalException::InternalException(const string &msg) : Exception(ExceptionType::INTERNAL, msg) { -#ifdef DUCKDB_CRASH_ON_ASSERT - Printer::Print("ABORT THROWN BY INTERNAL EXCEPTION: " + msg); - Printer::Print(StackTrace::GetStackTrace()); - abort(); -#endif -} - -InvalidInputException::InvalidInputException(const string &msg) : Exception(ExceptionType::INVALID_INPUT, msg) { -} - -InvalidInputException::InvalidInputException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_INPUT, msg, extra_info) { -} - -InvalidConfigurationException::InvalidConfigurationException(const string &msg) - : Exception(ExceptionType::INVALID_CONFIGURATION, msg) { -} - -InvalidConfigurationException::InvalidConfigurationException(const string &msg, - const unordered_map &extra_info) - : Exception(ExceptionType::INVALID_CONFIGURATION, msg, extra_info) { -} - -OutOfMemoryException::OutOfMemoryException(const string &msg) : Exception(ExceptionType::OUT_OF_MEMORY, msg) { -} - -ParameterNotAllowedException::ParameterNotAllowedException(const string &msg) - : Exception(ExceptionType::PARAMETER_NOT_ALLOWED, msg) { -} - -ParameterNotResolvedException::ParameterNotResolvedException() - : Exception(ExceptionType::PARAMETER_NOT_RESOLVED, "Parameter types could not be resolved") { -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception/binder_exception.cpp b/src/duckdb/src/common/exception/binder_exception.cpp deleted file mode 100644 index cc4e08722..000000000 --- a/src/duckdb/src/common/exception/binder_exception.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "duckdb/common/exception/binder_exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/function.hpp" - -namespace duckdb { - -BinderException::BinderException(const string &msg) : Exception(ExceptionType::BINDER, msg) { -} - -BinderException::BinderException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::BINDER, msg, extra_info) { -} - -BinderException BinderException::ColumnNotFound(const string &name, const vector &similar_bindings, - QueryErrorContext context) { - auto extra_info = Exception::InitializeExtraInfo("COLUMN_NOT_FOUND", context.query_location); - string candidate_str = StringUtil::CandidatesMessage(similar_bindings, "Candidate bindings"); - extra_info["name"] = name; - if (!similar_bindings.empty()) { - extra_info["candidates"] = StringUtil::Join(similar_bindings, ","); - } - return BinderException( - StringUtil::Format("Referenced column \"%s\" not found in FROM clause!%s", name, candidate_str), extra_info); -} - -BinderException BinderException::NoMatchingFunction(const string &name, const vector &arguments, - const vector &candidates) { - auto extra_info = Exception::InitializeExtraInfo("NO_MATCHING_FUNCTION", optional_idx()); - // no matching function was found, throw an error - string call_str = Function::CallToString(name, arguments); - string candidate_str; - for (auto &candidate : candidates) { - candidate_str += "\t" + candidate + "\n"; - } - extra_info["name"] = name; - extra_info["call"] = call_str; - if (!candidates.empty()) { - extra_info["candidates"] = StringUtil::Join(candidates, ","); - } - return BinderException( - StringUtil::Format("No function matches the given name and argument types '%s'. You might need to add " - "explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str), - extra_info); -} - -BinderException BinderException::Unsupported(ParsedExpression &expr, const string &message) { - auto extra_info = Exception::InitializeExtraInfo("UNSUPPORTED", expr.GetQueryLocation()); - return BinderException(message, extra_info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception/catalog_exception.cpp b/src/duckdb/src/common/exception/catalog_exception.cpp deleted file mode 100644 index 2ab210f0e..000000000 --- a/src/duckdb/src/common/exception/catalog_exception.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "duckdb/common/exception/catalog_exception.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -CatalogException::CatalogException(const string &msg) : Exception(ExceptionType::CATALOG, msg) { -} - -CatalogException::CatalogException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::CATALOG, msg, extra_info) { -} - -CatalogException CatalogException::MissingEntry(CatalogType type, const string &name, const string &suggestion, - QueryErrorContext context) { - string did_you_mean; - if (!suggestion.empty()) { - did_you_mean = "\nDid you mean \"" + suggestion + "\"?"; - } - - auto extra_info = Exception::InitializeExtraInfo("MISSING_ENTRY", context.query_location); - - extra_info["name"] = name; - extra_info["type"] = CatalogTypeToString(type); - if (!suggestion.empty()) { - extra_info["candidates"] = suggestion; - } - return CatalogException( - StringUtil::Format("%s with name %s does not exist!%s", CatalogTypeToString(type), name, did_you_mean), - extra_info); -} - -CatalogException CatalogException::MissingEntry(const string &type, const string &name, - const vector &suggestions, QueryErrorContext context) { - auto extra_info = Exception::InitializeExtraInfo("MISSING_ENTRY", context.query_location); - extra_info["error_subtype"] = "MISSING_ENTRY"; - extra_info["name"] = name; - extra_info["type"] = type; - if (!suggestions.empty()) { - extra_info["candidates"] = StringUtil::Join(suggestions, ", "); - } - return CatalogException(StringUtil::Format("unrecognized %s \"%s\"\n%s", type, name, - StringUtil::CandidatesErrorMessage(suggestions, name, "Did you mean")), - extra_info); -} - -CatalogException CatalogException::EntryAlreadyExists(CatalogType type, const string &name, QueryErrorContext context) { - auto extra_info = Exception::InitializeExtraInfo("ENTRY_ALREADY_EXISTS", optional_idx()); - extra_info["name"] = name; - extra_info["type"] = CatalogTypeToString(type); - return CatalogException(StringUtil::Format("%s with name \"%s\" already exists!", CatalogTypeToString(type), name), - extra_info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception/conversion_exception.cpp b/src/duckdb/src/common/exception/conversion_exception.cpp deleted file mode 100644 index 013dbdb9e..000000000 --- a/src/duckdb/src/common/exception/conversion_exception.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/types.hpp" - -namespace duckdb { - -ConversionException::ConversionException(const PhysicalType orig_type, const PhysicalType new_type) - : Exception(ExceptionType::CONVERSION, - "Type " + TypeIdToString(orig_type) + " can't be cast as " + TypeIdToString(new_type)) { -} - -ConversionException::ConversionException(const LogicalType &orig_type, const LogicalType &new_type) - : Exception(ExceptionType::CONVERSION, - "Type " + orig_type.ToString() + " can't be cast as " + new_type.ToString()) { -} - -ConversionException::ConversionException(const string &msg) : Exception(ExceptionType::CONVERSION, msg) { -} - -ConversionException::ConversionException(optional_idx error_location, const string &msg) - : Exception(ExceptionType::CONVERSION, msg, Exception::InitializeExtraInfo(error_location)) { -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception/parser_exception.cpp b/src/duckdb/src/common/exception/parser_exception.cpp deleted file mode 100644 index f3875da38..000000000 --- a/src/duckdb/src/common/exception/parser_exception.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "duckdb/common/exception/parser_exception.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/parser/query_error_context.hpp" - -namespace duckdb { - -ParserException::ParserException(const string &msg) : Exception(ExceptionType::PARSER, msg) { -} - -ParserException::ParserException(const string &msg, const unordered_map &extra_info) - : Exception(ExceptionType::PARSER, msg, extra_info) { -} - -ParserException ParserException::SyntaxError(const string &query, const string &error_message, - optional_idx error_location) { - return ParserException(error_message, Exception::InitializeExtraInfo("SYNTAX_ERROR", error_location)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/exception_format_value.cpp b/src/duckdb/src/common/exception_format_value.cpp deleted file mode 100644 index ddef4e10c..000000000 --- a/src/duckdb/src/common/exception_format_value.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/helper.hpp" // defines DUCKDB_EXPLICIT_FALLTHROUGH which fmt will use to annotate -#include "fmt/format.h" -#include "fmt/printf.h" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/parser/keyword_helper.hpp" - -namespace duckdb { - -ExceptionFormatValue::ExceptionFormatValue(double dbl_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE), dbl_val(dbl_val) { -} -ExceptionFormatValue::ExceptionFormatValue(int64_t int_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER), int_val(int_val) { -} -ExceptionFormatValue::ExceptionFormatValue(hugeint_t huge_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(Hugeint::ToString(huge_val)) { -} -ExceptionFormatValue::ExceptionFormatValue(uhugeint_t uhuge_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(Uhugeint::ToString(uhuge_val)) { -} -ExceptionFormatValue::ExceptionFormatValue(string str_val) - : type(ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING), str_val(std::move(str_val)) { -} - -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(PhysicalType value) { - return ExceptionFormatValue(TypeIdToString(value)); -} -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(LogicalType value) { // NOLINT: templating requires us to copy value here - return ExceptionFormatValue(value.ToString()); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(float value) { - return ExceptionFormatValue(double(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(double value) { - return ExceptionFormatValue(double(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(string value) { - return ExceptionFormatValue(std::move(value)); -} - -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLString value) { // NOLINT: templating requires us to copy value here - return KeywordHelper::WriteQuoted(value.raw_string, '\''); -} - -template <> -ExceptionFormatValue -ExceptionFormatValue::CreateFormatValue(SQLIdentifier value) { // NOLINT: templating requires us to copy value here - return KeywordHelper::WriteOptionallyQuoted(value.raw_string, '"'); -} - -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(const char *value) { - return ExceptionFormatValue(string(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(char *value) { - return ExceptionFormatValue(string(value)); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(hugeint_t value) { - return ExceptionFormatValue(value); -} -template <> -ExceptionFormatValue ExceptionFormatValue::CreateFormatValue(uhugeint_t value) { - return ExceptionFormatValue(value); -} - -string ExceptionFormatValue::Format(const string &msg, std::vector &values) { - try { - std::vector> format_args; - for (auto &val : values) { - switch (val.type) { - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_DOUBLE: - format_args.push_back(duckdb_fmt::internal::make_arg(val.dbl_val)); - break; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_INTEGER: - format_args.push_back(duckdb_fmt::internal::make_arg(val.int_val)); - break; - case ExceptionFormatValueType::FORMAT_VALUE_TYPE_STRING: - format_args.push_back(duckdb_fmt::internal::make_arg(val.str_val)); - break; - } - } - return duckdb_fmt::vsprintf(msg, duckdb_fmt::basic_format_args( - format_args.data(), static_cast(format_args.size()))); - } catch (std::exception &ex) { // LCOV_EXCL_START - // work-around for oss-fuzz limiting memory which causes issues here - if (StringUtil::Contains(ex.what(), "fuzz mode")) { - throw InvalidInputException(msg); - } - throw InternalException(std::string("Primary exception: ") + msg + - "\nSecondary exception in ExceptionFormatValue: " + ex.what()); - } // LCOV_EXCL_STOP -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/extra_type_info.cpp b/src/duckdb/src/common/extra_type_info.cpp deleted file mode 100644 index 117f9b0f7..000000000 --- a/src/duckdb/src/common/extra_type_info.cpp +++ /dev/null @@ -1,467 +0,0 @@ -#include "duckdb/common/extra_type_info.hpp" -#include "duckdb/common/extra_type_info/enum_type_info.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/string_map_set.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Extension Type Info -//===--------------------------------------------------------------------===// - -bool ExtensionTypeInfo::Equals(optional_ptr lhs, optional_ptr rhs) { - // Either both are null, or both are the same, so they are equal - if (lhs.get() == rhs.get()) { - return true; - } - // If one is null, then we cant compare them - if (lhs == nullptr || rhs == nullptr) { - return true; - } - - // Both are not null, so we can compare them - D_ASSERT(lhs != nullptr && rhs != nullptr); - - // Compare modifiers - const auto &lhs_mods = lhs->modifiers; - const auto &rhs_mods = rhs->modifiers; - const auto common_mods = MinValue(lhs_mods.size(), rhs_mods.size()); - for (idx_t i = 0; i < common_mods; i++) { - // If the types are not strictly equal, they are not equal - auto &lhs_val = lhs_mods[i].value; - auto &rhs_val = rhs_mods[i].value; - - if (lhs_val.type() != rhs_val.type()) { - return false; - } - - // If both are null, its fine - if (lhs_val.IsNull() && rhs_val.IsNull()) { - continue; - } - - // If one is null, the other must be null too - if (lhs_val.IsNull() != rhs_val.IsNull()) { - return false; - } - - if (lhs_val != rhs_val) { - return false; - } - } - - // Properties are optional, so only compare those present in both - const auto &lhs_props = lhs->properties; - const auto &rhs_props = rhs->properties; - - for (const auto &kv : lhs_props) { - auto it = rhs_props.find(kv.first); - if (it == rhs_props.end()) { - // Continue - continue; - } - if (kv.second != it->second) { - // Mismatch! - return false; - } - } - - // All ok! - return true; -} - -//===--------------------------------------------------------------------===// -// Extra Type Info -//===--------------------------------------------------------------------===// -ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type) : type(type) { -} -ExtraTypeInfo::ExtraTypeInfo(ExtraTypeInfoType type, string alias) : type(type), alias(std::move(alias)) { -} -ExtraTypeInfo::~ExtraTypeInfo() { -} - -ExtraTypeInfo::ExtraTypeInfo(const ExtraTypeInfo &other) : type(other.type), alias(other.alias) { - if (other.extension_info) { - extension_info = make_uniq(*other.extension_info); - } -} - -ExtraTypeInfo &ExtraTypeInfo::operator=(const ExtraTypeInfo &other) { - type = other.type; - alias = other.alias; - if (other.extension_info) { - extension_info = make_uniq(*other.extension_info); - } - return *this; -} - -shared_ptr ExtraTypeInfo::Copy() const { - return shared_ptr(new ExtraTypeInfo(*this)); -} - -bool ExtraTypeInfo::Equals(ExtraTypeInfo *other_p) const { - if (type == ExtraTypeInfoType::INVALID_TYPE_INFO || type == ExtraTypeInfoType::STRING_TYPE_INFO || - type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { - if (!other_p) { - if (!alias.empty()) { - return false; - } - if (extension_info) { - return false; - } - //! We only need to compare aliases when both types have them in this case - return true; - } - if (alias != other_p->alias) { - return false; - } - if (!ExtensionTypeInfo::Equals(extension_info, other_p->extension_info)) { - return false; - } - return true; - } - if (!other_p) { - return false; - } - if (type != other_p->type) { - return false; - } - if (alias != other_p->alias) { - return false; - } - if (!ExtensionTypeInfo::Equals(extension_info, other_p->extension_info)) { - return false; - } - return EqualsInternal(other_p); -} - -bool ExtraTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - // Do nothing - return true; -} - -//===--------------------------------------------------------------------===// -// Decimal Type Info -//===--------------------------------------------------------------------===// -DecimalTypeInfo::DecimalTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO) { -} - -DecimalTypeInfo::DecimalTypeInfo(uint8_t width_p, uint8_t scale_p) - : ExtraTypeInfo(ExtraTypeInfoType::DECIMAL_TYPE_INFO), width(width_p), scale(scale_p) { - D_ASSERT(width_p >= scale_p); -} - -bool DecimalTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return width == other.width && scale == other.scale; -} - -shared_ptr DecimalTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// String Type Info -//===--------------------------------------------------------------------===// -StringTypeInfo::StringTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO) { -} - -StringTypeInfo::StringTypeInfo(string collation_p) - : ExtraTypeInfo(ExtraTypeInfoType::STRING_TYPE_INFO), collation(std::move(collation_p)) { -} - -bool StringTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - // collation info has no impact on equality - return true; -} - -shared_ptr StringTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// List Type Info -//===--------------------------------------------------------------------===// -ListTypeInfo::ListTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO) { -} - -ListTypeInfo::ListTypeInfo(LogicalType child_type_p) - : ExtraTypeInfo(ExtraTypeInfoType::LIST_TYPE_INFO), child_type(std::move(child_type_p)) { -} - -bool ListTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return child_type == other.child_type; -} - -shared_ptr ListTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// Struct Type Info -//===--------------------------------------------------------------------===// -StructTypeInfo::StructTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO) { -} - -StructTypeInfo::StructTypeInfo(child_list_t child_types_p) - : ExtraTypeInfo(ExtraTypeInfoType::STRUCT_TYPE_INFO), child_types(std::move(child_types_p)) { -} - -bool StructTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return child_types == other.child_types; -} - -shared_ptr StructTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type Info -//===--------------------------------------------------------------------===// -AggregateStateTypeInfo::AggregateStateTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO) { -} - -AggregateStateTypeInfo::AggregateStateTypeInfo(aggregate_state_t state_type_p) - : ExtraTypeInfo(ExtraTypeInfoType::AGGREGATE_STATE_TYPE_INFO), state_type(std::move(state_type_p)) { -} - -bool AggregateStateTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return state_type.function_name == other.state_type.function_name && - state_type.return_type == other.state_type.return_type && - state_type.bound_argument_types == other.state_type.bound_argument_types; -} - -shared_ptr AggregateStateTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// User Type Info -//===--------------------------------------------------------------------===// -UserTypeInfo::UserTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO) { -} - -UserTypeInfo::UserTypeInfo(string name_p) - : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO), user_type_name(std::move(name_p)) { -} - -UserTypeInfo::UserTypeInfo(string name_p, vector modifiers_p) - : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO), user_type_name(std::move(name_p)), - user_type_modifiers(std::move(modifiers_p)) { -} - -UserTypeInfo::UserTypeInfo(string catalog_p, string schema_p, string name_p, vector modifiers_p) - : ExtraTypeInfo(ExtraTypeInfoType::USER_TYPE_INFO), catalog(std::move(catalog_p)), schema(std::move(schema_p)), - user_type_name(std::move(name_p)), user_type_modifiers(std::move(modifiers_p)) { -} - -bool UserTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return other.user_type_name == user_type_name; -} - -shared_ptr UserTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// Enum Type Info -//===--------------------------------------------------------------------===// -PhysicalType EnumTypeInfo::DictType(idx_t size) { - if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT8; - } else if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT16; - } else if (size <= NumericLimits::Maximum()) { - return PhysicalType::UINT32; - } else { - throw InternalException("Enum size must be lower than " + std::to_string(NumericLimits::Maximum())); - } -} - -EnumTypeInfo::EnumTypeInfo(Vector &values_insert_order_p, idx_t dict_size_p) - : ExtraTypeInfo(ExtraTypeInfoType::ENUM_TYPE_INFO), values_insert_order(values_insert_order_p), - dict_type(EnumDictType::VECTOR_DICT), dict_size(dict_size_p) { -} - -const EnumDictType &EnumTypeInfo::GetEnumDictType() const { - return dict_type; -} - -const Vector &EnumTypeInfo::GetValuesInsertOrder() const { - return values_insert_order; -} - -const idx_t &EnumTypeInfo::GetDictSize() const { - return dict_size; -} - -LogicalType EnumTypeInfo::CreateType(Vector &ordered_data, idx_t size) { - // Generate EnumTypeInfo - shared_ptr info; - auto enum_internal_type = EnumTypeInfo::DictType(size); - switch (enum_internal_type) { - case PhysicalType::UINT8: - info = make_shared_ptr>(ordered_data, size); - break; - case PhysicalType::UINT16: - info = make_shared_ptr>(ordered_data, size); - break; - case PhysicalType::UINT32: - info = make_shared_ptr>(ordered_data, size); - break; - default: - throw InternalException("Invalid Physical Type for ENUMs"); - } - // Generate Actual Enum Type - return LogicalType(LogicalTypeId::ENUM, info); -} - -template -int64_t TemplatedGetPos(const string_map_t &map, const string_t &key) { - auto it = map.find(key); - if (it == map.end()) { - return -1; - } - return it->second; -} - -int64_t EnumType::GetPos(const LogicalType &type, const string_t &key) { - auto info = type.AuxInfo(); - switch (type.InternalType()) { - case PhysicalType::UINT8: - return TemplatedGetPos(info->Cast>().GetValues(), key); - case PhysicalType::UINT16: - return TemplatedGetPos(info->Cast>().GetValues(), key); - case PhysicalType::UINT32: - return TemplatedGetPos(info->Cast>().GetValues(), key); - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } -} - -string_t EnumType::GetString(const LogicalType &type, idx_t pos) { - D_ASSERT(pos < EnumType::GetSize(type)); - return FlatVector::GetData(EnumType::GetValuesInsertOrder(type))[pos]; -} - -shared_ptr EnumTypeInfo::Deserialize(Deserializer &deserializer) { - auto values_count = deserializer.ReadProperty(200, "values_count"); - auto enum_internal_type = EnumTypeInfo::DictType(values_count); - switch (enum_internal_type) { - case PhysicalType::UINT8: - return EnumTypeInfoTemplated::Deserialize(deserializer, NumericCast(values_count)); - case PhysicalType::UINT16: - return EnumTypeInfoTemplated::Deserialize(deserializer, NumericCast(values_count)); - case PhysicalType::UINT32: - return EnumTypeInfoTemplated::Deserialize(deserializer, NumericCast(values_count)); - default: - throw InternalException("Invalid Physical Type for ENUMs"); - } -} - -// Equalities are only used in enums with different catalog entries -bool EnumTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - if (dict_type != other.dict_type) { - return false; - } - D_ASSERT(dict_type == EnumDictType::VECTOR_DICT); - // We must check if both enums have the same size - if (other.dict_size != dict_size) { - return false; - } - auto other_vector_ptr = FlatVector::GetData(other.values_insert_order); - auto this_vector_ptr = FlatVector::GetData(values_insert_order); - - // Now we must check if all strings are the same - for (idx_t i = 0; i < dict_size; i++) { - if (!Equals::Operation(other_vector_ptr[i], this_vector_ptr[i])) { - return false; - } - } - return true; -} - -void EnumTypeInfo::Serialize(Serializer &serializer) const { - ExtraTypeInfo::Serialize(serializer); - - // Enums are special in that we serialize their values as a list instead of dumping the whole vector - auto strings = FlatVector::GetData(values_insert_order); - serializer.WriteProperty(200, "values_count", dict_size); - serializer.WriteList(201, "values", dict_size, - [&](Serializer::List &list, idx_t i) { list.WriteElement(strings[i]); }); -} - -shared_ptr EnumTypeInfo::Copy() const { - Vector values_insert_order_copy(LogicalType::VARCHAR, false, false, 0); - values_insert_order_copy.Reference(values_insert_order); - return make_shared_ptr(values_insert_order_copy, dict_size); -} - -//===--------------------------------------------------------------------===// -// ArrayTypeInfo -//===--------------------------------------------------------------------===// - -ArrayTypeInfo::ArrayTypeInfo(LogicalType child_type_p, uint32_t size_p) - : ExtraTypeInfo(ExtraTypeInfoType::ARRAY_TYPE_INFO), child_type(std::move(child_type_p)), size(size_p) { -} - -bool ArrayTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return child_type == other.child_type && size == other.size; -} - -shared_ptr ArrayTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// Any Type Info -//===--------------------------------------------------------------------===// -AnyTypeInfo::AnyTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO) { -} - -AnyTypeInfo::AnyTypeInfo(LogicalType target_type_p, idx_t cast_score_p) - : ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO), target_type(std::move(target_type_p)), cast_score(cast_score_p) { -} - -bool AnyTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return target_type == other.target_type && cast_score == other.cast_score; -} - -shared_ptr AnyTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -//===--------------------------------------------------------------------===// -// Integer Literal Type Info -//===--------------------------------------------------------------------===// -IntegerLiteralTypeInfo::IntegerLiteralTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO) { -} - -IntegerLiteralTypeInfo::IntegerLiteralTypeInfo(Value constant_value_p) - : ExtraTypeInfo(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), constant_value(std::move(constant_value_p)) { - if (constant_value.IsNull()) { - throw InternalException("Integer literal cannot be NULL"); - } -} - -bool IntegerLiteralTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const { - auto &other = other_p->Cast(); - return constant_value == other.constant_value; -} - -shared_ptr IntegerLiteralTypeInfo::Copy() const { - return make_shared_ptr(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/file_buffer.cpp b/src/duckdb/src/common/file_buffer.cpp deleted file mode 100644 index b1b1febb3..000000000 --- a/src/duckdb/src/common/file_buffer.cpp +++ /dev/null @@ -1,111 +0,0 @@ -#include "duckdb/common/file_buffer.hpp" - -#include "duckdb/common/allocator.hpp" -#include "duckdb/common/checksum.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/storage/storage_info.hpp" -#include - -namespace duckdb { - -FileBuffer::FileBuffer(Allocator &allocator, FileBufferType type, uint64_t user_size) - : allocator(allocator), type(type) { - Init(); - if (user_size) { - Resize(user_size); - } -} - -void FileBuffer::Init() { - buffer = nullptr; - size = 0; - internal_buffer = nullptr; - internal_size = 0; -} - -FileBuffer::FileBuffer(FileBuffer &source, FileBufferType type_p) : allocator(source.allocator), type(type_p) { - // take over the structures of the source buffer - buffer = source.buffer; - size = source.size; - internal_buffer = source.internal_buffer; - internal_size = source.internal_size; - - source.Init(); -} - -FileBuffer::~FileBuffer() { - if (!internal_buffer) { - return; - } - allocator.FreeData(internal_buffer, internal_size); -} - -void FileBuffer::ReallocBuffer(idx_t new_size) { - data_ptr_t new_buffer; - if (internal_buffer) { - new_buffer = allocator.ReallocateData(internal_buffer, internal_size, new_size); - } else { - new_buffer = allocator.AllocateData(new_size); - } - - // FIXME: should we throw one of our exceptions here? - if (!new_buffer) { - throw std::bad_alloc(); - } - internal_buffer = new_buffer; - internal_size = new_size; - - // The caller must update these. - buffer = nullptr; - size = 0; -} - -FileBuffer::MemoryRequirement FileBuffer::CalculateMemory(uint64_t user_size) { - FileBuffer::MemoryRequirement result; - - if (type == FileBufferType::TINY_BUFFER) { - // We never do IO on tiny buffers, so there's no need to add a header or sector-align. - result.header_size = 0; - result.alloc_size = user_size; - } else { - result.header_size = Storage::DEFAULT_BLOCK_HEADER_SIZE; - result.alloc_size = AlignValue(result.header_size + user_size); - } - return result; -} - -void FileBuffer::Resize(uint64_t new_size) { - auto req = CalculateMemory(new_size); - ReallocBuffer(req.alloc_size); - - if (new_size > 0) { - buffer = internal_buffer + req.header_size; - size = internal_size - req.header_size; - } -} - -void FileBuffer::Read(FileHandle &handle, uint64_t location) { - D_ASSERT(type != FileBufferType::TINY_BUFFER); - handle.Read(internal_buffer, internal_size, location); -} - -void FileBuffer::Write(FileHandle &handle, uint64_t location) { - D_ASSERT(type != FileBufferType::TINY_BUFFER); - handle.Write(internal_buffer, internal_size, location); -} - -void FileBuffer::Clear() { - memset(internal_buffer, 0, internal_size); -} - -void FileBuffer::Initialize(DebugInitialize initialize) { - if (initialize == DebugInitialize::NO_INITIALIZE) { - return; - } - uint8_t value = initialize == DebugInitialize::DEBUG_ZERO_INITIALIZE ? 0 : 0xFF; - memset(internal_buffer, value, internal_size); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/file_system.cpp b/src/duckdb/src/common/file_system.cpp deleted file mode 100644 index fbcf0cede..000000000 --- a/src/duckdb/src/common/file_system.cpp +++ /dev/null @@ -1,662 +0,0 @@ -#include "duckdb/common/file_system.hpp" - -#include "duckdb/common/checksum.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_opener.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/windows.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/common/windows_util.hpp" -#include "duckdb/common/operator/multiply.hpp" - -#include -#include - -#ifndef _WIN32 -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef __MVS__ -#define _XOPEN_SOURCE_EXTENDED 1 -#include -// enjoy - https://reviews.llvm.org/D92110 -#define PATH_MAX _XOPEN_PATH_MAX -#endif - -#else -#include -#include - -#ifdef __MINGW32__ -// need to manually define this for mingw -extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); -#endif - -#undef FILE_CREATE // woo mingw -#endif - -namespace duckdb { - -constexpr FileOpenFlags FileFlags::FILE_FLAGS_READ; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_WRITE; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_DIRECT_IO; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_FILE_CREATE; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_FILE_CREATE_NEW; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_APPEND; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_PRIVATE; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_PARALLEL_ACCESS; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_EXCLUSIVE_CREATE; -constexpr FileOpenFlags FileFlags::FILE_FLAGS_NULL_IF_EXISTS; - -void FileOpenFlags::Verify() { -#ifdef DEBUG - bool is_read = flags & FileOpenFlags::FILE_FLAGS_READ; - bool is_write = flags & FileOpenFlags::FILE_FLAGS_WRITE; - bool is_create = - (flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE) || (flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE_NEW); - bool is_private = (flags & FileOpenFlags::FILE_FLAGS_PRIVATE); - bool null_if_not_exists = flags & FileOpenFlags::FILE_FLAGS_NULL_IF_NOT_EXISTS; - bool exclusive_create = flags & FileOpenFlags::FILE_FLAGS_EXCLUSIVE_CREATE; - bool null_if_exists = flags & FileOpenFlags::FILE_FLAGS_NULL_IF_EXISTS; - - // require either READ or WRITE (or both) - D_ASSERT(is_read || is_write); - // CREATE/Append flags require writing - D_ASSERT(is_write || !(flags & FileOpenFlags::FILE_FLAGS_APPEND)); - D_ASSERT(is_write || !(flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE)); - D_ASSERT(is_write || !(flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE_NEW)); - // cannot combine CREATE and CREATE_NEW flags - D_ASSERT(!(flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE && flags & FileOpenFlags::FILE_FLAGS_FILE_CREATE_NEW)); - - // For is_private can only be set along with a create flag - D_ASSERT(!is_private || is_create); - // FILE_FLAGS_NULL_IF_NOT_EXISTS cannot be combined with CREATE/CREATE_NEW - D_ASSERT(!(null_if_not_exists && is_create)); - // FILE_FLAGS_EXCLUSIVE_CREATE only can be combined with CREATE/CREATE_NEW - D_ASSERT(!exclusive_create || is_create); - // FILE_FLAGS_NULL_IF_EXISTS only can be set with EXCLUSIVE_CREATE - D_ASSERT(!null_if_exists || exclusive_create); -#endif -} - -FileSystem::~FileSystem() { -} - -FileSystem &FileSystem::GetFileSystem(ClientContext &context) { - auto &client_data = ClientData::Get(context); - return *client_data.client_file_system; -} - -bool PathMatched(const string &path, const string &sub_path) { - return path.rfind(sub_path, 0) == 0; -} - -#ifndef _WIN32 - -string FileSystem::GetEnvVariable(const string &name) { - const char *env = getenv(name.c_str()); - if (!env) { - return string(); - } - return env; -} - -bool FileSystem::IsPathAbsolute(const string &path) { - auto path_separator = PathSeparator(path); - return PathMatched(path, path_separator) || StringUtil::StartsWith(path, "file:/"); -} - -string FileSystem::PathSeparator(const string &path) { - return "/"; -} - -void FileSystem::SetWorkingDirectory(const string &path) { - if (chdir(path.c_str()) != 0) { - throw IOException("Could not change working directory!"); - } -} - -optional_idx FileSystem::GetAvailableMemory() { - errno = 0; - -#ifdef __MVS__ - struct rlimit limit; - int rlim_rc = getrlimit(RLIMIT_AS, &limit); - idx_t max_memory = MinValue(limit.rlim_max, UINTPTR_MAX); -#else - idx_t max_memory = MinValue((idx_t)sysconf(_SC_PHYS_PAGES) * (idx_t)sysconf(_SC_PAGESIZE), UINTPTR_MAX); -#endif - if (errno != 0) { - return optional_idx(); - } - return max_memory; -} - -optional_idx FileSystem::GetAvailableDiskSpace(const string &path) { - struct statvfs vfs; - - auto ret = statvfs(path.c_str(), &vfs); - if (ret == -1) { - return optional_idx(); - } - auto block_size = vfs.f_frsize; - // These are the blocks available for creating new files or extending existing ones - auto available_blocks = vfs.f_bfree; - idx_t available_disk_space = DConstants::INVALID_INDEX; - if (!TryMultiplyOperator::Operation(static_cast(block_size), static_cast(available_blocks), - available_disk_space)) { - return optional_idx(); - } - return available_disk_space; -} - -string FileSystem::GetWorkingDirectory() { - auto buffer = make_unsafe_uniq_array(PATH_MAX); - char *ret = getcwd(buffer.get(), PATH_MAX); - if (!ret) { - throw IOException("Could not get working directory!"); - } - return string(buffer.get()); -} - -string FileSystem::NormalizeAbsolutePath(const string &path) { - D_ASSERT(IsPathAbsolute(path)); - return path; -} - -#else - -string FileSystem::GetEnvVariable(const string &env) { - // first convert the environment variable name to the correct encoding - auto env_w = WindowsUtil::UTF8ToUnicode(env.c_str()); - // use _wgetenv to get the value - auto res_w = _wgetenv(env_w.c_str()); - if (!res_w) { - // no environment variable of this name found - return string(); - } - return WindowsUtil::UnicodeToUTF8(res_w); -} - -static bool StartsWithSingleBackslash(const string &path) { - if (path.size() < 2) { - return false; - } - if (path[0] != '/' && path[0] != '\\') { - return false; - } - if (path[1] == '/' || path[1] == '\\') { - return false; - } - return true; -} - -bool FileSystem::IsPathAbsolute(const string &path) { - // 1) A single backslash or forward-slash - if (StartsWithSingleBackslash(path)) { - return true; - } - // 2) special "long paths" on windows - if (PathMatched(path, "\\\\?\\")) { - return true; - } - // 3) a network path - if (PathMatched(path, "\\\\")) { - return true; - } - // 4) A disk designator with a backslash (e.g., C:\ or C:/) - auto path_aux = path; - path_aux.erase(0, 1); - if (PathMatched(path_aux, ":\\") || PathMatched(path_aux, ":/")) { - return true; - } - return false; -} - -string FileSystem::NormalizeAbsolutePath(const string &path) { - D_ASSERT(IsPathAbsolute(path)); - auto result = StringUtil::Lower(FileSystem::ConvertSeparators(path)); - if (StartsWithSingleBackslash(result)) { - // Path starts with a single backslash or forward slash - // prepend drive letter - return GetWorkingDirectory().substr(0, 2) + result; - } - return result; -} - -string FileSystem::PathSeparator(const string &path) { - if (StringUtil::StartsWith(path, "file:")) { - return "/"; - } else { - return "\\"; - } -} - -void FileSystem::SetWorkingDirectory(const string &path) { - auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); - if (!SetCurrentDirectoryW(unicode_path.c_str())) { - throw IOException("Could not change working directory to \"%s\"", path); - } -} - -optional_idx FileSystem::GetAvailableMemory() { - ULONGLONG available_memory_kb; - if (GetPhysicallyInstalledSystemMemory(&available_memory_kb)) { - return MinValue(available_memory_kb * 1000, UINTPTR_MAX); - } - // fallback: try GlobalMemoryStatusEx - MEMORYSTATUSEX mem_state; - mem_state.dwLength = sizeof(MEMORYSTATUSEX); - - if (GlobalMemoryStatusEx(&mem_state)) { - return MinValue(mem_state.ullTotalPhys, UINTPTR_MAX); - } - return optional_idx(); -} - -optional_idx FileSystem::GetAvailableDiskSpace(const string &path) { - ULARGE_INTEGER available_bytes, total_bytes, free_bytes; - - auto unicode_path = WindowsUtil::UTF8ToUnicode(path.c_str()); - if (!GetDiskFreeSpaceExW(unicode_path.c_str(), &available_bytes, &total_bytes, &free_bytes)) { - return optional_idx(); - } - (void)total_bytes; - (void)free_bytes; - return NumericCast(available_bytes.QuadPart); -} - -string FileSystem::GetWorkingDirectory() { - idx_t count = GetCurrentDirectoryW(0, nullptr); - if (count == 0) { - throw IOException("Could not get working directory!"); - } - auto buffer = make_unsafe_uniq_array(count); - idx_t ret = GetCurrentDirectoryW(count, buffer.get()); - if (count != ret + 1) { - throw IOException("Could not get working directory!"); - } - return WindowsUtil::UnicodeToUTF8(buffer.get()); -} - -#endif - -string FileSystem::JoinPath(const string &a, const string &b) { - // FIXME: sanitize paths - return a.empty() ? b : a + PathSeparator(a) + b; -} - -string FileSystem::ConvertSeparators(const string &path) { - auto separator_str = PathSeparator(path); - char separator = separator_str[0]; - if (separator == '/') { - // on unix-based systems we only accept / as a separator - return path; - } - // on windows-based systems we accept both - return StringUtil::Replace(path, "/", separator_str); -} - -string FileSystem::ExtractName(const string &path) { - if (path.empty()) { - return string(); - } - auto normalized_path = ConvertSeparators(path); - auto sep = PathSeparator(path); - auto splits = StringUtil::Split(normalized_path, sep); - D_ASSERT(!splits.empty()); - return splits.back(); -} - -string FileSystem::ExtractBaseName(const string &path) { - if (path.empty()) { - return string(); - } - auto vec = StringUtil::Split(ExtractName(path), "."); - D_ASSERT(!vec.empty()); - return vec[0]; -} - -string FileSystem::GetHomeDirectory(optional_ptr opener) { - // read the home_directory setting first, if it is set - if (opener) { - Value result; - if (opener->TryGetCurrentSetting("home_directory", result)) { - if (!result.IsNull() && !result.ToString().empty()) { - return result.ToString(); - } - } - } - // fallback to the default home directories for the specified system -#ifdef DUCKDB_WINDOWS - return FileSystem::GetEnvVariable("USERPROFILE"); -#else - return FileSystem::GetEnvVariable("HOME"); -#endif -} - -string FileSystem::GetHomeDirectory() { - return GetHomeDirectory(nullptr); -} - -string FileSystem::ExpandPath(const string &path, optional_ptr opener) { - if (path.empty()) { - return path; - } - if (path[0] == '~') { - return GetHomeDirectory(opener) + path.substr(1); - } - return path; -} - -string FileSystem::ExpandPath(const string &path) { - return FileSystem::ExpandPath(path, nullptr); -} - -// LCOV_EXCL_START -unique_ptr FileSystem::OpenFile(const string &path, FileOpenFlags flags, optional_ptr opener) { - throw NotImplementedException("%s: OpenFile is not implemented!", GetName()); -} - -void FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - throw NotImplementedException("%s: Read (with location) is not implemented!", GetName()); -} - -bool FileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_bytes) { - // This is not a required method. Derived FileSystems may optionally override/implement. - return false; -} - -void FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - throw NotImplementedException("%s: Write (with location) is not implemented!", GetName()); -} - -int64_t FileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - throw NotImplementedException("%s: Read is not implemented!", GetName()); -} - -int64_t FileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - throw NotImplementedException("%s: Write is not implemented!", GetName()); -} - -int64_t FileSystem::GetFileSize(FileHandle &handle) { - throw NotImplementedException("%s: GetFileSize is not implemented!", GetName()); -} - -time_t FileSystem::GetLastModifiedTime(FileHandle &handle) { - throw NotImplementedException("%s: GetLastModifiedTime is not implemented!", GetName()); -} - -FileType FileSystem::GetFileType(FileHandle &handle) { - return FileType::FILE_TYPE_INVALID; -} - -void FileSystem::Truncate(FileHandle &handle, int64_t new_size) { - throw NotImplementedException("%s: Truncate is not implemented!", GetName()); -} - -bool FileSystem::DirectoryExists(const string &directory, optional_ptr opener) { - throw NotImplementedException("%s: DirectoryExists is not implemented!", GetName()); -} - -void FileSystem::CreateDirectory(const string &directory, optional_ptr opener) { - throw NotImplementedException("%s: CreateDirectory is not implemented!", GetName()); -} - -void FileSystem::RemoveDirectory(const string &directory, optional_ptr opener) { - throw NotImplementedException("%s: RemoveDirectory is not implemented!", GetName()); -} - -bool FileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - throw NotImplementedException("%s: ListFiles is not implemented!", GetName()); -} - -void FileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { - throw NotImplementedException("%s: MoveFile is not implemented!", GetName()); -} - -bool FileSystem::FileExists(const string &filename, optional_ptr opener) { - throw NotImplementedException("%s: FileExists is not implemented!", GetName()); -} - -bool FileSystem::IsPipe(const string &filename, optional_ptr opener) { - return false; -} - -void FileSystem::RemoveFile(const string &filename, optional_ptr opener) { - throw NotImplementedException("%s: RemoveFile is not implemented!", GetName()); -} - -void FileSystem::FileSync(FileHandle &handle) { - throw NotImplementedException("%s: FileSync is not implemented!", GetName()); -} - -bool FileSystem::HasGlob(const string &str) { - for (idx_t i = 0; i < str.size(); i++) { - switch (str[i]) { - case '*': - case '?': - case '[': - return true; - default: - break; - } - } - return false; -} - -vector FileSystem::Glob(const string &path, FileOpener *opener) { - throw NotImplementedException("%s: Glob is not implemented!", GetName()); -} - -void FileSystem::RegisterSubSystem(unique_ptr sub_fs) { - throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr sub_fs) { - throw NotImplementedException("%s: Can't register a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::UnregisterSubSystem(const string &name) { - throw NotImplementedException("%s: Can't unregister a sub system on a non-virtual file system", GetName()); -} - -void FileSystem::SetDisabledFileSystems(const vector &names) { - throw NotImplementedException("%s: Can't disable file systems on a non-virtual file system", GetName()); -} - -vector FileSystem::ListSubSystems() { - throw NotImplementedException("%s: Can't list sub systems on a non-virtual file system", GetName()); -} - -bool FileSystem::CanHandleFile(const string &fpath) { - throw NotImplementedException("%s: CanHandleFile is not implemented!", GetName()); -} - -static string LookupExtensionForPattern(const string &pattern) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(pattern, entry.name)) { - return entry.extension; - } - } - return ""; -} - -vector FileSystem::GlobFiles(const string &pattern, ClientContext &context, FileGlobOptions options) { - auto result = Glob(pattern); - if (result.empty()) { - string required_extension = LookupExtensionForPattern(pattern); - if (!required_extension.empty() && !context.db->ExtensionIsLoaded(required_extension)) { - auto &dbconfig = DBConfig::GetConfig(context); - if (!ExtensionHelper::CanAutoloadExtension(required_extension) || - !dbconfig.options.autoload_known_extensions) { - auto error_message = - "File " + pattern + " requires the extension " + required_extension + " to be loaded"; - error_message = - ExtensionHelper::AddExtensionInstallHintToErrorMsg(context, error_message, required_extension); - throw MissingExtensionException(error_message); - } - // an extension is required to read this file, but it is not loaded - try to load it - ExtensionHelper::AutoLoadExtension(context, required_extension); - // success! glob again - // check the extension is loaded just in case to prevent an infinite loop here - if (!context.db->ExtensionIsLoaded(required_extension)) { - throw InternalException("Extension load \"%s\" did not throw but somehow the extension was not loaded", - required_extension); - } - return GlobFiles(pattern, context, options); - } - if (options == FileGlobOptions::DISALLOW_EMPTY) { - throw IOException("No files found that match the pattern \"%s\"", pattern); - } - } - return result; -} - -void FileSystem::Seek(FileHandle &handle, idx_t location) { - throw NotImplementedException("%s: Seek is not implemented!", GetName()); -} - -void FileSystem::Reset(FileHandle &handle) { - handle.Seek(0); -} - -idx_t FileSystem::SeekPosition(FileHandle &handle) { - throw NotImplementedException("%s: SeekPosition is not implemented!", GetName()); -} - -bool FileSystem::CanSeek() { - throw NotImplementedException("%s: CanSeek is not implemented!", GetName()); -} - -bool FileSystem::IsManuallySet() { - return false; -} - -unique_ptr FileSystem::OpenCompressedFile(unique_ptr handle, bool write) { - throw NotImplementedException("%s: OpenCompressedFile is not implemented!", GetName()); -} - -bool FileSystem::OnDiskFile(FileHandle &handle) { - throw NotImplementedException("%s: OnDiskFile is not implemented!", GetName()); -} -// LCOV_EXCL_STOP - -FileHandle::FileHandle(FileSystem &file_system, string path_p, FileOpenFlags flags) - : file_system(file_system), path(std::move(path_p)), flags(flags) { -} - -FileHandle::~FileHandle() { -} - -int64_t FileHandle::Read(void *buffer, idx_t nr_bytes) { - return file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes)); -} - -bool FileHandle::Trim(idx_t offset_bytes, idx_t length_bytes) { - return file_system.Trim(*this, offset_bytes, length_bytes); -} - -int64_t FileHandle::Write(void *buffer, idx_t nr_bytes) { - return file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes)); -} - -void FileHandle::Read(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Read(*this, buffer, UnsafeNumericCast(nr_bytes), location); -} - -void FileHandle::Write(void *buffer, idx_t nr_bytes, idx_t location) { - file_system.Write(*this, buffer, UnsafeNumericCast(nr_bytes), location); -} - -void FileHandle::Seek(idx_t location) { - file_system.Seek(*this, location); -} - -void FileHandle::Reset() { - file_system.Reset(*this); -} - -idx_t FileHandle::SeekPosition() { - return file_system.SeekPosition(*this); -} - -bool FileHandle::CanSeek() { - return file_system.CanSeek(); -} - -FileCompressionType FileHandle::GetFileCompressionType() { - return FileCompressionType::UNCOMPRESSED; -} - -bool FileHandle::IsPipe() { - return file_system.IsPipe(path); -} - -string FileHandle::ReadLine() { - string result; - char buffer[1]; - while (true) { - auto tuples_read = UnsafeNumericCast(Read(buffer, 1)); - if (tuples_read == 0 || buffer[0] == '\n') { - return result; - } - if (buffer[0] != '\r') { - result += buffer[0]; - } - } -} - -bool FileHandle::OnDiskFile() { - return file_system.OnDiskFile(*this); -} - -idx_t FileHandle::GetFileSize() { - return NumericCast(file_system.GetFileSize(*this)); -} - -void FileHandle::Sync() { - file_system.FileSync(*this); -} - -void FileHandle::Truncate(int64_t new_size) { - file_system.Truncate(*this, new_size); -} - -FileType FileHandle::GetType() { - return file_system.GetFileType(*this); -} - -idx_t FileHandle::GetProgress() { - throw NotImplementedException("GetProgress is not implemented for this file handle"); -} - -bool FileSystem::IsRemoteFile(const string &path) { - string extension = ""; - return IsRemoteFile(path, extension); -} - -bool FileSystem::IsRemoteFile(const string &path, string &extension) { - for (const auto &entry : EXTENSION_FILE_PREFIXES) { - if (StringUtil::StartsWith(path, entry.name)) { - extension = entry.extension; - return true; - } - } - return false; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/filename_pattern.cpp b/src/duckdb/src/common/filename_pattern.cpp deleted file mode 100644 index 04851ad3f..000000000 --- a/src/duckdb/src/common/filename_pattern.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "duckdb/common/filename_pattern.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -void FilenamePattern::SetFilenamePattern(const string &pattern) { - const string id_format {"{i}"}; - const string uuid_format {"{uuid}"}; - - base = pattern; - - pos = base.find(id_format); - uuid = false; - if (pos != string::npos) { - base = StringUtil::Replace(base, id_format, ""); - uuid = false; - } - - pos = base.find(uuid_format); - if (pos != string::npos) { - base = StringUtil::Replace(base, uuid_format, ""); - uuid = true; - } - - pos = std::min(pos, (idx_t)base.length()); -} - -string FilenamePattern::CreateFilename(FileSystem &fs, const string &path, const string &extension, - idx_t offset) const { - string result(base); - string replacement; - - if (uuid) { - replacement = UUID::ToString(UUID::GenerateRandomUUID()); - } else { - replacement = std::to_string(offset); - } - result.insert(pos, replacement); - return fs.JoinPath(path, result + "." + extension); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/fsst.cpp b/src/duckdb/src/common/fsst.cpp deleted file mode 100644 index 1e28ad5ab..000000000 --- a/src/duckdb/src/common/fsst.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "duckdb/storage/string_uncompressed.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/fsst.hpp" -#include "fsst.h" - -namespace duckdb { - -string_t FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, Vector &result, const char *compressed_string, - const idx_t compressed_string_len, vector &decompress_buffer) { - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); - auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT - auto decompressed_string_size = duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - decompress_buffer.size(), decompress_buffer.data()); - - D_ASSERT(!decompress_buffer.empty()); - D_ASSERT(decompressed_string_size <= decompress_buffer.size() - 1); - return StringVector::AddStringOrBlob(result, const_char_ptr_cast(decompress_buffer.data()), - decompressed_string_size); -} - -string FSSTPrimitives::DecompressValue(void *duckdb_fsst_decoder, const char *compressed_string, - const idx_t compressed_string_len, vector &decompress_buffer) { - - auto compressed_string_ptr = (unsigned char *)compressed_string; // NOLINT - auto fsst_decoder = reinterpret_cast(duckdb_fsst_decoder); - auto decompressed_string_size = duckdb_fsst_decompress(fsst_decoder, compressed_string_len, compressed_string_ptr, - decompress_buffer.size(), decompress_buffer.data()); - - D_ASSERT(!decompress_buffer.empty()); - D_ASSERT(decompressed_string_size <= decompress_buffer.size() - 1); - return string(char_ptr_cast(decompress_buffer.data()), decompressed_string_size); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/gzip_file_system.cpp b/src/duckdb/src/common/gzip_file_system.cpp deleted file mode 100644 index ee0a21580..000000000 --- a/src/duckdb/src/common/gzip_file_system.cpp +++ /dev/null @@ -1,424 +0,0 @@ -#include "duckdb/common/gzip_file_system.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/numeric_utils.hpp" - -#include "miniz.hpp" -#include "miniz_wrapper.hpp" - -#include "duckdb/common/limits.hpp" - -namespace duckdb { - -/* - - 0 2 bytes magic header 0x1f, 0x8b (\037 \213) - 2 1 byte compression method - 0: store (copied) - 1: compress - 2: pack - 3: lzh - 4..7: reserved - 8: deflate - 3 1 byte flags - bit 0 set: file probably ascii text - bit 1 set: continuation of multi-part gzip file, part number present - bit 2 set: extra field present - bit 3 set: original file name present - bit 4 set: file comment present - bit 5 set: file is encrypted, encryption header present - bit 6,7: reserved - 4 4 bytes file modification time in Unix format - 8 1 byte extra flags (depend on compression method) - 9 1 byte OS type -[ - 2 bytes optional part number (second part=1) -]? -[ - 2 bytes optional extra field length (e) - (e)bytes optional extra field -]? -[ - bytes optional original file name, zero terminated -]? -[ - bytes optional file comment, zero terminated -]? -[ - 12 bytes optional encryption header -]? - bytes compressed data - 4 bytes crc32 - 4 bytes uncompressed input size modulo 2^32 - - */ - -static idx_t GZipConsumeString(FileHandle &input) { - idx_t size = 1; // terminator - char buffer[1]; - while (input.Read(buffer, 1) == 1) { - if (buffer[0] == '\0') { - break; - } - size++; - } - return size; -} - -struct MiniZStreamWrapper : public StreamWrapper { - ~MiniZStreamWrapper() override; - - CompressedFile *file = nullptr; - unique_ptr mz_stream_ptr; - bool writing = false; - duckdb_miniz::mz_ulong crc; - idx_t total_size; - -public: - void Initialize(CompressedFile &file, bool write) override; - - bool Read(StreamData &stream_data) override; - void Write(CompressedFile &file, StreamData &stream_data, data_ptr_t buffer, int64_t nr_bytes) override; - - void Close() override; - - void FlushStream(); -}; - -MiniZStreamWrapper::~MiniZStreamWrapper() { - // avoid closing if destroyed during stack unwinding - if (Exception::UncaughtException()) { - return; - } - try { - MiniZStreamWrapper::Close(); - } catch (...) { // NOLINT - cannot throw in exception - } -} - -void MiniZStreamWrapper::Initialize(CompressedFile &file, bool write) { - Close(); - this->file = &file; - mz_stream_ptr = make_uniq(); - memset(mz_stream_ptr.get(), 0, sizeof(duckdb_miniz::mz_stream)); - this->writing = write; - - // TODO use custom alloc/free methods in miniz to throw exceptions on OOM - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - if (write) { - crc = MZ_CRC32_INIT; - total_size = 0; - - MiniZStream::InitializeGZIPHeader(gzip_hdr); - file.child_handle->Write(gzip_hdr, GZIP_HEADER_MINSIZE); - - auto ret = mz_deflateInit2(mz_stream_ptr.get(), duckdb_miniz::MZ_DEFAULT_LEVEL, MZ_DEFLATED, - -MZ_DEFAULT_WINDOW_BITS, 1, 0); - if (ret != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } else { - idx_t data_start = GZIP_HEADER_MINSIZE; - auto read_count = file.child_handle->Read(gzip_hdr, GZIP_HEADER_MINSIZE); - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, NumericCast(read_count)); - // Skip over the extra field if necessary - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - uint8_t gzip_xlen[2]; - file.child_handle->Seek(data_start); - file.child_handle->Read(gzip_xlen, 2); - auto xlen = NumericCast((uint8_t)gzip_xlen[0] | (uint8_t)gzip_xlen[1] << 8); - data_start += xlen + 2; - } - // Skip over the file name if necessary - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - file.child_handle->Seek(data_start); - data_start += GZipConsumeString(*file.child_handle); - } - file.child_handle->Seek(data_start); - // stream is now set to beginning of payload data - auto ret = duckdb_miniz::mz_inflateInit2(mz_stream_ptr.get(), -MZ_DEFAULT_WINDOW_BITS); - if (ret != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } -} - -bool MiniZStreamWrapper::Read(StreamData &sd) { - // Handling for the concatenated files - if (sd.refresh) { - auto available = (uint32_t)(sd.in_buff_end - sd.in_buff_start); - if (available <= GZIP_FOOTER_SIZE) { - // Only footer is available so we just close and return finished - Close(); - return true; - } - - sd.refresh = false; - auto body_ptr = sd.in_buff_start + GZIP_FOOTER_SIZE; - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); - body_ptr += GZIP_HEADER_MINSIZE; - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - auto xlen = NumericCast((uint8_t)*body_ptr | (uint8_t) * (body_ptr + 1) << 8); - body_ptr += xlen + 2; - if (GZIP_FOOTER_SIZE + GZIP_HEADER_MINSIZE + 2 + xlen >= GZIP_HEADER_MAXSIZE) { - throw InternalException("Extra field resulting in GZIP header larger than defined maximum (%d)", - GZIP_HEADER_MAXSIZE); - } - } - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - char c; - do { - c = UnsafeNumericCast(*body_ptr); - body_ptr++; - } while (c != '\0' && body_ptr < sd.in_buff_end); - if ((idx_t)(body_ptr - sd.in_buff_start) >= GZIP_HEADER_MAXSIZE) { - throw InternalException("Filename resulting in GZIP header larger than defined maximum (%d)", - GZIP_HEADER_MAXSIZE); - } - } - sd.in_buff_start = body_ptr; - if (sd.in_buff_end - sd.in_buff_start < 1) { - Close(); - return true; - } - duckdb_miniz::mz_inflateEnd(mz_stream_ptr.get()); - auto sta = duckdb_miniz::mz_inflateInit2(mz_stream_ptr.get(), -MZ_DEFAULT_WINDOW_BITS); - if (sta != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - } - - // actually decompress - mz_stream_ptr->next_in = sd.in_buff_start; - D_ASSERT(sd.in_buff_end - sd.in_buff_start < NumericLimits::Maximum()); - mz_stream_ptr->avail_in = (uint32_t)(sd.in_buff_end - sd.in_buff_start); - mz_stream_ptr->next_out = data_ptr_cast(sd.out_buff_end); - mz_stream_ptr->avail_out = (uint32_t)((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_end); - auto ret = duckdb_miniz::mz_inflate(mz_stream_ptr.get(), duckdb_miniz::MZ_NO_FLUSH); - if (ret != duckdb_miniz::MZ_OK && ret != duckdb_miniz::MZ_STREAM_END) { - throw IOException("Failed to decode gzip stream: %s", duckdb_miniz::mz_error(ret)); - } - // update pointers following inflate() - sd.in_buff_start = (data_ptr_t)mz_stream_ptr->next_in; // NOLINT - sd.in_buff_end = sd.in_buff_start + mz_stream_ptr->avail_in; - sd.out_buff_end = data_ptr_cast(mz_stream_ptr->next_out); - D_ASSERT(sd.out_buff_end + mz_stream_ptr->avail_out == sd.out_buff.get() + sd.out_buf_size); - - // if stream ended, deallocate inflator - if (ret == duckdb_miniz::MZ_STREAM_END) { - // Concatenated GZIP potentially coming up - refresh input buffer - sd.refresh = true; - } - return false; -} - -void MiniZStreamWrapper::Write(CompressedFile &file, StreamData &sd, data_ptr_t uncompressed_data, - int64_t uncompressed_size) { - // update the src and the total size - crc = duckdb_miniz::mz_crc32(crc, reinterpret_cast(uncompressed_data), - UnsafeNumericCast(uncompressed_size)); - total_size += UnsafeNumericCast(uncompressed_size); - - auto remaining = uncompressed_size; - while (remaining > 0) { - auto output_remaining = UnsafeNumericCast((sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start); - - mz_stream_ptr->next_in = reinterpret_cast(uncompressed_data); - mz_stream_ptr->avail_in = NumericCast(remaining); - mz_stream_ptr->next_out = sd.out_buff_start; - mz_stream_ptr->avail_out = NumericCast(output_remaining); - - auto res = mz_deflate(mz_stream_ptr.get(), duckdb_miniz::MZ_NO_FLUSH); - if (res != duckdb_miniz::MZ_OK) { - D_ASSERT(res != duckdb_miniz::MZ_STREAM_END); - throw InternalException("Failed to compress GZIP block"); - } - sd.out_buff_start += output_remaining - mz_stream_ptr->avail_out; - if (mz_stream_ptr->avail_out == 0) { - // no more output buffer available: flush - file.child_handle->Write(sd.out_buff.get(), - UnsafeNumericCast(sd.out_buff_start - sd.out_buff.get())); - sd.out_buff_start = sd.out_buff.get(); - } - auto written = UnsafeNumericCast(remaining - mz_stream_ptr->avail_in); - uncompressed_data += written; - remaining = mz_stream_ptr->avail_in; - } -} - -void MiniZStreamWrapper::FlushStream() { - auto &sd = file->stream_data; - mz_stream_ptr->next_in = nullptr; - mz_stream_ptr->avail_in = 0; - while (true) { - auto output_remaining = (sd.out_buff.get() + sd.out_buf_size) - sd.out_buff_start; - mz_stream_ptr->next_out = sd.out_buff_start; - mz_stream_ptr->avail_out = NumericCast(output_remaining); - - auto res = mz_deflate(mz_stream_ptr.get(), duckdb_miniz::MZ_FINISH); - sd.out_buff_start += (output_remaining - mz_stream_ptr->avail_out); - if (sd.out_buff_start > sd.out_buff.get()) { - file->child_handle->Write(sd.out_buff.get(), - UnsafeNumericCast(sd.out_buff_start - sd.out_buff.get())); - sd.out_buff_start = sd.out_buff.get(); - } - if (res == duckdb_miniz::MZ_STREAM_END) { - break; - } - if (res != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to compress GZIP block"); - } - } -} - -void MiniZStreamWrapper::Close() { - if (!mz_stream_ptr) { - return; - } - if (writing) { - // flush anything remaining in the stream - FlushStream(); - - // write the footer - unsigned char gzip_footer[MiniZStream::GZIP_FOOTER_SIZE]; - MiniZStream::InitializeGZIPFooter(gzip_footer, crc, total_size); - file->child_handle->Write(gzip_footer, MiniZStream::GZIP_FOOTER_SIZE); - - duckdb_miniz::mz_deflateEnd(mz_stream_ptr.get()); - } else { - duckdb_miniz::mz_inflateEnd(mz_stream_ptr.get()); - } - mz_stream_ptr = nullptr; - file = nullptr; -} - -class GZipFile : public CompressedFile { -public: - GZipFile(unique_ptr child_handle_p, const string &path, bool write) - : CompressedFile(gzip_fs, std::move(child_handle_p), path) { - Initialize(write); - } - FileCompressionType GetFileCompressionType() override { - return FileCompressionType::GZIP; - } - GZipFileSystem gzip_fs; -}; - -void GZipFileSystem::VerifyGZIPHeader(uint8_t gzip_hdr[], idx_t read_count) { - // check for incorrectly formatted files - if (read_count != GZIP_HEADER_MINSIZE) { - throw IOException("Input is not a GZIP stream"); - } - if (gzip_hdr[0] != 0x1F || gzip_hdr[1] != 0x8B) { // magic header - throw IOException("Input is not a GZIP stream"); - } - if (gzip_hdr[2] != GZIP_COMPRESSION_DEFLATE) { // compression method - throw IOException("Unsupported GZIP compression method"); - } - if (gzip_hdr[3] & GZIP_FLAG_UNSUPPORTED) { - throw IOException("Unsupported GZIP archive"); - } -} - -bool GZipFileSystem::CheckIsZip(const char *data, duckdb::idx_t size) { - if (size < GZIP_HEADER_MINSIZE) { - return false; - } - - auto data_ptr = reinterpret_cast(data); - if (data_ptr[0] != 0x1F || data_ptr[1] != 0x8B) { - return false; - } - - if (data_ptr[2] != GZIP_COMPRESSION_DEFLATE) { - return false; - } - - return true; -} - -string GZipFileSystem::UncompressGZIPString(const string &in) { - return UncompressGZIPString(in.data(), in.size()); -} - -string GZipFileSystem::UncompressGZIPString(const char *data, idx_t size) { - // decompress file - auto body_ptr = data; - - auto mz_stream_ptr = make_uniq(); - memset(mz_stream_ptr.get(), 0, sizeof(duckdb_miniz::mz_stream)); - - uint8_t gzip_hdr[GZIP_HEADER_MINSIZE]; - - // check for incorrectly formatted files - - // TODO this is mostly the same as gzip_file_system.cpp - if (size < GZIP_HEADER_MINSIZE) { - throw IOException("Input is not a GZIP stream"); - } - memcpy(gzip_hdr, body_ptr, GZIP_HEADER_MINSIZE); - body_ptr += GZIP_HEADER_MINSIZE; - GZipFileSystem::VerifyGZIPHeader(gzip_hdr, GZIP_HEADER_MINSIZE); - - if (gzip_hdr[3] & GZIP_FLAG_EXTRA) { - throw IOException("Extra field in a GZIP stream unsupported"); - } - - if (gzip_hdr[3] & GZIP_FLAG_NAME) { - char c; - do { - c = *body_ptr; - body_ptr++; - } while (c != '\0' && (idx_t)(body_ptr - data) < size); - } - - // stream is now set to beginning of payload data - auto status = duckdb_miniz::mz_inflateInit2(mz_stream_ptr.get(), -MZ_DEFAULT_WINDOW_BITS); - if (status != duckdb_miniz::MZ_OK) { - throw InternalException("Failed to initialize miniz"); - } - - auto bytes_remaining = size - NumericCast(body_ptr - data); - mz_stream_ptr->next_in = const_uchar_ptr_cast(body_ptr); - mz_stream_ptr->avail_in = NumericCast(bytes_remaining); - - unsigned char decompress_buffer[BUFSIZ]; - string decompressed; - - while (status == duckdb_miniz::MZ_OK) { - mz_stream_ptr->next_out = decompress_buffer; - mz_stream_ptr->avail_out = sizeof(decompress_buffer); - status = mz_inflate(mz_stream_ptr.get(), duckdb_miniz::MZ_NO_FLUSH); - if (status != duckdb_miniz::MZ_STREAM_END && status != duckdb_miniz::MZ_OK) { - throw IOException("Failed to uncompress"); - } - decompressed.append(char_ptr_cast(decompress_buffer), mz_stream_ptr->total_out - decompressed.size()); - } - duckdb_miniz::mz_inflateEnd(mz_stream_ptr.get()); - - if (decompressed.empty()) { - throw IOException("Failed to uncompress"); - } - return decompressed; -} - -unique_ptr GZipFileSystem::OpenCompressedFile(unique_ptr handle, bool write) { - auto path = handle->path; - return make_uniq(std::move(handle), path, write); -} - -unique_ptr GZipFileSystem::CreateStream() { - return make_uniq(); -} - -idx_t GZipFileSystem::InBufferSize() { - return BUFFER_SIZE; -} - -idx_t GZipFileSystem::OutBufferSize() { - return BUFFER_SIZE; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/hive_partitioning.cpp b/src/duckdb/src/common/hive_partitioning.cpp deleted file mode 100644 index c84d0505d..000000000 --- a/src/duckdb/src/common/hive_partitioning.cpp +++ /dev/null @@ -1,413 +0,0 @@ -#include "duckdb/common/hive_partitioning.hpp" - -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_columnref_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/common/multi_file_list.hpp" - -namespace duckdb { - -struct PartitioningColumnValue { - explicit PartitioningColumnValue(string value_p) : value(std::move(value_p)) { - } - PartitioningColumnValue(string key_p, string value_p) : key(std::move(key_p)), value(std::move(value_p)) { - } - - string key; - string value; -}; - -static unordered_map -GetKnownColumnValues(const string &filename, const HivePartitioningFilterInfo &filter_info) { - unordered_map result; - - auto &column_map = filter_info.column_map; - if (filter_info.filename_enabled) { - auto lookup_column_id = column_map.find("filename"); - if (lookup_column_id != column_map.end()) { - result.insert(make_pair(lookup_column_id->second, PartitioningColumnValue(filename))); - } - } - - if (filter_info.hive_enabled) { - auto partitions = HivePartitioning::Parse(filename); - for (auto &partition : partitions) { - auto lookup_column_id = column_map.find(partition.first); - if (lookup_column_id != column_map.end()) { - result.insert( - make_pair(lookup_column_id->second, PartitioningColumnValue(partition.first, partition.second))); - } - } - } - - return result; -} - -// Takes an expression and converts a list of known column_refs to constants -static void ConvertKnownColRefToConstants(ClientContext &context, unique_ptr &expr, - const unordered_map &known_column_values, - idx_t table_index) { - if (expr->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr->Cast(); - - // This bound column ref is for another table - if (table_index != bound_colref.binding.table_index) { - return; - } - - auto lookup = known_column_values.find(bound_colref.binding.column_index); - if (lookup != known_column_values.end()) { - auto &partition_val = lookup->second; - Value result_val; - if (partition_val.key.empty()) { - // filename column - use directly - result_val = Value(partition_val.value); - } else { - // hive partitioning column - cast the value to the target type - result_val = HivePartitioning::GetValue(context, partition_val.key, partition_val.value, - bound_colref.return_type); - } - expr = make_uniq(std::move(result_val)); - } - } else { - ExpressionIterator::EnumerateChildren(*expr, [&](unique_ptr &child) { - ConvertKnownColRefToConstants(context, child, known_column_values, table_index); - }); - } -} - -string HivePartitioning::Escape(const string &input) { - return StringUtil::URLEncode(input); -} - -string HivePartitioning::Unescape(const string &input) { - return StringUtil::URLDecode(input); -} - -// matches hive partitions in file name. For example: -// - s3://bucket/var1=value1/bla/bla/var2=value2 -// - http(s)://domain(:port)/lala/kasdl/var1=value1/?not-a-var=not-a-value -// - folder/folder/folder/../var1=value1/etc/.//var2=value2 -std::map HivePartitioning::Parse(const string &filename) { - idx_t partition_start = 0; - idx_t equality_sign = 0; - bool candidate_partition = true; - std::map result; - for (idx_t c = 0; c < filename.size(); c++) { - if (filename[c] == '?' || filename[c] == '\n') { - // get parameter or newline - not a partition - candidate_partition = false; - } - if (filename[c] == '\\' || filename[c] == '/') { - // separator - if (candidate_partition && equality_sign > partition_start) { - // we found a partition with an equality sign - string key = filename.substr(partition_start, equality_sign - partition_start); - string value = filename.substr(equality_sign + 1, c - equality_sign - 1); - result.insert(make_pair(std::move(key), std::move(value))); - } - partition_start = c + 1; - candidate_partition = true; - } else if (filename[c] == '=') { - if (equality_sign > partition_start) { - // multiple equality signs - not a partition - candidate_partition = false; - } - equality_sign = c; - } - } - return result; -} - -Value HivePartitioning::GetValue(ClientContext &context, const string &key, const string &str_val, - const LogicalType &type) { - // Handle nulls - if (StringUtil::CIEquals(str_val, "NULL")) { - return Value(type); - } - if (type.id() == LogicalTypeId::VARCHAR) { - // for string values we can directly return the type - return Value(Unescape(str_val)); - } - if (str_val.empty()) { - // empty strings are NULL for non-string types - return Value(type); - } - - // cast to the target type - Value value(Unescape(str_val)); - if (!value.TryCastAs(context, type)) { - throw InvalidInputException("Unable to cast '%s' (from hive partition column '%s') to: '%s'", value.ToString(), - StringUtil::Upper(key), type.ToString()); - } - return value; -} - -// TODO: this can still be improved by removing the parts of filter expressions that are true for all remaining files. -// currently, only expressions that cannot be evaluated during pushdown are removed. -void HivePartitioning::ApplyFiltersToFileList(ClientContext &context, vector &files, - vector> &filters, - const HivePartitioningFilterInfo &filter_info, - MultiFilePushdownInfo &info) { - - vector pruned_files; - vector have_preserved_filter(filters.size(), false); - vector> pruned_filters; - unordered_set filters_applied_to_files; - auto table_index = info.table_index; - - if ((!filter_info.filename_enabled && !filter_info.hive_enabled) || filters.empty()) { - return; - } - - for (idx_t i = 0; i < files.size(); i++) { - auto &file = files[i]; - bool should_prune_file = false; - auto known_values = GetKnownColumnValues(file, filter_info); - - for (idx_t j = 0; j < filters.size(); j++) { - auto &filter = filters[j]; - unique_ptr filter_copy = filter->Copy(); - ConvertKnownColRefToConstants(context, filter_copy, known_values, table_index); - // Evaluate the filter, if it can be evaluated here, we can not prune this filter - Value result_value; - - if (!filter_copy->IsScalar() || !filter_copy->IsFoldable() || - !ExpressionExecutor::TryEvaluateScalar(context, *filter_copy, result_value)) { - // can not be evaluated only with the filename/hive columns added, we can not prune this filter - if (!have_preserved_filter[j]) { - pruned_filters.emplace_back(filter->Copy()); - have_preserved_filter[j] = true; - } - } else if (result_value.IsNull() || !result_value.GetValue()) { - // filter evaluates to false - should_prune_file = true; - // convert the filter to a table filter. - if (filters_applied_to_files.find(j) == filters_applied_to_files.end()) { - info.extra_info.file_filters += filter->ToString(); - filters_applied_to_files.insert(j); - } - } - } - - if (!should_prune_file) { - pruned_files.push_back(file); - } - } - - D_ASSERT(filters.size() >= pruned_filters.size()); - - info.extra_info.total_files = files.size(); - info.extra_info.filtered_files = pruned_files.size(); - - filters = std::move(pruned_filters); - files = std::move(pruned_files); -} - -void HivePartitionedColumnData::InitializeKeys() { - keys.resize(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - keys[i].values.resize(group_by_columns.size()); - } -} - -template -static inline Value GetHiveKeyValue(const T &val) { - return Value::CreateValue(val); -} - -template -static inline Value GetHiveKeyValue(const T &val, const LogicalType &type) { - auto result = GetHiveKeyValue(val); - result.Reinterpret(type); - return result; -} - -static inline Value GetHiveKeyNullValue(const LogicalType &type) { - Value result; - result.Reinterpret(type); - return result; -} - -template -static void TemplatedGetHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - UnifiedVectorFormat format; - input.ToUnifiedFormat(count, format); - - const auto &sel = *format.sel; - const auto data = UnifiedVectorFormat::GetData(format); - const auto &validity = format.validity; - - const auto &type = input.GetType(); - - const auto reinterpret = Value::CreateValue(data[0]).GetTypeMutable() != type; - if (reinterpret) { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - const auto idx = sel.get_index(i); - if (validity.RowIsValid(idx)) { - key.values[col_idx] = GetHiveKeyValue(data[idx], type); - } else { - key.values[col_idx] = GetHiveKeyNullValue(type); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - const auto idx = sel.get_index(i); - if (validity.RowIsValid(idx)) { - key.values[col_idx] = GetHiveKeyValue(data[idx]); - } else { - key.values[col_idx] = GetHiveKeyNullValue(type); - } - } - } -} - -static void GetNestedHivePartitionValues(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - key.values[col_idx] = input.GetValue(i); - } -} - -static void GetHivePartitionValuesTypeSwitch(Vector &input, vector &keys, const idx_t col_idx, - const idx_t count) { - const auto &type = input.GetType(); - switch (type.InternalType()) { - case PhysicalType::BOOL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INT128: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT8: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT16: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT32: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT64: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::UINT128: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::FLOAT: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::DOUBLE: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::INTERVAL: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::VARCHAR: - TemplatedGetHivePartitionValues(input, keys, col_idx, count); - break; - case PhysicalType::STRUCT: - case PhysicalType::LIST: - GetNestedHivePartitionValues(input, keys, col_idx, count); - break; - default: - throw InternalException("Unsupported type for HivePartitionedColumnData::ComputePartitionIndices"); - } -} - -void HivePartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { - const auto count = input.size(); - - input.Hash(group_by_columns, hashes_v); - hashes_v.Flatten(count); - - for (idx_t col_idx = 0; col_idx < group_by_columns.size(); col_idx++) { - auto &group_by_col = input.data[group_by_columns[col_idx]]; - GetHivePartitionValuesTypeSwitch(group_by_col, keys, col_idx, count); - } - - const auto hashes = FlatVector::GetData(hashes_v); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - for (idx_t i = 0; i < count; i++) { - auto &key = keys[i]; - key.hash = hashes[i]; - auto lookup = local_partition_map.find(key); - if (lookup == local_partition_map.end()) { - idx_t new_partition_id = RegisterNewPartition(key, state); - partition_indices[i] = new_partition_id; - } else { - partition_indices[i] = lookup->second; - } - } -} - -std::map HivePartitionedColumnData::GetReverseMap() { - std::map ret; - for (const auto &pair : local_partition_map) { - ret[pair.second] = &(pair.first); - } - return ret; -} - -HivePartitionedColumnData::HivePartitionedColumnData(ClientContext &context, vector types, - vector partition_by_cols, - shared_ptr global_state) - : PartitionedColumnData(PartitionedColumnDataType::HIVE, context, std::move(types)), - global_state(std::move(global_state)), group_by_columns(std::move(partition_by_cols)), - hashes_v(LogicalType::HASH) { - InitializeKeys(); - CreateAllocator(); -} - -void HivePartitionedColumnData::AddNewPartition(HivePartitionKey key, idx_t partition_id, - PartitionedColumnDataAppendState &state) { - local_partition_map.emplace(std::move(key), partition_id); - - if (state.partition_append_states.size() <= partition_id) { - state.partition_append_states.resize(partition_id + 1); - state.partition_buffers.resize(partition_id + 1); - partitions.resize(partition_id + 1); - } - state.partition_append_states[partition_id] = make_uniq(); - state.partition_buffers[partition_id] = CreatePartitionBuffer(); - partitions[partition_id] = CreatePartitionCollection(0); - partitions[partition_id]->InitializeAppend(*state.partition_append_states[partition_id]); -} - -idx_t HivePartitionedColumnData::RegisterNewPartition(HivePartitionKey key, PartitionedColumnDataAppendState &state) { - idx_t partition_id; - if (global_state) { - // Synchronize Global state with our local state with the newly discovered partition - unique_lock lck_gstate(global_state->lock); - - // Insert into global map, or return partition if already present - auto res = global_state->partition_map.emplace(std::make_pair(key, global_state->partition_map.size())); - partition_id = res.first->second; - } else { - partition_id = local_partition_map.size(); - } - AddNewPartition(std::move(key), partition_id, state); - return partition_id; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/http_util.cpp b/src/duckdb/src/common/http_util.cpp deleted file mode 100644 index 71248367e..000000000 --- a/src/duckdb/src/common/http_util.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "duckdb/common/http_util.hpp" - -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -void HTTPUtil::ParseHTTPProxyHost(string &proxy_value, string &hostname_out, idx_t &port_out, idx_t default_port) { - auto sanitized_proxy_value = proxy_value; - if (StringUtil::StartsWith(proxy_value, "http://")) { - sanitized_proxy_value = proxy_value.substr(7); - } - auto proxy_split = StringUtil::Split(sanitized_proxy_value, ":"); - if (proxy_split.size() == 1) { - hostname_out = proxy_split[0]; - port_out = default_port; - } else if (proxy_split.size() == 2) { - idx_t port; - if (!TryCast::Operation(proxy_split[1], port, false)) { - throw InvalidInputException("Failed to parse port from http_proxy '%s'", proxy_value); - } - hostname_out = proxy_split[0]; - port_out = port; - } else { - throw InvalidInputException("Failed to parse http_proxy '%s' into a host and port", proxy_value); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/local_file_system.cpp b/src/duckdb/src/common/local_file_system.cpp deleted file mode 100644 index 83e242526..000000000 --- a/src/duckdb/src/common/local_file_system.cpp +++ /dev/null @@ -1,1451 +0,0 @@ -#include "duckdb/common/local_file_system.hpp" - -#include "duckdb/common/checksum.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_opener.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/windows.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" - -#include -#include -#include - -#ifndef _WIN32 -#include -#include -#include -#include -#include -#else -#include "duckdb/common/windows_util.hpp" - -#include -#include - -#ifdef __MINGW32__ -// need to manually define this for mingw -extern "C" WINBASEAPI BOOL WINAPI GetPhysicallyInstalledSystemMemory(PULONGLONG); -extern "C" WINBASEAPI BOOL QueryFullProcessImageNameW(HANDLE, DWORD, LPWSTR, PDWORD); -#endif - -#undef FILE_CREATE // woo mingw -#endif - -// includes for giving a better error message on lock conflicts -#if defined(__linux__) || defined(__APPLE__) -#include -#endif - -#if defined(__linux__) -// See https://man7.org/linux/man-pages/man2/fallocate.2.html -#ifndef _GNU_SOURCE -#define _GNU_SOURCE /* See feature_test_macros(7) */ -#endif -#include -#include -// See e.g.: -// https://opensource.apple.com/source/CarbonHeaders/CarbonHeaders-18.1/TargetConditionals.h.auto.html -#elif defined(__APPLE__) -#include -#if not(defined(TARGET_OS_IPHONE) && TARGET_OS_IPHONE == 1) -#include -#endif -#elif defined(_WIN32) -#include -#endif - -namespace duckdb { -#ifndef _WIN32 -bool LocalFileSystem::FileExists(const string &filename, optional_ptr opener) { - if (!filename.empty()) { - auto normalized_file = NormalizeLocalPath(filename); - if (access(normalized_file, 0) == 0) { - struct stat status; - stat(normalized_file, &status); - if (S_ISREG(status.st_mode)) { - return true; - } - } - } - // if any condition fails - return false; -} - -bool LocalFileSystem::IsPipe(const string &filename, optional_ptr opener) { - if (!filename.empty()) { - auto normalized_file = NormalizeLocalPath(filename); - if (access(normalized_file, 0) == 0) { - struct stat status; - stat(normalized_file, &status); - if (S_ISFIFO(status.st_mode)) { - return true; - } - } - } - // if any condition fails - return false; -} - -#else -static std::wstring NormalizePathAndConvertToUnicode(const string &path) { - string normalized_path_copy; - const char *normalized_path; - if (StringUtil::StartsWith(path, "file:/")) { - normalized_path_copy = LocalFileSystem::NormalizeLocalPath(path); - normalized_path_copy = LocalFileSystem().ConvertSeparators(normalized_path_copy); - normalized_path = normalized_path_copy.c_str(); - } else { - normalized_path = path.c_str(); - } - return WindowsUtil::UTF8ToUnicode(normalized_path); -} - -bool LocalFileSystem::FileExists(const string &filename, optional_ptr opener) { - auto unicode_path = NormalizePathAndConvertToUnicode(filename); - const wchar_t *wpath = unicode_path.c_str(); - if (_waccess(wpath, 0) == 0) { - struct _stati64 status; - _wstati64(wpath, &status); - if (status.st_mode & S_IFREG) { - return true; - } - } - return false; -} -bool LocalFileSystem::IsPipe(const string &filename, optional_ptr opener) { - auto unicode_path = NormalizePathAndConvertToUnicode(filename); - const wchar_t *wpath = unicode_path.c_str(); - if (_waccess(wpath, 0) == 0) { - struct _stati64 status; - _wstati64(wpath, &status); - if (status.st_mode & _S_IFCHR) { - return true; - } - } - return false; -} -#endif - -#ifndef _WIN32 -// somehow sometimes this is missing -#ifndef O_CLOEXEC -#define O_CLOEXEC 0 -#endif - -// Solaris -#ifndef O_DIRECT -#define O_DIRECT 0 -#endif - -struct UnixFileHandle : public FileHandle { -public: - UnixFileHandle(FileSystem &file_system, string path, int fd, FileOpenFlags flags) - : FileHandle(file_system, std::move(path), flags), fd(fd) { - } - ~UnixFileHandle() override { - UnixFileHandle::Close(); - } - - int fd; - -public: - void Close() override { - if (fd != -1) { - close(fd); - fd = -1; - } - }; -}; - -static FileType GetFileTypeInternal(int fd) { // LCOV_EXCL_START - struct stat s; - if (fstat(fd, &s) == -1) { - return FileType::FILE_TYPE_INVALID; - } - switch (s.st_mode & S_IFMT) { - case S_IFBLK: - return FileType::FILE_TYPE_BLOCKDEV; - case S_IFCHR: - return FileType::FILE_TYPE_CHARDEV; - case S_IFIFO: - return FileType::FILE_TYPE_FIFO; - case S_IFDIR: - return FileType::FILE_TYPE_DIR; - case S_IFLNK: - return FileType::FILE_TYPE_LINK; - case S_IFREG: - return FileType::FILE_TYPE_REGULAR; - case S_IFSOCK: - return FileType::FILE_TYPE_SOCKET; - default: - return FileType::FILE_TYPE_INVALID; - } -} // LCOV_EXCL_STOP - -#if __APPLE__ && !TARGET_OS_IPHONE - -static string AdditionalProcessInfo(FileSystem &fs, pid_t pid) { - if (pid == getpid()) { - return "Lock is already held in current process, likely another DuckDB instance"; - } - - string process_name, process_owner; - // macOS >= 10.7 has PROC_PIDT_SHORTBSDINFO -#ifdef PROC_PIDT_SHORTBSDINFO - // try to find out more about the process holding the lock - struct proc_bsdshortinfo proc; - if (proc_pidinfo(pid, PROC_PIDT_SHORTBSDINFO, 0, &proc, PROC_PIDT_SHORTBSDINFO_SIZE) == - PROC_PIDT_SHORTBSDINFO_SIZE) { - process_name = proc.pbsi_comm; // only a short version however, let's take it in case proc_pidpath() below fails - // try to get actual name of conflicting process owner - auto pw = getpwuid(proc.pbsi_uid); - if (pw) { - process_owner = pw->pw_name; - } - } -#else - return string(); -#endif - // try to get a better process name (full path) - char full_exec_path[PROC_PIDPATHINFO_MAXSIZE]; - if (proc_pidpath(pid, full_exec_path, PROC_PIDPATHINFO_MAXSIZE) > 0) { - // somehow could not get the path, lets use some sensible fallback - process_name = full_exec_path; - } - return StringUtil::Format("Conflicting lock is held in %s%s", - !process_name.empty() ? StringUtil::Format("%s (PID %d)", process_name, pid) - : StringUtil::Format("PID %d", pid), - !process_owner.empty() ? StringUtil::Format(" by user %s", process_owner) : ""); -} - -#elif __linux__ - -static string AdditionalProcessInfo(FileSystem &fs, pid_t pid) { - if (pid == getpid()) { - return "Lock is already held in current process, likely another DuckDB instance"; - } - string process_name, process_owner; - - try { - auto cmdline_file = fs.OpenFile(StringUtil::Format("/proc/%d/cmdline", pid), FileFlags::FILE_FLAGS_READ); - auto cmdline = cmdline_file->ReadLine(); - process_name = basename(const_cast(cmdline.c_str())); // NOLINT: old C API does not take const - } catch (std::exception &) { - // ignore - } - - // we would like to provide a full path to the executable if possible but we might not have rights - { - char exe_target[PATH_MAX]; - memset(exe_target, '\0', PATH_MAX); - auto proc_exe_link = StringUtil::Format("/proc/%d/exe", pid); - auto readlink_n = readlink(proc_exe_link.c_str(), exe_target, PATH_MAX); - if (readlink_n > 0) { - process_name = exe_target; - } - } - - // try to find out who created that process - try { - auto loginuid_file = fs.OpenFile(StringUtil::Format("/proc/%d/loginuid", pid), FileFlags::FILE_FLAGS_READ); - auto uid = std::stoi(loginuid_file->ReadLine()); - auto pw = getpwuid(uid); - if (pw) { - process_owner = pw->pw_name; - } - } catch (std::exception &) { - // ignore - } - - return StringUtil::Format("Conflicting lock is held in %s%s", - !process_name.empty() ? StringUtil::Format("%s (PID %d)", process_name, pid) - : StringUtil::Format("PID %d", pid), - !process_owner.empty() ? StringUtil::Format(" by user %s", process_owner) : ""); -} - -#else -static string AdditionalProcessInfo(FileSystem &fs, pid_t pid) { - return ""; -} -#endif - -bool LocalFileSystem::IsPrivateFile(const string &path_p, FileOpener *opener) { - auto path = FileSystem::ExpandPath(path_p, opener); - auto normalized_path = NormalizeLocalPath(path); - - struct stat st; - - if (lstat(normalized_path, &st) != 0) { - throw IOException( - "Failed to stat '%s' when checking file permissions, file may be missing or have incorrect permissions", - path.c_str()); - } - - // If group or other have any permission, the file is not private - if (st.st_mode & (S_IRGRP | S_IWGRP | S_IXGRP | S_IROTH | S_IWOTH | S_IXOTH)) { - return false; - } - - return true; -} - -unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenFlags flags, - optional_ptr opener) { - auto path = FileSystem::ExpandPath(path_p, opener); - auto normalized_path = NormalizeLocalPath(path); - if (flags.Compression() != FileCompressionType::UNCOMPRESSED) { - throw NotImplementedException("Unsupported compression type for default file system"); - } - - flags.Verify(); - - int open_flags = 0; - int rc; - bool open_read = flags.OpenForReading(); - bool open_write = flags.OpenForWriting(); - if (open_read && open_write) { - open_flags = O_RDWR; - } else if (open_read) { - open_flags = O_RDONLY; - } else if (open_write) { - open_flags = O_WRONLY; - } else { - throw InternalException("READ, WRITE or both should be specified when opening a file"); - } - if (open_write) { - // need Read or Write - D_ASSERT(flags.OpenForWriting()); - open_flags |= O_CLOEXEC; - if (flags.CreateFileIfNotExists()) { - open_flags |= O_CREAT; - } else if (flags.OverwriteExistingFile()) { - open_flags |= O_CREAT | O_TRUNC; - } - if (flags.OpenForAppending()) { - open_flags |= O_APPEND; - } - } - if (flags.DirectIO()) { -#if defined(__sun) && defined(__SVR4) - throw InvalidInputException("DIRECT_IO not supported on Solaris"); -#endif -#if defined(__DARWIN__) || defined(__APPLE__) || defined(__OpenBSD__) - // OSX does not have O_DIRECT, instead we need to use fcntl afterwards to support direct IO -#else - open_flags |= O_DIRECT; -#endif - } - - // Determine permissions - mode_t filesec; - if (flags.CreatePrivateFile()) { - open_flags |= O_EXCL; // Ensure we error on existing files or the permissions may not set - filesec = 0600; - } else { - filesec = 0666; - } - - if (flags.ExclusiveCreate()) { - open_flags |= O_EXCL; - } - - // Open the file - int fd = open(normalized_path, open_flags, filesec); - - if (fd == -1) { - if (flags.ReturnNullIfNotExists() && errno == ENOENT) { - return nullptr; - } - if (flags.ReturnNullIfExists() && errno == EEXIST) { - return nullptr; - } - throw IOException("Cannot open file \"%s\": %s", {{"errno", std::to_string(errno)}}, path, strerror(errno)); - } - -#if defined(__DARWIN__) || defined(__APPLE__) - if (flags.DirectIO()) { - // OSX requires fcntl for Direct IO - rc = fcntl(fd, F_NOCACHE, 1); - if (rc == -1) { - throw IOException("Could not enable direct IO for file \"%s\": %s", path, strerror(errno)); - } - } -#endif - - if (flags.Lock() != FileLockType::NO_LOCK) { - // set lock on file - // but only if it is not an input/output stream - auto file_type = GetFileTypeInternal(fd); - if (file_type != FileType::FILE_TYPE_FIFO && file_type != FileType::FILE_TYPE_SOCKET) { - struct flock fl; - memset(&fl, 0, sizeof fl); - fl.l_type = flags.Lock() == FileLockType::READ_LOCK ? F_RDLCK : F_WRLCK; - fl.l_whence = SEEK_SET; - fl.l_start = 0; - fl.l_len = 0; - rc = fcntl(fd, F_SETLK, &fl); - // Retain the original error. - int retained_errno = errno; - bool has_error = rc == -1; - string extended_error; - if (has_error) { - if (retained_errno == ENOTSUP) { - // file lock not supported for this file system - if (flags.Lock() == FileLockType::READ_LOCK) { - // for read-only, we ignore not-supported errors - has_error = false; - errno = 0; - } else { - extended_error = "File locks are not supported for this file system, cannot open the file in " - "read-write mode. Try opening the file in read-only mode"; - } - } - } - if (has_error) { - if (extended_error.empty()) { - // try to find out who is holding the lock using F_GETLK - rc = fcntl(fd, F_GETLK, &fl); - if (rc == -1) { // fnctl does not want to help us - extended_error = strerror(errno); - } else { - extended_error = AdditionalProcessInfo(*this, fl.l_pid); - } - if (flags.Lock() == FileLockType::WRITE_LOCK) { - // maybe we can get a read lock instead and tell this to the user. - fl.l_type = F_RDLCK; - rc = fcntl(fd, F_SETLK, &fl); - if (rc != -1) { // success! - extended_error += - ". However, you would be able to open this database in read-only mode, e.g. by " - "using the -readonly parameter in the CLI"; - } - } - } - rc = close(fd); - if (rc == -1) { - extended_error += ". Also, failed closing file"; - } - extended_error += ". See also https://duckdb.org/docs/connect/concurrency"; - throw IOException("Could not set lock on file \"%s\": %s", {{"errno", std::to_string(retained_errno)}}, - path, extended_error); - } - } - } - return make_uniq(*this, path, fd, flags); -} - -void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { - int fd = handle.Cast().fd; - off_t offset = lseek(fd, UnsafeNumericCast(location), SEEK_SET); - if (offset == (off_t)-1) { - throw IOException("Could not seek to location %lld for file \"%s\": %s", {{"errno", std::to_string(errno)}}, - location, handle.path, strerror(errno)); - } -} - -idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { - int fd = handle.Cast().fd; - off_t position = lseek(fd, 0, SEEK_CUR); - if (position == (off_t)-1) { - throw IOException("Could not get file position file \"%s\": %s", {{"errno", std::to_string(errno)}}, - handle.path, strerror(errno)); - } - return UnsafeNumericCast(position); -} - -void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - int fd = handle.Cast().fd; - auto read_buffer = char_ptr_cast(buffer); - while (nr_bytes > 0) { - int64_t bytes_read = - pread(fd, read_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); - if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } - if (bytes_read == 0) { - throw IOException( - "Could not read enough bytes from file \"%s\": attempted to read %llu bytes from location %llu", - handle.path, nr_bytes, location); - } - read_buffer += bytes_read; - nr_bytes -= bytes_read; - location += UnsafeNumericCast(bytes_read); - } -} - -int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - int fd = handle.Cast().fd; - int64_t bytes_read = read(fd, buffer, UnsafeNumericCast(nr_bytes)); - if (bytes_read == -1) { - throw IOException("Could not read from file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } - return bytes_read; -} - -void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - int fd = handle.Cast().fd; - auto write_buffer = char_ptr_cast(buffer); - while (nr_bytes > 0) { - int64_t bytes_written = - pwrite(fd, write_buffer, UnsafeNumericCast(nr_bytes), UnsafeNumericCast(location)); - if (bytes_written < 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } - if (bytes_written == 0) { - throw IOException("Could not write to file \"%s\" - attempted to write 0 bytes: %s", - {{"errno", std::to_string(errno)}}, handle.path, strerror(errno)); - } - write_buffer += bytes_written; - nr_bytes -= bytes_written; - location += UnsafeNumericCast(bytes_written); - } -} - -int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - int fd = handle.Cast().fd; - int64_t bytes_written = 0; - while (nr_bytes > 0) { - auto bytes_to_write = MinValue(idx_t(NumericLimits::Maximum()), idx_t(nr_bytes)); - int64_t current_bytes_written = write(fd, buffer, bytes_to_write); - if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } - bytes_written += current_bytes_written; - buffer = (void *)(data_ptr_cast(buffer) + current_bytes_written); - nr_bytes -= current_bytes_written; - } - return bytes_written; -} - -bool LocalFileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_bytes) { -#if defined(__linux__) - // FALLOC_FL_PUNCH_HOLE requires glibc 2.18 or up -#if __GLIBC__ < 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ < 18) - return false; -#else - int fd = handle.Cast().fd; - int res = fallocate(fd, FALLOC_FL_PUNCH_HOLE | FALLOC_FL_KEEP_SIZE, UnsafeNumericCast(offset_bytes), - UnsafeNumericCast(length_bytes)); - return res == 0; -#endif -#else - return false; -#endif -} - -int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - throw IOException("Failed to get file size for file \"%s\": %s", {{"errno", std::to_string(errno)}}, - handle.path, strerror(errno)); - } - return s.st_size; -} - -time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - int fd = handle.Cast().fd; - struct stat s; - if (fstat(fd, &s) == -1) { - throw IOException("Failed to get last modified time for file \"%s\": %s", {{"errno", std::to_string(errno)}}, - handle.path, strerror(errno)); - } - return s.st_mtime; -} - -FileType LocalFileSystem::GetFileType(FileHandle &handle) { - int fd = handle.Cast().fd; - return GetFileTypeInternal(fd); -} - -void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - int fd = handle.Cast().fd; - if (ftruncate(fd, new_size) != 0) { - throw IOException("Could not truncate file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } -} - -bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { - if (!directory.empty()) { - auto normalized_dir = NormalizeLocalPath(directory); - if (access(normalized_dir, 0) == 0) { - struct stat status; - stat(normalized_dir, &status); - if (status.st_mode & S_IFDIR) { - return true; - } - } - } - // if any condition fails - return false; -} - -void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr opener) { - struct stat st; - - auto normalized_dir = NormalizeLocalPath(directory); - if (stat(normalized_dir, &st) != 0) { - /* Directory does not exist. EEXIST for race condition */ - if (mkdir(normalized_dir, 0755) != 0 && errno != EEXIST) { - throw IOException("Failed to create directory \"%s\": %s", {{"errno", std::to_string(errno)}}, directory, - strerror(errno)); - } - } else if (!S_ISDIR(st.st_mode)) { - throw IOException("Failed to create directory \"%s\": path exists but is not a directory!", - {{"errno", std::to_string(errno)}}, directory); - } -} - -int RemoveDirectoryRecursive(const char *path) { - DIR *d = opendir(path); - idx_t path_len = (idx_t)strlen(path); - int r = -1; - - if (d) { - struct dirent *p; - r = 0; - while (!r && (p = readdir(d))) { - int r2 = -1; - char *buf; - idx_t len; - /* Skip the names "." and ".." as we don't want to recurse on them. */ - if (!strcmp(p->d_name, ".") || !strcmp(p->d_name, "..")) { - continue; - } - len = path_len + (idx_t)strlen(p->d_name) + 2; - buf = new (std::nothrow) char[len]; - if (buf) { - struct stat statbuf; - snprintf(buf, len, "%s/%s", path, p->d_name); - if (!stat(buf, &statbuf)) { - if (S_ISDIR(statbuf.st_mode)) { - r2 = RemoveDirectoryRecursive(buf); - } else { - r2 = unlink(buf); - } - } - delete[] buf; - } - r = r2; - } - closedir(d); - } - if (!r) { - r = rmdir(path); - } - return r; -} - -void LocalFileSystem::RemoveDirectory(const string &directory, optional_ptr opener) { - auto normalized_dir = NormalizeLocalPath(directory); - RemoveDirectoryRecursive(normalized_dir); -} - -void LocalFileSystem::RemoveFile(const string &filename, optional_ptr opener) { - auto normalized_file = NormalizeLocalPath(filename); - if (std::remove(normalized_file) != 0) { - throw IOException("Could not remove file \"%s\": %s", {{"errno", std::to_string(errno)}}, filename, - strerror(errno)); - } -} - -bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - auto normalized_dir = NormalizeLocalPath(directory); - auto dir = opendir(normalized_dir); - if (!dir) { - return false; - } - - // RAII wrapper around DIR to automatically free on exceptions in callback - std::unique_ptr> dir_unique_ptr(dir, [](DIR *d) { closedir(d); }); - - struct dirent *ent; - // loop over all files in the directory - while ((ent = readdir(dir)) != nullptr) { - string name = string(ent->d_name); - // skip . .. and empty files - if (name.empty() || name == "." || name == "..") { - continue; - } - // now stat the file to figure out if it is a regular file or directory - string full_path = JoinPath(normalized_dir, name); - struct stat status; - auto res = stat(full_path.c_str(), &status); - if (res != 0) { - continue; - } - if (!(status.st_mode & S_IFREG) && !(status.st_mode & S_IFDIR)) { - // not a file or directory: skip - continue; - } - // invoke callback - callback(name, status.st_mode & S_IFDIR); - } - - return true; -} - -void LocalFileSystem::FileSync(FileHandle &handle) { - int fd = handle.Cast().fd; - if (fsync(fd) != 0) { - throw FatalException("fsync failed!"); - } -} - -void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { - auto normalized_source = NormalizeLocalPath(source); - auto normalized_target = NormalizeLocalPath(target); - //! FIXME: rename does not guarantee atomicity or overwriting target file if it exists - if (rename(normalized_source, normalized_target) != 0) { - throw IOException("Could not rename file!", {{"errno", std::to_string(errno)}}); - } -} - -std::string LocalFileSystem::GetLastErrorAsString() { - return string(); -} - -#else - -constexpr char PIPE_PREFIX[] = "\\\\.\\pipe\\"; - -// Returns the last Win32 error, in string format. Returns an empty string if there is no error. -std::string LocalFileSystem::GetLastErrorAsString() { - // Get the error message, if any. - DWORD errorMessageID = GetLastError(); - if (errorMessageID == 0) - return std::string(); // No error message has been recorded - - LPSTR messageBuffer = nullptr; - idx_t size = - FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, - NULL, errorMessageID, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); - - std::string message(messageBuffer, size); - - // Free the buffer. - LocalFree(messageBuffer); - - return message; -} - -struct WindowsFileHandle : public FileHandle { -public: - WindowsFileHandle(FileSystem &file_system, string path, HANDLE fd, FileOpenFlags flags) - : FileHandle(file_system, path, flags), position(0), fd(fd) { - } - ~WindowsFileHandle() override { - Close(); - } - - idx_t position; - HANDLE fd; - -public: - void Close() override { - if (!fd) { - return; - } - CloseHandle(fd); - fd = nullptr; - }; -}; - -static string AdditionalLockInfo(const std::wstring path) { - // try to find out if another process is holding the lock - - // init of the somewhat obscure "Windows Restart Manager" - // see also https://devblogs.microsoft.com/oldnewthing/20120217-00/?p=8283 - - DWORD session, status, reason; - WCHAR session_key[CCH_RM_SESSION_KEY + 1] = {0}; - - status = RmStartSession(&session, 0, session_key); - if (status != ERROR_SUCCESS) { - return ""; - } - - PCWSTR path_ptr = path.c_str(); - status = RmRegisterResources(session, 1, &path_ptr, 0, NULL, 0, NULL); - if (status != ERROR_SUCCESS) { - return ""; - } - UINT process_info_size_needed, process_info_size; - - // we first call with nProcInfo = 0 to find out how much to allocate - process_info_size = 0; - status = RmGetList(session, &process_info_size_needed, &process_info_size, NULL, &reason); - if (status != ERROR_MORE_DATA || process_info_size_needed == 0) { - return ""; - } - - // allocate - auto process_info_buffer = duckdb::unique_ptr(new RM_PROCESS_INFO[process_info_size_needed]); - auto process_info = process_info_buffer.get(); - - // now call again to get actual data - process_info_size = process_info_size_needed; - status = RmGetList(session, &process_info_size_needed, &process_info_size, process_info, &reason); - if (status != ERROR_SUCCESS || process_info_size == 0) { - return ""; - } - - string conflict_string = "File is already open in "; - - for (UINT process_idx = 0; process_idx < process_info_size; process_idx++) { - string process_name = WindowsUtil::UnicodeToUTF8(process_info[process_idx].strAppName); - auto pid = process_info[process_idx].Process.dwProcessId; - - // find out full path if possible - HANDLE process = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, FALSE, pid); - if (process) { - WCHAR full_path[MAX_PATH]; - DWORD full_path_size = MAX_PATH; - if (QueryFullProcessImageNameW(process, 0, full_path, &full_path_size) && full_path_size <= MAX_PATH) { - process_name = WindowsUtil::UnicodeToUTF8(full_path); - } - CloseHandle(process); - } - conflict_string += StringUtil::Format("\n%s (PID %d)", process_name, pid); - } - - RmEndSession(session); - return conflict_string; -} - -bool LocalFileSystem::IsPrivateFile(const string &path_p, FileOpener *opener) { - // TODO: detect if file is shared in windows - return true; -} - -unique_ptr LocalFileSystem::OpenFile(const string &path_p, FileOpenFlags flags, - optional_ptr opener) { - auto path = FileSystem::ExpandPath(path_p, opener); - auto unicode_path = NormalizePathAndConvertToUnicode(path); - if (flags.Compression() != FileCompressionType::UNCOMPRESSED) { - throw NotImplementedException("Unsupported compression type for default file system"); - } - flags.Verify(); - - DWORD desired_access; - DWORD share_mode; - DWORD creation_disposition = OPEN_EXISTING; - DWORD flags_and_attributes = FILE_ATTRIBUTE_NORMAL; - bool open_read = flags.OpenForReading(); - bool open_write = flags.OpenForWriting(); - if (open_read && open_write) { - desired_access = GENERIC_READ | GENERIC_WRITE; - share_mode = 0; - } else if (open_read) { - desired_access = GENERIC_READ; - share_mode = FILE_SHARE_READ; - } else if (open_write) { - desired_access = GENERIC_WRITE; - share_mode = 0; - } else { - throw InternalException("READ, WRITE or both should be specified when opening a file"); - } - if (open_write) { - if (flags.CreateFileIfNotExists()) { - creation_disposition = OPEN_ALWAYS; - } else if (flags.OverwriteExistingFile()) { - creation_disposition = CREATE_ALWAYS; - } - } - if (flags.DirectIO()) { - flags_and_attributes |= FILE_FLAG_NO_BUFFERING; - } - HANDLE hFile = CreateFileW(unicode_path.c_str(), desired_access, share_mode, NULL, creation_disposition, - flags_and_attributes, NULL); - if (hFile == INVALID_HANDLE_VALUE) { - if (flags.ReturnNullIfNotExists() && GetLastError() == ERROR_FILE_NOT_FOUND) { - return nullptr; - } - auto error = LocalFileSystem::GetLastErrorAsString(); - - auto better_error = AdditionalLockInfo(unicode_path); - if (!better_error.empty()) { - throw IOException(better_error); - } else { - throw IOException("Cannot open file \"%s\": %s", path.c_str(), error); - } - } - auto handle = make_uniq(*this, path.c_str(), hFile, flags); - if (flags.OpenForAppending()) { - auto file_size = GetFileSize(*handle); - SetFilePointer(*handle, file_size); - } - return std::move(handle); -} - -void LocalFileSystem::SetFilePointer(FileHandle &handle, idx_t location) { - auto &whandle = handle.Cast(); - whandle.position = location; - LARGE_INTEGER wlocation; - wlocation.QuadPart = location; - SetFilePointerEx(whandle.fd, wlocation, NULL, FILE_BEGIN); -} - -idx_t LocalFileSystem::GetFilePointer(FileHandle &handle) { - return handle.Cast().position; -} - -static DWORD FSInternalRead(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { - DWORD bytes_read = 0; - OVERLAPPED ov = {}; - ov.Internal = 0; - ov.InternalHigh = 0; - ov.Offset = location & 0xFFFFFFFF; - ov.OffsetHigh = location >> 32; - ov.hEvent = 0; - auto rc = ReadFile(hFile, buffer, (DWORD)nr_bytes, &bytes_read, &ov); - if (!rc) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Could not read file \"%s\" (error in ReadFile(location: %llu, nr_bytes: %lld)): %s", - handle.path, location, nr_bytes, error); - } - return bytes_read; -} - -void LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - HANDLE hFile = ((WindowsFileHandle &)handle).fd; - auto bytes_read = FSInternalRead(handle, hFile, buffer, nr_bytes, location); - if (bytes_read != nr_bytes) { - throw IOException("Could not read all bytes from file \"%s\": wanted=%lld read=%lld", handle.path, nr_bytes, - bytes_read); - } -} - -int64_t LocalFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - HANDLE hFile = handle.Cast().fd; - auto &pos = handle.Cast().position; - auto n = std::min(std::max(GetFileSize(handle), pos) - pos, nr_bytes); - auto bytes_read = FSInternalRead(handle, hFile, buffer, n, pos); - pos += bytes_read; - return bytes_read; -} - -static DWORD FSInternalWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { - DWORD bytes_written = 0; - OVERLAPPED ov = {}; - ov.Internal = 0; - ov.InternalHigh = 0; - ov.Offset = location & 0xFFFFFFFF; - ov.OffsetHigh = location >> 32; - ov.hEvent = 0; - auto rc = WriteFile(hFile, buffer, (DWORD)nr_bytes, &bytes_written, &ov); - if (!rc) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Could not write file \"%s\" (error in WriteFile): %s", handle.path, error); - } - return bytes_written; -} - -static int64_t FSWrite(FileHandle &handle, HANDLE hFile, void *buffer, int64_t nr_bytes, idx_t location) { - int64_t bytes_written = 0; - while (nr_bytes > 0) { - auto bytes_to_write = MinValue(idx_t(NumericLimits::Maximum()), idx_t(nr_bytes)); - DWORD current_bytes_written = FSInternalWrite(handle, hFile, buffer, bytes_to_write, location); - if (current_bytes_written <= 0) { - throw IOException("Could not write file \"%s\": %s", {{"errno", std::to_string(errno)}}, handle.path, - strerror(errno)); - } - bytes_written += current_bytes_written; - buffer = (void *)(data_ptr_cast(buffer) + current_bytes_written); - location += current_bytes_written; - nr_bytes -= current_bytes_written; - } - return bytes_written; -} - -void LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - HANDLE hFile = handle.Cast().fd; - auto bytes_written = FSWrite(handle, hFile, buffer, nr_bytes, location); - if (bytes_written != nr_bytes) { - throw IOException("Could not write all bytes from file \"%s\": wanted=%lld wrote=%lld", handle.path, nr_bytes, - bytes_written); - } -} - -int64_t LocalFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - HANDLE hFile = handle.Cast().fd; - auto &pos = handle.Cast().position; - auto bytes_written = FSWrite(handle, hFile, buffer, nr_bytes, pos); - pos += bytes_written; - return bytes_written; -} - -bool LocalFileSystem::Trim(FileHandle &handle, idx_t offset_bytes, idx_t length_bytes) { - // TODO: Not yet implemented on windows. - return false; -} - -int64_t LocalFileSystem::GetFileSize(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - LARGE_INTEGER result; - if (!GetFileSizeEx(hFile, &result)) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to get file size for file \"%s\": %s", handle.path, error); - } - return result.QuadPart; -} - -time_t LocalFileSystem::GetLastModifiedTime(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - - // https://docs.microsoft.com/en-us/windows/win32/api/fileapi/nf-fileapi-getfiletime - FILETIME last_write; - if (GetFileTime(hFile, nullptr, nullptr, &last_write) == 0) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to get last modified time for file \"%s\": %s", handle.path, error); - } - - // https://stackoverflow.com/questions/29266743/what-is-dwlowdatetime-and-dwhighdatetime - ULARGE_INTEGER ul; - ul.LowPart = last_write.dwLowDateTime; - ul.HighPart = last_write.dwHighDateTime; - int64_t fileTime64 = ul.QuadPart; - - // fileTime64 contains a 64-bit value representing the number of - // 100-nanosecond intervals since January 1, 1601 (UTC). - // https://docs.microsoft.com/en-us/windows/win32/api/minwinbase/ns-minwinbase-filetime - - // Adapted from: https://stackoverflow.com/questions/6161776/convert-windows-filetime-to-second-in-unix-linux - const auto WINDOWS_TICK = 10000000; - const auto SEC_TO_UNIX_EPOCH = 11644473600LL; - time_t result = (fileTime64 / WINDOWS_TICK - SEC_TO_UNIX_EPOCH); - return result; -} - -void LocalFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - HANDLE hFile = handle.Cast().fd; - // seek to the location - SetFilePointer(handle, new_size); - // now set the end of file position - if (!SetEndOfFile(hFile)) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failure in SetEndOfFile call on file \"%s\": %s", handle.path, error); - } -} - -static DWORD WindowsGetFileAttributes(const string &filename) { - auto unicode_path = NormalizePathAndConvertToUnicode(filename); - return GetFileAttributesW(unicode_path.c_str()); -} - -static DWORD WindowsGetFileAttributes(const std::wstring &filename) { - return GetFileAttributesW(filename.c_str()); -} - -bool LocalFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { - DWORD attrs = WindowsGetFileAttributes(directory); - return (attrs != INVALID_FILE_ATTRIBUTES && (attrs & FILE_ATTRIBUTE_DIRECTORY)); -} - -void LocalFileSystem::CreateDirectory(const string &directory, optional_ptr opener) { - if (DirectoryExists(directory)) { - return; - } - auto unicode_path = NormalizePathAndConvertToUnicode(directory); - if (directory.empty() || !CreateDirectoryW(unicode_path.c_str(), NULL) || !DirectoryExists(directory)) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to create directory \"%s\": %s", directory.c_str(), error); - } -} - -static void DeleteDirectoryRecursive(FileSystem &fs, string directory) { - fs.ListFiles(directory, [&](const string &fname, bool is_directory) { - if (is_directory) { - DeleteDirectoryRecursive(fs, fs.JoinPath(directory, fname)); - } else { - fs.RemoveFile(fs.JoinPath(directory, fname)); - } - }); - auto unicode_path = NormalizePathAndConvertToUnicode(directory); - if (!RemoveDirectoryW(unicode_path.c_str())) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to delete directory \"%s\": %s", directory, error); - } -} - -void LocalFileSystem::RemoveDirectory(const string &directory, optional_ptr opener) { - if (FileExists(directory)) { - throw IOException("Attempting to delete directory \"%s\", but it is a file and not a directory!", directory); - } - if (!DirectoryExists(directory)) { - return; - } - DeleteDirectoryRecursive(*this, directory); -} - -void LocalFileSystem::RemoveFile(const string &filename, optional_ptr opener) { - auto unicode_path = NormalizePathAndConvertToUnicode(filename); - if (!DeleteFileW(unicode_path.c_str())) { - auto error = LocalFileSystem::GetLastErrorAsString(); - throw IOException("Failed to delete file \"%s\": %s", filename, error); - } -} - -bool LocalFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - string search_dir = JoinPath(directory, "*"); - auto unicode_path = NormalizePathAndConvertToUnicode(search_dir); - - WIN32_FIND_DATAW ffd; - HANDLE hFind = FindFirstFileW(unicode_path.c_str(), &ffd); - if (hFind == INVALID_HANDLE_VALUE) { - return false; - } - do { - string cFileName = WindowsUtil::UnicodeToUTF8(ffd.cFileName); - if (cFileName == "." || cFileName == "..") { - continue; - } - callback(cFileName, ffd.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY); - } while (FindNextFileW(hFind, &ffd) != 0); - - DWORD dwError = GetLastError(); - if (dwError != ERROR_NO_MORE_FILES) { - FindClose(hFind); - return false; - } - - FindClose(hFind); - return true; -} - -void LocalFileSystem::FileSync(FileHandle &handle) { - HANDLE hFile = handle.Cast().fd; - if (FlushFileBuffers(hFile) == 0) { - throw IOException("Could not flush file handle to disk!"); - } -} - -void LocalFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { - auto source_unicode = NormalizePathAndConvertToUnicode(source); - auto target_unicode = NormalizePathAndConvertToUnicode(target); - - if (!MoveFileW(source_unicode.c_str(), target_unicode.c_str())) { - throw IOException("Could not move file: %s", GetLastErrorAsString()); - } -} - -FileType LocalFileSystem::GetFileType(FileHandle &handle) { - auto path = handle.Cast().path; - // pipes in windows are just files in '\\.\pipe\' folder - if (strncmp(path.c_str(), PIPE_PREFIX, strlen(PIPE_PREFIX)) == 0) { - return FileType::FILE_TYPE_FIFO; - } - auto normalized_path = NormalizePathAndConvertToUnicode(path); - DWORD attrs = WindowsGetFileAttributes(normalized_path); - if (attrs != INVALID_FILE_ATTRIBUTES) { - if (attrs & FILE_ATTRIBUTE_DIRECTORY) { - return FileType::FILE_TYPE_DIR; - } else { - return FileType::FILE_TYPE_REGULAR; - } - } - return FileType::FILE_TYPE_INVALID; -} -#endif - -bool LocalFileSystem::CanSeek() { - return true; -} - -bool LocalFileSystem::OnDiskFile(FileHandle &handle) { - return true; -} - -void LocalFileSystem::Seek(FileHandle &handle, idx_t location) { - if (!CanSeek()) { - throw IOException("Cannot seek in files of this type"); - } - SetFilePointer(handle, location); -} - -idx_t LocalFileSystem::SeekPosition(FileHandle &handle) { - if (!CanSeek()) { - throw IOException("Cannot seek in files of this type"); - } - return GetFilePointer(handle); -} - -static bool IsCrawl(const string &glob) { - // glob must match exactly - return glob == "**"; -} -static bool HasMultipleCrawl(const vector &splits) { - return std::count(splits.begin(), splits.end(), "**") > 1; -} -static bool IsSymbolicLink(const string &path) { - auto normalized_path = LocalFileSystem::NormalizeLocalPath(path); -#ifndef _WIN32 - struct stat status; - return (lstat(normalized_path, &status) != -1 && S_ISLNK(status.st_mode)); -#else - auto attributes = WindowsGetFileAttributes(path); - if (attributes == INVALID_FILE_ATTRIBUTES) - return false; - return attributes & FILE_ATTRIBUTE_REPARSE_POINT; -#endif -} - -static void RecursiveGlobDirectories(FileSystem &fs, const string &path, vector &result, bool match_directory, - bool join_path) { - - fs.ListFiles(path, [&](const string &fname, bool is_directory) { - string concat; - if (join_path) { - concat = fs.JoinPath(path, fname); - } else { - concat = fname; - } - if (IsSymbolicLink(concat)) { - return; - } - if (is_directory == match_directory) { - result.push_back(concat); - } - if (is_directory) { - RecursiveGlobDirectories(fs, concat, result, match_directory, true); - } - }); -} - -static void GlobFilesInternal(FileSystem &fs, const string &path, const string &glob, bool match_directory, - vector &result, bool join_path) { - fs.ListFiles(path, [&](const string &fname, bool is_directory) { - if (is_directory != match_directory) { - return; - } - if (Glob(fname.c_str(), fname.size(), glob.c_str(), glob.size())) { - if (join_path) { - result.push_back(fs.JoinPath(path, fname)); - } else { - result.push_back(fname); - } - } - }); -} - -vector LocalFileSystem::FetchFileWithoutGlob(const string &path, FileOpener *opener, bool absolute_path) { - vector result; - if (FileExists(path, opener) || IsPipe(path, opener)) { - result.push_back(path); - } else if (!absolute_path) { - Value value; - if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { - auto search_paths_str = value.ToString(); - vector search_paths = StringUtil::Split(search_paths_str, ','); - for (const auto &search_path : search_paths) { - auto joined_path = JoinPath(search_path, path); - if (FileExists(joined_path, opener) || IsPipe(joined_path, opener)) { - result.push_back(joined_path); - } - } - } - } - return result; -} - -// Helper function to handle file:/ URLs -static idx_t GetFileUrlOffset(const string &path) { - if (!StringUtil::StartsWith(path, "file:/")) { - return 0; - } - - // Url without host: file:/some/path - if (path[6] != '/') { -#ifdef _WIN32 - return 6; -#else - return 5; -#endif - } - - // Url with empty host: file:///some/path - if (path[7] == '/') { -#ifdef _WIN32 - return 8; -#else - return 7; -#endif - } - - // Url with localhost: file://localhost/some/path - if (path.compare(7, 10, "localhost/") == 0) { -#ifdef _WIN32 - return 17; -#else - return 16; -#endif - } - - // unkown file:/ url format - return 0; -} - -const char *LocalFileSystem::NormalizeLocalPath(const string &path) { - return path.c_str() + GetFileUrlOffset(path); -} - -vector LocalFileSystem::Glob(const string &path, FileOpener *opener) { - if (path.empty()) { - return vector(); - } - // split up the path into separate chunks - vector splits; - - bool is_file_url = StringUtil::StartsWith(path, "file:/"); - idx_t file_url_path_offset = GetFileUrlOffset(path); - - idx_t last_pos = 0; - for (idx_t i = file_url_path_offset; i < path.size(); i++) { - if (path[i] == '\\' || path[i] == '/') { - if (i == last_pos) { - // empty: skip this position - last_pos = i + 1; - continue; - } - if (splits.empty()) { - // splits.push_back(path.substr(file_url_path_offset, i-file_url_path_offset)); - splits.push_back(path.substr(0, i)); - } else { - splits.push_back(path.substr(last_pos, i - last_pos)); - } - last_pos = i + 1; - } - } - splits.push_back(path.substr(last_pos, path.size() - last_pos)); - // handle absolute paths - bool absolute_path = false; - if (IsPathAbsolute(path)) { - // first character is a slash - unix absolute path - absolute_path = true; - } else if (StringUtil::Contains(splits[0], ":")) { // TODO: this is weird? shouldn't IsPathAbsolute handle this? - // first split has a colon - windows absolute path - absolute_path = true; - } else if (splits[0] == "~") { - // starts with home directory - auto home_directory = GetHomeDirectory(opener); - if (!home_directory.empty()) { - absolute_path = true; - splits[0] = home_directory; - D_ASSERT(path[0] == '~'); - if (!HasGlob(path)) { - return Glob(home_directory + path.substr(1)); - } - } - } - // Check if the path has a glob at all - if (!HasGlob(path)) { - // no glob: return only the file (if it exists or is a pipe) - return FetchFileWithoutGlob(path, opener, absolute_path); - } - vector previous_directories; - if (absolute_path) { - // for absolute paths, we don't start by scanning the current directory - previous_directories.push_back(splits[0]); - } else { - // If file_search_path is set, use those paths as the first glob elements - Value value; - if (opener && opener->TryGetCurrentSetting("file_search_path", value)) { - auto search_paths_str = value.ToString(); - vector search_paths = StringUtil::Split(search_paths_str, ','); - for (const auto &search_path : search_paths) { - previous_directories.push_back(search_path); - } - } - } - - if (HasMultipleCrawl(splits)) { - throw IOException("Cannot use multiple \'**\' in one path"); - } - - idx_t start_index; - if (is_file_url) { - start_index = 1; - } else if (absolute_path) { - start_index = 1; - } else { - start_index = 0; - } - - for (idx_t i = start_index ? 1 : 0; i < splits.size(); i++) { - bool is_last_chunk = i + 1 == splits.size(); - bool has_glob = HasGlob(splits[i]); - // if it's the last chunk we need to find files, otherwise we find directories - // not the last chunk: gather a list of all directories that match the glob pattern - vector result; - if (!has_glob) { - // no glob, just append as-is - if (previous_directories.empty()) { - result.push_back(splits[i]); - } else { - if (is_last_chunk) { - for (auto &prev_directory : previous_directories) { - const string filename = JoinPath(prev_directory, splits[i]); - if (FileExists(filename, opener) || DirectoryExists(filename, opener)) { - result.push_back(filename); - } - } - } else { - for (auto &prev_directory : previous_directories) { - result.push_back(JoinPath(prev_directory, splits[i])); - } - } - } - } else { - if (IsCrawl(splits[i])) { - if (!is_last_chunk) { - result = previous_directories; - } - if (previous_directories.empty()) { - RecursiveGlobDirectories(*this, ".", result, !is_last_chunk, false); - } else { - for (auto &prev_dir : previous_directories) { - RecursiveGlobDirectories(*this, prev_dir, result, !is_last_chunk, true); - } - } - } else { - if (previous_directories.empty()) { - // no previous directories: list in the current path - GlobFilesInternal(*this, ".", splits[i], !is_last_chunk, result, false); - } else { - // previous directories - // we iterate over each of the previous directories, and apply the glob of the current directory - for (auto &prev_directory : previous_directories) { - GlobFilesInternal(*this, prev_directory, splits[i], !is_last_chunk, result, true); - } - } - } - } - if (result.empty()) { - // no result found that matches the glob - // last ditch effort: search the path as a string literal - return FetchFileWithoutGlob(path, opener, absolute_path); - } - if (is_last_chunk) { - return result; - } - previous_directories = std::move(result); - } - return vector(); -} - -unique_ptr FileSystem::CreateLocal() { - return make_uniq(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/multi_file_list.cpp b/src/duckdb/src/common/multi_file_list.cpp deleted file mode 100644 index 668a5b363..000000000 --- a/src/duckdb/src/common/multi_file_list.cpp +++ /dev/null @@ -1,379 +0,0 @@ -#include "duckdb/common/multi_file_reader.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/common/string_util.hpp" - -#include - -namespace duckdb { - -MultiFilePushdownInfo::MultiFilePushdownInfo(LogicalGet &get) - : table_index(get.table_index), column_names(get.names), column_indexes(get.GetColumnIds()), - extra_info(get.extra_info) { - for (auto &col_id : column_indexes) { - column_ids.push_back(col_id.GetPrimaryIndex()); - } -} - -MultiFilePushdownInfo::MultiFilePushdownInfo(idx_t table_index, const vector &column_names, - const vector &column_ids, ExtraOperatorInfo &extra_info) - : table_index(table_index), column_names(column_names), column_ids(column_ids), extra_info(extra_info) { -} - -// Helper method to do Filter Pushdown into a MultiFileList -bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &options, MultiFilePushdownInfo &info, - vector> &filters, vector &expanded_files) { - HivePartitioningFilterInfo filter_info; - for (idx_t i = 0; i < info.column_ids.size(); i++) { - if (!IsRowIdColumnId(info.column_ids[i])) { - filter_info.column_map.insert({info.column_names[info.column_ids[i]], i}); - } - } - filter_info.hive_enabled = options.hive_partitioning; - filter_info.filename_enabled = options.filename; - - auto start_files = expanded_files.size(); - HivePartitioning::ApplyFiltersToFileList(context, expanded_files, filters, filter_info, info); - - if (expanded_files.size() != start_files) { - return true; - } - - return false; -} - -bool PushdownInternal(ClientContext &context, const MultiFileReaderOptions &options, const vector &names, - const vector &types, const vector &column_ids, - const TableFilterSet &filters, vector &expanded_files) { - idx_t table_index = 0; - ExtraOperatorInfo extra_info; - - // construct the pushdown info - MultiFilePushdownInfo info(table_index, names, column_ids, extra_info); - - // construct the set of expressions from the table filters - vector> filter_expressions; - for (auto &entry : filters.filters) { - auto column_idx = column_ids[entry.first]; - auto column_ref = - make_uniq(types[column_idx], ColumnBinding(table_index, entry.first)); - auto filter_expr = entry.second->ToExpression(*column_ref); - filter_expressions.push_back(std::move(filter_expr)); - } - - // call the original PushdownInternal method - return PushdownInternal(context, options, info, filter_expressions, expanded_files); -} - -//===--------------------------------------------------------------------===// -// MultiFileListIterator -//===--------------------------------------------------------------------===// -MultiFileListIterationHelper MultiFileList::Files() { - return MultiFileListIterationHelper(*this); -} - -MultiFileListIterationHelper::MultiFileListIterationHelper(MultiFileList &file_list_p) : file_list(file_list_p) { -} - -MultiFileListIterationHelper::MultiFileListIterator::MultiFileListIterator(MultiFileList *file_list_p) - : file_list(file_list_p) { - if (!file_list) { - return; - } - - file_list->InitializeScan(file_scan_data); - if (!file_list->Scan(file_scan_data, current_file)) { - // There is no first file: move iterator to nop state - file_list = nullptr; - file_scan_data.current_file_idx = DConstants::INVALID_INDEX; - } -} - -void MultiFileListIterationHelper::MultiFileListIterator::Next() { - if (!file_list) { - return; - } - - if (!file_list->Scan(file_scan_data, current_file)) { - // exhausted collection: move iterator to nop state - file_list = nullptr; - file_scan_data.current_file_idx = DConstants::INVALID_INDEX; - } -} - -MultiFileListIterationHelper::MultiFileListIterator MultiFileListIterationHelper::begin() { // NOLINT: match stl API - return MultiFileListIterationHelper::MultiFileListIterator( - file_list.GetExpandResult() == FileExpandResult::NO_FILES ? nullptr : &file_list); -} -MultiFileListIterationHelper::MultiFileListIterator MultiFileListIterationHelper::end() { // NOLINT: match stl API - return MultiFileListIterationHelper::MultiFileListIterator(nullptr); -} - -MultiFileListIterationHelper::MultiFileListIterator &MultiFileListIterationHelper::MultiFileListIterator::operator++() { - Next(); - return *this; -} - -bool MultiFileListIterationHelper::MultiFileListIterator::operator!=(const MultiFileListIterator &other) const { - return file_list != other.file_list || file_scan_data.current_file_idx != other.file_scan_data.current_file_idx; -} - -const string &MultiFileListIterationHelper::MultiFileListIterator::operator*() const { - return current_file; -} - -//===--------------------------------------------------------------------===// -// MultiFileList -//===--------------------------------------------------------------------===// -MultiFileList::MultiFileList(vector paths, FileGlobOptions options) - : paths(std::move(paths)), glob_options(options) { -} - -MultiFileList::~MultiFileList() { -} - -const vector MultiFileList::GetPaths() const { - return paths; -} - -void MultiFileList::InitializeScan(MultiFileListScanData &iterator) { - iterator.current_file_idx = 0; -} - -bool MultiFileList::Scan(MultiFileListScanData &iterator, string &result_file) { - D_ASSERT(iterator.current_file_idx != DConstants::INVALID_INDEX); - auto maybe_file = GetFile(iterator.current_file_idx); - - if (maybe_file.empty()) { - D_ASSERT(iterator.current_file_idx >= GetTotalFileCount()); - return false; - } - - result_file = maybe_file; - iterator.current_file_idx++; - return true; -} - -unique_ptr MultiFileList::ComplexFilterPushdown(ClientContext &context, - const MultiFileReaderOptions &options, - MultiFilePushdownInfo &info, - vector> &filters) { - // By default the filter pushdown into a multifilelist does nothing - return nullptr; -} - -unique_ptr -MultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, - const vector &names, const vector &types, - const vector &column_ids, TableFilterSet &filters) const { - // By default the filter pushdown into a multifilelist does nothing - return nullptr; -} - -unique_ptr MultiFileList::GetCardinality(ClientContext &context) { - return nullptr; -} - -string MultiFileList::GetFirstFile() { - return GetFile(0); -} - -bool MultiFileList::IsEmpty() { - return GetExpandResult() == FileExpandResult::NO_FILES; -} - -//===--------------------------------------------------------------------===// -// SimpleMultiFileList -//===--------------------------------------------------------------------===// -SimpleMultiFileList::SimpleMultiFileList(vector paths_p) - : MultiFileList(std::move(paths_p), FileGlobOptions::ALLOW_EMPTY) { -} - -unique_ptr SimpleMultiFileList::ComplexFilterPushdown(ClientContext &context_p, - const MultiFileReaderOptions &options, - MultiFilePushdownInfo &info, - vector> &filters) { - if (!options.hive_partitioning && !options.filename) { - return nullptr; - } - - // FIXME: don't copy list until first file is filtered - auto file_copy = paths; - auto res = PushdownInternal(context_p, options, info, filters, file_copy); - - if (res) { - return make_uniq(file_copy); - } - - return nullptr; -} - -unique_ptr -SimpleMultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, - const vector &names, const vector &types, - const vector &column_ids, TableFilterSet &filters) const { - if (!options.hive_partitioning && !options.filename) { - return nullptr; - } - - // FIXME: don't copy list until first file is filtered - auto file_copy = paths; - auto res = PushdownInternal(context, options, names, types, column_ids, filters, file_copy); - if (res) { - return make_uniq(file_copy); - } - - return nullptr; -} - -vector SimpleMultiFileList::GetAllFiles() { - return paths; -} - -FileExpandResult SimpleMultiFileList::GetExpandResult() { - if (paths.size() > 1) { - return FileExpandResult::MULTIPLE_FILES; - } else if (paths.size() == 1) { - return FileExpandResult::SINGLE_FILE; - } - - return FileExpandResult::NO_FILES; -} - -string SimpleMultiFileList::GetFile(idx_t i) { - if (paths.empty() || i >= paths.size()) { - return ""; - } - - return paths[i]; -} - -idx_t SimpleMultiFileList::GetTotalFileCount() { - return paths.size(); -} - -//===--------------------------------------------------------------------===// -// GlobMultiFileList -//===--------------------------------------------------------------------===// -GlobMultiFileList::GlobMultiFileList(ClientContext &context_p, vector paths_p, FileGlobOptions options) - : MultiFileList(std::move(paths_p), options), context(context_p), current_path(0) { -} - -unique_ptr GlobMultiFileList::ComplexFilterPushdown(ClientContext &context_p, - const MultiFileReaderOptions &options, - MultiFilePushdownInfo &info, - vector> &filters) { - lock_guard lck(lock); - - // Expand all - // FIXME: lazy expansion - // FIXME: push down filters into glob - while (ExpandNextPath()) { - } - - if (!options.hive_partitioning && !options.filename) { - return nullptr; - } - auto res = PushdownInternal(context, options, info, filters, expanded_files); - if (res) { - return make_uniq(expanded_files); - } - - return nullptr; -} - -unique_ptr -GlobMultiFileList::DynamicFilterPushdown(ClientContext &context, const MultiFileReaderOptions &options, - const vector &names, const vector &types, - const vector &column_ids, TableFilterSet &filters) const { - if (!options.hive_partitioning && !options.filename) { - return nullptr; - } - lock_guard lck(lock); - - // Expand all paths into a copy - // FIXME: lazy expansion and push filters into glob - idx_t path_index = current_path; - auto file_list = expanded_files; - while (ExpandPathInternal(path_index, file_list)) { - } - - auto res = PushdownInternal(context, options, names, types, column_ids, filters, file_list); - if (res) { - return make_uniq(file_list); - } - - return nullptr; -} - -vector GlobMultiFileList::GetAllFiles() { - lock_guard lck(lock); - while (ExpandNextPath()) { - } - return expanded_files; -} - -idx_t GlobMultiFileList::GetTotalFileCount() { - lock_guard lck(lock); - while (ExpandNextPath()) { - } - return expanded_files.size(); -} - -FileExpandResult GlobMultiFileList::GetExpandResult() { - // GetFile(1) will ensure at least the first 2 files are expanded if they are available - GetFile(1); - - if (expanded_files.size() > 1) { - return FileExpandResult::MULTIPLE_FILES; - } else if (expanded_files.size() == 1) { - return FileExpandResult::SINGLE_FILE; - } - - return FileExpandResult::NO_FILES; -} - -string GlobMultiFileList::GetFile(idx_t i) { - lock_guard lck(lock); - return GetFileInternal(i); -} - -string GlobMultiFileList::GetFileInternal(idx_t i) { - while (expanded_files.size() <= i) { - if (!ExpandNextPath()) { - return ""; - } - } - D_ASSERT(expanded_files.size() > i); - return expanded_files[i]; -} - -bool GlobMultiFileList::ExpandPathInternal(idx_t ¤t_path, vector &result) const { - if (current_path >= paths.size()) { - return false; - } - - auto &fs = FileSystem::GetFileSystem(context); - auto glob_files = fs.GlobFiles(paths[current_path], context, glob_options); - std::sort(glob_files.begin(), glob_files.end()); - result.insert(result.end(), glob_files.begin(), glob_files.end()); - - current_path++; - return true; -} - -bool GlobMultiFileList::ExpandNextPath() { - return ExpandPathInternal(current_path, expanded_files); -} - -bool GlobMultiFileList::IsFullyExpanded() const { - return current_path == paths.size(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/multi_file_reader.cpp b/src/duckdb/src/common/multi_file_reader.cpp deleted file mode 100644 index 97eba1014..000000000 --- a/src/duckdb/src/common/multi_file_reader.cpp +++ /dev/null @@ -1,618 +0,0 @@ -#include "duckdb/common/multi_file_reader.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/planner/expression/bound_columnref_expression.hpp" -#include "duckdb/common/string_util.hpp" - -#include - -namespace duckdb { - -MultiFileReaderGlobalState::~MultiFileReaderGlobalState() { -} - -MultiFileReader::~MultiFileReader() { -} - -unique_ptr MultiFileReader::Create(const TableFunction &table_function) { - unique_ptr res; - if (table_function.get_multi_file_reader) { - res = table_function.get_multi_file_reader(table_function); - res->function_name = table_function.name; - } else { - res = make_uniq(); - res->function_name = table_function.name; - } - return res; -} - -unique_ptr MultiFileReader::CreateDefault(const string &function_name) { - auto res = make_uniq(); - res->function_name = function_name; - return res; -} - -Value MultiFileReader::CreateValueFromFileList(const vector &file_list) { - vector files; - for (auto &file : file_list) { - files.push_back(file); - } - return Value::LIST(LogicalType::VARCHAR, std::move(files)); -} - -void MultiFileReader::AddParameters(TableFunction &table_function) { - table_function.named_parameters["filename"] = LogicalType::ANY; - table_function.named_parameters["hive_partitioning"] = LogicalType::BOOLEAN; - table_function.named_parameters["union_by_name"] = LogicalType::BOOLEAN; - table_function.named_parameters["hive_types"] = LogicalType::ANY; - table_function.named_parameters["hive_types_autocast"] = LogicalType::BOOLEAN; -} - -vector MultiFileReader::ParsePaths(const Value &input) { - if (input.IsNull()) { - throw ParserException("%s cannot take NULL list as parameter", function_name); - } - - if (input.type().id() == LogicalTypeId::VARCHAR) { - return {StringValue::Get(input)}; - } else if (input.type().id() == LogicalTypeId::LIST) { - vector paths; - for (auto &val : ListValue::GetChildren(input)) { - if (val.IsNull()) { - throw ParserException("%s reader cannot take NULL input as parameter", function_name); - } - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw ParserException("%s reader can only take a list of strings as a parameter", function_name); - } - paths.push_back(StringValue::Get(val)); - } - return paths; - } else { - throw InternalException("Unsupported type for MultiFileReader::ParsePaths called with: '%s'"); - } -} - -shared_ptr MultiFileReader::CreateFileList(ClientContext &context, const vector &paths, - FileGlobOptions options) { - vector result_files; - - auto res = make_uniq(context, paths, options); - if (res->GetExpandResult() == FileExpandResult::NO_FILES && options == FileGlobOptions::DISALLOW_EMPTY) { - throw IOException("%s needs at least one file to read", function_name); - } - return std::move(res); -} - -shared_ptr MultiFileReader::CreateFileList(ClientContext &context, const Value &input, - FileGlobOptions options) { - auto paths = ParsePaths(input); - return CreateFileList(context, paths, options); -} - -bool MultiFileReader::ParseOption(const string &key, const Value &val, MultiFileReaderOptions &options, - ClientContext &context) { - auto loption = StringUtil::Lower(key); - if (loption == "filename") { - if (val.type() == LogicalType::VARCHAR) { - // If not, we interpret it as the name of the column containing the filename - options.filename = true; - options.filename_column = StringValue::Get(val); - } else { - Value boolean_value; - string error_message; - if (val.DefaultTryCastAs(LogicalType::BOOLEAN, boolean_value, &error_message)) { - // If the argument can be cast to boolean, we just interpret it as a boolean - options.filename = BooleanValue::Get(boolean_value); - } - } - } else if (loption == "hive_partitioning") { - options.hive_partitioning = BooleanValue::Get(val); - options.auto_detect_hive_partitioning = false; - } else if (loption == "union_by_name") { - options.union_by_name = BooleanValue::Get(val); - } else if (loption == "hive_types_autocast" || loption == "hive_type_autocast") { - options.hive_types_autocast = BooleanValue::Get(val); - } else if (loption == "hive_types" || loption == "hive_type") { - if (val.type().id() != LogicalTypeId::STRUCT) { - throw InvalidInputException( - "'hive_types' only accepts a STRUCT('name':VARCHAR, ...), but '%s' was provided", - val.type().ToString()); - } - // verify that that all the children of the struct value are VARCHAR - auto &children = StructValue::GetChildren(val); - for (idx_t i = 0; i < children.size(); i++) { - const Value &child = children[i]; - if (child.type().id() != LogicalType::VARCHAR) { - throw InvalidInputException("hive_types: '%s' must be a VARCHAR, instead: '%s' was provided", - StructType::GetChildName(val.type(), i), child.type().ToString()); - } - // for every child of the struct, get the logical type - LogicalType transformed_type = TransformStringToLogicalType(child.ToString(), context); - const string &name = StructType::GetChildName(val.type(), i); - options.hive_types_schema[name] = transformed_type; - } - D_ASSERT(!options.hive_types_schema.empty()); - } else { - return false; - } - return true; -} - -unique_ptr MultiFileReader::ComplexFilterPushdown(ClientContext &context, MultiFileList &files, - const MultiFileReaderOptions &options, - MultiFilePushdownInfo &info, - vector> &filters) { - return files.ComplexFilterPushdown(context, options, info, filters); -} - -unique_ptr MultiFileReader::DynamicFilterPushdown(ClientContext &context, const MultiFileList &files, - const MultiFileReaderOptions &options, - const vector &names, - const vector &types, - const vector &column_ids, - TableFilterSet &filters) { - return files.DynamicFilterPushdown(context, options, names, types, column_ids, filters); -} - -bool MultiFileReader::Bind(MultiFileReaderOptions &options, MultiFileList &files, vector &return_types, - vector &names, MultiFileReaderBindData &bind_data) { - // The Default MultiFileReader can not perform any binding as it uses MultiFileLists with no schema information. - return false; -} - -void MultiFileReader::BindOptions(MultiFileReaderOptions &options, MultiFileList &files, - vector &return_types, vector &names, - MultiFileReaderBindData &bind_data) { - // Add generated constant column for filename - if (options.filename) { - if (std::find(names.begin(), names.end(), options.filename_column) != names.end()) { - throw BinderException("Option filename adds column \"%s\", but a column with this name is also in the " - "file. Try setting a different name: filename=''", - options.filename_column); - } - bind_data.filename_idx = names.size(); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back(options.filename_column); - } - - // Add generated constant columns from hive partitioning scheme - if (options.hive_partitioning) { - D_ASSERT(files.GetExpandResult() != FileExpandResult::NO_FILES); - auto partitions = HivePartitioning::Parse(files.GetFirstFile()); - // verify that all files have the same hive partitioning scheme - for (const auto &file : files.Files()) { - auto file_partitions = HivePartitioning::Parse(file); - for (auto &part_info : partitions) { - if (file_partitions.find(part_info.first) == file_partitions.end()) { - string error = "Hive partition mismatch between file \"%s\" and \"%s\": key \"%s\" not found"; - if (options.auto_detect_hive_partitioning == true) { - throw InternalException(error + "(hive partitioning was autodetected)", files.GetFirstFile(), - file, part_info.first); - } - throw BinderException(error.c_str(), files.GetFirstFile(), file, part_info.first); - } - } - if (partitions.size() != file_partitions.size()) { - string error_msg = "Hive partition mismatch between file \"%s\" and \"%s\""; - if (options.auto_detect_hive_partitioning == true) { - throw InternalException(error_msg + "(hive partitioning was autodetected)", files.GetFirstFile(), - file); - } - throw BinderException(error_msg.c_str(), files.GetFirstFile(), file); - } - } - - if (!options.hive_types_schema.empty()) { - // verify that all hive_types are existing partitions - options.VerifyHiveTypesArePartitions(partitions); - } - - for (auto &part : partitions) { - idx_t hive_partitioning_index; - auto lookup = std::find_if(names.begin(), names.end(), [&](const string &col_name) { - return StringUtil::CIEquals(col_name, part.first); - }); - if (lookup != names.end()) { - // hive partitioning column also exists in file - override - auto idx = NumericCast(lookup - names.begin()); - hive_partitioning_index = idx; - return_types[idx] = options.GetHiveLogicalType(part.first); - } else { - // hive partitioning column does not exist in file - add a new column containing the key - hive_partitioning_index = names.size(); - return_types.emplace_back(options.GetHiveLogicalType(part.first)); - names.emplace_back(part.first); - } - bind_data.hive_partitioning_indexes.emplace_back(part.first, hive_partitioning_index); - } - } -} - -void MultiFileReader::FinalizeBind(const MultiFileReaderOptions &file_options, const MultiFileReaderBindData &options, - const string &filename, const vector &local_names, - const vector &global_types, const vector &global_names, - const vector &global_column_ids, MultiFileReaderData &reader_data, - ClientContext &context, optional_ptr global_state) { - - // create a map of name -> column index - case_insensitive_map_t name_map; - if (file_options.union_by_name) { - for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { - name_map[local_names[col_idx]] = col_idx; - } - } - for (idx_t i = 0; i < global_column_ids.size(); i++) { - auto &col_idx = global_column_ids[i]; - if (col_idx.IsRowIdColumn()) { - // row-id - reader_data.constant_map.emplace_back(i, Value::BIGINT(42)); - continue; - } - auto column_id = col_idx.GetPrimaryIndex(); - if (column_id == options.filename_idx) { - // filename - reader_data.constant_map.emplace_back(i, Value(filename)); - continue; - } - if (!options.hive_partitioning_indexes.empty()) { - // hive partition constants - auto partitions = HivePartitioning::Parse(filename); - D_ASSERT(partitions.size() == options.hive_partitioning_indexes.size()); - bool found_partition = false; - for (auto &entry : options.hive_partitioning_indexes) { - if (column_id == entry.index) { - Value value = file_options.GetHivePartitionValue(partitions[entry.value], entry.value, context); - reader_data.constant_map.emplace_back(i, value); - found_partition = true; - break; - } - } - if (found_partition) { - continue; - } - } - if (file_options.union_by_name) { - auto &global_name = global_names[column_id]; - auto entry = name_map.find(global_name); - bool not_present_in_file = entry == name_map.end(); - if (not_present_in_file) { - // we need to project a column with name \"global_name\" - but it does not exist in the current file - // push a NULL value of the specified type - reader_data.constant_map.emplace_back(i, Value(global_types[column_id])); - continue; - } - } - } -} - -unique_ptr -MultiFileReader::InitializeGlobalState(ClientContext &context, const MultiFileReaderOptions &file_options, - const MultiFileReaderBindData &bind_data, const MultiFileList &file_list, - const vector &global_types, const vector &global_names, - const vector &global_column_ids) { - // By default, the multifilereader does not require any global state - return nullptr; -} - -void MultiFileReader::CreateNameMapping(const string &file_name, const vector &local_types, - const vector &local_names, const vector &global_types, - const vector &global_names, - const vector &global_column_ids, MultiFileReaderData &reader_data, - const string &initial_file, - optional_ptr global_state) { - D_ASSERT(global_types.size() == global_names.size()); - D_ASSERT(local_types.size() == local_names.size()); - // we have expected types: create a map of name -> column index - case_insensitive_map_t name_map; - for (idx_t col_idx = 0; col_idx < local_names.size(); col_idx++) { - name_map[local_names[col_idx]] = col_idx; - } - for (idx_t i = 0; i < global_column_ids.size(); i++) { - // check if this is a constant column - bool constant = false; - for (auto &entry : reader_data.constant_map) { - if (entry.column_id == i) { - constant = true; - break; - } - } - if (constant) { - // this column is constant for this file - continue; - } - // not constant - look up the column in the name map - auto &global_idx = global_column_ids[i]; - auto global_id = global_idx.GetPrimaryIndex(); - if (global_id >= global_types.size()) { - throw InternalException( - "MultiFileReader::CreatePositionalMapping - global_id is out of range in global_types for this file"); - } - auto &global_name = global_names[global_id]; - auto entry = name_map.find(global_name); - if (entry == name_map.end()) { - string candidate_names; - for (auto &local_name : local_names) { - if (!candidate_names.empty()) { - candidate_names += ", "; - } - candidate_names += local_name; - } - throw IOException( - StringUtil::Format("Failed to read file \"%s\": schema mismatch in glob: column \"%s\" was read from " - "the original file \"%s\", but could not be found in file \"%s\".\nCandidate names: " - "%s\nIf you are trying to " - "read files with different schemas, try setting union_by_name=True", - file_name, global_name, initial_file, file_name, candidate_names)); - } - // we found the column in the local file - check if the types are the same - auto local_id = entry->second; - D_ASSERT(global_id < global_types.size()); - D_ASSERT(local_id < local_types.size()); - auto &global_type = global_types[global_id]; - auto &local_type = local_types[local_id]; - ColumnIndex local_index(local_id); - if (global_type != local_type) { - // the types are not the same - add a cast - reader_data.cast_map[local_id] = global_type; - } else { - local_index = ColumnIndex(local_id, global_idx.GetChildIndexes()); - } - // create the mapping - reader_data.column_mapping.push_back(i); - reader_data.column_ids.push_back(local_id); - reader_data.column_indexes.push_back(std::move(local_index)); - } - - reader_data.empty_columns = reader_data.column_indexes.empty(); -} - -void MultiFileReader::CreateMapping(const string &file_name, const vector &local_types, - const vector &local_names, const vector &global_types, - const vector &global_names, const vector &global_column_ids, - optional_ptr filters, MultiFileReaderData &reader_data, - const string &initial_file, const MultiFileReaderBindData &options, - optional_ptr global_state) { - CreateNameMapping(file_name, local_types, local_names, global_types, global_names, global_column_ids, reader_data, - initial_file, global_state); - CreateFilterMap(global_types, filters, reader_data, global_state); -} - -void MultiFileReader::CreateFilterMap(const vector &global_types, optional_ptr filters, - MultiFileReaderData &reader_data, - optional_ptr global_state) { - if (filters) { - auto filter_map_size = global_types.size(); - if (global_state) { - filter_map_size += global_state->extra_columns.size(); - } - reader_data.filter_map.resize(filter_map_size); - - for (idx_t c = 0; c < reader_data.column_mapping.size(); c++) { - auto map_index = reader_data.column_mapping[c]; - reader_data.filter_map[map_index].index = c; - reader_data.filter_map[map_index].is_constant = false; - } - for (idx_t c = 0; c < reader_data.constant_map.size(); c++) { - auto constant_index = reader_data.constant_map[c].column_id; - reader_data.filter_map[constant_index].index = c; - reader_data.filter_map[constant_index].is_constant = true; - } - } -} - -void MultiFileReader::FinalizeChunk(ClientContext &context, const MultiFileReaderBindData &bind_data, - const MultiFileReaderData &reader_data, DataChunk &chunk, - optional_ptr global_state) { - // reference all the constants set up in MultiFileReader::FinalizeBind - for (auto &entry : reader_data.constant_map) { - chunk.data[entry.column_id].Reference(entry.value); - } - chunk.Verify(); -} - -void MultiFileReader::GetPartitionData(ClientContext &context, const MultiFileReaderBindData &bind_data, - const MultiFileReaderData &reader_data, - optional_ptr global_state, - const OperatorPartitionInfo &partition_info, - OperatorPartitionData &partition_data) { - for (auto &col : partition_info.partition_columns) { - bool found_constant = false; - for (auto &constant : reader_data.constant_map) { - if (constant.column_id == col) { - found_constant = true; - partition_data.partition_data.emplace_back(constant.value); - break; - } - } - if (!found_constant) { - throw InternalException( - "MultiFileReader::GetPartitionData - did not find constant for the given partition"); - } - } -} - -TablePartitionInfo MultiFileReader::GetPartitionInfo(ClientContext &context, const MultiFileReaderBindData &bind_data, - TableFunctionPartitionInput &input) { - // check if all of the columns are in the hive partition set - for (auto &partition_col : input.partition_ids) { - // check if this column is in the hive partitioned set - bool found = false; - for (auto &partition : bind_data.hive_partitioning_indexes) { - if (partition.index == partition_col) { - found = true; - break; - } - } - if (!found) { - // the column is not partitioned - hive partitioning alone can't guarantee the groups are partitioned - return TablePartitionInfo::NOT_PARTITIONED; - } - } - // if all columns are in the hive partitioning set, we know that each partition will only have a single value - // i.e. if the hive partitioning is by (YEAR, MONTH), each partition will have a single unique (YEAR, MONTH) - return TablePartitionInfo::SINGLE_VALUE_PARTITIONS; -} - -TableFunctionSet MultiFileReader::CreateFunctionSet(TableFunction table_function) { - TableFunctionSet function_set(table_function.name); - function_set.AddFunction(table_function); - D_ASSERT(table_function.arguments.size() >= 1 && table_function.arguments[0] == LogicalType::VARCHAR); - table_function.arguments[0] = LogicalType::LIST(LogicalType::VARCHAR); - function_set.AddFunction(std::move(table_function)); - return function_set; -} - -HivePartitioningIndex::HivePartitioningIndex(string value_p, idx_t index) : value(std::move(value_p)), index(index) { -} - -void MultiFileReaderOptions::AddBatchInfo(BindInfo &bind_info) const { - bind_info.InsertOption("filename", Value(filename_column)); - bind_info.InsertOption("hive_partitioning", Value::BOOLEAN(hive_partitioning)); - bind_info.InsertOption("auto_detect_hive_partitioning", Value::BOOLEAN(auto_detect_hive_partitioning)); - bind_info.InsertOption("union_by_name", Value::BOOLEAN(union_by_name)); - bind_info.InsertOption("hive_types_autocast", Value::BOOLEAN(hive_types_autocast)); -} - -void UnionByName::CombineUnionTypes(const vector &col_names, const vector &sql_types, - vector &union_col_types, vector &union_col_names, - case_insensitive_map_t &union_names_map) { - D_ASSERT(col_names.size() == sql_types.size()); - - for (idx_t col = 0; col < col_names.size(); ++col) { - auto union_find = union_names_map.find(col_names[col]); - - if (union_find != union_names_map.end()) { - // given same name , union_col's type must compatible with col's type - auto ¤t_type = union_col_types[union_find->second]; - auto compatible_type = LogicalType::ForceMaxLogicalType(current_type, sql_types[col]); - union_col_types[union_find->second] = compatible_type; - } else { - union_names_map[col_names[col]] = union_col_names.size(); - union_col_names.emplace_back(col_names[col]); - union_col_types.emplace_back(sql_types[col]); - } - } -} - -bool MultiFileReaderOptions::AutoDetectHivePartitioningInternal(MultiFileList &files, ClientContext &context) { - auto first_file = files.GetFirstFile(); - auto partitions = HivePartitioning::Parse(first_file); - if (partitions.empty()) { - // no partitions found in first file - return false; - } - - for (const auto &file : files.Files()) { - auto new_partitions = HivePartitioning::Parse(file); - if (new_partitions.size() != partitions.size()) { - // partition count mismatch - return false; - } - for (auto &part : new_partitions) { - auto entry = partitions.find(part.first); - if (entry == partitions.end()) { - // differing partitions between files - return false; - } - } - } - return true; -} -void MultiFileReaderOptions::AutoDetectHiveTypesInternal(MultiFileList &files, ClientContext &context) { - const LogicalType candidates[] = {LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::BIGINT}; - - unordered_map detected_types; - for (const auto &file : files.Files()) { - auto partitions = HivePartitioning::Parse(file); - if (partitions.empty()) { - return; - } - - for (auto &part : partitions) { - const string &name = part.first; - if (hive_types_schema.find(name) != hive_types_schema.end()) { - // type was explicitly provided by the user - continue; - } - LogicalType detected_type = LogicalType::VARCHAR; - Value value(part.second); - for (auto &candidate : candidates) { - const bool success = value.TryCastAs(context, candidate, true); - if (success) { - detected_type = candidate; - break; - } - } - auto entry = detected_types.find(name); - if (entry == detected_types.end()) { - // type was not yet detected - insert it - detected_types.insert(make_pair(name, std::move(detected_type))); - } else { - // type was already detected - check if the type matches - // if not promote to VARCHAR - if (entry->second != detected_type) { - entry->second = LogicalType::VARCHAR; - } - } - } - } - for (auto &entry : detected_types) { - hive_types_schema.insert(make_pair(entry.first, std::move(entry.second))); - } -} -void MultiFileReaderOptions::AutoDetectHivePartitioning(MultiFileList &files, ClientContext &context) { - D_ASSERT(files.GetExpandResult() != FileExpandResult::NO_FILES); - const bool hp_explicitly_disabled = !auto_detect_hive_partitioning && !hive_partitioning; - const bool ht_enabled = !hive_types_schema.empty(); - if (hp_explicitly_disabled && ht_enabled) { - throw InvalidInputException("cannot disable hive_partitioning when hive_types is enabled"); - } - if (ht_enabled && auto_detect_hive_partitioning && !hive_partitioning) { - // hive_types flag implies hive_partitioning - hive_partitioning = true; - auto_detect_hive_partitioning = false; - } - if (auto_detect_hive_partitioning) { - hive_partitioning = AutoDetectHivePartitioningInternal(files, context); - } - if (hive_partitioning && hive_types_autocast) { - AutoDetectHiveTypesInternal(files, context); - } -} -void MultiFileReaderOptions::VerifyHiveTypesArePartitions(const std::map &partitions) const { - for (auto &hive_type : hive_types_schema) { - if (partitions.find(hive_type.first) == partitions.end()) { - throw InvalidInputException("Unknown hive_type: \"%s\" does not appear to be a partition", hive_type.first); - } - } -} -LogicalType MultiFileReaderOptions::GetHiveLogicalType(const string &hive_partition_column) const { - if (!hive_types_schema.empty()) { - auto it = hive_types_schema.find(hive_partition_column); - if (it != hive_types_schema.end()) { - return it->second; - } - } - return LogicalType::VARCHAR; -} - -bool MultiFileReaderOptions::AnySet() { - return filename || hive_partitioning || union_by_name; -} - -Value MultiFileReaderOptions::GetHivePartitionValue(const string &value, const string &key, - ClientContext &context) const { - auto it = hive_types_schema.find(key); - if (it == hive_types_schema.end()) { - return HivePartitioning::GetValue(context, key, value, LogicalType::VARCHAR); - } - return HivePartitioning::GetValue(context, key, value, it->second); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/opener_file_system.cpp b/src/duckdb/src/common/opener_file_system.cpp deleted file mode 100644 index 8f55d6898..000000000 --- a/src/duckdb/src/common/opener_file_system.cpp +++ /dev/null @@ -1,37 +0,0 @@ -#include "duckdb/common/opener_file_system.hpp" -#include "duckdb/common/file_opener.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -void OpenerFileSystem::VerifyNoOpener(optional_ptr opener) { - if (opener) { - throw InternalException("OpenerFileSystem cannot take an opener - the opener is pushed automatically"); - } -} -void OpenerFileSystem::VerifyCanAccessFileInternal(const string &path, FileType type) { - auto opener = GetOpener(); - if (!opener) { - return; - } - auto db = opener->TryGetDatabase(); - if (!db) { - return; - } - auto &config = db->config; - if (!config.CanAccessFile(path, type)) { - throw PermissionException("Cannot access %s \"%s\" - file system operations are disabled by configuration", - type == FileType::FILE_TYPE_DIR ? "directory" : "file", path); - } -} - -void OpenerFileSystem::VerifyCanAccessFile(const string &path) { - VerifyCanAccessFileInternal(path, FileType::FILE_TYPE_REGULAR); -} - -void OpenerFileSystem::VerifyCanAccessDirectory(const string &path) { - VerifyCanAccessFileInternal(path, FileType::FILE_TYPE_DIR); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/operator/cast_operators.cpp b/src/duckdb/src/common/operator/cast_operators.cpp deleted file mode 100644 index 9742c6087..000000000 --- a/src/duckdb/src/common/operator/cast_operators.cpp +++ /dev/null @@ -1,2781 +0,0 @@ -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/operator/string_cast.hpp" -#include "duckdb/common/operator/numeric_cast.hpp" -#include "duckdb/common/operator/decimal_cast_operators.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/subtract.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types.hpp" -#include "fast_float/fast_float.h" -#include "fmt/format.h" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/operator/integer_cast_operator.hpp" -#include "duckdb/common/operator/double_cast_operator.hpp" - -#include -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Cast bool -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(bool input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(bool input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int8_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int8_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int8_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int16_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int16_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int16_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int32_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int32_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int32_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast int64_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(int64_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(int64_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast hugeint_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(hugeint_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(hugeint_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uhugeint_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uhugeint_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uhugeint_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint8_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint8_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint8_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint16_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint16_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint16_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint32_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint32_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint32_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast uint64_t -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(uint64_t input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(uint64_t input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast float -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(float input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(float input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast double -> Numeric -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(double input, bool &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, int64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, hugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uhugeint_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint8_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint16_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint32_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, uint64_t &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, float &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -template <> -bool TryCast::Operation(double input, double &result, bool strict) { - return NumericTryCast::Operation(input, result, strict); -} - -//===--------------------------------------------------------------------===// -// Cast String -> Numeric -//===--------------------------------------------------------------------===// - -template <> -bool TryCast::Operation(string_t input, bool &result, bool strict) { - auto input_data = reinterpret_cast(input.GetData()); - auto input_size = input.GetSize(); - return TryCastStringBool(input_data, input_size, result, strict); -} -template <> -bool TryCast::Operation(string_t input, int8_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int16_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int32_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, int64_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCast::Operation(string_t input, uint8_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint16_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint32_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} -template <> -bool TryCast::Operation(string_t input, uint64_t &result, bool strict) { - return TrySimpleIntegerCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCast::Operation(string_t input, float &result, bool strict) { - return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCast::Operation(string_t input, double &result, bool strict) { - return TryDoubleCast(input.GetData(), input.GetSize(), result, strict); -} - -template <> -bool TryCastErrorMessageCommaSeparated::Operation(string_t input, float &result, CastParameters ¶meters) { - if (!TryDoubleCast(input.GetData(), input.GetSize(), result, parameters.strict, ',')) { - HandleCastError::AssignError(StringUtil::Format("Could not cast string to float: \"%s\"", input.GetString()), - parameters); - return false; - } - return true; -} - -template <> -bool TryCastErrorMessageCommaSeparated::Operation(string_t input, double &result, CastParameters ¶meters) { - if (!TryDoubleCast(input.GetData(), input.GetSize(), result, parameters.strict, ',')) { - HandleCastError::AssignError(StringUtil::Format("Could not cast string to double: \"%s\"", input.GetString()), - parameters); - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Date -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(date_t input, date_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(date_t input, timestamp_t &result, bool strict) { - if (input == date_t::infinity()) { - result = timestamp_t::infinity(); - return true; - } else if (input == date_t::ninfinity()) { - result = timestamp_t::ninfinity(); - return true; - } - return Timestamp::TryFromDatetime(input, Time::FromTime(0, 0, 0), result); -} - -//===--------------------------------------------------------------------===// -// Cast From Time -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(dtime_t input, dtime_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(dtime_t input, dtime_tz_t &result, bool strict) { - result = dtime_tz_t(input, 0); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Time With Time Zone (Offset) -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(dtime_tz_t input, dtime_tz_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(dtime_tz_t input, dtime_t &result, bool strict) { - result = input.time(); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Timestamps -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(timestamp_t input, date_t &result, bool strict) { - result = Timestamp::GetDate(input); - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, dtime_t &result, bool strict) { - if (!Timestamp::IsFinite(input)) { - return false; - } - result = Timestamp::GetTime(input); - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_t &result, bool strict) { - result = input; - return true; -} - -template <> -bool TryCast::Operation(timestamp_sec_t input, timestamp_sec_t &result, bool strict) { - result.value = input.value; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_sec_t &result, bool strict) { - D_ASSERT(Timestamp::IsFinite(input)); - result.value = input.value / Interval::MICROS_PER_SEC; - return true; -} - -template <> -bool TryCast::Operation(timestamp_ms_t input, timestamp_ms_t &result, bool strict) { - result.value = input.value; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_ms_t &result, bool strict) { - D_ASSERT(Timestamp::IsFinite(input)); - result.value = input.value / Interval::MICROS_PER_MSEC; - return true; -} - -template <> -bool TryCast::Operation(timestamp_ns_t input, timestamp_ns_t &result, bool strict) { - result.value = input.value; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_ns_t &result, bool strict) { - D_ASSERT(Timestamp::IsFinite(input)); - if (!TryMultiplyOperator::Operation(input.value, Interval::NANOS_PER_MSEC, result.value)) { - throw ConversionException("Could not convert TIMESTAMP to TIMESTAMP_NS"); - } - return true; -} - -template <> -bool TryCast::Operation(timestamp_tz_t input, timestamp_tz_t &result, bool strict) { - result.value = input.value; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, timestamp_tz_t &result, bool strict) { - result.value = input.value; - return true; -} - -template <> -bool TryCast::Operation(timestamp_t input, dtime_tz_t &result, bool strict) { - if (!Timestamp::IsFinite(input)) { - return false; - } - result = dtime_tz_t(Timestamp::GetTime(input), 0); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast from Interval -//===--------------------------------------------------------------------===// -template <> -bool TryCast::Operation(interval_t input, interval_t &result, bool strict) { - result = input; - return true; -} - -//===--------------------------------------------------------------------===// -// Non-Standard Timestamps -//===--------------------------------------------------------------------===// -template <> -duckdb::string_t CastFromTimestampNS::Operation(duckdb::timestamp_ns_t input, Vector &result) { - return StringCast::Operation(input, result); -} -template <> -duckdb::string_t CastFromTimestampMS::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(CastTimestampMsToUs::Operation(input), result); -} -template <> -duckdb::string_t CastFromTimestampSec::Operation(duckdb::timestamp_t input, Vector &result) { - return StringCast::Operation(CastTimestampSecToUs::Operation(input), result); -} - -template <> -timestamp_t CastTimestampUsToMs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - timestamp_t cast_timestamp(Timestamp::GetEpochRounded(input, Interval::MICROS_PER_MSEC)); - return cast_timestamp; -} - -template <> -timestamp_t CastTimestampUsToNs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - timestamp_t cast_timestamp(Timestamp::GetEpochNanoSeconds(input)); - return cast_timestamp; -} - -template <> -timestamp_t CastTimestampUsToSec::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - timestamp_t cast_timestamp(Timestamp::GetEpochRounded(input, Interval::MICROS_PER_SEC)); - return cast_timestamp; -} - -template <> -timestamp_t CastTimestampMsToUs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::FromEpochMs(input.value); -} - -template <> -date_t CastTimestampMsToDate::Operation(timestamp_t input) { - return Timestamp::GetDate(Timestamp::FromEpochMs(input.value)); -} - -template <> -dtime_t CastTimestampMsToTime::Operation(timestamp_t input) { - return Timestamp::GetTime(Timestamp::FromEpochMs(input.value)); -} - -template <> -timestamp_t CastTimestampMsToNs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - auto us = CastTimestampMsToUs::Operation(input); - return CastTimestampUsToNs::Operation(us); -} - -template <> -timestamp_t CastTimestampNsToUs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::FromEpochNanoSeconds(input.value); -} - -template <> -timestamp_t CastTimestampSecToUs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - return Timestamp::FromEpochSeconds(input.value); -} - -template <> -date_t CastTimestampNsToDate::Operation(timestamp_t input) { - if (input == timestamp_t::infinity()) { - return date_t::infinity(); - } else if (input == timestamp_t::ninfinity()) { - return date_t::ninfinity(); - } - const auto us = CastTimestampNsToUs::Operation(input); - return Timestamp::GetDate(us); -} - -template <> -dtime_t CastTimestampNsToTime::Operation(timestamp_t input) { - const auto us = CastTimestampNsToUs::Operation(input); - return Timestamp::GetTime(us); -} - -template <> -timestamp_t CastTimestampSecToMs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - auto us = CastTimestampSecToUs::Operation(input); - return CastTimestampUsToMs::Operation(us); -} - -template <> -timestamp_t CastTimestampSecToNs::Operation(timestamp_t input) { - if (!Timestamp::IsFinite(input)) { - return input; - } - auto us = CastTimestampSecToUs::Operation(input); - return CastTimestampUsToNs::Operation(us); -} - -template <> -date_t CastTimestampSecToDate::Operation(timestamp_t input) { - const auto us = CastTimestampSecToUs::Operation(input); - return Timestamp::GetDate(us); -} - -template <> -dtime_t CastTimestampSecToTime::Operation(timestamp_t input) { - const auto us = CastTimestampSecToUs::Operation(input); - return Timestamp::GetTime(us); -} - -//===--------------------------------------------------------------------===// -// Cast To Timestamp -//===--------------------------------------------------------------------===// -template <> -bool TryCastToTimestampNS::Operation(string_t input, timestamp_ns_t &result, bool strict) { - return TryCast::Operation(input, result, strict); -} - -template <> -bool TryCastToTimestampMS::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result = CastTimestampUsToMs::Operation(result); - return true; -} - -template <> -bool TryCastToTimestampSec::Operation(string_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - result = CastTimestampUsToSec::Operation(result); - return true; -} - -template <> -bool TryCastToTimestampNS::Operation(date_t input, timestamp_ns_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - if (!Timestamp::IsFinite(result)) { - return true; - } - if (!TryMultiplyOperator::Operation(result.value, Interval::NANOS_PER_MICRO, result.value)) { - return false; - } - return true; -} - -template <> -bool TryCastToTimestampMS::Operation(date_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - if (!Timestamp::IsFinite(result)) { - return true; - } - result.value /= Interval::MICROS_PER_MSEC; - return true; -} - -template <> -bool TryCastToTimestampSec::Operation(date_t input, timestamp_t &result, bool strict) { - if (!TryCast::Operation(input, result, strict)) { - return false; - } - if (!Timestamp::IsFinite(result)) { - return true; - } - result.value /= Interval::MICROS_PER_MSEC * Interval::MSECS_PER_SEC; - return true; -} - -//===--------------------------------------------------------------------===// -// Cast From Blob -//===--------------------------------------------------------------------===// -template <> -string_t CastFromBlob::Operation(string_t input, Vector &vector) { - idx_t result_size = Blob::GetStringSize(input); - - string_t result = StringVector::EmptyString(vector, result_size); - Blob::ToString(input, result.GetDataWriteable()); - result.Finalize(); - - return result; -} - -template <> -string_t CastFromBlobToBit::Operation(string_t input, Vector &vector) { - idx_t result_size = input.GetSize() + 1; - if (result_size <= 1) { - throw ConversionException("Cannot cast empty BLOB to BIT"); - } - return StringVector::AddStringOrBlob(vector, Bit::BlobToBit(input)); -} - -//===--------------------------------------------------------------------===// -// Cast From Bit -//===--------------------------------------------------------------------===// -template <> -string_t CastFromBitToString::Operation(string_t input, Vector &vector) { - - idx_t result_size = Bit::BitLength(input); - string_t result = StringVector::EmptyString(vector, result_size); - Bit::ToString(input, result.GetDataWriteable()); - result.Finalize(); - - return result; -} - -//===--------------------------------------------------------------------===// -// Cast From Pointer -//===--------------------------------------------------------------------===// -template <> -string_t CastFromPointer::Operation(uintptr_t input, Vector &vector) { - std::string s = duckdb_fmt::format("0x{:x}", input); - return StringVector::AddString(vector, s); -} - -//===--------------------------------------------------------------------===// -// Cast To Blob -//===--------------------------------------------------------------------===// -template <> -bool TryCastToBlob::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters) { - idx_t result_size; - if (!Blob::TryGetBlobSize(input, result_size, parameters)) { - return false; - } - - result = StringVector::EmptyString(result_vector, result_size); - Blob::ToBlob(input, data_ptr_cast(result.GetDataWriteable())); - result.Finalize(); - return true; -} - -//===--------------------------------------------------------------------===// -// Cast To Bit -//===--------------------------------------------------------------------===// -template <> -bool TryCastToBit::Operation(string_t input, string_t &result, Vector &result_vector, CastParameters ¶meters) { - idx_t result_size; - if (!Bit::TryGetBitStringSize(input, result_size, parameters.error_message)) { - return false; - } - - result = StringVector::EmptyString(result_vector, result_size); - Bit::ToBit(input, result); - result.Finalize(); - return true; -} - -template <> -bool CastFromBitToNumeric::Operation(string_t input, bool &result, CastParameters ¶meters) { - D_ASSERT(input.GetSize() > 1); - - uint8_t value; - bool success = CastFromBitToNumeric::Operation(input, value, parameters); - result = (value > 0); - return (success); -} - -template <> -bool CastFromBitToNumeric::Operation(string_t input, hugeint_t &result, CastParameters ¶meters) { - D_ASSERT(input.GetSize() > 1); - - if (input.GetSize() - 1 > sizeof(hugeint_t)) { - throw ConversionException(parameters.query_location, "Bitstring doesn't fit inside of %s", - GetTypeId()); - } - Bit::BitToNumeric(input, result); - return (true); -} - -template <> -bool CastFromBitToNumeric::Operation(string_t input, uhugeint_t &result, CastParameters ¶meters) { - D_ASSERT(input.GetSize() > 1); - - if (input.GetSize() - 1 > sizeof(uhugeint_t)) { - throw ConversionException(parameters.query_location, "Bitstring doesn't fit inside of %s", - GetTypeId()); - } - Bit::BitToNumeric(input, result); - return (true); -} - -//===--------------------------------------------------------------------===// -// Cast From UUID -//===--------------------------------------------------------------------===// -template <> -string_t CastFromUUID::Operation(hugeint_t input, Vector &vector) { - string_t result = StringVector::EmptyString(vector, 36); - UUID::ToString(input, result.GetDataWriteable()); - result.Finalize(); - return result; -} - -//===--------------------------------------------------------------------===// -// Cast To UUID -//===--------------------------------------------------------------------===// -template <> -bool TryCastToUUID::Operation(string_t input, hugeint_t &result, Vector &result_vector, CastParameters ¶meters) { - return UUID::FromString(input.GetString(), result); -} - -//===--------------------------------------------------------------------===// -// Cast To Date -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, date_t &result, CastParameters ¶meters) { - idx_t pos; - bool special = false; - switch (Date::TryConvertDate(input.GetData(), input.GetSize(), pos, result, special, parameters.strict)) { - case DateCastResult::SUCCESS: - break; - case DateCastResult::ERROR_INCORRECT_FORMAT: - HandleCastError::AssignError(Date::FormatError(input), parameters); - return false; - case DateCastResult::ERROR_RANGE: - HandleCastError::AssignError(Date::RangeError(input), parameters); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, date_t &result, bool strict) { - idx_t pos; - bool special = false; - return Date::TryConvertDate(input.GetData(), input.GetSize(), pos, result, special, strict) == - DateCastResult::SUCCESS; -} - -template <> -date_t Cast::Operation(string_t input) { - return Date::FromCString(input.GetData(), input.GetSize()); -} - -//===--------------------------------------------------------------------===// -// Cast To Time -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, dtime_t &result, CastParameters ¶meters) { - if (!TryCast::Operation(input, result, parameters.strict)) { - HandleCastError::AssignError(Time::ConversionError(input), parameters); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, dtime_t &result, bool strict) { - idx_t pos; - return Time::TryConvertTime(input.GetData(), input.GetSize(), pos, result, strict); -} - -template <> -dtime_t Cast::Operation(string_t input) { - return Time::FromCString(input.GetData(), input.GetSize()); -} - -//===--------------------------------------------------------------------===// -// Cast To TimeTZ -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, dtime_tz_t &result, CastParameters ¶meters) { - if (!TryCast::Operation(input, result, parameters.strict)) { - HandleCastError::AssignError(Time::ConversionError(input), parameters); - return false; - } - return true; -} - -template <> -bool TryCast::Operation(string_t input, dtime_tz_t &result, bool strict) { - idx_t pos; - bool has_offset; - return Time::TryConvertTimeTZ(input.GetData(), input.GetSize(), pos, result, has_offset, strict); -} - -template <> -dtime_tz_t Cast::Operation(string_t input) { - dtime_tz_t result; - if (!TryCast::Operation(input, result, false)) { - throw ConversionException(Time::ConversionError(input)); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Cast To Timestamp -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, timestamp_t &result, CastParameters ¶meters) { - switch (Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result)) { - case TimestampCastResult::SUCCESS: - return true; - case TimestampCastResult::ERROR_INCORRECT_FORMAT: - HandleCastError::AssignError(Timestamp::FormatError(input), parameters); - break; - case TimestampCastResult::ERROR_NON_UTC_TIMEZONE: - HandleCastError::AssignError(Timestamp::UnsupportedTimezoneError(input), parameters); - break; - case TimestampCastResult::ERROR_RANGE: - HandleCastError::AssignError(Timestamp::RangeError(input), parameters); - break; - } - return false; -} - -template <> -bool TryCast::Operation(string_t input, timestamp_t &result, bool strict) { - return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; -} - -template <> -bool TryCast::Operation(string_t input, timestamp_ns_t &result, bool strict) { - return Timestamp::TryConvertTimestamp(input.GetData(), input.GetSize(), result) == TimestampCastResult::SUCCESS; -} - -template <> -timestamp_t Cast::Operation(string_t input) { - return Timestamp::FromCString(input.GetData(), input.GetSize()); -} - -template <> -timestamp_ns_t Cast::Operation(string_t input) { - int32_t nanos; - const auto ts = Timestamp::FromCString(input.GetData(), input.GetSize(), &nanos); - timestamp_ns_t result; - if (!Timestamp::TryFromTimestampNanos(ts, nanos, result)) { - throw ConversionException(Timestamp::RangeError(input)); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Cast From Interval -//===--------------------------------------------------------------------===// -template <> -bool TryCastErrorMessage::Operation(string_t input, interval_t &result, CastParameters ¶meters) { - return Interval::FromCString(input.GetData(), input.GetSize(), result, parameters.error_message, parameters.strict); -} - -//===--------------------------------------------------------------------===// -// Cast to hugeint / uhugeint -//===--------------------------------------------------------------------===// -// parsing hugeint from string is done a bit differently for performance reasons -// for other integer types we keep track of a single value -// and multiply that value by 10 for every digit we read -// however, for hugeints, multiplication is very expensive (>20X as expensive as for int64) -// for that reason, we parse numbers first into an int64 value -// when that value is full, we perform a HUGEINT multiplication to flush it into the hugeint -// this takes the number of HUGEINT multiplications down from [0-38] to [0-2] - -template -struct HugeIntCastData { - using ResultType = T; - using IntermediateType = INTERMEDIATE_T; - using Operation = OP; - ResultType result; - IntermediateType intermediate; - uint8_t digits; - - ResultType decimal; - uint16_t decimal_total_digits; - ResultType decimal_intermediate; - uint16_t decimal_intermediate_digits; - - bool Flush() { - if (digits == 0 && intermediate == 0) { - return true; - } - if (result.lower != 0 || result.upper != 0) { - if (digits > 38) { - return false; - } - if (!OP::TryMultiply(result, OP::POWERS_OF_TEN[digits], result)) { - return false; - } - } - if (!OP::TryAddInPlace(result, ResultType(intermediate))) { - return false; - } - digits = 0; - intermediate = 0; - return true; - } - - bool FlushDecimal() { - if (decimal_intermediate_digits == 0 && decimal_intermediate == 0) { - return true; - } - if (decimal.lower != 0 || decimal.upper != 0) { - if (decimal_intermediate_digits > 38) { - return false; - } - if (!OP::TryMultiply(decimal, OP::POWERS_OF_TEN[decimal_intermediate_digits], decimal)) { - return false; - } - } - if (!OP::TryAddInPlace(decimal, ResultType(decimal_intermediate))) { - return false; - } - decimal_total_digits += decimal_intermediate_digits; - decimal_intermediate_digits = 0; - decimal_intermediate = 0; - return true; - } -}; - -struct HugeIntegerCastOperation { - template - static bool HandleDigit(T &state, uint8_t digit) { - if (NEGATIVE) { - if (DUCKDB_UNLIKELY(state.intermediate < - (NumericLimits::Minimum() + digit) / 10)) { - // intermediate is full: need to flush it - if (!state.Flush()) { - return false; - } - } - state.intermediate = state.intermediate * 10 - digit; - } else { - if (DUCKDB_UNLIKELY(state.intermediate > - (NumericLimits::Maximum() - digit) / 10)) { - if (!state.Flush()) { - return false; - } - } - state.intermediate = state.intermediate * 10 + digit; - } - state.digits++; - return true; - } - - template - static bool HandleHexDigit(T &state, uint8_t digit) { - return false; - } - - template - static bool HandleBinaryDigit(T &state, uint8_t digit) { - return false; - } - - template - static bool HandleExponent(T &state, int32_t exponent) { - using result_t = typename T::ResultType; - if (!state.Flush()) { - return false; - } - - int32_t e = exponent; - if (e < -38) { - state.result = 0; - return true; - } - - // Negative Exponent - result_t remainder = 0; - if (e < 0) { - state.result = T::Operation::DivMod(state.result, T::Operation::POWERS_OF_TEN[-e], remainder); - if (remainder < 0) { - result_t negate_result; - if (!T::Operation::TryNegate(remainder, negate_result)) { - return false; - } - remainder = negate_result; - } - state.decimal = remainder; - state.decimal_total_digits = static_cast(-e); - state.decimal_intermediate = 0; - state.decimal_intermediate_digits = 0; - return Finalize(state); - } - - // Positive Exponent - if (state.result != 0) { - if (e > 38 || !TryMultiplyOperator::Operation(state.result, T::Operation::POWERS_OF_TEN[e], state.result)) { - return false; - } - } - if (!state.FlushDecimal()) { - return false; - } - if (state.decimal == 0) { - return Finalize(state); - } - - e = exponent - state.decimal_total_digits; - if (e < 0) { - state.decimal = T::Operation::DivMod(state.decimal, T::Operation::POWERS_OF_TEN[-e], remainder); - state.decimal_total_digits -= (exponent); - } else { - if (e > 38 || - !TryMultiplyOperator::Operation(state.decimal, T::Operation::POWERS_OF_TEN[e], state.decimal)) { - return false; - } - } - - if (NEGATIVE) { - if (!TrySubtractOperator::Operation(state.result, state.decimal, state.result)) { - return false; - } - } else if (!TryAddOperator::Operation(state.result, state.decimal, state.result)) { - return false; - } - state.decimal = remainder; - return Finalize(state); - } - - template - static bool HandleDecimal(T &state, uint8_t digit) { - if (!state.Flush()) { - return false; - } - if (DUCKDB_UNLIKELY(state.decimal_intermediate > (NumericLimits::Maximum() - digit) / 10)) { - if (!state.FlushDecimal()) { - return false; - } - } - state.decimal_intermediate = state.decimal_intermediate * 10 + digit; - state.decimal_intermediate_digits++; - return true; - } - - template - static bool Finalize(T &state) { - using result_t = typename T::ResultType; - if (!state.Flush() || !state.FlushDecimal()) { - return false; - } - - if (state.decimal == 0 || state.decimal_total_digits == 0) { - return true; - } - - // Get the first (left-most) digit of the decimals - while (state.decimal_total_digits > 39) { - state.decimal /= T::Operation::POWERS_OF_TEN[39]; - state.decimal_total_digits -= 39; - } - D_ASSERT((state.decimal_total_digits - 1) >= 0 && (state.decimal_total_digits - 1) <= 39); - state.decimal /= T::Operation::POWERS_OF_TEN[state.decimal_total_digits - 1]; - - if (state.decimal >= 5) { - if (NEGATIVE) { - return TrySubtractOperator::Operation(state.result, result_t(1), state.result); - } else { - return TryAddOperator::Operation(state.result, result_t(1), state.result); - } - } - return true; - } -}; - -template <> -bool TryCast::Operation(string_t input, hugeint_t &result, bool strict) { - HugeIntCastData state {}; - if (!TryIntegerCast, true, true, HugeIntegerCastOperation>( - input.GetData(), input.GetSize(), state, strict)) { - return false; - } - result = state.result; - return true; -} - -template <> -bool TryCast::Operation(string_t input, uhugeint_t &result, bool strict) { - HugeIntCastData state {}; - if (!TryIntegerCast, false, true, HugeIntegerCastOperation>( - input.GetData(), input.GetSize(), state, strict)) { - return false; - } - result = state.result; - return true; -} - -//===--------------------------------------------------------------------===// -// Decimal String Cast -//===--------------------------------------------------------------------===// - -template <> -bool TryCastToDecimal::Operation(string_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(string_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int16_t &result, CastParameters ¶meters, - uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int32_t &result, CastParameters ¶meters, - uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, int64_t &result, CastParameters ¶meters, - uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimalCommaSeparated::Operation(string_t input, hugeint_t &result, CastParameters ¶meters, - uint8_t width, uint8_t scale) { - return TryDecimalStringCast(input, result, parameters, width, scale); -} - -template <> -string_t StringCastFromDecimal::Operation(int16_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(int32_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(int64_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -template <> -string_t StringCastFromDecimal::Operation(hugeint_t input, uint8_t width, uint8_t scale, Vector &result) { - return DecimalToString::Format(input, width, scale, result); -} - -//===--------------------------------------------------------------------===// -// Decimal Casts -//===--------------------------------------------------------------------===// -// Decimal <-> Bool -//===--------------------------------------------------------------------===// -template -bool TryCastBoolToDecimal(bool input, T &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - if (width > scale) { - result = UnsafeNumericCast(input ? OP::POWERS_OF_TEN[scale] : 0); - return true; - } else { - return TryCast::Operation(input, result); - } -} - -template <> -bool TryCastToDecimal::Operation(bool input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastBoolToDecimal(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastBoolToDecimal(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastBoolToDecimal(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(bool input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastBoolToDecimal(input, result, parameters, width, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int16_t input, bool &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, bool &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, bool &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCast::Operation(input, result); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, bool &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCast::Operation(input, result); -} - -//===--------------------------------------------------------------------===// -// Numeric -> Decimal Cast -//===--------------------------------------------------------------------===// -struct SignedToDecimalOperator { - template - static bool Operation(SRC input, DST max_width) { - return int64_t(input) >= int64_t(max_width) || int64_t(input) <= int64_t(-max_width); - } -}; - -struct UnsignedToDecimalOperator { - template - static bool Operation(SRC input, DST max_width) { - return uint64_t(input) >= uint64_t(max_width); - } -}; - -template -bool StandardNumericToDecimalCast(SRC input, DST &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - // check for overflow - DST max_width = UnsafeNumericCast(NumericHelper::POWERS_OF_TEN[width - scale]); - if (OP::template Operation(input, max_width)) { - string error = StringUtil::Format("Could not cast value %d to DECIMAL(%d,%d)", input, width, scale); - HandleCastError::AssignError(error, parameters); - return false; - } - result = UnsafeNumericCast(DST(input) * NumericHelper::POWERS_OF_TEN[scale]); - return true; -} - -template -bool NumericToHugeDecimalCast(SRC input, hugeint_t &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - // check for overflow - hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; - hugeint_t hinput = Hugeint::Convert(input); - if (hinput >= max_width || hinput <= -max_width) { - string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", hinput.ToString(), width, scale); - HandleCastError::AssignError(error, parameters); - return false; - } - result = hinput * Hugeint::POWERS_OF_TEN[scale]; - return true; -} - -//===--------------------------------------------------------------------===// -// Cast int8_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int8_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int8_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int16_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int16_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int16_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int32_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int32_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int32_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast int64_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(int64_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, scale); -} -template <> -bool TryCastToDecimal::Operation(int64_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint8_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint8_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint8_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint16_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint16_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint16_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint32_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint32_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint32_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Cast uint64_t -> Decimal -//===--------------------------------------------------------------------===// -template <> -bool TryCastToDecimal::Operation(uint64_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return StandardNumericToDecimalCast(input, result, parameters, width, - scale); -} -template <> -bool TryCastToDecimal::Operation(uint64_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return NumericToHugeDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Hugeint -> Decimal Cast -//===--------------------------------------------------------------------===// -template -bool HugeintToDecimalCast(hugeint_t input, DST &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - // check for overflow - hugeint_t max_width = Hugeint::POWERS_OF_TEN[width - scale]; - if (input >= max_width || input <= -max_width) { - string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", input.ToString(), width, scale); - HandleCastError::AssignError(error, parameters); - return false; - } - result = Hugeint::Cast(input * Hugeint::POWERS_OF_TEN[scale]); - return true; -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(hugeint_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return HugeintToDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Uhugeint -> Decimal Cast -//===--------------------------------------------------------------------===// -template -bool UhugeintToDecimalCast(uhugeint_t input, DST &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - // check for overflow - uhugeint_t max_width = Uhugeint::POWERS_OF_TEN[width - scale]; - if (input >= max_width) { - string error = StringUtil::Format("Could not cast value %s to DECIMAL(%d,%d)", input.ToString(), width, scale); - HandleCastError::AssignError(error, parameters); - return false; - } - result = Uhugeint::Cast(input * Uhugeint::POWERS_OF_TEN[scale]); - return true; -} - -template <> -bool TryCastToDecimal::Operation(uhugeint_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return UhugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(uhugeint_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return UhugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(uhugeint_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return UhugeintToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(uhugeint_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return UhugeintToDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Float/Double -> Decimal Cast -//===--------------------------------------------------------------------===// -template -bool DoubleToDecimalCast(SRC input, DST &result, CastParameters ¶meters, uint8_t width, uint8_t scale) { - double value = input * NumericHelper::DOUBLE_POWERS_OF_TEN[scale]; - double roundedValue = round(value); - if (roundedValue <= -NumericHelper::DOUBLE_POWERS_OF_TEN[width] || - roundedValue >= NumericHelper::DOUBLE_POWERS_OF_TEN[width]) { - string error = StringUtil::Format("Could not cast value %f to DECIMAL(%d,%d)", input, width, scale); - HandleCastError::AssignError(error, parameters); - return false; - } - result = Cast::Operation(static_cast(value)); - return true; -} - -template <> -bool TryCastToDecimal::Operation(float input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(float input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -template <> -bool TryCastToDecimal::Operation(double input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return DoubleToDecimalCast(input, result, parameters, width, scale); -} - -//===--------------------------------------------------------------------===// -// Decimal -> Numeric Cast -//===--------------------------------------------------------------------===// -template -bool TryCastDecimalToNumeric(SRC input, DST &result, CastParameters ¶meters, uint8_t scale) { - // Round away from 0. - const auto power = NumericHelper::POWERS_OF_TEN[scale]; - // https://graphics.stanford.edu/~seander/bithacks.html#ConditionalNegate - const auto fNegate = int64_t(input < 0); - const auto rounding = ((power ^ -fNegate) + fNegate) / 2; - const auto scaled_value = (input + rounding) / power; - if (!TryCast::Operation(UnsafeNumericCast(scaled_value), result)) { - string error = StringUtil::Format("Failed to cast decimal value %d to type %s", scaled_value, GetTypeId()); - HandleCastError::AssignError(error, parameters); - return false; - } - return true; -} - -template -bool TryCastHugeDecimalToNumeric(hugeint_t input, DST &result, CastParameters ¶meters, uint8_t scale) { - const auto power = Hugeint::POWERS_OF_TEN[scale]; - const auto rounding = ((input < 0) ? -power : power) / 2; - auto scaled_value = (input + rounding) / power; - if (!TryCast::Operation(scaled_value, result)) { - string error = StringUtil::Format("Failed to cast decimal value %s to type %s", - ConvertToString::Operation(scaled_value), GetTypeId()); - HandleCastError::AssignError(error, parameters); - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int8_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int16_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int32_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> int64_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, int64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint8_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint8_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint16_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint16_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint32_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint32_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uint64_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uint64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uint64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uint64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uint64_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> hugeint_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, hugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Cast Decimal -> uhugeint_t -//===--------------------------------------------------------------------===// -template <> -bool TryCastFromDecimal::Operation(int16_t input, uhugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int32_t input, uhugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(int64_t input, uhugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToNumeric(input, result, parameters, scale); -} -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, uhugeint_t &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastHugeDecimalToNumeric(input, result, parameters, scale); -} - -//===--------------------------------------------------------------------===// -// Decimal -> Float/Double Cast -//===--------------------------------------------------------------------===// -template -static bool IsRepresentableExactly(SRC input, DST); - -template <> -bool IsRepresentableExactly(int16_t input, float dst) { - return true; -} - -const int64_t MAX_INT_REPRESENTABLE_IN_FLOAT = 0x001000000LL; -const int64_t MAX_INT_REPRESENTABLE_IN_DOUBLE = 0x0020000000000000LL; - -template <> -bool IsRepresentableExactly(int32_t input, float dst) { - return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); -} - -template <> -bool IsRepresentableExactly(int64_t input, float dst) { - return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); -} - -template <> -bool IsRepresentableExactly(hugeint_t input, float dst) { - return (input <= MAX_INT_REPRESENTABLE_IN_FLOAT && input >= -MAX_INT_REPRESENTABLE_IN_FLOAT); -} - -template <> -bool IsRepresentableExactly(int16_t input, double dst) { - return true; -} - -template <> -bool IsRepresentableExactly(int32_t input, double dst) { - return true; -} - -template <> -bool IsRepresentableExactly(int64_t input, double dst) { - return (input <= MAX_INT_REPRESENTABLE_IN_DOUBLE && input >= -MAX_INT_REPRESENTABLE_IN_DOUBLE); -} - -template <> -bool IsRepresentableExactly(hugeint_t input, double dst) { - return (input <= MAX_INT_REPRESENTABLE_IN_DOUBLE && input >= -MAX_INT_REPRESENTABLE_IN_DOUBLE); -} - -template -static SRC GetPowerOfTen(SRC input, uint8_t scale) { - return static_cast(NumericHelper::POWERS_OF_TEN[scale]); -} - -template <> -hugeint_t GetPowerOfTen(hugeint_t input, uint8_t scale) { - return Hugeint::POWERS_OF_TEN[scale]; -} - -template -static void GetDivMod(SRC lhs, SRC rhs, SRC &div, SRC &mod) { - div = lhs / rhs; - mod = lhs % rhs; -} - -template <> -void GetDivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &div, hugeint_t &mod) { - div = Hugeint::DivMod(lhs, rhs, mod); -} - -template -bool TryCastDecimalToFloatingPoint(SRC input, DST &result, uint8_t scale) { - if (IsRepresentableExactly(input, DST(0.0)) || scale == 0) { - // Fast path, integer is representable exaclty as a float/double - result = Cast::Operation(input) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); - return true; - } - auto power_of_ten = GetPowerOfTen(input, scale); - - SRC div = 0; - SRC mod = 0; - GetDivMod(input, power_of_ten, div, mod); - - result = Cast::Operation(div) + - Cast::Operation(mod) / DST(NumericHelper::DOUBLE_POWERS_OF_TEN[scale]); - return true; -} - -// DECIMAL -> FLOAT -template <> -bool TryCastFromDecimal::Operation(int16_t input, float &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, float &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, float &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, float &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -// DECIMAL -> DOUBLE -template <> -bool TryCastFromDecimal::Operation(int16_t input, double &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int32_t input, double &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(int64_t input, double &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -template <> -bool TryCastFromDecimal::Operation(hugeint_t input, double &result, CastParameters ¶meters, uint8_t width, - uint8_t scale) { - return TryCastDecimalToFloatingPoint(input, result, scale); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/operator/convert_to_string.cpp b/src/duckdb/src/common/operator/convert_to_string.cpp deleted file mode 100644 index 758717759..000000000 --- a/src/duckdb/src/common/operator/convert_to_string.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/common/operator/convert_to_string.hpp" -#include "duckdb/common/operator/string_cast.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -template -string StandardStringCast(T input) { - Vector v(LogicalType::VARCHAR); - return StringCast::Operation(input, v).GetString(); -} - -template <> -string ConvertToString::Operation(bool input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int8_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int16_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int32_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(int64_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint8_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint16_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint32_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uint64_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(hugeint_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(uhugeint_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(float input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(double input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(interval_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(date_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(dtime_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(timestamp_t input) { - return StandardStringCast(input); -} -template <> -string ConvertToString::Operation(string_t input) { - return input.GetString(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/operator/string_cast.cpp b/src/duckdb/src/common/operator/string_cast.cpp deleted file mode 100644 index e152c71bf..000000000 --- a/src/duckdb/src/common/operator/string_cast.cpp +++ /dev/null @@ -1,303 +0,0 @@ -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/operator/string_cast.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Cast Numeric -> String -//===--------------------------------------------------------------------===// -template <> -string_t StringCast::Operation(bool input, Vector &vector) { - if (input) { - return StringVector::AddString(vector, "true", 4); - } else { - return StringVector::AddString(vector, "false", 5); - } -} - -template <> -string_t StringCast::Operation(int8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(int16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -string_t StringCast::Operation(int32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(int64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint8_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint16_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint32_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(uint64_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} -template <> -duckdb::string_t StringCast::Operation(hugeint_t input, Vector &vector) { - return NumericHelper::FormatSigned(input, vector); -} - -template <> -string_t StringCast::Operation(float input, Vector &vector) { - std::string s = duckdb_fmt::format("{}", input); - return StringVector::AddString(vector, s); -} - -template <> -string_t StringCast::Operation(double input, Vector &vector) { - std::string s = duckdb_fmt::format("{}", input); - return StringVector::AddString(vector, s); -} - -template <> -string_t StringCast::Operation(interval_t input, Vector &vector) { - char buffer[70] = {}; - idx_t length = IntervalToStringCast::Format(input, buffer); - return StringVector::AddString(vector, buffer, length); -} - -template <> -duckdb::string_t StringCast::Operation(uhugeint_t input, Vector &vector) { - return UhugeintToStringCast::Format(input, vector); -} - -template <> -duckdb::string_t StringCast::Operation(date_t input, Vector &vector) { - if (input == date_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } else if (input == date_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - int32_t date[3]; - Date::Convert(input, date[0], date[1], date[2]); - - idx_t year_length; - bool add_bc; - idx_t length = DateToStringCast::Length(date, year_length, add_bc); - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - DateToStringCast::Format(data, date, year_length, add_bc); - - result.Finalize(); - return result; -} - -template <> -duckdb::string_t StringCast::Operation(dtime_t input, Vector &vector) { - int32_t time[4]; - Time::Convert(input, time[0], time[1], time[2], time[3]); - - char micro_buffer[10] = {}; - idx_t length = TimeToStringCast::Length(time, micro_buffer); - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - TimeToStringCast::Format(data, length, time, micro_buffer); - - result.Finalize(); - return result; -} - -template -duckdb::string_t StringFromTimestamp(timestamp_t input, Vector &vector) { - if (input == timestamp_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } - if (input == timestamp_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - - date_t date_entry; - dtime_t time_entry; - int32_t picos = 0; - if (HAS_NANOS) { - timestamp_ns_t ns; - ns.value = input.value; - Timestamp::Convert(ns, date_entry, time_entry, picos); - // Use picoseconds so we have 6 digits - picos *= 1000; - } else { - Timestamp::Convert(input, date_entry, time_entry); - } - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamp is DATE TIME (separated by space) - idx_t year_length; - bool add_bc; - char micro_buffer[6] = {}; - char nano_buffer[6] = {}; - idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - idx_t nano_length = 0; - if (picos) { - // If there are ps, we need all the µs - time_length = 15; - nano_length = 6; - nano_length -= NumericCast(TimeToStringCast::FormatMicros(picos, nano_buffer)); - } - const idx_t length = date_length + 1 + time_length + nano_length; - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - DateToStringCast::Format(data, date, year_length, add_bc); - data += date_length; - *data++ = ' '; - TimeToStringCast::Format(data, time_length, time, micro_buffer); - data += time_length; - memcpy(data, nano_buffer, nano_length); - D_ASSERT(data + nano_length <= result.GetDataWriteable() + length); - - result.Finalize(); - return result; -} - -template <> -duckdb::string_t StringCast::Operation(timestamp_t input, Vector &vector) { - return StringFromTimestamp(input, vector); -} - -template <> -duckdb::string_t StringCast::Operation(timestamp_ns_t input, Vector &vector) { - return StringFromTimestamp(input, vector); -} - -template <> -duckdb::string_t StringCast::Operation(duckdb::string_t input, Vector &result) { - return StringVector::AddStringOrBlob(result, input); -} - -template <> -string_t StringCastTZ::Operation(dtime_tz_t input, Vector &vector) { - int32_t time[4]; - Time::Convert(input.time(), time[0], time[1], time[2], time[3]); - - char micro_buffer[10] = {}; - const auto time_length = TimeToStringCast::Length(time, micro_buffer); - idx_t length = time_length; - - const auto offset = input.offset(); - const bool negative = (offset < 0); - ++length; - - auto ss = std::abs(offset); - const auto hh = ss / Interval::SECS_PER_HOUR; - - const auto hh_length = UnsafeNumericCast((hh < 100) ? 2 : NumericHelper::UnsignedLength(uint32_t(hh))); - length += hh_length; - - ss %= Interval::SECS_PER_HOUR; - const auto mm = ss / Interval::SECS_PER_MINUTE; - if (mm) { - length += 3; - } - - ss %= Interval::SECS_PER_MINUTE; - if (ss) { - length += 3; - } - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); - pos += time_length; - - data[pos++] = negative ? '-' : '+'; - if (hh < 100) { - TimeToStringCast::FormatTwoDigits(data + pos, hh); - } else { - NumericHelper::FormatUnsigned(hh, data + pos + hh_length); - } - pos += hh_length; - - if (mm) { - data[pos++] = ':'; - TimeToStringCast::FormatTwoDigits(data + pos, mm); - pos += 2; - } - - if (ss) { - data[pos++] = ':'; - TimeToStringCast::FormatTwoDigits(data + pos, ss); - pos += 2; - } - - result.Finalize(); - return result; -} - -template <> -string_t StringCastTZ::Operation(timestamp_t input, Vector &vector) { - if (input == timestamp_t::infinity()) { - return StringVector::AddString(vector, Date::PINF); - } - if (input == timestamp_t::ninfinity()) { - return StringVector::AddString(vector, Date::NINF); - } - - date_t date_entry; - dtime_t time_entry; - Timestamp::Convert(input, date_entry, time_entry); - - int32_t date[3], time[4]; - Date::Convert(date_entry, date[0], date[1], date[2]); - Time::Convert(time_entry, time[0], time[1], time[2], time[3]); - - // format for timestamptz is DATE TIME+00 (separated by space) - idx_t year_length; - bool add_bc; - char micro_buffer[6] = {}; - const idx_t date_length = DateToStringCast::Length(date, year_length, add_bc); - const idx_t time_length = TimeToStringCast::Length(time, micro_buffer); - const idx_t length = date_length + 1 + time_length + 3; - - string_t result = StringVector::EmptyString(vector, length); - auto data = result.GetDataWriteable(); - - idx_t pos = 0; - DateToStringCast::Format(data + pos, date, year_length, add_bc); - pos += date_length; - data[pos++] = ' '; - TimeToStringCast::Format(data + pos, time_length, time, micro_buffer); - pos += time_length; - data[pos++] = '+'; - data[pos++] = '0'; - data[pos++] = '0'; - - result.Finalize(); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/pipe_file_system.cpp b/src/duckdb/src/common/pipe_file_system.cpp deleted file mode 100644 index 3345e4987..000000000 --- a/src/duckdb/src/common/pipe_file_system.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "duckdb/common/pipe_file_system.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/numeric_utils.hpp" - -namespace duckdb { -class PipeFile : public FileHandle { -public: - explicit PipeFile(unique_ptr child_handle_p) - : FileHandle(pipe_fs, child_handle_p->path, child_handle_p->GetFlags()), - child_handle(std::move(child_handle_p)) { - } - - PipeFileSystem pipe_fs; - unique_ptr child_handle; - -public: - int64_t ReadChunk(void *buffer, int64_t nr_bytes); - int64_t WriteChunk(void *buffer, int64_t nr_bytes); - - void Close() override { - } -}; - -int64_t PipeFile::ReadChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Read(buffer, UnsafeNumericCast(nr_bytes)); -} -int64_t PipeFile::WriteChunk(void *buffer, int64_t nr_bytes) { - return child_handle->Write(buffer, UnsafeNumericCast(nr_bytes)); -} - -void PipeFileSystem::Reset(FileHandle &handle) { - throw InternalException("Cannot reset pipe file system"); -} - -int64_t PipeFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &pipe = handle.Cast(); - return pipe.ReadChunk(buffer, nr_bytes); -} - -int64_t PipeFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &pipe = handle.Cast(); - return pipe.WriteChunk(buffer, nr_bytes); -} - -int64_t PipeFileSystem::GetFileSize(FileHandle &handle) { - return 0; -} - -void PipeFileSystem::FileSync(FileHandle &handle) { -} - -unique_ptr PipeFileSystem::OpenPipe(unique_ptr handle) { - return make_uniq(std::move(handle)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/printer.cpp b/src/duckdb/src/common/printer.cpp deleted file mode 100644 index 0c704b74d..000000000 --- a/src/duckdb/src/common/printer.cpp +++ /dev/null @@ -1,81 +0,0 @@ -#include "duckdb/common/printer.hpp" -#include "duckdb/common/progress_bar/progress_bar.hpp" -#include "duckdb/common/windows_util.hpp" -#include "duckdb/common/windows.hpp" -#include - -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS -#include -#else -#include -#include -#include -#endif -#endif - -namespace duckdb { - -void Printer::RawPrint(OutputStream stream, const string &str) { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - if (IsTerminal(stream)) { - // print utf8 to terminal - auto unicode = WindowsUtil::UTF8ToMBCS(str.c_str()); - fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", unicode.c_str()); - return; - } -#endif - fprintf(stream == OutputStream::STREAM_STDERR ? stderr : stdout, "%s", str.c_str()); -#endif -} - -// LCOV_EXCL_START -void Printer::Print(OutputStream stream, const string &str) { - Printer::RawPrint(stream, str); - Printer::RawPrint(stream, "\n"); -} -void Printer::Flush(OutputStream stream) { -#ifndef DUCKDB_DISABLE_PRINT - fflush(stream == OutputStream::STREAM_STDERR ? stderr : stdout); -#endif -} - -void Printer::Print(const string &str) { - Printer::Print(OutputStream::STREAM_STDERR, str); -} - -bool Printer::IsTerminal(OutputStream stream) { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - auto stream_handle = stream == OutputStream::STREAM_STDERR ? STD_ERROR_HANDLE : STD_OUTPUT_HANDLE; - return GetFileType(GetStdHandle(stream_handle)) == FILE_TYPE_CHAR; -#else - return isatty(stream == OutputStream::STREAM_STDERR ? 2 : 1); -#endif -#else - throw InternalException("IsTerminal called while printing is disabled"); -#endif -} - -idx_t Printer::TerminalWidth() { -#ifndef DUCKDB_DISABLE_PRINT -#ifdef DUCKDB_WINDOWS - CONSOLE_SCREEN_BUFFER_INFO csbi; - int rows; - - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - rows = csbi.srWindow.Right - csbi.srWindow.Left + 1; - return rows; -#else - struct winsize w; - ioctl(0, TIOCGWINSZ, &w); - return w.ws_col; -#endif -#else - throw InternalException("TerminalWidth called while printing is disabled"); -#endif -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/progress_bar/progress_bar.cpp b/src/duckdb/src/common/progress_bar/progress_bar.cpp deleted file mode 100644 index 4a3de66fb..000000000 --- a/src/duckdb/src/common/progress_bar/progress_bar.cpp +++ /dev/null @@ -1,158 +0,0 @@ -#include "duckdb/common/progress_bar/progress_bar.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp" - -namespace duckdb { - -void QueryProgress::Initialize() { - percentage = -1; - rows_processed = 0; - total_rows_to_process = 0; -} - -void QueryProgress::Restart() { - percentage = 0; - rows_processed = 0; - total_rows_to_process = 0; -} - -double QueryProgress::GetPercentage() { - return percentage; -} -uint64_t QueryProgress::GetRowsProcesseed() { - return rows_processed; -} -uint64_t QueryProgress::GetTotalRowsToProcess() { - return total_rows_to_process; -} - -QueryProgress::QueryProgress() { - Initialize(); -} - -QueryProgress &QueryProgress::operator=(const QueryProgress &other) { - if (this != &other) { - percentage = other.percentage.load(); - rows_processed = other.rows_processed.load(); - total_rows_to_process = other.total_rows_to_process.load(); - } - return *this; -} - -QueryProgress::QueryProgress(const QueryProgress &other) { - percentage = other.percentage.load(); - rows_processed = other.rows_processed.load(); - total_rows_to_process = other.total_rows_to_process.load(); -} - -void ProgressBar::SystemOverrideCheck(ClientConfig &config) { - if (config.system_progress_bar_disable_reason != nullptr) { - throw InvalidInputException("Could not change the progress bar setting because: '%s'", - config.system_progress_bar_disable_reason); - } -} - -unique_ptr ProgressBar::DefaultProgressBarDisplay() { - return make_uniq(); -} - -ProgressBar::ProgressBar(Executor &executor, idx_t show_progress_after, - progress_bar_display_create_func_t create_display_func) - : executor(executor), show_progress_after(show_progress_after) { - if (create_display_func) { - display = create_display_func(); - } -} - -QueryProgress ProgressBar::GetDetailedQueryProgress() { - return query_progress; -} - -void ProgressBar::Start() { - profiler.Start(); - query_progress.Initialize(); - supported = true; -} - -bool ProgressBar::PrintEnabled() const { - return display != nullptr; -} - -bool ProgressBar::ShouldPrint(bool final) const { - if (!PrintEnabled()) { - // Don't print progress at all - return false; - } - if (!supported) { - return false; - } - - double elapsed_time = -1.0; - if (elapsed_time < 0.0) { - elapsed_time = profiler.Elapsed(); - } - - auto sufficient_time_elapsed = elapsed_time > static_cast(show_progress_after) / 1000.0; - if (!sufficient_time_elapsed) { - // Don't print yet - return false; - } - if (final) { - // Print the last completed bar - return true; - } - return query_progress.percentage > -1; -} - -void ProgressBar::Update(bool final) { - if (!final && !supported) { - return; - } - - ProgressData progress; - idx_t invalid_pipelines = executor.GetPipelinesProgress(progress); - - double new_percentage = 0.0; - if (invalid_pipelines == 0 && progress.IsValid()) { - if (progress.total > 1e15) { - progress.Normalize(1e15); - } - query_progress.rows_processed = idx_t(progress.done); - query_progress.total_rows_to_process = idx_t(progress.total); - new_percentage = progress.ProgressDone() * 100; - } - - if (!final && invalid_pipelines > 0) { - return; - } - - if (new_percentage > query_progress.percentage) { - query_progress.percentage = new_percentage; - } - if (ShouldPrint(final)) { - if (final) { - FinishProgressBarPrint(); - } else { - PrintProgress(LossyNumericCast(query_progress.percentage.load())); - } - } -} - -void ProgressBar::PrintProgress(int current_percentage_p) { - D_ASSERT(display); - display->Update(current_percentage_p); -} - -void ProgressBar::FinishProgressBarPrint() { - if (finished) { - return; - } - D_ASSERT(display); - display->Finish(); - finished = true; - if (query_progress.percentage == 0) { - query_progress.Initialize(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp b/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp deleted file mode 100644 index 912b8ccce..000000000 --- a/src/duckdb/src/common/progress_bar/terminal_progress_bar_display.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include "duckdb/common/progress_bar/display/terminal_progress_bar_display.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/to_string.hpp" - -namespace duckdb { - -int32_t TerminalProgressBarDisplay::NormalizePercentage(double percentage) { - if (percentage > 100) { - return 100; - } - if (percentage < 0) { - return 0; - } - return int32_t(percentage); -} - -void TerminalProgressBarDisplay::PrintProgressInternal(int32_t percentage) { - string result; - // we divide the number of blocks by the percentage - // 0% = 0 - // 100% = PROGRESS_BAR_WIDTH - // the percentage determines how many blocks we need to draw - double blocks_to_draw = PROGRESS_BAR_WIDTH * (percentage / 100.0); - // because of the power of unicode, we can also draw partial blocks - - // render the percentage with some padding to ensure everything stays nicely aligned - result = "\r"; - if (percentage < 100) { - result += " "; - } - if (percentage < 10) { - result += " "; - } - result += to_string(percentage) + "%"; - result += " "; - result += PROGRESS_START; - idx_t i; - for (i = 0; i < idx_t(blocks_to_draw); i++) { - result += PROGRESS_BLOCK; - } - if (i < PROGRESS_BAR_WIDTH) { - // print a partial block based on the percentage of the progress bar remaining - idx_t index = idx_t((blocks_to_draw - static_cast(idx_t(blocks_to_draw))) * PARTIAL_BLOCK_COUNT); - if (index >= PARTIAL_BLOCK_COUNT) { - index = PARTIAL_BLOCK_COUNT - 1; - } - result += PROGRESS_PARTIAL[index]; - i++; - } - for (; i < PROGRESS_BAR_WIDTH; i++) { - result += PROGRESS_EMPTY; - } - result += PROGRESS_END; - result += " "; - - Printer::RawPrint(OutputStream::STREAM_STDOUT, result); -} - -void TerminalProgressBarDisplay::Update(double percentage) { - auto percentage_int = NormalizePercentage(percentage); - if (percentage_int == rendered_percentage) { - return; - } - PrintProgressInternal(percentage_int); - Printer::Flush(OutputStream::STREAM_STDOUT); - rendered_percentage = percentage_int; -} - -void TerminalProgressBarDisplay::Finish() { - PrintProgressInternal(100); - Printer::RawPrint(OutputStream::STREAM_STDOUT, "\n"); - Printer::Flush(OutputStream::STREAM_STDOUT); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/radix_partitioning.cpp b/src/duckdb/src/common/radix_partitioning.cpp deleted file mode 100644 index 8b091a803..000000000 --- a/src/duckdb/src/common/radix_partitioning.cpp +++ /dev/null @@ -1,248 +0,0 @@ -#include "duckdb/common/radix_partitioning.hpp" - -#include "duckdb/common/types/column/partitioned_column_data.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" - -namespace duckdb { - -//! Templated radix partitioning constants, can be templated to the number of radix bits -template -struct RadixPartitioningConstants { -public: - //! Bitmask of the upper bits starting at the 5th byte - static constexpr idx_t NUM_PARTITIONS = RadixPartitioning::NumberOfPartitions(radix_bits); - static constexpr idx_t SHIFT = RadixPartitioning::Shift(radix_bits); - static constexpr hash_t MASK = RadixPartitioning::Mask(radix_bits); - -public: - //! Apply bitmask and right shift to get a number between 0 and NUM_PARTITIONS - static hash_t ApplyMask(const hash_t hash) { - D_ASSERT((hash & MASK) >> SHIFT < NUM_PARTITIONS); - return (hash & MASK) >> SHIFT; - } -}; - -template -RETURN_TYPE RadixBitsSwitch(const idx_t radix_bits, ARGS &&... args) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - switch (radix_bits) { - case 0: - return OP::template Operation<0>(std::forward(args)...); - case 1: - return OP::template Operation<1>(std::forward(args)...); - case 2: - return OP::template Operation<2>(std::forward(args)...); - case 3: - return OP::template Operation<3>(std::forward(args)...); - case 4: - return OP::template Operation<4>(std::forward(args)...); - case 5: // LCOV_EXCL_START - return OP::template Operation<5>(std::forward(args)...); - case 6: - return OP::template Operation<6>(std::forward(args)...); - case 7: - return OP::template Operation<7>(std::forward(args)...); - case 8: - return OP::template Operation<8>(std::forward(args)...); - case 9: - return OP::template Operation<9>(std::forward(args)...); - case 10: - return OP::template Operation<10>(std::forward(args)...); - case 11: - return OP::template Operation<10>(std::forward(args)...); - case 12: - return OP::template Operation<10>(std::forward(args)...); - default: - throw InternalException( - "radix_bits higher than RadixPartitioning::MAX_RADIX_BITS encountered in RadixBitsSwitch"); - } // LCOV_EXCL_STOP -} - -struct SelectFunctor { - template - static idx_t Operation(Vector &hashes, const SelectionVector *sel, const idx_t count, - const ValidityMask &partition_mask, SelectionVector *true_sel, SelectionVector *false_sel) { - using CONSTANTS = RadixPartitioningConstants; - return UnaryExecutor::Select( - hashes, sel, count, - [&](const hash_t hash) { - const auto partition_idx = CONSTANTS::ApplyMask(hash); - return partition_mask.RowIsValidUnsafe(partition_idx); - }, - true_sel, false_sel); - } -}; - -idx_t RadixPartitioning::Select(Vector &hashes, const SelectionVector *sel, const idx_t count, const idx_t radix_bits, - const ValidityMask &partition_mask, SelectionVector *true_sel, - SelectionVector *false_sel) { - return RadixBitsSwitch(radix_bits, hashes, sel, count, partition_mask, true_sel, false_sel); -} - -struct ComputePartitionIndicesFunctor { - template - static void Operation(Vector &hashes, Vector &partition_indices, const SelectionVector &append_sel, - const idx_t append_count) { - using CONSTANTS = RadixPartitioningConstants; - if (append_sel.IsSet()) { - auto hashes_sliced = Vector(hashes, append_sel, append_count); - UnaryExecutor::Execute(hashes_sliced, partition_indices, append_count, - [&](hash_t hash) { return CONSTANTS::ApplyMask(hash); }); - } else { - UnaryExecutor::Execute(hashes, partition_indices, append_count, - [&](hash_t hash) { return CONSTANTS::ApplyMask(hash); }); - } - } -}; - -//===--------------------------------------------------------------------===// -// Column Data Partitioning -//===--------------------------------------------------------------------===// -RadixPartitionedColumnData::RadixPartitionedColumnData(ClientContext &context_p, vector types_p, - idx_t radix_bits_p, idx_t hash_col_idx_p) - : PartitionedColumnData(PartitionedColumnDataType::RADIX, context_p, std::move(types_p)), radix_bits(radix_bits_p), - hash_col_idx(hash_col_idx_p) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - D_ASSERT(hash_col_idx < types.size()); - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - allocators->allocators.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - CreateAllocator(); - allocators->allocators.back()->SetPartitionIndex(i); - } - D_ASSERT(allocators->allocators.size() == num_partitions); -} - -RadixPartitionedColumnData::RadixPartitionedColumnData(const RadixPartitionedColumnData &other) - : PartitionedColumnData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { - for (idx_t i = 0; i < RadixPartitioning::NumberOfPartitions(radix_bits); i++) { - partitions.emplace_back(CreatePartitionCollection(i)); - } -} - -RadixPartitionedColumnData::~RadixPartitionedColumnData() { -} - -void RadixPartitionedColumnData::InitializeAppendStateInternal(PartitionedColumnDataAppendState &state) const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - state.partition_append_states.reserve(num_partitions); - state.partition_buffers.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - state.partition_append_states.emplace_back(make_uniq()); - partitions[i]->InitializeAppend(*state.partition_append_states[i]); - state.partition_buffers.emplace_back(CreatePartitionBuffer()); - } - - // Initialize fixed-size map - state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); -} - -void RadixPartitionedColumnData::ComputePartitionIndices(PartitionedColumnDataAppendState &state, DataChunk &input) { - D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - D_ASSERT(state.partition_buffers.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - *FlatVector::IncrementalSelectionVector(), input.size()); -} - -//===--------------------------------------------------------------------===// -// Tuple Data Partitioning -//===--------------------------------------------------------------------===// -RadixPartitionedTupleData::RadixPartitionedTupleData(BufferManager &buffer_manager, const TupleDataLayout &layout_p, - const idx_t radix_bits_p, const idx_t hash_col_idx_p) - : PartitionedTupleData(PartitionedTupleDataType::RADIX, buffer_manager, layout_p.Copy()), radix_bits(radix_bits_p), - hash_col_idx(hash_col_idx_p) { - D_ASSERT(radix_bits <= RadixPartitioning::MAX_RADIX_BITS); - D_ASSERT(hash_col_idx < layout.GetTypes().size()); - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - allocators->allocators.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - CreateAllocator(); - } - D_ASSERT(allocators->allocators.size() == num_partitions); - Initialize(); -} - -RadixPartitionedTupleData::RadixPartitionedTupleData(const RadixPartitionedTupleData &other) - : PartitionedTupleData(other), radix_bits(other.radix_bits), hash_col_idx(other.hash_col_idx) { - Initialize(); -} - -RadixPartitionedTupleData::~RadixPartitionedTupleData() { -} - -void RadixPartitionedTupleData::Initialize() { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - for (idx_t i = 0; i < num_partitions; i++) { - partitions.emplace_back(CreatePartitionCollection(i)); - partitions.back()->SetPartitionIndex(i); - } -} - -void RadixPartitionedTupleData::InitializeAppendStateInternal(PartitionedTupleDataAppendState &state, - const TupleDataPinProperties properties) const { - // Init pin state per partition - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - state.partition_pin_states.reserve(num_partitions); - for (idx_t i = 0; i < num_partitions; i++) { - state.partition_pin_states.emplace_back(make_unsafe_uniq()); - partitions[i]->InitializeAppend(*state.partition_pin_states[i], properties); - } - - // Init single chunk state - auto column_count = layout.ColumnCount(); - vector column_ids; - column_ids.reserve(column_count); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - column_ids.emplace_back(col_idx); - } - partitions[0]->InitializeChunkState(state.chunk_state, std::move(column_ids)); - - // Initialize fixed-size map - state.fixed_partition_entries.resize(RadixPartitioning::NumberOfPartitions(radix_bits)); -} - -void RadixPartitionedTupleData::ComputePartitionIndices(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel, const idx_t append_count) { - D_ASSERT(partitions.size() == RadixPartitioning::NumberOfPartitions(radix_bits)); - RadixBitsSwitch(radix_bits, input.data[hash_col_idx], state.partition_indices, - append_sel, append_count); -} - -void RadixPartitionedTupleData::ComputePartitionIndices(Vector &row_locations, idx_t count, - Vector &partition_indices) const { - Vector intermediate(LogicalType::HASH); - partitions[0]->Gather(row_locations, *FlatVector::IncrementalSelectionVector(), count, hash_col_idx, intermediate, - *FlatVector::IncrementalSelectionVector(), nullptr); - RadixBitsSwitch(radix_bits, intermediate, partition_indices, - *FlatVector::IncrementalSelectionVector(), count); -} - -void RadixPartitionedTupleData::RepartitionFinalizeStates(PartitionedTupleData &old_partitioned_data, - PartitionedTupleData &new_partitioned_data, - PartitionedTupleDataAppendState &state, - idx_t finished_partition_idx) const { - D_ASSERT(old_partitioned_data.GetType() == PartitionedTupleDataType::RADIX && - new_partitioned_data.GetType() == PartitionedTupleDataType::RADIX); - const auto &old_radix_partitions = old_partitioned_data.Cast(); - const auto &new_radix_partitions = new_partitioned_data.Cast(); - const auto old_radix_bits = old_radix_partitions.GetRadixBits(); - const auto new_radix_bits = new_radix_partitions.GetRadixBits(); - D_ASSERT(new_radix_bits > old_radix_bits); - - // We take the most significant digits as the partition index - // When repartitioning, e.g., partition 0 from "old" goes into the first N partitions in "new" - // When partition 0 is done, we can already finalize the append states, unpinning blocks - const auto multiplier = RadixPartitioning::NumberOfPartitions(new_radix_bits - old_radix_bits); - const auto from_idx = finished_partition_idx * multiplier; - const auto to_idx = from_idx + multiplier; - auto &partitions = new_partitioned_data.GetPartitions(); - for (idx_t partition_index = from_idx; partition_index < to_idx; partition_index++) { - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - partition.FinalizePinState(partition_pin_state); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/random_engine.cpp b/src/duckdb/src/common/random_engine.cpp deleted file mode 100644 index cf558ea7a..000000000 --- a/src/duckdb/src/common/random_engine.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "duckdb/common/random_engine.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "pcg_random.hpp" - -#ifdef __linux__ -#include -#include -#else -#include -#endif -namespace duckdb { - -struct RandomState { - RandomState() { - } - - pcg32 pcg; -}; - -RandomEngine::RandomEngine(int64_t seed) : random_state(make_uniq()) { - if (seed < 0) { -#ifdef __linux__ - idx_t random_seed = 0; - auto result = syscall(SYS_getrandom, &random_seed, sizeof(random_seed), 0); - if (result == -1) { - // Something went wrong with the syscall, we use chrono - const auto now = std::chrono::high_resolution_clock::now(); - random_seed = now.time_since_epoch().count(); - } - random_state->pcg.seed(random_seed); -#else - random_state->pcg.seed(pcg_extras::seed_seq_from()); -#endif - } else { - random_state->pcg.seed(NumericCast(seed)); - } -} - -RandomEngine::~RandomEngine() { -} - -double RandomEngine::NextRandom(double min, double max) { - D_ASSERT(max >= min); - return min + (NextRandom() * (max - min)); -} - -double RandomEngine::NextRandom() { - auto uint64 = NextRandomInteger64(); - return std::ldexp(uint64, -64); -} - -double RandomEngine::NextRandom32(double min, double max) { - D_ASSERT(max >= min); - return min + (NextRandom32() * (max - min)); -} - -double RandomEngine::NextRandom32() { - auto uint32 = NextRandomInteger(); - return std::ldexp(uint32, -32); -} - -uint32_t RandomEngine::NextRandomInteger() { - return random_state->pcg(); -} - -uint64_t RandomEngine::NextRandomInteger64() { - return (static_cast(NextRandomInteger()) << UINT64_C(32)) | static_cast(NextRandomInteger()); -} - -uint32_t RandomEngine::NextRandomInteger(uint32_t min, uint32_t max) { - return min + static_cast(NextRandom() * double(max - min)); -} - -uint32_t RandomEngine::NextRandomInteger32(uint32_t min, uint32_t max) { - return min + static_cast(NextRandom32() * double(max - min)); -} - -void RandomEngine::SetSeed(uint64_t seed) { - random_state->pcg.seed(seed); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/re2_regex.cpp b/src/duckdb/src/common/re2_regex.cpp deleted file mode 100644 index e934e105c..000000000 --- a/src/duckdb/src/common/re2_regex.cpp +++ /dev/null @@ -1,100 +0,0 @@ -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector.hpp" -#include - -#include "duckdb/common/re2_regex.hpp" -#include "re2/re2.h" - -namespace duckdb_re2 { - -static size_t GetMultibyteCharLength(const char c) { - if ((c & 0x80) == 0) { - return 1; // 1-byte character (ASCII) - } else if ((c & 0xE0) == 0xC0) { - return 2; // 2-byte character - } else if ((c & 0xF0) == 0xE0) { - return 3; // 3-byte character - } else if ((c & 0xF8) == 0xF0) { - return 4; // 4-byte character - } else { - return 0; // invalid UTF-8leading byte - } -} - -Regex::Regex(const std::string &pattern, RegexOptions options) { - RE2::Options o; - o.set_case_sensitive(options == RegexOptions::CASE_INSENSITIVE); - regex = duckdb::make_shared_ptr(StringPiece(pattern), o); -} - -bool RegexSearchInternal(const char *input_data, size_t input_size, Match &match, const RE2 ®ex, RE2::Anchor anchor, - size_t start, size_t end) { - duckdb::vector target_groups; - auto group_count = duckdb::UnsafeNumericCast(regex.NumberOfCapturingGroups() + 1); - target_groups.resize(group_count); - match.groups.clear(); - if (!regex.Match(StringPiece(input_data, input_size), start, end, anchor, target_groups.data(), - duckdb::UnsafeNumericCast(group_count))) { - return false; - } - for (auto &group : target_groups) { - GroupMatch group_match; - group_match.text = group.ToString(); - group_match.position = group.data() != nullptr ? duckdb::NumericCast(group.data() - input_data) : 0; - match.groups.emplace_back(group_match); - } - return true; -} - -bool RegexSearch(const std::string &input, Match &match, const Regex ®ex) { - auto input_sz = input.size(); - return RegexSearchInternal(input.c_str(), input_sz, match, regex.GetRegex(), RE2::UNANCHORED, 0, input_sz); -} - -bool RegexMatch(const std::string &input, Match &match, const Regex ®ex) { - auto input_sz = input.size(); - return RegexSearchInternal(input.c_str(), input_sz, match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, input_sz); -} - -bool RegexMatch(const char *start, const char *end, Match &match, const Regex ®ex) { - auto sz = duckdb::UnsafeNumericCast(end - start); - return RegexSearchInternal(start, sz, match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, sz); -} - -bool RegexMatch(const std::string &input, const Regex ®ex) { - Match nop_match; - auto input_sz = input.size(); - return RegexSearchInternal(input.c_str(), input_sz, nop_match, regex.GetRegex(), RE2::ANCHOR_BOTH, 0, input_sz); -} - -duckdb::vector RegexFindAll(const std::string &input, const Regex ®ex) { - return RegexFindAll(input.c_str(), input.size(), regex.GetRegex()); -} - -duckdb::vector RegexFindAll(const char *input_data, size_t input_size, const RE2 ®ex) { - duckdb::vector matches; - size_t position = 0; - Match match; - while (RegexSearchInternal(input_data, input_size, match, regex, RE2::UNANCHORED, position, input_size)) { - if (match.length(0)) { - position = match.position(0) + match.length(0); - } else { // match.length(0) == 0 - auto next_char_length = GetMultibyteCharLength(input_data[match.position(0)]); - if (!next_char_length) { - throw duckdb::InvalidInputException("Invalid UTF-8 leading byte at position " + - std::to_string(match.position(0) + 1)); - } - if (match.position(0) + next_char_length < input_size) { - position = match.position(0) + next_char_length; - } else { - matches.emplace_back(match); - break; - } - } - matches.emplace_back(match); - } - return matches; -} - -} // namespace duckdb_re2 diff --git a/src/duckdb/src/common/render_tree.cpp b/src/duckdb/src/common/render_tree.cpp deleted file mode 100644 index 582d5e1ad..000000000 --- a/src/duckdb/src/common/render_tree.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include "duckdb/common/render_tree.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" - -namespace duckdb { - -struct PipelineRenderNode { - explicit PipelineRenderNode(const PhysicalOperator &op) : op(op) { - } - - const PhysicalOperator &op; - unique_ptr child; -}; - -} // namespace duckdb - -namespace { - -using duckdb::MaxValue; -using duckdb::PhysicalDelimJoin; -using duckdb::PhysicalOperator; -using duckdb::PhysicalOperatorType; -using duckdb::PhysicalPositionalScan; -using duckdb::PipelineRenderNode; -using duckdb::RenderTreeNode; - -class TreeChildrenIterator { -public: - template - static bool HasChildren(const T &op) { - return !op.children.empty(); - } - template - static void Iterate(const T &op, const std::function &callback) { - for (auto &child : op.children) { - callback(*child); - } - } -}; - -template <> -bool TreeChildrenIterator::HasChildren(const PhysicalOperator &op) { - return !op.GetChildren().empty(); -} -template <> -void TreeChildrenIterator::Iterate(const PhysicalOperator &op, - const std::function &callback) { - for (auto &child : op.GetChildren()) { - callback(child); - } -} - -template <> -bool TreeChildrenIterator::HasChildren(const PipelineRenderNode &op) { - return op.child.get(); -} - -template <> -void TreeChildrenIterator::Iterate(const PipelineRenderNode &op, - const std::function &callback) { - if (op.child) { - callback(*op.child); - } -} - -} // namespace - -namespace duckdb { - -template -static void GetTreeWidthHeight(const T &op, idx_t &width, idx_t &height) { - if (!TreeChildrenIterator::HasChildren(op)) { - width = 1; - height = 1; - return; - } - width = 0; - height = 0; - - TreeChildrenIterator::Iterate(op, [&](const T &child) { - idx_t child_width, child_height; - GetTreeWidthHeight(child, child_width, child_height); - width += child_width; - height = MaxValue(height, child_height); - }); - height++; -} - -static unique_ptr CreateNode(const LogicalOperator &op) { - return make_uniq(op.GetName(), op.ParamsToString()); -} - -static unique_ptr CreateNode(const PhysicalOperator &op) { - return make_uniq(op.GetName(), op.ParamsToString()); -} - -static unique_ptr CreateNode(const PipelineRenderNode &op) { - return CreateNode(op.op); -} - -static unique_ptr CreateNode(const ProfilingNode &op) { - auto &info = op.GetProfilingInfo(); - InsertionOrderPreservingMap extra_info; - if (info.Enabled(info.settings, MetricsType::EXTRA_INFO)) { - extra_info = op.GetProfilingInfo().extra_info; - } - - string node_name = "QUERY"; - if (op.depth > 0) { - node_name = info.GetMetricAsString(MetricsType::OPERATOR_TYPE); - } - - auto result = make_uniq(node_name, extra_info); - if (info.Enabled(info.settings, MetricsType::OPERATOR_CARDINALITY)) { - auto cardinality = info.GetMetricAsString(MetricsType::OPERATOR_CARDINALITY); - result->extra_text[RenderTreeNode::CARDINALITY] = cardinality; - } - if (info.Enabled(info.settings, MetricsType::OPERATOR_TIMING)) { - auto value = info.metrics.at(MetricsType::OPERATOR_TIMING).GetValue(); - string timing = StringUtil::Format("%.2f", value); - result->extra_text[RenderTreeNode::TIMING] = timing + "s"; - } - return result; -} - -template -static idx_t CreateTreeRecursive(RenderTree &result, const T &op, idx_t x, idx_t y) { - auto node = CreateNode(op); - - if (!TreeChildrenIterator::HasChildren(op)) { - result.SetNode(x, y, std::move(node)); - return 1; - } - idx_t width = 0; - // render the children of this node - TreeChildrenIterator::Iterate(op, [&](const T &child) { - auto child_x = x + width; - auto child_y = y + 1; - node->AddChildPosition(child_x, child_y); - width += CreateTreeRecursive(result, child, child_x, child_y); - }); - result.SetNode(x, y, std::move(node)); - return width; -} - -template -static unique_ptr CreateTree(const T &op) { - idx_t width, height; - GetTreeWidthHeight(op, width, height); - - auto result = make_uniq(width, height); - - // now fill in the tree - CreateTreeRecursive(*result, op, 0, 0); - return result; -} - -RenderTree::RenderTree(idx_t width_p, idx_t height_p) : width(width_p), height(height_p) { - nodes = make_uniq_array>((width + 1) * (height + 1)); -} - -optional_ptr RenderTree::GetNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return nullptr; - } - return nodes[GetPosition(x, y)].get(); -} - -bool RenderTree::HasNode(idx_t x, idx_t y) { - if (x >= width || y >= height) { - return false; - } - return nodes[GetPosition(x, y)].get() != nullptr; -} - -idx_t RenderTree::GetPosition(idx_t x, idx_t y) { - return y * width + x; -} - -void RenderTree::SetNode(idx_t x, idx_t y, unique_ptr node) { - nodes[GetPosition(x, y)] = std::move(node); -} - -unique_ptr RenderTree::CreateRenderTree(const LogicalOperator &op) { - return CreateTree(op); -} - -unique_ptr RenderTree::CreateRenderTree(const PhysicalOperator &op) { - return CreateTree(op); -} - -unique_ptr RenderTree::CreateRenderTree(const ProfilingNode &op) { - return CreateTree(op); -} - -void RenderTree::SanitizeKeyNames() { - for (idx_t i = 0; i < width * height; i++) { - if (!nodes[i]) { - continue; - } - InsertionOrderPreservingMap new_map; - for (auto &entry : nodes[i]->extra_text) { - auto key = entry.first; - if (StringUtil::StartsWith(key, "__")) { - key = StringUtil::Replace(key, "__", ""); - key = StringUtil::Replace(key, "_", " "); - key = StringUtil::Title(key); - } - auto &value = entry.second; - new_map.insert(make_pair(key, value)); - } - nodes[i]->extra_text = std::move(new_map); - } -} - -unique_ptr RenderTree::CreateRenderTree(const Pipeline &pipeline) { - auto operators = pipeline.GetOperators(); - D_ASSERT(!operators.empty()); - unique_ptr node; - for (auto &op : operators) { - auto new_node = make_uniq(op.get()); - new_node->child = std::move(node); - node = std::move(new_node); - } - return CreateTree(*node); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_aggregate.cpp b/src/duckdb/src/common/row_operations/row_aggregate.cpp deleted file mode 100644 index fb433eb5b..000000000 --- a/src/duckdb/src/common/row_operations/row_aggregate.cpp +++ /dev/null @@ -1,123 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_operations/row_aggregate.cpp -// -// -//===----------------------------------------------------------------------===// -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" - -namespace duckdb { - -void RowOperations::InitializeStates(TupleDataLayout &layout, Vector &addresses, const SelectionVector &sel, - idx_t count) { - if (count == 0) { - return; - } - auto pointers = FlatVector::GetData(addresses); - auto &offsets = layout.GetOffsets(); - auto aggr_idx = layout.ColumnCount(); - - for (const auto &aggr : layout.GetAggregates()) { - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = pointers[row_idx]; - aggr.function.initialize(aggr.function, row + offsets[aggr_idx]); - } - ++aggr_idx; - } -} - -void RowOperations::DestroyStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, idx_t count) { - if (count == 0) { - return; - } - // Move to the first aggregate state - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(layout.GetAggrOffset()), count); - for (const auto &aggr : layout.GetAggregates()) { - if (aggr.function.destructor) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.destructor(addresses, aggr_input_data, count); - } - // Move to the next aggregate state - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggr.payload_size), count); - } -} - -void RowOperations::UpdateStates(RowOperationsState &state, AggregateObject &aggr, Vector &addresses, - DataChunk &payload, idx_t arg_idx, idx_t count) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.update(aggr.child_count == 0 ? nullptr : &payload.data[arg_idx], aggr_input_data, aggr.child_count, - addresses, count); -} - -void RowOperations::UpdateFilteredStates(RowOperationsState &state, AggregateFilterData &filter_data, - AggregateObject &aggr, Vector &addresses, DataChunk &payload, idx_t arg_idx) { - idx_t count = filter_data.ApplyFilter(payload); - if (count == 0) { - return; - } - - Vector filtered_addresses(addresses, filter_data.true_sel, count); - filtered_addresses.Flatten(count); - - UpdateStates(state, aggr, filtered_addresses, filter_data.filtered_payload, arg_idx, count); -} - -void RowOperations::CombineStates(RowOperationsState &state, TupleDataLayout &layout, Vector &sources, Vector &targets, - idx_t count) { - if (count == 0) { - return; - } - - // Move to the first aggregate states - VectorOperations::AddInPlace(sources, UnsafeNumericCast(layout.GetAggrOffset()), count); - VectorOperations::AddInPlace(targets, UnsafeNumericCast(layout.GetAggrOffset()), count); - - // Keep track of the offset - idx_t offset = layout.GetAggrOffset(); - - for (auto &aggr : layout.GetAggregates()) { - D_ASSERT(aggr.function.combine); - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); - aggr.function.combine(sources, targets, aggr_input_data, count); - - // Move to the next aggregate states - VectorOperations::AddInPlace(sources, UnsafeNumericCast(aggr.payload_size), count); - VectorOperations::AddInPlace(targets, UnsafeNumericCast(aggr.payload_size), count); - - // Increment the offset - offset += aggr.payload_size; - } - - // Now subtract the offset to get back to the original position - VectorOperations::AddInPlace(sources, -UnsafeNumericCast(offset), count); - VectorOperations::AddInPlace(targets, -UnsafeNumericCast(offset), count); -} - -void RowOperations::FinalizeStates(RowOperationsState &state, TupleDataLayout &layout, Vector &addresses, - DataChunk &result, idx_t aggr_idx) { - // Copy the addresses - Vector addresses_copy(LogicalType::POINTER); - VectorOperations::Copy(addresses, addresses_copy, result.size(), 0, 0); - - // Move to the first aggregate state - VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(layout.GetAggrOffset()), result.size()); - - auto &aggregates = layout.GetAggregates(); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &target = result.data[aggr_idx + i]; - auto &aggr = aggregates[i]; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), state.allocator); - aggr.function.finalize(addresses_copy, aggr_input_data, target, result.size(), 0); - - // Move to the next aggregate state - VectorOperations::AddInPlace(addresses_copy, UnsafeNumericCast(aggr.payload_size), result.size()); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_external.cpp b/src/duckdb/src/common/row_operations/row_external.cpp deleted file mode 100644 index 161fa7501..000000000 --- a/src/duckdb/src/common/row_operations/row_external.cpp +++ /dev/null @@ -1,164 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_operations/row_external.cpp -// -// -//===----------------------------------------------------------------------===// -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_layout.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -void RowOperations::SwizzleColumns(const RowLayout &layout, const data_ptr_t base_row_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Load heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = Load(heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string pointer with the within-row offset (if not inlined) - Store(UnsafeNumericCast(Load(string_ptr) - heap_row_ptrs[i]), - string_ptr); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data pointer with the within-row offset - Store(UnsafeNumericCast(Load(col_ptr) - heap_row_ptrs[i]), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -void RowOperations::SwizzleHeapPointer(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - const idx_t count, const idx_t base_offset) { - const idx_t row_width = layout.GetRowWidth(); - row_ptr += layout.GetHeapOffset(); - idx_t cumulative_offset = 0; - for (idx_t i = 0; i < count; i++) { - Store(base_offset + cumulative_offset, row_ptr); - cumulative_offset += Load(heap_base_ptr + cumulative_offset); - row_ptr += row_width; - } -} - -void RowOperations::CopyHeapAndSwizzle(const RowLayout &layout, data_ptr_t row_ptr, const data_ptr_t heap_base_ptr, - data_ptr_t heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - const auto heap_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - // Figure out source and size - const auto source_heap_ptr = Load(row_ptr + heap_offset); - const auto size = Load(source_heap_ptr); - D_ASSERT(size >= sizeof(uint32_t)); - - // Copy and swizzle - memcpy(heap_ptr, source_heap_ptr, size); - Store(UnsafeNumericCast(heap_ptr - heap_base_ptr), row_ptr + heap_offset); - - // Increment for next iteration - row_ptr += row_width; - heap_ptr += size; - } -} - -void RowOperations::UnswizzleHeapPointer(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const auto row_width = layout.GetRowWidth(); - data_ptr_t heap_ptr_ptr = base_row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - Store(base_heap_ptr + Load(heap_ptr_ptr), heap_ptr_ptr); - heap_ptr_ptr += row_width; - } -} - -static inline void VerifyUnswizzledString(const RowLayout &layout, const idx_t &col_idx, const data_ptr_t &row_ptr) { -#ifdef DEBUG - if (layout.GetTypes()[col_idx].id() != LogicalTypeId::VARCHAR) { - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - ValidityBytes row_mask(row_ptr, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - auto str = Load(row_ptr + layout.GetOffsets()[col_idx]); - str.Verify(); - } -#endif -} - -void RowOperations::UnswizzlePointers(const RowLayout &layout, const data_ptr_t base_row_ptr, - const data_ptr_t base_heap_ptr, const idx_t count) { - const idx_t row_width = layout.GetRowWidth(); - data_ptr_t heap_row_ptrs[STANDARD_VECTOR_SIZE]; - idx_t done = 0; - while (done != count) { - const idx_t next = MinValue(count - done, STANDARD_VECTOR_SIZE); - const data_ptr_t row_ptr = base_row_ptr + done * row_width; - // Restore heap row pointers - data_ptr_t heap_ptr_ptr = row_ptr + layout.GetHeapOffset(); - for (idx_t i = 0; i < next; i++) { - heap_row_ptrs[i] = base_heap_ptr + Load(heap_ptr_ptr); - Store(heap_row_ptrs[i], heap_ptr_ptr); - heap_ptr_ptr += row_width; - } - // Loop through the blob columns - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto physical_type = layout.GetTypes()[col_idx].InternalType(); - if (TypeIsConstantSize(physical_type)) { - continue; - } - data_ptr_t col_ptr = row_ptr + layout.GetOffsets()[col_idx]; - if (physical_type == PhysicalType::VARCHAR) { - data_ptr_t string_ptr = col_ptr + string_t::HEADER_SIZE; - for (idx_t i = 0; i < next; i++) { - if (Load(col_ptr) > string_t::INLINE_LENGTH) { - // Overwrite the string offset with the pointer (if not inlined) - Store(heap_row_ptrs[i] + Load(string_ptr), string_ptr); - VerifyUnswizzledString(layout, col_idx, row_ptr + i * row_width); - } - col_ptr += row_width; - string_ptr += row_width; - } - } else { - // Non-varchar blob columns - for (idx_t i = 0; i < next; i++) { - // Overwrite the column data offset with the pointer - Store(heap_row_ptrs[i] + Load(col_ptr), col_ptr); - col_ptr += row_width; - } - } - } - done += next; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_gather.cpp b/src/duckdb/src/common/row_operations/row_gather.cpp deleted file mode 100644 index 40743279b..000000000 --- a/src/duckdb/src/common/row_operations/row_gather.cpp +++ /dev/null @@ -1,181 +0,0 @@ -//===--------------------------------------------------------------------===// -// row_gather.cpp -// Description: This file contains the implementation of the gather operators -//===--------------------------------------------------------------------===// - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/constant_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/row/tuple_data_layout.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedGatherLoop(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - idx_t build_size) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - data[col_idx] = Load(row + col_offset); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } - } -} - -static void GatherVarchar(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - idx_t count, const RowLayout &layout, idx_t col_no, idx_t build_size, - data_ptr_t base_heap_ptr) { - // Precompute mask indexes - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_no, entry_idx, idx_in_entry); - - auto ptrs = FlatVector::GetData(rows); - auto data = FlatVector::GetData(col); - auto &col_mask = FlatVector::Validity(col); - - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - auto col_idx = col_sel.get_index(i); - auto col_ptr = row + col_offset; - data[col_idx] = Load(col_ptr); - ValidityBytes row_mask(row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - if (build_size > STANDARD_VECTOR_SIZE && col_mask.AllValid()) { - //! We need to initialize the mask with the vector size. - col_mask.Initialize(build_size); - } - col_mask.SetInvalid(col_idx); - } else if (base_heap_ptr && Load(col_ptr) > string_t::INLINE_LENGTH) { - // Not inline, so unswizzle the copied pointer the pointer - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - auto string_ptr = data_ptr_t(data + col_idx) + string_t::HEADER_SIZE; - Store(heap_row_ptr + Load(string_ptr), string_ptr); -#ifdef DEBUG - data[col_idx].Verify(); -#endif - } - } -} - -static void GatherNestedVector(Vector &rows, const SelectionVector &row_sel, Vector &col, - const SelectionVector &col_sel, idx_t count, const RowLayout &layout, idx_t col_no, - data_ptr_t base_heap_ptr) { - const auto &offsets = layout.GetOffsets(); - const auto col_offset = offsets[col_no]; - const auto heap_offset = layout.GetHeapOffset(); - auto ptrs = FlatVector::GetData(rows); - - // Build the gather locations - auto data_locations = make_unsafe_uniq_array_uninitialized(count); - auto mask_locations = make_unsafe_uniq_array_uninitialized(count); - for (idx_t i = 0; i < count; i++) { - auto row_idx = row_sel.get_index(i); - auto row = ptrs[row_idx]; - mask_locations[i] = row; - auto col_ptr = ptrs[row_idx] + col_offset; - if (base_heap_ptr) { - auto heap_ptr_ptr = row + heap_offset; - auto heap_row_ptr = base_heap_ptr + Load(heap_ptr_ptr); - data_locations[i] = heap_row_ptr + Load(col_ptr); - } else { - data_locations[i] = Load(col_ptr); - } - } - - // Deserialise into the selected locations - NestedValidity parent_validity(mask_locations.get(), col_no); - RowOperations::HeapGather(col, count, col_sel, data_locations.get(), &parent_validity); -} - -void RowOperations::Gather(Vector &rows, const SelectionVector &row_sel, Vector &col, const SelectionVector &col_sel, - const idx_t count, const RowLayout &layout, const idx_t col_no, const idx_t build_size, - data_ptr_t heap_ptr) { - D_ASSERT(rows.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(rows.GetType().id() == LogicalTypeId::POINTER); // "Cannot gather from non-pointer type!" - - col.SetVectorType(VectorType::FLAT_VECTOR); - switch (col.GetType().InternalType()) { - case PhysicalType::UINT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::UINT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT16: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT32: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT64: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INT128: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::FLOAT: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::DOUBLE: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::INTERVAL: - TemplatedGatherLoop(rows, row_sel, col, col_sel, count, layout, col_no, build_size); - break; - case PhysicalType::VARCHAR: - GatherVarchar(rows, row_sel, col, col_sel, count, layout, col_no, build_size, heap_ptr); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - GatherNestedVector(rows, row_sel, col, col_sel, count, layout, col_no, heap_ptr); - break; - default: - throw InternalException("Unimplemented type for RowOperations::Gather"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_gather.cpp b/src/duckdb/src/common/row_operations/row_heap_gather.cpp deleted file mode 100644 index fa433c64e..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_gather.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -template -static void TemplatedHeapGather(Vector &v, const idx_t count, const SelectionVector &sel, data_ptr_t *key_locations) { - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < count; ++i) { - const auto col_idx = sel.get_index(i); - target[col_idx] = Load(key_locations[i]); - key_locations[i] += sizeof(T); - } -} - -static void HeapGatherStringVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - auto target = FlatVector::GetData(v); - - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - auto len = Load(key_locations[i]); - key_locations[i] += sizeof(uint32_t); - target[col_idx] = StringVector::AddStringOrBlob(v, string_t(const_char_ptr_cast(key_locations[i]), len)); - key_locations[i] += len; - } -} - -static void HeapGatherStructVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // struct must have a validitymask for its fields - auto &child_types = StructType::GetChildTypes(v.GetType()); - const idx_t struct_validitymask_size = (child_types.size() + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < vcount; i++) { - // use key_locations as the validitymask, and create struct_key_locations - struct_validitymask_locations[i] = key_locations[i]; - key_locations[i] += struct_validitymask_size; - } - - // now deserialize into the struct vectors - auto &children = StructVector::GetEntries(v); - for (idx_t i = 0; i < child_types.size(); i++) { - NestedValidity parent_validity(struct_validitymask_locations, i); - RowOperations::HeapGather(*children[i], vcount, sel, key_locations, &parent_validity); - } -} - -static void HeapGatherListVector(Vector &v, const idx_t vcount, const SelectionVector &sel, data_ptr_t *key_locations) { - const auto &validity = FlatVector::Validity(v); - - auto child_type = ListType::GetChildType(v.GetType()); - auto list_data = ListVector::GetData(v); - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - uint64_t entry_offset = ListVector::GetListSize(v); - for (idx_t i = 0; i < vcount; i++) { - const auto col_idx = sel.get_index(i); - if (!validity.RowIsValid(col_idx)) { - continue; - } - // read list length - auto entry_remaining = Load(key_locations[i]); - key_locations[i] += sizeof(uint64_t); - // set list entry attributes - list_data[col_idx].length = entry_remaining; - list_data[col_idx].offset = entry_offset; - // skip over the validity mask - data_ptr_t validitymask_location = key_locations[i]; - idx_t offset_in_byte = 0; - key_locations[i] += (entry_remaining + 7) / 8; - // entry sizes - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type.InternalType())) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += entry_remaining * sizeof(idx_t); - } - - // now read the list data - while (entry_remaining > 0) { - auto next = MinValue(entry_remaining, (idx_t)STANDARD_VECTOR_SIZE); - - // initialize a new vector to append - Vector append_vector(v.GetType()); - append_vector.SetVectorType(v.GetVectorType()); - - auto &list_vec_to_append = ListVector::GetEntry(append_vector); - - // set validity - //! Since we are constructing the vector, this will always be a flat vector. - auto &append_validity = FlatVector::Validity(list_vec_to_append); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - append_validity.Set(entry_idx, *(validitymask_location) & (1 << offset_in_byte)); - if (++offset_in_byte == 8) { - validitymask_location++; - offset_in_byte = 0; - } - } - - // compute entry sizes and set locations where the list entries are - if (TypeIsConstantSize(child_type.InternalType())) { - // constant size list entries - const idx_t type_size = GetTypeIdSize(child_type.InternalType()); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now deserialize and add to listvector - RowOperations::HeapGather(list_vec_to_append, next, *FlatVector::IncrementalSelectionVector(), - list_entry_locations, nullptr); - ListVector::Append(v, list_vec_to_append, next); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapGatherArrayVector(Vector &v, const idx_t vcount, const SelectionVector &sel, - data_ptr_t *key_locations) { - // Setup - auto &child_type = ArrayType::GetChildType(v.GetType()); - auto array_size = ArrayType::GetSize(v.GetType()); - auto &child_vector = ArrayVector::GetEntry(v); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < vcount; i++) { - // Setup validity mask - data_ptr_t array_validitymask_location = key_locations[i]; - key_locations[i] += array_validitymask_size; - - NestedValidity parent_validity(array_validitymask_location); - - // The size of each variable size entry is stored after the validity mask - // (if the child type is variable size) - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // row idx - const auto row_idx = sel.get_index(i); - - idx_t array_start = row_idx * array_size; - idx_t elem_remaining = array_size; - - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - SelectionVector array_sel(STANDARD_VECTOR_SIZE); - - if (child_type_is_var_size) { - // variable size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += Load(var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } else { - // constant size list entries - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - array_sel.set_index(elem_idx, array_start + elem_idx); - } - } - - // Pass on this array's validity mask to the child vector - RowOperations::HeapGather(child_vector, chunk_size, array_sel, array_entry_locations, &parent_validity); - - elem_remaining -= chunk_size; - array_start += chunk_size; - parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapGather(Vector &v, const idx_t &vcount, const SelectionVector &sel, data_ptr_t *key_locations, - optional_ptr parent_validity) { - v.SetVectorType(VectorType::FLAT_VECTOR); - - auto &validity = FlatVector::Validity(v); - if (parent_validity) { - for (idx_t i = 0; i < vcount; i++) { - const auto valid = parent_validity->IsValid(i); - const auto col_idx = sel.get_index(i); - validity.Set(col_idx, valid); - } - } - - auto type = v.GetType().InternalType(); - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT8: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT16: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT32: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT64: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::UINT128: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::FLOAT: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::DOUBLE: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::INTERVAL: - TemplatedHeapGather(v, vcount, sel, key_locations); - break; - case PhysicalType::VARCHAR: - HeapGatherStringVector(v, vcount, sel, key_locations); - break; - case PhysicalType::STRUCT: - HeapGatherStructVector(v, vcount, sel, key_locations); - break; - case PhysicalType::LIST: - HeapGatherListVector(v, vcount, sel, key_locations); - break; - case PhysicalType::ARRAY: - HeapGatherArrayVector(v, vcount, sel, key_locations); - break; - default: - throw NotImplementedException("Unimplemented deserialize from row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp b/src/duckdb/src/common/row_operations/row_heap_scatter.cpp deleted file mode 100644 index 01cf7b589..000000000 --- a/src/duckdb/src/common/row_operations/row_heap_scatter.cpp +++ /dev/null @@ -1,581 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TemplatedValidityMask; - -NestedValidity::NestedValidity(data_ptr_t validitymask_location) - : list_validity_location(validitymask_location), struct_validity_locations(nullptr), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { -} - -NestedValidity::NestedValidity(data_ptr_t *validitymask_locations, idx_t child_vector_index) - : list_validity_location(nullptr), struct_validity_locations(validitymask_locations), entry_idx(0), idx_in_entry(0), - list_validity_offset(0) { - ValidityBytes::GetEntryIndex(child_vector_index, entry_idx, idx_in_entry); -} - -void NestedValidity::SetInvalid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = ~(1UL << list_idx_in_entry); - list_validity_location[list_entry_idx] &= bit; - } else { - // Is Struct - const auto bit = ~(1UL << idx_in_entry); - *(struct_validity_locations[idx] + entry_idx) &= bit; - } -} - -void NestedValidity::OffsetListBy(idx_t offset) { - list_validity_offset += offset; -} - -bool NestedValidity::IsValid(idx_t idx) { - if (list_validity_location) { - // Is List - - idx = idx + list_validity_offset; - - idx_t list_entry_idx; - idx_t list_idx_in_entry; - ValidityBytes::GetEntryIndex(idx, list_entry_idx, list_idx_in_entry); - const auto bit = (1UL << list_idx_in_entry); - return list_validity_location[list_entry_idx] & bit; - } else { - // Is Struct - const auto bit = (1UL << idx_in_entry); - return *(struct_validity_locations[idx] + entry_idx) & bit; - } -} - -static void ComputeStringEntrySizes(UnifiedVectorFormat &vdata, idx_t entry_sizes[], const idx_t ser_count, - const SelectionVector &sel, const idx_t offset) { - auto strings = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto str_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(str_idx)) { - entry_sizes[i] += sizeof(uint32_t) + strings[str_idx].GetSize(); - } - } -} - -static void ComputeStructEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - // obtain child vectors - idx_t num_children; - auto &children = StructVector::GetEntries(v); - num_children = children.size(); - // add struct validitymask size - const idx_t struct_validitymask_size = (num_children + 7) / 8; - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += struct_validitymask_size; - } - // compute size of child vectors - for (auto &struct_vector : children) { - RowOperations::ComputeEntrySizes(*struct_vector, entry_sizes, vcount, ser_count, sel, offset); - } -} - -static void ComputeListEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto list_entry = list_data[source_idx]; - - // make room for list length, list validitymask - entry_sizes[i] += sizeof(list_entry.length); - entry_sizes[i] += (list_entry.length + 7) / 8; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ListType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += list_entry.length * sizeof(list_entry.length); - } - - // compute size of each the elements in list_entry and sum them - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // compute and add to the total - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t list_idx = 0; list_idx < next; list_idx++) { - entry_sizes[i] += list_entry_sizes[list_idx]; - } - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } - } -} - -static void ComputeArrayEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_vector = ArrayVector::GetEntry(v); - - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - const idx_t array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - - // Validity for the array elements - entry_sizes[i] += array_validitymask_size; - - // serialize size of each entry (if non-constant size) - if (!TypeIsConstantSize(ArrayType::GetChildType(v.GetType()).InternalType())) { - entry_sizes[i] += array_size * sizeof(idx_t); - } - - auto elem_idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(elem_idx + offset); - - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - // the array could span multiple vectors, so we divide it into chunks - while (elem_remaining > 0) { - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // compute and add to the total - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t arr_elem_idx = 0; arr_elem_idx < chunk_size; arr_elem_idx++) { - entry_sizes[i] += array_entry_sizes[arr_elem_idx]; - } - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, UnifiedVectorFormat &vdata, idx_t entry_sizes[], idx_t vcount, - idx_t ser_count, const SelectionVector &sel, idx_t offset) { - const auto physical_type = v.GetType().InternalType(); - if (TypeIsConstantSize(physical_type)) { - const auto type_size = GetTypeIdSize(physical_type); - for (idx_t i = 0; i < ser_count; i++) { - entry_sizes[i] += type_size; - } - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::STRUCT: - ComputeStructEntrySizes(v, entry_sizes, vcount, ser_count, sel, offset); - break; - case PhysicalType::LIST: - ComputeListEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - case PhysicalType::ARRAY: - ComputeArrayEntrySizes(v, vdata, entry_sizes, ser_count, sel, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Column with variable size type %s cannot be serialized to row-format", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::ComputeEntrySizes(Vector &v, idx_t entry_sizes[], idx_t vcount, idx_t ser_count, - const SelectionVector &sel, idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - ComputeEntrySizes(v, vdata, entry_sizes, vcount, ser_count, sel, offset); -} - -template -static void TemplatedHeapScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - - auto target = (T *)key_locations[i]; - Store(source[source_idx], data_ptr_cast(target)); - key_locations[i] += sizeof(T); - - // set the validitymask - if (!vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStringVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto strings = UnifiedVectorFormat::GetData(vdata); - if (!parent_validity) { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } - } - } else { - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (vdata.validity.RowIsValid(source_idx)) { - auto &string_entry = strings[source_idx]; - // store string size - Store(NumericCast(string_entry.GetSize()), key_locations[i]); - key_locations[i] += sizeof(uint32_t); - // store the string - memcpy(key_locations[i], string_entry.GetData(), string_entry.GetSize()); - key_locations[i] += string_entry.GetSize(); - } else { - // set the validitymask - parent_validity->SetInvalid(i); - } - } - } -} - -static void HeapScatterStructVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto &children = StructVector::GetEntries(v); - idx_t num_children = children.size(); - - // struct must have a validitymask for its fields - const idx_t struct_validitymask_size = (num_children + 7) / 8; - data_ptr_t struct_validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < ser_count; i++) { - // initialize the struct validity mask - struct_validitymask_locations[i] = key_locations[i]; - memset(struct_validitymask_locations[i], -1, struct_validitymask_size); - key_locations[i] += struct_validitymask_size; - - // set whether the whole struct is null - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - } - - // now serialize the struct vectors - for (idx_t i = 0; i < children.size(); i++) { - auto &struct_vector = *children[i]; - NestedValidity struct_validity(struct_validitymask_locations, i); - RowOperations::HeapScatter(struct_vector, vcount, sel, ser_count, key_locations, &struct_validity, offset); - } -} - -static void HeapScatterListVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - - UnifiedVectorFormat list_vdata; - child_vector.ToUnifiedFormat(ListVector::GetListSize(v), list_vdata); - auto child_type = ListType::GetChildType(v.GetType()).InternalType(); - - idx_t list_entry_sizes[STANDARD_VECTOR_SIZE]; - data_ptr_t list_entry_locations[STANDARD_VECTOR_SIZE]; - - for (idx_t i = 0; i < ser_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx + offset); - if (!vdata.validity.RowIsValid(source_idx)) { - if (parent_validity) { - // set the row validitymask for this column to invalid - parent_validity->SetInvalid(i); - } - continue; - } - auto list_entry = list_data[source_idx]; - - // store list length - Store(list_entry.length, key_locations[i]); - key_locations[i] += sizeof(list_entry.length); - - // make room for the validitymask - data_ptr_t list_validitymask_location = key_locations[i]; - idx_t entry_offset_in_byte = 0; - idx_t validitymask_size = (list_entry.length + 7) / 8; - memset(list_validitymask_location, -1, validitymask_size); - key_locations[i] += validitymask_size; - - // serialize size of each entry (if non-constant size) - data_ptr_t var_entry_size_ptr = nullptr; - if (!TypeIsConstantSize(child_type)) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += list_entry.length * sizeof(idx_t); - } - - auto entry_remaining = list_entry.length; - auto entry_offset = list_entry.offset; - while (entry_remaining > 0) { - // the list entry can span multiple vectors - auto next = MinValue((idx_t)STANDARD_VECTOR_SIZE, entry_remaining); - - // serialize list validity - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - auto list_idx = list_vdata.sel->get_index(entry_idx + entry_offset); - if (!list_vdata.validity.RowIsValid(list_idx)) { - *(list_validitymask_location) &= ~(1UL << entry_offset_in_byte); - } - if (++entry_offset_in_byte == 8) { - list_validitymask_location++; - entry_offset_in_byte = 0; - } - } - - if (TypeIsConstantSize(child_type)) { - // constant size list entries: set list entry locations - const idx_t type_size = GetTypeIdSize(child_type); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += type_size; - } - } else { - // variable size list entries: compute entry sizes and set list entry locations - std::fill_n(list_entry_sizes, next, 0); - RowOperations::ComputeEntrySizes(child_vector, list_entry_sizes, next, next, - *FlatVector::IncrementalSelectionVector(), entry_offset); - for (idx_t entry_idx = 0; entry_idx < next; entry_idx++) { - list_entry_locations[entry_idx] = key_locations[i]; - key_locations[i] += list_entry_sizes[entry_idx]; - Store(list_entry_sizes[entry_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } - - // now serialize to the locations - RowOperations::HeapScatter(child_vector, ListVector::GetListSize(v), - *FlatVector::IncrementalSelectionVector(), next, list_entry_locations, nullptr, - entry_offset); - - // update for next iteration - entry_remaining -= next; - entry_offset += next; - } - } -} - -static void HeapScatterArrayVector(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, - idx_t offset) { - - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - auto child_type = ArrayType::GetChildType(v.GetType()); - auto child_type_size = GetTypeIdSize(child_type.InternalType()); - auto child_type_is_var_size = !TypeIsConstantSize(child_type.InternalType()); - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - - UnifiedVectorFormat child_vdata; - child_vector.ToUnifiedFormat(ArrayVector::GetTotalSize(v), child_vdata); - - data_ptr_t array_entry_locations[STANDARD_VECTOR_SIZE]; - idx_t array_entry_sizes[STANDARD_VECTOR_SIZE]; - - // array must have a validitymask for its elements - auto array_validitymask_size = (array_size + 7) / 8; - - for (idx_t i = 0; i < ser_count; i++) { - // Set if the whole array itself is null in the parent entry - auto source_idx = vdata.sel->get_index(sel.get_index(i) + offset); - if (parent_validity && !vdata.validity.RowIsValid(source_idx)) { - parent_validity->SetInvalid(i); - } - - // Now we can serialize the array itself - // Every array starts with a validity mask for the children - data_ptr_t array_validitymask_location = key_locations[i]; - memset(array_validitymask_location, -1, array_validitymask_size); - key_locations[i] += array_validitymask_size; - - NestedValidity array_parent_validity(array_validitymask_location); - - // If the array contains variable size entries, we reserve spaces for them here - data_ptr_t var_entry_size_ptr = nullptr; - if (child_type_is_var_size) { - var_entry_size_ptr = key_locations[i]; - key_locations[i] += array_size * sizeof(idx_t); - } - - // Then comes the elements - auto array_start = source_idx * array_size; - auto elem_remaining = array_size; - - while (elem_remaining > 0) { - // the array elements can span multiple vectors, so we divide it into chunks - auto chunk_size = MinValue(static_cast(STANDARD_VECTOR_SIZE), elem_remaining); - - // Setup the locations for the elements - if (child_type_is_var_size) { - // The elements are variable sized - std::fill_n(array_entry_sizes, chunk_size, 0); - RowOperations::ComputeEntrySizes(child_vector, array_entry_sizes, chunk_size, chunk_size, - *FlatVector::IncrementalSelectionVector(), array_start); - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += array_entry_sizes[elem_idx]; - - // Now store the size of the entry - Store(array_entry_sizes[elem_idx], var_entry_size_ptr); - var_entry_size_ptr += sizeof(idx_t); - } - } else { - // The elements are constant sized - for (idx_t elem_idx = 0; elem_idx < chunk_size; elem_idx++) { - array_entry_locations[elem_idx] = key_locations[i]; - key_locations[i] += child_type_size; - } - } - - RowOperations::HeapScatter(child_vector, ArrayVector::GetTotalSize(v), - *FlatVector::IncrementalSelectionVector(), chunk_size, array_entry_locations, - &array_parent_validity, array_start); - - // update for next iteration - elem_remaining -= chunk_size; - array_start += chunk_size; - array_parent_validity.OffsetListBy(chunk_size); - } - } -} - -void RowOperations::HeapScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, optional_ptr parent_validity, idx_t offset) { - if (TypeIsConstantSize(v.GetType().InternalType())) { - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - RowOperations::HeapScatterVData(vdata, v.GetType().InternalType(), sel, ser_count, key_locations, - parent_validity, offset); - } else { - switch (v.GetType().InternalType()) { - case PhysicalType::VARCHAR: - HeapScatterStringVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::STRUCT: - HeapScatterStructVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::LIST: - HeapScatterListVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::ARRAY: - HeapScatterArrayVector(v, vcount, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - // LCOV_EXCL_START - throw NotImplementedException("Serialization of variable length vector with type %s", - v.GetType().ToString()); - // LCOV_EXCL_STOP - } - } -} - -void RowOperations::HeapScatterVData(UnifiedVectorFormat &vdata, PhysicalType type, const SelectionVector &sel, - idx_t ser_count, data_ptr_t *key_locations, - optional_ptr parent_validity, idx_t offset) { - switch (type) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT8: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT16: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT32: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT64: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::UINT128: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::FLOAT: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::DOUBLE: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - case PhysicalType::INTERVAL: - TemplatedHeapScatter(vdata, sel, ser_count, key_locations, parent_validity, offset); - break; - default: - throw NotImplementedException("FIXME: Serialize to of constant type column to row-format"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_matcher.cpp b/src/duckdb/src/common/row_operations/row_matcher.cpp deleted file mode 100644 index d08ab9d4e..000000000 --- a/src/duckdb/src/common/row_operations/row_matcher.cpp +++ /dev/null @@ -1,434 +0,0 @@ -#include "duckdb/common/row_operations/row_matcher.hpp" - -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/row/tuple_data_collection.hpp" - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -template -static idx_t TemplatedMatchLoop(const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, - const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, - SelectionVector *no_match_sel, idx_t &no_match_count) { - using COMPARISON_OP = ComparisonOperationWrapper; - - // LHS - const auto &lhs_sel = *lhs_format.unified.sel; - const auto lhs_data = UnifiedVectorFormat::GetData(lhs_format.unified); - const auto &lhs_validity = lhs_format.unified.validity; - - // RHS - const auto rhs_locations = FlatVector::GetData(rhs_row_locations); - const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - idx_t match_count = 0; - for (idx_t i = 0; i < count; i++) { - const auto idx = sel.get_index(i); - - const auto lhs_idx = lhs_sel.get_index(idx); - const auto lhs_null = LHS_ALL_VALID ? false : !lhs_validity.RowIsValid(lhs_idx); - - const auto &rhs_location = rhs_locations[idx]; - const ValidityBytes rhs_mask(rhs_location, rhs_layout.ColumnCount()); - const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); - - if (COMPARISON_OP::template Operation(lhs_data[lhs_idx], Load(rhs_location + rhs_offset_in_row), lhs_null, - rhs_null)) { - sel.set_index(match_count++, idx); - } else if (NO_MATCH_SEL) { - no_match_sel->set_index(no_match_count++, idx); - } - } - return match_count; -} - -template -static idx_t TemplatedMatch(Vector &, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, const idx_t count, - const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, const idx_t col_idx, - const vector &, SelectionVector *no_match_sel, idx_t &no_match_count) { - if (lhs_format.unified.validity.AllValid()) { - return TemplatedMatchLoop(lhs_format, sel, count, rhs_layout, rhs_row_locations, - col_idx, no_match_sel, no_match_count); - } else { - return TemplatedMatchLoop(lhs_format, sel, count, rhs_layout, rhs_row_locations, - col_idx, no_match_sel, no_match_count); - } -} - -template -static idx_t StructMatchEquality(Vector &lhs_vector, const TupleDataVectorFormat &lhs_format, SelectionVector &sel, - const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - const idx_t col_idx, const vector &child_functions, - SelectionVector *no_match_sel, idx_t &no_match_count) { - using COMPARISON_OP = ComparisonOperationWrapper; - - // LHS - const auto &lhs_sel = *lhs_format.unified.sel; - const auto &lhs_validity = lhs_format.unified.validity; - - // RHS - const auto rhs_locations = FlatVector::GetData(rhs_row_locations); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - idx_t match_count = 0; - for (idx_t i = 0; i < count; i++) { - const auto idx = sel.get_index(i); - - const auto lhs_idx = lhs_sel.get_index(idx); - const auto lhs_null = lhs_validity.AllValid() ? false : !lhs_validity.RowIsValid(lhs_idx); - - const auto &rhs_location = rhs_locations[idx]; - const ValidityBytes rhs_mask(rhs_location, rhs_layout.ColumnCount()); - const auto rhs_null = !rhs_mask.RowIsValid(rhs_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry); - - // For structs there is no value to compare, here we match NULLs and let recursion do the rest - // So we use the comparison only if rhs or LHS is NULL and COMPARE_NULL is true - if (!(lhs_null || rhs_null) || - (COMPARISON_OP::COMPARE_NULL && COMPARISON_OP::template Operation(0, 0, lhs_null, rhs_null))) { - sel.set_index(match_count++, idx); - } else if (NO_MATCH_SEL) { - no_match_sel->set_index(no_match_count++, idx); - } - } - - // Create a Vector of pointers to the start of the TupleDataLayout of the STRUCT - Vector rhs_struct_row_locations(LogicalType::POINTER); - const auto rhs_offset_in_row = rhs_layout.GetOffsets()[col_idx]; - auto rhs_struct_locations = FlatVector::GetData(rhs_struct_row_locations); - for (idx_t i = 0; i < match_count; i++) { - const auto idx = sel.get_index(i); - rhs_struct_locations[idx] = rhs_locations[idx] + rhs_offset_in_row; - } - - // Get the struct layout and struct entries - const auto &rhs_struct_layout = rhs_layout.GetStructLayout(col_idx); - auto &lhs_struct_vectors = StructVector::GetEntries(lhs_vector); - D_ASSERT(rhs_struct_layout.ColumnCount() == lhs_struct_vectors.size()); - - for (idx_t struct_col_idx = 0; struct_col_idx < rhs_struct_layout.ColumnCount(); struct_col_idx++) { - auto &lhs_struct_vector = *lhs_struct_vectors[struct_col_idx]; - auto &lhs_struct_format = lhs_format.children[struct_col_idx]; - const auto &child_function = child_functions[struct_col_idx]; - match_count = child_function.function(lhs_struct_vector, lhs_struct_format, sel, match_count, rhs_struct_layout, - rhs_struct_row_locations, struct_col_idx, child_function.child_functions, - no_match_sel, no_match_count); - } - - return match_count; -} - -template -static idx_t SelectComparison(Vector &, Vector &, const SelectionVector &, idx_t, SelectionVector *, - SelectionVector *) { - throw NotImplementedException("Unsupported list comparison operand for RowMatcher::GetMatchFunction"); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedEquals(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NestedNotEquals(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctFrom(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::NotDistinctFrom(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctLessThan(left, right, &sel, count, true_sel, false_sel); -} - -template <> -idx_t SelectComparison(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - return VectorOperations::DistinctLessThanEquals(left, right, &sel, count, true_sel, false_sel); -} - -template -static idx_t GenericNestedMatch(Vector &lhs_vector, const TupleDataVectorFormat &, SelectionVector &sel, - const idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - const idx_t col_idx, const vector &, SelectionVector *no_match_sel, - idx_t &no_match_count) { - const auto &type = rhs_layout.GetTypes()[col_idx]; - - // Gather a dense Vector containing the column values being matched - Vector key(type); - const auto gather_function = TupleDataCollection::GetGatherFunction(type); - gather_function.function(rhs_layout, rhs_row_locations, col_idx, sel, count, key, - *FlatVector::IncrementalSelectionVector(), nullptr, gather_function.child_functions); - Vector::Verify(key, *FlatVector::IncrementalSelectionVector(), count); - - // Densify the input column - Vector sliced(lhs_vector, sel, count); - - if (NO_MATCH_SEL) { - SelectionVector no_match_sel_offset(no_match_sel->data() + no_match_count); - auto match_count = SelectComparison(sliced, key, sel, count, &sel, &no_match_sel_offset); - no_match_count += count - match_count; - return match_count; - } - return SelectComparison(sliced, key, sel, count, &sel, nullptr); -} - -void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates) { - match_functions.reserve(predicates.size()); - for (idx_t col_idx = 0; col_idx < predicates.size(); col_idx++) { - match_functions.push_back(GetMatchFunction(no_match_sel, layout.GetTypes()[col_idx], predicates[col_idx])); - } -} - -void RowMatcher::Initialize(const bool no_match_sel, const TupleDataLayout &layout, const Predicates &predicates, - vector &columns) { - - // The columns must have the same size as the predicates vector - D_ASSERT(columns.size() == predicates.size()); - - // The largest column_id must be smaller than the number of types to not cause an out-of-bounds error - D_ASSERT(*max_element(columns.begin(), columns.end()) < layout.GetTypes().size()); - - match_functions.reserve(predicates.size()); - for (idx_t idx = 0; idx < predicates.size(); idx++) { - column_t col_idx = columns[idx]; - match_functions.push_back(GetMatchFunction(no_match_sel, layout.GetTypes()[col_idx], predicates[idx])); - } -} - -idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, - idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - SelectionVector *no_match_sel, idx_t &no_match_count) { - D_ASSERT(!match_functions.empty()); - for (idx_t col_idx = 0; col_idx < match_functions.size(); col_idx++) { - const auto &match_function = match_functions[col_idx]; - count = - match_function.function(lhs.data[col_idx], lhs_formats[col_idx], sel, count, rhs_layout, rhs_row_locations, - col_idx, match_function.child_functions, no_match_sel, no_match_count); - } - return count; -} - -idx_t RowMatcher::Match(DataChunk &lhs, const vector &lhs_formats, SelectionVector &sel, - idx_t count, const TupleDataLayout &rhs_layout, Vector &rhs_row_locations, - SelectionVector *no_match_sel, idx_t &no_match_count, const vector &columns) { - D_ASSERT(!match_functions.empty()); - - // The column_ids must have the same size as the match_functions vector - D_ASSERT(columns.size() == match_functions.size()); - - // The largest column_id must be smaller than the number columns to not cause an out-of-bounds error - D_ASSERT(*max_element(columns.begin(), columns.end()) < lhs.ColumnCount()); - - for (idx_t fun_idx = 0; fun_idx < match_functions.size(); fun_idx++) { - // if we only care about specific columns, we need to use the column_ids to get the correct column index - // otherwise, we just use the fun_idx - const auto col_idx = columns[fun_idx]; - - const auto &match_function = match_functions[fun_idx]; - count = - match_function.function(lhs.data[col_idx], lhs_formats[col_idx], sel, count, rhs_layout, rhs_row_locations, - col_idx, match_function.child_functions, no_match_sel, no_match_count); - } - return count; -} - -MatchFunction RowMatcher::GetMatchFunction(const bool no_match_sel, const LogicalType &type, - const ExpressionType predicate) { - return no_match_sel ? GetMatchFunction(type, predicate) : GetMatchFunction(type, predicate); -} - -template -MatchFunction RowMatcher::GetMatchFunction(const LogicalType &type, const ExpressionType predicate) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - return GetMatchFunction(predicate); - case PhysicalType::INT8: - return GetMatchFunction(predicate); - case PhysicalType::INT16: - return GetMatchFunction(predicate); - case PhysicalType::INT32: - return GetMatchFunction(predicate); - case PhysicalType::INT64: - return GetMatchFunction(predicate); - case PhysicalType::INT128: - return GetMatchFunction(predicate); - case PhysicalType::UINT8: - return GetMatchFunction(predicate); - case PhysicalType::UINT16: - return GetMatchFunction(predicate); - case PhysicalType::UINT32: - return GetMatchFunction(predicate); - case PhysicalType::UINT64: - return GetMatchFunction(predicate); - case PhysicalType::UINT128: - return GetMatchFunction(predicate); - case PhysicalType::FLOAT: - return GetMatchFunction(predicate); - case PhysicalType::DOUBLE: - return GetMatchFunction(predicate); - case PhysicalType::INTERVAL: - return GetMatchFunction(predicate); - case PhysicalType::VARCHAR: - return GetMatchFunction(predicate); - case PhysicalType::STRUCT: - return GetStructMatchFunction(type, predicate); - case PhysicalType::LIST: - return GetListMatchFunction(predicate); - case PhysicalType::ARRAY: - // Same logic as for lists - return GetListMatchFunction(predicate); - default: - throw InternalException("Unsupported PhysicalType for RowMatcher::GetMatchFunction: %s", - EnumUtil::ToString(type.InternalType())); - } -} - -template -MatchFunction RowMatcher::GetMatchFunction(const ExpressionType predicate) { - MatchFunction result; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_LESSTHAN: - result.function = TemplatedMatch; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = TemplatedMatch; - break; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - return result; -} - -template -MatchFunction RowMatcher::GetStructMatchFunction(const LogicalType &type, const ExpressionType predicate) { - // We perform equality conditions like it's just a row, but we cannot perform inequality conditions like a row, - // because for equality conditions we need to always loop through all columns, but for inequality conditions, - // we need to find the first inequality, so the loop looks very different - MatchFunction result; - ExpressionType child_predicate = predicate; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = StructMatchEquality; - child_predicate = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = StructMatchEquality; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_LESSTHAN: - result.function = GenericNestedMatch; - return result; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = GenericNestedMatch; - return result; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetStructMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - - result.child_functions.reserve(StructType::GetChildCount(type)); - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(GetMatchFunction(child_type.second, child_predicate)); - } - - return result; -} - -template -MatchFunction RowMatcher::GetListMatchFunction(const ExpressionType predicate) { - MatchFunction result; - switch (predicate) { - case ExpressionType::COMPARE_EQUAL: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_NOTEQUAL: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_GREATERTHAN: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_LESSTHAN: - result.function = GenericNestedMatch; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - result.function = GenericNestedMatch; - break; - default: - throw InternalException("Unsupported ExpressionType for RowMatcher::GetListMatchFunction: %s", - EnumUtil::ToString(predicate)); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp b/src/duckdb/src/common/row_operations/row_radix_scatter.cpp deleted file mode 100644 index a85a71997..000000000 --- a/src/duckdb/src/common/row_operations/row_radix_scatter.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/common/helper.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -template -void TemplatedRadixScatter(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeData(key_locations[i] + 1, source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < sizeof(T) + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', sizeof(T)); - } - key_locations[i] += sizeof(T) + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeData(key_locations[i], source[source_idx]); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < sizeof(T); s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += sizeof(T); - } - } -} - -void RadixScatterStringVector(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, idx_t offset) { - auto source = UnifiedVectorFormat::GetData(vdata); - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - Radix::EncodeStringDataPrefix(key_locations[i] + 1, source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 1; s < prefix_len + 1; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', prefix_len); - } - key_locations[i] += prefix_len + 1; - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - // write value - Radix::EncodeStringDataPrefix(key_locations[i], source[source_idx], prefix_len); - // invert bits if desc - if (desc) { - for (idx_t s = 0; s < prefix_len; s++) { - *(key_locations[i] + s) = ~*(key_locations[i] + s); - } - } - key_locations[i] += prefix_len; - } - } -} - -void RadixScatterListVector(Vector &v, UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t add_count, - data_ptr_t *key_locations, const bool desc, const bool has_null, const bool nulls_first, - const idx_t prefix_len, const idx_t width, const idx_t offset) { - auto list_data = ListVector::GetData(v); - auto &child_vector = ListVector::GetEntry(v); - auto list_size = ListVector::GetListSize(v); - child_vector.Flatten(list_size); - - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - // write validity and according value - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - auto &list_entry = list_data[source_idx]; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 2, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 2); - key_location += width - 2; - } - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - auto &list_entry = list_data[source_idx]; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - if (list_entry.length > 0) { - // denote that the list is not empty with a 1 - *key_location++ = 1; - RowOperations::RadixScatter(child_vector, list_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, - list_entry.offset); - } else { - // denote that the list is empty with a 0 - *key_location++ = 0; - // mark rest of bits as empty - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterArrayVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - auto &child_vector = ArrayVector::GetEntry(v); - auto array_size = ArrayType::GetSize(v.GetType()); - - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - if (validity.RowIsValid(source_idx)) { - *key_location++ = valid; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width - 1, array_offset); - - // invert bits if desc - if (desc) { - // skip over validity byte, handled by nulls first/last - for (key_location = key_location_start + 1; key_location < key_location_start + width; - key_location++) { - *key_location = ~*key_location; - } - } - } else { - *key_location++ = invalid; - memset(key_location, '\0', width - 1); - key_location += width - 1; - } - D_ASSERT(key_location == key_location_start + width); - } - } else { - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - data_ptr_t &key_location = key_locations[i]; - const data_ptr_t key_location_start = key_location; - - auto array_offset = source_idx * array_size; - RowOperations::RadixScatter(child_vector, array_size, *FlatVector::IncrementalSelectionVector(), 1, - key_locations + i, false, true, false, prefix_len, width, array_offset); - // invert bits if desc - if (desc) { - for (key_location = key_location_start; key_location < key_location_start + width; key_location++) { - *key_location = ~*key_location; - } - } - D_ASSERT(key_location == key_location_start + width); - } - } -} - -void RadixScatterStructVector(Vector &v, UnifiedVectorFormat &vdata, idx_t vcount, const SelectionVector &sel, - idx_t add_count, data_ptr_t *key_locations, const bool desc, const bool has_null, - const bool nulls_first, const idx_t prefix_len, idx_t width, const idx_t offset) { - // serialize null values - if (has_null) { - auto &validity = vdata.validity; - const data_t valid = nulls_first ? 1 : 0; - const data_t invalid = 1 - valid; - - for (idx_t i = 0; i < add_count; i++) { - auto idx = sel.get_index(i); - auto source_idx = vdata.sel->get_index(idx) + offset; - - // write validity and according value - if (validity.RowIsValid(source_idx)) { - key_locations[i][0] = valid; - } else { - key_locations[i][0] = invalid; - memset(key_locations[i] + 1, '\0', width - 1); - } - key_locations[i]++; - } - width--; - } - // serialize the struct - auto &child_vector = *StructVector::GetEntries(v)[0]; - RowOperations::RadixScatter(child_vector, vcount, *FlatVector::IncrementalSelectionVector(), add_count, - key_locations, false, true, false, prefix_len, width, offset); - // invert bits if desc - if (desc) { - for (idx_t i = 0; i < add_count; i++) { - for (idx_t s = 0; s < width; s++) { - *(key_locations[i] - width + s) = ~*(key_locations[i] - width + s); - } - } - } -} - -void RowOperations::RadixScatter(Vector &v, idx_t vcount, const SelectionVector &sel, idx_t ser_count, - data_ptr_t *key_locations, bool desc, bool has_null, bool nulls_first, - idx_t prefix_len, idx_t width, idx_t offset) { -#ifdef DEBUG - // initialize to verify written width later - auto key_locations_copy = make_uniq_array(ser_count); - for (idx_t i = 0; i < ser_count; i++) { - key_locations_copy[i] = key_locations[i]; - } -#endif - - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(vcount, vdata); - switch (v.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT8: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT16: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT32: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT64: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::UINT128: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::FLOAT: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::DOUBLE: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::INTERVAL: - TemplatedRadixScatter(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, offset); - break; - case PhysicalType::VARCHAR: - RadixScatterStringVector(vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, offset); - break; - case PhysicalType::LIST: - RadixScatterListVector(v, vdata, sel, ser_count, key_locations, desc, has_null, nulls_first, prefix_len, width, - offset); - break; - case PhysicalType::STRUCT: - RadixScatterStructVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - case PhysicalType::ARRAY: - RadixScatterArrayVector(v, vdata, vcount, sel, ser_count, key_locations, desc, has_null, nulls_first, - prefix_len, width, offset); - break; - default: - throw NotImplementedException("Cannot ORDER BY column with type %s", v.GetType().ToString()); - } - -#ifdef DEBUG - for (idx_t i = 0; i < ser_count; i++) { - D_ASSERT(key_locations[i] == key_locations_copy[i] + width); - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/row_operations/row_scatter.cpp b/src/duckdb/src/common/row_operations/row_scatter.cpp deleted file mode 100644 index a535e1a27..000000000 --- a/src/duckdb/src/common/row_operations/row_scatter.cpp +++ /dev/null @@ -1,236 +0,0 @@ -//===--------------------------------------------------------------------===// -// row_scatter.cpp -// Description: This file contains the implementation of the row scattering -// operators -//===--------------------------------------------------------------------===// - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/common/types/row/row_layout.hpp" -#include "duckdb/common/types/selection_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = RowLayout::ValidityBytes; - -template -static void TemplatedScatter(UnifiedVectorFormat &col, Vector &rows, const SelectionVector &sel, const idx_t count, - const idx_t col_offset, const idx_t col_no, const idx_t col_count) { - auto data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - if (!col.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - auto isnull = !col.validity.RowIsValid(col_idx); - T store_value = isnull ? NullValue() : data[col_idx]; - Store(store_value, row + col_offset); - if (isnull) { - ValidityBytes col_mask(ptrs[idx], col_count); - col_mask.SetInvalidUnsafe(col_no); - } - } - } else { - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - - Store(data[col_idx], row + col_offset); - } - } -} - -static void ComputeStringEntrySizes(const UnifiedVectorFormat &col, idx_t entry_sizes[], const SelectionVector &sel, - const idx_t count, const idx_t offset = 0) { - auto data = UnifiedVectorFormat::GetData(col); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx) + offset; - const auto &str = data[col_idx]; - if (col.validity.RowIsValid(col_idx) && !str.IsInlined()) { - entry_sizes[i] += str.GetSize(); - } - } -} - -static void ScatterStringVector(UnifiedVectorFormat &col, Vector &rows, data_ptr_t str_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t col_count) { - auto string_data = UnifiedVectorFormat::GetData(col); - auto ptrs = FlatVector::GetData(rows); - - // Write out zero length to avoid swizzling problems. - const string_t null(nullptr, 0); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto col_idx = col.sel->get_index(idx); - auto row = ptrs[idx]; - if (!col.validity.RowIsValid(col_idx)) { - ValidityBytes col_mask(row, col_count); - col_mask.SetInvalidUnsafe(col_no); - Store(null, row + col_offset); - } else if (string_data[col_idx].IsInlined()) { - Store(string_data[col_idx], row + col_offset); - } else { - const auto &str = string_data[col_idx]; - string_t inserted(const_char_ptr_cast(str_locations[i]), UnsafeNumericCast(str.GetSize())); - memcpy(inserted.GetDataWriteable(), str.GetData(), str.GetSize()); - str_locations[i] += str.GetSize(); - inserted.Finalize(); - Store(inserted, row + col_offset); - } - } -} - -static void ScatterNestedVector(Vector &vec, UnifiedVectorFormat &col, Vector &rows, data_ptr_t data_locations[], - const SelectionVector &sel, const idx_t count, const idx_t col_offset, - const idx_t col_no, const idx_t vcount) { - // Store pointers to the data in the row - // Do this first because SerializeVector destroys the locations - auto ptrs = FlatVector::GetData(rows); - data_ptr_t validitymask_locations[STANDARD_VECTOR_SIZE]; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto row = ptrs[idx]; - validitymask_locations[i] = row; - - Store(data_locations[i], row + col_offset); - } - - // Serialise the data - NestedValidity parent_validity(validitymask_locations, col_no); - RowOperations::HeapScatter(vec, vcount, sel, count, data_locations, &parent_validity); -} - -void RowOperations::Scatter(DataChunk &columns, UnifiedVectorFormat col_data[], const RowLayout &layout, Vector &rows, - RowDataCollection &string_heap, const SelectionVector &sel, idx_t count) { - if (count == 0) { - return; - } - - // Set the validity mask for each row before inserting data - idx_t column_count = layout.ColumnCount(); - auto ptrs = FlatVector::GetData(rows); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - ValidityBytes(row, column_count).SetAllValid(layout.ColumnCount()); - } - - const auto vcount = columns.size(); - auto &offsets = layout.GetOffsets(); - auto &types = layout.GetTypes(); - - // Compute the entry size of the variable size columns - vector handles; - data_ptr_t data_locations[STANDARD_VECTOR_SIZE]; - if (!layout.AllConstant()) { - idx_t entry_sizes[STANDARD_VECTOR_SIZE]; - std::fill_n(entry_sizes, count, sizeof(uint32_t)); - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - if (TypeIsConstantSize(types[col_no].InternalType())) { - continue; - } - - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - switch (types[col_no].InternalType()) { - case PhysicalType::VARCHAR: - ComputeStringEntrySizes(col, entry_sizes, sel, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - RowOperations::ComputeEntrySizes(vec, col, entry_sizes, vcount, count, sel); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } - - // Build out the buffer space - handles = string_heap.Build(count, data_locations, entry_sizes); - - // Serialize information that is needed for swizzling if the computation goes out-of-core - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto row_idx = sel.get_index(i); - auto row = ptrs[row_idx]; - // Pointer to this row in the heap block - Store(data_locations[i], row + heap_pointer_offset); - // Row size is stored in the heap in front of each row - Store(NumericCast(entry_sizes[i]), data_locations[i]); - data_locations[i] += sizeof(uint32_t); - } - } - - for (idx_t col_no = 0; col_no < types.size(); col_no++) { - auto &vec = columns.data[col_no]; - auto &col = col_data[col_no]; - auto col_offset = offsets[col_no]; - - switch (types[col_no].InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT8: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT16: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT32: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT64: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::UINT128: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::FLOAT: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::DOUBLE: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::INTERVAL: - TemplatedScatter(col, rows, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::VARCHAR: - ScatterStringVector(col, rows, data_locations, sel, count, col_offset, col_no, column_count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - ScatterNestedVector(vec, col, rows, data_locations, sel, count, col_offset, col_no, vcount); - break; - default: - throw InternalException("Unsupported type for RowOperations::Scatter"); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/binary_deserializer.cpp b/src/duckdb/src/common/serializer/binary_deserializer.cpp deleted file mode 100644 index 0fec7bea6..000000000 --- a/src/duckdb/src/common/serializer/binary_deserializer.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/serializer/binary_deserializer.hpp" - -namespace duckdb { - -//------------------------------------------------------------------------- -// Nested Type Hooks -//------------------------------------------------------------------------- -void BinaryDeserializer::OnPropertyBegin(const field_id_t field_id, const char *) { - auto field = NextField(); - if (field != field_id) { - throw SerializationException("Failed to deserialize: field id mismatch, expected: %d, got: %d", field_id, - field); - } -} - -void BinaryDeserializer::OnPropertyEnd() { -} - -bool BinaryDeserializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *s) { - auto next_field = PeekField(); - auto present = next_field == field_id; - if (present) { - ConsumeField(); - } - return present; -} - -void BinaryDeserializer::OnOptionalPropertyEnd(bool present) { -} - -void BinaryDeserializer::OnObjectBegin() { - nesting_level++; -} - -void BinaryDeserializer::OnObjectEnd() { - auto next_field = NextField(); - if (next_field != MESSAGE_TERMINATOR_FIELD_ID) { - throw SerializationException("Failed to deserialize: expected end of object, but found field id: %d", - next_field); - } - nesting_level--; -} - -idx_t BinaryDeserializer::OnListBegin() { - return VarIntDecode(); -} - -void BinaryDeserializer::OnListEnd() { -} - -bool BinaryDeserializer::OnNullableBegin() { - return ReadBool(); -} - -void BinaryDeserializer::OnNullableEnd() { -} - -//------------------------------------------------------------------------- -// Primitive Types -//------------------------------------------------------------------------- -bool BinaryDeserializer::ReadBool() { - return static_cast(ReadPrimitive()); -} - -char BinaryDeserializer::ReadChar() { - return ReadPrimitive(); -} - -int8_t BinaryDeserializer::ReadSignedInt8() { - return VarIntDecode(); -} - -uint8_t BinaryDeserializer::ReadUnsignedInt8() { - return VarIntDecode(); -} - -int16_t BinaryDeserializer::ReadSignedInt16() { - return VarIntDecode(); -} - -uint16_t BinaryDeserializer::ReadUnsignedInt16() { - return VarIntDecode(); -} - -int32_t BinaryDeserializer::ReadSignedInt32() { - return VarIntDecode(); -} - -uint32_t BinaryDeserializer::ReadUnsignedInt32() { - return VarIntDecode(); -} - -int64_t BinaryDeserializer::ReadSignedInt64() { - return VarIntDecode(); -} - -uint64_t BinaryDeserializer::ReadUnsignedInt64() { - return VarIntDecode(); -} - -float BinaryDeserializer::ReadFloat() { - auto value = ReadPrimitive(); - return value; -} - -double BinaryDeserializer::ReadDouble() { - auto value = ReadPrimitive(); - return value; -} - -string BinaryDeserializer::ReadString() { - auto len = VarIntDecode(); - if (len == 0) { - return string(); - } - auto buffer = make_unsafe_uniq_array_uninitialized(len); - ReadData(buffer.get(), len); - return string(const_char_ptr_cast(buffer.get()), len); -} - -hugeint_t BinaryDeserializer::ReadHugeInt() { - auto upper = VarIntDecode(); - auto lower = VarIntDecode(); - return hugeint_t(upper, lower); -} - -uhugeint_t BinaryDeserializer::ReadUhugeInt() { - auto upper = VarIntDecode(); - auto lower = VarIntDecode(); - return uhugeint_t(upper, lower); -} - -void BinaryDeserializer::ReadDataPtr(data_ptr_t &ptr_p, idx_t count) { - auto len = VarIntDecode(); - if (len != count) { - throw SerializationException("Tried to read blob of %d size, but only %d elements are available", count, len); - } - ReadData(ptr_p, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/binary_serializer.cpp b/src/duckdb/src/common/serializer/binary_serializer.cpp deleted file mode 100644 index 8ec78c905..000000000 --- a/src/duckdb/src/common/serializer/binary_serializer.cpp +++ /dev/null @@ -1,175 +0,0 @@ -#include "duckdb/common/serializer/binary_serializer.hpp" - -#ifdef DEBUG -#include "duckdb/common/string_util.hpp" -#endif - -namespace duckdb { - -void BinarySerializer::OnPropertyBegin(const field_id_t field_id, const char *tag) { - // Just write the field id straight up - Write(field_id); -#ifdef DEBUG - // First of check that we are inside an object - if (debug_stack.empty()) { - throw InternalException("OnPropertyBegin called outside of object"); - } - - // Check that the tag is unique - auto &state = debug_stack.back(); - auto &seen_field_ids = state.seen_field_ids; - auto &seen_field_tags = state.seen_field_tags; - auto &seen_fields = state.seen_fields; - - if (seen_field_ids.find(field_id) != seen_field_ids.end() || seen_field_tags.find(tag) != seen_field_tags.end()) { - string all_fields; - for (auto &field : seen_fields) { - all_fields += StringUtil::Format("\"%s\":%d ", field.first, field.second); - } - throw InternalException("Duplicate field id/tag in field: \"%s\":%d, other fields: %s", tag, field_id, - all_fields); - } - - seen_field_ids.insert(field_id); - seen_field_tags.insert(tag); - seen_fields.emplace_back(tag, field_id); -#else - (void)tag; -#endif -} - -void BinarySerializer::OnPropertyEnd() { - // Nothing to do here -} - -void BinarySerializer::OnOptionalPropertyBegin(const field_id_t field_id, const char *tag, bool present) { - // Dont write anything at all if the property is not present - if (present) { - OnPropertyBegin(field_id, tag); - } -} - -void BinarySerializer::OnOptionalPropertyEnd(bool present) { - // Nothing to do here -} - -//------------------------------------------------------------------------- -// Nested Type Hooks -//------------------------------------------------------------------------- -void BinarySerializer::OnObjectBegin() { -#ifdef DEBUG - debug_stack.emplace_back(); -#endif -} - -void BinarySerializer::OnObjectEnd() { -#ifdef DEBUG - debug_stack.pop_back(); -#endif - // Write object terminator - Write(MESSAGE_TERMINATOR_FIELD_ID); -} - -void BinarySerializer::OnListBegin(idx_t count) { - VarIntEncode(count); -} - -void BinarySerializer::OnListEnd() { -} - -void BinarySerializer::OnNullableBegin(bool present) { - WriteValue(present); -} - -void BinarySerializer::OnNullableEnd() { -} - -//------------------------------------------------------------------------- -// Primitive Types -//------------------------------------------------------------------------- -void BinarySerializer::WriteNull() { - // This should never be called, optional writes should be handled by OnOptionalBegin -} - -void BinarySerializer::WriteValue(bool value) { - Write(value); -} - -void BinarySerializer::WriteValue(uint8_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(char value) { - Write(value); -} - -void BinarySerializer::WriteValue(int8_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint16_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int16_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint32_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int32_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(uint64_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(int64_t value) { - VarIntEncode(value); -} - -void BinarySerializer::WriteValue(hugeint_t value) { - VarIntEncode(value.upper); - VarIntEncode(value.lower); -} - -void BinarySerializer::WriteValue(uhugeint_t value) { - VarIntEncode(value.upper); - VarIntEncode(value.lower); -} - -void BinarySerializer::WriteValue(float value) { - Write(value); -} - -void BinarySerializer::WriteValue(double value) { - Write(value); -} - -void BinarySerializer::WriteValue(const string &value) { - auto len = NumericCast(value.length()); - VarIntEncode(len); - WriteData(value.c_str(), len); -} - -void BinarySerializer::WriteValue(const string_t value) { - auto len = NumericCast(value.GetSize()); - VarIntEncode(len); - WriteData(value.GetDataUnsafe(), len); -} - -void BinarySerializer::WriteValue(const char *value) { - auto len = NumericCast(strlen(value)); - VarIntEncode(len); - WriteData(value, len); -} - -void BinarySerializer::WriteDataPtr(const_data_ptr_t ptr, idx_t count) { - VarIntEncode(static_cast(count)); - WriteData(ptr, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/buffered_file_reader.cpp b/src/duckdb/src/common/serializer/buffered_file_reader.cpp deleted file mode 100644 index 96cd2e080..000000000 --- a/src/duckdb/src/common/serializer/buffered_file_reader.cpp +++ /dev/null @@ -1,72 +0,0 @@ -#include "duckdb/common/serializer/buffered_file_reader.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/serializer/buffered_file_writer.hpp" - -#include -#include - -namespace duckdb { - -BufferedFileReader::BufferedFileReader(FileSystem &fs, const char *path, FileLockType lock_type, - optional_ptr opener) - : fs(fs), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), read_data(0), - total_read(0) { - handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | lock_type, opener.get()); - file_size = NumericCast(fs.GetFileSize(*handle)); -} - -BufferedFileReader::BufferedFileReader(FileSystem &fs, unique_ptr handle_p) - : fs(fs), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), read_data(0), - handle(std::move(handle_p)), total_read(0) { - file_size = NumericCast(fs.GetFileSize(*handle)); -} - -void BufferedFileReader::ReadData(data_ptr_t target_buffer, uint64_t read_size) { - // first copy anything we can from the buffer - data_ptr_t end_ptr = target_buffer + read_size; - while (true) { - idx_t to_read = MinValue(UnsafeNumericCast(end_ptr - target_buffer), read_data - offset); - if (to_read > 0) { - memcpy(target_buffer, data.get() + offset, to_read); - offset += to_read; - target_buffer += to_read; - } - if (target_buffer < end_ptr) { - D_ASSERT(offset == read_data); - total_read += read_data; - // did not finish reading yet but exhausted buffer - // read data into buffer - offset = 0; - read_data = UnsafeNumericCast(fs.Read(*handle, data.get(), FILE_BUFFER_SIZE)); - if (read_data == 0) { - throw SerializationException("not enough data in file to deserialize result"); - } - } else { - return; - } - } -} - -bool BufferedFileReader::Finished() { - return total_read + offset == file_size; -} - -void BufferedFileReader::Seek(uint64_t location) { - D_ASSERT(location <= file_size); - handle->Seek(location); - total_read = location; - read_data = offset = 0; -} - -void BufferedFileReader::Reset() { - handle->Reset(); - total_read = 0; - read_data = offset = 0; -} - -uint64_t BufferedFileReader::CurrentOffset() { - return total_read + offset; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/buffered_file_writer.cpp b/src/duckdb/src/common/serializer/buffered_file_writer.cpp deleted file mode 100644 index f0d811921..000000000 --- a/src/duckdb/src/common/serializer/buffered_file_writer.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "duckdb/common/serializer/buffered_file_writer.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/typedefs.hpp" - -#include - -namespace duckdb { - -// Remove this when we switch C++17: https://stackoverflow.com/a/53350948 -constexpr FileOpenFlags BufferedFileWriter::DEFAULT_OPEN_FLAGS; - -BufferedFileWriter::BufferedFileWriter(FileSystem &fs, const string &path_p, FileOpenFlags open_flags) - : fs(fs), path(path_p), data(make_unsafe_uniq_array_uninitialized(FILE_BUFFER_SIZE)), offset(0), - total_written(0) { - handle = fs.OpenFile(path, open_flags | FileLockType::WRITE_LOCK); -} - -idx_t BufferedFileWriter::GetFileSize() { - return NumericCast(fs.GetFileSize(*handle)) + offset; -} - -idx_t BufferedFileWriter::GetTotalWritten() const { - return total_written + offset; -} - -void BufferedFileWriter::WriteData(const_data_ptr_t buffer, idx_t write_size) { - if (write_size >= (2ULL * FILE_BUFFER_SIZE - offset)) { - idx_t to_copy = 0; - // Check before performing direct IO if there is some data in the current internal buffer. - // If so, then fill the buffer (to avoid to small write operation), flush it and then write - // all the remain data directly. - // This is to avoid to split a large buffer into N*FILE_BUFFER_SIZE buffers - if (offset != 0) { - // Some data are still present in the buffer let write them before - to_copy = FILE_BUFFER_SIZE - offset; - memcpy(data.get() + offset, buffer, to_copy); - offset += to_copy; - Flush(); // Flush buffer before writing every things else - } - idx_t remaining_to_write = write_size - to_copy; - fs.Write(*handle, const_cast(buffer + to_copy), // NOLINT: wrong API in Write - UnsafeNumericCast(remaining_to_write)); - total_written += remaining_to_write; - } else { - // first copy anything we can from the buffer - const_data_ptr_t end_ptr = buffer + write_size; - while (buffer < end_ptr) { - idx_t to_write = MinValue(UnsafeNumericCast((end_ptr - buffer)), FILE_BUFFER_SIZE - offset); - D_ASSERT(to_write > 0); - memcpy(data.get() + offset, buffer, to_write); - offset += to_write; - buffer += to_write; - if (offset == FILE_BUFFER_SIZE) { - Flush(); - } - } - } -} - -void BufferedFileWriter::Flush() { - if (offset == 0) { - return; - } - fs.Write(*handle, data.get(), UnsafeNumericCast(offset)); - total_written += offset; - offset = 0; -} - -void BufferedFileWriter::Close() { - Flush(); - handle->Close(); - handle.reset(); -} - -void BufferedFileWriter::Sync() { - Flush(); - handle->Sync(); -} - -void BufferedFileWriter::Truncate(idx_t size) { - auto persistent = NumericCast(fs.GetFileSize(*handle)); - D_ASSERT(size <= persistent + offset); - if (persistent <= size) { - // truncating into the pending write buffer. - offset = size - persistent; - } else { - // truncate the physical file on disk - handle->Truncate(NumericCast(size)); - // reset anything written in the buffer - offset = 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/memory_stream.cpp b/src/duckdb/src/common/serializer/memory_stream.cpp deleted file mode 100644 index e5f0455e3..000000000 --- a/src/duckdb/src/common/serializer/memory_stream.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include "duckdb/common/serializer/memory_stream.hpp" - -namespace duckdb { - -MemoryStream::MemoryStream(idx_t capacity) : position(0), capacity(capacity), owns_data(true) { - D_ASSERT(capacity != 0 && IsPowerOfTwo(capacity)); - auto data_malloc_result = malloc(capacity); - if (!data_malloc_result) { - throw std::bad_alloc(); - } - data = static_cast(data_malloc_result); -} - -MemoryStream::MemoryStream(data_ptr_t buffer, idx_t capacity) - : position(0), capacity(capacity), owns_data(false), data(buffer) { -} - -MemoryStream::~MemoryStream() { - if (owns_data) { - free(data); - } -} - -MemoryStream::MemoryStream(MemoryStream &&other) noexcept { - // Move the data from the other stream into this stream - data = other.data; - position = other.position; - capacity = other.capacity; - owns_data = other.owns_data; - - // Reset the other stream - other.data = nullptr; - other.position = 0; - other.capacity = 0; - other.owns_data = false; -} - -MemoryStream &MemoryStream::operator=(MemoryStream &&other) noexcept { - if (this != &other) { - // Free the current data - if (owns_data) { - free(data); - } - - // Move the data from the other stream into this stream - data = other.data; - position = other.position; - capacity = other.capacity; - owns_data = other.owns_data; - - // Reset the other stream - other.data = nullptr; - other.position = 0; - other.capacity = 0; - other.owns_data = false; - } - return *this; -} - -void MemoryStream::WriteData(const_data_ptr_t source, idx_t write_size) { - while (position + write_size > capacity) { - if (owns_data) { - capacity *= 2; - data = static_cast(realloc(data, capacity)); - } else { - throw SerializationException("Failed to serialize: not enough space in buffer to fulfill write request"); - } - } - memcpy(data + position, source, write_size); - position += write_size; -} - -void MemoryStream::ReadData(data_ptr_t destination, idx_t read_size) { - if (position + read_size > capacity) { - throw SerializationException("Failed to deserialize: not enough data in buffer to fulfill read request"); - } - memcpy(destination, data + position, read_size); - position += read_size; -} - -void MemoryStream::Rewind() { - position = 0; -} - -void MemoryStream::Release() { - owns_data = false; -} - -data_ptr_t MemoryStream::GetData() const { - return data; -} - -idx_t MemoryStream::GetPosition() const { - return position; -} - -idx_t MemoryStream::GetCapacity() const { - return capacity; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/serializer/serializer.cpp b/src/duckdb/src/common/serializer/serializer.cpp deleted file mode 100644 index 91a7f4772..000000000 --- a/src/duckdb/src/common/serializer/serializer.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/types/value.hpp" - -namespace duckdb { - -template <> -void Serializer::WriteValue(const vector &vec) { - auto count = vec.size(); - OnListBegin(count); - for (auto item : vec) { - WriteValue(item); - } - OnListEnd(); -} - -template <> -void Serializer::WritePropertyWithDefault(const field_id_t field_id, const char *tag, const Value &value, - const Value &&default_value) { - // If current value is default, don't write it - if (!options.serialize_default_values && ValueOperations::NotDistinctFrom(value, default_value)) { - OnOptionalPropertyBegin(field_id, tag, false); - OnOptionalPropertyEnd(false); - return; - } - OnOptionalPropertyBegin(field_id, tag, true); - WriteValue(value); - OnOptionalPropertyEnd(true); -} - -void Serializer::List::WriteElement(data_ptr_t ptr, idx_t size) { - serializer.WriteDataPtr(ptr, size); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/comparators.cpp b/src/duckdb/src/common/sort/comparators.cpp deleted file mode 100644 index 4df4cccc4..000000000 --- a/src/duckdb/src/common/sort/comparators.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/sort/comparators.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -bool Comparators::TieIsBreakable(const idx_t &tie_col, const data_ptr_t &row_ptr, const SortLayout &sort_layout) { - const auto &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - // Check if the blob is NULL - ValidityBytes row_mask(row_ptr, sort_layout.column_count); - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - if (!row_mask.RowIsValid(row_mask.GetValidityEntry(entry_idx), idx_in_entry)) { - // Can't break a NULL tie - return false; - } - auto &row_layout = sort_layout.blob_layout; - if (row_layout.GetTypes()[col_idx].InternalType() != PhysicalType::VARCHAR) { - // Nested type, must be broken - return true; - } - const auto &tie_col_offset = row_layout.GetOffsets()[col_idx]; - auto tie_string = Load(row_ptr + tie_col_offset); - if (tie_string.GetSize() < sort_layout.prefix_lengths[tie_col] && tie_string.GetSize() > 0) { - // No need to break the tie - we already compared the full string - return false; - } - return true; -} - -int Comparators::CompareTuple(const SBScanState &left, const SBScanState &right, const data_ptr_t &l_ptr, - const data_ptr_t &r_ptr, const SortLayout &sort_layout, const bool &external_sort) { - // Compare the sorting columns one by one - int comp_res = 0; - data_ptr_t l_ptr_offset = l_ptr; - data_ptr_t r_ptr_offset = r_ptr; - for (idx_t col_idx = 0; col_idx < sort_layout.column_count; col_idx++) { - comp_res = FastMemcmp(l_ptr_offset, r_ptr_offset, sort_layout.column_sizes[col_idx]); - if (comp_res == 0 && !sort_layout.constant_size[col_idx]) { - comp_res = BreakBlobTie(col_idx, left, right, sort_layout, external_sort); - } - if (comp_res != 0) { - break; - } - l_ptr_offset += sort_layout.column_sizes[col_idx]; - r_ptr_offset += sort_layout.column_sizes[col_idx]; - } - return comp_res; -} - -int Comparators::CompareVal(const data_ptr_t l_ptr, const data_ptr_t r_ptr, const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::VARCHAR: - return TemplatedCompareVal(l_ptr, r_ptr); - case PhysicalType::LIST: - case PhysicalType::ARRAY: - case PhysicalType::STRUCT: { - auto l_nested_ptr = Load(l_ptr); - auto r_nested_ptr = Load(r_ptr); - return CompareValAndAdvance(l_nested_ptr, r_nested_ptr, type, true); - } - default: - throw NotImplementedException("Unimplemented CompareVal for type %s", type.ToString()); - } -} - -int Comparators::BreakBlobTie(const idx_t &tie_col, const SBScanState &left, const SBScanState &right, - const SortLayout &sort_layout, const bool &external) { - data_ptr_t l_data_ptr = left.DataPtr(*left.sb->blob_sorting_data); - data_ptr_t r_data_ptr = right.DataPtr(*right.sb->blob_sorting_data); - if (!TieIsBreakable(tie_col, l_data_ptr, sort_layout) && !TieIsBreakable(tie_col, r_data_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return 0; - } - // Align the pointers - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - l_data_ptr += tie_col_offset; - r_data_ptr += tie_col_offset; - // Do the comparison - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const auto &type = sort_layout.blob_layout.GetTypes()[col_idx]; - int result; - if (external) { - // Store heap pointers - data_ptr_t l_heap_ptr = left.HeapPtr(*left.sb->blob_sorting_data); - data_ptr_t r_heap_ptr = right.HeapPtr(*right.sb->blob_sorting_data); - // Unswizzle offset to pointer - UnswizzleSingleValue(l_data_ptr, l_heap_ptr, type); - UnswizzleSingleValue(r_data_ptr, r_heap_ptr, type); - // Compare - result = CompareVal(l_data_ptr, r_data_ptr, type); - // Swizzle the pointers back to offsets - SwizzleSingleValue(l_data_ptr, l_heap_ptr, type); - SwizzleSingleValue(r_data_ptr, r_heap_ptr, type); - } else { - result = CompareVal(l_data_ptr, r_data_ptr, type); - } - return order * result; -} - -template -int Comparators::TemplatedCompareVal(const data_ptr_t &left_ptr, const data_ptr_t &right_ptr) { - const auto left_val = Load(left_ptr); - const auto right_val = Load(right_ptr); - if (Equals::Operation(left_val, right_val)) { - return 0; - } else if (LessThan::Operation(left_val, right_val)) { - return -1; - } else { - return 1; - } -} - -int Comparators::CompareValAndAdvance(data_ptr_t &l_ptr, data_ptr_t &r_ptr, const LogicalType &type, bool valid) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT8: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT16: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT32: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT64: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::UINT128: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::FLOAT: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::DOUBLE: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::INTERVAL: - return TemplatedCompareAndAdvance(l_ptr, r_ptr); - case PhysicalType::VARCHAR: - return CompareStringAndAdvance(l_ptr, r_ptr, valid); - case PhysicalType::LIST: - return CompareListAndAdvance(l_ptr, r_ptr, ListType::GetChildType(type), valid); - case PhysicalType::STRUCT: - return CompareStructAndAdvance(l_ptr, r_ptr, StructType::GetChildTypes(type), valid); - case PhysicalType::ARRAY: - return CompareArrayAndAdvance(l_ptr, r_ptr, ArrayType::GetChildType(type), valid, ArrayType::GetSize(type)); - default: - throw NotImplementedException("Unimplemented CompareValAndAdvance for type %s", type.ToString()); - } -} - -template -int Comparators::TemplatedCompareAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr) { - auto result = TemplatedCompareVal(left_ptr, right_ptr); - left_ptr += sizeof(T); - right_ptr += sizeof(T); - return result; -} - -int Comparators::CompareStringAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, bool valid) { - if (!valid) { - return 0; - } - uint32_t left_string_size = Load(left_ptr); - uint32_t right_string_size = Load(right_ptr); - left_ptr += sizeof(uint32_t); - right_ptr += sizeof(uint32_t); - auto memcmp_res = memcmp(const_char_ptr_cast(left_ptr), const_char_ptr_cast(right_ptr), - std::min(left_string_size, right_string_size)); - - left_ptr += left_string_size; - right_ptr += right_string_size; - - if (memcmp_res != 0) { - return memcmp_res; - } - if (left_string_size == right_string_size) { - return 0; - } - if (left_string_size < right_string_size) { - return -1; - } - return 1; -} - -int Comparators::CompareStructAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const child_list_t &types, bool valid) { - idx_t count = types.size(); - // Load validity masks - ValidityBytes left_validity(left_ptr, types.size()); - ValidityBytes right_validity(right_ptr, types.size()); - left_ptr += (count + 7) / 8; - right_ptr += (count + 7) / 8; - // Initialize variables - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Compare - int comp_res = 0; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - auto &type = types[i].second; - if ((left_valid == right_valid) || TypeIsConstantSize(type.InternalType())) { - comp_res = CompareValAndAdvance(left_ptr, right_ptr, types[i].second, left_valid && valid); - } - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -int Comparators::CompareArrayAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid, idx_t array_size) { - if (!valid) { - return 0; - } - - // Load array validity masks - ValidityBytes left_validity(left_ptr, array_size); - ValidityBytes right_validity(right_ptr, array_size); - left_ptr += (array_size + 7) / 8; - right_ptr += (array_size + 7) / 8; - - int comp_res = 0; - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT8: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT16: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT32: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::UINT64: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INT128: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - case PhysicalType::INTERVAL: - comp_res = - TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, array_size); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized array entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += array_size * sizeof(idx_t); - right_ptr += array_size * sizeof(idx_t); - for (idx_t i = 0; i < array_size; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareArrayAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - return comp_res; -} - -int Comparators::CompareListAndAdvance(data_ptr_t &left_ptr, data_ptr_t &right_ptr, const LogicalType &type, - bool valid) { - if (!valid) { - return 0; - } - // Load list lengths - auto left_len = Load(left_ptr); - auto right_len = Load(right_ptr); - left_ptr += sizeof(idx_t); - right_ptr += sizeof(idx_t); - // Load list validity masks - ValidityBytes left_validity(left_ptr, left_len); - ValidityBytes right_validity(right_ptr, right_len); - left_ptr += (left_len + 7) / 8; - right_ptr += (right_len + 7) / 8; - // Compare - int comp_res = 0; - idx_t count = MinValue(left_len, right_len); - if (TypeIsConstantSize(type.InternalType())) { - // Templated code for fixed-size types - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT8: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT16: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT32: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT64: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::UINT128: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::FLOAT: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::DOUBLE: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - case PhysicalType::INTERVAL: - comp_res = TemplatedCompareListLoop(left_ptr, right_ptr, left_validity, right_validity, count); - break; - default: - throw NotImplementedException("CompareListAndAdvance for fixed-size type %s", type.ToString()); - } - } else { - // Variable-sized list entries - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - // Size (in bytes) of all variable-sizes entries is stored before the entries begin, - // to make deserialization easier. We need to skip over them - left_ptr += left_len * sizeof(idx_t); - right_ptr += right_len * sizeof(idx_t); - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - if (left_valid && right_valid) { - switch (type.InternalType()) { - case PhysicalType::LIST: - comp_res = CompareListAndAdvance(left_ptr, right_ptr, ListType::GetChildType(type), left_valid); - break; - case PhysicalType::ARRAY: - comp_res = CompareArrayAndAdvance(left_ptr, right_ptr, ArrayType::GetChildType(type), left_valid, - ArrayType::GetSize(type)); - break; - case PhysicalType::VARCHAR: - comp_res = CompareStringAndAdvance(left_ptr, right_ptr, left_valid); - break; - case PhysicalType::STRUCT: - comp_res = - CompareStructAndAdvance(left_ptr, right_ptr, StructType::GetChildTypes(type), left_valid); - break; - default: - throw NotImplementedException("CompareListAndAdvance for variable-size type %s", type.ToString()); - } - } else if (!left_valid && !right_valid) { - comp_res = 0; - } else if (left_valid) { - comp_res = -1; - } else { - comp_res = 1; - } - if (comp_res != 0) { - break; - } - } - } - // All values that we looped over were equal - if (comp_res == 0 && left_len != right_len) { - // Smaller lists first - if (left_len < right_len) { - comp_res = -1; - } else { - comp_res = 1; - } - } - return comp_res; -} - -template -int Comparators::TemplatedCompareListLoop(data_ptr_t &left_ptr, data_ptr_t &right_ptr, - const ValidityBytes &left_validity, const ValidityBytes &right_validity, - const idx_t &count) { - int comp_res = 0; - bool left_valid; - bool right_valid; - idx_t entry_idx; - idx_t idx_in_entry; - for (idx_t i = 0; i < count; i++) { - ValidityBytes::GetEntryIndex(i, entry_idx, idx_in_entry); - left_valid = left_validity.RowIsValid(left_validity.GetValidityEntry(entry_idx), idx_in_entry); - right_valid = right_validity.RowIsValid(right_validity.GetValidityEntry(entry_idx), idx_in_entry); - comp_res = TemplatedCompareAndAdvance(left_ptr, right_ptr); - if (!left_valid && !right_valid) { - comp_res = 0; - } else if (!left_valid) { - comp_res = 1; - } else if (!right_valid) { - comp_res = -1; - } - if (comp_res != 0) { - break; - } - } - return comp_res; -} - -void Comparators::UnswizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(heap_ptr + Load(data_ptr), data_ptr); -} - -void Comparators::SwizzleSingleValue(data_ptr_t data_ptr, const data_ptr_t &heap_ptr, const LogicalType &type) { - if (type.InternalType() == PhysicalType::VARCHAR) { - data_ptr += string_t::HEADER_SIZE; - } - Store(UnsafeNumericCast(Load(data_ptr) - heap_ptr), data_ptr); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/merge_sorter.cpp b/src/duckdb/src/common/sort/merge_sorter.cpp deleted file mode 100644 index b36887e66..000000000 --- a/src/duckdb/src/common/sort/merge_sorter.cpp +++ /dev/null @@ -1,663 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -MergeSorter::MergeSorter(GlobalSortState &state, BufferManager &buffer_manager) - : state(state), buffer_manager(buffer_manager), sort_layout(state.sort_layout) { -} - -void MergeSorter::PerformInMergeRound() { - while (true) { - { - lock_guard pair_guard(state.lock); - if (state.pair_idx == state.num_pairs) { - break; - } - GetNextPartition(); - } - MergePartition(); - } -} - -void MergeSorter::MergePartition() { - auto &left_block = *left->sb; - auto &right_block = *right->sb; -#ifdef DEBUG - D_ASSERT(left_block.radix_sorting_data.size() == left_block.payload_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.payload_data->data_blocks.size()); - if (!state.payload_layout.AllConstant() && state.external) { - D_ASSERT(left_block.payload_data->data_blocks.size() == left_block.payload_data->heap_blocks.size()); - D_ASSERT(right_block.payload_data->data_blocks.size() == right_block.payload_data->heap_blocks.size()); - } - if (!sort_layout.all_constant) { - D_ASSERT(left_block.radix_sorting_data.size() == left_block.blob_sorting_data->data_blocks.size()); - D_ASSERT(right_block.radix_sorting_data.size() == right_block.blob_sorting_data->data_blocks.size()); - if (state.external) { - D_ASSERT(left_block.blob_sorting_data->data_blocks.size() == - left_block.blob_sorting_data->heap_blocks.size()); - D_ASSERT(right_block.blob_sorting_data->data_blocks.size() == - right_block.blob_sorting_data->heap_blocks.size()); - } - } -#endif - // Set up the write block - // Each merge task produces a SortedBlock with exactly state.block_capacity rows or less - result->InitializeWrite(); - // Initialize arrays to store merge data - bool left_smaller[STANDARD_VECTOR_SIZE]; - idx_t next_entry_sizes[STANDARD_VECTOR_SIZE]; - // Merge loop -#ifdef DEBUG - auto l_count = left->Remaining(); - auto r_count = right->Remaining(); -#endif - while (true) { - auto l_remaining = left->Remaining(); - auto r_remaining = right->Remaining(); - if (l_remaining + r_remaining == 0) { - // Done - break; - } - const idx_t next = MinValue(l_remaining + r_remaining, (idx_t)STANDARD_VECTOR_SIZE); - if (l_remaining != 0 && r_remaining != 0) { - // Compute the merge (not needed if one side is exhausted) - ComputeMerge(next, left_smaller); - } - // Actually merge the data (radix, blob, and payload) - MergeRadix(next, left_smaller); - if (!sort_layout.all_constant) { - MergeData(*result->blob_sorting_data, *left_block.blob_sorting_data, *right_block.blob_sorting_data, next, - left_smaller, next_entry_sizes, true); - D_ASSERT(result->radix_sorting_data.size() == result->blob_sorting_data->data_blocks.size()); - } - MergeData(*result->payload_data, *left_block.payload_data, *right_block.payload_data, next, left_smaller, - next_entry_sizes, false); - D_ASSERT(result->radix_sorting_data.size() == result->payload_data->data_blocks.size()); - } -#ifdef DEBUG - D_ASSERT(result->Count() == l_count + r_count); -#endif -} - -void MergeSorter::GetNextPartition() { - // Create result block - state.sorted_blocks_temp[state.pair_idx].push_back(make_uniq(buffer_manager, state)); - result = state.sorted_blocks_temp[state.pair_idx].back().get(); - // Determine which blocks must be merged - auto &left_block = *state.sorted_blocks[state.pair_idx * 2]; - auto &right_block = *state.sorted_blocks[state.pair_idx * 2 + 1]; - const idx_t l_count = left_block.Count(); - const idx_t r_count = right_block.Count(); - // Initialize left and right reader - left = make_uniq(buffer_manager, state); - right = make_uniq(buffer_manager, state); - // Compute the work that this thread must do using Merge Path - idx_t l_end; - idx_t r_end; - if (state.l_start + state.r_start + state.block_capacity < l_count + r_count) { - left->sb = state.sorted_blocks[state.pair_idx * 2].get(); - right->sb = state.sorted_blocks[state.pair_idx * 2 + 1].get(); - const idx_t intersection = state.l_start + state.r_start + state.block_capacity; - GetIntersection(intersection, l_end, r_end); - D_ASSERT(l_end <= l_count); - D_ASSERT(r_end <= r_count); - D_ASSERT(intersection == l_end + r_end); - } else { - l_end = l_count; - r_end = r_count; - } - // Create slices of the data that this thread must merge - left->SetIndices(0, 0); - right->SetIndices(0, 0); - left_input = left_block.CreateSlice(state.l_start, l_end, left->entry_idx); - right_input = right_block.CreateSlice(state.r_start, r_end, right->entry_idx); - left->sb = left_input.get(); - right->sb = right_input.get(); - state.l_start = l_end; - state.r_start = r_end; - D_ASSERT(left->Remaining() + right->Remaining() == state.block_capacity || (l_end == l_count && r_end == r_count)); - // Update global state - if (state.l_start == l_count && state.r_start == r_count) { - // Delete references to previous pair - state.sorted_blocks[state.pair_idx * 2] = nullptr; - state.sorted_blocks[state.pair_idx * 2 + 1] = nullptr; - // Advance pair - state.pair_idx++; - state.l_start = 0; - state.r_start = 0; - } -} - -int MergeSorter::CompareUsingGlobalIndex(SBScanState &l, SBScanState &r, const idx_t l_idx, const idx_t r_idx) { - D_ASSERT(l_idx < l.sb->Count()); - D_ASSERT(r_idx < r.sb->Count()); - - // Easy comparison using the previous result (intersections must increase monotonically) - if (l_idx < state.l_start) { - return -1; - } - if (r_idx < state.r_start) { - return 1; - } - - l.sb->GlobalToLocalIndex(l_idx, l.block_idx, l.entry_idx); - r.sb->GlobalToLocalIndex(r_idx, r.block_idx, r.entry_idx); - - l.PinRadix(l.block_idx); - r.PinRadix(r.block_idx); - data_ptr_t l_ptr = l.radix_handle.Ptr() + l.entry_idx * sort_layout.entry_size; - data_ptr_t r_ptr = r.radix_handle.Ptr() + r.entry_idx * sort_layout.entry_size; - - int comp_res; - if (sort_layout.all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, sort_layout.comparison_size); - } else { - l.PinData(*l.sb->blob_sorting_data); - r.PinData(*r.sb->blob_sorting_data); - comp_res = Comparators::CompareTuple(l, r, l_ptr, r_ptr, sort_layout, state.external); - } - return comp_res; -} - -void MergeSorter::GetIntersection(const idx_t diagonal, idx_t &l_idx, idx_t &r_idx) { - const idx_t l_count = left->sb->Count(); - const idx_t r_count = right->sb->Count(); - // Cover some edge cases - // Code coverage off because these edge cases cannot happen unless other code changes - // Edge cases have been tested extensively while developing Merge Path in a script - // LCOV_EXCL_START - if (diagonal >= l_count + r_count) { - l_idx = l_count; - r_idx = r_count; - return; - } else if (diagonal == 0) { - l_idx = 0; - r_idx = 0; - return; - } else if (l_count == 0) { - l_idx = 0; - r_idx = diagonal; - return; - } else if (r_count == 0) { - r_idx = 0; - l_idx = diagonal; - return; - } - // LCOV_EXCL_STOP - // Determine offsets for the binary search - const idx_t l_offset = MinValue(l_count, diagonal); - const idx_t r_offset = diagonal > l_count ? diagonal - l_count : 0; - D_ASSERT(l_offset + r_offset == diagonal); - const idx_t search_space = diagonal > MaxValue(l_count, r_count) ? l_count + r_count - diagonal - : MinValue(diagonal, MinValue(l_count, r_count)); - // Double binary search - idx_t li = 0; - idx_t ri = search_space - 1; - idx_t middle; - int comp_res; - while (li <= ri) { - middle = (li + ri) / 2; - l_idx = l_offset - middle; - r_idx = r_offset + middle; - if (l_idx == l_count || r_idx == 0) { - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (comp_res > 0) { - l_idx--; - r_idx++; - } else { - return; - } - if (l_idx == 0 || r_idx == r_count) { - // This case is incredibly difficult to cover as it is dependent on parallelism randomness - // But it has been tested extensively during development in a script - // LCOV_EXCL_START - return; - // LCOV_EXCL_STOP - } else { - break; - } - } - comp_res = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx); - if (comp_res > 0) { - li = middle + 1; - } else { - ri = middle - 1; - } - } - int l_r_min1 = CompareUsingGlobalIndex(*left, *right, l_idx, r_idx - 1); - int l_min1_r = CompareUsingGlobalIndex(*left, *right, l_idx - 1, r_idx); - if (l_r_min1 > 0 && l_min1_r < 0) { - return; - } else if (l_r_min1 > 0) { - l_idx--; - r_idx++; - } else if (l_min1_r < 0) { - l_idx++; - r_idx--; - } -} - -void MergeSorter::ComputeMerge(const idx_t &count, bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - auto &l_sorted_block = *l.sb; - auto &r_sorted_block = *r.sb; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - // Data pointers for both sides - data_ptr_t l_radix_ptr; - data_ptr_t r_radix_ptr; - // Compute the merge of the next 'count' tuples - idx_t compared = 0; - while (compared < count) { - // Move to the next block (if needed) - if (l.block_idx < l_sorted_block.radix_sorting_data.size() && - l.entry_idx == l_sorted_block.radix_sorting_data[l.block_idx]->count) { - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_sorted_block.radix_sorting_data.size() && - r.entry_idx == r_sorted_block.radix_sorting_data[r.block_idx]->count) { - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_sorted_block.radix_sorting_data.size(); - const bool r_done = r.block_idx == r_sorted_block.radix_sorting_data.size(); - if (l_done || r_done) { - // One of the sides is exhausted, no need to compare - break; - } - // Pin the radix sorting data - left->PinRadix(l.block_idx); - l_radix_ptr = left->RadixPtr(); - right->PinRadix(r.block_idx); - r_radix_ptr = right->RadixPtr(); - - const idx_t l_count = l_sorted_block.radix_sorting_data[l.block_idx]->count; - const idx_t r_count = r_sorted_block.radix_sorting_data[r.block_idx]->count; - // Compute the merge - if (sort_layout.all_constant) { - // All sorting columns are constant size - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = FastMemcmp(l_radix_ptr, r_radix_ptr, sort_layout.comparison_size) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } else { - // Pin the blob data - left->PinData(*l_sorted_block.blob_sorting_data); - right->PinData(*r_sorted_block.blob_sorting_data); - // Merge with variable size sorting columns - for (; compared < count && l.entry_idx < l_count && r.entry_idx < r_count; compared++) { - left_smaller[compared] = - Comparators::CompareTuple(*left, *right, l_radix_ptr, r_radix_ptr, sort_layout, state.external) < 0; - const bool &l_smaller = left_smaller[compared]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to increment entries and pointers - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - l_radix_ptr += l_smaller * sort_layout.entry_size; - r_radix_ptr += r_smaller * sort_layout.entry_size; - } - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeRadix(const idx_t &count, const bool left_smaller[]) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - auto &l_blocks = l.sb->radix_sorting_data; - auto &r_blocks = r.sb->radix_sorting_data; - RowDataBlock *l_block = nullptr; - RowDataBlock *r_block = nullptr; - - data_ptr_t l_ptr; - data_ptr_t r_ptr; - - RowDataBlock *result_block = result->radix_sorting_data.back().get(); - auto result_handle = buffer_manager.Pin(result_block->block); - data_ptr_t result_ptr = result_handle.Ptr() + result_block->count * sort_layout.entry_size; - - idx_t copied = 0; - while (copied < count) { - // Move to the next block (if needed) - if (l.block_idx < l_blocks.size() && l.entry_idx == l_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_blocks[l.block_idx]->block = nullptr; - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_blocks.size() && r.entry_idx == r_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_blocks[r.block_idx]->block = nullptr; - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_blocks.size(); - const bool r_done = r.block_idx == r_blocks.size(); - // Pin the radix sortable blocks - idx_t l_count; - if (!l_done) { - l_block = l_blocks[l.block_idx].get(); - left->PinRadix(l.block_idx); - l_ptr = l.RadixPtr(); - l_count = l_block->count; - } else { - l_count = 0; - } - idx_t r_count; - if (!r_done) { - r_block = r_blocks[r.block_idx].get(); - r.PinRadix(r.block_idx); - r_ptr = r.RadixPtr(); - r_count = r_block->count; - } else { - r_count = 0; - } - // Copy using computed merge - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_block, result_ptr, - sort_layout.entry_size, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_block, result_ptr, sort_layout.entry_size, copied, count); - } - } - // Reset block indices - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); -} - -void MergeSorter::MergeData(SortedData &result_data, SortedData &l_data, SortedData &r_data, const idx_t &count, - const bool left_smaller[], idx_t next_entry_sizes[], bool reset_indices) { - auto &l = *left; - auto &r = *right; - // Save indices to restore afterwards - idx_t l_block_idx_before = l.block_idx; - idx_t l_entry_idx_before = l.entry_idx; - idx_t r_block_idx_before = r.block_idx; - idx_t r_entry_idx_before = r.entry_idx; - - const auto &layout = result_data.layout; - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - - // Left and right row data to merge - data_ptr_t l_ptr; - data_ptr_t r_ptr; - // Accompanying left and right heap data (if needed) - data_ptr_t l_heap_ptr; - data_ptr_t r_heap_ptr; - - // Result rows to write to - RowDataBlock *result_data_block = result_data.data_blocks.back().get(); - auto result_data_handle = buffer_manager.Pin(result_data_block->block); - data_ptr_t result_data_ptr = result_data_handle.Ptr() + result_data_block->count * row_width; - // Result heap to write to (if needed) - RowDataBlock *result_heap_block = nullptr; - BufferHandle result_heap_handle; - data_ptr_t result_heap_ptr; - if (!layout.AllConstant() && state.external) { - result_heap_block = result_data.heap_blocks.back().get(); - result_heap_handle = buffer_manager.Pin(result_heap_block->block); - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - - idx_t copied = 0; - while (copied < count) { - // Move to new data blocks (if needed) - if (l.block_idx < l_data.data_blocks.size() && l.entry_idx == l_data.data_blocks[l.block_idx]->count) { - // Delete reference to previous block - l_data.data_blocks[l.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - l_data.heap_blocks[l.block_idx]->block = nullptr; - } - // Advance block - l.block_idx++; - l.entry_idx = 0; - } - if (r.block_idx < r_data.data_blocks.size() && r.entry_idx == r_data.data_blocks[r.block_idx]->count) { - // Delete reference to previous block - r_data.data_blocks[r.block_idx]->block = nullptr; - if (!layout.AllConstant() && state.external) { - r_data.heap_blocks[r.block_idx]->block = nullptr; - } - // Advance block - r.block_idx++; - r.entry_idx = 0; - } - const bool l_done = l.block_idx == l_data.data_blocks.size(); - const bool r_done = r.block_idx == r_data.data_blocks.size(); - // Pin the row data blocks - if (!l_done) { - l.PinData(l_data); - l_ptr = l.DataPtr(l_data); - } - if (!r_done) { - r.PinData(r_data); - r_ptr = r.DataPtr(r_data); - } - const idx_t &l_count = !l_done ? l_data.data_blocks[l.block_idx]->count : 0; - const idx_t &r_count = !r_done ? r_data.data_blocks[r.block_idx]->count : 0; - // Perform the merge - if (layout.AllConstant() || !state.external) { - // If all constant size, or if we are doing an in-memory sort, we do not need to touch the heap - if (!l_done && !r_done) { - // Both sides have data - merge - MergeRows(l_ptr, l.entry_idx, l_count, r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, - row_width, left_smaller, copied, count); - } else if (r_done) { - // Right side is exhausted - FlushRows(l_ptr, l.entry_idx, l_count, *result_data_block, result_data_ptr, row_width, copied, count); - } else { - // Left side is exhausted - FlushRows(r_ptr, r.entry_idx, r_count, *result_data_block, result_data_ptr, row_width, copied, count); - } - } else { - // External sorting with variable size data. Pin the heap blocks too - if (!l_done) { - l_heap_ptr = l.BaseHeapPtr(l_data) + Load(l_ptr + heap_pointer_offset); - D_ASSERT(l_heap_ptr - l.BaseHeapPtr(l_data) >= 0); - D_ASSERT((idx_t)(l_heap_ptr - l.BaseHeapPtr(l_data)) < l_data.heap_blocks[l.block_idx]->byte_offset); - } - if (!r_done) { - r_heap_ptr = r.BaseHeapPtr(r_data) + Load(r_ptr + heap_pointer_offset); - D_ASSERT(r_heap_ptr - r.BaseHeapPtr(r_data) >= 0); - D_ASSERT((idx_t)(r_heap_ptr - r.BaseHeapPtr(r_data)) < r_data.heap_blocks[r.block_idx]->byte_offset); - } - // Both the row and heap data need to be dealt with - if (!l_done && !r_done) { - // Both sides have data - merge - idx_t l_idx_copy = l.entry_idx; - idx_t r_idx_copy = r.entry_idx; - data_ptr_t result_data_ptr_copy = result_data_ptr; - idx_t copied_copy = copied; - // Merge row data - MergeRows(l_ptr, l_idx_copy, l_count, r_ptr, r_idx_copy, r_count, *result_data_block, - result_data_ptr_copy, row_width, left_smaller, copied_copy, count); - const idx_t merged = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t l_heap_ptr_copy = l_heap_ptr; - data_ptr_t r_heap_ptr_copy = r_heap_ptr; - for (idx_t i = 0; i < merged; i++) { - // Store base heap offset in the row data - Store(result_heap_block->byte_offset + copy_bytes, result_data_ptr + heap_pointer_offset); - result_data_ptr += row_width; - // Compute entry size and add to total - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - auto &entry_size = next_entry_sizes[copied + i]; - entry_size = - l_smaller * Load(l_heap_ptr_copy) + r_smaller * Load(r_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - D_ASSERT(NumericCast(l_heap_ptr_copy - l.BaseHeapPtr(l_data)) + l_smaller * entry_size <= - l_data.heap_blocks[l.block_idx]->byte_offset); - D_ASSERT(NumericCast(r_heap_ptr_copy - r.BaseHeapPtr(r_data)) + r_smaller * entry_size <= - r_data.heap_blocks[r.block_idx]->byte_offset); - l_heap_ptr_copy += l_smaller * entry_size; - r_heap_ptr_copy += r_smaller * entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (result_heap_block->byte_offset + copy_bytes > result_heap_block->capacity) { - idx_t new_capacity = result_heap_block->byte_offset + copy_bytes; - buffer_manager.ReAllocate(result_heap_block->block, new_capacity); - result_heap_block->capacity = new_capacity; - result_heap_ptr = result_heap_handle.Ptr() + result_heap_block->byte_offset; - } - D_ASSERT(result_heap_block->byte_offset + copy_bytes <= result_heap_block->capacity); - // Now copy the heap data - for (idx_t i = 0; i < merged; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - const auto &entry_size = next_entry_sizes[copied + i]; - memcpy(result_heap_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_heap_ptr) + - r_smaller * CastPointerToValue(r_heap_ptr)), - entry_size); - D_ASSERT(Load(result_heap_ptr) == entry_size); - result_heap_ptr += entry_size; - l_heap_ptr += l_smaller * entry_size; - r_heap_ptr += r_smaller * entry_size; - l.entry_idx += l_smaller; - r.entry_idx += r_smaller; - } - // Update result indices and pointers - result_heap_block->count += merged; - result_heap_block->byte_offset += copy_bytes; - copied += merged; - } else if (r_done) { - // Right side is exhausted - flush left - FlushBlobs(layout, l_count, l_ptr, l.entry_idx, l_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } else { - // Left side is exhausted - flush right - FlushBlobs(layout, r_count, r_ptr, r.entry_idx, r_heap_ptr, *result_data_block, result_data_ptr, - *result_heap_block, result_heap_handle, result_heap_ptr, copied, count); - } - D_ASSERT(result_data_block->count == result_heap_block->count); - } - } - if (reset_indices) { - left->SetIndices(l_block_idx_before, l_entry_idx_before); - right->SetIndices(r_block_idx_before, r_entry_idx_before); - } -} - -void MergeSorter::MergeRows(data_ptr_t &l_ptr, idx_t &l_entry_idx, const idx_t &l_count, data_ptr_t &r_ptr, - idx_t &r_entry_idx, const idx_t &r_count, RowDataBlock &target_block, - data_ptr_t &target_ptr, const idx_t &entry_size, const bool left_smaller[], idx_t &copied, - const idx_t &count) { - const idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - idx_t i; - for (i = 0; i < next && l_entry_idx < l_count && r_entry_idx < r_count; i++) { - const bool &l_smaller = left_smaller[copied + i]; - const bool r_smaller = !l_smaller; - // Use comparison bool (0 or 1) to copy an entry from either side - FastMemcpy( - target_ptr, - reinterpret_cast(l_smaller * CastPointerToValue(l_ptr) + r_smaller * CastPointerToValue(r_ptr)), - entry_size); - target_ptr += entry_size; - // Use the comparison bool to increment entries and pointers - l_entry_idx += l_smaller; - r_entry_idx += r_smaller; - l_ptr += l_smaller * entry_size; - r_ptr += r_smaller * entry_size; - } - // Update counts - target_block.count += i; - copied += i; -} - -void MergeSorter::FlushRows(data_ptr_t &source_ptr, idx_t &source_entry_idx, const idx_t &source_count, - RowDataBlock &target_block, data_ptr_t &target_ptr, const idx_t &entry_size, idx_t &copied, - const idx_t &count) { - // Compute how many entries we can fit - idx_t next = MinValue(count - copied, target_block.capacity - target_block.count); - next = MinValue(next, source_count - source_entry_idx); - // Copy them all in a single memcpy - const idx_t copy_bytes = next * entry_size; - memcpy(target_ptr, source_ptr, copy_bytes); - target_ptr += copy_bytes; - source_ptr += copy_bytes; - // Update counts - source_entry_idx += next; - target_block.count += next; - copied += next; -} - -void MergeSorter::FlushBlobs(const RowLayout &layout, const idx_t &source_count, data_ptr_t &source_data_ptr, - idx_t &source_entry_idx, data_ptr_t &source_heap_ptr, RowDataBlock &target_data_block, - data_ptr_t &target_data_ptr, RowDataBlock &target_heap_block, - BufferHandle &target_heap_handle, data_ptr_t &target_heap_ptr, idx_t &copied, - const idx_t &count) { - const idx_t row_width = layout.GetRowWidth(); - const idx_t heap_pointer_offset = layout.GetHeapOffset(); - idx_t source_entry_idx_copy = source_entry_idx; - data_ptr_t target_data_ptr_copy = target_data_ptr; - idx_t copied_copy = copied; - // Flush row data - FlushRows(source_data_ptr, source_entry_idx_copy, source_count, target_data_block, target_data_ptr_copy, row_width, - copied_copy, count); - const idx_t flushed = copied_copy - copied; - // Compute the entry sizes and number of heap bytes that will be copied - idx_t copy_bytes = 0; - data_ptr_t source_heap_ptr_copy = source_heap_ptr; - for (idx_t i = 0; i < flushed; i++) { - // Store base heap offset in the row data - Store(target_heap_block.byte_offset + copy_bytes, target_data_ptr + heap_pointer_offset); - target_data_ptr += row_width; - // Compute entry size and add to total - auto entry_size = Load(source_heap_ptr_copy); - D_ASSERT(entry_size >= sizeof(uint32_t)); - source_heap_ptr_copy += entry_size; - copy_bytes += entry_size; - } - // Reallocate result heap block size (if needed) - if (target_heap_block.byte_offset + copy_bytes > target_heap_block.capacity) { - idx_t new_capacity = target_heap_block.byte_offset + copy_bytes; - buffer_manager.ReAllocate(target_heap_block.block, new_capacity); - target_heap_block.capacity = new_capacity; - target_heap_ptr = target_heap_handle.Ptr() + target_heap_block.byte_offset; - } - D_ASSERT(target_heap_block.byte_offset + copy_bytes <= target_heap_block.capacity); - // Copy the heap data in one go - memcpy(target_heap_ptr, source_heap_ptr, copy_bytes); - target_heap_ptr += copy_bytes; - source_heap_ptr += copy_bytes; - source_entry_idx += flushed; - copied += flushed; - // Update result indices and pointers - target_heap_block.count += flushed; - target_heap_block.byte_offset += copy_bytes; - D_ASSERT(target_heap_block.byte_offset <= target_heap_block.capacity); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/partition_state.cpp b/src/duckdb/src/common/sort/partition_state.cpp deleted file mode 100644 index 9c4379695..000000000 --- a/src/duckdb/src/common/sort/partition_state.cpp +++ /dev/null @@ -1,688 +0,0 @@ -#include "duckdb/common/sort/partition_state.hpp" - -#include "duckdb/common/types/column/column_data_consumer.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/parallel/executor_task.hpp" - -#include - -namespace duckdb { - -PartitionGlobalHashGroup::PartitionGlobalHashGroup(BufferManager &buffer_manager, const Orders &partitions, - const Orders &orders, const Types &payload_types, bool external) - : count(0) { - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = external; - - // Set up a comparator for the partition subset - partition_layout = global_sort->sort_layout.GetPrefixComparisonLayout(partitions.size()); -} - -void PartitionGlobalHashGroup::ComputeMasks(ValidityMask &partition_mask, OrderMasks &order_masks) { - D_ASSERT(count > 0); - - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - - partition_mask.SetValidUnsafe(0); - unordered_map prefixes; - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(0); - D_ASSERT(order_mask.first >= partition_layout.column_count); - prefixes[order_mask.first] = global_sort->sort_layout.GetPrefixComparisonLayout(order_mask.first); - } - - for (++curr; curr.GetIndex() < count; ++curr) { - // Compare the partition subset first because if that differs, then so does the full ordering - const auto part_cmp = ComparePartitions(prev, curr); - - if (part_cmp) { - partition_mask.SetValidUnsafe(curr.GetIndex()); - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } else { - for (auto &order_mask : order_masks) { - if (prev.Compare(curr, prefixes[order_mask.first])) { - order_mask.second.SetValidUnsafe(curr.GetIndex()); - } - } - } - ++prev; - } -} - -void PartitionGlobalSinkState::GenerateOrderings(Orders &partitions, Orders &orders, - const vector> &partition_bys, - const Orders &order_bys, - const vector> &partition_stats) { - - // we sort by both 1) partition by expression list and 2) order by expressions - const auto partition_cols = partition_bys.size(); - for (idx_t prt_idx = 0; prt_idx < partition_cols; prt_idx++) { - auto &pexpr = partition_bys[prt_idx]; - - if (partition_stats.empty() || !partition_stats[prt_idx]) { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), nullptr); - } else { - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, pexpr->Copy(), - partition_stats[prt_idx]->ToUnique()); - } - partitions.emplace_back(orders.back().Copy()); - } - - for (const auto &order : order_bys) { - orders.emplace_back(order.Copy()); - } -} - -PartitionGlobalSinkState::PartitionGlobalSinkState(ClientContext &context, - const vector> &partition_bys, - const vector &order_bys, - const Types &payload_types, - const vector> &partition_stats, - idx_t estimated_cardinality) - : context(context), buffer_manager(BufferManager::GetBufferManager(context)), allocator(Allocator::Get(context)), - fixed_bits(0), payload_types(payload_types), memory_per_thread(0), max_bits(1), count(0) { - - GenerateOrderings(partitions, orders, partition_bys, order_bys, partition_stats); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - external = ClientConfig::GetConfig(context).GetSetting(context); - - const auto thread_pages = PreviousPowerOfTwo(memory_per_thread / (4 * buffer_manager.GetBlockAllocSize())); - while (max_bits < 10 && (thread_pages >> max_bits) > 1) { - ++max_bits; - } - - if (!orders.empty()) { - if (partitions.empty()) { - // Sort early into a dedicated hash group if we only sort. - grouping_types.Initialize(payload_types); - auto new_group = - make_uniq(buffer_manager, partitions, orders, payload_types, external); - hash_groups.emplace_back(std::move(new_group)); - } else { - auto types = payload_types; - types.push_back(LogicalType::HASH); - grouping_types.Initialize(types); - ResizeGroupingData(estimated_cardinality); - } - } -} - -bool PartitionGlobalSinkState::HasMergeTasks() const { - if (grouping_data) { - auto &groups = grouping_data->GetPartitions(); - return !groups.empty(); - } else if (!hash_groups.empty()) { - D_ASSERT(hash_groups.size() == 1); - return hash_groups[0]->count > 0; - } else { - return false; - } -} - -void PartitionGlobalSinkState::SyncPartitioning(const PartitionGlobalSinkState &other) { - fixed_bits = other.grouping_data ? other.grouping_data->GetRadixBits() : 0; - - const auto old_bits = grouping_data ? grouping_data->GetRadixBits() : 0; - if (fixed_bits != old_bits) { - const auto hash_col_idx = payload_types.size(); - grouping_data = make_uniq(buffer_manager, grouping_types, fixed_bits, hash_col_idx); - } -} - -unique_ptr PartitionGlobalSinkState::CreatePartition(idx_t new_bits) const { - const auto hash_col_idx = payload_types.size(); - return make_uniq(buffer_manager, grouping_types, new_bits, hash_col_idx); -} - -void PartitionGlobalSinkState::ResizeGroupingData(idx_t cardinality) { - // Have we started to combine? Then just live with it. - if (fixed_bits || (grouping_data && !grouping_data->GetPartitions().empty())) { - return; - } - // Is the average partition size too large? - const idx_t partition_size = DEFAULT_ROW_GROUP_SIZE; - const auto bits = grouping_data ? grouping_data->GetRadixBits() : 0; - auto new_bits = bits ? bits : 4; - while (new_bits < max_bits && (cardinality / RadixPartitioning::NumberOfPartitions(new_bits)) > partition_size) { - ++new_bits; - } - - // Repartition the grouping data - if (new_bits != bits) { - grouping_data = CreatePartition(new_bits); - } -} - -void PartitionGlobalSinkState::SyncLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // We are done if the local_partition is right sized. - auto &local_radix = local_partition->Cast(); - const auto new_bits = grouping_data->GetRadixBits(); - if (local_radix.GetRadixBits() == new_bits) { - return; - } - - // If the local partition is now too small, flush it and reallocate - auto new_partition = CreatePartition(new_bits); - local_partition->FlushAppendState(*local_append); - local_partition->Repartition(*new_partition); - - local_partition = std::move(new_partition); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); -} - -void PartitionGlobalSinkState::UpdateLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - // Make sure grouping_data doesn't change under us. - lock_guard guard(lock); - - if (!local_partition) { - local_partition = CreatePartition(grouping_data->GetRadixBits()); - local_append = make_uniq(); - local_partition->InitializeAppendState(*local_append); - return; - } - - // Grow the groups if they are too big - ResizeGroupingData(count); - - // Sync local partition to have the same bit count - SyncLocalPartition(local_partition, local_append); -} - -void PartitionGlobalSinkState::CombineLocalPartition(GroupingPartition &local_partition, GroupingAppend &local_append) { - if (!local_partition) { - return; - } - local_partition->FlushAppendState(*local_append); - - // Make sure grouping_data doesn't change under us. - // Combine has an internal mutex, so this is single-threaded anyway. - lock_guard guard(lock); - SyncLocalPartition(local_partition, local_append); - grouping_data->Combine(*local_partition); -} - -PartitionLocalMergeState::PartitionLocalMergeState(PartitionGlobalSinkState &gstate) - : merge_state(nullptr), stage(PartitionSortStage::INIT), finished(true), executor(gstate.context) { - - // Set up the sort expression computation. - vector sort_types; - for (auto &order : gstate.orders) { - auto &oexpr = order.expression; - sort_types.emplace_back(oexpr->return_type); - executor.AddExpression(*oexpr); - } - sort_chunk.Initialize(gstate.allocator, sort_types); - payload_chunk.Initialize(gstate.allocator, gstate.payload_types); -} - -void PartitionLocalMergeState::Scan() { - if (!merge_state->group_data) { - // OVER(ORDER BY...) - // Already sorted - return; - } - - auto &group_data = *merge_state->group_data; - auto &hash_group = *merge_state->hash_group; - auto &chunk_state = merge_state->chunk_state; - // Copy the data from the group into the sort code. - auto &global_sort = *hash_group.global_sort; - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - TupleDataScanState local_scan; - group_data.InitializeScan(local_scan, merge_state->column_ids); - while (group_data.Scan(chunk_state, local_scan, payload_chunk)) { - sort_chunk.Reset(); - executor.Execute(payload_chunk, sort_chunk); - - local_sort.SinkChunk(sort_chunk, payload_chunk); - if (local_sort.SizeInBytes() > merge_state->memory_per_thread) { - local_sort.Sort(global_sort, true); - } - hash_group.count += payload_chunk.size(); - } - - global_sort.AddLocalState(local_sort); -} - -// Per-thread sink state -PartitionLocalSinkState::PartitionLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : gstate(gstate_p), allocator(Allocator::Get(context)), executor(context) { - - vector group_types; - for (idx_t prt_idx = 0; prt_idx < gstate.partitions.size(); prt_idx++) { - auto &pexpr = *gstate.partitions[prt_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - sort_cols = gstate.orders.size() + group_types.size(); - - if (sort_cols) { - auto payload_types = gstate.payload_types; - if (!group_types.empty()) { - // OVER(PARTITION BY...) - group_chunk.Initialize(allocator, group_types); - payload_types.emplace_back(LogicalType::HASH); - } else { - // OVER(ORDER BY...) - for (idx_t ord_idx = 0; ord_idx < gstate.orders.size(); ord_idx++) { - auto &pexpr = *gstate.orders[ord_idx].expression.get(); - group_types.push_back(pexpr.return_type); - executor.AddExpression(pexpr); - } - group_chunk.Initialize(allocator, group_types); - - // Single partition - auto &global_sort = *gstate.hash_groups[0]->global_sort; - local_sort = make_uniq(); - local_sort->Initialize(global_sort, global_sort.buffer_manager); - } - // OVER(...) - payload_chunk.Initialize(allocator, payload_types); - } else { - // OVER() - payload_layout.Initialize(gstate.payload_types); - } -} - -void PartitionLocalSinkState::Hash(DataChunk &input_chunk, Vector &hash_vector) { - const auto count = input_chunk.size(); - D_ASSERT(group_chunk.ColumnCount() > 0); - - // OVER(PARTITION BY...) (hash grouping) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - VectorOperations::Hash(group_chunk.data[0], hash_vector, count); - for (idx_t prt_idx = 1; prt_idx < group_chunk.ColumnCount(); ++prt_idx) { - VectorOperations::CombineHash(hash_vector, group_chunk.data[prt_idx], count); - } -} - -void PartitionLocalSinkState::Sink(DataChunk &input_chunk) { - gstate.count += input_chunk.size(); - - // OVER() - if (sort_cols == 0) { - // No sorts, so build paged row chunks - if (!rows) { - const auto entry_size = payload_layout.GetRowWidth(); - const auto block_size = gstate.buffer_manager.GetBlockSize(); - const auto capacity = MaxValue(STANDARD_VECTOR_SIZE, block_size / entry_size + 1); - rows = make_uniq(gstate.buffer_manager, capacity, entry_size); - strings = make_uniq(gstate.buffer_manager, block_size, 1U, true); - } - const auto row_count = input_chunk.size(); - const auto row_sel = FlatVector::IncrementalSelectionVector(); - Vector addresses(LogicalType::POINTER); - auto key_locations = FlatVector::GetData(addresses); - const auto prev_rows_blocks = rows->blocks.size(); - auto handles = rows->Build(row_count, key_locations, nullptr, row_sel); - auto input_data = input_chunk.ToUnifiedFormat(); - RowOperations::Scatter(input_chunk, input_data.get(), payload_layout, addresses, *strings, *row_sel, row_count); - // Mark that row blocks contain pointers (heap blocks are pinned) - if (!payload_layout.AllConstant()) { - D_ASSERT(strings->keep_pinned); - for (size_t i = prev_rows_blocks; i < rows->blocks.size(); ++i) { - rows->blocks[i]->block->SetSwizzling("PartitionLocalSinkState::Sink"); - } - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - group_chunk.Reset(); - executor.Execute(input_chunk, group_chunk); - local_sort->SinkChunk(group_chunk, input_chunk); - - auto &hash_group = *gstate.hash_groups[0]; - hash_group.count += input_chunk.size(); - - if (local_sort->SizeInBytes() > gstate.memory_per_thread) { - auto &global_sort = *hash_group.global_sort; - local_sort->Sort(global_sort, true); - } - return; - } - - // OVER(...) - payload_chunk.Reset(); - auto &hash_vector = payload_chunk.data.back(); - Hash(input_chunk, hash_vector); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); ++col_idx) { - payload_chunk.data[col_idx].Reference(input_chunk.data[col_idx]); - } - payload_chunk.SetCardinality(input_chunk); - - gstate.UpdateLocalPartition(local_partition, local_append); - local_partition->Append(*local_append, payload_chunk); -} - -void PartitionLocalSinkState::Combine() { - // OVER() - if (sort_cols == 0) { - // Only one partition again, so need a global lock. - lock_guard glock(gstate.lock); - if (gstate.rows) { - if (rows) { - gstate.rows->Merge(*rows); - gstate.strings->Merge(*strings); - rows.reset(); - strings.reset(); - } - } else { - gstate.rows = std::move(rows); - gstate.strings = std::move(strings); - } - return; - } - - if (local_sort) { - // OVER(ORDER BY...) - auto &hash_group = *gstate.hash_groups[0]; - auto &global_sort = *hash_group.global_sort; - global_sort.AddLocalState(*local_sort); - local_sort.reset(); - return; - } - - // OVER(...) - gstate.CombineLocalPartition(local_partition, local_append); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink, GroupDataPtr group_data_p, - hash_t hash_bin) - : sink(sink), group_data(std::move(group_data_p)), group_idx(sink.hash_groups.size()), - memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - auto new_group = make_uniq(sink.buffer_manager, sink.partitions, sink.orders, - sink.payload_types, sink.external); - sink.hash_groups.emplace_back(std::move(new_group)); - - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; - - column_ids.reserve(sink.payload_types.size()); - for (column_t i = 0; i < sink.payload_types.size(); ++i) { - column_ids.emplace_back(i); - } - group_data->InitializeScan(chunk_state, column_ids); -} - -PartitionGlobalMergeState::PartitionGlobalMergeState(PartitionGlobalSinkState &sink) - : sink(sink), group_idx(0), memory_per_thread(sink.memory_per_thread), - num_threads(NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads())), - stage(PartitionSortStage::INIT), total_tasks(0), tasks_assigned(0), tasks_completed(0) { - - const hash_t hash_bin = 0; - hash_group = sink.hash_groups[group_idx].get(); - global_sort = sink.hash_groups[group_idx]->global_sort.get(); - - sink.bin_groups[hash_bin] = group_idx; -} - -void PartitionLocalMergeState::Prepare() { - merge_state->group_data.reset(); - - auto &global_sort = *merge_state->global_sort; - global_sort.PrepareMergePhase(); -} - -void PartitionLocalMergeState::Merge() { - auto &global_sort = *merge_state->global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); -} - -void PartitionLocalMergeState::Sorted() { - merge_state->sink.OnSortedPartition(merge_state->group_idx); -} - -void PartitionLocalMergeState::ExecuteTask() { - switch (stage) { - case PartitionSortStage::SCAN: - Scan(); - break; - case PartitionSortStage::PREPARE: - Prepare(); - break; - case PartitionSortStage::MERGE: - Merge(); - break; - case PartitionSortStage::SORTED: - Sorted(); - break; - default: - throw InternalException("Unexpected PartitionSortStage in ExecuteTask!"); - } - - merge_state->CompleteTask(); - finished = true; -} - -bool PartitionGlobalMergeState::AssignTask(PartitionLocalMergeState &local_state) { - lock_guard guard(lock); - - if (tasks_assigned >= total_tasks) { - return false; - } - - local_state.merge_state = this; - local_state.stage = stage; - local_state.finished = false; - tasks_assigned++; - - return true; -} - -void PartitionGlobalMergeState::CompleteTask() { - lock_guard guard(lock); - - ++tasks_completed; -} - -bool PartitionGlobalMergeState::TryPrepareNextStage() { - lock_guard guard(lock); - - if (tasks_completed < total_tasks) { - return false; - } - - tasks_assigned = tasks_completed = 0; - - switch (stage) { - case PartitionSortStage::INIT: - // If the partitions are unordered, don't scan in parallel - // because it produces non-deterministic orderings. - // This can theoretically happen with ORDER BY, - // but that is something the query should be explicit about. - total_tasks = sink.orders.size() > sink.partitions.size() ? num_threads : 1; - stage = PartitionSortStage::SCAN; - return true; - - case PartitionSortStage::SCAN: - total_tasks = 1; - stage = PartitionSortStage::PREPARE; - return true; - - case PartitionSortStage::PREPARE: - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - stage = PartitionSortStage::MERGE; - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::MERGE: - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - break; - } - global_sort->InitializeMergeRound(); - total_tasks = num_threads; - return true; - - case PartitionSortStage::SORTED: - stage = PartitionSortStage::FINISHED; - total_tasks = 0; - return false; - - case PartitionSortStage::FINISHED: - return false; - } - - stage = PartitionSortStage::SORTED; - total_tasks = 1; - - return true; -} - -PartitionGlobalMergeStates::PartitionGlobalMergeStates(PartitionGlobalSinkState &sink) { - // Schedule all the sorts for maximum thread utilisation - if (sink.grouping_data) { - auto &partitions = sink.grouping_data->GetPartitions(); - sink.bin_groups.resize(partitions.size(), partitions.size()); - for (hash_t hash_bin = 0; hash_bin < partitions.size(); ++hash_bin) { - auto &group_data = partitions[hash_bin]; - // Prepare for merge sort phase - if (group_data->Count()) { - auto state = make_uniq(sink, std::move(group_data), hash_bin); - states.emplace_back(std::move(state)); - } - } - } else { - // OVER(ORDER BY...) - // Already sunk into the single global sort, so set up single merge with no data - sink.bin_groups.resize(1, 1); - auto state = make_uniq(sink); - states.emplace_back(std::move(state)); - } - - sink.OnBeginMerge(); -} - -class PartitionMergeTask : public ExecutorTask { -public: - PartitionMergeTask(shared_ptr event_p, ClientContext &context_p, PartitionGlobalMergeStates &hash_groups_p, - PartitionGlobalSinkState &gstate, const PhysicalOperator &op) - : ExecutorTask(context_p, std::move(event_p), op), local_state(gstate), hash_groups(hash_groups_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - struct ExecutorCallback : public PartitionGlobalMergeStates::Callback { - explicit ExecutorCallback(Executor &executor) : executor(executor) { - } - - bool HasError() const override { - return executor.HasError(); - } - - Executor &executor; - }; - - PartitionLocalMergeState local_state; - PartitionGlobalMergeStates &hash_groups; -}; - -bool PartitionGlobalMergeStates::ExecuteTask(PartitionLocalMergeState &local_state, Callback &callback) { - // Loop until all hash groups are done - size_t sorted = 0; - while (sorted < states.size()) { - // First check if there is an unfinished task for this thread - if (callback.HasError()) { - return false; - } - if (!local_state.TaskFinished()) { - local_state.ExecuteTask(); - continue; - } - - // Thread is done with its assigned task, try to fetch new work - for (auto group = sorted; group < states.size(); ++group) { - auto &global_state = states[group]; - if (global_state->IsFinished()) { - // This hash group is done - // Update the high water mark of densely completed groups - if (sorted == group) { - ++sorted; - } - continue; - } - - // Try to assign work for this hash group to this thread - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // Hash group global state couldn't assign a task to this thread - // Try to prepare the next stage - if (!global_state->TryPrepareNextStage()) { - // This current hash group is not yet done - // But we were not able to assign a task for it to this thread - // See if the next hash group is better - continue; - } - - // We were able to prepare the next stage for this hash group! - // Try to assign a task once more - if (global_state->AssignTask(local_state)) { - // We assigned a task to this thread! - // Break out of this loop to re-enter the top-level loop and execute the task - break; - } - - // We were able to prepare the next merge round, - // but we were not able to assign a task for it to this thread - // The tasks were assigned to other threads while this thread waited for the lock - // Go to the next iteration to see if another hash group has a task - } - } - - return true; -} - -TaskExecutionResult PartitionMergeTask::ExecuteTask(TaskExecutionMode mode) { - ExecutorCallback callback(executor); - - if (!hash_groups.ExecuteTask(local_state, callback)) { - return TaskExecutionResult::TASK_ERROR; - } - - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -void PartitionMergeEvent::Schedule() { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.emplace_back(make_uniq(shared_from_this(), context, merge_states, gstate, op)); - } - SetTasks(std::move(merge_tasks)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/radix_sort.cpp b/src/duckdb/src/common/sort/radix_sort.cpp deleted file mode 100644 index b193cee61..000000000 --- a/src/duckdb/src/common/sort/radix_sort.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/duckdb_pdqsort.hpp" -#include "duckdb/common/sort/sort.hpp" - -namespace duckdb { - -//! Calls std::sort on strings that are tied by their prefix after the radix sort -static void SortTiedBlobs(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &start, const idx_t &end, - const idx_t &tie_col, bool *ties, const data_ptr_t blob_ptr, const SortLayout &sort_layout) { - const auto row_width = sort_layout.blob_layout.GetRowWidth(); - // Locate the first blob row in question - data_ptr_t row_ptr = dataptr + start * sort_layout.entry_size; - data_ptr_t blob_row_ptr = blob_ptr + Load(row_ptr + sort_layout.comparison_size) * row_width; - if (!Comparators::TieIsBreakable(tie_col, blob_row_ptr, sort_layout)) { - // Quick check to see if ties can be broken - return; - } - // Fill pointer array for sorting - auto ptr_block = make_unsafe_uniq_array_uninitialized(end - start); - auto entry_ptrs = (data_ptr_t *)ptr_block.get(); - for (idx_t i = start; i < end; i++) { - entry_ptrs[i - start] = row_ptr; - row_ptr += sort_layout.entry_size; - } - // Slow pointer-based sorting - const int order = sort_layout.order_types[tie_col] == OrderType::DESCENDING ? -1 : 1; - const idx_t &col_idx = sort_layout.sorting_to_blob_col.at(tie_col); - const auto &tie_col_offset = sort_layout.blob_layout.GetOffsets()[col_idx]; - auto logical_type = sort_layout.blob_layout.GetTypes()[col_idx]; - std::sort(entry_ptrs, entry_ptrs + end - start, - [&blob_ptr, &order, &sort_layout, &tie_col_offset, &row_width, &logical_type](const data_ptr_t l, - const data_ptr_t r) { - idx_t left_idx = Load(l + sort_layout.comparison_size); - idx_t right_idx = Load(r + sort_layout.comparison_size); - data_ptr_t left_ptr = blob_ptr + left_idx * row_width + tie_col_offset; - data_ptr_t right_ptr = blob_ptr + right_idx * row_width + tie_col_offset; - return order * Comparators::CompareVal(left_ptr, right_ptr, logical_type) < 0; - }); - // Re-order - auto temp_block = buffer_manager.GetBufferAllocator().Allocate((end - start) * sort_layout.entry_size); - data_ptr_t temp_ptr = temp_block.get(); - for (idx_t i = 0; i < end - start; i++) { - FastMemcpy(temp_ptr, entry_ptrs[i], sort_layout.entry_size); - temp_ptr += sort_layout.entry_size; - } - memcpy(dataptr + start * sort_layout.entry_size, temp_block.get(), (end - start) * sort_layout.entry_size); - // Determine if there are still ties (if this is not the last column) - if (tie_col < sort_layout.column_count - 1) { - data_ptr_t idx_ptr = dataptr + start * sort_layout.entry_size + sort_layout.comparison_size; - // Load current entry - data_ptr_t current_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - for (idx_t i = 0; i < end - start - 1; i++) { - // Load next entry and compare - idx_ptr += sort_layout.entry_size; - data_ptr_t next_ptr = blob_ptr + Load(idx_ptr) * row_width + tie_col_offset; - ties[start + i] = Comparators::CompareVal(current_ptr, next_ptr, logical_type) == 0; - current_ptr = next_ptr; - } - } -} - -//! Identifies sequences of rows that are tied by the prefix of a blob column, and sorts them -static void SortTiedBlobs(BufferManager &buffer_manager, SortedBlock &sb, bool *ties, data_ptr_t dataptr, - const idx_t &count, const idx_t &tie_col, const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - auto &blob_block = *sb.blob_sorting_data->data_blocks.back(); - auto blob_handle = buffer_manager.Pin(blob_block.block); - const data_ptr_t blob_ptr = blob_handle.Ptr(); - - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - SortTiedBlobs(buffer_manager, dataptr, i, j + 1, tie_col, ties, blob_ptr, sort_layout); - i = j; - } -} - -//! Returns whether there are any 'true' values in the ties[] array -static bool AnyTies(bool ties[], const idx_t &count) { - D_ASSERT(!ties[count - 1]); - bool any_ties = false; - for (idx_t i = 0; i < count - 1; i++) { - any_ties = any_ties || ties[i]; - } - return any_ties; -} - -//! Compares subsequent rows to check for ties -static void ComputeTies(data_ptr_t dataptr, const idx_t &count, const idx_t &col_offset, const idx_t &tie_size, - bool ties[], const SortLayout &sort_layout) { - D_ASSERT(!ties[count - 1]); - D_ASSERT(col_offset + tie_size <= sort_layout.comparison_size); - // Align dataptr - dataptr += col_offset; - for (idx_t i = 0; i < count - 1; i++) { - ties[i] = ties[i] && FastMemcmp(dataptr, dataptr + sort_layout.entry_size, tie_size) == 0; - dataptr += sort_layout.entry_size; - } -} - -//! Textbook LSD radix sort -void RadixSortLSD(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &sorting_size) { - auto temp_block = buffer_manager.GetBufferAllocator().Allocate(count * row_width); - bool swap = false; - - idx_t counts[SortConstants::VALUES_PER_RADIX]; - for (idx_t r = 1; r <= sorting_size; r++) { - // Init counts to 0 - memset(counts, 0, sizeof(counts)); - // Const some values for convenience - const data_ptr_t source_ptr = swap ? temp_block.get() : dataptr; - const data_ptr_t target_ptr = swap ? dataptr : temp_block.get(); - const idx_t offset = col_offset + sorting_size - r; - // Collect counts - data_ptr_t offset_ptr = source_ptr + offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute offsets from counts - idx_t max_count = counts[0]; - for (idx_t val = 1; val < SortConstants::VALUES_PER_RADIX; val++) { - max_count = MaxValue(max_count, counts[val]); - counts[val] = counts[val] + counts[val - 1]; - } - if (max_count == count) { - continue; - } - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr + (count - 1) * row_width; - for (idx_t i = 0; i < count; i++) { - idx_t &radix_offset = --counts[*(row_ptr + offset)]; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr -= row_width; - } - swap = !swap; - } - // Move data back to original buffer (if it was swapped) - if (swap) { - memcpy(dataptr, temp_block.get(), count * row_width); - } -} - -//! Insertion sort, used when count of values is low -inline void InsertionSort(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, - const idx_t &col_offset, const idx_t &row_width, const idx_t &total_comp_width, - const idx_t &offset, bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - if (count > 1) { - const idx_t total_offset = col_offset + offset; - auto temp_val = make_unsafe_uniq_array_uninitialized(row_width); - const data_ptr_t val = temp_val.get(); - const auto comp_width = total_comp_width - offset; - for (idx_t i = 1; i < count; i++) { - FastMemcpy(val, source_ptr + i * row_width, row_width); - idx_t j = i; - while (j > 0 && - FastMemcmp(source_ptr + (j - 1) * row_width + total_offset, val + total_offset, comp_width) > 0) { - FastMemcpy(source_ptr + j * row_width, source_ptr + (j - 1) * row_width, row_width); - j--; - } - FastMemcpy(source_ptr + j * row_width, val, row_width); - } - } - if (swap) { - memcpy(target_ptr, source_ptr, count * row_width); - } -} - -//! MSD radix sort that switches to insertion sort with low bucket sizes -void RadixSortMSD(const data_ptr_t orig_ptr, const data_ptr_t temp_ptr, const idx_t &count, const idx_t &col_offset, - const idx_t &row_width, const idx_t &comp_width, const idx_t &offset, idx_t locations[], bool swap) { - const data_ptr_t source_ptr = swap ? temp_ptr : orig_ptr; - const data_ptr_t target_ptr = swap ? orig_ptr : temp_ptr; - // Init counts to 0 - memset(locations, 0, SortConstants::MSD_RADIX_LOCATIONS * sizeof(idx_t)); - idx_t *counts = locations + 1; - // Collect counts - const idx_t total_offset = col_offset + offset; - data_ptr_t offset_ptr = source_ptr + total_offset; - for (idx_t i = 0; i < count; i++) { - counts[*offset_ptr]++; - offset_ptr += row_width; - } - // Compute locations from counts - idx_t max_count = 0; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - max_count = MaxValue(max_count, counts[radix]); - counts[radix] += locations[radix]; - } - if (max_count != count) { - // Re-order the data in temporary array - data_ptr_t row_ptr = source_ptr; - for (idx_t i = 0; i < count; i++) { - const idx_t &radix_offset = locations[*(row_ptr + total_offset)]++; - FastMemcpy(target_ptr + radix_offset * row_width, row_ptr, row_width); - row_ptr += row_width; - } - swap = !swap; - } - // Check if done - if (offset == comp_width - 1) { - if (swap) { - memcpy(orig_ptr, temp_ptr, count * row_width); - } - return; - } - if (max_count == count) { - RadixSortMSD(orig_ptr, temp_ptr, count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - return; - } - // Recurse - idx_t radix_count = locations[0]; - for (idx_t radix = 0; radix < SortConstants::VALUES_PER_RADIX; radix++) { - const idx_t loc = (locations[radix] - radix_count) * row_width; - if (radix_count > SortConstants::INSERTION_SORT_THRESHOLD) { - RadixSortMSD(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - locations + SortConstants::MSD_RADIX_LOCATIONS, swap); - } else if (radix_count != 0) { - InsertionSort(orig_ptr + loc, temp_ptr + loc, radix_count, col_offset, row_width, comp_width, offset + 1, - swap); - } - radix_count = locations[radix + 1] - locations[radix]; - } -} - -//! Calls different sort functions, depending on the count and sorting sizes -void RadixSort(BufferManager &buffer_manager, const data_ptr_t &dataptr, const idx_t &count, const idx_t &col_offset, - const idx_t &sorting_size, const SortLayout &sort_layout, bool contains_string) { - - if (contains_string) { - auto begin = duckdb_pdqsort::PDQIterator(dataptr, sort_layout.entry_size); - auto end = begin + count; - duckdb_pdqsort::PDQConstants constants(sort_layout.entry_size, col_offset, sorting_size, *end); - return duckdb_pdqsort::pdqsort_branchless(begin, begin + count, constants); - } - - if (count <= SortConstants::INSERTION_SORT_THRESHOLD) { - return InsertionSort(dataptr, nullptr, count, col_offset, sort_layout.entry_size, sorting_size, 0, false); - } - - if (sorting_size <= SortConstants::MSD_RADIX_SORT_SIZE_THRESHOLD) { - return RadixSortLSD(buffer_manager, dataptr, count, col_offset, sort_layout.entry_size, sorting_size); - } - - const auto block_size = buffer_manager.GetBlockSize(); - auto temp_block = - buffer_manager.Allocate(MemoryTag::ORDER_BY, MaxValue(count * sort_layout.entry_size, block_size)); - auto pre_allocated_array = - make_unsafe_uniq_array_uninitialized(sorting_size * SortConstants::MSD_RADIX_LOCATIONS); - RadixSortMSD(dataptr, temp_block.Ptr(), count, col_offset, sort_layout.entry_size, sorting_size, 0, - pre_allocated_array.get(), false); -} - -//! Identifies sequences of rows that are tied, and calls radix sort on these -static void SubSortTiedTuples(BufferManager &buffer_manager, const data_ptr_t dataptr, const idx_t &count, - const idx_t &col_offset, const idx_t &sorting_size, bool ties[], - const SortLayout &sort_layout, bool contains_string) { - D_ASSERT(!ties[count - 1]); - for (idx_t i = 0; i < count; i++) { - if (!ties[i]) { - continue; - } - idx_t j; - for (j = i + 1; j < count; j++) { - if (!ties[j]) { - break; - } - } - RadixSort(buffer_manager, dataptr + i * sort_layout.entry_size, j - i + 1, col_offset, sorting_size, - sort_layout, contains_string); - i = j; - } -} - -void LocalSortState::SortInMemory() { - auto &sb = *sorted_blocks.back(); - auto &block = *sb.radix_sorting_data.back(); - const auto &count = block.count; - auto handle = buffer_manager->Pin(block.block); - const auto dataptr = handle.Ptr(); - // Assign an index to each row - data_ptr_t idx_dataptr = dataptr + sort_layout->comparison_size; - for (uint32_t i = 0; i < count; i++) { - Store(i, idx_dataptr); - idx_dataptr += sort_layout->entry_size; - } - // Radix sort and break ties until no more ties, or until all columns are sorted - idx_t sorting_size = 0; - idx_t col_offset = 0; - unsafe_unique_array ties_ptr; - bool *ties = nullptr; - bool contains_string = false; - for (idx_t i = 0; i < sort_layout->column_count; i++) { - sorting_size += sort_layout->column_sizes[i]; - contains_string = contains_string || sort_layout->logical_types[i].InternalType() == PhysicalType::VARCHAR; - if (sort_layout->constant_size[i] && i < sort_layout->column_count - 1) { - // Add columns to the sorting size until we reach a variable size column, or the last column - continue; - } - - if (!ties) { - // This is the first sort - RadixSort(*buffer_manager, dataptr, count, col_offset, sorting_size, *sort_layout, contains_string); - ties_ptr = make_unsafe_uniq_array_uninitialized(count); - ties = ties_ptr.get(); - std::fill_n(ties, count - 1, true); - ties[count - 1] = false; - } else { - // For subsequent sorts, we only have to subsort the tied tuples - SubSortTiedTuples(*buffer_manager, dataptr, count, col_offset, sorting_size, ties, *sort_layout, - contains_string); - } - - contains_string = false; - - if (sort_layout->constant_size[i] && i == sort_layout->column_count - 1) { - // All columns are sorted, no ties to break because last column is constant size - break; - } - - ComputeTies(dataptr, count, col_offset, sorting_size, ties, *sort_layout); - if (!AnyTies(ties, count)) { - // No ties, stop sorting - break; - } - - if (!sort_layout->constant_size[i]) { - SortTiedBlobs(*buffer_manager, sb, ties, dataptr, count, i, *sort_layout); - if (!AnyTies(ties, count)) { - // No more ties after tie-breaking, stop - break; - } - } - - col_offset += sorting_size; - sorting_size = 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sort_state.cpp b/src/duckdb/src/common/sort/sort_state.cpp deleted file mode 100644 index 386f3498e..000000000 --- a/src/duckdb/src/common/sort/sort_state.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -#include -#include - -namespace duckdb { - -idx_t GetNestedSortingColSize(idx_t &col_size, const LogicalType &type) { - auto physical_type = type.InternalType(); - if (TypeIsConstantSize(physical_type)) { - col_size += GetTypeIdSize(physical_type); - return 0; - } else { - switch (physical_type) { - case PhysicalType::VARCHAR: { - // Nested strings are between 4 and 11 chars long for alignment - auto size_before_str = col_size; - col_size += 11; - col_size -= (col_size - 12) % 8; - return col_size - size_before_str; - } - case PhysicalType::LIST: - // Lists get 2 bytes (null and empty list) - col_size += 2; - return GetNestedSortingColSize(col_size, ListType::GetChildType(type)); - case PhysicalType::STRUCT: - // Structs get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, StructType::GetChildType(type, 0)); - case PhysicalType::ARRAY: - // Arrays get 1 bytes (null) - col_size++; - return GetNestedSortingColSize(col_size, ArrayType::GetChildType(type)); - default: - throw NotImplementedException("Unable to order column with type %s", type.ToString()); - } - } -} - -SortLayout::SortLayout(const vector &orders) - : column_count(orders.size()), all_constant(true), comparison_size(0), entry_size(0) { - vector blob_layout_types; - for (idx_t i = 0; i < column_count; i++) { - const auto &order = orders[i]; - - order_types.push_back(order.type); - order_by_null_types.push_back(order.null_order); - auto &expr = *order.expression; - logical_types.push_back(expr.return_type); - - auto physical_type = expr.return_type.InternalType(); - constant_size.push_back(TypeIsConstantSize(physical_type)); - - if (order.stats) { - stats.push_back(order.stats.get()); - has_null.push_back(stats.back()->CanHaveNull()); - } else { - stats.push_back(nullptr); - has_null.push_back(true); - } - - idx_t col_size = has_null.back() ? 1 : 0; - prefix_lengths.push_back(0); - if (!TypeIsConstantSize(physical_type) && physical_type != PhysicalType::VARCHAR) { - prefix_lengths.back() = GetNestedSortingColSize(col_size, expr.return_type); - } else if (physical_type == PhysicalType::VARCHAR) { - idx_t size_before = col_size; - if (stats.back() && StringStats::HasMaxStringLength(*stats.back())) { - col_size += StringStats::MaxStringLength(*stats.back()); - if (col_size > 12) { - col_size = 12; - } else { - constant_size.back() = true; - } - } else { - col_size = 12; - } - prefix_lengths.back() = col_size - size_before; - } else { - col_size += GetTypeIdSize(physical_type); - } - - comparison_size += col_size; - column_sizes.push_back(col_size); - } - entry_size = comparison_size + sizeof(uint32_t); - - // 8-byte alignment - if (entry_size % 8 != 0) { - // First assign more bytes to strings instead of aligning - idx_t bytes_to_fill = 8 - (entry_size % 8); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - if (bytes_to_fill == 0) { - break; - } - if (logical_types[col_idx].InternalType() == PhysicalType::VARCHAR && stats[col_idx] && - StringStats::HasMaxStringLength(*stats[col_idx])) { - idx_t diff = StringStats::MaxStringLength(*stats[col_idx]) - prefix_lengths[col_idx]; - if (diff > 0) { - // Increase all sizes accordingly - idx_t increase = MinValue(bytes_to_fill, diff); - column_sizes[col_idx] += increase; - prefix_lengths[col_idx] += increase; - constant_size[col_idx] = increase == diff; - comparison_size += increase; - entry_size += increase; - bytes_to_fill -= increase; - } - } - } - entry_size = AlignValue(entry_size); - } - - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - all_constant = all_constant && constant_size[col_idx]; - if (!constant_size[col_idx]) { - sorting_to_blob_col[col_idx] = blob_layout_types.size(); - blob_layout_types.push_back(logical_types[col_idx]); - } - } - - blob_layout.Initialize(blob_layout_types); -} - -SortLayout SortLayout::GetPrefixComparisonLayout(idx_t num_prefix_cols) const { - SortLayout result; - result.column_count = num_prefix_cols; - result.all_constant = true; - result.comparison_size = 0; - for (idx_t col_idx = 0; col_idx < num_prefix_cols; col_idx++) { - result.order_types.push_back(order_types[col_idx]); - result.order_by_null_types.push_back(order_by_null_types[col_idx]); - result.logical_types.push_back(logical_types[col_idx]); - - result.all_constant = result.all_constant && constant_size[col_idx]; - result.constant_size.push_back(constant_size[col_idx]); - - result.comparison_size += column_sizes[col_idx]; - result.column_sizes.push_back(column_sizes[col_idx]); - - result.prefix_lengths.push_back(prefix_lengths[col_idx]); - result.stats.push_back(stats[col_idx]); - result.has_null.push_back(has_null[col_idx]); - } - result.entry_size = entry_size; - result.blob_layout = blob_layout; - result.sorting_to_blob_col = sorting_to_blob_col; - return result; -} - -LocalSortState::LocalSortState() : initialized(false) { - if (!Radix::IsLittleEndian()) { - throw NotImplementedException("Sorting is not supported on big endian architectures"); - } -} - -void LocalSortState::Initialize(GlobalSortState &global_sort_state, BufferManager &buffer_manager_p) { - sort_layout = &global_sort_state.sort_layout; - payload_layout = &global_sort_state.payload_layout; - buffer_manager = &buffer_manager_p; - const auto block_size = buffer_manager->GetBlockSize(); - - // Radix sorting data - auto entries_per_block = RowDataCollection::EntriesPerBlock(sort_layout->entry_size, block_size); - radix_sorting_data = make_uniq(*buffer_manager, entries_per_block, sort_layout->entry_size); - - // Blob sorting data - if (!sort_layout->all_constant) { - auto blob_row_width = sort_layout->blob_layout.GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(blob_row_width, block_size); - blob_sorting_data = make_uniq(*buffer_manager, entries_per_block, blob_row_width); - blob_sorting_heap = make_uniq(*buffer_manager, block_size, 1U, true); - } - - // Payload data - auto payload_row_width = payload_layout->GetRowWidth(); - entries_per_block = RowDataCollection::EntriesPerBlock(payload_row_width, block_size); - payload_data = make_uniq(*buffer_manager, entries_per_block, payload_row_width); - payload_heap = make_uniq(*buffer_manager, block_size, 1U, true); - initialized = true; -} - -void LocalSortState::SinkChunk(DataChunk &sort, DataChunk &payload) { - D_ASSERT(sort.size() == payload.size()); - // Build and serialize sorting data to radix sortable rows - auto data_pointers = FlatVector::GetData(addresses); - auto handles = radix_sorting_data->Build(sort.size(), data_pointers, nullptr); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - bool has_null = sort_layout->has_null[sort_col]; - bool nulls_first = sort_layout->order_by_null_types[sort_col] == OrderByNullType::NULLS_FIRST; - bool desc = sort_layout->order_types[sort_col] == OrderType::DESCENDING; - RowOperations::RadixScatter(sort.data[sort_col], sort.size(), sel_ptr, sort.size(), data_pointers, desc, - has_null, nulls_first, sort_layout->prefix_lengths[sort_col], - sort_layout->column_sizes[sort_col]); - } - - // Also fully serialize blob sorting columns (to be able to break ties - if (!sort_layout->all_constant) { - DataChunk blob_chunk; - blob_chunk.SetCardinality(sort.size()); - for (idx_t sort_col = 0; sort_col < sort.ColumnCount(); sort_col++) { - if (!sort_layout->constant_size[sort_col]) { - blob_chunk.data.emplace_back(sort.data[sort_col]); - } - } - handles = blob_sorting_data->Build(blob_chunk.size(), data_pointers, nullptr); - auto blob_data = blob_chunk.ToUnifiedFormat(); - RowOperations::Scatter(blob_chunk, blob_data.get(), sort_layout->blob_layout, addresses, *blob_sorting_heap, - sel_ptr, blob_chunk.size()); - D_ASSERT(blob_sorting_heap->keep_pinned); - } - - // Finally, serialize payload data - handles = payload_data->Build(payload.size(), data_pointers, nullptr); - auto input_data = payload.ToUnifiedFormat(); - RowOperations::Scatter(payload, input_data.get(), *payload_layout, addresses, *payload_heap, sel_ptr, - payload.size()); - D_ASSERT(payload_heap->keep_pinned); -} - -idx_t LocalSortState::SizeInBytes() const { - idx_t size_in_bytes = radix_sorting_data->SizeInBytes() + payload_data->SizeInBytes(); - if (!sort_layout->all_constant) { - size_in_bytes += blob_sorting_data->SizeInBytes() + blob_sorting_heap->SizeInBytes(); - } - if (!payload_layout->AllConstant()) { - size_in_bytes += payload_heap->SizeInBytes(); - } - return size_in_bytes; -} - -void LocalSortState::Sort(GlobalSortState &global_sort_state, bool reorder_heap) { - D_ASSERT(radix_sorting_data->count == payload_data->count); - if (radix_sorting_data->count == 0) { - return; - } - // Move all data to a single SortedBlock - sorted_blocks.emplace_back(make_uniq(*buffer_manager, global_sort_state)); - auto &sb = *sorted_blocks.back(); - // Fixed-size sorting data - auto sorting_block = ConcatenateBlocks(*radix_sorting_data); - sb.radix_sorting_data.push_back(std::move(sorting_block)); - // Variable-size sorting data - if (!sort_layout->all_constant) { - auto &blob_data = *blob_sorting_data; - auto new_block = ConcatenateBlocks(blob_data); - sb.blob_sorting_data->data_blocks.push_back(std::move(new_block)); - } - // Payload data - auto payload_block = ConcatenateBlocks(*payload_data); - sb.payload_data->data_blocks.push_back(std::move(payload_block)); - // Now perform the actual sort - SortInMemory(); - // Re-order before the merge sort - ReOrder(global_sort_state, reorder_heap); -} - -unique_ptr LocalSortState::ConcatenateBlocks(RowDataCollection &row_data) { - // Don't copy and delete if there is only one block. - if (row_data.blocks.size() == 1) { - auto new_block = std::move(row_data.blocks[0]); - row_data.blocks.clear(); - row_data.count = 0; - return new_block; - } - // Create block with the correct capacity - auto &buffer_manager = row_data.buffer_manager; - const idx_t &entry_size = row_data.entry_size; - idx_t capacity = MaxValue((buffer_manager.GetBlockSize() + entry_size - 1) / entry_size, row_data.count); - auto new_block = make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, entry_size); - new_block->count = row_data.count; - auto new_block_handle = buffer_manager.Pin(new_block->block); - data_ptr_t new_block_ptr = new_block_handle.Ptr(); - // Copy the data of the blocks into a single block - for (idx_t i = 0; i < row_data.blocks.size(); i++) { - auto &block = row_data.blocks[i]; - auto block_handle = buffer_manager.Pin(block->block); - memcpy(new_block_ptr, block_handle.Ptr(), block->count * entry_size); - new_block_ptr += block->count * entry_size; - block.reset(); - } - row_data.blocks.clear(); - row_data.count = 0; - return new_block; -} - -void LocalSortState::ReOrder(SortedData &sd, data_ptr_t sorting_ptr, RowDataCollection &heap, GlobalSortState &gstate, - bool reorder_heap) { - sd.swizzled = reorder_heap; - auto &unordered_data_block = sd.data_blocks.back(); - const idx_t count = unordered_data_block->count; - auto unordered_data_handle = buffer_manager->Pin(unordered_data_block->block); - const data_ptr_t unordered_data_ptr = unordered_data_handle.Ptr(); - // Create new block that will hold re-ordered row data - auto ordered_data_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, - unordered_data_block->capacity, unordered_data_block->entry_size); - ordered_data_block->count = count; - auto ordered_data_handle = buffer_manager->Pin(ordered_data_block->block); - data_ptr_t ordered_data_ptr = ordered_data_handle.Ptr(); - // Re-order fixed-size row layout - const idx_t row_width = sd.layout.GetRowWidth(); - const idx_t sorting_entry_size = gstate.sort_layout.entry_size; - for (idx_t i = 0; i < count; i++) { - auto index = Load(sorting_ptr); - FastMemcpy(ordered_data_ptr, unordered_data_ptr + index * row_width, row_width); - ordered_data_ptr += row_width; - sorting_ptr += sorting_entry_size; - } - ordered_data_block->block->SetSwizzling( - sd.layout.AllConstant() || !sd.swizzled ? nullptr : "LocalSortState::ReOrder.ordered_data"); - // Replace the unordered data block with the re-ordered data block - sd.data_blocks.clear(); - sd.data_blocks.push_back(std::move(ordered_data_block)); - // Deal with the heap (if necessary) - if (!sd.layout.AllConstant() && reorder_heap) { - // Swizzle the column pointers to offsets - RowOperations::SwizzleColumns(sd.layout, ordered_data_handle.Ptr(), count); - sd.data_blocks.back()->block->SetSwizzling(nullptr); - // Create a single heap block to store the ordered heap - idx_t total_byte_offset = - std::accumulate(heap.blocks.begin(), heap.blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->byte_offset; }); - idx_t heap_block_size = MaxValue(total_byte_offset, buffer_manager->GetBlockSize()); - auto ordered_heap_block = make_uniq(MemoryTag::ORDER_BY, *buffer_manager, heap_block_size, 1U); - ordered_heap_block->count = count; - ordered_heap_block->byte_offset = total_byte_offset; - auto ordered_heap_handle = buffer_manager->Pin(ordered_heap_block->block); - data_ptr_t ordered_heap_ptr = ordered_heap_handle.Ptr(); - // Fill the heap in order - ordered_data_ptr = ordered_data_handle.Ptr(); - const idx_t heap_pointer_offset = sd.layout.GetHeapOffset(); - for (idx_t i = 0; i < count; i++) { - auto heap_row_ptr = Load(ordered_data_ptr + heap_pointer_offset); - auto heap_row_size = Load(heap_row_ptr); - memcpy(ordered_heap_ptr, heap_row_ptr, heap_row_size); - ordered_heap_ptr += heap_row_size; - ordered_data_ptr += row_width; - } - // Swizzle the base pointer to the offset of each row in the heap - RowOperations::SwizzleHeapPointer(sd.layout, ordered_data_handle.Ptr(), ordered_heap_handle.Ptr(), count); - // Move the re-ordered heap to the SortedData, and clear the local heap - sd.heap_blocks.push_back(std::move(ordered_heap_block)); - heap.pinned_blocks.clear(); - heap.blocks.clear(); - heap.count = 0; - } -} - -void LocalSortState::ReOrder(GlobalSortState &gstate, bool reorder_heap) { - auto &sb = *sorted_blocks.back(); - auto sorting_handle = buffer_manager->Pin(sb.radix_sorting_data.back()->block); - const data_ptr_t sorting_ptr = sorting_handle.Ptr() + gstate.sort_layout.comparison_size; - // Re-order variable size sorting columns - if (!gstate.sort_layout.all_constant) { - ReOrder(*sb.blob_sorting_data, sorting_ptr, *blob_sorting_heap, gstate, reorder_heap); - } - // And the payload - ReOrder(*sb.payload_data, sorting_ptr, *payload_heap, gstate, reorder_heap); -} - -GlobalSortState::GlobalSortState(BufferManager &buffer_manager, const vector &orders, - RowLayout &payload_layout) - : buffer_manager(buffer_manager), sort_layout(SortLayout(orders)), payload_layout(payload_layout), - block_capacity(0), external(false) { -} - -void GlobalSortState::AddLocalState(LocalSortState &local_sort_state) { - if (!local_sort_state.radix_sorting_data) { - return; - } - - // Sort accumulated data - // we only re-order the heap when the data is expected to not fit in memory - // re-ordering the heap avoids random access when reading/merging but incurs a significant cost of shuffling data - // when data fits in memory, doing random access on reads is cheaper than re-shuffling - local_sort_state.Sort(*this, external || !local_sort_state.sorted_blocks.empty()); - - // Append local state sorted data to this global state - lock_guard append_guard(lock); - for (auto &sb : local_sort_state.sorted_blocks) { - sorted_blocks.push_back(std::move(sb)); - } - auto &payload_heap = local_sort_state.payload_heap; - for (idx_t i = 0; i < payload_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(payload_heap->blocks[i])); - pinned_blocks.push_back(std::move(payload_heap->pinned_blocks[i])); - } - if (!sort_layout.all_constant) { - auto &blob_heap = local_sort_state.blob_sorting_heap; - for (idx_t i = 0; i < blob_heap->blocks.size(); i++) { - heap_blocks.push_back(std::move(blob_heap->blocks[i])); - pinned_blocks.push_back(std::move(blob_heap->pinned_blocks[i])); - } - } -} - -void GlobalSortState::PrepareMergePhase() { - // Determine if we need to use do an external sort - idx_t total_heap_size = - std::accumulate(sorted_blocks.begin(), sorted_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->HeapSize(); }); - if (external || (pinned_blocks.empty() && total_heap_size * 4 > buffer_manager.GetQueryMaxMemory())) { - external = true; - } - // Use the data that we have to determine which partition size to use during the merge - if (external && total_heap_size > 0) { - // If we have variable size data we need to be conservative, as there might be skew - idx_t max_block_size = 0; - for (auto &sb : sorted_blocks) { - idx_t size_in_bytes = sb->SizeInBytes(); - if (size_in_bytes > max_block_size) { - max_block_size = size_in_bytes; - block_capacity = sb->Count(); - } - } - } else { - for (auto &sb : sorted_blocks) { - block_capacity = MaxValue(block_capacity, sb->Count()); - } - } - // Unswizzle and pin heap blocks if we can fit everything in memory - if (!external) { - for (auto &sb : sorted_blocks) { - sb->blob_sorting_data->Unswizzle(); - sb->payload_data->Unswizzle(); - } - } -} - -void GlobalSortState::InitializeMergeRound() { - D_ASSERT(sorted_blocks_temp.empty()); - // If we reverse this list, the blocks that were merged last will be merged first in the next round - // These are still in memory, therefore this reduces the amount of read/write to disk! - std::reverse(sorted_blocks.begin(), sorted_blocks.end()); - // Uneven number of blocks - keep one on the side - if (sorted_blocks.size() % 2 == 1) { - odd_one_out = std::move(sorted_blocks.back()); - sorted_blocks.pop_back(); - } - // Init merge path path indices - pair_idx = 0; - num_pairs = sorted_blocks.size() / 2; - l_start = 0; - r_start = 0; - // Allocate room for merge results - for (idx_t p_idx = 0; p_idx < num_pairs; p_idx++) { - sorted_blocks_temp.emplace_back(); - } -} - -void GlobalSortState::CompleteMergeRound(bool keep_radix_data) { - sorted_blocks.clear(); - for (auto &sorted_block_vector : sorted_blocks_temp) { - sorted_blocks.push_back(make_uniq(buffer_manager, *this)); - sorted_blocks.back()->AppendSortedBlocks(sorted_block_vector); - } - sorted_blocks_temp.clear(); - if (odd_one_out) { - sorted_blocks.push_back(std::move(odd_one_out)); - odd_one_out = nullptr; - } - // Only one block left: Done! - if (sorted_blocks.size() == 1 && !keep_radix_data) { - sorted_blocks[0]->radix_sorting_data.clear(); - sorted_blocks[0]->blob_sorting_data = nullptr; - } -} -void GlobalSortState::Print() { - PayloadScanner scanner(*this, false); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - for (;;) { - scanner.Scan(chunk); - const auto count = chunk.size(); - if (!count) { - break; - } - chunk.Print(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/sort/sorted_block.cpp b/src/duckdb/src/common/sort/sorted_block.cpp deleted file mode 100644 index c4766c956..000000000 --- a/src/duckdb/src/common/sort/sorted_block.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/common/sort/sorted_block.hpp" - -#include "duckdb/common/constants.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" - -#include - -namespace duckdb { - -SortedData::SortedData(SortedDataType type, const RowLayout &layout, BufferManager &buffer_manager, - GlobalSortState &state) - : type(type), layout(layout), swizzled(state.external), buffer_manager(buffer_manager), state(state) { -} - -idx_t SortedData::Count() { - idx_t count = std::accumulate(data_blocks.begin(), data_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!layout.AllConstant() && state.external) { - D_ASSERT(count == std::accumulate(heap_blocks.begin(), heap_blocks.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; })); - } - return count; -} - -void SortedData::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + layout.GetRowWidth() - 1) / layout.GetRowWidth(), state.block_capacity); - data_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, layout.GetRowWidth())); - if (!layout.AllConstant() && state.external) { - heap_blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_size, 1U)); - D_ASSERT(data_blocks.size() == heap_blocks.size()); - } -} - -unique_ptr SortedData::CreateSlice(idx_t start_block_index, idx_t end_block_index, idx_t end_entry_index) { - // Add the corresponding blocks to the result - auto result = make_uniq(type, layout, buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->data_blocks.push_back(data_blocks[i]->Copy()); - if (!layout.AllConstant() && state.external) { - result->heap_blocks.push_back(heap_blocks[i]->Copy()); - } - } - // All of the blocks that come before block with idx = start_block_idx can be reset (other references exist) - for (idx_t i = 0; i < start_block_index; i++) { - data_blocks[i]->block = nullptr; - if (!layout.AllConstant() && state.external) { - heap_blocks[i]->block = nullptr; - } - } - // Use start and end entry indices to set the boundaries - D_ASSERT(end_entry_index <= result->data_blocks.back()->count); - result->data_blocks.back()->count = end_entry_index; - if (!layout.AllConstant() && state.external) { - result->heap_blocks.back()->count = end_entry_index; - } - return result; -} - -void SortedData::Unswizzle() { - if (layout.AllConstant() || !swizzled) { - return; - } - for (idx_t i = 0; i < data_blocks.size(); i++) { - auto &data_block = data_blocks[i]; - auto &heap_block = heap_blocks[i]; - D_ASSERT(data_block->block->IsSwizzled()); - auto data_handle_p = buffer_manager.Pin(data_block->block); - auto heap_handle_p = buffer_manager.Pin(heap_block->block); - RowOperations::UnswizzlePointers(layout, data_handle_p.Ptr(), heap_handle_p.Ptr(), data_block->count); - state.heap_blocks.push_back(std::move(heap_block)); - state.pinned_blocks.push_back(std::move(heap_handle_p)); - } - swizzled = false; - heap_blocks.clear(); -} - -SortedBlock::SortedBlock(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), state(state), sort_layout(state.sort_layout), - payload_layout(state.payload_layout) { - blob_sorting_data = make_uniq(SortedDataType::BLOB, sort_layout.blob_layout, buffer_manager, state); - payload_data = make_uniq(SortedDataType::PAYLOAD, payload_layout, buffer_manager, state); -} - -idx_t SortedBlock::Count() const { - idx_t count = std::accumulate(radix_sorting_data.begin(), radix_sorting_data.end(), (idx_t)0, - [](idx_t a, const unique_ptr &b) { return a + b->count; }); - if (!sort_layout.all_constant) { - D_ASSERT(count == blob_sorting_data->Count()); - } - D_ASSERT(count == payload_data->Count()); - return count; -} - -void SortedBlock::InitializeWrite() { - CreateBlock(); - if (!sort_layout.all_constant) { - blob_sorting_data->CreateBlock(); - } - payload_data->CreateBlock(); -} - -void SortedBlock::CreateBlock() { - const auto block_size = buffer_manager.GetBlockSize(); - auto capacity = MaxValue((block_size + sort_layout.entry_size - 1) / sort_layout.entry_size, state.block_capacity); - radix_sorting_data.push_back( - make_uniq(MemoryTag::ORDER_BY, buffer_manager, capacity, sort_layout.entry_size)); -} - -void SortedBlock::AppendSortedBlocks(vector> &sorted_blocks) { - D_ASSERT(Count() == 0); - for (auto &sb : sorted_blocks) { - for (auto &radix_block : sb->radix_sorting_data) { - radix_sorting_data.push_back(std::move(radix_block)); - } - if (!sort_layout.all_constant) { - for (auto &blob_block : sb->blob_sorting_data->data_blocks) { - blob_sorting_data->data_blocks.push_back(std::move(blob_block)); - } - for (auto &heap_block : sb->blob_sorting_data->heap_blocks) { - blob_sorting_data->heap_blocks.push_back(std::move(heap_block)); - } - } - for (auto &payload_data_block : sb->payload_data->data_blocks) { - payload_data->data_blocks.push_back(std::move(payload_data_block)); - } - if (!payload_data->layout.AllConstant()) { - for (auto &payload_heap_block : sb->payload_data->heap_blocks) { - payload_data->heap_blocks.push_back(std::move(payload_heap_block)); - } - } - } -} - -void SortedBlock::GlobalToLocalIndex(const idx_t &global_idx, idx_t &local_block_index, idx_t &local_entry_index) { - if (global_idx == Count()) { - local_block_index = radix_sorting_data.size() - 1; - local_entry_index = radix_sorting_data.back()->count; - return; - } - D_ASSERT(global_idx < Count()); - local_entry_index = global_idx; - for (local_block_index = 0; local_block_index < radix_sorting_data.size(); local_block_index++) { - const idx_t &block_count = radix_sorting_data[local_block_index]->count; - if (local_entry_index >= block_count) { - local_entry_index -= block_count; - } else { - break; - } - } - D_ASSERT(local_entry_index < radix_sorting_data[local_block_index]->count); -} - -unique_ptr SortedBlock::CreateSlice(const idx_t start, const idx_t end, idx_t &entry_idx) { - // Identify blocks/entry indices of this slice - idx_t start_block_index; - idx_t start_entry_index; - GlobalToLocalIndex(start, start_block_index, start_entry_index); - idx_t end_block_index; - idx_t end_entry_index; - GlobalToLocalIndex(end, end_block_index, end_entry_index); - // Add the corresponding blocks to the result - auto result = make_uniq(buffer_manager, state); - for (idx_t i = start_block_index; i <= end_block_index; i++) { - result->radix_sorting_data.push_back(radix_sorting_data[i]->Copy()); - } - // Reset all blocks that come before block with idx = start_block_idx (slice holds new reference) - for (idx_t i = 0; i < start_block_index; i++) { - radix_sorting_data[i]->block = nullptr; - } - // Use start and end entry indices to set the boundaries - entry_idx = start_entry_index; - D_ASSERT(end_entry_index <= result->radix_sorting_data.back()->count); - result->radix_sorting_data.back()->count = end_entry_index; - // Same for the var size sorting data - if (!sort_layout.all_constant) { - result->blob_sorting_data = blob_sorting_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - } - // And the payload data - result->payload_data = payload_data->CreateSlice(start_block_index, end_block_index, end_entry_index); - return result; -} - -idx_t SortedBlock::HeapSize() const { - idx_t result = 0; - if (!sort_layout.all_constant) { - for (auto &block : blob_sorting_data->heap_blocks) { - result += block->capacity; - } - } - if (!payload_layout.AllConstant()) { - for (auto &block : payload_data->heap_blocks) { - result += block->capacity; - } - } - return result; -} - -idx_t SortedBlock::SizeInBytes() const { - idx_t bytes = 0; - for (idx_t i = 0; i < radix_sorting_data.size(); i++) { - bytes += radix_sorting_data[i]->capacity * sort_layout.entry_size; - if (!sort_layout.all_constant) { - bytes += blob_sorting_data->data_blocks[i]->capacity * sort_layout.blob_layout.GetRowWidth(); - bytes += blob_sorting_data->heap_blocks[i]->capacity; - } - bytes += payload_data->data_blocks[i]->capacity * payload_layout.GetRowWidth(); - if (!payload_layout.AllConstant()) { - bytes += payload_data->heap_blocks[i]->capacity; - } - } - return bytes; -} - -SBScanState::SBScanState(BufferManager &buffer_manager, GlobalSortState &state) - : buffer_manager(buffer_manager), sort_layout(state.sort_layout), state(state), block_idx(0), entry_idx(0) { -} - -void SBScanState::PinRadix(idx_t block_idx_to) { - auto &radix_sorting_data = sb->radix_sorting_data; - D_ASSERT(block_idx_to < radix_sorting_data.size()); - auto &block = radix_sorting_data[block_idx_to]; - if (!radix_handle.IsValid() || radix_handle.GetBlockHandle() != block->block) { - radix_handle = buffer_manager.Pin(block->block); - } -} - -void SBScanState::PinData(SortedData &sd) { - D_ASSERT(block_idx < sd.data_blocks.size()); - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - - auto &data_block = sd.data_blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = buffer_manager.Pin(data_block->block); - } - if (sd.layout.AllConstant() || !state.external) { - return; - } - auto &heap_block = sd.heap_blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = buffer_manager.Pin(heap_block->block); - } -} - -data_ptr_t SBScanState::RadixPtr() const { - return radix_handle.Ptr() + entry_idx * sort_layout.entry_size; -} - -data_ptr_t SBScanState::DataPtr(SortedData &sd) const { - auto &data_handle = sd.type == SortedDataType::BLOB ? blob_sorting_data_handle : payload_data_handle; - D_ASSERT(sd.data_blocks[block_idx]->block->Readers() != 0 && - data_handle.GetBlockHandle() == sd.data_blocks[block_idx]->block); - return data_handle.Ptr() + entry_idx * sd.layout.GetRowWidth(); -} - -data_ptr_t SBScanState::HeapPtr(SortedData &sd) const { - return BaseHeapPtr(sd) + Load(DataPtr(sd) + sd.layout.GetHeapOffset()); -} - -data_ptr_t SBScanState::BaseHeapPtr(SortedData &sd) const { - auto &heap_handle = sd.type == SortedDataType::BLOB ? blob_sorting_heap_handle : payload_heap_handle; - D_ASSERT(!sd.layout.AllConstant() && state.external); - D_ASSERT(sd.heap_blocks[block_idx]->block->Readers() != 0 && - heap_handle.GetBlockHandle() == sd.heap_blocks[block_idx]->block); - return heap_handle.Ptr(); -} - -idx_t SBScanState::Remaining() const { - const auto &blocks = sb->radix_sorting_data; - idx_t remaining = 0; - if (block_idx < blocks.size()) { - remaining += blocks[block_idx]->count - entry_idx; - for (idx_t i = block_idx + 1; i < blocks.size(); i++) { - remaining += blocks[i]->count; - } - } - return remaining; -} - -void SBScanState::SetIndices(idx_t block_idx_to, idx_t entry_idx_to) { - block_idx = block_idx_to; - entry_idx = entry_idx_to; -} - -PayloadScanner::PayloadScanner(SortedData &sorted_data, GlobalSortState &global_sort_state, bool flush_p) { - auto count = sorted_data.Count(); - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant()) { - heap->count = count; - } - - if (flush_p) { - // If we are flushing, we can just move the data - rows->blocks = std::move(sorted_data.data_blocks); - if (!layout.AllConstant()) { - heap->blocks = std::move(sorted_data.heap_blocks); - } - } else { - // Not flushing, create references to the blocks - for (auto &block : sorted_data.data_blocks) { - rows->blocks.emplace_back(block->Copy()); - } - if (!layout.AllConstant()) { - for (auto &block : sorted_data.heap_blocks) { - heap->blocks.emplace_back(block->Copy()); - } - } - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, bool flush_p) - : PayloadScanner(*global_sort_state.sorted_blocks[0]->payload_data, global_sort_state, flush_p) { -} - -PayloadScanner::PayloadScanner(GlobalSortState &global_sort_state, idx_t block_idx, bool flush_p) { - auto &sorted_data = *global_sort_state.sorted_blocks[0]->payload_data; - auto count = sorted_data.data_blocks[block_idx]->count; - auto &layout = sorted_data.layout; - const auto block_size = global_sort_state.buffer_manager.GetBlockSize(); - - // Create collections to put the data into so we can use RowDataCollectionScanner - rows = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (flush_p) { - rows->blocks.emplace_back(std::move(sorted_data.data_blocks[block_idx])); - } else { - rows->blocks.emplace_back(sorted_data.data_blocks[block_idx]->Copy()); - } - rows->count = count; - - heap = make_uniq(global_sort_state.buffer_manager, block_size, 1U); - if (!sorted_data.layout.AllConstant() && sorted_data.swizzled) { - if (flush_p) { - heap->blocks.emplace_back(std::move(sorted_data.heap_blocks[block_idx])); - } else { - heap->blocks.emplace_back(sorted_data.heap_blocks[block_idx]->Copy()); - } - heap->count = count; - } - - scanner = make_uniq(*rows, *heap, layout, global_sort_state.external, flush_p); -} - -void PayloadScanner::Scan(DataChunk &chunk) { - scanner->Scan(chunk); -} - -int SBIterator::ComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for IEJoin!"); - } -} - -static idx_t GetBlockCountWithEmptyCheck(const GlobalSortState &gss) { - D_ASSERT(!gss.sorted_blocks.empty()); - return gss.sorted_blocks[0]->radix_sorting_data.size(); -} - -SBIterator::SBIterator(GlobalSortState &gss, ExpressionType comparison, idx_t entry_idx_p) - : sort_layout(gss.sort_layout), block_count(GetBlockCountWithEmptyCheck(gss)), block_capacity(gss.block_capacity), - entry_size(sort_layout.entry_size), all_constant(sort_layout.all_constant), external(gss.external), - cmp(ComparisonValue(comparison)), scan(gss.buffer_manager, gss), block_ptr(nullptr), entry_ptr(nullptr) { - - scan.sb = gss.sorted_blocks[0].get(); - scan.block_idx = block_count; - SetIndex(entry_idx_p); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/stacktrace.cpp b/src/duckdb/src/common/stacktrace.cpp deleted file mode 100644 index 7a42b35cf..000000000 --- a/src/duckdb/src/common/stacktrace.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include "duckdb/common/stacktrace.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" - -#if defined(__GLIBC__) || defined(__APPLE__) -#include -#include -#endif - -namespace duckdb { - -#if defined(__GLIBC__) || defined(__APPLE__) -static string UnmangleSymbol(string symbol) { - // find the mangled name - idx_t mangle_start = symbol.size(); - idx_t mangle_end = 0; - for (idx_t i = 0; i < symbol.size(); ++i) { - if (symbol[i] == '_') { - mangle_start = i; - break; - } - } - for (idx_t i = mangle_start; i < symbol.size(); i++) { - if (StringUtil::CharacterIsSpace(symbol[i]) || symbol[i] == ')' || symbol[i] == '+') { - mangle_end = i; - break; - } - } - if (mangle_start >= mangle_end) { - return symbol; - } - string mangled_symbol = symbol.substr(mangle_start, mangle_end - mangle_start); - - int status; - auto demangle_result = abi::__cxa_demangle(mangled_symbol.c_str(), nullptr, nullptr, &status); - if (status != 0 || !demangle_result) { - return symbol; - } - string result; - result += symbol.substr(0, mangle_start); - result += demangle_result; - result += symbol.substr(mangle_end); - free(demangle_result); - return result; -} - -static string CleanupStackTrace(string symbol) { -#ifdef __APPLE__ - // structure of frame pointers is [depth] [library] [pointer] [symbol] - // we are only interested in [depth] and [symbol] - - // find the depth - idx_t start; - for (start = 0; start < symbol.size(); start++) { - if (!StringUtil::CharacterIsDigit(symbol[start])) { - break; - } - } - - // now scan forward until we find the frame pointer - idx_t frame_end = symbol.size(); - for (idx_t i = start; i + 1 < symbol.size(); ++i) { - if (symbol[i] == '0' && symbol[i + 1] == 'x') { - idx_t k; - for (k = i + 2; k < symbol.size(); ++k) { - if (!StringUtil::CharacterIsHex(symbol[k])) { - break; - } - } - frame_end = k; - break; - } - } - static constexpr idx_t STACK_TRACE_INDENTATION = 8; - if (frame_end == symbol.size() || start >= STACK_TRACE_INDENTATION) { - // frame pointer not found - just preserve the original frame - return symbol; - } - idx_t space_count = STACK_TRACE_INDENTATION - start; - return symbol.substr(0, start) + string(space_count, ' ') + symbol.substr(frame_end, symbol.size() - frame_end); -#else - return symbol; -#endif -} - -string StackTrace::GetStacktracePointers(idx_t max_depth) { - string result; - auto callstack = unique_ptr(new void *[max_depth]); - int frames = backtrace(callstack.get(), NumericCast(max_depth)); - // skip two frames (these are always StackTrace::...) - for (idx_t i = 2; i < NumericCast(frames); i++) { - if (!result.empty()) { - result += ";"; - } - result += to_string(CastPointerToValue(callstack[i])); - } - return result; -} - -string StackTrace::ResolveStacktraceSymbols(const string &pointers) { - auto splits = StringUtil::Split(pointers, ";"); - idx_t frame_count = splits.size(); - auto callstack = unique_ptr(new void *[frame_count]); - for (idx_t i = 0; i < frame_count; i++) { - callstack[i] = cast_uint64_to_pointer(StringUtil::ToUnsigned(splits[i])); - } - string result; - char **strs = backtrace_symbols(callstack.get(), NumericCast(frame_count)); - for (idx_t i = 0; i < frame_count; i++) { - result += CleanupStackTrace(UnmangleSymbol(strs[i])); - result += "\n"; - } - free(reinterpret_cast(strs)); - return "\n" + result; -} - -#else -string StackTrace::GetStacktracePointers(idx_t max_depth) { - return string(); -} - -string StackTrace::ResolveStacktraceSymbols(const string &pointers) { - return string(); -} -#endif - -} // namespace duckdb diff --git a/src/duckdb/src/common/string_util.cpp b/src/duckdb/src/common/string_util.cpp deleted file mode 100644 index f96e3d64e..000000000 --- a/src/duckdb/src/common/string_util.cpp +++ /dev/null @@ -1,841 +0,0 @@ -#include "duckdb/common/string_util.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/stack.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/exception/parser_exception.hpp" -#include "duckdb/common/random_engine.hpp" -#include "jaro_winkler.hpp" -#include "utf8proc_wrapper.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "yyjson.hpp" - -using namespace duckdb_yyjson; // NOLINT - -namespace duckdb { - -string StringUtil::GenerateRandomName(idx_t length) { - RandomEngine engine; - std::stringstream ss; - for (idx_t i = 0; i < length; i++) { - ss << "0123456789abcdef"[engine.NextRandomInteger(0, 15)]; - } - return ss.str(); -} - -bool StringUtil::Contains(const string &haystack, const string &needle) { - return Find(haystack, needle).IsValid(); -} - -optional_idx StringUtil::Find(const string &haystack, const string &needle) { - auto index = haystack.find(needle); - if (index == string::npos) { - return optional_idx(); - } - return optional_idx(index); -} - -bool StringUtil::Contains(const string &haystack, const char &needle_char) { - return (haystack.find(needle_char) != string::npos); -} - -idx_t StringUtil::ToUnsigned(const string &str) { - return std::stoull(str); -} - -void StringUtil::LTrim(string &str) { - auto it = str.begin(); - while (it != str.end() && CharacterIsSpace(*it)) { - it++; - } - str.erase(str.begin(), it); -} - -// Remove trailing ' ', '\f', '\n', '\r', '\t', '\v' -void StringUtil::RTrim(string &str) { - str.erase(find_if(str.rbegin(), str.rend(), [](char ch) { return ch > 0 && !CharacterIsSpace(ch); }).base(), - str.end()); -} - -void StringUtil::RTrim(string &str, const string &chars_to_trim) { - str.erase(find_if(str.rbegin(), str.rend(), - [&chars_to_trim](char ch) { return ch > 0 && chars_to_trim.find(ch) == string::npos; }) - .base(), - str.end()); -} - -void StringUtil::Trim(string &str) { - StringUtil::LTrim(str); - StringUtil::RTrim(str); -} - -bool StringUtil::StartsWith(string str, string prefix) { - if (prefix.size() > str.size()) { - return false; - } - return equal(prefix.begin(), prefix.end(), str.begin()); -} - -bool StringUtil::EndsWith(const string &str, const string &suffix) { - if (suffix.size() > str.size()) { - return false; - } - return equal(suffix.rbegin(), suffix.rend(), str.rbegin()); -} - -string StringUtil::Repeat(const string &str, idx_t n) { - std::ostringstream os; - for (idx_t i = 0; i < n; i++) { - os << str; - } - return (os.str()); -} - -namespace string_util_internal { - -inline void SkipSpaces(const string &str, idx_t &index) { - while (index < str.size() && std::isspace(str[index])) { - index++; - } -} - -inline void ConsumeLetter(const string &str, idx_t &index, char expected) { - if (index >= str.size() || str[index] != expected) { - throw ParserException("Invalid quoted list: %s", str); - } - - index++; -} - -template -inline void TakeWhile(const string &str, idx_t &index, const F &cond, string &taker) { - while (index < str.size() && cond(str[index])) { - taker.push_back(str[index]); - index++; - } -} - -inline string TakePossiblyQuotedItem(const string &str, idx_t &index, char delimiter, char quote) { - string entry; - - if (str[index] == quote) { - index++; - TakeWhile( - str, index, [quote](char c) { return c != quote; }, entry); - ConsumeLetter(str, index, quote); - } else { - TakeWhile( - str, index, [delimiter, quote](char c) { return c != delimiter && c != quote && !std::isspace(c); }, entry); - } - - return entry; -} - -} // namespace string_util_internal - -vector StringUtil::SplitWithQuote(const string &str, char delimiter, char quote) { - vector entries; - idx_t i = 0; - - string_util_internal::SkipSpaces(str, i); - while (i < str.size()) { - if (!entries.empty()) { - string_util_internal::ConsumeLetter(str, i, delimiter); - } - - entries.emplace_back(string_util_internal::TakePossiblyQuotedItem(str, i, delimiter, quote)); - string_util_internal::SkipSpaces(str, i); - } - - return entries; -} - -vector StringUtil::SplitWithParentheses(const string &str, char delimiter, char par_open, char par_close) { - vector result; - string current; - stack parentheses; - - for (size_t i = 0; i < str.size(); ++i) { - char ch = str[i]; - - // stack to keep track if we are within parentheses - if (ch == par_open) { - parentheses.push(ch); - } - if (ch == par_close) { - if (!parentheses.empty()) { - parentheses.pop(); - } else { - throw InvalidInputException("Incongruent parentheses in string: '%s'", str); - } - } - // split if not within parentheses - if (parentheses.empty() && ch == delimiter) { - result.push_back(current); - current.clear(); - } else { - current += ch; - } - } - // Add the last segment - if (!current.empty()) { - result.push_back(current); - } - if (!parentheses.empty()) { - throw InvalidInputException("Incongruent parentheses in string: '%s'", str); - } - return result; -} - -string StringUtil::Join(const vector &input, const string &separator) { - return StringUtil::Join(input, input.size(), separator, [](const string &s) { return s; }); -} - -string StringUtil::Join(const set &input, const string &separator) { - // The result - std::string result; - - auto it = input.begin(); - while (it != input.end()) { - result += *it; - it++; - if (it == input.end()) { - break; - } - result += separator; - } - return result; -} - -string StringUtil::BytesToHumanReadableString(idx_t bytes, idx_t multiplier) { - D_ASSERT(multiplier == 1000 || multiplier == 1024); - idx_t array[6] = {}; - const char *unit[2][6] = {{"bytes", "KiB", "MiB", "GiB", "TiB", "PiB"}, {"bytes", "kB", "MB", "GB", "TB", "PB"}}; - - const int sel = (multiplier == 1000); - - array[0] = bytes; - for (idx_t i = 1; i < 6; i++) { - array[i] = array[i - 1] / multiplier; - array[i - 1] %= multiplier; - } - - for (idx_t i = 5; i >= 1; i--) { - if (array[i]) { - // Map 0 -> 0 and (multiplier-1) -> 9 - idx_t fractional_part = (array[i - 1] * 10) / multiplier; - return to_string(array[i]) + "." + to_string(fractional_part) + " " + unit[sel][i]; - } - } - - return to_string(array[0]) + (bytes == 1 ? " byte" : " bytes"); -} - -string StringUtil::Upper(const string &str) { - string copy(str); - transform(copy.begin(), copy.end(), copy.begin(), [](unsigned char c) { return std::toupper(c); }); - return (copy); -} - -string StringUtil::Lower(const string &str) { - string copy(str); - transform(copy.begin(), copy.end(), copy.begin(), - [](unsigned char c) { return StringUtil::CharacterToLower(static_cast(c)); }); - return (copy); -} - -string StringUtil::Title(const string &str) { - string copy; - bool first_character = true; - for (auto c : str) { - bool is_alpha = StringUtil::CharacterIsAlpha(c); - if (is_alpha) { - if (first_character) { - copy += StringUtil::CharacterToUpper(c); - first_character = false; - } else { - copy += StringUtil::CharacterToLower(c); - } - } else { - first_character = true; - copy += c; - } - } - return copy; -} - -bool StringUtil::IsLower(const string &str) { - return str == Lower(str); -} - -bool StringUtil::IsUpper(const string &str) { - return str == Upper(str); -} - -// Jenkins hash function: https://en.wikipedia.org/wiki/Jenkins_hash_function -uint64_t StringUtil::CIHash(const string &str) { - uint32_t hash = 0; - for (auto c : str) { - hash += static_cast(StringUtil::CharacterToLower(static_cast(c))); - hash += hash << 10; - hash ^= hash >> 6; - } - hash += hash << 3; - hash ^= hash >> 11; - hash += hash << 15; - return hash; -} - -bool StringUtil::CIEquals(const string &l1, const string &l2) { - if (l1.size() != l2.size()) { - return false; - } - const auto charmap = ASCII_TO_LOWER_MAP; - for (idx_t c = 0; c < l1.size(); c++) { - if (charmap[(uint8_t)l1[c]] != charmap[(uint8_t)l2[c]]) { - return false; - } - } - return true; -} - -bool StringUtil::CILessThan(const string &s1, const string &s2) { - const auto charmap = ASCII_TO_UPPER_MAP; - - unsigned char u1 {}, u2 {}; - - idx_t length = MinValue(s1.length(), s2.length()); - length += s1.length() != s2.length(); - for (idx_t i = 0; i < length; i++) { - u1 = (unsigned char)s1[i]; - u2 = (unsigned char)s2[i]; - if (charmap[u1] != charmap[u2]) { - break; - } - } - return (charmap[u1] - charmap[u2]) < 0; -} - -idx_t StringUtil::CIFind(vector &vector, const string &search_string) { - for (idx_t i = 0; i < vector.size(); i++) { - const auto &string = vector[i]; - if (CIEquals(string, search_string)) { - return i; - } - } - return DConstants::INVALID_INDEX; -} - -vector StringUtil::Split(const string &str, char delimiter) { - std::stringstream ss(str); - vector lines; - string temp; - while (getline(ss, temp, delimiter)) { - lines.push_back(temp); - } - return (lines); -} - -vector StringUtil::Split(const string &input, const string &split) { - vector splits; - - idx_t last = 0; - idx_t input_len = input.size(); - idx_t split_len = split.size(); - while (last <= input_len) { - idx_t next = input.find(split, last); - if (next == string::npos) { - next = input_len; - } - - // Push the substring [last, next) on to splits - string substr = input.substr(last, next - last); - if (!substr.empty()) { - splits.push_back(substr); - } - last = next + split_len; - } - if (splits.empty()) { - splits.push_back(input); - } - return splits; -} - -string StringUtil::Replace(string source, const string &from, const string &to) { - if (from.empty()) { - throw InternalException("Invalid argument to StringUtil::Replace - empty FROM"); - } - idx_t start_pos = 0; - while ((start_pos = source.find(from, start_pos)) != string::npos) { - source.replace(start_pos, from.length(), to); - start_pos += to.length(); // In case 'to' contains 'from', like - // replacing 'x' with 'yx' - } - return source; -} - -vector StringUtil::TopNStrings(vector> scores, idx_t n, double threshold) { - if (scores.empty()) { - return vector(); - } - sort(scores.begin(), scores.end(), [](const pair &a, const pair &b) -> bool { - return a.second > b.second || (a.second == b.second && a.first.size() < b.first.size()); - }); - vector result; - result.push_back(scores[0].first); - for (idx_t i = 1; i < MinValue(scores.size(), n); i++) { - if (scores[i].second < threshold) { - break; - } - result.push_back(scores[i].first); - } - return result; -} - -static double NormalizeScore(idx_t score, idx_t max_score) { - return 1.0 - static_cast(score) / static_cast(max_score); -} - -vector StringUtil::TopNStrings(const vector> &scores, idx_t n, idx_t threshold) { - // obtain the max score to normalize - idx_t max_score = threshold; - for (auto &score : scores) { - if (score.second > max_score) { - max_score = score.second; - } - } - - // normalize - vector> normalized_scores; - for (auto &score : scores) { - normalized_scores.push_back(make_pair(score.first, NormalizeScore(score.second, max_score))); - } - return TopNStrings(std::move(normalized_scores), n, NormalizeScore(threshold, max_score)); -} - -struct LevenshteinArray { - LevenshteinArray(idx_t len1, idx_t len2) : len1(len1) { - dist = make_unsafe_uniq_array(len1 * len2); - } - - idx_t &Score(idx_t i, idx_t j) { - return dist[GetIndex(i, j)]; - } - -private: - idx_t len1; - unsafe_unique_array dist; - - idx_t GetIndex(idx_t i, idx_t j) { - return j * len1 + i; - } -}; - -// adapted from https://en.wikibooks.org/wiki/Algorithm_Implementation/Strings/Levenshtein_distance#C++ -idx_t StringUtil::LevenshteinDistance(const string &s1_p, const string &s2_p, idx_t not_equal_penalty) { - auto s1 = StringUtil::Lower(s1_p); - auto s2 = StringUtil::Lower(s2_p); - idx_t len1 = s1.size(); - idx_t len2 = s2.size(); - if (len1 == 0) { - return len2; - } - if (len2 == 0) { - return len1; - } - LevenshteinArray array(len1 + 1, len2 + 1); - array.Score(0, 0) = 0; - for (idx_t i = 0; i <= len1; i++) { - array.Score(i, 0) = i; - } - for (idx_t j = 0; j <= len2; j++) { - array.Score(0, j) = j; - } - for (idx_t i = 1; i <= len1; i++) { - for (idx_t j = 1; j <= len2; j++) { - // d[i][j] = std::min({ d[i - 1][j] + 1, - // d[i][j - 1] + 1, - // d[i - 1][j - 1] + (s1[i - 1] == s2[j - 1] ? 0 : 1) }); - auto equal = s1[i - 1] == s2[j - 1] ? 0 : not_equal_penalty; - idx_t adjacent_score1 = array.Score(i - 1, j) + 1; - idx_t adjacent_score2 = array.Score(i, j - 1) + 1; - idx_t adjacent_score3 = array.Score(i - 1, j - 1) + equal; - - idx_t t = MinValue(adjacent_score1, adjacent_score2); - array.Score(i, j) = MinValue(t, adjacent_score3); - } - } - return array.Score(len1, len2); -} - -idx_t StringUtil::SimilarityScore(const string &s1, const string &s2) { - return LevenshteinDistance(s1, s2, 3); -} - -double StringUtil::SimilarityRating(const string &s1, const string &s2) { - return duckdb_jaro_winkler::jaro_winkler_similarity(s1.data(), s1.data() + s1.size(), s2.data(), - s2.data() + s2.size()); -} - -vector StringUtil::TopNLevenshtein(const vector &strings, const string &target, idx_t n, - idx_t threshold) { - vector> scores; - scores.reserve(strings.size()); - for (auto &str : strings) { - if (target.size() < str.size()) { - scores.emplace_back(str, SimilarityScore(str.substr(0, target.size()), target)); - } else { - scores.emplace_back(str, SimilarityScore(str, target)); - } - } - return TopNStrings(scores, n, threshold); -} - -vector StringUtil::TopNJaroWinkler(const vector &strings, const string &target, idx_t n, - double threshold) { - vector> scores; - scores.reserve(strings.size()); - for (auto &str : strings) { - scores.emplace_back(str, SimilarityRating(str, target)); - } - return TopNStrings(scores, n, threshold); -} - -string StringUtil::CandidatesMessage(const vector &candidates, const string &candidate) { - string result_str; - if (!candidates.empty()) { - result_str = "\n" + candidate + ": "; - for (idx_t i = 0; i < candidates.size(); i++) { - if (i > 0) { - result_str += ", "; - } - result_str += "\"" + candidates[i] + "\""; - } - } - return result_str; -} - -string StringUtil::CandidatesErrorMessage(const vector &strings, const string &target, - const string &message_prefix, idx_t n) { - auto closest_strings = StringUtil::TopNLevenshtein(strings, target, n); - return StringUtil::CandidatesMessage(closest_strings, message_prefix); -} - -unordered_map StringUtil::ParseJSONMap(const string &json) { - unordered_map result; - if (json.empty()) { - return result; - } - yyjson_read_flag flags = YYJSON_READ_ALLOW_INVALID_UNICODE; - yyjson_doc *doc = yyjson_read(json.c_str(), json.size(), flags); - if (!doc) { - throw SerializationException("Failed to parse JSON string: %s", json); - } - yyjson_val *root = yyjson_doc_get_root(doc); - if (!root || yyjson_get_type(root) != YYJSON_TYPE_OBJ) { - yyjson_doc_free(doc); - throw SerializationException("Failed to parse JSON string: %s", json); - } - yyjson_obj_iter iter; - yyjson_obj_iter_init(root, &iter); - yyjson_val *key, *value; - while ((key = yyjson_obj_iter_next(&iter))) { - value = yyjson_obj_iter_get_val(key); - if (yyjson_get_type(value) != YYJSON_TYPE_STR) { - yyjson_doc_free(doc); - throw SerializationException("Failed to parse JSON string: %s", json); - } - auto key_val = yyjson_get_str(key); - auto key_len = yyjson_get_len(key); - auto value_val = yyjson_get_str(value); - auto value_len = yyjson_get_len(value); - result.emplace(string(key_val, key_len), string(value_val, value_len)); - } - yyjson_doc_free(doc); - return result; -} - -string ToJsonMapInternal(const unordered_map &map, yyjson_mut_doc *doc, yyjson_mut_val *root) { - for (auto &entry : map) { - auto key = yyjson_mut_strncpy(doc, entry.first.c_str(), entry.first.size()); - auto value = yyjson_mut_strncpy(doc, entry.second.c_str(), entry.second.size()); - yyjson_mut_obj_add(root, key, value); - } - yyjson_write_err err; - size_t len; - constexpr yyjson_write_flag flags = YYJSON_WRITE_ALLOW_INVALID_UNICODE; - char *json = yyjson_mut_write_opts(doc, flags, nullptr, &len, &err); - if (!json) { - yyjson_mut_doc_free(doc); - throw SerializationException("Failed to write JSON string: %s", err.msg); - } - // Create a string from the JSON - string result(json, len); - - // Free the JSON and the document - free(json); - yyjson_mut_doc_free(doc); - - // Return the result - return result; -} -string StringUtil::ToJSONMap(const unordered_map &map) { - yyjson_mut_doc *doc = yyjson_mut_doc_new(nullptr); - yyjson_mut_val *root = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, root); - - return ToJsonMapInternal(map, doc, root); -} - -string StringUtil::ExceptionToJSONMap(ExceptionType type, const string &message, - const unordered_map &map) { - D_ASSERT(map.find("exception_type") == map.end()); - D_ASSERT(map.find("exception_message") == map.end()); - - yyjson_mut_doc *doc = yyjson_mut_doc_new(nullptr); - yyjson_mut_val *root = yyjson_mut_obj(doc); - yyjson_mut_doc_set_root(doc, root); - - auto except_str = Exception::ExceptionTypeToString(type); - yyjson_mut_obj_add_strncpy(doc, root, "exception_type", except_str.c_str(), except_str.size()); - yyjson_mut_obj_add_strncpy(doc, root, "exception_message", message.c_str(), message.size()); - - return ToJsonMapInternal(map, doc, root); -} - -string StringUtil::GetFileName(const string &file_path) { - - idx_t pos = file_path.find_last_of("/\\"); - if (pos == string::npos) { - return file_path; - } - auto end = file_path.size() - 1; - - // If the rest of the string is just slashes or dots, trim them - if (file_path.find_first_not_of("/\\.", pos) == string::npos) { - // Trim the trailing slashes and dots - while (end > 0 && (file_path[end] == '/' || file_path[end] == '.' || file_path[end] == '\\')) { - end--; - } - - // Now find the next slash - pos = file_path.find_last_of("/\\", end); - if (pos == string::npos) { - return file_path.substr(0, end + 1); - } - } - - return file_path.substr(pos + 1, end - pos); -} - -string StringUtil::GetFileExtension(const string &file_name) { - auto name = GetFileName(file_name); - idx_t pos = name.find_last_of('.'); - // We dont consider e.g. `.gitignore` to have an extension - if (pos == string::npos || pos == 0) { - return ""; - } - return name.substr(pos + 1); -} - -string StringUtil::GetFileStem(const string &file_name) { - auto name = GetFileName(file_name); - if (name.size() > 1 && name[0] == '.') { - return name; - } - idx_t pos = name.find_last_of('.'); - if (pos == string::npos) { - return name; - } - return name.substr(0, pos); -} - -string StringUtil::GetFilePath(const string &file_path) { - // Trim the trailing slashes - auto end = file_path.size() - 1; - while (end > 0 && (file_path[end] == '/' || file_path[end] == '\\')) { - end--; - } - - auto pos = file_path.find_last_of("/\\", end); - if (pos == string::npos) { - return ""; - } - - while (pos > 0 && (file_path[pos] == '/' || file_path[pos] == '\\')) { - pos--; - } - - return file_path.substr(0, pos + 1); -} - -struct URLEncodeLength { - using RESULT_TYPE = idx_t; - - static void ProcessCharacter(idx_t &result, char) { - result++; - } - - static void ProcessHex(idx_t &result, const char *, idx_t) { - result++; - } -}; - -struct URLEncodeWrite { - using RESULT_TYPE = char *; - - static void ProcessCharacter(char *&result, char c) { - *result = c; - result++; - } - - static void ProcessHex(char *&result, const char *input, idx_t idx) { - uint32_t hex_first = StringUtil::GetHexValue(input[idx + 1]); - uint32_t hex_second = StringUtil::GetHexValue(input[idx + 2]); - uint32_t hex_value = (hex_first << 4) + hex_second; - ProcessCharacter(result, static_cast(hex_value)); - } -}; - -template -void URLEncodeInternal(const char *input, idx_t input_size, typename OP::RESULT_TYPE &result, bool encode_slash) { - // https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-query-string-auth.html - static const char *HEX_DIGIT = "0123456789ABCDEF"; - for (idx_t i = 0; i < input_size; i++) { - char ch = input[i]; - if ((ch >= 'A' && ch <= 'Z') || (ch >= 'a' && ch <= 'z') || (ch >= '0' && ch <= '9') || ch == '_' || - ch == '-' || ch == '~' || ch == '.') { - OP::ProcessCharacter(result, ch); - } else if (ch == '/' && !encode_slash) { - OP::ProcessCharacter(result, ch); - } else { - OP::ProcessCharacter(result, '%'); - OP::ProcessCharacter(result, HEX_DIGIT[static_cast(ch) >> 4]); - OP::ProcessCharacter(result, HEX_DIGIT[static_cast(ch) & 15]); - } - } -} - -idx_t StringUtil::URLEncodeSize(const char *input, idx_t input_size, bool encode_slash) { - idx_t result_length = 0; - URLEncodeInternal(input, input_size, result_length, encode_slash); - return result_length; -} - -void StringUtil::URLEncodeBuffer(const char *input, idx_t input_size, char *output, bool encode_slash) { - URLEncodeInternal(input, input_size, output, encode_slash); -} - -string StringUtil::URLEncode(const string &input, bool encode_slash) { - idx_t result_size = URLEncodeSize(input.c_str(), input.size(), encode_slash); - auto result_data = make_uniq_array(result_size); - URLEncodeBuffer(input.c_str(), input.size(), result_data.get(), encode_slash); - return string(result_data.get(), result_size); -} - -template -void URLDecodeInternal(const char *input, idx_t input_size, typename OP::RESULT_TYPE &result, bool plus_to_space) { - for (idx_t i = 0; i < input_size; i++) { - char ch = input[i]; - if (plus_to_space && ch == '+') { - OP::ProcessCharacter(result, ' '); - } else if (ch == '%' && i + 2 < input_size && StringUtil::CharacterIsHex(input[i + 1]) && - StringUtil::CharacterIsHex(input[i + 2])) { - OP::ProcessHex(result, input, i); - i += 2; - } else { - OP::ProcessCharacter(result, ch); - } - } -} - -idx_t StringUtil::URLDecodeSize(const char *input, idx_t input_size, bool plus_to_space) { - idx_t result_length = 0; - URLDecodeInternal(input, input_size, result_length, plus_to_space); - return result_length; -} - -void StringUtil::URLDecodeBuffer(const char *input, idx_t input_size, char *output, bool plus_to_space) { - char *output_start = output; - URLDecodeInternal(input, input_size, output, plus_to_space); - if (!Utf8Proc::IsValid(output_start, NumericCast(output - output_start))) { - throw InvalidInputException("Failed to decode string \"%s\" using URL decoding - decoded value is invalid UTF8", - string(input, input_size)); - } -} - -string StringUtil::URLDecode(const string &input, bool plus_to_space) { - idx_t result_size = URLDecodeSize(input.c_str(), input.size(), plus_to_space); - auto result_data = make_uniq_array(result_size); - URLDecodeBuffer(input.c_str(), input.size(), result_data.get(), plus_to_space); - return string(result_data.get(), result_size); -} - -uint32_t StringUtil::StringToEnum(const EnumStringLiteral enum_list[], idx_t enum_count, const char *enum_name, - const char *str_value) { - for (idx_t i = 0; i < enum_count; i++) { - if (CIEquals(enum_list[i].string, str_value)) { - return enum_list[i].number; - } - } - // string to enum conversion failed - generate candidates - vector candidates; - for (idx_t i = 0; i < enum_count; i++) { - candidates.push_back(enum_list[i].string); - } - auto closest_values = TopNJaroWinkler(candidates, str_value); - auto message = CandidatesMessage(closest_values, "Candidates"); - throw NotImplementedException("Enum value: unrecognized value \"%s\" for enum \"%s\"\n%s", str_value, enum_name, - message); -} - -const char *StringUtil::EnumToString(const EnumStringLiteral enum_list[], idx_t enum_count, const char *enum_name, - uint32_t enum_value) { - for (idx_t i = 0; i < enum_count; i++) { - if (enum_list[i].number == enum_value) { - return enum_list[i].string; - } - } - throw NotImplementedException("Enum value: unrecognized enum value \"%d\" for enum \"%s\"", enum_value, enum_name); -} - -const uint8_t StringUtil::ASCII_TO_UPPER_MAP[] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, - 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, - 88, 89, 90, 91, 92, 93, 94, 95, 96, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, - 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 123, 124, 125, 126, 127, 128, 129, 130, 131, - 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, - 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, - 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, - 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, - 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255}; - -const uint8_t StringUtil::ASCII_TO_LOWER_MAP[] = { - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, - 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 97, - 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, - 120, 121, 122, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, - 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, - 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, - 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, - 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, - 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, - 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231, 232, 233, 234, 235, 236, 237, 238, 239, 240, 241, - 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 252, 253, 254, 255}; - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer.cpp deleted file mode 100644 index c924ff260..000000000 --- a/src/duckdb/src/common/tree_renderer.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/common/tree_renderer.hpp" -#include "duckdb/common/tree_renderer/text_tree_renderer.hpp" -#include "duckdb/common/tree_renderer/json_tree_renderer.hpp" -#include "duckdb/common/tree_renderer/html_tree_renderer.hpp" -#include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" - -#include - -namespace duckdb { - -unique_ptr TreeRenderer::CreateRenderer(ExplainFormat format) { - switch (format) { - case ExplainFormat::DEFAULT: - case ExplainFormat::TEXT: - return make_uniq(); - case ExplainFormat::JSON: - return make_uniq(); - case ExplainFormat::HTML: - return make_uniq(); - case ExplainFormat::GRAPHVIZ: - return make_uniq(); - default: - throw NotImplementedException("ExplainFormat %s not implemented", EnumUtil::ToString(format)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp deleted file mode 100644 index 40a93b692..000000000 --- a/src/duckdb/src/common/tree_renderer/graphviz_tree_renderer.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "duckdb/common/tree_renderer/graphviz_tree_renderer.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/main/query_profiler.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -string GRAPHVIZTreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string GRAPHVIZTreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string GRAPHVIZTreeRenderer::ToString(const ProfilingNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string GRAPHVIZTreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void GRAPHVIZTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void GRAPHVIZTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void GRAPHVIZTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void GRAPHVIZTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void GRAPHVIZTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { - const string digraph_format = R"( -digraph G { - node [shape=box, style=rounded, fontname="Courier New", fontsize=10]; -%s -%s -} - )"; - - vector nodes; - vector edges; - - const string node_format = R"( node_%d_%d [label="%s"];)"; - - for (idx_t y = 0; y < root.height; y++) { - for (idx_t x = 0; x < root.width; x++) { - auto node = root.GetNode(x, y); - if (!node) { - continue; - } - - // Create Node - vector body; - body.push_back(node->name); - for (auto &item : node->extra_text) { - auto &key = item.first; - auto &value_raw = item.second; - - auto value = QueryProfiler::JSONSanitize(value_raw); - body.push_back(StringUtil::Format("%s:\\n%s", key, value)); - } - nodes.push_back(StringUtil::Format(node_format, x, y, StringUtil::Join(body, "\\n───\\n"))); - - // Create Edge(s) - for (auto &coord : node->child_positions) { - edges.push_back(StringUtil::Format(" node_%d_%d -> node_%d_%d;", x, y, coord.x, coord.y)); - } - } - } - auto node_lines = StringUtil::Join(nodes, "\n"); - auto edge_lines = StringUtil::Join(edges, "\n"); - - string result = StringUtil::Format(digraph_format, node_lines, edge_lines); - ss << result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp deleted file mode 100644 index 66a8f2d1d..000000000 --- a/src/duckdb/src/common/tree_renderer/html_tree_renderer.cpp +++ /dev/null @@ -1,267 +0,0 @@ -#include "duckdb/common/tree_renderer/html_tree_renderer.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -string HTMLTreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string HTMLTreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string HTMLTreeRenderer::ToString(const ProfilingNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string HTMLTreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void HTMLTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void HTMLTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void HTMLTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void HTMLTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -static string CreateStyleSection(RenderTree &root) { - return R"( - - )"; -} - -static string CreateHeadSection(RenderTree &root) { - string head_section = R"( - - - - - - - DuckDB Query Plan - %s - - )"; - return StringUtil::Format(head_section, CreateStyleSection(root)); -} - -static string CreateGridItemContent(RenderTreeNode &node) { - const string content_format = R"( -

-%s -
- )"; - - vector items; - for (auto &item : node.extra_text) { - auto &key = item.first; - auto &value = item.second; - if (value.empty()) { - continue; - } - items.push_back(StringUtil::Format(R"(
%s
)", key)); - auto splits = StringUtil::Split(value, "\n"); - for (auto &split : splits) { - items.push_back(StringUtil::Format(R"(
%s
)", split)); - } - } - string result; - if (!items.empty()) { - result = StringUtil::Format(content_format, StringUtil::Join(items, "\n")); - } - if (!node.child_positions.empty()) { - result += ""; - } - return result; -} - -static string CreateGridItem(RenderTree &root, idx_t x, idx_t y) { - const string grid_item_format = R"( -
-
%s
%s -
- )"; - - auto node = root.GetNode(x, y); - if (!node) { - return ""; - } - - auto title = node->name; - auto content = CreateGridItemContent(*node); - return StringUtil::Format(grid_item_format, title, content); -} - -static string CreateTreeRecursive(RenderTree &root, idx_t x, idx_t y) { - string result; - - result += "
  • "; - result += CreateGridItem(root, x, y); - auto node = root.GetNode(x, y); - if (!node->child_positions.empty()) { - result += "
      "; - for (auto &coord : node->child_positions) { - result += CreateTreeRecursive(root, coord.x, coord.y); - } - result += "
    "; - } - result += "
  • "; - return result; -} - -static string CreateBodySection(RenderTree &root) { - const string body_section = R"( - -
    -
      %s
    -
    - - - - - - )"; - return StringUtil::Format(body_section, CreateTreeRecursive(root, 0, 0)); -} - -void HTMLTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { - string result; - result += CreateHeadSection(root); - result += CreateBodySection(root); - ss << result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp deleted file mode 100644 index edc9309be..000000000 --- a/src/duckdb/src/common/tree_renderer/json_tree_renderer.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include "duckdb/common/tree_renderer/json_tree_renderer.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "utf8proc_wrapper.hpp" - -#include "yyjson.hpp" - -#include - -using namespace duckdb_yyjson; // NOLINT - -namespace duckdb { - -string JSONTreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string JSONTreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string JSONTreeRenderer::ToString(const ProfilingNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string JSONTreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void JSONTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void JSONTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void JSONTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void JSONTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -static yyjson_mut_val *RenderRecursive(yyjson_mut_doc *doc, RenderTree &tree, idx_t x, idx_t y) { - auto node_p = tree.GetNode(x, y); - D_ASSERT(node_p); - auto &node = *node_p; - - auto object = yyjson_mut_obj(doc); - auto children = yyjson_mut_arr(doc); - for (auto &child_pos : node.child_positions) { - auto child_object = RenderRecursive(doc, tree, child_pos.x, child_pos.y); - yyjson_mut_arr_append(children, child_object); - } - yyjson_mut_obj_add_str(doc, object, "name", node.name.c_str()); - yyjson_mut_obj_add_val(doc, object, "children", children); - auto extra_info = yyjson_mut_obj(doc); - for (auto &it : node.extra_text) { - auto &key = it.first; - auto &value = it.second; - auto splits = StringUtil::Split(value, "\n"); - if (splits.size() > 1) { - auto list_items = yyjson_mut_arr(doc); - for (auto &split : splits) { - yyjson_mut_arr_add_strcpy(doc, list_items, split.c_str()); - } - yyjson_mut_obj_add_val(doc, extra_info, key.c_str(), list_items); - } else { - yyjson_mut_obj_add_strcpy(doc, extra_info, key.c_str(), value.c_str()); - } - } - yyjson_mut_obj_add_val(doc, object, "extra_info", extra_info); - return object; -} - -void JSONTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { - auto doc = yyjson_mut_doc_new(nullptr); - auto result_obj = yyjson_mut_arr(doc); - yyjson_mut_doc_set_root(doc, result_obj); - - auto plan = RenderRecursive(doc, root, 0, 0); - yyjson_mut_arr_append(result_obj, plan); - - auto data = yyjson_mut_val_write_opts(result_obj, YYJSON_WRITE_ALLOW_INF_AND_NAN | YYJSON_WRITE_PRETTY, nullptr, - nullptr, nullptr); - if (!data) { - yyjson_mut_doc_free(doc); - throw InternalException("The plan could not be rendered as JSON, yyjson failed"); - } - ss << string(data); - free(data); - yyjson_mut_doc_free(doc); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp deleted file mode 100644 index fa24ee103..000000000 --- a/src/duckdb/src/common/tree_renderer/text_tree_renderer.cpp +++ /dev/null @@ -1,494 +0,0 @@ -#include "duckdb/common/tree_renderer/text_tree_renderer.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/physical_operator.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/common/typedefs.hpp" - -#include - -namespace duckdb { - -namespace { - -struct StringSegment { -public: - StringSegment(idx_t start, idx_t width) : start(start), width(width) { - } - -public: - idx_t start; - idx_t width; -}; - -} // namespace - -void TextTreeRenderer::RenderTopLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - if (root.HasNode(x, y)) { - ss << config.LTCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - if (y == 0) { - // top level node: no node above this one - ss << config.HORIZONTAL; - } else { - // render connection to node above this one - ss << config.DMIDDLE; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - ss << config.RTCORNER; - } else { - bool has_adjacent_nodes = false; - for (idx_t i = 0; x + i < root.width; i++) { - has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); - } - if (!has_adjacent_nodes) { - // There are no nodes to the right side of this position - // no need to fill the empty space - continue; - } - // there are nodes next to this, fill the space - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - ss << '\n'; -} - -static bool NodeHasMultipleChildren(RenderTreeNode &node) { - return node.child_positions.size() > 1; -} - -static bool ShouldRenderWhitespace(RenderTree &root, idx_t x, idx_t y) { - idx_t found_children = 0; - for (;; x--) { - auto node = root.GetNode(x, y); - if (root.HasNode(x, y + 1)) { - found_children++; - } - if (node) { - if (NodeHasMultipleChildren(*node)) { - if (found_children < node->child_positions.size()) { - return true; - } - } - return false; - } - if (x == 0) { - break; - } - } - return false; -} - -void TextTreeRenderer::RenderBottomLayer(RenderTree &root, std::ostream &ss, idx_t y) { - for (idx_t x = 0; x <= root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - bool has_adjacent_nodes = false; - for (idx_t i = 0; x + i < root.width; i++) { - has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); - } - auto node = root.GetNode(x, y); - if (node) { - ss << config.LDCORNER; - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - if (root.HasNode(x, y + 1)) { - // node below this one: connect to that one - ss << config.TMIDDLE; - } else { - // no node below this one: end the box - ss << config.HORIZONTAL; - } - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2 - 1); - ss << config.RDCORNER; - } else if (root.HasNode(x, y + 1)) { - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - ss << config.VERTICAL; - if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } - } else { - if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - } - ss << '\n'; -} - -string AdjustTextForRendering(string source, idx_t max_render_width) { - const idx_t size = source.size(); - const char *input = source.c_str(); - - idx_t render_width = 0; - - // For every character in the input, create a StringSegment - vector render_widths; - idx_t current_position = 0; - while (current_position < size) { - idx_t char_render_width = Utf8Proc::RenderWidth(input, size, current_position); - current_position = Utf8Proc::NextGraphemeCluster(input, size, current_position); - render_width += char_render_width; - render_widths.push_back(StringSegment(current_position, render_width)); - if (render_width > max_render_width) { - break; - } - } - - if (render_width > max_render_width) { - // need to find a position to truncate - for (idx_t pos = render_widths.size(); pos > 0; pos--) { - auto &source_range = render_widths[pos - 1]; - if (source_range.width < max_render_width - 4) { - return source.substr(0, source_range.start) + string("...") + - string(max_render_width - source_range.width - 3, ' '); - } - } - source = "..."; - } - // need to pad with spaces - idx_t total_spaces = max_render_width - render_width; - idx_t half_spaces = total_spaces / 2; - idx_t extra_left_space = total_spaces % 2 == 0 ? 0 : 1; - return string(half_spaces + extra_left_space, ' ') + source + string(half_spaces, ' '); -} - -void TextTreeRenderer::RenderBoxContent(RenderTree &root, std::ostream &ss, idx_t y) { - // we first need to figure out how high our boxes are going to be - vector> extra_info; - idx_t extra_height = 0; - extra_info.resize(root.width); - for (idx_t x = 0; x < root.width; x++) { - auto node = root.GetNode(x, y); - if (node) { - SplitUpExtraInfo(node->extra_text, extra_info[x], config.max_extra_lines); - if (extra_info[x].size() > extra_height) { - extra_height = extra_info[x].size(); - } - } - } - idx_t halfway_point = (extra_height + 1) / 2; - // now we render the actual node - for (idx_t render_y = 0; render_y <= extra_height; render_y++) { - for (idx_t x = 0; x < root.width; x++) { - if (x * config.node_render_width >= config.maximum_render_width) { - break; - } - bool has_adjacent_nodes = false; - for (idx_t i = 0; x + i < root.width; i++) { - has_adjacent_nodes = has_adjacent_nodes || root.HasNode(x + i, y); - } - auto node = root.GetNode(x, y); - if (!node) { - if (render_y == halfway_point) { - bool has_child_to_the_right = ShouldRenderWhitespace(root, x, y); - if (root.HasNode(x, y + 1)) { - // node right below this one - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); - if (has_child_to_the_right) { - ss << config.TMIDDLE; - // but we have another child to the right! keep rendering the line - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width / 2); - } else { - ss << config.RTCORNER; - if (has_adjacent_nodes) { - // only a child below this one: fill the rest with spaces - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } - } - } else if (has_child_to_the_right) { - // child to the right, but no child right below this one: render a full line - ss << StringUtil::Repeat(config.HORIZONTAL, config.node_render_width); - } else { - if (has_adjacent_nodes) { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - } else if (render_y >= halfway_point) { - if (root.HasNode(x, y + 1)) { - // we have a node below this empty spot: render a vertical line - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - ss << config.VERTICAL; - if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { - ss << StringUtil::Repeat(" ", config.node_render_width / 2); - } - } else { - if (has_adjacent_nodes || ShouldRenderWhitespace(root, x, y)) { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - } else { - if (has_adjacent_nodes) { - // empty spot: render spaces - ss << StringUtil::Repeat(" ", config.node_render_width); - } - } - } else { - ss << config.VERTICAL; - // figure out what to render - string render_text; - if (render_y == 0) { - render_text = node->name; - } else { - if (render_y <= extra_info[x].size()) { - render_text = extra_info[x][render_y - 1]; - } - } - if (render_y + 1 == extra_height && render_text.empty()) { - auto entry = node->extra_text.find(RenderTreeNode::CARDINALITY); - if (entry != node->extra_text.end()) { - render_text = entry->second + " Rows"; - } - } - if (render_y == extra_height && render_text.empty()) { - auto timing_entry = node->extra_text.find(RenderTreeNode::TIMING); - if (timing_entry != node->extra_text.end()) { - render_text = "(" + timing_entry->second + ")"; - } else if (node->extra_text.find(RenderTreeNode::CARDINALITY) == node->extra_text.end()) { - // we only render estimated cardinality if there is no real cardinality - auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); - if (entry != node->extra_text.end()) { - render_text = "~" + entry->second + " Rows"; - } - } - if (node->extra_text.find(RenderTreeNode::CARDINALITY) == node->extra_text.end()) { - // we only render estimated cardinality if there is no real cardinality - auto entry = node->extra_text.find(RenderTreeNode::ESTIMATED_CARDINALITY); - if (entry != node->extra_text.end()) { - render_text = "~" + entry->second + " Rows"; - } - } - } - render_text = AdjustTextForRendering(render_text, config.node_render_width - 2); - ss << render_text; - - if (render_y == halfway_point && NodeHasMultipleChildren(*node)) { - ss << config.LMIDDLE; - } else { - ss << config.VERTICAL; - } - } - } - ss << '\n'; - } -} - -string TextTreeRenderer::ToString(const LogicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TextTreeRenderer::ToString(const PhysicalOperator &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TextTreeRenderer::ToString(const ProfilingNode &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -string TextTreeRenderer::ToString(const Pipeline &op) { - std::stringstream ss; - Render(op, ss); - return ss.str(); -} - -void TextTreeRenderer::Render(const LogicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void TextTreeRenderer::Render(const PhysicalOperator &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void TextTreeRenderer::Render(const ProfilingNode &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void TextTreeRenderer::Render(const Pipeline &op, std::ostream &ss) { - auto tree = RenderTree::CreateRenderTree(op); - ToStream(*tree, ss); -} - -void TextTreeRenderer::ToStreamInternal(RenderTree &root, std::ostream &ss) { - while (root.width * config.node_render_width > config.maximum_render_width) { - if (config.node_render_width - 2 < config.minimum_render_width) { - break; - } - config.node_render_width -= 2; - } - - for (idx_t y = 0; y < root.height; y++) { - // start by rendering the top layer - RenderTopLayer(root, ss, y); - // now we render the content of the boxes - RenderBoxContent(root, ss, y); - // render the bottom layer of each of the boxes - RenderBottomLayer(root, ss, y); - } -} - -bool TextTreeRenderer::CanSplitOnThisChar(char l) { - return (l < '0' || (l > '9' && l < 'A') || (l > 'Z' && l < 'a')) && l != '_'; -} - -bool TextTreeRenderer::IsPadding(char l) { - return l == ' ' || l == '\t' || l == '\n' || l == '\r'; -} - -string TextTreeRenderer::RemovePadding(string l) { - idx_t start = 0, end = l.size(); - while (start < l.size() && IsPadding(l[start])) { - start++; - } - while (end > 0 && IsPadding(l[end - 1])) { - end--; - } - return l.substr(start, end - start); -} - -void TextTreeRenderer::SplitStringBuffer(const string &source, vector &result) { - D_ASSERT(Utf8Proc::IsValid(source.c_str(), source.size())); - const idx_t max_line_render_size = config.node_render_width - 2; - // utf8 in prompt, get render width - idx_t character_pos = 0; - idx_t start_pos = 0; - idx_t render_width = 0; - idx_t last_possible_split = 0; - - const idx_t size = source.size(); - const char *input = source.c_str(); - - while (character_pos < size) { - size_t char_render_width = Utf8Proc::RenderWidth(input, size, character_pos); - idx_t next_character_pos = Utf8Proc::NextGraphemeCluster(input, size, character_pos); - - // Does the next character make us exceed the line length? - if (render_width + char_render_width > max_line_render_size) { - if (start_pos + 8 > last_possible_split) { - // The last character we can split on is one of the first 8 characters of the line - // to not create very small lines we instead split on the current character - last_possible_split = character_pos; - } - result.push_back(source.substr(start_pos, last_possible_split - start_pos)); - render_width = character_pos - last_possible_split; - start_pos = last_possible_split; - character_pos = last_possible_split; - } - // check if we can split on this character - if (CanSplitOnThisChar(source[character_pos])) { - last_possible_split = character_pos; - } - character_pos = next_character_pos; - render_width += char_render_width; - } - if (size > start_pos) { - // append the remainder of the input - result.push_back(source.substr(start_pos, size - start_pos)); - } -} - -void TextTreeRenderer::SplitUpExtraInfo(const InsertionOrderPreservingMap &extra_info, vector &result, - idx_t max_lines) { - if (extra_info.empty()) { - return; - } - for (auto &item : extra_info) { - auto &text = item.second; - if (!Utf8Proc::IsValid(text.c_str(), text.size())) { - return; - } - } - result.push_back(ExtraInfoSeparator()); - - bool requires_padding = false; - bool was_inlined = false; - for (auto &item : extra_info) { - string str = RemovePadding(item.second); - if (str.empty()) { - continue; - } - bool is_inlined = false; - if (!StringUtil::StartsWith(item.first, "__")) { - // the name is not internal (i.e. not __text__) - so we display the name in addition to the entry - const idx_t available_width = (config.node_render_width - 7); - idx_t total_size = item.first.size() + str.size() + 2; - bool is_multiline = StringUtil::Contains(str, "\n"); - if (!is_multiline && total_size < available_width) { - // we can inline the full entry - no need for any separators unless the previous entry explicitly - // requires it - str = item.first + ": " + str; - is_inlined = true; - } else { - str = item.first + ":\n" + str; - } - } - if (is_inlined && was_inlined) { - // we can skip the padding if we have multiple inlined entries in a row - requires_padding = false; - } - if (requires_padding) { - result.emplace_back(); - } - // cardinality, timing and estimated cardinality are rendered separately - // this is to allow alignment horizontally across nodes - if (item.first == RenderTreeNode::CARDINALITY) { - // cardinality - need to reserve space for cardinality AND timing - result.emplace_back(); - if (extra_info.find(RenderTreeNode::TIMING) != extra_info.end()) { - result.emplace_back(); - } - break; - } - if (item.first == RenderTreeNode::ESTIMATED_CARDINALITY) { - // estimated cardinality - reserve space for estimate - if (extra_info.find(RenderTreeNode::CARDINALITY) != extra_info.end()) { - // if we have a true cardinality render that instead of the estimate - result.pop_back(); - continue; - } - result.emplace_back(); - break; - } - auto splits = StringUtil::Split(str, "\n"); - if (splits.size() > max_lines) { - // truncate this entry - vector truncated_splits; - for (idx_t i = 0; i < max_lines / 2; i++) { - truncated_splits.push_back(std::move(splits[i])); - } - truncated_splits.push_back("..."); - for (idx_t i = splits.size() - max_lines / 2; i < splits.size(); i++) { - truncated_splits.push_back(std::move(splits[i])); - } - splits = std::move(truncated_splits); - } - for (auto &split : splits) { - SplitStringBuffer(split, result); - } - requires_padding = true; - was_inlined = is_inlined; - } -} - -string TextTreeRenderer::ExtraInfoSeparator() { - return StringUtil::Repeat(string(config.HORIZONTAL), (config.node_render_width - 9)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/tree_renderer/tree_renderer.cpp b/src/duckdb/src/common/tree_renderer/tree_renderer.cpp deleted file mode 100644 index b0a58b597..000000000 --- a/src/duckdb/src/common/tree_renderer/tree_renderer.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "duckdb/common/tree_renderer.hpp" - -namespace duckdb { - -void TreeRenderer::ToStream(RenderTree &root, std::ostream &ss) { - if (!UsesRawKeyNames()) { - root.SanitizeKeyNames(); - } - return ToStreamInternal(root, ss); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types.cpp b/src/duckdb/src/common/types.cpp deleted file mode 100644 index 342a84bc0..000000000 --- a/src/duckdb/src/common/types.cpp +++ /dev/null @@ -1,1914 +0,0 @@ -#include "duckdb/common/types.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/catalog/catalog_search_path.hpp" -#include "duckdb/catalog/default/default_types.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/extra_type_info.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/hash.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/function/cast_rules.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/parser/keyword_helper.hpp" -#include "duckdb/parser/parser.hpp" - -#include - -namespace duckdb { - -LogicalType::LogicalType() : LogicalType(LogicalTypeId::INVALID) { -} - -LogicalType::LogicalType(LogicalTypeId id) : id_(id) { - physical_type_ = GetInternalType(); -} -LogicalType::LogicalType(LogicalTypeId id, shared_ptr type_info_p) - : id_(id), type_info_(std::move(type_info_p)) { - physical_type_ = GetInternalType(); -} - -LogicalType::LogicalType(const LogicalType &other) - : id_(other.id_), physical_type_(other.physical_type_), type_info_(other.type_info_) { -} - -LogicalType::LogicalType(LogicalType &&other) noexcept - : id_(other.id_), physical_type_(other.physical_type_), type_info_(std::move(other.type_info_)) { -} - -hash_t LogicalType::Hash() const { - return duckdb::Hash((uint8_t)id_); -} - -PhysicalType LogicalType::GetInternalType() { - switch (id_) { - case LogicalTypeId::BOOLEAN: - return PhysicalType::BOOL; - case LogicalTypeId::TINYINT: - return PhysicalType::INT8; - case LogicalTypeId::UTINYINT: - return PhysicalType::UINT8; - case LogicalTypeId::SMALLINT: - return PhysicalType::INT16; - case LogicalTypeId::USMALLINT: - return PhysicalType::UINT16; - case LogicalTypeId::SQLNULL: - case LogicalTypeId::DATE: - case LogicalTypeId::INTEGER: - return PhysicalType::INT32; - case LogicalTypeId::UINTEGER: - return PhysicalType::UINT32; - case LogicalTypeId::BIGINT: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - return PhysicalType::INT64; - case LogicalTypeId::UBIGINT: - return PhysicalType::UINT64; - case LogicalTypeId::UHUGEINT: - return PhysicalType::UINT128; - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UUID: - return PhysicalType::INT128; - case LogicalTypeId::FLOAT: - return PhysicalType::FLOAT; - case LogicalTypeId::DOUBLE: - return PhysicalType::DOUBLE; - case LogicalTypeId::DECIMAL: { - if (!type_info_) { - return PhysicalType::INVALID; - } - auto width = DecimalType::GetWidth(*this); - if (width <= Decimal::MAX_WIDTH_INT16) { - return PhysicalType::INT16; - } else if (width <= Decimal::MAX_WIDTH_INT32) { - return PhysicalType::INT32; - } else if (width <= Decimal::MAX_WIDTH_INT64) { - return PhysicalType::INT64; - } else if (width <= Decimal::MAX_WIDTH_INT128) { - return PhysicalType::INT128; - } else { - throw InternalException("Decimal has a width of %d which is bigger than the maximum supported width of %d", - width, DecimalType::MaxWidth()); - } - } - case LogicalTypeId::VARCHAR: - case LogicalTypeId::CHAR: - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: - return PhysicalType::VARCHAR; - case LogicalTypeId::INTERVAL: - return PhysicalType::INTERVAL; - case LogicalTypeId::UNION: - case LogicalTypeId::STRUCT: - return PhysicalType::STRUCT; - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return PhysicalType::LIST; - case LogicalTypeId::ARRAY: - return PhysicalType::ARRAY; - case LogicalTypeId::POINTER: - // LCOV_EXCL_START - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - return PhysicalType::UINT32; - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - return PhysicalType::UINT64; - } else { - throw InternalException("Unsupported pointer size"); - } - // LCOV_EXCL_STOP - case LogicalTypeId::VALIDITY: - return PhysicalType::BIT; - case LogicalTypeId::ENUM: { - if (!type_info_) { - return PhysicalType::INVALID; - } - return EnumType::GetPhysicalType(*this); - } - case LogicalTypeId::TABLE: - case LogicalTypeId::LAMBDA: - case LogicalTypeId::ANY: - case LogicalTypeId::INVALID: - case LogicalTypeId::UNKNOWN: - case LogicalTypeId::STRING_LITERAL: - case LogicalTypeId::INTEGER_LITERAL: - return PhysicalType::INVALID; - case LogicalTypeId::USER: - return PhysicalType::UNKNOWN; - case LogicalTypeId::AGGREGATE_STATE: - return PhysicalType::VARCHAR; - default: - throw InternalException("Invalid LogicalType %s", ToString()); - } -} - -// **DEPRECATED**: Use EnumUtil directly instead. -string LogicalTypeIdToString(LogicalTypeId type) { - return EnumUtil::ToString(type); -} - -constexpr const LogicalTypeId LogicalType::INVALID; -constexpr const LogicalTypeId LogicalType::SQLNULL; -constexpr const LogicalTypeId LogicalType::UNKNOWN; -constexpr const LogicalTypeId LogicalType::BOOLEAN; -constexpr const LogicalTypeId LogicalType::TINYINT; -constexpr const LogicalTypeId LogicalType::UTINYINT; -constexpr const LogicalTypeId LogicalType::SMALLINT; -constexpr const LogicalTypeId LogicalType::USMALLINT; -constexpr const LogicalTypeId LogicalType::INTEGER; -constexpr const LogicalTypeId LogicalType::UINTEGER; -constexpr const LogicalTypeId LogicalType::BIGINT; -constexpr const LogicalTypeId LogicalType::UBIGINT; -constexpr const LogicalTypeId LogicalType::HUGEINT; -constexpr const LogicalTypeId LogicalType::UHUGEINT; -constexpr const LogicalTypeId LogicalType::UUID; -constexpr const LogicalTypeId LogicalType::FLOAT; -constexpr const LogicalTypeId LogicalType::DOUBLE; -constexpr const LogicalTypeId LogicalType::DATE; - -constexpr const LogicalTypeId LogicalType::TIMESTAMP; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_MS; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_NS; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_S; - -constexpr const LogicalTypeId LogicalType::TIME; - -constexpr const LogicalTypeId LogicalType::TIME_TZ; -constexpr const LogicalTypeId LogicalType::TIMESTAMP_TZ; - -constexpr const LogicalTypeId LogicalType::HASH; -constexpr const LogicalTypeId LogicalType::POINTER; - -constexpr const LogicalTypeId LogicalType::VARCHAR; - -constexpr const LogicalTypeId LogicalType::BLOB; -constexpr const LogicalTypeId LogicalType::BIT; -constexpr const LogicalTypeId LogicalType::VARINT; - -constexpr const LogicalTypeId LogicalType::INTERVAL; -constexpr const LogicalTypeId LogicalType::ROW_TYPE; - -// TODO these are incomplete and should maybe not exist as such -constexpr const LogicalTypeId LogicalType::TABLE; -constexpr const LogicalTypeId LogicalType::LAMBDA; - -constexpr const LogicalTypeId LogicalType::ANY; - -const vector LogicalType::Numeric() { - vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::FLOAT, - LogicalType::DOUBLE, LogicalTypeId::DECIMAL, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, - LogicalType::UHUGEINT}; - return types; -} - -const vector LogicalType::Integral() { - vector types = {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::HUGEINT, LogicalType::UTINYINT, - LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, - LogicalType::UHUGEINT}; - return types; -} - -const vector LogicalType::Real() { - vector types = {LogicalType::FLOAT, LogicalType::DOUBLE}; - return types; -} - -const vector LogicalType::AllTypes() { - vector types = { - LogicalType::BOOLEAN, LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, - LogicalType::BIGINT, LogicalType::DATE, LogicalType::TIMESTAMP, LogicalType::DOUBLE, - LogicalType::FLOAT, LogicalType::VARCHAR, LogicalType::BLOB, LogicalType::BIT, - LogicalType::VARINT, LogicalType::INTERVAL, LogicalType::HUGEINT, LogicalTypeId::DECIMAL, - LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, - LogicalType::UHUGEINT, LogicalType::TIME, LogicalTypeId::LIST, LogicalTypeId::STRUCT, - LogicalType::TIME_TZ, LogicalType::TIMESTAMP_TZ, LogicalTypeId::MAP, LogicalTypeId::UNION, - LogicalType::UUID, LogicalTypeId::ARRAY}; - return types; -} - -const PhysicalType ROW_TYPE = PhysicalType::INT64; - -// LCOV_EXCL_START -string TypeIdToString(PhysicalType type) { - switch (type) { - case PhysicalType::BOOL: - return "BOOL"; - case PhysicalType::INT8: - return "INT8"; - case PhysicalType::INT16: - return "INT16"; - case PhysicalType::INT32: - return "INT32"; - case PhysicalType::INT64: - return "INT64"; - case PhysicalType::UINT8: - return "UINT8"; - case PhysicalType::UINT16: - return "UINT16"; - case PhysicalType::UINT32: - return "UINT32"; - case PhysicalType::UINT64: - return "UINT64"; - case PhysicalType::INT128: - return "INT128"; - case PhysicalType::UINT128: - return "UINT128"; - case PhysicalType::FLOAT: - return "FLOAT"; - case PhysicalType::DOUBLE: - return "DOUBLE"; - case PhysicalType::VARCHAR: - return "VARCHAR"; - case PhysicalType::INTERVAL: - return "INTERVAL"; - case PhysicalType::STRUCT: - return "STRUCT"; - case PhysicalType::LIST: - return "LIST"; - case PhysicalType::ARRAY: - return "ARRAY"; - case PhysicalType::INVALID: - return "INVALID"; - case PhysicalType::BIT: - return "BIT"; - case PhysicalType::UNKNOWN: - return "UNKNOWN"; - } - return "INVALID"; -} -// LCOV_EXCL_STOP - -idx_t GetTypeIdSize(PhysicalType type) { - switch (type) { - case PhysicalType::BIT: - case PhysicalType::BOOL: - return sizeof(bool); - case PhysicalType::INT8: - return sizeof(int8_t); - case PhysicalType::INT16: - return sizeof(int16_t); - case PhysicalType::INT32: - return sizeof(int32_t); - case PhysicalType::INT64: - return sizeof(int64_t); - case PhysicalType::UINT8: - return sizeof(uint8_t); - case PhysicalType::UINT16: - return sizeof(uint16_t); - case PhysicalType::UINT32: - return sizeof(uint32_t); - case PhysicalType::UINT64: - return sizeof(uint64_t); - case PhysicalType::INT128: - return sizeof(hugeint_t); - case PhysicalType::UINT128: - return sizeof(uhugeint_t); - case PhysicalType::FLOAT: - return sizeof(float); - case PhysicalType::DOUBLE: - return sizeof(double); - case PhysicalType::VARCHAR: - return sizeof(string_t); - case PhysicalType::INTERVAL: - return sizeof(interval_t); - case PhysicalType::STRUCT: - case PhysicalType::UNKNOWN: - case PhysicalType::ARRAY: - return 0; // no own payload - case PhysicalType::LIST: - return sizeof(list_entry_t); // offset + len - default: - throw InternalException("Invalid PhysicalType for GetTypeIdSize"); - } -} - -bool TypeIsConstantSize(PhysicalType type) { - return (type >= PhysicalType::BOOL && type <= PhysicalType::DOUBLE) || type == PhysicalType::INTERVAL || - type == PhysicalType::INT128 || type == PhysicalType::UINT128; -} -bool TypeIsIntegral(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128 || - type == PhysicalType::UINT128; -} -bool TypeIsNumeric(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::DOUBLE) || type == PhysicalType::INT128 || - type == PhysicalType::UINT128; -} -bool TypeIsInteger(PhysicalType type) { - return (type >= PhysicalType::UINT8 && type <= PhysicalType::INT64) || type == PhysicalType::INT128 || - type == PhysicalType::UINT128; -} - -static string TypeModifierListToString(const vector &mod_list) { - string result; - if (mod_list.empty()) { - return result; - } - result = "("; - for (idx_t i = 0; i < mod_list.size(); i++) { - result += mod_list[i].ToString(); - if (i < mod_list.size() - 1) { - result += ", "; - } - } - result += ")"; - return result; -} - -string LogicalType::ToString() const { - if (id_ != LogicalTypeId::USER) { - auto alias = GetAlias(); - if (!alias.empty()) { - if (HasExtensionInfo()) { - auto &ext_info = *GetExtensionInfo(); - alias += TypeModifierListToString(ext_info.modifiers); - } - return alias; - } - } - switch (id_) { - case LogicalTypeId::STRUCT: { - if (!type_info_) { - return "STRUCT"; - } - auto is_unnamed = StructType::IsUnnamed(*this); - auto &child_types = StructType::GetChildTypes(*this); - string ret = "STRUCT("; - for (size_t i = 0; i < child_types.size(); i++) { - if (is_unnamed) { - ret += child_types[i].second.ToString(); - } else { - ret += StringUtil::Format("%s %s", SQLIdentifier(child_types[i].first), child_types[i].second); - } - if (i < child_types.size() - 1) { - ret += ", "; - } - } - ret += ")"; - return ret; - } - case LogicalTypeId::LIST: { - if (!type_info_) { - return "LIST"; - } - return ListType::GetChildType(*this).ToString() + "[]"; - } - case LogicalTypeId::MAP: { - if (!type_info_) { - return "MAP"; - } - auto &key_type = MapType::KeyType(*this); - auto &value_type = MapType::ValueType(*this); - return "MAP(" + key_type.ToString() + ", " + value_type.ToString() + ")"; - } - case LogicalTypeId::UNION: { - if (!type_info_) { - return "UNION"; - } - string ret = "UNION("; - size_t count = UnionType::GetMemberCount(*this); - for (size_t i = 0; i < count; i++) { - auto member_name = UnionType::GetMemberName(*this, i); - auto member_type = UnionType::GetMemberType(*this, i).ToString(); - ret += StringUtil::Format("%s %s", SQLIdentifier(member_name), member_type); - if (i < count - 1) { - ret += ", "; - } - } - ret += ")"; - return ret; - } - case LogicalTypeId::ARRAY: { - if (!type_info_) { - return "ARRAY"; - } - auto size = ArrayType::GetSize(*this); - if (size == 0) { - return ArrayType::GetChildType(*this).ToString() + "[ANY]"; - } else { - return ArrayType::GetChildType(*this).ToString() + "[" + to_string(size) + "]"; - } - } - case LogicalTypeId::DECIMAL: { - if (!type_info_) { - return "DECIMAL"; - } - auto width = DecimalType::GetWidth(*this); - auto scale = DecimalType::GetScale(*this); - if (width == 0) { - return "DECIMAL"; - } - return StringUtil::Format("DECIMAL(%d,%d)", width, scale); - } - case LogicalTypeId::ENUM: { - string ret = "ENUM("; - for (idx_t i = 0; i < EnumType::GetSize(*this); i++) { - if (i > 0) { - ret += ", "; - } - ret += KeywordHelper::WriteQuoted(EnumType::GetString(*this, i).GetString(), '\''); - } - ret += ")"; - return ret; - } - case LogicalTypeId::USER: { - string result; - auto &catalog = UserType::GetCatalog(*this); - auto &schema = UserType::GetSchema(*this); - auto &type = UserType::GetTypeName(*this); - auto &mods = UserType::GetTypeModifiers(*this); - - if (!catalog.empty()) { - result = KeywordHelper::WriteOptionallyQuoted(catalog); - } - if (!schema.empty()) { - if (!result.empty()) { - result += "."; - } - result += KeywordHelper::WriteOptionallyQuoted(schema); - } - if (!result.empty()) { - result += "."; - } - result += KeywordHelper::WriteOptionallyQuoted(type); - - if (!mods.empty()) { - result += "("; - for (idx_t i = 0; i < mods.size(); i++) { - result += mods[i].ToString(); - if (i < mods.size() - 1) { - result += ", "; - } - } - result += ")"; - } - - return result; - } - case LogicalTypeId::AGGREGATE_STATE: { - return AggregateStateType::GetTypeName(*this); - } - case LogicalTypeId::SQLNULL: { - return "\"NULL\""; - } - default: - return EnumUtil::ToString(id_); - } -} -// LCOV_EXCL_STOP - -LogicalTypeId TransformStringToLogicalTypeId(const string &str) { - auto type = DefaultTypeGenerator::GetDefaultType(str); - if (type == LogicalTypeId::INVALID) { - // This is a User Type, at this point we don't know if its one of the User Defined Types or an error - // It is checked in the binder - type = LogicalTypeId::USER; - } - return type; -} - -LogicalType TransformStringToLogicalType(const string &str) { - if (StringUtil::Lower(str) == "null") { - return LogicalType::SQLNULL; - } - ColumnList column_list; - try { - column_list = Parser::ParseColumnList("dummy " + str); - } catch (const std::runtime_error &e) { - const vector suggested_types {"BIGINT", - "INT8", - "LONG", - "BIT", - "BITSTRING", - "BLOB", - "BYTEA", - "BINARY,", - "VARBINARY", - "BOOLEAN", - "BOOL", - "LOGICAL", - "DATE", - "DECIMAL(prec, scale)", - "DOUBLE", - "FLOAT8", - "FLOAT", - "FLOAT4", - "REAL", - "HUGEINT", - "INTEGER", - "INT4", - "INT", - "SIGNED", - "INTERVAL", - "SMALLINT", - "INT2", - "SHORT", - "TIME", - "TIMESTAMPTZ", - "TIMESTAMP", - "DATETIME", - "TINYINT", - "INT1", - "UBIGINT", - "UHUGEINT", - "UINTEGER", - "USMALLINT", - "UTINYINT", - "UUID", - "VARCHAR", - "CHAR", - "BPCHAR", - "TEXT", - "STRING", - "MAP(INTEGER, VARCHAR)", - "UNION(num INTEGER, text VARCHAR)"}; - std::ostringstream error; - error << "Value \"" << str << "\" can not be converted to a DuckDB Type." << '\n'; - error << "Possible examples as suggestions: " << '\n'; - auto suggestions = StringUtil::TopNJaroWinkler(suggested_types, str); - for (auto &suggestion : suggestions) { - error << "* " << suggestion << '\n'; - } - throw InvalidInputException(error.str()); - } - return column_list.GetColumn(LogicalIndex(0)).Type(); -} - -LogicalType GetUserTypeRecursive(const LogicalType &type, ClientContext &context) { - if (type.id() == LogicalTypeId::USER && type.HasAlias()) { - auto &type_entry = - Catalog::GetEntry(context, INVALID_CATALOG, INVALID_SCHEMA, type.GetAlias()); - return type_entry.user_type; - } - // Look for LogicalTypeId::USER in nested types - if (type.id() == LogicalTypeId::STRUCT) { - child_list_t children; - children.reserve(StructType::GetChildCount(type)); - for (auto &child : StructType::GetChildTypes(type)) { - children.emplace_back(child.first, GetUserTypeRecursive(child.second, context)); - } - return LogicalType::STRUCT(children); - } - if (type.id() == LogicalTypeId::LIST) { - return LogicalType::LIST(GetUserTypeRecursive(ListType::GetChildType(type), context)); - } - if (type.id() == LogicalTypeId::MAP) { - return LogicalType::MAP(GetUserTypeRecursive(MapType::KeyType(type), context), - GetUserTypeRecursive(MapType::ValueType(type), context)); - } - // Not LogicalTypeId::USER or a nested type - return type; -} - -LogicalType TransformStringToLogicalType(const string &str, ClientContext &context) { - return GetUserTypeRecursive(TransformStringToLogicalType(str), context); -} - -bool LogicalType::IsIntegral() const { - switch (id_) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - return true; - default: - return false; - } -} - -bool LogicalType::IsNumeric() const { - switch (id_) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - return true; - default: - return false; - } -} - -bool LogicalType::IsTemporal() const { - switch (id_) { - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - return true; - default: - return false; - } -} - -bool LogicalType::IsValid() const { - return id() != LogicalTypeId::INVALID && id() != LogicalTypeId::UNKNOWN; -} - -bool LogicalType::IsComplete() const { - // Check if type does not contain incomplete types - return !TypeVisitor::Contains(*this, [](const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::INVALID: - case LogicalTypeId::UNKNOWN: - case LogicalTypeId::ANY: - return true; // These are incomplete by default - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - if (!type.AuxInfo() || type.AuxInfo()->type != ExtraTypeInfoType::LIST_TYPE_INFO) { - return true; // Missing or incorrect type info - } - break; - case LogicalTypeId::STRUCT: - case LogicalTypeId::UNION: - if (!type.AuxInfo() || type.AuxInfo()->type != ExtraTypeInfoType::STRUCT_TYPE_INFO) { - return true; // Missing or incorrect type info - } - break; - case LogicalTypeId::ARRAY: - if (!type.AuxInfo() || type.AuxInfo()->type != ExtraTypeInfoType::ARRAY_TYPE_INFO) { - return true; // Missing or incorrect type info - } - break; - case LogicalTypeId::DECIMAL: - if (!type.AuxInfo() || type.AuxInfo()->type != ExtraTypeInfoType::DECIMAL_TYPE_INFO) { - return true; // Missing or incorrect type info - } - break; - case LogicalTypeId::ENUM: - if (!type.AuxInfo() || type.AuxInfo()->type != ExtraTypeInfoType::ENUM_TYPE_INFO) { - return true; // Missing or incorrect type info - } - break; - default: - return false; - } - - // Type has type info, check if it is complete - D_ASSERT(type.AuxInfo()); - switch (type.AuxInfo()->type) { - case ExtraTypeInfoType::STRUCT_TYPE_INFO: - return type.AuxInfo()->Cast().child_types.empty(); // Cannot be empty - case ExtraTypeInfoType::DECIMAL_TYPE_INFO: - return DecimalType::GetWidth(type) >= 1 && DecimalType::GetWidth(type) <= Decimal::MAX_WIDTH_DECIMAL && - DecimalType::GetScale(type) <= DecimalType::GetWidth(type); - default: - return false; // Nested types are checked by TypeVisitor recursion - } - }); -} - -bool LogicalType::SupportsRegularUpdate() const { - switch (id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::ARRAY: - case LogicalTypeId::MAP: - case LogicalTypeId::UNION: - return false; - case LogicalTypeId::STRUCT: { - auto &child_types = StructType::GetChildTypes(*this); - for (auto &entry : child_types) { - if (!entry.second.SupportsRegularUpdate()) { - return false; - } - } - return true; - } - default: - return true; - } -} - -bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const { - switch (id_) { - case LogicalTypeId::SQLNULL: - width = 0; - scale = 0; - break; - case LogicalTypeId::BOOLEAN: - width = 1; - scale = 0; - break; - case LogicalTypeId::TINYINT: - // tinyint: [-127, 127] = DECIMAL(3,0) - width = 3; - scale = 0; - break; - case LogicalTypeId::SMALLINT: - // smallint: [-32767, 32767] = DECIMAL(5,0) - width = 5; - scale = 0; - break; - case LogicalTypeId::INTEGER: - // integer: [-2147483647, 2147483647] = DECIMAL(10,0) - width = 10; - scale = 0; - break; - case LogicalTypeId::BIGINT: - // bigint: [-9223372036854775807, 9223372036854775807] = DECIMAL(19,0) - width = 19; - scale = 0; - break; - case LogicalTypeId::UTINYINT: - // UInt8 — [0 : 255] - width = 3; - scale = 0; - break; - case LogicalTypeId::USMALLINT: - // UInt16 — [0 : 65535] - width = 5; - scale = 0; - break; - case LogicalTypeId::UINTEGER: - // UInt32 — [0 : 4294967295] - width = 10; - scale = 0; - break; - case LogicalTypeId::UBIGINT: - // UInt64 — [0 : 18446744073709551615] - width = 20; - scale = 0; - break; - case LogicalTypeId::HUGEINT: - // hugeint: max size decimal (38, 0) - // note that a hugeint is not guaranteed to fit in this - width = 38; - scale = 0; - break; - case LogicalTypeId::UHUGEINT: - // hugeint: max size decimal (38, 0) - // note that a uhugeint is not guaranteed to fit in this - width = 38; - scale = 0; - break; - case LogicalTypeId::DECIMAL: - width = DecimalType::GetWidth(*this); - scale = DecimalType::GetScale(*this); - break; - case LogicalTypeId::INTEGER_LITERAL: - return IntegerLiteral::GetType(*this).GetDecimalProperties(width, scale); - default: - // Nonsense values to ensure initialization - width = 255u; - scale = 255u; - // FIXME(carlo): This should be probably a throw, requires checkign the various call-sites - return false; - } - return true; -} - -//! Grows Decimal width/scale when appropriate -static LogicalType DecimalSizeCheck(const LogicalType &left, const LogicalType &right) { - D_ASSERT(left.id() == LogicalTypeId::DECIMAL || right.id() == LogicalTypeId::DECIMAL); - D_ASSERT(left.id() != right.id()); - - //! Make sure the 'right' is the DECIMAL type - if (left.id() == LogicalTypeId::DECIMAL) { - return DecimalSizeCheck(right, left); - } - auto width = DecimalType::GetWidth(right); - auto scale = DecimalType::GetScale(right); - - uint8_t other_width; - uint8_t other_scale; - bool success = left.GetDecimalProperties(other_width, other_scale); - if (!success) { - throw InternalException("Type provided to DecimalSizeCheck was not a numeric type"); - } - D_ASSERT(other_scale == 0); - const auto effective_width = width - scale; - if (other_width > effective_width) { - auto new_width = NumericCast(other_width + scale); - //! Cap the width at max, if an actual value exceeds this, an exception will be thrown later - if (new_width > DecimalType::MaxWidth()) { - new_width = DecimalType::MaxWidth(); - } - return LogicalType::DECIMAL(new_width, scale); - } - return right; -} - -static LogicalType CombineNumericTypes(const LogicalType &left, const LogicalType &right) { - D_ASSERT(left.id() != right.id()); - if (left.id() > right.id()) { - // this method is symmetric - // arrange it so the left type is smaller to limit the number of options we need to check - return CombineNumericTypes(right, left); - } - // we can't cast implicitly either way and types are not equal - // this happens when left is signed and right is unsigned - // e.g. INTEGER and UINTEGER - // in this case we need to upcast to make sure the types fit - - if (left.id() == LogicalTypeId::BIGINT || right.id() == LogicalTypeId::UBIGINT) { - return LogicalType::HUGEINT; - } - if (left.id() == LogicalTypeId::INTEGER || right.id() == LogicalTypeId::UINTEGER) { - return LogicalType::BIGINT; - } - if (left.id() == LogicalTypeId::SMALLINT || right.id() == LogicalTypeId::USMALLINT) { - return LogicalType::INTEGER; - } - if (left.id() == LogicalTypeId::TINYINT || right.id() == LogicalTypeId::UTINYINT) { - return LogicalType::SMALLINT; - } - - // No type is larger than (u)hugeint, so casting to double is required - // UHUGEINT is on the left because the enum is lower - if (left.id() == LogicalTypeId::UHUGEINT || right.id() == LogicalTypeId::HUGEINT) { - return LogicalType::DOUBLE; - } - throw InternalException("Cannot combine these numeric types (%s & %s)", left.ToString(), right.ToString()); -} - -LogicalType LogicalType::NormalizeType(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::STRING_LITERAL: - return LogicalType::VARCHAR; - case LogicalTypeId::INTEGER_LITERAL: - return IntegerLiteral::GetType(type); - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - default: - return type; - } -} - -template -static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { - // left and right are not equal - // NULL/unknown (parameter) types always take the other type - LogicalTypeId other_types[] = {LogicalTypeId::SQLNULL, LogicalTypeId::UNKNOWN}; - for (auto &other_type : other_types) { - if (left.id() == other_type) { - result = LogicalType::NormalizeType(right); - return true; - } else if (right.id() == other_type) { - result = LogicalType::NormalizeType(left); - return true; - } - } - - // for enums, match the varchar rules - if (left.id() == LogicalTypeId::ENUM) { - return OP::Operation(LogicalType::VARCHAR, right, result); - } else if (right.id() == LogicalTypeId::ENUM) { - return OP::Operation(left, LogicalType::VARCHAR, result); - } - - // for everything but enums - string literals also take the other type - if (left.id() == LogicalTypeId::STRING_LITERAL) { - result = LogicalType::NormalizeType(right); - return true; - } else if (right.id() == LogicalTypeId::STRING_LITERAL) { - result = LogicalType::NormalizeType(left); - return true; - } - - // for other types - use implicit cast rules to check if we can combine the types - auto left_to_right_cost = CastRules::ImplicitCast(left, right); - auto right_to_left_cost = CastRules::ImplicitCast(right, left); - if (left_to_right_cost >= 0 && (left_to_right_cost < right_to_left_cost || right_to_left_cost < 0)) { - // we can implicitly cast left to right, return right - //! Depending on the type, we might need to grow the `width` of the DECIMAL type - if (right.id() == LogicalTypeId::DECIMAL) { - result = DecimalSizeCheck(left, right); - } else { - result = right; - } - return true; - } - if (right_to_left_cost >= 0) { - // we can implicitly cast right to left, return left - //! Depending on the type, we might need to grow the `width` of the DECIMAL type - if (left.id() == LogicalTypeId::DECIMAL) { - result = DecimalSizeCheck(right, left); - } else { - result = left; - } - return true; - } - // for integer literals - rerun the operation with the underlying type - if (left.id() == LogicalTypeId::INTEGER_LITERAL) { - return OP::Operation(IntegerLiteral::GetType(left), right, result); - } - if (right.id() == LogicalTypeId::INTEGER_LITERAL) { - return OP::Operation(left, IntegerLiteral::GetType(right), result); - } - // for unsigned/signed comparisons we have a few fallbacks - if (left.IsNumeric() && right.IsNumeric()) { - result = CombineNumericTypes(left, right); - return true; - } - if (left.id() == LogicalTypeId::BOOLEAN && right.IsIntegral()) { - result = right; - return true; - } - if (right.id() == LogicalTypeId::BOOLEAN && left.IsIntegral()) { - result = left; - return true; - } - return false; -} - -template -static bool CombineStructTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { - auto &left_children = StructType::GetChildTypes(left); - auto &right_children = StructType::GetChildTypes(right); - - auto left_unnamed = StructType::IsUnnamed(left); - auto is_unnamed = left_unnamed || StructType::IsUnnamed(right); - child_list_t child_types; - - // At least one side is unnamed, so we attempt positional casting. - if (is_unnamed) { - if (left_children.size() != right_children.size()) { - // We can't cast, or create the super-set. - return false; - } - - for (idx_t i = 0; i < left_children.size(); i++) { - LogicalType child_type; - if (!OP::Operation(left_children[i].second, right_children[i].second, child_type)) { - return false; - } - auto &child_name = left_unnamed ? right_children[i].first : left_children[i].first; - child_types.emplace_back(child_name, std::move(child_type)); - } - result = LogicalType::STRUCT(child_types); - return true; - } - - // Create a super-set of the STRUCT fields. - // First, create a name->index map of the right children. - case_insensitive_map_t right_children_map; - for (idx_t i = 0; i < right_children.size(); i++) { - auto &name = right_children[i].first; - right_children_map[name] = i; - } - - for (idx_t i = 0; i < left_children.size(); i++) { - auto &left_child = left_children[i]; - auto right_child_it = right_children_map.find(left_child.first); - - if (right_child_it == right_children_map.end()) { - // We can directly put the left child. - child_types.emplace_back(left_child.first, left_child.second); - continue; - } - - // We need to recurse to ensure the children have a maximum logical type. - LogicalType child_type; - auto &right_child = right_children[right_child_it->second]; - if (!OP::Operation(left_child.second, right_child.second, child_type)) { - return false; - } - child_types.emplace_back(left_child.first, std::move(child_type)); - right_children_map.erase(right_child_it); - } - - // Add all remaining right children. - for (const auto &right_child_it : right_children_map) { - auto &right_child = right_children[right_child_it.second]; - child_types.emplace_back(right_child.first, right_child.second); - } - - result = LogicalType::STRUCT(child_types); - return true; -} - -template -static bool CombineEqualTypes(const LogicalType &left, const LogicalType &right, LogicalType &result) { - // Since both left and right are equal we get the left type as our type_id for checks - auto type_id = left.id(); - switch (type_id) { - case LogicalTypeId::STRING_LITERAL: - // two string literals convert to varchar - result = LogicalType::VARCHAR; - return true; - case LogicalTypeId::INTEGER_LITERAL: - // for two integer literals we unify the underlying types - return OP::Operation(IntegerLiteral::GetType(left), IntegerLiteral::GetType(right), result); - case LogicalTypeId::ENUM: - // If both types are different ENUMs we do a string comparison. - result = left == right ? left : LogicalType::VARCHAR; - return true; - case LogicalTypeId::VARCHAR: - // varchar: use type that has collation (if any) - if (StringType::GetCollation(right).empty()) { - result = left; - } else { - result = right; - } - return true; - case LogicalTypeId::DECIMAL: { - // unify the width/scale so that the resulting decimal always fits - // "width - scale" gives us the number of digits on the left side of the decimal point - // "scale" gives us the number of digits allowed on the right of the decimal point - // using the max of these of the two types gives us the new decimal size - auto extra_width_left = DecimalType::GetWidth(left) - DecimalType::GetScale(left); - auto extra_width_right = DecimalType::GetWidth(right) - DecimalType::GetScale(right); - auto extra_width = - MaxValue(NumericCast(extra_width_left), NumericCast(extra_width_right)); - auto scale = MaxValue(DecimalType::GetScale(left), DecimalType::GetScale(right)); - auto width = NumericCast(extra_width + scale); - if (width > DecimalType::MaxWidth()) { - // if the resulting decimal does not fit, we truncate the scale - width = DecimalType::MaxWidth(); - scale = NumericCast(width - extra_width); - } - result = LogicalType::DECIMAL(width, scale); - return true; - } - case LogicalTypeId::LIST: { - // list: perform max recursively on child type - LogicalType new_child; - if (!OP::Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { - return false; - } - result = LogicalType::LIST(new_child); - return true; - } - case LogicalTypeId::ARRAY: { - LogicalType new_child; - if (!OP::Operation(ArrayType::GetChildType(left), ArrayType::GetChildType(right), new_child)) { - return false; - } - auto new_size = MaxValue(ArrayType::GetSize(left), ArrayType::GetSize(right)); - result = LogicalType::ARRAY(new_child, new_size); - return true; - } - case LogicalTypeId::MAP: { - // map: perform max recursively on child type - LogicalType new_child; - if (!OP::Operation(ListType::GetChildType(left), ListType::GetChildType(right), new_child)) { - return false; - } - result = LogicalType::MAP(new_child); - return true; - } - case LogicalTypeId::STRUCT: { - return CombineStructTypes(left, right, result); - } - case LogicalTypeId::UNION: { - auto left_member_count = UnionType::GetMemberCount(left); - auto right_member_count = UnionType::GetMemberCount(right); - if (left_member_count != right_member_count) { - // return the "larger" type, with the most members - result = left_member_count > right_member_count ? left : right; - return true; - } - // otherwise, keep left, don't try to meld the two together. - result = left; - return true; - } - default: - result = left; - return true; - } -} - -template -bool TryGetMaxLogicalTypeInternal(const LogicalType &left, const LogicalType &right, LogicalType &result) { - // we always prefer aliased types - if (!left.GetAlias().empty()) { - result = left; - return true; - } - if (!right.GetAlias().empty()) { - result = right; - return true; - } - if (left.id() != right.id()) { - return CombineUnequalTypes(left, right, result); - } else { - return CombineEqualTypes(left, right, result); - } -} - -struct TryGetTypeOperation { - static bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) { - return TryGetMaxLogicalTypeInternal(left, right, result); - } -}; - -struct ForceGetTypeOperation { - static bool Operation(const LogicalType &left, const LogicalType &right, LogicalType &result) { - result = LogicalType::ForceMaxLogicalType(left, right); - return true; - } -}; - -bool LogicalType::TryGetMaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right, - LogicalType &result) { - if (DBConfig::GetConfig(context).options.old_implicit_casting) { - result = LogicalType::ForceMaxLogicalType(left, right); - return true; - } - return TryGetMaxLogicalTypeInternal(left, right, result); -} - -static idx_t GetLogicalTypeScore(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::INVALID: - case LogicalTypeId::SQLNULL: - case LogicalTypeId::UNKNOWN: - case LogicalTypeId::ANY: - case LogicalTypeId::STRING_LITERAL: - case LogicalTypeId::INTEGER_LITERAL: - return 0; - // numerics - case LogicalTypeId::BOOLEAN: - return 10; - case LogicalTypeId::UTINYINT: - return 11; - case LogicalTypeId::TINYINT: - return 12; - case LogicalTypeId::USMALLINT: - return 13; - case LogicalTypeId::SMALLINT: - return 14; - case LogicalTypeId::UINTEGER: - return 15; - case LogicalTypeId::INTEGER: - return 16; - case LogicalTypeId::UBIGINT: - return 17; - case LogicalTypeId::BIGINT: - return 18; - case LogicalTypeId::UHUGEINT: - return 19; - case LogicalTypeId::HUGEINT: - return 20; - case LogicalTypeId::DECIMAL: - return 21; - case LogicalTypeId::FLOAT: - return 22; - case LogicalTypeId::DOUBLE: - return 23; - // date/time/timestamp - case LogicalTypeId::TIME: - case LogicalTypeId::TIME_TZ: - return 50; - case LogicalTypeId::DATE: - return 51; - case LogicalTypeId::TIMESTAMP_SEC: - return 52; - case LogicalTypeId::TIMESTAMP_MS: - return 53; - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return 54; - case LogicalTypeId::TIMESTAMP_NS: - return 55; - case LogicalTypeId::INTERVAL: - return 56; - // text/character strings - case LogicalTypeId::CHAR: - return 75; - case LogicalTypeId::VARCHAR: - return 77; - case LogicalTypeId::ENUM: - return 78; - // blob/complex types - case LogicalTypeId::BIT: - return 100; - case LogicalTypeId::BLOB: - return 101; - case LogicalTypeId::UUID: - return 102; - case LogicalTypeId::VARINT: - return 103; - // nested types - case LogicalTypeId::STRUCT: - return 125; - case LogicalTypeId::LIST: - case LogicalTypeId::ARRAY: - return 126; - case LogicalTypeId::MAP: - return 127; - case LogicalTypeId::UNION: - case LogicalTypeId::TABLE: - return 150; - // weirdo types - case LogicalTypeId::LAMBDA: - case LogicalTypeId::AGGREGATE_STATE: - case LogicalTypeId::POINTER: - case LogicalTypeId::VALIDITY: - case LogicalTypeId::USER: - break; - } - return 1000; -} - -LogicalType LogicalType::ForceMaxLogicalType(const LogicalType &left, const LogicalType &right) { - LogicalType result; - if (TryGetMaxLogicalTypeInternal(left, right, result)) { - return result; - } - // we prefer the type with the highest score - auto left_score = GetLogicalTypeScore(left); - auto right_score = GetLogicalTypeScore(right); - if (left_score < right_score) { - return right; - } else { - return left; - } -} - -LogicalType LogicalType::MaxLogicalType(ClientContext &context, const LogicalType &left, const LogicalType &right) { - LogicalType result; - if (!TryGetMaxLogicalType(context, left, right, result)) { - throw NotImplementedException("Cannot combine types %s and %s - an explicit cast is required", left.ToString(), - right.ToString()); - } - return result; -} - -void LogicalType::Verify() const { -#ifdef DEBUG - switch (id_) { - case LogicalTypeId::DECIMAL: - D_ASSERT(DecimalType::GetWidth(*this) >= 1 && DecimalType::GetWidth(*this) <= Decimal::MAX_WIDTH_DECIMAL); - D_ASSERT(DecimalType::GetScale(*this) >= 0 && DecimalType::GetScale(*this) <= DecimalType::GetWidth(*this)); - break; - case LogicalTypeId::STRUCT: { - // verify child types - case_insensitive_set_t child_names; - bool all_empty = true; - for (auto &entry : StructType::GetChildTypes(*this)) { - if (entry.first.empty()) { - D_ASSERT(all_empty); - } else { - // check for duplicate struct names - all_empty = false; - auto existing_entry = child_names.find(entry.first); - D_ASSERT(existing_entry == child_names.end()); - child_names.insert(entry.first); - } - entry.second.Verify(); - } - break; - } - case LogicalTypeId::LIST: - ListType::GetChildType(*this).Verify(); - break; - case LogicalTypeId::MAP: { - MapType::KeyType(*this).Verify(); - MapType::ValueType(*this).Verify(); - break; - } - default: - break; - } -#endif -} - -bool ApproxEqual(float ldecimal, float rdecimal) { - if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { - return true; - } - if (!Value::FloatIsFinite(ldecimal) || !Value::FloatIsFinite(rdecimal)) { - return ldecimal == rdecimal; - } - float epsilon = static_cast(std::fabs(rdecimal) * 0.01 + 0.00000001); - return std::fabs(ldecimal - rdecimal) <= epsilon; -} - -bool ApproxEqual(double ldecimal, double rdecimal) { - if (Value::IsNan(ldecimal) && Value::IsNan(rdecimal)) { - return true; - } - if (!Value::DoubleIsFinite(ldecimal) || !Value::DoubleIsFinite(rdecimal)) { - return ldecimal == rdecimal; - } - double epsilon = std::fabs(rdecimal) * 0.01 + 0.00000001; - return std::fabs(ldecimal - rdecimal) <= epsilon; -} - -//===--------------------------------------------------------------------===// -// Extra Type Info -//===--------------------------------------------------------------------===// - -LogicalType LogicalType::DeepCopy() const { - LogicalType copy = *this; - if (type_info_) { - copy.type_info_ = type_info_->Copy(); - } - return copy; -} - -void LogicalType::SetAlias(string alias) { - if (!type_info_) { - type_info_ = make_shared_ptr(ExtraTypeInfoType::GENERIC_TYPE_INFO, std::move(alias)); - } else { - type_info_->alias = std::move(alias); - } -} - -string LogicalType::GetAlias() const { - if (id() == LogicalTypeId::USER) { - return UserType::GetTypeName(*this); - } - if (type_info_) { - return type_info_->alias; - } - return string(); -} - -bool LogicalType::HasAlias() const { - if (id() == LogicalTypeId::USER) { - return !UserType::GetTypeName(*this).empty(); - } - if (type_info_ && !type_info_->alias.empty()) { - return true; - } - return false; -} - -bool LogicalType::HasExtensionInfo() const { - if (type_info_ && type_info_->extension_info) { - return true; - } - return false; -} - -optional_ptr LogicalType::GetExtensionInfo() const { - if (type_info_ && type_info_->extension_info) { - return type_info_->extension_info.get(); - } - return nullptr; -} - -optional_ptr LogicalType::GetExtensionInfo() { - if (type_info_ && type_info_->extension_info) { - return type_info_->extension_info.get(); - } - return nullptr; -} - -void LogicalType::SetExtensionInfo(unique_ptr info) { - if (!type_info_) { - type_info_ = make_shared_ptr(ExtraTypeInfoType::GENERIC_TYPE_INFO); - } - type_info_->extension_info = std::move(info); -} - -//===--------------------------------------------------------------------===// -// Decimal Type -//===--------------------------------------------------------------------===// -uint8_t DecimalType::GetWidth(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().width; -} - -uint8_t DecimalType::GetScale(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().scale; -} - -uint8_t DecimalType::MaxWidth() { - return DecimalWidth::max; -} - -LogicalType LogicalType::DECIMAL(uint8_t width, uint8_t scale) { - D_ASSERT(width >= scale); - auto type_info = make_shared_ptr(width, scale); - return LogicalType(LogicalTypeId::DECIMAL, std::move(type_info)); -} - -//===--------------------------------------------------------------------===// -// String Type -//===--------------------------------------------------------------------===// -string StringType::GetCollation(const LogicalType &type) { - if (type.id() != LogicalTypeId::VARCHAR) { - return string(); - } - auto info = type.AuxInfo(); - if (!info) { - return string(); - } - if (info->type == ExtraTypeInfoType::GENERIC_TYPE_INFO) { - return string(); - } - return info->Cast().collation; -} - -LogicalType LogicalType::VARCHAR_COLLATION(string collation) { // NOLINT - auto string_info = make_shared_ptr(std::move(collation)); - return LogicalType(LogicalTypeId::VARCHAR, std::move(string_info)); -} - -//===--------------------------------------------------------------------===// -// List Type -//===--------------------------------------------------------------------===// -const LogicalType &ListType::GetChildType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::LIST || type.id() == LogicalTypeId::MAP); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().child_type; -} - -LogicalType LogicalType::LIST(const LogicalType &child) { - auto info = make_shared_ptr(child); - return LogicalType(LogicalTypeId::LIST, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Aggregate State Type -//===--------------------------------------------------------------------===// -const aggregate_state_t &AggregateStateType::GetStateType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().state_type; -} - -const string AggregateStateType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::AGGREGATE_STATE); - auto info = type.AuxInfo(); - if (!info) { - return "AGGREGATE_STATE"; - } - auto aggr_state = info->Cast().state_type; - return "AGGREGATE_STATE<" + aggr_state.function_name + "(" + - StringUtil::Join(aggr_state.bound_argument_types, aggr_state.bound_argument_types.size(), ", ", - [](const LogicalType &arg_type) { return arg_type.ToString(); }) + - ")" + "::" + aggr_state.return_type.ToString() + ">"; -} - -//===--------------------------------------------------------------------===// -// Struct Type -//===--------------------------------------------------------------------===// -const child_list_t &StructType::GetChildTypes(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::STRUCT || type.id() == LogicalTypeId::UNION); - - auto info = type.AuxInfo(); - D_ASSERT(info); - return info->Cast().child_types; -} - -const LogicalType &StructType::GetChildType(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - return child_types[index].second; -} - -const string &StructType::GetChildName(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - return child_types[index].first; -} - -idx_t StructType::GetChildIndexUnsafe(const LogicalType &type, const string &name) { - auto &child_types = StructType::GetChildTypes(type); - for (idx_t i = 0; i < child_types.size(); i++) { - if (StringUtil::CIEquals(child_types[i].first, name)) { - return i; - } - } - throw InternalException("Could not find child with name \"%s\" in struct type \"%s\"", name, type.ToString()); -} - -idx_t StructType::GetChildCount(const LogicalType &type) { - return StructType::GetChildTypes(type).size(); -} -bool StructType::IsUnnamed(const LogicalType &type) { - auto &child_types = StructType::GetChildTypes(type); - if (child_types.empty()) { - return false; - } - return child_types[0].first.empty(); // NOLINT -} - -LogicalType LogicalType::STRUCT(child_list_t children) { - auto info = make_shared_ptr(std::move(children)); - return LogicalType(LogicalTypeId::STRUCT, std::move(info)); -} - -LogicalType LogicalType::AGGREGATE_STATE(aggregate_state_t state_type) { // NOLINT - auto info = make_shared_ptr(std::move(state_type)); - return LogicalType(LogicalTypeId::AGGREGATE_STATE, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Map Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::MAP(const LogicalType &child_p) { - D_ASSERT(child_p.id() == LogicalTypeId::STRUCT); - auto &children = StructType::GetChildTypes(child_p); - D_ASSERT(children.size() == 2); - - // We do this to enforce that for every MAP created, the keys are called "key" - // and the values are called "value" - - // This is done because for Vector the keys of the STRUCT are used in equality checks. - // Vector::Reference will throw if the types don't match - child_list_t new_children(2); - new_children[0] = children[0]; - new_children[0].first = "key"; - - new_children[1] = children[1]; - new_children[1].first = "value"; - - auto child = LogicalType::STRUCT(std::move(new_children)); - auto info = make_shared_ptr(child); - return LogicalType(LogicalTypeId::MAP, std::move(info)); -} - -LogicalType LogicalType::MAP(LogicalType key, LogicalType value) { - child_list_t child_types; - child_types.emplace_back("key", std::move(key)); - child_types.emplace_back("value", std::move(value)); - return LogicalType::MAP(LogicalType::STRUCT(child_types)); -} - -const LogicalType &MapType::KeyType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::MAP); - return StructType::GetChildTypes(ListType::GetChildType(type))[0].second; -} - -const LogicalType &MapType::ValueType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::MAP); - return StructType::GetChildTypes(ListType::GetChildType(type))[1].second; -} - -//===--------------------------------------------------------------------===// -// Union Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::UNION(child_list_t members) { - D_ASSERT(!members.empty()); - D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); - // union types always have a hidden "tag" field in front - members.insert(members.begin(), {"", LogicalType::UTINYINT}); - auto info = make_shared_ptr(std::move(members)); - return LogicalType(LogicalTypeId::UNION, std::move(info)); -} - -const LogicalType &UnionType::GetMemberType(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - // skip the "tag" field - return child_types[index + 1].second; -} - -const string &UnionType::GetMemberName(const LogicalType &type, idx_t index) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(index < child_types.size()); - // skip the "tag" field - return child_types[index + 1].first; -} - -idx_t UnionType::GetMemberCount(const LogicalType &type) { - // don't count the "tag" field - return StructType::GetChildTypes(type).size() - 1; -} -const child_list_t UnionType::CopyMemberTypes(const LogicalType &type) { - auto child_types = StructType::GetChildTypes(type); - child_types.erase(child_types.begin()); - return child_types; -} - -//===--------------------------------------------------------------------===// -// User Type -//===--------------------------------------------------------------------===// -const string &UserType::GetCatalog(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.AuxInfo(); - return info->Cast().catalog; -} - -const string &UserType::GetSchema(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.AuxInfo(); - return info->Cast().schema; -} - -const string &UserType::GetTypeName(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.AuxInfo(); - return info->Cast().user_type_name; -} - -const vector &UserType::GetTypeModifiers(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.AuxInfo(); - return info->Cast().user_type_modifiers; -} - -vector &UserType::GetTypeModifiers(LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::USER); - auto info = type.GetAuxInfoShrPtr(); - return info->Cast().user_type_modifiers; -} - -LogicalType LogicalType::USER(const string &user_type_name) { - auto info = make_shared_ptr(user_type_name); - return LogicalType(LogicalTypeId::USER, std::move(info)); -} - -LogicalType LogicalType::USER(const string &user_type_name, const vector &user_type_mods) { - auto info = make_shared_ptr(user_type_name, user_type_mods); - return LogicalType(LogicalTypeId::USER, std::move(info)); -} - -LogicalType LogicalType::USER(string catalog, string schema, string name, vector user_type_mods) { - auto info = make_shared_ptr(std::move(catalog), std::move(schema), std::move(name), - std::move(user_type_mods)); - return LogicalType(LogicalTypeId::USER, std::move(info)); -} - -//===--------------------------------------------------------------------===// -// Enum Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::ENUM(Vector &ordered_data, idx_t size) { - return EnumTypeInfo::CreateType(ordered_data, size); -} - -LogicalType LogicalType::ENUM(const string &enum_name, Vector &ordered_data, idx_t size) { - return LogicalType::ENUM(ordered_data, size); -} - -const string EnumType::GetValue(const Value &val) { - auto info = val.type().AuxInfo(); - auto &values_insert_order = info->Cast().GetValuesInsertOrder(); - return StringValue::Get(values_insert_order.GetValue(val.GetValue())); -} - -const Vector &EnumType::GetValuesInsertOrder(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto info = type.AuxInfo(); - return info->Cast().GetValuesInsertOrder(); -} - -idx_t EnumType::GetSize(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto info = type.AuxInfo(); - return info->Cast().GetDictSize(); -} - -PhysicalType EnumType::GetPhysicalType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ENUM); - auto aux_info = type.AuxInfo(); - auto &info = aux_info->Cast(); - D_ASSERT(info.GetEnumDictType() == EnumDictType::VECTOR_DICT); - return EnumTypeInfo::DictType(info.GetDictSize()); -} - -//===--------------------------------------------------------------------===// -// JSON Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::JSON() { - auto json_type = LogicalType(LogicalTypeId::VARCHAR); - json_type.SetAlias(JSON_TYPE_NAME); - return json_type; -} - -bool LogicalType::IsJSONType() const { - return id() == LogicalTypeId::VARCHAR && HasAlias() && GetAlias() == JSON_TYPE_NAME; -} - -//===--------------------------------------------------------------------===// -// Array Type -//===--------------------------------------------------------------------===// - -const LogicalType &ArrayType::GetChildType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ARRAY); - auto info = type.AuxInfo(); - return info->Cast().child_type; -} - -idx_t ArrayType::GetSize(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ARRAY); - auto info = type.AuxInfo(); - return info->Cast().size; -} - -bool ArrayType::IsAnySize(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ARRAY); - auto info = type.AuxInfo(); - return info->Cast().size == 0; -} - -LogicalType ArrayType::ConvertToList(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::ARRAY: { - return LogicalType::LIST(ConvertToList(ArrayType::GetChildType(type))); - } - case LogicalTypeId::LIST: - return LogicalType::LIST(ConvertToList(ListType::GetChildType(type))); - case LogicalTypeId::STRUCT: { - auto children = StructType::GetChildTypes(type); - for (auto &child : children) { - child.second = ConvertToList(child.second); - } - return LogicalType::STRUCT(children); - } - case LogicalTypeId::MAP: { - auto key_type = ConvertToList(MapType::KeyType(type)); - auto value_type = ConvertToList(MapType::ValueType(type)); - return LogicalType::MAP(key_type, value_type); - } - case LogicalTypeId::UNION: { - auto children = UnionType::CopyMemberTypes(type); - for (auto &child : children) { - child.second = ConvertToList(child.second); - } - return LogicalType::UNION(children); - } - default: - return type; - } -} - -LogicalType LogicalType::ARRAY(const LogicalType &child, optional_idx size) { - if (!size.IsValid()) { - // Create an incomplete ARRAY type, used for binding - auto info = make_shared_ptr(child, 0); - return LogicalType(LogicalTypeId::ARRAY, std::move(info)); - } else { - auto array_size = size.GetIndex(); - D_ASSERT(array_size > 0); - D_ASSERT(array_size <= ArrayType::MAX_ARRAY_SIZE); - auto info = make_shared_ptr(child, array_size); - return LogicalType(LogicalTypeId::ARRAY, std::move(info)); - } -} - -//===--------------------------------------------------------------------===// -// Any Type -//===--------------------------------------------------------------------===// -LogicalType LogicalType::ANY_PARAMS(LogicalType target, idx_t cast_score) { // NOLINT - auto type_info = make_shared_ptr(std::move(target), cast_score); - return LogicalType(LogicalTypeId::ANY, std::move(type_info)); -} - -LogicalType AnyType::GetTargetType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ANY); - auto info = type.AuxInfo(); - if (!info) { - return LogicalType::ANY; - } - return info->Cast().target_type; -} - -idx_t AnyType::GetCastScore(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::ANY); - auto info = type.AuxInfo(); - if (!info) { - return 5; - } - return info->Cast().cast_score; -} - -//===--------------------------------------------------------------------===// -// Integer Literal Type -//===--------------------------------------------------------------------===// -LogicalType IntegerLiteral::GetType(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::INTEGER_LITERAL); - auto info = type.AuxInfo(); - D_ASSERT(info->type == ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO); - return info->Cast().constant_value.type(); -} - -bool IntegerLiteral::FitsInType(const LogicalType &type, const LogicalType &target) { - D_ASSERT(type.id() == LogicalTypeId::INTEGER_LITERAL); - // we can always cast integer literals to float and double - if (target.id() == LogicalTypeId::FLOAT || target.id() == LogicalTypeId::DOUBLE) { - return true; - } - if (!target.IsIntegral()) { - return false; - } - // we can cast to integral types if the constant value fits within that type - auto info = type.AuxInfo(); - D_ASSERT(info->type == ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO); - auto &literal_info = info->Cast(); - Value copy = literal_info.constant_value; - return copy.DefaultTryCastAs(target); -} - -LogicalType LogicalType::INTEGER_LITERAL(const Value &constant) { // NOLINT - if (!constant.type().IsIntegral()) { - throw InternalException("INTEGER_LITERAL can only be made from literals of integer types"); - } - auto type_info = make_shared_ptr(constant); - return LogicalType(LogicalTypeId::INTEGER_LITERAL, std::move(type_info)); -} - -//===--------------------------------------------------------------------===// -// Logical Type -//===--------------------------------------------------------------------===// - -// the destructor needs to know about the extra type info -LogicalType::~LogicalType() { -} - -bool LogicalType::EqualTypeInfo(const LogicalType &rhs) const { - if (type_info_.get() == rhs.type_info_.get()) { - return true; - } - if (type_info_) { - return type_info_->Equals(rhs.type_info_.get()); - } else { - D_ASSERT(rhs.type_info_); - return rhs.type_info_->Equals(type_info_.get()); - } -} - -bool LogicalType::operator==(const LogicalType &rhs) const { - if (id_ != rhs.id_) { - return false; - } - return EqualTypeInfo(rhs); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/batched_data_collection.cpp b/src/duckdb/src/common/types/batched_data_collection.cpp deleted file mode 100644 index fd25dbc1a..000000000 --- a/src/duckdb/src/common/types/batched_data_collection.cpp +++ /dev/null @@ -1,178 +0,0 @@ -#include "duckdb/common/types/batched_data_collection.hpp" - -#include "duckdb/common/optional_ptr.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p) { -} - -BatchedDataCollection::BatchedDataCollection(ClientContext &context_p, vector types_p, batch_map_t batches, - bool buffer_managed_p) - : context(context_p), types(std::move(types_p)), buffer_managed(buffer_managed_p), data(std::move(batches)) { -} - -void BatchedDataCollection::Append(DataChunk &input, idx_t batch_index) { - D_ASSERT(batch_index != DConstants::INVALID_INDEX); - optional_ptr collection; - if (last_collection.collection && last_collection.batch_index == batch_index) { - // we are inserting into the same collection as before: use it directly - collection = last_collection.collection; - } else { - // new collection: check if there is already an entry - D_ASSERT(data.find(batch_index) == data.end()); - unique_ptr new_collection; - if (last_collection.collection) { - new_collection = make_uniq(*last_collection.collection); - } else if (buffer_managed) { - new_collection = make_uniq(BufferManager::GetBufferManager(context), types); - } else { - new_collection = make_uniq(Allocator::DefaultAllocator(), types); - } - last_collection.collection = new_collection.get(); - last_collection.batch_index = batch_index; - new_collection->InitializeAppend(last_collection.append_state); - collection = new_collection.get(); - data.insert(make_pair(batch_index, std::move(new_collection))); - } - collection->Append(last_collection.append_state, input); -} - -void BatchedDataCollection::Merge(BatchedDataCollection &other) { - for (auto &entry : other.data) { - if (data.find(entry.first) != data.end()) { - throw InternalException( - "BatchedDataCollection::Merge error - batch index %d is present in both collections. This occurs when " - "batch indexes are not uniquely distributed over threads", - entry.first); - } - data[entry.first] = std::move(entry.second); - } - other.data.clear(); -} - -void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state, const BatchedChunkIteratorRange &range) { - state.range = range; - if (state.range.begin == state.range.end) { - return; - } - state.range.begin->second->InitializeScan(state.scan_state); -} - -void BatchedDataCollection::InitializeScan(BatchedChunkScanState &state) { - auto range = BatchRange(); - return InitializeScan(state, range); -} - -void BatchedDataCollection::Scan(BatchedChunkScanState &state, DataChunk &output) { - while (state.range.begin != state.range.end) { - // check if there is a chunk remaining in this collection - auto collection = state.range.begin->second.get(); - collection->Scan(state.scan_state, output); - if (output.size() > 0) { - return; - } - // there isn't! move to the next collection - state.range.begin->second.reset(); - state.range.begin++; - if (state.range.begin == state.range.end) { - return; - } - state.range.begin->second->InitializeScan(state.scan_state); - } -} - -unique_ptr BatchedDataCollection::FetchCollection() { - unique_ptr result; - for (auto &entry : data) { - if (!result) { - result = std::move(entry.second); - } else { - result->Combine(*entry.second); - } - } - data.clear(); - if (!result) { - // empty result - return make_uniq(Allocator::DefaultAllocator(), types); - } - return result; -} - -const vector &BatchedDataCollection::Types() const { - return types; -} - -idx_t BatchedDataCollection::Count() const { - idx_t count = 0; - for (auto &collection : data) { - count += collection.second->Count(); - } - return count; -} - -idx_t BatchedDataCollection::BatchCount() const { - return data.size(); -} - -idx_t BatchedDataCollection::IndexToBatchIndex(idx_t index) const { - if (index >= data.size()) { - throw InternalException("Index %d is out of range for this collection, it only contains %d batches", index, - data.size()); - } - auto entry = data.begin(); - std::advance(entry, index); - return entry->first; -} - -idx_t BatchedDataCollection::BatchSize(idx_t batch_index) const { - auto &collection = Batch(batch_index); - return collection.Count(); -} - -const ColumnDataCollection &BatchedDataCollection::Batch(idx_t batch_index) const { - auto entry = data.find(batch_index); - if (entry == data.end()) { - throw InternalException("This batched data collection does not contain a collection for batch_index %d", - batch_index); - } - return *entry->second; -} - -BatchedChunkIteratorRange BatchedDataCollection::BatchRange(idx_t begin_idx, idx_t end_idx) { - D_ASSERT(begin_idx < end_idx); - if (end_idx > data.size()) { - // Limit the iterator to the end - end_idx = DConstants::INVALID_INDEX; - } - BatchedChunkIteratorRange range; - range.begin = data.begin(); - std::advance(range.begin, begin_idx); - if (end_idx == DConstants::INVALID_INDEX) { - range.end = data.end(); - } else { - range.end = data.begin(); - std::advance(range.end, end_idx); - } - return range; -} - -string BatchedDataCollection::ToString() const { - string result; - result += "Batched Data Collection\n"; - for (auto &entry : data) { - result += "Batch Index - " + to_string(entry.first) + "\n"; - result += entry.second->ToString() + "\n\n"; - } - return result; -} - -void BatchedDataCollection::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/bit.cpp b/src/duckdb/src/common/types/bit.cpp deleted file mode 100644 index 5006d64f2..000000000 --- a/src/duckdb/src/common/types/bit.cpp +++ /dev/null @@ -1,434 +0,0 @@ -#include "duckdb/common/assert.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/typedefs.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/string_type.hpp" - -namespace duckdb { - -// **** helper functions **** -static char ComputePadding(idx_t len) { - return UnsafeNumericCast((8 - (len % 8)) % 8); -} - -idx_t Bit::ComputeBitstringLen(idx_t len) { - idx_t result = len / 8; - if (len % 8 != 0) { - result++; - } - // additional first byte to store info on zero padding - result++; - return result; -} - -static inline idx_t GetBitPadding(const bitstring_t &bit_string) { - auto data = const_data_ptr_cast(bit_string.GetData()); - D_ASSERT(idx_t(data[0]) <= 8); - return data[0]; -} - -static inline idx_t GetBitSize(const string_t &str) { - string error_message; - idx_t str_len; - if (!Bit::TryGetBitStringSize(str, str_len, &error_message)) { - throw ConversionException(error_message); - } - return str_len; -} - -uint8_t Bit::GetFirstByte(const bitstring_t &str) { - D_ASSERT(str.GetSize() > 1); - - auto data = const_data_ptr_cast(str.GetData()); - return data[1] & ((1 << (8 - data[0])) - 1); -} - -void Bit::Finalize(bitstring_t &str) { - // bit strings require all padding bits to be set to 1 - // this method sets all padding bits to 1 - auto padding = GetBitPadding(str); - for (idx_t i = 0; i < idx_t(padding); i++) { - Bit::SetBitInternal(str, i, 1); - } - str.Finalize(); - Bit::Verify(str); -} - -void Bit::SetEmptyBitString(bitstring_t &target, string_t &input) { - char *res_buf = target.GetDataWriteable(); - const char *buf = input.GetData(); - memset(res_buf, 0, input.GetSize()); - res_buf[0] = buf[0]; - Bit::Finalize(target); -} - -void Bit::SetEmptyBitString(bitstring_t &target, idx_t len) { - char *res_buf = target.GetDataWriteable(); - memset(res_buf, 0, target.GetSize()); - res_buf[0] = ComputePadding(len); - Bit::Finalize(target); -} - -// **** casting functions **** -void Bit::ToString(bitstring_t bits, char *output) { - auto data = const_data_ptr_cast(bits.GetData()); - auto len = bits.GetSize(); - - idx_t padding = GetBitPadding(bits); - idx_t output_idx = 0; - for (idx_t bit_idx = padding; bit_idx < 8; bit_idx++) { - output[output_idx++] = data[1] & (1 << (7 - bit_idx)) ? '1' : '0'; - } - for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - output[output_idx++] = data[byte_idx] & (1 << (7 - bit_idx)) ? '1' : '0'; - } - } -} - -string Bit::ToString(bitstring_t str) { - auto len = BitLength(str); - auto buffer = make_unsafe_uniq_array_uninitialized(len); - ToString(str, buffer.get()); - return string(buffer.get(), len); -} - -bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_message) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '0' || data[i] == '1') { - str_len++; - } else { - string error = StringUtil::Format("Invalid character encountered in string -> bit conversion: '%s'", - string(const_char_ptr_cast(data) + i, 1)); - HandleCastError::AssignError(error, error_message); - return false; - } - } - if (str_len == 0) { - string error = "Cannot cast empty string to BIT"; - HandleCastError::AssignError(error, error_message); - return false; - } - str_len = ComputeBitstringLen(str_len); - return true; -} - -void Bit::ToBit(string_t str, bitstring_t &output_str) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - auto output = output_str.GetDataWriteable(); - - char byte = 0; - idx_t padded_byte = len % 8; - for (idx_t i = 0; i < padded_byte; i++) { - byte <<= 1; - if (data[i] == '1') { - byte |= 1; - } - } - if (padded_byte != 0) { - *(output++) = UnsafeNumericCast((8 - padded_byte)); // the first byte contains the number of padded zeroes - } - *(output++) = byte; - - for (idx_t byte_idx = padded_byte; byte_idx < len; byte_idx += 8) { - byte = 0; - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - byte <<= 1; - if (data[byte_idx + bit_idx] == '1') { - byte |= 1; - } - } - *(output++) = byte; - } - Bit::Finalize(output_str); -} - -string Bit::ToBit(string_t str) { - auto bit_len = GetBitSize(str); - auto buffer = make_unsafe_uniq_array_uninitialized(bit_len); - bitstring_t output_str(buffer.get(), UnsafeNumericCast(bit_len)); - Bit::ToBit(str, output_str); - return output_str.GetString(); -} - -void Bit::BlobToBit(string_t blob, bitstring_t &output_str) { - auto data = const_data_ptr_cast(blob.GetData()); - auto output = output_str.GetDataWriteable(); - idx_t size = blob.GetSize(); - - *output = 0; // No padding - memcpy(output + 1, data, size); -} - -string Bit::BlobToBit(string_t blob) { - auto buffer = make_unsafe_uniq_array_uninitialized(blob.GetSize() + 1); - bitstring_t output_str(buffer.get(), UnsafeNumericCast(blob.GetSize() + 1)); - Bit::BlobToBit(blob, output_str); - return output_str.GetString(); -} - -void Bit::BitToBlob(bitstring_t bit, string_t &output_blob) { - D_ASSERT(bit.GetSize() == output_blob.GetSize() + 1); - - auto data = const_data_ptr_cast(bit.GetData()); - auto output = output_blob.GetDataWriteable(); - idx_t size = output_blob.GetSize(); - - output[0] = UnsafeNumericCast(GetFirstByte(bit)); - if (size >= 2) { - ++output; - // First byte in bitstring contains amount of padded bits, - // second byte in bitstring is the padded byte, - // therefore the rest of the data starts at data + 2 (third byte) - memcpy(output, data + 2, size - 1); - } -} - -string Bit::BitToBlob(bitstring_t bit) { - D_ASSERT(bit.GetSize() > 1); - - auto buffer = make_unsafe_uniq_array_uninitialized(bit.GetSize() - 1); - string_t output_str(buffer.get(), UnsafeNumericCast(bit.GetSize() - 1)); - Bit::BitToBlob(bit, output_str); - return output_str.GetString(); -} - -// **** scalar functions **** -void Bit::BitString(const string_t &input, idx_t bit_length, bitstring_t &result) { - char *res_buf = result.GetDataWriteable(); - const char *buf = input.GetData(); - - auto padding = ComputePadding(bit_length); - res_buf[0] = padding; - auto padding_len = UnsafeNumericCast(padding); - for (idx_t i = 0; i < bit_length; i++) { - if (i < bit_length - input.GetSize()) { - Bit::SetBitInternal(result, i + padding_len, 0); - } else { - idx_t bit = buf[i - (bit_length - input.GetSize())] == '1' ? 1 : 0; - Bit::SetBitInternal(result, i + padding_len, bit); - } - } - Bit::Finalize(result); -} - -void Bit::ExtendBitString(const bitstring_t &input, idx_t bit_length, bitstring_t &result) { - uint8_t *res_buf = reinterpret_cast(result.GetDataWriteable()); - - auto padding = ComputePadding(bit_length); - res_buf[0] = static_cast(padding); - - idx_t original_length = Bit::BitLength(input); - D_ASSERT(bit_length >= original_length); - idx_t shift = bit_length - original_length; - for (idx_t i = 0; i < bit_length; i++) { - if (i < shift) { - Bit::SetBit(result, i, 0); - } else { - idx_t bit = Bit::GetBit(input, i - shift); - Bit::SetBit(result, i, bit); - } - } - Bit::Finalize(result); -} - -idx_t Bit::BitLength(bitstring_t bits) { - return ((bits.GetSize() - 1) * 8) - GetBitPadding(bits); -} - -idx_t Bit::OctetLength(bitstring_t bits) { - return bits.GetSize() - 1; -} - -idx_t Bit::BitCount(bitstring_t bits) { - idx_t count = 0; - const char *buf = bits.GetData(); - for (idx_t byte_idx = 1; byte_idx < OctetLength(bits) + 1; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - count += (buf[byte_idx] & (1 << bit_idx)) ? 1 : 0; - } - } - return count - GetBitPadding(bits); -} - -idx_t Bit::BitPosition(bitstring_t substring, bitstring_t bits) { - const char *buf = bits.GetData(); - auto len = bits.GetSize(); - auto substr_len = BitLength(substring); - idx_t substr_idx = 0; - - for (idx_t bit_idx = GetBitPadding(bits); bit_idx < 8; bit_idx++) { - idx_t bit = buf[1] & (1 << (7 - bit_idx)) ? 1 : 0; - if (bit == GetBit(substring, substr_idx)) { - substr_idx++; - if (substr_idx == substr_len) { - return (bit_idx - GetBitPadding(bits)) - substr_len + 2; - } - } else { - substr_idx = 0; - } - } - - for (idx_t byte_idx = 2; byte_idx < len; byte_idx++) { - for (idx_t bit_idx = 0; bit_idx < 8; bit_idx++) { - idx_t bit = buf[byte_idx] & (1 << (7 - bit_idx)) ? 1 : 0; - if (bit == GetBit(substring, substr_idx)) { - substr_idx++; - if (substr_idx == substr_len) { - return (((byte_idx - 1) * 8) + bit_idx - GetBitPadding(bits)) - substr_len + 2; - } - } else { - substr_idx = 0; - } - } - } - return 0; -} - -idx_t Bit::GetBit(bitstring_t bit_string, idx_t n) { - return Bit::GetBitInternal(bit_string, n + GetBitPadding(bit_string)); -} - -idx_t Bit::GetBitIndex(idx_t n) { - return n / 8 + 1; -} - -idx_t Bit::GetBitInternal(bitstring_t bit_string, idx_t n) { - const char *buf = bit_string.GetData(); - auto idx = Bit::GetBitIndex(n); - D_ASSERT(idx < bit_string.GetSize()); - auto byte = buf[idx] >> (7 - (n % 8)); - return (byte & 1 ? 1 : 0); -} - -void Bit::SetBit(bitstring_t &bit_string, idx_t n, idx_t new_value) { - SetBitInternal(bit_string, n + GetBitPadding(bit_string), new_value); - Bit::Finalize(bit_string); -} - -void Bit::SetBitInternal(bitstring_t &bit_string, idx_t n, idx_t new_value) { - uint8_t *buf = reinterpret_cast(bit_string.GetDataWriteable()); - - auto idx = Bit::GetBitIndex(n); - D_ASSERT(idx < bit_string.GetSize()); - auto shift_byte = UnsafeNumericCast(1 << (7 - (n % 8))); - if (new_value == 0) { - shift_byte = ~shift_byte; - buf[idx] &= shift_byte; - } else { - buf[idx] |= shift_byte; - } -} - -// **** BITWISE operators **** -void Bit::RightShift(const bitstring_t &bit_string, idx_t shift, bitstring_t &result) { - uint8_t *res_buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *buf = reinterpret_cast(bit_string.GetData()); - - res_buf[0] = buf[0]; - auto padding = GetBitPadding(result); - for (idx_t i = 0; i < Bit::BitLength(result); i++) { - if (i < shift) { - Bit::SetBitInternal(result, i + padding, 0); - } else { - idx_t bit = Bit::GetBit(bit_string, i - shift); - Bit::SetBitInternal(result, i + padding, bit); - } - } - Bit::Finalize(result); -} - -void Bit::LeftShift(const bitstring_t &bit_string, idx_t shift, bitstring_t &result) { - uint8_t *res_buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *buf = reinterpret_cast(bit_string.GetData()); - - res_buf[0] = buf[0]; - auto padding = GetBitPadding(result); - for (idx_t i = 0; i < Bit::BitLength(bit_string); i++) { - if (i < (Bit::BitLength(bit_string) - shift)) { - idx_t bit = Bit::GetBit(bit_string, shift + i); - Bit::SetBitInternal(result, i + padding, bit); - } else { - Bit::SetBitInternal(result, i + padding, 0); - } - } - Bit::Finalize(result); -} - -void Bit::BitwiseAnd(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot AND bit strings of different sizes"); - } - - uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); - const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] & r_buf[i]; - } - Bit::Finalize(result); -} - -void Bit::BitwiseOr(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot OR bit strings of different sizes"); - } - - uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); - const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] | r_buf[i]; - } - Bit::Finalize(result); -} - -void Bit::BitwiseXor(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) { - if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) { - throw InvalidInputException("Cannot XOR bit strings of different sizes"); - } - - uint8_t *buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *r_buf = reinterpret_cast(rhs.GetData()); - const uint8_t *l_buf = reinterpret_cast(lhs.GetData()); - - buf[0] = l_buf[0]; - for (idx_t i = 1; i < lhs.GetSize(); i++) { - buf[i] = l_buf[i] ^ r_buf[i]; - } - Bit::Finalize(result); -} - -void Bit::BitwiseNot(const bitstring_t &input, bitstring_t &result) { - uint8_t *result_buf = reinterpret_cast(result.GetDataWriteable()); - const uint8_t *buf = reinterpret_cast(input.GetData()); - - result_buf[0] = buf[0]; - for (idx_t i = 1; i < input.GetSize(); i++) { - result_buf[i] = ~buf[i]; - } - Bit::Finalize(result); -} - -void Bit::Verify(const bitstring_t &input) { -#ifdef DEBUG - // bit strings require all padding bits to be set to 1 - auto padding = GetBitPadding(input); - for (idx_t i = 0; i < padding; i++) { - D_ASSERT(Bit::GetBitInternal(input, i)); - } - // verify bit respects the "normal" string_t rules (i.e. null padding for inlined strings, prefix matches) - input.VerifyCharacters(); -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/blob.cpp b/src/duckdb/src/common/types/blob.cpp deleted file mode 100644 index 11cd47b7e..000000000 --- a/src/duckdb/src/common/types/blob.cpp +++ /dev/null @@ -1,281 +0,0 @@ -#include "duckdb/common/types/blob.hpp" - -#include "duckdb/common/assert.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/string_type.hpp" - -namespace duckdb { - -constexpr const char *Blob::HEX_TABLE; -const int Blob::HEX_MAP[256] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, - -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; - -bool IsRegularCharacter(data_t c) { - return c >= 32 && c <= 126 && c != '\\' && c != '\'' && c != '"'; -} - -idx_t Blob::GetStringSize(string_t blob) { - auto data = const_data_ptr_cast(blob.GetData()); - auto len = blob.GetSize(); - idx_t str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (IsRegularCharacter(data[i])) { - // ascii characters are rendered as-is - str_len++; - } else { - // non-ascii characters are rendered as hexadecimal (e.g. \x00) - str_len += 4; - } - } - return str_len; -} - -void Blob::ToString(string_t blob, char *output) { - auto data = const_data_ptr_cast(blob.GetData()); - auto len = blob.GetSize(); - idx_t str_idx = 0; - for (idx_t i = 0; i < len; i++) { - if (IsRegularCharacter(data[i])) { - // ascii characters are rendered as-is - output[str_idx++] = UnsafeNumericCast(data[i]); - } else { - auto byte_a = data[i] >> 4; - auto byte_b = data[i] & 0x0F; - D_ASSERT(byte_a >= 0 && byte_a < 16); - D_ASSERT(byte_b >= 0 && byte_b < 16); - // non-ascii characters are rendered as hexadecimal (e.g. \x00) - output[str_idx++] = '\\'; - output[str_idx++] = 'x'; - output[str_idx++] = Blob::HEX_TABLE[byte_a]; - output[str_idx++] = Blob::HEX_TABLE[byte_b]; - } - } - D_ASSERT(str_idx == GetStringSize(blob)); -} - -string Blob::ToString(string_t blob) { - auto str_len = GetStringSize(blob); - auto buffer = make_unsafe_uniq_array_uninitialized(str_len); - Blob::ToString(blob, buffer.get()); - return string(buffer.get(), str_len); -} - -bool Blob::TryGetBlobSize(string_t str, idx_t &str_len, CastParameters ¶meters) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - str_len = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '\\') { - if (i + 3 >= len) { - string error = StringUtil::Format("Invalid hex escape code encountered in string -> blob conversion of " - "string \"%s\": unterminated escape code at end of blob", - str.GetString()); - HandleCastError::AssignError(error, parameters); - return false; - } - if (data[i + 1] != 'x' || Blob::HEX_MAP[data[i + 2]] < 0 || Blob::HEX_MAP[data[i + 3]] < 0) { - string error = StringUtil::Format( - "Invalid hex escape code encountered in string -> blob conversion of string \"%s\": %s", - str.GetString(), string(const_char_ptr_cast(data) + i, 4)); - HandleCastError::AssignError(error, parameters); - return false; - } - str_len++; - i += 3; - } else if (data[i] <= 127) { - str_len++; - } else { - string error = StringUtil::Format( - "Invalid byte encountered in STRING -> BLOB conversion of string \"%s\". All non-ascii characters " - "must be escaped with hex codes (e.g. \\xAA)", - str.GetString()); - HandleCastError::AssignError(error, parameters); - return false; - } - } - return true; -} - -idx_t Blob::GetBlobSize(string_t str) { - CastParameters parameters; - return GetBlobSize(str, parameters); -} - -idx_t Blob::GetBlobSize(string_t str, CastParameters ¶meters) { - idx_t str_len; - auto result = Blob::TryGetBlobSize(str, str_len, parameters); - if (!result) { - throw InternalException("Blob::TryGetBlobSize failed but no exception was thrown!?"); - } - return str_len; -} - -void Blob::ToBlob(string_t str, data_ptr_t output) { - auto data = const_data_ptr_cast(str.GetData()); - auto len = str.GetSize(); - idx_t blob_idx = 0; - for (idx_t i = 0; i < len; i++) { - if (data[i] == '\\') { - int byte_a = Blob::HEX_MAP[data[i + 2]]; - int byte_b = Blob::HEX_MAP[data[i + 3]]; - D_ASSERT(i + 3 < len); - D_ASSERT(byte_a >= 0 && byte_b >= 0); - D_ASSERT(data[i + 1] == 'x'); - output[blob_idx++] = UnsafeNumericCast((byte_a << 4) + byte_b); - i += 3; - } else if (data[i] <= 127) { - output[blob_idx++] = data_t(data[i]); - } else { - throw ConversionException("Invalid byte encountered in STRING -> BLOB conversion. All non-ascii characters " - "must be escaped with hex codes (e.g. \\xAA)"); - } - } - D_ASSERT(blob_idx == GetBlobSize(str)); -} - -string Blob::ToBlob(string_t str) { - CastParameters parameters; - return Blob::ToBlob(str, parameters); -} - -string Blob::ToBlob(string_t str, CastParameters ¶meters) { - auto blob_len = GetBlobSize(str, parameters); - auto buffer = make_unsafe_uniq_array_uninitialized(blob_len); - Blob::ToBlob(str, data_ptr_cast(buffer.get())); - return string(buffer.get(), blob_len); -} - -// base64 functions are adapted from https://gist.github.com/tomykaira/f0fd86b6c73063283afe550bc5d77594 -idx_t Blob::ToBase64Size(string_t blob) { - // every 4 characters in base64 encode 3 bytes, plus (potential) padding at the end - auto input_size = blob.GetSize(); - return ((input_size + 2) / 3) * 4; -} - -void Blob::ToBase64(string_t blob, char *output) { - auto input_data = const_data_ptr_cast(blob.GetData()); - auto input_size = blob.GetSize(); - idx_t out_idx = 0; - idx_t i; - // convert the bulk of the string to base64 - // this happens in steps of 3 bytes -> 4 output bytes - for (i = 0; i + 2 < input_size; i += 3) { - output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2) | ((input_data[i + 2] & 0xC0) >> 6)]; - output[out_idx++] = Blob::BASE64_MAP[input_data[i + 2] & 0x3F]; - } - - if (i < input_size) { - // there are one or two bytes left over: we have to insert padding - // first write the first 6 bits of the first byte - output[out_idx++] = Blob::BASE64_MAP[(input_data[i] >> 2) & 0x3F]; - // now check the character count - if (i == input_size - 1) { - // single byte left over: convert the remainder of that byte and insert padding - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4)]; - output[out_idx++] = Blob::BASE64_PADDING; - } else { - // two bytes left over: convert the second byte as well - output[out_idx++] = Blob::BASE64_MAP[((input_data[i] & 0x3) << 4) | ((input_data[i + 1] & 0xF0) >> 4)]; - output[out_idx++] = Blob::BASE64_MAP[((input_data[i + 1] & 0xF) << 2)]; - } - output[out_idx++] = Blob::BASE64_PADDING; - } -} - -static constexpr int BASE64_DECODING_TABLE[256] = { - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 62, -1, -1, -1, 63, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, - -1, -1, -1, -1, -1, -1, -1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, - 22, 23, 24, 25, -1, -1, -1, -1, -1, -1, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, - 45, 46, 47, 48, 49, 50, 51, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; - -idx_t Blob::FromBase64Size(string_t str) { - auto input_data = str.GetData(); - auto input_size = str.GetSize(); - if (input_size % 4 != 0) { - // valid base64 needs to always be cleanly divisible by 4 - throw ConversionException("Could not decode string \"%s\" as base64: length must be a multiple of 4", - str.GetString()); - } - if (input_size < 4) { - // empty string - return 0; - } - auto base_size = input_size / 4 * 3; - // check for padding to figure out the length - if (input_data[input_size - 2] == Blob::BASE64_PADDING) { - // two bytes of padding - return base_size - 2; - } - if (input_data[input_size - 1] == Blob::BASE64_PADDING) { - // one byte of padding - return base_size - 1; - } - // no padding - return base_size; -} - -template -uint32_t DecodeBase64Bytes(const string_t &str, const_data_ptr_t input_data, idx_t base_idx) { - int decoded_bytes[4]; - for (idx_t decode_idx = 0; decode_idx < 4; decode_idx++) { - if (ALLOW_PADDING && decode_idx >= 2 && input_data[base_idx + decode_idx] == Blob::BASE64_PADDING) { - // the last two bytes of a base64 string can have padding: in this case we set the byte to 0 - decoded_bytes[decode_idx] = 0; - } else { - decoded_bytes[decode_idx] = BASE64_DECODING_TABLE[input_data[base_idx + decode_idx]]; - } - if (decoded_bytes[decode_idx] < 0) { - throw ConversionException( - "Could not decode string \"%s\" as base64: invalid byte value '%d' at position %d", str.GetString(), - input_data[base_idx + decode_idx], base_idx + decode_idx); - } - } - return UnsafeNumericCast((decoded_bytes[0] << 3 * 6) + (decoded_bytes[1] << 2 * 6) + - (decoded_bytes[2] << 1 * 6) + (decoded_bytes[3] << 0 * 6)); -} - -void Blob::FromBase64(string_t str, data_ptr_t output, idx_t output_size) { - D_ASSERT(output_size == FromBase64Size(str)); - auto input_data = const_data_ptr_cast(str.GetData()); - auto input_size = str.GetSize(); - if (input_size == 0) { - return; - } - idx_t out_idx = 0; - idx_t i = 0; - for (i = 0; i + 4 < input_size; i += 4) { - auto combined = DecodeBase64Bytes(str, input_data, i); - output[out_idx++] = (combined >> 2 * 8) & 0xFF; - output[out_idx++] = (combined >> 1 * 8) & 0xFF; - output[out_idx++] = (combined >> 0 * 8) & 0xFF; - } - // decode the final four bytes: padding is allowed here - auto combined = DecodeBase64Bytes(str, input_data, i); - output[out_idx++] = (combined >> 2 * 8) & 0xFF; - if (out_idx < output_size) { - output[out_idx++] = (combined >> 1 * 8) & 0xFF; - } - if (out_idx < output_size) { - output[out_idx++] = (combined >> 0 * 8) & 0xFF; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/cast_helpers.cpp b/src/duckdb/src/common/types/cast_helpers.cpp deleted file mode 100644 index 5a5f11210..000000000 --- a/src/duckdb/src/common/types/cast_helpers.cpp +++ /dev/null @@ -1,313 +0,0 @@ -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" - -namespace duckdb { - -const int64_t NumericHelper::POWERS_OF_TEN[] {1, - 10, - 100, - 1000, - 10000, - 100000, - 1000000, - 10000000, - 100000000, - 1000000000, - 10000000000, - 100000000000, - 1000000000000, - 10000000000000, - 100000000000000, - 1000000000000000, - 10000000000000000, - 100000000000000000, - 1000000000000000000}; - -const double NumericHelper::DOUBLE_POWERS_OF_TEN[] {1e0, 1e1, 1e2, 1e3, 1e4, 1e5, 1e6, 1e7, 1e8, 1e9, - 1e10, 1e11, 1e12, 1e13, 1e14, 1e15, 1e16, 1e17, 1e18, 1e19, - 1e20, 1e21, 1e22, 1e23, 1e24, 1e25, 1e26, 1e27, 1e28, 1e29, - 1e30, 1e31, 1e32, 1e33, 1e34, 1e35, 1e36, 1e37, 1e38, 1e39}; - -template <> -int NumericHelper::UnsignedLength(uint8_t value) { - int length = 1; - length += value >= 10; - length += value >= 100; - return length; -} - -template <> -int NumericHelper::UnsignedLength(uint16_t value) { - int length = 1; - length += value >= 10; - length += value >= 100; - length += value >= 1000; - length += value >= 10000; - return length; -} - -template <> -int NumericHelper::UnsignedLength(uint32_t value) { - if (value >= 10000) { - int length = 5; - length += value >= 100000; - length += value >= 1000000; - length += value >= 10000000; - length += value >= 100000000; - length += value >= 1000000000; - return length; - } else { - int length = 1; - length += value >= 10; - length += value >= 100; - length += value >= 1000; - return length; - } -} - -template <> -int NumericHelper::UnsignedLength(uint64_t value) { - if (value >= 10000000000ULL) { - if (value >= 1000000000000000ULL) { - int length = 16; - length += value >= 10000000000000000ULL; - length += value >= 100000000000000000ULL; - length += value >= 1000000000000000000ULL; - length += value >= 10000000000000000000ULL; - return length; - } else { - int length = 11; - length += value >= 100000000000ULL; - length += value >= 1000000000000ULL; - length += value >= 10000000000000ULL; - length += value >= 100000000000000ULL; - return length; - } - } else { - if (value >= 100000ULL) { - int length = 6; - length += value >= 1000000ULL; - length += value >= 10000000ULL; - length += value >= 100000000ULL; - length += value >= 1000000000ULL; - return length; - } else { - int length = 1; - length += value >= 10ULL; - length += value >= 100ULL; - length += value >= 1000ULL; - length += value >= 10000ULL; - return length; - } - } -} - -template <> -int NumericHelper::UnsignedLength(hugeint_t value) { - D_ASSERT(value.upper >= 0); - if (value.upper == 0) { - return UnsignedLength(value.lower); - } - // search the length using the POWERS_OF_TEN array - // the length has to be between [17] and [38], because the hugeint is bigger than 2^63 - // we use the same approach as above, but split a bit more because comparisons for hugeints are more expensive - if (value >= Hugeint::POWERS_OF_TEN[27]) { - // [27..38] - if (value >= Hugeint::POWERS_OF_TEN[32]) { - if (value >= Hugeint::POWERS_OF_TEN[36]) { - int length = 37; - length += value >= Hugeint::POWERS_OF_TEN[37]; - length += value >= Hugeint::POWERS_OF_TEN[38]; - return length; - } else { - int length = 33; - length += value >= Hugeint::POWERS_OF_TEN[33]; - length += value >= Hugeint::POWERS_OF_TEN[34]; - length += value >= Hugeint::POWERS_OF_TEN[35]; - return length; - } - } else { - if (value >= Hugeint::POWERS_OF_TEN[30]) { - int length = 31; - length += value >= Hugeint::POWERS_OF_TEN[31]; - length += value >= Hugeint::POWERS_OF_TEN[32]; - return length; - } else { - int length = 28; - length += value >= Hugeint::POWERS_OF_TEN[28]; - length += value >= Hugeint::POWERS_OF_TEN[29]; - return length; - } - } - } else { - // [17..27] - if (value >= Hugeint::POWERS_OF_TEN[22]) { - // [22..27] - if (value >= Hugeint::POWERS_OF_TEN[25]) { - int length = 26; - length += value >= Hugeint::POWERS_OF_TEN[26]; - return length; - } else { - int length = 23; - length += value >= Hugeint::POWERS_OF_TEN[23]; - length += value >= Hugeint::POWERS_OF_TEN[24]; - return length; - } - } else { - // [17..22] - if (value >= Hugeint::POWERS_OF_TEN[20]) { - int length = 21; - length += value >= Hugeint::POWERS_OF_TEN[21]; - return length; - } else { - int length = 18; - length += value >= Hugeint::POWERS_OF_TEN[18]; - length += value >= Hugeint::POWERS_OF_TEN[19]; - return length; - } - } - } -} - -template <> -string_t NumericHelper::FormatSigned(hugeint_t value, Vector &vector) { - int negative = value.upper < 0; - if (negative) { - if (value == NumericLimits::Minimum()) { - string_t result = StringVector::AddString(vector, Hugeint::HUGEINT_MINIMUM_STRING); - return result; - } - Hugeint::NegateInPlace(value); - } - int length = UnsignedLength(value) + negative; - string_t result = StringVector::EmptyString(vector, NumericCast(length)); - auto dataptr = result.GetDataWriteable(); - auto endptr = dataptr + length; - if (value.upper == 0) { - // small value: format as uint64_t - endptr = NumericHelper::FormatUnsigned(value.lower, endptr); - } else { - endptr = FormatUnsigned(value, endptr); - } - if (negative) { - *--endptr = '-'; - } - D_ASSERT(endptr == dataptr); - result.Finalize(); - return result; -} - -template <> -std::string NumericHelper::ToString(hugeint_t value) { - return Hugeint::ToString(value); -} - -template <> -std::string NumericHelper::ToString(uhugeint_t value) { - return Uhugeint::ToString(value); -} - -template <> -int DecimalToString::DecimalLength(hugeint_t value, uint8_t width, uint8_t scale) { - D_ASSERT(value > NumericLimits::Minimum()); - int negative; - - if (value.upper < 0) { - Hugeint::NegateInPlace(value); - negative = 1; - } else { - negative = 0; - } - if (scale == 0) { - // scale is 0: regular number - return NumericHelper::UnsignedLength(value) + negative; - } - // length is max of either: - // scale + 2 OR - // integer length + 1 - // scale + 2 happens when the number is in the range of (-1, 1) - // in that case we print "0.XXX", which is the scale, plus "0." (2 chars) - // integer length + 1 happens when the number is outside of that range - // in that case we print the integer number, but with one extra character ('.') - auto extra_numbers = width > scale ? 2 : 1; - return MaxValue(scale + extra_numbers, NumericHelper::UnsignedLength(value) + 1) + negative; -} - -template <> -string_t DecimalToString::Format(hugeint_t value, uint8_t width, uint8_t scale, Vector &vector) { - int length = DecimalLength(value, width, scale); - string_t result = StringVector::EmptyString(vector, NumericCast(length)); - - auto dst = result.GetDataWriteable(); - - FormatDecimal(value, width, scale, dst, NumericCast(length)); - - result.Finalize(); - return result; -} - -template <> -char *NumericHelper::FormatUnsigned(hugeint_t value, char *ptr) { - while (value.upper > 0) { - // while integer division is slow, hugeint division is MEGA slow - // we want to avoid doing as many divisions as possible - // for that reason we start off doing a division by a large power of ten that uint64_t can hold - // (100000000000000000) - this is the third largest - // the reason we don't use the largest is because that can result in an overflow inside the division - // function - uint64_t remainder; - value = Hugeint::DivModPositive(value, 100000000000000000ULL, remainder); - - auto startptr = ptr; - // now we format the remainder: note that we need to pad with zero's in case - // the remainder is small (i.e. less than 10000000000000000) - ptr = NumericHelper::FormatUnsigned(remainder, ptr); - - int format_length = UnsafeNumericCast(startptr - ptr); - // pad with zero - for (int i = format_length; i < 17; i++) { - *--ptr = '0'; - } - } - // once the value falls in the range of a uint64_t, fallback to formatting as uint64_t to avoid hugeint division - return NumericHelper::FormatUnsigned(value.lower, ptr); -} - -template <> -void DecimalToString::FormatDecimal(hugeint_t value, uint8_t width, uint8_t scale, char *dst, idx_t len) { - auto endptr = dst + len; - - int negative = value.upper < 0; - if (negative) { - Hugeint::NegateInPlace(value); - *dst = '-'; - dst++; - } - if (scale == 0) { - // with scale=0 we format the number as a regular number - NumericHelper::FormatUnsigned(value, endptr); - return; - } - - // we write two numbers: - // the numbers BEFORE the decimal (major) - // and the numbers AFTER the decimal (minor) - hugeint_t minor; - hugeint_t major = Hugeint::DivMod(value, Hugeint::POWERS_OF_TEN[scale], minor); - - // write the number after the decimal - dst = NumericHelper::FormatUnsigned(minor, endptr); - // (optionally) pad with zeros and add the decimal point - while (dst > (endptr - scale)) { - *--dst = '0'; - } - *--dst = '.'; - // now write the part before the decimal - D_ASSERT(width > scale || major == 0); - if (width > scale) { - dst = NumericHelper::FormatUnsigned(major, dst); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_allocator.cpp b/src/duckdb/src/common/types/column/column_data_allocator.cpp deleted file mode 100644 index 66a1e612f..000000000 --- a/src/duckdb/src/common/types/column/column_data_allocator.cpp +++ /dev/null @@ -1,297 +0,0 @@ -#include "duckdb/common/types/column/column_data_allocator.hpp" - -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/types/column/column_data_collection_segment.hpp" -#include "duckdb/storage/buffer/block_handle.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -ColumnDataAllocator::ColumnDataAllocator(Allocator &allocator) : type(ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - alloc.allocator = &allocator; -} - -ColumnDataAllocator::ColumnDataAllocator(BufferManager &buffer_manager) - : type(ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { - alloc.buffer_manager = &buffer_manager; -} - -ColumnDataAllocator::ColumnDataAllocator(ClientContext &context, ColumnDataAllocatorType allocator_type) - : type(allocator_type) { - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - alloc.buffer_manager = &BufferManager::GetBufferManager(context); - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - alloc.allocator = &Allocator::Get(context); - break; - default: - throw InternalException("Unrecognized column data allocator type"); - } -} - -ColumnDataAllocator::ColumnDataAllocator(ColumnDataAllocator &other) { - type = other.GetType(); - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - alloc.allocator = other.alloc.allocator; - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - alloc.buffer_manager = other.alloc.buffer_manager; - break; - default: - throw InternalException("Unrecognized column data allocator type"); - } -} - -ColumnDataAllocator::~ColumnDataAllocator() { - if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - return; - } - for (auto &block : blocks) { - block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); - } - const auto data_size = SizeInBytes(); - blocks.clear(); - if (Allocator::SupportsFlush() && - data_size > alloc.buffer_manager->GetBufferPool().GetAllocatorBulkDeallocationFlushThreshold()) { - Allocator::FlushAll(); - } -} - -BufferHandle ColumnDataAllocator::Pin(uint32_t block_id) { - D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); - shared_ptr handle; - if (shared) { - // we only need to grab the lock when accessing the vector, because vector access is not thread-safe: - // the vector can be resized by another thread while we try to access it - lock_guard guard(lock); - handle = blocks[block_id].handle; - } else { - handle = blocks[block_id].handle; - } - return alloc.buffer_manager->Pin(handle); -} - -BufferHandle ColumnDataAllocator::AllocateBlock(idx_t size) { - D_ASSERT(type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || type == ColumnDataAllocatorType::HYBRID); - auto max_size = MaxValue(size, GetBufferManager().GetBlockSize()); - BlockMetaData data; - data.size = 0; - data.capacity = NumericCast(max_size); - auto pin = alloc.buffer_manager->Allocate(MemoryTag::COLUMN_DATA, max_size, false); - data.handle = pin.GetBlockHandle(); - blocks.push_back(std::move(data)); - if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits - blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); - } - allocated_size += max_size; - return pin; -} - -void ColumnDataAllocator::AllocateEmptyBlock(idx_t size) { - auto allocation_amount = MaxValue(NextPowerOfTwo(size), 4096); - if (!blocks.empty()) { - idx_t last_capacity = blocks.back().capacity; - auto next_capacity = MinValue(last_capacity * 2, last_capacity + Storage::DEFAULT_BLOCK_SIZE); - allocation_amount = MaxValue(next_capacity, allocation_amount); - } - D_ASSERT(type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); - BlockMetaData data; - data.size = 0; - data.capacity = NumericCast(allocation_amount); - data.handle = nullptr; - blocks.push_back(std::move(data)); - allocated_size += allocation_amount; -} - -void ColumnDataAllocator::AssignPointer(uint32_t &block_id, uint32_t &offset, data_ptr_t pointer) { - auto pointer_value = uintptr_t(pointer); - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - block_id = uint32_t(pointer_value); - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - block_id = uint32_t(pointer_value & 0xFFFFFFFF); - offset = uint32_t(pointer_value >> 32); - } else { - throw InternalException("ColumnDataCollection: Architecture not supported!?"); - } -} - -void ColumnDataAllocator::AllocateBuffer(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - D_ASSERT(allocated_data.empty()); - if (blocks.empty() || blocks.back().Capacity() < size) { - auto pinned_block = AllocateBlock(size); - if (chunk_state) { - D_ASSERT(!blocks.empty()); - auto new_block_id = blocks.size() - 1; - chunk_state->handles[new_block_id] = std::move(pinned_block); - } - } - auto &block = blocks.back(); - D_ASSERT(size <= block.capacity - block.size); - block_id = NumericCast(blocks.size() - 1); - if (chunk_state && chunk_state->handles.find(block_id) == chunk_state->handles.end()) { - // not guaranteed to be pinned already by this thread (if shared allocator) - chunk_state->handles[block_id] = alloc.buffer_manager->Pin(blocks[block_id].handle); - } - offset = block.size; - block.size += size; -} - -void ColumnDataAllocator::AllocateMemory(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - D_ASSERT(blocks.size() == allocated_data.size()); - if (blocks.empty() || blocks.back().Capacity() < size) { - AllocateEmptyBlock(size); - auto &last_block = blocks.back(); - auto allocated = alloc.allocator->Allocate(last_block.capacity); - allocated_data.push_back(std::move(allocated)); - } - auto &block = blocks.back(); - D_ASSERT(size <= block.capacity - block.size); - AssignPointer(block_id, offset, allocated_data.back().get() + block.size); - block.size += size; -} - -void ColumnDataAllocator::AllocateData(idx_t size, uint32_t &block_id, uint32_t &offset, - ChunkManagementState *chunk_state) { - switch (type) { - case ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR: - case ColumnDataAllocatorType::HYBRID: - if (shared) { - lock_guard guard(lock); - AllocateBuffer(size, block_id, offset, chunk_state); - } else { - AllocateBuffer(size, block_id, offset, chunk_state); - } - break; - case ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR: - D_ASSERT(!shared); - AllocateMemory(size, block_id, offset, chunk_state); - break; - default: - throw InternalException("Unrecognized allocator type"); - } -} - -void ColumnDataAllocator::Initialize(ColumnDataAllocator &other) { - D_ASSERT(other.HasBlocks()); - blocks.push_back(other.blocks.back()); -} - -data_ptr_t ColumnDataAllocator::GetDataPointer(ChunkManagementState &state, uint32_t block_id, uint32_t offset) { - if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - // in-memory allocator: construct pointer from block_id and offset - if (sizeof(uintptr_t) == sizeof(uint32_t)) { - uintptr_t pointer_value = uintptr_t(block_id); - return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer - } else if (sizeof(uintptr_t) == sizeof(uint64_t)) { - uintptr_t pointer_value = (uintptr_t(offset) << 32) | uintptr_t(block_id); - return (data_ptr_t)pointer_value; // NOLINT - convert from pointer value back to pointer - } else { - throw InternalException("ColumnDataCollection: Architecture not supported!?"); - } - } - D_ASSERT(state.handles.find(block_id) != state.handles.end()); - return state.handles[block_id].Ptr() + offset; -} - -void ColumnDataAllocator::UnswizzlePointers(ChunkManagementState &state, Vector &result, idx_t v_offset, uint16_t count, - uint32_t block_id, uint32_t offset) { - D_ASSERT(result.GetType().InternalType() == PhysicalType::VARCHAR); - lock_guard guard(lock); - - auto &validity = FlatVector::Validity(result); - auto strings = FlatVector::GetData(result); - - // find first non-inlined string - auto i = NumericCast(v_offset); - const uint32_t end = NumericCast(v_offset + count); - for (; i < end; i++) { - if (!validity.RowIsValid(i)) { - continue; - } - if (!strings[i].IsInlined()) { - break; - } - } - // at least one string must be non-inlined, otherwise this function should not be called - D_ASSERT(i < end); - - auto base_ptr = char_ptr_cast(GetDataPointer(state, block_id, offset)); - if (strings[i].GetData() == base_ptr) { - // pointers are still valid - return; - } - - // pointer mismatch! pointers are invalid, set them correctly - for (; i < end; i++) { - if (!validity.RowIsValid(i)) { - continue; - } - if (strings[i].IsInlined()) { - continue; - } - strings[i].SetPointer(base_ptr); - base_ptr += strings[i].GetSize(); - } -} - -void ColumnDataAllocator::SetDestroyBufferUponUnpin(uint32_t block_id) { - blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); -} - -Allocator &ColumnDataAllocator::GetAllocator() { - if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - return *alloc.allocator; - } - return alloc.buffer_manager->GetBufferAllocator(); -} - -BufferManager &ColumnDataAllocator::GetBufferManager() { - if (type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR) { - throw InternalException("cannot obtain the buffer manager for in memory allocations"); - } - return *alloc.buffer_manager; -} - -void ColumnDataAllocator::InitializeChunkState(ChunkManagementState &state, ChunkMetaData &chunk) { - if (type != ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR && type != ColumnDataAllocatorType::HYBRID) { - // nothing to pin - return; - } - // release any handles that are no longer required - bool found_handle; - do { - found_handle = false; - for (auto it = state.handles.begin(); it != state.handles.end(); it++) { - if (chunk.block_ids.find(NumericCast(it->first)) != chunk.block_ids.end()) { - // still required: do not release - continue; - } - state.handles.erase(it); - found_handle = true; - break; - } - } while (found_handle); - - // grab any handles that are now required - for (auto &block_id : chunk.block_ids) { - if (state.handles.find(block_id) != state.handles.end()) { - // already pinned: don't need to do anything - continue; - } - state.handles[block_id] = Pin(block_id); - } -} - -uint32_t BlockMetaData::Capacity() { - D_ASSERT(size <= capacity); - return capacity - size; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_collection.cpp b/src/duckdb/src/common/types/column/column_data_collection.cpp deleted file mode 100644 index d6e01e5a6..000000000 --- a/src/duckdb/src/common/types/column/column_data_collection.cpp +++ /dev/null @@ -1,1242 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection.hpp" - -#include "duckdb/common/printer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/column/column_data_collection_segment.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -struct ColumnDataMetaData; - -typedef void (*column_data_copy_function_t)(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, - Vector &source, idx_t offset, idx_t copy_count); - -struct ColumnDataCopyFunction { - column_data_copy_function_t function; - vector child_functions; -}; - -struct ColumnDataMetaData { - ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataCollectionSegment &segment, - ColumnDataAppendState &state, ChunkMetaData &chunk_data, VectorDataIndex vector_data_index) - : copy_function(copy_function), segment(segment), state(state), chunk_data(chunk_data), - vector_data_index(vector_data_index) { - } - ColumnDataMetaData(ColumnDataCopyFunction ©_function, ColumnDataMetaData &parent, - VectorDataIndex vector_data_index) - : copy_function(copy_function), segment(parent.segment), state(parent.state), chunk_data(parent.chunk_data), - vector_data_index(vector_data_index) { - } - - ColumnDataCopyFunction ©_function; - ColumnDataCollectionSegment &segment; - ColumnDataAppendState &state; - ChunkMetaData &chunk_data; - VectorDataIndex vector_data_index; - idx_t child_list_size = DConstants::INVALID_INDEX; - - VectorMetaData &GetVectorMetaData() { - return segment.GetVectorData(vector_data_index); - } -}; - -//! Explicitly initialized without types -ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p) { - types.clear(); - count = 0; - this->finished_append = false; - allocator = make_shared_ptr(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(Allocator &allocator_p, vector types_p) { - Initialize(std::move(types_p)); - allocator = make_shared_ptr(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(BufferManager &buffer_manager, vector types_p) { - Initialize(std::move(types_p)); - allocator = make_shared_ptr(buffer_manager); -} - -ColumnDataCollection::ColumnDataCollection(shared_ptr allocator_p, vector types_p) { - Initialize(std::move(types_p)); - this->allocator = std::move(allocator_p); -} - -ColumnDataCollection::ColumnDataCollection(ClientContext &context, vector types_p, - ColumnDataAllocatorType type) - : ColumnDataCollection(make_shared_ptr(context, type), std::move(types_p)) { - D_ASSERT(!types.empty()); -} - -ColumnDataCollection::ColumnDataCollection(ColumnDataCollection &other) - : ColumnDataCollection(other.allocator, other.types) { - other.finished_append = true; - D_ASSERT(!types.empty()); -} - -ColumnDataCollection::~ColumnDataCollection() { -} - -void ColumnDataCollection::Initialize(vector types_p) { - this->types = std::move(types_p); - this->count = 0; - this->finished_append = false; - D_ASSERT(!types.empty()); - copy_functions.reserve(types.size()); - for (auto &type : types) { - copy_functions.push_back(GetCopyFunction(type)); - } -} - -void ColumnDataCollection::CreateSegment() { - segments.emplace_back(make_uniq(allocator, types)); -} - -Allocator &ColumnDataCollection::GetAllocator() const { - return allocator->GetAllocator(); -} - -idx_t ColumnDataCollection::SizeInBytes() const { - idx_t total_size = 0; - for (const auto &segment : segments) { - total_size += segment->SizeInBytes(); - } - return total_size; -} - -idx_t ColumnDataCollection::AllocationSize() const { - idx_t total_size = 0; - for (const auto &segment : segments) { - total_size += segment->AllocationSize(); - } - return total_size; -} - -void ColumnDataCollection::SetPartitionIndex(const idx_t index) { - D_ASSERT(!partition_index.IsValid()); - D_ASSERT(Count() == 0); - partition_index = index; - allocator->SetPartitionIndex(index); -} - -//===--------------------------------------------------------------------===// -// ColumnDataRow -//===--------------------------------------------------------------------===// -ColumnDataRow::ColumnDataRow(DataChunk &chunk_p, idx_t row_index, idx_t base_index) - : chunk(chunk_p), row_index(row_index), base_index(base_index) { -} - -Value ColumnDataRow::GetValue(idx_t column_index) const { - D_ASSERT(column_index < chunk.ColumnCount()); - D_ASSERT(row_index < chunk.size()); - return chunk.data[column_index].GetValue(row_index); -} - -idx_t ColumnDataRow::RowIndex() const { - return base_index + row_index; -} - -//===--------------------------------------------------------------------===// -// ColumnDataRowCollection -//===--------------------------------------------------------------------===// -ColumnDataRowCollection::ColumnDataRowCollection(const ColumnDataCollection &collection) { - if (collection.Count() == 0) { - return; - } - // read all the chunks - ColumnDataScanState temp_scan_state; - collection.InitializeScan(temp_scan_state, ColumnDataScanProperties::DISALLOW_ZERO_COPY); - while (true) { - auto chunk = make_uniq(); - collection.InitializeScanChunk(*chunk); - if (!collection.Scan(temp_scan_state, *chunk)) { - break; - } - chunks.push_back(std::move(chunk)); - } - // now create all of the column data rows - rows.reserve(collection.Count()); - idx_t base_row = 0; - for (auto &chunk : chunks) { - for (idx_t row_idx = 0; row_idx < chunk->size(); row_idx++) { - rows.emplace_back(*chunk, row_idx, base_row); - } - base_row += chunk->size(); - } -} - -ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) { - return rows[i]; -} - -const ColumnDataRow &ColumnDataRowCollection::operator[](idx_t i) const { - return rows[i]; -} - -Value ColumnDataRowCollection::GetValue(idx_t column, idx_t index) const { - return rows[index].GetValue(column); -} - -//===--------------------------------------------------------------------===// -// ColumnDataChunkIterator -//===--------------------------------------------------------------------===// -ColumnDataChunkIterationHelper ColumnDataCollection::Chunks() const { - vector column_ids; - for (idx_t i = 0; i < ColumnCount(); i++) { - column_ids.push_back(i); - } - return Chunks(column_ids); -} - -ColumnDataChunkIterationHelper ColumnDataCollection::Chunks(vector column_ids) const { - return ColumnDataChunkIterationHelper(*this, std::move(column_ids)); -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterationHelper(const ColumnDataCollection &collection_p, - vector column_ids_p) - : collection(collection_p), column_ids(std::move(column_ids_p)) { -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterator::ColumnDataChunkIterator( - const ColumnDataCollection *collection_p, vector column_ids_p) - : collection(collection_p), scan_chunk(make_shared_ptr()), row_index(0) { - if (!collection) { - return; - } - collection->InitializeScan(scan_state, std::move(column_ids_p)); - collection->InitializeScanChunk(scan_state, *scan_chunk); - collection->Scan(scan_state, *scan_chunk); -} - -void ColumnDataChunkIterationHelper::ColumnDataChunkIterator::Next() { - if (!collection) { - return; - } - if (!collection->Scan(scan_state, *scan_chunk)) { - collection = nullptr; - row_index = 0; - } else { - row_index += scan_chunk->size(); - } -} - -ColumnDataChunkIterationHelper::ColumnDataChunkIterator & -ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator++() { - Next(); - return *this; -} - -bool ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator!=(const ColumnDataChunkIterator &other) const { - return collection != other.collection || row_index != other.row_index; -} - -DataChunk &ColumnDataChunkIterationHelper::ColumnDataChunkIterator::operator*() const { - return *scan_chunk; -} - -//===--------------------------------------------------------------------===// -// ColumnDataRowIterator -//===--------------------------------------------------------------------===// -ColumnDataRowIterationHelper ColumnDataCollection::Rows() const { - return ColumnDataRowIterationHelper(*this); -} - -ColumnDataRowIterationHelper::ColumnDataRowIterationHelper(const ColumnDataCollection &collection_p) - : collection(collection_p) { -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator::ColumnDataRowIterator(const ColumnDataCollection *collection_p) - : collection(collection_p), scan_chunk(make_shared_ptr()), current_row(*scan_chunk, 0, 0) { - if (!collection) { - return; - } - collection->InitializeScan(scan_state); - collection->InitializeScanChunk(*scan_chunk); - collection->Scan(scan_state, *scan_chunk); -} - -void ColumnDataRowIterationHelper::ColumnDataRowIterator::Next() { - if (!collection) { - return; - } - current_row.row_index++; - if (current_row.row_index >= scan_chunk->size()) { - current_row.base_index += scan_chunk->size(); - current_row.row_index = 0; - if (!collection->Scan(scan_state, *scan_chunk)) { - // exhausted collection: move iterator to nop state - current_row.base_index = 0; - collection = nullptr; - } - } -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::begin() { // NOLINT - return ColumnDataRowIterationHelper::ColumnDataRowIterator(collection.Count() == 0 ? nullptr : &collection); -} -ColumnDataRowIterationHelper::ColumnDataRowIterator ColumnDataRowIterationHelper::end() { // NOLINT - return ColumnDataRowIterationHelper::ColumnDataRowIterator(nullptr); -} - -ColumnDataRowIterationHelper::ColumnDataRowIterator &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator++() { - Next(); - return *this; -} - -bool ColumnDataRowIterationHelper::ColumnDataRowIterator::operator!=(const ColumnDataRowIterator &other) const { - return collection != other.collection || current_row.row_index != other.current_row.row_index || - current_row.base_index != other.current_row.base_index; -} - -const ColumnDataRow &ColumnDataRowIterationHelper::ColumnDataRowIterator::operator*() const { - return current_row; -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -void ColumnDataCollection::InitializeAppend(ColumnDataAppendState &state) { - D_ASSERT(!finished_append); - state.current_chunk_state.handles.clear(); - state.vector_data.resize(types.size()); - if (segments.empty()) { - CreateSegment(); - } - auto &segment = *segments.back(); - if (segment.chunk_data.empty()) { - segment.AllocateNewChunk(); - } - segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); -} - -void ColumnDataCopyValidity(const UnifiedVectorFormat &source_data, validity_t *target, idx_t source_offset, - idx_t target_offset, idx_t copy_count) { - ValidityMask validity(target, STANDARD_VECTOR_SIZE); - if (target_offset == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - // FIXME: we can do something more optimized here using bitshifts & bitwise ors - if (!source_data.validity.AllValid()) { - for (idx_t i = 0; i < copy_count; i++) { - auto idx = source_data.sel->get_index(source_offset + i); - if (!source_data.validity.RowIsValid(idx)) { - validity.SetInvalid(target_offset + i); - } - } - } -} - -template -struct BaseValueCopy { - static idx_t TypeSize() { - return sizeof(T); - } - - template - static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, - idx_t source_idx) { - auto result_data = (T *)target; - auto source_data = (T *)source; - result_data[target_idx] = OP::Operation(meta_data, source_data[source_idx]); - } -}; - -template -struct StandardValueCopy : public BaseValueCopy { - static T Operation(ColumnDataMetaData &, T input) { - return input; - } -}; - -struct StringValueCopy : public BaseValueCopy { - static string_t Operation(ColumnDataMetaData &meta_data, string_t input) { - return input.IsInlined() ? input : meta_data.segment.heap->AddBlob(input); - } -}; - -struct ConstListValueCopy : public BaseValueCopy { - using TYPE = list_entry_t; - - static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { - input.offset = meta_data.child_list_size; - return input; - } -}; - -struct ListValueCopy : public BaseValueCopy { - using TYPE = list_entry_t; - - static TYPE Operation(ColumnDataMetaData &meta_data, TYPE input) { - input.offset = meta_data.child_list_size; - meta_data.child_list_size += input.length; - return input; - } -}; - -struct StructValueCopy { - static idx_t TypeSize() { - return 0; - } - - template - static void Assign(ColumnDataMetaData &meta_data, data_ptr_t target, data_ptr_t source, idx_t target_idx, - idx_t source_idx) { - } -}; - -template -static void TemplatedColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, - Vector &source, idx_t offset, idx_t count) { - auto &segment = meta_data.segment; - auto &append_state = meta_data.state; - - auto current_index = meta_data.vector_data_index; - idx_t remaining = count; - while (remaining > 0) { - auto ¤t_segment = segment.GetVectorData(current_index); - idx_t append_count = MinValue(STANDARD_VECTOR_SIZE - current_segment.count, remaining); - - auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, - current_segment.offset); - auto validity_data = ColumnDataCollectionSegment::GetValidityPointerForWriting(base_ptr, OP::TypeSize()); - - ValidityMask result_validity(validity_data, STANDARD_VECTOR_SIZE); - if (current_segment.count == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - result_validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_data.sel->get_index(offset + i); - if (source_data.validity.RowIsValid(source_idx)) { - OP::template Assign(meta_data, base_ptr, source_data.data, current_segment.count + i, source_idx); - } else { - result_validity.SetInvalid(current_segment.count + i); - } - } - current_segment.count += append_count; - offset += append_count; - remaining -= append_count; - if (remaining > 0) { - // need to append more, check if we need to allocate a new vector or not - if (!current_segment.next_data.IsValid()) { - segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); - } - D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); - current_index = segment.GetVectorData(current_index).next_data; - } - } -} - -template -static void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - TemplatedColumnDataCopy>(meta_data, source_data, source, offset, copy_count); -} - -template <> -void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - - const auto &allocator_type = meta_data.segment.allocator->GetType(); - if (allocator_type == ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR || - allocator_type == ColumnDataAllocatorType::HYBRID) { - // strings cannot be spilled to disk - use StringHeap - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - return; - } - D_ASSERT(allocator_type == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); - - auto &segment = meta_data.segment; - auto &append_state = meta_data.state; - - VectorDataIndex child_index; - if (meta_data.GetVectorMetaData().child_index.IsValid()) { - // find the last child index - child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); - auto next_child_index = segment.GetVectorData(child_index).next_data; - while (next_child_index.IsValid()) { - child_index = next_child_index; - next_child_index = segment.GetVectorData(child_index).next_data; - } - } - - auto current_index = meta_data.vector_data_index; - idx_t remaining = copy_count; - auto block_size = meta_data.segment.allocator->GetBufferManager().GetBlockSize(); - while (remaining > 0) { - // how many values fit in the current string vector - idx_t vector_remaining = - MinValue(STANDARD_VECTOR_SIZE - segment.GetVectorData(current_index).count, remaining); - - // 'append_count' is less if we cannot fit that amount of non-inlined strings on one buffer-managed block - idx_t append_count; - idx_t heap_size = 0; - const auto source_entries = UnifiedVectorFormat::GetData(source_data); - for (append_count = 0; append_count < vector_remaining; append_count++) { - auto source_idx = source_data.sel->get_index(offset + append_count); - if (!source_data.validity.RowIsValid(source_idx)) { - continue; - } - const auto &entry = source_entries[source_idx]; - if (entry.IsInlined()) { - continue; - } - if (heap_size + entry.GetSize() > block_size) { - break; - } - heap_size += entry.GetSize(); - } - - if (vector_remaining != 0 && append_count == 0) { - // The string exceeds Storage::DEFAULT_BLOCK_SIZE, so we allocate one block at a time for long strings. - auto source_idx = source_data.sel->get_index(offset + append_count); - D_ASSERT(source_data.validity.RowIsValid(source_idx)); - D_ASSERT(!source_entries[source_idx].IsInlined()); - D_ASSERT(source_entries[source_idx].GetSize() > block_size); - heap_size += source_entries[source_idx].GetSize(); - append_count++; - } - - // allocate string heap for the next 'append_count' strings - data_ptr_t heap_ptr = nullptr; - if (heap_size != 0) { - child_index = segment.AllocateStringHeap(heap_size, meta_data.chunk_data, append_state, child_index); - if (!meta_data.GetVectorMetaData().child_index.IsValid()) { - meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); - } - auto &child_segment = segment.GetVectorData(child_index); - heap_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, child_segment.block_id, - child_segment.offset); - } - - auto ¤t_segment = segment.GetVectorData(current_index); - auto base_ptr = segment.allocator->GetDataPointer(append_state.current_chunk_state, current_segment.block_id, - current_segment.offset); - auto validity_data = ColumnDataCollectionSegment::GetValidityPointerForWriting(base_ptr, sizeof(string_t)); - ValidityMask target_validity(validity_data, STANDARD_VECTOR_SIZE); - if (current_segment.count == 0) { - // first time appending to this vector - // all data here is still uninitialized - // initialize the validity mask to set all to valid - target_validity.SetAllValid(STANDARD_VECTOR_SIZE); - } - - auto target_entries = reinterpret_cast(base_ptr); - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_data.sel->get_index(offset + i); - auto target_idx = current_segment.count + i; - if (!source_data.validity.RowIsValid(source_idx)) { - target_validity.SetInvalid(target_idx); - continue; - } - const auto &source_entry = source_entries[source_idx]; - auto &target_entry = target_entries[target_idx]; - if (source_entry.IsInlined()) { - target_entry = source_entry; - } else { - D_ASSERT(heap_ptr != nullptr); - memcpy(heap_ptr, source_entry.GetData(), source_entry.GetSize()); - target_entry = - string_t(const_char_ptr_cast(heap_ptr), UnsafeNumericCast(source_entry.GetSize())); - heap_ptr += source_entry.GetSize(); - } - } - - if (heap_size != 0) { - current_segment.swizzle_data.emplace_back(child_index, current_segment.count, append_count); - } - - current_segment.count += append_count; - offset += append_count; - remaining -= append_count; - - if (vector_remaining - append_count == 0) { - // need to append more, check if we need to allocate a new vector or not - if (!current_segment.next_data.IsValid()) { - segment.AllocateVector(source.GetType(), meta_data.chunk_data, append_state, current_index); - } - D_ASSERT(segment.GetVectorData(current_index).next_data.IsValid()); - current_index = segment.GetVectorData(current_index).next_data; - } - } -} - -template <> -void ColumnDataCopy(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - - auto &segment = meta_data.segment; - - auto &child_vector = ListVector::GetEntry(source); - auto &child_type = child_vector.GetType(); - - if (!meta_data.GetVectorMetaData().child_index.IsValid()) { - auto child_index = segment.AllocateVector(child_type, meta_data.chunk_data, meta_data.state); - meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); - } - - auto &child_function = meta_data.copy_function.child_functions[0]; - auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); - - // figure out the current list size by traversing the set of child entries - idx_t current_list_size = 0; - auto current_child_index = child_index; - while (current_child_index.IsValid()) { - auto &child_vdata = segment.GetVectorData(current_child_index); - current_list_size += child_vdata.count; - current_child_index = child_vdata.next_data; - } - - // set the child vector - UnifiedVectorFormat child_vector_data; - ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); - auto info = ListVector::GetConsecutiveChildListInfo(source, offset, copy_count); - - if (info.needs_slicing) { - SelectionVector sel(info.child_list_info.length); - ListVector::GetConsecutiveChildSelVector(source, sel, offset, copy_count); - - auto sliced_child_vector = Vector(child_vector, sel, info.child_list_info.length); - sliced_child_vector.Flatten(info.child_list_info.length); - info.child_list_info.offset = 0; - - sliced_child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); - child_function.function(child_meta_data, child_vector_data, sliced_child_vector, info.child_list_info.offset, - info.child_list_info.length); - - } else { - child_vector.ToUnifiedFormat(info.child_list_info.length, child_vector_data); - child_function.function(child_meta_data, child_vector_data, child_vector, info.child_list_info.offset, - info.child_list_info.length); - } - - // now copy the list entries - meta_data.child_list_size = current_list_size; - if (info.is_constant) { - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - } else { - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - } -} - -void ColumnDataCopyStruct(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - auto &segment = meta_data.segment; - - // copy the NULL values for the main struct vector - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - - auto &child_types = StructType::GetChildTypes(source.GetType()); - // now copy all the child vectors - D_ASSERT(meta_data.GetVectorMetaData().child_index.IsValid()); - auto &child_vectors = StructVector::GetEntries(source); - for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { - auto &child_function = meta_data.copy_function.child_functions[child_idx]; - auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index, child_idx); - ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); - - UnifiedVectorFormat child_data; - child_vectors[child_idx]->ToUnifiedFormat(copy_count, child_data); - - child_function.function(child_meta_data, child_data, *child_vectors[child_idx], offset, copy_count); - } -} - -void ColumnDataCopyArray(ColumnDataMetaData &meta_data, const UnifiedVectorFormat &source_data, Vector &source, - idx_t offset, idx_t copy_count) { - - auto &segment = meta_data.segment; - - // copy the NULL values for the main array vector (the same as for a struct vector) - TemplatedColumnDataCopy(meta_data, source_data, source, offset, copy_count); - - auto &child_vector = ArrayVector::GetEntry(source); - auto &child_type = child_vector.GetType(); - auto array_size = ArrayType::GetSize(source.GetType()); - - if (!meta_data.GetVectorMetaData().child_index.IsValid()) { - auto child_index = segment.AllocateVector(child_type, meta_data.chunk_data, meta_data.state); - meta_data.GetVectorMetaData().child_index = meta_data.segment.AddChildIndex(child_index); - } - - auto &child_function = meta_data.copy_function.child_functions[0]; - auto child_index = segment.GetChildIndex(meta_data.GetVectorMetaData().child_index); - - auto current_child_index = child_index; - while (current_child_index.IsValid()) { - auto &child_vdata = segment.GetVectorData(current_child_index); - current_child_index = child_vdata.next_data; - } - - UnifiedVectorFormat child_vector_data; - ColumnDataMetaData child_meta_data(child_function, meta_data, child_index); - child_vector.ToUnifiedFormat(copy_count * array_size, child_vector_data); - - // Broadcast and sync the validity of the array vector to the child vector - - if (source_data.validity.IsMaskSet()) { - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = source_data.sel->get_index(offset + i); - if (!source_data.validity.RowIsValid(source_idx)) { - for (idx_t j = 0; j < array_size; j++) { - child_vector_data.validity.SetInvalid(source_idx * array_size + j); - } - } - } - } - - auto is_constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // If the array is constant, we need to copy the child vector n times - if (is_constant) { - for (idx_t i = 0; i < copy_count; i++) { - child_function.function(child_meta_data, child_vector_data, child_vector, 0, array_size); - } - } else { - child_function.function(child_meta_data, child_vector_data, child_vector, offset * array_size, - copy_count * array_size); - } -} - -ColumnDataCopyFunction ColumnDataCollection::GetCopyFunction(const LogicalType &type) { - ColumnDataCopyFunction result; - column_data_copy_function_t function; - switch (type.InternalType()) { - case PhysicalType::BOOL: - function = ColumnDataCopy; - break; - case PhysicalType::INT8: - function = ColumnDataCopy; - break; - case PhysicalType::INT16: - function = ColumnDataCopy; - break; - case PhysicalType::INT32: - function = ColumnDataCopy; - break; - case PhysicalType::INT64: - function = ColumnDataCopy; - break; - case PhysicalType::INT128: - function = ColumnDataCopy; - break; - case PhysicalType::UINT8: - function = ColumnDataCopy; - break; - case PhysicalType::UINT16: - function = ColumnDataCopy; - break; - case PhysicalType::UINT32: - function = ColumnDataCopy; - break; - case PhysicalType::UINT64: - function = ColumnDataCopy; - break; - case PhysicalType::UINT128: - function = ColumnDataCopy; - break; - case PhysicalType::FLOAT: - function = ColumnDataCopy; - break; - case PhysicalType::DOUBLE: - function = ColumnDataCopy; - break; - case PhysicalType::INTERVAL: - function = ColumnDataCopy; - break; - case PhysicalType::VARCHAR: - function = ColumnDataCopy; - break; - case PhysicalType::STRUCT: { - function = ColumnDataCopyStruct; - auto &child_types = StructType::GetChildTypes(type); - for (auto &kv : child_types) { - result.child_functions.push_back(GetCopyFunction(kv.second)); - } - break; - } - case PhysicalType::LIST: { - function = ColumnDataCopy; - auto child_function = GetCopyFunction(ListType::GetChildType(type)); - result.child_functions.push_back(child_function); - break; - } - case PhysicalType::ARRAY: { - function = ColumnDataCopyArray; - auto child_function = GetCopyFunction(ArrayType::GetChildType(type)); - result.child_functions.push_back(child_function); - break; - } - default: - throw InternalException("Unsupported type %s for ColumnDataCollection::GetCopyFunction", - EnumUtil::ToString(type.InternalType())); - } - result.function = function; - return result; -} - -static bool IsComplexType(const LogicalType &type) { - switch (type.InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - case PhysicalType::ARRAY: - return true; - default: - return false; - }; -} - -void ColumnDataCollection::Append(ColumnDataAppendState &state, DataChunk &input) { - D_ASSERT(!finished_append); - D_ASSERT(types == input.GetTypes()); - - auto &segment = *segments.back(); - for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { - if (IsComplexType(input.data[vector_idx].GetType())) { - input.data[vector_idx].Flatten(input.size()); - } - input.data[vector_idx].ToUnifiedFormat(input.size(), state.vector_data[vector_idx]); - } - - idx_t remaining = input.size(); - while (remaining > 0) { - auto &chunk_data = segment.chunk_data.back(); - idx_t append_amount = MinValue(remaining, STANDARD_VECTOR_SIZE - chunk_data.count); - if (append_amount > 0) { - idx_t offset = input.size() - remaining; - for (idx_t vector_idx = 0; vector_idx < types.size(); vector_idx++) { - ColumnDataMetaData meta_data(copy_functions[vector_idx], segment, state, chunk_data, - chunk_data.vector_data[vector_idx]); - copy_functions[vector_idx].function(meta_data, state.vector_data[vector_idx], input.data[vector_idx], - offset, append_amount); - } - chunk_data.count += append_amount; - } - remaining -= append_amount; - if (remaining > 0) { - // more to do - // allocate a new chunk - segment.AllocateNewChunk(); - segment.InitializeChunkState(segment.chunk_data.size() - 1, state.current_chunk_state); - } - } - segment.count += input.size(); - count += input.size(); -} - -void ColumnDataCollection::Append(DataChunk &input) { - ColumnDataAppendState state; - InitializeAppend(state); - Append(state, input); -} - -//===--------------------------------------------------------------------===// -// Scan -//===--------------------------------------------------------------------===// -void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, ColumnDataScanProperties properties) const { - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - InitializeScan(state, std::move(column_ids), properties); -} - -void ColumnDataCollection::InitializeScan(ColumnDataScanState &state, vector column_ids, - ColumnDataScanProperties properties) const { - state.chunk_index = 0; - state.segment_index = 0; - state.current_row_index = 0; - state.next_row_index = 0; - state.current_chunk_state.handles.clear(); - state.properties = properties; - state.column_ids = std::move(column_ids); -} - -void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, - ColumnDataScanProperties properties) const { - InitializeScan(state.scan_state, properties); -} - -void ColumnDataCollection::InitializeScan(ColumnDataParallelScanState &state, vector column_ids, - ColumnDataScanProperties properties) const { - InitializeScan(state.scan_state, std::move(column_ids), properties); -} - -bool ColumnDataCollection::Scan(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, - DataChunk &result) const { - result.Reset(); - - idx_t chunk_index; - idx_t segment_index; - idx_t row_index; - { - lock_guard l(state.lock); - if (!NextScanIndex(state.scan_state, chunk_index, segment_index, row_index)) { - return false; - } - } - ScanAtIndex(state, lstate, result, chunk_index, segment_index, row_index); - return true; -} - -void ColumnDataCollection::InitializeScanChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), types); -} - -void ColumnDataCollection::InitializeScanChunk(ColumnDataScanState &state, DataChunk &chunk) const { - D_ASSERT(!state.column_ids.empty()); - vector chunk_types; - chunk_types.reserve(state.column_ids.size()); - for (idx_t i = 0; i < state.column_ids.size(); i++) { - auto column_idx = state.column_ids[i]; - D_ASSERT(column_idx < types.size()); - chunk_types.push_back(types[column_idx]); - } - chunk.Initialize(allocator->GetAllocator(), chunk_types); -} - -bool ColumnDataCollection::NextScanIndex(ColumnDataScanState &state, idx_t &chunk_index, idx_t &segment_index, - idx_t &row_index) const { - row_index = state.current_row_index = state.next_row_index; - // check if we still have collections to scan - if (state.segment_index >= segments.size()) { - // no more data left in the scan - return false; - } - // check within the current collection if we still have chunks to scan - while (state.chunk_index >= segments[state.segment_index]->chunk_data.size()) { - // exhausted all chunks for this internal data structure: move to the next one - state.chunk_index = 0; - state.segment_index++; - state.current_chunk_state.handles.clear(); - if (state.segment_index >= segments.size()) { - return false; - } - } - state.next_row_index += segments[state.segment_index]->chunk_data[state.chunk_index].count; - segment_index = state.segment_index; - chunk_index = state.chunk_index++; - return true; -} - -bool ColumnDataCollection::PrevScanIndex(ColumnDataScanState &state, idx_t &chunk_index, idx_t &segment_index, - idx_t &row_index) const { - // check within the current segment if we still have chunks to scan - // Note that state.chunk_index is 1-indexed, with 0 as undefined. - while (state.chunk_index <= 1) { - if (!state.segment_index) { - return false; - } - - --state.segment_index; - state.chunk_index = segments[state.segment_index]->chunk_data.size() + 1; - state.current_chunk_state.handles.clear(); - } - - --state.chunk_index; - segment_index = state.segment_index; - chunk_index = state.chunk_index - 1; - state.next_row_index = state.current_row_index; - state.current_row_index -= segments[state.segment_index]->chunk_data[chunk_index].count; - row_index = state.current_row_index; - return true; -} - -void ColumnDataCollection::ScanAtIndex(ColumnDataParallelScanState &state, ColumnDataLocalScanState &lstate, - DataChunk &result, idx_t chunk_index, idx_t segment_index, - idx_t row_index) const { - if (segment_index != lstate.current_segment_index) { - lstate.current_chunk_state.handles.clear(); - lstate.current_segment_index = segment_index; - } - auto &segment = *segments[segment_index]; - lstate.current_chunk_state.properties = state.scan_state.properties; - segment.ReadChunk(chunk_index, lstate.current_chunk_state, result, state.scan_state.column_ids); - lstate.current_row_index = row_index; - result.Verify(); -} - -bool ColumnDataCollection::Scan(ColumnDataScanState &state, DataChunk &result) const { - result.Reset(); - - idx_t chunk_index; - idx_t segment_index; - idx_t row_index; - if (!NextScanIndex(state, chunk_index, segment_index, row_index)) { - return false; - } - - // found a chunk to scan -> scan it - auto &segment = *segments[segment_index]; - state.current_chunk_state.properties = state.properties; - segment.ReadChunk(chunk_index, state.current_chunk_state, result, state.column_ids); - result.Verify(); - return true; -} - -bool ColumnDataCollection::Seek(idx_t seek_idx, ColumnDataScanState &state, DataChunk &result) const { - // Idempotency: Don't change anything if the row is already in range - if (state.current_row_index <= seek_idx && seek_idx < state.next_row_index) { - return true; - } - - result.Reset(); - - // Linear scan for now. We could use a current_row_index => chunk map at some point - // but most use cases should be pretty local - idx_t chunk_index; - idx_t segment_index; - idx_t row_index; - while (seek_idx < state.current_row_index) { - if (!PrevScanIndex(state, chunk_index, segment_index, row_index)) { - return false; - } - } - while (state.next_row_index <= seek_idx) { - if (!NextScanIndex(state, chunk_index, segment_index, row_index)) { - return false; - } - } - - // found a chunk to scan -> scan it - auto &segment = *segments[segment_index]; - state.current_chunk_state.properties = state.properties; - segment.ReadChunk(chunk_index, state.current_chunk_state, result, state.column_ids); - result.Verify(); - return true; -} - -ColumnDataRowCollection ColumnDataCollection::GetRows() const { - return ColumnDataRowCollection(*this); -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void ColumnDataCollection::Combine(ColumnDataCollection &other) { - if (other.count == 0) { - return; - } - if (types != other.types) { - throw InternalException("Attempting to combine ColumnDataCollections with mismatching types"); - } - this->count += other.count; - this->segments.reserve(segments.size() + other.segments.size()); - for (auto &other_seg : other.segments) { - segments.push_back(std::move(other_seg)); - } - other.Reset(); - Verify(); -} - -//===--------------------------------------------------------------------===// -// Fetch -//===--------------------------------------------------------------------===// -idx_t ColumnDataCollection::ChunkCount() const { - idx_t chunk_count = 0; - for (auto &segment : segments) { - chunk_count += segment->ChunkCount(); - } - return chunk_count; -} - -void ColumnDataCollection::FetchChunk(idx_t chunk_idx, DataChunk &result) const { - D_ASSERT(chunk_idx < ChunkCount()); - for (auto &segment : segments) { - if (chunk_idx >= segment->ChunkCount()) { - chunk_idx -= segment->ChunkCount(); - } else { - segment->FetchChunk(chunk_idx, result); - return; - } - } - throw InternalException("Failed to find chunk in ColumnDataCollection"); -} - -//===--------------------------------------------------------------------===// -// Helpers -//===--------------------------------------------------------------------===// -void ColumnDataCollection::Verify() { -#ifdef DEBUG - // verify counts - idx_t total_segment_count = 0; - for (auto &segment : segments) { - segment->Verify(); - total_segment_count += segment->count; - } - D_ASSERT(total_segment_count == this->count); -#endif -} - -// LCOV_EXCL_START -string ColumnDataCollection::ToString() const { - DataChunk chunk; - InitializeScanChunk(chunk); - - ColumnDataScanState scan_state; - InitializeScan(scan_state); - - string result = StringUtil::Format("ColumnDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); - idx_t chunk_idx = 0; - idx_t row_count = 0; - while (Scan(scan_state, chunk)) { - result += - StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + - chunk.ToString(); - chunk_idx++; - row_count += chunk.size(); - } - - return result; -} -// LCOV_EXCL_STOP - -void ColumnDataCollection::Print() const { - Printer::Print(ToString()); -} - -void ColumnDataCollection::Reset() { - count = 0; - segments.clear(); - - // Refreshes the ColumnDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared_ptr(*allocator); -} - -struct ValueResultEquals { - bool operator()(const Value &a, const Value &b) const { - return Value::DefaultValuesAreEqual(a, b); - } -}; - -bool ColumnDataCollection::ResultEquals(const ColumnDataCollection &left, const ColumnDataCollection &right, - string &error_message, bool ordered) { - if (left.ColumnCount() != right.ColumnCount()) { - error_message = "Column count mismatch"; - return false; - } - if (left.Count() != right.Count()) { - error_message = "Row count mismatch"; - return false; - } - auto left_rows = left.GetRows(); - auto right_rows = right.GetRows(); - for (idx_t r = 0; r < left.Count(); r++) { - for (idx_t c = 0; c < left.ColumnCount(); c++) { - auto lvalue = left_rows.GetValue(c, r); - auto rvalue = right_rows.GetValue(c, r); - - if (!Value::DefaultValuesAreEqual(lvalue, rvalue)) { - error_message = - StringUtil::Format("%s <> %s (row: %lld, col: %lld)\n", lvalue.ToString(), rvalue.ToString(), r, c); - break; - } - } - if (!error_message.empty()) { - if (ordered) { - return false; - } else { - break; - } - } - } - if (!error_message.empty()) { - // do an unordered comparison - bool found_all = true; - for (idx_t c = 0; c < left.ColumnCount(); c++) { - std::unordered_multiset lvalues; - for (idx_t r = 0; r < left.Count(); r++) { - auto lvalue = left_rows.GetValue(c, r); - lvalues.insert(lvalue); - } - for (idx_t r = 0; r < right.Count(); r++) { - auto rvalue = right_rows.GetValue(c, r); - auto entry = lvalues.find(rvalue); - if (entry == lvalues.end()) { - found_all = false; - break; - } - lvalues.erase(entry); - } - if (!found_all) { - break; - } - } - if (!found_all) { - return false; - } - error_message = string(); - } - return true; -} - -vector> ColumnDataCollection::GetHeapReferences() { - vector> result(segments.size(), nullptr); - for (idx_t segment_idx = 0; segment_idx < segments.size(); segment_idx++) { - result[segment_idx] = segments[segment_idx]->heap; - } - return result; -} - -ColumnDataAllocatorType ColumnDataCollection::GetAllocatorType() const { - return allocator->GetType(); -} - -const vector> &ColumnDataCollection::GetSegments() const { - return segments; -} - -void ColumnDataCollection::Serialize(Serializer &serializer) const { - vector> values; - values.resize(ColumnCount()); - for (auto &chunk : Chunks()) { - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - for (idx_t r = 0; r < chunk.size(); r++) { - values[c].push_back(chunk.GetValue(c, r)); - } - } - } - serializer.WriteProperty(100, "types", types); - serializer.WriteProperty(101, "values", values); -} - -unique_ptr ColumnDataCollection::Deserialize(Deserializer &deserializer) { - auto types = deserializer.ReadProperty>(100, "types"); - auto values = deserializer.ReadProperty>>(101, "values"); - - auto collection = make_uniq(Allocator::DefaultAllocator(), types); - if (values.empty()) { - return collection; - } - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), types); - - for (idx_t r = 0; r < values[0].size(); r++) { - for (idx_t c = 0; c < types.size(); c++) { - chunk.SetValue(c, chunk.size(), values[c][r]); - } - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - } - if (chunk.size() > 0) { - collection->Append(chunk); - } - return collection; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp b/src/duckdb/src/common/types/column/column_data_collection_segment.cpp deleted file mode 100644 index 1ec0f6f45..000000000 --- a/src/duckdb/src/common/types/column/column_data_collection_segment.cpp +++ /dev/null @@ -1,309 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection_segment.hpp" - -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -ColumnDataCollectionSegment::ColumnDataCollectionSegment(shared_ptr allocator_p, - vector types_p) - : allocator(std::move(allocator_p)), types(std::move(types_p)), count(0), - heap(make_shared_ptr(allocator->GetAllocator())) { -} - -idx_t ColumnDataCollectionSegment::GetDataSize(idx_t type_size) { - return AlignValue(type_size * STANDARD_VECTOR_SIZE); -} - -validity_t *ColumnDataCollectionSegment::GetValidityPointerForWriting(data_ptr_t base_ptr, idx_t type_size) { - return reinterpret_cast(base_ptr + GetDataSize(type_size)); -} - -validity_t *ColumnDataCollectionSegment::GetValidityPointer(data_ptr_t base_ptr, idx_t type_size, idx_t count) { - auto validity_mask = reinterpret_cast(base_ptr + GetDataSize(type_size)); - - // Optimized check to see if all entries are valid - for (idx_t i = 0; i < (count / ValidityMask::BITS_PER_VALUE); i++) { - if (!ValidityMask::AllValid(validity_mask[i])) { - return validity_mask; - } - } - - if ((count % ValidityMask::BITS_PER_VALUE) != 0) { - // Create a mask with the lower `bits_to_check` bits set to 1 - validity_t mask = (1ULL << (count % ValidityMask::BITS_PER_VALUE)) - 1; - if ((validity_mask[(count / ValidityMask::BITS_PER_VALUE)] & mask) != mask) { - return validity_mask; - } - } - // All entries are valid, no need to initialize the validity mask - return nullptr; -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVectorInternal(const LogicalType &type, ChunkMetaData &chunk_meta, - ChunkManagementState *chunk_state) { - VectorMetaData meta_data; - meta_data.count = 0; - - auto internal_type = type.InternalType(); - auto struct_or_array = internal_type == PhysicalType::STRUCT || internal_type == PhysicalType::ARRAY; - auto type_size = struct_or_array ? 0 : GetTypeIdSize(internal_type); - - allocator->AllocateData(GetDataSize(type_size) + ValidityMask::STANDARD_MASK_SIZE, meta_data.block_id, - meta_data.offset, chunk_state); - if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR || - allocator->GetType() == ColumnDataAllocatorType::HYBRID) { - chunk_meta.block_ids.insert(meta_data.block_id); - } - - auto index = vector_data.size(); - vector_data.push_back(meta_data); - return VectorDataIndex(index); -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, - ChunkManagementState *chunk_state, - VectorDataIndex prev_index) { - auto index = AllocateVectorInternal(type, chunk_meta, chunk_state); - if (prev_index.IsValid()) { - GetVectorData(prev_index).next_data = index; - } - if (type.InternalType() == PhysicalType::STRUCT) { - // initialize the struct children - auto &child_types = StructType::GetChildTypes(type); - auto base_child_index = ReserveChildren(child_types.size()); - for (idx_t child_idx = 0; child_idx < child_types.size(); child_idx++) { - VectorDataIndex prev_child_index; - if (prev_index.IsValid()) { - prev_child_index = GetChildIndex(GetVectorData(prev_index).child_index, child_idx); - } - auto child_index = AllocateVector(child_types[child_idx].second, chunk_meta, chunk_state, prev_child_index); - SetChildIndex(base_child_index, child_idx, child_index); - } - GetVectorData(index).child_index = base_child_index; - } - return index; -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateVector(const LogicalType &type, ChunkMetaData &chunk_meta, - ColumnDataAppendState &append_state, - VectorDataIndex prev_index) { - return AllocateVector(type, chunk_meta, &append_state.current_chunk_state, prev_index); -} - -VectorDataIndex ColumnDataCollectionSegment::AllocateStringHeap(idx_t size, ChunkMetaData &chunk_meta, - ColumnDataAppendState &append_state, - VectorDataIndex prev_index) { - D_ASSERT(allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR); - D_ASSERT(size != 0); - - VectorMetaData meta_data; - meta_data.count = 0; - allocator->AllocateData(AlignValue(size), meta_data.block_id, meta_data.offset, &append_state.current_chunk_state); - chunk_meta.block_ids.insert(meta_data.block_id); - - VectorDataIndex index(vector_data.size()); - vector_data.push_back(meta_data); - - if (prev_index.IsValid()) { - GetVectorData(prev_index).next_data = index; - } - - return index; -} - -void ColumnDataCollectionSegment::AllocateNewChunk() { - ChunkMetaData meta_data; - meta_data.count = 0; - meta_data.vector_data.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - auto vector_idx = AllocateVector(types[i], meta_data); - meta_data.vector_data.push_back(vector_idx); - } - chunk_data.push_back(std::move(meta_data)); -} - -void ColumnDataCollectionSegment::InitializeChunkState(idx_t chunk_index, ChunkManagementState &state) { - auto &chunk = chunk_data[chunk_index]; - allocator->InitializeChunkState(state, chunk); -} - -VectorDataIndex ColumnDataCollectionSegment::GetChildIndex(VectorChildIndex index, idx_t child_entry) { - D_ASSERT(index.IsValid()); - D_ASSERT(index.index + child_entry < child_indices.size()); - return VectorDataIndex(child_indices[index.index + child_entry]); -} - -VectorChildIndex ColumnDataCollectionSegment::AddChildIndex(VectorDataIndex index) { - auto result = child_indices.size(); - child_indices.push_back(index); - return VectorChildIndex(result); -} - -VectorChildIndex ColumnDataCollectionSegment::ReserveChildren(idx_t child_count) { - auto result = child_indices.size(); - for (idx_t i = 0; i < child_count; i++) { - child_indices.emplace_back(); - } - return VectorChildIndex(result); -} - -void ColumnDataCollectionSegment::SetChildIndex(VectorChildIndex base_idx, idx_t child_number, VectorDataIndex index) { - D_ASSERT(base_idx.IsValid()); - D_ASSERT(index.IsValid()); - D_ASSERT(base_idx.index + child_number < child_indices.size()); - child_indices[base_idx.index + child_number] = index; -} - -idx_t ColumnDataCollectionSegment::ReadVectorInternal(ChunkManagementState &state, VectorDataIndex vector_index, - Vector &result) { - auto &vector_type = result.GetType(); - auto internal_type = vector_type.InternalType(); - auto type_size = GetTypeIdSize(internal_type); - auto &vdata = GetVectorData(vector_index); - - auto base_ptr = allocator->GetDataPointer(state, vdata.block_id, vdata.offset); - auto validity_data = GetValidityPointer(base_ptr, type_size, vdata.count); - if (!vdata.next_data.IsValid() && state.properties != ColumnDataScanProperties::DISALLOW_ZERO_COPY) { - // no next data, we can do a zero-copy read of this vector - FlatVector::SetData(result, base_ptr); - FlatVector::Validity(result).Initialize(validity_data, STANDARD_VECTOR_SIZE); - return vdata.count; - } - - // the data for this vector is spread over multiple vector data entries - // we need to copy over the data for each of the vectors - // first figure out how many rows we need to copy by looping over all of the child vector indexes - idx_t vector_count = 0; - auto next_index = vector_index; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - vector_count += current_vdata.count; - next_index = current_vdata.next_data; - } - // resize the result vector - result.Resize(0, vector_count); - next_index = vector_index; - // now perform the copy of each of the vectors - auto target_data = FlatVector::GetData(result); - auto &target_validity = FlatVector::Validity(result); - idx_t current_offset = 0; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - base_ptr = allocator->GetDataPointer(state, current_vdata.block_id, current_vdata.offset); - validity_data = GetValidityPointer(base_ptr, type_size, current_vdata.count); - if (type_size > 0) { - memcpy(target_data + current_offset * type_size, base_ptr, current_vdata.count * type_size); - } - ValidityMask current_validity(validity_data, STANDARD_VECTOR_SIZE); - target_validity.SliceInPlace(current_validity, current_offset, 0, current_vdata.count); - current_offset += current_vdata.count; - next_index = current_vdata.next_data; - } - return vector_count; -} - -idx_t ColumnDataCollectionSegment::ReadVector(ChunkManagementState &state, VectorDataIndex vector_index, - Vector &result) { - auto &vector_type = result.GetType(); - auto internal_type = vector_type.InternalType(); - auto &vdata = GetVectorData(vector_index); - if (vdata.count == 0) { - return 0; - } - auto vcount = ReadVectorInternal(state, vector_index, result); - if (internal_type == PhysicalType::LIST) { - // list: copy child - auto &child_vector = ListVector::GetEntry(result); - auto child_count = ReadVector(state, GetChildIndex(vdata.child_index), child_vector); - ListVector::SetListSize(result, child_count); - } else if (internal_type == PhysicalType::ARRAY) { - auto &child_vector = ArrayVector::GetEntry(result); - auto child_count = ReadVector(state, GetChildIndex(vdata.child_index), child_vector); - (void)child_count; - } else if (internal_type == PhysicalType::STRUCT) { - auto &child_vectors = StructVector::GetEntries(result); - for (idx_t child_idx = 0; child_idx < child_vectors.size(); child_idx++) { - auto child_count = - ReadVector(state, GetChildIndex(vdata.child_index, child_idx), *child_vectors[child_idx]); - if (child_count != vcount) { - throw InternalException("Column Data Collection: mismatch in struct child sizes"); - } - } - } else if (internal_type == PhysicalType::VARCHAR) { - if (allocator->GetType() == ColumnDataAllocatorType::BUFFER_MANAGER_ALLOCATOR) { - auto next_index = vector_index; - idx_t offset = 0; - while (next_index.IsValid()) { - auto ¤t_vdata = GetVectorData(next_index); - for (auto &swizzle_segment : current_vdata.swizzle_data) { - auto &string_heap_segment = GetVectorData(swizzle_segment.child_index); - allocator->UnswizzlePointers(state, result, offset + swizzle_segment.offset, swizzle_segment.count, - string_heap_segment.block_id, string_heap_segment.offset); - } - offset += current_vdata.count; - next_index = current_vdata.next_data; - } - } - if (state.properties == ColumnDataScanProperties::DISALLOW_ZERO_COPY) { - VectorOperations::Copy(result, result, vdata.count, 0, 0); - } - } - return vcount; -} - -void ColumnDataCollectionSegment::ReadChunk(idx_t chunk_index, ChunkManagementState &state, DataChunk &chunk, - const vector &column_ids) { - D_ASSERT(chunk.ColumnCount() == column_ids.size()); - D_ASSERT(state.properties != ColumnDataScanProperties::INVALID); - chunk.Reset(); - InitializeChunkState(chunk_index, state); - auto &chunk_meta = chunk_data[chunk_index]; - for (idx_t i = 0; i < column_ids.size(); i++) { - auto vector_idx = column_ids[i]; - D_ASSERT(vector_idx < chunk_meta.vector_data.size()); - ReadVector(state, chunk_meta.vector_data[vector_idx], chunk.data[i]); - } - chunk.SetCardinality(chunk_meta.count); -} - -idx_t ColumnDataCollectionSegment::ChunkCount() const { - return chunk_data.size(); -} - -idx_t ColumnDataCollectionSegment::SizeInBytes() const { - D_ASSERT(!allocator->IsShared()); - return allocator->SizeInBytes() + heap->SizeInBytes(); -} - -idx_t ColumnDataCollectionSegment::AllocationSize() const { - D_ASSERT(!allocator->IsShared()); - return allocator->AllocationSize() + heap->AllocationSize(); -} - -void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result) { - vector column_ids; - column_ids.reserve(types.size()); - for (idx_t i = 0; i < types.size(); i++) { - column_ids.push_back(i); - } - FetchChunk(chunk_idx, result, column_ids); -} - -void ColumnDataCollectionSegment::FetchChunk(idx_t chunk_idx, DataChunk &result, const vector &column_ids) { - D_ASSERT(chunk_idx < chunk_data.size()); - ChunkManagementState state; - state.properties = ColumnDataScanProperties::DISALLOW_ZERO_COPY; - ReadChunk(chunk_idx, state, result, column_ids); -} - -void ColumnDataCollectionSegment::Verify() { -#ifdef DEBUG - idx_t total_count = 0; - for (idx_t i = 0; i < chunk_data.size(); i++) { - total_count += chunk_data[i].count; - } - D_ASSERT(total_count == this->count); -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/column_data_consumer.cpp b/src/duckdb/src/common/types/column/column_data_consumer.cpp deleted file mode 100644 index fa20f1d5b..000000000 --- a/src/duckdb/src/common/types/column/column_data_consumer.cpp +++ /dev/null @@ -1,102 +0,0 @@ -#include "duckdb/common/types/column/column_data_consumer.hpp" - -#include - -namespace duckdb { - -using ChunkReference = ColumnDataConsumer::ChunkReference; - -ChunkReference::ChunkReference(ColumnDataCollectionSegment *segment_p, uint32_t chunk_index_p) - : segment(segment_p), chunk_index_in_segment(chunk_index_p) { -} - -uint32_t ChunkReference::GetMinimumBlockID() const { - const auto &block_ids = segment->chunk_data[chunk_index_in_segment].block_ids; - return *std::min_element(block_ids.begin(), block_ids.end()); -} - -ColumnDataConsumer::ColumnDataConsumer(ColumnDataCollection &collection_p, vector column_ids) - : collection(collection_p), column_ids(std::move(column_ids)) { -} - -void ColumnDataConsumer::InitializeScan() { - chunk_count = collection.ChunkCount(); - current_chunk_index = 0; - chunk_delete_index = DConstants::INVALID_INDEX; - - // Initialize chunk references and sort them, so we can scan them in a sane order, regardless of how it was created - chunk_references.reserve(chunk_count); - for (auto &segment : collection.GetSegments()) { - for (idx_t chunk_index = 0; chunk_index < segment->chunk_data.size(); chunk_index++) { - chunk_references.emplace_back(segment.get(), chunk_index); - } - } - std::sort(chunk_references.begin(), chunk_references.end()); -} - -bool ColumnDataConsumer::AssignChunk(ColumnDataConsumerScanState &state) { - lock_guard guard(lock); - if (current_chunk_index == chunk_count) { - // All chunks have been assigned - state.current_chunk_state.handles.clear(); - state.chunk_index = DConstants::INVALID_INDEX; - return false; - } - // Assign chunk index - state.chunk_index = current_chunk_index++; - D_ASSERT(chunks_in_progress.find(state.chunk_index) == chunks_in_progress.end()); - chunks_in_progress.insert(state.chunk_index); - return true; -} - -void ColumnDataConsumer::ScanChunk(ColumnDataConsumerScanState &state, DataChunk &chunk) const { - D_ASSERT(state.chunk_index < chunk_count); - auto &chunk_ref = chunk_references[state.chunk_index]; - if (state.allocator != chunk_ref.segment->allocator.get()) { - // Previously scanned a chunk from a different allocator, reset the handles - state.allocator = chunk_ref.segment->allocator.get(); - state.current_chunk_state.handles.clear(); - } - chunk_ref.segment->ReadChunk(chunk_ref.chunk_index_in_segment, state.current_chunk_state, chunk, column_ids); -} - -void ColumnDataConsumer::FinishChunk(ColumnDataConsumerScanState &state) { - D_ASSERT(state.chunk_index < chunk_count); - idx_t delete_index_start; - idx_t delete_index_end; - { - lock_guard guard(lock); - D_ASSERT(chunks_in_progress.find(state.chunk_index) != chunks_in_progress.end()); - delete_index_start = chunk_delete_index; - delete_index_end = *std::min_element(chunks_in_progress.begin(), chunks_in_progress.end()); - chunks_in_progress.erase(state.chunk_index); - chunk_delete_index = delete_index_end; - } - ConsumeChunks(delete_index_start, delete_index_end); -} -void ColumnDataConsumer::ConsumeChunks(idx_t delete_index_start, idx_t delete_index_end) { - for (idx_t chunk_index = delete_index_start; chunk_index < delete_index_end; chunk_index++) { - if (chunk_index == 0) { - continue; - } - auto &prev_chunk_ref = chunk_references[chunk_index - 1]; - auto &curr_chunk_ref = chunk_references[chunk_index]; - auto prev_allocator = prev_chunk_ref.segment->allocator.get(); - auto curr_allocator = curr_chunk_ref.segment->allocator.get(); - auto prev_min_block_id = prev_chunk_ref.GetMinimumBlockID(); - auto curr_min_block_id = curr_chunk_ref.GetMinimumBlockID(); - if (prev_allocator != curr_allocator) { - // Moved to the next allocator, delete all remaining blocks in the previous one - for (uint32_t block_id = prev_min_block_id; block_id < prev_allocator->BlockCount(); block_id++) { - prev_allocator->SetDestroyBufferUponUnpin(block_id); - } - continue; - } - // Same allocator, see if we can delete blocks - for (uint32_t block_id = prev_min_block_id; block_id < curr_min_block_id; block_id++) { - prev_allocator->SetDestroyBufferUponUnpin(block_id); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/column/partitioned_column_data.cpp b/src/duckdb/src/common/types/column/partitioned_column_data.cpp deleted file mode 100644 index 78e5b3673..000000000 --- a/src/duckdb/src/common/types/column/partitioned_column_data.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include "duckdb/common/types/column/partitioned_column_data.hpp" - -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -PartitionedColumnData::PartitionedColumnData(PartitionedColumnDataType type_p, ClientContext &context_p, - vector types_p) - : type(type_p), context(context_p), types(std::move(types_p)), - allocators(make_shared_ptr()) { -} - -PartitionedColumnData::PartitionedColumnData(const PartitionedColumnData &other) - : type(other.type), context(other.context), types(other.types), allocators(other.allocators) { -} - -unique_ptr PartitionedColumnData::CreateShared() { - switch (type) { - case PartitionedColumnDataType::RADIX: - return make_uniq(Cast()); - default: - throw NotImplementedException("CreateShared for this type of PartitionedColumnData"); - } -} - -PartitionedColumnData::~PartitionedColumnData() { -} - -void PartitionedColumnData::InitializeAppendState(PartitionedColumnDataAppendState &state) const { - state.partition_sel.Initialize(); - state.slice_chunk.Initialize(BufferAllocator::Get(context), types); - InitializeAppendStateInternal(state); -} - -bool PartitionedColumnData::UseFixedSizeMap() const { - return MaxPartitionIndex() < PartitionedTupleDataAppendState::MAP_THRESHOLD; -} - -unique_ptr PartitionedColumnData::CreatePartitionBuffer() const { - auto result = make_uniq(); - result->Initialize(BufferAllocator::Get(context), types, BufferSize()); - return result; -} - -void PartitionedColumnData::Append(PartitionedColumnDataAppendState &state, DataChunk &input) { - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(state, input); - - // Build the selection vector for the partitions - BuildPartitionSel(state, input.size()); - - // Early out: check if everything belongs to a single partition - const auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); - if (partition_index.IsValid()) { - auto &partition = *partitions[partition_index.GetIndex()]; - auto &partition_append_state = *state.partition_append_states[partition_index.GetIndex()]; - partition.Append(partition_append_state, input); - return; - } - - if (UseFixedSizeMap()) { - AppendInternal(state, input); - } else { - AppendInternal(state, input); - } -} - -void PartitionedColumnData::BuildPartitionSel(PartitionedColumnDataAppendState &state, const idx_t append_count) const { - if (UseFixedSizeMap()) { - BuildPartitionSel(state, append_count); - } else { - BuildPartitionSel(state, append_count); - } -} - -template -MAP_TYPE &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &) { - throw InternalException("Unknown MAP_TYPE for PartitionedTupleDataGetMap"); -} - -template <> -fixed_size_map_t &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &state) { - return state.fixed_partition_entries; -} - -template <> -perfect_map_t &PartitionedColumnDataGetMap(PartitionedColumnDataAppendState &state) { - return state.partition_entries; -} - -template -void PartitionedColumnData::BuildPartitionSel(PartitionedColumnDataAppendState &state, const idx_t append_count) { - using GETTER = TemplatedMapGetter; - auto &partition_entries = state.GetMap(); - partition_entries.clear(); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - switch (state.partition_indices.GetVectorType()) { - case VectorType::FLAT_VECTOR: - for (idx_t i = 0; i < append_count; i++) { - const auto &partition_index = partition_indices[i]; - auto partition_entry = partition_entries.find(partition_index); - if (partition_entry == partition_entries.end()) { - partition_entries[partition_index] = list_entry_t(0, 1); - } else { - GETTER::GetValue(partition_entry).length++; - } - } - break; - case VectorType::CONSTANT_VECTOR: - partition_entries[partition_indices[0]] = list_entry_t(0, append_count); - break; - default: - throw InternalException("Unexpected VectorType in PartitionedTupleData::Append"); - } - - // Early out: check if everything belongs to a single partition - if (partition_entries.size() == 1) { - return; - } - - // Compute offsets from the counts - idx_t offset = 0; - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - auto &partition_entry = GETTER::GetValue(it); - partition_entry.offset = offset; - offset += partition_entry.length; - } - - // Now initialize a single selection vector that acts as a selection vector for every partition - auto &partition_sel = state.partition_sel; - for (idx_t i = 0; i < append_count; i++) { - const auto &partition_index = partition_indices[i]; - auto &partition_offset = partition_entries[partition_index].offset; - partition_sel[partition_offset++] = UnsafeNumericCast(i); - } -} - -template -void PartitionedColumnData::AppendInternal(PartitionedColumnDataAppendState &state, DataChunk &input) { - using GETTER = TemplatedMapGetter; - const auto &partition_entries = state.GetMap(); - - // Loop through the partitions to append the new data to the partition buffers, and flush the buffers if necessary - SelectionVector partition_sel; - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - const auto &partition_index = GETTER::GetKey(it); - - // Partition, buffer, and append state for this partition index - auto &partition = *partitions[partition_index]; - auto &partition_buffer = *state.partition_buffers[partition_index]; - auto &partition_append_state = *state.partition_append_states[partition_index]; - - // Length and offset into the selection vector for this chunk, for this partition - const auto &partition_entry = GETTER::GetValue(it); - const auto &partition_length = partition_entry.length; - const auto partition_offset = partition_entry.offset - partition_length; - - // Create a selection vector for this partition using the offset into the single selection vector - partition_sel.Initialize(state.partition_sel.data() + partition_offset); - - if (partition_length >= HalfBufferSize()) { - // Slice the input chunk using the selection vector - state.slice_chunk.Reset(); - state.slice_chunk.Slice(input, partition_sel, partition_length); - - // Append it to the partition directly - partition.Append(partition_append_state, state.slice_chunk); - } else { - // Append the input chunk to the partition buffer using the selection vector - partition_buffer.Append(input, false, &partition_sel, partition_length); - - if (partition_buffer.size() >= HalfBufferSize()) { - // Next batch won't fit in the buffer, flush it to the partition - partition.Append(partition_append_state, partition_buffer); - partition_buffer.Reset(); - partition_buffer.SetCapacity(BufferSize()); - } - } - } -} - -void PartitionedColumnData::FlushAppendState(PartitionedColumnDataAppendState &state) { - for (idx_t i = 0; i < state.partition_buffers.size(); i++) { - if (!state.partition_buffers[i]) { - continue; - } - auto &partition_buffer = *state.partition_buffers[i]; - if (partition_buffer.size() > 0) { - partitions[i]->Append(partition_buffer); - partition_buffer.Reset(); - } - } -} - -void PartitionedColumnData::Combine(PartitionedColumnData &other) { - // Now combine the state's partitions into this - lock_guard guard(lock); - - if (partitions.empty()) { - // This is the first merge, we just copy them over - partitions = std::move(other.partitions); - } else { - D_ASSERT(partitions.size() == other.partitions.size()); - // Combine the append state's partitions into this PartitionedColumnData - for (idx_t i = 0; i < other.partitions.size(); i++) { - if (!other.partitions[i]) { - continue; - } - if (!partitions[i]) { - partitions[i] = std::move(other.partitions[i]); - } else { - partitions[i]->Combine(*other.partitions[i]); - } - } - } -} - -vector> &PartitionedColumnData::GetPartitions() { - return partitions; -} - -void PartitionedColumnData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared_ptr(BufferManager::GetBufferManager(context))); - allocators->allocators.back()->MakeShared(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/conflict_info.cpp b/src/duckdb/src/common/types/conflict_info.cpp deleted file mode 100644 index 44f8aa7f1..000000000 --- a/src/duckdb/src/common/types/conflict_info.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/common/types/constraint_conflict_info.hpp" -#include "duckdb/storage/index.hpp" - -namespace duckdb { - -bool ConflictInfo::ConflictTargetMatches(Index &index) const { - if (only_check_unique && !index.IsUnique()) { - // We only support checking ON CONFLICT for Unique/Primary key constraints - return false; - } - if (column_ids.empty()) { - return true; - } - // Check whether the column ids match - return column_ids == index.GetColumnIdSet(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/conflict_manager.cpp b/src/duckdb/src/common/types/conflict_manager.cpp deleted file mode 100644 index 409d0278f..000000000 --- a/src/duckdb/src/common/types/conflict_manager.cpp +++ /dev/null @@ -1,272 +0,0 @@ -#include "duckdb/common/types/conflict_manager.hpp" -#include "duckdb/storage/index.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/common/types/constraint_conflict_info.hpp" - -namespace duckdb { - -ConflictManager::ConflictManager(VerifyExistenceType lookup_type, idx_t input_size, - optional_ptr conflict_info) - : lookup_type(lookup_type), input_size(input_size), conflict_info(conflict_info), conflicts(input_size, false), - mode(ConflictManagerMode::THROW) { -} - -ManagedSelection &ConflictManager::InternalSelection() { - if (!conflicts.Initialized()) { - conflicts.Initialize(input_size); - } - return conflicts; -} - -const unordered_set &ConflictManager::InternalConflictSet() const { - D_ASSERT(conflict_set); - return *conflict_set; -} - -Vector &ConflictManager::InternalRowIds() { - if (!row_ids) { - row_ids = make_uniq(LogicalType::ROW_TYPE, input_size); - } - return *row_ids; -} - -Vector &ConflictManager::InternalIntermediate() { - if (!intermediate_vector) { - intermediate_vector = make_uniq(LogicalType::BOOLEAN, true, true, input_size); - } - return *intermediate_vector; -} - -const ConflictInfo &ConflictManager::GetConflictInfo() const { - D_ASSERT(conflict_info); - return *conflict_info; -} - -void ConflictManager::FinishLookup() { - if (mode == ConflictManagerMode::THROW) { - return; - } - if (!SingleIndexTarget()) { - return; - } - if (conflicts.Count() != 0) { - // We have recorded conflicts from the one index we're interested in - // We set this so we don't duplicate the conflicts when there are duplicate indexes - // that also match our conflict target - single_index_finished = true; - } -} - -void ConflictManager::SetMode(ConflictManagerMode mode) { - // Only allow SCAN when we have conflict info - D_ASSERT(mode != ConflictManagerMode::SCAN || conflict_info != nullptr); - this->mode = mode; -} - -void ConflictManager::AddToConflictSet(idx_t chunk_index) { - if (!conflict_set) { - conflict_set = make_uniq>(); - } - auto &set = *conflict_set; - set.insert(chunk_index); -} - -void ConflictManager::AddConflictInternal(idx_t chunk_index, row_t row_id) { - D_ASSERT(mode == ConflictManagerMode::SCAN); - - // Only when we should not throw on conflict should we get here - D_ASSERT(!ShouldThrow(chunk_index)); - AddToConflictSet(chunk_index); - if (SingleIndexTarget()) { - // If we have identical indexes, only the conflicts of the first index should be recorded - // as the other index(es) would produce the exact same conflicts anyways - if (single_index_finished) { - return; - } - - // We can be more efficient because we don't need to merge conflicts of multiple indexes - auto &selection = InternalSelection(); - auto &row_ids = InternalRowIds(); - auto data = FlatVector::GetData(row_ids); - data[selection.Count()] = row_id; - selection.Append(chunk_index); - } else { - auto &intermediate = InternalIntermediate(); - auto data = FlatVector::GetData(intermediate); - // Mark this index in the chunk as producing a conflict - data[chunk_index] = true; - if (row_id_map.empty()) { - row_id_map.resize(input_size); - } - row_id_map[chunk_index] = row_id; - } -} - -bool ConflictManager::IsConflict(LookupResultType type) { - switch (type) { - case LookupResultType::LOOKUP_NULL: { - if (ShouldIgnoreNulls()) { - return false; - } - // If nulls are not ignored, treat this as a hit instead - return IsConflict(LookupResultType::LOOKUP_HIT); - } - case LookupResultType::LOOKUP_HIT: { - return true; - } - case LookupResultType::LOOKUP_MISS: { - // FIXME: If we record a miss as a conflict when the verify type is APPEND_FK, then we can simplify the checks - // in VerifyForeignKeyConstraint This also means we should not record a hit as a conflict when the verify type - // is APPEND_FK - return false; - } - default: { - throw NotImplementedException("Type not implemented for LookupResultType"); - } - } -} - -bool ConflictManager::AddHit(idx_t chunk_index, row_t row_id) { - D_ASSERT(chunk_index < input_size); - // First check if this causes a conflict - if (!IsConflict(LookupResultType::LOOKUP_HIT)) { - return false; - } - - // Then check if we should throw on a conflict - if (ShouldThrow(chunk_index)) { - return true; - } - if (mode == ConflictManagerMode::THROW) { - // When our mode is THROW, and the chunk index is part of the previously scanned conflicts - // then we ignore the conflict instead - D_ASSERT(!ShouldThrow(chunk_index)); - return false; - } - D_ASSERT(conflict_info); - // Because we don't throw, we need to register the conflict - AddConflictInternal(chunk_index, row_id); - return false; -} - -bool ConflictManager::AddMiss(idx_t chunk_index) { - D_ASSERT(chunk_index < input_size); - return IsConflict(LookupResultType::LOOKUP_MISS); -} - -bool ConflictManager::AddNull(idx_t chunk_index) { - D_ASSERT(chunk_index < input_size); - if (!IsConflict(LookupResultType::LOOKUP_NULL)) { - return false; - } - return AddHit(chunk_index, static_cast(DConstants::INVALID_INDEX)); -} - -bool ConflictManager::SingleIndexTarget() const { - D_ASSERT(conflict_info); - // We are only interested in a specific index - return !conflict_info->column_ids.empty(); -} - -bool ConflictManager::ShouldThrow(idx_t chunk_index) const { - if (mode == ConflictManagerMode::SCAN) { - return false; - } - D_ASSERT(mode == ConflictManagerMode::THROW); - if (conflict_set == nullptr) { - // No conflicts were scanned, so this conflict is not in the set - return true; - } - auto &set = InternalConflictSet(); - if (set.count(chunk_index)) { - return false; - } - // None of the scanned conflicts arose from this insert tuple - return true; -} - -bool ConflictManager::ShouldIgnoreNulls() const { - switch (lookup_type) { - case VerifyExistenceType::APPEND: - return true; - case VerifyExistenceType::APPEND_FK: - return false; - case VerifyExistenceType::DELETE_FK: - return true; - default: - throw InternalException("Type not implemented for VerifyExistenceType"); - } -} - -Vector &ConflictManager::RowIds() { - D_ASSERT(finalized); - return *row_ids; -} - -const ManagedSelection &ConflictManager::Conflicts() const { - D_ASSERT(finalized); - return conflicts; -} - -idx_t ConflictManager::ConflictCount() const { - return conflicts.Count(); -} - -void ConflictManager::AddIndex(BoundIndex &index, optional_ptr delete_index) { - matched_indexes.push_back(index); - matched_delete_indexes.push_back(delete_index); - matched_index_names.insert(index.name); -} - -bool ConflictManager::MatchedIndex(BoundIndex &index) { - return matched_index_names.find(index.name) != matched_index_names.end(); -} - -const vector> &ConflictManager::MatchedIndexes() const { - return matched_indexes; -} - -const vector> &ConflictManager::MatchedDeleteIndexes() const { - return matched_delete_indexes; -} - -void ConflictManager::Finalize() { - D_ASSERT(!finalized); - if (SingleIndexTarget()) { - // Selection vector has been directly populated already, no need to finalize - finalized = true; - return; - } - finalized = true; - if (!intermediate_vector) { - // No conflicts were found, we're done - return; - } - auto &intermediate = InternalIntermediate(); - auto data = FlatVector::GetData(intermediate); - auto &selection = InternalSelection(); - // Create the selection vector from the encountered conflicts - for (idx_t i = 0; i < input_size; i++) { - if (data[i]) { - selection.Append(i); - } - } - // Now create the row_ids Vector, aligned with the selection vector - auto &internal_row_ids = InternalRowIds(); - auto row_id_data = FlatVector::GetData(internal_row_ids); - - for (idx_t i = 0; i < selection.Count(); i++) { - D_ASSERT(!row_id_map.empty()); - auto index = selection[i]; - D_ASSERT(index < row_id_map.size()); - auto row_id = row_id_map[index]; - row_id_data[i] = row_id; - } - intermediate_vector.reset(); -} - -VerifyExistenceType ConflictManager::LookupType() const { - return lookup_type; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/data_chunk.cpp b/src/duckdb/src/common/types/data_chunk.cpp deleted file mode 100644 index 8b00a95f7..000000000 --- a/src/duckdb/src/common/types/data_chunk.cpp +++ /dev/null @@ -1,396 +0,0 @@ -#include "duckdb/common/types/data_chunk.hpp" - -#include "duckdb/common/array.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/sel_cache.hpp" -#include "duckdb/common/types/vector_cache.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/execution_context.hpp" - -#include "duckdb/common/serializer/memory_stream.hpp" -#include "duckdb/common/serializer/binary_serializer.hpp" -#include "duckdb/common/serializer/binary_deserializer.hpp" - -namespace duckdb { - -DataChunk::DataChunk() : count(0), capacity(STANDARD_VECTOR_SIZE) { -} - -DataChunk::~DataChunk() { -} - -void DataChunk::InitializeEmpty(const vector &types) { - D_ASSERT(data.empty()); - capacity = STANDARD_VECTOR_SIZE; - for (idx_t i = 0; i < types.size(); i++) { - data.emplace_back(types[i], nullptr); - } -} - -void DataChunk::Initialize(ClientContext &context, const vector &types, idx_t capacity_p) { - Initialize(Allocator::Get(context), types, capacity_p); -} - -void DataChunk::Initialize(Allocator &allocator, const vector &types, idx_t capacity_p) { - auto initialize = vector(types.size(), true); - Initialize(allocator, types, initialize, capacity_p); -} - -void DataChunk::Initialize(ClientContext &context, const vector &types, const vector &initialize, - idx_t capacity_p) { - Initialize(Allocator::Get(context), types, initialize, capacity_p); -} - -void DataChunk::Initialize(Allocator &allocator, const vector &types, const vector &initialize, - idx_t capacity_p) { - D_ASSERT(types.size() == initialize.size()); - D_ASSERT(data.empty()); - - capacity = capacity_p; - for (idx_t i = 0; i < types.size(); i++) { - if (!initialize[i]) { - data.emplace_back(types[i], nullptr); - vector_caches.emplace_back(); - continue; - } - - VectorCache cache(allocator, types[i], capacity); - data.emplace_back(cache); - vector_caches.push_back(std::move(cache)); - } -} - -idx_t DataChunk::GetAllocationSize() const { - idx_t total_size = 0; - auto cardinality = size(); - for (auto &vec : data) { - total_size += vec.GetAllocationSize(cardinality); - } - return total_size; -} - -void DataChunk::Reset() { - if (data.empty() || vector_caches.empty()) { - return; - } - if (vector_caches.size() != data.size()) { - throw InternalException("VectorCache and column count mismatch in DataChunk::Reset"); - } - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].ResetFromCache(vector_caches[i]); - } - capacity = STANDARD_VECTOR_SIZE; - SetCardinality(0); -} - -void DataChunk::Destroy() { - data.clear(); - vector_caches.clear(); - capacity = 0; - SetCardinality(0); -} - -Value DataChunk::GetValue(idx_t col_idx, idx_t index) const { - D_ASSERT(index < size()); - return data[col_idx].GetValue(index); -} - -void DataChunk::SetValue(idx_t col_idx, idx_t index, const Value &val) { - data[col_idx].SetValue(index, val); -} - -bool DataChunk::AllConstant() const { - for (auto &v : data) { - if (v.GetVectorType() != VectorType::CONSTANT_VECTOR) { - return false; - } - } - return true; -} - -void DataChunk::Reference(DataChunk &chunk) { - D_ASSERT(chunk.ColumnCount() <= ColumnCount()); - SetCapacity(chunk); - SetCardinality(chunk); - for (idx_t i = 0; i < chunk.ColumnCount(); i++) { - data[i].Reference(chunk.data[i]); - } -} - -void DataChunk::Move(DataChunk &chunk) { - SetCardinality(chunk); - SetCapacity(chunk); - data = std::move(chunk.data); - vector_caches = std::move(chunk.vector_caches); - - chunk.Destroy(); -} - -void DataChunk::Copy(DataChunk &other, idx_t offset) const { - D_ASSERT(ColumnCount() == other.ColumnCount()); - D_ASSERT(other.size() == 0); - - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(data[i], other.data[i], size(), offset, 0); - } - other.SetCardinality(size() - offset); -} - -void DataChunk::Copy(DataChunk &other, const SelectionVector &sel, const idx_t source_count, const idx_t offset) const { - D_ASSERT(ColumnCount() == other.ColumnCount()); - D_ASSERT(other.size() == 0); - D_ASSERT(source_count <= size()); - - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(other.data[i].GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(data[i], other.data[i], sel, source_count, offset, 0); - } - other.SetCardinality(source_count - offset); -} - -void DataChunk::Split(DataChunk &other, idx_t split_idx) { - D_ASSERT(other.size() == 0); - D_ASSERT(other.data.empty()); - D_ASSERT(split_idx < data.size()); - const idx_t num_cols = data.size(); - for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { - other.data.push_back(std::move(data[col_idx])); - other.vector_caches.push_back(std::move(vector_caches[col_idx])); - } - for (idx_t col_idx = split_idx; col_idx < num_cols; col_idx++) { - data.pop_back(); - vector_caches.pop_back(); - } - other.SetCapacity(*this); - other.SetCardinality(*this); -} - -void DataChunk::Fuse(DataChunk &other) { - D_ASSERT(other.size() == size()); - const idx_t num_cols = other.data.size(); - for (idx_t col_idx = 0; col_idx < num_cols; ++col_idx) { - data.emplace_back(std::move(other.data[col_idx])); - vector_caches.emplace_back(std::move(other.vector_caches[col_idx])); - } - other.Destroy(); -} - -void DataChunk::ReferenceColumns(DataChunk &other, const vector &column_ids) { - D_ASSERT(ColumnCount() == column_ids.size()); - Reset(); - for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { - auto &other_col = other.data[column_ids[col_idx]]; - auto &this_col = data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - SetCardinality(other.size()); -} - -void DataChunk::Append(const DataChunk &other, bool resize, SelectionVector *sel, idx_t sel_count) { - idx_t new_size = sel ? size() + sel_count : size() + other.size(); - if (other.size() == 0) { - return; - } - if (ColumnCount() != other.ColumnCount()) { - throw InternalException("Column counts of appending chunk doesn't match!"); - } - if (new_size > capacity) { - if (resize) { - auto new_capacity = NextPowerOfTwo(new_size); - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Resize(size(), new_capacity); - } - capacity = new_capacity; - } else { - throw InternalException("Can't append chunk to other chunk without resizing"); - } - } - for (idx_t i = 0; i < ColumnCount(); i++) { - D_ASSERT(data[i].GetVectorType() == VectorType::FLAT_VECTOR); - if (sel) { - VectorOperations::Copy(other.data[i], data[i], *sel, sel_count, 0, size()); - } else { - VectorOperations::Copy(other.data[i], data[i], other.size(), 0, size()); - } - } - SetCardinality(new_size); -} - -void DataChunk::Flatten() { - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Flatten(size()); - } -} - -vector DataChunk::GetTypes() const { - vector types; - for (idx_t i = 0; i < ColumnCount(); i++) { - types.push_back(data[i].GetType()); - } - return types; -} - -string DataChunk::ToString() const { - string retval = "Chunk - [" + to_string(ColumnCount()) + " Columns]\n"; - for (idx_t i = 0; i < ColumnCount(); i++) { - retval += "- " + data[i].ToString(size()) + "\n"; - } - return retval; -} - -void DataChunk::Serialize(Serializer &serializer) const { - - // write the count - auto row_count = size(); - serializer.WriteProperty(100, "rows", NumericCast(row_count)); - - // we should never try to serialize empty data chunks - auto column_count = ColumnCount(); - D_ASSERT(column_count); - - // write the types - serializer.WriteList(101, "types", column_count, - [&](Serializer::List &list, idx_t i) { list.WriteElement(data[i].GetType()); }); - - // write the data - serializer.WriteList(102, "columns", column_count, [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { - // Reference the vector to avoid potentially mutating it during serialization - Vector serialized_vector(data[i].GetType()); - serialized_vector.Reference(data[i]); - serialized_vector.Serialize(object, row_count); - }); - }); -} - -void DataChunk::Deserialize(Deserializer &deserializer) { - - // read and set the row count - auto row_count = deserializer.ReadProperty(100, "rows"); - - // read the types - vector types; - deserializer.ReadList(101, "types", [&](Deserializer::List &list, idx_t i) { - auto type = list.ReadElement(); - types.push_back(type); - }); - - // initialize the data chunk - D_ASSERT(!types.empty()); - Initialize(Allocator::DefaultAllocator(), types, MaxValue(row_count, STANDARD_VECTOR_SIZE)); - SetCardinality(row_count); - - // read the data - deserializer.ReadList(102, "columns", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &object) { data[i].Deserialize(object, row_count); }); - }); -} - -void DataChunk::Slice(const SelectionVector &sel_vector, idx_t count_p) { - this->count = count_p; - SelCache merge_cache; - for (idx_t c = 0; c < ColumnCount(); c++) { - data[c].Slice(sel_vector, count_p, merge_cache); - } -} - -void DataChunk::Slice(const DataChunk &other, const SelectionVector &sel, idx_t count_p, idx_t col_offset) { - D_ASSERT(other.ColumnCount() <= col_offset + ColumnCount()); - this->count = count_p; - SelCache merge_cache; - for (idx_t c = 0; c < other.ColumnCount(); c++) { - if (other.data[c].GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // already a dictionary! merge the dictionaries - data[col_offset + c].Reference(other.data[c]); - data[col_offset + c].Slice(sel, count_p, merge_cache); - } else { - data[col_offset + c].Slice(other.data[c], sel, count_p); - } - } -} - -void DataChunk::Slice(idx_t offset, idx_t slice_count) { - D_ASSERT(offset + slice_count <= size()); - SelectionVector sel(slice_count); - for (idx_t i = 0; i < slice_count; i++) { - sel.set_index(i, offset + i); - } - Slice(sel, slice_count); -} - -unsafe_unique_array DataChunk::ToUnifiedFormat() { - auto unified_data = make_unsafe_uniq_array(ColumnCount()); - for (idx_t col_idx = 0; col_idx < ColumnCount(); col_idx++) { - data[col_idx].ToUnifiedFormat(size(), unified_data[col_idx]); - } - return unified_data; -} - -void DataChunk::Hash(Vector &result) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - VectorOperations::Hash(data[0], result, size()); - for (idx_t i = 1; i < ColumnCount(); i++) { - VectorOperations::CombineHash(result, data[i], size()); - } -} - -void DataChunk::Hash(vector &column_ids, Vector &result) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - D_ASSERT(!column_ids.empty()); - - VectorOperations::Hash(data[column_ids[0]], result, size()); - for (idx_t i = 1; i < column_ids.size(); i++) { - VectorOperations::CombineHash(result, data[column_ids[i]], size()); - } -} - -void DataChunk::Verify() { -#ifdef DEBUG - D_ASSERT(size() <= capacity); - - // verify that all vectors in this chunk have the chunk selection vector - for (idx_t i = 0; i < ColumnCount(); i++) { - data[i].Verify(size()); - } - - if (!ColumnCount()) { - // don't try to round-trip dummy data chunks with no data - // e.g., these exist in queries like 'SELECT distinct(col0, col1) FROM tbl', where we have groups, but no - // payload so the payload will be such an empty data chunk - return; - } - - // verify that we can round-trip chunk serialization - MemoryStream mem_stream; - BinarySerializer serializer(mem_stream); - - serializer.Begin(); - Serialize(serializer); - serializer.End(); - - mem_stream.Rewind(); - - BinaryDeserializer deserializer(mem_stream); - DataChunk new_chunk; - - deserializer.Begin(); - new_chunk.Deserialize(deserializer); - deserializer.End(); - - D_ASSERT(size() == new_chunk.size()); -#endif -} - -void DataChunk::Print() const { - Printer::Print(ToString()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/date.cpp b/src/duckdb/src/common/types/date.cpp deleted file mode 100644 index 429ee0311..000000000 --- a/src/duckdb/src/common/types/date.cpp +++ /dev/null @@ -1,625 +0,0 @@ -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/assert.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/limits.hpp" -#include -#include -#include - -namespace duckdb { - -static_assert(sizeof(date_t) == sizeof(int32_t), "date_t was padded"); - -const char *Date::PINF = "infinity"; // NOLINT -const char *Date::NINF = "-infinity"; // NOLINT -const char *Date::EPOCH = "epoch"; // NOLINT - -const string_t Date::MONTH_NAMES_ABBREVIATED[] = {"Jan", "Feb", "Mar", "Apr", "May", "Jun", - "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; -const string_t Date::MONTH_NAMES[] = {"January", "February", "March", "April", "May", "June", - "July", "August", "September", "October", "November", "December"}; -const string_t Date::DAY_NAMES[] = {"Sunday", "Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday"}; -const string_t Date::DAY_NAMES_ABBREVIATED[] = {"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}; - -const int32_t Date::NORMAL_DAYS[] = {0, 31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; -const int32_t Date::CUMULATIVE_DAYS[] = {0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365}; -const int32_t Date::LEAP_DAYS[] = {0, 31, 29, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31}; -const int32_t Date::CUMULATIVE_LEAP_DAYS[] = {0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366}; -const int8_t Date::MONTH_PER_DAY_OF_YEAR[] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, - 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; -const int8_t Date::LEAP_MONTH_PER_DAY_OF_YEAR[] = { - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, - 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, - 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, - 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, - 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, - 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, - 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12}; -const int32_t Date::CUMULATIVE_YEAR_DAYS[] = { - 0, 365, 730, 1096, 1461, 1826, 2191, 2557, 2922, 3287, 3652, 4018, 4383, 4748, - 5113, 5479, 5844, 6209, 6574, 6940, 7305, 7670, 8035, 8401, 8766, 9131, 9496, 9862, - 10227, 10592, 10957, 11323, 11688, 12053, 12418, 12784, 13149, 13514, 13879, 14245, 14610, 14975, - 15340, 15706, 16071, 16436, 16801, 17167, 17532, 17897, 18262, 18628, 18993, 19358, 19723, 20089, - 20454, 20819, 21184, 21550, 21915, 22280, 22645, 23011, 23376, 23741, 24106, 24472, 24837, 25202, - 25567, 25933, 26298, 26663, 27028, 27394, 27759, 28124, 28489, 28855, 29220, 29585, 29950, 30316, - 30681, 31046, 31411, 31777, 32142, 32507, 32872, 33238, 33603, 33968, 34333, 34699, 35064, 35429, - 35794, 36160, 36525, 36890, 37255, 37621, 37986, 38351, 38716, 39082, 39447, 39812, 40177, 40543, - 40908, 41273, 41638, 42004, 42369, 42734, 43099, 43465, 43830, 44195, 44560, 44926, 45291, 45656, - 46021, 46387, 46752, 47117, 47482, 47847, 48212, 48577, 48942, 49308, 49673, 50038, 50403, 50769, - 51134, 51499, 51864, 52230, 52595, 52960, 53325, 53691, 54056, 54421, 54786, 55152, 55517, 55882, - 56247, 56613, 56978, 57343, 57708, 58074, 58439, 58804, 59169, 59535, 59900, 60265, 60630, 60996, - 61361, 61726, 62091, 62457, 62822, 63187, 63552, 63918, 64283, 64648, 65013, 65379, 65744, 66109, - 66474, 66840, 67205, 67570, 67935, 68301, 68666, 69031, 69396, 69762, 70127, 70492, 70857, 71223, - 71588, 71953, 72318, 72684, 73049, 73414, 73779, 74145, 74510, 74875, 75240, 75606, 75971, 76336, - 76701, 77067, 77432, 77797, 78162, 78528, 78893, 79258, 79623, 79989, 80354, 80719, 81084, 81450, - 81815, 82180, 82545, 82911, 83276, 83641, 84006, 84371, 84736, 85101, 85466, 85832, 86197, 86562, - 86927, 87293, 87658, 88023, 88388, 88754, 89119, 89484, 89849, 90215, 90580, 90945, 91310, 91676, - 92041, 92406, 92771, 93137, 93502, 93867, 94232, 94598, 94963, 95328, 95693, 96059, 96424, 96789, - 97154, 97520, 97885, 98250, 98615, 98981, 99346, 99711, 100076, 100442, 100807, 101172, 101537, 101903, - 102268, 102633, 102998, 103364, 103729, 104094, 104459, 104825, 105190, 105555, 105920, 106286, 106651, 107016, - 107381, 107747, 108112, 108477, 108842, 109208, 109573, 109938, 110303, 110669, 111034, 111399, 111764, 112130, - 112495, 112860, 113225, 113591, 113956, 114321, 114686, 115052, 115417, 115782, 116147, 116513, 116878, 117243, - 117608, 117974, 118339, 118704, 119069, 119435, 119800, 120165, 120530, 120895, 121260, 121625, 121990, 122356, - 122721, 123086, 123451, 123817, 124182, 124547, 124912, 125278, 125643, 126008, 126373, 126739, 127104, 127469, - 127834, 128200, 128565, 128930, 129295, 129661, 130026, 130391, 130756, 131122, 131487, 131852, 132217, 132583, - 132948, 133313, 133678, 134044, 134409, 134774, 135139, 135505, 135870, 136235, 136600, 136966, 137331, 137696, - 138061, 138427, 138792, 139157, 139522, 139888, 140253, 140618, 140983, 141349, 141714, 142079, 142444, 142810, - 143175, 143540, 143905, 144271, 144636, 145001, 145366, 145732, 146097}; - -void Date::ExtractYearOffset(int32_t &n, int32_t &year, int32_t &year_offset) { - year = Date::EPOCH_YEAR; - // first we normalize n to be in the year range [1970, 2370] - // since leap years repeat every 400 years, we can safely normalize just by "shifting" the CumulativeYearDays array - while (n < 0) { - n += Date::DAYS_PER_YEAR_INTERVAL; - year -= Date::YEAR_INTERVAL; - } - while (n >= Date::DAYS_PER_YEAR_INTERVAL) { - n -= Date::DAYS_PER_YEAR_INTERVAL; - year += Date::YEAR_INTERVAL; - } - // interpolation search - // we can find an upper bound of the year by assuming each year has 365 days - year_offset = n / 365; - // because of leap years we might be off by a little bit: compensate by decrementing the year offset until we find - // our year - while (n < Date::CUMULATIVE_YEAR_DAYS[year_offset]) { - year_offset--; - D_ASSERT(year_offset >= 0); - } - year += year_offset; - D_ASSERT(n >= Date::CUMULATIVE_YEAR_DAYS[year_offset]); -} - -void Date::Convert(date_t d, int32_t &year, int32_t &month, int32_t &day) { - auto n = d.days; - int32_t year_offset; - Date::ExtractYearOffset(n, year, year_offset); - - day = n - Date::CUMULATIVE_YEAR_DAYS[year_offset]; - D_ASSERT(day >= 0 && day <= 365); - - bool is_leap_year = (Date::CUMULATIVE_YEAR_DAYS[year_offset + 1] - Date::CUMULATIVE_YEAR_DAYS[year_offset]) == 366; - if (is_leap_year) { - month = Date::LEAP_MONTH_PER_DAY_OF_YEAR[day]; - day -= Date::CUMULATIVE_LEAP_DAYS[month - 1]; - } else { - month = Date::MONTH_PER_DAY_OF_YEAR[day]; - day -= Date::CUMULATIVE_DAYS[month - 1]; - } - day++; - D_ASSERT(day > 0 && day <= (is_leap_year ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month])); - D_ASSERT(month > 0 && month <= 12); -} - -bool Date::TryFromDate(int32_t year, int32_t month, int32_t day, date_t &result) { - int32_t n = 0; - if (!Date::IsValid(year, month, day)) { - return false; - } - n += Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month - 1] : Date::CUMULATIVE_DAYS[month - 1]; - n += day - 1; - if (year < 1970) { - int32_t diff_from_base = 1970 - year; - int32_t year_index = 400 - (diff_from_base % 400); - int32_t fractions = diff_from_base / 400; - n += Date::CUMULATIVE_YEAR_DAYS[year_index]; - n -= Date::DAYS_PER_YEAR_INTERVAL; - n -= fractions * Date::DAYS_PER_YEAR_INTERVAL; - } else if (year >= 2370) { - int32_t diff_from_base = year - 2370; - int32_t year_index = diff_from_base % 400; - int32_t fractions = diff_from_base / 400; - n += Date::CUMULATIVE_YEAR_DAYS[year_index]; - n += Date::DAYS_PER_YEAR_INTERVAL; - n += fractions * Date::DAYS_PER_YEAR_INTERVAL; - } else { - n += Date::CUMULATIVE_YEAR_DAYS[year - 1970]; - } -#ifdef DEBUG - int32_t y, m, d; - Date::Convert(date_t(n), y, m, d); - D_ASSERT(year == y); - D_ASSERT(month == m); - D_ASSERT(day == d); -#endif - result = date_t(n); - return true; -} - -date_t Date::FromDate(int32_t year, int32_t month, int32_t day) { - date_t result; - if (!Date::TryFromDate(year, month, day, result)) { - throw ConversionException("Date out of range: %d-%d-%d", year, month, day); - } - return result; -} - -bool Date::ParseDoubleDigit(const char *buf, idx_t len, idx_t &pos, int32_t &result) { - if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { - result = buf[pos++] - '0'; - if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { - result = (buf[pos++] - '0') + result * 10; - } - return true; - } - return false; -} - -bool Date::TryConvertDateSpecial(const char *buf, idx_t len, idx_t &pos, const char *special) { - auto p = pos; - for (; p < len && *special; ++p) { - const auto s = *special++; - if (!s || StringUtil::CharacterToLower(buf[p]) != s) { - return false; - } - } - if (*special) { - return false; - } - pos = p; - return true; -} - -DateCastResult Date::TryConvertDate(const char *buf, idx_t len, idx_t &pos, date_t &result, bool &special, - bool strict) { - special = false; - pos = 0; - if (len == 0) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - int32_t day = 0; - int32_t month = -1; - int32_t year = 0; - bool yearneg = false; - int sep; - - // skip leading spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - - if (pos >= len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - if (buf[pos] == '-') { - yearneg = true; - pos++; - if (pos >= len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - } - if (!StringUtil::CharacterIsDigit(buf[pos])) { - // Check for special values - if (TryConvertDateSpecial(buf, len, pos, PINF)) { - result = yearneg ? date_t::ninfinity() : date_t::infinity(); - } else if (TryConvertDateSpecial(buf, len, pos, EPOCH)) { - result = date_t::epoch(); - } else { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - // skip trailing spaces - parsing must be strict here - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - special = true; - return (pos == len) ? DateCastResult::SUCCESS : DateCastResult::ERROR_INCORRECT_FORMAT; - } - // first parse the year - idx_t year_length = 0; - for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++) { - if (year >= 100000000) { - return DateCastResult::ERROR_RANGE; - } - year = (buf[pos] - '0') + year * 10; - year_length++; - } - if (year_length < 2 && strict) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - if (yearneg) { - year = -year; - } - - if (pos >= len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - // fetch the separator - sep = buf[pos++]; - if (sep != ' ' && sep != '-' && sep != '/' && sep != '\\') { - // invalid separator - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - // parse the month - if (!Date::ParseDoubleDigit(buf, len, pos, month)) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - if (pos >= len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - if (buf[pos++] != sep) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - if (pos >= len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - // now parse the day - if (!Date::ParseDoubleDigit(buf, len, pos, day)) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - - // check for an optional trailing " (BC)"" - if (len - pos >= 5 && StringUtil::CharacterIsSpace(buf[pos]) && buf[pos + 1] == '(' && - StringUtil::CharacterToLower(buf[pos + 2]) == 'b' && StringUtil::CharacterToLower(buf[pos + 3]) == 'c' && - buf[pos + 4] == ')') { - if (yearneg || year == 0) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - year = -year + 1; - pos += 5; - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - } else { - // in non-strict mode, check for any direct trailing digits - if (pos < len && StringUtil::CharacterIsDigit(buf[pos])) { - return DateCastResult::ERROR_INCORRECT_FORMAT; - } - } - - return Date::TryFromDate(year, month, day, result) ? DateCastResult::SUCCESS : DateCastResult::ERROR_RANGE; -} - -string Date::FormatError(const string &str) { - return StringUtil::Format("invalid date field format: \"%s\", " - "expected format is (YYYY-MM-DD)", - str); -} - -string Date::RangeError(const string &str) { - return StringUtil::Format("date field value out of range: \"%s\"", str); -} - -string Date::RangeError(string_t str) { - return RangeError(str.GetString()); -} - -string Date::FormatError(string_t str) { - return FormatError(str.GetString()); -} - -date_t Date::FromCString(const char *buf, idx_t len, bool strict) { - date_t result; - idx_t pos; - bool special = false; - switch (TryConvertDate(buf, len, pos, result, special, strict)) { - case DateCastResult::ERROR_INCORRECT_FORMAT: - throw ConversionException(FormatError(string(buf, len))); - case DateCastResult::ERROR_RANGE: - throw ConversionException(RangeError(string(buf, len))); - case DateCastResult::SUCCESS: - break; - } - return result; -} - -date_t Date::FromString(const string &str, bool strict) { - return Date::FromCString(str.c_str(), str.size(), strict); -} - -string Date::ToString(date_t date) { - // PG displays temporal infinities in lowercase, - // but numerics in Titlecase. - if (date == date_t::infinity()) { - return PINF; - } else if (date == date_t::ninfinity()) { - return NINF; - } - int32_t date_units[3]; - idx_t year_length; - bool add_bc; - Date::Convert(date, date_units[0], date_units[1], date_units[2]); - - auto length = DateToStringCast::Length(date_units, year_length, add_bc); - auto buffer = make_unsafe_uniq_array_uninitialized(length); - DateToStringCast::Format(buffer.get(), date_units, year_length, add_bc); - return string(buffer.get(), length); -} - -string Date::Format(int32_t year, int32_t month, int32_t day) { - return ToString(Date::FromDate(year, month, day)); -} - -bool Date::IsLeapYear(int32_t year) { - return year % 4 == 0 && (year % 100 != 0 || year % 400 == 0); -} - -bool Date::IsValid(int32_t year, int32_t month, int32_t day) { - if (month < 1 || month > 12) { - return false; - } - if (day < 1) { - return false; - } - if (year <= DATE_MIN_YEAR) { - if (year < DATE_MIN_YEAR) { - return false; - } else if (year == DATE_MIN_YEAR) { - if (month < DATE_MIN_MONTH || (month == DATE_MIN_MONTH && day < DATE_MIN_DAY)) { - return false; - } - } - } - if (year >= DATE_MAX_YEAR) { - if (year > DATE_MAX_YEAR) { - return false; - } else if (year == DATE_MAX_YEAR) { - if (month > DATE_MAX_MONTH || (month == DATE_MAX_MONTH && day > DATE_MAX_DAY)) { - return false; - } - } - } - return Date::IsLeapYear(year) ? day <= Date::LEAP_DAYS[month] : day <= Date::NORMAL_DAYS[month]; -} - -int32_t Date::MonthDays(int32_t year, int32_t month) { - D_ASSERT(month >= 1 && month <= 12); - return Date::IsLeapYear(year) ? Date::LEAP_DAYS[month] : Date::NORMAL_DAYS[month]; -} - -date_t Date::EpochDaysToDate(int32_t epoch) { - return (date_t)epoch; -} - -int32_t Date::EpochDays(date_t date) { - return date.days; -} - -date_t Date::EpochToDate(int64_t epoch) { - return date_t(UnsafeNumericCast(epoch / Interval::SECS_PER_DAY)); -} - -int64_t Date::Epoch(date_t date) { - return ((int64_t)date.days) * Interval::SECS_PER_DAY; -} - -int64_t Date::EpochNanoseconds(date_t date) { - int64_t result; - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY * 1000, - result)) { - throw ConversionException("Could not convert DATE (%s) to nanoseconds", Date::ToString(date)); - } - return result; -} - -int64_t Date::EpochMicroseconds(date_t date) { - int64_t result; - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result)) { - throw ConversionException("Could not convert DATE (%s) to microseconds", Date::ToString(date)); - } - return result; -} - -int64_t Date::EpochMilliseconds(date_t date) { - int64_t result; - const auto MILLIS_PER_DAY = Interval::MICROS_PER_DAY / Interval::MICROS_PER_MSEC; - if (!TryMultiplyOperator::Operation(date.days, MILLIS_PER_DAY, result)) { - throw ConversionException("Could not convert DATE (%s) to milliseconds", Date::ToString(date)); - } - return result; -} - -int32_t Date::ExtractYear(date_t d) { - int32_t year, year_offset; - Date::ExtractYearOffset(d.days, year, year_offset); - return year; -} - -int32_t Date::ExtractMonth(date_t date) { - int32_t out_year, out_month, out_day; - Date::Convert(date, out_year, out_month, out_day); - return out_month; -} - -int32_t Date::ExtractDay(date_t date) { - int32_t out_year, out_month, out_day; - Date::Convert(date, out_year, out_month, out_day); - return out_day; -} - -int32_t Date::ExtractDayOfTheYear(date_t date) { - int32_t year, year_offset; - Date::ExtractYearOffset(date.days, year, year_offset); - return date.days - Date::CUMULATIVE_YEAR_DAYS[year_offset] + 1; -} - -int64_t Date::ExtractJulianDay(date_t date) { - // Julian Day 0 is (-4713, 11, 24) in the proleptic Gregorian calendar. - static const int64_t JULIAN_EPOCH = -2440588; - return date.days - JULIAN_EPOCH; -} - -int32_t Date::ExtractISODayOfTheWeek(date_t date) { - // date of 0 is 1970-01-01, which was a Thursday (4) - // -7 = 4 - // -6 = 5 - // -5 = 6 - // -4 = 7 - // -3 = 1 - // -2 = 2 - // -1 = 3 - // 0 = 4 - // 1 = 5 - // 2 = 6 - // 3 = 7 - // 4 = 1 - // 5 = 2 - // 6 = 3 - // 7 = 4 - if (date.days < 0) { - // negative date: start off at 4 and cycle downwards - return UnsafeNumericCast((7 - ((-int64_t(date.days) + 3) % 7))); - } else { - // positive date: start off at 4 and cycle upwards - return UnsafeNumericCast(((int64_t(date.days) + 3) % 7) + 1); - } -} - -template -static T PythonDivMod(const T &x, const T &y, T &r) { - // D_ASSERT(y > 0); - T quo = x / y; - r = x - quo * y; - if (r < 0) { - --quo; - r += y; - } - // D_ASSERT(0 <= r && r < y); - return quo; -} - -static date_t GetISOWeekOne(int32_t year) { - const auto first_day = Date::FromDate(year, 1, 1); /* ord of 1/1 */ - /* 0 if 1/1 is a Monday, 1 if a Tue, etc. */ - const auto first_weekday = Date::ExtractISODayOfTheWeek(first_day) - 1; - /* ordinal of closest Monday at or before 1/1 */ - auto week1_monday = first_day - first_weekday; - - if (first_weekday > 3) { /* if 1/1 was Fri, Sat, Sun */ - week1_monday += 7; - } - - return week1_monday; -} - -static int32_t GetISOYearWeek(const date_t date, int32_t &year) { - int32_t month, day; - Date::Convert(date, year, month, day); - auto week1_monday = GetISOWeekOne(year); - auto week = PythonDivMod((date.days - week1_monday.days), 7, day); - if (week < 0) { - week1_monday = GetISOWeekOne(--year); - week = PythonDivMod((date.days - week1_monday.days), 7, day); - } else if (week >= 52 && date >= GetISOWeekOne(year + 1)) { - ++year; - week = 0; - } - - return week + 1; -} - -void Date::ExtractISOYearWeek(date_t date, int32_t &year, int32_t &week) { - week = GetISOYearWeek(date, year); -} - -int32_t Date::ExtractISOWeekNumber(date_t date) { - int32_t year, week; - ExtractISOYearWeek(date, year, week); - return week; -} - -int32_t Date::ExtractISOYearNumber(date_t date) { - int32_t year, week; - ExtractISOYearWeek(date, year, week); - return year; -} - -int32_t Date::ExtractWeekNumberRegular(date_t date, bool monday_first) { - int32_t year, month, day; - Date::Convert(date, year, month, day); - month -= 1; - day -= 1; - // get the day of the year - auto day_of_the_year = - (Date::IsLeapYear(year) ? Date::CUMULATIVE_LEAP_DAYS[month] : Date::CUMULATIVE_DAYS[month]) + day; - // now figure out the first monday or sunday of the year - // what day is January 1st? - auto day_of_jan_first = Date::ExtractISODayOfTheWeek(Date::FromDate(year, 1, 1)); - // monday = 1, sunday = 7 - int32_t first_week_start; - if (monday_first) { - // have to find next "1" - if (day_of_jan_first == 1) { - // jan 1 is monday: starts immediately - first_week_start = 0; - } else { - // jan 1 is not monday: count days until next monday - first_week_start = 8 - day_of_jan_first; - } - } else { - first_week_start = 7 - day_of_jan_first; - } - if (day_of_the_year < first_week_start) { - // day occurs before first week starts: week 0 - return 0; - } - return ((day_of_the_year - first_week_start) / 7) + 1; -} - -// Returns the date of the monday of the current week. -date_t Date::GetMondayOfCurrentWeek(date_t date) { - int32_t dotw = Date::ExtractISODayOfTheWeek(date); - return date - (dotw - 1); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/decimal.cpp b/src/duckdb/src/common/types/decimal.cpp deleted file mode 100644 index 5ecb39a0a..000000000 --- a/src/duckdb/src/common/types/decimal.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "duckdb/common/types/decimal.hpp" - -#include "duckdb/common/types/cast_helpers.hpp" - -namespace duckdb { - -template -string TemplatedDecimalToString(SIGNED value, uint8_t width, uint8_t scale) { - auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array_uninitialized(UnsafeNumericCast(len + 1)); - DecimalToString::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); - return string(data.get(), UnsafeNumericCast(len)); -} - -string Decimal::ToString(int16_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(int32_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(int64_t value, uint8_t width, uint8_t scale) { - return TemplatedDecimalToString(value, width, scale); -} - -string Decimal::ToString(hugeint_t value, uint8_t width, uint8_t scale) { - auto len = DecimalToString::DecimalLength(value, width, scale); - auto data = make_unsafe_uniq_array_uninitialized(UnsafeNumericCast(len + 1)); - DecimalToString::FormatDecimal(value, width, scale, data.get(), UnsafeNumericCast(len)); - return string(data.get(), UnsafeNumericCast(len)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/hash.cpp b/src/duckdb/src/common/types/hash.cpp deleted file mode 100644 index 83a1ef223..000000000 --- a/src/duckdb/src/common/types/hash.cpp +++ /dev/null @@ -1,147 +0,0 @@ -#include "duckdb/common/types/hash.hpp" - -#include "duckdb/common/helper.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/uhugeint.hpp" - -#include -#include - -namespace duckdb { - -template <> -hash_t Hash(uint64_t val) { - return MurmurHash64(val); -} - -template <> -hash_t Hash(int64_t val) { - return MurmurHash64((uint64_t)val); -} - -template <> -hash_t Hash(hugeint_t val) { - return MurmurHash64(val.lower) ^ MurmurHash64(static_cast(val.upper)); -} - -template <> -hash_t Hash(uhugeint_t val) { - return MurmurHash64(val.lower) ^ MurmurHash64(val.upper); -} - -template -struct FloatingPointEqualityTransform { - static void OP(T &val) { - if (val == (T)0.0) { - // Turn negative zero into positive zero - val = (T)0.0; - } else if (std::isnan(val)) { - val = std::numeric_limits::quiet_NaN(); - } - } -}; - -template <> -hash_t Hash(float val) { - static_assert(sizeof(float) == sizeof(uint32_t), ""); - FloatingPointEqualityTransform::OP(val); - uint32_t uval = Load(const_data_ptr_cast(&val)); - return MurmurHash64(uval); -} - -template <> -hash_t Hash(double val) { - static_assert(sizeof(double) == sizeof(uint64_t), ""); - FloatingPointEqualityTransform::OP(val); - uint64_t uval = Load(const_data_ptr_cast(&val)); - return MurmurHash64(uval); -} - -template <> -hash_t Hash(interval_t val) { - int64_t months, days, micros; - val.Normalize(months, days, micros); - return Hash(days) ^ Hash(months) ^ Hash(micros); -} - -template <> -hash_t Hash(const char *str) { - return Hash(str, strlen(str)); -} - -template <> -hash_t Hash(string_t val) { - return Hash(val.GetData(), val.GetSize()); -} - -template <> -hash_t Hash(char *val) { - return Hash(val); -} - -// MIT License -// Copyright (c) 2018-2021 Martin Ankerl -// https://github.com/martinus/robin-hood-hashing/blob/3.11.5/LICENSE -hash_t HashBytes(void *ptr, size_t len) noexcept { - static constexpr uint64_t M = UINT64_C(0xc6a4a7935bd1e995); - static constexpr uint64_t SEED = UINT64_C(0xe17a1465); - static constexpr unsigned int R = 47; - - auto const *const data64 = static_cast(ptr); - uint64_t h = SEED ^ (len * M); - - size_t const n_blocks = len / 8; - for (size_t i = 0; i < n_blocks; ++i) { - auto k = Load(reinterpret_cast(data64 + i)); - - k *= M; - k ^= k >> R; - k *= M; - - h ^= k; - h *= M; - } - - auto const *const data8 = reinterpret_cast(data64 + n_blocks); - switch (len & 7U) { - case 7: - h ^= static_cast(data8[6]) << 48U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 6: - h ^= static_cast(data8[5]) << 40U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 5: - h ^= static_cast(data8[4]) << 32U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 4: - h ^= static_cast(data8[3]) << 24U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 3: - h ^= static_cast(data8[2]) << 16U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 2: - h ^= static_cast(data8[1]) << 8U; - DUCKDB_EXPLICIT_FALLTHROUGH; - case 1: - h ^= static_cast(data8[0]); - h *= M; - DUCKDB_EXPLICIT_FALLTHROUGH; - default: - break; - } - h ^= h >> R; - h *= M; - h ^= h >> R; - return static_cast(h); -} - -hash_t Hash(const char *val, size_t size) { - return HashBytes((void *)val, size); -} - -hash_t Hash(uint8_t *val, size_t size) { - return HashBytes((void *)val, size); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/hugeint.cpp b/src/duckdb/src/common/types/hugeint.cpp deleted file mode 100644 index b19a03b70..000000000 --- a/src/duckdb/src/common/types/hugeint.cpp +++ /dev/null @@ -1,998 +0,0 @@ -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/windows_undefs.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// String Conversion -//===--------------------------------------------------------------------===// -const hugeint_t Hugeint::POWERS_OF_TEN[] { - hugeint_t(1), - hugeint_t(10), - hugeint_t(100), - hugeint_t(1000), - hugeint_t(10000), - hugeint_t(100000), - hugeint_t(1000000), - hugeint_t(10000000), - hugeint_t(100000000), - hugeint_t(1000000000), - hugeint_t(10000000000), - hugeint_t(100000000000), - hugeint_t(1000000000000), - hugeint_t(10000000000000), - hugeint_t(100000000000000), - hugeint_t(1000000000000000), - hugeint_t(10000000000000000), - hugeint_t(100000000000000000), - hugeint_t(1000000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10), - hugeint_t(1000000000000000000) * hugeint_t(100), - hugeint_t(1000000000000000000) * hugeint_t(1000), - hugeint_t(1000000000000000000) * hugeint_t(10000), - hugeint_t(1000000000000000000) * hugeint_t(100000), - hugeint_t(1000000000000000000) * hugeint_t(1000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(10000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(100000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(10), - hugeint_t(1000000000000000000) * hugeint_t(1000000000000000000) * hugeint_t(100)}; - -//===--------------------------------------------------------------------===// -// Negate -//===--------------------------------------------------------------------===// - -template <> -void Hugeint::NegateInPlace(hugeint_t &input) { - input.lower = NumericLimits::Maximum() - input.lower + 1ull; - input.upper = -1 - input.upper + (input.lower == 0); -} - -bool Hugeint::TryNegate(hugeint_t input, hugeint_t &result) { - if (input.upper == NumericLimits::Minimum() && input.lower == 0) { - return false; - } - NegateInPlace(input); - result = input; - return true; -} - -hugeint_t Hugeint::Abs(hugeint_t n) { - if (n < 0) { - return Hugeint::Negate(n); - } else { - return n; - } -} - -//===--------------------------------------------------------------------===// -// Divide -//===--------------------------------------------------------------------===// - -static uint8_t PositiveHugeintHighestBit(hugeint_t bits) { - uint8_t out = 0; - if (bits.upper) { - out = 64; - uint64_t up = static_cast(bits.upper); - while (up) { - up >>= 1; - out++; - } - } else { - uint64_t low = bits.lower; - while (low) { - low >>= 1; - out++; - } - } - return out; -} - -static bool PositiveHugeintIsBitSet(hugeint_t lhs, uint8_t bit_position) { - if (bit_position < 64) { - return lhs.lower & (uint64_t(1) << uint64_t(bit_position)); - } else { - return static_cast(lhs.upper) & (uint64_t(1) << uint64_t(bit_position - 64)); - } -} - -static hugeint_t PositiveHugeintLeftShift(hugeint_t lhs, uint32_t amount) { - D_ASSERT(amount > 0 && amount < 64); - hugeint_t result; - result.lower = lhs.lower << amount; - result.upper = - UnsafeNumericCast((UnsafeNumericCast(lhs.upper) << amount) + (lhs.lower >> (64 - amount))); - return result; -} - -hugeint_t Hugeint::DivModPositive(hugeint_t lhs, uint64_t rhs, uint64_t &remainder) { - D_ASSERT(lhs.upper >= 0); - // DivMod code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // initialize the result and remainder to 0 - hugeint_t div_result; - div_result.lower = 0; - div_result.upper = 0; - remainder = 0; - - uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); - // now iterate over the amount of bits that are set in the LHS - for (uint8_t x = highest_bit_set; x > 0; x--) { - // left-shift the current result and remainder by 1 - div_result = PositiveHugeintLeftShift(div_result, 1); - remainder <<= 1; - // we get the value of the bit at position X, where position 0 is the least-significant bit - if (PositiveHugeintIsBitSet(lhs, x - 1)) { - // increment the remainder - remainder++; - } - if (remainder >= rhs) { - // the remainder has passed the division multiplier: add one to the divide result - remainder -= rhs; - div_result.lower++; - if (div_result.lower == 0) { - // overflow - div_result.upper++; - } - } - } - return div_result; -} - -string Hugeint::ToString(hugeint_t input) { - uint64_t remainder; - string result; - if (input == NumericLimits::Minimum()) { - return string(Hugeint::HUGEINT_MINIMUM_STRING); - } - bool negative = input.upper < 0; - if (negative) { - NegateInPlace(input); - } - while (true) { - if (!input.lower && !input.upper) { - break; - } - input = Hugeint::DivModPositive(input, 10, remainder); - result = string(1, UnsafeNumericCast('0' + remainder)) + result; // NOLINT - } - if (result.empty()) { - // value is zero - return "0"; - } - return negative ? "-" + result : result; -} - -//===--------------------------------------------------------------------===// -// Multiply -//===--------------------------------------------------------------------===// - -// Multiply with overflow checks -bool Hugeint::TryMultiply(hugeint_t lhs, hugeint_t rhs, hugeint_t &result) { - // Check if one of the sides is hugeint_t minimum, as that can't be negated. - // You can only multiply the minimum by 0 or 1, any other value will result in overflow - if (lhs == NumericLimits::Minimum() || rhs == NumericLimits::Minimum()) { - if (lhs == 0 || rhs == 0) { - result = 0; - return true; - } - if (lhs == 1 || rhs == 1) { - result = NumericLimits::Minimum(); - return true; - } - return false; - } - - bool lhs_negative = lhs.upper < 0; - bool rhs_negative = rhs.upper < 0; - if (lhs_negative && !TryNegate(lhs, lhs)) { - return false; - } - if (rhs_negative && !TryNegate(rhs, rhs)) { - return false; - } - -#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) - __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); - __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); - __uint128_t result_i128; - if (__builtin_mul_overflow(left, right, &result_i128)) { - return false; - } - uint64_t upper = uint64_t(result_i128 >> 64); - if (upper & 0x8000000000000000) { - return false; - } - result.upper = int64_t(upper); - result.lower = uint64_t(result_i128 & 0xffffffffffffffff); -#else - // Multiply code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // split values into 4 32-bit parts - uint64_t top[4] = {uint64_t(lhs.upper) >> 32, uint64_t(lhs.upper) & 0xffffffff, lhs.lower >> 32, - lhs.lower & 0xffffffff}; - uint64_t bottom[4] = {uint64_t(rhs.upper) >> 32, uint64_t(rhs.upper) & 0xffffffff, rhs.lower >> 32, - rhs.lower & 0xffffffff}; - uint64_t products[4][4]; - - // multiply each component of the values - for (auto x = 0; x < 4; x++) { - for (auto y = 0; y < 4; y++) { - products[x][y] = top[x] * bottom[y]; - } - } - - // if any of these products are set to a non-zero value, there is always an overflow - if (products[0][0] || products[0][1] || products[0][2] || products[1][0] || products[2][0] || products[1][1]) { - return false; - } - // if the high bits of any of these are set, there is always an overflow - if ((products[0][3] & 0xffffffff80000000) || (products[1][2] & 0xffffffff80000000) || - (products[2][1] & 0xffffffff80000000) || (products[3][0] & 0xffffffff80000000)) { - return false; - } - - // otherwise we merge the result of the different products together in-order - - // first row - uint64_t fourth32 = (products[3][3] & 0xffffffff); - uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); - uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); - uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); - - // second row - third32 += (products[2][3] & 0xffffffff); - second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); - first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); - - // third row - second32 += (products[1][3] & 0xffffffff); - first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); - - // fourth row - first32 += (products[0][3] & 0xffffffff); - - // move carry to next digit - third32 += fourth32 >> 32; - second32 += third32 >> 32; - first32 += second32 >> 32; - - // check if the combination of the different products resulted in an overflow - if (first32 & 0xffffff80000000) { - return false; - } - - // remove carry from current digit - fourth32 &= 0xffffffff; - third32 &= 0xffffffff; - second32 &= 0xffffffff; - first32 &= 0xffffffff; - - // combine components - result.lower = (third32 << 32) | fourth32; - result.upper = (first32 << 32) | second32; -#endif - if (lhs_negative ^ rhs_negative) { - NegateInPlace(result); - } - return true; -} - -// Multiply without overflow check -template <> -hugeint_t Hugeint::Multiply(hugeint_t lhs, hugeint_t rhs) { - hugeint_t result; - bool lhs_negative = lhs.upper < 0; - bool rhs_negative = rhs.upper < 0; - if (lhs_negative) { - NegateInPlace(lhs); - } - if (rhs_negative) { - NegateInPlace(rhs); - } - -#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) - __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); - __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); - __uint128_t result_i128; - result_i128 = left * right; - uint64_t upper = uint64_t(result_i128 >> 64); - result.upper = int64_t(upper); - result.lower = uint64_t(result_i128 & 0xffffffffffffffff); -#else - // Multiply code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // split values into 4 32-bit parts - uint64_t top[4] = {uint64_t(lhs.upper) >> 32, uint64_t(lhs.upper) & 0xffffffff, lhs.lower >> 32, - lhs.lower & 0xffffffff}; - uint64_t bottom[4] = {uint64_t(rhs.upper) >> 32, uint64_t(rhs.upper) & 0xffffffff, rhs.lower >> 32, - rhs.lower & 0xffffffff}; - uint64_t products[4][4]; - - // multiply each component of the values - for (auto x = 0; x < 4; x++) { - for (auto y = 0; y < 4; y++) { - products[x][y] = top[x] * bottom[y]; - } - } - - // first row - uint64_t fourth32 = (products[3][3] & 0xffffffff); - uint64_t third32 = (products[3][2] & 0xffffffff) + (products[3][3] >> 32); - uint64_t second32 = (products[3][1] & 0xffffffff) + (products[3][2] >> 32); - uint64_t first32 = (products[3][0] & 0xffffffff) + (products[3][1] >> 32); - - // second row - third32 += (products[2][3] & 0xffffffff); - second32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); - first32 += (products[2][1] & 0xffffffff) + (products[2][2] >> 32); - - // third row - second32 += (products[1][3] & 0xffffffff); - first32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); - - // fourth row - first32 += (products[0][3] & 0xffffffff); - - // move carry to next digit - third32 += fourth32 >> 32; - second32 += third32 >> 32; - first32 += second32 >> 32; - - // remove carry from current digit - fourth32 &= 0xffffffff; - third32 &= 0xffffffff; - second32 &= 0xffffffff; - first32 &= 0xffffffff; - - // combine components - result.lower = (third32 << 32) | fourth32; - result.upper = (first32 << 32) | second32; -#endif - if (lhs_negative ^ rhs_negative) { - NegateInPlace(result); - } - return result; -} - -//===--------------------------------------------------------------------===// -// Divide -//===--------------------------------------------------------------------===// - -int Sign(hugeint_t n) { - return ((n > 0) - (n < 0)); -} - -hugeint_t Abs(hugeint_t n) { - D_ASSERT(n != NumericLimits::Minimum()); - return (n * Sign(n)); -} - -static hugeint_t DivModMinimum(hugeint_t lhs, hugeint_t rhs, hugeint_t &remainder) { - D_ASSERT(lhs == NumericLimits::Minimum() || rhs == NumericLimits::Minimum()); - if (rhs == NumericLimits::Minimum()) { - if (lhs == NumericLimits::Minimum()) { - remainder = 0; - return 1; - } - remainder = lhs; - return 0; - } - - // Add 1 to minimum and run through DivMod again - hugeint_t result = Hugeint::DivMod(NumericLimits::Minimum() + 1, rhs, remainder); - - // If the 1 mattered we need to adjust the result, otherwise the remainder - if (Abs(remainder) + 1 == Abs(rhs)) { - result -= Sign(rhs); - remainder = 0; - } else { - remainder -= 1; - } - return result; -} - -// No overflow checks -hugeint_t Hugeint::DivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &remainder) { - if (rhs == 0) { - remainder = lhs; - return hugeint_t(0); - } - - // Check if one of the sides is hugeint_t minimum, as that can't be negated. - if (lhs == NumericLimits::Minimum() || rhs == NumericLimits::Minimum()) { - return DivModMinimum(lhs, rhs, remainder); - } - - bool lhs_negative = lhs.upper < 0; - bool rhs_negative = rhs.upper < 0; - if (lhs_negative) { - Hugeint::NegateInPlace(lhs); - } - if (rhs_negative) { - Hugeint::NegateInPlace(rhs); - } - // DivMod code adapted from: - // https://github.com/calccrypto/uint128_t/blob/master/uint128_t.cpp - - // initialize the result and remainder to 0 - hugeint_t div_result; - div_result.lower = 0; - div_result.upper = 0; - remainder.lower = 0; - remainder.upper = 0; - - uint8_t highest_bit_set = PositiveHugeintHighestBit(lhs); - // now iterate over the amount of bits that are set in the LHS - for (uint8_t x = highest_bit_set; x > 0; x--) { - // left-shift the current result and remainder by 1 - div_result = PositiveHugeintLeftShift(div_result, 1); - remainder = PositiveHugeintLeftShift(remainder, 1); - - // we get the value of the bit at position X, where position 0 is the least-significant bit - if (PositiveHugeintIsBitSet(lhs, x - 1)) { - remainder += 1; - } - if (Hugeint::GreaterThanEquals(remainder, rhs)) { - // the remainder has passed the division multiplier: add one to the divide result - remainder -= rhs; - div_result += 1; - } - } - if (lhs_negative ^ rhs_negative) { - Hugeint::NegateInPlace(div_result); - } - if (lhs_negative) { - Hugeint::NegateInPlace(remainder); - } - return div_result; -} - -bool Hugeint::TryDivMod(hugeint_t lhs, hugeint_t rhs, hugeint_t &result, hugeint_t &remainder) { - // No division by zero - if (rhs == 0) { - return false; - } - - // division only has one reason to overflow: MINIMUM / -1 - if (lhs == NumericLimits::Minimum() && rhs == -1) { - return false; - } - - result = Hugeint::DivMod(lhs, rhs, remainder); - return true; -} - -template <> -hugeint_t Hugeint::Divide(hugeint_t lhs, hugeint_t rhs) { - hugeint_t remainder; - return Hugeint::DivMod(lhs, rhs, remainder); -} - -template <> -hugeint_t Hugeint::Modulo(hugeint_t lhs, hugeint_t rhs) { - hugeint_t remainder; - (void)Hugeint::DivMod(lhs, rhs, remainder); - return remainder; -} - -//===--------------------------------------------------------------------===// -// Add/Subtract -//===--------------------------------------------------------------------===// -bool Hugeint::TryAddInPlace(hugeint_t &lhs, hugeint_t rhs) { - int overflow = lhs.lower + rhs.lower < lhs.lower; - if (rhs.upper >= 0) { - // RHS is positive: check for overflow - if (lhs.upper > (std::numeric_limits::max() - rhs.upper - overflow)) { - return false; - } - lhs.upper = lhs.upper + overflow + rhs.upper; - } else { - // RHS is negative: check for underflow - if (lhs.upper < std::numeric_limits::min() - rhs.upper - overflow) { - return false; - } - lhs.upper = lhs.upper + (overflow + rhs.upper); - } - lhs.lower += rhs.lower; - return true; -} - -bool Hugeint::TrySubtractInPlace(hugeint_t &lhs, hugeint_t rhs) { - // underflow - int underflow = lhs.lower - rhs.lower > lhs.lower; - if (rhs.upper >= 0) { - // RHS is positive: check for underflow - if (lhs.upper < (std::numeric_limits::min() + rhs.upper + underflow)) { - return false; - } - lhs.upper = (lhs.upper - rhs.upper) - underflow; - } else { - // RHS is negative: check for overflow - if (lhs.upper > std::numeric_limits::min() && - lhs.upper - 1 >= (std::numeric_limits::max() + rhs.upper + underflow)) { - return false; - } - lhs.upper = lhs.upper - (rhs.upper + underflow); - } - lhs.lower -= rhs.lower; - return true; -} - -template <> -hugeint_t Hugeint::Add(hugeint_t lhs, hugeint_t rhs) { - return lhs + rhs; -} - -template <> -hugeint_t Hugeint::Subtract(hugeint_t lhs, hugeint_t rhs) { - return lhs - rhs; -} - -//===--------------------------------------------------------------------===// -// Hugeint Cast/Conversion -//===--------------------------------------------------------------------===// -template -bool HugeintTryCastInteger(hugeint_t input, DST &result) { - switch (input.upper) { - case 0: - // positive number: check if the positive number is in range - if (input.lower <= uint64_t(NumericLimits::Maximum())) { - result = DST(input.lower); - return true; - } - break; - case -1: - if (!SIGNED) { - return false; - } - // negative number: check if the negative number is in range - if (input.lower >= NumericLimits::Maximum() - uint64_t(NumericLimits::Maximum())) { - result = -DST(NumericLimits::Maximum() - input.lower) - 1; - return true; - } - break; - default: - break; - } - return false; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int8_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int16_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int32_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, int64_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint8_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint16_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint32_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uint64_t &result) { - return HugeintTryCastInteger(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, hugeint_t &result) { - result = input; - return true; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, uhugeint_t &result) { - if (input < 0) { - return false; - } - - result.lower = input.lower; - result.upper = UnsafeNumericCast(input.upper); - return true; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, float &result) { - double dbl_result; - Hugeint::TryCast(input, dbl_result); - result = (float)dbl_result; - return true; -} - -template -bool CastBigintToFloating(hugeint_t input, REAL_T &result) { - switch (input.upper) { - case -1: - // special case for upper = -1 to avoid rounding issues in small negative numbers - result = -REAL_T(NumericLimits::Maximum() - input.lower) - 1; - break; - default: - result = REAL_T(input.lower) + REAL_T(input.upper) * (REAL_T(NumericLimits::Maximum()) + 1); - break; - } - return true; -} - -template <> -bool Hugeint::TryCast(hugeint_t input, double &result) { - return CastBigintToFloating(input, result); -} - -template <> -bool Hugeint::TryCast(hugeint_t input, long double &result) { - return CastBigintToFloating(input, result); -} - -template -hugeint_t HugeintConvertInteger(DST input) { - hugeint_t result; - result.lower = (uint64_t)input; - result.upper = (input < 0) * -1; - return result; -} - -template <> -bool Hugeint::TryConvert(int8_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(const char *value, hugeint_t &result) { - auto len = strlen(value); - string_t string_val(value, UnsafeNumericCast(len)); - return TryCast::Operation(string_val, result, true); -} - -template <> -bool Hugeint::TryConvert(int16_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(int32_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(int64_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint8_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint16_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint32_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} -template <> -bool Hugeint::TryConvert(uint64_t value, hugeint_t &result) { - result = HugeintConvertInteger(value); - return true; -} - -template <> -bool Hugeint::TryConvert(hugeint_t value, hugeint_t &result) { - result = value; - return true; -} - -template <> -bool Hugeint::TryConvert(float value, hugeint_t &result) { - return Hugeint::TryConvert(double(value), result); -} - -template -bool ConvertFloatingToBigint(REAL_T value, hugeint_t &result) { - if (!Value::IsFinite(value)) { - return false; - } - if (value <= -170141183460469231731687303715884105728.0 || value >= 170141183460469231731687303715884105727.0) { - return false; - } - bool negative = value < 0; - if (negative) { - value = -value; - } - result.lower = (uint64_t)fmod(value, REAL_T(NumericLimits::Maximum())); - result.upper = (int64_t)(value / REAL_T(NumericLimits::Maximum())); - if (negative) { - Hugeint::NegateInPlace(result); - } - return true; -} - -template <> -bool Hugeint::TryConvert(double value, hugeint_t &result) { - return ConvertFloatingToBigint(value, result); -} - -template <> -bool Hugeint::TryConvert(long double value, hugeint_t &result) { - return ConvertFloatingToBigint(value, result); -} - -//===--------------------------------------------------------------------===// -// hugeint_t operators -//===--------------------------------------------------------------------===// -hugeint_t::hugeint_t(int64_t value) { - auto result = Hugeint::Convert(value); - this->lower = result.lower; - this->upper = result.upper; -} - -bool hugeint_t::operator==(const hugeint_t &rhs) const { - return Hugeint::Equals(*this, rhs); -} - -bool hugeint_t::operator!=(const hugeint_t &rhs) const { - return Hugeint::NotEquals(*this, rhs); -} - -bool hugeint_t::operator<(const hugeint_t &rhs) const { - return Hugeint::LessThan(*this, rhs); -} - -bool hugeint_t::operator<=(const hugeint_t &rhs) const { - return Hugeint::LessThanEquals(*this, rhs); -} - -bool hugeint_t::operator>(const hugeint_t &rhs) const { - return Hugeint::GreaterThan(*this, rhs); -} - -bool hugeint_t::operator>=(const hugeint_t &rhs) const { - return Hugeint::GreaterThanEquals(*this, rhs); -} - -hugeint_t hugeint_t::operator+(const hugeint_t &rhs) const { - return hugeint_t(upper + rhs.upper + ((lower + rhs.lower) < lower), lower + rhs.lower); -} - -hugeint_t hugeint_t::operator-(const hugeint_t &rhs) const { - return hugeint_t(upper - rhs.upper - ((lower - rhs.lower) > lower), lower - rhs.lower); -} - -hugeint_t hugeint_t::operator*(const hugeint_t &rhs) const { - hugeint_t result = *this; - result *= rhs; - return result; -} - -hugeint_t hugeint_t::operator/(const hugeint_t &rhs) const { - return Hugeint::Divide(*this, rhs); -} - -hugeint_t hugeint_t::operator%(const hugeint_t &rhs) const { - return Hugeint::Modulo(*this, rhs); -} - -hugeint_t hugeint_t::operator-() const { - return Hugeint::Negate(*this); -} - -hugeint_t hugeint_t::operator>>(const hugeint_t &rhs) const { - hugeint_t result; - uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return hugeint_t(0); - } else if (shift == 0) { - return *this; - } else if (shift == 64) { - result.upper = (upper < 0) ? -1 : 0; - result.lower = uint64_t(upper); - } else if (shift < 64) { - // perform lower shift in unsigned integer, and mask away the most significant bit - result.lower = (uint64_t(upper) << (64 - shift)) | (lower >> shift); - result.upper = upper >> shift; - } else { - D_ASSERT(shift < 128); - result.lower = uint64_t(upper >> (shift - 64)); - result.upper = (upper < 0) ? -1 : 0; - } - return result; -} - -hugeint_t hugeint_t::operator<<(const hugeint_t &rhs) const { - if (upper < 0) { - return hugeint_t(0); - } - hugeint_t result; - uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return hugeint_t(0); - } else if (shift == 64) { - result.upper = int64_t(lower); - result.lower = 0; - } else if (shift == 0) { - return *this; - } else if (shift < 64) { - // perform upper shift in unsigned integer, and mask away the most significant bit - uint64_t upper_shift = ((uint64_t(upper) << shift) + (lower >> (64 - shift))) & 0x7FFFFFFFFFFFFFFF; - result.lower = lower << shift; - result.upper = int64_t(upper_shift); - } else { - D_ASSERT(shift < 128); - result.lower = 0; - result.upper = UnsafeNumericCast((lower << (shift - 64)) & 0x7FFFFFFFFFFFFFFF); - } - return result; -} - -hugeint_t hugeint_t::operator&(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower & rhs.lower; - result.upper = upper & rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator|(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower | rhs.lower; - result.upper = upper | rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator^(const hugeint_t &rhs) const { - hugeint_t result; - result.lower = lower ^ rhs.lower; - result.upper = upper ^ rhs.upper; - return result; -} - -hugeint_t hugeint_t::operator~() const { - hugeint_t result; - result.lower = ~lower; - result.upper = ~upper; - return result; -} - -hugeint_t &hugeint_t::operator+=(const hugeint_t &rhs) { - *this = *this + rhs; - return *this; -} -hugeint_t &hugeint_t::operator-=(const hugeint_t &rhs) { - *this = *this - rhs; - return *this; -} -hugeint_t &hugeint_t::operator*=(const hugeint_t &rhs) { - *this = Hugeint::Multiply(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator/=(const hugeint_t &rhs) { - *this = Hugeint::Divide(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator%=(const hugeint_t &rhs) { - *this = Hugeint::Modulo(*this, rhs); - return *this; -} -hugeint_t &hugeint_t::operator>>=(const hugeint_t &rhs) { - *this = *this >> rhs; - return *this; -} -hugeint_t &hugeint_t::operator<<=(const hugeint_t &rhs) { - *this = *this << rhs; - return *this; -} -hugeint_t &hugeint_t::operator&=(const hugeint_t &rhs) { - lower &= rhs.lower; - upper &= rhs.upper; - return *this; -} -hugeint_t &hugeint_t::operator|=(const hugeint_t &rhs) { - lower |= rhs.lower; - upper |= rhs.upper; - return *this; -} -hugeint_t &hugeint_t::operator^=(const hugeint_t &rhs) { - lower ^= rhs.lower; - upper ^= rhs.upper; - return *this; -} - -bool hugeint_t::operator!() const { - return *this == 0; -} - -hugeint_t::operator bool() const { - return *this != 0; -} - -template -static T NarrowCast(const hugeint_t &input) { - // NarrowCast is supposed to truncate (take lower) - return static_cast(input.lower); -} - -hugeint_t::operator uint8_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint16_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint32_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uint64_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int8_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int16_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int32_t() const { - return NarrowCast(*this); -} -hugeint_t::operator int64_t() const { - return NarrowCast(*this); -} -hugeint_t::operator uhugeint_t() const { - return {static_cast(this->upper), this->lower}; -} - -string hugeint_t::ToString() const { - return Hugeint::ToString(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/hyperloglog.cpp b/src/duckdb/src/common/types/hyperloglog.cpp deleted file mode 100644 index 3ccd1f0d9..000000000 --- a/src/duckdb/src/common/types/hyperloglog.cpp +++ /dev/null @@ -1,270 +0,0 @@ -#include "duckdb/common/types/hyperloglog.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "hyperloglog.hpp" - -#include - -namespace duckdb_hll { -struct robj; // NOLINT -} - -namespace duckdb { - -idx_t HyperLogLog::Count() const { - uint32_t c[Q + 2] = {0}; - ExtractCounts(c); - return static_cast(EstimateCardinality(c)); -} - -//! Algorithm 2 -void HyperLogLog::Merge(const HyperLogLog &other) { - for (idx_t i = 0; i < M; ++i) { - Update(i, other.k[i]); - } -} - -//! Algorithm 4 -void HyperLogLog::ExtractCounts(uint32_t *c) const { - for (idx_t i = 0; i < M; ++i) { - c[k[i]]++; - } -} - -//! Taken from redis code -static double HLLSigma(double x) { - if (x == 1.) { - return std::numeric_limits::infinity(); - } - double z_prime; - double y = 1; - double z = x; - do { - x *= x; - z_prime = z; - z += x * y; - y += y; - } while (z_prime != z); - return z; -} - -//! Taken from redis code -static double HLLTau(double x) { - if (x == 0. || x == 1.) { - return 0.; - } - double z_prime; - double y = 1.0; - double z = 1 - x; - do { - x = sqrt(x); - z_prime = z; - y *= 0.5; - z -= pow(1 - x, 2) * y; - } while (z_prime != z); - return z / 3; -} - -//! Algorithm 6 -int64_t HyperLogLog::EstimateCardinality(uint32_t *c) { - auto z = M * HLLTau((double(M) - c[Q]) / double(M)); - - for (idx_t k = Q; k >= 1; --k) { - z += c[k]; - z *= 0.5; - } - - z += M * HLLSigma(c[0] / double(M)); - - return llroundl(ALPHA * M * M / z); -} - -void HyperLogLog::Update(Vector &input, Vector &hash_vec, const idx_t count) { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - UnifiedVectorFormat hdata; - hash_vec.ToUnifiedFormat(count, hdata); - const auto hashes = UnifiedVectorFormat::GetData(hdata); - - if (hash_vec.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (idata.validity.RowIsValid(0)) { - InsertElement(hashes[0]); - } - } else { - D_ASSERT(hash_vec.GetVectorType() == VectorType::FLAT_VECTOR); - if (idata.validity.AllValid()) { - for (idx_t i = 0; i < count; ++i) { - const auto hash = hashes[i]; - InsertElement(hash); - } - } else { - for (idx_t i = 0; i < count; ++i) { - if (idata.validity.RowIsValid(idata.sel->get_index(i))) { - const auto hash = hashes[i]; - InsertElement(hash); - } - } - } - } -} - -unique_ptr HyperLogLog::Copy() const { - auto result = make_uniq(); - memcpy(result->k, this->k, sizeof(k)); - D_ASSERT(result->Count() == Count()); - return result; -} - -class HLLV1 { -public: - HLLV1() { - hll = duckdb_hll::hll_create(); - duckdb_hll::hllSparseToDense(hll); - } - - ~HLLV1() { - duckdb_hll::hll_destroy(hll); - } - -public: - static idx_t GetSize() { - return duckdb_hll::get_size(); - } - - data_ptr_t GetPtr() const { - return data_ptr_cast((hll)->ptr); - } - - void ToNew(HyperLogLog &new_hll) const { - const idx_t mult = duckdb_hll::num_registers() / HyperLogLog::M; - // Old implementation used more registers, so we compress the registers, losing some accuracy - for (idx_t i = 0; i < HyperLogLog::M; i++) { - uint8_t max_old = 0; - for (idx_t j = 0; j < mult; j++) { - D_ASSERT(i * mult + j < duckdb_hll::num_registers()); - max_old = MaxValue(max_old, duckdb_hll::get_register(hll, i * mult + j)); - } - new_hll.Update(i, max_old); - } - } - - void FromNew(const HyperLogLog &new_hll) { - const auto new_hll_count = new_hll.Count(); - if (new_hll_count == 0) { - return; - } - - const idx_t mult = duckdb_hll::num_registers() / HyperLogLog::M; - // When going from less to more registers, we cannot just duplicate the registers, - // as each register in the new HLL is the minimum of 'mult' registers in the old HLL. - // Duplicating will make for VERY large over-estimations. Instead, we do the following: - - // Set the first of every 'mult' registers in the old HLL to the value in the new HLL - // This ensures that we can convert NEW to OLD and back to NEW without loss of information - double avg = 0; - for (idx_t i = 0; i < HyperLogLog::M; i++) { - const auto max_new = MinValue(new_hll.GetRegister(i), duckdb_hll::maximum_zeros()); - duckdb_hll::set_register(hll, i * mult, max_new); - avg += static_cast(max_new); - } - avg /= static_cast(HyperLogLog::M); - - // Using the average will ALWAYS overestimate, so we reduce it a bit here - if (avg > 10) { - avg *= 0.75; - } else if (avg > 2) { - avg -= 2; - } - - // Set all other registers to a default value, starting with 0 (the initialization value) - // We optimize the default value in 5 iterations or until OLD count is close to NEW count - double default_val = 0; - for (idx_t opt_idx = 0; opt_idx < 5; opt_idx++) { - if (IsWithinAcceptableRange(new_hll_count, Count())) { - break; - } - - // Delta is half the average, then a quarter, etc. - const double delta = avg / static_cast(1 << (opt_idx + 1)); - if (Count() > new_hll_count) { - default_val = delta > default_val ? 0 : default_val - delta; - } else { - default_val += delta; - } - - // If the default value is, e.g., 3.3, then the first 70% gets value 3, and the rest gets value 4 - const double floor_fraction = 1 - (default_val - floor(default_val)); - for (idx_t i = 0; i < HyperLogLog::M; i++) { - const auto max_new = MinValue(new_hll.GetRegister(i), duckdb_hll::maximum_zeros()); - uint8_t register_value; - if (static_cast(i) / static_cast(HyperLogLog::M) < floor_fraction) { - register_value = ExactNumericCast(floor(default_val)); - } else { - register_value = ExactNumericCast(ceil(default_val)); - } - register_value = MinValue(register_value, max_new); - for (idx_t j = 1; j < mult; j++) { - D_ASSERT(i * mult + j < duckdb_hll::num_registers()); - duckdb_hll::set_register(hll, i * mult + j, register_value); - } - } - } - } - -private: - idx_t Count() const { - size_t result; - if (duckdb_hll::hll_count(hll, &result) != HLL_C_OK) { - throw InternalException("Could not count HLL?"); - } - return result; - } - - bool IsWithinAcceptableRange(const idx_t &new_hll_count, const idx_t &old_hll_count) const { - const auto newd = static_cast(new_hll_count); - const auto oldd = static_cast(old_hll_count); - return MaxValue(newd, oldd) / MinValue(newd, oldd) < ACCEPTABLE_Q_ERROR; - } - -private: - static constexpr double ACCEPTABLE_Q_ERROR = 1.2; - duckdb_hll::robj *hll; -}; - -void HyperLogLog::Serialize(Serializer &serializer) const { - if (serializer.ShouldSerialize(3)) { - serializer.WriteProperty(100, "type", HLLStorageType::HLL_V2); - serializer.WriteProperty(101, "data", k, sizeof(k)); - } else { - auto old = make_uniq(); - old->FromNew(*this); - - serializer.WriteProperty(100, "type", HLLStorageType::HLL_V1); - serializer.WriteProperty(101, "data", old->GetPtr(), old->GetSize()); - } -} - -unique_ptr HyperLogLog::Deserialize(Deserializer &deserializer) { - auto result = make_uniq(); - auto storage_type = deserializer.ReadProperty(100, "type"); - switch (storage_type) { - case HLLStorageType::HLL_V1: { - auto old = make_uniq(); - deserializer.ReadProperty(101, "data", old->GetPtr(), old->GetSize()); - old->ToNew(*result); - break; - } - case HLLStorageType::HLL_V2: - deserializer.ReadProperty(101, "data", result->k, sizeof(k)); - break; - default: - throw SerializationException("Unknown HyperLogLog storage type!"); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/interval.cpp b/src/duckdb/src/common/types/interval.cpp deleted file mode 100644 index 719e9ee2c..000000000 --- a/src/duckdb/src/common/types/interval.cpp +++ /dev/null @@ -1,532 +0,0 @@ -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/enums/date_part_specifier.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/string_util.hpp" - -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -namespace duckdb { - -bool Interval::FromString(const string &str, interval_t &result) { - string error_message; - return Interval::FromCString(str.c_str(), str.size(), result, &error_message, false); -} - -template -void IntervalTryAddition(T &target, int64_t input, int64_t multiplier, int64_t fraction = 0) { - int64_t addition; - if (!TryMultiplyOperator::Operation(input, multiplier, addition)) { - throw OutOfRangeException("interval value is out of range"); - } - T addition_base = Cast::Operation(addition); - if (!TryAddOperator::Operation(target, addition_base, target)) { - throw OutOfRangeException("interval value is out of range"); - } - if (fraction) { - // Add in (fraction * multiplier) / MICROS_PER_SEC - // This is always in range - addition = (fraction * multiplier) / Interval::MICROS_PER_SEC; - addition_base = Cast::Operation(addition); - if (!TryAddOperator::Operation(target, addition_base, target)) { - throw OutOfRangeException("interval fraction is out of range"); - } - } -} - -bool Interval::FromCString(const char *str, idx_t len, interval_t &result, string *error_message, bool strict) { - idx_t pos = 0; - idx_t start_pos; - bool negative; - bool found_any = false; - int64_t number; - int64_t fraction; - DatePartSpecifier specifier; - string specifier_str; - - result.days = 0; - result.micros = 0; - result.months = 0; - - if (len == 0) { - return false; - } - - switch (str[pos]) { - case '@': - pos++; - goto standard_interval; - case 'P': - case 'p': - pos++; - goto posix_interval; - default: - goto standard_interval; - } -standard_interval: - // start parsing a standard interval (e.g. 2 years 3 months...) - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - // skip spaces - continue; - } else if (c >= '0' && c <= '9') { - // start parsing a positive number - negative = false; - goto interval_parse_number; - } else if (c == '-') { - // negative number - negative = true; - pos++; - goto interval_parse_number; - } else if (c == 'a' || c == 'A') { - // parse the word "ago" as the final specifier - goto interval_parse_ago; - } else { - // unrecognized character, expected a number or end of string - return false; - } - } - goto end_of_string; -interval_parse_number: - start_pos = pos; - for (; pos < len; pos++) { - char c = str[pos]; - if (c >= '0' && c <= '9') { - // the number continues - continue; - } else if (c == ':') { - // colon: we are parsing a time - goto interval_parse_time; - } else { - if (pos == start_pos) { - return false; - } - // finished the number, parse it from the string - string_t nr_string(str + start_pos, UnsafeNumericCast(pos - start_pos)); - number = Cast::Operation(nr_string); - fraction = 0; - if (c == '.') { - // we expect some microseconds - int32_t mult = 100000; - for (++pos; pos < len && StringUtil::CharacterIsDigit(str[pos]); ++pos, mult /= 10) { - if (mult > 0) { - fraction += int64_t(str[pos] - '0') * mult; - } - } - } - if (negative) { - number = -number; - fraction = -fraction; - } - goto interval_parse_identifier; - } - } - goto end_of_string; -interval_parse_time : { - // parse the remainder of the time as a Time type - dtime_t time; - idx_t pos; - if (!Time::TryConvertInterval(str + start_pos, len - start_pos, pos, time)) { - return false; - } - result.micros += time.micros; - found_any = true; - if (negative) { - result.micros = -result.micros; - } - goto end_of_string; -} -interval_parse_identifier: - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - // skip spaces at the start - continue; - } else { - break; - } - } - // now parse the identifier - start_pos = pos; - for (; pos < len; pos++) { - char c = str[pos]; - if ((c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z')) { - // keep parsing the string - continue; - } else { - break; - } - } - specifier_str = string(str + start_pos, pos - start_pos); - - // Special case SS[.FFFFFF] - implied SECONDS/MICROSECONDS - if (specifier_str.empty() && !found_any) { - IntervalTryAddition(result.micros, number, MICROS_PER_SEC); - IntervalTryAddition(result.micros, fraction, 1); - found_any = true; - // parse any trailing whitespace - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - continue; - } else { - return false; - } - } - goto end_of_string; - } - - if (!TryGetDatePartSpecifier(specifier_str, specifier)) { - HandleCastError::AssignError(StringUtil::Format("extract specifier \"%s\" not recognized", specifier_str), - error_message); - return false; - } - // add the specifier to the interval - switch (specifier) { - case DatePartSpecifier::MILLENNIUM: - IntervalTryAddition(result.months, number, MONTHS_PER_MILLENIUM, fraction); - break; - case DatePartSpecifier::CENTURY: - IntervalTryAddition(result.months, number, MONTHS_PER_CENTURY, fraction); - break; - case DatePartSpecifier::DECADE: - IntervalTryAddition(result.months, number, MONTHS_PER_DECADE, fraction); - break; - case DatePartSpecifier::YEAR: - IntervalTryAddition(result.months, number, MONTHS_PER_YEAR, fraction); - break; - case DatePartSpecifier::QUARTER: - IntervalTryAddition(result.months, number, MONTHS_PER_QUARTER, fraction); - // Reduce to fraction of a month - fraction *= MONTHS_PER_QUARTER; - fraction %= MICROS_PER_SEC; - IntervalTryAddition(result.days, 0, DAYS_PER_MONTH, fraction); - break; - case DatePartSpecifier::MONTH: - IntervalTryAddition(result.months, number, 1); - IntervalTryAddition(result.days, 0, DAYS_PER_MONTH, fraction); - break; - case DatePartSpecifier::DAY: - IntervalTryAddition(result.days, number, 1); - IntervalTryAddition(result.micros, 0, MICROS_PER_DAY, fraction); - break; - case DatePartSpecifier::WEEK: - IntervalTryAddition(result.days, number, DAYS_PER_WEEK, fraction); - // Reduce to fraction of a day - fraction *= DAYS_PER_WEEK; - fraction %= MICROS_PER_SEC; - IntervalTryAddition(result.micros, 0, MICROS_PER_DAY, fraction); - break; - case DatePartSpecifier::MICROSECONDS: - // Round the fraction - number += (fraction * 2) / MICROS_PER_SEC; - IntervalTryAddition(result.micros, number, 1); - break; - case DatePartSpecifier::MILLISECONDS: - IntervalTryAddition(result.micros, number, MICROS_PER_MSEC, fraction); - break; - case DatePartSpecifier::SECOND: - IntervalTryAddition(result.micros, number, MICROS_PER_SEC, fraction); - break; - case DatePartSpecifier::MINUTE: - IntervalTryAddition(result.micros, number, MICROS_PER_MINUTE, fraction); - break; - case DatePartSpecifier::HOUR: - IntervalTryAddition(result.micros, number, MICROS_PER_HOUR, fraction); - break; - default: - HandleCastError::AssignError( - StringUtil::Format("extract specifier \"%s\" not supported for interval", specifier_str), error_message); - return false; - } - found_any = true; - goto standard_interval; -interval_parse_ago: - D_ASSERT(str[pos] == 'a' || str[pos] == 'A'); - // parse the "ago" string at the end of the interval - if (len - pos < 3) { - return false; - } - pos++; - if (!(str[pos] == 'g' || str[pos] == 'G')) { - return false; - } - pos++; - if (!(str[pos] == 'o' || str[pos] == 'O')) { - return false; - } - pos++; - // parse any trailing whitespace - for (; pos < len; pos++) { - char c = str[pos]; - if (c == ' ' || c == '\t' || c == '\n') { - continue; - } else { - return false; - } - } - // invert all the values - result.months = -result.months; - result.days = -result.days; - result.micros = -result.micros; - goto end_of_string; -end_of_string: - if (!found_any) { - // end of string and no identifiers were found: cannot convert empty interval - return false; - } - return true; -posix_interval: - return false; -} - -string Interval::ToString(const interval_t &interval) { - char buffer[70]; - idx_t length = IntervalToStringCast::Format(interval, buffer); - return string(buffer, length); -} - -int64_t Interval::GetMilli(const interval_t &val) { - int64_t milli_month, milli_day, milli; - if (!TryMultiplyOperator::Operation((int64_t)val.months, Interval::MICROS_PER_MONTH / 1000, milli_month)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - if (!TryMultiplyOperator::Operation((int64_t)val.days, Interval::MICROS_PER_DAY / 1000, milli_day)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - milli = val.micros / 1000; - if (!TryAddOperator::Operation(milli, milli_month, milli)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - if (!TryAddOperator::Operation(milli, milli_day, milli)) { - throw ConversionException("Could not convert Interval to Milliseconds"); - } - return milli; -} - -int64_t Interval::GetMicro(const interval_t &val) { - int64_t micro_month, micro_day, micro_total; - micro_total = val.micros; - if (!TryMultiplyOperator::Operation((int64_t)val.months, MICROS_PER_MONTH, micro_month)) { - throw ConversionException("Could not convert Month to Microseconds"); - } - if (!TryMultiplyOperator::Operation((int64_t)val.days, MICROS_PER_DAY, micro_day)) { - throw ConversionException("Could not convert Day to Microseconds"); - } - if (!TryAddOperator::Operation(micro_total, micro_month, micro_total)) { - throw ConversionException("Could not convert Interval to Microseconds"); - } - if (!TryAddOperator::Operation(micro_total, micro_day, micro_total)) { - throw ConversionException("Could not convert Interval to Microseconds"); - } - - return micro_total; -} - -int64_t Interval::GetNanoseconds(const interval_t &val) { - int64_t nano; - const auto micro_total = GetMicro(val); - if (!TryMultiplyOperator::Operation(micro_total, NANOS_PER_MICRO, nano)) { - throw ConversionException("Could not convert Interval to Nanoseconds"); - } - - return nano; -} - -interval_t Interval::GetAge(timestamp_t timestamp_1, timestamp_t timestamp_2) { - D_ASSERT(Timestamp::IsFinite(timestamp_1) && Timestamp::IsFinite(timestamp_2)); - date_t date1, date2; - dtime_t time1, time2; - - Timestamp::Convert(timestamp_1, date1, time1); - Timestamp::Convert(timestamp_2, date2, time2); - - // and from date extract the years, months and days - int32_t year1, month1, day1; - int32_t year2, month2, day2; - Date::Convert(date1, year1, month1, day1); - Date::Convert(date2, year2, month2, day2); - // finally perform the differences - auto year_diff = year1 - year2; - auto month_diff = month1 - month2; - auto day_diff = day1 - day2; - - // and from time extract hours, minutes, seconds and milliseconds - int32_t hour1, min1, sec1, micros1; - int32_t hour2, min2, sec2, micros2; - Time::Convert(time1, hour1, min1, sec1, micros1); - Time::Convert(time2, hour2, min2, sec2, micros2); - // finally perform the differences - auto hour_diff = hour1 - hour2; - auto min_diff = min1 - min2; - auto sec_diff = sec1 - sec2; - auto micros_diff = micros1 - micros2; - - // flip sign if necessary - bool sign_flipped = false; - if (timestamp_1 < timestamp_2) { - year_diff = -year_diff; - month_diff = -month_diff; - day_diff = -day_diff; - hour_diff = -hour_diff; - min_diff = -min_diff; - sec_diff = -sec_diff; - micros_diff = -micros_diff; - sign_flipped = true; - } - // now propagate any negative field into the next higher field - while (micros_diff < 0) { - micros_diff += MICROS_PER_SEC; - sec_diff--; - } - while (sec_diff < 0) { - sec_diff += SECS_PER_MINUTE; - min_diff--; - } - while (min_diff < 0) { - min_diff += MINS_PER_HOUR; - hour_diff--; - } - while (hour_diff < 0) { - hour_diff += HOURS_PER_DAY; - day_diff--; - } - while (day_diff < 0) { - if (timestamp_1 < timestamp_2) { - day_diff += Date::IsLeapYear(year1) ? Date::LEAP_DAYS[month1] : Date::NORMAL_DAYS[month1]; - month_diff--; - } else { - day_diff += Date::IsLeapYear(year2) ? Date::LEAP_DAYS[month2] : Date::NORMAL_DAYS[month2]; - month_diff--; - } - } - while (month_diff < 0) { - month_diff += MONTHS_PER_YEAR; - year_diff--; - } - - // recover sign if necessary - if (sign_flipped) { - year_diff = -year_diff; - month_diff = -month_diff; - day_diff = -day_diff; - hour_diff = -hour_diff; - min_diff = -min_diff; - sec_diff = -sec_diff; - micros_diff = -micros_diff; - } - interval_t interval; - interval.months = year_diff * MONTHS_PER_YEAR + month_diff; - interval.days = day_diff; - interval.micros = Time::FromTime(hour_diff, min_diff, sec_diff, micros_diff).micros; - - return interval; -} - -interval_t Interval::GetDifference(timestamp_t timestamp_1, timestamp_t timestamp_2) { - if (!Timestamp::IsFinite(timestamp_1) || !Timestamp::IsFinite(timestamp_2)) { - throw InvalidInputException("Cannot subtract infinite timestamps"); - } - const auto us_1 = Timestamp::GetEpochMicroSeconds(timestamp_1); - const auto us_2 = Timestamp::GetEpochMicroSeconds(timestamp_2); - int64_t delta_us; - if (!TrySubtractOperator::Operation(us_1, us_2, delta_us)) { - throw ConversionException("Timestamp difference is out of bounds"); - } - return FromMicro(delta_us); -} - -interval_t Interval::FromMicro(int64_t delta_us) { - interval_t result; - result.months = 0; - result.days = UnsafeNumericCast(delta_us / Interval::MICROS_PER_DAY); - result.micros = delta_us % Interval::MICROS_PER_DAY; - - return result; -} - -interval_t Interval::Invert(interval_t interval) { - interval.days = -interval.days; - interval.micros = -interval.micros; - interval.months = -interval.months; - return interval; -} - -date_t Interval::Add(date_t left, interval_t right) { - if (!Date::IsFinite(left)) { - return left; - } - date_t result; - if (right.months != 0) { - int32_t year, month, day; - Date::Convert(left, year, month, day); - int32_t year_diff = right.months / Interval::MONTHS_PER_YEAR; - year += year_diff; - month += right.months - year_diff * Interval::MONTHS_PER_YEAR; - if (month > Interval::MONTHS_PER_YEAR) { - year++; - month -= Interval::MONTHS_PER_YEAR; - } else if (month <= 0) { - year--; - month += Interval::MONTHS_PER_YEAR; - } - day = MinValue(day, Date::MonthDays(year, month)); - result = Date::FromDate(year, month, day); - } else { - result = left; - } - if (right.days != 0) { - if (!TryAddOperator::Operation(result.days, right.days, result.days)) { - throw OutOfRangeException("Date out of range"); - } - } - if (right.micros != 0) { - if (!TryAddOperator::Operation(result.days, int32_t(right.micros / Interval::MICROS_PER_DAY), result.days)) { - throw OutOfRangeException("Date out of range"); - } - } - if (!Date::IsFinite(result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -dtime_t Interval::Add(dtime_t left, interval_t right, date_t &date) { - int64_t diff = right.micros - ((right.micros / Interval::MICROS_PER_DAY) * Interval::MICROS_PER_DAY); - left += diff; - if (left.micros >= Interval::MICROS_PER_DAY) { - left.micros -= Interval::MICROS_PER_DAY; - date.days++; - } else if (left.micros < 0) { - left.micros += Interval::MICROS_PER_DAY; - date.days--; - } - return left; -} - -dtime_tz_t Interval::Add(dtime_tz_t left, interval_t right, date_t &date) { - return dtime_tz_t(Interval::Add(left.time(), right, date), left.offset()); -} - -timestamp_t Interval::Add(timestamp_t left, interval_t right) { - if (!Timestamp::IsFinite(left)) { - return left; - } - date_t date; - dtime_t time; - Timestamp::Convert(left, date, time); - auto new_date = Interval::Add(date, right); - auto new_time = Interval::Add(time, right, new_date); - return Timestamp::FromDatetime(new_date, new_time); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/list_segment.cpp b/src/duckdb/src/common/types/list_segment.cpp deleted file mode 100644 index 8145cf07f..000000000 --- a/src/duckdb/src/common/types/list_segment.cpp +++ /dev/null @@ -1,677 +0,0 @@ -#include "duckdb/common/types/list_segment.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -// forward declarations -//===--------------------------------------------------------------------===// -// Primitives -//===--------------------------------------------------------------------===// -template -static idx_t GetAllocationSize(uint16_t capacity) { - return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(T))); -} - -template -static data_ptr_t AllocatePrimitiveData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetAllocationSize(capacity)); -} - -template -static T *GetPrimitiveData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + segment->capacity * sizeof(bool)); -} - -template -static const T *GetPrimitiveData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -//===--------------------------------------------------------------------===// -// Strings -//===--------------------------------------------------------------------===// -static idx_t GetStringAllocationSize(uint16_t capacity) { - return AlignValue(sizeof(ListSegment) + (capacity * (sizeof(char)))); -} - -static data_ptr_t AllocateStringData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetStringAllocationSize(capacity)); -} - -static char *GetStringData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment)); -} - -//===--------------------------------------------------------------------===// -// Lists -//===--------------------------------------------------------------------===// -static idx_t GetAllocationSizeList(uint16_t capacity) { - return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool) + sizeof(uint64_t)) + sizeof(LinkedList)); -} - -static data_ptr_t AllocateListData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetAllocationSizeList(capacity)); -} - -static uint64_t *GetListLengthData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const uint64_t *GetListLengthData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const LinkedList *GetListChildData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * (sizeof(bool) + sizeof(uint64_t))); -} - -static LinkedList *GetListChildData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * (sizeof(bool) + sizeof(uint64_t))); -} - -//===--------------------------------------------------------------------===// -// Array -//===--------------------------------------------------------------------===// -static idx_t GetAllocationSizeArray(uint16_t capacity) { - // Only store the null mask for the array segment, length is fixed so we don't need to store it - return AlignValue(sizeof(ListSegment) + capacity * (sizeof(bool)) + sizeof(LinkedList)); -} - -static data_ptr_t AllocateArrayData(ArenaAllocator &allocator, uint16_t capacity) { - return allocator.Allocate(GetAllocationSizeArray(capacity)); -} - -static const LinkedList *GetArrayChildData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static LinkedList *GetArrayChildData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -//===--------------------------------------------------------------------===// -// Structs -//===--------------------------------------------------------------------===// -static idx_t GetAllocationSizeStruct(uint16_t capacity, idx_t child_count) { - return AlignValue(sizeof(ListSegment) + capacity * sizeof(bool) + child_count * sizeof(ListSegment *)); -} - -static data_ptr_t AllocateStructData(ArenaAllocator &allocator, uint16_t capacity, idx_t child_count) { - return allocator.Allocate(GetAllocationSizeStruct(capacity, child_count)); -} - -static ListSegment **GetStructData(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + +sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static const ListSegment *const *GetStructData(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment) + - segment->capacity * sizeof(bool)); -} - -static bool *GetNullMask(ListSegment *segment) { - return reinterpret_cast(data_ptr_cast(segment) + sizeof(ListSegment)); -} - -static const bool *GetNullMask(const ListSegment *segment) { - return reinterpret_cast(const_data_ptr_cast(segment) + sizeof(ListSegment)); -} - -static uint16_t GetCapacityForNewSegment(uint16_t capacity) { - auto next_power_of_two = idx_t(capacity) * 2; - if (next_power_of_two >= NumericLimits::Maximum()) { - return capacity; - } - return uint16_t(next_power_of_two); -} - -//===--------------------------------------------------------------------===// -// Create -//===--------------------------------------------------------------------===// -template -static ListSegment *CreatePrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { - // allocate data and set the header - auto segment = reinterpret_cast(AllocatePrimitiveData(allocator, capacity)); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - return segment; -} - -static ListSegment *CreateVarcharDataSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, - uint16_t capacity) { - // allocate data and set the header - auto segment = reinterpret_cast(AllocateStringData(allocator, capacity)); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - return segment; -} - -static ListSegment *CreateListSegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { - // allocate data and set the header - auto segment = reinterpret_cast(AllocateListData(allocator, capacity)); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - - // create an empty linked list for the child vector - auto linked_child_list = GetListChildData(segment); - LinkedList linked_list(0, nullptr, nullptr); - Store(linked_list, data_ptr_cast(linked_child_list)); - - return segment; -} - -static ListSegment *CreateStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - uint16_t capacity) { - // allocate data and set header - auto segment = - reinterpret_cast(AllocateStructData(allocator, capacity, functions.child_functions.size())); - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - - // create a child ListSegment with exactly the same capacity for each child vector - auto child_segments = GetStructData(segment); - for (idx_t i = 0; i < functions.child_functions.size(); i++) { - auto child_function = functions.child_functions[i]; - auto child_segment = child_function.create_segment(child_function, allocator, capacity); - Store(child_segment, data_ptr_cast(child_segments + i)); - } - - return segment; -} - -static ListSegment *CreateArraySegment(const ListSegmentFunctions &, ArenaAllocator &allocator, uint16_t capacity) { - // allocate data and set header - auto segment = reinterpret_cast(AllocateArrayData(allocator, capacity)); - - segment->capacity = capacity; - segment->count = 0; - segment->next = nullptr; - - // create an empty linked list for the child vector - auto linked_child_list = GetArrayChildData(segment); - LinkedList linked_list(0, nullptr, nullptr); - Store(linked_list, data_ptr_cast(linked_child_list)); - - return segment; -} - -static ListSegment *GetSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - LinkedList &linked_list) { - ListSegment *segment; - - // determine segment - if (!linked_list.last_segment) { - // empty linked list, create the first (and last) segment - segment = functions.create_segment(functions, allocator, functions.initial_capacity); - linked_list.first_segment = segment; - linked_list.last_segment = segment; - } else if (linked_list.last_segment->capacity == linked_list.last_segment->count) { - // the last segment of the linked list is full, create a new one and append it - auto capacity = GetCapacityForNewSegment(linked_list.last_segment->capacity); - segment = functions.create_segment(functions, allocator, capacity); - linked_list.last_segment->next = segment; - linked_list.last_segment = segment; - } else { - // the last segment of the linked list is not full, append the data to it - segment = linked_list.last_segment; - } - - D_ASSERT(segment); - return segment; -} - -//===--------------------------------------------------------------------===// -// Append -//===--------------------------------------------------------------------===// -template -static void WriteDataToPrimitiveSegment(const ListSegmentFunctions &, ArenaAllocator &, ListSegment *segment, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // write value - if (valid) { - auto segment_data = GetPrimitiveData(segment); - auto input_data_ptr = UnifiedVectorFormat::GetData(input_data.unified); - Store(input_data_ptr[sel_entry_idx], data_ptr_cast(segment_data + segment->count)); - } -} - -static void WriteDataToVarcharSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, - idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this string - auto str_length_data = GetListLengthData(segment); - - // we can reconstruct the offset from the length - if (!valid) { - Store(0, data_ptr_cast(str_length_data + segment->count)); - return; - } - auto &str_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - auto str_data = str_entry.GetData(); - idx_t str_size = str_entry.GetSize(); - Store(str_size, data_ptr_cast(str_length_data + segment->count)); - - // write the characters to the linked list of child segments - auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - idx_t current_offset = 0; - while (current_offset < str_size) { - auto child_segment = GetSegment(functions.child_functions.back(), allocator, child_segments); - auto data = GetStringData(child_segment); - idx_t copy_count = MinValue(str_size - current_offset, child_segment->capacity - child_segment->count); - memcpy(data + child_segment->count, str_data + current_offset, copy_count); - current_offset += copy_count; - child_segment->count += copy_count; - } - child_segments.total_capacity += str_size; - // store the updated linked list - Store(child_segments, data_ptr_cast(GetListChildData(segment))); -} - -static void WriteDataToListSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // set the length of this list - auto list_length_data = GetListLengthData(segment); - uint64_t list_length = 0; - - if (valid) { - // get list entry information - const auto &list_entry = UnifiedVectorFormat::GetData(input_data.unified)[sel_entry_idx]; - list_length = list_entry.length; - - // loop over the child vector entries and recurse on them - auto child_segments = Load(data_ptr_cast(GetListChildData(segment))); - D_ASSERT(functions.child_functions.size() == 1); - for (idx_t child_idx = 0; child_idx < list_entry.length; child_idx++) { - auto source_idx_child = list_entry.offset + child_idx; - functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), - source_idx_child); - } - // store the updated linked list - Store(child_segments, data_ptr_cast(GetListChildData(segment))); - } - - Store(list_length, data_ptr_cast(list_length_data + segment->count)); -} - -static void WriteDataToStructSegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // write value - D_ASSERT(input_data.children.size() == functions.child_functions.size()); - auto child_list = GetStructData(segment); - - // write the data of each of the children of the struct - for (idx_t i = 0; i < input_data.children.size(); i++) { - auto child_list_segment = Load(data_ptr_cast(child_list + i)); - auto &child_function = functions.child_functions[i]; - child_function.write_data(child_function, allocator, child_list_segment, input_data.children[i], entry_idx); - child_list_segment->count++; - } -} - -static void WriteDataToArraySegment(const ListSegmentFunctions &functions, ArenaAllocator &allocator, - ListSegment *segment, RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) { - auto sel_entry_idx = input_data.unified.sel->get_index(entry_idx); - - // write null validity - auto null_mask = GetNullMask(segment); - auto valid = input_data.unified.validity.RowIsValid(sel_entry_idx); - null_mask[segment->count] = !valid; - - // Arrays require there to be values in the child even when the entry is NULL. - auto array_size = ArrayType::GetSize(input_data.logical_type); - auto array_offset = sel_entry_idx * array_size; - - auto child_segments = Load(data_ptr_cast(GetArrayChildData(segment))); - D_ASSERT(functions.child_functions.size() == 1); - for (idx_t elem_idx = array_offset; elem_idx < array_offset + array_size; elem_idx++) { - functions.child_functions[0].AppendRow(allocator, child_segments, input_data.children.back(), elem_idx); - } - // store the updated linked list - Store(child_segments, data_ptr_cast(GetArrayChildData(segment))); -} - -void ListSegmentFunctions::AppendRow(ArenaAllocator &allocator, LinkedList &linked_list, - RecursiveUnifiedVectorFormat &input_data, idx_t &entry_idx) const { - - auto &write_data_to_segment = *this; - auto segment = GetSegment(write_data_to_segment, allocator, linked_list); - write_data_to_segment.write_data(write_data_to_segment, allocator, segment, input_data, entry_idx); - - linked_list.total_capacity++; - segment->count++; -} - -//===--------------------------------------------------------------------===// -// Read -//===--------------------------------------------------------------------===// -template -static void ReadDataFromPrimitiveSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto aggr_vector_data = FlatVector::GetData(result); - - // load values - for (idx_t i = 0; i < segment->count; i++) { - if (aggr_vector_validity.RowIsValid(total_count + i)) { - auto data = GetPrimitiveData(segment); - aggr_vector_data[total_count + i] = Load(const_data_ptr_cast(data + i)); - } - } -} - -static void ReadDataFromVarcharSegment(const ListSegmentFunctions &, const ListSegment *segment, Vector &result, - idx_t &total_count) { - auto &aggr_vector_validity = FlatVector::Validity(result); - - // use length and (reconstructed) offset to get the correct substrings - auto aggr_vector_data = FlatVector::GetData(result); - auto str_length_data = GetListLengthData(segment); - - auto null_mask = GetNullMask(segment); - auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); - auto current_segment = linked_child_list.first_segment; - idx_t child_offset = 0; - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - // set to null - aggr_vector_validity.SetInvalid(total_count + i); - continue; - } - // read the string - auto &result_str = aggr_vector_data[total_count + i]; - auto str_length = Load(const_data_ptr_cast(str_length_data + i)); - // allocate an empty string for the given size - result_str = StringVector::EmptyString(result, str_length); - auto result_data = result_str.GetDataWriteable(); - // copy over the data - idx_t current_offset = 0; - while (current_offset < str_length) { - if (!current_segment) { - throw InternalException("Insufficient data to read string"); - } - auto child_data = GetStringData(current_segment); - idx_t max_copy = MinValue(str_length - current_offset, current_segment->capacity - child_offset); - memcpy(result_data + current_offset, child_data + child_offset, max_copy); - current_offset += max_copy; - child_offset += max_copy; - if (child_offset >= current_segment->capacity) { - D_ASSERT(child_offset == current_segment->capacity); - current_segment = current_segment->next; - child_offset = 0; - } - } - - // finalize the str - result_str.Finalize(); - } -} - -static void ReadDataFromListSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto list_vector_data = FlatVector::GetData(result); - - // get the starting offset - idx_t offset = 0; - if (total_count != 0) { - offset = list_vector_data[total_count - 1].offset + list_vector_data[total_count - 1].length; - } - idx_t starting_offset = offset; - - // set length and offsets - auto list_length_data = GetListLengthData(segment); - for (idx_t i = 0; i < segment->count; i++) { - auto list_length = Load(const_data_ptr_cast(list_length_data + i)); - list_vector_data[total_count + i].length = list_length; - list_vector_data[total_count + i].offset = offset; - offset += list_length; - } - - auto &child_vector = ListVector::GetEntry(result); - auto linked_child_list = Load(const_data_ptr_cast(GetListChildData(segment))); - ListVector::Reserve(result, offset); - - // recurse into the linked list of child values - D_ASSERT(functions.child_functions.size() == 1); - functions.child_functions[0].BuildListVector(linked_child_list, child_vector, starting_offset); - ListVector::SetListSize(result, offset); -} - -static void ReadDataFromStructSegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto &children = StructVector::GetEntries(result); - - // recurse into the child segments of each child of the struct - D_ASSERT(children.size() == functions.child_functions.size()); - auto struct_children = GetStructData(segment); - for (idx_t child_count = 0; child_count < children.size(); child_count++) { - auto struct_children_segment = Load(const_data_ptr_cast(struct_children + child_count)); - auto &child_function = functions.child_functions[child_count]; - child_function.read_data(child_function, struct_children_segment, *children[child_count], total_count); - } -} - -static void ReadDataFromArraySegment(const ListSegmentFunctions &functions, const ListSegment *segment, Vector &result, - idx_t &total_count) { - - auto &aggr_vector_validity = FlatVector::Validity(result); - - // set NULLs - auto null_mask = GetNullMask(segment); - for (idx_t i = 0; i < segment->count; i++) { - if (null_mask[i]) { - aggr_vector_validity.SetInvalid(total_count + i); - } - } - - auto &child_vector = ArrayVector::GetEntry(result); - auto linked_child_list = Load(const_data_ptr_cast(GetArrayChildData(segment))); - auto array_size = ArrayType::GetSize(result.GetType()); - auto child_size = array_size * total_count; - - // recurse into the linked list of child values - D_ASSERT(functions.child_functions.size() == 1); - functions.child_functions[0].BuildListVector(linked_child_list, child_vector, child_size); -} - -void ListSegmentFunctions::BuildListVector(const LinkedList &linked_list, Vector &result, idx_t total_count) const { - auto &read_data_from_segment = *this; - auto segment = linked_list.first_segment; - while (segment) { - read_data_from_segment.read_data(read_data_from_segment, segment, result, total_count); - total_count += segment->count; - segment = segment->next; - } -} - -//===--------------------------------------------------------------------===// -// Functions -//===--------------------------------------------------------------------===// -template -void SegmentPrimitiveFunction(ListSegmentFunctions &functions) { - functions.create_segment = CreatePrimitiveSegment; - functions.write_data = WriteDataToPrimitiveSegment; - functions.read_data = ReadDataFromPrimitiveSegment; -} - -void GetSegmentDataFunctions(ListSegmentFunctions &functions, const LogicalType &type) { - - if (type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - auto physical_type = type.InternalType(); - switch (physical_type) { - case PhysicalType::BIT: - case PhysicalType::BOOL: - SegmentPrimitiveFunction(functions); - functions.initial_capacity = 8; - break; - case PhysicalType::INT8: - SegmentPrimitiveFunction(functions); - functions.initial_capacity = 8; - break; - case PhysicalType::INT16: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT32: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT64: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT8: - SegmentPrimitiveFunction(functions); - functions.initial_capacity = 8; - break; - case PhysicalType::UINT16: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT32: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT64: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::FLOAT: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::DOUBLE: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INT128: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::UINT128: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::INTERVAL: - SegmentPrimitiveFunction(functions); - break; - case PhysicalType::VARCHAR: { - functions.create_segment = CreateListSegment; - functions.write_data = WriteDataToVarcharSegment; - functions.read_data = ReadDataFromVarcharSegment; - - ListSegmentFunctions child_function; - child_function.create_segment = CreateVarcharDataSegment; - child_function.write_data = nullptr; - child_function.read_data = nullptr; - child_function.initial_capacity = 16; - functions.child_functions.push_back(child_function); - break; - } - case PhysicalType::LIST: { - functions.create_segment = CreateListSegment; - functions.write_data = WriteDataToListSegment; - functions.read_data = ReadDataFromListSegment; - - // recurse - functions.child_functions.emplace_back(); - GetSegmentDataFunctions(functions.child_functions.back(), ListType::GetChildType(type)); - break; - } - case PhysicalType::STRUCT: { - functions.create_segment = CreateStructSegment; - functions.write_data = WriteDataToStructSegment; - functions.read_data = ReadDataFromStructSegment; - - // recurse - auto child_types = StructType::GetChildTypes(type); - for (idx_t i = 0; i < child_types.size(); i++) { - functions.child_functions.emplace_back(); - GetSegmentDataFunctions(functions.child_functions.back(), child_types[i].second); - } - break; - } - case PhysicalType::ARRAY: { - functions.create_segment = CreateArraySegment; - functions.write_data = WriteDataToArraySegment; - functions.read_data = ReadDataFromArraySegment; - - // recurse - functions.child_functions.emplace_back(); - GetSegmentDataFunctions(functions.child_functions.back(), ArrayType::GetChildType(type)); - break; - } - default: - throw InternalException("LIST aggregate not yet implemented for " + type.ToString()); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp b/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp deleted file mode 100644 index b77463d8c..000000000 --- a/src/duckdb/src/common/types/row/partitioned_tuple_data.cpp +++ /dev/null @@ -1,372 +0,0 @@ -#include "duckdb/common/types/row/partitioned_tuple_data.hpp" - -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/types/row/tuple_data_iterator.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -PartitionedTupleData::PartitionedTupleData(PartitionedTupleDataType type_p, BufferManager &buffer_manager_p, - const TupleDataLayout &layout_p) - : type(type_p), buffer_manager(buffer_manager_p), layout(layout_p.Copy()), count(0), data_size(0), - allocators(make_shared_ptr()) { -} - -PartitionedTupleData::PartitionedTupleData(const PartitionedTupleData &other) - : type(other.type), buffer_manager(other.buffer_manager), layout(other.layout.Copy()), count(0), data_size(0) { -} - -PartitionedTupleData::~PartitionedTupleData() { -} - -const TupleDataLayout &PartitionedTupleData::GetLayout() const { - return layout; -} - -PartitionedTupleDataType PartitionedTupleData::GetType() const { - return type; -} - -void PartitionedTupleData::InitializeAppendState(PartitionedTupleDataAppendState &state, - TupleDataPinProperties properties) const { - state.partition_sel.Initialize(); - state.reverse_partition_sel.Initialize(); - - InitializeAppendStateInternal(state, properties); -} - -void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel, const idx_t append_count) { - TupleDataCollection::ToUnifiedFormat(state.chunk_state, input); - AppendUnified(state, input, append_sel, append_count); -} - -bool PartitionedTupleData::UseFixedSizeMap() const { - return MaxPartitionIndex() < PartitionedTupleDataAppendState::MAP_THRESHOLD; -} - -void PartitionedTupleData::AppendUnified(PartitionedTupleDataAppendState &state, DataChunk &input, - const SelectionVector &append_sel, const idx_t append_count) { - const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? input.size() : append_count; - - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(state, input, append_sel, actual_append_count); - - // Build the selection vector for the partitions - BuildPartitionSel(state, append_sel, actual_append_count); - - // Early out: check if everything belongs to a single partition - const auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); - if (partition_index.IsValid()) { - auto &partition = *partitions[partition_index.GetIndex()]; - auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; - - const auto size_before = partition.SizeInBytes(); - partition.AppendUnified(partition_pin_state, state.chunk_state, input, append_sel, actual_append_count); - data_size += partition.SizeInBytes() - size_before; - } else { - // Compute the heap sizes for the whole chunk - if (!layout.AllConstant()) { - TupleDataCollection::ComputeHeapSizes(state.chunk_state, input, state.partition_sel, actual_append_count); - } - - // Build the buffer space - BuildBufferSpace(state); - - // Now scatter everything in one go - partitions[0]->Scatter(state.chunk_state, input, state.partition_sel, actual_append_count); - } - - count += actual_append_count; - Verify(); -} - -void PartitionedTupleData::Append(PartitionedTupleDataAppendState &state, TupleDataChunkState &input, - const idx_t append_count) { - // Compute partition indices and store them in state.partition_indices - ComputePartitionIndices(input.row_locations, append_count, state.partition_indices); - - // Build the selection vector for the partitions - BuildPartitionSel(state, *FlatVector::IncrementalSelectionVector(), append_count); - - // Early out: check if everything belongs to a single partition - auto partition_index = state.GetPartitionIndexIfSinglePartition(UseFixedSizeMap()); - if (partition_index.IsValid()) { - auto &partition = *partitions[partition_index.GetIndex()]; - auto &partition_pin_state = *state.partition_pin_states[partition_index.GetIndex()]; - - state.chunk_state.heap_sizes.Reference(input.heap_sizes); - - const auto size_before = partition.SizeInBytes(); - partition.Build(partition_pin_state, state.chunk_state, 0, append_count); - data_size += partition.SizeInBytes() - size_before; - - partition.CopyRows(state.chunk_state, input, *FlatVector::IncrementalSelectionVector(), append_count); - } else { - // Build the buffer space - state.chunk_state.heap_sizes.Slice(input.heap_sizes, state.partition_sel, append_count); - state.chunk_state.heap_sizes.Flatten(append_count); - BuildBufferSpace(state); - - // Copy the rows - partitions[0]->CopyRows(state.chunk_state, input, state.partition_sel, append_count); - } - - count += append_count; - Verify(); -} - -void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, - const idx_t append_count) const { - if (UseFixedSizeMap()) { - BuildPartitionSel(state, append_sel, append_count); - } else { - BuildPartitionSel(state, append_sel, append_count); - } -} - -template -void PartitionedTupleData::BuildPartitionSel(PartitionedTupleDataAppendState &state, const SelectionVector &append_sel, - const idx_t append_count) { - using GETTER = TemplatedMapGetter; - auto &partition_entries = state.GetMap(); - const auto partition_indices = FlatVector::GetData(state.partition_indices); - partition_entries.clear(); - switch (state.partition_indices.GetVectorType()) { - case VectorType::FLAT_VECTOR: - for (idx_t i = 0; i < append_count; i++) { - const auto &partition_index = partition_indices[i]; - auto partition_entry = partition_entries.find(partition_index); - if (partition_entry == partition_entries.end()) { - partition_entries[partition_index] = list_entry_t(0, 1); - } else { - GETTER::GetValue(partition_entry).length++; - } - } - break; - case VectorType::CONSTANT_VECTOR: - partition_entries[partition_indices[0]] = list_entry_t(0, append_count); - break; - default: - throw InternalException("Unexpected VectorType in PartitionedTupleData::Append"); - } - - // Early out: check if everything belongs to a single partition - if (partition_entries.size() == 1) { - // This needs to be initialized, even if we go the short path here - for (sel_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - state.reverse_partition_sel[index] = i; - } - return; - } - - // Compute offsets from the counts - idx_t offset = 0; - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - auto &partition_entry = GETTER::GetValue(it); - partition_entry.offset = offset; - offset += partition_entry.length; - } - - // Now initialize a single selection vector that acts as a selection vector for every partition - auto &partition_sel = state.partition_sel; - auto &reverse_partition_sel = state.reverse_partition_sel; - for (idx_t i = 0; i < append_count; i++) { - const auto index = append_sel.get_index(i); - const auto &partition_index = partition_indices[i]; - auto &partition_offset = partition_entries[partition_index].offset; - reverse_partition_sel[index] = UnsafeNumericCast(partition_offset); - partition_sel[partition_offset++] = UnsafeNumericCast(index); - } -} - -void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { - if (UseFixedSizeMap()) { - BuildBufferSpace(state); - } else { - BuildBufferSpace(state); - } -} - -template -void PartitionedTupleData::BuildBufferSpace(PartitionedTupleDataAppendState &state) { - using GETTER = TemplatedMapGetter; - const auto &partition_entries = state.GetMap(); - for (auto it = partition_entries.begin(); it != partition_entries.end(); ++it) { - const auto &partition_index = GETTER::GetKey(it); - - // Partition, pin state for this partition index - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - - // Length and offset for this partition - const auto &partition_entry = GETTER::GetValue(it); - const auto &partition_length = partition_entry.length; - const auto partition_offset = partition_entry.offset - partition_length; - - // Build out the buffer space for this partition - const auto size_before = partition.SizeInBytes(); - partition.Build(partition_pin_state, state.chunk_state, partition_offset, partition_length); - data_size += partition.SizeInBytes() - size_before; - } -} - -void PartitionedTupleData::FlushAppendState(PartitionedTupleDataAppendState &state) { - for (idx_t partition_index = 0; partition_index < partitions.size(); partition_index++) { - auto &partition = *partitions[partition_index]; - auto &partition_pin_state = *state.partition_pin_states[partition_index]; - partition.FinalizePinState(partition_pin_state); - } -} - -void PartitionedTupleData::Combine(PartitionedTupleData &other) { - if (other.Count() == 0) { - return; - } - - // Now combine the state's partitions into this - lock_guard guard(lock); - if (partitions.empty()) { - // This is the first merge, we just copy them over - partitions = std::move(other.partitions); - } else { - D_ASSERT(partitions.size() == other.partitions.size()); - // Combine the append state's partitions into this PartitionedTupleData - for (idx_t i = 0; i < other.partitions.size(); i++) { - partitions[i]->Combine(*other.partitions[i]); - } - } - this->count += other.count; - this->data_size += other.data_size; - Verify(); -} - -void PartitionedTupleData::Reset() { - for (auto &partition : partitions) { - partition->Reset(); - } - this->count = 0; - this->data_size = 0; - Verify(); -} - -void PartitionedTupleData::Repartition(PartitionedTupleData &new_partitioned_data) { - D_ASSERT(layout.GetTypes() == new_partitioned_data.layout.GetTypes()); - - if (partitions.size() == new_partitioned_data.partitions.size()) { - new_partitioned_data.Combine(*this); - return; - } - - PartitionedTupleDataAppendState append_state; - new_partitioned_data.InitializeAppendState(append_state); - - for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { - auto &partition = *partitions[partition_idx]; - - if (partition.Count() > 0) { - TupleDataChunkIterator iterator(partition, TupleDataPinProperties::DESTROY_AFTER_DONE, true); - auto &chunk_state = iterator.GetChunkState(); - do { - new_partitioned_data.Append(append_state, chunk_state, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - - RepartitionFinalizeStates(*this, new_partitioned_data, append_state, partition_idx); - } - partitions[partition_idx]->Reset(); - } - new_partitioned_data.FlushAppendState(append_state); - - count = 0; - data_size = 0; - - Verify(); -} - -void PartitionedTupleData::Unpin() { - for (auto &partition : partitions) { - partition->Unpin(); - } -} - -unsafe_vector> &PartitionedTupleData::GetPartitions() { - return partitions; -} - -unique_ptr PartitionedTupleData::GetUnpartitioned() { - auto data_collection = std::move(partitions[0]); - partitions[0] = make_uniq(buffer_manager, layout); - - for (idx_t i = 1; i < partitions.size(); i++) { - data_collection->Combine(*partitions[i]); - } - count = 0; - data_size = 0; - - data_collection->Verify(); - Verify(); - - return data_collection; -} - -idx_t PartitionedTupleData::Count() const { - return count; -} - -idx_t PartitionedTupleData::SizeInBytes() const { - idx_t total_size = 0; - for (auto &partition : partitions) { - total_size += partition->SizeInBytes(); - } - return total_size; -} - -idx_t PartitionedTupleData::PartitionCount() const { - return partitions.size(); -} - -void PartitionedTupleData::GetSizesAndCounts(vector &partition_sizes, vector &partition_counts) const { - D_ASSERT(partition_sizes.size() == PartitionCount()); - D_ASSERT(partition_sizes.size() == partition_counts.size()); - for (idx_t i = 0; i < PartitionCount(); i++) { - auto &partition = *partitions[i]; - partition_sizes[i] += partition.SizeInBytes(); - partition_counts[i] += partition.Count(); - } -} - -void PartitionedTupleData::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - idx_t total_size = 0; - for (auto &partition : partitions) { - partition->Verify(); - total_count += partition->Count(); - total_size += partition->SizeInBytes(); - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -// LCOV_EXCL_START -string PartitionedTupleData::ToString() { - string result = - StringUtil::Format("PartitionedTupleData - [%llu Partitions, %llu Rows]\n", partitions.size(), Count()); - for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { - result += StringUtil::Format("Partition %llu: ", partition_idx) + partitions[partition_idx]->ToString(); - } - return result; -} - -void PartitionedTupleData::Print() { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -void PartitionedTupleData::CreateAllocator() { - allocators->allocators.emplace_back(make_shared_ptr(buffer_manager, layout)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection.cpp b/src/duckdb/src/common/types/row/row_data_collection.cpp deleted file mode 100644 index b178b7fb5..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection.cpp +++ /dev/null @@ -1,141 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection.hpp" - -namespace duckdb { - -RowDataCollection::RowDataCollection(BufferManager &buffer_manager, idx_t block_capacity, idx_t entry_size, - bool keep_pinned) - : buffer_manager(buffer_manager), count(0), block_capacity(block_capacity), entry_size(entry_size), - keep_pinned(keep_pinned) { - D_ASSERT(block_capacity * entry_size + entry_size > buffer_manager.GetBlockSize()); -} - -idx_t RowDataCollection::AppendToBlock(RowDataBlock &block, BufferHandle &handle, - vector &append_entries, idx_t remaining, idx_t entry_sizes[]) { - idx_t append_count = 0; - data_ptr_t dataptr; - if (entry_sizes) { - D_ASSERT(entry_size == 1); - // compute how many entries fit if entry size is variable - dataptr = handle.Ptr() + block.byte_offset; - for (idx_t i = 0; i < remaining; i++) { - if (block.byte_offset + entry_sizes[i] > block.capacity) { - if (block.count == 0 && append_count == 0 && entry_sizes[i] > block.capacity) { - // special case: single entry is bigger than block capacity - // resize current block to fit the entry, append it, and move to the next block - block.capacity = entry_sizes[i]; - buffer_manager.ReAllocate(block.block, block.capacity); - dataptr = handle.Ptr(); - append_count++; - block.byte_offset += entry_sizes[i]; - } - break; - } - append_count++; - block.byte_offset += entry_sizes[i]; - } - } else { - append_count = MinValue(remaining, block.capacity - block.count); - dataptr = handle.Ptr() + block.count * entry_size; - } - append_entries.emplace_back(dataptr, append_count); - block.count += append_count; - return append_count; -} - -RowDataBlock &RowDataCollection::CreateBlock() { - blocks.push_back(make_uniq(MemoryTag::ORDER_BY, buffer_manager, block_capacity, entry_size)); - return *blocks.back(); -} - -vector RowDataCollection::Build(idx_t added_count, data_ptr_t key_locations[], idx_t entry_sizes[], - const SelectionVector *sel) { - vector handles; - vector append_entries; - - // first allocate space of where to serialize the keys and payload columns - idx_t remaining = added_count; - { - // first append to the last block (if any) - lock_guard append_lock(rdc_lock); - count += added_count; - - if (!blocks.empty()) { - auto &last_block = *blocks.back(); - if (last_block.count < last_block.capacity) { - // last block has space: pin the buffer of this block - auto handle = buffer_manager.Pin(last_block.block); - // now append to the block - idx_t append_count = AppendToBlock(last_block, handle, append_entries, remaining, entry_sizes); - remaining -= append_count; - handles.push_back(std::move(handle)); - } - } - while (remaining > 0) { - // now for the remaining data, allocate new buffers to store the data and append there - auto &new_block = CreateBlock(); - auto handle = buffer_manager.Pin(new_block.block); - - // offset the entry sizes array if we have added entries already - idx_t *offset_entry_sizes = entry_sizes ? entry_sizes + added_count - remaining : nullptr; - - idx_t append_count = AppendToBlock(new_block, handle, append_entries, remaining, offset_entry_sizes); - D_ASSERT(new_block.count > 0); - remaining -= append_count; - - if (keep_pinned) { - pinned_blocks.push_back(std::move(handle)); - } else { - handles.push_back(std::move(handle)); - } - } - } - // now set up the key_locations based on the append entries - idx_t append_idx = 0; - for (auto &append_entry : append_entries) { - idx_t next = append_idx + append_entry.count; - if (entry_sizes) { - for (; append_idx < next; append_idx++) { - key_locations[append_idx] = append_entry.baseptr; - append_entry.baseptr += entry_sizes[append_idx]; - } - } else { - for (; append_idx < next; append_idx++) { - auto idx = sel->get_index(append_idx); - key_locations[idx] = append_entry.baseptr; - append_entry.baseptr += entry_size; - } - } - } - // return the unique pointers to the handles because they must stay pinned - return handles; -} - -void RowDataCollection::Merge(RowDataCollection &other) { - if (other.count == 0) { - return; - } - RowDataCollection temp(buffer_manager, buffer_manager.GetBlockSize(), 1); - { - // One lock at a time to avoid deadlocks - lock_guard read_lock(other.rdc_lock); - temp.count = other.count; - temp.block_capacity = other.block_capacity; - temp.entry_size = other.entry_size; - temp.blocks = std::move(other.blocks); - temp.pinned_blocks = std::move(other.pinned_blocks); - } - other.Clear(); - - lock_guard write_lock(rdc_lock); - count += temp.count; - block_capacity = MaxValue(block_capacity, temp.block_capacity); - entry_size = MaxValue(entry_size, temp.entry_size); - for (auto &block : temp.blocks) { - blocks.emplace_back(std::move(block)); - } - for (auto &handle : temp.pinned_blocks) { - pinned_blocks.emplace_back(std::move(handle)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp b/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp deleted file mode 100644 index 9b3a4be06..000000000 --- a/src/duckdb/src/common/types/row/row_data_collection_scanner.cpp +++ /dev/null @@ -1,330 +0,0 @@ -#include "duckdb/common/types/row/row_data_collection_scanner.hpp" - -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/row_data_collection.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -void RowDataCollectionScanner::AlignHeapBlocks(RowDataCollection &swizzled_block_collection, - RowDataCollection &swizzled_string_heap, - RowDataCollection &block_collection, RowDataCollection &string_heap, - const RowLayout &layout) { - if (block_collection.count == 0) { - return; - } - - if (layout.AllConstant()) { - // No heap blocks! Just merge fixed-size data - swizzled_block_collection.Merge(block_collection); - return; - } - - // We create one heap block per data block and swizzle the pointers - D_ASSERT(string_heap.keep_pinned == swizzled_string_heap.keep_pinned); - auto &buffer_manager = block_collection.buffer_manager; - auto &heap_blocks = string_heap.blocks; - idx_t heap_block_idx = 0; - idx_t heap_block_remaining = heap_blocks[heap_block_idx]->count; - for (auto &data_block : block_collection.blocks) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - - // Pin the data block and swizzle the pointers within the rows - auto data_handle = buffer_manager.Pin(data_block->block); - auto data_ptr = data_handle.Ptr(); - if (!string_heap.keep_pinned) { - D_ASSERT(!data_block->block->IsSwizzled()); - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - data_block->block->SetSwizzling(nullptr); - } - // At this point the data block is pinned and the heap pointer is valid - // so we can copy heap data as needed - - // We want to copy as little of the heap data as possible, check how the data and heap blocks line up - if (heap_block_remaining >= data_block->count) { - // Easy: current heap block contains all strings for this data block, just copy (reference) the block - swizzled_string_heap.blocks.emplace_back(heap_blocks[heap_block_idx]->Copy()); - swizzled_string_heap.blocks.back()->count = data_block->count; - - // Swizzle the heap pointer if we are not pinning the heap - auto &heap_block = swizzled_string_heap.blocks.back()->block; - auto heap_handle = buffer_manager.Pin(heap_block); - if (!swizzled_string_heap.keep_pinned) { - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block->count, - NumericCast(heap_offset)); - } else { - swizzled_string_heap.pinned_blocks.emplace_back(std::move(heap_handle)); - } - - // Update counter - heap_block_remaining -= data_block->count; - } else { - // Strings for this data block are spread over the current heap block and the next (and possibly more) - if (string_heap.keep_pinned) { - // The heap is changing underneath the data block, - // so swizzle the string pointers to make them portable. - RowOperations::SwizzleColumns(layout, data_ptr, data_block->count); - } - idx_t data_block_remaining = data_block->count; - vector> ptrs_and_sizes; - idx_t total_size = 0; - const auto base_row_ptr = data_ptr; - while (data_block_remaining > 0) { - if (heap_block_remaining == 0) { - heap_block_remaining = heap_blocks[++heap_block_idx]->count; - } - auto next = MinValue(data_block_remaining, heap_block_remaining); - - // Figure out where to start copying strings, and how many bytes we need to copy - auto heap_start_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_end_ptr = - Load(data_ptr + layout.GetHeapOffset() + (next - 1) * layout.GetRowWidth()); - auto size = NumericCast(heap_end_ptr - heap_start_ptr + Load(heap_end_ptr)); - ptrs_and_sizes.emplace_back(heap_start_ptr, size); - D_ASSERT(size <= heap_blocks[heap_block_idx]->byte_offset); - - // Swizzle the heap pointer - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_start_ptr, next, total_size); - total_size += size; - - // Update where we are in the data and heap blocks - data_ptr += next * layout.GetRowWidth(); - data_block_remaining -= next; - heap_block_remaining -= next; - } - - // Finally, we allocate a new heap block and copy data to it - swizzled_string_heap.blocks.emplace_back(make_uniq( - MemoryTag::ORDER_BY, buffer_manager, MaxValue(total_size, buffer_manager.GetBlockSize()), 1U)); - auto new_heap_handle = buffer_manager.Pin(swizzled_string_heap.blocks.back()->block); - auto new_heap_ptr = new_heap_handle.Ptr(); - for (auto &ptr_and_size : ptrs_and_sizes) { - memcpy(new_heap_ptr, ptr_and_size.first, ptr_and_size.second); - new_heap_ptr += ptr_and_size.second; - } - new_heap_ptr = new_heap_handle.Ptr(); - if (swizzled_string_heap.keep_pinned) { - // Since the heap blocks are pinned, we can unswizzle the data again. - swizzled_string_heap.pinned_blocks.emplace_back(std::move(new_heap_handle)); - RowOperations::UnswizzlePointers(layout, base_row_ptr, new_heap_ptr, data_block->count); - RowOperations::UnswizzleHeapPointer(layout, base_row_ptr, new_heap_ptr, data_block->count); - } - } - } - - // We're done with variable-sized data, now just merge the fixed-size data - swizzled_block_collection.Merge(block_collection); - D_ASSERT(swizzled_block_collection.blocks.size() == swizzled_string_heap.blocks.size()); - - // Update counts and cleanup - swizzled_string_heap.count = string_heap.count; - string_heap.Clear(); -} - -void RowDataCollectionScanner::ScanState::PinData() { - auto &rows = scanner.rows; - D_ASSERT(block_idx < rows.blocks.size()); - auto &data_block = rows.blocks[block_idx]; - if (!data_handle.IsValid() || data_handle.GetBlockHandle() != data_block->block) { - data_handle = rows.buffer_manager.Pin(data_block->block); - } - if (scanner.layout.AllConstant() || !scanner.external) { - return; - } - - auto &heap = scanner.heap; - D_ASSERT(block_idx < heap.blocks.size()); - auto &heap_block = heap.blocks[block_idx]; - if (!heap_handle.IsValid() || heap_handle.GetBlockHandle() != heap_block->block) { - heap_handle = heap.buffer_manager.Pin(heap_block->block); - } -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - ValidateUnscannedBlock(); -} - -RowDataCollectionScanner::RowDataCollectionScanner(RowDataCollection &rows_p, RowDataCollection &heap_p, - const RowLayout &layout_p, bool external_p, idx_t block_idx, - bool flush_p) - : rows(rows_p), heap(heap_p), layout(layout_p), read_state(*this), total_count(rows.count), total_scanned(0), - external(external_p), flush(flush_p), unswizzling(!layout.AllConstant() && external && !heap.keep_pinned) { - - if (unswizzling) { - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - } - - D_ASSERT(block_idx < rows.blocks.size()); - read_state.block_idx = block_idx; - read_state.entry_idx = 0; - - // Pretend that we have scanned up to the start block - // and will stop at the end - auto begin = rows.blocks.begin(); - auto end = begin + NumericCast(block_idx); - total_scanned = - std::accumulate(begin, end, idx_t(0), [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - total_count = total_scanned + (*end)->count; - - ValidateUnscannedBlock(); -} - -void RowDataCollectionScanner::SwizzleBlockInternal(RowDataBlock &data_block, RowDataBlock &heap_block) { - // Pin the data block and swizzle the pointers within the rows - D_ASSERT(!data_block.block->IsSwizzled()); - auto data_handle = rows.buffer_manager.Pin(data_block.block); - auto data_ptr = data_handle.Ptr(); - RowOperations::SwizzleColumns(layout, data_ptr, data_block.count); - data_block.block->SetSwizzling(nullptr); - - // Swizzle the heap pointers - auto heap_handle = heap.buffer_manager.Pin(heap_block.block); - auto heap_ptr = Load(data_ptr + layout.GetHeapOffset()); - auto heap_offset = heap_ptr - heap_handle.Ptr(); - RowOperations::SwizzleHeapPointer(layout, data_ptr, heap_ptr, data_block.count, NumericCast(heap_offset)); -} - -void RowDataCollectionScanner::SwizzleBlock(idx_t block_idx) { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - auto &data_block = rows.blocks[block_idx]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[block_idx]); - } -} - -void RowDataCollectionScanner::ReSwizzle() { - if (rows.count == 0) { - return; - } - - if (!unswizzling) { - // No swizzled blocks! - return; - } - - D_ASSERT(rows.blocks.size() == heap.blocks.size()); - for (idx_t i = 0; i < rows.blocks.size(); ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } -} - -void RowDataCollectionScanner::ValidateUnscannedBlock() const { - if (unswizzling && read_state.block_idx < rows.blocks.size() && Remaining()) { - D_ASSERT(rows.blocks[read_state.block_idx]->block->IsSwizzled()); - } -} - -void RowDataCollectionScanner::Scan(DataChunk &chunk) { - auto count = MinValue((idx_t)STANDARD_VECTOR_SIZE, total_count - total_scanned); - if (count == 0) { - chunk.SetCardinality(count); - return; - } - - // Only flush blocks we processed. - const auto flush_block_idx = read_state.block_idx; - - const idx_t &row_width = layout.GetRowWidth(); - // Set up a batch of pointers to scan data from - idx_t scanned = 0; - auto data_pointers = FlatVector::GetData(addresses); - - // We must pin ALL blocks we are going to gather from - vector pinned_blocks; - while (scanned < count) { - read_state.PinData(); - auto &data_block = rows.blocks[read_state.block_idx]; - idx_t next = MinValue(data_block->count - read_state.entry_idx, count - scanned); - const data_ptr_t data_ptr = read_state.data_handle.Ptr() + read_state.entry_idx * row_width; - // Set up the next pointers - data_ptr_t row_ptr = data_ptr; - for (idx_t i = 0; i < next; i++) { - data_pointers[scanned + i] = row_ptr; - row_ptr += row_width; - } - // Unswizzle the offsets back to pointers (if needed) - if (unswizzling) { - RowOperations::UnswizzlePointers(layout, data_ptr, read_state.heap_handle.Ptr(), next); - rows.blocks[read_state.block_idx]->block->SetSwizzling("RowDataCollectionScanner::Scan"); - } - // Update state indices - read_state.entry_idx += next; - scanned += next; - total_scanned += next; - if (read_state.entry_idx == data_block->count) { - // Pin completed blocks so we don't lose them - pinned_blocks.emplace_back(rows.buffer_manager.Pin(data_block->block)); - if (unswizzling) { - auto &heap_block = heap.blocks[read_state.block_idx]; - pinned_blocks.emplace_back(heap.buffer_manager.Pin(heap_block->block)); - } - read_state.block_idx++; - read_state.entry_idx = 0; - ValidateUnscannedBlock(); - } - } - D_ASSERT(scanned == count); - // Deserialize the payload data - for (idx_t col_no = 0; col_no < layout.ColumnCount(); col_no++) { - RowOperations::Gather(addresses, *FlatVector::IncrementalSelectionVector(), chunk.data[col_no], - *FlatVector::IncrementalSelectionVector(), count, layout, col_no); - } - chunk.SetCardinality(count); - chunk.Verify(); - - // Switch to a new set of pinned blocks - read_state.pinned_blocks.swap(pinned_blocks); - - if (flush) { - // Release blocks we have passed. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - rows.blocks[i]->block = nullptr; - if (unswizzling) { - heap.blocks[i]->block = nullptr; - } - } - } else if (unswizzling) { - // Reswizzle blocks we have passed so they can be flushed safely. - for (idx_t i = flush_block_idx; i < read_state.block_idx; ++i) { - auto &data_block = rows.blocks[i]; - if (data_block->block && !data_block->block->IsSwizzled()) { - SwizzleBlockInternal(*data_block, *heap.blocks[i]); - } - } - } -} - -void RowDataCollectionScanner::Reset(bool flush_p) { - flush = flush_p; - total_scanned = 0; - - read_state.block_idx = 0; - read_state.entry_idx = 0; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/row_layout.cpp b/src/duckdb/src/common/types/row/row_layout.cpp deleted file mode 100644 index 3add8e425..000000000 --- a/src/duckdb/src/common/types/row/row_layout.cpp +++ /dev/null @@ -1,62 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/common/types/row_layout.cpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/common/types/row/row_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -RowLayout::RowLayout() : flag_width(0), data_width(0), row_width(0), all_constant(true), heap_pointer_offset(0) { -} - -void RowLayout::Initialize(vector types_p, bool align) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (const auto &type : types) { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - - // This enables pointer swizzling for out-of-core computation. - if (!all_constant) { - // When unswizzled, the pointer lives here. - // When swizzled, the pointer is replaced by an offset. - heap_pointer_offset = row_width; - // The 8 byte pointer will be replaced with an 8 byte idx_t when swizzled. - // However, this cannot be sizeof(data_ptr_t), since 32 bit builds use 4 byte pointers. - row_width += sizeof(idx_t); - } - - // Data columns. No alignment required. - for (const auto &type : types) { - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - data_width = row_width - flag_width; - - // Alignment padding for the next row - if (align) { - row_width = AlignValue(row_width); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp b/src/duckdb/src/common/types/row/tuple_data_allocator.cpp deleted file mode 100644 index ed9145a11..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_allocator.cpp +++ /dev/null @@ -1,528 +0,0 @@ -#include "duckdb/common/types/row/tuple_data_allocator.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/types/row/tuple_data_segment.hpp" -#include "duckdb/common/types/row/tuple_data_states.hpp" -#include "duckdb/storage/buffer/block_handle.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -TupleDataBlock::TupleDataBlock(BufferManager &buffer_manager, idx_t capacity_p) : capacity(capacity_p), size(0) { - auto buffer_handle = buffer_manager.Allocate(MemoryTag::HASH_TABLE, capacity, false); - handle = buffer_handle.GetBlockHandle(); -} - -TupleDataBlock::TupleDataBlock(TupleDataBlock &&other) noexcept : capacity(0), size(0) { - std::swap(handle, other.handle); - std::swap(capacity, other.capacity); - std::swap(size, other.size); -} - -TupleDataBlock &TupleDataBlock::operator=(TupleDataBlock &&other) noexcept { - std::swap(handle, other.handle); - std::swap(capacity, other.capacity); - std::swap(size, other.size); - return *this; -} - -TupleDataAllocator::TupleDataAllocator(BufferManager &buffer_manager, const TupleDataLayout &layout) - : buffer_manager(buffer_manager), layout(layout.Copy()) { -} - -TupleDataAllocator::TupleDataAllocator(TupleDataAllocator &allocator) - : buffer_manager(allocator.buffer_manager), layout(allocator.layout.Copy()) { -} - -void TupleDataAllocator::SetDestroyBufferUponUnpin() { - for (auto &block : row_blocks) { - if (block.handle) { - block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); - } - } - for (auto &block : heap_blocks) { - if (block.handle) { - block.handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); - } - } -} - -TupleDataAllocator::~TupleDataAllocator() { - SetDestroyBufferUponUnpin(); -} - -BufferManager &TupleDataAllocator::GetBufferManager() { - return buffer_manager; -} - -Allocator &TupleDataAllocator::GetAllocator() { - return buffer_manager.GetBufferAllocator(); -} - -const TupleDataLayout &TupleDataAllocator::GetLayout() const { - return layout; -} - -idx_t TupleDataAllocator::RowBlockCount() const { - return row_blocks.size(); -} - -idx_t TupleDataAllocator::HeapBlockCount() const { - return heap_blocks.size(); -} - -void TupleDataAllocator::SetPartitionIndex(const idx_t index) { - D_ASSERT(!partition_index.IsValid()); - D_ASSERT(row_blocks.empty() && heap_blocks.empty()); - partition_index = index; -} - -void TupleDataAllocator::Build(TupleDataSegment &segment, TupleDataPinState &pin_state, - TupleDataChunkState &chunk_state, const idx_t append_offset, const idx_t append_count) { - D_ASSERT(this == segment.allocator.get()); - auto &chunks = segment.chunks; - if (!chunks.empty()) { - ReleaseOrStoreHandles(pin_state, segment, chunks.back(), true); - } - - // Build the chunk parts for the incoming data - chunk_part_indices.clear(); - idx_t offset = 0; - while (offset != append_count) { - if (chunks.empty() || chunks.back().count == STANDARD_VECTOR_SIZE) { - chunks.emplace_back(); - } - auto &chunk = chunks.back(); - - // Build the next part - auto next = MinValue(append_count - offset, STANDARD_VECTOR_SIZE - chunk.count); - chunk.AddPart(BuildChunkPart(pin_state, chunk_state, append_offset + offset, next, chunk), layout); - auto &chunk_part = chunk.parts.back(); - next = chunk_part.count; - - segment.count += next; - segment.data_size += chunk_part.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { - segment.data_size += chunk_part.total_heap_size; - } - - if (layout.HasDestructor()) { - const auto base_row_ptr = GetRowPointer(pin_state, chunk_part); - for (auto &aggr_idx : layout.GetAggregateDestructorIndices()) { - const auto aggr_offset = layout.GetOffsets()[layout.ColumnCount() + aggr_idx]; - auto &aggr_fun = layout.GetAggregates()[aggr_idx]; - for (idx_t i = 0; i < next; i++) { - duckdb::FastMemset(base_row_ptr + i * layout.GetRowWidth() + aggr_offset, '\0', - aggr_fun.payload_size); - } - } - } - - offset += next; - chunk_part_indices.emplace_back(chunks.size() - 1, chunk.parts.size() - 1); - } - - // Now initialize the pointers to write the data to - chunk_parts.clear(); - for (auto &indices : chunk_part_indices) { - chunk_parts.emplace_back(segment.chunks[indices.first].parts[indices.second]); - } - InitializeChunkStateInternal(pin_state, chunk_state, append_offset, false, true, false, chunk_parts); - - // To reduce metadata, we try to merge chunk parts where possible - // Due to the way chunk parts are constructed, only the last part of the first chunk is eligible for merging - segment.chunks[chunk_part_indices[0].first].MergeLastChunkPart(layout); - - segment.Verify(); -} - -TupleDataChunkPart TupleDataAllocator::BuildChunkPart(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count, - TupleDataChunk &chunk) { - D_ASSERT(append_count != 0); - TupleDataChunkPart result(*chunk.lock); - const auto block_size = buffer_manager.GetBlockSize(); - - // Allocate row block (if needed) - if (row_blocks.empty() || row_blocks.back().RemainingCapacity() < layout.GetRowWidth()) { - row_blocks.emplace_back(buffer_manager, block_size); - if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits - row_blocks.back().handle->SetEvictionQueueIndex(RadixPartitioning::RadixBits(partition_index.GetIndex())); - } - } - result.row_block_index = NumericCast(row_blocks.size() - 1); - auto &row_block = row_blocks[result.row_block_index]; - result.row_block_offset = NumericCast(row_block.size); - - // Set count (might be reduced later when checking heap space) - result.count = NumericCast(MinValue(row_block.RemainingCapacity(layout.GetRowWidth()), append_count)); - if (!layout.AllConstant()) { - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - - // Compute total heap size first - idx_t total_heap_size = 0; - for (idx_t i = 0; i < result.count; i++) { - const auto &heap_size = heap_sizes[append_offset + i]; - total_heap_size += heap_size; - } - - if (total_heap_size == 0) { - result.SetHeapEmpty(); - } else { - const auto heap_remaining = MaxValue( - heap_blocks.empty() ? block_size : heap_blocks.back().RemainingCapacity(), heap_sizes[append_offset]); - - if (total_heap_size <= heap_remaining) { - // Everything fits - result.total_heap_size = NumericCast(total_heap_size); - } else { - // Not everything fits - determine how many we can read next - result.total_heap_size = 0; - for (idx_t i = 0; i < result.count; i++) { - const auto &heap_size = heap_sizes[append_offset + i]; - if (result.total_heap_size + heap_size > heap_remaining) { - result.count = NumericCast(i); - break; - } - result.total_heap_size += heap_size; - } - } - - if (result.total_heap_size == 0) { - result.SetHeapEmpty(); - } else { - // Allocate heap block (if needed) - if (heap_blocks.empty() || heap_blocks.back().RemainingCapacity() < heap_sizes[append_offset]) { - const auto size = MaxValue(block_size, heap_sizes[append_offset]); - heap_blocks.emplace_back(buffer_manager, size); - if (partition_index.IsValid()) { // Set the eviction queue index logarithmically using RadixBits - heap_blocks.back().handle->SetEvictionQueueIndex( - RadixPartitioning::RadixBits(partition_index.GetIndex())); - } - } - result.heap_block_index = NumericCast(heap_blocks.size() - 1); - auto &heap_block = heap_blocks[result.heap_block_index]; - result.heap_block_offset = NumericCast(heap_block.size); - - // Mark this portion of the heap block as filled and set the pointer - heap_block.size += result.total_heap_size; - result.base_heap_ptr = GetBaseHeapPointer(pin_state, result); - } - } - } - D_ASSERT(result.count != 0 && result.count <= STANDARD_VECTOR_SIZE); - - // Mark this portion of the row block as filled - row_block.size += result.count * layout.GetRowWidth(); - - return result; -} - -void TupleDataAllocator::InitializeChunkState(TupleDataSegment &segment, TupleDataPinState &pin_state, - TupleDataChunkState &chunk_state, idx_t chunk_idx, bool init_heap) { - D_ASSERT(this == segment.allocator.get()); - D_ASSERT(chunk_idx < segment.ChunkCount()); - auto &chunk = segment.chunks[chunk_idx]; - - // Release or store any handles that are no longer required: - // We can't release the heap here if the current chunk's heap_block_ids is empty, because if we are iterating with - // PinProperties::DESTROY_AFTER_DONE, we might destroy a heap block that is needed by a later chunk, e.g., - // when chunk 0 needs heap block 0, chunk 1 does not need any heap blocks, and chunk 2 needs heap block 0 again - ReleaseOrStoreHandles(pin_state, segment, chunk, !chunk.heap_block_ids.empty()); - - unsafe_vector> parts; - parts.reserve(chunk.parts.size()); - for (auto &part : chunk.parts) { - parts.emplace_back(part); - } - - InitializeChunkStateInternal(pin_state, chunk_state, 0, true, init_heap, init_heap, parts); -} - -static inline void InitializeHeapSizes(const data_ptr_t row_locations[], idx_t heap_sizes[], const idx_t offset, - const idx_t next, const TupleDataChunkPart &part, const idx_t heap_size_offset) { - // Read the heap sizes from the rows - for (idx_t i = 0; i < next; i++) { - auto idx = offset + i; - heap_sizes[idx] = Load(row_locations[idx] + heap_size_offset); - } - - // Verify total size -#ifdef DEBUG - idx_t total_heap_size = 0; - for (idx_t i = 0; i < next; i++) { - auto idx = offset + i; - total_heap_size += heap_sizes[idx]; - } - D_ASSERT(total_heap_size == part.total_heap_size); -#endif -} - -void TupleDataAllocator::InitializeChunkStateInternal(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - idx_t offset, bool recompute, bool init_heap_pointers, - bool init_heap_sizes, - unsafe_vector> &parts) { - auto row_locations = FlatVector::GetData(chunk_state.row_locations); - auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); - - for (auto &part_ref : parts) { - auto &part = part_ref.get(); - const auto next = part.count; - - // Set up row locations for the scan - const auto row_width = layout.GetRowWidth(); - const auto base_row_ptr = GetRowPointer(pin_state, part); - for (idx_t i = 0; i < next; i++) { - row_locations[offset + i] = base_row_ptr + i * row_width; - } - - if (layout.AllConstant()) { // Can't have a heap - offset += next; - continue; - } - - if (part.total_heap_size == 0) { - if (init_heap_sizes) { // No heap, but we need the heap sizes - InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); - } - offset += next; - continue; - } - - // Check if heap block has changed - re-compute the pointers within each row if so - if (recompute && pin_state.properties != TupleDataPinProperties::ALREADY_PINNED) { - const auto new_base_heap_ptr = GetBaseHeapPointer(pin_state, part); - if (part.base_heap_ptr != new_base_heap_ptr) { - lock_guard guard(part.lock); - const auto old_base_heap_ptr = part.base_heap_ptr; - if (old_base_heap_ptr != new_base_heap_ptr) { - Vector old_heap_ptrs( - Value::POINTER(CastPointerToValue(old_base_heap_ptr + part.heap_block_offset))); - Vector new_heap_ptrs( - Value::POINTER(CastPointerToValue(new_base_heap_ptr + part.heap_block_offset))); - RecomputeHeapPointers(old_heap_ptrs, *ConstantVector::ZeroSelectionVector(), row_locations, - new_heap_ptrs, offset, next, layout, 0); - part.base_heap_ptr = new_base_heap_ptr; - } - } - } - - if (init_heap_sizes) { - InitializeHeapSizes(row_locations, heap_sizes, offset, next, part, layout.GetHeapSizeOffset()); - } - - if (init_heap_pointers) { - // Set the pointers where the heap data will be written (if needed) - heap_locations[offset] = part.base_heap_ptr + part.heap_block_offset; - for (idx_t i = 1; i < next; i++) { - auto idx = offset + i; - heap_locations[idx] = heap_locations[idx - 1] + heap_sizes[idx - 1]; - } - } - - offset += next; - } - D_ASSERT(offset <= STANDARD_VECTOR_SIZE); -} - -static inline void VerifyStrings(const TupleDataLayout &layout, const LogicalTypeId type_id, - const data_ptr_t row_locations[], const idx_t col_idx, const idx_t base_col_offset, - const idx_t col_offset, const idx_t offset, const idx_t count) { -#ifdef DEBUG - if (type_id != LogicalTypeId::VARCHAR) { - // Make sure we don't verify BLOB / AGGREGATE_STATE - return; - } - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - for (idx_t i = 0; i < count; i++) { - const auto &row_location = row_locations[offset + i] + base_col_offset; - ValidityBytes row_mask(row_location, layout.ColumnCount()); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - auto recomputed_string = Load(row_location + col_offset); - recomputed_string.Verify(); - } - } -#endif -} - -void TupleDataAllocator::RecomputeHeapPointers(Vector &old_heap_ptrs, const SelectionVector &old_heap_sel, - const data_ptr_t row_locations[], Vector &new_heap_ptrs, - const idx_t offset, const idx_t count, const TupleDataLayout &layout, - const idx_t base_col_offset) { - const auto old_heap_locations = FlatVector::GetData(old_heap_ptrs); - - UnifiedVectorFormat new_heap_data; - new_heap_ptrs.ToUnifiedFormat(offset + count, new_heap_data); - const auto new_heap_locations = UnifiedVectorFormat::GetData(new_heap_data); - const auto new_heap_sel = *new_heap_data.sel; - - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - const auto &col_offset = layout.GetOffsets()[col_idx]; - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto &type = layout.GetTypes()[col_idx]; - switch (type.InternalType()) { - case PhysicalType::VARCHAR: { - for (idx_t i = 0; i < count; i++) { - const auto idx = offset + i; - const auto &row_location = row_locations[idx] + base_col_offset; - ValidityBytes row_mask(row_location, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - continue; - } - - const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; - const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; - - const auto string_location = row_location + col_offset; - if (Load(string_location) > string_t::INLINE_LENGTH) { - const auto string_ptr_location = string_location + string_t::HEADER_SIZE; - const auto string_ptr = Load(string_ptr_location); - const auto diff = string_ptr - old_heap_ptr; - D_ASSERT(diff >= 0); - Store(new_heap_ptr + diff, string_ptr_location); - } - } - VerifyStrings(layout, type.id(), row_locations, col_idx, base_col_offset, col_offset, offset, count); - break; - } - case PhysicalType::LIST: - case PhysicalType::ARRAY: { - for (idx_t i = 0; i < count; i++) { - const auto idx = offset + i; - const auto &row_location = row_locations[idx] + base_col_offset; - ValidityBytes row_mask(row_location, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - continue; - } - - const auto &old_heap_ptr = old_heap_locations[old_heap_sel.get_index(idx)]; - const auto &new_heap_ptr = new_heap_locations[new_heap_sel.get_index(idx)]; - - const auto &list_ptr_location = row_location + col_offset; - const auto list_ptr = Load(list_ptr_location); - const auto diff = list_ptr - old_heap_ptr; - D_ASSERT(diff >= 0); - Store(new_heap_ptr + diff, list_ptr_location); - } - break; - } - case PhysicalType::STRUCT: { - const auto &struct_layout = layout.GetStructLayout(col_idx); - if (!struct_layout.AllConstant()) { - RecomputeHeapPointers(old_heap_ptrs, old_heap_sel, row_locations, new_heap_ptrs, offset, count, - struct_layout, base_col_offset + col_offset); - } - break; - } - default: - continue; - } - } -} - -void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment, - TupleDataChunk &chunk, bool release_heap) { - D_ASSERT(this == segment.allocator.get()); - ReleaseOrStoreHandlesInternal(segment, segment.pinned_row_handles, pin_state.row_handles, chunk.row_block_ids, - row_blocks, pin_state.properties); - if (!layout.AllConstant() && release_heap) { - ReleaseOrStoreHandlesInternal(segment, segment.pinned_heap_handles, pin_state.heap_handles, - chunk.heap_block_ids, heap_blocks, pin_state.properties); - } -} - -void TupleDataAllocator::ReleaseOrStoreHandles(TupleDataPinState &pin_state, TupleDataSegment &segment) { - static TupleDataChunk DUMMY_CHUNK; - ReleaseOrStoreHandles(pin_state, segment, DUMMY_CHUNK, true); -} - -void TupleDataAllocator::ReleaseOrStoreHandlesInternal( - TupleDataSegment &segment, unsafe_vector &pinned_handles, perfect_map_t &handles, - const perfect_set_t &block_ids, unsafe_vector &blocks, TupleDataPinProperties properties) { - bool found_handle; - do { - found_handle = false; - for (auto it = handles.begin(); it != handles.end(); it++) { - const auto block_id = it->first; - if (block_ids.find(block_id) != block_ids.end()) { - // still required: do not release - continue; - } - switch (properties) { - case TupleDataPinProperties::KEEP_EVERYTHING_PINNED: { - lock_guard guard(segment.pinned_handles_lock); - const auto block_count = block_id + 1; - if (block_count > pinned_handles.size()) { - pinned_handles.resize(block_count); - } - pinned_handles[block_id] = std::move(it->second); - break; - } - case TupleDataPinProperties::UNPIN_AFTER_DONE: - case TupleDataPinProperties::ALREADY_PINNED: - break; - case TupleDataPinProperties::DESTROY_AFTER_DONE: - // Prevent it from being added to the eviction queue - blocks[block_id].handle->SetDestroyBufferUpon(DestroyBufferUpon::UNPIN); - // Destroy - blocks[block_id].handle.reset(); - break; - default: - D_ASSERT(properties == TupleDataPinProperties::INVALID); - throw InternalException("Encountered TupleDataPinProperties::INVALID"); - } - handles.erase(it); - found_handle = true; - break; - } - } while (found_handle); -} - -BufferHandle &TupleDataAllocator::PinRowBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - const auto &row_block_index = part.row_block_index; - auto it = pin_state.row_handles.find(row_block_index); - if (it == pin_state.row_handles.end()) { - D_ASSERT(row_block_index < row_blocks.size()); - auto &row_block = row_blocks[row_block_index]; - D_ASSERT(row_block.handle); - D_ASSERT(part.row_block_offset < row_block.size); - D_ASSERT(part.row_block_offset + part.count * layout.GetRowWidth() <= row_block.size); - it = pin_state.row_handles.emplace(row_block_index, buffer_manager.Pin(row_block.handle)).first; - } - return it->second; -} - -BufferHandle &TupleDataAllocator::PinHeapBlock(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - const auto &heap_block_index = part.heap_block_index; - auto it = pin_state.heap_handles.find(heap_block_index); - if (it == pin_state.heap_handles.end()) { - D_ASSERT(heap_block_index < heap_blocks.size()); - auto &heap_block = heap_blocks[heap_block_index]; - D_ASSERT(heap_block.handle); - D_ASSERT(part.heap_block_offset < heap_block.size); - D_ASSERT(part.heap_block_offset + part.total_heap_size <= heap_block.size); - it = pin_state.heap_handles.emplace(heap_block_index, buffer_manager.Pin(heap_block.handle)).first; - } - return it->second; -} - -data_ptr_t TupleDataAllocator::GetRowPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - return PinRowBlock(pin_state, part).Ptr() + part.row_block_offset; -} - -data_ptr_t TupleDataAllocator::GetBaseHeapPointer(TupleDataPinState &pin_state, const TupleDataChunkPart &part) { - return PinHeapBlock(pin_state, part).Ptr(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_collection.cpp b/src/duckdb/src/common/types/row/tuple_data_collection.cpp deleted file mode 100644 index 7db3ba37a..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_collection.cpp +++ /dev/null @@ -1,611 +0,0 @@ -#include "duckdb/common/types/row/tuple_data_collection.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/common/types/row/tuple_data_allocator.hpp" - -#include - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -TupleDataCollection::TupleDataCollection(BufferManager &buffer_manager, const TupleDataLayout &layout_p) - : layout(layout_p.Copy()), allocator(make_shared_ptr(buffer_manager, layout)) { - Initialize(); -} - -TupleDataCollection::TupleDataCollection(shared_ptr allocator) - : layout(allocator->GetLayout().Copy()), allocator(std::move(allocator)) { - Initialize(); -} - -TupleDataCollection::~TupleDataCollection() { -} - -void TupleDataCollection::Initialize() { - D_ASSERT(!layout.GetTypes().empty()); - this->count = 0; - this->data_size = 0; - scatter_functions.reserve(layout.ColumnCount()); - gather_functions.reserve(layout.ColumnCount()); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - auto &type = layout.GetTypes()[col_idx]; - scatter_functions.emplace_back(GetScatterFunction(type)); - gather_functions.emplace_back(GetGatherFunction(type)); - } -} - -void GetAllColumnIDsInternal(vector &column_ids, const idx_t column_count) { - column_ids.reserve(column_count); - for (idx_t col_idx = 0; col_idx < column_count; col_idx++) { - column_ids.emplace_back(col_idx); - } -} - -void TupleDataCollection::GetAllColumnIDs(vector &column_ids) { - GetAllColumnIDsInternal(column_ids, layout.ColumnCount()); -} - -const TupleDataLayout &TupleDataCollection::GetLayout() const { - return layout; -} - -const idx_t &TupleDataCollection::Count() const { - return count; -} - -idx_t TupleDataCollection::ChunkCount() const { - idx_t total_chunk_count = 0; - for (const auto &segment : segments) { - total_chunk_count += segment.ChunkCount(); - } - return total_chunk_count; -} - -idx_t TupleDataCollection::SizeInBytes() const { - idx_t total_size = 0; - for (const auto &segment : segments) { - total_size += segment.SizeInBytes(); - } - return total_size; -} - -void TupleDataCollection::Unpin() { - for (auto &segment : segments) { - segment.Unpin(); - } -} - -void TupleDataCollection::SetPartitionIndex(const idx_t index) { - D_ASSERT(!partition_index.IsValid()); - D_ASSERT(Count() == 0); - partition_index = index; - allocator->SetPartitionIndex(index); -} - -// LCOV_EXCL_START -void VerifyAppendColumns(const TupleDataLayout &layout, const vector &column_ids) { -#ifdef DEBUG - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - if (std::find(column_ids.begin(), column_ids.end(), col_idx) != column_ids.end()) { - continue; - } - // This column will not be appended in the first go - verify that it is fixed-size - we cannot resize heap after - const auto physical_type = layout.GetTypes()[col_idx].InternalType(); - D_ASSERT(physical_type != PhysicalType::VARCHAR && physical_type != PhysicalType::LIST && - physical_type != PhysicalType::ARRAY); - if (physical_type == PhysicalType::STRUCT) { - const auto &struct_layout = layout.GetStructLayout(col_idx); - vector struct_column_ids; - struct_column_ids.reserve(struct_layout.ColumnCount()); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - struct_column_ids.emplace_back(struct_col_idx); - } - VerifyAppendColumns(struct_layout, struct_column_ids); - } - } -#endif -} -// LCOV_EXCL_STOP - -void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, TupleDataPinProperties properties) { - vector column_ids; - GetAllColumnIDs(column_ids); - InitializeAppend(append_state, std::move(column_ids), properties); -} - -void TupleDataCollection::InitializeAppend(TupleDataAppendState &append_state, vector column_ids, - TupleDataPinProperties properties) { - VerifyAppendColumns(layout, column_ids); - InitializeAppend(append_state.pin_state, properties); - InitializeChunkState(append_state.chunk_state, std::move(column_ids)); -} - -void TupleDataCollection::InitializeAppend(TupleDataPinState &pin_state, TupleDataPinProperties properties) { - pin_state.properties = properties; - if (segments.empty()) { - segments.emplace_back(allocator); - } -} - -static void InitializeVectorFormat(vector &vector_data, const vector &types) { - vector_data.resize(types.size()); - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - const auto &child_list = StructType::GetChildTypes(type); - vector child_types; - child_types.reserve(child_list.size()); - for (const auto &child_entry : child_list) { - child_types.emplace_back(child_entry.second); - } - InitializeVectorFormat(vector_data[col_idx].children, child_types); - break; - } - case PhysicalType::LIST: - InitializeVectorFormat(vector_data[col_idx].children, {ListType::GetChildType(type)}); - break; - case PhysicalType::ARRAY: - InitializeVectorFormat(vector_data[col_idx].children, {ArrayType::GetChildType(type)}); - break; - default: - break; - } - } -} - -void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, vector column_ids) { - TupleDataCollection::InitializeChunkState(chunk_state, layout.GetTypes(), std::move(column_ids)); -} - -void TupleDataCollection::InitializeChunkState(TupleDataChunkState &chunk_state, const vector &types, - vector column_ids) { - if (column_ids.empty()) { - GetAllColumnIDsInternal(column_ids, types.size()); - } - InitializeVectorFormat(chunk_state.vector_data, types); - - chunk_state.cached_cast_vectors.clear(); - chunk_state.cached_cast_vector_cache.clear(); - for (auto &col : column_ids) { - auto &type = types[col]; - if (TypeVisitor::Contains(type, LogicalTypeId::ARRAY)) { - auto cast_type = ArrayType::ConvertToList(type); - chunk_state.cached_cast_vector_cache.push_back( - make_uniq(Allocator::DefaultAllocator(), cast_type)); - chunk_state.cached_cast_vectors.push_back(make_uniq(*chunk_state.cached_cast_vector_cache.back())); - } else { - chunk_state.cached_cast_vectors.emplace_back(); - chunk_state.cached_cast_vector_cache.emplace_back(); - } - } - - chunk_state.column_ids = std::move(column_ids); -} - -void TupleDataCollection::Append(DataChunk &new_chunk, const SelectionVector &append_sel, idx_t append_count) { - TupleDataAppendState append_state; - InitializeAppend(append_state); - Append(append_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(DataChunk &new_chunk, vector column_ids, const SelectionVector &append_sel, - const idx_t append_count) { - TupleDataAppendState append_state; - InitializeAppend(append_state, std::move(column_ids)); - Append(append_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(TupleDataAppendState &append_state, DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { - Append(append_state.pin_state, append_state.chunk_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::Append(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { - TupleDataCollection::ToUnifiedFormat(chunk_state, new_chunk); - AppendUnified(pin_state, chunk_state, new_chunk, append_sel, append_count); -} - -void TupleDataCollection::AppendUnified(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - DataChunk &new_chunk, const SelectionVector &append_sel, - const idx_t append_count) { - const idx_t actual_append_count = append_count == DConstants::INVALID_INDEX ? new_chunk.size() : append_count; - if (actual_append_count == 0) { - return; - } - - if (!layout.AllConstant()) { - TupleDataCollection::ComputeHeapSizes(chunk_state, new_chunk, append_sel, actual_append_count); - } - - Build(pin_state, chunk_state, 0, actual_append_count); - Scatter(chunk_state, new_chunk, append_sel, actual_append_count); -} - -static inline void ToUnifiedFormatInternal(TupleDataVectorFormat &format, Vector &vector, const idx_t count) { - vector.ToUnifiedFormat(count, format.unified); - format.original_sel = format.unified.sel; - format.original_owned_sel.Initialize(format.unified.owned_sel); - switch (vector.GetType().InternalType()) { - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(vector); - D_ASSERT(format.children.size() == entries.size()); - for (idx_t struct_col_idx = 0; struct_col_idx < entries.size(); struct_col_idx++) { - ToUnifiedFormatInternal(format.children[struct_col_idx], *entries[struct_col_idx], count); - } - break; - } - case PhysicalType::LIST: - D_ASSERT(format.children.size() == 1); - ToUnifiedFormatInternal(format.children[0], ListVector::GetEntry(vector), ListVector::GetListSize(vector)); - break; - case PhysicalType::ARRAY: { - D_ASSERT(format.children.size() == 1); - - // For arrays, we cheat a bit and pretend that they are lists by creating and assigning list_entry_t's to the - // vector This allows us to reuse all the list serialization functions for array types too. - auto array_size = ArrayType::GetSize(vector.GetType()); - - // How many list_entry_t's do we need to cover the whole child array? - // Make sure we round up so its all covered - auto child_array_total_size = ArrayVector::GetTotalSize(vector); - auto list_entry_t_count = - MaxValue((child_array_total_size + array_size) / array_size, format.unified.validity.Capacity()); - - // Create list entries! - format.array_list_entries = make_unsafe_uniq_array(list_entry_t_count); - for (idx_t i = 0; i < list_entry_t_count; i++) { - format.array_list_entries[i].length = array_size; - format.array_list_entries[i].offset = i * array_size; - } - format.unified.data = reinterpret_cast(format.array_list_entries.get()); - - ToUnifiedFormatInternal(format.children[0], ArrayVector::GetEntry(vector), child_array_total_size); - break; - } - default: - break; - } -} - -void TupleDataCollection::ToUnifiedFormat(TupleDataChunkState &chunk_state, DataChunk &new_chunk) { - D_ASSERT(chunk_state.vector_data.size() >= chunk_state.column_ids.size()); // Needs InitializeAppend - for (const auto &col_idx : chunk_state.column_ids) { - ToUnifiedFormatInternal(chunk_state.vector_data[col_idx], new_chunk.data[col_idx], new_chunk.size()); - } -} - -void TupleDataCollection::GetVectorData(const TupleDataChunkState &chunk_state, UnifiedVectorFormat result[]) { - const auto &vector_data = chunk_state.vector_data; - for (idx_t i = 0; i < vector_data.size(); i++) { - const auto &source = vector_data[i].unified; - auto &target = result[i]; - target.sel = source.sel; - target.data = source.data; - target.validity = source.validity; - } -} - -void TupleDataCollection::Build(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const idx_t append_offset, const idx_t append_count) { - auto &segment = segments.back(); - const auto size_before = segment.SizeInBytes(); - segment.allocator->Build(segment, pin_state, chunk_state, append_offset, append_count); - data_size += segment.SizeInBytes() - size_before; - count += append_count; - Verify(); -} - -// LCOV_EXCL_START -void VerifyHeapSizes(const data_ptr_t source_locations[], const idx_t heap_sizes[], const SelectionVector &append_sel, - const idx_t append_count, const idx_t heap_size_offset) { -#ifdef DEBUG - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - const auto stored_heap_size = Load(source_locations[idx] + heap_size_offset); - D_ASSERT(stored_heap_size == heap_sizes[idx]); - } -#endif -} -// LCOV_EXCL_STOP - -void TupleDataCollection::CopyRows(TupleDataChunkState &chunk_state, TupleDataChunkState &input, - const SelectionVector &append_sel, const idx_t append_count) const { - const auto source_locations = FlatVector::GetData(input.row_locations); - const auto target_locations = FlatVector::GetData(chunk_state.row_locations); - - // Copy rows - const auto row_width = layout.GetRowWidth(); - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - FastMemcpy(target_locations[i], source_locations[idx], row_width); - } - - // Copy heap if we need to - if (!layout.AllConstant()) { - const auto source_heap_locations = FlatVector::GetData(input.heap_locations); - const auto target_heap_locations = FlatVector::GetData(chunk_state.heap_locations); - const auto heap_sizes = FlatVector::GetData(input.heap_sizes); - VerifyHeapSizes(source_locations, heap_sizes, append_sel, append_count, layout.GetHeapSizeOffset()); - - // Check if we need to copy anything at all - idx_t total_heap_size = 0; - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - total_heap_size += heap_sizes[idx]; - } - if (total_heap_size == 0) { - return; - } - - // Copy heap - for (idx_t i = 0; i < append_count; i++) { - auto idx = append_sel.get_index(i); - FastMemcpy(target_heap_locations[i], source_heap_locations[idx], heap_sizes[idx]); - } - - // Recompute pointers after copying the data - TupleDataAllocator::RecomputeHeapPointers(input.heap_locations, append_sel, target_locations, - chunk_state.heap_locations, 0, append_count, layout, 0); - } -} - -void TupleDataCollection::Combine(TupleDataCollection &other) { - if (other.count == 0) { - return; - } - if (this->layout.GetTypes() != other.GetLayout().GetTypes()) { - throw InternalException("Attempting to combine TupleDataCollection with mismatching types"); - } - this->segments.reserve(this->segments.size() + other.segments.size()); - for (auto &other_seg : other.segments) { - AddSegment(std::move(other_seg)); - } - other.Reset(); -} - -void TupleDataCollection::AddSegment(TupleDataSegment &&segment) { - count += segment.count; - data_size += segment.data_size; - segments.emplace_back(std::move(segment)); - Verify(); -} - -void TupleDataCollection::Combine(unique_ptr other) { - Combine(*other); -} - -void TupleDataCollection::Reset() { - count = 0; - data_size = 0; - segments.clear(); - - // Refreshes the TupleDataAllocator to prevent holding on to allocated data unnecessarily - allocator = make_shared_ptr(*allocator); -} - -void TupleDataCollection::InitializeChunk(DataChunk &chunk) const { - chunk.Initialize(allocator->GetAllocator(), layout.GetTypes()); -} - -void TupleDataCollection::InitializeChunk(DataChunk &chunk, const vector &columns) const { - vector chunk_types(columns.size()); - // keep the order of the columns - for (idx_t i = 0; i < columns.size(); i++) { - auto column_idx = columns[i]; - D_ASSERT(column_idx < layout.ColumnCount()); - chunk_types[i] = layout.GetTypes()[column_idx]; - } - chunk.Initialize(allocator->GetAllocator(), chunk_types); -} - -void TupleDataCollection::InitializeScanChunk(TupleDataScanState &state, DataChunk &chunk) const { - auto &column_ids = state.chunk_state.column_ids; - D_ASSERT(!column_ids.empty()); - vector chunk_types; - chunk_types.reserve(column_ids.size()); - for (idx_t i = 0; i < column_ids.size(); i++) { - auto column_idx = column_ids[i]; - D_ASSERT(column_idx < layout.ColumnCount()); - chunk_types.push_back(layout.GetTypes()[column_idx]); - } - chunk.Initialize(allocator->GetAllocator(), chunk_types); -} - -void TupleDataCollection::InitializeScan(TupleDataScanState &state, TupleDataPinProperties properties) const { - vector column_ids; - column_ids.reserve(layout.ColumnCount()); - for (idx_t i = 0; i < layout.ColumnCount(); i++) { - column_ids.push_back(i); - } - InitializeScan(state, std::move(column_ids), properties); -} - -void TupleDataCollection::InitializeScan(TupleDataScanState &state, vector column_ids, - TupleDataPinProperties properties) const { - state.pin_state.row_handles.clear(); - state.pin_state.heap_handles.clear(); - state.pin_state.properties = properties; - state.segment_index = 0; - state.chunk_index = 0; - - auto &chunk_state = state.chunk_state; - - for (auto &col : column_ids) { - auto &type = layout.GetTypes()[col]; - - if (TypeVisitor::Contains(type, LogicalTypeId::ARRAY)) { - auto cast_type = ArrayType::ConvertToList(type); - chunk_state.cached_cast_vector_cache.push_back( - make_uniq(Allocator::DefaultAllocator(), cast_type)); - chunk_state.cached_cast_vectors.push_back(make_uniq(*chunk_state.cached_cast_vector_cache.back())); - } else { - chunk_state.cached_cast_vectors.emplace_back(); - chunk_state.cached_cast_vector_cache.emplace_back(); - } - } - - state.chunk_state.column_ids = std::move(column_ids); -} - -void TupleDataCollection::InitializeScan(TupleDataParallelScanState &gstate, TupleDataPinProperties properties) const { - InitializeScan(gstate.scan_state, properties); -} - -void TupleDataCollection::InitializeScan(TupleDataParallelScanState &state, vector column_ids, - TupleDataPinProperties properties) const { - InitializeScan(state.scan_state, std::move(column_ids), properties); -} - -bool TupleDataCollection::Scan(TupleDataScanState &state, DataChunk &result) { - const auto segment_index_before = state.segment_index; - idx_t segment_index; - idx_t chunk_index; - if (!NextScanIndex(state, segment_index, chunk_index)) { - if (!segments.empty()) { - FinalizePinState(state.pin_state, segments[segment_index_before]); - } - result.SetCardinality(0); - return false; - } - if (segment_index_before != DConstants::INVALID_INDEX && segment_index != segment_index_before) { - FinalizePinState(state.pin_state, segments[segment_index_before]); - } - ScanAtIndex(state.pin_state, state.chunk_state, state.chunk_state.column_ids, segment_index, chunk_index, result); - return true; -} - -bool TupleDataCollection::Scan(TupleDataParallelScanState &gstate, TupleDataLocalScanState &lstate, DataChunk &result) { - lstate.pin_state.properties = gstate.scan_state.pin_state.properties; - - const auto segment_index_before = lstate.segment_index; - { - lock_guard guard(gstate.lock); - if (!NextScanIndex(gstate.scan_state, lstate.segment_index, lstate.chunk_index)) { - if (!segments.empty()) { - FinalizePinState(lstate.pin_state, segments[segment_index_before]); - } - result.SetCardinality(0); - return false; - } - } - if (segment_index_before != DConstants::INVALID_INDEX && segment_index_before != lstate.segment_index) { - FinalizePinState(lstate.pin_state, segments[lstate.segment_index]); - } - ScanAtIndex(lstate.pin_state, lstate.chunk_state, gstate.scan_state.chunk_state.column_ids, lstate.segment_index, - lstate.chunk_index, result); - return true; -} - -bool TupleDataCollection::ScanComplete(const TupleDataScanState &state) const { - if (Count() == 0) { - return true; - } - return state.segment_index == segments.size() - 1 && state.chunk_index == segments.back().ChunkCount(); -} - -void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state, TupleDataSegment &segment) { - segment.allocator->ReleaseOrStoreHandles(pin_state, segment); -} - -void TupleDataCollection::FinalizePinState(TupleDataPinState &pin_state) { - D_ASSERT(!segments.empty()); - FinalizePinState(pin_state, segments.back()); -} - -bool TupleDataCollection::NextScanIndex(TupleDataScanState &state, idx_t &segment_index, idx_t &chunk_index) { - // Check if we still have segments to scan - if (state.segment_index >= segments.size()) { - // No more data left in the scan - return false; - } - // Check within the current segment if we still have chunks to scan - while (state.chunk_index >= segments[state.segment_index].ChunkCount()) { - // Exhausted all chunks for this segment: Move to the next one - state.segment_index++; - state.chunk_index = 0; - if (state.segment_index >= segments.size()) { - return false; - } - } - segment_index = state.segment_index; - chunk_index = state.chunk_index++; - return true; -} -void TupleDataCollection::ScanAtIndex(TupleDataPinState &pin_state, TupleDataChunkState &chunk_state, - const vector &column_ids, idx_t segment_index, idx_t chunk_index, - DataChunk &result) { - auto &segment = segments[segment_index]; - auto &chunk = segment.chunks[chunk_index]; - segment.allocator->InitializeChunkState(segment, pin_state, chunk_state, chunk_index, false); - result.Reset(); - - ResetCachedCastVectors(chunk_state, column_ids); - Gather(chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), chunk.count, column_ids, result, - *FlatVector::IncrementalSelectionVector(), chunk_state.cached_cast_vectors); - result.SetCardinality(chunk.count); -} - -void TupleDataCollection::ResetCachedCastVectors(TupleDataChunkState &chunk_state, const vector &column_ids) { - for (idx_t i = 0; i < column_ids.size(); i++) { - if (chunk_state.cached_cast_vectors[i]) { - chunk_state.cached_cast_vectors[i]->ResetFromCache(*chunk_state.cached_cast_vector_cache[i]); - } - } -} - -// LCOV_EXCL_START -string TupleDataCollection::ToString() { - DataChunk chunk; - InitializeChunk(chunk); - - TupleDataScanState scan_state; - InitializeScan(scan_state); - - string result = StringUtil::Format("TupleDataCollection - [%llu Chunks, %llu Rows]\n", ChunkCount(), Count()); - idx_t chunk_idx = 0; - idx_t row_count = 0; - while (Scan(scan_state, chunk)) { - result += - StringUtil::Format("Chunk %llu - [Rows %llu - %llu]\n", chunk_idx, row_count, row_count + chunk.size()) + - chunk.ToString(); - chunk_idx++; - row_count += chunk.size(); - } - - return result; -} - -void TupleDataCollection::Print() { - Printer::Print(ToString()); -} - -void TupleDataCollection::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - idx_t total_size = 0; - for (const auto &segment : segments) { - segment.Verify(); - total_count += segment.count; - total_size += segment.data_size; - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -void TupleDataCollection::VerifyEverythingPinned() const { -#ifdef DEBUG - for (const auto &segment : segments) { - segment.VerifyEverythingPinned(); - } -#endif -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp b/src/duckdb/src/common/types/row/tuple_data_iterator.cpp deleted file mode 100644 index a209322aa..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_iterator.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "duckdb/common/types/row/tuple_data_iterator.hpp" - -#include "duckdb/common/types/row/tuple_data_allocator.hpp" - -namespace duckdb { - -TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties_p, - bool init_heap) - : TupleDataChunkIterator(collection_p, properties_p, 0, collection_p.ChunkCount(), init_heap) { -} - -TupleDataChunkIterator::TupleDataChunkIterator(TupleDataCollection &collection_p, TupleDataPinProperties properties, - idx_t chunk_idx_from, idx_t chunk_idx_to, bool init_heap_p) - : collection(collection_p), init_heap(init_heap_p) { - state.pin_state.properties = properties; - D_ASSERT(chunk_idx_from < chunk_idx_to); - D_ASSERT(chunk_idx_to <= collection.ChunkCount()); - idx_t overall_chunk_index = 0; - for (idx_t segment_idx = 0; segment_idx < collection.segments.size(); segment_idx++) { - const auto &segment = collection.segments[segment_idx]; - if (chunk_idx_from >= overall_chunk_index && chunk_idx_from <= overall_chunk_index + segment.ChunkCount()) { - // We start in this segment - start_segment_idx = segment_idx; - start_chunk_idx = chunk_idx_from - overall_chunk_index; - } - if (chunk_idx_to >= overall_chunk_index && chunk_idx_to <= overall_chunk_index + segment.ChunkCount()) { - // We end in this segment - end_segment_idx = segment_idx; - end_chunk_idx = chunk_idx_to - overall_chunk_index; - } - overall_chunk_index += segment.ChunkCount(); - } - - Reset(); -} - -void TupleDataChunkIterator::InitializeCurrentChunk() { - auto &segment = collection.segments[current_segment_idx]; - segment.allocator->InitializeChunkState(segment, state.pin_state, state.chunk_state, current_chunk_idx, init_heap); -} - -bool TupleDataChunkIterator::Done() const { - return current_segment_idx == end_segment_idx && current_chunk_idx == end_chunk_idx; -} - -bool TupleDataChunkIterator::Next() { - D_ASSERT(!Done()); // Check if called after already done - - // Set the next indices and checks if we're at the end of the collection - // NextScanIndex can go past this iterators 'end', so we have to check the indices again - const auto segment_idx_before = current_segment_idx; - if (!collection.NextScanIndex(state, current_segment_idx, current_chunk_idx) || Done()) { - // Drop pins / stores them if TupleDataPinProperties::KEEP_EVERYTHING_PINNED - collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); - current_segment_idx = end_segment_idx; - current_chunk_idx = end_chunk_idx; - return false; - } - - // Finalize pin state when moving from one segment to the next - if (current_segment_idx != segment_idx_before) { - collection.FinalizePinState(state.pin_state, collection.segments[segment_idx_before]); - } - - InitializeCurrentChunk(); - return true; -} - -void TupleDataChunkIterator::Reset() { - state.segment_index = start_segment_idx; - state.chunk_index = start_chunk_idx; - collection.NextScanIndex(state, current_segment_idx, current_chunk_idx); - InitializeCurrentChunk(); -} - -idx_t TupleDataChunkIterator::GetCurrentChunkCount() const { - return collection.segments[current_segment_idx].chunks[current_chunk_idx].count; -} - -TupleDataChunkState &TupleDataChunkIterator::GetChunkState() { - return state.chunk_state; -} - -data_ptr_t *TupleDataChunkIterator::GetRowLocations() { - return FlatVector::GetData(state.chunk_state.row_locations); -} - -data_ptr_t *TupleDataChunkIterator::GetHeapLocations() { - return FlatVector::GetData(state.chunk_state.heap_locations); -} - -idx_t *TupleDataChunkIterator::GetHeapSizes() { - return FlatVector::GetData(state.chunk_state.heap_sizes); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_layout.cpp b/src/duckdb/src/common/types/row/tuple_data_layout.cpp deleted file mode 100644 index 5dec78e06..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_layout.cpp +++ /dev/null @@ -1,127 +0,0 @@ -#include "duckdb/common/types/row/tuple_data_layout.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -TupleDataLayout::TupleDataLayout() - : flag_width(0), data_width(0), aggr_width(0), row_width(0), all_constant(true), heap_size_offset(0) { -} - -TupleDataLayout TupleDataLayout::Copy() const { - TupleDataLayout result; - result.types = this->types; - result.aggregates = this->aggregates; - if (this->struct_layouts) { - result.struct_layouts = make_uniq>(); - for (const auto &entry : *this->struct_layouts) { - result.struct_layouts->emplace(entry.first, entry.second.Copy()); - } - } - result.flag_width = this->flag_width; - result.data_width = this->data_width; - result.aggr_width = this->aggr_width; - result.row_width = this->row_width; - result.offsets = this->offsets; - result.all_constant = this->all_constant; - result.heap_size_offset = this->heap_size_offset; - result.aggr_destructor_idxs = this->aggr_destructor_idxs; - return result; -} - -void TupleDataLayout::Initialize(vector types_p, Aggregates aggregates_p, bool align, bool heap_offset_p) { - offsets.clear(); - types = std::move(types_p); - - // Null mask at the front - 1 bit per value. - flag_width = ValidityBytes::ValidityMaskSize(types.size()); - row_width = flag_width; - - // Whether all columns are constant size. - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - if (type.InternalType() == PhysicalType::STRUCT) { - // structs are recursively stored as a TupleDataLayout again - const auto &child_types = StructType::GetChildTypes(type); - vector child_type_vector; - child_type_vector.reserve(child_types.size()); - for (auto &ct : child_types) { - child_type_vector.emplace_back(ct.second); - } - if (!struct_layouts) { - struct_layouts = make_uniq>(); - } - auto struct_entry = struct_layouts->emplace(col_idx, TupleDataLayout()); - struct_entry.first->second.Initialize(std::move(child_type_vector), false, false); - all_constant = all_constant && struct_entry.first->second.AllConstant(); - } else { - all_constant = all_constant && TypeIsConstantSize(type.InternalType()); - } - } - - // This enables pointer swizzling for out-of-core computation. - if (heap_offset_p && !all_constant) { - heap_size_offset = row_width; - row_width += sizeof(uint32_t); - } - - // Data columns. No alignment required. - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - const auto &type = types[col_idx]; - offsets.push_back(row_width); - const auto internal_type = type.InternalType(); - if (TypeIsConstantSize(internal_type) || internal_type == PhysicalType::VARCHAR) { - row_width += GetTypeIdSize(type.InternalType()); - } else if (internal_type == PhysicalType::STRUCT) { - // Just get the size of the TupleDataLayout of the struct - row_width += GetStructLayout(col_idx).GetRowWidth(); - } else { - // Variable size types use pointers to the actual data (can be swizzled). - // Again, we would use sizeof(data_ptr_t), but this is not guaranteed to be equal to sizeof(idx_t). - row_width += sizeof(idx_t); - } - } - - // Alignment padding for aggregates -#ifndef DUCKDB_ALLOW_UNDEFINED - if (align) { - row_width = AlignValue(row_width); - } -#endif - data_width = row_width - flag_width; - - // Aggregate fields. - aggregates = std::move(aggregates_p); - for (auto &aggregate : aggregates) { - offsets.push_back(row_width); - row_width += aggregate.payload_size; -#ifndef DUCKDB_ALLOW_UNDEFINED - D_ASSERT(aggregate.payload_size == AlignValue(aggregate.payload_size)); -#endif - } - aggr_width = row_width - data_width - flag_width; - - // Alignment padding for the next row -#ifndef DUCKDB_ALLOW_UNDEFINED - if (align) { - row_width = AlignValue(row_width); - } -#endif - - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - const auto &aggr = aggregates[aggr_idx]; - if (aggr.function.destructor) { - aggr_destructor_idxs.push_back(aggr_idx); - } - } -} - -void TupleDataLayout::Initialize(vector types_p, bool align, bool heap_offset_p) { - Initialize(std::move(types_p), Aggregates(), align, heap_offset_p); -} - -void TupleDataLayout::Initialize(Aggregates aggregates_p, bool align, bool heap_offset_p) { - Initialize(vector(), std::move(aggregates_p), align, heap_offset_p); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp b/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp deleted file mode 100644 index a735d90b0..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_scatter_gather.cpp +++ /dev/null @@ -1,1584 +0,0 @@ -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/tuple_data_collection.hpp" -#include "duckdb/common/uhugeint.hpp" - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -template -static constexpr idx_t TupleDataWithinListFixedSize() { - return sizeof(T); -} - -template <> -constexpr idx_t TupleDataWithinListFixedSize() { - return sizeof(uint32_t); -} - -template -static void TupleDataValueStore(const T &source, const data_ptr_t &row_location, const idx_t offset_in_row, - data_ptr_t &) { - Store(source, row_location + offset_in_row); -} - -template <> -inline void TupleDataValueStore(const string_t &source, const data_ptr_t &row_location, const idx_t offset_in_row, - data_ptr_t &heap_location) { -#ifdef DEBUG - source.VerifyCharacters(); -#endif - if (source.IsInlined()) { - Store(source, row_location + offset_in_row); - } else { - FastMemcpy(heap_location, source.GetData(), source.GetSize()); - Store(string_t(const_char_ptr_cast(heap_location), UnsafeNumericCast(source.GetSize())), - row_location + offset_in_row); - heap_location += source.GetSize(); - } -} - -template -static void TupleDataWithinListValueStore(const T &source, const data_ptr_t &location, data_ptr_t &) { - Store(source, location); -} - -template <> -inline void TupleDataWithinListValueStore(const string_t &source, const data_ptr_t &location, - data_ptr_t &heap_location) { -#ifdef DEBUG - source.VerifyCharacters(); -#endif - Store(UnsafeNumericCast(source.GetSize()), location); - FastMemcpy(heap_location, source.GetData(), source.GetSize()); - heap_location += source.GetSize(); -} - -template -void TupleDataValueVerify(const LogicalType &, const T &) { -#ifdef DEBUG - // NOP -#endif -} - -template <> -inline void TupleDataValueVerify(const LogicalType &type, const string_t &value) { -#ifdef DEBUG - if (type.id() == LogicalTypeId::VARCHAR) { - value.Verify(); - } -#endif -} - -template -static T TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &) { - return Load(location); -} - -template <> -inline string_t TupleDataWithinListValueLoad(const data_ptr_t &location, data_ptr_t &heap_location) { - const auto size = Load(location); - string_t result(const_char_ptr_cast(heap_location), size); - heap_location += size; - return result; -} - -static void ResetCombinedListData(vector &vector_data) { -#ifdef DEBUG - for (auto &vd : vector_data) { - vd.combined_list_data = nullptr; - ResetCombinedListData(vd.children); - } -#endif -} - -void TupleDataCollection::ComputeHeapSizes(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) { - ResetCombinedListData(chunk_state.vector_data); - - auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - std::fill_n(heap_sizes, append_count, 0); - - for (idx_t col_idx = 0; col_idx < new_chunk.ColumnCount(); col_idx++) { - auto &source_v = new_chunk.data[col_idx]; - auto &source_format = chunk_state.vector_data[col_idx]; - ComputeHeapSizes(chunk_state.heap_sizes, source_v, source_format, append_sel, append_count); - } -} - -static idx_t StringHeapSize(const string_t &val) { - return val.IsInlined() ? 0 : val.GetSize(); -} - -void TupleDataCollection::ComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, const SelectionVector &append_sel, - const idx_t append_count) { - const auto type = source_v.GetType().InternalType(); - if (type != PhysicalType::VARCHAR && type != PhysicalType::STRUCT && type != PhysicalType::LIST && - type != PhysicalType::ARRAY) { - return; - } - - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - // Source - const auto &source_vector_data = source_format.unified; - const auto &source_sel = *source_vector_data.sel; - const auto &source_validity = source_vector_data.validity; - - switch (type) { - case PhysicalType::VARCHAR: { - // Only non-inlined strings are stored in the heap - const auto source_data = UnifiedVectorFormat::GetData(source_vector_data); - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (source_validity.RowIsValid(source_idx)) { - heap_sizes[i] += StringHeapSize(source_data[source_idx]); - } else { - heap_sizes[i] += StringHeapSize(NullValue()); - } - } - break; - } - case PhysicalType::STRUCT: { - // Recurse through the struct children - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - const auto &struct_source = struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; - ComputeHeapSizes(heap_sizes_v, *struct_source, struct_format, append_sel, append_count); - } - break; - } - case PhysicalType::LIST: { - // Lists are stored entirely in the heap - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (source_validity.RowIsValid(source_idx)) { - heap_sizes[i] += sizeof(uint64_t); // Size of the list - } - } - - // Recurse - D_ASSERT(source_format.children.size() == 1); - auto &child_source_v = ListVector::GetEntry(source_v); - auto &child_format = source_format.children[0]; - WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, append_count, - source_vector_data); - break; - } - case PhysicalType::ARRAY: { - // Arrays are stored entirely in the heap - for (idx_t i = 0; i < append_count; i++) { - auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (source_validity.RowIsValid(source_idx)) { - heap_sizes[i] += sizeof(uint64_t); // Size of the list - } - } - - // Recurse - D_ASSERT(source_format.children.size() == 1); - auto &child_source_v = ArrayVector::GetEntry(source_v); - auto &child_format = source_format.children[0]; - WithinCollectionComputeHeapSizes(heap_sizes_v, child_source_v, child_format, append_sel, append_count, - source_vector_data); - break; - } - default: - throw NotImplementedException("ComputeHeapSizes for %s", EnumUtil::ToString(source_v.GetType().id())); - } -} - -void TupleDataCollection::WithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const UnifiedVectorFormat &list_data) { - auto type = source_v.GetType().InternalType(); - if (TypeIsConstantSize(type)) { - ComputeFixedWithinCollectionHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, - list_data); - return; - } - switch (type) { - case PhysicalType::VARCHAR: - StringWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, - list_data); - break; - case PhysicalType::STRUCT: - StructWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, - list_data); - break; - case PhysicalType::LIST: - CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, - list_data); - break; - case PhysicalType::ARRAY: - CollectionWithinCollectionComputeHeapSizes(heap_sizes_v, source_v, source_format, append_sel, append_count, - list_data); - break; - default: - throw NotImplementedException("WithinListHeapComputeSizes for %s", EnumUtil::ToString(source_v.GetType().id())); - } -} - -void TupleDataCollection::ComputeFixedWithinCollectionHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &, - const SelectionVector &append_sel, - const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // Parent list data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - D_ASSERT(TypeIsConstantSize(source_v.GetType().InternalType())); - const auto type_size = GetTypeIdSize(source_v.GetType().InternalType()); - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list length - const auto &list_length = list_entries[list_idx].length; - if (list_length == 0) { - continue; - } - - // Size is validity mask and all values - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * type_size; - } -} - -void TupleDataCollection::StringWithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, - const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // Parent list data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &source_validity = source_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - if (list_length == 0) { - continue; - } - - // Size is validity mask and all string sizes - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * TupleDataWithinListFixedSize(); - - // Plus all the actual strings - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (source_validity.RowIsValid(child_source_idx)) { - heap_size += data[child_source_idx].GetSize(); - } - } - } -} - -void TupleDataCollection::StructWithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, - const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // Parent list data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list length - const auto &list_length = list_entries[list_idx].length; - if (list_length == 0) { - continue; - } - - // Size is just the validity mask - heap_sizes[i] += ValidityBytes::SizeInBytes(list_length); - } - - // Recurse - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - - auto &struct_format = source_format.children[struct_col_idx]; - WithinCollectionComputeHeapSizes(heap_sizes_v, struct_source, struct_format, append_sel, append_count, - list_data); - } -} - -static void ApplySliceRecursive(const Vector &source_v, TupleDataVectorFormat &source_format, - const SelectionVector &combined_sel, const idx_t count) { - D_ASSERT(source_format.combined_list_data); - auto &combined_list_data = *source_format.combined_list_data; - - combined_list_data.selection_data = source_format.original_sel->Slice(combined_sel, count); - source_format.unified.owned_sel.Initialize(combined_list_data.selection_data); - source_format.unified.sel = &source_format.unified.owned_sel; - - if (source_v.GetType().InternalType() == PhysicalType::STRUCT) { - // We have to apply it to the child vectors too - auto &struct_sources = StructVector::GetEntries(source_v); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; -#ifdef DEBUG - D_ASSERT(!struct_format.combined_list_data); -#endif - if (!struct_format.combined_list_data) { - struct_format.combined_list_data = make_uniq(); - } - ApplySliceRecursive(struct_source, struct_format, *source_format.unified.sel, count); - } - } -} - -void TupleDataCollection::CollectionWithinCollectionComputeHeapSizes(Vector &heap_sizes_v, const Vector &source_v, - TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, - const idx_t append_count, - const UnifiedVectorFormat &list_data) { - // Parent list data - const auto list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Source - const auto &child_list_data = source_format.unified; - const auto child_list_sel = *child_list_data.sel; - const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); - const auto &child_list_validity = child_list_data.validity; - - // Target - auto heap_sizes = FlatVector::GetData(heap_sizes_v); - - // Figure out actual child list size (can differ from ListVector::GetListSize if dict/const vector), - // and we cannot use ConstantVector::ZeroSelectionVector because it may need to be longer than STANDARD_VECTOR_SIZE - idx_t sum_of_sizes = 0; - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - if (list_length == 0) { - continue; - } - - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - if (!child_list_validity.RowIsValid(child_list_idx)) { - continue; - } - - const auto &child_list_entry = child_list_entries[child_list_idx]; - const auto &child_list_length = child_list_entry.length; - - sum_of_sizes += child_list_length; - } - } - - const auto child_list_child_count = MaxValue( - sum_of_sizes, source_v.GetType().InternalType() == PhysicalType::LIST ? ListVector::GetListSize(source_v) - : ArrayVector::GetTotalSize(source_v)); - - D_ASSERT(source_format.children.size() == 1); - auto &child_format = source_format.children[0]; -#ifdef DEBUG - // In debug mode this should be deleted by ResetCombinedListData - D_ASSERT(!child_format.combined_list_data); -#endif - if (!child_format.combined_list_data) { - child_format.combined_list_data = make_uniq(); - } - auto &combined_list_data = *child_format.combined_list_data; - - // Construct combined list entries and a selection/validity vector for the child list child - SelectionVector combined_sel(child_list_child_count); - for (idx_t i = 0; i < child_list_child_count; i++) { - combined_sel.set_index(i, 0); - } - auto &combined_list_entries = combined_list_data.combined_list_entries; - auto &combined_validity = combined_list_data.combined_validity; - combined_validity.SetAllValid(STANDARD_VECTOR_SIZE); - - idx_t combined_list_offset = 0; - for (idx_t i = 0; i < append_count; i++) { - const auto append_idx = append_sel.get_index(i); - const auto list_idx = list_sel.get_index(append_idx); - if (!list_validity.RowIsValid(list_idx)) { - combined_validity.SetInvalidUnsafe(append_idx); - continue; // Original list entry is invalid - no need to serialize the child list - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - - // Size is the validity mask and the list sizes - auto &heap_size = heap_sizes[i]; - heap_size += ValidityBytes::SizeInBytes(list_length); - heap_size += list_length * sizeof(uint64_t); - - idx_t child_list_size = 0; - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - if (child_list_validity.RowIsValid(child_list_idx)) { - const auto &child_list_entry = child_list_entries[child_list_idx]; - const auto &child_list_offset = child_list_entry.offset; - const auto &child_list_length = child_list_entry.length; - if (child_list_length == 0) { - continue; - } - - // Add this child's list entries to the combined selection vector - for (idx_t child_value_i = 0; child_value_i < child_list_length; child_value_i++) { - auto idx = combined_list_offset + child_list_size + child_value_i; - auto loc = child_list_offset + child_value_i; - combined_sel.set_index(idx, loc); - } - - child_list_size += child_list_length; - } - } - - // Combine the child list entries into one - auto &combined_list_entry = combined_list_entries[append_idx]; - combined_list_entry.offset = combined_list_offset; - combined_list_entry.length = child_list_size; - combined_list_offset += child_list_size; - } - - // TODO: Template this? - auto &child_source = source_v.GetType().InternalType() == PhysicalType::LIST ? ListVector::GetEntry(source_v) - : ArrayVector::GetEntry(source_v); - ApplySliceRecursive(child_source, child_format, combined_sel, child_list_child_count); - - // Create a combined child_list_data to be used as list_data in the recursion - auto &combined_child_list_data = combined_list_data.combined_data; - combined_child_list_data.sel = FlatVector::IncrementalSelectionVector(); - combined_child_list_data.data = data_ptr_cast(combined_list_entries); - combined_child_list_data.validity.Initialize(combined_validity); - - // Recurse - WithinCollectionComputeHeapSizes(heap_sizes_v, child_source, child_format, append_sel, append_count, - combined_child_list_data); -} - -template -static void TemplatedInitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count) { - for (idx_t i = 0; i < append_count; i++) { - Store(T(-1), row_locations[i]); - } -} - -template -static void TemplatedInitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count) { - for (idx_t i = 0; i < append_count; i++) { - memset(row_locations[i], ~0, validity_bytes); - } -} - -static void InitializeValidityMask(const data_ptr_t row_locations[], const idx_t append_count, - const idx_t validity_bytes) { - switch (validity_bytes) { - case 1: - TemplatedInitializeValidityMask(row_locations, append_count); - break; - case 2: - TemplatedInitializeValidityMask(row_locations, append_count); - break; - case 3: - TemplatedInitializeValidityMask<3>(row_locations, append_count); - break; - case 4: - TemplatedInitializeValidityMask(row_locations, append_count); - break; - case 5: - TemplatedInitializeValidityMask<5>(row_locations, append_count); - break; - case 6: - TemplatedInitializeValidityMask<6>(row_locations, append_count); - break; - case 7: - TemplatedInitializeValidityMask<7>(row_locations, append_count); - break; - case 8: - TemplatedInitializeValidityMask(row_locations, append_count); - break; - default: - for (idx_t i = 0; i < append_count; i++) { - FastMemset(row_locations[i], ~0, validity_bytes); - } - } -} - -void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const DataChunk &new_chunk, - const SelectionVector &append_sel, const idx_t append_count) const { -#ifdef DEBUG - Vector heap_locations_copy(LogicalType::POINTER); - if (!layout.AllConstant()) { - const auto heap_locations = FlatVector::GetData(chunk_state.heap_locations); - const auto copied_heap_locations = FlatVector::GetData(heap_locations_copy); - for (idx_t i = 0; i < append_count; i++) { - copied_heap_locations[i] = heap_locations[i]; - } - } -#endif - - const auto row_locations = FlatVector::GetData(chunk_state.row_locations); - - // Set the validity mask for each row before inserting data - InitializeValidityMask(row_locations, append_count, ValidityBytes::SizeInBytes(layout.ColumnCount())); - - if (!layout.AllConstant()) { - // Set the heap size for each row - const auto heap_size_offset = layout.GetHeapSizeOffset(); - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - for (idx_t i = 0; i < append_count; i++) { - Store(UnsafeNumericCast(heap_sizes[i]), row_locations[i] + heap_size_offset); - } - } - - // Write the data - for (const auto &col_idx : chunk_state.column_ids) { - Scatter(chunk_state, new_chunk.data[col_idx], col_idx, append_sel, append_count); - } - -#ifdef DEBUG - // Verify that the size of the data written to the heap is the same as the size we computed it would be - if (!layout.AllConstant()) { - const auto original_heap_locations = FlatVector::GetData(heap_locations_copy); - const auto heap_sizes = FlatVector::GetData(chunk_state.heap_sizes); - const auto offset_heap_locations = FlatVector::GetData(chunk_state.heap_locations); - for (idx_t i = 0; i < append_count; i++) { - if (heap_sizes[i] != 0) { - D_ASSERT(offset_heap_locations[i] == original_heap_locations[i] + heap_sizes[i]); - } - } - } -#endif -} - -void TupleDataCollection::Scatter(TupleDataChunkState &chunk_state, const Vector &source, const column_t column_id, - const SelectionVector &append_sel, const idx_t append_count) const { - const auto &scatter_function = scatter_functions[column_id]; - scatter_function.function(source, chunk_state.vector_data[column_id], append_sel, append_count, layout, - chunk_state.row_locations, chunk_state.heap_locations, column_id, - chunk_state.vector_data[column_id].unified, scatter_function.child_functions); -} - -template -static void TupleDataTemplatedScatter(const Vector &, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, const UnifiedVectorFormat &, - const vector &) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &validity = source_data.validity; - - // Target - const auto target_locations = FlatVector::GetData(row_locations); - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - if (validity.AllValid()) { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); - } - } else { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (validity.RowIsValid(source_idx)) { - TupleDataValueStore(data[source_idx], target_locations[i], offset_in_row, target_heap_locations[i]); - } else { - TupleDataValueStore(NullValue(), target_locations[i], offset_in_row, target_heap_locations[i]); - ValidityBytes(target_locations[i], layout.ColumnCount()).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - } -} - -static void TupleDataStructScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &dummy_arg, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto &validity = source_data.validity; - - // Target - const auto target_locations = FlatVector::GetData(row_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Set validity of the STRUCT in this layout - if (!validity.AllValid()) { - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (!validity.RowIsValid(source_idx)) { - ValidityBytes(target_locations[i], layout.ColumnCount()).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - } - - // Create a Vector of pointers to the TupleDataLayout of the STRUCT - Vector struct_row_locations(LogicalType::POINTER, append_count); - auto struct_target_locations = FlatVector::GetData(struct_row_locations); - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < append_count; i++) { - struct_target_locations[i] = target_locations[i] + offset_in_row; - } - - const auto &struct_layout = layout.GetStructLayout(col_idx); - auto &struct_sources = StructVector::GetEntries(source); - D_ASSERT(struct_layout.ColumnCount() == struct_sources.size()); - - // Set the validity of the entries within the STRUCTs - InitializeValidityMask(struct_target_locations, append_count, - ValidityBytes::SizeInBytes(struct_layout.ColumnCount())); - - // Recurse through the struct children - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - const auto &struct_source_format = source_format.children[struct_col_idx]; - const auto &struct_scatter_function = child_functions[struct_col_idx]; - struct_scatter_function.function(struct_source, struct_source_format, append_sel, append_count, struct_layout, - struct_row_locations, heap_locations, struct_col_idx, dummy_arg, - struct_scatter_function.child_functions); - } -} - -//------------------------------------------------------------------------------ -// List Scatter -//------------------------------------------------------------------------------ -static void TupleDataListScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &, - const vector &child_functions) { - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &validity = source_data.validity; - - // Target - const auto target_locations = FlatVector::GetData(row_locations); - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Set validity of the LIST in this layout, and store pointer to where it's stored - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (validity.RowIsValid(source_idx)) { - auto &target_heap_location = target_heap_locations[i]; - Store(target_heap_location, target_locations[i] + offset_in_row); - - // Store list length and skip over it - Store(data[source_idx].length, target_heap_location); - target_heap_location += sizeof(uint64_t); - } else { - ValidityBytes(target_locations[i], layout.ColumnCount()).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - - // Recurse - D_ASSERT(child_functions.size() == 1); - auto &child_source = ListVector::GetEntry(source); - auto &child_format = source_format.children[0]; - const auto &child_function = child_functions[0]; - child_function.function(child_source, child_format, append_sel, append_count, layout, row_locations, heap_locations, - col_idx, source_format.unified, child_function.child_functions); -} - -//------------------------------------------------------------------------------ -// Array Scatter -//------------------------------------------------------------------------------ -static void TupleDataArrayScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, Vector &heap_locations, - const idx_t col_idx, const UnifiedVectorFormat &, - const vector &child_functions) { - // Source - // The Array vector has fake list_entry_t's set by this point, so this is fine - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &validity = source_data.validity; - - // Target - const auto target_locations = FlatVector::GetData(row_locations); - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Set validity of the LIST in this layout, and store pointer to where it's stored - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < append_count; i++) { - const auto source_idx = source_sel.get_index(append_sel.get_index(i)); - if (validity.RowIsValid(source_idx)) { - auto &target_heap_location = target_heap_locations[i]; - Store(target_heap_location, target_locations[i] + offset_in_row); - - // Store list length and skip over it - Store(data[source_idx].length, target_heap_location); - target_heap_location += sizeof(uint64_t); - } else { - ValidityBytes(target_locations[i], layout.ColumnCount()).SetInvalidUnsafe(entry_idx, idx_in_entry); - } - } - - // Recurse - D_ASSERT(child_functions.size() == 1); - auto &child_source = ArrayVector::GetEntry(source); - auto &child_format = source_format.children[0]; - const auto &child_function = child_functions[0]; - child_function.function(child_source, child_format, append_sel, append_count, layout, row_locations, heap_locations, - col_idx, source_format.unified, child_function.child_functions); -} - -//------------------------------------------------------------------------------ -// Collection Scatter -//------------------------------------------------------------------------------ -template -static void TupleDataTemplatedWithinCollectionScatter(const Vector &, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &, const Vector &, Vector &heap_locations, - const idx_t, const UnifiedVectorFormat &list_data, - const vector &) { - // Parent list data - const auto &list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto data = UnifiedVectorFormat::GetData(source_data); - const auto &source_validity = source_data.validity; - - // Target - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - if (list_length == 0) { - continue; - } - - // Initialize validity mask and skip heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location, list_length); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto child_data_location = target_heap_location; - target_heap_location += list_length * TupleDataWithinListFixedSize(); - - // Store the data and validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (source_validity.RowIsValid(child_source_idx)) { - TupleDataWithinListValueStore(data[child_source_idx], - child_data_location + child_i * TupleDataWithinListFixedSize(), - target_heap_location); - } else { - child_mask.SetInvalidUnsafe(child_i); - } - } - } -} - -static void TupleDataStructWithinCollectionScatter(const Vector &source, const TupleDataVectorFormat &source_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { - // Parent list data - const auto &list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Source - const auto &source_data = source_format.unified; - const auto &source_sel = *source_data.sel; - const auto &source_validity = source_data.validity; - - // Target - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - // Initialize the validity of the STRUCTs - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - if (list_length == 0) { - continue; - } - - // Initialize validity mask and skip the heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location, list_length); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Store the validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_source_idx = source_sel.get_index(list_offset + child_i); - if (!source_validity.RowIsValid(child_source_idx)) { - child_mask.SetInvalidUnsafe(child_i); - } - } - } - - // Recurse through the children - auto &struct_sources = StructVector::GetEntries(source); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_sources.size(); struct_col_idx++) { - auto &struct_source = *struct_sources[struct_col_idx]; - auto &struct_format = source_format.children[struct_col_idx]; - const auto &struct_scatter_function = child_functions[struct_col_idx]; - struct_scatter_function.function(struct_source, struct_format, append_sel, append_count, layout, row_locations, - heap_locations, struct_col_idx, list_data, - struct_scatter_function.child_functions); - } -} - -template -static void TupleDataCollectionWithinCollectionScatter(const Vector &child_list, - const TupleDataVectorFormat &child_list_format, - const SelectionVector &append_sel, const idx_t append_count, - const TupleDataLayout &layout, const Vector &row_locations, - Vector &heap_locations, const idx_t col_idx, - const UnifiedVectorFormat &list_data, - const vector &child_functions) { - // Parent list data - const auto &list_sel = *list_data.sel; - const auto list_entries = UnifiedVectorFormat::GetData(list_data); - const auto &list_validity = list_data.validity; - - // Source - const auto &child_list_data = child_list_format.unified; - const auto &child_list_sel = *child_list_data.sel; - const auto child_list_entries = UnifiedVectorFormat::GetData(child_list_data); - const auto &child_list_validity = child_list_data.validity; - - // Target - const auto target_heap_locations = FlatVector::GetData(heap_locations); - - for (idx_t i = 0; i < append_count; i++) { - const auto list_idx = list_sel.get_index(append_sel.get_index(i)); - if (!list_validity.RowIsValid(list_idx)) { - continue; // Original list entry is invalid - no need to serialize the child list - } - - // Get the current list entry - const auto &list_entry = list_entries[list_idx]; - const auto &list_offset = list_entry.offset; - const auto &list_length = list_entry.length; - if (list_length == 0) { - continue; - } - - // Initialize validity mask and skip heap pointer over it - auto &target_heap_location = target_heap_locations[i]; - ValidityBytes child_mask(target_heap_location, list_length); - child_mask.SetAllValid(list_length); - target_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto child_data_location = target_heap_location; - target_heap_location += list_length * sizeof(uint64_t); - - for (idx_t child_i = 0; child_i < list_length; child_i++) { - const auto child_list_idx = child_list_sel.get_index(list_offset + child_i); - if (child_list_validity.RowIsValid(child_list_idx)) { - const auto &child_list_length = child_list_entries[child_list_idx].length; - Store(child_list_length, child_data_location + child_i * sizeof(uint64_t)); - } else { - child_mask.SetInvalidUnsafe(child_i); - } - } - } - - // Recurse - D_ASSERT(child_functions.size() == 1); - auto &child_vec = COLLECTION_VECTOR::GetEntry(child_list); - auto &child_format = child_list_format.children[0]; - auto &combined_child_list_data = child_format.combined_list_data->combined_data; - const auto &child_function = child_functions[0]; - child_function.function(child_vec, child_format, append_sel, append_count, layout, row_locations, heap_locations, - col_idx, combined_child_list_data, child_function.child_functions); -} - -//------------------------------------------------------------------------------ -// Get Scatter Function -//------------------------------------------------------------------------------ -template -tuple_data_scatter_function_t TupleDataGetScatterFunction(bool within_collection) { - return within_collection ? TupleDataTemplatedWithinCollectionScatter : TupleDataTemplatedScatter; -} - -TupleDataScatterFunction TupleDataCollection::GetScatterFunction(const LogicalType &type, bool within_collection) { - TupleDataScatterFunction result; - switch (type.InternalType()) { - case PhysicalType::BOOL: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INT8: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INT16: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INT32: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INT64: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INT128: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::UINT8: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::UINT16: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::UINT32: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::UINT64: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::UINT128: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::FLOAT: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::DOUBLE: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::INTERVAL: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::VARCHAR: - result.function = TupleDataGetScatterFunction(within_collection); - break; - case PhysicalType::STRUCT: { - result.function = within_collection ? TupleDataStructWithinCollectionScatter : TupleDataStructScatter; - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(GetScatterFunction(child_type.second, within_collection)); - } - break; - } - case PhysicalType::LIST: - result.function = - within_collection ? TupleDataCollectionWithinCollectionScatter : TupleDataListScatter; - result.child_functions.emplace_back(GetScatterFunction(ListType::GetChildType(type), true)); - break; - case PhysicalType::ARRAY: - result.function = - within_collection ? TupleDataCollectionWithinCollectionScatter : TupleDataArrayScatter; - result.child_functions.emplace_back(GetScatterFunction(ArrayType::GetChildType(type), true)); - break; - default: - throw InternalException("Unsupported type for TupleDataCollection::GetScatterFunction"); - } - return result; -} - -//------------------------------------------------------------------------------- -// Gather -//------------------------------------------------------------------------------- -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - DataChunk &result, const SelectionVector &target_sel, - vector> &cached_cast_vectors) const { - D_ASSERT(result.ColumnCount() == layout.ColumnCount()); - vector column_ids; - column_ids.reserve(layout.ColumnCount()); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount(); col_idx++) { - column_ids.emplace_back(col_idx); - } - Gather(row_locations, scan_sel, scan_count, column_ids, result, target_sel, cached_cast_vectors); -} - -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - const vector &column_ids, DataChunk &result, - const SelectionVector &target_sel, - vector> &cached_cast_vectors) const { - for (idx_t col_idx = 0; col_idx < column_ids.size(); col_idx++) { - Gather(row_locations, scan_sel, scan_count, column_ids[col_idx], result.data[col_idx], target_sel, - cached_cast_vectors[col_idx].get()); - } -} - -void TupleDataCollection::Gather(Vector &row_locations, const SelectionVector &scan_sel, const idx_t scan_count, - const column_t column_id, Vector &result, const SelectionVector &target_sel, - optional_ptr cached_cast_vector) const { - D_ASSERT(!cached_cast_vector || FlatVector::Validity(*cached_cast_vector).AllValid()); // ResetCachedCastVectors - const auto &gather_function = gather_functions[column_id]; - gather_function.function(layout, row_locations, column_id, scan_sel, scan_count, result, target_sel, - cached_cast_vector, gather_function.child_functions); - Vector::Verify(result, target_sel, scan_count); -} - -template -static void TupleDataTemplatedGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr, - const vector &) { - // Source - const auto source_locations = FlatVector::GetData(row_locations); - - // Target - auto target_data = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < scan_count; i++) { - const auto &source_row = source_locations[scan_sel.get_index(i)]; - const auto target_idx = target_sel.get_index(i); - target_data[target_idx] = Load(source_row + offset_in_row); - ValidityBytes row_mask(source_row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - target_validity.SetInvalid(target_idx); - } -#ifdef DEBUG - else { - TupleDataValueVerify(target.GetType(), target_data[target_idx]); - } -#endif - } -} - -static void TupleDataStructGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr dummy_vector, - const vector &child_functions) { - // Source - const auto source_locations = FlatVector::GetData(row_locations); - - // Target - auto &target_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Get validity of the struct and create a Vector of pointers to the start of the TupleDataLayout of the STRUCT - Vector struct_row_locations(LogicalType::POINTER); - auto struct_source_locations = FlatVector::GetData(struct_row_locations); - const auto offset_in_row = layout.GetOffsets()[col_idx]; - for (idx_t i = 0; i < scan_count; i++) { - const auto source_idx = scan_sel.get_index(i); - const auto &source_row = source_locations[source_idx]; - - // Set the validity - ValidityBytes row_mask(source_row, layout.ColumnCount()); - if (!row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - const auto target_idx = target_sel.get_index(i); - target_validity.SetInvalid(target_idx); - } - - // Set the pointer - struct_source_locations[source_idx] = source_row + offset_in_row; - } - - // Get the struct layout and struct entries - const auto &struct_layout = layout.GetStructLayout(col_idx); - auto &struct_targets = StructVector::GetEntries(target); - D_ASSERT(struct_layout.ColumnCount() == struct_targets.size()); - - // Recurse through the struct children - for (idx_t struct_col_idx = 0; struct_col_idx < struct_layout.ColumnCount(); struct_col_idx++) { - auto &struct_target = *struct_targets[struct_col_idx]; - const auto &struct_gather_function = child_functions[struct_col_idx]; - struct_gather_function.function(struct_layout, struct_row_locations, struct_col_idx, scan_sel, scan_count, - struct_target, target_sel, dummy_vector, - struct_gather_function.child_functions); - } -} - -//------------------------------------------------------------------------------ -// List Gather -//------------------------------------------------------------------------------ -static void TupleDataListGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr, - const vector &child_functions) { - // Source - const auto source_locations = FlatVector::GetData(row_locations); - - // Target - const auto target_list_entries = FlatVector::GetData(target); - auto &target_list_validity = FlatVector::Validity(target); - - // Precompute mask indexes - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - // Load pointers to the data from the row - Vector heap_locations(LogicalType::POINTER); - const auto source_heap_locations = FlatVector::GetData(heap_locations); - - const auto offset_in_row = layout.GetOffsets()[col_idx]; - auto list_size_before = ListVector::GetListSize(target); - uint64_t target_list_offset = list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto &source_row = source_locations[scan_sel.get_index(i)]; - ValidityBytes row_mask(source_row, layout.ColumnCount()); - - const auto target_idx = target_sel.get_index(i); - if (row_mask.RowIsValid(row_mask.GetValidityEntryUnsafe(entry_idx), idx_in_entry)) { - auto &source_heap_location = source_heap_locations[i]; - source_heap_location = Load(source_row + offset_in_row); - - // Load list size and skip over - const auto list_length = Load(source_heap_location); - source_heap_location += sizeof(uint64_t); - - // Initialize list entry, and increment offset - auto &target_list_entry = target_list_entries[target_idx]; - target_list_entry.offset = target_list_offset; - target_list_entry.length = list_length; - target_list_offset += list_length; - } else { - target_list_validity.SetInvalid(target_idx); - } - } - ListVector::Reserve(target, target_list_offset); - ListVector::SetListSize(target, target_list_offset); - - // Recurse - D_ASSERT(child_functions.size() == 1); - const auto &child_function = child_functions[0]; - child_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, - ListVector::GetEntry(target), target_sel, &target, child_function.child_functions); -} - -//------------------------------------------------------------------------------ -// Collection Gather -//------------------------------------------------------------------------------ -template -static void -TupleDataTemplatedWithinCollectionGather(const TupleDataLayout &, Vector &heap_locations, const idx_t list_size_before, - const SelectionVector &, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr list_vector, - const vector &) { - // List parent - const auto list_entries = FlatVector::GetData(*list_vector); - const auto &list_validity = FlatVector::Validity(*list_vector); - - // Source - const auto source_heap_locations = FlatVector::GetData(heap_locations); - - // Target - const auto target_data = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - - uint64_t target_offset = list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto target_idx = target_sel.get_index(i); - if (!list_validity.RowIsValid(target_idx)) { - continue; - } - - const auto &list_length = list_entries[target_idx].length; - if (list_length == 0) { - continue; - } - - // Initialize validity mask - auto &source_heap_location = source_heap_locations[i]; - ValidityBytes source_mask(source_heap_location, list_length); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto source_data_location = source_heap_location; - source_heap_location += list_length * TupleDataWithinListFixedSize(); - - // Load the child validity and data belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (source_mask.RowIsValidUnsafe(child_i)) { - auto &target_value = target_data[target_offset + child_i]; - target_value = TupleDataWithinListValueLoad( - source_data_location + child_i * TupleDataWithinListFixedSize(), source_heap_location); - TupleDataValueVerify(target.GetType(), target_value); - } else { - target_validity.SetInvalid(target_offset + child_i); - } - } - target_offset += list_length; - } -} - -static void TupleDataStructWithinCollectionGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr list_vector, - const vector &child_functions) { - // List parent - const auto list_entries = FlatVector::GetData(*list_vector); - const auto &list_validity = FlatVector::Validity(*list_vector); - - // Source - const auto source_heap_locations = FlatVector::GetData(heap_locations); - - // Target - auto &target_validity = FlatVector::Validity(target); - - uint64_t target_offset = list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto target_idx = target_sel.get_index(i); - if (!list_validity.RowIsValid(target_idx)) { - continue; - } - - const auto &list_length = list_entries[target_idx].length; - if (list_length == 0) { - continue; - } - - // Initialize validity mask and skip over it - auto &source_heap_location = source_heap_locations[i]; - ValidityBytes source_mask(source_heap_location, list_length); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Load the child validity belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (!source_mask.RowIsValidUnsafe(child_i)) { - target_validity.SetInvalid(target_offset + child_i); - } - } - target_offset += list_length; - } - - // Recurse - auto &struct_targets = StructVector::GetEntries(target); - for (idx_t struct_col_idx = 0; struct_col_idx < struct_targets.size(); struct_col_idx++) { - auto &struct_target = *struct_targets[struct_col_idx]; - const auto &struct_gather_function = child_functions[struct_col_idx]; - struct_gather_function.function(layout, heap_locations, list_size_before, scan_sel, scan_count, struct_target, - target_sel, list_vector, struct_gather_function.child_functions); - } -} - -static void TupleDataCollectionWithinCollectionGather(const TupleDataLayout &layout, Vector &heap_locations, - const idx_t list_size_before, const SelectionVector &scan_sel, - const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, - optional_ptr list_vector, - const vector &child_functions) { - // List parent - const auto list_entries = FlatVector::GetData(*list_vector); - const auto &list_validity = FlatVector::Validity(*list_vector); - - // Source - const auto source_heap_locations = FlatVector::GetData(heap_locations); - - // Target - const auto target_list_entries = FlatVector::GetData(target); - auto &target_validity = FlatVector::Validity(target); - const auto child_list_size_before = ListVector::GetListSize(target); - - // We need to create a vector that has the combined list sizes (hugeint_t has same size as list_entry_t) - Vector combined_list_vector(LogicalType::HUGEINT); - FlatVector::SetValidity(combined_list_vector, list_validity); // Has same validity as list parent - const auto combined_list_entries = FlatVector::GetData(combined_list_vector); - - uint64_t target_offset = list_size_before; - uint64_t target_child_offset = child_list_size_before; - for (idx_t i = 0; i < scan_count; i++) { - const auto target_idx = target_sel.get_index(i); - if (!list_validity.RowIsValid(target_idx)) { - continue; - } - - // Set the offset of the combined list entry - auto &combined_list_entry = combined_list_entries[target_idx]; - combined_list_entry.offset = target_child_offset; - - const auto &list_length = list_entries[target_idx].length; - if (list_length == 0) { - combined_list_entry.length = 0; - continue; - } - - // Initialize validity mask and skip over it - auto &source_heap_location = source_heap_locations[i]; - ValidityBytes source_mask(source_heap_location, list_length); - source_heap_location += ValidityBytes::SizeInBytes(list_length); - - // Get the start to the fixed-size data and skip the heap pointer over it - const auto source_data_location = source_heap_location; - source_heap_location += list_length * sizeof(uint64_t); - - // Load the child validity and data belonging to this list entry - for (idx_t child_i = 0; child_i < list_length; child_i++) { - if (source_mask.RowIsValidUnsafe(child_i)) { - auto &target_list_entry = target_list_entries[target_offset + child_i]; - target_list_entry.offset = target_child_offset; - target_list_entry.length = Load(source_data_location + child_i * sizeof(uint64_t)); - target_child_offset += target_list_entry.length; - } else { - target_validity.SetInvalid(target_offset + child_i); - } - } - - // Set the length of the combined list entry - combined_list_entry.length = target_child_offset - combined_list_entry.offset; - - target_offset += list_length; - } - - ListVector::Reserve(target, target_child_offset); - ListVector::SetListSize(target, target_child_offset); - - // Recurse - D_ASSERT(child_functions.size() == 1); - const auto &child_function = child_functions[0]; - child_function.function(layout, heap_locations, child_list_size_before, scan_sel, scan_count, - ListVector::GetEntry(target), target_sel, &combined_list_vector, - child_function.child_functions); -} - -//------------------------------------------------------------------------------ -// Special cases for arrays -//------------------------------------------------------------------------------ -// A gather function that wraps another gather function and casts the result to the target array type -static void TupleDataCastToArrayListGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr cached_cast_vector, - const vector &child_functions) { - if (cached_cast_vector) { - // Reuse the cached cast vector - TupleDataListGather(layout, row_locations, col_idx, scan_sel, scan_count, *cached_cast_vector, target_sel, - cached_cast_vector, child_functions); - VectorOperations::DefaultCast(*cached_cast_vector, target, scan_count); - } else { - // Otherwise, create a new temporary cast vector - Vector cast_vector(ArrayType::ConvertToList(target.GetType())); - TupleDataListGather(layout, row_locations, col_idx, scan_sel, scan_count, cast_vector, target_sel, &cast_vector, - child_functions); - VectorOperations::DefaultCast(cast_vector, target, scan_count); - } -} - -static void TupleDataCastToArrayStructGather(const TupleDataLayout &layout, Vector &row_locations, const idx_t col_idx, - const SelectionVector &scan_sel, const idx_t scan_count, Vector &target, - const SelectionVector &target_sel, optional_ptr cached_cast_vector, - const vector &child_functions) { - - if (cached_cast_vector) { - // Reuse the cached cast vector - TupleDataStructGather(layout, row_locations, col_idx, scan_sel, scan_count, *cached_cast_vector, target_sel, - cached_cast_vector, child_functions); - VectorOperations::DefaultCast(*cached_cast_vector, target, scan_count); - } else { - // Otherwise, create a new temporary cast vector - Vector cast_vector(ArrayType::ConvertToList(target.GetType())); - TupleDataStructGather(layout, row_locations, col_idx, scan_sel, scan_count, cast_vector, target_sel, - &cast_vector, child_functions); - VectorOperations::DefaultCast(cast_vector, target, scan_count); - } -} - -//------------------------------------------------------------------------------ -// Get Gather Function -//------------------------------------------------------------------------------ -template -tuple_data_gather_function_t TupleDataGetGatherFunction(bool within_collection) { - return within_collection ? TupleDataTemplatedWithinCollectionGather : TupleDataTemplatedGather; -} - -static TupleDataGatherFunction TupleDataGetGatherFunctionInternal(const LogicalType &type, bool within_collection) { - TupleDataGatherFunction result; - switch (type.InternalType()) { - case PhysicalType::BOOL: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INT8: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INT16: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INT32: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INT64: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INT128: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::UINT8: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::UINT16: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::UINT32: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::UINT64: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::UINT128: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::FLOAT: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::DOUBLE: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::INTERVAL: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::VARCHAR: - result.function = TupleDataGetGatherFunction(within_collection); - break; - case PhysicalType::STRUCT: { - result.function = within_collection ? TupleDataStructWithinCollectionGather : TupleDataStructGather; - for (const auto &child_type : StructType::GetChildTypes(type)) { - result.child_functions.push_back(TupleDataGetGatherFunctionInternal(child_type.second, within_collection)); - } - break; - } - case PhysicalType::LIST: - result.function = within_collection ? TupleDataCollectionWithinCollectionGather : TupleDataListGather; - result.child_functions.push_back(TupleDataGetGatherFunctionInternal(ListType::GetChildType(type), true)); - break; - case PhysicalType::ARRAY: - result.function = within_collection ? TupleDataCollectionWithinCollectionGather : TupleDataListGather; - result.child_functions.push_back(TupleDataGetGatherFunctionInternal(ArrayType::GetChildType(type), true)); - break; - default: - throw InternalException("Unsupported type for TupleDataCollection::GetGatherFunction"); - } - return result; -} - -TupleDataGatherFunction TupleDataCollection::GetGatherFunction(const LogicalType &type) { - if (!type.IsNested()) { - return TupleDataGetGatherFunctionInternal(type, false); - } - - if (TypeVisitor::Contains(type, LogicalTypeId::ARRAY)) { - // Special case: we cant handle arrays yet, so we need to replace them with lists when gathering - const auto new_type = ArrayType::ConvertToList(type); - TupleDataGatherFunction result; - // Theres only two cases: Either the array is within a struct, or it is within a list (or has now become a list) - switch (new_type.InternalType()) { - case PhysicalType::LIST: - result.function = TupleDataCastToArrayListGather; - result.child_functions.push_back( - TupleDataGetGatherFunctionInternal(ListType::GetChildType(new_type), true)); - return result; - case PhysicalType::STRUCT: - result.function = TupleDataCastToArrayStructGather; - for (const auto &child_type : StructType::GetChildTypes(new_type)) { - result.child_functions.push_back(TupleDataGetGatherFunctionInternal(child_type.second, false)); - } - return result; - default: - throw InternalException("Unsupported type for TupleDataCollection::GetGatherFunction"); - } - } - return TupleDataGetGatherFunctionInternal(type, false); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/row/tuple_data_segment.cpp b/src/duckdb/src/common/types/row/tuple_data_segment.cpp deleted file mode 100644 index ddec13238..000000000 --- a/src/duckdb/src/common/types/row/tuple_data_segment.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "duckdb/common/types/row/tuple_data_segment.hpp" - -#include "duckdb/common/types/row/tuple_data_allocator.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" - -namespace duckdb { - -TupleDataChunkPart::TupleDataChunkPart(mutex &lock_p) : lock(lock_p) { -} - -void TupleDataChunkPart::SetHeapEmpty() { - heap_block_index = INVALID_INDEX; - heap_block_offset = INVALID_INDEX; - total_heap_size = 0; - base_heap_ptr = nullptr; -} - -void SwapTupleDataChunkPart(TupleDataChunkPart &a, TupleDataChunkPart &b) { - std::swap(a.row_block_index, b.row_block_index); - std::swap(a.row_block_offset, b.row_block_offset); - std::swap(a.heap_block_index, b.heap_block_index); - std::swap(a.heap_block_offset, b.heap_block_offset); - std::swap(a.base_heap_ptr, b.base_heap_ptr); - std::swap(a.total_heap_size, b.total_heap_size); - std::swap(a.count, b.count); - std::swap(a.lock, b.lock); -} - -TupleDataChunkPart::TupleDataChunkPart(TupleDataChunkPart &&other) noexcept : lock((other.lock)) { - SwapTupleDataChunkPart(*this, other); -} - -TupleDataChunkPart &TupleDataChunkPart::operator=(TupleDataChunkPart &&other) noexcept { - SwapTupleDataChunkPart(*this, other); - return *this; -} - -TupleDataChunk::TupleDataChunk() : count(0), lock(make_unsafe_uniq()) { - parts.reserve(2); -} - -static inline void SwapTupleDataChunk(TupleDataChunk &a, TupleDataChunk &b) noexcept { - std::swap(a.parts, b.parts); - std::swap(a.row_block_ids, b.row_block_ids); - std::swap(a.heap_block_ids, b.heap_block_ids); - std::swap(a.count, b.count); - std::swap(a.lock, b.lock); -} - -TupleDataChunk::TupleDataChunk(TupleDataChunk &&other) noexcept { - SwapTupleDataChunk(*this, other); -} - -TupleDataChunk &TupleDataChunk::operator=(TupleDataChunk &&other) noexcept { - SwapTupleDataChunk(*this, other); - return *this; -} - -void TupleDataChunk::AddPart(TupleDataChunkPart &&part, const TupleDataLayout &layout) { - count += part.count; - row_block_ids.insert(part.row_block_index); - if (!layout.AllConstant() && part.total_heap_size > 0) { - heap_block_ids.insert(part.heap_block_index); - } - part.lock = *lock; - parts.emplace_back(std::move(part)); -} - -void TupleDataChunk::Verify() const { -#ifdef DEBUG - idx_t total_count = 0; - for (const auto &part : parts) { - total_count += part.count; - } - D_ASSERT(this->count == total_count); - D_ASSERT(this->count <= STANDARD_VECTOR_SIZE); -#endif -} - -void TupleDataChunk::MergeLastChunkPart(const TupleDataLayout &layout) { - if (parts.size() < 2) { - return; - } - - auto &second_to_last = parts[parts.size() - 2]; - auto &last = parts[parts.size() - 1]; - - auto rows_align = - last.row_block_index == second_to_last.row_block_index && - last.row_block_offset == second_to_last.row_block_offset + second_to_last.count * layout.GetRowWidth(); - - if (!rows_align) { // If rows don't align we can never merge - return; - } - - if (layout.AllConstant()) { // No heap and rows align - merge - second_to_last.count += last.count; - parts.pop_back(); - return; - } - - if (last.heap_block_index == second_to_last.heap_block_index && - last.heap_block_offset == second_to_last.heap_block_index + second_to_last.total_heap_size && - last.base_heap_ptr == second_to_last.base_heap_ptr) { // There is a heap and it aligns - merge - second_to_last.total_heap_size += last.total_heap_size; - second_to_last.count += last.count; - parts.pop_back(); - } -} - -TupleDataSegment::TupleDataSegment(shared_ptr allocator_p) - : allocator(std::move(allocator_p)), count(0), data_size(0) { -} - -TupleDataSegment::~TupleDataSegment() { - lock_guard guard(pinned_handles_lock); - if (allocator) { - allocator->SetDestroyBufferUponUnpin(); // Prevent blocks from being added to eviction queue - } - pinned_row_handles.clear(); - pinned_heap_handles.clear(); - if (Allocator::SupportsFlush() && allocator && - data_size > allocator->GetBufferManager().GetBufferPool().GetAllocatorBulkDeallocationFlushThreshold()) { - Allocator::FlushAll(); - } - allocator.reset(); -} - -void SwapTupleDataSegment(TupleDataSegment &a, TupleDataSegment &b) { - std::swap(a.allocator, b.allocator); - std::swap(a.chunks, b.chunks); - std::swap(a.count, b.count); - std::swap(a.data_size, b.data_size); - std::swap(a.pinned_row_handles, b.pinned_row_handles); - std::swap(a.pinned_heap_handles, b.pinned_heap_handles); -} - -TupleDataSegment::TupleDataSegment(TupleDataSegment &&other) noexcept { - SwapTupleDataSegment(*this, other); -} - -TupleDataSegment &TupleDataSegment::operator=(TupleDataSegment &&other) noexcept { - SwapTupleDataSegment(*this, other); - return *this; -} - -idx_t TupleDataSegment::ChunkCount() const { - return chunks.size(); -} - -idx_t TupleDataSegment::SizeInBytes() const { - return data_size; -} - -void TupleDataSegment::Unpin() { - lock_guard guard(pinned_handles_lock); - pinned_row_handles.clear(); - pinned_heap_handles.clear(); -} - -void TupleDataSegment::Verify() const { -#ifdef DEBUG - const auto &layout = allocator->GetLayout(); - - idx_t total_count = 0; - idx_t total_size = 0; - for (const auto &chunk : chunks) { - chunk.Verify(); - total_count += chunk.count; - - total_size += chunk.count * layout.GetRowWidth(); - if (!layout.AllConstant()) { - for (const auto &part : chunk.parts) { - total_size += part.total_heap_size; - } - } - } - D_ASSERT(total_count == this->count); - D_ASSERT(total_size == this->data_size); -#endif -} - -void TupleDataSegment::VerifyEverythingPinned() const { -#ifdef DEBUG - D_ASSERT(pinned_row_handles.size() == allocator->RowBlockCount()); - D_ASSERT(pinned_heap_handles.size() == allocator->HeapBlockCount()); -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/selection_vector.cpp b/src/duckdb/src/common/types/selection_vector.cpp deleted file mode 100644 index 7b6fda414..000000000 --- a/src/duckdb/src/common/types/selection_vector.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "duckdb/common/types/selection_vector.hpp" - -#include "duckdb/common/printer.hpp" -#include "duckdb/common/to_string.hpp" - -namespace duckdb { - -SelectionData::SelectionData(idx_t count) { - owned_data = make_unsafe_uniq_array_uninitialized(count); -#ifdef DEBUG - for (idx_t i = 0; i < count; i++) { - owned_data[i] = std::numeric_limits::max(); - } -#endif -} - -// LCOV_EXCL_START -string SelectionVector::ToString(idx_t count) const { - string result = "Selection Vector (" + to_string(count) + ") ["; - for (idx_t i = 0; i < count; i++) { - if (i != 0) { - result += ", "; - } - result += to_string(get_index(i)); - } - result += "]"; - return result; -} - -void SelectionVector::Print(idx_t count) const { - Printer::Print(ToString(count)); -} -// LCOV_EXCL_STOP - -buffer_ptr SelectionVector::Slice(const SelectionVector &sel, idx_t count) const { - auto data = make_buffer(count); - auto result_ptr = data->owned_data.get(); - // for every element, we perform result[i] = target[new[i]] - for (idx_t i = 0; i < count; i++) { - auto new_idx = sel.get_index(i); - auto idx = this->get_index(new_idx); - result_ptr[i] = UnsafeNumericCast(idx); - } - return data; -} - -void SelectionVector::Verify(idx_t count, idx_t vector_size) const { -#ifdef DEBUG - D_ASSERT(vector_size >= 1); - for (idx_t i = 0; i < count; i++) { - auto index = get_index(i); - if (index >= vector_size) { - throw InternalException( - "Provided SelectionVector is invalid, index %d points to %d, which is out of range. " - "the valid range (0-%d)", - i, index, vector_size - 1); - } - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/string_heap.cpp b/src/duckdb/src/common/types/string_heap.cpp deleted file mode 100644 index be23433dd..000000000 --- a/src/duckdb/src/common/types/string_heap.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "duckdb/common/types/string_heap.hpp" - -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -StringHeap::StringHeap(Allocator &allocator) : allocator(allocator) { -} - -void StringHeap::Destroy() { - allocator.Destroy(); -} - -void StringHeap::Move(StringHeap &other) { - other.allocator.Move(allocator); -} - -string_t StringHeap::AddString(const char *data, idx_t len) { - D_ASSERT(Utf8Proc::Analyze(data, len) != UnicodeType::INVALID); - return AddBlob(data, len); -} - -string_t StringHeap::AddString(const char *data) { - return AddString(data, strlen(data)); -} - -string_t StringHeap::AddString(const string &data) { - return AddString(data.c_str(), data.size()); -} - -string_t StringHeap::AddString(const string_t &data) { - return AddString(data.GetData(), data.GetSize()); -} - -string_t StringHeap::AddBlob(const char *data, idx_t len) { - auto insert_string = EmptyString(len); - auto insert_pos = insert_string.GetDataWriteable(); - memcpy(insert_pos, data, len); - insert_string.Finalize(); - return insert_string; -} - -string_t StringHeap::AddBlob(const string_t &data) { - return AddBlob(data.GetData(), data.GetSize()); -} - -string_t StringHeap::EmptyString(idx_t len) { - D_ASSERT(len > string_t::INLINE_LENGTH); - if (len > string_t::MAX_STRING_SIZE) { - throw OutOfRangeException("Cannot create a string of size: '%d', the maximum supported string size is: '%d'", - len, string_t::MAX_STRING_SIZE); - } - auto insert_pos = const_char_ptr_cast(allocator.Allocate(len)); - return string_t(insert_pos, UnsafeNumericCast(len)); -} - -idx_t StringHeap::SizeInBytes() const { - return allocator.SizeInBytes(); -} - -idx_t StringHeap::AllocationSize() const { - return allocator.AllocationSize(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/string_type.cpp b/src/duckdb/src/common/types/string_type.cpp deleted file mode 100644 index f5a236557..000000000 --- a/src/duckdb/src/common/types/string_type.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "duckdb/common/types/string_type.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/value.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -void string_t::Verify() const { -#ifdef DEBUG - VerifyUTF8(); -#endif - - VerifyCharacters(); -} - -void string_t::VerifyUTF8() const { - auto dataptr = GetData(); - (void)dataptr; - D_ASSERT(dataptr); - - auto utf_type = Utf8Proc::Analyze(dataptr, GetSize()); - (void)utf_type; - D_ASSERT(utf_type != UnicodeType::INVALID); -} - -void string_t::VerifyCharacters() const { - auto dataptr = GetData(); - (void)dataptr; - D_ASSERT(dataptr); - - // verify that the prefix contains the first four characters of the string - for (idx_t i = 0; i < MinValue(PREFIX_LENGTH, GetSize()); i++) { - D_ASSERT(GetPrefix()[i] == dataptr[i]); - } - // verify that for strings with length <= INLINE_LENGTH, the rest of the string is zero - for (idx_t i = GetSize(); i < INLINE_LENGTH; i++) { - D_ASSERT(GetData()[i] == '\0'); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/time.cpp b/src/duckdb/src/common/types/time.cpp deleted file mode 100644 index fa4d135f5..000000000 --- a/src/duckdb/src/common/types/time.cpp +++ /dev/null @@ -1,347 +0,0 @@ -#include "duckdb/common/types/time.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" - -#include -#include -#include - -namespace duckdb { -static_assert(sizeof(dtime_t) == sizeof(int64_t), "dtime_t was padded"); - -// string format is hh:mm:ss.microsecondsZ -// microseconds and Z are optional -// ISO 8601 - -bool Time::TryConvertInternal(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, - optional_ptr nanos) { - int32_t hour = -1, min = -1, sec = -1, micros = -1; - pos = 0; - - if (len == 0) { - return false; - } - - int sep; - - // skip leading spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - - if (pos >= len) { - return false; - } - - if (!StringUtil::CharacterIsDigit(buf[pos])) { - return false; - } - - // Allow up to 9 digit hours to support intervals - hour = 0; - for (int32_t digits = 9; pos < len && StringUtil::CharacterIsDigit(buf[pos]); ++pos) { - if (digits-- > 0) { - hour = hour * 10 + (buf[pos] - '0'); - } else { - return false; - } - } - - if (pos >= len) { - return false; - } - - // fetch the separator - sep = buf[pos++]; - if (sep != ':') { - // invalid separator - return false; - } - idx_t sep_pos = pos; - if (pos == len && !strict) { - min = 0; - } else { - if (!Date::ParseDoubleDigit(buf, len, pos, min)) { - return false; - } - if (min < 0 || min >= 60) { - return false; - } - } - - if (pos > len) { - return false; - } - if (pos == len && (!strict || sep_pos + 2 == pos)) { - sec = 0; - } else { - if (buf[pos++] != sep) { - return false; - } - - if (pos == len && !strict) { - sec = 0; - } else { - if (!Date::ParseDoubleDigit(buf, len, pos, sec)) { - return false; - } - if (sec < 0 || sec >= 60) { - return false; - } - } - } - - micros = 0; - if (pos < len && buf[pos] == '.') { - pos++; - // we expect some microseconds - int32_t mult = 100000; - if (nanos) { - // do we expect nanoseconds? - mult *= Interval::NANOS_PER_MICRO; - } - for (; pos < len && StringUtil::CharacterIsDigit(buf[pos]); pos++, mult /= 10) { - if (mult > 0) { - micros += (buf[pos] - '0') * mult; - } - } - if (nanos) { - *nanos = UnsafeNumericCast(micros % Interval::NANOS_PER_MICRO); - micros /= Interval::NANOS_PER_MICRO; - } - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return false; - } - } - - result = Time::FromTime(hour, min, sec, micros); - return true; -} - -bool Time::TryConvertInterval(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, - optional_ptr nanos) { - return Time::TryConvertInternal(buf, len, pos, result, strict, nanos); -} - -bool Time::TryConvertTime(const char *buf, idx_t len, idx_t &pos, dtime_t &result, bool strict, - optional_ptr nanos) { - if (!Time::TryConvertInternal(buf, len, pos, result, strict, nanos)) { - if (!strict) { - // last chance, check if we can parse as timestamp - timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp, nanos) == TimestampCastResult::SUCCESS) { - if (!Timestamp::IsFinite(timestamp)) { - return false; - } - result = Timestamp::GetTime(timestamp); - return true; - } - } - return false; - } - return result.micros <= Interval::MICROS_PER_DAY; -} - -bool Time::TryConvertTimeTZ(const char *buf, idx_t len, idx_t &pos, dtime_tz_t &result, bool &has_offset, bool strict, - optional_ptr nanos) { - dtime_t time_part; - has_offset = false; - if (!Time::TryConvertInternal(buf, len, pos, time_part, false, nanos)) { - if (!strict) { - // last chance, check if we can parse as timestamp - timestamp_t timestamp; - if (Timestamp::TryConvertTimestamp(buf, len, timestamp, nanos) == TimestampCastResult::SUCCESS) { - if (!Timestamp::IsFinite(timestamp)) { - return false; - } - result = dtime_tz_t(Timestamp::GetTime(timestamp), 0); - return true; - } - } - return false; - } - - // skip optional whitespace before offset - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - - // Get the ±HH[:MM] part - int hh = 0; - int mm = 0; - has_offset = (pos < len); - if (has_offset && !Timestamp::TryParseUTCOffset(buf, pos, len, hh, mm)) { - return false; - } - - // Offsets are in seconds in the open interval (-16:00:00, +16:00:00) - int32_t offset = ((hh * Interval::MINS_PER_HOUR) + mm) * Interval::SECS_PER_MINUTE; - - // Check for trailing seconds. - // (PG claims they don't support this but they do...) - if (pos < len && buf[pos] == ':') { - ++pos; - int ss = 0; - if (!Date::ParseDoubleDigit(buf, len, pos, ss)) { - return false; - } - offset += (offset < 0) ? -ss : ss; - } - - if (offset < dtime_tz_t::MIN_OFFSET || offset > dtime_tz_t::MAX_OFFSET) { - return false; - } - - // in strict mode, check remaining string for non-space characters - if (strict) { - // skip trailing spaces - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } - // check position. if end was not reached, non-space chars remaining - if (pos < len) { - return false; - } - } - - result = dtime_tz_t(time_part, offset); - - return true; -} - -dtime_t Time::NormalizeTimeTZ(dtime_tz_t timetz) { - date_t date(0); - return Interval::Add(timetz.time(), {0, 0, -timetz.offset() * Interval::MICROS_PER_SEC}, date); -} - -string Time::ConversionError(const string &str) { - return StringUtil::Format("time field value out of range: \"%s\", " - "expected format is ([YYYY-MM-DD ]HH:MM:SS[.MS])", - str); -} - -string Time::ConversionError(string_t str) { - return Time::ConversionError(str.GetString()); -} - -dtime_t Time::FromCString(const char *buf, idx_t len, bool strict, optional_ptr nanos) { - dtime_t result; - idx_t pos; - if (!Time::TryConvertTime(buf, len, pos, result, strict, nanos)) { - throw ConversionException(ConversionError(string(buf, len))); - } - return result; -} - -dtime_t Time::FromString(const string &str, bool strict, optional_ptr nanos) { - return Time::FromCString(str.c_str(), str.size(), strict, nanos); -} - -string Time::ToString(dtime_t time) { - int32_t time_units[4]; - Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); - - char micro_buffer[6]; - auto length = TimeToStringCast::Length(time_units, micro_buffer); - auto buffer = make_unsafe_uniq_array_uninitialized(length); - TimeToStringCast::Format(buffer.get(), length, time_units, micro_buffer); - return string(buffer.get(), length); -} - -string Time::ToUTCOffset(int hour_offset, int minute_offset) { - dtime_t time((hour_offset * Interval::MINS_PER_HOUR + minute_offset) * Interval::MICROS_PER_MINUTE); - - char buffer[1 + 2 + 1 + 2]; - idx_t length = 0; - buffer[length++] = (time.micros < 0 ? '-' : '+'); - time.micros = std::abs(time.micros); - - int32_t time_units[4]; - Time::Convert(time, time_units[0], time_units[1], time_units[2], time_units[3]); - - TimeToStringCast::FormatTwoDigits(buffer + length, time_units[0]); - length += 2; - if (time_units[1]) { - buffer[length++] = ':'; - TimeToStringCast::FormatTwoDigits(buffer + length, time_units[1]); - length += 2; - } - - return string(buffer, length); -} - -dtime_t Time::FromTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { - int64_t result; - result = hour; // hours - result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes - result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds - result = result * Interval::MICROS_PER_SEC + microseconds; // seconds -> microseconds - return dtime_t(result); -} - -int64_t Time::ToNanoTime(int32_t hour, int32_t minute, int32_t second, int32_t nanoseconds) { - int64_t result; - result = hour; // hours - result = result * Interval::MINS_PER_HOUR + minute; // hours -> minutes - result = result * Interval::SECS_PER_MINUTE + second; // minutes -> seconds - result = result * Interval::NANOS_PER_SEC + nanoseconds; // seconds -> nanoseconds - return result; -} - -bool Time::IsValidTime(int32_t hour, int32_t minute, int32_t second, int32_t microseconds) { - if (hour < 0 || hour >= 24) { - return (hour == 24) && (minute == 0) && (second == 0) && (microseconds == 0); - } - if (minute < 0 || minute >= 60) { - return false; - } - if (second < 0 || second > 60) { - return false; - } - if (microseconds < 0 || microseconds > 1000000) { - return false; - } - return true; -} - -void Time::Convert(dtime_t dtime, int32_t &hour, int32_t &min, int32_t &sec, int32_t µs) { - int64_t time = dtime.micros; - hour = int32_t(time / Interval::MICROS_PER_HOUR); - time -= int64_t(hour) * Interval::MICROS_PER_HOUR; - min = int32_t(time / Interval::MICROS_PER_MINUTE); - time -= int64_t(min) * Interval::MICROS_PER_MINUTE; - sec = int32_t(time / Interval::MICROS_PER_SEC); - time -= int64_t(sec) * Interval::MICROS_PER_SEC; - micros = int32_t(time); - D_ASSERT(Time::IsValidTime(hour, min, sec, micros)); -} - -dtime_t Time::FromTimeMs(int64_t time_ms) { - int64_t result; - if (!TryMultiplyOperator::Operation(time_ms, Interval::MICROS_PER_MSEC, result)) { - throw ConversionException("Could not convert Time(MS) to Time(US)"); - } - return dtime_t(result); -} - -dtime_t Time::FromTimeNs(int64_t time_ns) { - return dtime_t(time_ns / Interval::NANOS_PER_MICRO); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/timestamp.cpp b/src/duckdb/src/common/types/timestamp.cpp deleted file mode 100644 index d596e7530..000000000 --- a/src/duckdb/src/common/types/timestamp.cpp +++ /dev/null @@ -1,507 +0,0 @@ -#include "duckdb/common/types/timestamp.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/chrono.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/limits.hpp" -#include - -namespace duckdb { - -static_assert(sizeof(timestamp_t) == sizeof(int64_t), "timestamp_t was padded"); - -// Temporal values need to round down when changing precision, -// but C/C++ rounds towrds 0 when you simply divide. -// This piece of bit banging solves that problem. -template -static inline T TemporalRound(T value, T scale) { - const auto negative = int(value < 0); - return UnsafeNumericCast((value + negative) / scale - negative); -} - -// timestamp/datetime uses 64 bits, high 32 bits for date and low 32 bits for time -// string format is YYYY-MM-DDThh:mm:ssZ -// T may be a space -// Z is optional -// ISO 8601 - -// arithmetic operators -timestamp_t timestamp_t::operator+(const double &value) const { - timestamp_t result; - if (!TryAddOperator::Operation(this->value, int64_t(value), result.value)) { - throw OutOfRangeException("Overflow in timestamp addition"); - } - return result; -} - -int64_t timestamp_t::operator-(const timestamp_t &other) const { - int64_t result; - if (!TrySubtractOperator::Operation(value, int64_t(other.value), result)) { - throw OutOfRangeException("Overflow in timestamp subtraction"); - } - return result; -} - -// in-place operators -timestamp_t ×tamp_t::operator+=(const int64_t &delta) { - if (!TryAddOperator::Operation(value, delta, value)) { - throw OutOfRangeException("Overflow in timestamp increment"); - } - return *this; -} - -timestamp_t ×tamp_t::operator-=(const int64_t &delta) { - if (!TrySubtractOperator::Operation(value, delta, value)) { - throw OutOfRangeException("Overflow in timestamp decrement"); - } - return *this; -} - -TimestampCastResult Timestamp::TryConvertTimestampTZ(const char *str, idx_t len, timestamp_t &result, bool &has_offset, - string_t &tz, optional_ptr nanos) { - idx_t pos; - date_t date; - dtime_t time; - has_offset = false; - switch (Date::TryConvertDate(str, len, pos, date, has_offset)) { - case DateCastResult::ERROR_INCORRECT_FORMAT: - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - case DateCastResult::ERROR_RANGE: - return TimestampCastResult::ERROR_RANGE; - default: - break; - } - if (pos == len) { - // no time: only a date or special - if (date == date_t::infinity()) { - result = timestamp_t::infinity(); - return TimestampCastResult::SUCCESS; - } else if (date == date_t::ninfinity()) { - result = timestamp_t::ninfinity(); - return TimestampCastResult::SUCCESS; - } - return Timestamp::TryFromDatetime(date, dtime_t(0), result) ? TimestampCastResult::SUCCESS - : TimestampCastResult::ERROR_RANGE; - } - // try to parse a time field - if (str[pos] == ' ' || str[pos] == 'T') { - pos++; - } - idx_t time_pos = 0; - // TryConvertTime may recursively call us, so we opt for a stricter - // operation. Note that we can't pass strict== true here because we - // want to process any suffix. - if (!Time::TryConvertInterval(str + pos, len - pos, time_pos, time, false, nanos)) { - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - } - // We parsed an interval, so make sure it is in range. - if (time.micros > Interval::MICROS_PER_DAY) { - return TimestampCastResult::ERROR_RANGE; - } - pos += time_pos; - if (!Timestamp::TryFromDatetime(date, time, result)) { - return TimestampCastResult::ERROR_RANGE; - } - if (pos < len) { - // skip a "Z" at the end (as per the ISO8601 specs) - int hour_offset, minute_offset; - if (str[pos] == 'Z') { - pos++; - has_offset = true; - } else if (Timestamp::TryParseUTCOffset(str, pos, len, hour_offset, minute_offset)) { - const int64_t delta = hour_offset * Interval::MICROS_PER_HOUR + minute_offset * Interval::MICROS_PER_MINUTE; - if (!TrySubtractOperator::Operation(result.value, delta, result.value)) { - return TimestampCastResult::ERROR_RANGE; - } - has_offset = true; - } else { - // Parse a time zone: / [A-Za-z0-9/_]+/ - if (str[pos++] != ' ') { - return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; - } - auto tz_name = str + pos; - for (; pos < len && CharacterIsTimeZone(str[pos]); ++pos) { - continue; - } - auto tz_len = str + pos - tz_name; - if (tz_len) { - tz = string_t(tz_name, UnsafeNumericCast(tz_len)); - } - // Note that the caller must reinterpret the instant we return to the given time zone - } - - // skip any spaces at the end - while (pos < len && StringUtil::CharacterIsSpace(str[pos])) { - pos++; - } - if (pos < len) { - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - } - } - return TimestampCastResult::SUCCESS; -} - -TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_t &result, - optional_ptr nanos) { - string_t tz(nullptr, 0); - bool has_offset = false; - // We don't understand TZ without an extension, so fail if one was provided. - auto success = TryConvertTimestampTZ(str, len, result, has_offset, tz, nanos); - if (success != TimestampCastResult::SUCCESS) { - return success; - } - if (tz.GetSize() == 0) { - // no timezone provided - success! - return TimestampCastResult::SUCCESS; - } - if (tz.GetSize() == 3) { - // we can ONLY handle UTC without ICU being loaded - auto tz_ptr = tz.GetData(); - if ((tz_ptr[0] == 'u' || tz_ptr[0] == 'U') && (tz_ptr[1] == 't' || tz_ptr[1] == 'T') && - (tz_ptr[2] == 'c' || tz_ptr[2] == 'C')) { - return TimestampCastResult::SUCCESS; - } - } - return TimestampCastResult::ERROR_NON_UTC_TIMEZONE; -} - -bool Timestamp::TryFromTimestampNanos(timestamp_t input, int32_t nanos, timestamp_ns_t &result) { - if (!IsFinite(input)) { - result.value = input.value; - return true; - } - // Scale to ns - if (!TryMultiplyOperator::Operation(input.value, Interval::NANOS_PER_MICRO, result.value)) { - return false; - } - - if (!TryAddOperator::Operation(result.value, int64_t(nanos), result.value)) { - return false; - } - - return IsFinite(result); -} - -TimestampCastResult Timestamp::TryConvertTimestamp(const char *str, idx_t len, timestamp_ns_t &result) { - int32_t nanos = 0; - auto success = TryConvertTimestamp(str, len, result, &nanos); - if (success != TimestampCastResult::SUCCESS) { - return success; - } - if (!TryFromTimestampNanos(result, nanos, result)) { - return TimestampCastResult::ERROR_INCORRECT_FORMAT; - } - return TimestampCastResult::SUCCESS; -} - -string Timestamp::FormatError(const string &str) { - return StringUtil::Format("invalid timestamp field format: \"%s\", " - "expected format is (YYYY-MM-DD HH:MM:SS[.US][±HH:MM| ZONE])", - str); -} - -string Timestamp::UnsupportedTimezoneError(const string &str) { - return StringUtil::Format("timestamp field value \"%s\" has a timestamp that is not UTC.\nUse the TIMESTAMPTZ type " - "with the ICU extension loaded to handle non-UTC timestamps.", - str); -} - -string Timestamp::RangeError(const string &str) { - return StringUtil::Format("timestamp field value out of range: \"%s\"", str); -} - -string Timestamp::FormatError(string_t str) { - return Timestamp::FormatError(str.GetString()); -} - -string Timestamp::UnsupportedTimezoneError(string_t str) { - return Timestamp::UnsupportedTimezoneError(str.GetString()); -} - -string Timestamp::RangeError(string_t str) { - return Timestamp::RangeError(str.GetString()); -} - -timestamp_t Timestamp::FromCString(const char *str, idx_t len, optional_ptr nanos) { - timestamp_t result; - switch (Timestamp::TryConvertTimestamp(str, len, result, nanos)) { - case TimestampCastResult::SUCCESS: - break; - case TimestampCastResult::ERROR_NON_UTC_TIMEZONE: - throw ConversionException(UnsupportedTimezoneError(string(str, len))); - case TimestampCastResult::ERROR_INCORRECT_FORMAT: - throw ConversionException(FormatError(string(str, len))); - case TimestampCastResult::ERROR_RANGE: - throw ConversionException(RangeError(string(str, len))); - } - return result; -} - -bool Timestamp::TryParseUTCOffset(const char *str, idx_t &pos, idx_t len, int &hour_offset, int &minute_offset) { - minute_offset = 0; - idx_t curpos = pos; - // parse the next 3 characters - if (curpos + 3 > len) { - // no characters left to parse - return false; - } - char sign_char = str[curpos]; - if (sign_char != '+' && sign_char != '-') { - // expected either + or - - return false; - } - curpos++; - if (!StringUtil::CharacterIsDigit(str[curpos]) || !StringUtil::CharacterIsDigit(str[curpos + 1])) { - // expected +HH or -HH - return false; - } - hour_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); - if (sign_char == '-') { - hour_offset = -hour_offset; - } - curpos += 2; - - // optional minute specifier: expected either "MM" or ":MM" - if (curpos >= len) { - // done, nothing left - pos = curpos; - return true; - } - if (str[curpos] == ':') { - curpos++; - } - if (curpos + 2 > len || !StringUtil::CharacterIsDigit(str[curpos]) || - !StringUtil::CharacterIsDigit(str[curpos + 1])) { - // no MM specifier - pos = curpos; - return true; - } - // we have an MM specifier: parse it - minute_offset = (str[curpos] - '0') * 10 + (str[curpos + 1] - '0'); - if (sign_char == '-') { - minute_offset = -minute_offset; - } - pos = curpos + 2; - return true; -} - -timestamp_t Timestamp::FromString(const string &str) { - return Timestamp::FromCString(str.c_str(), str.size()); -} - -string Timestamp::ToString(timestamp_t timestamp) { - if (timestamp == timestamp_t::infinity()) { - return Date::PINF; - } - if (timestamp == timestamp_t::ninfinity()) { - return Date::NINF; - } - - date_t date; - dtime_t time; - Timestamp::Convert(timestamp, date, time); - return Date::ToString(date) + " " + Time::ToString(time); -} - -date_t Timestamp::GetDate(timestamp_t timestamp) { - if (DUCKDB_UNLIKELY(timestamp == timestamp_t::infinity())) { - return date_t::infinity(); - } - if (DUCKDB_UNLIKELY(timestamp == timestamp_t::ninfinity())) { - return date_t::ninfinity(); - } - return date_t(UnsafeNumericCast((timestamp.value + (timestamp.value < 0)) / Interval::MICROS_PER_DAY - - (timestamp.value < 0))); -} - -dtime_t Timestamp::GetTime(timestamp_t timestamp) { - if (!IsFinite(timestamp)) { - throw ConversionException("Can't get TIME of infinite TIMESTAMP"); - } - date_t date = Timestamp::GetDate(timestamp); - return dtime_t(timestamp.value - (int64_t(date.days) * int64_t(Interval::MICROS_PER_DAY))); -} - -bool Timestamp::TryFromDatetime(date_t date, dtime_t time, timestamp_t &result) { - if (!TryMultiplyOperator::Operation(date.days, Interval::MICROS_PER_DAY, result.value)) { - return false; - } - if (!TryAddOperator::Operation(result.value, time.micros, result.value)) { - return false; - } - return Timestamp::IsFinite(result); -} - -bool Timestamp::TryFromDatetime(date_t date, dtime_tz_t timetz, timestamp_t &result) { - if (!TryFromDatetime(date, timetz.time(), result)) { - return false; - } - // Offset is in seconds - const auto offset = int64_t(timetz.offset() * Interval::MICROS_PER_SEC); - if (!TryAddOperator::Operation(result.value, -offset, result.value)) { - return false; - } - return Timestamp::IsFinite(result); -} - -timestamp_t Timestamp::FromDatetime(date_t date, dtime_t time) { - timestamp_t result; - if (!TryFromDatetime(date, time, result)) { - throw ConversionException("Date and time not in timestamp range"); - } - return result; -} - -void Timestamp::Convert(timestamp_t timestamp, date_t &out_date, dtime_t &out_time) { - out_date = GetDate(timestamp); - int64_t days_micros; - if (!TryMultiplyOperator::Operation(out_date.days, Interval::MICROS_PER_DAY, - days_micros)) { - throw ConversionException("Date out of range in timestamp conversion"); - } - out_time = dtime_t(timestamp.value - days_micros); - D_ASSERT(timestamp == Timestamp::FromDatetime(out_date, out_time)); -} - -void Timestamp::Convert(timestamp_ns_t input, date_t &out_date, dtime_t &out_time, int32_t &out_nanos) { - timestamp_t ms(TemporalRound(input.value, Interval::NANOS_PER_MICRO)); - out_date = Timestamp::GetDate(ms); - int64_t days_nanos; - if (!TryMultiplyOperator::Operation(out_date.days, Interval::NANOS_PER_DAY, - days_nanos)) { - throw ConversionException("Date out of range in timestamp_ns conversion"); - } - - out_time = dtime_t((input.value - days_nanos) / Interval::NANOS_PER_MICRO); - out_nanos = UnsafeNumericCast((input.value - days_nanos) % Interval::NANOS_PER_MICRO); -} - -timestamp_t Timestamp::GetCurrentTimestamp() { - auto now = system_clock::now(); - auto epoch_ms = duration_cast(now.time_since_epoch()).count(); - return Timestamp::FromEpochMs(epoch_ms); -} - -timestamp_t Timestamp::FromEpochSecondsPossiblyInfinite(int64_t sec) { - int64_t result; - if (!TryMultiplyOperator::Operation(sec, Interval::MICROS_PER_SEC, result)) { - throw ConversionException("Could not convert Timestamp(S) to Timestamp(US)"); - } - return timestamp_t(result); -} - -timestamp_t Timestamp::FromEpochSeconds(int64_t sec) { - D_ASSERT(Timestamp::IsFinite(timestamp_t(sec))); - return FromEpochSecondsPossiblyInfinite(sec); -} - -timestamp_t Timestamp::FromEpochMsPossiblyInfinite(int64_t ms) { - int64_t result; - if (!TryMultiplyOperator::Operation(ms, Interval::MICROS_PER_MSEC, result)) { - throw ConversionException("Could not convert Timestamp(MS) to Timestamp(US)"); - } - return timestamp_t(result); -} - -timestamp_t Timestamp::FromEpochMs(int64_t ms) { - D_ASSERT(Timestamp::IsFinite(timestamp_t(ms))); - return FromEpochMsPossiblyInfinite(ms); -} - -timestamp_t Timestamp::FromEpochMicroSeconds(int64_t micros) { - return timestamp_t(micros); -} - -timestamp_t Timestamp::FromEpochNanoSecondsPossiblyInfinite(int64_t ns) { - return timestamp_t(ns / Interval::NANOS_PER_MICRO); -} - -timestamp_t Timestamp::FromEpochNanoSeconds(int64_t ns) { - D_ASSERT(Timestamp::IsFinite(timestamp_t(ns))); - return FromEpochNanoSecondsPossiblyInfinite(ns); -} - -timestamp_ns_t Timestamp::TimestampNsFromEpochMillis(int64_t millis) { - D_ASSERT(Timestamp::IsFinite(timestamp_t(millis))); - timestamp_ns_t result; - if (!TryMultiplyOperator::Operation(millis, Interval::NANOS_PER_MICRO, result.value)) { - throw ConversionException("Could not convert Timestamp(US) to Timestamp(NS)"); - } - return result; -} - -timestamp_ns_t Timestamp::TimestampNsFromEpochMicros(int64_t micros) { - D_ASSERT(Timestamp::IsFinite(timestamp_t(micros))); - timestamp_ns_t result; - if (!TryMultiplyOperator::Operation(micros, Interval::NANOS_PER_MSEC, result.value)) { - throw ConversionException("Could not convert Timestamp(MS) to Timestamp(NS)"); - } - return result; -} - -int64_t Timestamp::GetEpochSeconds(timestamp_t timestamp) { - D_ASSERT(Timestamp::IsFinite(timestamp)); - return timestamp.value / Interval::MICROS_PER_SEC; -} - -int64_t Timestamp::GetEpochMs(timestamp_t timestamp) { - D_ASSERT(Timestamp::IsFinite(timestamp)); - return timestamp.value / Interval::MICROS_PER_MSEC; -} - -int64_t Timestamp::GetEpochMicroSeconds(timestamp_t timestamp) { - return timestamp.value; -} - -bool Timestamp::TryGetEpochNanoSeconds(timestamp_t timestamp, int64_t &result) { - D_ASSERT(Timestamp::IsFinite(timestamp)); - if (!TryMultiplyOperator::Operation(timestamp.value, Interval::NANOS_PER_MICRO, result)) { - return false; - } - return true; -} - -int64_t Timestamp::GetEpochNanoSeconds(timestamp_t timestamp) { - int64_t result; - D_ASSERT(Timestamp::IsFinite(timestamp)); - if (!TryGetEpochNanoSeconds(timestamp, result)) { - throw ConversionException("Could not convert Timestamp(US) to Timestamp(NS)"); - } - return result; -} - -int64_t Timestamp::GetEpochNanoSeconds(timestamp_ns_t timestamp) { - D_ASSERT(Timestamp::IsFinite(timestamp)); - return timestamp.value; -} - -int64_t Timestamp::GetEpochRounded(timestamp_t input, int64_t power_of_ten) { - D_ASSERT(Timestamp::IsFinite(input)); - // Round away from the epoch. - // Scale first so we don't overflow. - const auto scaling = power_of_ten / 2; - input.value /= scaling; - if (input.value < 0) { - --input.value; - } else { - ++input.value; - } - input.value /= 2; - return input.value; -} - -double Timestamp::GetJulianDay(timestamp_t timestamp) { - double result = double(Timestamp::GetTime(timestamp).micros); - result /= Interval::MICROS_PER_DAY; - result += double(Date::ExtractJulianDay(Timestamp::GetDate(timestamp))); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/uhugeint.cpp b/src/duckdb/src/common/types/uhugeint.cpp deleted file mode 100644 index 7469dd809..000000000 --- a/src/duckdb/src/common/types/uhugeint.cpp +++ /dev/null @@ -1,747 +0,0 @@ -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/windows_undefs.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// String Conversion -//===--------------------------------------------------------------------===// -const uhugeint_t Uhugeint::POWERS_OF_TEN[] { - uhugeint_t(1), - uhugeint_t(10), - uhugeint_t(100), - uhugeint_t(1000), - uhugeint_t(10000), - uhugeint_t(100000), - uhugeint_t(1000000), - uhugeint_t(10000000), - uhugeint_t(100000000), - uhugeint_t(1000000000), - uhugeint_t(10000000000), - uhugeint_t(100000000000), - uhugeint_t(1000000000000), - uhugeint_t(10000000000000), - uhugeint_t(100000000000000), - uhugeint_t(1000000000000000), - uhugeint_t(10000000000000000), - uhugeint_t(100000000000000000), - uhugeint_t(1000000000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(10), - uhugeint_t(1000000000000000000) * uhugeint_t(100), - uhugeint_t(1000000000000000000) * uhugeint_t(1000), - uhugeint_t(1000000000000000000) * uhugeint_t(10000), - uhugeint_t(1000000000000000000) * uhugeint_t(100000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000), - uhugeint_t(1000000000000000000) * uhugeint_t(10000000), - uhugeint_t(1000000000000000000) * uhugeint_t(100000000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(10000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(100000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(10000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(100000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(10000000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(100000000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000000000000), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000000000000) * uhugeint_t(10), - uhugeint_t(1000000000000000000) * uhugeint_t(1000000000000000000) * uhugeint_t(100)}; - -string Uhugeint::ToString(uhugeint_t input) { - uhugeint_t remainder; - string result; - while (true) { - if (!input.lower && !input.upper) { - break; - } - input = Uhugeint::DivMod(input, 10, remainder); - result = string(1, UnsafeNumericCast('0' + remainder.lower)) + result; // NOLINT - } - if (result.empty()) { - // value is zero - return "0"; - } - return result; -} - -//===--------------------------------------------------------------------===// -// Negate -//===--------------------------------------------------------------------===// - -template <> -void Uhugeint::NegateInPlace(uhugeint_t &input) { - uhugeint_t result = 0; - result -= input; - input = result; -} - -bool Uhugeint::TryNegate(uhugeint_t input, uhugeint_t &result) { - // unsigned integers can always be negated - Uhugeint::NegateInPlace(input); - result = input; - return true; -} - -//===--------------------------------------------------------------------===// -// Multiply -//===--------------------------------------------------------------------===// -bool Uhugeint::TryMultiply(uhugeint_t lhs, uhugeint_t rhs, uhugeint_t &result) { -#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) - __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); - __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); - __uint128_t result_u128; - if (__builtin_mul_overflow(left, right, &result_u128)) { - return false; - } - result.upper = uint64_t(result_u128 >> 64); - result.lower = uint64_t(result_u128 & 0xffffffffffffffff); -#else - // split values into 4 32-bit parts - uint64_t top[4] = {lhs.upper >> 32, lhs.upper & 0xffffffff, lhs.lower >> 32, lhs.lower & 0xffffffff}; - uint64_t bottom[4] = {rhs.upper >> 32, rhs.upper & 0xffffffff, rhs.lower >> 32, rhs.lower & 0xffffffff}; - uint64_t products[4][4]; - - // multiply each component of the values - for (int y = 3; y > -1; y--) { - for (int x = 3; x > -1; x--) { - products[3 - x][y] = top[x] * bottom[y]; - } - } - - // if any of these products are set to a non-zero value, there is always an overflow - if (products[2][1] || products[1][0] || products[2][0]) { - return false; - } - - // if the high bits of any of these are set, there is always an overflow - if (products[1][1] & 0xffffffff00000000 || products[3][0] & 0xffffffff00000000 || - products[3][3] & 0xffffffff00000000 || products[3][2] & 0xffffffff00000000 || - products[3][1] & 0xffffffff00000000 || products[2][2] & 0xffffffff00000000 || - products[0][0] & 0xffffffff00000000) { - return false; - } - - // first row - uint64_t fourth32 = (products[0][3] & 0xffffffff); - uint64_t third32 = (products[0][2] & 0xffffffff) + (products[0][3] >> 32); - uint64_t second32 = (products[0][1] & 0xffffffff) + (products[0][2] >> 32); - uint64_t first32 = (products[0][0] & 0xffffffff) + (products[0][1] >> 32); - - // second row - third32 += (products[1][3] & 0xffffffff); - second32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); - first32 += (products[1][1] & 0xffffffff) + (products[1][2] >> 32); - - // third row - second32 += (products[2][3] & 0xffffffff); - first32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); - - // fourth row - first32 += (products[3][3] & 0xffffffff); - - // move carry to next digit - third32 += fourth32 >> 32; - second32 += third32 >> 32; - first32 += second32 >> 32; - - // remove carry from current digit - fourth32 &= 0xffffffff; - third32 &= 0xffffffff; - second32 &= 0xffffffff; - first32 &= 0xffffffff; - - // combine components - result.lower = (third32 << 32) | fourth32; - result.upper = (first32 << 32) | second32; -#endif - return true; -} - -// No overflow check, will wrap -template <> -uhugeint_t Uhugeint::Multiply(uhugeint_t lhs, uhugeint_t rhs) { - uhugeint_t result; -#if ((__GNUC__ >= 5) || defined(__clang__)) && defined(__SIZEOF_INT128__) - __uint128_t left = __uint128_t(lhs.lower) + (__uint128_t(lhs.upper) << 64); - __uint128_t right = __uint128_t(rhs.lower) + (__uint128_t(rhs.upper) << 64); - __uint128_t result_u128; - - result_u128 = left * right; - result.upper = uint64_t(result_u128 >> 64); - result.lower = uint64_t(result_u128 & 0xffffffffffffffff); -#else - // split values into 4 32-bit parts - uint64_t top[4] = {lhs.upper >> 32, lhs.upper & 0xffffffff, lhs.lower >> 32, lhs.lower & 0xffffffff}; - uint64_t bottom[4] = {rhs.upper >> 32, rhs.upper & 0xffffffff, rhs.lower >> 32, rhs.lower & 0xffffffff}; - uint64_t products[4][4]; - - // multiply each component of the values - for (int y = 3; y > -1; y--) { - for (int x = 3; x > -1; x--) { - products[3 - x][y] = top[x] * bottom[y]; - } - } - - // first row - uint64_t fourth32 = (products[0][3] & 0xffffffff); - uint64_t third32 = (products[0][2] & 0xffffffff) + (products[0][3] >> 32); - uint64_t second32 = (products[0][1] & 0xffffffff) + (products[0][2] >> 32); - uint64_t first32 = (products[0][0] & 0xffffffff) + (products[0][1] >> 32); - - // second row - third32 += (products[1][3] & 0xffffffff); - second32 += (products[1][2] & 0xffffffff) + (products[1][3] >> 32); - first32 += (products[1][1] & 0xffffffff) + (products[1][2] >> 32); - - // third row - second32 += (products[2][3] & 0xffffffff); - first32 += (products[2][2] & 0xffffffff) + (products[2][3] >> 32); - - // fourth row - first32 += (products[3][3] & 0xffffffff); - - // move carry to next digit - third32 += fourth32 >> 32; - second32 += third32 >> 32; - first32 += second32 >> 32; - - // remove carry from current digit - fourth32 &= 0xffffffff; - third32 &= 0xffffffff; - second32 &= 0xffffffff; - first32 &= 0xffffffff; - - // combine components - result.lower = (third32 << 32) | fourth32; - result.upper = (first32 << 32) | second32; -#endif - return result; -} - -//===--------------------------------------------------------------------===// -// Divide -//===--------------------------------------------------------------------===// - -int Sign(uhugeint_t n) { - return (n > 0); -} - -uhugeint_t Abs(uhugeint_t n) { - return (n); -} - -static uint8_t Bits(uhugeint_t x) { - uint8_t out = 0; - if (x.upper) { - out = 64; - for (uint64_t upper = x.upper; upper; upper >>= 1) { - ++out; - } - } else { - for (uint64_t lower = x.lower; lower; lower >>= 1) { - ++out; - } - } - return out; -} - -uhugeint_t Uhugeint::DivMod(uhugeint_t lhs, uhugeint_t rhs, uhugeint_t &remainder) { - if (rhs == 0) { - remainder = lhs; - return uhugeint_t(0); - } - - remainder = uhugeint_t(0); - if (rhs == uhugeint_t(1)) { - return lhs; - } else if (lhs == rhs) { - return uhugeint_t(1); - } else if (lhs == uhugeint_t(0) || lhs < rhs) { - remainder = lhs; - return uhugeint_t(0); - } - - uhugeint_t result = 0; - for (uint8_t idx = Bits(lhs); idx > 0; --idx) { - result <<= 1; - remainder <<= 1; - - if (((lhs >> (idx - 1U)) & 1) != 0) { - remainder += 1; - } - - if (remainder >= rhs) { - remainder -= rhs; - result += 1; - } - } - return result; -} - -template <> -uhugeint_t Uhugeint::Divide(uhugeint_t lhs, uhugeint_t rhs) { - uhugeint_t remainder; - return Uhugeint::DivMod(lhs, rhs, remainder); -} - -template <> -uhugeint_t Uhugeint::Modulo(uhugeint_t lhs, uhugeint_t rhs) { - uhugeint_t remainder; - (void)Uhugeint::DivMod(lhs, rhs, remainder); - return remainder; -} - -//===--------------------------------------------------------------------===// -// Add/Subtract -//===--------------------------------------------------------------------===// -bool Uhugeint::TryAddInPlace(uhugeint_t &lhs, uhugeint_t rhs) { - uint64_t new_upper = lhs.upper + rhs.upper; - bool no_overflow = !(new_upper < lhs.upper || new_upper < rhs.upper); - new_upper += (lhs.lower + rhs.lower) < lhs.lower; - if (new_upper < lhs.upper || new_upper < rhs.upper) { - no_overflow = false; - } - lhs.upper = new_upper; - lhs.lower += rhs.lower; - return no_overflow; -} - -bool Uhugeint::TrySubtractInPlace(uhugeint_t &lhs, uhugeint_t rhs) { - uint64_t new_upper = lhs.upper - rhs.upper - ((lhs.lower - rhs.lower) > lhs.lower); - bool no_overflow = !(new_upper > lhs.upper); - lhs.lower -= rhs.lower; - lhs.upper = new_upper; - return no_overflow; -} - -template <> -uhugeint_t Uhugeint::Add(uhugeint_t lhs, uhugeint_t rhs) { - return lhs + rhs; -} - -template <> -uhugeint_t Uhugeint::Subtract(uhugeint_t lhs, uhugeint_t rhs) { - return lhs - rhs; -} - -//===--------------------------------------------------------------------===// -// Cast/Conversion -//===--------------------------------------------------------------------===// -template -bool UhugeintTryCastInteger(uhugeint_t input, DST &result) { - if (input.upper == 0 && input.lower <= uint64_t(NumericLimits::Maximum())) { - result = DST(input.lower); - return true; - } - return false; -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, int8_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, int16_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, int32_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, int64_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, uint8_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, uint16_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, uint32_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, uint64_t &result) { - return UhugeintTryCastInteger(input, result); -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, uhugeint_t &result) { - result = input; - return true; -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, hugeint_t &result) { - if (input > uhugeint_t(NumericLimits::Maximum())) { - return false; - } - - result.lower = input.lower; - result.upper = UnsafeNumericCast(input.upper); - return true; -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, float &result) { - double dbl_result; - Uhugeint::TryCast(input, dbl_result); - result = (float)dbl_result; - return true; -} - -template -bool CastUhugeintToFloating(uhugeint_t input, REAL_T &result) { - result = REAL_T(input.lower) + REAL_T(input.upper) * REAL_T(NumericLimits::Maximum()); - return true; -} - -template <> -bool Uhugeint::TryCast(uhugeint_t input, double &result) { - return CastUhugeintToFloating(input, result); -} - -template -uhugeint_t UhugeintConvertInteger(DST input) { - uhugeint_t result; - result.lower = (uint64_t)input; - result.upper = 0; - return result; -} - -template <> -bool Uhugeint::TryConvert(const char *value, uhugeint_t &result) { - auto len = strlen(value); - string_t string_val(value, UnsafeNumericCast(len)); - return TryCast::Operation(string_val, result, true); -} - -template <> -bool Uhugeint::TryConvert(int8_t value, uhugeint_t &result) { - if (value < 0) { - return false; - } - result = UhugeintConvertInteger(value); - return true; -} - -template <> -bool Uhugeint::TryConvert(int16_t value, uhugeint_t &result) { - if (value < 0) { - return false; - } - result = UhugeintConvertInteger(value); - return true; -} - -template <> -bool Uhugeint::TryConvert(int32_t value, uhugeint_t &result) { - if (value < 0) { - return false; - } - result = UhugeintConvertInteger(value); - return true; -} - -template <> -bool Uhugeint::TryConvert(int64_t value, uhugeint_t &result) { - if (value < 0) { - return false; - } - result = UhugeintConvertInteger(value); - return true; -} -template <> -bool Uhugeint::TryConvert(uint8_t value, uhugeint_t &result) { - result = UhugeintConvertInteger(value); - return true; -} -template <> -bool Uhugeint::TryConvert(uint16_t value, uhugeint_t &result) { - result = UhugeintConvertInteger(value); - return true; -} -template <> -bool Uhugeint::TryConvert(uint32_t value, uhugeint_t &result) { - result = UhugeintConvertInteger(value); - return true; -} -template <> -bool Uhugeint::TryConvert(uint64_t value, uhugeint_t &result) { - result = UhugeintConvertInteger(value); - return true; -} - -template <> -bool Uhugeint::TryConvert(uhugeint_t value, uhugeint_t &result) { - result = value; - return true; -} - -template <> -bool Uhugeint::TryConvert(float value, uhugeint_t &result) { - return Uhugeint::TryConvert(double(value), result); -} - -template -bool ConvertFloatingToUhugeint(REAL_T value, uhugeint_t &result) { - if (!Value::IsFinite(value)) { - return false; - } - if (value < 0 || value >= 340282366920938463463374607431768211456.0) { - return false; - } - result.lower = (uint64_t)fmod(value, REAL_T(NumericLimits::Maximum())); - result.upper = (uint64_t)(value / REAL_T(NumericLimits::Maximum())); - return true; -} - -template <> -bool Uhugeint::TryConvert(double value, uhugeint_t &result) { - return ConvertFloatingToUhugeint(value, result); -} - -template <> -bool Uhugeint::TryConvert(long double value, uhugeint_t &result) { - return ConvertFloatingToUhugeint(value, result); -} - -//===--------------------------------------------------------------------===// -// uhugeint_t operators -//===--------------------------------------------------------------------===// -uhugeint_t::uhugeint_t(uint64_t value) { - this->lower = value; - this->upper = 0; -} - -bool uhugeint_t::operator==(const uhugeint_t &rhs) const { - return Uhugeint::Equals(*this, rhs); -} - -bool uhugeint_t::operator!=(const uhugeint_t &rhs) const { - return Uhugeint::NotEquals(*this, rhs); -} - -bool uhugeint_t::operator<(const uhugeint_t &rhs) const { - return Uhugeint::LessThan(*this, rhs); -} - -bool uhugeint_t::operator<=(const uhugeint_t &rhs) const { - return Uhugeint::LessThanEquals(*this, rhs); -} - -bool uhugeint_t::operator>(const uhugeint_t &rhs) const { - return Uhugeint::GreaterThan(*this, rhs); -} - -bool uhugeint_t::operator>=(const uhugeint_t &rhs) const { - return Uhugeint::GreaterThanEquals(*this, rhs); -} - -uhugeint_t uhugeint_t::operator+(const uhugeint_t &rhs) const { - return uhugeint_t(upper + rhs.upper + ((lower + rhs.lower) < lower), lower + rhs.lower); -} - -uhugeint_t uhugeint_t::operator-(const uhugeint_t &rhs) const { - return uhugeint_t(upper - rhs.upper - ((lower - rhs.lower) > lower), lower - rhs.lower); -} - -uhugeint_t uhugeint_t::operator*(const uhugeint_t &rhs) const { - uhugeint_t result = *this; - result *= rhs; - return result; -} - -uhugeint_t uhugeint_t::operator/(const uhugeint_t &rhs) const { - return Uhugeint::Divide(*this, rhs); -} - -uhugeint_t uhugeint_t::operator%(const uhugeint_t &rhs) const { - return Uhugeint::Modulo(*this, rhs); -} - -uhugeint_t uhugeint_t::operator-() const { - return Uhugeint::Negate(*this); -} - -uhugeint_t uhugeint_t::operator>>(const uhugeint_t &rhs) const { - const uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return uhugeint_t(0); - } else if (shift == 0) { - return *this; - } else if (shift == 64) { - return uhugeint_t(0, upper); - } else if (shift < 64) { - return uhugeint_t(upper >> shift, (upper << (64 - shift)) + (lower >> shift)); - } else if ((128 > shift) && (shift > 64)) { - return uhugeint_t(0, (upper >> (shift - 64))); - } - return uhugeint_t(0); -} - -uhugeint_t uhugeint_t::operator<<(const uhugeint_t &rhs) const { - const uint64_t shift = rhs.lower; - if (rhs.upper != 0 || shift >= 128) { - return uhugeint_t(0); - } else if (shift == 0) { - return *this; - } else if (shift == 64) { - return uhugeint_t(lower, 0); - } else if (shift < 64) { - return uhugeint_t((upper << shift) + (lower >> (64 - shift)), lower << shift); - } else if ((128 > shift) && (shift > 64)) { - return uhugeint_t(lower << (shift - 64), 0); - } - return uhugeint_t(0); -} - -uhugeint_t uhugeint_t::operator&(const uhugeint_t &rhs) const { - uhugeint_t result; - result.lower = lower & rhs.lower; - result.upper = upper & rhs.upper; - return result; -} - -uhugeint_t uhugeint_t::operator|(const uhugeint_t &rhs) const { - uhugeint_t result; - result.lower = lower | rhs.lower; - result.upper = upper | rhs.upper; - return result; -} - -uhugeint_t uhugeint_t::operator^(const uhugeint_t &rhs) const { - uhugeint_t result; - result.lower = lower ^ rhs.lower; - result.upper = upper ^ rhs.upper; - return result; -} - -uhugeint_t uhugeint_t::operator~() const { - uhugeint_t result; - result.lower = ~lower; - result.upper = ~upper; - return result; -} - -uhugeint_t &uhugeint_t::operator+=(const uhugeint_t &rhs) { - *this = *this + rhs; - return *this; -} - -uhugeint_t &uhugeint_t::operator-=(const uhugeint_t &rhs) { - *this = *this - rhs; - return *this; -} - -uhugeint_t &uhugeint_t::operator*=(const uhugeint_t &rhs) { - *this = Uhugeint::Multiply(*this, rhs); - return *this; -} - -uhugeint_t &uhugeint_t::operator/=(const uhugeint_t &rhs) { - *this = Uhugeint::Divide(*this, rhs); - return *this; -} - -uhugeint_t &uhugeint_t::operator%=(const uhugeint_t &rhs) { - *this = Uhugeint::Modulo(*this, rhs); - return *this; -} - -uhugeint_t &uhugeint_t::operator>>=(const uhugeint_t &rhs) { - *this = *this >> rhs; - return *this; -} - -uhugeint_t &uhugeint_t::operator<<=(const uhugeint_t &rhs) { - *this = *this << rhs; - return *this; -} - -uhugeint_t &uhugeint_t::operator&=(const uhugeint_t &rhs) { - lower &= rhs.lower; - upper &= rhs.upper; - return *this; -} - -uhugeint_t &uhugeint_t::operator|=(const uhugeint_t &rhs) { - lower |= rhs.lower; - upper |= rhs.upper; - return *this; -} - -uhugeint_t &uhugeint_t::operator^=(const uhugeint_t &rhs) { - lower ^= rhs.lower; - upper ^= rhs.upper; - return *this; -} - -bool uhugeint_t::operator!() const { - return *this == 0; -} - -uhugeint_t::operator bool() const { - return *this != 0; -} - -template -static T NarrowCast(const uhugeint_t &input) { - // NarrowCast is supposed to truncate (take lower) - return static_cast(input.lower); -} - -uhugeint_t::operator uint8_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator uint16_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator uint32_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator uint64_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator int8_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator int16_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator int32_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator int64_t() const { - return NarrowCast(*this); -} -uhugeint_t::operator hugeint_t() const { - return {static_cast(this->upper), this->lower}; -} - -string uhugeint_t::ToString() const { - return Uhugeint::ToString(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/uuid.cpp b/src/duckdb/src/common/types/uuid.cpp deleted file mode 100644 index a0196ae81..000000000 --- a/src/duckdb/src/common/types/uuid.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/random_engine.hpp" - -namespace duckdb { - -bool UUID::FromString(const string &str, hugeint_t &result) { - auto hex2char = [](char ch) -> unsigned char { - if (ch >= '0' && ch <= '9') { - return UnsafeNumericCast(ch - '0'); - } - if (ch >= 'a' && ch <= 'f') { - return UnsafeNumericCast(10 + ch - 'a'); - } - if (ch >= 'A' && ch <= 'F') { - return UnsafeNumericCast(10 + ch - 'A'); - } - return 0; - }; - auto is_hex = [](char ch) -> bool { - return (ch >= '0' && ch <= '9') || (ch >= 'a' && ch <= 'f') || (ch >= 'A' && ch <= 'F'); - }; - - if (str.empty()) { - return false; - } - idx_t has_braces = 0; - if (str.front() == '{') { - has_braces = 1; - } - if (has_braces && str.back() != '}') { - return false; - } - - result.lower = 0; - result.upper = 0; - size_t count = 0; - for (size_t i = has_braces; i < str.size() - has_braces; ++i) { - if (str[i] == '-') { - continue; - } - if (count >= 32 || !is_hex(str[i])) { - return false; - } - if (count >= 16) { - result.lower = (result.lower << 4) | hex2char(str[i]); - } else { - result.upper = (result.upper << 4) | hex2char(str[i]); - } - count++; - } - // Flip the first bit to make `order by uuid` same as `order by uuid::varchar` - result.upper ^= NumericLimits::Minimum(); - return count == 32; -} - -void UUID::ToString(hugeint_t input, char *buf) { - auto byte_to_hex = [](uint64_t byte_val, char *buf, idx_t &pos) { - D_ASSERT(byte_val <= 0xFF); - static char const HEX_DIGITS[] = "0123456789abcdef"; - buf[pos++] = HEX_DIGITS[(byte_val >> 4) & 0xf]; - buf[pos++] = HEX_DIGITS[byte_val & 0xf]; - }; - - // Flip back before convert to string - int64_t upper = int64_t(uint64_t(input.upper) ^ (uint64_t(1) << 63)); - idx_t pos = 0; - byte_to_hex(upper >> 56 & 0xFF, buf, pos); - byte_to_hex(upper >> 48 & 0xFF, buf, pos); - byte_to_hex(upper >> 40 & 0xFF, buf, pos); - byte_to_hex(upper >> 32 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(upper >> 24 & 0xFF, buf, pos); - byte_to_hex(upper >> 16 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(upper >> 8 & 0xFF, buf, pos); - byte_to_hex(upper & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(input.lower >> 56 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 48 & 0xFF, buf, pos); - buf[pos++] = '-'; - byte_to_hex(input.lower >> 40 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 32 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 24 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 16 & 0xFF, buf, pos); - byte_to_hex(input.lower >> 8 & 0xFF, buf, pos); - byte_to_hex(input.lower & 0xFF, buf, pos); -} - -hugeint_t UUID::FromUHugeint(uhugeint_t input) { - hugeint_t result; - result.lower = input.lower; - if (input.upper > uint64_t(NumericLimits::Maximum())) { - result.upper = int64_t(input.upper - uint64_t(NumericLimits::Maximum()) - 1); - } else { - result.upper = int64_t(input.upper) - NumericLimits::Maximum() - 1; - } - return result; -} - -uhugeint_t UUID::ToUHugeint(hugeint_t input) { - uhugeint_t result; - result.lower = input.lower; - if (input.upper >= 0) { - result.upper = uint64_t(input.upper) + uint64_t(NumericLimits::Maximum()) + 1; - } else { - result.upper = uint64_t(input.upper + NumericLimits::Maximum() + 1); - } - return result; -} - -hugeint_t UUID::GenerateRandomUUID(RandomEngine &engine) { - uint8_t bytes[16]; - for (int i = 0; i < 16; i += 4) { - *reinterpret_cast(bytes + i) = engine.NextRandomInteger(); - } - // variant must be 10xxxxxx - bytes[8] &= 0xBF; - bytes[8] |= 0x80; - // version must be 0100xxxx - bytes[6] &= 0x4F; - bytes[6] |= 0x40; - - hugeint_t result; - result.upper = 0; - result.upper |= ((int64_t)bytes[0] << 56); - result.upper |= ((int64_t)bytes[1] << 48); - result.upper |= ((int64_t)bytes[2] << 40); - result.upper |= ((int64_t)bytes[3] << 32); - result.upper |= ((int64_t)bytes[4] << 24); - result.upper |= ((int64_t)bytes[5] << 16); - result.upper |= ((int64_t)bytes[6] << 8); - result.upper |= bytes[7]; - result.lower = 0; - result.lower |= ((uint64_t)bytes[8] << 56); - result.lower |= ((uint64_t)bytes[9] << 48); - result.lower |= ((uint64_t)bytes[10] << 40); - result.lower |= ((uint64_t)bytes[11] << 32); - result.lower |= ((uint64_t)bytes[12] << 24); - result.lower |= ((uint64_t)bytes[13] << 16); - result.lower |= ((uint64_t)bytes[14] << 8); - result.lower |= bytes[15]; - return result; -} - -hugeint_t UUID::GenerateRandomUUID() { - RandomEngine engine; - return GenerateRandomUUID(engine); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/validity_mask.cpp b/src/duckdb/src/common/types/validity_mask.cpp deleted file mode 100644 index c4a526b59..000000000 --- a/src/duckdb/src/common/types/validity_mask.cpp +++ /dev/null @@ -1,283 +0,0 @@ -#include "duckdb/common/types/validity_mask.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/serializer/write_stream.hpp" -#include "duckdb/common/serializer/read_stream.hpp" -#include "duckdb/common/types/selection_vector.hpp" - -namespace duckdb { - -ValidityData::ValidityData(idx_t count) : TemplatedValidityData(count) { -} -ValidityData::ValidityData(const ValidityMask &original, idx_t count) - : TemplatedValidityData(original.GetData(), count) { -} - -void ValidityMask::Combine(const ValidityMask &other, idx_t count) { - if (other.AllValid()) { - // X & 1 = X - return; - } - if (AllValid()) { - // 1 & Y = Y - Initialize(other); - return; - } - if (validity_mask == other.validity_mask) { - // X & X == X - return; - } - // have to merge - // create a new validity mask that contains the combined mask - auto owned_data = std::move(validity_data); - auto data = GetData(); - auto other_data = other.GetData(); - - Initialize(count); - auto result_data = GetData(); - - auto entry_count = ValidityData::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - result_data[entry_idx] = data[entry_idx] & other_data[entry_idx]; - } -} - -// LCOV_EXCL_START -string ValidityMask::ToString(idx_t count) const { - string result = "Validity Mask (" + to_string(count) + ") ["; - for (idx_t i = 0; i < count; i++) { - result += RowIsValid(i) ? "." : "X"; - } - result += "]"; - return result; -} - -string ValidityMask::ToString() const { - return ValidityMask::ToString(capacity); -} -// LCOV_EXCL_STOP - -void ValidityMask::Resize(idx_t new_size) { - idx_t old_size = capacity; - if (new_size <= old_size) { - return; - } - capacity = new_size; - if (validity_mask) { - auto new_size_count = EntryCount(new_size); - auto old_size_count = EntryCount(old_size); - auto new_validity_data = make_buffer(new_size); - auto new_owned_data = new_validity_data->owned_data.get(); - for (idx_t entry_idx = 0; entry_idx < old_size_count; entry_idx++) { - new_owned_data[entry_idx] = validity_mask[entry_idx]; - } - for (idx_t entry_idx = old_size_count; entry_idx < new_size_count; entry_idx++) { - new_owned_data[entry_idx] = ValidityData::MAX_ENTRY; - } - validity_data = std::move(new_validity_data); - validity_mask = validity_data->owned_data.get(); - } -} - -idx_t ValidityMask::Capacity() const { - return capacity; -} - -void ValidityMask::Slice(const ValidityMask &other, idx_t source_offset, idx_t count) { - if (other.AllValid()) { - validity_mask = nullptr; - validity_data.reset(); - return; - } - if (source_offset == 0) { - Initialize(other); - return; - } - ValidityMask new_mask(count); - new_mask.SliceInPlace(other, 0, source_offset, count); - Initialize(new_mask); -} - -bool ValidityMask::IsAligned(idx_t count) { - return count % BITS_PER_VALUE == 0; -} - -void ValidityMask::CopySel(const ValidityMask &other, const SelectionVector &sel, idx_t source_offset, - idx_t target_offset, idx_t copy_count) { - if (!other.IsMaskSet() && !IsMaskSet()) { - // no need to copy anything if neither has any null values - return; - } - - if (!sel.IsSet() && IsAligned(source_offset) && IsAligned(target_offset)) { - // common case where we are shifting into an aligned mask using a flat vector - SliceInPlace(other, target_offset, source_offset, copy_count); - return; - } - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel.get_index(source_offset + i); - Set(target_offset + i, other.RowIsValid(source_idx)); - } -} - -void ValidityMask::SliceInPlace(const ValidityMask &other, idx_t target_offset, idx_t source_offset, idx_t count) { - if (AllValid() && other.AllValid()) { - // Both validity masks are uninitialized, nothing to do - return; - } - EnsureWritable(); - const idx_t ragged = count % BITS_PER_VALUE; - const idx_t entire_units = count / BITS_PER_VALUE; - if (IsAligned(source_offset) && IsAligned(target_offset)) { - auto target_validity = GetData(); - auto source_validity = other.GetData(); - auto source_offset_entries = EntryCount(source_offset); - auto target_offset_entries = EntryCount(target_offset); - if (!source_validity) { - // if source has no validity mask - set all bytes to 1 - memset(target_validity + target_offset_entries, 0xFF, sizeof(validity_t) * entire_units); - } else { - memcpy(target_validity + target_offset_entries, source_validity + source_offset_entries, - sizeof(validity_t) * entire_units); - } - if (ragged) { - auto src_entry = - source_validity ? source_validity[source_offset_entries + entire_units] : ValidityBuffer::MAX_ENTRY; - src_entry &= (ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - ragged)); - - target_validity += target_offset_entries + entire_units; - auto tgt_entry = *target_validity; - tgt_entry &= (ValidityBuffer::MAX_ENTRY << ragged); - - *target_validity = tgt_entry | src_entry; - } - return; - } else if (IsAligned(target_offset)) { - // Simple common case where we are shifting into an aligned mask (e.g., 0 in Slice above) - const idx_t tail = source_offset % BITS_PER_VALUE; - const idx_t head = BITS_PER_VALUE - tail; - auto source_validity = other.GetData() + (source_offset / BITS_PER_VALUE); - auto target_validity = this->GetData() + (target_offset / BITS_PER_VALUE); - auto src_entry = *source_validity++; - for (idx_t i = 0; i < entire_units; ++i) { - // Start with head of previous src - validity_t tgt_entry = src_entry >> tail; - src_entry = *source_validity++; - // Add in tail of current src - tgt_entry |= (src_entry << head); - *target_validity++ = tgt_entry; - } - // Finish last ragged entry - if (ragged) { - // Start with head of previous src - validity_t tgt_entry = (src_entry >> tail); - // Add in the tail of the next src, if head was too small - if (head < ragged) { - src_entry = *source_validity++; - tgt_entry |= (src_entry << head); - } - // Mask off the bits that go past the ragged end - tgt_entry &= (ValidityBuffer::MAX_ENTRY >> (BITS_PER_VALUE - ragged)); - // Restore the ragged end of the target - tgt_entry |= *target_validity & (ValidityBuffer::MAX_ENTRY << ragged); - *target_validity++ = tgt_entry; - } - return; - } - - // FIXME: use bitwise operations here -#if 1 - for (idx_t i = 0; i < count; i++) { - Set(target_offset + i, other.RowIsValid(source_offset + i)); - } -#else - // first shift the "whole" units - idx_t entire_units = offset / BITS_PER_VALUE; - idx_t sub_units = offset - entire_units * BITS_PER_VALUE; - if (entire_units > 0) { - idx_t validity_idx; - for (validity_idx = 0; validity_idx + entire_units < STANDARD_ENTRY_COUNT; validity_idx++) { - new_mask.validity_mask[validity_idx] = other.validity_mask[validity_idx + entire_units]; - } - } - // now we shift the remaining sub units - // this gets a bit more complicated because we have to shift over the borders of the entries - // e.g. suppose we have 2 entries of length 4 and we left-shift by two - // 0101|1010 - // a regular left-shift of both gets us: - // 0100|1000 - // we then OR the overflow (right-shifted by BITS_PER_VALUE - offset) together to get the correct result - // 0100|1000 -> - // 0110|1000 - if (sub_units > 0) { - idx_t validity_idx; - for (validity_idx = 0; validity_idx + 1 < STANDARD_ENTRY_COUNT; validity_idx++) { - new_mask.validity_mask[validity_idx] = - (other.validity_mask[validity_idx] >> sub_units) | - (other.validity_mask[validity_idx + 1] << (BITS_PER_VALUE - sub_units)); - } - new_mask.validity_mask[validity_idx] >>= sub_units; - } -#ifdef DEBUG - for (idx_t i = offset; i < STANDARD_VECTOR_SIZE; i++) { - D_ASSERT(new_mask.RowIsValid(i - offset) == other.RowIsValid(i)); - } - Initialize(new_mask); -#endif -#endif -} - -enum class ValiditySerialization : uint8_t { BITMASK = 0, VALID_VALUES = 1, INVALID_VALUES = 2 }; - -void ValidityMask::Write(WriteStream &writer, idx_t count) { - auto valid_values = CountValid(count); - auto invalid_values = count - valid_values; - auto bitmask_bytes = ValidityMask::ValidityMaskSize(count); - auto need_u32 = count >= NumericLimits::Maximum(); - auto bytes_per_value = need_u32 ? sizeof(uint32_t) : sizeof(uint16_t); - auto valid_value_size = bytes_per_value * valid_values + sizeof(uint32_t); - auto invalid_value_size = bytes_per_value * invalid_values + sizeof(uint32_t); - if (valid_value_size < bitmask_bytes || invalid_value_size < bitmask_bytes) { - auto serialize_valid = valid_value_size < invalid_value_size; - // serialize (in)valid value indexes as [COUNT][V0][V1][...][VN] - auto flag = serialize_valid ? ValiditySerialization::VALID_VALUES : ValiditySerialization::INVALID_VALUES; - writer.Write(flag); - writer.Write(NumericCast(MinValue(valid_values, invalid_values))); - for (idx_t i = 0; i < count; i++) { - if (RowIsValid(i) == serialize_valid) { - if (need_u32) { - writer.Write(UnsafeNumericCast(i)); - } else { - writer.Write(UnsafeNumericCast(i)); - } - } - } - } else { - // serialize the entire bitmask - writer.Write(ValiditySerialization::BITMASK); - writer.WriteData(const_data_ptr_cast(GetData()), bitmask_bytes); - } -} - -void ValidityMask::Read(ReadStream &reader, idx_t count) { - Initialize(count); - // deserialize the storage type - auto flag = reader.Read(); - if (flag == ValiditySerialization::BITMASK) { - // deserialize the bitmask - reader.ReadData(data_ptr_cast(GetData()), ValidityMask::ValidityMaskSize(count)); - return; - } - auto is_u32 = count >= NumericLimits::Maximum(); - auto is_valid = flag == ValiditySerialization::VALID_VALUES; - auto serialize_count = reader.Read(); - if (is_valid) { - SetAllInvalid(count); - } - for (idx_t i = 0; i < serialize_count; i++) { - idx_t index = is_u32 ? reader.Read() : reader.Read(); - Set(index, is_valid); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/value.cpp b/src/duckdb/src/common/types/value.cpp deleted file mode 100644 index 4e8566c9a..000000000 --- a/src/duckdb/src/common/types/value.cpp +++ /dev/null @@ -1,2257 +0,0 @@ -#include "duckdb/common/types/value.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -#include "duckdb/common/uhugeint.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/common/operator/numeric_binary_operators.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/types/blob.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/main/error_manager.hpp" -#include "duckdb/common/types/varint.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Extra Value Info -//===--------------------------------------------------------------------===// -enum class ExtraValueInfoType : uint8_t { INVALID_TYPE_INFO = 0, STRING_VALUE_INFO = 1, NESTED_VALUE_INFO = 2 }; - -struct ExtraValueInfo { - explicit ExtraValueInfo(ExtraValueInfoType type) : type(type) { - } - virtual ~ExtraValueInfo() { - } - - ExtraValueInfoType type; - -public: - bool Equals(ExtraValueInfo *other_p) const { - if (!other_p) { - return false; - } - if (type != other_p->type) { - return false; - } - return EqualsInternal(other_p); - } - - template - T &Get() { - if (type != T::TYPE) { - throw InternalException("ExtraValueInfo type mismatch"); - } - return (T &)*this; - } - -protected: - virtual bool EqualsInternal(ExtraValueInfo *other_p) const { - return true; - } -}; - -//===--------------------------------------------------------------------===// -// String Value Info -//===--------------------------------------------------------------------===// -struct StringValueInfo : public ExtraValueInfo { - static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::STRING_VALUE_INFO; - -public: - explicit StringValueInfo(string str_p) - : ExtraValueInfo(ExtraValueInfoType::STRING_VALUE_INFO), str(std::move(str_p)) { - } - - const string &GetString() { - return str; - } - -protected: - bool EqualsInternal(ExtraValueInfo *other_p) const override { - return other_p->Get().str == str; - } - - string str; -}; - -//===--------------------------------------------------------------------===// -// Nested Value Info -//===--------------------------------------------------------------------===// -struct NestedValueInfo : public ExtraValueInfo { - static constexpr const ExtraValueInfoType TYPE = ExtraValueInfoType::NESTED_VALUE_INFO; - -public: - NestedValueInfo() : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO) { - } - explicit NestedValueInfo(vector values_p) - : ExtraValueInfo(ExtraValueInfoType::NESTED_VALUE_INFO), values(std::move(values_p)) { - } - - const vector &GetValues() { - return values; - } - -protected: - bool EqualsInternal(ExtraValueInfo *other_p) const override { - return other_p->Get().values == values; - } - - vector values; -}; -//===--------------------------------------------------------------------===// -// Value -//===--------------------------------------------------------------------===// -Value::Value(LogicalType type) : type_(std::move(type)), is_null(true) { -} - -Value::Value(int32_t val) : type_(LogicalType::INTEGER), is_null(false) { - value_.integer = val; -} - -Value::Value(bool val) : type_(LogicalType::BOOLEAN), is_null(false) { - value_.boolean = val; -} - -Value::Value(int64_t val) : type_(LogicalType::BIGINT), is_null(false) { - value_.bigint = val; -} - -Value::Value(float val) : type_(LogicalType::FLOAT), is_null(false) { - value_.float_ = val; -} - -Value::Value(double val) : type_(LogicalType::DOUBLE), is_null(false) { - value_.double_ = val; -} - -Value::Value(const char *val) : Value(val ? string(val) : string()) { -} - -Value::Value(std::nullptr_t val) : Value(LogicalType::VARCHAR) { -} - -Value::Value(string_t val) : Value(val.GetString()) { -} - -Value::Value(string val) : type_(LogicalType::VARCHAR), is_null(false) { - if (!Value::StringIsValid(val.c_str(), val.size())) { - throw ErrorManager::InvalidUnicodeError(val, "value construction"); - } - value_info_ = make_shared_ptr(std::move(val)); -} - -Value::~Value() { -} - -Value::Value(const Value &other) - : type_(other.type_), is_null(other.is_null), value_(other.value_), value_info_(other.value_info_) { -} - -Value::Value(Value &&other) noexcept - : type_(std::move(other.type_)), is_null(other.is_null), value_(other.value_), - value_info_(std::move(other.value_info_)) { -} - -Value &Value::operator=(const Value &other) { - if (this == &other) { - return *this; - } - type_ = other.type_; - is_null = other.is_null; - value_ = other.value_; - value_info_ = other.value_info_; - return *this; -} - -Value &Value::operator=(Value &&other) noexcept { - type_ = std::move(other.type_); - is_null = other.is_null; - value_ = other.value_; - value_info_ = std::move(other.value_info_); - return *this; -} - -Value Value::MinimumValue(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(false); - case LogicalTypeId::TINYINT: - return Value::TINYINT(NumericLimits::Minimum()); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(NumericLimits::Minimum()); - case LogicalTypeId::INTEGER: - case LogicalTypeId::SQLNULL: - return Value::INTEGER(NumericLimits::Minimum()); - case LogicalTypeId::BIGINT: - return Value::BIGINT(NumericLimits::Minimum()); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(NumericLimits::Minimum()); - case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(NumericLimits::Minimum()); - case LogicalTypeId::UUID: - return Value::UUID(NumericLimits::Minimum()); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(NumericLimits::Minimum()); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(NumericLimits::Minimum()); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(NumericLimits::Minimum()); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(NumericLimits::Minimum()); - case LogicalTypeId::DATE: - return Value::DATE(Date::FromDate(Date::DATE_MIN_YEAR, Date::DATE_MIN_MONTH, Date::DATE_MIN_DAY)); - case LogicalTypeId::TIME: - return Value::TIME(dtime_t(0)); - case LogicalTypeId::TIMESTAMP: { - const auto date = Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY); - return Value::TIMESTAMP(date, dtime_t(0)); - } - case LogicalTypeId::TIMESTAMP_SEC: { - // Get the minimum timestamp and cast it to timestamp_sec_t. - const auto min_ts = MinimumValue(LogicalType::TIMESTAMP).GetValue(); - const auto ts = Cast::Operation(min_ts); - return Value::TIMESTAMPSEC(ts); - } - case LogicalTypeId::TIMESTAMP_MS: { - // Get the minimum timestamp and cast it to timestamp_ms_t. - const auto min_ts = MinimumValue(LogicalType::TIMESTAMP).GetValue(); - const auto ts = Cast::Operation(min_ts); - return Value::TIMESTAMPMS(ts); - } - case LogicalTypeId::TIMESTAMP_NS: { - // Clear the fractional day. - auto min_ns = NumericLimits::Minimum(); - min_ns /= Interval::NANOS_PER_DAY; - min_ns *= Interval::NANOS_PER_DAY; - return Value::TIMESTAMPNS(timestamp_ns_t(min_ns)); - } - case LogicalTypeId::TIME_TZ: - // "00:00:00+1559" from the PG docs, but actually 00:00:00+15:59:59 - return Value::TIMETZ(dtime_tz_t(dtime_t(0), dtime_tz_t::MAX_OFFSET)); - case LogicalTypeId::TIMESTAMP_TZ: { - const auto date = Date::FromDate(Timestamp::MIN_YEAR, Timestamp::MIN_MONTH, Timestamp::MIN_DAY); - const auto ts = Timestamp::FromDatetime(date, dtime_t(0)); - return Value::TIMESTAMPTZ(timestamp_tz_t(ts)); - } - case LogicalTypeId::FLOAT: - return Value::FLOAT(NumericLimits::Minimum()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(NumericLimits::Minimum()); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(int16_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(int32_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(int64_t(-NumericHelper::POWERS_OF_TEN[width] + 1), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(-Hugeint::POWERS_OF_TEN[width] + 1, width, scale); - default: - throw InternalException("Unknown decimal type"); - } - } - case LogicalTypeId::ENUM: - return Value::ENUM(0, type); - case LogicalTypeId::VARINT: - return Value::VARINT(Varint::VarcharToVarInt( - "-179769313486231570814527423731704356798070567525844996598917476803157260780028538760589558632766878171540" - "4589535143824642343213268894641827684675467035375169860499105765512820762454900903893289440758685084551339" - "42304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); - default: - throw InvalidTypeException(type, "MinimumValue requires numeric type"); - } -} - -Value Value::MaximumValue(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(true); - case LogicalTypeId::TINYINT: - return Value::TINYINT(NumericLimits::Maximum()); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(NumericLimits::Maximum()); - case LogicalTypeId::INTEGER: - case LogicalTypeId::SQLNULL: - return Value::INTEGER(NumericLimits::Maximum()); - case LogicalTypeId::BIGINT: - return Value::BIGINT(NumericLimits::Maximum()); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(NumericLimits::Maximum()); - case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(NumericLimits::Maximum()); - case LogicalTypeId::UUID: - return Value::UUID(NumericLimits::Maximum()); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(NumericLimits::Maximum()); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(NumericLimits::Maximum()); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(NumericLimits::Maximum()); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(NumericLimits::Maximum()); - case LogicalTypeId::DATE: - return Value::DATE(Date::FromDate(Date::DATE_MAX_YEAR, Date::DATE_MAX_MONTH, Date::DATE_MAX_DAY)); - case LogicalTypeId::TIME: - // 24:00:00 according to PG - return Value::TIME(dtime_t(Interval::MICROS_PER_DAY)); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t(NumericLimits::Maximum() - 1)); - case LogicalTypeId::TIMESTAMP_SEC: { - // Get the maximum timestamp and cast it to timestamp_s_t. - const auto max_ts = MaximumValue(LogicalType::TIMESTAMP).GetValue(); - const auto ts = Cast::Operation(max_ts); - return Value::TIMESTAMPSEC(ts); - } - case LogicalTypeId::TIMESTAMP_MS: { - // Get the maximum timestamp and cast it to timestamp_ms_t. - const auto max_ts = MaximumValue(LogicalType::TIMESTAMP).GetValue(); - const auto ts = Cast::Operation(max_ts); - return Value::TIMESTAMPMS(ts); - } - case LogicalTypeId::TIMESTAMP_NS: { - const auto ts = timestamp_ns_t(NumericLimits::Maximum() - 1); - return Value::TIMESTAMPNS(ts); - } - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_tz_t(NumericLimits::Maximum() - 1)); - case LogicalTypeId::TIME_TZ: - // "24:00:00-1559" from the PG docs but actually "24:00:00-15:59:59". - return Value::TIMETZ(dtime_tz_t(dtime_t(Interval::MICROS_PER_DAY), dtime_tz_t::MIN_OFFSET)); - case LogicalTypeId::FLOAT: - return Value::FLOAT(NumericLimits::Maximum()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(NumericLimits::Maximum()); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(int16_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(int32_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(int64_t(NumericHelper::POWERS_OF_TEN[width] - 1), width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(Hugeint::POWERS_OF_TEN[width] - 1, width, scale); - default: - throw InternalException("Unknown decimal type"); - } - } - case LogicalTypeId::ENUM: { - auto enum_size = EnumType::GetSize(type); - return Value::ENUM(enum_size - (enum_size ? 1 : 0), type); - } - case LogicalTypeId::VARINT: - return Value::VARINT(Varint::VarcharToVarInt( - "1797693134862315708145274237317043567980705675258449965989174768031572607800285387605895586327668781715404" - "5895351438246423432132688946418276846754670353751698604991057655128207624549009038932894407586850845513394" - "2304583236903222948165808559332123348274797826204144723168738177180919299881250404026184124858368")); - default: - throw InvalidTypeException(type, "MaximumValue requires numeric type"); - } -} - -Value Value::Infinity(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::DATE: - return Value::DATE(date_t::infinity()); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t::infinity()); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_sec_t(timestamp_t::infinity().value)); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_ms_t(timestamp_t::infinity().value)); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_ns_t(timestamp_t::infinity().value)); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_tz_t(timestamp_t::infinity())); - case LogicalTypeId::FLOAT: - return Value::FLOAT(std::numeric_limits::infinity()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(std::numeric_limits::infinity()); - default: - throw InvalidTypeException(type, "Infinity requires numeric type"); - } -} - -Value Value::NegativeInfinity(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::DATE: - return Value::DATE(date_t::ninfinity()); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t::ninfinity()); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_sec_t(timestamp_t::ninfinity().value)); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_ms_t(timestamp_t::ninfinity().value)); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_ns_t(timestamp_t::ninfinity().value)); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_tz_t(timestamp_t::ninfinity())); - case LogicalTypeId::FLOAT: - return Value::FLOAT(-std::numeric_limits::infinity()); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(-std::numeric_limits::infinity()); - default: - throw InvalidTypeException(type, "NegativeInfinity requires numeric type"); - } -} - -Value Value::BOOLEAN(bool value) { - Value result(LogicalType::BOOLEAN); - result.value_.boolean = value; - result.is_null = false; - return result; -} - -Value Value::TINYINT(int8_t value) { - Value result(LogicalType::TINYINT); - result.value_.tinyint = value; - result.is_null = false; - return result; -} - -Value Value::SMALLINT(int16_t value) { - Value result(LogicalType::SMALLINT); - result.value_.smallint = value; - result.is_null = false; - return result; -} - -Value Value::INTEGER(int32_t value) { - Value result(LogicalType::INTEGER); - result.value_.integer = value; - result.is_null = false; - return result; -} - -Value Value::BIGINT(int64_t value) { - Value result(LogicalType::BIGINT); - result.value_.bigint = value; - result.is_null = false; - return result; -} - -Value Value::HUGEINT(hugeint_t value) { - Value result(LogicalType::HUGEINT); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::UHUGEINT(uhugeint_t value) { - Value result(LogicalType::UHUGEINT); - result.value_.uhugeint = value; - result.is_null = false; - return result; -} - -Value Value::UUID(hugeint_t value) { - Value result(LogicalType::UUID); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::UUID(const string &value) { - Value result(LogicalType::UUID); - result.value_.hugeint = UUID::FromString(value); - result.is_null = false; - return result; -} - -Value Value::UTINYINT(uint8_t value) { - Value result(LogicalType::UTINYINT); - result.value_.utinyint = value; - result.is_null = false; - return result; -} - -Value Value::USMALLINT(uint16_t value) { - Value result(LogicalType::USMALLINT); - result.value_.usmallint = value; - result.is_null = false; - return result; -} - -Value Value::UINTEGER(uint32_t value) { - Value result(LogicalType::UINTEGER); - result.value_.uinteger = value; - result.is_null = false; - return result; -} - -Value Value::UBIGINT(uint64_t value) { - Value result(LogicalType::UBIGINT); - result.value_.ubigint = value; - result.is_null = false; - return result; -} - -bool Value::FloatIsFinite(float value) { - return !(std::isnan(value) || std::isinf(value)); -} - -bool Value::DoubleIsFinite(double value) { - return !(std::isnan(value) || std::isinf(value)); -} - -template <> -bool Value::IsNan(float input) { - return std::isnan(input); -} - -template <> -bool Value::IsNan(double input) { - return std::isnan(input); -} - -template <> -bool Value::IsFinite(float input) { - return Value::FloatIsFinite(input); -} - -template <> -bool Value::IsFinite(double input) { - return Value::DoubleIsFinite(input); -} - -template <> -bool Value::IsFinite(date_t input) { - return Date::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_t input) { - return Timestamp::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_sec_t input) { - return Timestamp::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_ms_t input) { - return Timestamp::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_ns_t input) { - return Timestamp::IsFinite(input); -} - -template <> -bool Value::IsFinite(timestamp_tz_t input) { - return Timestamp::IsFinite(input); -} - -bool Value::StringIsValid(const char *str, idx_t length) { - auto utf_type = Utf8Proc::Analyze(str, length); - return utf_type != UnicodeType::INVALID; -} - -Value Value::DECIMAL(int16_t value, uint8_t width, uint8_t scale) { - return Value::DECIMAL(int64_t(value), width, scale); -} - -Value Value::DECIMAL(int32_t value, uint8_t width, uint8_t scale) { - return Value::DECIMAL(int64_t(value), width, scale); -} - -Value Value::DECIMAL(int64_t value, uint8_t width, uint8_t scale) { - auto decimal_type = LogicalType::DECIMAL(width, scale); - Value result(decimal_type); - switch (decimal_type.InternalType()) { - case PhysicalType::INT16: - result.value_.smallint = NumericCast(value); - break; - case PhysicalType::INT32: - result.value_.integer = NumericCast(value); - break; - case PhysicalType::INT64: - result.value_.bigint = value; - break; - default: - result.value_.hugeint = value; - break; - } - result.type_.Verify(); - result.is_null = false; - return result; -} - -Value Value::DECIMAL(hugeint_t value, uint8_t width, uint8_t scale) { - D_ASSERT(width >= Decimal::MAX_WIDTH_INT64 && width <= Decimal::MAX_WIDTH_INT128); - Value result(LogicalType::DECIMAL(width, scale)); - result.value_.hugeint = value; - result.is_null = false; - return result; -} - -Value Value::FLOAT(float value) { - Value result(LogicalType::FLOAT); - result.value_.float_ = value; - result.is_null = false; - return result; -} - -Value Value::DOUBLE(double value) { - Value result(LogicalType::DOUBLE); - result.value_.double_ = value; - result.is_null = false; - return result; -} - -Value Value::HASH(hash_t value) { - Value result(LogicalType::HASH); - result.value_.hash = value; - result.is_null = false; - return result; -} - -Value Value::POINTER(uintptr_t value) { - Value result(LogicalType::POINTER); - result.value_.pointer = value; - result.is_null = false; - return result; -} - -Value Value::DATE(date_t value) { - Value result(LogicalType::DATE); - result.value_.date = value; - result.is_null = false; - return result; -} - -Value Value::DATE(int32_t year, int32_t month, int32_t day) { - return Value::DATE(Date::FromDate(year, month, day)); -} - -Value Value::TIME(dtime_t value) { - Value result(LogicalType::TIME); - result.value_.time = value; - result.is_null = false; - return result; -} - -Value Value::TIMETZ(dtime_tz_t value) { - Value result(LogicalType::TIME_TZ); - result.value_.timetz = value; - result.is_null = false; - return result; -} - -Value Value::TIME(int32_t hour, int32_t min, int32_t sec, int32_t micros) { - return Value::TIME(Time::FromTime(hour, min, sec, micros)); -} - -Value Value::TIMESTAMP(timestamp_t value) { - Value result(LogicalType::TIMESTAMP); - result.value_.timestamp = value; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPSEC(timestamp_sec_t timestamp) { - Value result(LogicalType::TIMESTAMP_S); - result.value_.timestamp_s = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPMS(timestamp_ms_t timestamp) { - Value result(LogicalType::TIMESTAMP_MS); - result.value_.timestamp_ms = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPNS(timestamp_ns_t timestamp) { - Value result(LogicalType::TIMESTAMP_NS); - result.value_.timestamp_ns = timestamp; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMPTZ(timestamp_tz_t value) { - Value result(LogicalType::TIMESTAMP_TZ); - result.value_.timestamp_tz = value; - result.is_null = false; - return result; -} - -Value Value::TIMESTAMP(date_t date, dtime_t time) { - return Value::TIMESTAMP(Timestamp::FromDatetime(date, time)); -} - -Value Value::TIMESTAMP(int32_t year, int32_t month, int32_t day, int32_t hour, int32_t min, int32_t sec, - int32_t micros) { - auto date = Date::FromDate(year, month, day); - auto time = Time::FromTime(hour, min, sec, micros); - auto val = Value::TIMESTAMP(date, time); - val.type_ = LogicalType::TIMESTAMP; - return val; -} - -Value Value::STRUCT(const LogicalType &type, vector struct_values) { - Value result; - auto child_types = StructType::GetChildTypes(type); - for (size_t i = 0; i < struct_values.size(); i++) { - struct_values[i] = struct_values[i].DefaultCastAs(child_types[i].second); - } - result.value_info_ = make_shared_ptr(std::move(struct_values)); - result.type_ = type; - result.is_null = false; - return result; -} - -Value Value::STRUCT(child_list_t values) { - child_list_t child_types; - vector struct_values; - for (auto &child : values) { - child_types.push_back(make_pair(std::move(child.first), child.second.type())); - struct_values.push_back(std::move(child.second)); - } - return Value::STRUCT(LogicalType::STRUCT(child_types), std::move(struct_values)); -} - -void MapKeyCheck(unordered_set &unique_keys, const Value &key) { - // NULL key check. - if (key.IsNull()) { - MapVector::EvalMapInvalidReason(MapInvalidReason::NULL_KEY); - } - - // Duplicate key check. - auto key_hash = key.Hash(); - if (unique_keys.find(key_hash) != unique_keys.end()) { - MapVector::EvalMapInvalidReason(MapInvalidReason::DUPLICATE_KEY); - } - unique_keys.insert(key_hash); -} - -Value Value::MAP(const LogicalType &child_type, vector values) { // NOLINT - vector map_keys; - vector map_values; - unordered_set unique_keys; - - for (auto &val : values) { - D_ASSERT(val.type().InternalType() == PhysicalType::STRUCT); - auto &children = StructValue::GetChildren(val); - D_ASSERT(children.size() == 2); - - auto &key = children[0]; - MapKeyCheck(unique_keys, key); - - map_keys.push_back(key); - map_values.push_back(children[1]); - } - - auto &key_type = StructType::GetChildType(child_type, 0); - auto &value_type = StructType::GetChildType(child_type, 1); - return Value::MAP(key_type, value_type, std::move(map_keys), std::move(map_values)); -} - -Value Value::MAP(const LogicalType &key_type, const LogicalType &value_type, vector keys, vector values) { - D_ASSERT(keys.size() == values.size()); - Value result; - - result.type_ = LogicalType::MAP(key_type, value_type); - result.is_null = false; - unordered_set unique_keys; - - for (idx_t i = 0; i < keys.size(); i++) { - child_list_t struct_types; - vector new_children; - struct_types.reserve(2); - new_children.reserve(2); - - struct_types.push_back(make_pair("key", key_type)); - struct_types.push_back(make_pair("value", value_type)); - - auto key = keys[i].DefaultCastAs(key_type); - MapKeyCheck(unique_keys, key); - - new_children.push_back(key); - new_children.push_back(values[i]); - auto struct_type = LogicalType::STRUCT(std::move(struct_types)); - values[i] = Value::STRUCT(struct_type, std::move(new_children)); - } - - result.value_info_ = make_shared_ptr(std::move(values)); - return result; -} - -Value Value::MAP(const unordered_map &kv_pairs) { - Value result; - result.type_ = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - result.is_null = false; - vector pairs; - for (auto &kv : kv_pairs) { - pairs.push_back(Value::STRUCT({{"key", Value(kv.first)}, {"value", Value(kv.second)}})); - } - result.value_info_ = make_shared_ptr(std::move(pairs)); - return result; -} - -Value Value::UNION(child_list_t members, uint8_t tag, Value value) { - D_ASSERT(!members.empty()); - D_ASSERT(members.size() <= UnionType::MAX_UNION_MEMBERS); - D_ASSERT(members.size() > tag); - - D_ASSERT(value.type() == members[tag].second); - - Value result; - result.is_null = false; - // add the tag to the front of the struct - vector union_values; - union_values.emplace_back(Value::UTINYINT(tag)); - for (idx_t i = 0; i < members.size(); i++) { - if (i != tag) { - union_values.emplace_back(members[i].second); - } else { - union_values.emplace_back(nullptr); - } - } - union_values[tag + 1] = std::move(value); - result.value_info_ = make_shared_ptr(std::move(union_values)); - result.type_ = LogicalType::UNION(std::move(members)); - return result; -} - -Value Value::LIST(const LogicalType &child_type, vector values) { - Value result; - result.type_ = LogicalType::LIST(child_type); - result.is_null = false; - for (auto &val : values) { - val = val.DefaultCastAs(child_type); - } - result.value_info_ = make_shared_ptr(std::move(values)); - return result; -} - -Value Value::LIST(vector values) { - if (values.empty()) { - throw InternalException( - "Value::LIST(values) cannot be used to make an empty list - use Value::LIST(type, values) instead"); - } - auto &type = values[0].type(); - return Value::LIST(type, std::move(values)); -} - -Value Value::ARRAY(const LogicalType &child_type, vector values) { - Value result; - result.type_ = LogicalType::ARRAY(child_type, values.size()); - for (auto &val : values) { - val = val.DefaultCastAs(child_type); - } - result.value_info_ = make_shared_ptr(std::move(values)); - result.is_null = false; - return result; -} - -Value Value::BLOB(const_data_ptr_t data, idx_t len) { - Value result(LogicalType::BLOB); - result.is_null = false; - result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); - return result; -} - -Value Value::VARINT(const_data_ptr_t data, idx_t len) { - return VARINT(string(const_char_ptr_cast(data), len)); -} - -Value Value::VARINT(const string &data) { - Value result(LogicalType::VARINT); - result.is_null = false; - result.value_info_ = make_shared_ptr(data); - return result; -} - -Value Value::BLOB(const string &data) { - Value result(LogicalType::BLOB); - result.is_null = false; - result.value_info_ = make_shared_ptr(Blob::ToBlob(string_t(data))); - return result; -} - -Value Value::AGGREGATE_STATE(const LogicalType &type, const_data_ptr_t data, idx_t len) { // NOLINT - Value result(type); - result.is_null = false; - result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); - return result; -} - -Value Value::BIT(const_data_ptr_t data, idx_t len) { - Value result(LogicalType::BIT); - result.is_null = false; - result.value_info_ = make_shared_ptr(string(const_char_ptr_cast(data), len)); - return result; -} - -Value Value::BIT(const string &data) { - Value result(LogicalType::BIT); - result.is_null = false; - result.value_info_ = make_shared_ptr(Bit::ToBit(string_t(data))); - return result; -} - -Value Value::ENUM(uint64_t value, const LogicalType &original_type) { - D_ASSERT(original_type.id() == LogicalTypeId::ENUM); - Value result(original_type); - switch (original_type.InternalType()) { - case PhysicalType::UINT8: - result.value_.utinyint = NumericCast(value); - break; - case PhysicalType::UINT16: - result.value_.usmallint = NumericCast(value); - break; - case PhysicalType::UINT32: - result.value_.uinteger = NumericCast(value); - break; - default: - throw InternalException("Incorrect Physical Type for ENUM"); - } - result.is_null = false; - return result; -} - -Value Value::INTERVAL(int32_t months, int32_t days, int64_t micros) { - Value result(LogicalType::INTERVAL); - result.is_null = false; - result.value_.interval.months = months; - result.value_.interval.days = days; - result.value_.interval.micros = micros; - return result; -} - -Value Value::INTERVAL(interval_t interval) { - return Value::INTERVAL(interval.months, interval.days, interval.micros); -} - -//===--------------------------------------------------------------------===// -// CreateValue -//===--------------------------------------------------------------------===// -template <> -Value Value::CreateValue(bool value) { - return Value::BOOLEAN(value); -} - -template <> -Value Value::CreateValue(int8_t value) { - return Value::TINYINT(value); -} - -template <> -Value Value::CreateValue(int16_t value) { - return Value::SMALLINT(value); -} - -template <> -Value Value::CreateValue(int32_t value) { - return Value::INTEGER(value); -} - -template <> -Value Value::CreateValue(int64_t value) { - return Value::BIGINT(value); -} - -template <> -Value Value::CreateValue(uint8_t value) { - return Value::UTINYINT(value); -} - -template <> -Value Value::CreateValue(uint16_t value) { - return Value::USMALLINT(value); -} - -template <> -Value Value::CreateValue(uint32_t value) { - return Value::UINTEGER(value); -} - -template <> -Value Value::CreateValue(uint64_t value) { - return Value::UBIGINT(value); -} - -template <> -Value Value::CreateValue(hugeint_t value) { - return Value::HUGEINT(value); -} - -template <> -Value Value::CreateValue(uhugeint_t value) { - return Value::UHUGEINT(value); -} - -template <> -Value Value::CreateValue(date_t value) { - return Value::DATE(value); -} - -template <> -Value Value::CreateValue(dtime_t value) { - return Value::TIME(value); -} - -template <> -Value Value::CreateValue(dtime_tz_t value) { - return Value::TIMETZ(value); -} - -template <> -Value Value::CreateValue(timestamp_t value) { - return Value::TIMESTAMP(value); -} - -template <> -Value Value::CreateValue(timestamp_sec_t value) { - return Value::TIMESTAMPSEC(value); -} - -template <> -Value Value::CreateValue(timestamp_ms_t value) { - return Value::TIMESTAMPMS(value); -} - -template <> -Value Value::CreateValue(timestamp_ns_t value) { - return Value::TIMESTAMPNS(value); -} - -template <> -Value Value::CreateValue(timestamp_tz_t value) { - return Value::TIMESTAMPTZ(value); -} - -template <> -Value Value::CreateValue(const char *value) { - return Value(string(value)); -} - -template <> -Value Value::CreateValue(string value) { // NOLINT: required for templating - return Value::BLOB(value); -} - -template <> -Value Value::CreateValue(string_t value) { - return Value(value); -} - -template <> -Value Value::CreateValue(float value) { - return Value::FLOAT(value); -} - -template <> -Value Value::CreateValue(double value) { - return Value::DOUBLE(value); -} - -template <> -Value Value::CreateValue(interval_t value) { - return Value::INTERVAL(value); -} - -template <> -Value Value::CreateValue(Value value) { - return value; -} - -//===--------------------------------------------------------------------===// -// GetValue -//===--------------------------------------------------------------------===// -template -T Value::GetValueInternal() const { - if (IsNull()) { - throw InternalException("Calling GetValueInternal on a value that is NULL"); - } - switch (type_.id()) { - case LogicalTypeId::BOOLEAN: - return Cast::Operation(value_.boolean); - case LogicalTypeId::TINYINT: - return Cast::Operation(value_.tinyint); - case LogicalTypeId::SMALLINT: - return Cast::Operation(value_.smallint); - case LogicalTypeId::INTEGER: - return Cast::Operation(value_.integer); - case LogicalTypeId::BIGINT: - return Cast::Operation(value_.bigint); - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UUID: - return Cast::Operation(value_.hugeint); - case LogicalTypeId::UHUGEINT: - return Cast::Operation(value_.uhugeint); - case LogicalTypeId::DATE: - return Cast::Operation(value_.date); - case LogicalTypeId::TIME: - return Cast::Operation(value_.time); - case LogicalTypeId::TIME_TZ: - return Cast::Operation(value_.timetz); - case LogicalTypeId::TIMESTAMP: - return Cast::Operation(value_.timestamp); - case LogicalTypeId::TIMESTAMP_SEC: - return Cast::Operation(value_.timestamp_s); - case LogicalTypeId::TIMESTAMP_MS: - return Cast::Operation(value_.timestamp_ms); - case LogicalTypeId::TIMESTAMP_NS: - return Cast::Operation(value_.timestamp_ns); - case LogicalTypeId::TIMESTAMP_TZ: - return Cast::Operation(value_.timestamp_tz); - case LogicalTypeId::UTINYINT: - return Cast::Operation(value_.utinyint); - case LogicalTypeId::USMALLINT: - return Cast::Operation(value_.usmallint); - case LogicalTypeId::UINTEGER: - return Cast::Operation(value_.uinteger); - case LogicalTypeId::UBIGINT: - return Cast::Operation(value_.ubigint); - case LogicalTypeId::FLOAT: - return Cast::Operation(value_.float_); - case LogicalTypeId::DOUBLE: - return Cast::Operation(value_.double_); - case LogicalTypeId::VARCHAR: - return Cast::Operation(StringValue::Get(*this).c_str()); - case LogicalTypeId::INTERVAL: - return Cast::Operation(value_.interval); - case LogicalTypeId::DECIMAL: - return DefaultCastAs(LogicalType::DOUBLE).GetValueInternal(); - case LogicalTypeId::ENUM: { - switch (type_.InternalType()) { - case PhysicalType::UINT8: - return Cast::Operation(value_.utinyint); - case PhysicalType::UINT16: - return Cast::Operation(value_.usmallint); - case PhysicalType::UINT32: - return Cast::Operation(value_.uinteger); - default: - throw InternalException("Invalid Internal Type for ENUMs"); - } - } - default: - throw NotImplementedException("Unimplemented type \"%s\" for GetValue()", type_.ToString()); - } -} - -template <> -bool Value::GetValue() const { - return GetValueInternal(); -} -template <> -int8_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -int16_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -int32_t Value::GetValue() const { - if (type_.id() == LogicalTypeId::DATE) { - return value_.integer; - } - return GetValueInternal(); -} -template <> -int64_t Value::GetValue() const { - if (IsNull()) { - throw InternalException("Calling GetValue on a value that is NULL"); - } - switch (type_.id()) { - case LogicalTypeId::TIMESTAMP: - return value_.timestamp.value; - case LogicalTypeId::TIMESTAMP_SEC: - return value_.timestamp_s.value; - case LogicalTypeId::TIMESTAMP_MS: - return value_.timestamp_ms.value; - case LogicalTypeId::TIMESTAMP_NS: - return value_.timestamp_ns.value; - case LogicalTypeId::TIMESTAMP_TZ: - return value_.timestamp_tz.value; - case LogicalTypeId::TIME: - return value_.bigint; - default: - return GetValueInternal(); - } -} -template <> -hugeint_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint8_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint16_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint32_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uint64_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -uhugeint_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -string Value::GetValue() const { - return ToString(); -} -template <> -float Value::GetValue() const { - return GetValueInternal(); -} -template <> -double Value::GetValue() const { - return GetValueInternal(); -} -template <> -date_t Value::GetValue() const { - return GetValueInternal(); -} -template <> -dtime_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -timestamp_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -timestamp_sec_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -timestamp_ms_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -timestamp_ns_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -timestamp_tz_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -dtime_tz_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -DUCKDB_API interval_t Value::GetValue() const { - return GetValueInternal(); -} - -template <> -DUCKDB_API Value Value::GetValue() const { - return Value(*this); -} - -uintptr_t Value::GetPointer() const { - D_ASSERT(type() == LogicalType::POINTER); - return value_.pointer; -} - -Value Value::Numeric(const LogicalType &type, int64_t value) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - D_ASSERT(value == 0 || value == 1); - return Value::BOOLEAN(value ? true : false); - case LogicalTypeId::TINYINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::TINYINT((int8_t)value); - case LogicalTypeId::SMALLINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::SMALLINT((int16_t)value); - case LogicalTypeId::INTEGER: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::INTEGER((int32_t)value); - case LogicalTypeId::BIGINT: - return Value::BIGINT(value); - case LogicalTypeId::UTINYINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UTINYINT((uint8_t)value); - case LogicalTypeId::USMALLINT: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::USMALLINT((uint16_t)value); - case LogicalTypeId::UINTEGER: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UINTEGER((uint32_t)value); - case LogicalTypeId::UBIGINT: - D_ASSERT(value >= 0); - return Value::UBIGINT(NumericCast(value)); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(value); - case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(NumericCast(value)); - case LogicalTypeId::DECIMAL: - return Value::DECIMAL(value, DecimalType::GetWidth(type), DecimalType::GetScale(type)); - case LogicalTypeId::FLOAT: - return Value((float)value); - case LogicalTypeId::DOUBLE: - return Value((double)value); - case LogicalTypeId::POINTER: - return Value::POINTER(NumericCast(value)); - case LogicalTypeId::DATE: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::DATE(date_t(NumericCast(value))); - case LogicalTypeId::TIME: - return Value::TIME(dtime_t(value)); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(timestamp_t(value)); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(timestamp_sec_t(value)); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(timestamp_ms_t(value)); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(timestamp_ns_t(value)); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(timestamp_tz_t(value)); - case LogicalTypeId::ENUM: - switch (type.InternalType()) { - case PhysicalType::UINT8: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UTINYINT((uint8_t)value); - case PhysicalType::UINT16: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::USMALLINT((uint16_t)value); - case PhysicalType::UINT32: - D_ASSERT(value >= NumericLimits::Minimum() && value <= NumericLimits::Maximum()); - return Value::UINTEGER((uint32_t)value); - default: - throw InternalException("Enum doesn't accept this physical type"); - } - default: - throw InvalidTypeException(type, "Numeric requires numeric type"); - } -} - -Value Value::Numeric(const LogicalType &type, hugeint_t value) { -#ifdef DEBUG - // perform a throwing cast to verify that the type fits - Value::HUGEINT(value).DefaultCastAs(type); -#endif - switch (type.id()) { - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(value); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(Hugeint::Cast(value)); - default: - return Value::Numeric(type, Hugeint::Cast(value)); - } -} - -Value Value::Numeric(const LogicalType &type, uhugeint_t value) { -#ifdef DEBUG - // perform a throwing cast to verify that the type fits - Value::UHUGEINT(value).DefaultCastAs(type); -#endif - switch (type.id()) { - case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(value); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(Uhugeint::Cast(value)); - default: - return Value::Numeric(type, Uhugeint::Cast(value)); - } -} - -//===--------------------------------------------------------------------===// -// GetValueUnsafe -//===--------------------------------------------------------------------===// -template <> -DUCKDB_API bool Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::BOOL); - return value_.boolean; -} - -template <> -int8_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT8 || type_.InternalType() == PhysicalType::BOOL); - return value_.tinyint; -} - -template <> -int16_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT16); - return value_.smallint; -} - -template <> -int32_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT32); - return value_.integer; -} - -template <> -int64_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.bigint; -} - -template <> -hugeint_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT128); - return value_.hugeint; -} - -template <> -uint8_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT8); - return value_.utinyint; -} - -template <> -uint16_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT16); - return value_.usmallint; -} - -template <> -uint32_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT32); - return value_.uinteger; -} - -template <> -uint64_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT64); - return value_.ubigint; -} - -template <> -uhugeint_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::UINT128); - return value_.uhugeint; -} - -template <> -string Value::GetValueUnsafe() const { - return StringValue::Get(*this); -} - -template <> -DUCKDB_API string_t Value::GetValueUnsafe() const { - return string_t(StringValue::Get(*this)); -} - -template <> -float Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::FLOAT); - return value_.float_; -} - -template <> -double Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::DOUBLE); - return value_.double_; -} - -template <> -date_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT32); - return value_.date; -} - -template <> -dtime_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.time; -} - -template <> -dtime_tz_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timetz; -} - -template <> -timestamp_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp; -} - -template <> -timestamp_sec_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp_s; -} - -template <> -timestamp_ms_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp_ms; -} - -template <> -timestamp_ns_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp_ns; -} - -template <> -timestamp_tz_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INT64); - return value_.timestamp_tz; -} - -template <> -interval_t Value::GetValueUnsafe() const { - D_ASSERT(type_.InternalType() == PhysicalType::INTERVAL); - return value_.interval; -} - -//===--------------------------------------------------------------------===// -// Hash -//===--------------------------------------------------------------------===// -hash_t Value::Hash() const { - if (IsNull()) { - return 0; - } - Vector input(*this); - Vector result(LogicalType::HASH, 1); - VectorOperations::Hash(input, result, 1); - - auto data = FlatVector::GetData(result); - return data[0]; -} - -string Value::ToString() const { - if (IsNull()) { - return "NULL"; - } - return StringValue::Get(DefaultCastAs(LogicalType::VARCHAR)); -} - -string Value::ToSQLString() const { - if (IsNull()) { - return ToString(); - } - switch (type_.id()) { - case LogicalTypeId::UUID: - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::INTERVAL: - case LogicalTypeId::BLOB: - return "'" + ToString() + "'::" + type_.ToString(); - case LogicalTypeId::VARCHAR: - case LogicalTypeId::ENUM: { - auto str_val = ToString(); - if (str_val.size() == 1 && str_val[0] == '\0') { - return "chr(0)"; - } - return "'" + StringUtil::Replace(ToString(), "'", "''") + "'"; - } - case LogicalTypeId::STRUCT: { - bool is_unnamed = StructType::IsUnnamed(type_); - string ret = is_unnamed ? "(" : "{"; - auto &child_types = StructType::GetChildTypes(type_); - auto &struct_values = StructValue::GetChildren(*this); - for (idx_t i = 0; i < struct_values.size(); i++) { - auto &name = child_types[i].first; - auto &child = struct_values[i]; - if (is_unnamed) { - ret += child.ToSQLString(); - } else { - ret += "'" + name + "': " + child.ToSQLString(); - } - if (i < struct_values.size() - 1) { - ret += ", "; - } - } - ret += is_unnamed ? ")" : "}"; - return ret; - } - case LogicalTypeId::FLOAT: - if (!FloatIsFinite(FloatValue::Get(*this))) { - return "'" + ToString() + "'::" + type_.ToString(); - } - return ToString(); - case LogicalTypeId::DOUBLE: { - double val = DoubleValue::Get(*this); - if (!DoubleIsFinite(val)) { - if (!Value::IsNan(val)) { - // to infinity and beyond - return val < 0 ? "-1e1000" : "1e1000"; - } - return "'" + ToString() + "'::" + type_.ToString(); - } - return ToString(); - } - case LogicalTypeId::LIST: { - string ret = "["; - auto &list_values = ListValue::GetChildren(*this); - for (idx_t i = 0; i < list_values.size(); i++) { - auto &child = list_values[i]; - ret += child.ToSQLString(); - if (i < list_values.size() - 1) { - ret += ", "; - } - } - ret += "]"; - return ret; - } - case LogicalTypeId::ARRAY: { - string ret = "["; - auto &array_values = ArrayValue::GetChildren(*this); - for (idx_t i = 0; i < array_values.size(); i++) { - auto &child = array_values[i]; - ret += child.ToSQLString(); - if (i < array_values.size() - 1) { - ret += ", "; - } - } - ret += "]"; - return ret; - } - default: - return ToString(); - } -} - -//===--------------------------------------------------------------------===// -// Type-specific getters -//===--------------------------------------------------------------------===// -bool BooleanValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int8_t TinyIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int16_t SmallIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int32_t IntegerValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -int64_t BigIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -hugeint_t HugeIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint8_t UTinyIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint16_t USmallIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint32_t UIntegerValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uint64_t UBigIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -uhugeint_t UhugeIntValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -float FloatValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -double DoubleValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -const string &StringValue::Get(const Value &value) { - if (value.is_null) { - throw InternalException("Calling StringValue::Get on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetString(); -} - -date_t DateValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -dtime_t TimeValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_t TimestampValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_sec_t TimestampSValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_ms_t TimestampMSValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_ns_t TimestampNSValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -timestamp_tz_t TimestampTZValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -interval_t IntervalValue::Get(const Value &value) { - return value.GetValueUnsafe(); -} - -const vector &StructValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling StructValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::STRUCT); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const vector &MapValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling MapValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().id() == LogicalTypeId::MAP); - D_ASSERT(value.type().InternalType() == PhysicalType::LIST); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const vector &ListValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling ListValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::LIST); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const vector &ArrayValue::GetChildren(const Value &value) { - if (value.is_null) { - throw InternalException("Calling ArrayValue::GetChildren on a NULL value"); - } - D_ASSERT(value.type().InternalType() == PhysicalType::ARRAY); - D_ASSERT(value.value_info_); - return value.value_info_->Get().GetValues(); -} - -const Value &UnionValue::GetValue(const Value &value) { - D_ASSERT(value.type().id() == LogicalTypeId::UNION); - auto &children = StructValue::GetChildren(value); - auto tag = children[0].GetValueUnsafe(); - D_ASSERT(tag < children.size() - 1); - return children[tag + 1]; -} - -union_tag_t UnionValue::GetTag(const Value &value) { - D_ASSERT(value.type().id() == LogicalTypeId::UNION); - auto children = StructValue::GetChildren(value); - auto tag = children[0].GetValueUnsafe(); - D_ASSERT(tag < children.size() - 1); - return tag; -} - -const LogicalType &UnionValue::GetType(const Value &value) { - return UnionType::GetMemberType(value.type(), UnionValue::GetTag(value)); -} - -hugeint_t IntegralValue::Get(const Value &value) { - switch (value.type().InternalType()) { - case PhysicalType::INT8: - return TinyIntValue::Get(value); - case PhysicalType::INT16: - return SmallIntValue::Get(value); - case PhysicalType::INT32: - return IntegerValue::Get(value); - case PhysicalType::INT64: - return BigIntValue::Get(value); - case PhysicalType::INT128: - return HugeIntValue::Get(value); - case PhysicalType::UINT8: - return UTinyIntValue::Get(value); - case PhysicalType::UINT16: - return USmallIntValue::Get(value); - case PhysicalType::UINT32: - return UIntegerValue::Get(value); - case PhysicalType::UINT64: - return NumericCast(UBigIntValue::Get(value)); - case PhysicalType::UINT128: - return static_cast(UhugeIntValue::Get(value)); - default: - throw InternalException("Invalid internal type \"%s\" for IntegralValue::Get", value.type().ToString()); - } -} - -//===--------------------------------------------------------------------===// -// Comparison Operators -//===--------------------------------------------------------------------===// -bool Value::operator==(const Value &rhs) const { - return ValueOperations::Equals(*this, rhs); -} - -bool Value::operator!=(const Value &rhs) const { - return ValueOperations::NotEquals(*this, rhs); -} - -bool Value::operator<(const Value &rhs) const { - return ValueOperations::LessThan(*this, rhs); -} - -bool Value::operator>(const Value &rhs) const { - return ValueOperations::GreaterThan(*this, rhs); -} - -bool Value::operator<=(const Value &rhs) const { - return ValueOperations::LessThanEquals(*this, rhs); -} - -bool Value::operator>=(const Value &rhs) const { - return ValueOperations::GreaterThanEquals(*this, rhs); -} - -bool Value::operator==(const int64_t &rhs) const { - return *this == Value::Numeric(type_, rhs); -} - -bool Value::operator!=(const int64_t &rhs) const { - return *this != Value::Numeric(type_, rhs); -} - -bool Value::operator<(const int64_t &rhs) const { - return *this < Value::Numeric(type_, rhs); -} - -bool Value::operator>(const int64_t &rhs) const { - return *this > Value::Numeric(type_, rhs); -} - -bool Value::operator<=(const int64_t &rhs) const { - return *this <= Value::Numeric(type_, rhs); -} - -bool Value::operator>=(const int64_t &rhs) const { - return *this >= Value::Numeric(type_, rhs); -} - -bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - Value &new_value, string *error_message, bool strict) const { - if (type_ == target_type) { - new_value = Copy(); - return true; - } - Vector input(*this); - Vector result(target_type); - if (!VectorOperations::TryCast(set, get_input, input, result, 1, error_message, strict)) { - return false; - } - new_value = result.GetValue(0); - return true; -} - -bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, Value &new_value, string *error_message, - bool strict) const { - GetCastFunctionInput get_input(context); - return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, new_value, error_message, strict); -} - -bool Value::DefaultTryCastAs(const LogicalType &target_type, Value &new_value, string *error_message, - bool strict) const { - CastFunctionSet set; - GetCastFunctionInput get_input; - return TryCastAs(set, get_input, target_type, new_value, error_message, strict); -} - -Value Value::CastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - bool strict) const { - if (target_type.id() == LogicalTypeId::ANY) { - return *this; - } - Value new_value; - string error_message; - if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { - throw InvalidInputException("Failed to cast value: %s", error_message); - } - return new_value; -} - -Value Value::CastAs(ClientContext &context, const LogicalType &target_type, bool strict) const { - GetCastFunctionInput get_input(context); - return CastAs(CastFunctionSet::Get(context), get_input, target_type, strict); -} - -Value Value::DefaultCastAs(const LogicalType &target_type, bool strict) const { - CastFunctionSet set; - GetCastFunctionInput get_input; - return CastAs(set, get_input, target_type, strict); -} - -bool Value::TryCastAs(CastFunctionSet &set, GetCastFunctionInput &get_input, const LogicalType &target_type, - bool strict) { - Value new_value; - string error_message; - if (!TryCastAs(set, get_input, target_type, new_value, &error_message, strict)) { - return false; - } - type_ = target_type; - is_null = new_value.is_null; - value_ = new_value.value_; - value_info_ = std::move(new_value.value_info_); - return true; -} - -bool Value::TryCastAs(ClientContext &context, const LogicalType &target_type, bool strict) { - GetCastFunctionInput get_input(context); - return TryCastAs(CastFunctionSet::Get(context), get_input, target_type, strict); -} - -bool Value::DefaultTryCastAs(const LogicalType &target_type, bool strict) { - CastFunctionSet set; - GetCastFunctionInput get_input; - return TryCastAs(set, get_input, target_type, strict); -} - -void Value::Reinterpret(LogicalType new_type) { - this->type_ = std::move(new_type); -} - -const LogicalType &GetChildType(const LogicalType &parent_type, idx_t i) { - switch (parent_type.InternalType()) { - case PhysicalType::LIST: - return ListType::GetChildType(parent_type); - case PhysicalType::STRUCT: - return StructType::GetChildType(parent_type, i); - case PhysicalType::ARRAY: - return ArrayType::GetChildType(parent_type); - default: - throw InternalException("Parent type is not a nested type"); - } -} - -bool SerializeTypeMatches(const LogicalType &expected_type, const LogicalType &actual_type) { - if (expected_type.id() != actual_type.id()) { - // type id needs to be the same - return false; - } - if (expected_type.IsNested()) { - // for nested types that is enough - we will recurse into the children and check there again anyway - return true; - } - // otherwise we do a deep comparison of the type (e.g. decimal flags need to be consistent) - return expected_type == actual_type; -} - -void Value::SerializeChildren(Serializer &serializer, const vector &children, const LogicalType &parent_type) { - serializer.WriteObject(102, "value", [&](Serializer &child_serializer) { - child_serializer.WriteList(100, "children", children.size(), [&](Serializer::List &list, idx_t i) { - auto &value_type = GetChildType(parent_type, i); - bool serialize_type = value_type.InternalType() == PhysicalType::INVALID; - if (!serialize_type && !SerializeTypeMatches(value_type, children[i].type())) { - throw InternalException("Error when serializing type - serializing a child of a nested value with type " - "%s, but expected type %s", - children[i].type(), value_type); - } - list.WriteObject([&](Serializer &element_serializer) { - children[i].SerializeInternal(element_serializer, serialize_type); - }); - }); - }); -} - -void Value::SerializeInternal(Serializer &serializer, bool serialize_type) const { - if (serialize_type || !serializer.ShouldSerialize(4)) { - // only the root value needs to serialize its type - // for forwards compatibility reasons, we also serialize the type always when targeting versions < v1.2.0 - serializer.WriteProperty(100, "type", type_); - } - serializer.WriteProperty(101, "is_null", is_null); - if (IsNull()) { - return; - } - switch (type_.InternalType()) { - case PhysicalType::BIT: - throw InternalException("BIT type should not be serialized"); - case PhysicalType::BOOL: - serializer.WriteProperty(102, "value", value_.boolean); - break; - case PhysicalType::INT8: - serializer.WriteProperty(102, "value", value_.tinyint); - break; - case PhysicalType::INT16: - serializer.WriteProperty(102, "value", value_.smallint); - break; - case PhysicalType::INT32: - serializer.WriteProperty(102, "value", value_.integer); - break; - case PhysicalType::INT64: - serializer.WriteProperty(102, "value", value_.bigint); - break; - case PhysicalType::UINT8: - serializer.WriteProperty(102, "value", value_.utinyint); - break; - case PhysicalType::UINT16: - serializer.WriteProperty(102, "value", value_.usmallint); - break; - case PhysicalType::UINT32: - serializer.WriteProperty(102, "value", value_.uinteger); - break; - case PhysicalType::UINT64: - serializer.WriteProperty(102, "value", value_.ubigint); - break; - case PhysicalType::INT128: - serializer.WriteProperty(102, "value", value_.hugeint); - break; - case PhysicalType::UINT128: - serializer.WriteProperty(102, "value", value_.uhugeint); - break; - case PhysicalType::FLOAT: - serializer.WriteProperty(102, "value", value_.float_); - break; - case PhysicalType::DOUBLE: - serializer.WriteProperty(102, "value", value_.double_); - break; - case PhysicalType::INTERVAL: - serializer.WriteProperty(102, "value", value_.interval); - break; - case PhysicalType::VARCHAR: { - if (type_.id() == LogicalTypeId::BLOB) { - auto blob_str = Blob::ToString(StringValue::Get(*this)); - serializer.WriteProperty(102, "value", blob_str); - } else { - serializer.WriteProperty(102, "value", StringValue::Get(*this)); - } - } break; - case PhysicalType::LIST: - SerializeChildren(serializer, ListValue::GetChildren(*this), type_); - break; - case PhysicalType::STRUCT: - SerializeChildren(serializer, StructValue::GetChildren(*this), type_); - break; - case PhysicalType::ARRAY: - SerializeChildren(serializer, ArrayValue::GetChildren(*this), type_); - break; - default: - throw NotImplementedException("Unimplemented type for Serialize"); - } -} - -void Value::Serialize(Serializer &serializer) const { - // serialize the value - the top-level value always needs to serialize its type - SerializeInternal(serializer, true); -} - -Value Value::Deserialize(Deserializer &deserializer) { - auto type = deserializer.ReadPropertyWithExplicitDefault(100, "type", LogicalTypeId::INVALID); - if (type.InternalType() == PhysicalType::INVALID) { - type = deserializer.Get(); - } - auto is_null = deserializer.ReadProperty(101, "is_null"); - Value new_value = Value(type); - if (is_null) { - return new_value; - } - new_value.is_null = false; - switch (type.InternalType()) { - case PhysicalType::BIT: - throw InternalException("BIT type should not be deserialized"); - case PhysicalType::BOOL: - new_value.value_.boolean = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT8: - new_value.value_.utinyint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT8: - new_value.value_.tinyint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT16: - new_value.value_.usmallint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT16: - new_value.value_.smallint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT32: - new_value.value_.uinteger = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT32: - new_value.value_.integer = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT64: - new_value.value_.ubigint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT64: - new_value.value_.bigint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::UINT128: - new_value.value_.uhugeint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INT128: - new_value.value_.hugeint = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::FLOAT: - new_value.value_.float_ = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::DOUBLE: - new_value.value_.double_ = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::INTERVAL: - new_value.value_.interval = deserializer.ReadProperty(102, "value"); - break; - case PhysicalType::VARCHAR: { - auto str = deserializer.ReadProperty(102, "value"); - if (type.id() == LogicalTypeId::BLOB) { - new_value.value_info_ = make_shared_ptr(Blob::ToBlob(str)); - } else { - new_value.value_info_ = make_shared_ptr(str); - } - } break; - case PhysicalType::LIST: { - deserializer.Set(ListType::GetChildType(type)); - deserializer.ReadObject(102, "value", [&](Deserializer &obj) { - auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared_ptr(children); - }); - deserializer.Unset(); - } break; - case PhysicalType::STRUCT: { - deserializer.ReadObject(102, "value", [&](Deserializer &obj) { - vector children; - obj.ReadList(100, "children", [&](Deserializer::List &list, idx_t i) { - deserializer.Set(StructType::GetChildType(type, i)); - auto child = list.ReadElement(); - deserializer.Unset(); - children.push_back(std::move(child)); - }); - new_value.value_info_ = make_shared_ptr(children); - }); - } break; - case PhysicalType::ARRAY: { - deserializer.Set(ArrayType::GetChildType(type)); - deserializer.ReadObject(102, "value", [&](Deserializer &obj) { - auto children = obj.ReadProperty>(100, "children"); - new_value.value_info_ = make_shared_ptr(children); - }); - deserializer.Unset(); - } break; - default: - throw NotImplementedException("Unimplemented type for Deserialize"); - } - return new_value; -} - -void Value::Print() const { - Printer::Print(ToString()); -} - -bool Value::NotDistinctFrom(const Value &lvalue, const Value &rvalue) { - return ValueOperations::NotDistinctFrom(lvalue, rvalue); -} - -static string SanitizeValue(string input) { - // some results might contain padding spaces, e.g. when rendering - // VARCHAR(10) and the string only has 6 characters, they will be padded - // with spaces to 10 in the rendering. We don't do that here yet as we - // are looking at internal structures. So just ignore any extra spaces - // on the right - StringUtil::RTrim(input); - // for result checking code, replace null bytes with their escaped value (\0) - return StringUtil::Replace(input, string("\0", 1), "\\0"); -} - -bool Value::ValuesAreEqual(CastFunctionSet &set, GetCastFunctionInput &get_input, const Value &result_value, - const Value &value) { - if (result_value.IsNull() != value.IsNull()) { - return false; - } - if (result_value.IsNull() && value.IsNull()) { - // NULL = NULL in checking code - return true; - } - switch (value.type_.id()) { - case LogicalTypeId::FLOAT: { - auto other = result_value.CastAs(set, get_input, LogicalType::FLOAT); - float ldecimal = value.value_.float_; - float rdecimal = other.value_.float_; - return ApproxEqual(ldecimal, rdecimal); - } - case LogicalTypeId::DOUBLE: { - auto other = result_value.CastAs(set, get_input, LogicalType::DOUBLE); - double ldecimal = value.value_.double_; - double rdecimal = other.value_.double_; - return ApproxEqual(ldecimal, rdecimal); - } - case LogicalTypeId::VARCHAR: { - auto other = result_value.CastAs(set, get_input, LogicalType::VARCHAR); - string left = SanitizeValue(StringValue::Get(other)); - string right = SanitizeValue(StringValue::Get(value)); - return left == right; - } - default: - if (result_value.type_.id() == LogicalTypeId::FLOAT || result_value.type_.id() == LogicalTypeId::DOUBLE) { - return Value::ValuesAreEqual(set, get_input, value, result_value); - } - return value == result_value; - } -} - -bool Value::ValuesAreEqual(ClientContext &context, const Value &result_value, const Value &value) { - GetCastFunctionInput get_input(context); - return Value::ValuesAreEqual(CastFunctionSet::Get(context), get_input, result_value, value); -} -bool Value::DefaultValuesAreEqual(const Value &result_value, const Value &value) { - CastFunctionSet set; - GetCastFunctionInput get_input; - return Value::ValuesAreEqual(set, get_input, result_value, value); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/varint.cpp b/src/duckdb/src/common/types/varint.cpp deleted file mode 100644 index b92ded113..000000000 --- a/src/duckdb/src/common/types/varint.cpp +++ /dev/null @@ -1,311 +0,0 @@ -#include "duckdb/common/types/varint.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include - -namespace duckdb { - -void Varint::Verify(const string_t &input) { -#ifdef DEBUG - // Size must be >= 4 - idx_t varint_bytes = input.GetSize(); - if (varint_bytes < 4) { - throw InternalException("Varint number of bytes is invalid, current number of bytes is %d", varint_bytes); - } - // Bytes in header must quantify the number of data bytes - auto varint_ptr = input.GetData(); - bool is_negative = (varint_ptr[0] & 0x80) == 0; - uint32_t number_of_bytes = 0; - char mask = 0x7F; - if (is_negative) { - number_of_bytes |= static_cast(~varint_ptr[0] & mask) << 16 & 0xFF0000; - number_of_bytes |= static_cast(~varint_ptr[1]) << 8 & 0xFF00; - ; - number_of_bytes |= static_cast(~varint_ptr[2]) & 0xFF; - } else { - number_of_bytes |= static_cast(varint_ptr[0] & mask) << 16 & 0xFF0000; - number_of_bytes |= static_cast(varint_ptr[1]) << 8 & 0xFF00; - number_of_bytes |= static_cast(varint_ptr[2]) & 0xFF; - } - if (number_of_bytes != varint_bytes - 3) { - throw InternalException("The number of bytes set in the Varint header: %d bytes. Does not " - "match the number of bytes encountered as the varint data: %d bytes.", - number_of_bytes, varint_bytes - 3); - } - // No bytes between 4 and end can be 0, unless total size == 4 - if (varint_bytes > 4) { - if (is_negative) { - if (~varint_ptr[3] == 0) { - throw InternalException("Invalid top data bytes set to 0 for VARINT values"); - } - } else { - if (varint_ptr[3] == 0) { - throw InternalException("Invalid top data bytes set to 0 for VARINT values"); - } - } - } -#endif -} -void Varint::SetHeader(char *blob, uint64_t number_of_bytes, bool is_negative) { - uint32_t header = static_cast(number_of_bytes); - // Set MSBit of 3rd byte - header |= 0x00800000; - if (is_negative) { - header = ~header; - } - // we ignore MSByte of header. - // write the 3 bytes to blob. - blob[0] = static_cast(header >> 16); - blob[1] = static_cast(header >> 8 & 0xFF); - blob[2] = static_cast(header & 0xFF); -} - -// Creates a blob representing the value 0 -string_t Varint::InitializeVarintZero(Vector &result) { - uint32_t blob_size = 1 + VARINT_HEADER_SIZE; - auto blob = StringVector::EmptyString(result, blob_size); - auto writable_blob = blob.GetDataWriteable(); - SetHeader(writable_blob, 1, false); - writable_blob[3] = 0; - blob.Finalize(); - return blob; -} - -string Varint::InitializeVarintZero() { - uint32_t blob_size = 1 + VARINT_HEADER_SIZE; - string result(blob_size, '0'); - SetHeader(&result[0], 1, false); - result[3] = 0; - return result; -} - -int Varint::CharToDigit(char c) { - return c - '0'; -} - -char Varint::DigitToChar(int digit) { - // FIXME: this would be the proper solution: - // return UnsafeNumericCast(digit + '0'); - return static_cast(digit + '0'); -} - -bool Varint::VarcharFormatting(const string_t &value, idx_t &start_pos, idx_t &end_pos, bool &is_negative, - bool &is_zero) { - // If it's empty we error - if (value.Empty()) { - return false; - } - start_pos = 0; - is_zero = false; - - auto int_value_char = value.GetData(); - end_pos = value.GetSize(); - - // If first character is -, we have a negative number, if + we have a + number - is_negative = int_value_char[0] == '-'; - if (is_negative) { - start_pos++; - } - if (int_value_char[0] == '+') { - start_pos++; - } - // Now lets trim 0s - bool at_least_one_zero = false; - while (start_pos < end_pos && int_value_char[start_pos] == '0') { - start_pos++; - at_least_one_zero = true; - } - if (start_pos == end_pos) { - if (at_least_one_zero) { - // This is a 0 value - is_zero = true; - return true; - } - // This is either a '+' or '-'. Hence, invalid. - return false; - } - idx_t cur_pos = start_pos; - // Verify all is numeric - while (cur_pos < end_pos && std::isdigit(int_value_char[cur_pos])) { - cur_pos++; - } - if (cur_pos < end_pos) { - idx_t possible_end = cur_pos; - // Oh oh, this is not a digit, if it's a . we might be fine, otherwise, this is invalid. - if (int_value_char[cur_pos] == '.') { - cur_pos++; - } else { - return false; - } - - while (cur_pos < end_pos) { - if (std::isdigit(int_value_char[cur_pos])) { - cur_pos++; - } else { - // By now we can only have numbers, otherwise this is invalid. - return false; - } - } - // Floor cast this boy - end_pos = possible_end; - } - return true; -} - -void Varint::GetByteArray(vector &byte_array, bool &is_negative, const string_t &blob) { - if (blob.GetSize() < 4) { - throw InvalidInputException("Invalid blob size."); - } - auto blob_ptr = blob.GetData(); - - // Determine if the number is negative - is_negative = (blob_ptr[0] & 0x80) == 0; - for (idx_t i = 3; i < blob.GetSize(); i++) { - if (is_negative) { - byte_array.push_back(static_cast(~blob_ptr[i])); - } else { - byte_array.push_back(static_cast(blob_ptr[i])); - } - } -} - -string Varint::FromByteArray(uint8_t *data, idx_t size, bool is_negative) { - string result(VARINT_HEADER_SIZE + size, '0'); - SetHeader(&result[0], size, is_negative); - uint8_t *result_data = reinterpret_cast(&result[VARINT_HEADER_SIZE]); - if (is_negative) { - for (idx_t i = 0; i < size; i++) { - result_data[i] = ~data[i]; - } - } else { - for (idx_t i = 0; i < size; i++) { - result_data[i] = data[i]; - } - } - return result; -} - -string Varint::VarIntToVarchar(const string_t &blob) { - string decimal_string; - vector byte_array; - bool is_negative; - GetByteArray(byte_array, is_negative, blob); - while (!byte_array.empty()) { - string quotient; - uint8_t remainder = 0; - for (uint8_t byte : byte_array) { - int new_value = remainder * 256 + byte; - quotient += DigitToChar(new_value / 10); - remainder = static_cast(new_value % 10); - } - decimal_string += DigitToChar(remainder); - // Remove leading zeros from the quotient - byte_array.clear(); - for (char digit : quotient) { - if (digit != '0' || !byte_array.empty()) { - byte_array.push_back(static_cast(CharToDigit(digit))); - } - } - } - if (is_negative) { - decimal_string += '-'; - } - // Reverse the string to get the correct decimal representation - std::reverse(decimal_string.begin(), decimal_string.end()); - return decimal_string; -} - -string Varint::VarcharToVarInt(const string_t &value) { - idx_t start_pos, end_pos; - bool is_negative, is_zero; - if (!VarcharFormatting(value, start_pos, end_pos, is_negative, is_zero)) { - throw ConversionException("Could not convert string \'%s\' to Varint", value.GetString()); - } - if (is_zero) { - // Return Value 0 - return InitializeVarintZero(); - } - auto int_value_char = value.GetData(); - idx_t actual_size = end_pos - start_pos; - - // we initalize result with space for our header - string result(VARINT_HEADER_SIZE, '0'); - unsafe_vector digits; - - // The max number a uint64_t can represent is 18.446.744.073.709.551.615 - // That has 20 digits - // In the worst case a remainder of a division will be 255, which is 3 digits - // Since the max value is 184, we need to take one more digit out - // Hence we end up with a max of 16 digits supported. - constexpr uint8_t max_digits = 16; - const idx_t number_of_digits = static_cast(std::ceil(static_cast(actual_size) / max_digits)); - - // lets convert the string to a uint64_t vector - idx_t cur_end = end_pos; - for (idx_t i = 0; i < number_of_digits; i++) { - idx_t cur_start = static_cast(start_pos) > static_cast(cur_end - max_digits) - ? start_pos - : cur_end - max_digits; - std::string current_number(int_value_char + cur_start, cur_end - cur_start); - digits.push_back(std::stoull(current_number)); - // move cur_end to more digits down the road - cur_end = cur_end - max_digits; - } - - // Now that we have our uint64_t vector, lets start our division process to figure out the new number and remainder - while (!digits.empty()) { - idx_t digit_idx = digits.size() - 1; - uint8_t remainder = 0; - idx_t digits_size = digits.size(); - for (idx_t i = 0; i < digits_size; i++) { - digits[digit_idx] += static_cast(remainder * pow(10, max_digits)); - remainder = static_cast(digits[digit_idx] % 256); - digits[digit_idx] /= 256; - if (digits[digit_idx] == 0 && digit_idx == digits.size() - 1) { - // we can cap this - digits.pop_back(); - } - digit_idx--; - } - if (is_negative) { - result.push_back(static_cast(~remainder)); - } else { - result.push_back(static_cast(remainder)); - } - } - std::reverse(result.begin() + VARINT_HEADER_SIZE, result.end()); - // Set header after we know the size of the varint - SetHeader(&result[0], result.size() - VARINT_HEADER_SIZE, is_negative); - return result; -} - -bool Varint::VarintToDouble(const string_t &blob, double &result, bool &strict) { - result = 0; - - if (blob.GetSize() < 4) { - throw InvalidInputException("Invalid blob size."); - } - auto blob_ptr = blob.GetData(); - - // Determine if the number is negative - bool is_negative = (blob_ptr[0] & 0x80) == 0; - idx_t byte_pos = 0; - for (idx_t i = blob.GetSize() - 1; i > 2; i--) { - if (is_negative) { - result += static_cast(~blob_ptr[i]) * pow(256, static_cast(byte_pos)); - } else { - result += static_cast(blob_ptr[i]) * pow(256, static_cast(byte_pos)); - } - byte_pos++; - } - - if (is_negative) { - result *= -1; - } - if (!std::isfinite(result)) { - // We throw an error - throw ConversionException("Could not convert varint '%s' to Double", VarIntToVarchar(blob)); - } - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector.cpp b/src/duckdb/src/common/types/vector.cpp deleted file mode 100644 index 722b8b699..000000000 --- a/src/duckdb/src/common/types/vector.cpp +++ /dev/null @@ -1,2661 +0,0 @@ -#include "duckdb/common/types/vector.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/assert.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/fsst.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/type_visitor.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/sel_cache.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/common/types/varint.hpp" -#include "duckdb/common/types/vector_cache.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" -#include "duckdb/storage/string_uncompressed.hpp" -#include "fsst.h" - -#include // strlen() on Solaris - -namespace duckdb { - -UnifiedVectorFormat::UnifiedVectorFormat() : sel(nullptr), data(nullptr) { -} - -UnifiedVectorFormat::UnifiedVectorFormat(UnifiedVectorFormat &&other) noexcept : sel(nullptr), data(nullptr) { - bool refers_to_self = other.sel == &other.owned_sel; - std::swap(sel, other.sel); - std::swap(data, other.data); - std::swap(validity, other.validity); - std::swap(owned_sel, other.owned_sel); - if (refers_to_self) { - sel = &owned_sel; - } -} - -UnifiedVectorFormat &UnifiedVectorFormat::operator=(UnifiedVectorFormat &&other) noexcept { - bool refers_to_self = other.sel == &other.owned_sel; - std::swap(sel, other.sel); - std::swap(data, other.data); - std::swap(validity, other.validity); - std::swap(owned_sel, other.owned_sel); - if (refers_to_self) { - sel = &owned_sel; - } - return *this; -} - -Vector::Vector(LogicalType type_p, bool create_data, bool initialize_to_zero, idx_t capacity) - : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(nullptr), validity(capacity) { - if (create_data) { - Initialize(initialize_to_zero, capacity); - } -} - -Vector::Vector(LogicalType type_p, idx_t capacity) : Vector(std::move(type_p), true, false, capacity) { -} - -Vector::Vector(LogicalType type_p, data_ptr_t dataptr) - : vector_type(VectorType::FLAT_VECTOR), type(std::move(type_p)), data(dataptr) { - if (dataptr && !type.IsValid()) { - throw InternalException("Cannot create a vector of type INVALID!"); - } -} - -Vector::Vector(const VectorCache &cache) : type(cache.GetType()) { - ResetFromCache(cache); -} - -Vector::Vector(Vector &other) : type(other.type) { - Reference(other); -} - -Vector::Vector(const Vector &other, const SelectionVector &sel, idx_t count) : type(other.type) { - Slice(other, sel, count); -} - -Vector::Vector(const Vector &other, idx_t offset, idx_t end) : type(other.type) { - Slice(other, offset, end); -} - -Vector::Vector(const Value &value) : type(value.type()) { - Reference(value); -} - -Vector::Vector(Vector &&other) noexcept - : vector_type(other.vector_type), type(std::move(other.type)), data(other.data), - validity(std::move(other.validity)), buffer(std::move(other.buffer)), auxiliary(std::move(other.auxiliary)) { -} - -void Vector::Reference(const Value &value) { - D_ASSERT(GetType().id() == value.type().id()); - this->vector_type = VectorType::CONSTANT_VECTOR; - buffer = VectorBuffer::CreateConstantVector(value.type()); - auto internal_type = value.type().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - auto struct_buffer = make_uniq(); - auto &child_types = StructType::GetChildTypes(value.type()); - auto &child_vectors = struct_buffer->GetChildren(); - for (idx_t i = 0; i < child_types.size(); i++) { - auto vector = - make_uniq(value.IsNull() ? Value(child_types[i].second) : StructValue::GetChildren(value)[i]); - child_vectors.push_back(std::move(vector)); - } - auxiliary = shared_ptr(struct_buffer.release()); - if (value.IsNull()) { - SetValue(0, value); - } - } else if (internal_type == PhysicalType::LIST) { - auto list_buffer = make_uniq(value.type()); - auxiliary = shared_ptr(list_buffer.release()); - data = buffer->GetData(); - SetValue(0, value); - } else if (internal_type == PhysicalType::ARRAY) { - auto array_buffer = make_uniq(value.type()); - auxiliary = shared_ptr(array_buffer.release()); - SetValue(0, value); - } else { - auxiliary.reset(); - data = buffer->GetData(); - SetValue(0, value); - } -} - -void Vector::Reference(const Vector &other) { - if (other.GetType().id() != GetType().id()) { - throw InternalException("Vector::Reference used on vector of different type"); - } - D_ASSERT(other.GetType() == GetType()); - Reinterpret(other); -} - -void Vector::ReferenceAndSetType(const Vector &other) { - type = other.GetType(); - Reference(other); -} - -void Vector::Reinterpret(const Vector &other) { - vector_type = other.vector_type; -#ifdef DEBUG - auto &this_type = GetType(); - auto &other_type = other.GetType(); - - auto type_is_same = other_type == this_type; - bool this_is_nested = this_type.IsNested(); - bool other_is_nested = other_type.IsNested(); - - bool not_nested = this_is_nested == false && other_is_nested == false; - bool type_size_equal = GetTypeIdSize(this_type.InternalType()) == GetTypeIdSize(other_type.InternalType()); - //! Either the types are completely identical, or they are not nested and their physical type size is the same - //! The reason nested types are not allowed is because copying the auxiliary buffer does not happen recursively - //! e.g DOUBLE[] to BIGINT[], the type of the LIST would say BIGINT but the child Vector says DOUBLE - D_ASSERT((not_nested && type_size_equal) || type_is_same); -#endif - AssignSharedPointer(buffer, other.buffer); - AssignSharedPointer(auxiliary, other.auxiliary); - data = other.data; - validity = other.validity; -} - -void Vector::ResetFromCache(const VectorCache &cache) { - cache.ResetFromCache(*this); -} - -void Vector::Slice(const Vector &other, idx_t offset, idx_t end) { - D_ASSERT(end >= offset); - if (other.GetVectorType() == VectorType::CONSTANT_VECTOR) { - Reference(other); - return; - } - if (other.GetVectorType() != VectorType::FLAT_VECTOR) { - // we can slice the data directly only for flat vectors - // for non-flat vectors slice using a selection vector instead - idx_t count = end - offset; - SelectionVector sel(count); - for (idx_t i = 0; i < count; i++) { - sel.set_index(i, offset + i); - } - Slice(other, sel, count); - return; - } - - auto internal_type = GetType().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - Vector new_vector(GetType()); - auto &entries = StructVector::GetEntries(new_vector); - auto &other_entries = StructVector::GetEntries(other); - D_ASSERT(entries.size() == other_entries.size()); - for (idx_t i = 0; i < entries.size(); i++) { - entries[i]->Slice(*other_entries[i], offset, end); - } - new_vector.validity.Slice(other.validity, offset, end - offset); - Reference(new_vector); - } else if (internal_type == PhysicalType::ARRAY) { - Vector new_vector(GetType()); - auto &child_vec = ArrayVector::GetEntry(new_vector); - auto &other_child_vec = ArrayVector::GetEntry(other); - D_ASSERT(ArrayType::GetSize(GetType()) == ArrayType::GetSize(other.GetType())); - const auto array_size = ArrayType::GetSize(GetType()); - // We need to slice the child vector with the multiplied offset and end - child_vec.Slice(other_child_vec, offset * array_size, end * array_size); - new_vector.validity.Slice(other.validity, offset, end - offset); - Reference(new_vector); - } else { - Reference(other); - if (offset > 0) { - data = data + GetTypeIdSize(internal_type) * offset; - validity.Slice(other.validity, offset, end - offset); - } - } -} - -void Vector::Slice(const Vector &other, const SelectionVector &sel, idx_t count) { - Reference(other); - Slice(sel, count); -} - -void Vector::Slice(const SelectionVector &sel, idx_t count) { - if (GetVectorType() == VectorType::CONSTANT_VECTOR) { - // dictionary on a constant is just a constant - return; - } - if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // already a dictionary, slice the current dictionary - auto ¤t_sel = DictionaryVector::SelVector(*this); - auto dictionary_size = DictionaryVector::DictionarySize(*this); - auto dictionary_id = DictionaryVector::DictionaryId(*this); - auto sliced_dictionary = current_sel.Slice(sel, count); - buffer = make_buffer(std::move(sliced_dictionary)); - if (GetType().InternalType() == PhysicalType::STRUCT) { - auto &child_vector = DictionaryVector::Child(*this); - - Vector new_child(child_vector); - new_child.auxiliary = make_buffer(new_child, sel, count); - auxiliary = make_buffer(std::move(new_child)); - } - if (dictionary_size.IsValid()) { - auto &dict_buffer = buffer->Cast(); - dict_buffer.SetDictionarySize(dictionary_size.GetIndex()); - dict_buffer.SetDictionaryId(std::move(dictionary_id)); - } - return; - } - - if (GetVectorType() == VectorType::FSST_VECTOR) { - Flatten(sel, count); - return; - } - - Vector child_vector(*this); - auto internal_type = GetType().InternalType(); - if (internal_type == PhysicalType::STRUCT) { - child_vector.auxiliary = make_buffer(*this, sel, count); - } - auto child_ref = make_buffer(std::move(child_vector)); - auto dict_buffer = make_buffer(sel); - vector_type = VectorType::DICTIONARY_VECTOR; - buffer = std::move(dict_buffer); - auxiliary = std::move(child_ref); -} - -void Vector::Dictionary(idx_t dictionary_size, const SelectionVector &sel, idx_t count) { - Slice(sel, count); - if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { - buffer->Cast().SetDictionarySize(dictionary_size); - } -} - -void Vector::Dictionary(const Vector &dict, idx_t dictionary_size, const SelectionVector &sel, idx_t count) { - Reference(dict); - Dictionary(dictionary_size, sel, count); -} - -void Vector::Slice(const SelectionVector &sel, idx_t count, SelCache &cache) { - if (GetVectorType() == VectorType::DICTIONARY_VECTOR && GetType().InternalType() != PhysicalType::STRUCT) { - // dictionary vector: need to merge dictionaries - // check if we have a cached entry - auto ¤t_sel = DictionaryVector::SelVector(*this); - auto dictionary_size = DictionaryVector::DictionarySize(*this); - auto dictionary_id = DictionaryVector::DictionaryId(*this); - auto target_data = current_sel.data(); - auto entry = cache.cache.find(target_data); - if (entry != cache.cache.end()) { - // cached entry exists: use that - this->buffer = make_buffer(entry->second->Cast().GetSelVector()); - vector_type = VectorType::DICTIONARY_VECTOR; - } else { - Slice(sel, count); - cache.cache[target_data] = this->buffer; - } - if (dictionary_size.IsValid()) { - auto &dict_buffer = buffer->Cast(); - dict_buffer.SetDictionarySize(dictionary_size.GetIndex()); - dict_buffer.SetDictionaryId(std::move(dictionary_id)); - } - } else { - Slice(sel, count); - } -} - -void Vector::Initialize(bool initialize_to_zero, idx_t capacity) { - auxiliary.reset(); - validity.Reset(); - auto &type = GetType(); - auto internal_type = type.InternalType(); - if (internal_type == PhysicalType::STRUCT) { - auto struct_buffer = make_uniq(type, capacity); - auxiliary = shared_ptr(struct_buffer.release()); - } else if (internal_type == PhysicalType::LIST) { - auto list_buffer = make_uniq(type, capacity); - auxiliary = shared_ptr(list_buffer.release()); - } else if (internal_type == PhysicalType::ARRAY) { - auto array_buffer = make_uniq(type, capacity); - auxiliary = shared_ptr(array_buffer.release()); - } - auto type_size = GetTypeIdSize(internal_type); - if (type_size > 0) { - buffer = VectorBuffer::CreateStandardVector(type, capacity); - data = buffer->GetData(); - if (initialize_to_zero) { - memset(data, 0, capacity * type_size); - } - } - - if (capacity > validity.Capacity()) { - validity.Resize(capacity); - } -} - -void Vector::FindResizeInfos(vector &resize_infos, const idx_t multiplier) { - - ResizeInfo resize_info(*this, data, buffer.get(), multiplier); - resize_infos.emplace_back(resize_info); - - // Base case. - if (data) { - return; - } - - D_ASSERT(auxiliary); - switch (GetAuxiliary()->GetBufferType()) { - case VectorBufferType::LIST_BUFFER: { - auto &vector_list_buffer = auxiliary->Cast(); - auto &child = vector_list_buffer.GetChild(); - child.FindResizeInfos(resize_infos, multiplier); - break; - } - case VectorBufferType::STRUCT_BUFFER: { - auto &vector_struct_buffer = auxiliary->Cast(); - auto &children = vector_struct_buffer.GetChildren(); - for (auto &child : children) { - child->FindResizeInfos(resize_infos, multiplier); - } - break; - } - case VectorBufferType::ARRAY_BUFFER: { - // We need to multiply the multiplier by the array size because - // the child vectors of ARRAY types are always child_count * array_size. - auto &vector_array_buffer = auxiliary->Cast(); - auto new_multiplier = vector_array_buffer.GetArraySize() * multiplier; - auto &child = vector_array_buffer.GetChild(); - child.FindResizeInfos(resize_infos, new_multiplier); - break; - } - default: - break; - } -} - -void Vector::Resize(idx_t current_size, idx_t new_size) { - // The vector does not contain any data. - if (!buffer) { - buffer = make_buffer(0); - } - - // Obtain the resize information for each (nested) vector. - vector resize_infos; - FindResizeInfos(resize_infos, 1); - - for (auto &resize_info_entry : resize_infos) { - // Resize the validity mask. - auto new_validity_size = new_size * resize_info_entry.multiplier; - resize_info_entry.vec.validity.Resize(new_validity_size); - - // For nested data types, we only need to resize the validity mask. - if (!resize_info_entry.data) { - continue; - } - - auto type_size = GetTypeIdSize(resize_info_entry.vec.GetType().InternalType()); - auto old_size = current_size * type_size * resize_info_entry.multiplier * sizeof(data_t); - auto target_size = new_size * type_size * resize_info_entry.multiplier * sizeof(data_t); - - // We have an upper limit of 128GB for a single vector. - if (target_size > DConstants::MAX_VECTOR_SIZE) { - throw OutOfRangeException("Cannot resize vector to %s: maximum allowed vector size is %s", - StringUtil::BytesToHumanReadableString(target_size), - StringUtil::BytesToHumanReadableString(DConstants::MAX_VECTOR_SIZE)); - } - - // Copy the data buffer to a resized buffer. - auto new_data = make_unsafe_uniq_array_uninitialized(target_size); - memcpy(new_data.get(), resize_info_entry.data, old_size); - resize_info_entry.buffer->SetData(std::move(new_data)); - resize_info_entry.vec.data = resize_info_entry.buffer->GetData(); - } -} - -static bool IsStructOrArrayRecursive(const LogicalType &type) { - return TypeVisitor::Contains(type, [](const LogicalType &type) { - auto physical_type = type.InternalType(); - return (physical_type == PhysicalType::STRUCT || physical_type == PhysicalType::ARRAY); - }); -} - -void Vector::SetValue(idx_t index, const Value &val) { - if (GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // dictionary: apply dictionary and forward to child - auto &sel_vector = DictionaryVector::SelVector(*this); - auto &child = DictionaryVector::Child(*this); - return child.SetValue(sel_vector.get_index(index), val); - } - if (!val.IsNull() && val.type() != GetType()) { - SetValue(index, val.DefaultCastAs(GetType())); - return; - } - D_ASSERT(val.IsNull() || (val.type().InternalType() == GetType().InternalType())); - - validity.EnsureWritable(); - validity.Set(index, !val.IsNull()); - auto physical_type = GetType().InternalType(); - if (val.IsNull() && !IsStructOrArrayRecursive(GetType())) { - // for structs and arrays we still need to set the child-entries to NULL - // so we do not bail out yet - return; - } - - switch (physical_type) { - case PhysicalType::BOOL: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT8: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT16: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT32: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT64: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INT128: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT8: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT16: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT32: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT64: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::UINT128: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::FLOAT: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::DOUBLE: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::INTERVAL: - reinterpret_cast(data)[index] = val.GetValueUnsafe(); - break; - case PhysicalType::VARCHAR: { - if (!val.IsNull()) { - reinterpret_cast(data)[index] = StringVector::AddStringOrBlob(*this, StringValue::Get(val)); - } - break; - } - case PhysicalType::STRUCT: { - D_ASSERT(GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR); - - auto &children = StructVector::GetEntries(*this); - if (val.IsNull()) { - for (size_t i = 0; i < children.size(); i++) { - auto &vec_child = children[i]; - vec_child->SetValue(index, Value()); - } - } else { - auto &val_children = StructValue::GetChildren(val); - D_ASSERT(children.size() == val_children.size()); - for (size_t i = 0; i < children.size(); i++) { - auto &vec_child = children[i]; - auto &struct_child = val_children[i]; - vec_child->SetValue(index, struct_child); - } - } - break; - } - case PhysicalType::LIST: { - auto offset = ListVector::GetListSize(*this); - if (val.IsNull()) { - auto &entry = reinterpret_cast(data)[index]; - ListVector::PushBack(*this, Value()); - entry.length = 1; - entry.offset = offset; - } else { - auto &val_children = ListValue::GetChildren(val); - if (!val_children.empty()) { - for (idx_t i = 0; i < val_children.size(); i++) { - ListVector::PushBack(*this, val_children[i]); - } - } - //! now set the pointer - auto &entry = reinterpret_cast(data)[index]; - entry.length = val_children.size(); - entry.offset = offset; - } - break; - } - case PhysicalType::ARRAY: { - auto array_size = ArrayType::GetSize(GetType()); - auto &child = ArrayVector::GetEntry(*this); - if (val.IsNull()) { - for (idx_t i = 0; i < array_size; i++) { - child.SetValue(index * array_size + i, Value()); - } - } else { - auto &val_children = ArrayValue::GetChildren(val); - for (idx_t i = 0; i < array_size; i++) { - child.SetValue(index * array_size + i, val_children[i]); - } - } - break; - } - default: - throw InternalException("Unimplemented type for Vector::SetValue"); - } -} - -Value Vector::GetValueInternal(const Vector &v_p, idx_t index_p) { - const Vector *vector = &v_p; - idx_t index = index_p; - bool finished = false; - while (!finished) { - switch (vector->GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - index = 0; - finished = true; - break; - case VectorType::FLAT_VECTOR: - finished = true; - break; - case VectorType::FSST_VECTOR: - finished = true; - break; - // dictionary: apply dictionary and forward to child - case VectorType::DICTIONARY_VECTOR: { - auto &sel_vector = DictionaryVector::SelVector(*vector); - auto &child = DictionaryVector::Child(*vector); - vector = &child; - index = sel_vector.get_index(index); - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*vector, start, increment); - return Value::Numeric(vector->GetType(), start + increment * NumericCast(index)); - } - default: - throw InternalException("Unimplemented vector type for Vector::GetValue"); - } - } - auto data = vector->data; - auto &validity = vector->validity; - auto &type = vector->GetType(); - - if (!validity.RowIsValid(index)) { - return Value(vector->GetType()); - } - - if (vector->GetVectorType() == VectorType::FSST_VECTOR) { - if (vector->GetType().InternalType() != PhysicalType::VARCHAR) { - throw InternalException("FSST Vector with non-string datatype found!"); - } - auto str_compressed = reinterpret_cast(data)[index]; - auto decoder = FSSTVector::GetDecoder(*vector); - auto &decompress_buffer = FSSTVector::GetDecompressBuffer(*vector); - auto string_val = FSSTPrimitives::DecompressValue(decoder, str_compressed.GetData(), str_compressed.GetSize(), - decompress_buffer); - switch (vector->GetType().id()) { - case LogicalTypeId::VARCHAR: - return Value(std::move(string_val)); - case LogicalTypeId::BLOB: - return Value::BLOB_RAW(string_val); - default: - throw InternalException("Unsupported vector type for FSST vector"); - } - } - - switch (vector->GetType().id()) { - case LogicalTypeId::BOOLEAN: - return Value::BOOLEAN(reinterpret_cast(data)[index]); - case LogicalTypeId::TINYINT: - return Value::TINYINT(reinterpret_cast(data)[index]); - case LogicalTypeId::SMALLINT: - return Value::SMALLINT(reinterpret_cast(data)[index]); - case LogicalTypeId::INTEGER: - return Value::INTEGER(reinterpret_cast(data)[index]); - case LogicalTypeId::DATE: - return Value::DATE(reinterpret_cast(data)[index]); - case LogicalTypeId::TIME: - return Value::TIME(reinterpret_cast(data)[index]); - case LogicalTypeId::TIME_TZ: - return Value::TIMETZ(reinterpret_cast(data)[index]); - case LogicalTypeId::BIGINT: - return Value::BIGINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UTINYINT: - return Value::UTINYINT(reinterpret_cast(data)[index]); - case LogicalTypeId::USMALLINT: - return Value::USMALLINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UINTEGER: - return Value::UINTEGER(reinterpret_cast(data)[index]); - case LogicalTypeId::UBIGINT: - return Value::UBIGINT(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP: - return Value::TIMESTAMP(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_NS: - return Value::TIMESTAMPNS(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_MS: - return Value::TIMESTAMPMS(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_SEC: - return Value::TIMESTAMPSEC(reinterpret_cast(data)[index]); - case LogicalTypeId::TIMESTAMP_TZ: - return Value::TIMESTAMPTZ(reinterpret_cast(data)[index]); - case LogicalTypeId::HUGEINT: - return Value::HUGEINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UHUGEINT: - return Value::UHUGEINT(reinterpret_cast(data)[index]); - case LogicalTypeId::UUID: - return Value::UUID(reinterpret_cast(data)[index]); - case LogicalTypeId::DECIMAL: { - auto width = DecimalType::GetWidth(type); - auto scale = DecimalType::GetScale(type); - switch (type.InternalType()) { - case PhysicalType::INT16: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT32: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT64: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - case PhysicalType::INT128: - return Value::DECIMAL(reinterpret_cast(data)[index], width, scale); - default: - throw InternalException("Physical type '%s' has a width bigger than 38, which is not supported", - TypeIdToString(type.InternalType())); - } - } - case LogicalTypeId::ENUM: { - switch (type.InternalType()) { - case PhysicalType::UINT8: - return Value::ENUM(reinterpret_cast(data)[index], type); - case PhysicalType::UINT16: - return Value::ENUM(reinterpret_cast(data)[index], type); - case PhysicalType::UINT32: - return Value::ENUM(reinterpret_cast(data)[index], type); - default: - throw InternalException("ENUM can only have unsigned integers as physical types"); - } - } - case LogicalTypeId::POINTER: - return Value::POINTER(reinterpret_cast(data)[index]); - case LogicalTypeId::FLOAT: - return Value::FLOAT(reinterpret_cast(data)[index]); - case LogicalTypeId::DOUBLE: - return Value::DOUBLE(reinterpret_cast(data)[index]); - case LogicalTypeId::INTERVAL: - return Value::INTERVAL(reinterpret_cast(data)[index]); - case LogicalTypeId::VARCHAR: { - auto str = reinterpret_cast(data)[index]; - return Value(str.GetString()); - } - case LogicalTypeId::BLOB: { - auto str = reinterpret_cast(data)[index]; - return Value::BLOB(const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::VARINT: { - auto str = reinterpret_cast(data)[index]; - return Value::VARINT(const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::AGGREGATE_STATE: { - auto str = reinterpret_cast(data)[index]; - return Value::AGGREGATE_STATE(vector->GetType(), const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::BIT: { - auto str = reinterpret_cast(data)[index]; - return Value::BIT(const_data_ptr_cast(str.GetData()), str.GetSize()); - } - case LogicalTypeId::MAP: { - auto offlen = reinterpret_cast(data)[index]; - auto &child_vec = ListVector::GetEntry(*vector); - duckdb::vector children; - for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { - children.push_back(child_vec.GetValue(i)); - } - return Value::MAP(ListType::GetChildType(type), std::move(children)); - } - case LogicalTypeId::UNION: { - // Remember to pass the original index_p here so we dont slice twice when looking up the tag - // in case this is a dictionary vector - union_tag_t tag; - if (UnionVector::TryGetTag(*vector, index_p, tag)) { - auto value = UnionVector::GetMember(*vector, tag).GetValue(index_p); - auto members = UnionType::CopyMemberTypes(type); - return Value::UNION(members, tag, std::move(value)); - } else { - return Value(vector->GetType()); - } - } - case LogicalTypeId::STRUCT: { - // we can derive the value schema from the vector schema - auto &child_entries = StructVector::GetEntries(*vector); - child_list_t children; - for (idx_t child_idx = 0; child_idx < child_entries.size(); child_idx++) { - auto &struct_child = child_entries[child_idx]; - children.push_back(make_pair(StructType::GetChildName(type, child_idx), struct_child->GetValue(index_p))); - } - return Value::STRUCT(std::move(children)); - } - case LogicalTypeId::LIST: { - auto offlen = reinterpret_cast(data)[index]; - auto &child_vec = ListVector::GetEntry(*vector); - duckdb::vector children; - for (idx_t i = offlen.offset; i < offlen.offset + offlen.length; i++) { - children.push_back(child_vec.GetValue(i)); - } - return Value::LIST(ListType::GetChildType(type), std::move(children)); - } - case LogicalTypeId::ARRAY: { - auto stride = ArrayType::GetSize(type); - auto offset = index * stride; - auto &child_vec = ArrayVector::GetEntry(*vector); - duckdb::vector children; - for (idx_t i = offset; i < offset + stride; i++) { - children.push_back(child_vec.GetValue(i)); - } - return Value::ARRAY(ArrayType::GetChildType(type), std::move(children)); - } - default: - throw InternalException("Unimplemented type for value access"); - } -} - -Value Vector::GetValue(const Vector &v_p, idx_t index_p) { - auto value = GetValueInternal(v_p, index_p); - // set the alias of the type to the correct value, if there is a type alias - if (v_p.GetType().HasAlias()) { - value.GetTypeMutable().CopyAuxInfo(v_p.GetType()); - } - if (v_p.GetType().id() != LogicalTypeId::AGGREGATE_STATE && value.type().id() != LogicalTypeId::AGGREGATE_STATE) { - - D_ASSERT(v_p.GetType() == value.type()); - } - return value; -} - -Value Vector::GetValue(idx_t index) const { - return GetValue(*this, index); -} - -// LCOV_EXCL_START -string VectorTypeToString(VectorType type) { - switch (type) { - case VectorType::FLAT_VECTOR: - return "FLAT"; - case VectorType::FSST_VECTOR: - return "FSST"; - case VectorType::SEQUENCE_VECTOR: - return "SEQUENCE"; - case VectorType::DICTIONARY_VECTOR: - return "DICTIONARY"; - case VectorType::CONSTANT_VECTOR: - return "CONSTANT"; - default: - return "UNKNOWN"; - } -} - -string Vector::ToString(idx_t count) const { - string retval = - VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": " + to_string(count) + " = [ "; - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - case VectorType::DICTIONARY_VECTOR: - for (idx_t i = 0; i < count; i++) { - retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); - } - break; - case VectorType::FSST_VECTOR: { - for (idx_t i = 0; i < count; i++) { - string_t compressed_string = reinterpret_cast(data)[i]; - auto decoder = FSSTVector::GetDecoder(*this); - auto &decompress_buffer = FSSTVector::GetDecompressBuffer(*this); - Value val = FSSTPrimitives::DecompressValue(decoder, compressed_string.GetData(), - compressed_string.GetSize(), decompress_buffer); - retval += GetValue(i).ToString() + (i == count - 1 ? "" : ", "); - } - } break; - case VectorType::CONSTANT_VECTOR: - retval += GetValue(0).ToString(); - break; - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*this, start, increment); - for (idx_t i = 0; i < count; i++) { - retval += to_string(start + increment * UnsafeNumericCast(i)) + (i == count - 1 ? "" : ", "); - } - break; - } - default: - retval += "UNKNOWN VECTOR TYPE"; - break; - } - retval += "]"; - return retval; -} - -void Vector::Print(idx_t count) const { - Printer::Print(ToString(count)); -} - -// TODO: add the size of validity masks to this -idx_t Vector::GetAllocationSize(idx_t cardinality) const { - if (!type.IsNested()) { - auto physical_size = GetTypeIdSize(type.InternalType()); - return cardinality * physical_size; - } - auto internal_type = type.InternalType(); - switch (internal_type) { - case PhysicalType::LIST: { - auto physical_size = GetTypeIdSize(type.InternalType()); - auto total_size = physical_size * cardinality; - - auto child_cardinality = ListVector::GetListCapacity(*this); - auto &child_entry = ListVector::GetEntry(*this); - total_size += (child_entry.GetAllocationSize(child_cardinality)); - return total_size; - } - case PhysicalType::ARRAY: { - auto child_cardinality = ArrayVector::GetTotalSize(*this); - - auto &child_entry = ArrayVector::GetEntry(*this); - auto total_size = (child_entry.GetAllocationSize(child_cardinality)); - return total_size; - } - case PhysicalType::STRUCT: { - idx_t total_size = 0; - auto &children = StructVector::GetEntries(*this); - for (auto &child : children) { - total_size += child->GetAllocationSize(cardinality); - } - return total_size; - } - default: - throw NotImplementedException("Vector::GetAllocationSize not implemented for type: %s", type.ToString()); - break; - } -} - -string Vector::ToString() const { - string retval = VectorTypeToString(GetVectorType()) + " " + GetType().ToString() + ": (UNKNOWN COUNT) [ "; - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - case VectorType::DICTIONARY_VECTOR: - break; - case VectorType::CONSTANT_VECTOR: - retval += GetValue(0).ToString(); - break; - case VectorType::SEQUENCE_VECTOR: { - break; - } - default: - retval += "UNKNOWN VECTOR TYPE"; - break; - } - retval += "]"; - return retval; -} - -void Vector::Print() const { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -template -static void TemplatedFlattenConstantVector(data_ptr_t data, data_ptr_t old_data, idx_t count) { - auto constant = Load(old_data); - auto output = (T *)data; - for (idx_t i = 0; i < count; i++) { - output[i] = constant; - } -} - -void Vector::Flatten(idx_t count) { - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - // already a flat vector - break; - case VectorType::FSST_VECTOR: { - // Even though count may only be a part of the vector, we need to flatten the whole thing due to the way - // ToUnifiedFormat uses flatten - idx_t total_count = FSSTVector::GetCount(*this); - // create vector to decompress into - Vector other(GetType(), total_count); - // now copy the data of this vector to the other vector, decompressing the strings in the process - VectorOperations::Copy(*this, other, total_count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::DICTIONARY_VECTOR: { - // create a new flat vector of this type - Vector other(GetType(), count); - // now copy the data of this vector to the other vector, removing the selection vector in the process - VectorOperations::Copy(*this, other, count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::CONSTANT_VECTOR: { - bool is_null = ConstantVector::IsNull(*this); - // allocate a new buffer for the vector - auto old_buffer = std::move(buffer); - auto old_data = data; - buffer = VectorBuffer::CreateStandardVector(type, MaxValue(STANDARD_VECTOR_SIZE, count)); - if (old_buffer) { - D_ASSERT(buffer->GetAuxiliaryData() == nullptr); - // The old buffer might be relying on the auxiliary data, keep it alive - buffer->MoveAuxiliaryData(*old_buffer); - } - data = buffer->GetData(); - vector_type = VectorType::FLAT_VECTOR; - if (is_null && GetType().InternalType() != PhysicalType::ARRAY) { - // constant NULL, set nullmask - validity.EnsureWritable(); - validity.SetAllInvalid(count); - if (GetType().InternalType() != PhysicalType::STRUCT) { - // for structs we still need to flatten the child vectors as well - return; - } - } - // non-null constant: have to repeat the constant - switch (GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT8: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT16: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT32: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT64: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT8: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT16: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT32: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT64: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INT128: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::UINT128: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::FLOAT: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::DOUBLE: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::INTERVAL: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::VARCHAR: - TemplatedFlattenConstantVector(data, old_data, count); - break; - case PhysicalType::LIST: { - TemplatedFlattenConstantVector(data, old_data, count); - break; - } - case PhysicalType::ARRAY: { - auto &original_child = ArrayVector::GetEntry(*this); - auto array_size = ArrayType::GetSize(GetType()); - auto flattened_buffer = make_uniq(GetType(), count); - auto &new_child = flattened_buffer->GetChild(); - - // Fast path: The array is a constant null - if (is_null) { - // Invalidate the parent array - validity.SetAllInvalid(count); - // Also invalidate the new child array - new_child.validity.SetAllInvalid(count * array_size); - // Recurse - new_child.Flatten(count * array_size); - // TODO: the fast path should exit here, but the part below it is somehow required for correctness - // Attach the flattened buffer and return - // auxiliary = shared_ptr(flattened_buffer.release()); - // return; - } - - // Now we need to "unpack" the child vector. - // Basically, do this: - // - // | a1 | | 1 | | a1 | | 1 | - // | 2 | | a2 | | 2 | - // => .. | 1 | - // | 2 | - // ... - - auto child_vec = make_uniq(original_child); - child_vec->Flatten(count * array_size); - - // Create a selection vector - SelectionVector sel(count * array_size); - for (idx_t array_idx = 0; array_idx < count; array_idx++) { - for (idx_t elem_idx = 0; elem_idx < array_size; elem_idx++) { - auto position = array_idx * array_size + elem_idx; - // Broadcast the validity - if (FlatVector::IsNull(*child_vec, elem_idx)) { - FlatVector::SetNull(new_child, position, true); - } - sel.set_index(position, elem_idx); - } - } - - // Copy over the data to the new buffer - VectorOperations::Copy(*child_vec, new_child, sel, count * array_size, 0, 0); - auxiliary = shared_ptr(flattened_buffer.release()); - - break; - } - case PhysicalType::STRUCT: { - auto normalified_buffer = make_uniq(); - - auto &new_children = normalified_buffer->GetChildren(); - - auto &child_entries = StructVector::GetEntries(*this); - for (auto &child : child_entries) { - D_ASSERT(child->GetVectorType() == VectorType::CONSTANT_VECTOR); - auto vector = make_uniq(*child); - vector->Flatten(count); - new_children.push_back(std::move(vector)); - } - auxiliary = shared_ptr(normalified_buffer.release()); - break; - } - default: - throw InternalException("Unimplemented type for VectorOperations::Flatten"); - } - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment, sequence_count; - SequenceVector::GetSequence(*this, start, increment, sequence_count); - auto seq_count = NumericCast(sequence_count); - - buffer = VectorBuffer::CreateStandardVector(GetType(), MaxValue(STANDARD_VECTOR_SIZE, seq_count)); - data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, seq_count, start, increment); - break; - } - default: - throw InternalException("Unimplemented type for normalify"); - } -} - -void Vector::Flatten(const SelectionVector &sel, idx_t count) { - switch (GetVectorType()) { - case VectorType::FLAT_VECTOR: - // already a flat vector - break; - case VectorType::FSST_VECTOR: { - // create a new flat vector of this type - Vector other(GetType(), count); - // copy the data of this vector to the other vector, removing compression and selection vector in the process - VectorOperations::Copy(*this, other, sel, count, 0, 0); - // create a reference to the data in the other vector - this->Reference(other); - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - SequenceVector::GetSequence(*this, start, increment); - - buffer = VectorBuffer::CreateStandardVector(GetType()); - data = buffer->GetData(); - VectorOperations::GenerateSequence(*this, count, sel, start, increment); - break; - } - default: - throw InternalException("Unimplemented type for normalify with selection vector"); - } -} - -void Vector::ToUnifiedFormat(idx_t count, UnifiedVectorFormat &format) { - switch (GetVectorType()) { - case VectorType::DICTIONARY_VECTOR: { - auto &sel = DictionaryVector::SelVector(*this); - format.owned_sel.Initialize(sel); - format.sel = &format.owned_sel; - - auto &child = DictionaryVector::Child(*this); - if (child.GetVectorType() == VectorType::FLAT_VECTOR) { - format.data = FlatVector::GetData(child); - format.validity = FlatVector::Validity(child); - } else { - // dictionary with non-flat child: create a new reference to the child and flatten it - Vector child_vector(child); - child_vector.Flatten(sel, count); - auto new_aux = make_buffer(std::move(child_vector)); - - format.data = FlatVector::GetData(new_aux->data); - format.validity = FlatVector::Validity(new_aux->data); - this->auxiliary = std::move(new_aux); - } - break; - } - case VectorType::CONSTANT_VECTOR: - format.sel = ConstantVector::ZeroSelectionVector(count, format.owned_sel); - format.data = ConstantVector::GetData(*this); - format.validity = ConstantVector::Validity(*this); - break; - default: - Flatten(count); - format.sel = FlatVector::IncrementalSelectionVector(); - format.data = FlatVector::GetData(*this); - format.validity = FlatVector::Validity(*this); - break; - } -} - -void Vector::RecursiveToUnifiedFormat(Vector &input, idx_t count, RecursiveUnifiedVectorFormat &data) { - - input.ToUnifiedFormat(count, data.unified); - data.logical_type = input.GetType(); - - if (input.GetType().InternalType() == PhysicalType::LIST) { - auto &child = ListVector::GetEntry(input); - auto child_count = ListVector::GetListSize(input); - data.children.emplace_back(); - Vector::RecursiveToUnifiedFormat(child, child_count, data.children.back()); - - } else if (input.GetType().InternalType() == PhysicalType::ARRAY) { - auto &child = ArrayVector::GetEntry(input); - auto array_size = ArrayType::GetSize(input.GetType()); - auto child_count = count * array_size; - data.children.emplace_back(); - Vector::RecursiveToUnifiedFormat(child, child_count, data.children.back()); - - } else if (input.GetType().InternalType() == PhysicalType::STRUCT) { - auto &children = StructVector::GetEntries(input); - for (idx_t i = 0; i < children.size(); i++) { - data.children.emplace_back(); - } - for (idx_t i = 0; i < children.size(); i++) { - Vector::RecursiveToUnifiedFormat(*children[i], count, data.children[i]); - } - } -} - -void Vector::Sequence(int64_t start, int64_t increment, idx_t count) { - this->vector_type = VectorType::SEQUENCE_VECTOR; - this->buffer = make_buffer(sizeof(int64_t) * 3); - auto data = reinterpret_cast(buffer->GetData()); - data[0] = start; - data[1] = increment; - data[2] = int64_t(count); - validity.Reset(); - auxiliary.reset(); -} - -// FIXME: This should ideally be const -void Vector::Serialize(Serializer &serializer, idx_t count) { - auto &logical_type = GetType(); - - UnifiedVectorFormat vdata; - ToUnifiedFormat(count, vdata); - - const bool has_validity_mask = (count > 0) && !vdata.validity.AllValid(); - serializer.WriteProperty(100, "has_validity_mask", has_validity_mask); - if (has_validity_mask) { - ValidityMask flat_mask(count); - flat_mask.Initialize(); - for (idx_t i = 0; i < count; ++i) { - auto row_idx = vdata.sel->get_index(i); - flat_mask.Set(i, vdata.validity.RowIsValid(row_idx)); - } - serializer.WriteProperty(101, "validity", const_data_ptr_cast(flat_mask.GetData()), - flat_mask.ValidityMaskSize(count)); - } - if (TypeIsConstantSize(logical_type.InternalType())) { - // constant size type: simple copy - idx_t write_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array_uninitialized(write_size); - VectorOperations::WriteToStorage(*this, count, ptr.get()); - serializer.WriteProperty(102, "data", ptr.get(), write_size); - } else { - switch (logical_type.InternalType()) { - case PhysicalType::VARCHAR: { - auto strings = UnifiedVectorFormat::GetData(vdata); - - // Serialize data as a list - serializer.WriteList(102, "data", count, [&](Serializer::List &list, idx_t i) { - auto idx = vdata.sel->get_index(i); - auto str = !vdata.validity.RowIsValid(idx) ? NullValue() : strings[idx]; - list.WriteElement(str); - }); - break; - } - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(*this); - - // Serialize entries as a list - serializer.WriteList(103, "children", entries.size(), [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { entries[i]->Serialize(object, count); }); - }); - break; - } - case PhysicalType::LIST: { - auto &child = ListVector::GetEntry(*this); - auto list_size = ListVector::GetListSize(*this); - - // serialize the list entries in a flat array - auto entries = make_unsafe_uniq_array_uninitialized(count); - auto source_array = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - auto source = source_array[idx]; - if (vdata.validity.RowIsValid(idx)) { - entries[i].offset = source.offset; - entries[i].length = source.length; - } else { - entries[i].offset = 0; - entries[i].length = 0; - } - } - serializer.WriteProperty(104, "list_size", list_size); - serializer.WriteList(105, "entries", count, [&](Serializer::List &list, idx_t i) { - list.WriteObject([&](Serializer &object) { - object.WriteProperty(100, "offset", entries[i].offset); - object.WriteProperty(101, "length", entries[i].length); - }); - }); - serializer.WriteObject(106, "child", [&](Serializer &object) { child.Serialize(object, list_size); }); - break; - } - case PhysicalType::ARRAY: { - Vector serialized_vector(*this); - serialized_vector.Flatten(count); - - auto &child = ArrayVector::GetEntry(serialized_vector); - auto array_size = ArrayType::GetSize(serialized_vector.GetType()); - auto child_size = array_size * count; - serializer.WriteProperty(103, "array_size", array_size); - serializer.WriteObject(104, "child", [&](Serializer &object) { child.Serialize(object, child_size); }); - break; - } - default: - throw InternalException("Unimplemented variable width type for Vector::Serialize!"); - } - } -} - -void Vector::Deserialize(Deserializer &deserializer, idx_t count) { - auto &logical_type = GetType(); - - auto &validity = FlatVector::Validity(*this); - auto validity_count = MaxValue(count, STANDARD_VECTOR_SIZE); - validity.Reset(validity_count); - const auto has_validity_mask = deserializer.ReadProperty(100, "has_validity_mask"); - if (has_validity_mask) { - validity.Initialize(validity_count); - deserializer.ReadProperty(101, "validity", data_ptr_cast(validity.GetData()), validity.ValidityMaskSize(count)); - } - - if (TypeIsConstantSize(logical_type.InternalType())) { - // constant size type: read fixed amount of data - auto column_size = GetTypeIdSize(logical_type.InternalType()) * count; - auto ptr = make_unsafe_uniq_array_uninitialized(column_size); - deserializer.ReadProperty(102, "data", ptr.get(), column_size); - - VectorOperations::ReadFromStorage(ptr.get(), count, *this); - } else { - switch (logical_type.InternalType()) { - case PhysicalType::VARCHAR: { - auto strings = FlatVector::GetData(*this); - deserializer.ReadList(102, "data", [&](Deserializer::List &list, idx_t i) { - auto str = list.ReadElement(); - if (validity.RowIsValid(i)) { - strings[i] = StringVector::AddStringOrBlob(*this, str); - } - }); - break; - } - case PhysicalType::STRUCT: { - auto &entries = StructVector::GetEntries(*this); - // Deserialize entries as a list - deserializer.ReadList(103, "children", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &obj) { entries[i]->Deserialize(obj, count); }); - }); - break; - } - case PhysicalType::LIST: { - // Read the list size - auto list_size = deserializer.ReadProperty(104, "list_size"); - ListVector::Reserve(*this, list_size); - ListVector::SetListSize(*this, list_size); - - // Read the entries - auto list_entries = FlatVector::GetData(*this); - deserializer.ReadList(105, "entries", [&](Deserializer::List &list, idx_t i) { - list.ReadObject([&](Deserializer &obj) { - list_entries[i].offset = obj.ReadProperty(100, "offset"); - list_entries[i].length = obj.ReadProperty(101, "length"); - }); - }); - - // Read the child vector - deserializer.ReadObject(106, "child", [&](Deserializer &obj) { - auto &child = ListVector::GetEntry(*this); - child.Deserialize(obj, list_size); - }); - break; - } - case PhysicalType::ARRAY: { - auto array_size = deserializer.ReadProperty(103, "array_size"); - deserializer.ReadObject(104, "child", [&](Deserializer &obj) { - auto &child = ArrayVector::GetEntry(*this); - child.Deserialize(obj, array_size * count); - }); - break; - } - default: - throw InternalException("Unimplemented variable width type for Vector::Deserialize!"); - } - } -} - -void Vector::SetVectorType(VectorType vector_type_p) { - vector_type = vector_type_p; - auto physical_type = GetType().InternalType(); - auto flat_or_const = GetVectorType() == VectorType::CONSTANT_VECTOR || GetVectorType() == VectorType::FLAT_VECTOR; - if (TypeIsConstantSize(physical_type) && flat_or_const) { - auxiliary.reset(); - } - if (vector_type == VectorType::CONSTANT_VECTOR && physical_type == PhysicalType::STRUCT) { - auto &entries = StructVector::GetEntries(*this); - for (auto &entry : entries) { - entry->SetVectorType(vector_type); - } - } -} - -void Vector::UTFVerify(const SelectionVector &sel, idx_t count) { -#ifdef DEBUG - if (count == 0) { - return; - } - if (GetType().InternalType() == PhysicalType::VARCHAR) { - // we just touch all the strings and let the sanitizer figure out if any - // of them are deallocated/corrupt - switch (GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - auto string = ConstantVector::GetData(*this); - if (!ConstantVector::IsNull(*this)) { - string->Verify(); - } - break; - } - case VectorType::FLAT_VECTOR: { - auto strings = FlatVector::GetData(*this); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel.get_index(i); - if (validity.RowIsValid(oidx)) { - strings[oidx].Verify(); - } - } - break; - } - default: - break; - } - } -#endif -} - -void Vector::UTFVerify(idx_t count) { - auto flat_sel = FlatVector::IncrementalSelectionVector(); - - UTFVerify(*flat_sel, count); -} - -void Vector::VerifyMap(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - D_ASSERT(vector_p.GetType().id() == LogicalTypeId::MAP); - auto &child = ListType::GetChildType(vector_p.GetType()); - D_ASSERT(StructType::GetChildCount(child) == 2); - D_ASSERT(StructType::GetChildName(child, 0) == "key"); - D_ASSERT(StructType::GetChildName(child, 1) == "value"); - - auto valid_check = MapVector::CheckMapValidity(vector_p, count, sel_p); - D_ASSERT(valid_check == MapInvalidReason::VALID); -#endif // DEBUG -} - -void Vector::VerifyUnion(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - - D_ASSERT(vector_p.GetType().id() == LogicalTypeId::UNION); - auto valid_check = UnionVector::CheckUnionValidity(vector_p, count, sel_p); - if (valid_check != UnionInvalidReason::VALID) { - throw InternalException("Union not valid, reason: %s", EnumUtil::ToString(valid_check)); - } -#endif // DEBUG -} - -void Vector::Verify(Vector &vector_p, const SelectionVector &sel_p, idx_t count) { -#ifdef DEBUG - if (count == 0) { - return; - } - Vector *vector = &vector_p; - const SelectionVector *sel = &sel_p; - SelectionVector owned_sel; - auto &type = vector->GetType(); - auto vtype = vector->GetVectorType(); - if (vector->GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(*vector); - D_ASSERT(child.GetVectorType() != VectorType::DICTIONARY_VECTOR); - auto &dict_sel = DictionaryVector::SelVector(*vector); - // merge the selection vectors and verify the child - auto new_buffer = dict_sel.Slice(*sel, count); - owned_sel.Initialize(new_buffer); - sel = &owned_sel; - vector = &child; - vtype = vector->GetVectorType(); - } - if (TypeIsConstantSize(type.InternalType()) && - (vtype == VectorType::CONSTANT_VECTOR || vtype == VectorType::FLAT_VECTOR)) { - D_ASSERT(!vector->auxiliary); - } - if (type.id() == LogicalTypeId::VARCHAR) { - // verify that the string is correct unicode - switch (vtype) { - case VectorType::FLAT_VECTOR: { - auto &validity = FlatVector::Validity(*vector); - auto strings = FlatVector::GetData(*vector); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - strings[oidx].Verify(); - } - } - break; - } - default: - break; - } - } - - if (type.id() == LogicalTypeId::VARINT) { - switch (vtype) { - case VectorType::FLAT_VECTOR: { - auto &validity = FlatVector::Validity(*vector); - auto strings = FlatVector::GetData(*vector); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - Varint::Verify(strings[oidx]); - } - } - } break; - default: - break; - } - } - - if (type.id() == LogicalTypeId::BIT) { - switch (vtype) { - case VectorType::FLAT_VECTOR: { - auto &validity = FlatVector::Validity(*vector); - auto strings = FlatVector::GetData(*vector); - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - auto buf = strings[oidx].GetData(); - D_ASSERT(*buf >= 0 && *buf < 8); - Bit::Verify(strings[oidx]); - } - } - break; - } - default: - break; - } - } - - if (type.InternalType() == PhysicalType::ARRAY) { - // Arrays have the following invariants - // 1. if the array vector is a CONSTANT_VECTOR - // 1.1 The child vector is a FLAT_VECTOR with count = array_size - // 1.2 OR The child vector is a CONSTANT_VECTOR and must be NULL - // 1.3 OR The child vector is a CONSTANT_VECTOR and array_size = 1 - // 2. if the array vector is a FLAT_VECTOR, the child vector is a FLAT_VECTOR - // 2.2 the count of the child vector is array_size * (parent)count - - auto &child = ArrayVector::GetEntry(*vector); - auto array_size = ArrayType::GetSize(type); - - if (child.GetVectorType() == VectorType::CONSTANT_VECTOR) { - D_ASSERT(ConstantVector::IsNull(child)); - } else { - D_ASSERT(child.GetVectorType() == VectorType::FLAT_VECTOR); - } - - if (vtype == VectorType::CONSTANT_VECTOR) { - if (!ConstantVector::IsNull(*vector)) { - child.Verify(array_size); - } - } else if (vtype == VectorType::FLAT_VECTOR) { - // Flat vector case - auto &validity = FlatVector::Validity(*vector); - idx_t selected_child_count = 0; - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - selected_child_count += array_size; - } - } - - SelectionVector child_sel(selected_child_count); - idx_t child_count = 0; - for (idx_t i = 0; i < count; i++) { - auto oidx = sel->get_index(i); - if (validity.RowIsValid(oidx)) { - for (idx_t j = 0; j < array_size; j++) { - child_sel.set_index(child_count++, oidx * array_size + j); - } - } - } - Vector::Verify(child, child_sel, child_count); - } - } - - if (type.InternalType() == PhysicalType::STRUCT) { - auto &child_types = StructType::GetChildTypes(type); - D_ASSERT(!child_types.empty()); - - // create a selection vector of the non-null entries of the struct vector - auto &children = StructVector::GetEntries(*vector); - D_ASSERT(child_types.size() == children.size()); - for (idx_t child_idx = 0; child_idx < children.size(); child_idx++) { - D_ASSERT(children[child_idx]->GetType() == child_types[child_idx].second); - Vector::Verify(*children[child_idx], sel_p, count); - if (vtype == VectorType::CONSTANT_VECTOR) { - D_ASSERT(children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR); - if (ConstantVector::IsNull(*vector)) { - D_ASSERT(ConstantVector::IsNull(*children[child_idx])); - } - } - if (vtype != VectorType::FLAT_VECTOR) { - continue; - } - optional_ptr child_validity; - SelectionVector owned_child_sel; - const SelectionVector *child_sel = &owned_child_sel; - if (children[child_idx]->GetVectorType() == VectorType::FLAT_VECTOR) { - child_sel = FlatVector::IncrementalSelectionVector(); - child_validity = &FlatVector::Validity(*children[child_idx]); - } else if (children[child_idx]->GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(*children[child_idx]); - if (child.GetVectorType() != VectorType::FLAT_VECTOR) { - continue; - } - child_validity = &FlatVector::Validity(child); - child_sel = &DictionaryVector::SelVector(*children[child_idx]); - } else if (children[child_idx]->GetVectorType() == VectorType::CONSTANT_VECTOR) { - child_sel = ConstantVector::ZeroSelectionVector(count, owned_child_sel); - child_validity = &ConstantVector::Validity(*children[child_idx]); - } else { - continue; - } - // for any NULL entry in the struct, the child should be NULL as well - auto &validity = FlatVector::Validity(*vector); - for (idx_t i = 0; i < count; i++) { - auto index = sel->get_index(i); - if (!validity.RowIsValid(index)) { - auto child_index = child_sel->get_index(sel_p.get_index(i)); - D_ASSERT(!child_validity->RowIsValid(child_index)); - } - } - } - - if (vector->GetType().id() == LogicalTypeId::UNION) { - // Pass in raw vector - VerifyUnion(vector_p, sel_p, count); - } - } - - if (type.InternalType() == PhysicalType::LIST) { - if (vtype == VectorType::CONSTANT_VECTOR) { - if (!ConstantVector::IsNull(*vector)) { - auto &child = ListVector::GetEntry(*vector); - SelectionVector child_sel(ListVector::GetListSize(*vector)); - idx_t child_count = 0; - auto le = ConstantVector::GetData(*vector); - D_ASSERT(le->offset + le->length <= ListVector::GetListSize(*vector)); - for (idx_t k = 0; k < le->length; k++) { - child_sel.set_index(child_count++, le->offset + k); - } - Vector::Verify(child, child_sel, child_count); - } - } else if (vtype == VectorType::FLAT_VECTOR) { - auto &validity = FlatVector::Validity(*vector); - auto &child = ListVector::GetEntry(*vector); - auto child_size = ListVector::GetListSize(*vector); - auto list_data = FlatVector::GetData(*vector); - idx_t total_size = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel->get_index(i); - auto &le = list_data[idx]; - if (validity.RowIsValid(idx)) { - D_ASSERT(le.offset + le.length <= child_size); - total_size += le.length; - } - } - SelectionVector child_sel(total_size); - idx_t child_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel->get_index(i); - auto &le = list_data[idx]; - if (validity.RowIsValid(idx)) { - D_ASSERT(le.offset + le.length <= child_size); - for (idx_t k = 0; k < le.length; k++) { - child_sel.set_index(child_count++, le.offset + k); - } - } - } - Vector::Verify(child, child_sel, child_count); - } - - if (vector->GetType().id() == LogicalTypeId::MAP) { - VerifyMap(*vector, *sel, count); - } - } -#endif -} - -void Vector::Verify(idx_t count) { - auto flat_sel = FlatVector::IncrementalSelectionVector(); - Verify(*this, *flat_sel, count); -} - -void Vector::DebugTransformToDictionary(Vector &vector, idx_t count) { - if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { - // only supported for flat vectors currently - return; - } - // convert vector to dictionary vector - // first create an inverted vector of twice the size with NULL values every other value - // i.e. [1, 2, 3] is converted into [NULL, 3, NULL, 2, NULL, 1] - idx_t verify_count = count * 2; - SelectionVector inverted_sel(verify_count); - idx_t offset = 0; - for (idx_t i = 0; i < count; i++) { - idx_t current_index = count - i - 1; - inverted_sel.set_index(offset++, current_index); - inverted_sel.set_index(offset++, current_index); - } - Vector inverted_vector(vector, inverted_sel, verify_count); - inverted_vector.Flatten(verify_count); - // now insert the NULL values at every other position - for (idx_t i = 0; i < count; i++) { - FlatVector::SetNull(inverted_vector, i * 2, true); - } - // construct the selection vector pointing towards the original values - // we start at the back, (verify_count - 1) and move backwards - SelectionVector original_sel(count); - offset = 0; - for (idx_t i = 0; i < count; i++) { - original_sel.set_index(offset++, verify_count - 1 - i * 2); - } - // now slice the inverted vector with the inverted selection vector - vector.Slice(inverted_vector, original_sel, count); - vector.Verify(count); -} - -void Vector::DebugShuffleNestedVector(Vector &vector, idx_t count) { - switch (vector.GetType().id()) { - case LogicalTypeId::STRUCT: { - auto &entries = StructVector::GetEntries(vector); - // recurse into child elements - for (auto &entry : entries) { - Vector::DebugShuffleNestedVector(*entry, count); - } - break; - } - case LogicalTypeId::LIST: { - if (vector.GetVectorType() != VectorType::FLAT_VECTOR) { - break; - } - auto list_entries = FlatVector::GetData(vector); - idx_t child_count = 0; - for (idx_t r = 0; r < count; r++) { - if (FlatVector::IsNull(vector, r)) { - continue; - } - child_count += list_entries[r].length; - } - if (child_count == 0) { - break; - } - auto &child_vector = ListVector::GetEntry(vector); - // reverse the order of all lists - SelectionVector child_sel(child_count); - idx_t position = child_count; - for (idx_t r = 0; r < count; r++) { - if (FlatVector::IsNull(vector, r)) { - continue; - } - // move this list to the back - position -= list_entries[r].length; - for (idx_t k = 0; k < list_entries[r].length; k++) { - child_sel.set_index(position + k, list_entries[r].offset + k); - } - // adjust the offset to this new position - list_entries[r].offset = position; - } - child_vector.Slice(child_sel, child_count); - child_vector.Flatten(child_count); - ListVector::SetListSize(vector, child_count); - - // recurse into child elements - Vector::DebugShuffleNestedVector(child_vector, child_count); - break; - } - default: - break; - } -} - -//===--------------------------------------------------------------------===// -// FlatVector -//===--------------------------------------------------------------------===// -void FlatVector::SetNull(Vector &vector, idx_t idx, bool is_null) { - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); - vector.validity.Set(idx, !is_null); - if (!is_null) { - return; - } - - auto &type = vector.GetType(); - auto internal_type = type.InternalType(); - - // Set all child entries to NULL. - if (internal_type == PhysicalType::STRUCT) { - auto &entries = StructVector::GetEntries(vector); - for (auto &entry : entries) { - FlatVector::SetNull(*entry, idx, is_null); - } - return; - } - - // Set all child entries to NULL. - if (internal_type == PhysicalType::ARRAY) { - auto &child = ArrayVector::GetEntry(vector); - auto array_size = ArrayType::GetSize(type); - auto child_offset = idx * array_size; - for (idx_t i = 0; i < array_size; i++) { - FlatVector::SetNull(child, child_offset + i, is_null); - } - } -} - -//===--------------------------------------------------------------------===// -// ConstantVector -//===--------------------------------------------------------------------===// -void ConstantVector::SetNull(Vector &vector, bool is_null) { - D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - vector.validity.Set(0, !is_null); - if (is_null) { - auto &type = vector.GetType(); - auto internal_type = type.InternalType(); - if (internal_type == PhysicalType::STRUCT) { - // set all child entries to null as well - auto &entries = StructVector::GetEntries(vector); - for (auto &entry : entries) { - entry->SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(*entry, is_null); - } - } else if (internal_type == PhysicalType::ARRAY) { - auto &child = ArrayVector::GetEntry(vector); - D_ASSERT(child.GetVectorType() == VectorType::CONSTANT_VECTOR || - child.GetVectorType() == VectorType::FLAT_VECTOR); - auto array_size = ArrayType::GetSize(type); - if (child.GetVectorType() == VectorType::CONSTANT_VECTOR) { - D_ASSERT(array_size == 1); - ConstantVector::SetNull(child, is_null); - } else { - for (idx_t i = 0; i < array_size; i++) { - FlatVector::SetNull(child, i, is_null); - } - } - } - } -} - -const SelectionVector *ConstantVector::ZeroSelectionVector(idx_t count, SelectionVector &owned_sel) { - if (count <= STANDARD_VECTOR_SIZE) { - return ConstantVector::ZeroSelectionVector(); - } - owned_sel.Initialize(count); - for (idx_t i = 0; i < count; i++) { - owned_sel.set_index(i, 0); - } - return &owned_sel; -} - -void ConstantVector::Reference(Vector &vector, Vector &source, idx_t position, idx_t count) { - auto &source_type = source.GetType(); - switch (source_type.InternalType()) { - case PhysicalType::LIST: { - // retrieve the list entry from the source vector - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto list_index = vdata.sel->get_index(position); - if (!vdata.validity.RowIsValid(list_index)) { - // list is null: create null value - Value null_value(source_type); - vector.Reference(null_value); - break; - } - - auto list_data = UnifiedVectorFormat::GetData(vdata); - auto list_entry = list_data[list_index]; - - // add the list entry as the first element of "vector" - // FIXME: we only need to allocate space for 1 tuple here - auto target_data = FlatVector::GetData(vector); - target_data[0] = list_entry; - - // create a reference to the child list of the source vector - auto &child = ListVector::GetEntry(vector); - child.Reference(ListVector::GetEntry(source)); - - ListVector::SetListSize(vector, ListVector::GetListSize(source)); - vector.SetVectorType(VectorType::CONSTANT_VECTOR); - break; - } - case PhysicalType::ARRAY: { - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - auto source_idx = vdata.sel->get_index(position); - if (!vdata.validity.RowIsValid(source_idx)) { - // list is null: create null value - Value null_value(source_type); - vector.Reference(null_value); - break; - } - - // Reference the child vector - auto &target_child = ArrayVector::GetEntry(vector); - auto &source_child = ArrayVector::GetEntry(source); - target_child.Reference(source_child); - - // Only take the element at the given position - auto array_size = ArrayType::GetSize(source_type); - SelectionVector sel(array_size); - for (idx_t i = 0; i < array_size; i++) { - sel.set_index(i, array_size * source_idx + i); - } - target_child.Slice(sel, array_size); - target_child.Flatten(array_size); // since its constant we only have to flatten this much - - vector.SetVectorType(VectorType::CONSTANT_VECTOR); - vector.validity.Set(0, true); - break; - } - case PhysicalType::STRUCT: { - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - auto struct_index = vdata.sel->get_index(position); - if (!vdata.validity.RowIsValid(struct_index)) { - // null struct: create null value - Value null_value(source_type); - vector.Reference(null_value); - break; - } - - // struct: pass constant reference into child entries - auto &source_entries = StructVector::GetEntries(source); - auto &target_entries = StructVector::GetEntries(vector); - for (idx_t i = 0; i < source_entries.size(); i++) { - ConstantVector::Reference(*target_entries[i], *source_entries[i], position, count); - } - vector.SetVectorType(VectorType::CONSTANT_VECTOR); - vector.validity.Set(0, true); - break; - } - default: - // default behavior: get a value from the vector and reference it - // this is not that expensive for scalar types - auto value = source.GetValue(position); - vector.Reference(value); - D_ASSERT(vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - break; - } -} - -//===--------------------------------------------------------------------===// -// StringVector -//===--------------------------------------------------------------------===// -string_t StringVector::AddString(Vector &vector, const char *data, idx_t len) { - return StringVector::AddString(vector, string_t(data, UnsafeNumericCast(len))); -} - -string_t StringVector::AddStringOrBlob(Vector &vector, const char *data, idx_t len) { - return StringVector::AddStringOrBlob(vector, string_t(data, UnsafeNumericCast(len))); -} - -string_t StringVector::AddString(Vector &vector, const char *data) { - return StringVector::AddString(vector, string_t(data, UnsafeNumericCast(strlen(data)))); -} - -string_t StringVector::AddString(Vector &vector, const string &data) { - return StringVector::AddString(vector, string_t(data.c_str(), UnsafeNumericCast(data.size()))); -} - -string_t StringVector::AddString(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::VARCHAR || vector.GetType().id() == LogicalTypeId::BIT); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary.get()->Cast(); - return string_buffer.AddString(data); -} - -string_t StringVector::AddStringOrBlob(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary.get()->Cast(); - return string_buffer.AddBlob(data); -} - -string_t StringVector::EmptyString(Vector &vector, idx_t len) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (len <= string_t::INLINE_LENGTH) { - return string_t(UnsafeNumericCast(len)); - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRING_BUFFER); - auto &string_buffer = vector.auxiliary.get()->Cast(); - return string_buffer.EmptyString(len); -} - -void StringVector::AddHandle(Vector &vector, BufferHandle handle) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - auto &string_buffer = vector.auxiliary->Cast(); - string_buffer.AddHeapReference(make_buffer(std::move(handle))); -} - -void StringVector::AddBuffer(Vector &vector, buffer_ptr buffer) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(buffer.get() != vector.auxiliary.get()); - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - auto &string_buffer = vector.auxiliary->Cast(); - string_buffer.AddHeapReference(std::move(buffer)); -} - -void StringVector::AddHeapReference(Vector &vector, Vector &other) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - D_ASSERT(other.GetType().InternalType() == PhysicalType::VARCHAR); - - if (other.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - StringVector::AddHeapReference(vector, DictionaryVector::Child(other)); - return; - } - if (!other.auxiliary) { - return; - } - StringVector::AddBuffer(vector, other.auxiliary); -} - -//===--------------------------------------------------------------------===// -// FSSTVector -//===--------------------------------------------------------------------===// -string_t FSSTVector::AddCompressedString(Vector &vector, const char *data, idx_t len) { - return FSSTVector::AddCompressedString(vector, string_t(data, UnsafeNumericCast(len))); -} - -string_t FSSTVector::AddCompressedString(Vector &vector, string_t data) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (data.IsInlined()) { - // string will be inlined: no need to store in string heap - return data; - } - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary.get()->Cast(); - return fsst_string_buffer.AddBlob(data); -} - -void *FSSTVector::GetDecoder(const Vector &vector) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (!vector.auxiliary) { - throw InternalException("GetDecoder called on FSST Vector without registered buffer"); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.GetDecoder(); -} - -vector &FSSTVector::GetDecompressBuffer(const Vector &vector) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - if (!vector.auxiliary) { - throw InternalException("GetDecompressBuffer called on FSST Vector without registered buffer"); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.GetDecompressBuffer(); -} - -void FSSTVector::RegisterDecoder(Vector &vector, buffer_ptr &duckdb_fsst_decoder, - const idx_t string_block_limit) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - fsst_string_buffer.AddDecoder(duckdb_fsst_decoder, string_block_limit); -} - -void FSSTVector::SetCount(Vector &vector, idx_t count) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - fsst_string_buffer.SetCount(count); -} - -idx_t FSSTVector::GetCount(Vector &vector) { - D_ASSERT(vector.GetType().InternalType() == PhysicalType::VARCHAR); - - if (!vector.auxiliary) { - vector.auxiliary = make_buffer(); - } - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::FSST_BUFFER); - - auto &fsst_string_buffer = vector.auxiliary->Cast(); - return fsst_string_buffer.GetCount(); -} - -void FSSTVector::DecompressVector(const Vector &src, Vector &dst, idx_t src_offset, idx_t dst_offset, idx_t copy_count, - const SelectionVector *sel) { - D_ASSERT(src.GetVectorType() == VectorType::FSST_VECTOR); - D_ASSERT(dst.GetVectorType() == VectorType::FLAT_VECTOR); - auto dst_mask = FlatVector::Validity(dst); - auto ldata = FSSTVector::GetCompressedData(src); - auto tdata = FlatVector::GetData(dst); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(src_offset + i); - auto target_idx = dst_offset + i; - string_t compressed_string = ldata[source_idx]; - if (dst_mask.RowIsValid(target_idx) && compressed_string.GetSize() > 0) { - auto decoder = FSSTVector::GetDecoder(src); - auto &decompress_buffer = FSSTVector::GetDecompressBuffer(src); - tdata[target_idx] = FSSTPrimitives::DecompressValue(decoder, dst, compressed_string.GetData(), - compressed_string.GetSize(), decompress_buffer); - } else { - tdata[target_idx] = string_t(nullptr, 0); - } - } -} - -//===--------------------------------------------------------------------===// -// MapVector -//===--------------------------------------------------------------------===// -Vector &MapVector::GetKeys(Vector &vector) { - auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); - D_ASSERT(entries.size() == 2); - return *entries[0]; -} -Vector &MapVector::GetValues(Vector &vector) { - auto &entries = StructVector::GetEntries(ListVector::GetEntry(vector)); - D_ASSERT(entries.size() == 2); - return *entries[1]; -} - -const Vector &MapVector::GetKeys(const Vector &vector) { - return GetKeys((Vector &)vector); -} -const Vector &MapVector::GetValues(const Vector &vector) { - return GetValues((Vector &)vector); -} - -MapInvalidReason MapVector::CheckMapValidity(Vector &map, idx_t count, const SelectionVector &sel) { - - D_ASSERT(map.GetType().id() == LogicalTypeId::MAP); - - // unify the MAP vector, which is a physical LIST vector - UnifiedVectorFormat map_data; - map.ToUnifiedFormat(count, map_data); - auto map_entries = UnifiedVectorFormat::GetDataNoConst(map_data); - auto maps_length = ListVector::GetListSize(map); - - // unify the child vector containing the keys - auto &keys = MapVector::GetKeys(map); - UnifiedVectorFormat key_data; - keys.ToUnifiedFormat(maps_length, key_data); - - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - - auto mapped_row = sel.get_index(row_idx); - auto map_idx = map_data.sel->get_index(mapped_row); - - if (!map_data.validity.RowIsValid(map_idx)) { - continue; - } - - value_set_t unique_keys; - auto length = map_entries[map_idx].length; - auto offset = map_entries[map_idx].offset; - - for (idx_t child_idx = 0; child_idx < length; child_idx++) { - auto key_idx = key_data.sel->get_index(offset + child_idx); - - if (!key_data.validity.RowIsValid(key_idx)) { - return MapInvalidReason::NULL_KEY; - } - - auto value = keys.GetValue(key_idx); - auto unique = unique_keys.insert(value).second; - if (!unique) { - return MapInvalidReason::DUPLICATE_KEY; - } - } - } - - return MapInvalidReason::VALID; -} - -void MapVector::MapConversionVerify(Vector &vector, idx_t count) { - auto reason = MapVector::CheckMapValidity(vector, count); - EvalMapInvalidReason(reason); -} - -void MapVector::EvalMapInvalidReason(MapInvalidReason reason) { - switch (reason) { - case MapInvalidReason::VALID: - return; - case MapInvalidReason::DUPLICATE_KEY: - throw InvalidInputException("Map keys must be unique."); - case MapInvalidReason::NULL_KEY: - throw InvalidInputException("Map keys can not be NULL."); - case MapInvalidReason::NOT_ALIGNED: - throw InvalidInputException("The map key list does not align with the map value list."); - case MapInvalidReason::INVALID_PARAMS: - throw InvalidInputException("Invalid map argument(s). Valid map arguments are a list of key-value pairs (MAP " - "{'key1': 'val1', ...}), two lists (MAP ([1, 2], [10, 11])), or no arguments."); - default: - throw InternalException("MapInvalidReason not implemented"); - } -} - -//===--------------------------------------------------------------------===// -// StructVector -//===--------------------------------------------------------------------===// -vector> &StructVector::GetEntries(Vector &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::STRUCT || vector.GetType().id() == LogicalTypeId::UNION); - - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return StructVector::GetEntries(child); - } - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::STRUCT_BUFFER); - return vector.auxiliary->Cast().GetChildren(); -} - -const vector> &StructVector::GetEntries(const Vector &vector) { - return GetEntries((Vector &)vector); -} - -//===--------------------------------------------------------------------===// -// ListVector -//===--------------------------------------------------------------------===// -template -T &ListVector::GetEntryInternal(T &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return ListVector::GetEntry(child); - } - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); - return vector.auxiliary->template Cast().GetChild(); -} - -const Vector &ListVector::GetEntry(const Vector &vector) { - return GetEntryInternal(vector); -} - -Vector &ListVector::GetEntry(Vector &vector) { - return GetEntryInternal(vector); -} - -void ListVector::Reserve(Vector &vector, idx_t required_capacity) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST || vector.GetType().id() == LogicalTypeId::MAP); - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::LIST_BUFFER); - auto &child_buffer = vector.auxiliary->Cast(); - child_buffer.Reserve(required_capacity); -} - -idx_t ListVector::GetListSize(const Vector &vec) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - return ListVector::GetListSize(child); - } - D_ASSERT(vec.auxiliary); - return vec.auxiliary->Cast().GetSize(); -} - -idx_t ListVector::GetListCapacity(const Vector &vec) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - return ListVector::GetListSize(child); - } - D_ASSERT(vec.auxiliary); - return vec.auxiliary->Cast().GetCapacity(); -} - -void ListVector::ReferenceEntry(Vector &vector, Vector &other) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::LIST); - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(other.GetType().id() == LogicalTypeId::LIST); - D_ASSERT(other.GetVectorType() == VectorType::FLAT_VECTOR || other.GetVectorType() == VectorType::CONSTANT_VECTOR); - vector.auxiliary = other.auxiliary; -} - -void ListVector::SetListSize(Vector &vec, idx_t size) { - if (vec.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vec); - ListVector::SetListSize(child, size); - return; - } - vec.auxiliary->Cast().SetSize(size); -} - -void ListVector::Append(Vector &target, const Vector &source, idx_t source_size, idx_t source_offset) { - if (source_size - source_offset == 0) { - //! Nothing to add - return; - } - auto &target_buffer = target.auxiliary->Cast(); - target_buffer.Append(source, source_size, source_offset); -} - -void ListVector::Append(Vector &target, const Vector &source, const SelectionVector &sel, idx_t source_size, - idx_t source_offset) { - if (source_size - source_offset == 0) { - //! Nothing to add - return; - } - auto &target_buffer = target.auxiliary->Cast(); - target_buffer.Append(source, sel, source_size, source_offset); -} - -void ListVector::PushBack(Vector &target, const Value &insert) { - auto &target_buffer = target.auxiliary.get()->Cast(); - target_buffer.PushBack(insert); -} - -idx_t ListVector::GetConsecutiveChildList(Vector &list, Vector &result, idx_t offset, idx_t count) { - - auto info = ListVector::GetConsecutiveChildListInfo(list, offset, count); - if (info.needs_slicing) { - SelectionVector sel(info.child_list_info.length); - ListVector::GetConsecutiveChildSelVector(list, sel, offset, count); - - result.Slice(sel, info.child_list_info.length); - result.Flatten(info.child_list_info.length); - } - return info.child_list_info.length; -} - -ConsecutiveChildListInfo ListVector::GetConsecutiveChildListInfo(Vector &list, idx_t offset, idx_t count) { - - ConsecutiveChildListInfo info; - UnifiedVectorFormat unified_list_data; - list.ToUnifiedFormat(offset + count, unified_list_data); - auto list_data = UnifiedVectorFormat::GetData(unified_list_data); - - // find the first non-NULL entry - idx_t first_length = 0; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - info.child_list_info.offset = list_data[idx].offset; - first_length = list_data[idx].length; - break; - } - - // small performance improvement for constant vectors - // avoids iterating over all their (constant) elements - if (list.GetVectorType() == VectorType::CONSTANT_VECTOR) { - info.child_list_info.length = first_length; - return info; - } - - // now get the child count and determine whether the children are stored consecutively - // also determine if a flat vector has pseudo constant values (all offsets + length the same) - // this can happen e.g. for UNNESTs - bool is_consecutive = true; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - if (list_data[idx].offset != info.child_list_info.offset || list_data[idx].length != first_length) { - info.is_constant = false; - } - if (list_data[idx].offset != info.child_list_info.offset + info.child_list_info.length) { - is_consecutive = false; - } - info.child_list_info.length += list_data[idx].length; - } - - if (info.is_constant) { - info.child_list_info.length = first_length; - } - if (!info.is_constant && !is_consecutive) { - info.needs_slicing = true; - } - - return info; -} - -void ListVector::GetConsecutiveChildSelVector(Vector &list, SelectionVector &sel, idx_t offset, idx_t count) { - UnifiedVectorFormat unified_list_data; - list.ToUnifiedFormat(offset + count, unified_list_data); - auto list_data = UnifiedVectorFormat::GetData(unified_list_data); - - // SelectionVector child_sel(info.second.length); - idx_t entry = 0; - for (idx_t i = offset; i < offset + count; i++) { - auto idx = unified_list_data.sel->get_index(i); - if (!unified_list_data.validity.RowIsValid(idx)) { - continue; - } - for (idx_t k = 0; k < list_data[idx].length; k++) { - // child_sel.set_index(entry++, list_data[idx].offset + k); - sel.set_index(entry++, list_data[idx].offset + k); - } - } - // - // result.Slice(child_sel, info.second.length); - // result.Flatten(info.second.length); - // info.second.offset = 0; -} - -//===--------------------------------------------------------------------===// -// UnionVector -//===--------------------------------------------------------------------===// -const Vector &UnionVector::GetMember(const Vector &vector, idx_t member_index) { - D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); - auto &entries = StructVector::GetEntries(vector); - return *entries[member_index + 1]; // skip the "tag" entry -} - -Vector &UnionVector::GetMember(Vector &vector, idx_t member_index) { - D_ASSERT(member_index < UnionType::GetMemberCount(vector.GetType())); - auto &entries = StructVector::GetEntries(vector); - return *entries[member_index + 1]; // skip the "tag" entry -} - -const Vector &UnionVector::GetTags(const Vector &vector) { - // the tag vector is always the first struct child. - return *StructVector::GetEntries(vector)[0]; -} - -Vector &UnionVector::GetTags(Vector &vector) { - // the tag vector is always the first struct child. - return *StructVector::GetEntries(vector)[0]; -} - -void UnionVector::SetToMember(Vector &union_vector, union_tag_t tag, Vector &member_vector, idx_t count, - bool keep_tags_for_null) { - D_ASSERT(union_vector.GetType().id() == LogicalTypeId::UNION); - D_ASSERT(tag < UnionType::GetMemberCount(union_vector.GetType())); - - // Set the union member to the specified vector - UnionVector::GetMember(union_vector, tag).Reference(member_vector); - auto &tag_vector = UnionVector::GetTags(union_vector); - - if (member_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // if the member vector is constant, we can set the union to constant as well - union_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(tag_vector)[0] = tag; - if (keep_tags_for_null) { - ConstantVector::SetNull(union_vector, false); - ConstantVector::SetNull(tag_vector, false); - } else { - ConstantVector::SetNull(union_vector, ConstantVector::IsNull(member_vector)); - ConstantVector::SetNull(tag_vector, ConstantVector::IsNull(member_vector)); - } - - } else { - // otherwise flatten and set to flatvector - member_vector.Flatten(count); - union_vector.SetVectorType(VectorType::FLAT_VECTOR); - - if (member_vector.validity.AllValid()) { - // if the member vector is all valid, we can set the tag to constant - tag_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - auto tag_data = ConstantVector::GetData(tag_vector); - *tag_data = tag; - } else { - tag_vector.SetVectorType(VectorType::FLAT_VECTOR); - if (keep_tags_for_null) { - FlatVector::Validity(tag_vector).SetAllValid(count); - FlatVector::Validity(union_vector).SetAllValid(count); - } else { - // ensure the tags have the same validity as the member - FlatVector::Validity(union_vector) = FlatVector::Validity(member_vector); - FlatVector::Validity(tag_vector) = FlatVector::Validity(member_vector); - } - - auto tag_data = FlatVector::GetData(tag_vector); - memset(tag_data, tag, count); - } - } - - // Set the non-selected members to constant null vectors - for (idx_t i = 0; i < UnionType::GetMemberCount(union_vector.GetType()); i++) { - if (i != tag) { - auto &member = UnionVector::GetMember(union_vector, i); - member.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(member, true); - } - } -} - -bool UnionVector::TryGetTag(const Vector &vector, idx_t index, union_tag_t &result) { - // the tag vector is always the first struct child. - auto &tag_vector = *StructVector::GetEntries(vector)[0]; - if (tag_vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(tag_vector); - auto &dict_sel = DictionaryVector::SelVector(tag_vector); - auto mapped_idx = dict_sel.get_index(index); - if (FlatVector::IsNull(child, mapped_idx)) { - return false; - } else { - result = FlatVector::GetData(child)[mapped_idx]; - return true; - } - } - if (tag_vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(tag_vector)) { - return false; - } else { - result = ConstantVector::GetData(tag_vector)[0]; - return true; - } - } - if (FlatVector::IsNull(tag_vector, index)) { - return false; - } else { - result = FlatVector::GetData(tag_vector)[index]; - return true; - } -} - -//! Raw selection vector passed in (not merged with any other selection vectors) -UnionInvalidReason UnionVector::CheckUnionValidity(Vector &vector_p, idx_t count, const SelectionVector &sel_p) { - D_ASSERT(vector_p.GetType().id() == LogicalTypeId::UNION); - - // Will contain the (possibly) merged selection vector - const SelectionVector *sel = &sel_p; - SelectionVector owned_sel; - Vector *vector = &vector_p; - if (vector->GetVectorType() == VectorType::DICTIONARY_VECTOR) { - // In the case of a dictionary vector, unwrap the Vector, and merge the selection vectors. - auto &child = DictionaryVector::Child(*vector); - D_ASSERT(child.GetVectorType() != VectorType::DICTIONARY_VECTOR); - auto &dict_sel = DictionaryVector::SelVector(*vector); - // merge the selection vectors and verify the child - auto new_buffer = dict_sel.Slice(*sel, count); - owned_sel.Initialize(new_buffer); - sel = &owned_sel; - vector = &child; - } else if (vector->GetVectorType() == VectorType::CONSTANT_VECTOR) { - sel = ConstantVector::ZeroSelectionVector(count, owned_sel); - } - - auto member_count = UnionType::GetMemberCount(vector_p.GetType()); - if (member_count == 0) { - return UnionInvalidReason::NO_MEMBERS; - } - - UnifiedVectorFormat vector_vdata; - vector_p.ToUnifiedFormat(count, vector_vdata); - - auto &entries = StructVector::GetEntries(vector_p); - duckdb::vector child_vdata(entries.size()); - for (idx_t entry_idx = 0; entry_idx < entries.size(); entry_idx++) { - auto &child = *entries[entry_idx]; - child.ToUnifiedFormat(count, child_vdata[entry_idx]); - } - - auto &tag_vdata = child_vdata[0]; - - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto mapped_idx = sel->get_index(row_idx); - - if (!vector_vdata.validity.RowIsValid(mapped_idx)) { - continue; - } - - auto tag_idx = tag_vdata.sel->get_index(sel_p.get_index(row_idx)); - if (!tag_vdata.validity.RowIsValid(tag_idx)) { - // we can't have NULL tags! - return UnionInvalidReason::NULL_TAG; - } - auto tag = UnifiedVectorFormat::GetData(tag_vdata)[tag_idx]; - if (tag >= member_count) { - return UnionInvalidReason::TAG_OUT_OF_RANGE; - } - - bool found_valid = false; - for (idx_t i = 0; i < member_count; i++) { - auto &member_vdata = child_vdata[1 + i]; // skip the tag - idx_t member_idx = member_vdata.sel->get_index(sel_p.get_index(row_idx)); - if (!member_vdata.validity.RowIsValid(member_idx)) { - continue; - } - if (found_valid) { - return UnionInvalidReason::VALIDITY_OVERLAP; - } - found_valid = true; - if (tag != static_cast(i)) { - return UnionInvalidReason::TAG_MISMATCH; - } - } - } - - return UnionInvalidReason::VALID; -} - -//===--------------------------------------------------------------------===// -// ArrayVector -//===--------------------------------------------------------------------===// -template -T &ArrayVector::GetEntryInternal(T &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::ARRAY); - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return ArrayVector::GetEntry(child); - } - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR || - vector.GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(vector.auxiliary); - D_ASSERT(vector.auxiliary->GetBufferType() == VectorBufferType::ARRAY_BUFFER); - return vector.auxiliary->template Cast().GetChild(); -} - -const Vector &ArrayVector::GetEntry(const Vector &vector) { - return GetEntryInternal(vector); -} - -Vector &ArrayVector::GetEntry(Vector &vector) { - return GetEntryInternal(vector); -} - -idx_t ArrayVector::GetTotalSize(const Vector &vector) { - D_ASSERT(vector.GetType().id() == LogicalTypeId::ARRAY); - D_ASSERT(vector.auxiliary); - if (vector.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - auto &child = DictionaryVector::Child(vector); - return ArrayVector::GetTotalSize(child); - } - return vector.auxiliary->Cast().GetChildSize(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_buffer.cpp b/src/duckdb/src/common/types/vector_buffer.cpp deleted file mode 100644 index aed61d192..000000000 --- a/src/duckdb/src/common/types/vector_buffer.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include "duckdb/common/types/vector_buffer.hpp" - -#include "duckdb/common/assert.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/storage/buffer/buffer_handle.hpp" - -namespace duckdb { - -buffer_ptr VectorBuffer::CreateStandardVector(PhysicalType type, idx_t capacity) { - return make_buffer(capacity * GetTypeIdSize(type)); -} - -buffer_ptr VectorBuffer::CreateConstantVector(PhysicalType type) { - return make_buffer(GetTypeIdSize(type)); -} - -buffer_ptr VectorBuffer::CreateConstantVector(const LogicalType &type) { - return VectorBuffer::CreateConstantVector(type.InternalType()); -} - -buffer_ptr VectorBuffer::CreateStandardVector(const LogicalType &type, idx_t capacity) { - return VectorBuffer::CreateStandardVector(type.InternalType(), capacity); -} - -VectorStringBuffer::VectorStringBuffer() : VectorBuffer(VectorBufferType::STRING_BUFFER) { -} - -VectorStringBuffer::VectorStringBuffer(VectorBufferType type) : VectorBuffer(type) { -} - -VectorFSSTStringBuffer::VectorFSSTStringBuffer() : VectorStringBuffer(VectorBufferType::FSST_BUFFER) { -} - -VectorStructBuffer::VectorStructBuffer() : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { -} - -VectorStructBuffer::VectorStructBuffer(const LogicalType &type, idx_t capacity) - : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { - auto &child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - auto vector = make_uniq(child_type.second, capacity); - children.push_back(std::move(vector)); - } -} - -VectorStructBuffer::VectorStructBuffer(Vector &other, const SelectionVector &sel, idx_t count) - : VectorBuffer(VectorBufferType::STRUCT_BUFFER) { - auto &other_vector = StructVector::GetEntries(other); - for (auto &child_vector : other_vector) { - auto vector = make_uniq(*child_vector, sel, count); - children.push_back(std::move(vector)); - } -} - -VectorStructBuffer::~VectorStructBuffer() { -} - -VectorListBuffer::VectorListBuffer(unique_ptr vector, idx_t initial_capacity) - : VectorBuffer(VectorBufferType::LIST_BUFFER), child(std::move(vector)), capacity(initial_capacity) { -} - -VectorListBuffer::VectorListBuffer(const LogicalType &list_type, idx_t initial_capacity) - : VectorBuffer(VectorBufferType::LIST_BUFFER), - child(make_uniq(ListType::GetChildType(list_type), initial_capacity)), capacity(initial_capacity) { -} - -void VectorListBuffer::Reserve(idx_t to_reserve) { - if (to_reserve > capacity) { - if (to_reserve > DConstants::MAX_VECTOR_SIZE) { - // overflow: throw an exception - throw OutOfRangeException("Cannot resize vector to %d rows: maximum allowed vector size is %s", to_reserve, - StringUtil::BytesToHumanReadableString(DConstants::MAX_VECTOR_SIZE)); - } - idx_t new_capacity = NextPowerOfTwo(to_reserve); - D_ASSERT(new_capacity >= to_reserve); - child->Resize(capacity, new_capacity); - capacity = new_capacity; - } -} - -void VectorListBuffer::Append(const Vector &to_append, idx_t to_append_size, idx_t source_offset) { - Reserve(size + to_append_size - source_offset); - VectorOperations::Copy(to_append, *child, to_append_size, source_offset, size); - size += to_append_size - source_offset; -} - -void VectorListBuffer::Append(const Vector &to_append, const SelectionVector &sel, idx_t to_append_size, - idx_t source_offset) { - Reserve(size + to_append_size - source_offset); - VectorOperations::Copy(to_append, *child, sel, to_append_size, source_offset, size); - size += to_append_size - source_offset; -} - -void VectorListBuffer::PushBack(const Value &insert) { - while (size + 1 > capacity) { - child->Resize(capacity, capacity * 2); - capacity *= 2; - } - child->SetValue(size++, insert); -} - -void VectorListBuffer::SetCapacity(idx_t new_capacity) { - this->capacity = new_capacity; -} - -void VectorListBuffer::SetSize(idx_t new_size) { - this->size = new_size; -} - -VectorListBuffer::~VectorListBuffer() { -} - -VectorArrayBuffer::VectorArrayBuffer(unique_ptr child_vector, idx_t array_size, idx_t initial_capacity) - : VectorBuffer(VectorBufferType::ARRAY_BUFFER), child(std::move(child_vector)), array_size(array_size), - size(initial_capacity) { - D_ASSERT(array_size != 0); -} - -VectorArrayBuffer::VectorArrayBuffer(const LogicalType &array, idx_t initial) - : VectorBuffer(VectorBufferType::ARRAY_BUFFER), - child(make_uniq(ArrayType::GetChildType(array), initial * ArrayType::GetSize(array))), - array_size(ArrayType::GetSize(array)), size(initial) { - // initialize the child array with (array_size * size) ^ - D_ASSERT(!ArrayType::IsAnySize(array)); -} - -VectorArrayBuffer::~VectorArrayBuffer() { -} - -Vector &VectorArrayBuffer::GetChild() { - return *child; -} - -idx_t VectorArrayBuffer::GetArraySize() { - return array_size; -} - -idx_t VectorArrayBuffer::GetChildSize() { - return size * array_size; -} - -ManagedVectorBuffer::ManagedVectorBuffer(BufferHandle handle) - : VectorBuffer(VectorBufferType::MANAGED_BUFFER), handle(std::move(handle)) { -} - -ManagedVectorBuffer::~ManagedVectorBuffer() { -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_cache.cpp b/src/duckdb/src/common/types/vector_cache.cpp deleted file mode 100644 index 49ffe3579..000000000 --- a/src/duckdb/src/common/types/vector_cache.cpp +++ /dev/null @@ -1,142 +0,0 @@ -#include "duckdb/common/types/vector_cache.hpp" - -#include "duckdb/common/allocator.hpp" -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -class VectorCacheBuffer : public VectorBuffer { -public: - explicit VectorCacheBuffer(Allocator &allocator, const LogicalType &type_p, idx_t capacity_p = STANDARD_VECTOR_SIZE) - : VectorBuffer(VectorBufferType::OPAQUE_BUFFER), type(type_p), capacity(capacity_p) { - auto internal_type = type.InternalType(); - switch (internal_type) { - case PhysicalType::LIST: { - // memory for the list offsets - owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); - // child data of the list - auto &child_type = ListType::GetChildType(type); - child_caches.push_back(make_buffer(allocator, child_type, capacity)); - auto child_vector = make_uniq(child_type, false, false); - auxiliary = make_shared_ptr(std::move(child_vector)); - break; - } - case PhysicalType::ARRAY: { - auto &child_type = ArrayType::GetChildType(type); - auto array_size = ArrayType::GetSize(type); - child_caches.push_back(make_buffer(allocator, child_type, array_size * capacity)); - auto child_vector = make_uniq(child_type, true, false, array_size * capacity); - auxiliary = make_shared_ptr(std::move(child_vector), array_size, capacity); - break; - } - case PhysicalType::STRUCT: { - auto &child_types = StructType::GetChildTypes(type); - for (auto &child_type : child_types) { - child_caches.push_back(make_buffer(allocator, child_type.second, capacity)); - } - auto struct_buffer = make_shared_ptr(type); - auxiliary = std::move(struct_buffer); - break; - } - default: - owned_data = allocator.Allocate(capacity * GetTypeIdSize(internal_type)); - break; - } - } - - void ResetFromCache(Vector &result, const buffer_ptr &buffer) { - D_ASSERT(type == result.GetType()); - auto internal_type = type.InternalType(); - result.vector_type = VectorType::FLAT_VECTOR; - AssignSharedPointer(result.buffer, buffer); - result.validity.Reset(capacity); - switch (internal_type) { - case PhysicalType::LIST: { - result.data = owned_data.get(); - // reinitialize the VectorListBuffer - AssignSharedPointer(result.auxiliary, auxiliary); - // propagate through child - auto &child_cache = child_caches[0]->Cast(); - auto &list_buffer = result.auxiliary->Cast(); - list_buffer.SetCapacity(child_cache.capacity); - list_buffer.SetSize(0); - list_buffer.SetAuxiliaryData(nullptr); - - auto &list_child = list_buffer.GetChild(); - child_cache.ResetFromCache(list_child, child_caches[0]); - break; - } - case PhysicalType::ARRAY: { - // fixed size list does not have own data - result.data = nullptr; - // reinitialize the VectorArrayBuffer - // auxiliary->SetAuxiliaryData(nullptr); - AssignSharedPointer(result.auxiliary, auxiliary); - - // propagate through child - auto &child_cache = child_caches[0]->Cast(); - auto &array_child = result.auxiliary->Cast().GetChild(); - child_cache.ResetFromCache(array_child, child_caches[0]); - break; - } - case PhysicalType::STRUCT: { - // struct does not have data - result.data = nullptr; - // reinitialize the VectorStructBuffer - auxiliary->SetAuxiliaryData(nullptr); - AssignSharedPointer(result.auxiliary, auxiliary); - // propagate through children - auto &children = result.auxiliary->Cast().GetChildren(); - for (idx_t i = 0; i < children.size(); i++) { - auto &child_cache = child_caches[i]->Cast(); - child_cache.ResetFromCache(*children[i], child_caches[i]); - } - break; - } - default: - // regular type: no aux data and reset data to cached data - result.data = owned_data.get(); - result.auxiliary.reset(); - break; - } - } - - const LogicalType &GetType() { - return type; - } - -private: - //! The type of the vector cache - LogicalType type; - //! Owned data - AllocatedData owned_data; - //! Child caches (if any). Used for nested types. - vector> child_caches; - //! Aux data for the vector (if any) - buffer_ptr auxiliary; - //! Capacity of the vector - idx_t capacity; -}; - -VectorCache::VectorCache() : buffer(nullptr) { -} - -VectorCache::VectorCache(Allocator &allocator, const LogicalType &type_p, const idx_t capacity_p) { - buffer = make_buffer(allocator, type_p, capacity_p); -} - -void VectorCache::ResetFromCache(Vector &result) const { - if (!buffer) { - return; - } - auto &vector_cache = buffer->Cast(); - vector_cache.ResetFromCache(result, buffer); -} - -const LogicalType &VectorCache::GetType() const { - D_ASSERT(buffer); - auto &vector_cache = buffer->Cast(); - return vector_cache.GetType(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/types/vector_constants.cpp b/src/duckdb/src/common/types/vector_constants.cpp deleted file mode 100644 index 78c7fb5b2..000000000 --- a/src/duckdb/src/common/types/vector_constants.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/common/types/vector.hpp" - -namespace duckdb { - -const SelectionVector *ConstantVector::ZeroSelectionVector() { - static const SelectionVector ZERO_SELECTION_VECTOR = - SelectionVector(const_cast(ConstantVector::ZERO_VECTOR)); // NOLINT - return &ZERO_SELECTION_VECTOR; -} - -const SelectionVector *FlatVector::IncrementalSelectionVector() { - static const SelectionVector INCREMENTAL_SELECTION_VECTOR; - return &INCREMENTAL_SELECTION_VECTOR; -} - -const sel_t ConstantVector::ZERO_VECTOR[STANDARD_VECTOR_SIZE] = {0}; - -} // namespace duckdb diff --git a/src/duckdb/src/common/value_operations/comparison_operations.cpp b/src/duckdb/src/common/value_operations/comparison_operations.cpp deleted file mode 100644 index b2d59b0b5..000000000 --- a/src/duckdb/src/common/value_operations/comparison_operations.cpp +++ /dev/null @@ -1,265 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Comparison Operations -//===--------------------------------------------------------------------===// - -struct ValuePositionComparator { - // Return true if the positional Values definitely match. - // Default to the same as the final value - template - static inline bool Definite(const Value &lhs, const Value &rhs) { - return Final(lhs, rhs); - } - - // Select the positional Values that need further testing. - // Usually this means Is Not Distinct, as those are the semantics used by Postges - template - static inline bool Possible(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); - } - - // Return true if the positional Values definitely match in the final position - // This needs to be specialised. - template - static inline bool Final(const Value &lhs, const Value &rhs) { - return false; - } - - // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. - // This essentially means that the existing positions compare equal. - // Default to the same semantics as the OP for idx_t. This works in most cases. - template - static inline bool TieBreak(const idx_t lpos, const idx_t rpos) { - return OP::Operation(lpos, rpos); - } -}; - -// Equals must always check every column -template <> -inline bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return false; -} - -template <> -inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); -} - -// NotEquals must check everything that matched -template <> -inline bool ValuePositionComparator::Possible(const Value &lhs, const Value &rhs) { - return true; -} - -template <> -inline bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::NotDistinctFrom(lhs, rhs); -} - -// Non-strict inequalities must use strict comparisons for Definite -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValueOperations::DistinctGreaterThan(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(lhs, rhs); -} - -template <> -bool ValuePositionComparator::Definite(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Definite(rhs, lhs); -} - -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return !ValuePositionComparator::Final(rhs, lhs); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -bool ValuePositionComparator::Final(const Value &lhs, const Value &rhs) { - return ValuePositionComparator::Final(rhs, lhs); -} - -template -static bool TemplatedBooleanOperation(const Value &left, const Value &right) { - const auto &left_type = left.type(); - const auto &right_type = right.type(); - if (left_type != right_type) { - Value left_copy = left; - Value right_copy = right; - - auto comparison_type = LogicalType::ForceMaxLogicalType(left_type, right_type); - if (!left_copy.DefaultTryCastAs(comparison_type) || !right_copy.DefaultTryCastAs(comparison_type)) { - return false; - } - D_ASSERT(left_copy.type() == right_copy.type()); - return TemplatedBooleanOperation(left_copy, right_copy); - } - switch (left_type.InternalType()) { - case PhysicalType::BOOL: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT8: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT16: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT32: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT64: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT8: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT16: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT32: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT64: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::UINT128: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INT128: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::FLOAT: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::DOUBLE: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::INTERVAL: - return OP::Operation(left.GetValueUnsafe(), right.GetValueUnsafe()); - case PhysicalType::VARCHAR: - return OP::Operation(StringValue::Get(left), StringValue::Get(right)); - case PhysicalType::STRUCT: { - auto &left_children = StructValue::GetChildren(left); - auto &right_children = StructValue::GetChildren(right); - // this should be enforced by the type - D_ASSERT(left_children.size() == right_children.size()); - idx_t i = 0; - for (; i < left_children.size() - 1; ++i) { - if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { - return true; - } - if (!ValuePositionComparator::Possible(left_children[i], right_children[i])) { - return false; - } - } - return ValuePositionComparator::Final(left_children[i], right_children[i]); - } - case PhysicalType::LIST: { - auto &left_children = ListValue::GetChildren(left); - auto &right_children = ListValue::GetChildren(right); - for (idx_t pos = 0;; ++pos) { - if (pos == left_children.size() || pos == right_children.size()) { - return ValuePositionComparator::TieBreak(left_children.size(), right_children.size()); - } - if (ValuePositionComparator::Definite(left_children[pos], right_children[pos])) { - return true; - } - if (!ValuePositionComparator::Possible(left_children[pos], right_children[pos])) { - return false; - } - } - return false; - } - case PhysicalType::ARRAY: { - auto &left_children = ArrayValue::GetChildren(left); - auto &right_children = ArrayValue::GetChildren(right); - - // Should be enforced by the type - D_ASSERT(left_children.size() == right_children.size()); - - for (idx_t i = 0; i < left_children.size(); i++) { - if (ValuePositionComparator::Definite(left_children[i], right_children[i])) { - return true; - } - if (!ValuePositionComparator::Possible(left_children[i], right_children[i])) { - return false; - } - } - return true; - } - default: - throw InternalException("Unimplemented type for value comparison"); - } -} - -bool ValueOperations::Equals(const Value &left, const Value &right) { - if (left.IsNull() || right.IsNull()) { - throw InternalException("Comparison on NULL values"); - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::NotEquals(const Value &left, const Value &right) { - return !ValueOperations::Equals(left, right); -} - -bool ValueOperations::GreaterThan(const Value &left, const Value &right) { - if (left.IsNull() || right.IsNull()) { - throw InternalException("Comparison on NULL values"); - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::GreaterThanEquals(const Value &left, const Value &right) { - return !ValueOperations::GreaterThan(right, left); -} - -bool ValueOperations::LessThan(const Value &left, const Value &right) { - return ValueOperations::GreaterThan(right, left); -} - -bool ValueOperations::LessThanEquals(const Value &left, const Value &right) { - return !ValueOperations::GreaterThan(left, right); -} - -bool ValueOperations::NotDistinctFrom(const Value &left, const Value &right) { - if (left.IsNull() && right.IsNull()) { - return true; - } - if (left.IsNull() != right.IsNull()) { - return false; - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::DistinctFrom(const Value &left, const Value &right) { - return !ValueOperations::NotDistinctFrom(left, right); -} - -bool ValueOperations::DistinctGreaterThan(const Value &left, const Value &right) { - if (left.IsNull() && right.IsNull()) { - return false; - } else if (right.IsNull()) { - return false; - } else if (left.IsNull()) { - return true; - } - return TemplatedBooleanOperation(left, right); -} - -bool ValueOperations::DistinctGreaterThanEquals(const Value &left, const Value &right) { - return !ValueOperations::DistinctGreaterThan(right, left); -} - -bool ValueOperations::DistinctLessThan(const Value &left, const Value &right) { - return ValueOperations::DistinctGreaterThan(right, left); -} - -bool ValueOperations::DistinctLessThanEquals(const Value &left, const Value &right) { - return !ValueOperations::DistinctGreaterThan(left, right); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/boolean_operators.cpp b/src/duckdb/src/common/vector_operations/boolean_operators.cpp deleted file mode 100644 index 667c92085..000000000 --- a/src/duckdb/src/common/vector_operations/boolean_operators.cpp +++ /dev/null @@ -1,177 +0,0 @@ -//===--------------------------------------------------------------------===// -// boolean_operators.cpp -// Description: This file contains the implementation of the boolean -// operations AND OR ! -//===--------------------------------------------------------------------===// - -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// AND/OR -//===--------------------------------------------------------------------===// -template -static void TemplatedBooleanNullmask(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType().id() == LogicalTypeId::BOOLEAN && right.GetType().id() == LogicalTypeId::BOOLEAN && - result.GetType().id() == LogicalTypeId::BOOLEAN); - - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // operation on two constants, result is constant vector - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - auto result_data = ConstantVector::GetData(result); - - bool is_null = OP::Operation(*ldata > 0, *rdata > 0, ConstantVector::IsNull(left), - ConstantVector::IsNull(right), *result_data); - ConstantVector::SetNull(result, is_null); - } else { - // perform generic loop - UnifiedVectorFormat ldata, rdata; - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto left_data = UnifiedVectorFormat::GetData(ldata); // we use uint8 to avoid load of gunk bools - auto right_data = UnifiedVectorFormat::GetData(rdata); - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - if (!ldata.validity.AllValid() || !rdata.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - bool is_null = - OP::Operation(left_data[lidx] > 0, right_data[ridx] > 0, !ldata.validity.RowIsValid(lidx), - !rdata.validity.RowIsValid(ridx), result_data[i]); - result_mask.Set(i, !is_null); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - result_data[i] = OP::SimpleOperation(left_data[lidx], right_data[ridx]); - } - } - } -} - -/* -SQL AND Rules: - -TRUE AND TRUE = TRUE -TRUE AND FALSE = FALSE -TRUE AND NULL = NULL -FALSE AND TRUE = FALSE -FALSE AND FALSE = FALSE -FALSE AND NULL = FALSE -NULL AND TRUE = NULL -NULL AND FALSE = FALSE -NULL AND NULL = NULL - -Basically: -- Only true if both are true -- False if either is false (regardless of NULLs) -- NULL otherwise -*/ -struct TernaryAnd { - static bool SimpleOperation(bool left, bool right) { - return left && right; - } - static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { - if (left_null && right_null) { - // both NULL: - // result is NULL - return true; - } else if (left_null) { - // left is NULL: - // result is FALSE if right is false - // result is NULL if right is true - result = right; - return right; - } else if (right_null) { - // right is NULL: - // result is FALSE if left is false - // result is NULL if left is true - result = left; - return left; - } else { - // no NULL: perform the AND - result = left && right; - return false; - } - } -}; - -void VectorOperations::And(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); -} - -/* -SQL OR Rules: - -OR -TRUE OR TRUE = TRUE -TRUE OR FALSE = TRUE -TRUE OR NULL = TRUE -FALSE OR TRUE = TRUE -FALSE OR FALSE = FALSE -FALSE OR NULL = NULL -NULL OR TRUE = TRUE -NULL OR FALSE = NULL -NULL OR NULL = NULL - -Basically: -- Only false if both are false -- True if either is true (regardless of NULLs) -- NULL otherwise -*/ - -struct TernaryOr { - static bool SimpleOperation(bool left, bool right) { - return left || right; - } - static bool Operation(bool left, bool right, bool left_null, bool right_null, bool &result) { - if (left_null && right_null) { - // both NULL: - // result is NULL - return true; - } else if (left_null) { - // left is NULL: - // result is TRUE if right is true - // result is NULL if right is false - result = right; - return !right; - } else if (right_null) { - // right is NULL: - // result is TRUE if left is true - // result is NULL if left is false - result = left; - return !left; - } else { - // no NULL: perform the OR - result = left || right; - return false; - } - } -}; - -void VectorOperations::Or(Vector &left, Vector &right, Vector &result, idx_t count) { - TemplatedBooleanNullmask(left, right, result, count); -} - -struct NotOperator { - template - static inline TR Operation(TA left) { - return !left; - } -}; - -void VectorOperations::Not(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType() == LogicalType::BOOLEAN && result.GetType() == LogicalType::BOOLEAN); - UnaryExecutor::Execute(input, result, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/comparison_operators.cpp b/src/duckdb/src/common/vector_operations/comparison_operators.cpp deleted file mode 100644 index 9c5c3d662..000000000 --- a/src/duckdb/src/common/vector_operations/comparison_operators.cpp +++ /dev/null @@ -1,308 +0,0 @@ -//===--------------------------------------------------------------------===// -// comparison_operators.cpp -// Description: This file contains the implementation of the comparison -// operations == != >= <= > < -//===--------------------------------------------------------------------===// - -#include "duckdb/common/operator/comparison_operators.hpp" - -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -#include "duckdb/common/likely.hpp" - -namespace duckdb { - -template -bool EqualsFloat(T left, T right) { - if (DUCKDB_UNLIKELY(Value::IsNan(left) && Value::IsNan(right))) { - return true; - } - return left == right; -} - -template <> -bool Equals::Operation(const float &left, const float &right) { - return EqualsFloat(left, right); -} - -template <> -bool Equals::Operation(const double &left, const double &right) { - return EqualsFloat(left, right); -} - -template -bool GreaterThanFloat(T left, T right) { - // handle nans - // nan is always bigger than everything else - bool left_is_nan = Value::IsNan(left); - bool right_is_nan = Value::IsNan(right); - // if right is nan, there is no number that is bigger than right - if (DUCKDB_UNLIKELY(right_is_nan)) { - return false; - } - // if left is nan, but right is not, left is always bigger - if (DUCKDB_UNLIKELY(left_is_nan)) { - return true; - } - return left > right; -} - -template <> -bool GreaterThan::Operation(const float &left, const float &right) { - return GreaterThanFloat(left, right); -} - -template <> -bool GreaterThan::Operation(const double &left, const double &right) { - return GreaterThanFloat(left, right); -} - -template -bool GreaterThanEqualsFloat(T left, T right) { - // handle nans - // nan is always bigger than everything else - bool left_is_nan = Value::IsNan(left); - bool right_is_nan = Value::IsNan(right); - // if right is nan, there is no bigger number - // we only return true if left is also nan (in which case the numbers are equal) - if (DUCKDB_UNLIKELY(right_is_nan)) { - return left_is_nan; - } - // if left is nan, but right is not, left is always bigger - if (DUCKDB_UNLIKELY(left_is_nan)) { - return true; - } - return left >= right; -} - -template <> -bool GreaterThanEquals::Operation(const float &left, const float &right) { - return GreaterThanEqualsFloat(left, right); -} - -template <> -bool GreaterThanEquals::Operation(const double &left, const double &right) { - return GreaterThanEqualsFloat(left, right); -} - -struct ComparisonSelector { - template - static idx_t Select(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - throw NotImplementedException("Unknown comparison operation!"); - } -}; - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel, &null_mask); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel, &null_mask); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel, &null_mask); -} - -template <> -inline idx_t -ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel, &null_mask); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::GreaterThan(right, left, sel, count, true_sel, false_sel, &null_mask); -} - -template <> -inline idx_t ComparisonSelector::Select(Vector &left, Vector &right, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel, ValidityMask &null_mask) { - return VectorOperations::GreaterThanEquals(right, left, sel, count, true_sel, false_sel, &null_mask); -} - -static void ComparesNotNull(UnifiedVectorFormat &ldata, UnifiedVectorFormat &rdata, ValidityMask &vresult, - idx_t count) { - for (idx_t i = 0; i < count; ++i) { - auto lidx = ldata.sel->get_index(i); - auto ridx = rdata.sel->get_index(i); - if (!ldata.validity.RowIsValid(lidx) || !rdata.validity.RowIsValid(ridx)) { - vresult.SetInvalid(i); - } - } -} - -template -static void NestedComparisonExecutor(Vector &left, Vector &right, Vector &result, idx_t count) { - const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; - const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; - - if ((left_constant && ConstantVector::IsNull(left)) || (right_constant && ConstantVector::IsNull(right))) { - // either left or right is constant NULL: result is constant NULL - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - if (left_constant && right_constant) { - // both sides are constant, and neither is NULL so just compare one element. - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto &result_validity = ConstantVector::Validity(result); - SelectionVector true_sel(1); - auto match_count = ComparisonSelector::Select(left, right, nullptr, 1, &true_sel, nullptr, result_validity); - // since we are dealing with nested types where the values are not NULL, the result is always valid (i.e true or - // false) - result_validity.SetAllValid(1); - auto result_data = ConstantVector::GetData(result); - result_data[0] = match_count > 0; - return; - } - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - - UnifiedVectorFormat leftv, rightv; - left.ToUnifiedFormat(count, leftv); - right.ToUnifiedFormat(count, rightv); - if (!leftv.validity.AllValid() || !rightv.validity.AllValid()) { - ComparesNotNull(leftv, rightv, result_validity, count); - } - ValidityMask original_mask; - original_mask.SetAllValid(count); - original_mask.Copy(result_validity, count); - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - idx_t match_count = - ComparisonSelector::Select(left, right, nullptr, count, &true_sel, &false_sel, result_validity); - - for (idx_t i = 0; i < match_count; ++i) { - const auto idx = true_sel.get_index(i); - result_data[idx] = true; - // if the row was valid during the null check, set it to valid here as well - if (original_mask.RowIsValid(idx)) { - result_validity.SetValid(idx); - } - } - - const idx_t no_match_count = count - match_count; - for (idx_t i = 0; i < no_match_count; ++i) { - const auto idx = false_sel.get_index(i); - result_data[idx] = false; - if (original_mask.RowIsValid(idx)) { - result_validity.SetValid(idx); - } - } -} - -struct ComparisonExecutor { -private: - template - static inline void TemplatedExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - BinaryExecutor::Execute(left, right, result, count); - } - -public: - template - static inline void Execute(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType().InternalType() == right.GetType().InternalType() && - result.GetType() == LogicalType::BOOLEAN); - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT16: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT32: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT64: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT8: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT16: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT32: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT64: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INT128: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::UINT128: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::FLOAT: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::DOUBLE: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::INTERVAL: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::VARCHAR: - TemplatedExecute(left, right, result, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - NestedComparisonExecutor(left, right, result, count); - break; - default: - throw InternalException("Invalid type for comparison"); - } - } -}; - -void VectorOperations::Equals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::NotEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::GreaterThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::LessThanEquals(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(right, left, result, count); -} - -void VectorOperations::GreaterThan(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(left, right, result, count); -} - -void VectorOperations::LessThan(Vector &left, Vector &right, Vector &result, idx_t count) { - ComparisonExecutor::Execute(right, left, result, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/generators.cpp b/src/duckdb/src/common/vector_operations/generators.cpp deleted file mode 100644 index 2fc5b67c6..000000000 --- a/src/duckdb/src/common/vector_operations/generators.cpp +++ /dev/null @@ -1,91 +0,0 @@ -//===--------------------------------------------------------------------===// -// generators.cpp -// Description: This file contains the implementation of different generators -//===--------------------------------------------------------------------===// - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/numeric_utils.hpp" - -namespace duckdb { - -template -void TemplatedGenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { - D_ASSERT(result.GetType().IsNumeric()); - if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { - throw InternalException("Sequence start or increment out of type range"); - } - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto value = T(start); - for (idx_t i = 0; i < count; i++) { - if (i > 0) { - value += increment; - } - result_data[i] = value; - } -} - -void VectorOperations::GenerateSequence(Vector &result, idx_t count, int64_t start, int64_t increment) { - if (!result.GetType().IsNumeric()) { - throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); - } - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT16: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT32: - TemplatedGenerateSequence(result, count, start, increment); - break; - case PhysicalType::INT64: - TemplatedGenerateSequence(result, count, start, increment); - break; - default: - throw NotImplementedException("Unimplemented type for generate sequence"); - } -} - -template -void TemplatedGenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, - int64_t increment) { - D_ASSERT(result.GetType().IsNumeric()); - if (start > NumericLimits::Maximum() || increment > NumericLimits::Maximum()) { - throw InternalException("Sequence start or increment out of type range"); - } - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto value = static_cast(start); - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - result_data[idx] = static_cast(value + static_cast(increment) * idx); - } -} - -void VectorOperations::GenerateSequence(Vector &result, idx_t count, const SelectionVector &sel, int64_t start, - int64_t increment) { - if (!result.GetType().IsNumeric()) { - throw InvalidTypeException(result.GetType(), "Can only generate sequences for numeric values!"); - } - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT16: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT32: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - case PhysicalType::INT64: - TemplatedGenerateSequence(result, count, sel, start, increment); - break; - default: - throw NotImplementedException("Unimplemented type for generate sequence"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp b/src/duckdb/src/common/vector_operations/is_distinct_from.cpp deleted file mode 100644 index b250eddc7..000000000 --- a/src/duckdb/src/common/vector_operations/is_distinct_from.cpp +++ /dev/null @@ -1,1261 +0,0 @@ -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" - -namespace duckdb { - -struct DistinctBinaryLambdaWrapper { - template - static inline RESULT_TYPE Operation(LEFT_TYPE left, RIGHT_TYPE right, bool is_left_null, bool is_right_null) { - return OP::template Operation(left, right, is_left_null, is_right_null); - } -}; - -template -static void DistinctExecuteGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - RESULT_TYPE *__restrict result_data, const SelectionVector *__restrict lsel, - const SelectionVector *__restrict rsel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, ValidityMask &result_mask) { - for (idx_t i = 0; i < count; i++) { - auto lindex = lsel->get_index(i); - auto rindex = rsel->get_index(i); - auto lentry = ldata[lindex]; - auto rentry = rdata[rindex]; - result_data[i] = - OP::template Operation(lentry, rentry, !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex)); - } -} - -template -static void DistinctExecuteConstant(Vector &left, Vector &right, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - auto result_data = ConstantVector::GetData(result); - *result_data = - OP::template Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right)); -} - -template -static void DistinctExecuteGeneric(Vector &left, Vector &right, Vector &result, idx_t count) { - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - DistinctExecuteConstant(left, right, result); - } else { - UnifiedVectorFormat ldata, rdata; - - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - DistinctExecuteGenericLoop( - UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), - result_data, ldata.sel, rdata.sel, count, ldata.validity, rdata.validity, FlatVector::Validity(result)); - } -} - -template -static void DistinctExecuteSwitch(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecuteGeneric(left, right, result, count); -} - -template -static void DistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecuteSwitch(left, right, result, count); -} - -#ifndef DUCKDB_SMALLER_BINARY -template -#else -template -#endif -static inline idx_t -DistinctSelectGenericLoop(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { -#ifdef DUCKDB_SMALLER_BINARY - bool HAS_TRUE_SEL = true_sel; - bool HAS_FALSE_SEL = false_sel; -#endif - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - auto result_idx = result_sel->get_index(i); - auto lindex = lsel->get_index(i); - auto rindex = rsel->get_index(i); -#ifndef DUCKDB_SMALLER_BINARY - if (NO_NULL) { - if (OP::Operation(ldata[lindex], rdata[rindex], false, false)) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } else -#endif - { - if (OP::Operation(ldata[lindex], rdata[rindex], !lmask.RowIsValid(lindex), !rmask.RowIsValid(rindex))) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} - -#ifndef DUCKDB_SMALLER_BINARY -template -static inline idx_t -DistinctSelectGenericLoopSelSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else if (true_sel) { - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DistinctSelectGenericLoop( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } -} -#endif - -template -static inline idx_t -DistinctSelectGenericLoopSwitch(const LEFT_TYPE *__restrict ldata, const RIGHT_TYPE *__restrict rdata, - const SelectionVector *__restrict lsel, const SelectionVector *__restrict rsel, - const SelectionVector *__restrict result_sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { -#ifndef DUCKDB_SMALLER_BINARY - if (!lmask.AllValid() || !rmask.AllValid()) { - return DistinctSelectGenericLoopSelSwitch( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } else { - return DistinctSelectGenericLoopSelSwitch( - ldata, rdata, lsel, rsel, result_sel, count, lmask, rmask, true_sel, false_sel); - } -#else - return DistinctSelectGenericLoop(ldata, rdata, lsel, rsel, result_sel, count, lmask, - rmask, true_sel, false_sel); -#endif -} - -template -static idx_t DistinctSelectGeneric(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - UnifiedVectorFormat ldata, rdata; - - left.ToUnifiedFormat(count, ldata); - right.ToUnifiedFormat(count, rdata); - - return DistinctSelectGenericLoopSwitch( - UnifiedVectorFormat::GetData(ldata), UnifiedVectorFormat::GetData(rdata), ldata.sel, - rdata.sel, sel, count, ldata.validity, rdata.validity, true_sel, false_sel); -} - -#ifndef DUCKDB_SMALLER_BINARY -template -static inline idx_t DistinctSelectFlatLoop(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, SelectionVector *false_sel) { - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - idx_t result_idx = sel->get_index(i); - idx_t lidx = LEFT_CONSTANT ? 0 : i; - idx_t ridx = RIGHT_CONSTANT ? 0 : i; - const bool lnull = !lmask.RowIsValid(lidx); - const bool rnull = !rmask.RowIsValid(ridx); - bool comparison_result = OP::Operation(ldata[lidx], rdata[ridx], lnull, rnull); - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count, result_idx); - true_count += comparison_result; - } - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count, result_idx); - false_count += !comparison_result; - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} - -template -static inline idx_t DistinctSelectFlatLoopSelSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, - SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } else if (true_sel) { - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DistinctSelectFlatLoop( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); - } -} - -template -static inline idx_t DistinctSelectFlatLoopSwitch(LEFT_TYPE *__restrict ldata, RIGHT_TYPE *__restrict rdata, - const SelectionVector *sel, idx_t count, ValidityMask &lmask, - ValidityMask &rmask, SelectionVector *true_sel, - SelectionVector *false_sel) { - return DistinctSelectFlatLoopSelSwitch( - ldata, rdata, sel, count, lmask, rmask, true_sel, false_sel); -} - -template -static idx_t DistinctSelectFlat(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - auto ldata = FlatVector::GetData(left); - auto rdata = FlatVector::GetData(right); - if (LEFT_CONSTANT) { - ValidityMask validity; - if (ConstantVector::IsNull(left)) { - validity.SetAllInvalid(1); - } - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, validity, FlatVector::Validity(right), true_sel, false_sel); - } else if (RIGHT_CONSTANT) { - ValidityMask validity; - if (ConstantVector::IsNull(right)) { - validity.SetAllInvalid(1); - } - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::Validity(left), validity, true_sel, false_sel); - } else { - return DistinctSelectFlatLoopSwitch( - ldata, rdata, sel, count, FlatVector::Validity(left), FlatVector::Validity(right), true_sel, false_sel); - } -} -#endif - -template -static idx_t DistinctSelectConstant(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - auto ldata = ConstantVector::GetData(left); - auto rdata = ConstantVector::GetData(right); - - // both sides are constant, return either 0 or the count - // in this case we do not fill in the result selection vector at all - if (!OP::Operation(*ldata, *rdata, ConstantVector::IsNull(left), ConstantVector::IsNull(right))) { - if (false_sel) { - for (idx_t i = 0; i < count; i++) { - false_sel->set_index(i, sel->get_index(i)); - } - } - return 0; - } else { - if (true_sel) { - for (idx_t i = 0; i < count; i++) { - true_sel->set_index(i, sel->get_index(i)); - } - } - return count; - } -} - -static void UpdateNullMask(Vector &vec, const SelectionVector &sel, idx_t count, ValidityMask &null_mask) { - UnifiedVectorFormat vdata; - vec.ToUnifiedFormat(count, vdata); - - if (vdata.validity.AllValid()) { - return; - } - - for (idx_t i = 0; i < count; ++i) { - const auto ridx = sel.get_index(i); - const auto vidx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(vidx)) { - null_mask.SetInvalid(ridx); - } - } -} - -template -static idx_t DistinctSelect(Vector &left, Vector &right, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel, - optional_ptr null_mask) { - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - // TODO: Push this down? - if (null_mask) { - UpdateNullMask(left, *sel, count, *null_mask); - UpdateNullMask(right, *sel, count, *null_mask); - } - - if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return DistinctSelectConstant(left, right, sel, count, true_sel, false_sel); -#ifndef DUCKDB_SMALLER_BINARY - } else if (left.GetVectorType() == VectorType::CONSTANT_VECTOR && - right.GetVectorType() == VectorType::FLAT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); - } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && - right.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, false_sel); - } else if (left.GetVectorType() == VectorType::FLAT_VECTOR && right.GetVectorType() == VectorType::FLAT_VECTOR) { - return DistinctSelectFlat(left, right, sel, count, true_sel, - false_sel); -#endif - } else { - return DistinctSelectGeneric(left, right, sel, count, true_sel, false_sel); - } -} - -template -static idx_t DistinctSelectNotNull(Vector &left, Vector &right, const idx_t count, idx_t &true_count, - const SelectionVector &sel, SelectionVector &maybe_vec, OptionalSelection &true_opt, - OptionalSelection &false_opt, optional_ptr null_mask) { - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(count, lvdata); - right.ToUnifiedFormat(count, rvdata); - - auto &lmask = lvdata.validity; - auto &rmask = rvdata.validity; - - idx_t remaining = 0; - if (lmask.AllValid() && rmask.AllValid()) { - // None are NULL, distinguish values. - for (idx_t i = 0; i < count; ++i) { - const auto idx = sel.get_index(i); - maybe_vec.set_index(remaining++, idx); - } - return remaining; - } - - // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) - SelectionVector slicer(count); - true_count = 0; - idx_t false_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto result_idx = sel.get_index(i); - const auto lidx = lvdata.sel->get_index(i); - const auto ridx = rvdata.sel->get_index(i); - const auto lnull = !lmask.RowIsValid(lidx); - const auto rnull = !rmask.RowIsValid(ridx); - if (lnull || rnull) { - // If either is NULL then we can major distinguish them - if (null_mask) { - null_mask->SetInvalid(result_idx); - } - if (!OP::Operation(false, false, lnull, rnull)) { - false_opt.Append(false_count, result_idx); - } else { - true_opt.Append(true_count, result_idx); - } - } else { - // Neither is NULL, distinguish values. - slicer.set_index(remaining, i); - maybe_vec.set_index(remaining++, result_idx); - } - } - - true_opt.Advance(true_count); - false_opt.Advance(false_count); - - if (remaining && remaining < count) { - left.Slice(slicer, remaining); - right.Slice(slicer, remaining); - } - - return remaining; -} - -struct PositionComparator { - // Select the rows that definitely match. - // Default to the same as the final row - template - static idx_t Definite(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - optional_ptr true_sel, SelectionVector &false_sel, - optional_ptr null_mask) { - return Final(left, right, sel, count, true_sel, &false_sel, null_mask); - } - - // Select the possible rows that need further testing. - // Usually this means Is Not Distinct, as those are the semantics used by Postges - template - static idx_t Possible(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - SelectionVector &true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::NestedEquals(left, right, &sel, count, &true_sel, false_sel, null_mask); - } - - // Select the matching rows for the final position. - // This needs to be specialised. - template - static idx_t Final(Vector &left, Vector &right, const SelectionVector &sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return 0; - } - - // Tie-break based on length when one of the sides has been exhausted, returning true if the LHS matches. - // This essentially means that the existing positions compare equal. - // Default to the same semantics as the OP for idx_t. This works in most cases. - template - static bool TieBreak(const idx_t lpos, const idx_t rpos) { - return OP::Operation(lpos, rpos, false, false); - } -}; - -// NotDistinctFrom must always check every column -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - SelectionVector &false_sel, - optional_ptr null_mask) { - return 0; -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::NestedEquals(left, right, &sel, count, true_sel, false_sel, null_mask); -} - -// DistinctFrom must check everything that matched -template <> -idx_t PositionComparator::Possible(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, SelectionVector &true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return count; -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::NestedNotEquals(left, right, &sel, count, true_sel, false_sel, null_mask); -} - -// Non-strict inequalities must use strict comparisons for Definite -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - SelectionVector &false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, &false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThanEquals(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Definite(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - SelectionVector &false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, &false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThanEquals(left, right, &sel, count, true_sel, false_sel, null_mask); -} - -// Strict inequalities just use strict for both Definite and Final -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - // DistinctGreaterThan has NULLs last - return VectorOperations::DistinctGreaterThan(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, const SelectionVector &sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(left, right, &sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t PositionComparator::Final(Vector &left, Vector &right, - const SelectionVector &sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - // DistinctLessThan has NULLs last - return VectorOperations::DistinctLessThan(right, left, &sel, count, true_sel, false_sel, null_mask); -} - -using StructEntries = vector>; - -static void ExtractNestedSelection(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - OptionalSelection &opt) { - - for (idx_t i = 0; i < count;) { - const auto slice_idx = slice_sel.get_index(i); - const auto result_idx = sel.get_index(slice_idx); - opt.Append(i, result_idx); - } - opt.Advance(count); -} - -static void ExtractNestedMask(const SelectionVector &slice_sel, const idx_t count, const SelectionVector &sel, - ValidityMask *child_mask, optional_ptr null_mask) { - - if (!child_mask) { - return; - } - - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - const auto result_idx = sel.get_index(slice_idx); - if (child_mask && !child_mask->RowIsValid(slice_idx)) { - null_mask->SetInvalid(result_idx); - } - } - - child_mask->Reset(null_mask->Capacity()); -} - -static void DensifyNestedSelection(const SelectionVector &dense_sel, const idx_t count, SelectionVector &slice_sel) { - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, dense_sel.get_index(i)); - } -} - -template -static idx_t DistinctSelectStruct(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, - OptionalSelection &true_opt, OptionalSelection &false_opt, - optional_ptr null_mask) { - if (count == 0) { - return 0; - } - - // Avoid allocating in the 99% of the cases where we don't need to. - StructEntries lsliced, rsliced; - auto &lchildren = StructVector::GetEntries(left); - auto &rchildren = StructVector::GetEntries(right); - D_ASSERT(lchildren.size() == rchildren.size()); - - // In order to reuse the comparators, we have to track what passed and failed internally. - // To do that, we need local SVs that we then merge back into the real ones after every pass. - const auto vcount = count; - SelectionVector slice_sel(count); - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, i); - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - ValidityMask child_validity; - ValidityMask *child_mask = nullptr; - if (null_mask) { - child_mask = &child_validity; - child_mask->Reset(null_mask->Capacity()); - } - - idx_t match_count = 0; - for (idx_t col_no = 0; col_no < lchildren.size(); ++col_no) { - // Slice the children to maintain density - Vector lchild(*lchildren[col_no]); - lchild.Flatten(vcount); - lchild.Slice(slice_sel, count); - - Vector rchild(*rchildren[col_no]); - rchild.Flatten(vcount); - rchild.Slice(slice_sel, count); - - // Find everything that definitely matches - auto true_count = - PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel, child_mask); - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - if (true_count > 0) { - auto false_count = count - true_count; - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - - // Remove the definite matches from the slicing vector - DensifyNestedSelection(false_sel, false_count, slice_sel); - - match_count += true_count; - count -= true_count; - } - - if (col_no != lchildren.size() - 1) { - // Find what might match on the next position - true_count = - PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel, child_mask); - auto false_count = count - true_count; - - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - // Remove any definite failures from the slicing vector - if (false_count) { - DensifyNestedSelection(true_sel, true_count, slice_sel); - } - - count = true_count; - } else { - true_count = - PositionComparator::Final(lchild, rchild, slice_sel, count, &true_sel, &false_sel, child_mask); - auto false_count = count - true_count; - - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - match_count += true_count; - } - } - return match_count; -} - -static void PositionListCursor(SelectionVector &cursor, UnifiedVectorFormat &vdata, const idx_t pos, - const SelectionVector &slice_sel, const idx_t count) { - const auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - - const auto lidx = vdata.sel->get_index(slice_idx); - const auto &entry = data[lidx]; - cursor.set_index(i, entry.offset + pos); - } -} - -template -static idx_t DistinctSelectList(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, - OptionalSelection &true_opt, OptionalSelection &false_opt, - optional_ptr null_mask) { - if (count == 0) { - return count; - } - - // Create dictionary views of the children so we can vectorise the positional comparisons. - SelectionVector lcursor(count); - SelectionVector rcursor(count); - - Vector lentry_flattened(ListVector::GetEntry(left)); - Vector rentry_flattened(ListVector::GetEntry(right)); - lentry_flattened.Flatten(ListVector::GetListSize(left)); - rentry_flattened.Flatten(ListVector::GetListSize(right)); - Vector lchild(lentry_flattened, lcursor, count); - Vector rchild(rentry_flattened, rcursor, count); - - // To perform the positional comparison, we use a vectorisation of the following algorithm: - // bool CompareLists(T *left, idx_t nleft, T *right, nright) { - // for (idx_t pos = 0; ; ++pos) { - // if (nleft == pos || nright == pos) - // return OP::TieBreak(nleft, nright); - // if (OP::Definite(*left, *right)) - // return true; - // if (!OP::Maybe(*left, *right)) - // return false; - // } - // ++left, ++right; - // } - // } - - // Get pointers to the list entries - UnifiedVectorFormat lvdata; - left.ToUnifiedFormat(count, lvdata); - const auto ldata = UnifiedVectorFormat::GetData(lvdata); - - UnifiedVectorFormat rvdata; - right.ToUnifiedFormat(count, rvdata); - const auto rdata = UnifiedVectorFormat::GetData(rvdata); - - // In order to reuse the comparators, we have to track what passed and failed internally. - // To do that, we need local SVs that we then merge back into the real ones after every pass. - SelectionVector slice_sel(count); - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, i); - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - ValidityMask child_validity; - ValidityMask *child_mask = nullptr; - if (null_mask) { - child_mask = &child_validity; - child_mask->Reset(null_mask->Capacity()); - } - - idx_t match_count = 0; - for (idx_t pos = 0; count > 0; ++pos) { - // Set up the cursors for the current position - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - - // Tie-break the pairs where one of the LISTs is exhausted. - idx_t true_count = 0; - idx_t false_count = 0; - idx_t maybe_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - const auto lidx = lvdata.sel->get_index(slice_idx); - const auto &lentry = ldata[lidx]; - const auto ridx = rvdata.sel->get_index(slice_idx); - const auto &rentry = rdata[ridx]; - if (lentry.length == pos || rentry.length == pos) { - const auto idx = sel.get_index(slice_idx); - if (PositionComparator::TieBreak(lentry.length, rentry.length)) { - true_opt.Append(true_count, idx); - } else { - false_opt.Append(false_count, idx); - } - } else { - true_sel.set_index(maybe_count++, slice_idx); - } - } - true_opt.Advance(true_count); - false_opt.Advance(false_count); - match_count += true_count; - - // Redensify the list cursors - if (maybe_count < count) { - count = maybe_count; - DensifyNestedSelection(true_sel, count, slice_sel); - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - } - - // Find everything that definitely matches - true_count = - PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel, child_mask); - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - if (true_count) { - false_count = count - true_count; - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - match_count += true_count; - - // Redensify the list cursors - count -= true_count; - DensifyNestedSelection(false_sel, count, slice_sel); - PositionListCursor(lcursor, lvdata, pos, slice_sel, count); - PositionListCursor(rcursor, rvdata, pos, slice_sel, count); - } - - // Find what might match on the next position - true_count = - PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel, child_mask); - false_count = count - true_count; - - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - if (false_count) { - DensifyNestedSelection(true_sel, true_count, slice_sel); - } - count = true_count; - } - - return match_count; -} - -static void PositionArrayCursor(SelectionVector &cursor, UnifiedVectorFormat &vdata, const idx_t pos, - const SelectionVector &slice_sel, const idx_t count, idx_t array_size) { - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - const auto lidx = vdata.sel->get_index(slice_idx); - const auto offset = array_size * lidx; - cursor.set_index(i, offset + pos); - } -} - -template -static idx_t DistinctSelectArray(Vector &left, Vector &right, idx_t count, const SelectionVector &sel, - OptionalSelection &true_opt, OptionalSelection &false_opt, - optional_ptr null_mask) { - if (count == 0) { - return count; - } - - // FIXME: This function can probably be optimized since we know the array size is fixed for every entry. - - D_ASSERT(ArrayType::GetSize(left.GetType()) == ArrayType::GetSize(right.GetType())); - auto array_size = ArrayType::GetSize(left.GetType()); - - // Create dictionary views of the children so we can vectorise the positional comparisons. - SelectionVector lcursor(count); - SelectionVector rcursor(count); - - Vector lentry_flattened(ArrayVector::GetEntry(left)); - Vector rentry_flattened(ArrayVector::GetEntry(right)); - lentry_flattened.Flatten(ArrayVector::GetTotalSize(left)); - rentry_flattened.Flatten(ArrayVector::GetTotalSize(right)); - Vector lchild(lentry_flattened, lcursor, count); - Vector rchild(rentry_flattened, rcursor, count); - - // Get pointers to the list entries - UnifiedVectorFormat lvdata; - left.ToUnifiedFormat(count, lvdata); - - UnifiedVectorFormat rvdata; - right.ToUnifiedFormat(count, rvdata); - - // In order to reuse the comparators, we have to track what passed and failed internally. - // To do that, we need local SVs that we then merge back into the real ones after every pass. - SelectionVector slice_sel(count); - for (idx_t i = 0; i < count; ++i) { - slice_sel.set_index(i, i); - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - ValidityMask child_validity; - ValidityMask *child_mask = nullptr; - if (null_mask) { - child_mask = &child_validity; - child_mask->Reset(null_mask->Capacity()); - } - - idx_t match_count = 0; - for (idx_t pos = 0; count > 0; ++pos) { - // Set up the cursors for the current position - PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); - PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - - // Tie-break the pairs where one of the LISTs is exhausted. - idx_t true_count = 0; - idx_t false_count = 0; - idx_t maybe_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto slice_idx = slice_sel.get_index(i); - if (array_size == pos) { - const auto idx = sel.get_index(slice_idx); - if (PositionComparator::TieBreak(array_size, array_size)) { - true_opt.Append(true_count, idx); - } else { - false_opt.Append(false_count, idx); - } - } else { - true_sel.set_index(maybe_count++, slice_idx); - } - } - true_opt.Advance(true_count); - false_opt.Advance(false_count); - match_count += true_count; - - // Redensify the list cursors - if (maybe_count < count) { - count = maybe_count; - DensifyNestedSelection(true_sel, count, slice_sel); - PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); - PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - } - - // Find everything that definitely matches - true_count = - PositionComparator::Definite(lchild, rchild, slice_sel, count, &true_sel, false_sel, child_mask); - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - if (true_count) { - false_count = count - true_count; - - // Extract the definite matches into the true result - ExtractNestedSelection(false_count ? true_sel : slice_sel, true_count, sel, true_opt); - match_count += true_count; - - // Redensify the list cursors - count -= true_count; - DensifyNestedSelection(false_sel, count, slice_sel); - PositionArrayCursor(lcursor, lvdata, pos, slice_sel, count, array_size); - PositionArrayCursor(rcursor, rvdata, pos, slice_sel, count, array_size); - } - - // Find what might match on the next position - true_count = - PositionComparator::Possible(lchild, rchild, slice_sel, count, true_sel, &false_sel, null_mask); - false_count = count - true_count; - - // Extract any NULLs we found - ExtractNestedMask(slice_sel, count, sel, child_mask, null_mask); - - // Extract the definite failures into the false result - ExtractNestedSelection(true_count ? false_sel : slice_sel, false_count, sel, false_opt); - - if (false_count) { - DensifyNestedSelection(true_sel, true_count, slice_sel); - } - count = true_count; - } - - return match_count; -} - -template -static idx_t DistinctSelectNested(Vector &left, Vector &right, optional_ptr sel, - const idx_t count, optional_ptr true_sel, - optional_ptr false_sel, optional_ptr null_mask) { - // The Select operations all use a dense pair of input vectors to partition - // a selection vector in a single pass. But to implement progressive comparisons, - // we have to make multiple passes, so we need to keep track of the original input positions - // and then scatter the output selections when we are done. - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - // Make buffered selections for progressive comparisons - // TODO: Remove unnecessary allocations - SelectionVector true_vec(count); - OptionalSelection true_opt(&true_vec); - - SelectionVector false_vec(count); - OptionalSelection false_opt(&false_vec); - - SelectionVector maybe_vec(count); - - // Handle NULL nested values - Vector l_not_null(left); - Vector r_not_null(right); - - idx_t match_count = 0; - auto unknown = DistinctSelectNotNull(l_not_null, r_not_null, count, match_count, *sel, maybe_vec, true_opt, - false_opt, null_mask); - - switch (left.GetType().InternalType()) { - case PhysicalType::LIST: - match_count += - DistinctSelectList(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); - break; - case PhysicalType::STRUCT: - match_count += - DistinctSelectStruct(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); - break; - case PhysicalType::ARRAY: - match_count += - DistinctSelectArray(l_not_null, r_not_null, unknown, maybe_vec, true_opt, false_opt, null_mask); - break; - default: - throw NotImplementedException("Unimplemented type for DISTINCT"); - } - - // Copy the buffered selections to the output selections - if (true_sel) { - DensifyNestedSelection(true_vec, match_count, *true_sel); - } - - if (false_sel) { - DensifyNestedSelection(false_vec, count - match_count, *false_sel); - } - - return match_count; -} - -template -static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count); - -template -static inline void TemplatedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - DistinctExecute(left, right, result, count); -} -template -static void ExecuteDistinct(Vector &left, Vector &right, Vector &result, idx_t count) { - D_ASSERT(left.GetType() == right.GetType() && result.GetType() == LogicalType::BOOLEAN); - // the inplace loops take the result as the last parameter - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT16: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT32: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT64: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT8: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT16: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT32: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT64: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INT128: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::UINT128: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::FLOAT: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::DOUBLE: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::INTERVAL: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::VARCHAR: - TemplatedDistinctExecute(left, right, result, count); - break; - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - NestedDistinctExecute(left, right, result, count); - break; - default: - throw InternalException("Invalid type for distinct comparison"); - } -} - -template -static idx_t TemplatedDistinctSelectOperation(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::INT16: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::INT32: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::INT64: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::UINT8: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::UINT16: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::UINT32: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::UINT64: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::INT128: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::UINT128: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), - false_sel.get(), null_mask); - case PhysicalType::FLOAT: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::DOUBLE: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::INTERVAL: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), - false_sel.get(), null_mask); - case PhysicalType::VARCHAR: - return DistinctSelect(left, right, sel.get(), count, true_sel.get(), false_sel.get(), - null_mask); - case PhysicalType::STRUCT: - case PhysicalType::LIST: - case PhysicalType::ARRAY: - return DistinctSelectNested(left, right, sel, count, true_sel, false_sel, null_mask); - default: - throw InternalException("Invalid type for distinct selection"); - } -} - -template -static void NestedDistinctExecute(Vector &left, Vector &right, Vector &result, idx_t count) { - const auto left_constant = left.GetVectorType() == VectorType::CONSTANT_VECTOR; - const auto right_constant = right.GetVectorType() == VectorType::CONSTANT_VECTOR; - - if (left_constant && right_constant) { - // both sides are constant, so just compare one element. - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto result_data = ConstantVector::GetData(result); - SelectionVector true_sel(1); - auto match_count = TemplatedDistinctSelectOperation(left, right, nullptr, 1, &true_sel, nullptr, nullptr); - result_data[0] = match_count > 0; - return; - } - - SelectionVector true_sel(count); - SelectionVector false_sel(count); - - // DISTINCT is either true or false - idx_t match_count = - TemplatedDistinctSelectOperation(left, right, nullptr, count, &true_sel, &false_sel, nullptr); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < match_count; ++i) { - const auto idx = true_sel.get_index(i); - result_data[idx] = true; - } - - const idx_t no_match_count = count - match_count; - for (idx_t i = 0; i < no_match_count; ++i) { - const auto idx = false_sel.get_index(i); - result_data[idx] = false; - } -} - -void VectorOperations::DistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { - ExecuteDistinct(left, right, result, count); -} - -void VectorOperations::NotDistinctFrom(Vector &left, Vector &right, Vector &result, idx_t count) { - ExecuteDistinct(left, right, result, count); -} - -// true := A != B with nulls being equal -idx_t VectorOperations::DistinctFrom(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel, - nullptr); -} -// true := A == B with nulls being equal -idx_t VectorOperations::NotDistinctFrom(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel) { - return count - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel, - nullptr); -} - -// true := A > B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThan(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel, - null_mask); -} - -// true := A > B with nulls being minimal -idx_t VectorOperations::DistinctGreaterThanNullsFirst(Vector &left, Vector &right, - optional_ptr sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, - false_sel, null_mask); -} - -// true := A >= B with nulls being maximal -idx_t VectorOperations::DistinctGreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return count - TemplatedDistinctSelectOperation(right, left, sel, count, false_sel, - true_sel, null_mask); -} -// true := A < B with nulls being maximal -idx_t VectorOperations::DistinctLessThan(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedDistinctSelectOperation(right, left, sel, count, true_sel, false_sel, - null_mask); -} - -// true := A < B with nulls being minimal -idx_t VectorOperations::DistinctLessThanNullsFirst(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedDistinctSelectOperation(right, left, sel, count, true_sel, - false_sel, nullptr); -} - -// true := A <= B with nulls being maximal -idx_t VectorOperations::DistinctLessThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return count - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, - true_sel, null_mask); -} - -// true := A != B with nulls being equal, inputs selected -idx_t VectorOperations::NestedNotEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, optional_ptr null_mask) { - return TemplatedDistinctSelectOperation(left, right, sel, count, true_sel, false_sel, - null_mask); -} -// true := A == B with nulls being equal, inputs selected -idx_t VectorOperations::NestedEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return count - TemplatedDistinctSelectOperation(left, right, sel, count, false_sel, true_sel, - null_mask); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/null_operations.cpp b/src/duckdb/src/common/vector_operations/null_operations.cpp deleted file mode 100644 index dd34ac8ed..000000000 --- a/src/duckdb/src/common/vector_operations/null_operations.cpp +++ /dev/null @@ -1,113 +0,0 @@ -//===--------------------------------------------------------------------===// -// null_operators.cpp -// Description: This file contains the implementation of the -// IS NULL/NOT IS NULL operators -//===--------------------------------------------------------------------===// - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -template -void IsNullLoop(Vector &input, Vector &result, idx_t count) { - D_ASSERT(result.GetType() == LogicalType::BOOLEAN); - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - auto result_data = ConstantVector::GetData(result); - *result_data = INVERSE ? !ConstantVector::IsNull(input) : ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - result_data[i] = INVERSE ? data.validity.RowIsValid(idx) : !data.validity.RowIsValid(idx); - } - } -} - -void VectorOperations::IsNotNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); -} - -void VectorOperations::IsNull(Vector &input, Vector &result, idx_t count) { - IsNullLoop(input, result, count); -} - -bool VectorOperations::HasNotNull(Vector &input, idx_t count) { - if (count == 0) { - return false; - } - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return !ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - if (data.validity.AllValid()) { - return true; - } - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - if (data.validity.RowIsValid(idx)) { - return true; - } - } - return false; - } -} - -bool VectorOperations::HasNull(Vector &input, idx_t count) { - if (count == 0) { - return false; - } - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - return ConstantVector::IsNull(input); - } else { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - - if (data.validity.AllValid()) { - return false; - } - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - if (!data.validity.RowIsValid(idx)) { - return true; - } - } - return false; - } -} - -idx_t VectorOperations::CountNotNull(Vector &input, const idx_t count) { - idx_t valid = 0; - - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(count, vdata); - if (vdata.validity.AllValid()) { - return count; - } - switch (input.GetVectorType()) { - case VectorType::FLAT_VECTOR: - valid += vdata.validity.CountValid(count); - break; - case VectorType::CONSTANT_VECTOR: - valid += vdata.validity.CountValid(1) * count; - break; - default: - for (idx_t i = 0; i < count; ++i) { - const auto row_idx = vdata.sel->get_index(i); - valid += idx_t(vdata.validity.RowIsValid(row_idx)); - } - break; - } - - return valid; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp b/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp deleted file mode 100644 index 863f3ba8d..000000000 --- a/src/duckdb/src/common/vector_operations/numeric_inplace_operators.cpp +++ /dev/null @@ -1,40 +0,0 @@ -//===--------------------------------------------------------------------===// -// numeric_inplace_operators.cpp -// Description: This file contains the implementation of numeric inplace ops -// += *= /= -= %= -//===--------------------------------------------------------------------===// - -#include "duckdb/common/vector_operations/vector_operations.hpp" - -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// In-Place Addition -//===--------------------------------------------------------------------===// - -void VectorOperations::AddInPlace(Vector &input, int64_t right, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::POINTER); - if (right == 0) { - return; - } - switch (input.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - D_ASSERT(!ConstantVector::IsNull(input)); - auto data = ConstantVector::GetData(input); - *data += UnsafeNumericCast(right); - break; - } - default: { - D_ASSERT(input.GetVectorType() == VectorType::FLAT_VECTOR); - auto data = FlatVector::GetData(input); - for (idx_t i = 0; i < count; i++) { - data[i] = UnsafeNumericCast(UnsafeNumericCast(data[i]) + right); - } - break; - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_cast.cpp b/src/duckdb/src/common/vector_operations/vector_cast.cpp deleted file mode 100644 index e49e51361..000000000 --- a/src/duckdb/src/common/vector_operations/vector_cast.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "duckdb/common/limits.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/function/scalar_function.hpp" - -namespace duckdb { - -bool VectorOperations::TryCast(CastFunctionSet &set, GetCastFunctionInput &input, Vector &source, Vector &result, - idx_t count, string *error_message, bool strict, const bool nullify_parent) { - auto cast_function = set.GetCastFunction(source.GetType(), result.GetType(), input); - unique_ptr local_state; - if (cast_function.init_local_state) { - CastLocalStateParameters lparameters(input.context, cast_function.cast_data); - local_state = cast_function.init_local_state(lparameters); - } - CastParameters parameters(cast_function.cast_data.get(), strict, error_message, local_state.get(), nullify_parent); - return cast_function.function(source, result, count, parameters); -} - -bool VectorOperations::DefaultTryCast(Vector &source, Vector &result, idx_t count, string *error_message, bool strict) { - CastFunctionSet set; - GetCastFunctionInput input; - return VectorOperations::TryCast(set, input, source, result, count, error_message, strict); -} - -void VectorOperations::DefaultCast(Vector &source, Vector &result, idx_t count, bool strict) { - VectorOperations::DefaultTryCast(source, result, count, nullptr, strict); -} - -bool VectorOperations::TryCast(ClientContext &context, Vector &source, Vector &result, idx_t count, - string *error_message, bool strict, const bool nullify_parent) { - auto &config = DBConfig::GetConfig(context); - auto &set = config.GetCastFunctions(); - GetCastFunctionInput get_input(context); - return VectorOperations::TryCast(set, get_input, source, result, count, error_message, strict, nullify_parent); -} - -void VectorOperations::Cast(ClientContext &context, Vector &source, Vector &result, idx_t count, bool strict) { - VectorOperations::TryCast(context, source, result, count, nullptr, strict); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_copy.cpp b/src/duckdb/src/common/vector_operations/vector_copy.cpp deleted file mode 100644 index 0880f23be..000000000 --- a/src/duckdb/src/common/vector_operations/vector_copy.cpp +++ /dev/null @@ -1,281 +0,0 @@ -//===--------------------------------------------------------------------===// -// copy.cpp -// Description: This file contains the implementation of the different copy -// functions -//===--------------------------------------------------------------------===// - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/storage/segment/uncompressed.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -template -static void TemplatedCopy(const Vector &source, const SelectionVector &sel, Vector &target, idx_t source_offset, - idx_t target_offset, idx_t copy_count) { - auto ldata = FlatVector::GetData(source); - auto tdata = FlatVector::GetData(target); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel.get_index(source_offset + i); - tdata[target_offset + i] = ldata[source_idx]; - } -} - -static const ValidityMask &ExtractValidityMask(const Vector &v) { - switch (v.GetVectorType()) { - case VectorType::FLAT_VECTOR: - return FlatVector::Validity(v); - case VectorType::FSST_VECTOR: - return FSSTVector::Validity(v); - default: - throw InternalException("Unsupported vector type in vector copy"); - } -} - -void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, - idx_t source_offset, idx_t target_offset, idx_t copy_count) { - - SelectionVector owned_sel; - const SelectionVector *sel = &sel_p; - - const Vector *source = &source_p; - bool finished = false; - while (!finished) { - switch (source->GetVectorType()) { - case VectorType::DICTIONARY_VECTOR: { - // dictionary vector: merge selection vectors - auto &child = DictionaryVector::Child(*source); - auto &dict_sel = DictionaryVector::SelVector(*source); - // merge the selection vectors and verify the child - auto new_buffer = dict_sel.Slice(*sel, source_count); - owned_sel.Initialize(new_buffer); - sel = &owned_sel; - source = &child; - break; - } - case VectorType::SEQUENCE_VECTOR: { - int64_t start, increment; - Vector seq(source->GetType()); - SequenceVector::GetSequence(*source, start, increment); - VectorOperations::GenerateSequence(seq, source_count, *sel, start, increment); - VectorOperations::Copy(seq, target, *sel, source_count, source_offset, target_offset); - return; - } - case VectorType::CONSTANT_VECTOR: - sel = ConstantVector::ZeroSelectionVector(copy_count, owned_sel); - finished = true; - break; - case VectorType::FSST_VECTOR: - finished = true; - break; - case VectorType::FLAT_VECTOR: - finished = true; - break; - default: - throw NotImplementedException("FIXME unimplemented vector type for VectorOperations::Copy"); - } - } - - if (copy_count == 0) { - return; - } - - // Allow copying of a single value to constant vectors - const auto target_vector_type = target.GetVectorType(); - if (copy_count == 1 && target_vector_type == VectorType::CONSTANT_VECTOR) { - target_offset = 0; - target.SetVectorType(VectorType::FLAT_VECTOR); - } - D_ASSERT(target.GetVectorType() == VectorType::FLAT_VECTOR); - - // first copy the nullmask - auto &tmask = FlatVector::Validity(target); - if (source->GetVectorType() == VectorType::CONSTANT_VECTOR) { - const bool valid = !ConstantVector::IsNull(*source); - for (idx_t i = 0; i < copy_count; i++) { - tmask.Set(target_offset + i, valid); - } - } else { - auto &smask = ExtractValidityMask(*source); - tmask.CopySel(smask, *sel, source_offset, target_offset, copy_count); - } - - D_ASSERT(sel); - - // For FSST Vectors we decompress instead of copying. - if (source->GetVectorType() == VectorType::FSST_VECTOR) { - FSSTVector::DecompressVector(*source, target, source_offset, target_offset, copy_count, sel); - return; - } - - // now copy over the data - switch (source->GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT16: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT32: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT64: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT8: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT16: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT32: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT64: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INT128: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::UINT128: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::FLOAT: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::DOUBLE: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::INTERVAL: - TemplatedCopy(*source, *sel, target, source_offset, target_offset, copy_count); - break; - case PhysicalType::VARCHAR: { - auto ldata = FlatVector::GetData(*source); - auto tdata = FlatVector::GetData(target); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(source_offset + i); - auto target_idx = target_offset + i; - if (tmask.RowIsValid(target_idx)) { - tdata[target_idx] = StringVector::AddStringOrBlob(target, ldata[source_idx]); - } - } - break; - } - case PhysicalType::STRUCT: { - auto &source_children = StructVector::GetEntries(*source); - auto &target_children = StructVector::GetEntries(target); - D_ASSERT(source_children.size() == target_children.size()); - for (idx_t i = 0; i < source_children.size(); i++) { - VectorOperations::Copy(*source_children[i], *target_children[i], sel_p, source_count, source_offset, - target_offset, copy_count); - } - break; - } - case PhysicalType::ARRAY: { - D_ASSERT(target.GetType().InternalType() == PhysicalType::ARRAY); - D_ASSERT(ArrayType::GetSize(source->GetType()) == ArrayType::GetSize(target.GetType())); - - auto &source_child = ArrayVector::GetEntry(*source); - auto &target_child = ArrayVector::GetEntry(target); - auto array_size = ArrayType::GetSize(source->GetType()); - - // Create a selection vector for the child elements - SelectionVector child_sel(source_count * array_size); - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(source_offset + i); - for (idx_t j = 0; j < array_size; j++) { - child_sel.set_index((source_offset * array_size) + (i * array_size + j), source_idx * array_size + j); - } - } - VectorOperations::Copy(source_child, target_child, child_sel, source_count * array_size, - source_offset * array_size, target_offset * array_size); - break; - } - case PhysicalType::LIST: { - D_ASSERT(target.GetType().InternalType() == PhysicalType::LIST); - - auto &source_child = ListVector::GetEntry(*source); - auto sdata = FlatVector::GetData(*source); - auto tdata = FlatVector::GetData(target); - - if (target_vector_type == VectorType::CONSTANT_VECTOR) { - // If we are only writing one value, then the copied values (if any) are contiguous - // and we can just Append from the offset position - if (!tmask.RowIsValid(target_offset)) { - break; - } - auto source_idx = sel->get_index(source_offset); - auto &source_entry = sdata[source_idx]; - const idx_t source_child_size = source_entry.length + source_entry.offset; - - //! overwrite constant target vectors. - ListVector::SetListSize(target, 0); - ListVector::Append(target, source_child, source_child_size, source_entry.offset); - - auto &target_entry = tdata[target_offset]; - target_entry.length = source_entry.length; - target_entry.offset = 0; - } else { - //! if the source has list offsets, we need to append them to the target - //! build a selection vector for the copied child elements - vector child_rows; - for (idx_t i = 0; i < copy_count; ++i) { - if (tmask.RowIsValid(target_offset + i)) { - auto source_idx = sel->get_index(source_offset + i); - auto &source_entry = sdata[source_idx]; - for (idx_t j = 0; j < source_entry.length; ++j) { - child_rows.emplace_back(source_entry.offset + j); - } - } - } - idx_t source_child_size = child_rows.size(); - SelectionVector child_sel(child_rows.data()); - - idx_t old_target_child_len = ListVector::GetListSize(target); - - //! append to list itself - ListVector::Append(target, source_child, child_sel, source_child_size); - - //! now write the list offsets - for (idx_t i = 0; i < copy_count; i++) { - auto source_idx = sel->get_index(source_offset + i); - auto &source_entry = sdata[source_idx]; - auto &target_entry = tdata[target_offset + i]; - - target_entry.length = source_entry.length; - target_entry.offset = old_target_child_len; - if (tmask.RowIsValid(target_offset + i)) { - old_target_child_len += target_entry.length; - } - } - } - break; - } - default: - throw NotImplementedException("Unimplemented type '%s' for copy!", - TypeIdToString(source->GetType().InternalType())); - } - - if (target_vector_type != VectorType::FLAT_VECTOR) { - target.SetVectorType(target_vector_type); - } -} - -void VectorOperations::Copy(const Vector &source_p, Vector &target, const SelectionVector &sel_p, idx_t source_count, - idx_t source_offset, idx_t target_offset) { - D_ASSERT(source_offset <= source_count); - D_ASSERT(source_p.GetType() == target.GetType()); - idx_t copy_count = source_count - source_offset; - VectorOperations::Copy(source_p, target, sel_p, source_count, source_offset, target_offset, copy_count); -} - -void VectorOperations::Copy(const Vector &source, Vector &target, idx_t source_count, idx_t source_offset, - idx_t target_offset) { - VectorOperations::Copy(source, target, *FlatVector::IncrementalSelectionVector(), source_count, source_offset, - target_offset); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_hash.cpp b/src/duckdb/src/common/vector_operations/vector_hash.cpp deleted file mode 100644 index c82422c23..000000000 --- a/src/duckdb/src/common/vector_operations/vector_hash.cpp +++ /dev/null @@ -1,465 +0,0 @@ -//===--------------------------------------------------------------------===// -// hash.cpp -// Description: This file contains the vectorized hash implementations -//===--------------------------------------------------------------------===// - -#include "duckdb/common/types/hash.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -struct HashOp { - static const hash_t NULL_HASH = 0xbf58476d1ce4e5b9; - - template - static inline hash_t Operation(T input, bool is_null) { - return is_null ? NULL_HASH : duckdb::Hash(input); - } -}; - -static inline hash_t CombineHashScalar(hash_t a, hash_t b) { - a ^= a >> 32; - a *= 0xd6e8feb86659fd93U; - return a ^ b; -} - -template -static inline void TightLoopHash(const T *__restrict ldata, hash_t *__restrict result_data, const SelectionVector *rsel, - idx_t count, const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - result_data[ridx] = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - result_data[ridx] = duckdb::Hash(ldata[idx]); - } - } -} - -template -static inline void TemplatedLoopHash(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto ldata = ConstantVector::GetData(input); - auto result_data = ConstantVector::GetData(result); - *result_data = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); - } else { - result.SetVectorType(VectorType::FLAT_VECTOR); - - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - TightLoopHash(UnifiedVectorFormat::GetData(idata), FlatVector::GetData(result), rsel, - count, idata.sel, idata.validity); - } -} - -template -static inline void StructLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - auto &children = StructVector::GetEntries(input); - - D_ASSERT(!children.empty()); - idx_t col_no = 0; - if (HAS_RSEL) { - if (FIRST_HASH) { - VectorOperations::Hash(*children[col_no++], hashes, *rsel, count); - } else { - VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); - } - while (col_no < children.size()) { - VectorOperations::CombineHash(hashes, *children[col_no++], *rsel, count); - } - } else { - if (FIRST_HASH) { - VectorOperations::Hash(*children[col_no++], hashes, count); - } else { - VectorOperations::CombineHash(hashes, *children[col_no++], count); - } - while (col_no < children.size()) { - VectorOperations::CombineHash(hashes, *children[col_no++], count); - } - } -} - -template -static inline void ListLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - // FIXME: if we want to be more efficient we shouldn't flatten, but the logic here currently requires it - hashes.Flatten(count); - auto hdata = FlatVector::GetData(hashes); - - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - const auto ldata = UnifiedVectorFormat::GetData(idata); - - // Hash the children into a temporary - auto &child = ListVector::GetEntry(input); - const auto child_count = ListVector::GetListSize(input); - - Vector child_hashes(LogicalType::HASH, child_count); - if (child_count > 0) { - VectorOperations::Hash(child, child_hashes, child_count); - child_hashes.Flatten(child_count); - } - auto chdata = FlatVector::GetData(child_hashes); - - // Reduce the number of entries to check to the non-empty ones - SelectionVector unprocessed(count); - SelectionVector cursor(HAS_RSEL ? STANDARD_VECTOR_SIZE : count); - idx_t remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const idx_t ridx = HAS_RSEL ? rsel->get_index(i) : i; - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (idata.validity.RowIsValid(lidx) && entry.length > 0) { - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, entry.offset); - } else if (FIRST_HASH) { - hdata[ridx] = HashOp::NULL_HASH; - } - // Empty or NULL non-first elements have no effect. - } - - count = remaining; - if (count == 0) { - return; - } - - // Merge the first position hash into the main hash - idx_t position = 1; - if (FIRST_HASH) { - remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const auto ridx = unprocessed.get_index(i); - const auto cidx = cursor.get_index(ridx); - hdata[ridx] = chdata[cidx]; - - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (entry.length > position) { - // Entry still has values to hash - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, cidx + 1); - } - } - count = remaining; - if (count == 0) { - return; - } - ++position; - } - - // Combine the hashes for the remaining positions until there are none left - for (;; ++position) { - remaining = 0; - for (idx_t i = 0; i < count; ++i) { - const auto ridx = unprocessed.get_index(i); - const auto cidx = cursor.get_index(ridx); - hdata[ridx] = CombineHashScalar(hdata[ridx], chdata[cidx]); - - const auto lidx = idata.sel->get_index(ridx); - const auto &entry = ldata[lidx]; - if (entry.length > position) { - // Entry still has values to hash - unprocessed.set_index(remaining++, ridx); - cursor.set_index(ridx, cidx + 1); - } - } - - count = remaining; - if (count == 0) { - break; - } - } -} - -template -static inline void ArrayLoopHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - hashes.Flatten(count); - auto hdata = FlatVector::GetData(hashes); - - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - // Hash the children into a temporary - auto &child = ArrayVector::GetEntry(input); - auto array_size = ArrayType::GetSize(input.GetType()); - - auto is_flat = input.GetVectorType() == VectorType::FLAT_VECTOR; - auto is_constant = input.GetVectorType() == VectorType::CONSTANT_VECTOR; - - if (!HAS_RSEL && (is_flat || is_constant)) { - // Fast path for contiguous vectors with no selection vector - auto child_count = array_size * (is_constant ? 1 : count); - - Vector child_hashes(LogicalType::HASH, child_count); - VectorOperations::Hash(child, child_hashes, child_count); - child_hashes.Flatten(child_count); - auto chdata = FlatVector::GetData(child_hashes); - - for (idx_t i = 0; i < count; i++) { - auto lidx = idata.sel->get_index(i); - if (idata.validity.RowIsValid(lidx)) { - if (FIRST_HASH) { - hdata[i] = 0; - } - for (idx_t j = 0; j < array_size; j++) { - auto offset = lidx * array_size + j; - hdata[i] = CombineHashScalar(hdata[i], chdata[offset]); - } - } else if (FIRST_HASH) { - hdata[i] = HashOp::NULL_HASH; - } - } - } else { - // Hash the arrays one-by-one - SelectionVector array_sel(array_size); - Vector array_hashes(LogicalType::HASH, array_size); - for (idx_t i = 0; i < count; i++) { - const auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - const auto lidx = idata.sel->get_index(ridx); - - if (idata.validity.RowIsValid(lidx)) { - // Create a selection vector for the array - for (idx_t j = 0; j < array_size; j++) { - array_sel.set_index(j, lidx * array_size + j); - } - - // Hash the array slice - Vector dict_vec(child, array_sel, array_size); - VectorOperations::Hash(dict_vec, array_hashes, array_size); - auto ahdata = FlatVector::GetData(array_hashes); - - if (FIRST_HASH) { - hdata[ridx] = 0; - } - // Combine the hashes of the array - for (idx_t j = 0; j < array_size; j++) { - hdata[ridx] = CombineHashScalar(hdata[ridx], ahdata[j]); - // Clear the hash for the next iteration - ahdata[j] = 0; - } - } else if (FIRST_HASH) { - hdata[ridx] = HashOp::NULL_HASH; - } - } - } -} - -template -static inline void HashTypeSwitch(Vector &input, Vector &result, const SelectionVector *rsel, idx_t count) { - D_ASSERT(result.GetType().id() == LogicalType::HASH); - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT16: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT32: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT64: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT8: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT16: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT32: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT64: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INT128: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::UINT128: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::FLOAT: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::DOUBLE: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::INTERVAL: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::VARCHAR: - TemplatedLoopHash(input, result, rsel, count); - break; - case PhysicalType::STRUCT: - StructLoopHash(input, result, rsel, count); - break; - case PhysicalType::LIST: - ListLoopHash(input, result, rsel, count); - break; - case PhysicalType::ARRAY: - ArrayLoopHash(input, result, rsel, count); - break; - default: - throw InvalidTypeException(input.GetType(), "Invalid type for hash"); - } -} - -void VectorOperations::Hash(Vector &input, Vector &result, idx_t count) { - HashTypeSwitch(input, result, nullptr, count); -} - -void VectorOperations::Hash(Vector &input, Vector &result, const SelectionVector &sel, idx_t count) { - HashTypeSwitch(input, result, &sel, count); -} - -template -static inline void TightLoopCombineHashConstant(const T *__restrict ldata, hash_t constant_hash, - hash_t *__restrict hash_data, const SelectionVector *rsel, idx_t count, - const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = duckdb::Hash(ldata[idx]); - hash_data[ridx] = CombineHashScalar(constant_hash, other_hash); - } - } -} - -template -static inline void TightLoopCombineHash(const T *__restrict ldata, hash_t *__restrict hash_data, - const SelectionVector *rsel, idx_t count, - const SelectionVector *__restrict sel_vector, ValidityMask &mask) { - if (!mask.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = HashOp::Operation(ldata[idx], !mask.RowIsValid(idx)); - hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); - } - } else { - for (idx_t i = 0; i < count; i++) { - auto ridx = HAS_RSEL ? rsel->get_index(i) : i; - auto idx = sel_vector->get_index(ridx); - auto other_hash = duckdb::Hash(ldata[idx]); - hash_data[ridx] = CombineHashScalar(hash_data[ridx], other_hash); - } - } -} - -template -void TemplatedLoopCombineHash(Vector &input, Vector &hashes, const SelectionVector *rsel, idx_t count) { - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR && hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto ldata = ConstantVector::GetData(input); - auto hash_data = ConstantVector::GetData(hashes); - - auto other_hash = HashOp::Operation(*ldata, ConstantVector::IsNull(input)); - *hash_data = CombineHashScalar(*hash_data, other_hash); - } else { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - if (hashes.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // mix constant with non-constant, first get the constant value - auto constant_hash = *ConstantVector::GetData(hashes); - // now re-initialize the hashes vector to an empty flat vector - hashes.SetVectorType(VectorType::FLAT_VECTOR); - TightLoopCombineHashConstant(UnifiedVectorFormat::GetData(idata), constant_hash, - FlatVector::GetData(hashes), rsel, count, idata.sel, - idata.validity); - } else { - D_ASSERT(hashes.GetVectorType() == VectorType::FLAT_VECTOR); - TightLoopCombineHash(UnifiedVectorFormat::GetData(idata), - FlatVector::GetData(hashes), rsel, count, idata.sel, - idata.validity); - } - } -} - -template -static inline void CombineHashTypeSwitch(Vector &hashes, Vector &input, const SelectionVector *rsel, idx_t count) { - D_ASSERT(hashes.GetType().id() == LogicalType::HASH); - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT16: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT32: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT64: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT8: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT16: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT32: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT64: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INT128: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::UINT128: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::FLOAT: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::DOUBLE: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::INTERVAL: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::VARCHAR: - TemplatedLoopCombineHash(input, hashes, rsel, count); - break; - case PhysicalType::STRUCT: - StructLoopHash(input, hashes, rsel, count); - break; - case PhysicalType::LIST: - ListLoopHash(input, hashes, rsel, count); - break; - case PhysicalType::ARRAY: - ArrayLoopHash(input, hashes, rsel, count); - break; - default: - throw InvalidTypeException(input.GetType(), "Invalid type for hash"); - } -} - -void VectorOperations::CombineHash(Vector &hashes, Vector &input, idx_t count) { - CombineHashTypeSwitch(hashes, input, nullptr, count); -} - -void VectorOperations::CombineHash(Vector &hashes, Vector &input, const SelectionVector &rsel, idx_t count) { - CombineHashTypeSwitch(hashes, input, &rsel, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/vector_operations/vector_storage.cpp b/src/duckdb/src/common/vector_operations/vector_storage.cpp deleted file mode 100644 index 9c399519d..000000000 --- a/src/duckdb/src/common/vector_operations/vector_storage.cpp +++ /dev/null @@ -1,132 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" - -namespace duckdb { - -template -static void CopyToStorageLoop(UnifiedVectorFormat &vdata, idx_t count, data_ptr_t target) { - auto ldata = UnifiedVectorFormat::GetData(vdata); - auto result_data = (T *)target; - for (idx_t i = 0; i < count; i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - result_data[i] = NullValue(); - } else { - result_data[i] = ldata[idx]; - } - } -} - -void VectorOperations::WriteToStorage(Vector &source, idx_t count, data_ptr_t target) { - if (count == 0) { - return; - } - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - switch (source.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT16: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT32: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT64: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT8: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT16: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT32: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT64: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INT128: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::UINT128: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::FLOAT: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::DOUBLE: - CopyToStorageLoop(vdata, count, target); - break; - case PhysicalType::INTERVAL: - CopyToStorageLoop(vdata, count, target); - break; - default: - throw NotImplementedException("Unimplemented type for WriteToStorage"); - } -} - -template -static void ReadFromStorageLoop(data_ptr_t source, idx_t count, Vector &result) { - auto ldata = (T *)source; - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - result_data[i] = ldata[i]; - } -} - -void VectorOperations::ReadFromStorage(data_ptr_t source, idx_t count, Vector &result) { - result.SetVectorType(VectorType::FLAT_VECTOR); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT16: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT32: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT64: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT8: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT16: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT32: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT64: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INT128: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::UINT128: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::FLOAT: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::DOUBLE: - ReadFromStorageLoop(source, count, result); - break; - case PhysicalType::INTERVAL: - ReadFromStorageLoop(source, count, result); - break; - default: - throw NotImplementedException("Unimplemented type for ReadFromStorage"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/virtual_file_system.cpp b/src/duckdb/src/common/virtual_file_system.cpp deleted file mode 100644 index 74892a4e0..000000000 --- a/src/duckdb/src/common/virtual_file_system.cpp +++ /dev/null @@ -1,198 +0,0 @@ -#include "duckdb/common/virtual_file_system.hpp" -#include "duckdb/common/gzip_file_system.hpp" -#include "duckdb/common/pipe_file_system.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -VirtualFileSystem::VirtualFileSystem() : default_fs(FileSystem::CreateLocal()) { - VirtualFileSystem::RegisterSubSystem(FileCompressionType::GZIP, make_uniq()); -} - -unique_ptr VirtualFileSystem::OpenFile(const string &path, FileOpenFlags flags, - optional_ptr opener) { - auto compression = flags.Compression(); - if (compression == FileCompressionType::AUTO_DETECT) { - // auto-detect compression settings based on file name - auto lower_path = StringUtil::Lower(path); - if (StringUtil::EndsWith(lower_path, ".tmp")) { - // strip .tmp - lower_path = lower_path.substr(0, lower_path.length() - 4); - } - if (IsFileCompressed(path, FileCompressionType::GZIP)) { - compression = FileCompressionType::GZIP; - } else if (IsFileCompressed(path, FileCompressionType::ZSTD)) { - compression = FileCompressionType::ZSTD; - } else { - compression = FileCompressionType::UNCOMPRESSED; - } - } - // open the base file handle in UNCOMPRESSED mode - flags.SetCompression(FileCompressionType::UNCOMPRESSED); - auto file_handle = FindFileSystem(path).OpenFile(path, flags, opener); - if (!file_handle) { - return nullptr; - } - if (file_handle->GetType() == FileType::FILE_TYPE_FIFO) { - file_handle = PipeFileSystem::OpenPipe(std::move(file_handle)); - } else if (compression != FileCompressionType::UNCOMPRESSED) { - auto entry = compressed_fs.find(compression); - if (entry == compressed_fs.end()) { - throw NotImplementedException( - "Attempting to open a compressed file, but the compression type is not supported"); - } - file_handle = entry->second->OpenCompressedFile(std::move(file_handle), flags.OpenForWriting()); - } - return file_handle; -} - -void VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - handle.file_system.Read(handle, buffer, nr_bytes, location); -} - -void VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - handle.file_system.Write(handle, buffer, nr_bytes, location); -} - -int64_t VirtualFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - return handle.file_system.Read(handle, buffer, nr_bytes); -} - -int64_t VirtualFileSystem::Write(FileHandle &handle, void *buffer, int64_t nr_bytes) { - return handle.file_system.Write(handle, buffer, nr_bytes); -} - -int64_t VirtualFileSystem::GetFileSize(FileHandle &handle) { - return handle.file_system.GetFileSize(handle); -} -time_t VirtualFileSystem::GetLastModifiedTime(FileHandle &handle) { - return handle.file_system.GetLastModifiedTime(handle); -} -FileType VirtualFileSystem::GetFileType(FileHandle &handle) { - return handle.file_system.GetFileType(handle); -} - -void VirtualFileSystem::Truncate(FileHandle &handle, int64_t new_size) { - handle.file_system.Truncate(handle, new_size); -} - -void VirtualFileSystem::FileSync(FileHandle &handle) { - handle.file_system.FileSync(handle); -} - -// need to look up correct fs for this -bool VirtualFileSystem::DirectoryExists(const string &directory, optional_ptr opener) { - return FindFileSystem(directory).DirectoryExists(directory, opener); -} -void VirtualFileSystem::CreateDirectory(const string &directory, optional_ptr opener) { - FindFileSystem(directory).CreateDirectory(directory, opener); -} - -void VirtualFileSystem::RemoveDirectory(const string &directory, optional_ptr opener) { - FindFileSystem(directory).RemoveDirectory(directory, opener); -} - -bool VirtualFileSystem::ListFiles(const string &directory, const std::function &callback, - FileOpener *opener) { - return FindFileSystem(directory).ListFiles(directory, callback, opener); -} - -void VirtualFileSystem::MoveFile(const string &source, const string &target, optional_ptr opener) { - FindFileSystem(source).MoveFile(source, target, opener); -} - -bool VirtualFileSystem::FileExists(const string &filename, optional_ptr opener) { - return FindFileSystem(filename).FileExists(filename, opener); -} - -bool VirtualFileSystem::IsPipe(const string &filename, optional_ptr opener) { - return FindFileSystem(filename).IsPipe(filename, opener); -} - -void VirtualFileSystem::RemoveFile(const string &filename, optional_ptr opener) { - FindFileSystem(filename).RemoveFile(filename, opener); -} - -string VirtualFileSystem::PathSeparator(const string &path) { - return FindFileSystem(path).PathSeparator(path); -} - -vector VirtualFileSystem::Glob(const string &path, FileOpener *opener) { - return FindFileSystem(path).Glob(path, opener); -} - -void VirtualFileSystem::RegisterSubSystem(unique_ptr fs) { - sub_systems.push_back(std::move(fs)); -} - -void VirtualFileSystem::UnregisterSubSystem(const string &name) { - for (auto sub_system = sub_systems.begin(); sub_system != sub_systems.end(); sub_system++) { - if (sub_system->get()->GetName() == name) { - sub_systems.erase(sub_system); - return; - } - } - throw InvalidInputException("Could not find filesystem with name %s", name); -} - -void VirtualFileSystem::RegisterSubSystem(FileCompressionType compression_type, unique_ptr fs) { - compressed_fs[compression_type] = std::move(fs); -} - -vector VirtualFileSystem::ListSubSystems() { - vector names(sub_systems.size()); - for (idx_t i = 0; i < sub_systems.size(); i++) { - names[i] = sub_systems[i]->GetName(); - } - return names; -} - -std::string VirtualFileSystem::GetName() const { - return "VirtualFileSystem"; -} - -void VirtualFileSystem::SetDisabledFileSystems(const vector &names) { - unordered_set new_disabled_file_systems; - for (auto &name : names) { - if (name.empty()) { - continue; - } - if (new_disabled_file_systems.find(name) != new_disabled_file_systems.end()) { - throw InvalidInputException("Duplicate disabled file system \"%s\"", name); - } - new_disabled_file_systems.insert(name); - } - for (auto &disabled_fs : disabled_file_systems) { - if (new_disabled_file_systems.find(disabled_fs) == new_disabled_file_systems.end()) { - throw InvalidInputException("File system \"%s\" has been disabled previously, it cannot be re-enabled", - disabled_fs); - } - } - disabled_file_systems = std::move(new_disabled_file_systems); -} - -FileSystem &VirtualFileSystem::FindFileSystem(const string &path) { - auto &fs = FindFileSystemInternal(path); - if (!disabled_file_systems.empty() && disabled_file_systems.find(fs.GetName()) != disabled_file_systems.end()) { - throw PermissionException("File system %s has been disabled by configuration", fs.GetName()); - } - return fs; -} - -FileSystem &VirtualFileSystem::FindFileSystemInternal(const string &path) { - FileSystem *fs = nullptr; - for (auto &sub_system : sub_systems) { - if (sub_system->CanHandleFile(path)) { - if (sub_system->IsManuallySet()) { - return *sub_system; - } - fs = sub_system.get(); - } - } - if (fs) { - return *fs; - } - return *default_fs; -} - -} // namespace duckdb diff --git a/src/duckdb/src/common/windows_util.cpp b/src/duckdb/src/common/windows_util.cpp deleted file mode 100644 index 8be7a3a04..000000000 --- a/src/duckdb/src/common/windows_util.cpp +++ /dev/null @@ -1,53 +0,0 @@ -#include "duckdb/common/windows_util.hpp" - -namespace duckdb { - -#ifdef DUCKDB_WINDOWS - -std::wstring WindowsUtil::UTF8ToUnicode(const char *input) { - idx_t result_size; - - result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, nullptr, 0); - if (result_size == 0) { - throw IOException("Failure in MultiByteToWideChar"); - } - auto buffer = make_unsafe_uniq_array(result_size); - result_size = MultiByteToWideChar(CP_UTF8, 0, input, -1, buffer.get(), result_size); - if (result_size == 0) { - throw IOException("Failure in MultiByteToWideChar"); - } - return std::wstring(buffer.get(), result_size); -} - -static string WideCharToMultiByteWrapper(LPCWSTR input, uint32_t code_page) { - idx_t result_size; - - result_size = WideCharToMultiByte(code_page, 0, input, -1, 0, 0, 0, 0); - if (result_size == 0) { - throw IOException("Failure in WideCharToMultiByte"); - } - auto buffer = make_unsafe_uniq_array(result_size); - result_size = WideCharToMultiByte(code_page, 0, input, -1, buffer.get(), result_size, 0, 0); - if (result_size == 0) { - throw IOException("Failure in WideCharToMultiByte"); - } - return string(buffer.get(), result_size - 1); -} - -string WindowsUtil::UnicodeToUTF8(LPCWSTR input) { - return WideCharToMultiByteWrapper(input, CP_UTF8); -} - -static string WindowsUnicodeToMBCS(LPCWSTR unicode_text, int use_ansi) { - uint32_t code_page = use_ansi ? CP_ACP : CP_OEMCP; - return WideCharToMultiByteWrapper(unicode_text, code_page); -} - -string WindowsUtil::UTF8ToMBCS(const char *input, bool use_ansi) { - auto unicode = WindowsUtil::UTF8ToUnicode(input); - return WindowsUnicodeToMBCS(unicode.c_str(), use_ansi); -} - -#endif - -} // namespace duckdb diff --git a/src/duckdb/src/execution/adaptive_filter.cpp b/src/duckdb/src/execution/adaptive_filter.cpp deleted file mode 100644 index 6077ae3ff..000000000 --- a/src/duckdb/src/execution/adaptive_filter.cpp +++ /dev/null @@ -1,113 +0,0 @@ -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/execution/adaptive_filter.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/vector.hpp" - -namespace duckdb { - -AdaptiveFilter::AdaptiveFilter(const Expression &expr) : observe_interval(10), execute_interval(20), warmup(true) { - auto &conj_expr = expr.Cast(); - D_ASSERT(conj_expr.children.size() > 1); - for (idx_t idx = 0; idx < conj_expr.children.size(); idx++) { - permutation.push_back(idx); - if (conj_expr.children[idx]->CanThrow()) { - disable_permutations = true; - } - if (idx != conj_expr.children.size() - 1) { - swap_likeliness.push_back(100); - } - } - right_random_border = 100 * (conj_expr.children.size() - 1); -} - -AdaptiveFilter::AdaptiveFilter(const TableFilterSet &table_filters) - : observe_interval(10), execute_interval(20), warmup(true) { - for (idx_t idx = 0; idx < table_filters.filters.size(); idx++) { - permutation.push_back(idx); - swap_likeliness.push_back(100); - } - swap_likeliness.pop_back(); - right_random_border = 100 * (table_filters.filters.size() - 1); -} - -AdaptiveFilterState AdaptiveFilter::BeginFilter() const { - if (permutation.size() <= 1 || disable_permutations) { - return AdaptiveFilterState(); - } - AdaptiveFilterState state; - state.start_time = high_resolution_clock::now(); - return state; -} - -void AdaptiveFilter::EndFilter(AdaptiveFilterState state) { - if (permutation.size() <= 1 || disable_permutations) { - // nothing to permute - return; - } - auto end_time = high_resolution_clock::now(); - AdaptRuntimeStatistics(duration_cast>(end_time - state.start_time).count()); -} - -void AdaptiveFilter::AdaptRuntimeStatistics(double duration) { - iteration_count++; - runtime_sum += duration; - - D_ASSERT(!disable_permutations); - if (!warmup) { - // the last swap was observed - if (observe && iteration_count == observe_interval) { - // keep swap if runtime decreased, else reverse swap - if (prev_mean - (runtime_sum / static_cast(iteration_count)) <= 0) { - // reverse swap because runtime didn't decrease - std::swap(permutation[swap_idx], permutation[swap_idx + 1]); - - // decrease swap likeliness, but make sure there is always a small likeliness left - if (swap_likeliness[swap_idx] > 1) { - swap_likeliness[swap_idx] /= 2; - } - } else { - // keep swap because runtime decreased, reset likeliness - swap_likeliness[swap_idx] = 100; - } - observe = false; - - // reset values - iteration_count = 0; - runtime_sum = 0.0; - } else if (!observe && iteration_count == execute_interval) { - // save old mean to evaluate swap - prev_mean = runtime_sum / static_cast(iteration_count); - - // get swap index and swap likeliness - // a <= i <= b - auto random_number = generator.NextRandomInteger(1, NumericCast(right_random_border)); - - swap_idx = random_number / 100; // index to be swapped - idx_t likeliness = random_number - 100 * swap_idx; // random number between [0, 100) - - // check if swap is going to happen - if (swap_likeliness[swap_idx] > likeliness) { // always true for the first swap of an index - // swap - std::swap(permutation[swap_idx], permutation[swap_idx + 1]); - - // observe whether swap will be applied - observe = true; - } - - // reset values - iteration_count = 0; - runtime_sum = 0.0; - } - } else { - if (iteration_count == 5) { - // initially set all values - iteration_count = 0; - runtime_sum = 0.0; - observe = false; - warmup = false; - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/aggregate_hashtable.cpp b/src/duckdb/src/execution/aggregate_hashtable.cpp deleted file mode 100644 index 198b77fac..000000000 --- a/src/duckdb/src/execution/aggregate_hashtable.cpp +++ /dev/null @@ -1,854 +0,0 @@ -#include "duckdb/execution/aggregate_hashtable.hpp" - -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/types/row/tuple_data_iterator.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/ht_entry.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -using ValidityBytes = TupleDataLayout::ValidityBytes; - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types, vector payload_types, - const vector &bindings, - idx_t initial_capacity, idx_t radix_bits) - : GroupedAggregateHashTable(context, allocator, std::move(group_types), std::move(payload_types), - AggregateObject::CreateAggregateObjects(bindings), initial_capacity, radix_bits) { -} - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types) - : GroupedAggregateHashTable(context, allocator, std::move(group_types), {}, vector()) { -} - -GroupedAggregateHashTable::AggregateHTAppendState::AggregateHTAppendState() - : ht_offsets(LogicalType::UBIGINT), hash_salts(LogicalType::HASH), group_compare_vector(STANDARD_VECTOR_SIZE), - no_match_vector(STANDARD_VECTOR_SIZE), empty_vector(STANDARD_VECTOR_SIZE), new_groups(STANDARD_VECTOR_SIZE), - addresses(LogicalType::POINTER) { -} - -GroupedAggregateHashTable::GroupedAggregateHashTable(ClientContext &context, Allocator &allocator, - vector group_types_p, - vector payload_types_p, - vector aggregate_objects_p, - idx_t initial_capacity, idx_t radix_bits) - : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - radix_bits(radix_bits), count(0), capacity(0), skip_lookups(false), - aggregate_allocator(make_shared_ptr(allocator)) { - - // Append hash column to the end and initialise the row layout - group_types_p.emplace_back(LogicalType::HASH); - layout.Initialize(std::move(group_types_p), std::move(aggregate_objects_p)); - - hash_offset = layout.GetOffsets()[layout.ColumnCount() - 1]; - - // Partitioned data and pointer table - InitializePartitionedData(); - if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD) { - InitializeUnpartitionedData(); - } - Resize(initial_capacity); - - // Predicates - predicates.resize(layout.ColumnCount() - 1, ExpressionType::COMPARE_NOT_DISTINCT_FROM); - row_matcher.Initialize(true, layout, predicates); -} - -void GroupedAggregateHashTable::InitializePartitionedData() { - if (!partitioned_data || - RadixPartitioning::RadixBitsOfPowerOfTwo(partitioned_data->PartitionCount()) != radix_bits) { - D_ASSERT(!partitioned_data || partitioned_data->Count() == 0); - partitioned_data = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); - } else { - partitioned_data->Reset(); - } - - D_ASSERT(GetLayout().GetAggrWidth() == layout.GetAggrWidth()); - D_ASSERT(GetLayout().GetDataWidth() == layout.GetDataWidth()); - D_ASSERT(GetLayout().GetRowWidth() == layout.GetRowWidth()); - - partitioned_data->InitializeAppendState(state.partitioned_append_state, - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); -} - -void GroupedAggregateHashTable::InitializeUnpartitionedData() { - D_ASSERT(radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD); - if (!unpartitioned_data) { - unpartitioned_data = - make_uniq(buffer_manager, layout, 0ULL, layout.ColumnCount() - 1); - } else { - unpartitioned_data->Reset(); - } - unpartitioned_data->InitializeAppendState(state.unpartitioned_append_state, - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); -} - -const PartitionedTupleData &GroupedAggregateHashTable::GetPartitionedData() const { - return *partitioned_data; -} - -unique_ptr GroupedAggregateHashTable::AcquirePartitionedData() { - // Flush/unpin partitioned data - partitioned_data->FlushAppendState(state.partitioned_append_state); - partitioned_data->Unpin(); - - if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD) { - // Flush/unpin unpartitioned data and append to partitioned data - if (unpartitioned_data) { - unpartitioned_data->FlushAppendState(state.unpartitioned_append_state); - unpartitioned_data->Unpin(); - unpartitioned_data->Repartition(*partitioned_data); - } - InitializeUnpartitionedData(); - } - - // Return and re-initialize - auto result = std::move(partitioned_data); - InitializePartitionedData(); - return result; -} - -void GroupedAggregateHashTable::Abandon() { - if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD) { - // Flush/unpin unpartitioned data and append to partitioned data - if (unpartitioned_data) { - unpartitioned_data->FlushAppendState(state.unpartitioned_append_state); - unpartitioned_data->Unpin(); - unpartitioned_data->Repartition(*partitioned_data); - } - InitializeUnpartitionedData(); - } - - // Start over - ClearPointerTable(); - count = 0; - - // Resetting the id ensures the dict state is reset properly when needed - state.dict_state.dictionary_id = string(); -} - -void GroupedAggregateHashTable::Repartition() { - auto old = AcquirePartitionedData(); - D_ASSERT(old->GetPartitions().size() != partitioned_data->GetPartitions().size()); - old->Repartition(*partitioned_data); -} - -shared_ptr GroupedAggregateHashTable::GetAggregateAllocator() { - return aggregate_allocator; -} - -GroupedAggregateHashTable::~GroupedAggregateHashTable() { - Destroy(); -} - -void GroupedAggregateHashTable::Destroy() { - if (!partitioned_data || partitioned_data->Count() == 0 || !layout.HasDestructor()) { - return; - } - - // There are aggregates with destructors: Call the destructor for each of the aggregates - // Currently does not happen because aggregate destructors are called while scanning in RadixPartitionedHashTable - // LCOV_EXCL_START - RowOperationsState row_state(*aggregate_allocator); - for (auto &data_collection : partitioned_data->GetPartitions()) { - if (data_collection->Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); - auto &row_locations = iterator.GetChunkState().row_locations; - do { - RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - data_collection->Reset(); - } - // LCOV_EXCL_STOP -} - -const TupleDataLayout &GroupedAggregateHashTable::GetLayout() const { - return partitioned_data->GetLayout(); -} - -idx_t GroupedAggregateHashTable::Count() const { - return count; -} - -idx_t GroupedAggregateHashTable::InitialCapacity() { - return STANDARD_VECTOR_SIZE * 2ULL; -} - -idx_t GroupedAggregateHashTable::GetCapacityForCount(idx_t count) { - count = MaxValue(InitialCapacity(), count); - return NextPowerOfTwo(LossyNumericCast(static_cast(count) * LOAD_FACTOR)); -} - -idx_t GroupedAggregateHashTable::Capacity() const { - return capacity; -} - -idx_t GroupedAggregateHashTable::ResizeThreshold() const { - return ResizeThreshold(Capacity()); -} - -idx_t GroupedAggregateHashTable::ResizeThreshold(const idx_t capacity) { - return LossyNumericCast(static_cast(capacity) / LOAD_FACTOR); -} - -idx_t GroupedAggregateHashTable::ApplyBitMask(hash_t hash) const { - return hash & bitmask; -} - -void GroupedAggregateHashTable::Verify() { -#ifdef DEBUG - if (skip_lookups) { - return; - } - idx_t total_count = 0; - for (idx_t i = 0; i < capacity; i++) { - const auto &entry = entries[i]; - if (!entry.IsOccupied()) { - continue; - } - auto hash = Load(entry.GetPointer() + hash_offset); - D_ASSERT(entry.GetSalt() == ht_entry_t::ExtractSalt(hash)); - total_count++; - } - D_ASSERT(total_count == Count()); -#endif -} - -void GroupedAggregateHashTable::ClearPointerTable() { - std::fill_n(entries, capacity, ht_entry_t()); -} - -void GroupedAggregateHashTable::SetRadixBits(idx_t radix_bits_p) { - radix_bits = radix_bits_p; -} - -idx_t GroupedAggregateHashTable::GetRadixBits() const { - return radix_bits; -} - -idx_t GroupedAggregateHashTable::GetSinkCount() const { - return sink_count; -} - -void GroupedAggregateHashTable::SkipLookups() { - skip_lookups = true; -} - -void GroupedAggregateHashTable::Resize(idx_t size) { - D_ASSERT(size >= STANDARD_VECTOR_SIZE); - D_ASSERT(IsPowerOfTwo(size)); - if (Count() != 0 && size < capacity) { - throw InternalException("Cannot downsize a non-empty hash table!"); - } - - capacity = size; - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); - ClearPointerTable(); - bitmask = capacity - 1; - - if (Count() != 0) { - ReinsertTuples(*partitioned_data); - if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD) { - ReinsertTuples(*unpartitioned_data); - } - } - - Verify(); -} - -void GroupedAggregateHashTable::ReinsertTuples(PartitionedTupleData &data) { - for (auto &data_collection : data.GetPartitions()) { - if (data_collection->Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::ALREADY_PINNED, false); - const auto row_locations = iterator.GetRowLocations(); - do { - for (idx_t i = 0; i < iterator.GetCurrentChunkCount(); i++) { - const auto &row_location = row_locations[i]; - const auto hash = Load(row_location + hash_offset); - - // Find an empty entry - auto ht_offset = ApplyBitMask(hash); - D_ASSERT(ht_offset == hash % capacity); - while (entries[ht_offset].IsOccupied()) { - IncrementAndWrap(ht_offset, bitmask); - } - auto &entry = entries[ht_offset]; - D_ASSERT(!entry.IsOccupied()); - entry.SetSalt(ht_entry_t::ExtractSalt(hash)); - entry.SetPointer(row_location); - D_ASSERT(entry.IsOccupied()); - } - } while (iterator.Next()); - } -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, AggregateType filter) { - unsafe_vector aggregate_filter; - - auto &aggregates = layout.GetAggregates(); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - if (aggregate.aggr_type == filter) { - aggregate_filter.push_back(i); - } - } - return AddChunk(groups, payload, aggregate_filter); -} - -GroupedAggregateHashTable::AggregateDictionaryState::AggregateDictionaryState() - : hashes(LogicalType::HASH), new_dictionary_pointers(LogicalType::POINTER), unique_entries(STANDARD_VECTOR_SIZE) { -} - -optional_idx GroupedAggregateHashTable::TryAddDictionaryGroups(DataChunk &groups, DataChunk &payload, - const unsafe_vector &filter) { - static constexpr idx_t MAX_DICTIONARY_SIZE_THRESHOLD = 20000; - static constexpr idx_t DICTIONARY_THRESHOLD = 2; - // dictionary vector - check if this is a duplicate eliminated dictionary from the storage - auto &dict_col = groups.data[0]; - auto opt_dict_size = DictionaryVector::DictionarySize(dict_col); - if (!opt_dict_size.IsValid()) { - // dict size not known - this is not a dictionary that comes from the storage - return optional_idx(); - } - idx_t dict_size = opt_dict_size.GetIndex(); - auto &dictionary_id = DictionaryVector::DictionaryId(dict_col); - if (dictionary_id.empty()) { - // dictionary has no id, we can't cache across vectors - // only use dictionary compression if there are fewer entries than groups - if (dict_size >= groups.size() * DICTIONARY_THRESHOLD) { - // dictionary is too large - use regular aggregation - return optional_idx(); - } - } else { - // dictionary has an id - we can cache across vectors - // use a much larger limit for dictionary - if (dict_size >= MAX_DICTIONARY_SIZE_THRESHOLD) { - // dictionary is too large - use regular aggregation - return optional_idx(); - } - } - auto &dictionary_vector = DictionaryVector::Child(dict_col); - auto &offsets = DictionaryVector::SelVector(dict_col); - auto &dict_state = state.dict_state; - if (dict_state.dictionary_id.empty() || dict_state.dictionary_id != dictionary_id) { - // new dictionary - initialize the index state - if (dict_size > dict_state.capacity) { - dict_state.dictionary_addresses = make_uniq(LogicalType::POINTER, dict_size); - dict_state.found_entry = make_unsafe_uniq_array(dict_size); - dict_state.capacity = dict_size; - } - memset(dict_state.found_entry.get(), 0, dict_size * sizeof(bool)); - dict_state.dictionary_id = dictionary_id; - } else if (dict_size > dict_state.capacity) { - throw InternalException("AggregateHT - using cached dictionary data but dictionary has changed (dictionary id " - "%s - dict size %d, current capacity %d)", - dict_state.dictionary_id, dict_size, dict_state.capacity); - } - - auto &found_entry = dict_state.found_entry; - auto &unique_entries = dict_state.unique_entries; - idx_t unique_count = 0; - // for each of the dictionary entries - check if we have already done a look-up into the hash table - // if we have, we can just use the cached group pointers - for (idx_t i = 0; i < groups.size(); i++) { - auto dict_idx = offsets.get_index(i); - unique_entries.set_index(unique_count, dict_idx); - unique_count += !found_entry[dict_idx]; - found_entry[dict_idx] = true; - } - auto &new_dictionary_pointers = dict_state.new_dictionary_pointers; - idx_t new_group_count = 0; - if (unique_count > 0) { - auto &unique_values = dict_state.unique_values; - if (unique_values.ColumnCount() == 0) { - unique_values.InitializeEmpty(groups.GetTypes()); - } - // slice the dictionary - unique_values.data[0].Slice(dictionary_vector, unique_entries, unique_count); - unique_values.SetCardinality(unique_count); - // now we know which entries we are going to add - hash them - auto &hashes = dict_state.hashes; - unique_values.Hash(hashes); - - // add the dictionary groups to the hash table - new_group_count = FindOrCreateGroups(unique_values, hashes, new_dictionary_pointers, state.new_groups); - } - auto &aggregates = layout.GetAggregates(); - if (aggregates.empty()) { - // early-out - no aggregates to update - return new_group_count; - } - - // set the addresses that we found for each of the unique groups in the main addresses vector - auto new_dict_addresses = FlatVector::GetData(new_dictionary_pointers); - // for each of the new groups, add them to the global (cached) list of addresses for the dictionary - auto &dictionary_addresses = *dict_state.dictionary_addresses; - auto dict_addresses = FlatVector::GetData(dictionary_addresses); - for (idx_t i = 0; i < unique_count; i++) { - auto dict_idx = unique_entries.get_index(i); - dict_addresses[dict_idx] = new_dict_addresses[i] + layout.GetAggrOffset(); - } - // now set up the addresses for the aggregates - auto result_addresses = FlatVector::GetData(state.addresses); - for (idx_t i = 0; i < groups.size(); i++) { - auto dict_idx = offsets.get_index(i); - result_addresses[i] = dict_addresses[dict_idx]; - } - - // finally process the aggregates - UpdateAggregates(payload, filter); - - return new_group_count; -} - -optional_idx GroupedAggregateHashTable::TryAddConstantGroups(DataChunk &groups, DataChunk &payload, - const unsafe_vector &filter) { -#ifndef DEBUG - if (groups.size() <= 1) { - // this only has a point if we have multiple groups - return optional_idx(); - } -#endif - auto &dict_state = state.dict_state; - auto &unique_values = dict_state.unique_values; - if (unique_values.ColumnCount() == 0) { - unique_values.InitializeEmpty(groups.GetTypes()); - } - // slice the dictionary - unique_values.Reference(groups); - unique_values.SetCardinality(1); - unique_values.Flatten(); - - auto &hashes = dict_state.hashes; - unique_values.Hash(hashes); - - // add the single constant group to the hash table - auto &new_dictionary_pointers = dict_state.new_dictionary_pointers; - auto new_group_count = FindOrCreateGroups(unique_values, hashes, new_dictionary_pointers, state.new_groups); - - auto &aggregates = layout.GetAggregates(); - if (aggregates.empty()) { - // early-out - no aggregates to update - return new_group_count; - } - - auto new_dict_addresses = FlatVector::GetData(new_dictionary_pointers); - auto result_addresses = FlatVector::GetData(state.addresses); - uintptr_t aggregate_address = new_dict_addresses[0] + layout.GetAggrOffset(); - for (idx_t i = 0; i < payload.size(); i++) { - result_addresses[i] = aggregate_address; - } - - // process the aggregates - // FIXME: we can use simple_update here if the aggregates support it - UpdateAggregates(payload, filter); - - return new_group_count; -} - -optional_idx GroupedAggregateHashTable::TryAddCompressedGroups(DataChunk &groups, DataChunk &payload, - const unsafe_vector &filter) { - // all groups must be compressed - if (groups.AllConstant()) { - return TryAddConstantGroups(groups, payload, filter); - } - if (groups.ColumnCount() == 1 && groups.data[0].GetVectorType() == VectorType::DICTIONARY_VECTOR) { - return TryAddDictionaryGroups(groups, payload, filter); - } - return optional_idx(); -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload, const unsafe_vector &filter) { - sink_count += groups.size(); - - // check if we can use an optimized path that utilizes compressed vectors - auto result = TryAddCompressedGroups(groups, payload, filter); - if (result.IsValid()) { - return result.GetIndex(); - } - // otherwise append the raw values - Vector hashes(LogicalType::HASH); - groups.Hash(hashes); - - return AddChunk(groups, hashes, payload, filter); -} - -void GroupedAggregateHashTable::UpdateAggregates(DataChunk &payload, const unsafe_vector &filter) { - // Now every cell has an entry, update the aggregates - auto &aggregates = layout.GetAggregates(); - idx_t filter_idx = 0; - idx_t payload_idx = 0; - RowOperationsState row_state(*aggregate_allocator); - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggr = aggregates[i]; - if (filter_idx >= filter.size() || i < filter[filter_idx]) { - // Skip all the aggregates that are not in the filter - payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); - continue; - } - D_ASSERT(i == filter[filter_idx]); - - if (aggr.aggr_type != AggregateType::DISTINCT && aggr.filter) { - RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(i), aggr, state.addresses, payload, - payload_idx); - } else { - RowOperations::UpdateStates(row_state, aggr, state.addresses, payload, payload_idx, payload.size()); - } - - // Move to the next aggregate - payload_idx += aggr.child_count; - VectorOperations::AddInPlace(state.addresses, NumericCast(aggr.payload_size), payload.size()); - filter_idx++; - } - - Verify(); -} - -idx_t GroupedAggregateHashTable::AddChunk(DataChunk &groups, Vector &group_hashes, DataChunk &payload, - const unsafe_vector &filter) { - if (groups.size() == 0) { - return 0; - } - -#ifdef DEBUG - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - for (idx_t i = 0; i < groups.ColumnCount(); i++) { - D_ASSERT(groups.GetTypes()[i] == layout.GetTypes()[i]); - } -#endif - - const auto new_group_count = FindOrCreateGroups(groups, group_hashes, state.addresses, state.new_groups); - VectorOperations::AddInPlace(state.addresses, NumericCast(layout.GetAggrOffset()), payload.size()); - - UpdateAggregates(payload, filter); - - return new_group_count; -} - -void GroupedAggregateHashTable::FetchAggregates(DataChunk &groups, DataChunk &result) { -#ifdef DEBUG - groups.Verify(); - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - for (idx_t i = 0; i < result.ColumnCount(); i++) { - D_ASSERT(result.data[i].GetType() == payload_types[i]); - } -#endif - - result.SetCardinality(groups); - if (groups.size() == 0) { - return; - } - - // find the groups associated with the addresses - // FIXME: this should not use the FindOrCreateGroups, creating them is unnecessary - Vector addresses(LogicalType::POINTER); - FindOrCreateGroups(groups, addresses); - // now fetch the aggregates - RowOperationsState row_state(*aggregate_allocator); - RowOperations::FinalizeStates(row_state, layout, addresses, result, 0); -} - -idx_t GroupedAggregateHashTable::FindOrCreateGroupsInternal(DataChunk &groups, Vector &group_hashes_v, - Vector &addresses_v, SelectionVector &new_groups_out) { - D_ASSERT(groups.ColumnCount() + 1 == layout.ColumnCount()); - D_ASSERT(group_hashes_v.GetType() == LogicalType::HASH); - D_ASSERT(state.ht_offsets.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(state.ht_offsets.GetType() == LogicalType::UBIGINT); - D_ASSERT(addresses_v.GetType() == LogicalType::POINTER); - D_ASSERT(state.hash_salts.GetType() == LogicalType::HASH); - - // Need to fit the entire vector, and resize at threshold - const auto chunk_size = groups.size(); - if (Count() + chunk_size > capacity || Count() + chunk_size > ResizeThreshold()) { - Verify(); - Resize(capacity * 2); - } - D_ASSERT(capacity - Count() >= chunk_size); // we need to be able to fit at least one vector of data - - // we start out with all entries [0, 1, 2, ..., chunk_size] - const SelectionVector *sel_vector = FlatVector::IncrementalSelectionVector(); - - // Make a chunk that references the groups and the hashes and convert to unified format - if (state.group_chunk.ColumnCount() == 0) { - state.group_chunk.InitializeEmpty(layout.GetTypes()); - } - D_ASSERT(state.group_chunk.ColumnCount() == layout.GetTypes().size()); - for (idx_t grp_idx = 0; grp_idx < groups.ColumnCount(); grp_idx++) { - state.group_chunk.data[grp_idx].Reference(groups.data[grp_idx]); - } - state.group_chunk.data[groups.ColumnCount()].Reference(group_hashes_v); - state.group_chunk.SetCardinality(groups); - - // convert all vectors to unified format - TupleDataCollection::ToUnifiedFormat(state.partitioned_append_state.chunk_state, state.group_chunk); - if (!state.group_data) { - state.group_data = make_unsafe_uniq_array_uninitialized(state.group_chunk.ColumnCount()); - } - TupleDataCollection::GetVectorData(state.partitioned_append_state.chunk_state, state.group_data.get()); - - group_hashes_v.Flatten(chunk_size); - const auto hashes = FlatVector::GetData(group_hashes_v); - - addresses_v.Flatten(chunk_size); - const auto addresses = FlatVector::GetData(addresses_v); - - if (skip_lookups) { - // Just appending now - partitioned_data->AppendUnified(state.partitioned_append_state, state.group_chunk, - *FlatVector::IncrementalSelectionVector(), chunk_size); - RowOperations::InitializeStates(layout, state.partitioned_append_state.chunk_state.row_locations, - *FlatVector::IncrementalSelectionVector(), chunk_size); - - const auto row_locations = - FlatVector::GetData(state.partitioned_append_state.chunk_state.row_locations); - const auto &row_sel = state.partitioned_append_state.reverse_partition_sel; - for (idx_t i = 0; i < chunk_size; i++) { - const auto &row_idx = row_sel[i]; - const auto &row_location = row_locations[row_idx]; - addresses[i] = row_location; - } - count += chunk_size; - return chunk_size; - } - - // Compute the entry in the table based on the hash using a modulo, - // and precompute the hash salts for faster comparison below - const auto ht_offsets = FlatVector::GetData(state.ht_offsets); - const auto hash_salts = FlatVector::GetData(state.hash_salts); - - // We also compute the occupied count, which is essentially useless. - // However, this loop is branchless, while the main lookup loop below is not. - // So, by doing the lookups here, we better amortize cache misses. - idx_t occupied_count = 0; - for (idx_t r = 0; r < chunk_size; r++) { - const auto &hash = hashes[r]; - auto &ht_offset = ht_offsets[r]; - ht_offset = ApplyBitMask(hash); - occupied_count += entries[ht_offset].IsOccupied(); // Lookup - D_ASSERT(ht_offset == hash % capacity); - hash_salts[r] = ht_entry_t::ExtractSalt(hash); - } - - idx_t new_group_count = 0; - idx_t remaining_entries = chunk_size; - idx_t iteration_count; - for (iteration_count = 0; remaining_entries > 0 && iteration_count < capacity; iteration_count++) { - idx_t new_entry_count = 0; - idx_t need_compare_count = 0; - idx_t no_match_count = 0; - - // For each remaining entry, figure out whether or not it belongs to a full or empty group - for (idx_t i = 0; i < remaining_entries; i++) { - const auto index = sel_vector->get_index(i); - const auto salt = hash_salts[index]; - auto &ht_offset = ht_offsets[index]; - - idx_t inner_iteration_count; - for (inner_iteration_count = 0; inner_iteration_count < capacity; inner_iteration_count++) { - auto &entry = entries[ht_offset]; - if (!entry.IsOccupied()) { // Unoccupied: claim it - entry.SetSalt(salt); - state.empty_vector.set_index(new_entry_count++, index); - new_groups_out.set_index(new_group_count++, index); - break; - } - - if (DUCKDB_LIKELY(entry.GetSalt() == salt)) { // Matching salt: compare groups - state.group_compare_vector.set_index(need_compare_count++, index); - break; - } - - // Linear probing - IncrementAndWrap(ht_offset, bitmask); - } - if (DUCKDB_UNLIKELY(inner_iteration_count == capacity)) { - throw InternalException("Maximum inner iteration count reached in GroupedAggregateHashTable"); - } - } - - if (DUCKDB_UNLIKELY(occupied_count > new_entry_count + need_compare_count)) { - // We use the useless occupied_count we summed above here so the variable is used, - // and the compiler cannot optimize away the vectorized lookups above. This should never be triggered. - throw InternalException("Internal validation failed in GroupedAggregateHashTable"); - } - occupied_count = 0; // Have to set to 0 for next iterations - - if (new_entry_count != 0) { - // Append everything that belongs to an empty group - optional_ptr data; - optional_ptr append_state; - if (radix_bits >= UNPARTITIONED_RADIX_BITS_THRESHOLD && - new_entry_count / RadixPartitioning::NumberOfPartitions(radix_bits) <= 4) { - TupleDataCollection::ToUnifiedFormat(state.unpartitioned_append_state.chunk_state, state.group_chunk); - data = unpartitioned_data.get(); - append_state = &state.unpartitioned_append_state; - } else { - data = partitioned_data.get(); - append_state = &state.partitioned_append_state; - } - data->AppendUnified(*append_state, state.group_chunk, state.empty_vector, new_entry_count); - RowOperations::InitializeStates(layout, append_state->chunk_state.row_locations, - *FlatVector::IncrementalSelectionVector(), new_entry_count); - - // Set the entry pointers in the 1st part of the HT now that the data has been appended - const auto row_locations = FlatVector::GetData(append_state->chunk_state.row_locations); - const auto &row_sel = append_state->reverse_partition_sel; - for (idx_t new_entry_idx = 0; new_entry_idx < new_entry_count; new_entry_idx++) { - const auto &index = state.empty_vector[new_entry_idx]; - const auto &row_idx = row_sel[index]; - const auto &row_location = row_locations[row_idx]; - - auto &entry = entries[ht_offsets[index]]; - - entry.SetPointer(row_location); - addresses[index] = row_location; - } - } - - if (need_compare_count != 0) { - // Get the pointers to the rows that need to be compared - for (idx_t need_compare_idx = 0; need_compare_idx < need_compare_count; need_compare_idx++) { - const auto &index = state.group_compare_vector[need_compare_idx]; - const auto &entry = entries[ht_offsets[index]]; - addresses[index] = entry.GetPointer(); - } - - // Perform group comparisons - row_matcher.Match(state.group_chunk, state.partitioned_append_state.chunk_state.vector_data, - state.group_compare_vector, need_compare_count, layout, addresses_v, - &state.no_match_vector, no_match_count); - } - - // Linear probing: each of the entries that do not match move to the next entry in the HT - for (idx_t i = 0; i < no_match_count; i++) { - const auto &index = state.no_match_vector[i]; - auto &ht_offset = ht_offsets[index]; - IncrementAndWrap(ht_offset, bitmask); - } - sel_vector = &state.no_match_vector; - remaining_entries = no_match_count; - } - if (iteration_count == capacity) { - throw InternalException("Maximum outer iteration count reached in GroupedAggregateHashTable"); - } - - count += new_group_count; - return new_group_count; -} - -// this is to support distinct aggregations where we need to record whether we -// have already seen a value for a group -idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &group_hashes, Vector &addresses_out, - SelectionVector &new_groups_out) { - return FindOrCreateGroupsInternal(groups, group_hashes, addresses_out, new_groups_out); -} - -void GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses) { - // create a dummy new_groups sel vector - FindOrCreateGroups(groups, addresses, state.new_groups); -} - -idx_t GroupedAggregateHashTable::FindOrCreateGroups(DataChunk &groups, Vector &addresses_out, - SelectionVector &new_groups_out) { - Vector hashes(LogicalType::HASH); - groups.Hash(hashes); - return FindOrCreateGroups(groups, hashes, addresses_out, new_groups_out); -} - -struct FlushMoveState { - explicit FlushMoveState(TupleDataCollection &collection_p) - : collection(collection_p), hashes(LogicalType::HASH), group_addresses(LogicalType::POINTER), - new_groups_sel(STANDARD_VECTOR_SIZE) { - const auto &layout = collection.GetLayout(); - vector column_ids; - column_ids.reserve(layout.ColumnCount() - 1); - for (idx_t col_idx = 0; col_idx < layout.ColumnCount() - 1; col_idx++) { - column_ids.emplace_back(col_idx); - } - collection.InitializeScan(scan_state, column_ids, TupleDataPinProperties::DESTROY_AFTER_DONE); - collection.InitializeScanChunk(scan_state, groups); - hash_col_idx = layout.ColumnCount() - 1; - } - - bool Scan() { - if (collection.Scan(scan_state, groups)) { - collection.Gather(scan_state.chunk_state.row_locations, *FlatVector::IncrementalSelectionVector(), - groups.size(), hash_col_idx, hashes, *FlatVector::IncrementalSelectionVector(), nullptr); - return true; - } - - collection.FinalizePinState(scan_state.pin_state); - return false; - } - - TupleDataCollection &collection; - TupleDataScanState scan_state; - DataChunk groups; - - idx_t hash_col_idx; - Vector hashes; - - Vector group_addresses; - SelectionVector new_groups_sel; -}; - -void GroupedAggregateHashTable::Combine(GroupedAggregateHashTable &other) { - auto other_partitioned_data = other.AcquirePartitionedData(); - auto other_data = other_partitioned_data->GetUnpartitioned(); - Combine(*other_data); - - // Inherit ownership to all stored aggregate allocators - stored_allocators.emplace_back(other.aggregate_allocator); - for (const auto &stored_allocator : other.stored_allocators) { - stored_allocators.emplace_back(stored_allocator); - } -} - -void GroupedAggregateHashTable::Combine(TupleDataCollection &other_data, optional_ptr> progress) { - D_ASSERT(other_data.GetLayout().GetAggrWidth() == layout.GetAggrWidth()); - D_ASSERT(other_data.GetLayout().GetDataWidth() == layout.GetDataWidth()); - D_ASSERT(other_data.GetLayout().GetRowWidth() == layout.GetRowWidth()); - - if (other_data.Count() == 0) { - return; - } - - FlushMoveState fm_state(other_data); - RowOperationsState row_state(*aggregate_allocator); - - idx_t chunk_idx = 0; - const auto chunk_count = other_data.ChunkCount(); - while (fm_state.Scan()) { - const auto input_chunk_size = fm_state.groups.size(); - FindOrCreateGroups(fm_state.groups, fm_state.hashes, fm_state.group_addresses, fm_state.new_groups_sel); - RowOperations::CombineStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, - fm_state.group_addresses, input_chunk_size); - if (layout.HasDestructor()) { - RowOperations::DestroyStates(row_state, layout, fm_state.scan_state.chunk_state.row_locations, - input_chunk_size); - } - - if (progress) { - *progress = static_cast(++chunk_idx) / static_cast(chunk_count); - } - } - - Verify(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/base_aggregate_hashtable.cpp b/src/duckdb/src/execution/base_aggregate_hashtable.cpp deleted file mode 100644 index eec99f9e3..000000000 --- a/src/duckdb/src/execution/base_aggregate_hashtable.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/execution/base_aggregate_hashtable.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -BaseAggregateHashTable::BaseAggregateHashTable(ClientContext &context, Allocator &allocator, - const vector &aggregates, - vector payload_types_p) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), - payload_types(std::move(payload_types_p)) { - filter_set.Initialize(context, aggregates, payload_types); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/column_binding_resolver.cpp b/src/duckdb/src/execution/column_binding_resolver.cpp deleted file mode 100644 index 3a931a3f5..000000000 --- a/src/duckdb/src/execution/column_binding_resolver.cpp +++ /dev/null @@ -1,202 +0,0 @@ -#include "duckdb/execution/column_binding_resolver.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/planner/expression/bound_columnref_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_any_join.hpp" -#include "duckdb/planner/operator/logical_comparison_join.hpp" -#include "duckdb/planner/operator/logical_create_index.hpp" -#include "duckdb/planner/operator/logical_extension_operator.hpp" -#include "duckdb/planner/operator/logical_insert.hpp" - -namespace duckdb { - -ColumnBindingResolver::ColumnBindingResolver(bool verify_only) : verify_only(verify_only) { -} - -void ColumnBindingResolver::VisitOperator(LogicalOperator &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: { - // special case: comparison join - auto &comp_join = op.Cast(); - // first get the bindings of the LHS and resolve the LHS expressions - VisitOperator(*comp_join.children[0]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.left); - } - // visit the duplicate eliminated columns on the LHS, if any - for (auto &expr : comp_join.duplicate_eliminated_columns) { - VisitExpression(&expr); - } - // then get the bindings of the RHS and resolve the RHS expressions - VisitOperator(*comp_join.children[1]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.right); - } - // finally update the bindings with the result bindings of the join - bindings = op.GetColumnBindings(); - return; - } - case LogicalOperatorType::LOGICAL_DELIM_JOIN: { - auto &comp_join = op.Cast(); - // depending on whether the delim join has been flipped, get the appropriate bindings - if (comp_join.delim_flipped) { - VisitOperator(*comp_join.children[1]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.right); - } - } else { - VisitOperator(*comp_join.children[0]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.left); - } - } - // visit the duplicate eliminated columns - for (auto &expr : comp_join.duplicate_eliminated_columns) { - VisitExpression(&expr); - } - // now get the other side - if (comp_join.delim_flipped) { - VisitOperator(*comp_join.children[0]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.left); - } - } else { - VisitOperator(*comp_join.children[1]); - for (auto &cond : comp_join.conditions) { - VisitExpression(&cond.right); - } - } - // finally update the bindings with the result bindings of the join - bindings = op.GetColumnBindings(); - return; - } - case LogicalOperatorType::LOGICAL_ANY_JOIN: { - // ANY join, this join is different because we evaluate the expression on the bindings of BOTH join sides at - // once i.e. we set the bindings first to the bindings of the entire join, and then resolve the expressions of - // this operator - VisitOperatorChildren(op); - bindings = op.GetColumnBindings(); - auto &any_join = op.Cast(); - if (any_join.join_type == JoinType::SEMI || any_join.join_type == JoinType::ANTI) { - auto right_bindings = op.children[1]->GetColumnBindings(); - bindings.insert(bindings.end(), right_bindings.begin(), right_bindings.end()); - } - if (any_join.join_type == JoinType::RIGHT_SEMI || any_join.join_type == JoinType::RIGHT_ANTI) { - throw InternalException("RIGHT SEMI/ANTI any join not supported yet"); - } - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_CREATE_INDEX: { - // CREATE INDEX statement, add the columns of the table with table index 0 to the binding set - // afterwards bind the expressions of the CREATE INDEX statement - auto &create_index = op.Cast(); - bindings = LogicalOperator::GenerateColumnBindings(0, create_index.table.GetColumns().LogicalColumnCount()); - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_GET: { - //! We first need to update the current set of bindings and then visit operator expressions - bindings = op.GetColumnBindings(); - VisitOperatorExpressions(op); - return; - } - case LogicalOperatorType::LOGICAL_INSERT: { - //! We want to execute the normal path, but also add a dummy 'excluded' binding if there is a - // ON CONFLICT DO UPDATE clause - auto &insert_op = op.Cast(); - if (insert_op.action_type != OnConflictAction::THROW) { - // Get the bindings from the children - VisitOperatorChildren(op); - auto column_count = insert_op.table.GetColumns().PhysicalColumnCount(); - auto dummy_bindings = LogicalOperator::GenerateColumnBindings(insert_op.excluded_table_index, column_count); - // Now insert our dummy bindings at the start of the bindings, - // so the first 'column_count' indices of the chunk are reserved for our 'excluded' columns - bindings.insert(bindings.begin(), dummy_bindings.begin(), dummy_bindings.end()); - if (insert_op.on_conflict_condition) { - VisitExpression(&insert_op.on_conflict_condition); - } - if (insert_op.do_update_condition) { - VisitExpression(&insert_op.do_update_condition); - } - VisitOperatorExpressions(op); - bindings = op.GetColumnBindings(); - return; - } - break; - } - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: { - auto &ext_op = op.Cast(); - ext_op.ResolveColumnBindings(*this, bindings); - return; - } - default: - break; - } - - // general case - // first visit the children of this operator - VisitOperatorChildren(op); - // now visit the expressions of this operator to resolve any bound column references - VisitOperatorExpressions(op); - // finally update the current set of bindings to the current set of column bindings - bindings = op.GetColumnBindings(); -} - -unique_ptr ColumnBindingResolver::VisitReplace(BoundColumnRefExpression &expr, - unique_ptr *expr_ptr) { - D_ASSERT(expr.depth == 0); - // check the current set of column bindings to see which index corresponds to the column reference - for (idx_t i = 0; i < bindings.size(); i++) { - if (expr.binding == bindings[i]) { - if (verify_only) { - // in verification mode - return nullptr; - } - return make_uniq(expr.GetAlias(), expr.return_type, i); - } - } - // LCOV_EXCL_START - // could not bind the column reference, this should never happen and indicates a bug in the code - // generate an error message - throw InternalException("Failed to bind column reference \"%s\" [%d.%d] (bindings: %s)", expr.GetAlias(), - expr.binding.table_index, expr.binding.column_index, - LogicalOperator::ColumnBindingsToString(bindings)); - // LCOV_EXCL_STOP -} - -unordered_set ColumnBindingResolver::VerifyInternal(LogicalOperator &op) { - unordered_set result; - for (auto &child : op.children) { - auto child_indexes = VerifyInternal(*child); - for (auto index : child_indexes) { - D_ASSERT(index != DConstants::INVALID_INDEX); - if (result.find(index) != result.end()) { - throw InternalException("Duplicate table index \"%lld\" found", index); - } - result.insert(index); - } - } - auto indexes = op.GetTableIndex(); - for (auto index : indexes) { - D_ASSERT(index != DConstants::INVALID_INDEX); - if (result.find(index) != result.end()) { - throw InternalException("Duplicate table index \"%lld\" found", index); - } - result.insert(index); - } - return result; -} - -void ColumnBindingResolver::Verify(LogicalOperator &op) { -#ifdef DEBUG - ColumnBindingResolver resolver(true); - resolver.VisitOperator(op); - VerifyInternal(op); -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor.cpp b/src/duckdb/src/execution/expression_executor.cpp deleted file mode 100644 index 567deb326..000000000 --- a/src/duckdb/src/execution/expression_executor.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" - -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/execution_context.hpp" -#include "duckdb/storage/statistics/base_statistics.hpp" -#include "duckdb/planner/expression/list.hpp" - -namespace duckdb { - -ExpressionExecutor::ExpressionExecutor(ClientContext &context) : context(&context) { -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression *expression) - : ExpressionExecutor(context) { - D_ASSERT(expression); - AddExpression(*expression); -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const Expression &expression) - : ExpressionExecutor(context) { - AddExpression(expression); -} - -ExpressionExecutor::ExpressionExecutor(ClientContext &context, const vector> &exprs) - : ExpressionExecutor(context) { - D_ASSERT(exprs.size() > 0); - for (auto &expr : exprs) { - AddExpression(*expr); - } -} - -ExpressionExecutor::ExpressionExecutor(const vector> &exprs) : context(nullptr) { - D_ASSERT(exprs.size() > 0); - for (auto &expr : exprs) { - AddExpression(*expr); - } -} - -ExpressionExecutor::ExpressionExecutor() : context(nullptr) { -} - -bool ExpressionExecutor::HasContext() { - return context; -} - -ClientContext &ExpressionExecutor::GetContext() { - if (!context) { - throw InternalException("Calling ExpressionExecutor::GetContext on an expression executor without a context"); - } - return *context; -} - -Allocator &ExpressionExecutor::GetAllocator() { - return context ? Allocator::Get(*context) : Allocator::DefaultAllocator(); -} - -void ExpressionExecutor::AddExpression(const Expression &expr) { - expressions.push_back(&expr); - auto state = make_uniq(); - Initialize(expr, *state); - state->Verify(); - states.push_back(std::move(state)); -} - -void ExpressionExecutor::Initialize(const Expression &expression, ExpressionExecutorState &state) { - state.executor = this; - state.root_state = InitializeState(expression, state); -} - -void ExpressionExecutor::Execute(DataChunk *input, DataChunk &result) { - SetChunk(input); - D_ASSERT(expressions.size() == result.ColumnCount()); - D_ASSERT(!expressions.empty()); - - for (idx_t i = 0; i < expressions.size(); i++) { - ExecuteExpression(i, result.data[i]); - } - result.SetCardinality(input ? input->size() : 1); - result.Verify(); -} - -void ExpressionExecutor::ExecuteExpression(DataChunk &input, Vector &result) { - SetChunk(&input); - ExecuteExpression(result); -} - -idx_t ExpressionExecutor::SelectExpression(DataChunk &input, SelectionVector &sel) { - D_ASSERT(expressions.size() == 1); - SetChunk(&input); - idx_t selected_tuples = Select(*expressions[0], states[0]->root_state.get(), nullptr, input.size(), &sel, nullptr); - return selected_tuples; -} - -void ExpressionExecutor::ExecuteExpression(Vector &result) { - D_ASSERT(expressions.size() == 1); - ExecuteExpression(0, result); -} - -void ExpressionExecutor::ExecuteExpression(idx_t expr_idx, Vector &result) { - D_ASSERT(expr_idx < expressions.size()); - D_ASSERT(result.GetType().id() == expressions[expr_idx]->return_type.id()); - Execute(*expressions[expr_idx], states[expr_idx]->root_state.get(), nullptr, chunk ? chunk->size() : 1, result); -} - -Value ExpressionExecutor::EvaluateScalar(ClientContext &context, const Expression &expr, bool allow_unfoldable) { - D_ASSERT(allow_unfoldable || expr.IsFoldable()); - D_ASSERT(expr.IsScalar()); - // use an ExpressionExecutor to execute the expression - ExpressionExecutor executor(context, expr); - - Vector result(expr.return_type); - executor.ExecuteExpression(result); - - D_ASSERT(allow_unfoldable || result.GetVectorType() == VectorType::CONSTANT_VECTOR); - auto result_value = result.GetValue(0); - D_ASSERT(result_value.type().InternalType() == expr.return_type.InternalType()); - return result_value; -} - -bool ExpressionExecutor::TryEvaluateScalar(ClientContext &context, const Expression &expr, Value &result) { - try { - result = EvaluateScalar(context, expr); - return true; - } catch (InternalException &ex) { - throw; - } catch (...) { - return false; - } -} - -void ExpressionExecutor::Verify(const Expression &expr, Vector &vector, idx_t count) { - D_ASSERT(expr.return_type.id() == vector.GetType().id()); - vector.Verify(count); - if (expr.verification_stats) { - expr.verification_stats->Verify(vector, count); - } -#ifdef DUCKDB_VERIFY_DICTIONARY_EXPRESSION - Vector::DebugTransformToDictionary(vector, count); -#endif -} - -unique_ptr ExpressionExecutor::InitializeState(const Expression &expr, - ExpressionExecutorState &state) { - switch (expr.GetExpressionClass()) { - case ExpressionClass::BOUND_REF: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_BETWEEN: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CASE: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CAST: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_COMPARISON: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CONJUNCTION: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_CONSTANT: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_FUNCTION: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_OPERATOR: - return InitializeState(expr.Cast(), state); - case ExpressionClass::BOUND_PARAMETER: - return InitializeState(expr.Cast(), state); - default: - throw InternalException("Attempting to initialize state of expression of unknown type!"); - } -} - -void ExpressionExecutor::Execute(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { -#ifdef DEBUG - // The result vector must be used for the first time, or must be reset. - // Otherwise, the validity mask can contain previous (now incorrect) data. - if (result.GetVectorType() == VectorType::FLAT_VECTOR) { - - // We do not initialize vector caches for these expressions. - if (expr.GetExpressionClass() != ExpressionClass::BOUND_REF && - expr.GetExpressionClass() != ExpressionClass::BOUND_CONSTANT && - expr.GetExpressionClass() != ExpressionClass::BOUND_PARAMETER) { - D_ASSERT(FlatVector::Validity(result).CheckAllValid(count)); - } - } -#endif - - if (count == 0) { - return; - } - if (result.GetType().id() != expr.return_type.id()) { - throw InternalException( - "ExpressionExecutor::Execute called with a result vector of type %s that does not match expression type %s", - result.GetType(), expr.return_type); - } - switch (expr.GetExpressionClass()) { - case ExpressionClass::BOUND_BETWEEN: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_REF: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CASE: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CAST: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_COMPARISON: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CONJUNCTION: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_CONSTANT: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_FUNCTION: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_OPERATOR: - Execute(expr.Cast(), state, sel, count, result); - break; - case ExpressionClass::BOUND_PARAMETER: - Execute(expr.Cast(), state, sel, count, result); - break; - default: - throw InternalException("Attempting to execute expression of unknown type!"); - } - Verify(expr, result, count); -} - -idx_t ExpressionExecutor::Select(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - if (count == 0) { - return 0; - } - D_ASSERT(true_sel || false_sel); - D_ASSERT(expr.return_type.id() == LogicalTypeId::BOOLEAN); - switch (expr.GetExpressionClass()) { -#ifndef DUCKDB_SMALLER_BINARY - case ExpressionClass::BOUND_BETWEEN: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); -#endif - case ExpressionClass::BOUND_COMPARISON: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); - case ExpressionClass::BOUND_CONJUNCTION: - return Select(expr.Cast(), state, sel, count, true_sel, false_sel); - default: - return DefaultSelect(expr, state, sel, count, true_sel, false_sel); - } -} - -template -static inline idx_t DefaultSelectLoop(const SelectionVector *bsel, const uint8_t *__restrict bdata, ValidityMask &mask, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - idx_t true_count = 0, false_count = 0; - for (idx_t i = 0; i < count; i++) { - auto bidx = bsel->get_index(i); - auto result_idx = sel->get_index(i); - if ((NO_NULL || mask.RowIsValid(bidx)) && bdata[bidx] > 0) { - if (HAS_TRUE_SEL) { - true_sel->set_index(true_count++, result_idx); - } - } else { - if (HAS_FALSE_SEL) { - false_sel->set_index(false_count++, result_idx); - } - } - } - if (HAS_TRUE_SEL) { - return true_count; - } else { - return count - false_count; - } -} - -template -static inline idx_t DefaultSelectSwitch(UnifiedVectorFormat &idata, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - if (true_sel && false_sel) { - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } else if (true_sel) { - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } else { - D_ASSERT(false_sel); - return DefaultSelectLoop(idata.sel, UnifiedVectorFormat::GetData(idata), - idata.validity, sel, count, true_sel, false_sel); - } -} - -idx_t ExpressionExecutor::DefaultSelect(const Expression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { - // generic selection of boolean expression: - // resolve the true/false expression first - // then use that to generate the selection vector - bool intermediate_bools[STANDARD_VECTOR_SIZE]; - Vector intermediate(LogicalType::BOOLEAN, data_ptr_cast(intermediate_bools)); - Execute(expr, state, sel, count, intermediate); - - UnifiedVectorFormat idata; - intermediate.ToUnifiedFormat(count, idata); - - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - if (!idata.validity.AllValid()) { - return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); - } else { - return DefaultSelectSwitch(idata, sel, count, true_sel, false_sel); - } -} - -vector> &ExpressionExecutor::GetStates() { - return states; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_between.cpp b/src/duckdb/src/execution/expression_executor/execute_between.cpp deleted file mode 100644 index 341835136..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_between.cpp +++ /dev/null @@ -1,163 +0,0 @@ -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_between_expression.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" - -namespace duckdb { - -#ifndef DUCKDB_SMALLER_BINARY -struct BothInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThanEquals::Operation(input, lower) && LessThanEquals::Operation(input, upper); - } -}; - -struct LowerInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThanEquals::Operation(input, lower) && LessThan::Operation(input, upper); - } -}; - -struct UpperInclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThan::Operation(input, lower) && LessThanEquals::Operation(input, upper); - } -}; - -struct ExclusiveBetweenOperator { - template - static inline bool Operation(T input, T lower, T upper) { - return GreaterThan::Operation(input, lower) && LessThan::Operation(input, upper); - } -}; - -template -static idx_t BetweenLoopTypeSwitch(Vector &input, Vector &lower, Vector &upper, const SelectionVector *sel, idx_t count, - SelectionVector *true_sel, SelectionVector *false_sel) { - switch (input.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT16: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT32: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT64: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INT128: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT8: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT16: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT32: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT64: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::UINT128: - return TernaryExecutor::Select(input, lower, upper, sel, count, - true_sel, false_sel); - case PhysicalType::FLOAT: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, false_sel); - case PhysicalType::DOUBLE: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::VARCHAR: - return TernaryExecutor::Select(input, lower, upper, sel, count, true_sel, - false_sel); - case PhysicalType::INTERVAL: - return TernaryExecutor::Select(input, lower, upper, sel, count, - true_sel, false_sel); - default: - throw InvalidTypeException(input.GetType(), "Invalid type for BETWEEN"); - } -} -#endif - -unique_ptr ExpressionExecutor::InitializeState(const BoundBetweenExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(*expr.input); - result->AddChild(*expr.lower); - result->AddChild(*expr.upper); - - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { - // resolve the children - state->intermediate_chunk.Reset(); - - auto &input = state->intermediate_chunk.data[0]; - auto &lower = state->intermediate_chunk.data[1]; - auto &upper = state->intermediate_chunk.data[2]; - - Execute(*expr.input, state->child_states[0].get(), sel, count, input); - Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); - Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); - - Vector intermediate1(LogicalType::BOOLEAN); - Vector intermediate2(LogicalType::BOOLEAN); - - if (expr.upper_inclusive && expr.lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); - } else if (expr.lower_inclusive) { - VectorOperations::GreaterThanEquals(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); - } else if (expr.upper_inclusive) { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThanEquals(input, upper, intermediate2, count); - } else { - VectorOperations::GreaterThan(input, lower, intermediate1, count); - VectorOperations::LessThan(input, upper, intermediate2, count); - } - VectorOperations::And(intermediate1, intermediate2, result, count); -} - -idx_t ExpressionExecutor::Select(const BoundBetweenExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, SelectionVector *true_sel, SelectionVector *false_sel) { -#ifdef DUCKDB_SMALLER_BINARY - throw InternalException("ExpressionExecutor::Select not available with DUCKDB_SMALLER_BINARY"); -#else - // resolve the children - Vector input(state->intermediate_chunk.data[0]); - Vector lower(state->intermediate_chunk.data[1]); - Vector upper(state->intermediate_chunk.data[2]); - - Execute(*expr.input, state->child_states[0].get(), sel, count, input); - Execute(*expr.lower, state->child_states[1].get(), sel, count, lower); - Execute(*expr.upper, state->child_states[2].get(), sel, count, upper); - - if (expr.upper_inclusive && expr.lower_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else if (expr.lower_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else if (expr.upper_inclusive) { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, - false_sel); - } else { - return BetweenLoopTypeSwitch(input, lower, upper, sel, count, true_sel, false_sel); - } -#endif -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_case.cpp b/src/duckdb/src/execution/expression_executor/execute_case.cpp deleted file mode 100644 index cdeae3116..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_case.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_case_expression.hpp" - -namespace duckdb { - -struct CaseExpressionState : public ExpressionState { - CaseExpressionState(const Expression &expr, ExpressionExecutorState &root) - : ExpressionState(expr, root), true_sel(STANDARD_VECTOR_SIZE), false_sel(STANDARD_VECTOR_SIZE) { - } - - SelectionVector true_sel; - SelectionVector false_sel; -}; - -unique_ptr ExpressionExecutor::InitializeState(const BoundCaseExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &case_check : expr.case_checks) { - result->AddChild(*case_check.when_expr); - result->AddChild(*case_check.then_expr); - } - result->AddChild(*expr.else_expr); - - result->Finalize(); - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundCaseExpression &expr, ExpressionState *state_p, const SelectionVector *sel, - idx_t count, Vector &result) { - auto &state = state_p->Cast(); - - state.intermediate_chunk.Reset(); - - // first execute the check expression - auto current_true_sel = &state.true_sel; - auto current_false_sel = &state.false_sel; - auto current_sel = sel; - idx_t current_count = count; - for (idx_t i = 0; i < expr.case_checks.size(); i++) { - auto &case_check = expr.case_checks[i]; - auto &intermediate_result = state.intermediate_chunk.data[i * 2 + 1]; - auto check_state = state.child_states[i * 2].get(); - auto then_state = state.child_states[i * 2 + 1].get(); - - idx_t tcount = - Select(*case_check.when_expr, check_state, current_sel, current_count, current_true_sel, current_false_sel); - if (tcount == 0) { - // everything is false: do nothing - continue; - } - idx_t fcount = current_count - tcount; - if (fcount == 0 && current_count == count) { - // everything is true in the first CHECK statement - // we can skip the entire case and only execute the TRUE side - Execute(*case_check.then_expr, then_state, sel, count, result); - return; - } else { - // we need to execute and then fill in the desired tuples in the result - Execute(*case_check.then_expr, then_state, current_true_sel, tcount, intermediate_result); - FillSwitch(intermediate_result, result, *current_true_sel, NumericCast(tcount)); - } - // continue with the false tuples - current_sel = current_false_sel; - current_count = fcount; - if (fcount == 0) { - // everything is true: we are done - break; - } - } - if (current_count > 0) { - auto else_state = state.child_states.back().get(); - if (current_count == count) { - // everything was false, we can just evaluate the else expression directly - Execute(*expr.else_expr, else_state, sel, count, result); - return; - } else { - auto &intermediate_result = state.intermediate_chunk.data[expr.case_checks.size() * 2]; - - D_ASSERT(current_sel); - Execute(*expr.else_expr, else_state, current_sel, current_count, intermediate_result); - FillSwitch(intermediate_result, result, *current_sel, NumericCast(current_count)); - } - } - if (sel) { - result.Slice(*sel, count); - } -} - -template -void TemplatedFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - result.SetVectorType(VectorType::FLAT_VECTOR); - auto res = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto data = ConstantVector::GetData(vector); - if (ConstantVector::IsNull(vector)) { - for (idx_t i = 0; i < count; i++) { - result_mask.SetInvalid(sel.get_index(i)); - } - } else { - for (idx_t i = 0; i < count; i++) { - res[sel.get_index(i)] = *data; - } - } - } else { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - auto data = UnifiedVectorFormat::GetData(vdata); - for (idx_t i = 0; i < count; i++) { - auto source_idx = vdata.sel->get_index(i); - auto res_idx = sel.get_index(i); - - res[res_idx] = data[source_idx]; - result_mask.Set(res_idx, vdata.validity.RowIsValid(source_idx)); - } - } -} - -void ValidityFillLoop(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - result.SetVectorType(VectorType::FLAT_VECTOR); - auto &result_mask = FlatVector::Validity(result); - if (vector.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(vector)) { - for (idx_t i = 0; i < count; i++) { - result_mask.SetInvalid(sel.get_index(i)); - } - } - } else { - UnifiedVectorFormat vdata; - vector.ToUnifiedFormat(count, vdata); - if (vdata.validity.AllValid()) { - return; - } - for (idx_t i = 0; i < count; i++) { - auto source_idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(source_idx)) { - result_mask.SetInvalid(sel.get_index(i)); - } - } - } -} - -void ExpressionExecutor::FillSwitch(Vector &vector, Vector &result, const SelectionVector &sel, sel_t count) { - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT16: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT32: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT64: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT8: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT16: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT32: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT64: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INT128: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::UINT128: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::FLOAT: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::DOUBLE: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::INTERVAL: - TemplatedFillLoop(vector, result, sel, count); - break; - case PhysicalType::VARCHAR: - TemplatedFillLoop(vector, result, sel, count); - StringVector::AddHeapReference(result, vector); - break; - case PhysicalType::STRUCT: { - auto &vector_entries = StructVector::GetEntries(vector); - auto &result_entries = StructVector::GetEntries(result); - ValidityFillLoop(vector, result, sel, count); - D_ASSERT(vector_entries.size() == result_entries.size()); - for (idx_t i = 0; i < vector_entries.size(); i++) { - FillSwitch(*vector_entries[i], *result_entries[i], sel, count); - } - break; - } - case PhysicalType::LIST: { - idx_t offset = ListVector::GetListSize(result); - auto &list_child = ListVector::GetEntry(vector); - ListVector::Append(result, list_child, ListVector::GetListSize(vector)); - - // all the false offsets need to be incremented by true_child.count - TemplatedFillLoop(vector, result, sel, count); - if (offset == 0) { - break; - } - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto result_idx = sel.get_index(i); - result_data[result_idx].offset += offset; - } - - Vector::Verify(result, sel, count); - break; - } - default: - throw NotImplementedException("Unimplemented type for case expression: %s", result.GetType().ToString()); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_cast.cpp b/src/duckdb/src/execution/expression_executor/execute_cast.cpp deleted file mode 100644 index 0627dcf51..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_cast.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundCastExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(*expr.child); - result->Finalize(); - - if (expr.bound_cast.init_local_state) { - auto context_ptr = root.executor->HasContext() ? &root.executor->GetContext() : nullptr; - CastLocalStateParameters parameters(context_ptr, expr.bound_cast.cast_data); - result->local_state = expr.bound_cast.init_local_state(parameters); - } - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundCastExpression &expr, ExpressionState *state, const SelectionVector *sel, - idx_t count, Vector &result) { - auto lstate = ExecuteFunctionState::GetFunctionState(*state); - - // resolve the child - state->intermediate_chunk.Reset(); - - auto &child = state->intermediate_chunk.data[0]; - auto child_state = state->child_states[0].get(); - - Execute(*expr.child, child_state, sel, count, child); - if (expr.try_cast) { - string error_message; - CastParameters parameters(expr.bound_cast.cast_data.get(), false, &error_message, lstate); - parameters.query_location = expr.GetQueryLocation(); - expr.bound_cast.function(child, result, count, parameters); - } else { - // cast it to the type specified by the cast expression - D_ASSERT(result.GetType() == expr.return_type); - CastParameters parameters(expr.bound_cast.cast_data.get(), false, nullptr, lstate); - parameters.query_location = expr.GetQueryLocation(); - expr.bound_cast.function(child, result, count, parameters); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp b/src/duckdb/src/execution/expression_executor/execute_comparison.cpp deleted file mode 100644 index 6e78de49c..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_comparison.cpp +++ /dev/null @@ -1,382 +0,0 @@ -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" - -#include - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundComparisonExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->AddChild(*expr.left); - result->AddChild(*expr.right); - - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundComparisonExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // resolve the children - state->intermediate_chunk.Reset(); - auto &left = state->intermediate_chunk.data[0]; - auto &right = state->intermediate_chunk.data[1]; - - Execute(*expr.left, state->child_states[0].get(), sel, count, left); - Execute(*expr.right, state->child_states[1].get(), sel, count, right); - - switch (expr.GetExpressionType()) { - case ExpressionType::COMPARE_EQUAL: - VectorOperations::Equals(left, right, result, count); - break; - case ExpressionType::COMPARE_NOTEQUAL: - VectorOperations::NotEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHAN: - VectorOperations::LessThan(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHAN: - VectorOperations::GreaterThan(left, right, result, count); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - VectorOperations::LessThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - VectorOperations::GreaterThanEquals(left, right, result, count); - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - VectorOperations::DistinctFrom(left, right, result, count); - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - VectorOperations::NotDistinctFrom(left, right, result, count); - break; - default: - throw InternalException("Unknown comparison type!"); - } -} - -static void UpdateNullMask(Vector &vec, optional_ptr sel, idx_t count, ValidityMask &null_mask) { - UnifiedVectorFormat vdata; - vec.ToUnifiedFormat(count, vdata); - - if (vdata.validity.AllValid()) { - return; - } - - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - for (idx_t i = 0; i < count; ++i) { - const auto ridx = sel->get_index(i); - const auto vidx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(vidx)) { - null_mask.SetInvalid(ridx); - } - } -} - -template -static idx_t NestedSelectOperation(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask); - -template -static idx_t TemplatedSelectOperation(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - if (null_mask) { - UpdateNullMask(left, sel, count, *null_mask); - UpdateNullMask(right, sel, count, *null_mask); - } - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::INT16: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::INT32: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::INT64: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::UINT8: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::UINT16: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::UINT32: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::UINT64: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::INT128: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::UINT128: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::FLOAT: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), false_sel.get()); - case PhysicalType::DOUBLE: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::INTERVAL: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::VARCHAR: - return BinaryExecutor::Select(left, right, sel.get(), count, true_sel.get(), - false_sel.get()); - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - return NestedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); - default: - throw InternalException("Invalid type for comparison"); - } -} - -struct NestedSelector { - // Select the matching rows for the values of a nested type that are not both NULL. - // Those semantics are the same as the corresponding non-distinct comparator - template - static idx_t Select(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - throw InvalidTypeException(left.GetType(), "Invalid operation for nested SELECT"); - } -}; - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::NestedEquals(left, right, sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::NestedNotEquals(left, right, sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctLessThan(left, right, sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, - optional_ptr sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctLessThanEquals(left, right, sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThan(left, right, sel, count, true_sel, false_sel, null_mask); -} - -template <> -idx_t NestedSelector::Select(Vector &left, Vector &right, - optional_ptr sel, idx_t count, - optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return VectorOperations::DistinctGreaterThanEquals(left, right, sel, count, true_sel, false_sel, null_mask); -} - -static inline idx_t SelectNotNull(Vector &left, Vector &right, const idx_t count, const SelectionVector &sel, - SelectionVector &maybe_vec, OptionalSelection &false_opt, - optional_ptr null_mask) { - - UnifiedVectorFormat lvdata, rvdata; - left.ToUnifiedFormat(count, lvdata); - right.ToUnifiedFormat(count, rvdata); - - auto &lmask = lvdata.validity; - auto &rmask = rvdata.validity; - - // For top-level comparisons, NULL semantics are in effect, - // so filter out any NULLs - idx_t remaining = 0; - if (lmask.AllValid() && rmask.AllValid()) { - // None are NULL, distinguish values. - for (idx_t i = 0; i < count; ++i) { - const auto idx = sel.get_index(i); - maybe_vec.set_index(remaining++, idx); - } - return remaining; - } - - // Slice the Vectors down to the rows that are not determined (i.e., neither is NULL) - SelectionVector slicer(count); - idx_t false_count = 0; - for (idx_t i = 0; i < count; ++i) { - const auto result_idx = sel.get_index(i); - const auto lidx = lvdata.sel->get_index(i); - const auto ridx = rvdata.sel->get_index(i); - if (!lmask.RowIsValid(lidx) || !rmask.RowIsValid(ridx)) { - if (null_mask) { - null_mask->SetInvalid(result_idx); - } - false_opt.Append(false_count, result_idx); - } else { - // Neither is NULL, distinguish values. - slicer.set_index(remaining, i); - maybe_vec.set_index(remaining++, result_idx); - } - } - false_opt.Advance(false_count); - - if (remaining && remaining < count) { - left.Slice(slicer, remaining); - right.Slice(slicer, remaining); - } - - return remaining; -} - -static void ScatterSelection(optional_ptr target, const idx_t count, - const SelectionVector &dense_vec) { - if (target) { - for (idx_t i = 0; i < count; ++i) { - target->set_index(i, dense_vec.get_index(i)); - } - } -} - -template -static idx_t NestedSelectOperation(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - // The Select operations all use a dense pair of input vectors to partition - // a selection vector in a single pass. But to implement progressive comparisons, - // we have to make multiple passes, so we need to keep track of the original input positions - // and then scatter the output selections when we are done. - if (!sel) { - sel = FlatVector::IncrementalSelectionVector(); - } - - // Make buffered selections for progressive comparisons - // TODO: Remove unnecessary allocations - SelectionVector true_vec(count); - OptionalSelection true_opt(&true_vec); - - SelectionVector false_vec(count); - OptionalSelection false_opt(&false_vec); - - SelectionVector maybe_vec(count); - - // Handle NULL nested values - Vector l_not_null(left); - Vector r_not_null(right); - - auto match_count = SelectNotNull(l_not_null, r_not_null, count, *sel, maybe_vec, false_opt, null_mask); - auto no_match_count = count - match_count; - count = match_count; - - // Now that we have handled the NULLs, we can use the recursive nested comparator for the rest. - match_count = - NestedSelector::Select(l_not_null, r_not_null, &maybe_vec, count, optional_ptr(true_opt), - optional_ptr(false_opt), null_mask); - no_match_count += (count - match_count); - - // Copy the buffered selections to the output selections - ScatterSelection(true_sel, match_count, true_vec); - ScatterSelection(false_sel, no_match_count, false_vec); - - return match_count; -} - -idx_t VectorOperations::Equals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); -} - -idx_t VectorOperations::NotEquals(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); -} - -idx_t VectorOperations::GreaterThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); -} - -idx_t VectorOperations::GreaterThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedSelectOperation(left, right, sel, count, true_sel, false_sel, null_mask); -} - -idx_t VectorOperations::LessThan(Vector &left, Vector &right, optional_ptr sel, idx_t count, - optional_ptr true_sel, optional_ptr false_sel, - optional_ptr null_mask) { - return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel, null_mask); -} - -idx_t VectorOperations::LessThanEquals(Vector &left, Vector &right, optional_ptr sel, - idx_t count, optional_ptr true_sel, - optional_ptr false_sel, optional_ptr null_mask) { - return TemplatedSelectOperation(right, left, sel, count, true_sel, false_sel, null_mask); -} - -idx_t ExpressionExecutor::Select(const BoundComparisonExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - // resolve the children - state->intermediate_chunk.Reset(); - auto &left = state->intermediate_chunk.data[0]; - auto &right = state->intermediate_chunk.data[1]; - - Execute(*expr.left, state->child_states[0].get(), sel, count, left); - Execute(*expr.right, state->child_states[1].get(), sel, count, right); - - switch (expr.GetExpressionType()) { - case ExpressionType::COMPARE_EQUAL: - return VectorOperations::Equals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOTEQUAL: - return VectorOperations::NotEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHAN: - return VectorOperations::LessThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHAN: - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_DISTINCT_FROM: - return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, false_sel); - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, false_sel); - default: - throw InternalException("Unknown comparison type!"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp b/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp deleted file mode 100644 index 1b2bc3a4e..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_conjunction.cpp +++ /dev/null @@ -1,140 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/execution/adaptive_filter.hpp" - -#include - -namespace duckdb { - -struct ConjunctionState : public ExpressionState { - ConjunctionState(const Expression &expr, ExpressionExecutorState &root) : ExpressionState(expr, root) { - adaptive_filter = make_uniq(expr); - } - unique_ptr adaptive_filter; -}; - -unique_ptr ExpressionExecutor::InitializeState(const BoundConjunctionExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(*child); - } - - result->Finalize(); - return std::move(result); -} - -void ExpressionExecutor::Execute(const BoundConjunctionExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // execute the children - state->intermediate_chunk.Reset(); - for (idx_t i = 0; i < expr.children.size(); i++) { - auto ¤t_result = state->intermediate_chunk.data[i]; - Execute(*expr.children[i], state->child_states[i].get(), sel, count, current_result); - if (i == 0) { - // move the result - result.Reference(current_result); - } else { - Vector intermediate(LogicalType::BOOLEAN); - // AND/OR together - switch (expr.GetExpressionType()) { - case ExpressionType::CONJUNCTION_AND: - VectorOperations::And(current_result, result, intermediate, count); - break; - case ExpressionType::CONJUNCTION_OR: - VectorOperations::Or(current_result, result, intermediate, count); - break; - default: - throw InternalException("Unknown conjunction type!"); - } - result.Reference(intermediate); - } - } -} - -idx_t ExpressionExecutor::Select(const BoundConjunctionExpression &expr, ExpressionState *state_p, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel, - SelectionVector *false_sel) { - auto &state = state_p->Cast(); - - if (expr.GetExpressionType() == ExpressionType::CONJUNCTION_AND) { - // get runtime statistics - auto filter_state = state.adaptive_filter->BeginFilter(); - const SelectionVector *current_sel = sel; - idx_t current_count = count; - idx_t false_count = 0; - - unique_ptr temp_true, temp_false; - if (false_sel) { - temp_false = make_uniq(STANDARD_VECTOR_SIZE); - } - if (!true_sel) { - temp_true = make_uniq(STANDARD_VECTOR_SIZE); - true_sel = temp_true.get(); - } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], - state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, - current_count, true_sel, temp_false.get()); - idx_t fcount = current_count - tcount; - if (fcount > 0 && false_sel) { - // move failing tuples into the false_sel - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < fcount; i++) { - false_sel->set_index(false_count++, temp_false->get_index(i)); - } - } - current_count = tcount; - if (current_count == 0) { - break; - } - if (current_count < count) { - // tuples were filtered out: move on to using the true_sel to only evaluate passing tuples in subsequent - // iterations - current_sel = true_sel; - } - } - // adapt runtime statistics - state.adaptive_filter->EndFilter(filter_state); - return current_count; - } else { - // get runtime statistics - auto filter_state = state.adaptive_filter->BeginFilter(); - - const SelectionVector *current_sel = sel; - idx_t current_count = count; - idx_t result_count = 0; - - unique_ptr temp_true, temp_false; - if (true_sel) { - temp_true = make_uniq(STANDARD_VECTOR_SIZE); - } - if (!false_sel) { - temp_false = make_uniq(STANDARD_VECTOR_SIZE); - false_sel = temp_false.get(); - } - for (idx_t i = 0; i < expr.children.size(); i++) { - idx_t tcount = Select(*expr.children[state.adaptive_filter->permutation[i]], - state.child_states[state.adaptive_filter->permutation[i]].get(), current_sel, - current_count, temp_true.get(), false_sel); - if (tcount > 0) { - if (true_sel) { - // tuples passed, move them into the actual result vector - for (idx_t i = 0; i < tcount; i++) { - true_sel->set_index(result_count++, temp_true->get_index(i)); - } - } - // now move on to check only the non-passing tuples - current_count -= tcount; - current_sel = false_sel; - } - } - - // adapt runtime statistics - state.adaptive_filter->EndFilter(filter_state); - return result_count; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_constant.cpp b/src/duckdb/src/execution/expression_executor/execute_constant.cpp deleted file mode 100644 index cd9b463b8..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_constant.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundConstantExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundConstantExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.value.type() == expr.return_type); - result.Reference(expr.value); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_function.cpp b/src/duckdb/src/execution/expression_executor/execute_function.cpp deleted file mode 100644 index 5a95d27a3..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_function.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -ExecuteFunctionState::ExecuteFunctionState(const Expression &expr, ExpressionExecutorState &root) - : ExpressionState(expr, root) { -} - -ExecuteFunctionState::~ExecuteFunctionState() { -} - -unique_ptr ExpressionExecutor::InitializeState(const BoundFunctionExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(*child); - } - - result->Finalize(); - if (expr.function.init_local_state) { - result->local_state = expr.function.init_local_state(*result, expr, expr.bind_info.get()); - } - return std::move(result); -} - -static void VerifyNullHandling(const BoundFunctionExpression &expr, DataChunk &args, Vector &result) { -#ifdef DEBUG - if (args.data.empty() || expr.function.null_handling != FunctionNullHandling::DEFAULT_NULL_HANDLING) { - return; - } - - // Combine all the argument validity masks into a flat validity mask - idx_t count = args.size(); - ValidityMask combined_mask(count); - for (auto &arg : args.data) { - UnifiedVectorFormat arg_data; - arg.ToUnifiedFormat(count, arg_data); - - for (idx_t i = 0; i < count; i++) { - auto idx = arg_data.sel->get_index(i); - if (!arg_data.validity.RowIsValid(idx)) { - combined_mask.SetInvalid(i); - } - } - } - - // Default is that if any of the arguments are NULL, the result is also NULL - UnifiedVectorFormat result_data; - result.ToUnifiedFormat(count, result_data); - for (idx_t i = 0; i < count; i++) { - if (!combined_mask.RowIsValid(i)) { - auto idx = result_data.sel->get_index(i); - D_ASSERT(!result_data.validity.RowIsValid(idx)); - } - } -#endif -} - -void ExpressionExecutor::Execute(const BoundFunctionExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - state->intermediate_chunk.Reset(); - auto &arguments = state->intermediate_chunk; - if (!state->types.empty()) { - for (idx_t i = 0; i < expr.children.size(); i++) { - D_ASSERT(state->types[i] == expr.children[i]->return_type); - Execute(*expr.children[i], state->child_states[i].get(), sel, count, arguments.data[i]); -#ifdef DEBUG - if (expr.children[i]->return_type.id() == LogicalTypeId::VARCHAR) { - arguments.data[i].UTFVerify(count); - } -#endif - } - } - arguments.SetCardinality(count); - arguments.Verify(); - - D_ASSERT(expr.function.function); - // #ifdef DEBUG - expr.function.function(arguments, *state, result); - - VerifyNullHandling(expr, arguments, result); - D_ASSERT(result.GetType() == expr.return_type); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_operator.cpp b/src/duckdb/src/execution/expression_executor/execute_operator.cpp deleted file mode 100644 index b543679e8..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_operator.cpp +++ /dev/null @@ -1,140 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundOperatorExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - for (auto &child : expr.children) { - result->AddChild(*child); - } - - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundOperatorExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - // special handling for special snowflake 'IN' - // IN has n children - if (expr.GetExpressionType() == ExpressionType::COMPARE_IN || - expr.GetExpressionType() == ExpressionType::COMPARE_NOT_IN) { - if (expr.children.size() < 2) { - throw InvalidInputException("IN needs at least two children"); - } - - Vector left(expr.children[0]->return_type); - // eval left side - Execute(*expr.children[0], state->child_states[0].get(), sel, count, left); - - // init result to false - Vector intermediate(LogicalType::BOOLEAN); - Value false_val = Value::BOOLEAN(false); - intermediate.Reference(false_val); - - // in rhs is a list of constants - // for every child, OR the result of the comparison with the left - // to get the overall result. - for (idx_t child = 1; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->return_type); - Vector comp_res(LogicalType::BOOLEAN); - - Execute(*expr.children[child], state->child_states[child].get(), sel, count, vector_to_check); - VectorOperations::Equals(left, vector_to_check, comp_res, count); - - if (child == 1) { - // first child: move to result - intermediate.Reference(comp_res); - } else { - // otherwise OR together - Vector new_result(LogicalType::BOOLEAN, true, false); - VectorOperations::Or(intermediate, comp_res, new_result, count); - intermediate.Reference(new_result); - } - } - if (expr.GetExpressionType() == ExpressionType::COMPARE_NOT_IN) { - // NOT IN: invert result - VectorOperations::Not(intermediate, result, count); - } else { - // directly use the result - result.Reference(intermediate); - } - } else if (expr.GetExpressionType() == ExpressionType::OPERATOR_COALESCE) { - SelectionVector sel_a(count); - SelectionVector sel_b(count); - SelectionVector slice_sel(count); - SelectionVector result_sel(count); - SelectionVector *next_sel = &sel_a; - const SelectionVector *current_sel = sel; - idx_t remaining_count = count; - idx_t next_count; - for (idx_t child = 0; child < expr.children.size(); child++) { - Vector vector_to_check(expr.children[child]->return_type); - Execute(*expr.children[child], state->child_states[child].get(), current_sel, remaining_count, - vector_to_check); - - UnifiedVectorFormat vdata; - vector_to_check.ToUnifiedFormat(remaining_count, vdata); - - idx_t result_count = 0; - next_count = 0; - for (idx_t i = 0; i < remaining_count; i++) { - auto base_idx = current_sel ? current_sel->get_index(i) : i; - auto idx = vdata.sel->get_index(i); - if (vdata.validity.RowIsValid(idx)) { - slice_sel.set_index(result_count, i); - result_sel.set_index(result_count++, base_idx); - } else { - next_sel->set_index(next_count++, base_idx); - } - } - if (result_count > 0) { - vector_to_check.Slice(slice_sel, result_count); - FillSwitch(vector_to_check, result, result_sel, NumericCast(result_count)); - } - current_sel = next_sel; - next_sel = next_sel == &sel_a ? &sel_b : &sel_a; - remaining_count = next_count; - if (next_count == 0) { - break; - } - } - if (remaining_count > 0) { - for (idx_t i = 0; i < remaining_count; i++) { - FlatVector::SetNull(result, current_sel->get_index(i), true); - } - } - if (sel) { - result.Slice(*sel, count); - } else if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - } else if (expr.children.size() == 1) { - state->intermediate_chunk.Reset(); - auto &child = state->intermediate_chunk.data[0]; - - Execute(*expr.children[0], state->child_states[0].get(), sel, count, child); - switch (expr.GetExpressionType()) { - case ExpressionType::OPERATOR_NOT: { - VectorOperations::Not(child, result, count); - break; - } - case ExpressionType::OPERATOR_IS_NULL: { - VectorOperations::IsNull(child, result, count); - break; - } - case ExpressionType::OPERATOR_IS_NOT_NULL: { - VectorOperations::IsNotNull(child, result, count); - break; - } - default: - throw NotImplementedException("Unsupported operator type with 1 child!"); - } - } else { - throw NotImplementedException("operator"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_parameter.cpp b/src/duckdb/src/execution/expression_executor/execute_parameter.cpp deleted file mode 100644 index c03ca9347..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_parameter.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundParameterExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundParameterExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.parameter_data); - D_ASSERT(expr.parameter_data->return_type == expr.return_type); - D_ASSERT(expr.parameter_data->GetValue().type() == expr.return_type); - result.Reference(expr.parameter_data->GetValue()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor/execute_reference.cpp b/src/duckdb/src/execution/expression_executor/execute_reference.cpp deleted file mode 100644 index 88fdfa63d..000000000 --- a/src/duckdb/src/execution/expression_executor/execute_reference.cpp +++ /dev/null @@ -1,25 +0,0 @@ -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -namespace duckdb { - -unique_ptr ExpressionExecutor::InitializeState(const BoundReferenceExpression &expr, - ExpressionExecutorState &root) { - auto result = make_uniq(expr, root); - result->Finalize(); - return result; -} - -void ExpressionExecutor::Execute(const BoundReferenceExpression &expr, ExpressionState *state, - const SelectionVector *sel, idx_t count, Vector &result) { - D_ASSERT(expr.index != DConstants::INVALID_INDEX); - D_ASSERT(expr.index < chunk->ColumnCount()); - - if (sel) { - result.Slice(chunk->data[expr.index], *sel, count); - } else { - result.Reference(chunk->data[expr.index]); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/expression_executor_state.cpp b/src/duckdb/src/execution/expression_executor_state.cpp deleted file mode 100644 index 070a399db..000000000 --- a/src/duckdb/src/execution/expression_executor_state.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/execution/expression_executor_state.hpp" - -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -void ExpressionState::AddChild(Expression &child_expr) { - types.push_back(child_expr.return_type); - auto child_state = ExpressionExecutor::InitializeState(child_expr, root); - child_states.push_back(std::move(child_state)); - - auto expr_class = child_expr.GetExpressionClass(); - auto initialize_child = expr_class != ExpressionClass::BOUND_REF && expr_class != ExpressionClass::BOUND_CONSTANT && - expr_class != ExpressionClass::BOUND_PARAMETER; - initialize.push_back(initialize_child); -} - -void ExpressionState::Finalize() { - if (types.empty()) { - return; - } - intermediate_chunk.Initialize(GetAllocator(), types, initialize); -} - -Allocator &ExpressionState::GetAllocator() { - return root.executor->GetAllocator(); -} - -bool ExpressionState::HasContext() { - return root.executor->HasContext(); -} - -ClientContext &ExpressionState::GetContext() { - if (!HasContext()) { - throw BinderException("Cannot use %s in this context", (expr.Cast()).function.name); - } - return root.executor->GetContext(); -} - -ExpressionState::ExpressionState(const Expression &expr, ExpressionExecutorState &root) : expr(expr), root(root) { -} - -ExpressionExecutorState::ExpressionExecutorState() { -} - -void ExpressionState::Verify(ExpressionExecutorState &root_executor) { - D_ASSERT(&root_executor == &root); - for (auto &entry : child_states) { - entry->Verify(root_executor); - } -} - -void ExpressionExecutorState::Verify() { - D_ASSERT(executor); - root_state->Verify(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/art.cpp b/src/duckdb/src/execution/index/art/art.cpp deleted file mode 100644 index a92f2eebf..000000000 --- a/src/duckdb/src/execution/index/art/art.cpp +++ /dev/null @@ -1,1436 +0,0 @@ -#include "duckdb/execution/index/art/art.hpp" - -#include "duckdb/common/types/conflict_manager.hpp" -#include "duckdb/common/unordered_map.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/art/base_leaf.hpp" -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/iterator.hpp" -#include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/node256.hpp" -#include "duckdb/execution/index/art/node256_leaf.hpp" -#include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/execution/index/art/prefix.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/planner/expression/bound_between_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/storage/arena_allocator.hpp" -#include "duckdb/storage/metadata/metadata_reader.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/storage/table_io_manager.hpp" - -namespace duckdb { - -struct ARTIndexScanState : public IndexScanState { - //! The predicates to scan. - //! A single predicate for point lookups, and two predicates for range scans. - Value values[2]; - //! The expressions over the scan predicates. - ExpressionType expressions[2]; - bool checked = false; - //! All scanned row IDs. - unsafe_vector row_ids; -}; - -//===--------------------------------------------------------------------===// -// ART -//===--------------------------------------------------------------------===// - -ART::ART(const string &name, const IndexConstraintType index_constraint_type, const vector &column_ids, - TableIOManager &table_io_manager, const vector> &unbound_expressions, - AttachedDatabase &db, - const shared_ptr, ALLOCATOR_COUNT>> &allocators_ptr, - const IndexStorageInfo &info) - : BoundIndex(name, ART::TYPE_NAME, index_constraint_type, column_ids, table_io_manager, unbound_expressions, db), - allocators(allocators_ptr), owns_data(false), append_mode(ARTAppendMode::DEFAULT) { - - // FIXME: Use the new byte representation function to support nested types. - for (idx_t i = 0; i < types.size(); i++) { - switch (types[i]) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::INT128: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - case PhysicalType::UINT128: - case PhysicalType::FLOAT: - case PhysicalType::DOUBLE: - case PhysicalType::VARCHAR: - break; - default: - throw InvalidTypeException(logical_types[i], "Invalid type for index key."); - } - } - - // Initialize the allocators. - SetPrefixCount(info); - if (!allocators) { - owns_data = true; - auto prefix_size = NumericCast(prefix_count) + NumericCast(Prefix::METADATA_SIZE); - auto &block_manager = table_io_manager.GetIndexBlockManager(); - - array, ALLOCATOR_COUNT> allocator_array = { - make_unsafe_uniq(prefix_size, block_manager), - make_unsafe_uniq(sizeof(Leaf), block_manager), - make_unsafe_uniq(sizeof(Node4), block_manager), - make_unsafe_uniq(sizeof(Node16), block_manager), - make_unsafe_uniq(sizeof(Node48), block_manager), - make_unsafe_uniq(sizeof(Node256), block_manager), - make_unsafe_uniq(sizeof(Node7Leaf), block_manager), - make_unsafe_uniq(sizeof(Node15Leaf), block_manager), - make_unsafe_uniq(sizeof(Node256Leaf), block_manager), - }; - allocators = - make_shared_ptr, ALLOCATOR_COUNT>>(std::move(allocator_array)); - } - - if (!info.IsValid()) { - // We create a new ART. - return; - } - - if (info.root_block_ptr.IsValid()) { - // Backwards compatibility. - Deserialize(info.root_block_ptr); - return; - } - - // Set the root node and initialize the allocators. - tree.Set(info.root); - InitAllocators(info); -} - -//===--------------------------------------------------------------------===// -// Initialize Scans -//===--------------------------------------------------------------------===// - -static unique_ptr InitializeScanSinglePredicate(const Value &value, - const ExpressionType expression_type) { - auto result = make_uniq(); - result->values[0] = value; - result->expressions[0] = expression_type; - return std::move(result); -} - -static unique_ptr InitializeScanTwoPredicates(const Value &low_value, - const ExpressionType low_expression_type, - const Value &high_value, - const ExpressionType high_expression_type) { - auto result = make_uniq(); - result->values[0] = low_value; - result->expressions[0] = low_expression_type; - result->values[1] = high_value; - result->expressions[1] = high_expression_type; - return std::move(result); -} - -unique_ptr ART::TryInitializeScan(const Expression &expr, const Expression &filter_expr) { - Value low_value, high_value, equal_value; - ExpressionType low_comparison_type = ExpressionType::INVALID, high_comparison_type = ExpressionType::INVALID; - - // Try to find a matching index for any of the filter expressions. - ComparisonExpressionMatcher matcher; - - // Match on a comparison type. - matcher.expr_type = make_uniq(); - - // Match on a constant comparison with the indexed expression. - matcher.matchers.push_back(make_uniq(expr)); - matcher.matchers.push_back(make_uniq()); - matcher.policy = SetMatcher::Policy::UNORDERED; - - vector> bindings; - auto filter_match = - matcher.Match(const_cast(filter_expr), bindings); // NOLINT: Match does not alter the expr. - if (filter_match) { - // This is a range or equality comparison with a constant value, so we can use the index. - // bindings[0] = the expression - // bindings[1] = the index expression - // bindings[2] = the constant - auto &comparison = bindings[0].get().Cast(); - auto constant_value = bindings[2].get().Cast().value; - auto comparison_type = comparison.GetExpressionType(); - - if (comparison.left->GetExpressionType() == ExpressionType::VALUE_CONSTANT) { - // The expression is on the right side, we flip the comparison expression. - comparison_type = FlipComparisonExpression(comparison_type); - } - - if (comparison_type == ExpressionType::COMPARE_EQUAL) { - // An equality value overrides any other bounds. - equal_value = constant_value; - } else if (comparison_type == ExpressionType::COMPARE_GREATERTHANOREQUALTO || - comparison_type == ExpressionType::COMPARE_GREATERTHAN) { - // This is a lower bound. - low_value = constant_value; - low_comparison_type = comparison_type; - } else { - // This is an upper bound. - high_value = constant_value; - high_comparison_type = comparison_type; - } - - } else if (filter_expr.GetExpressionType() == ExpressionType::COMPARE_BETWEEN) { - auto &between = filter_expr.Cast(); - if (!between.input->Equals(expr)) { - // The expression does not match the index expression. - return nullptr; - } - - if (between.lower->GetExpressionType() != ExpressionType::VALUE_CONSTANT || - between.upper->GetExpressionType() != ExpressionType::VALUE_CONSTANT) { - // Not a constant expression. - return nullptr; - } - - low_value = between.lower->Cast().value; - low_comparison_type = between.lower_inclusive ? ExpressionType::COMPARE_GREATERTHANOREQUALTO - : ExpressionType::COMPARE_GREATERTHAN; - high_value = (between.upper->Cast()).value; - high_comparison_type = - between.upper_inclusive ? ExpressionType::COMPARE_LESSTHANOREQUALTO : ExpressionType::COMPARE_LESSTHAN; - } - - // We cannot use an index scan. - if (equal_value.IsNull() && low_value.IsNull() && high_value.IsNull()) { - return nullptr; - } - - // Initialize the index scan state and return it. - if (!equal_value.IsNull()) { - // Equality predicate. - return InitializeScanSinglePredicate(equal_value, ExpressionType::COMPARE_EQUAL); - } - if (!low_value.IsNull() && !high_value.IsNull()) { - // Two-sided predicate. - return InitializeScanTwoPredicates(low_value, low_comparison_type, high_value, high_comparison_type); - } - if (!low_value.IsNull()) { - // Less-than predicate. - return InitializeScanSinglePredicate(low_value, low_comparison_type); - } - // Greater-than predicate. - return InitializeScanSinglePredicate(high_value, high_comparison_type); -} - -//===--------------------------------------------------------------------===// -// ART Keys -//===--------------------------------------------------------------------===// - -template -static void TemplatedGenerateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { - D_ASSERT(keys.size() >= count); - - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - auto input_data = UnifiedVectorFormat::GetData(data); - - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - if (IS_NOT_NULL || data.validity.RowIsValid(idx)) { - ARTKey::CreateARTKey(allocator, keys[i], input_data[idx]); - continue; - } - - // We need to reset the key value in the reusable keys vector. - keys[i] = ARTKey(); - } -} - -template -static void ConcatenateKeys(ArenaAllocator &allocator, Vector &input, idx_t count, unsafe_vector &keys) { - UnifiedVectorFormat data; - input.ToUnifiedFormat(count, data); - auto input_data = UnifiedVectorFormat::GetData(data); - - for (idx_t i = 0; i < count; i++) { - auto idx = data.sel->get_index(i); - - if (IS_NOT_NULL) { - auto other_key = ARTKey::CreateARTKey(allocator, input_data[idx]); - keys[i].Concat(allocator, other_key); - continue; - } - - // A previous column entry was NULL. - if (keys[i].Empty()) { - continue; - } - - // This column entry is NULL, so we set the whole key to NULL. - if (!data.validity.RowIsValid(idx)) { - keys[i] = ARTKey(); - continue; - } - - // Concatenate the keys. - auto other_key = ARTKey::CreateARTKey(allocator, input_data[idx]); - keys[i].Concat(allocator, other_key); - } -} - -template -void GenerateKeysInternal(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { - switch (input.data[0].GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::INT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT8: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT16: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT32: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT64: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::UINT128: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::FLOAT: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::DOUBLE: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - case PhysicalType::VARCHAR: - TemplatedGenerateKeys(allocator, input.data[0], input.size(), keys); - break; - default: - throw InternalException("Invalid type for index"); - } - - // We concatenate the keys for each remaining column of a compound key. - for (idx_t i = 1; i < input.ColumnCount(); i++) { - switch (input.data[i].GetType().InternalType()) { - case PhysicalType::BOOL: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::INT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT8: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT16: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT32: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT64: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::UINT128: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::FLOAT: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::DOUBLE: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - case PhysicalType::VARCHAR: - ConcatenateKeys(allocator, input.data[i], input.size(), keys); - break; - default: - throw InternalException("Invalid type for index"); - } - } -} - -template <> -void ART::GenerateKeys<>(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { - GenerateKeysInternal(allocator, input, keys); -} - -template <> -void ART::GenerateKeys(ArenaAllocator &allocator, DataChunk &input, unsafe_vector &keys) { - GenerateKeysInternal(allocator, input, keys); -} - -void ART::GenerateKeyVectors(ArenaAllocator &allocator, DataChunk &input, Vector &row_ids, unsafe_vector &keys, - unsafe_vector &row_id_keys) { - GenerateKeys<>(allocator, input, keys); - - DataChunk row_id_chunk; - row_id_chunk.Initialize(Allocator::DefaultAllocator(), vector {LogicalType::ROW_TYPE}, input.size()); - row_id_chunk.data[0].Reference(row_ids); - row_id_chunk.SetCardinality(input.size()); - GenerateKeys<>(allocator, row_id_chunk, row_id_keys); -} - -//===--------------------------------------------------------------------===// -// Construct from sorted data. -//===--------------------------------------------------------------------===// - -bool ART::ConstructInternal(const unsafe_vector &keys, const unsafe_vector &row_ids, Node &node, - ARTKeySection §ion) { - D_ASSERT(section.start < keys.size()); - D_ASSERT(section.end < keys.size()); - D_ASSERT(section.start <= section.end); - - auto &start = keys[section.start]; - auto &end = keys[section.end]; - D_ASSERT(start.len != 0); - - // Increment the depth until we reach a leaf or find a mismatching byte. - auto prefix_depth = section.depth; - while (start.len != section.depth && start.ByteMatches(end, section.depth)) { - section.depth++; - } - - if (start.len == section.depth) { - // We reached a leaf. All the bytes of start_key and end_key match. - auto row_id_count = section.end - section.start + 1; - if (IsUnique() && row_id_count != 1) { - return false; - } - - reference ref(node); - auto count = UnsafeNumericCast(start.len - prefix_depth); - Prefix::New(*this, ref, start, prefix_depth, count); - if (row_id_count == 1) { - Leaf::New(ref, row_ids[section.start].GetRowId()); - } else { - Leaf::New(*this, ref, row_ids, section.start, row_id_count); - } - return true; - } - - // Create a new node and recurse. - unsafe_vector children; - section.GetChildSections(children, keys); - - // Create the prefix. - reference ref(node); - auto prefix_length = section.depth - prefix_depth; - Prefix::New(*this, ref, start, prefix_depth, prefix_length); - - // Create the node. - Node::New(*this, ref, Node::GetNodeType(children.size())); - for (auto &child : children) { - Node new_child; - auto success = ConstructInternal(keys, row_ids, new_child, child); - Node::InsertChild(*this, ref, child.key_byte, new_child); - if (!success) { - return false; - } - } - return true; -} - -bool ART::Construct(unsafe_vector &keys, unsafe_vector &row_ids, const idx_t row_count) { - ARTKeySection section(0, row_count - 1, 0, 0); - if (!ConstructInternal(keys, row_ids, tree, section)) { - return false; - } - -#ifdef DEBUG - unsafe_vector row_ids_debug; - Iterator it(*this); - it.FindMinimum(tree); - ARTKey empty_key = ARTKey(); - it.Scan(empty_key, NumericLimits().Maximum(), row_ids_debug, false); - D_ASSERT(row_count == row_ids_debug.size()); -#endif - return true; -} - -//===--------------------------------------------------------------------===// -// Insert and Constraint Checking -//===--------------------------------------------------------------------===// - -ErrorData ART::Insert(IndexLock &l, DataChunk &chunk, Vector &row_ids) { - return Insert(l, chunk, row_ids, nullptr); -} - -ErrorData ART::Insert(IndexLock &l, DataChunk &chunk, Vector &row_ids, optional_ptr delete_index) { - D_ASSERT(row_ids.GetType().InternalType() == ROW_TYPE); - auto row_count = chunk.size(); - - ArenaAllocator allocator(BufferAllocator::Get(db)); - unsafe_vector keys(row_count); - unsafe_vector row_id_keys(row_count); - GenerateKeyVectors(allocator, chunk, row_ids, keys, row_id_keys); - - optional_ptr delete_art; - if (delete_index) { - delete_art = delete_index->Cast(); - } - - auto conflict_type = ARTConflictType::NO_CONFLICT; - optional_idx conflict_idx; - auto was_empty = !tree.HasMetadata(); - - // Insert the entries into the index. - for (idx_t i = 0; i < row_count; i++) { - if (keys[i].Empty()) { - continue; - } - conflict_type = Insert(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus(), delete_art); - if (conflict_type != ARTConflictType::NO_CONFLICT) { - conflict_idx = i; - break; - } - } - - // Remove any previously inserted entries. - if (conflict_type != ARTConflictType::NO_CONFLICT) { - D_ASSERT(conflict_idx.IsValid()); - for (idx_t i = 0; i < conflict_idx.GetIndex(); i++) { - if (keys[i].Empty()) { - continue; - } - Erase(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus()); - } - } - - if (was_empty) { - // All nodes are in-memory. - VerifyAllocationsInternal(); - } - - if (conflict_type == ARTConflictType::TRANSACTION) { - auto msg = AppendRowError(chunk, conflict_idx.GetIndex()); - return ErrorData(TransactionException("write-write conflict on key: \"%s\"", msg)); - } - - if (conflict_type == ARTConflictType::CONSTRAINT) { - auto msg = AppendRowError(chunk, conflict_idx.GetIndex()); - return ErrorData(ConstraintException("PRIMARY KEY or UNIQUE constraint violation: duplicate key \"%s\"", msg)); - } - -#ifdef DEBUG - for (idx_t i = 0; i < row_count; i++) { - if (keys[i].Empty()) { - continue; - } - D_ASSERT(Lookup(tree, keys[i], 0)); - } -#endif - return ErrorData(); -} - -ErrorData ART::Append(IndexLock &l, DataChunk &chunk, Vector &row_ids) { - // Execute all column expressions before inserting the data chunk. - DataChunk expr_chunk; - expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(chunk, expr_chunk); - - // Now insert the data chunk. - return Insert(l, expr_chunk, row_ids, nullptr); -} - -ErrorData ART::AppendWithDeleteIndex(IndexLock &l, DataChunk &chunk, Vector &row_ids, - optional_ptr delete_index) { - // Execute all column expressions before inserting the data chunk. - DataChunk expr_chunk; - expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(chunk, expr_chunk); - - // Now insert the data chunk. - return Insert(l, expr_chunk, row_ids, delete_index); -} - -void ART::VerifyAppend(DataChunk &chunk, optional_ptr delete_index, optional_ptr manager) { - if (manager) { - D_ASSERT(manager->LookupType() == VerifyExistenceType::APPEND); - return VerifyConstraint(chunk, delete_index, *manager); - } - ConflictManager local_manager(VerifyExistenceType::APPEND, chunk.size()); - VerifyConstraint(chunk, delete_index, local_manager); -} - -void ART::InsertIntoEmpty(Node &node, const ARTKey &key, const idx_t depth, const ARTKey &row_id, - const GateStatus status) { - D_ASSERT(depth <= key.len); - D_ASSERT(!node.HasMetadata()); - - if (status == GateStatus::GATE_SET) { - Leaf::New(node, row_id.GetRowId()); - return; - } - - reference ref(node); - auto count = key.len - depth; - - Prefix::New(*this, ref, key, depth, count); - Leaf::New(ref, row_id.GetRowId()); -} - -ARTConflictType ART::InsertIntoInlined(Node &node, const ARTKey &key, const idx_t depth, const ARTKey &row_id, - const GateStatus status, optional_ptr delete_art) { - - if (!IsUnique() || append_mode == ARTAppendMode::INSERT_DUPLICATES) { - Leaf::InsertIntoInlined(*this, node, row_id, depth, status); - return ARTConflictType::NO_CONFLICT; - } - - if (!delete_art) { - if (append_mode == ARTAppendMode::IGNORE_DUPLICATES) { - return ARTConflictType::NO_CONFLICT; - } - return ARTConflictType::CONSTRAINT; - } - - // Lookup in the delete_art. - auto delete_leaf = delete_art->Lookup(delete_art->tree, key, 0); - if (!delete_leaf) { - return ARTConflictType::CONSTRAINT; - } - - // The row ID has changed. - // Thus, the local index has a newer (local) row ID, and this is a constraint violation. - D_ASSERT(delete_leaf->GetType() == NType::LEAF_INLINED); - auto deleted_row_id = delete_leaf->GetRowId(); - auto this_row_id = node.GetRowId(); - if (deleted_row_id != this_row_id) { - return ARTConflictType::CONSTRAINT; - } - - // The deleted key and its row ID match the current key and its row ID. - Leaf::InsertIntoInlined(*this, node, row_id, depth, status); - return ARTConflictType::NO_CONFLICT; -} - -ARTConflictType ART::InsertIntoNode(Node &node, const ARTKey &key, const idx_t depth, const ARTKey &row_id, - const GateStatus status, optional_ptr delete_art) { - D_ASSERT(depth < key.len); - auto child = node.GetChildMutable(*this, key[depth]); - - // Recurse, if a child exists at key[depth]. - if (child) { - D_ASSERT(child->HasMetadata()); - auto conflict_type = Insert(*child, key, depth + 1, row_id, status, delete_art); - node.ReplaceChild(*this, key[depth], *child); - return conflict_type; - } - - // Create an inlined prefix at key[depth]. - if (status == GateStatus::GATE_SET) { - Node remainder; - auto byte = key[depth]; - auto conflict_type = Insert(remainder, key, depth + 1, row_id, status, delete_art); - Node::InsertChild(*this, node, byte, remainder); - return conflict_type; - } - - // Insert an inlined leaf at key[depth]. - Node leaf; - reference ref(leaf); - - // Create the prefix. - if (depth + 1 < key.len) { - auto count = key.len - depth - 1; - Prefix::New(*this, ref, key, depth + 1, count); - } - - // Create the inlined leaf. - Leaf::New(ref, row_id.GetRowId()); - Node::InsertChild(*this, node, key[depth], leaf); - return ARTConflictType::NO_CONFLICT; -} - -ARTConflictType ART::Insert(Node &node, const ARTKey &key, idx_t depth, const ARTKey &row_id, const GateStatus status, - optional_ptr delete_art) { - if (!node.HasMetadata()) { - InsertIntoEmpty(node, key, depth, row_id, status); - return ARTConflictType::NO_CONFLICT; - } - - // Enter a nested leaf. - if (status == GateStatus::GATE_NOT_SET && node.GetGateStatus() == GateStatus::GATE_SET) { - if (IsUnique()) { - // Unique indexes can have duplicates, if another transaction DELETE + INSERT - // the same key. In that case, the previous value must be kept alive until all - // other transactions do not depend on it anymore. - - // We restrict this transactionality to two-value leaves, so any subsequent - // incoming transaction must fail here. - return ARTConflictType::TRANSACTION; - } - return Insert(node, row_id, 0, row_id, GateStatus::GATE_SET, delete_art); - } - - auto type = node.GetType(); - switch (type) { - case NType::LEAF_INLINED: { - return InsertIntoInlined(node, key, depth, row_id, status, delete_art); - } - case NType::LEAF: { - Leaf::TransformToNested(*this, node); - return Insert(node, key, depth, row_id, status, delete_art); - } - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: { - // Row IDs are unique, so there are never any duplicate byte conflicts here. - auto byte = key[Prefix::ROW_ID_COUNT]; - Node::InsertChild(*this, node, byte); - return ARTConflictType::NO_CONFLICT; - } - case NType::NODE_4: - case NType::NODE_16: - case NType::NODE_48: - case NType::NODE_256: - return InsertIntoNode(node, key, depth, row_id, status, delete_art); - case NType::PREFIX: - return Prefix::Insert(*this, node, key, depth, row_id, status, delete_art); - default: - throw InternalException("Invalid node type for ART::Insert."); - } -} - -//===--------------------------------------------------------------------===// -// Drop and Delete -//===--------------------------------------------------------------------===// - -void ART::CommitDrop(IndexLock &index_lock) { - for (auto &allocator : *allocators) { - allocator->Reset(); - } - tree.Clear(); -} - -void ART::Delete(IndexLock &state, DataChunk &input, Vector &row_ids) { - auto row_count = input.size(); - - DataChunk expr_chunk; - expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expr_chunk); - - ArenaAllocator allocator(BufferAllocator::Get(db)); - unsafe_vector keys(row_count); - unsafe_vector row_id_keys(row_count); - GenerateKeyVectors(allocator, expr_chunk, row_ids, keys, row_id_keys); - - for (idx_t i = 0; i < row_count; i++) { - if (keys[i].Empty()) { - continue; - } - Erase(tree, keys[i], 0, row_id_keys[i], tree.GetGateStatus()); - } - - if (!tree.HasMetadata()) { - // No more allocations. - VerifyAllocationsInternal(); - } - -#ifdef DEBUG - for (idx_t i = 0; i < row_count; i++) { - if (keys[i].Empty()) { - continue; - } - auto leaf = Lookup(tree, keys[i], 0); - if (leaf && leaf->GetType() == NType::LEAF_INLINED) { - D_ASSERT(leaf->GetRowId() != row_id_keys[i].GetRowId()); - } - } -#endif -} - -void ART::Erase(Node &node, reference key, idx_t depth, reference row_id, - GateStatus status) { - if (!node.HasMetadata()) { - return; - } - - // Traverse the prefix. - reference next(node); - if (next.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, next, key, depth); - - // Prefixes don't match: nothing to erase. - if (next.get().GetType() == NType::PREFIX && next.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { - return; - } - } - - // Delete the row ID from the leaf. - // This is the root node, which can be a leaf with possible prefix nodes. - if (next.get().GetType() == NType::LEAF_INLINED) { - if (next.get().GetRowId() == row_id.get().GetRowId()) { - Node::Free(*this, node); - } - return; - } - - // Transform a deprecated leaf. - if (next.get().GetType() == NType::LEAF) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - Leaf::TransformToNested(*this, next); - } - - // Enter a nested leaf. - if (status == GateStatus::GATE_NOT_SET && next.get().GetGateStatus() == GateStatus::GATE_SET) { - return Erase(next, row_id, 0, row_id, GateStatus::GATE_SET); - } - - D_ASSERT(depth < key.get().len); - if (next.get().IsLeafNode()) { - auto byte = key.get()[depth]; - if (next.get().HasByte(*this, byte)) { - Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); - } - return; - } - - auto child = next.get().GetChildMutable(*this, key.get()[depth]); - if (!child) { - // No child at the byte: nothing to erase. - return; - } - - // Transform a deprecated leaf. - if (child->GetType() == NType::LEAF) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - Leaf::TransformToNested(*this, *child); - } - - // Enter a nested leaf. - if (status == GateStatus::GATE_NOT_SET && child->GetGateStatus() == GateStatus::GATE_SET) { - Erase(*child, row_id, 0, row_id, GateStatus::GATE_SET); - if (!child->HasMetadata()) { - Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); - } else { - next.get().ReplaceChild(*this, key.get()[depth], *child); - } - return; - } - - auto temp_depth = depth + 1; - reference ref(*child); - - if (ref.get().GetType() == NType::PREFIX) { - Prefix::TraverseMutable(*this, ref, key, temp_depth); - - // Prefixes don't match: nothing to erase. - if (ref.get().GetType() == NType::PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { - return; - } - } - - if (ref.get().GetType() == NType::LEAF_INLINED) { - if (ref.get().GetRowId() == row_id.get().GetRowId()) { - Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); - } - return; - } - - // Recurse. - Erase(*child, key, depth + 1, row_id, status); - if (!child->HasMetadata()) { - Node::DeleteChild(*this, next, node, key.get()[depth], status, key.get()); - } else { - next.get().ReplaceChild(*this, key.get()[depth], *child); - } -} - -//===--------------------------------------------------------------------===// -// Point and range lookups -//===--------------------------------------------------------------------===// - -const unsafe_optional_ptr ART::Lookup(const Node &node, const ARTKey &key, idx_t depth) { - reference ref(node); - while (ref.get().HasMetadata()) { - - // Return the leaf. - if (ref.get().IsAnyLeaf() || ref.get().GetGateStatus() == GateStatus::GATE_SET) { - return unsafe_optional_ptr(ref.get()); - } - - // Traverse the prefix. - if (ref.get().GetType() == NType::PREFIX) { - Prefix::Traverse(*this, ref, key, depth); - if (ref.get().GetType() == NType::PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { - // Prefix mismatch, return nullptr. - return nullptr; - } - continue; - } - - // Get the child node. - D_ASSERT(depth < key.len); - auto child = ref.get().GetChild(*this, key[depth]); - - // No child at the matching byte, return nullptr. - if (!child) { - return nullptr; - } - - // Continue in the child. - ref = *child; - D_ASSERT(ref.get().HasMetadata()); - depth++; - } - - return nullptr; -} - -bool ART::SearchEqual(ARTKey &key, idx_t max_count, unsafe_vector &row_ids) { - auto leaf = Lookup(tree, key, 0); - if (!leaf) { - return true; - } - - Iterator it(*this); - it.FindMinimum(*leaf); - ARTKey empty_key = ARTKey(); - return it.Scan(empty_key, max_count, row_ids, false); -} - -bool ART::SearchGreater(ARTKey &key, bool equal, idx_t max_count, unsafe_vector &row_ids) { - if (!tree.HasMetadata()) { - return true; - } - - // Find the lowest value that satisfies the predicate. - Iterator it(*this); - - // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, key, equal, 0)) { - return true; - } - - // We continue the scan. We do not check the bounds as any value following this value is - // greater and satisfies our predicate. - return it.Scan(ARTKey(), max_count, row_ids, false); -} - -bool ART::SearchLess(ARTKey &upper_bound, bool equal, idx_t max_count, unsafe_vector &row_ids) { - if (!tree.HasMetadata()) { - return true; - } - - // Find the minimum value in the ART: we start scanning from this value. - Iterator it(*this); - it.FindMinimum(tree); - - // Early-out, if the minimum value is higher than the upper bound. - if (it.current_key.GreaterThan(upper_bound, equal, it.GetNestedDepth())) { - return true; - } - - // Continue the scan until we reach the upper bound. - return it.Scan(upper_bound, max_count, row_ids, equal); -} - -bool ART::SearchCloseRange(ARTKey &lower_bound, ARTKey &upper_bound, bool left_equal, bool right_equal, idx_t max_count, - unsafe_vector &row_ids) { - // Find the first node that satisfies the left predicate. - Iterator it(*this); - - // Early-out, if the maximum value in the ART is lower than the lower bound. - if (!it.LowerBound(tree, lower_bound, left_equal, 0)) { - return true; - } - - // Continue the scan until we reach the upper bound. - return it.Scan(upper_bound, max_count, row_ids, right_equal); -} - -bool ART::Scan(IndexScanState &state, const idx_t max_count, unsafe_vector &row_ids) { - auto &scan_state = state.Cast(); - D_ASSERT(scan_state.values[0].type().InternalType() == types[0]); - ArenaAllocator arena_allocator(Allocator::Get(db)); - auto key = ARTKey::CreateKey(arena_allocator, types[0], scan_state.values[0]); - - if (scan_state.values[1].IsNull()) { - // Single predicate. - lock_guard l(lock); - switch (scan_state.expressions[0]) { - case ExpressionType::COMPARE_EQUAL: - return SearchEqual(key, max_count, row_ids); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return SearchGreater(key, true, max_count, row_ids); - case ExpressionType::COMPARE_GREATERTHAN: - return SearchGreater(key, false, max_count, row_ids); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return SearchLess(key, true, max_count, row_ids); - case ExpressionType::COMPARE_LESSTHAN: - return SearchLess(key, false, max_count, row_ids); - default: - throw InternalException("Index scan type not implemented"); - } - } - - // Two predicates. - lock_guard l(lock); - D_ASSERT(scan_state.values[1].type().InternalType() == types[0]); - auto upper_bound = ARTKey::CreateKey(arena_allocator, types[0], scan_state.values[1]); - bool left_equal = scan_state.expressions[0] == ExpressionType ::COMPARE_GREATERTHANOREQUALTO; - bool right_equal = scan_state.expressions[1] == ExpressionType ::COMPARE_LESSTHANOREQUALTO; - return SearchCloseRange(key, upper_bound, left_equal, right_equal, max_count, row_ids); -} - -//===--------------------------------------------------------------------===// -// More Constraint Checking -//===--------------------------------------------------------------------===// - -string ART::GenerateErrorKeyName(DataChunk &input, idx_t row_idx) { - DataChunk expr_chunk; - expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(input, expr_chunk); - - string key_name; - for (idx_t k = 0; k < expr_chunk.ColumnCount(); k++) { - if (k > 0) { - key_name += ", "; - } - key_name += unbound_expressions[k]->GetName() + ": " + expr_chunk.data[k].GetValue(row_idx).ToString(); - } - return key_name; -} - -string ART::GenerateConstraintErrorMessage(VerifyExistenceType verify_type, const string &key_name) { - switch (verify_type) { - case VerifyExistenceType::APPEND: { - // APPEND to PK/UNIQUE table, but node/key already exists in PK/UNIQUE table. - string type = IsPrimary() ? "primary key" : "unique"; - return StringUtil::Format("Duplicate key \"%s\" violates %s constraint.", key_name, type); - } - case VerifyExistenceType::APPEND_FK: { - // APPEND_FK to FK table, node/key does not exist in PK/UNIQUE table. - return StringUtil::Format( - "Violates foreign key constraint because key \"%s\" does not exist in the referenced table", key_name); - } - case VerifyExistenceType::DELETE_FK: { - // DELETE_FK that still exists in a FK table, i.e., not a valid delete. - return StringUtil::Format("Violates foreign key constraint because key \"%s\" is still referenced by a foreign " - "key in a different table", - key_name); - } - default: - throw NotImplementedException("Type not implemented for VerifyExistenceType"); - } -} - -void ART::VerifyLeaf(const Node &leaf, const ARTKey &key, optional_ptr delete_art, ConflictManager &manager, - optional_idx &conflict_idx, idx_t i) { - // Fast path, the leaf is inlined, and the delete ART does not exist. - if (leaf.GetType() == NType::LEAF_INLINED && !delete_art) { - if (manager.AddHit(i, leaf.GetRowId())) { - conflict_idx = i; - } - return; - } - - // Get the delete_leaf. - // All leaves in the delete ART are inlined. - auto deleted_leaf = delete_art->Lookup(delete_art->tree, key, 0); - - // The leaf is inlined, and the same key does not exist in the delete ART. - if (leaf.GetType() == NType::LEAF_INLINED && !deleted_leaf) { - if (manager.AddHit(i, leaf.GetRowId())) { - conflict_idx = i; - } - return; - } - - // The leaf is inlined, and the same key exists in the delete ART. - if (leaf.GetType() == NType::LEAF_INLINED && deleted_leaf) { - auto deleted_row_id = deleted_leaf->GetRowId(); - auto this_row_id = leaf.GetRowId(); - - if (deleted_row_id == this_row_id) { - if (manager.AddMiss(i)) { - conflict_idx = i; - } - return; - } - - if (manager.AddHit(i, this_row_id)) { - conflict_idx = i; - } - return; - } - - // Scan the two row IDs in the leaf. - Iterator it(*this); - it.FindMinimum(leaf); - ARTKey empty_key = ARTKey(); - unsafe_vector row_ids; - it.Scan(empty_key, 2, row_ids, false); - - if (!deleted_leaf) { - if (manager.AddHit(i, row_ids[0]) || manager.AddHit(i, row_ids[0])) { - conflict_idx = i; - } - return; - } - - auto deleted_row_id = deleted_leaf->GetRowId(); - - if (deleted_row_id == row_ids[0] || deleted_row_id == row_ids[1]) { - if (manager.AddMiss(i)) { - conflict_idx = i; - } - return; - } - - if (manager.AddHit(i, row_ids[0]) || manager.AddHit(i, row_ids[1])) { - conflict_idx = i; - } -} - -void ART::VerifyConstraint(DataChunk &chunk, optional_ptr delete_index, ConflictManager &manager) { - // Lock the index during constraint checking. - lock_guard l(lock); - - DataChunk expr_chunk; - expr_chunk.Initialize(Allocator::DefaultAllocator(), logical_types); - ExecuteExpressions(chunk, expr_chunk); - - ArenaAllocator arena_allocator(BufferAllocator::Get(db)); - unsafe_vector keys(expr_chunk.size()); - GenerateKeys<>(arena_allocator, expr_chunk, keys); - - optional_ptr delete_art; - if (delete_index) { - delete_art = delete_index->Cast(); - } - - optional_idx conflict_idx; - for (idx_t i = 0; !conflict_idx.IsValid() && i < chunk.size(); i++) { - if (keys[i].Empty()) { - if (manager.AddNull(i)) { - conflict_idx = i; - } - continue; - } - - auto leaf = Lookup(tree, keys[i], 0); - if (!leaf) { - if (manager.AddMiss(i)) { - conflict_idx = i; - } - continue; - } - VerifyLeaf(*leaf, keys[i], delete_art, manager, conflict_idx, i); - } - - manager.FinishLookup(); - if (!conflict_idx.IsValid()) { - return; - } - - auto key_name = GenerateErrorKeyName(chunk, conflict_idx.GetIndex()); - auto exception_msg = GenerateConstraintErrorMessage(manager.LookupType(), key_name); - throw ConstraintException(exception_msg); -} - -string ART::GetConstraintViolationMessage(VerifyExistenceType verify_type, idx_t failed_index, DataChunk &input) { - auto key_name = GenerateErrorKeyName(input, failed_index); - auto exception_msg = GenerateConstraintErrorMessage(verify_type, key_name); - return exception_msg; -} - -//===--------------------------------------------------------------------===// -// Storage and Memory -//===--------------------------------------------------------------------===// - -void ART::TransformToDeprecated() { - auto idx = Node::GetAllocatorIdx(NType::PREFIX); - auto &block_manager = (*allocators)[idx]->block_manager; - unsafe_unique_ptr deprecated_allocator; - - if (prefix_count != Prefix::DEPRECATED_COUNT) { - auto prefix_size = NumericCast(Prefix::DEPRECATED_COUNT) + NumericCast(Prefix::METADATA_SIZE); - deprecated_allocator = make_unsafe_uniq(prefix_size, block_manager); - } - - // Transform all leaves, and possibly the prefixes. - if (tree.HasMetadata()) { - Node::TransformToDeprecated(*this, tree, deprecated_allocator); - } - - // Replace the prefix allocator with the deprecated allocator. - if (deprecated_allocator) { - prefix_count = Prefix::DEPRECATED_COUNT; - - D_ASSERT((*allocators)[idx]->IsEmpty()); - (*allocators)[idx]->Reset(); - (*allocators)[idx] = std::move(deprecated_allocator); - } -} - -IndexStorageInfo ART::GetStorageInfo(const case_insensitive_map_t &options, const bool to_wal) { - // If the storage format uses deprecated leaf storage, - // then we need to transform all nested leaves before serialization. - auto v1_0_0_option = options.find("v1_0_0_storage"); - bool v1_0_0_storage = v1_0_0_option == options.end() || v1_0_0_option->second != Value(false); - if (v1_0_0_storage) { - TransformToDeprecated(); - } - - IndexStorageInfo info(name); - info.root = tree.Get(); - info.options = options; - - for (auto &allocator : *allocators) { - allocator->RemoveEmptyBuffers(); - } - -#ifdef DEBUG - if (v1_0_0_storage) { - D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_7_LEAF)]->IsEmpty()); - D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_15_LEAF)]->IsEmpty()); - D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::NODE_256_LEAF)]->IsEmpty()); - D_ASSERT((*allocators)[Node::GetAllocatorIdx(NType::PREFIX)]->GetSegmentSize() == - Prefix::DEPRECATED_COUNT + Prefix::METADATA_SIZE); - } -#endif - - auto allocator_count = v1_0_0_storage ? DEPRECATED_ALLOCATOR_COUNT : ALLOCATOR_COUNT; - if (!to_wal) { - // Store the data on disk as partial blocks and set the block ids. - WritePartialBlocks(v1_0_0_storage); - - } else { - // Set the correct allocation sizes and get the map containing all buffers. - for (idx_t i = 0; i < allocator_count; i++) { - info.buffers.push_back((*allocators)[i]->InitSerializationToWAL()); - } - } - - for (idx_t i = 0; i < allocator_count; i++) { - info.allocator_infos.push_back((*allocators)[i]->GetInfo()); - } - return info; -} - -void ART::WritePartialBlocks(const bool v1_0_0_storage) { - auto &block_manager = table_io_manager.GetIndexBlockManager(); - PartialBlockManager partial_block_manager(block_manager, PartialBlockType::FULL_CHECKPOINT); - - idx_t allocator_count = v1_0_0_storage ? DEPRECATED_ALLOCATOR_COUNT : ALLOCATOR_COUNT; - for (idx_t i = 0; i < allocator_count; i++) { - (*allocators)[i]->SerializeBuffers(partial_block_manager); - } - partial_block_manager.FlushPartialBlocks(); -} - -void ART::InitAllocators(const IndexStorageInfo &info) { - for (idx_t i = 0; i < info.allocator_infos.size(); i++) { - (*allocators)[i]->Init(info.allocator_infos[i]); - } -} - -void ART::Deserialize(const BlockPointer &pointer) { - D_ASSERT(pointer.IsValid()); - - auto &metadata_manager = table_io_manager.GetMetadataManager(); - MetadataReader reader(metadata_manager, pointer); - tree = reader.Read(); - - for (idx_t i = 0; i < DEPRECATED_ALLOCATOR_COUNT; i++) { - (*allocators)[i]->Deserialize(metadata_manager, reader.Read()); - } -} - -void ART::SetPrefixCount(const IndexStorageInfo &info) { - auto numeric_max = NumericLimits().Maximum(); - auto max_aligned = AlignValueFloor(numeric_max - Prefix::METADATA_SIZE); - - if (info.IsValid() && info.root_block_ptr.IsValid()) { - prefix_count = Prefix::DEPRECATED_COUNT; - return; - } - - if (info.IsValid()) { - auto serialized_count = info.allocator_infos[0].segment_size - Prefix::METADATA_SIZE; - prefix_count = NumericCast(serialized_count); - return; - } - - if (!IsUnique()) { - prefix_count = Prefix::ROW_ID_COUNT; - return; - } - - idx_t compound_size = 0; - for (const auto &type : types) { - compound_size += GetTypeIdSize(type); - } - - auto aligned = AlignValue(compound_size) - 1; - if (aligned > NumericCast(max_aligned)) { - prefix_count = max_aligned; - return; - } - - prefix_count = NumericCast(aligned); -} - -idx_t ART::GetInMemorySize(IndexLock &index_lock) { - D_ASSERT(owns_data); - - idx_t in_memory_size = 0; - for (auto &allocator : *allocators) { - in_memory_size += allocator->GetInMemorySize(); - } - return in_memory_size; -} - -//===--------------------------------------------------------------------===// -// Vacuum -//===--------------------------------------------------------------------===// - -void ART::InitializeVacuum(unordered_set &indexes) { - for (idx_t i = 0; i < allocators->size(); i++) { - if ((*allocators)[i]->InitializeVacuum()) { - indexes.insert(NumericCast(i)); - } - } -} - -void ART::FinalizeVacuum(const unordered_set &indexes) { - for (const auto &idx : indexes) { - (*allocators)[idx]->FinalizeVacuum(); - } -} - -void ART::Vacuum(IndexLock &state) { - D_ASSERT(owns_data); - - if (!tree.HasMetadata()) { - for (auto &allocator : *allocators) { - allocator->Reset(); - } - return; - } - - // True, if an allocator needs a vacuum, false otherwise. - unordered_set indexes; - InitializeVacuum(indexes); - - // Skip vacuum, if no allocators require it. - if (indexes.empty()) { - return; - } - - // Traverse the allocated memory of the tree to perform a vacuum. - tree.Vacuum(*this, indexes); - - // Finalize the vacuum operation. - FinalizeVacuum(indexes); -} - -//===--------------------------------------------------------------------===// -// Merging -//===--------------------------------------------------------------------===// - -void ART::InitializeMerge(unsafe_vector &upper_bounds) { - D_ASSERT(owns_data); - for (auto &allocator : *allocators) { - upper_bounds.emplace_back(allocator->GetUpperBoundBufferId()); - } -} - -bool ART::MergeIndexes(IndexLock &state, BoundIndex &other_index) { - auto &other_art = other_index.Cast(); - if (!other_art.tree.HasMetadata()) { - return true; - } - - if (other_art.owns_data) { - if (tree.HasMetadata()) { - // Fully deserialize other_index, and traverse it to increment its buffer IDs. - unsafe_vector upper_bounds; - InitializeMerge(upper_bounds); - other_art.tree.InitMerge(other_art, upper_bounds); - } - - // Merge the node storage. - for (idx_t i = 0; i < allocators->size(); i++) { - (*allocators)[i]->Merge(*(*other_art.allocators)[i]); - } - } - - // Merge the ARTs. - D_ASSERT(tree.GetGateStatus() == other_art.tree.GetGateStatus()); - if (!tree.Merge(*this, other_art.tree, tree.GetGateStatus())) { - return false; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Verification -//===--------------------------------------------------------------------===// - -string ART::VerifyAndToString(IndexLock &state, const bool only_verify) { - return VerifyAndToStringInternal(only_verify); -} - -string ART::VerifyAndToStringInternal(const bool only_verify) { - if (tree.HasMetadata()) { - return "ART: " + tree.VerifyAndToString(*this, only_verify); - } - return "[empty]"; -} - -void ART::VerifyAllocations(IndexLock &state) { - return VerifyAllocationsInternal(); -} - -void ART::VerifyAllocationsInternal() { -#ifdef DEBUG - unordered_map node_counts; - for (idx_t i = 0; i < allocators->size(); i++) { - node_counts[NumericCast(i)] = 0; - } - - if (tree.HasMetadata()) { - tree.VerifyAllocations(*this, node_counts); - } - - for (idx_t i = 0; i < allocators->size(); i++) { - auto segment_count = (*allocators)[i]->GetSegmentCount(); - D_ASSERT(segment_count == node_counts[NumericCast(i)]); - } -#endif -} - -constexpr const char *ART::TYPE_NAME; - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/art_key.cpp b/src/duckdb/src/execution/index/art/art_key.cpp deleted file mode 100644 index d5769f0f5..000000000 --- a/src/duckdb/src/execution/index/art/art_key.cpp +++ /dev/null @@ -1,182 +0,0 @@ -#include "duckdb/execution/index/art/art_key.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// ARTKey -//===--------------------------------------------------------------------===// - -ARTKey::ARTKey() : len(0) { -} - -ARTKey::ARTKey(const data_ptr_t data, idx_t len) : len(len), data(data) { -} - -ARTKey::ARTKey(ArenaAllocator &allocator, idx_t len) : len(len) { - data = allocator.Allocate(len); -} - -template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, string_t value) { - auto string_data = const_data_ptr_cast(value.GetData()); - auto string_len = value.GetSize(); - - // We escape \00 and \01. - idx_t escape_count = 0; - for (idx_t i = 0; i < string_len; i++) { - if (string_data[i] <= 1) { - escape_count++; - } - } - - idx_t len = string_len + escape_count + 1; - auto data = allocator.Allocate(len); - - // Copy over the data and add escapes. - idx_t pos = 0; - for (idx_t i = 0; i < string_len; i++) { - if (string_data[i] <= 1) { - // Add escape. - data[pos++] = '\01'; - } - data[pos++] = string_data[i]; - } - - // End with a null-terminator. - data[pos] = '\0'; - return ARTKey(data, len); -} - -template <> -ARTKey ARTKey::CreateARTKey(ArenaAllocator &allocator, const char *value) { - return ARTKey::CreateARTKey(allocator, string_t(value, UnsafeNumericCast(strlen(value)))); -} - -template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, ARTKey &key, string_t value) { - key = ARTKey::CreateARTKey(allocator, value); -} - -template <> -void ARTKey::CreateARTKey(ArenaAllocator &allocator, ARTKey &key, const char *value) { - ARTKey::CreateARTKey(allocator, key, string_t(value, UnsafeNumericCast(strlen(value)))); -} - -ARTKey ARTKey::CreateKey(ArenaAllocator &allocator, PhysicalType type, Value &value) { - D_ASSERT(type == value.type().InternalType()); - switch (type) { - case PhysicalType::BOOL: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::INT8: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::INT16: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::INT32: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::INT64: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::UINT8: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::UINT16: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::UINT32: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::UINT64: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::INT128: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::UINT128: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::FLOAT: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::DOUBLE: - return ARTKey::CreateARTKey(allocator, value); - case PhysicalType::VARCHAR: - return ARTKey::CreateARTKey(allocator, value); - default: - throw InternalException("Invalid type for the ART key."); - } -} - -bool ARTKey::operator>(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(len, key.len); i++) { - if (data[i] > key.data[i]) { - return true; - } else if (data[i] < key.data[i]) { - return false; - } - } - return len > key.len; -} - -bool ARTKey::operator>=(const ARTKey &key) const { - for (idx_t i = 0; i < MinValue(len, key.len); i++) { - if (data[i] > key.data[i]) { - return true; - } else if (data[i] < key.data[i]) { - return false; - } - } - return len >= key.len; -} - -bool ARTKey::operator==(const ARTKey &key) const { - if (len != key.len) { - return false; - } - for (idx_t i = 0; i < len; i++) { - if (data[i] != key.data[i]) { - return false; - } - } - return true; -} - -void ARTKey::Concat(ArenaAllocator &allocator, const ARTKey &other) { - auto compound_data = allocator.Allocate(len + other.len); - memcpy(compound_data, data, len); - memcpy(compound_data + len, other.data, other.len); - len += other.len; - data = compound_data; -} - -row_t ARTKey::GetRowId() const { - D_ASSERT(len == sizeof(row_t)); - return Radix::DecodeData(data); -} - -idx_t ARTKey::GetMismatchPos(const ARTKey &other, const idx_t start) const { - D_ASSERT(len <= other.len); - D_ASSERT(start <= len); - for (idx_t i = start; i < other.len; i++) { - if (data[i] != other.data[i]) { - return i; - } - } - return DConstants::INVALID_INDEX; -} - -//===--------------------------------------------------------------------===// -// ARTKeySection -//===--------------------------------------------------------------------===// - -ARTKeySection::ARTKeySection(idx_t start, idx_t end, idx_t depth, data_t byte) - : start(start), end(end), depth(depth), key_byte(byte) { -} - -ARTKeySection::ARTKeySection(idx_t start, idx_t end, const unsafe_vector &keys, const ARTKeySection §ion) - : start(start), end(end), depth(section.depth + 1), key_byte(keys[end].data[section.depth]) { -} - -void ARTKeySection::GetChildSections(unsafe_vector §ions, const unsafe_vector &keys) { - auto child_idx = start; - for (idx_t i = start + 1; i <= end; i++) { - if (keys[i - 1].data[depth] != keys[i].data[depth]) { - sections.emplace_back(child_idx, i - 1, keys, *this); - child_idx = i; - } - } - sections.emplace_back(child_idx, end, keys, *this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/base_leaf.cpp b/src/duckdb/src/execution/index/art/base_leaf.cpp deleted file mode 100644 index 5034a0399..000000000 --- a/src/duckdb/src/execution/index/art/base_leaf.cpp +++ /dev/null @@ -1,168 +0,0 @@ -#include "duckdb/execution/index/art/base_leaf.hpp" - -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/prefix.hpp" -#include "duckdb/execution/index/art/node256_leaf.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// BaseLeaf -//===--------------------------------------------------------------------===// - -template -void BaseLeaf::InsertByteInternal(BaseLeaf &n, const uint8_t byte) { - // Still space. Insert the child. - uint8_t child_pos = 0; - while (child_pos < n.count && n.key[child_pos] < byte) { - child_pos++; - } - - // Move children backwards to make space. - for (uint8_t i = n.count; i > child_pos; i--) { - n.key[i] = n.key[i - 1]; - } - - n.key[child_pos] = byte; - n.count++; -} - -template -BaseLeaf &BaseLeaf::DeleteByteInternal(ART &art, Node &node, const uint8_t byte) { - auto &n = Node::Ref(art, node, node.GetType()); - uint8_t child_pos = 0; - - for (; child_pos < n.count; child_pos++) { - if (n.key[child_pos] == byte) { - break; - } - } - n.count--; - - // Possibly move children backwards. - for (uint8_t i = child_pos; i < n.count; i++) { - n.key[i] = n.key[i + 1]; - } - return n; -} - -//===--------------------------------------------------------------------===// -// Node7Leaf -//===--------------------------------------------------------------------===// - -void Node7Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node15. - auto &n7 = Node::Ref(art, node, NODE_7_LEAF); - if (n7.count == CAPACITY) { - auto node7 = node; - Node15Leaf::GrowNode7Leaf(art, node, node7); - Node15Leaf::InsertByte(art, node, byte); - return; - } - - // Still space. Insert the child. - uint8_t child_pos = 0; - while (child_pos < n7.count && n7.key[child_pos] < byte) { - child_pos++; - } - - InsertByteInternal(n7, byte); -} - -void Node7Leaf::DeleteByte(ART &art, Node &node, Node &prefix, const uint8_t byte, const ARTKey &row_id) { - auto &n7 = DeleteByteInternal(art, node, byte); - - // Compress one-way nodes. - if (n7.count == 1) { - D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); - - // Get the remaining row ID. - auto remainder = UnsafeNumericCast(row_id.GetRowId()) & AND_LAST_BYTE; - remainder |= UnsafeNumericCast(n7.key[0]); - - n7.count--; - Node::Free(art, node); - - if (prefix.GetType() == NType::PREFIX) { - Node::Free(art, prefix); - Leaf::New(prefix, UnsafeNumericCast(remainder)); - } else { - Leaf::New(node, UnsafeNumericCast(remainder)); - } - } -} - -void Node7Leaf::ShrinkNode15Leaf(ART &art, Node &node7_leaf, Node &node15_leaf) { - auto &n7 = New(art, node7_leaf); - auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); - node7_leaf.SetGateStatus(node15_leaf.GetGateStatus()); - - n7.count = n15.count; - for (uint8_t i = 0; i < n15.count; i++) { - n7.key[i] = n15.key[i]; - } - - n15.count = 0; - Node::Free(art, node15_leaf); -} - -//===--------------------------------------------------------------------===// -// Node15Leaf -//===--------------------------------------------------------------------===// - -void Node15Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - // The node is full. Grow to Node256Leaf. - auto &n15 = Node::Ref(art, node, NODE_15_LEAF); - if (n15.count == CAPACITY) { - auto node15 = node; - Node256Leaf::GrowNode15Leaf(art, node, node15); - Node256Leaf::InsertByte(art, node, byte); - return; - } - - InsertByteInternal(n15, byte); -} - -void Node15Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { - auto &n15 = DeleteByteInternal(art, node, byte); - - // Shrink node to Node7. - if (n15.count < Node7Leaf::CAPACITY) { - auto node15 = node; - Node7Leaf::ShrinkNode15Leaf(art, node, node15); - } -} - -void Node15Leaf::GrowNode7Leaf(ART &art, Node &node15_leaf, Node &node7_leaf) { - auto &n7 = Node::Ref(art, node7_leaf, NType::NODE_7_LEAF); - auto &n15 = New(art, node15_leaf); - node15_leaf.SetGateStatus(node7_leaf.GetGateStatus()); - - n15.count = n7.count; - for (uint8_t i = 0; i < n7.count; i++) { - n15.key[i] = n7.key[i]; - } - - n7.count = 0; - Node::Free(art, node7_leaf); -} - -void Node15Leaf::ShrinkNode256Leaf(ART &art, Node &node15_leaf, Node &node256_leaf) { - auto &n15 = New(art, node15_leaf); - auto &n256 = Node::Ref(art, node256_leaf, NType::NODE_256_LEAF); - node15_leaf.SetGateStatus(node256_leaf.GetGateStatus()); - - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (mask.RowIsValid(i)) { - n15.key[n15.count] = UnsafeNumericCast(i); - n15.count++; - } - } - - Node::Free(art, node256_leaf); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/base_node.cpp b/src/duckdb/src/execution/index/art/base_node.cpp deleted file mode 100644 index 228d12e80..000000000 --- a/src/duckdb/src/execution/index/art/base_node.cpp +++ /dev/null @@ -1,163 +0,0 @@ -#include "duckdb/execution/index/art/base_node.hpp" - -#include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/execution/index/art/prefix.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// BaseNode -//===--------------------------------------------------------------------===// - -template -void BaseNode::InsertChildInternal(BaseNode &n, const uint8_t byte, const Node child) { - // Still space. Insert the child. - uint8_t child_pos = 0; - while (child_pos < n.count && n.key[child_pos] < byte) { - child_pos++; - } - - // Move children backwards to make space. - for (uint8_t i = n.count; i > child_pos; i--) { - n.key[i] = n.key[i - 1]; - n.children[i] = n.children[i - 1]; - } - - n.key[child_pos] = byte; - n.children[child_pos] = child; - n.count++; -} - -template -BaseNode &BaseNode::DeleteChildInternal(ART &art, Node &node, const uint8_t byte) { - auto &n = Node::Ref(art, node, TYPE); - - uint8_t child_pos = 0; - for (; child_pos < n.count; child_pos++) { - if (n.key[child_pos] == byte) { - break; - } - } - - // Free the child and decrease the count. - Node::Free(art, n.children[child_pos]); - n.count--; - - // Possibly move children backwards. - for (uint8_t i = child_pos; i < n.count; i++) { - n.key[i] = n.key[i + 1]; - n.children[i] = n.children[i + 1]; - } - return n; -} - -//===--------------------------------------------------------------------===// -// Node4 -//===--------------------------------------------------------------------===// - -void Node4::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - // The node is full. Grow to Node16. - auto &n = Node::Ref(art, node, NODE_4); - if (n.count == CAPACITY) { - auto node4 = node; - Node16::GrowNode4(art, node, node4); - Node16::InsertChild(art, node, byte, child); - return; - } - - InsertChildInternal(n, byte, child); -} - -void Node4::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, const GateStatus status) { - auto &n = DeleteChildInternal(art, node, byte); - - // Compress one-way nodes. - if (n.count == 1) { - n.count--; - - auto child = n.children[0]; - auto remainder = n.key[0]; - auto old_status = node.GetGateStatus(); - - Node::Free(art, node); - Prefix::Concat(art, prefix, remainder, old_status, child, status); - } -} - -void Node4::ShrinkNode16(ART &art, Node &node4, Node &node16) { - auto &n4 = New(art, node4); - auto &n16 = Node::Ref(art, node16, NType::NODE_16); - node4.SetGateStatus(node16.GetGateStatus()); - - n4.count = n16.count; - for (uint8_t i = 0; i < n16.count; i++) { - n4.key[i] = n16.key[i]; - n4.children[i] = n16.children[i]; - } - - n16.count = 0; - Node::Free(art, node16); -} - -//===--------------------------------------------------------------------===// -// Node16 -//===--------------------------------------------------------------------===// - -void Node16::DeleteChild(ART &art, Node &node, const uint8_t byte) { - auto &n = DeleteChildInternal(art, node, byte); - - // Shrink node to Node4. - if (n.count < Node4::CAPACITY) { - auto node16 = node; - Node4::ShrinkNode16(art, node, node16); - } -} - -void Node16::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - // The node is full. Grow to Node48. - auto &n16 = Node::Ref(art, node, NODE_16); - if (n16.count == CAPACITY) { - auto node16 = node; - Node48::GrowNode16(art, node, node16); - Node48::InsertChild(art, node, byte, child); - return; - } - - InsertChildInternal(n16, byte, child); -} - -void Node16::GrowNode4(ART &art, Node &node16, Node &node4) { - auto &n4 = Node::Ref(art, node4, NType::NODE_4); - auto &n16 = New(art, node16); - node16.SetGateStatus(node4.GetGateStatus()); - - n16.count = n4.count; - for (uint8_t i = 0; i < n4.count; i++) { - n16.key[i] = n4.key[i]; - n16.children[i] = n4.children[i]; - } - - n4.count = 0; - Node::Free(art, node4); -} - -void Node16::ShrinkNode48(ART &art, Node &node16, Node &node48) { - auto &n16 = New(art, node16); - auto &n48 = Node::Ref(art, node48, NType::NODE_48); - node16.SetGateStatus(node48.GetGateStatus()); - - n16.count = 0; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (n48.child_index[i] != Node48::EMPTY_MARKER) { - n16.key[n16.count] = UnsafeNumericCast(i); - n16.children[n16.count] = n48.children[n48.child_index[i]]; - n16.count++; - } - } - - n48.count = 0; - Node::Free(art, node48); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/iterator.cpp b/src/duckdb/src/execution/index/art/iterator.cpp deleted file mode 100644 index 689029a02..000000000 --- a/src/duckdb/src/execution/index/art/iterator.cpp +++ /dev/null @@ -1,284 +0,0 @@ -#include "duckdb/execution/index/art/iterator.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/index/art/node.hpp" -#include "duckdb/execution/index/art/prefix.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// IteratorKey -//===--------------------------------------------------------------------===// - -bool IteratorKey::Contains(const ARTKey &key) const { - if (Size() < key.len) { - return false; - } - for (idx_t i = 0; i < key.len; i++) { - if (key_bytes[i] != key.data[i]) { - return false; - } - } - return true; -} - -bool IteratorKey::GreaterThan(const ARTKey &key, const bool equal, const uint8_t nested_depth) const { - for (idx_t i = 0; i < MinValue(Size(), key.len); i++) { - if (key_bytes[i] > key.data[i]) { - return true; - } else if (key_bytes[i] < key.data[i]) { - return false; - } - } - - // Returns true, if current_key is greater than (or equal to) key. - D_ASSERT(Size() >= nested_depth); - auto this_len = Size() - nested_depth; - return equal ? this_len > key.len : this_len >= key.len; -} - -//===--------------------------------------------------------------------===// -// Iterator -//===--------------------------------------------------------------------===// - -bool Iterator::Scan(const ARTKey &upper_bound, const idx_t max_count, unsafe_vector &row_ids, const bool equal) { - bool has_next; - do { - // An empty upper bound indicates that no upper bound exists. - if (!upper_bound.Empty() && status == GateStatus::GATE_NOT_SET) { - if (current_key.GreaterThan(upper_bound, equal, nested_depth)) { - return true; - } - } - - switch (last_leaf.GetType()) { - case NType::LEAF_INLINED: - if (row_ids.size() + 1 > max_count) { - return false; - } - row_ids.push_back(last_leaf.GetRowId()); - break; - case NType::LEAF: - if (!Leaf::DeprecatedGetRowIds(art, last_leaf, row_ids, max_count)) { - return false; - } - break; - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: { - uint8_t byte = 0; - while (last_leaf.GetNextByte(art, byte)) { - if (row_ids.size() + 1 > max_count) { - return false; - } - row_id[ROW_ID_SIZE - 1] = byte; - ARTKey key(&row_id[0], ROW_ID_SIZE); - row_ids.push_back(key.GetRowId()); - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - } - break; - } - default: - throw InternalException("Invalid leaf type for index scan."); - } - - has_next = Next(); - } while (has_next); - return true; -} - -void Iterator::FindMinimum(const Node &node) { - D_ASSERT(node.HasMetadata()); - - // Found the minimum. - if (node.IsAnyLeaf()) { - last_leaf = node; - return; - } - - // We are passing a gate node. - if (node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - status = GateStatus::GATE_SET; - nested_depth = 0; - } - - // Traverse the prefix. - if (node.GetType() == NType::PREFIX) { - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = prefix.data[i]; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); - } - } - nodes.emplace(node, 0); - return FindMinimum(*prefix.ptr); - } - - // Go to the leftmost entry in the current node. - uint8_t byte = 0; - auto next = node.GetNextChild(art, byte); - D_ASSERT(next); - - // Recurse on the leftmost node. - current_key.Push(byte); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth] = byte; - nested_depth++; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); - } - nodes.emplace(node, byte); - FindMinimum(*next); -} - -bool Iterator::LowerBound(const Node &node, const ARTKey &key, const bool equal, idx_t depth) { - if (!node.HasMetadata()) { - return false; - } - - // We found any leaf node, or a gate. - if (node.IsAnyLeaf() || node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - D_ASSERT(current_key.Size() == key.len); - if (!equal && current_key.Contains(key)) { - return Next(); - } - - if (node.GetGateStatus() == GateStatus::GATE_SET) { - FindMinimum(node); - } else { - last_leaf = node; - } - return true; - } - - D_ASSERT(node.GetGateStatus() == GateStatus::GATE_NOT_SET); - if (node.GetType() != NType::PREFIX) { - auto next_byte = key[depth]; - auto child = node.GetNextChild(art, next_byte); - - // The key is greater than any key in this subtree. - if (!child) { - return Next(); - } - - current_key.Push(next_byte); - nodes.emplace(node, next_byte); - - // We return the minimum because all keys are greater than the lower bound. - if (next_byte > key[depth]) { - FindMinimum(*child); - return true; - } - - // We recurse into the child. - return LowerBound(*child, key, equal, depth + 1); - } - - // Push back all prefix bytes. - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - current_key.Push(prefix.data[i]); - } - nodes.emplace(node, 0); - - // We compare the prefix bytes with the key bytes. - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - // We found a prefix byte that is less than its corresponding key byte. - // I.e., the subsequent node is lesser than the key. Thus, the next node - // is the lower bound. - if (prefix.data[i] < key[depth + i]) { - return Next(); - } - - // We found a prefix byte that is greater than its corresponding key byte. - // I.e., the subsequent node is greater than the key. Thus, the minimum is - // the lower bound. - if (prefix.data[i] > key[depth + i]) { - FindMinimum(*prefix.ptr); - return true; - } - } - - // The prefix matches the key. We recurse into the child. - depth += prefix.data[Prefix::Count(art)]; - return LowerBound(*prefix.ptr, key, equal, depth); -} - -bool Iterator::Next() { - while (!nodes.empty()) { - auto &top = nodes.top(); - D_ASSERT(!top.node.IsAnyLeaf()); - - if (top.node.GetType() == NType::PREFIX) { - PopNode(); - continue; - } - - if (top.byte == NumericLimits::Maximum()) { - // No more children of this node. - // Move up the tree by popping the key byte of the current node. - PopNode(); - continue; - } - - top.byte++; - auto next_node = top.node.GetNextChild(art, top.byte); - if (!next_node) { - // No more children of this node. - // Move up the tree by popping the key byte of the current node. - PopNode(); - continue; - } - - current_key.Pop(1); - current_key.Push(top.byte); - if (status == GateStatus::GATE_SET) { - row_id[nested_depth - 1] = top.byte; - } - - FindMinimum(*next_node); - return true; - } - return false; -} - -void Iterator::PopNode() { - auto gate_status = nodes.top().node.GetGateStatus(); - - // Pop the byte and the node. - if (nodes.top().node.GetType() != NType::PREFIX) { - current_key.Pop(1); - if (status == GateStatus::GATE_SET) { - nested_depth--; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); - } - - } else { - // Pop all prefix bytes and the node. - Prefix prefix(art, nodes.top().node); - auto prefix_byte_count = prefix.data[Prefix::Count(art)]; - current_key.Pop(prefix_byte_count); - - if (status == GateStatus::GATE_SET) { - nested_depth -= prefix_byte_count; - D_ASSERT(nested_depth < Prefix::ROW_ID_SIZE); - } - } - nodes.pop(); - - // We are popping a gate node. - if (gate_status == GateStatus::GATE_SET) { - D_ASSERT(status == GateStatus::GATE_SET); - status = GateStatus::GATE_NOT_SET; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/leaf.cpp b/src/duckdb/src/execution/index/art/leaf.cpp deleted file mode 100644 index 5cde0d5d0..000000000 --- a/src/duckdb/src/execution/index/art/leaf.cpp +++ /dev/null @@ -1,248 +0,0 @@ -#include "duckdb/execution/index/art/leaf.hpp" - -#include "duckdb/common/types.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/art/base_leaf.hpp" -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/iterator.hpp" -#include "duckdb/execution/index/art/node.hpp" -#include "duckdb/execution/index/art/prefix.hpp" - -namespace duckdb { - -void Leaf::New(Node &node, const row_t row_id) { - D_ASSERT(row_id < MAX_ROW_ID_LOCAL); - node.Clear(); - node.SetMetadata(static_cast(INLINED)); - node.SetRowId(row_id); -} - -void Leaf::New(ART &art, reference &node, const unsafe_vector &row_ids, const idx_t start, - const idx_t count) { - D_ASSERT(count > 1); - D_ASSERT(!node.get().HasMetadata()); - - // We cannot recurse into the leaf during Construct(...) because row IDs are not sorted. - for (idx_t i = 0; i < count; i++) { - idx_t offset = start + i; - art.Insert(node, row_ids[offset], 0, row_ids[offset], GateStatus::GATE_SET, nullptr); - } - node.get().SetGateStatus(GateStatus::GATE_SET); -} - -void Leaf::MergeInlined(ART &art, Node &l_node, Node &r_node) { - D_ASSERT(r_node.GetType() == INLINED); - - ArenaAllocator arena_allocator(Allocator::Get(art.db)); - auto key = ARTKey::CreateARTKey(arena_allocator, r_node.GetRowId()); - art.Insert(l_node, key, 0, key, l_node.GetGateStatus(), nullptr); - r_node.Clear(); -} - -void Leaf::InsertIntoInlined(ART &art, Node &node, const ARTKey &row_id, idx_t depth, const GateStatus status) { - D_ASSERT(node.GetType() == INLINED); - - ArenaAllocator allocator(Allocator::Get(art.db)); - auto key = ARTKey::CreateARTKey(allocator, node.GetRowId()); - - GateStatus new_status; - if (status == GateStatus::GATE_NOT_SET || node.GetGateStatus() == GateStatus::GATE_SET) { - new_status = GateStatus::GATE_SET; - } else { - new_status = GateStatus::GATE_NOT_SET; - } - - if (new_status == GateStatus::GATE_SET) { - depth = 0; - } - node.Clear(); - - // Get the mismatching position. - D_ASSERT(row_id.len == key.len); - auto pos = row_id.GetMismatchPos(key, depth); - D_ASSERT(pos != DConstants::INVALID_INDEX); - D_ASSERT(pos >= depth); - auto byte = row_id.data[pos]; - - // Create the (optional) prefix and the node. - reference next(node); - auto count = pos - depth; - if (count != 0) { - Prefix::New(art, next, row_id, depth, count); - } - if (pos == Prefix::ROW_ID_COUNT) { - Node7Leaf::New(art, next); - } else { - Node4::New(art, next); - } - - // Create the children. - Node row_id_node; - Leaf::New(row_id_node, row_id.GetRowId()); - Node remainder; - if (pos != Prefix::ROW_ID_COUNT) { - Leaf::New(remainder, key.GetRowId()); - } - - Node::InsertChild(art, next, key[pos], remainder); - Node::InsertChild(art, next, byte, row_id_node); - node.SetGateStatus(new_status); -} - -void Leaf::TransformToNested(ART &art, Node &node) { - D_ASSERT(node.GetType() == LEAF); - - ArenaAllocator allocator(Allocator::Get(art.db)); - Node root = Node(); - - // Temporarily disable constraint checking. - if (art.IsUnique() && art.append_mode == ARTAppendMode::DEFAULT) { - art.append_mode = ARTAppendMode::INSERT_DUPLICATES; - } - - // Move all row IDs into the nested leaf. - reference leaf_ref(node); - while (leaf_ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, leaf_ref, LEAF); - for (uint8_t i = 0; i < leaf.count; i++) { - auto row_id = ARTKey::CreateARTKey(allocator, leaf.row_ids[i]); - auto conflict_type = art.Insert(root, row_id, 0, row_id, GateStatus::GATE_SET, nullptr); - if (conflict_type != ARTConflictType::NO_CONFLICT) { - throw InternalException("invalid conflict type in Leaf::TransformToNested"); - } - } - leaf_ref = leaf.ptr; - } - - art.append_mode = ARTAppendMode::DEFAULT; - root.SetGateStatus(GateStatus::GATE_SET); - Node::Free(art, node); - node = root; -} - -void Leaf::TransformToDeprecated(ART &art, Node &node) { - D_ASSERT(node.GetGateStatus() == GateStatus::GATE_SET || node.GetType() == LEAF); - - // Early-out, if we never transformed this leaf. - if (node.GetGateStatus() == GateStatus::GATE_NOT_SET) { - return; - } - - // Collect all row IDs and free the nested leaf. - unsafe_vector row_ids; - Iterator it(art); - it.FindMinimum(node); - ARTKey empty_key = ARTKey(); - it.Scan(empty_key, NumericLimits().Maximum(), row_ids, false); - Node::Free(art, node); - D_ASSERT(row_ids.size() > 1); - - // Create the deprecated leaves. - idx_t remaining = row_ids.size(); - idx_t copy_count = 0; - reference ref(node); - while (remaining) { - ref.get() = Node::GetAllocator(art, LEAF).New(); - ref.get().SetMetadata(static_cast(LEAF)); - - auto &leaf = Node::Ref(art, ref, LEAF); - auto min = MinValue(UnsafeNumericCast(LEAF_SIZE), remaining); - leaf.count = UnsafeNumericCast(min); - - for (uint8_t i = 0; i < leaf.count; i++) { - leaf.row_ids[i] = row_ids[copy_count + i]; - } - - copy_count += leaf.count; - remaining -= leaf.count; - - ref = leaf.ptr; - leaf.ptr.Clear(); - } -} - -//===--------------------------------------------------------------------===// -// Deprecated code paths. -//===--------------------------------------------------------------------===// - -void Leaf::DeprecatedFree(ART &art, Node &node) { - D_ASSERT(node.GetType() == LEAF); - - Node next; - while (node.HasMetadata()) { - next = Node::Ref(art, node, LEAF).ptr; - Node::GetAllocator(art, LEAF).Free(node); - node = next; - } - node.Clear(); -} - -bool Leaf::DeprecatedGetRowIds(ART &art, const Node &node, unsafe_vector &row_ids, const idx_t max_count) { - D_ASSERT(node.GetType() == LEAF); - - reference ref(node); - while (ref.get().HasMetadata()) { - - auto &leaf = Node::Ref(art, ref, LEAF); - if (row_ids.size() + leaf.count > max_count) { - return false; - } - for (uint8_t i = 0; i < leaf.count; i++) { - row_ids.push_back(leaf.row_ids[i]); - } - ref = leaf.ptr; - } - return true; -} - -void Leaf::DeprecatedVacuum(ART &art, Node &node) { - D_ASSERT(node.HasMetadata()); - D_ASSERT(node.GetType() == LEAF); - - auto &allocator = Node::GetAllocator(art, LEAF); - reference ref(node); - while (ref.get().HasMetadata()) { - if (allocator.NeedsVacuum(ref)) { - ref.get() = allocator.VacuumPointer(ref); - ref.get().SetMetadata(static_cast(LEAF)); - } - auto &leaf = Node::Ref(art, ref, LEAF); - ref = leaf.ptr; - } -} - -string Leaf::DeprecatedVerifyAndToString(ART &art, const Node &node, const bool only_verify) { - D_ASSERT(node.GetType() == LEAF); - - string str = ""; - reference ref(node); - - while (ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref, LEAF); - D_ASSERT(leaf.count <= LEAF_SIZE); - - str += "Leaf [count: " + to_string(leaf.count) + ", row IDs: "; - for (uint8_t i = 0; i < leaf.count; i++) { - str += to_string(leaf.row_ids[i]) + "-"; - } - str += "] "; - ref = leaf.ptr; - } - - return only_verify ? "" : str; -} - -void Leaf::DeprecatedVerifyAllocations(ART &art, unordered_map &node_counts) const { - auto idx = Node::GetAllocatorIdx(LEAF); - node_counts[idx]++; - - reference ref(ptr); - while (ref.get().HasMetadata()) { - auto &leaf = Node::Ref(art, ref, LEAF); - node_counts[idx]++; - ref = leaf.ptr; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node.cpp b/src/duckdb/src/execution/index/art/node.cpp deleted file mode 100644 index 25c1dd5ff..000000000 --- a/src/duckdb/src/execution/index/art/node.cpp +++ /dev/null @@ -1,765 +0,0 @@ -#include "duckdb/execution/index/art/node.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/common/swap.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/art/base_leaf.hpp" -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/iterator.hpp" -#include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/node256.hpp" -#include "duckdb/execution/index/art/node256_leaf.hpp" -#include "duckdb/execution/index/art/node48.hpp" -#include "duckdb/execution/index/art/prefix.hpp" -#include "duckdb/storage/table_io_manager.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// New and free -//===--------------------------------------------------------------------===// - -void Node::New(ART &art, Node &node, NType type) { - switch (type) { - case NType::NODE_7_LEAF: - Node7Leaf::New(art, node); - break; - case NType::NODE_15_LEAF: - Node15Leaf::New(art, node); - break; - case NType::NODE_256_LEAF: - Node256Leaf::New(art, node); - break; - case NType::NODE_4: - Node4::New(art, node); - break; - case NType::NODE_16: - Node16::New(art, node); - break; - case NType::NODE_48: - Node48::New(art, node); - break; - case NType::NODE_256: - Node256::New(art, node); - break; - default: - throw InternalException("Invalid node type for New: %d.", static_cast(type)); - } -} - -void Node::Free(ART &art, Node &node) { - if (!node.HasMetadata()) { - return node.Clear(); - } - - // Free the children. - auto type = node.GetType(); - switch (type) { - case NType::PREFIX: - return Prefix::Free(art, node); - case NType::LEAF: - return Leaf::DeprecatedFree(art, node); - case NType::NODE_4: - Node4::Free(art, node); - break; - case NType::NODE_16: - Node16::Free(art, node); - break; - case NType::NODE_48: - Node48::Free(art, node); - break; - case NType::NODE_256: - Node256::Free(art, node); - break; - case NType::LEAF_INLINED: - return node.Clear(); - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - break; - } - - GetAllocator(art, type).Free(node); - node.Clear(); -} - -//===--------------------------------------------------------------------===// -// Allocators -//===--------------------------------------------------------------------===// - -FixedSizeAllocator &Node::GetAllocator(const ART &art, const NType type) { - return *(*art.allocators)[GetAllocatorIdx(type)]; -} - -uint8_t Node::GetAllocatorIdx(const NType type) { - switch (type) { - case NType::PREFIX: - return 0; - case NType::LEAF: - return 1; - case NType::NODE_4: - return 2; - case NType::NODE_16: - return 3; - case NType::NODE_48: - return 4; - case NType::NODE_256: - return 5; - case NType::NODE_7_LEAF: - return 6; - case NType::NODE_15_LEAF: - return 7; - case NType::NODE_256_LEAF: - return 8; - default: - throw InternalException("Invalid node type for GetAllocatorIdx: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// Inserts -//===--------------------------------------------------------------------===// - -void Node::ReplaceChild(const ART &art, const uint8_t byte, const Node child) const { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::NODE_4: - return Node4::ReplaceChild(Ref(art, *this, type), byte, child); - case NType::NODE_16: - return Node16::ReplaceChild(Ref(art, *this, type), byte, child); - case NType::NODE_48: - return Ref(art, *this, type).ReplaceChild(byte, child); - case NType::NODE_256: - return Ref(art, *this, type).ReplaceChild(byte, child); - default: - throw InternalException("Invalid node type for ReplaceChild: %d.", static_cast(type)); - } -} - -void Node::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - D_ASSERT(node.HasMetadata()); - - auto type = node.GetType(); - switch (type) { - case NType::NODE_4: - return Node4::InsertChild(art, node, byte, child); - case NType::NODE_16: - return Node16::InsertChild(art, node, byte, child); - case NType::NODE_48: - return Node48::InsertChild(art, node, byte, child); - case NType::NODE_256: - return Node256::InsertChild(art, node, byte, child); - case NType::NODE_7_LEAF: - return Node7Leaf::InsertByte(art, node, byte); - case NType::NODE_15_LEAF: - return Node15Leaf::InsertByte(art, node, byte); - case NType::NODE_256_LEAF: - return Node256Leaf::InsertByte(art, node, byte); - default: - throw InternalException("Invalid node type for InsertChild: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// Delete -//===--------------------------------------------------------------------===// - -void Node::DeleteChild(ART &art, Node &node, Node &prefix, const uint8_t byte, const GateStatus status, - const ARTKey &row_id) { - D_ASSERT(node.HasMetadata()); - - auto type = node.GetType(); - switch (type) { - case NType::NODE_4: - return Node4::DeleteChild(art, node, prefix, byte, status); - case NType::NODE_16: - return Node16::DeleteChild(art, node, byte); - case NType::NODE_48: - return Node48::DeleteChild(art, node, byte); - case NType::NODE_256: - return Node256::DeleteChild(art, node, byte); - case NType::NODE_7_LEAF: - return Node7Leaf::DeleteByte(art, node, prefix, byte, row_id); - case NType::NODE_15_LEAF: - return Node15Leaf::DeleteByte(art, node, byte); - case NType::NODE_256_LEAF: - return Node256Leaf::DeleteByte(art, node, byte); - default: - throw InternalException("Invalid node type for DeleteChild: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// Get child and byte. -//===--------------------------------------------------------------------===// - -template -unsafe_optional_ptr GetChildInternal(ART &art, NODE &node, const uint8_t byte) { - D_ASSERT(node.HasMetadata()); - - auto type = node.GetType(); - switch (type) { - case NType::NODE_4: - return Node4::GetChild(Node::Ref(art, node, type), byte); - case NType::NODE_16: - return Node16::GetChild(Node::Ref(art, node, type), byte); - case NType::NODE_48: - return Node48::GetChild(Node::Ref(art, node, type), byte); - case NType::NODE_256: { - return Node256::GetChild(Node::Ref(art, node, type), byte); - } - default: - throw InternalException("Invalid node type for GetChildInternal: %d.", static_cast(type)); - } -} - -const unsafe_optional_ptr Node::GetChild(ART &art, const uint8_t byte) const { - return GetChildInternal(art, *this, byte); -} - -unsafe_optional_ptr Node::GetChildMutable(ART &art, const uint8_t byte) const { - return GetChildInternal(art, *this, byte); -} - -template -unsafe_optional_ptr GetNextChildInternal(ART &art, NODE &node, uint8_t &byte) { - D_ASSERT(node.HasMetadata()); - - auto type = node.GetType(); - switch (type) { - case NType::NODE_4: - return Node4::GetNextChild(Node::Ref(art, node, type), byte); - case NType::NODE_16: - return Node16::GetNextChild(Node::Ref(art, node, type), byte); - case NType::NODE_48: - return Node48::GetNextChild(Node::Ref(art, node, type), byte); - case NType::NODE_256: - return Node256::GetNextChild(Node::Ref(art, node, type), byte); - default: - throw InternalException("Invalid node type for GetNextChildInternal: %d.", static_cast(type)); - } -} - -const unsafe_optional_ptr Node::GetNextChild(ART &art, uint8_t &byte) const { - return GetNextChildInternal(art, *this, byte); -} - -unsafe_optional_ptr Node::GetNextChildMutable(ART &art, uint8_t &byte) const { - return GetNextChildInternal(art, *this, byte); -} - -bool Node::HasByte(ART &art, uint8_t &byte) const { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::NODE_7_LEAF: - return Ref(art, *this, NType::NODE_7_LEAF).HasByte(byte); - case NType::NODE_15_LEAF: - return Ref(art, *this, NType::NODE_15_LEAF).HasByte(byte); - case NType::NODE_256_LEAF: - return Ref(art, *this, NType::NODE_256_LEAF).HasByte(byte); - default: - throw InternalException("Invalid node type for GetNextByte: %d.", static_cast(type)); - } -} - -bool Node::GetNextByte(ART &art, uint8_t &byte) const { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::NODE_7_LEAF: - return Ref(art, *this, NType::NODE_7_LEAF).GetNextByte(byte); - case NType::NODE_15_LEAF: - return Ref(art, *this, NType::NODE_15_LEAF).GetNextByte(byte); - case NType::NODE_256_LEAF: - return Ref(art, *this, NType::NODE_256_LEAF).GetNextByte(byte); - default: - throw InternalException("Invalid node type for GetNextByte: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// Utility -//===--------------------------------------------------------------------===// - -idx_t GetCapacity(NType type) { - switch (type) { - case NType::NODE_4: - return Node4::CAPACITY; - case NType::NODE_7_LEAF: - return Node7Leaf::CAPACITY; - case NType::NODE_15_LEAF: - return Node15Leaf::CAPACITY; - case NType::NODE_16: - return Node16::CAPACITY; - case NType::NODE_48: - return Node48::CAPACITY; - case NType::NODE_256_LEAF: - return Node256::CAPACITY; - case NType::NODE_256: - return Node256::CAPACITY; - default: - throw InternalException("Invalid node type for GetCapacity: %d.", static_cast(type)); - } -} - -NType Node::GetNodeType(idx_t count) { - if (count <= Node4::CAPACITY) { - return NType::NODE_4; - } else if (count <= Node16::CAPACITY) { - return NType::NODE_16; - } else if (count <= Node48::CAPACITY) { - return NType::NODE_48; - } - return NType::NODE_256; -} - -bool Node::IsNode() const { - switch (GetType()) { - case NType::NODE_4: - case NType::NODE_16: - case NType::NODE_48: - case NType::NODE_256: - return true; - default: - return false; - } -} - -bool Node::IsLeafNode() const { - switch (GetType()) { - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - return true; - default: - return false; - } -} - -bool Node::IsAnyLeaf() const { - if (IsLeafNode()) { - return true; - } - - switch (GetType()) { - case NType::LEAF_INLINED: - case NType::LEAF: - return true; - default: - return false; - } -} - -//===--------------------------------------------------------------------===// -// Merge -//===--------------------------------------------------------------------===// - -void Node::InitMerge(ART &art, const unsafe_vector &upper_bounds) { - D_ASSERT(HasMetadata()); - auto type = GetType(); - - switch (type) { - case NType::PREFIX: - return Prefix::InitializeMerge(art, *this, upper_bounds); - case NType::LEAF: - throw InternalException("Failed to initialize merge due to deprecated ART storage."); - case NType::NODE_4: - InitMergeInternal(art, Ref(art, *this, type), upper_bounds); - break; - case NType::NODE_16: - InitMergeInternal(art, Ref(art, *this, type), upper_bounds); - break; - case NType::NODE_48: - InitMergeInternal(art, Ref(art, *this, type), upper_bounds); - break; - case NType::NODE_256: - InitMergeInternal(art, Ref(art, *this, type), upper_bounds); - break; - case NType::LEAF_INLINED: - return; - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - break; - } - - auto idx = GetAllocatorIdx(type); - IncreaseBufferId(upper_bounds[idx]); -} - -bool Node::MergeNormalNodes(ART &art, Node &l_node, Node &r_node, uint8_t &byte, const GateStatus status) { - // Merge N4, N16, N48, N256 nodes. - D_ASSERT(l_node.IsNode() && r_node.IsNode()); - D_ASSERT(l_node.GetGateStatus() == r_node.GetGateStatus()); - - auto r_child = r_node.GetNextChildMutable(art, byte); - while (r_child) { - auto l_child = l_node.GetChildMutable(art, byte); - if (!l_child) { - Node::InsertChild(art, l_node, byte, *r_child); - r_node.ReplaceChild(art, byte); - } else { - if (!l_child->MergeInternal(art, *r_child, status)) { - return false; - } - } - - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - r_child = r_node.GetNextChildMutable(art, byte); - } - - Node::Free(art, r_node); - return true; -} - -void Node::MergeLeafNodes(ART &art, Node &l_node, Node &r_node, uint8_t &byte) { - // Merge N7, N15, N256 leaf nodes. - D_ASSERT(l_node.IsLeafNode() && r_node.IsLeafNode()); - D_ASSERT(l_node.GetGateStatus() == GateStatus::GATE_NOT_SET); - D_ASSERT(r_node.GetGateStatus() == GateStatus::GATE_NOT_SET); - - auto has_next = r_node.GetNextByte(art, byte); - while (has_next) { - // Row IDs are always unique. - Node::InsertChild(art, l_node, byte); - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - has_next = r_node.GetNextByte(art, byte); - } - - Node::Free(art, r_node); -} - -bool Node::MergeNodes(ART &art, Node &other, GateStatus status) { - // Merge the smaller node into the bigger node. - if (GetType() < other.GetType()) { - swap(*this, other); - } - - uint8_t byte = 0; - if (IsNode()) { - return MergeNormalNodes(art, *this, other, byte, status); - } - MergeLeafNodes(art, *this, other, byte); - return true; -} - -bool Node::Merge(ART &art, Node &other, const GateStatus status) { - if (HasMetadata()) { - return MergeInternal(art, other, status); - } - - *this = other; - other = Node(); - return true; -} - -bool Node::PrefixContainsOther(ART &art, Node &l_node, Node &r_node, const uint8_t pos, const GateStatus status) { - // r_node's prefix contains l_node's prefix. l_node must be a node with child nodes. - D_ASSERT(l_node.IsNode()); - - // Check if the next byte (pos) in r_node exists in l_node. - auto byte = Prefix::GetByte(art, r_node, pos); - auto child = l_node.GetChildMutable(art, byte); - - // Reduce r_node's prefix to the bytes after pos. - Prefix::Reduce(art, r_node, pos); - if (child) { - return child->MergeInternal(art, r_node, status); - } - - Node::InsertChild(art, l_node, byte, r_node); - r_node.Clear(); - return true; -} - -void Node::MergeIntoNode4(ART &art, Node &l_node, Node &r_node, const uint8_t pos) { - Node l_child; - auto l_byte = Prefix::GetByte(art, l_node, pos); - - reference ref(l_node); - auto status = Prefix::Split(art, ref, l_child, pos); - Node4::New(art, ref); - ref.get().SetGateStatus(status); - - Node4::InsertChild(art, ref, l_byte, l_child); - - auto r_byte = Prefix::GetByte(art, r_node, pos); - Prefix::Reduce(art, r_node, pos); - Node4::InsertChild(art, ref, r_byte, r_node); - r_node.Clear(); -} - -bool Node::MergePrefixes(ART &art, Node &other, const GateStatus status) { - reference l_node(*this); - reference r_node(other); - auto pos = DConstants::INVALID_INDEX; - - if (l_node.get().GetType() == NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - // Traverse prefixes. Possibly change the referenced nodes. - if (!Prefix::Traverse(art, l_node, r_node, pos, status)) { - return false; - } - if (pos == DConstants::INVALID_INDEX) { - return true; - } - - } else { - // l_prefix contains r_prefix. - if (l_node.get().GetType() == NType::PREFIX) { - swap(*this, other); - } - pos = 0; - } - - D_ASSERT(pos != DConstants::INVALID_INDEX); - if (l_node.get().GetType() != NType::PREFIX && r_node.get().GetType() == NType::PREFIX) { - return PrefixContainsOther(art, l_node, r_node, UnsafeNumericCast(pos), status); - } - - // The prefixes differ. - MergeIntoNode4(art, l_node, r_node, UnsafeNumericCast(pos)); - return true; -} - -bool Node::MergeInternal(ART &art, Node &other, const GateStatus status) { - D_ASSERT(HasMetadata()); - D_ASSERT(other.HasMetadata()); - - // Merge inlined leaves. - if (GetType() == NType::LEAF_INLINED) { - swap(*this, other); - } - if (other.GetType() == NType::LEAF_INLINED) { - D_ASSERT(status == GateStatus::GATE_NOT_SET); - D_ASSERT(other.GetGateStatus() == GateStatus::GATE_SET || other.GetType() == NType::LEAF_INLINED); - D_ASSERT(GetType() == NType::LEAF_INLINED || GetGateStatus() == GateStatus::GATE_SET); - - if (art.IsUnique()) { - return false; - } - Leaf::MergeInlined(art, *this, other); - return true; - } - - // Enter a gate. - if (GetGateStatus() == GateStatus::GATE_SET && status == GateStatus::GATE_NOT_SET) { - D_ASSERT(other.GetGateStatus() == GateStatus::GATE_SET); - D_ASSERT(GetType() != NType::LEAF_INLINED); - D_ASSERT(other.GetType() != NType::LEAF_INLINED); - - // Get all row IDs. - unsafe_vector row_ids; - Iterator it(art); - it.FindMinimum(other); - ARTKey empty_key = ARTKey(); - it.Scan(empty_key, NumericLimits().Maximum(), row_ids, false); - Node::Free(art, other); - D_ASSERT(row_ids.size() > 1); - - // Insert all row IDs. - ArenaAllocator allocator(Allocator::Get(art.db)); - for (idx_t i = 0; i < row_ids.size(); i++) { - auto row_id = ARTKey::CreateARTKey(allocator, row_ids[i]); - art.Insert(*this, row_id, 0, row_id, GateStatus::GATE_SET, nullptr); - } - return true; - } - - // Merge N4, N16, N48, N256 nodes. - if (IsNode() && other.IsNode()) { - return MergeNodes(art, other, status); - } - // Merge N7, N15, N256 leaf nodes. - if (IsLeafNode() && other.IsLeafNode()) { - D_ASSERT(status == GateStatus::GATE_SET); - return MergeNodes(art, other, status); - } - - // Merge prefixes. - return MergePrefixes(art, other, status); -} - -//===--------------------------------------------------------------------===// -// Vacuum -//===--------------------------------------------------------------------===// - -void Node::Vacuum(ART &art, const unordered_set &indexes) { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::LEAF_INLINED: - return; - case NType::PREFIX: - return Prefix::Vacuum(art, *this, indexes); - case NType::LEAF: - if (indexes.find(GetAllocatorIdx(type)) == indexes.end()) { - return; - } - return Leaf::DeprecatedVacuum(art, *this); - default: - break; - } - - auto idx = GetAllocatorIdx(type); - auto &allocator = GetAllocator(art, type); - auto needs_vacuum = indexes.find(idx) != indexes.end() && allocator.NeedsVacuum(*this); - if (needs_vacuum) { - auto status = GetGateStatus(); - *this = allocator.VacuumPointer(*this); - SetMetadata(static_cast(type)); - SetGateStatus(status); - } - - switch (type) { - case NType::NODE_4: - return VacuumInternal(art, Ref(art, *this, type), indexes); - case NType::NODE_16: - return VacuumInternal(art, Ref(art, *this, type), indexes); - case NType::NODE_48: - return VacuumInternal(art, Ref(art, *this, type), indexes); - case NType::NODE_256: - return VacuumInternal(art, Ref(art, *this, type), indexes); - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - return; - default: - throw InternalException("Invalid node type for Vacuum: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// TransformToDeprecated -//===--------------------------------------------------------------------===// - -void Node::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { - D_ASSERT(node.HasMetadata()); - - if (node.GetGateStatus() == GateStatus::GATE_SET) { - D_ASSERT(node.GetType() != NType::LEAF_INLINED); - return Leaf::TransformToDeprecated(art, node); - } - - auto type = node.GetType(); - switch (type) { - case NType::PREFIX: - return Prefix::TransformToDeprecated(art, node, allocator); - case NType::LEAF_INLINED: - return; - case NType::LEAF: - return; - case NType::NODE_4: - return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); - case NType::NODE_16: - return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); - case NType::NODE_48: - return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); - case NType::NODE_256: - return TransformToDeprecatedInternal(art, InMemoryRef(art, node, type), allocator); - default: - throw InternalException("Invalid node type for TransformToDeprecated: %d.", static_cast(type)); - } -} - -//===--------------------------------------------------------------------===// -// Verification -//===--------------------------------------------------------------------===// - -string Node::VerifyAndToString(ART &art, const bool only_verify) const { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::LEAF_INLINED: - return only_verify ? "" : "Inlined Leaf [row ID: " + to_string(GetRowId()) + "]"; - case NType::LEAF: - return Leaf::DeprecatedVerifyAndToString(art, *this, only_verify); - case NType::PREFIX: { - auto str = Prefix::VerifyAndToString(art, *this, only_verify); - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str; - } - default: - break; - } - - string str = "Node" + to_string(GetCapacity(type)) + ": [ "; - uint8_t byte = 0; - - if (IsLeafNode()) { - str = "Leaf " + str; - auto has_byte = GetNextByte(art, byte); - while (has_byte) { - str += to_string(byte) + "-"; - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - has_byte = GetNextByte(art, byte); - } - } else { - auto child = GetNextChild(art, byte); - while (child) { - str += "(" + to_string(byte) + ", " + child->VerifyAndToString(art, only_verify) + ")"; - if (byte == NumericLimits::Maximum()) { - break; - } - byte++; - child = GetNextChild(art, byte); - } - } - - if (GetGateStatus() == GateStatus::GATE_SET) { - str = "Gate [ " + str + " ]"; - } - return only_verify ? "" : "\n" + str + "]"; -} - -void Node::VerifyAllocations(ART &art, unordered_map &node_counts) const { - D_ASSERT(HasMetadata()); - - auto type = GetType(); - switch (type) { - case NType::PREFIX: - return Prefix::VerifyAllocations(art, *this, node_counts); - case NType::LEAF: - return Ref(art, *this, type).DeprecatedVerifyAllocations(art, node_counts); - case NType::LEAF_INLINED: - return; - case NType::NODE_4: - VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); - break; - case NType::NODE_16: - VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); - break; - case NType::NODE_48: - VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); - break; - case NType::NODE_256: - VerifyAllocationsInternal(art, Ref(art, *this, type), node_counts); - break; - case NType::NODE_7_LEAF: - case NType::NODE_15_LEAF: - case NType::NODE_256_LEAF: - break; - } - - node_counts[GetAllocatorIdx(type)]++; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256.cpp b/src/duckdb/src/execution/index/art/node256.cpp deleted file mode 100644 index f08717e13..000000000 --- a/src/duckdb/src/execution/index/art/node256.cpp +++ /dev/null @@ -1,78 +0,0 @@ -#include "duckdb/execution/index/art/node256.hpp" - -#include "duckdb/execution/index/art/node48.hpp" - -namespace duckdb { - -Node256 &Node256::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NODE_256).New(); - node.SetMetadata(static_cast(NODE_256)); - auto &n256 = Node::Ref(art, node, NODE_256); - - n256.count = 0; - for (uint16_t i = 0; i < CAPACITY; i++) { - n256.children[i].Clear(); - } - - return n256; -} - -void Node256::Free(ART &art, Node &node) { - auto &n256 = Node::Ref(art, node, NODE_256); - if (!n256.count) { - return; - } - - Iterator(n256, [&](Node &child) { Node::Free(art, child); }); -} - -void Node256::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - auto &n256 = Node::Ref(art, node, NODE_256); - n256.count++; - n256.children[byte] = child; -} - -void Node256::DeleteChild(ART &art, Node &node, const uint8_t byte) { - auto &n256 = Node::Ref(art, node, NODE_256); - - // Free the child and decrease the count. - Node::Free(art, n256.children[byte]); - n256.count--; - - // Shrink to Node48. - if (n256.count <= SHRINK_THRESHOLD) { - auto node256 = node; - Node48::ShrinkNode256(art, node, node256); - } -} - -void Node256::ReplaceChild(const uint8_t byte, const Node child) { - D_ASSERT(count > SHRINK_THRESHOLD); - - auto status = children[byte].GetGateStatus(); - children[byte] = child; - if (status == GateStatus::GATE_SET && child.HasMetadata()) { - children[byte].SetGateStatus(status); - } -} - -Node256 &Node256::GrowNode48(ART &art, Node &node256, Node &node48) { - auto &n48 = Node::Ref(art, node48, NType::NODE_48); - auto &n256 = New(art, node256); - node256.SetGateStatus(node48.GetGateStatus()); - - n256.count = n48.count; - for (uint16_t i = 0; i < CAPACITY; i++) { - if (n48.child_index[i] != Node48::EMPTY_MARKER) { - n256.children[i] = n48.children[n48.child_index[i]]; - } else { - n256.children[i].Clear(); - } - } - - n48.count = 0; - Node::Free(art, node48); - return n256; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node256_leaf.cpp b/src/duckdb/src/execution/index/art/node256_leaf.cpp deleted file mode 100644 index 5c74674f4..000000000 --- a/src/duckdb/src/execution/index/art/node256_leaf.cpp +++ /dev/null @@ -1,71 +0,0 @@ -#include "duckdb/execution/index/art/node256_leaf.hpp" - -#include "duckdb/execution/index/art/base_leaf.hpp" -#include "duckdb/execution/index/art/node48.hpp" - -namespace duckdb { - -Node256Leaf &Node256Leaf::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NODE_256_LEAF).New(); - node.SetMetadata(static_cast(NODE_256_LEAF)); - auto &n256 = Node::Ref(art, node, NODE_256_LEAF); - - n256.count = 0; - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - mask.SetAllInvalid(CAPACITY); - return n256; -} - -void Node256Leaf::InsertByte(ART &art, Node &node, const uint8_t byte) { - auto &n256 = Node::Ref(art, node, NODE_256_LEAF); - n256.count++; - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - mask.SetValid(byte); -} - -void Node256Leaf::DeleteByte(ART &art, Node &node, const uint8_t byte) { - auto &n256 = Node::Ref(art, node, NODE_256_LEAF); - n256.count--; - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - mask.SetInvalid(byte); - - // Shrink node to Node15 - if (n256.count <= Node48::SHRINK_THRESHOLD) { - auto node256 = node; - Node15Leaf::ShrinkNode256Leaf(art, node, node256); - } -} - -bool Node256Leaf::HasByte(uint8_t &byte) { - ValidityMask v_mask(&mask[0], Node256::CAPACITY); - return v_mask.RowIsValid(byte); -} - -bool Node256Leaf::GetNextByte(uint8_t &byte) { - ValidityMask v_mask(&mask[0], Node256::CAPACITY); - for (uint16_t i = byte; i < CAPACITY; i++) { - if (v_mask.RowIsValid(i)) { - byte = UnsafeNumericCast(i); - return true; - } - } - return false; -} - -Node256Leaf &Node256Leaf::GrowNode15Leaf(ART &art, Node &node256_leaf, Node &node15_leaf) { - auto &n15 = Node::Ref(art, node15_leaf, NType::NODE_15_LEAF); - auto &n256 = New(art, node256_leaf); - node256_leaf.SetGateStatus(node15_leaf.GetGateStatus()); - - n256.count = n15.count; - ValidityMask mask(&n256.mask[0], Node256::CAPACITY); - for (uint8_t i = 0; i < n15.count; i++) { - mask.SetValid(n15.key[i]); - } - - n15.count = 0; - Node::Free(art, node15_leaf); - return n256; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/node48.cpp b/src/duckdb/src/execution/index/art/node48.cpp deleted file mode 100644 index f9ad0460c..000000000 --- a/src/duckdb/src/execution/index/art/node48.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#include "duckdb/execution/index/art/node48.hpp" - -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/node256.hpp" - -namespace duckdb { - -Node48 &Node48::New(ART &art, Node &node) { - node = Node::GetAllocator(art, NODE_48).New(); - node.SetMetadata(static_cast(NODE_48)); - auto &n48 = Node::Ref(art, node, NODE_48); - - n48.count = 0; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - n48.child_index[i] = EMPTY_MARKER; - } - for (uint8_t i = 0; i < CAPACITY; i++) { - n48.children[i].Clear(); - } - - return n48; -} - -void Node48::Free(ART &art, Node &node) { - auto &n48 = Node::Ref(art, node, NODE_48); - if (!n48.count) { - return; - } - - Iterator(n48, [&](Node &child) { Node::Free(art, child); }); -} - -void Node48::InsertChild(ART &art, Node &node, const uint8_t byte, const Node child) { - auto &n48 = Node::Ref(art, node, NODE_48); - - // The node is full. Grow to Node256. - if (n48.count == CAPACITY) { - auto node48 = node; - Node256::GrowNode48(art, node, node48); - Node256::InsertChild(art, node, byte, child); - return; - } - - // Still space. Insert the child. - uint8_t child_pos = n48.count; - if (n48.children[child_pos].HasMetadata()) { - // Find an empty position in the node list. - child_pos = 0; - while (n48.children[child_pos].HasMetadata()) { - child_pos++; - } - } - - n48.children[child_pos] = child; - n48.child_index[byte] = child_pos; - n48.count++; -} - -void Node48::DeleteChild(ART &art, Node &node, const uint8_t byte) { - auto &n48 = Node::Ref(art, node, NODE_48); - - // Free the child and decrease the count. - Node::Free(art, n48.children[n48.child_index[byte]]); - n48.child_index[byte] = EMPTY_MARKER; - n48.count--; - - // Shrink to Node16. - if (n48.count < SHRINK_THRESHOLD) { - auto node48 = node; - Node16::ShrinkNode48(art, node, node48); - } -} - -void Node48::ReplaceChild(const uint8_t byte, const Node child) { - D_ASSERT(count >= SHRINK_THRESHOLD); - - auto status = children[child_index[byte]].GetGateStatus(); - children[child_index[byte]] = child; - if (status == GateStatus::GATE_SET && child.HasMetadata()) { - children[child_index[byte]].SetGateStatus(status); - } -} - -Node48 &Node48::GrowNode16(ART &art, Node &node48, Node &node16) { - auto &n16 = Node::Ref(art, node16, NType::NODE_16); - auto &n48 = New(art, node48); - node48.SetGateStatus(node16.GetGateStatus()); - - n48.count = n16.count; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - n48.child_index[i] = EMPTY_MARKER; - } - for (uint8_t i = 0; i < n16.count; i++) { - n48.child_index[n16.key[i]] = i; - n48.children[i] = n16.children[i]; - } - for (uint8_t i = n16.count; i < CAPACITY; i++) { - n48.children[i].Clear(); - } - - n16.count = 0; - Node::Free(art, node16); - return n48; -} - -Node48 &Node48::ShrinkNode256(ART &art, Node &node48, Node &node256) { - auto &n48 = New(art, node48); - auto &n256 = Node::Ref(art, node256, NType::NODE_256); - node48.SetGateStatus(node256.GetGateStatus()); - - n48.count = 0; - for (uint16_t i = 0; i < Node256::CAPACITY; i++) { - if (!n256.children[i].HasMetadata()) { - n48.child_index[i] = EMPTY_MARKER; - continue; - } - n48.child_index[i] = n48.count; - n48.children[n48.count] = n256.children[i]; - n48.count++; - } - for (uint8_t i = n48.count; i < CAPACITY; i++) { - n48.children[i].Clear(); - } - - n256.count = 0; - Node::Free(art, node256); - return n48; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/plan_art.cpp b/src/duckdb/src/execution/index/art/plan_art.cpp deleted file mode 100644 index ce459b290..000000000 --- a/src/duckdb/src/execution/index/art/plan_art.cpp +++ /dev/null @@ -1,89 +0,0 @@ -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/operator/filter/physical_filter.hpp" -#include "duckdb/execution/operator/order/physical_order.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/operator/schema/physical_create_art_index.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_create_index.hpp" - -namespace duckdb { - -unique_ptr ART::CreatePlan(PlanIndexInput &input) { - auto &op = input.op; - - // PROJECTION on indexed columns. - vector new_column_types; - vector> select_list; - for (idx_t i = 0; i < op.expressions.size(); i++) { - new_column_types.push_back(op.expressions[i]->return_type); - select_list.push_back(std::move(op.expressions[i])); - } - new_column_types.emplace_back(LogicalType::ROW_TYPE); - select_list.push_back(make_uniq(LogicalType::ROW_TYPE, op.info->scan_types.size() - 1)); - - auto projection = make_uniq(new_column_types, std::move(select_list), op.estimated_cardinality); - projection->children.push_back(std::move(input.table_scan)); - - // Optional NOT NULL filter. - unique_ptr prev_operator; - auto is_alter = op.alter_table_info != nullptr; - if (!is_alter) { - vector filter_types; - vector> filter_select_list; - auto not_null_type = ExpressionType::OPERATOR_IS_NOT_NULL; - - for (idx_t i = 0; i < new_column_types.size() - 1; i++) { - filter_types.push_back(new_column_types[i]); - auto is_not_null_expr = make_uniq(not_null_type, LogicalType::BOOLEAN); - auto bound_ref = make_uniq(new_column_types[i], i); - is_not_null_expr->children.push_back(std::move(bound_ref)); - filter_select_list.push_back(std::move(is_not_null_expr)); - } - - prev_operator = - make_uniq(std::move(filter_types), std::move(filter_select_list), op.estimated_cardinality); - prev_operator->types.emplace_back(LogicalType::ROW_TYPE); - prev_operator->children.push_back(std::move(projection)); - - } else { - prev_operator = std::move(projection); - } - - // Determine whether to push an ORDER BY operator. - auto sort = true; - if (op.unbound_expressions.size() > 1) { - sort = false; - } else if (op.unbound_expressions[0]->return_type.InternalType() == PhysicalType::VARCHAR) { - sort = false; - } - - // CREATE INDEX operator. - auto physical_create_index = make_uniq( - op, op.table, op.info->column_ids, std::move(op.info), std::move(op.unbound_expressions), - op.estimated_cardinality, sort, std::move(op.alter_table_info)); - - if (!sort) { - physical_create_index->children.push_back(std::move(prev_operator)); - return std::move(physical_create_index); - } - - // ORDER BY operator. - vector orders; - vector projections; - for (idx_t i = 0; i < new_column_types.size() - 1; i++) { - auto col_expr = make_uniq_base(new_column_types[i], i); - orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(col_expr)); - projections.emplace_back(i); - } - projections.emplace_back(new_column_types.size() - 1); - - auto physical_order = - make_uniq(new_column_types, std::move(orders), std::move(projections), op.estimated_cardinality); - - physical_order->children.push_back(std::move(prev_operator)); - physical_create_index->children.push_back(std::move(physical_order)); - return std::move(physical_create_index); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/art/prefix.cpp b/src/duckdb/src/execution/index/art/prefix.cpp deleted file mode 100644 index f0821be04..000000000 --- a/src/duckdb/src/execution/index/art/prefix.cpp +++ /dev/null @@ -1,519 +0,0 @@ -#include "duckdb/execution/index/art/prefix.hpp" - -#include "duckdb/common/swap.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/art/base_leaf.hpp" -#include "duckdb/execution/index/art/base_node.hpp" -#include "duckdb/execution/index/art/leaf.hpp" -#include "duckdb/execution/index/art/node.hpp" - -namespace duckdb { - -Prefix::Prefix(const ART &art, const Node ptr_p, const bool is_mutable, const bool set_in_memory) { - if (!set_in_memory) { - data = Node::GetAllocator(art, PREFIX).Get(ptr_p, is_mutable); - } else { - data = Node::GetAllocator(art, PREFIX).GetIfLoaded(ptr_p); - if (!data) { - ptr = nullptr; - in_memory = false; - return; - } - } - ptr = reinterpret_cast(data + Count(art) + 1); - in_memory = true; -} - -Prefix::Prefix(unsafe_unique_ptr &allocator, const Node ptr_p, const idx_t count) { - data = allocator->Get(ptr_p, true); - ptr = reinterpret_cast(data + count + 1); - in_memory = true; -} - -idx_t Prefix::GetMismatchWithOther(const Prefix &l_prefix, const Prefix &r_prefix, const idx_t max_count) { - for (idx_t i = 0; i < max_count; i++) { - if (l_prefix.data[i] != r_prefix.data[i]) { - return i; - } - } - return DConstants::INVALID_INDEX; -} - -idx_t Prefix::GetMismatchWithKey(ART &art, const Node &node, const ARTKey &key, idx_t &depth) { - Prefix prefix(art, node); - for (idx_t i = 0; i < prefix.data[Prefix::Count(art)]; i++) { - if (prefix.data[i] != key[depth]) { - return i; - } - depth++; - } - return DConstants::INVALID_INDEX; -} - -uint8_t Prefix::GetByte(const ART &art, const Node &node, const uint8_t pos) { - D_ASSERT(node.GetType() == PREFIX); - Prefix prefix(art, node); - return prefix.data[pos]; -} - -Prefix Prefix::NewInternal(ART &art, Node &node, const data_ptr_t data, const uint8_t count, const idx_t offset, - const NType type) { - node = Node::GetAllocator(art, type).New(); - node.SetMetadata(static_cast(type)); - - Prefix prefix(art, node, true); - prefix.data[Count(art)] = count; - if (data) { - D_ASSERT(count); - memcpy(prefix.data, data + offset, count); - } - return prefix; -} - -void Prefix::New(ART &art, reference &ref, const ARTKey &key, const idx_t depth, idx_t count) { - idx_t offset = 0; - - while (count) { - auto min = MinValue(UnsafeNumericCast(Count(art)), count); - auto this_count = UnsafeNumericCast(min); - auto prefix = NewInternal(art, ref, key.data, this_count, offset + depth, PREFIX); - - ref = *prefix.ptr; - offset += this_count; - count -= this_count; - } -} - -void Prefix::Free(ART &art, Node &node) { - Node next; - - while (node.HasMetadata() && node.GetType() == PREFIX) { - Prefix prefix(art, node, true); - next = *prefix.ptr; - Node::GetAllocator(art, PREFIX).Free(node); - node = next; - } - - Node::Free(art, node); - node.Clear(); -} - -void Prefix::InitializeMerge(ART &art, Node &node, const unsafe_vector &upper_bounds) { - auto buffer_count = upper_bounds[Node::GetAllocatorIdx(PREFIX)]; - Node next = node; - Prefix prefix(art, next, true); - - while (next.GetType() == PREFIX) { - next = *prefix.ptr; - if (prefix.ptr->GetType() == PREFIX) { - prefix.ptr->IncreaseBufferId(buffer_count); - prefix = Prefix(art, next, true); - } - } - - node.IncreaseBufferId(buffer_count); - prefix.ptr->InitMerge(art, upper_bounds); -} - -void Prefix::Concat(ART &art, Node &parent, uint8_t byte, const GateStatus old_status, const Node &child, - const GateStatus status) { - D_ASSERT(!parent.IsAnyLeaf()); - D_ASSERT(child.HasMetadata()); - - if (old_status == GateStatus::GATE_SET) { - // Concat Node4. - D_ASSERT(status == GateStatus::GATE_SET); - return ConcatGate(art, parent, byte, child); - } - if (child.GetGateStatus() == GateStatus::GATE_SET) { - // Concat Node4. - D_ASSERT(status == GateStatus::GATE_NOT_SET); - return ConcatChildIsGate(art, parent, byte, child); - } - - if (status == GateStatus::GATE_SET && child.GetType() == NType::LEAF_INLINED) { - auto row_id = child.GetRowId(); - Free(art, parent); - Leaf::New(parent, row_id); - return; - } - - if (parent.GetType() != PREFIX) { - auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); - if (child.GetType() == PREFIX) { - prefix.Append(art, child); - } else { - *prefix.ptr = child; - } - return; - } - - auto tail = GetTail(art, parent); - tail = tail.Append(art, byte); - - if (child.GetType() == PREFIX) { - tail.Append(art, child); - } else { - *tail.ptr = child; - } -} - -template -idx_t TraverseInternal(ART &art, reference &node, const ARTKey &key, idx_t &depth, - const bool is_mutable = false) { - D_ASSERT(node.get().HasMetadata()); - D_ASSERT(node.get().GetType() == NType::PREFIX); - - while (node.get().GetType() == NType::PREFIX) { - auto pos = Prefix::GetMismatchWithKey(art, node, key, depth); - if (pos != DConstants::INVALID_INDEX) { - return pos; - } - - Prefix prefix(art, node, is_mutable); - node = *prefix.ptr; - if (node.get().GetGateStatus() == GateStatus::GATE_SET) { - break; - } - } - return DConstants::INVALID_INDEX; -} - -idx_t Prefix::Traverse(ART &art, reference &node, const ARTKey &key, idx_t &depth) { - return TraverseInternal(art, node, key, depth); -} - -idx_t Prefix::TraverseMutable(ART &art, reference &node, const ARTKey &key, idx_t &depth) { - return TraverseInternal(art, node, key, depth, true); -} - -bool Prefix::Traverse(ART &art, reference &l_node, reference &r_node, idx_t &pos, const GateStatus status) { - D_ASSERT(l_node.get().HasMetadata()); - D_ASSERT(r_node.get().HasMetadata()); - - Prefix l_prefix(art, l_node, true); - Prefix r_prefix(art, r_node, true); - - idx_t max_count = MinValue(l_prefix.data[Count(art)], r_prefix.data[Count(art)]); - pos = GetMismatchWithOther(l_prefix, r_prefix, max_count); - if (pos != DConstants::INVALID_INDEX) { - return true; - } - - // Match. - if (l_prefix.data[Count(art)] == r_prefix.data[Count(art)]) { - auto r_child = *r_prefix.ptr; - r_prefix.ptr->Clear(); - Node::Free(art, r_node); - return l_prefix.ptr->MergeInternal(art, r_child, status); - } - - pos = max_count; - if (r_prefix.ptr->GetType() != PREFIX && r_prefix.data[Count(art)] == max_count) { - // l_prefix contains r_prefix. - swap(l_node.get(), r_node.get()); - l_node = *r_prefix.ptr; - return true; - } - // r_prefix contains l_prefix. - l_node = *l_prefix.ptr; - return true; -} - -void Prefix::Reduce(ART &art, Node &node, const idx_t pos) { - D_ASSERT(node.HasMetadata()); - D_ASSERT(pos < Count(art)); - - Prefix prefix(art, node); - if (pos == idx_t(prefix.data[Count(art)] - 1)) { - auto next = *prefix.ptr; - prefix.ptr->Clear(); - Node::Free(art, node); - node = next; - return; - } - - for (idx_t i = 0; i < Count(art) - pos - 1; i++) { - prefix.data[i] = prefix.data[pos + i + 1]; - } - - prefix.data[Count(art)] -= pos + 1; - prefix.Append(art, *prefix.ptr); -} - -GateStatus Prefix::Split(ART &art, reference &node, Node &child, const uint8_t pos) { - D_ASSERT(node.get().HasMetadata()); - - Prefix prefix(art, node, true); - - // The split is at the last prefix byte. Decrease the count and return. - if (pos + 1 == Count(art)) { - prefix.data[Count(art)]--; - node = *prefix.ptr; - child = *prefix.ptr; - return GateStatus::GATE_NOT_SET; - } - - if (pos + 1 < prefix.data[Count(art)]) { - // Create a new prefix and - // 1. copy the remaining bytes of this prefix. - // 2. append remaining prefix nodes. - auto new_prefix = NewInternal(art, child, nullptr, 0, 0, PREFIX); - new_prefix.data[Count(art)] = prefix.data[Count(art)] - pos - 1; - memcpy(new_prefix.data, prefix.data + pos + 1, new_prefix.data[Count(art)]); - - if (prefix.ptr->GetType() == PREFIX && prefix.ptr->GetGateStatus() == GateStatus::GATE_NOT_SET) { - new_prefix.Append(art, *prefix.ptr); - } else { - *new_prefix.ptr = *prefix.ptr; - } - - } else if (pos + 1 == prefix.data[Count(art)]) { - // No prefix bytes after the split. - child = *prefix.ptr; - } - - // Set the new count of this node. - prefix.data[Count(art)] = pos; - - // No bytes left before the split, free this node. - if (pos == 0) { - auto old_status = node.get().GetGateStatus(); - prefix.ptr->Clear(); - Node::Free(art, node); - return old_status; - } - - // There are bytes left before the split. - // The subsequent node replaces the split byte. - node = *prefix.ptr; - return GateStatus::GATE_NOT_SET; -} - -ARTConflictType Prefix::Insert(ART &art, Node &node, const ARTKey &key, idx_t depth, const ARTKey &row_id, - const GateStatus status, optional_ptr delete_art) { - reference next(node); - auto pos = TraverseMutable(art, next, key, depth); - - // We recurse into the next node, if - // (1) the prefix matches the key. - // (2) we reach a gate. - if (pos == DConstants::INVALID_INDEX) { - if (next.get().GetType() != NType::PREFIX || next.get().GetGateStatus() == GateStatus::GATE_SET) { - return art.Insert(next, key, depth, row_id, status, delete_art); - } - } - - Node remainder; - auto byte = GetByte(art, next, UnsafeNumericCast(pos)); - auto split_status = Split(art, next, remainder, UnsafeNumericCast(pos)); - Node4::New(art, next); - next.get().SetGateStatus(split_status); - - // Insert the remaining prefix into the new Node4. - Node4::InsertChild(art, next, byte, remainder); - - if (status == GateStatus::GATE_SET) { - D_ASSERT(pos != ROW_ID_COUNT); - Node new_row_id; - Leaf::New(new_row_id, key.GetRowId()); - Node::InsertChild(art, next, key[depth], new_row_id); - return ARTConflictType::NO_CONFLICT; - } - - Node leaf; - reference ref(leaf); - if (depth + 1 < key.len) { - // Create the prefix. - auto count = key.len - depth - 1; - Prefix::New(art, ref, key, depth + 1, count); - } - // Create the inlined leaf. - Leaf::New(ref, row_id.GetRowId()); - Node4::InsertChild(art, next, key[depth], leaf); - return ARTConflictType::NO_CONFLICT; -} - -string Prefix::VerifyAndToString(ART &art, const Node &node, const bool only_verify) { - string str = ""; - reference ref(node); - - Iterator(art, ref, true, false, [&](Prefix &prefix) { - D_ASSERT(prefix.data[Count(art)] != 0); - D_ASSERT(prefix.data[Count(art)] <= Count(art)); - - str += " Prefix :[ "; - for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { - str += to_string(prefix.data[i]) + "-"; - } - str += " ] "; - }); - - auto child = ref.get().VerifyAndToString(art, only_verify); - return only_verify ? "" : str + child; -} - -void Prefix::VerifyAllocations(ART &art, const Node &node, unordered_map &node_counts) { - auto idx = Node::GetAllocatorIdx(PREFIX); - reference ref(node); - Iterator(art, ref, false, false, [&](Prefix &prefix) { node_counts[idx]++; }); - return ref.get().VerifyAllocations(art, node_counts); -} - -void Prefix::Vacuum(ART &art, Node &node, const unordered_set &indexes) { - bool set = indexes.find(Node::GetAllocatorIdx(PREFIX)) != indexes.end(); - auto &allocator = Node::GetAllocator(art, PREFIX); - - reference ref(node); - while (ref.get().GetType() == PREFIX) { - if (set && allocator.NeedsVacuum(ref)) { - auto status = ref.get().GetGateStatus(); - ref.get() = allocator.VacuumPointer(ref); - ref.get().SetMetadata(static_cast(PREFIX)); - ref.get().SetGateStatus(status); - } - Prefix prefix(art, ref, true); - ref = *prefix.ptr; - } - - ref.get().Vacuum(art, indexes); -} - -void Prefix::TransformToDeprecated(ART &art, Node &node, unsafe_unique_ptr &allocator) { - // Early-out, if we do not need any transformations. - if (!allocator) { - reference ref(node); - while (ref.get().GetType() == PREFIX && ref.get().GetGateStatus() == GateStatus::GATE_NOT_SET) { - Prefix prefix(art, ref, true, true); - if (!prefix.in_memory) { - return; - } - ref = *prefix.ptr; - } - return Node::TransformToDeprecated(art, ref, allocator); - } - - // We need to create a new prefix (chain). - Node new_node; - new_node = allocator->New(); - new_node.SetMetadata(static_cast(PREFIX)); - Prefix new_prefix(allocator, new_node, DEPRECATED_COUNT); - - Node current_node = node; - while (current_node.GetType() == PREFIX && current_node.GetGateStatus() == GateStatus::GATE_NOT_SET) { - Prefix prefix(art, current_node, true, true); - if (!prefix.in_memory) { - return; - } - - for (idx_t i = 0; i < prefix.data[Count(art)]; i++) { - new_prefix = new_prefix.TransformToDeprecatedAppend(art, allocator, prefix.data[i]); - } - - *new_prefix.ptr = *prefix.ptr; - prefix.ptr->Clear(); - Node::Free(art, current_node); - current_node = *new_prefix.ptr; - } - - node = new_node; - return Node::TransformToDeprecated(art, *new_prefix.ptr, allocator); -} - -Prefix Prefix::Append(ART &art, const uint8_t byte) { - if (data[Count(art)] != Count(art)) { - data[data[Count(art)]] = byte; - data[Count(art)]++; - return *this; - } - - auto prefix = NewInternal(art, *ptr, nullptr, 0, 0, PREFIX); - return prefix.Append(art, byte); -} - -void Prefix::Append(ART &art, Node other) { - D_ASSERT(other.HasMetadata()); - - Prefix prefix = *this; - while (other.GetType() == PREFIX) { - if (other.GetGateStatus() == GateStatus::GATE_SET) { - *prefix.ptr = other; - return; - } - - Prefix other_prefix(art, other, true); - for (idx_t i = 0; i < other_prefix.data[Count(art)]; i++) { - prefix = prefix.Append(art, other_prefix.data[i]); - } - - *prefix.ptr = *other_prefix.ptr; - Node::GetAllocator(art, PREFIX).Free(other); - other = *prefix.ptr; - } -} - -Prefix Prefix::GetTail(ART &art, const Node &node) { - Prefix prefix(art, node, true); - while (prefix.ptr->GetType() == PREFIX) { - prefix = Prefix(art, *prefix.ptr, true); - } - return prefix; -} - -void Prefix::ConcatGate(ART &art, Node &parent, uint8_t byte, const Node &child) { - D_ASSERT(child.HasMetadata()); - Node new_prefix = Node(); - - // Inside gates, inlined row IDs are not prefixed. - if (child.GetType() == NType::LEAF_INLINED) { - Leaf::New(new_prefix, child.GetRowId()); - - } else if (child.GetType() == PREFIX) { - // At least one more row ID in this gate. - auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); - prefix.ptr->Clear(); - prefix.Append(art, child); - new_prefix.SetGateStatus(GateStatus::GATE_SET); - - } else { - // At least one more row ID in this gate. - auto prefix = NewInternal(art, new_prefix, &byte, 1, 0, PREFIX); - *prefix.ptr = child; - new_prefix.SetGateStatus(GateStatus::GATE_SET); - } - - if (parent.GetType() != PREFIX) { - parent = new_prefix; - return; - } - *GetTail(art, parent).ptr = new_prefix; -} - -void Prefix::ConcatChildIsGate(ART &art, Node &parent, uint8_t byte, const Node &child) { - // Create a new prefix and point it to the gate. - if (parent.GetType() != PREFIX) { - auto prefix = NewInternal(art, parent, &byte, 1, 0, PREFIX); - *prefix.ptr = child; - return; - } - - auto tail = GetTail(art, parent); - tail = tail.Append(art, byte); - *tail.ptr = child; -} - -Prefix Prefix::TransformToDeprecatedAppend(ART &art, unsafe_unique_ptr &allocator, uint8_t byte) { - if (data[DEPRECATED_COUNT] != DEPRECATED_COUNT) { - data[data[DEPRECATED_COUNT]] = byte; - data[DEPRECATED_COUNT]++; - return *this; - } - - *ptr = allocator->New(); - ptr->SetMetadata(static_cast(PREFIX)); - Prefix prefix(allocator, *ptr, DEPRECATED_COUNT); - return prefix.TransformToDeprecatedAppend(art, allocator, byte); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/bound_index.cpp b/src/duckdb/src/execution/index/bound_index.cpp deleted file mode 100644 index 199cc4bdd..000000000 --- a/src/duckdb/src/execution/index/bound_index.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "duckdb/execution/index/bound_index.hpp" - -#include "duckdb/common/radix.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/planner/expression/bound_columnref_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/storage/table/append_state.hpp" - -namespace duckdb { - -//------------------------------------------------------------------------------- -// Bound index -//------------------------------------------------------------------------------- - -BoundIndex::BoundIndex(const string &name, const string &index_type, IndexConstraintType index_constraint_type, - const vector &column_ids, TableIOManager &table_io_manager, - const vector> &unbound_expressions_p, AttachedDatabase &db) - : Index(column_ids, table_io_manager, db), name(name), index_type(index_type), - index_constraint_type(index_constraint_type) { - - for (auto &expr : unbound_expressions_p) { - types.push_back(expr->return_type.InternalType()); - logical_types.push_back(expr->return_type); - unbound_expressions.emplace_back(expr->Copy()); - bound_expressions.push_back(BindExpression(expr->Copy())); - executor.AddExpression(*bound_expressions.back()); - } -} - -void BoundIndex::InitializeLock(IndexLock &state) { - state.index_lock = unique_lock(lock); -} - -ErrorData BoundIndex::Append(DataChunk &chunk, Vector &row_ids) { - IndexLock l; - InitializeLock(l); - return Append(l, chunk, row_ids); -} - -ErrorData BoundIndex::AppendWithDeleteIndex(IndexLock &l, DataChunk &chunk, Vector &row_ids, - optional_ptr delete_index) { - // Fallback to the old Append. - return Append(l, chunk, row_ids); -} - -ErrorData BoundIndex::AppendWithDeleteIndex(DataChunk &chunk, Vector &row_ids, optional_ptr delete_index) { - IndexLock l; - InitializeLock(l); - return AppendWithDeleteIndex(l, chunk, row_ids, delete_index); -} - -void BoundIndex::VerifyAppend(DataChunk &chunk, optional_ptr delete_index, - optional_ptr manager) { - throw NotImplementedException("this implementation of VerifyAppend does not exist."); -} - -void BoundIndex::VerifyConstraint(DataChunk &chunk, optional_ptr delete_index, ConflictManager &manager) { - throw NotImplementedException("this implementation of VerifyConstraint does not exist."); -} - -void BoundIndex::CommitDrop() { - IndexLock index_lock; - InitializeLock(index_lock); - CommitDrop(index_lock); -} - -void BoundIndex::Delete(DataChunk &entries, Vector &row_identifiers) { - IndexLock state; - InitializeLock(state); - Delete(state, entries, row_identifiers); -} - -ErrorData BoundIndex::Insert(IndexLock &l, DataChunk &chunk, Vector &row_ids, optional_ptr delete_index) { - throw NotImplementedException("this implementation of Insert does not exist."); -} - -bool BoundIndex::MergeIndexes(BoundIndex &other_index) { - IndexLock state; - InitializeLock(state); - return MergeIndexes(state, other_index); -} - -string BoundIndex::VerifyAndToString(const bool only_verify) { - IndexLock state; - InitializeLock(state); - return VerifyAndToString(state, only_verify); -} - -void BoundIndex::VerifyAllocations() { - IndexLock state; - InitializeLock(state); - return VerifyAllocations(state); -} - -void BoundIndex::Vacuum() { - IndexLock state; - InitializeLock(state); - Vacuum(state); -} - -idx_t BoundIndex::GetInMemorySize() { - IndexLock state; - InitializeLock(state); - return GetInMemorySize(state); -} - -void BoundIndex::ExecuteExpressions(DataChunk &input, DataChunk &result) { - executor.Execute(input, result); -} - -unique_ptr BoundIndex::BindExpression(unique_ptr expr) { - if (expr->GetExpressionType() == ExpressionType::BOUND_COLUMN_REF) { - auto &bound_colref = expr->Cast(); - return make_uniq(expr->return_type, column_ids[bound_colref.binding.column_index]); - } - ExpressionIterator::EnumerateChildren( - *expr, [this](unique_ptr &expr) { expr = BindExpression(std::move(expr)); }); - return expr; -} - -bool BoundIndex::IndexIsUpdated(const vector &column_ids_p) const { - for (auto &column : column_ids_p) { - if (column_id_set.find(column.index) != column_id_set.end()) { - return true; - } - } - return false; -} - -IndexStorageInfo BoundIndex::GetStorageInfo(const case_insensitive_map_t &options, const bool to_wal) { - throw NotImplementedException("The implementation of this index serialization does not exist."); -} - -string BoundIndex::AppendRowError(DataChunk &input, idx_t index) { - string error; - for (idx_t c = 0; c < input.ColumnCount(); c++) { - if (c > 0) { - error += ", "; - } - error += input.GetValue(c, index).ToString(); - } - return error; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/fixed_size_allocator.cpp b/src/duckdb/src/execution/index/fixed_size_allocator.cpp deleted file mode 100644 index 24ea09181..000000000 --- a/src/duckdb/src/execution/index/fixed_size_allocator.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/execution/index/fixed_size_allocator.hpp" - -#include "duckdb/storage/metadata/metadata_reader.hpp" - -namespace duckdb { - -FixedSizeAllocator::FixedSizeAllocator(const idx_t segment_size, BlockManager &block_manager) - : block_manager(block_manager), buffer_manager(block_manager.buffer_manager), segment_size(segment_size), - total_segment_count(0) { - - if (segment_size > block_manager.GetBlockSize() - sizeof(validity_t)) { - throw InternalException("The maximum segment size of fixed-size allocators is " + - to_string(block_manager.GetBlockSize() - sizeof(validity_t))); - } - - // calculate how many segments fit into one buffer (available_segments_per_buffer) - - idx_t bits_per_value = sizeof(validity_t) * 8; - idx_t byte_count = 0; - - bitmask_count = 0; - available_segments_per_buffer = 0; - - while (byte_count < block_manager.GetBlockSize()) { - if (!bitmask_count || (bitmask_count * bits_per_value) % available_segments_per_buffer == 0) { - // we need to add another validity_t value to the bitmask, to allow storing another - // bits_per_value segments on a buffer - bitmask_count++; - byte_count += sizeof(validity_t); - } - - auto remaining_bytes = block_manager.GetBlockSize() - byte_count; - auto remaining_segments = MinValue(remaining_bytes / segment_size, bits_per_value); - - if (remaining_segments == 0) { - break; - } - - available_segments_per_buffer += remaining_segments; - byte_count += remaining_segments * segment_size; - } - - bitmask_offset = bitmask_count * sizeof(validity_t); -} - -IndexPointer FixedSizeAllocator::New() { - // no more segments available - if (buffers_with_free_space.empty()) { - - // add a new buffer - auto buffer_id = GetAvailableBufferId(); - FixedSizeBuffer new_buffer(block_manager); - buffers.insert(make_pair(buffer_id, std::move(new_buffer))); - buffers_with_free_space.insert(buffer_id); - - // set the bitmask - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - ValidityMask mask(reinterpret_cast(buffer.Get()), available_segments_per_buffer); - - // zero-initialize the bitmask to avoid leaking memory to disk - auto data = mask.GetData(); - for (idx_t i = 0; i < bitmask_count; i++) { - data[i] = 0; - } - - // initializing the bitmask of the new buffer - mask.SetAllValid(available_segments_per_buffer); - } - - // return a pointer to a free segment - D_ASSERT(!buffers_with_free_space.empty()); - auto buffer_id = uint32_t(*buffers_with_free_space.begin()); - - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - auto offset = buffer.GetOffset(bitmask_count, available_segments_per_buffer); - - total_segment_count++; - buffer.segment_count++; - if (buffer.segment_count == available_segments_per_buffer) { - buffers_with_free_space.erase(buffer_id); - } - - // zero-initialize that segment - auto buffer_ptr = buffer.Get(); - auto offset_in_buffer = buffer_ptr + offset * segment_size + bitmask_offset; - memset(offset_in_buffer, 0, segment_size); - - return IndexPointer(buffer_id, offset); -} - -void FixedSizeAllocator::Free(const IndexPointer ptr) { - - auto buffer_id = ptr.GetBufferId(); - auto offset = ptr.GetOffset(); - - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - - auto bitmask_ptr = reinterpret_cast(buffer.Get()); - ValidityMask mask(bitmask_ptr, offset + 1); // FIXME - D_ASSERT(!mask.RowIsValid(offset)); - mask.SetValid(offset); - - D_ASSERT(total_segment_count > 0); - D_ASSERT(buffer.segment_count > 0); - - // adjust the allocator fields - buffers_with_free_space.insert(buffer_id); - total_segment_count--; - buffer.segment_count--; -} - -void FixedSizeAllocator::Reset() { - for (auto &buffer : buffers) { - buffer.second.Destroy(); - } - buffers.clear(); - buffers_with_free_space.clear(); - total_segment_count = 0; -} - -idx_t FixedSizeAllocator::GetInMemorySize() const { - idx_t memory_usage = 0; - for (auto &buffer : buffers) { - if (buffer.second.InMemory()) { - memory_usage += block_manager.GetBlockSize(); - } - } - return memory_usage; -} - -idx_t FixedSizeAllocator::GetUpperBoundBufferId() const { - idx_t upper_bound_id = 0; - for (auto &buffer : buffers) { - if (buffer.first >= upper_bound_id) { - upper_bound_id = buffer.first + 1; - } - } - return upper_bound_id; -} - -void FixedSizeAllocator::Merge(FixedSizeAllocator &other) { - - D_ASSERT(segment_size == other.segment_size); - - // remember the buffer count and merge the buffers - idx_t upper_bound_id = GetUpperBoundBufferId(); - for (auto &buffer : other.buffers) { - buffers.insert(make_pair(buffer.first + upper_bound_id, std::move(buffer.second))); - } - other.buffers.clear(); - - // merge the buffers with free spaces - for (auto &buffer_id : other.buffers_with_free_space) { - buffers_with_free_space.insert(buffer_id + upper_bound_id); - } - other.buffers_with_free_space.clear(); - - // add the total allocations - total_segment_count += other.total_segment_count; -} - -bool FixedSizeAllocator::InitializeVacuum() { - - // NOTE: we do not vacuum buffers that are not in memory. We might consider changing this - // in the future, although buffers on disk should almost never be eligible for a vacuum - - if (total_segment_count == 0) { - Reset(); - return false; - } - RemoveEmptyBuffers(); - - // determine if a vacuum is necessary - multimap temporary_vacuum_buffers; - D_ASSERT(vacuum_buffers.empty()); - idx_t available_segments_in_memory = 0; - - for (auto &buffer : buffers) { - buffer.second.vacuum = false; - if (buffer.second.InMemory()) { - auto available_segments_in_buffer = available_segments_per_buffer - buffer.second.segment_count; - available_segments_in_memory += available_segments_in_buffer; - temporary_vacuum_buffers.emplace(available_segments_in_buffer, buffer.first); - } - } - - // no buffers in memory - if (temporary_vacuum_buffers.empty()) { - return false; - } - - auto excess_buffer_count = available_segments_in_memory / available_segments_per_buffer; - - // calculate the vacuum threshold adaptively - D_ASSERT(excess_buffer_count < temporary_vacuum_buffers.size()); - idx_t memory_usage = GetInMemorySize(); - idx_t excess_memory_usage = excess_buffer_count * block_manager.GetBlockSize(); - auto excess_percentage = double(excess_memory_usage) / double(memory_usage); - auto threshold = double(VACUUM_THRESHOLD) / 100.0; - if (excess_percentage < threshold) { - return false; - } - - D_ASSERT(excess_buffer_count <= temporary_vacuum_buffers.size()); - D_ASSERT(temporary_vacuum_buffers.size() <= buffers.size()); - - // erasing from a multimap, we vacuum the buffers with the most free spaces (least full) - while (temporary_vacuum_buffers.size() != excess_buffer_count) { - temporary_vacuum_buffers.erase(temporary_vacuum_buffers.begin()); - } - - // adjust the buffers, and erase all to-be-vacuumed buffers from the available buffer list - for (auto &vacuum_buffer : temporary_vacuum_buffers) { - auto buffer_id = vacuum_buffer.second; - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - buffers.find(buffer_id)->second.vacuum = true; - buffers_with_free_space.erase(buffer_id); - } - - for (auto &vacuum_buffer : temporary_vacuum_buffers) { - vacuum_buffers.insert(vacuum_buffer.second); - } - - return true; -} - -void FixedSizeAllocator::FinalizeVacuum() { - - for (auto &buffer_id : vacuum_buffers) { - D_ASSERT(buffers.find(buffer_id) != buffers.end()); - auto &buffer = buffers.find(buffer_id)->second; - D_ASSERT(buffer.InMemory()); - buffer.Destroy(); - buffers.erase(buffer_id); - } - vacuum_buffers.clear(); -} - -IndexPointer FixedSizeAllocator::VacuumPointer(const IndexPointer ptr) { - - // we do not need to adjust the bitmask of the old buffer, because we will free the entire - // buffer after the vacuum operation - - auto new_ptr = New(); - // new increases the allocation count, we need to counter that here - total_segment_count--; - - memcpy(Get(new_ptr), Get(ptr), segment_size); - return new_ptr; -} - -FixedSizeAllocatorInfo FixedSizeAllocator::GetInfo() const { - - FixedSizeAllocatorInfo info; - info.segment_size = segment_size; - - for (const auto &buffer : buffers) { - info.buffer_ids.push_back(buffer.first); - info.block_pointers.push_back(buffer.second.block_pointer); - info.segment_counts.push_back(buffer.second.segment_count); - info.allocation_sizes.push_back(buffer.second.allocation_size); - } - - for (auto &buffer_id : buffers_with_free_space) { - info.buffers_with_free_space.push_back(buffer_id); - } - - return info; -} - -void FixedSizeAllocator::SerializeBuffers(PartialBlockManager &partial_block_manager) { - for (auto &buffer : buffers) { - buffer.second.Serialize(partial_block_manager, available_segments_per_buffer, segment_size, bitmask_offset); - } -} - -vector FixedSizeAllocator::InitSerializationToWAL() { - - vector buffer_infos; - for (auto &buffer : buffers) { - buffer.second.SetAllocationSize(available_segments_per_buffer, segment_size, bitmask_offset); - buffer_infos.emplace_back(buffer.second.Get(), buffer.second.allocation_size); - } - return buffer_infos; -} - -void FixedSizeAllocator::Init(const FixedSizeAllocatorInfo &info) { - segment_size = info.segment_size; - total_segment_count = 0; - - for (idx_t i = 0; i < info.buffer_ids.size(); i++) { - - // read all FixedSizeBuffer data - auto buffer_id = info.buffer_ids[i]; - auto buffer_block_pointer = info.block_pointers[i]; - auto segment_count = info.segment_counts[i]; - auto allocation_size = info.allocation_sizes[i]; - - // create the FixedSizeBuffer - FixedSizeBuffer new_buffer(block_manager, segment_count, allocation_size, buffer_block_pointer); - buffers.insert(make_pair(buffer_id, std::move(new_buffer))); - total_segment_count += segment_count; - } - - for (const auto &buffer_id : info.buffers_with_free_space) { - buffers_with_free_space.insert(buffer_id); - } -} - -void FixedSizeAllocator::Deserialize(MetadataManager &metadata_manager, const BlockPointer &block_pointer) { - - MetadataReader reader(metadata_manager, block_pointer); - segment_size = reader.Read(); - auto buffer_count = reader.Read(); - auto buffers_with_free_space_count = reader.Read(); - - total_segment_count = 0; - - for (idx_t i = 0; i < buffer_count; i++) { - auto buffer_id = reader.Read(); - auto buffer_block_pointer = reader.Read(); - auto segment_count = reader.Read(); - auto allocation_size = reader.Read(); - FixedSizeBuffer new_buffer(block_manager, segment_count, allocation_size, buffer_block_pointer); - buffers.insert(make_pair(buffer_id, std::move(new_buffer))); - total_segment_count += segment_count; - } - for (idx_t i = 0; i < buffers_with_free_space_count; i++) { - buffers_with_free_space.insert(reader.Read()); - } -} - -idx_t FixedSizeAllocator::GetAvailableBufferId() const { - idx_t buffer_id = buffers.size(); - while (buffers.find(buffer_id) != buffers.end()) { - D_ASSERT(buffer_id > 0); - buffer_id--; - } - return buffer_id; -} - -void FixedSizeAllocator::RemoveEmptyBuffers() { - - auto buffer_it = buffers.begin(); - while (buffer_it != buffers.end()) { - if (buffer_it->second.segment_count != 0) { - buffer_it++; - continue; - } - - buffers_with_free_space.erase(buffer_it->first); - buffer_it->second.Destroy(); - buffer_it = buffers.erase(buffer_it); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/fixed_size_buffer.cpp b/src/duckdb/src/execution/index/fixed_size_buffer.cpp deleted file mode 100644 index 336d96396..000000000 --- a/src/duckdb/src/execution/index/fixed_size_buffer.cpp +++ /dev/null @@ -1,247 +0,0 @@ -#include "duckdb/execution/index/fixed_size_buffer.hpp" - -#include "duckdb/storage/block_manager.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// PartialBlockForIndex -//===--------------------------------------------------------------------===// - -PartialBlockForIndex::PartialBlockForIndex(PartialBlockState state, BlockManager &block_manager, - const shared_ptr &block_handle) - : PartialBlock(state, block_manager, block_handle) { -} - -void PartialBlockForIndex::Flush(const idx_t free_space_left) { - FlushInternal(free_space_left); - block_handle = block_manager.ConvertToPersistent(state.block_id, std::move(block_handle)); - Clear(); -} - -void PartialBlockForIndex::Merge(PartialBlock &other, idx_t offset, idx_t other_size) { - throw InternalException("no merge for PartialBlockForIndex"); -} - -void PartialBlockForIndex::Clear() { - block_handle.reset(); -} - -//===--------------------------------------------------------------------===// -// FixedSizeBuffer -//===--------------------------------------------------------------------===// - -constexpr idx_t FixedSizeBuffer::BASE[]; -constexpr uint8_t FixedSizeBuffer::SHIFT[]; - -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager) - : block_manager(block_manager), segment_count(0), allocation_size(0), dirty(false), vacuum(false), block_pointer(), - block_handle(nullptr) { - - auto &buffer_manager = block_manager.buffer_manager; - buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.GetBlockSize(), false); - block_handle = buffer_handle.GetBlockHandle(); -} - -FixedSizeBuffer::FixedSizeBuffer(BlockManager &block_manager, const idx_t segment_count, const idx_t allocation_size, - const BlockPointer &block_pointer) - : block_manager(block_manager), segment_count(segment_count), allocation_size(allocation_size), dirty(false), - vacuum(false), block_pointer(block_pointer) { - - D_ASSERT(block_pointer.IsValid()); - block_handle = block_manager.RegisterBlock(block_pointer.block_id); - D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); -} - -void FixedSizeBuffer::Destroy() { - if (InMemory()) { - // we can have multiple readers on a pinned block, and unpinning the buffer handle - // decrements the reader count on the underlying block handle (Destroy() unpins) - buffer_handle.Destroy(); - } - if (OnDisk()) { - // marking a block as modified decreases the reference count of multi-use blocks - block_manager.MarkBlockAsModified(block_pointer.block_id); - } -} - -void FixedSizeBuffer::Serialize(PartialBlockManager &partial_block_manager, const idx_t available_segments, - const idx_t segment_size, const idx_t bitmask_offset) { - - // Early-out, if the block is already on disk and not in memory. - if (!InMemory()) { - if (!OnDisk() || dirty) { - throw InternalException("invalid or missing buffer in FixedSizeAllocator"); - } - return; - } - - // Early-out, if the buffer is already on disk and not dirty. - if (!dirty && OnDisk()) { - return; - } - - // Adjust the allocation size. - D_ASSERT(segment_count != 0); - SetAllocationSize(available_segments, segment_size, bitmask_offset); - - // the buffer is in memory, so we copied it onto a new buffer when pinning - D_ASSERT(InMemory()); - if (OnDisk()) { - block_manager.MarkBlockAsModified(block_pointer.block_id); - } - - // now we write the changes, first get a partial block allocation - PartialBlockAllocation allocation = - partial_block_manager.GetBlockAllocation(NumericCast(allocation_size)); - block_pointer.block_id = allocation.state.block_id; - block_pointer.offset = allocation.state.offset; - - auto &buffer_manager = block_manager.buffer_manager; - - if (allocation.partial_block) { - // copy to an existing partial block - D_ASSERT(block_pointer.offset > 0); - auto &p_block_for_index = allocation.partial_block->Cast(); - auto dst_handle = buffer_manager.Pin(p_block_for_index.block_handle); - memcpy(dst_handle.Ptr() + block_pointer.offset, buffer_handle.Ptr(), allocation_size); - SetUninitializedRegions(p_block_for_index, segment_size, block_pointer.offset, bitmask_offset, - available_segments); - - } else { - // create a new block that can potentially be used as a partial block - D_ASSERT(block_handle); - D_ASSERT(!block_pointer.offset); - auto p_block_for_index = make_uniq(allocation.state, block_manager, block_handle); - SetUninitializedRegions(*p_block_for_index, segment_size, block_pointer.offset, bitmask_offset, - available_segments); - allocation.partial_block = std::move(p_block_for_index); - } - - // resetting this buffer - buffer_handle.Destroy(); - - // register the partial block - partial_block_manager.RegisterPartialBlock(std::move(allocation)); - - block_handle = block_manager.RegisterBlock(block_pointer.block_id); - D_ASSERT(block_handle->BlockId() < MAXIMUM_BLOCK); - - // we persist any changes, so the buffer is no longer dirty - dirty = false; -} - -void FixedSizeBuffer::Pin() { - auto &buffer_manager = block_manager.buffer_manager; - D_ASSERT(block_pointer.IsValid()); - D_ASSERT(block_handle && block_handle->BlockId() < MAXIMUM_BLOCK); - D_ASSERT(!dirty); - - buffer_handle = buffer_manager.Pin(block_handle); - - // Copy the (partial) data into a new (not yet disk-backed) buffer handle. - shared_ptr new_block_handle; - auto new_buffer_handle = buffer_manager.Allocate(MemoryTag::ART_INDEX, block_manager.GetBlockSize(), false); - new_block_handle = new_buffer_handle.GetBlockHandle(); - memcpy(new_buffer_handle.Ptr(), buffer_handle.Ptr() + block_pointer.offset, allocation_size); - - buffer_handle = std::move(new_buffer_handle); - block_handle = std::move(new_block_handle); -} - -uint32_t FixedSizeBuffer::GetOffset(const idx_t bitmask_count, const idx_t available_segments) { - - // get the bitmask data - auto bitmask_ptr = reinterpret_cast(Get()); - ValidityMask mask(bitmask_ptr, available_segments); - auto data = mask.GetData(); - - // fills up a buffer sequentially before searching for free bits - if (mask.RowIsValid(segment_count)) { - mask.SetInvalid(segment_count); - return UnsafeNumericCast(segment_count); - } - - for (idx_t entry_idx = 0; entry_idx < bitmask_count; entry_idx++) { - // get an entry with free bits - if (data[entry_idx] == 0) { - continue; - } - - // find the position of the free bit - auto entry = data[entry_idx]; - idx_t first_valid_bit = 0; - - // this loop finds the position of the rightmost set bit in entry and stores it - // in first_valid_bit - for (idx_t i = 0; i < 6; i++) { - // set the left half of the bits of this level to zero and test if the entry is still not zero - if (entry & BASE[i]) { - // first valid bit is in the rightmost s[i] bits - // permanently set the left half of the bits to zero - entry &= BASE[i]; - } else { - // first valid bit is in the leftmost s[i] bits - // shift by s[i] for the next iteration and add s[i] to the position of the rightmost set bit - entry >>= SHIFT[i]; - first_valid_bit += SHIFT[i]; - } - } - D_ASSERT(entry); - - auto prev_bits = entry_idx * sizeof(validity_t) * 8; - D_ASSERT(mask.RowIsValid(prev_bits + first_valid_bit)); - mask.SetInvalid(prev_bits + first_valid_bit); - return UnsafeNumericCast(prev_bits + first_valid_bit); - } - - throw InternalException("Invalid bitmask for FixedSizeAllocator"); -} - -void FixedSizeBuffer::SetAllocationSize(const idx_t available_segments, const idx_t segment_size, - const idx_t bitmask_offset) { - if (!dirty) { - return; - } - - // We traverse from the back. A binary search would be faster. - // However, buffers are often (almost) full, so the overhead is acceptable. - auto bitmask_ptr = reinterpret_cast(Get()); - ValidityMask mask(bitmask_ptr, available_segments); - - auto max_offset = available_segments; - for (idx_t i = available_segments; i > 0; i--) { - if (!mask.RowIsValid(i - 1)) { - max_offset = i; - break; - } - } - allocation_size = max_offset * segment_size + bitmask_offset; -} - -void FixedSizeBuffer::SetUninitializedRegions(PartialBlockForIndex &p_block_for_index, const idx_t segment_size, - const idx_t offset, const idx_t bitmask_offset, - const idx_t available_segments) { - - // this function calls Get() on the buffer - D_ASSERT(InMemory()); - - auto bitmask_ptr = reinterpret_cast(Get()); - ValidityMask mask(bitmask_ptr, available_segments); - - idx_t i = 0; - idx_t max_offset = offset + allocation_size; - idx_t current_offset = offset + bitmask_offset; - while (current_offset < max_offset) { - - if (mask.RowIsValid(i)) { - D_ASSERT(current_offset + segment_size <= max_offset); - p_block_for_index.AddUninitializedRegion(current_offset, current_offset + segment_size); - } - current_offset += segment_size; - i++; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/index_type_set.cpp b/src/duckdb/src/execution/index/index_type_set.cpp deleted file mode 100644 index 4fe7cda4f..000000000 --- a/src/duckdb/src/execution/index/index_type_set.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "duckdb/execution/index/index_type.hpp" -#include "duckdb/execution/index/index_type_set.hpp" -#include "duckdb/execution/index/art/art.hpp" - -namespace duckdb { - -IndexTypeSet::IndexTypeSet() { - - // Register the ART index type by default - IndexType art_index_type; - art_index_type.name = ART::TYPE_NAME; - art_index_type.create_instance = ART::Create; - art_index_type.create_plan = ART::CreatePlan; - - RegisterIndexType(art_index_type); -} - -optional_ptr IndexTypeSet::FindByName(const string &name) { - lock_guard g(lock); - auto entry = functions.find(name); - if (entry == functions.end()) { - return nullptr; - } - return &entry->second; -} - -void IndexTypeSet::RegisterIndexType(const IndexType &index_type) { - lock_guard g(lock); - if (functions.find(index_type.name) != functions.end()) { - throw CatalogException("Index type with name \"%s\" already exists!", index_type.name.c_str()); - } - functions[index_type.name] = index_type; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/index/unbound_index.cpp b/src/duckdb/src/execution/index/unbound_index.cpp deleted file mode 100644 index b8173d751..000000000 --- a/src/duckdb/src/execution/index/unbound_index.cpp +++ /dev/null @@ -1,30 +0,0 @@ -#include "duckdb/execution/index/unbound_index.hpp" -#include "duckdb/parser/parsed_data/create_index_info.hpp" -#include "duckdb/storage/table_io_manager.hpp" -#include "duckdb/storage/block_manager.hpp" -#include "duckdb/storage/index_storage_info.hpp" - -namespace duckdb { - -//------------------------------------------------------------------------------- -// Unbound index -//------------------------------------------------------------------------------- - -UnboundIndex::UnboundIndex(unique_ptr create_info, IndexStorageInfo storage_info_p, - TableIOManager &table_io_manager, AttachedDatabase &db) - : Index(create_info->Cast().column_ids, table_io_manager, db), create_info(std::move(create_info)), - storage_info(std::move(storage_info_p)) { -} - -void UnboundIndex::CommitDrop() { - auto &block_manager = table_io_manager.GetIndexBlockManager(); - for (auto &info : storage_info.allocator_infos) { - for (auto &block : info.block_pointers) { - if (block.IsValid()) { - block_manager.MarkBlockAsModified(block.block_id); - } - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/join_hashtable.cpp b/src/duckdb/src/execution/join_hashtable.cpp deleted file mode 100644 index 14abdc61e..000000000 --- a/src/duckdb/src/execution/join_hashtable.cpp +++ /dev/null @@ -1,1679 +0,0 @@ -#include "duckdb/execution/join_hashtable.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/ht_entry.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { -using ValidityBytes = JoinHashTable::ValidityBytes; -using ScanStructure = JoinHashTable::ScanStructure; -using ProbeSpill = JoinHashTable::ProbeSpill; -using ProbeSpillLocalState = JoinHashTable::ProbeSpillLocalAppendState; - -JoinHashTable::SharedState::SharedState() - : rhs_row_locations(LogicalType::POINTER), salt_v(LogicalType::UBIGINT), salt_match_sel(STANDARD_VECTOR_SIZE), - key_no_match_sel(STANDARD_VECTOR_SIZE) { -} - -JoinHashTable::ProbeState::ProbeState() - : SharedState(), ht_offsets_v(LogicalType::UBIGINT), ht_offsets_dense_v(LogicalType::UBIGINT), - non_empty_sel(STANDARD_VECTOR_SIZE) { -} - -JoinHashTable::InsertState::InsertState(const JoinHashTable &ht) - : SharedState(), remaining_sel(STANDARD_VECTOR_SIZE), key_match_sel(STANDARD_VECTOR_SIZE) { - ht.data_collection->InitializeChunk(lhs_data, ht.equality_predicate_columns); - ht.data_collection->InitializeChunkState(chunk_state, ht.equality_predicate_columns); -} - -JoinHashTable::JoinHashTable(ClientContext &context, const vector &conditions_p, - vector btypes, JoinType type_p, const vector &output_columns_p) - : buffer_manager(BufferManager::GetBufferManager(context)), conditions(conditions_p), - build_types(std::move(btypes)), output_columns(output_columns_p), entry_size(0), tuple_size(0), - vfound(Value::BOOLEAN(false)), join_type(type_p), finalized(false), has_null(false), - radix_bits(INITIAL_RADIX_BITS) { - for (idx_t i = 0; i < conditions.size(); ++i) { - auto &condition = conditions[i]; - D_ASSERT(condition.left->return_type == condition.right->return_type); - auto type = condition.left->return_type; - if (condition.comparison == ExpressionType::COMPARE_EQUAL || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - - // ensure that all equality conditions are at the front, - // and that all other conditions are at the back - D_ASSERT(equality_types.size() == condition_types.size()); - equality_types.push_back(type); - equality_predicates.push_back(condition.comparison); - equality_predicate_columns.push_back(i); - - } else { - // all non-equality conditions are at the back - non_equality_predicates.push_back(condition.comparison); - non_equality_predicate_columns.push_back(i); - } - - null_values_are_equal.push_back(condition.comparison == ExpressionType::COMPARE_DISTINCT_FROM || - condition.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM); - - condition_types.push_back(type); - } - // at least one equality is necessary - D_ASSERT(!equality_types.empty()); - - // Types for the layout - vector layout_types(condition_types); - layout_types.insert(layout_types.end(), build_types.begin(), build_types.end()); - if (PropagatesBuildSide(join_type)) { - // full/right outer joins need an extra bool to keep track of whether or not a tuple has found a matching entry - // we place the bool before the NEXT pointer - layout_types.emplace_back(LogicalType::BOOLEAN); - } - layout_types.emplace_back(LogicalType::HASH); - layout.Initialize(layout_types, false); - - // Initialize the row matcher that are used for filtering during the probing only if there are non-equality - if (!non_equality_predicates.empty()) { - - row_matcher_probe = unique_ptr(new RowMatcher()); - row_matcher_probe_no_match_sel = unique_ptr(new RowMatcher()); - - row_matcher_probe->Initialize(false, layout, non_equality_predicates, non_equality_predicate_columns); - row_matcher_probe_no_match_sel->Initialize(true, layout, non_equality_predicates, - non_equality_predicate_columns); - - needs_chain_matcher = true; - } else { - needs_chain_matcher = false; - } - - chains_longer_than_one = false; - row_matcher_build.Initialize(true, layout, equality_predicates); - - const auto &offsets = layout.GetOffsets(); - tuple_size = offsets[condition_types.size() + build_types.size()]; - pointer_offset = offsets.back(); - entry_size = layout.GetRowWidth(); - - data_collection = make_uniq(buffer_manager, layout); - sink_collection = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); - - dead_end = make_unsafe_uniq_array_uninitialized(layout.GetRowWidth()); - memset(dead_end.get(), 0, layout.GetRowWidth()); - - if (join_type == JoinType::SINGLE) { - auto &config = ClientConfig::GetConfig(context); - single_join_error_on_multiple_rows = config.scalar_subquery_error_on_multiple_rows; - } - - InitializePartitionMasks(); -} - -JoinHashTable::~JoinHashTable() { -} - -void JoinHashTable::Merge(JoinHashTable &other) { - { - lock_guard guard(data_lock); - data_collection->Combine(*other.data_collection); - } - - if (join_type == JoinType::MARK) { - auto &info = correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - has_null = has_null || other.has_null; - if (!info.correlated_types.empty()) { - auto &other_info = other.correlated_mark_join_info; - info.correlated_counts->Combine(*other_info.correlated_counts); - } - } - - sink_collection->Combine(*other.sink_collection); -} - -static void ApplyBitmaskAndGetSaltBuild(Vector &hashes_v, Vector &salt_v, const idx_t &count, const idx_t &bitmask) { - if (hashes_v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - auto &hash = *ConstantVector::GetData(hashes_v); - salt_v.SetVectorType(VectorType::CONSTANT_VECTOR); - - *ConstantVector::GetData(salt_v) = ht_entry_t::ExtractSalt(hash); - salt_v.Flatten(count); - - hash = hash & bitmask; - hashes_v.Flatten(count); - } else { - hashes_v.Flatten(count); - auto salts = FlatVector::GetData(salt_v); - auto hashes = FlatVector::GetData(hashes_v); - for (idx_t i = 0; i < count; i++) { - salts[i] = ht_entry_t::ExtractSalt(hashes[i]); - hashes[i] &= bitmask; - } - } -} - -//! Gets a pointer to the entry in the HT for each of the hashes_v using linear probing. Will update the key_match_sel -//! vector and the count argument to the number and position of the matches -template -static inline void GetRowPointersInternal(DataChunk &keys, TupleDataChunkState &key_state, - JoinHashTable::ProbeState &state, Vector &hashes_v, - const SelectionVector &sel, idx_t &count, JoinHashTable &ht, - ht_entry_t *entries, Vector &pointers_result_v, SelectionVector &match_sel) { - UnifiedVectorFormat hashes_v_unified; - hashes_v.ToUnifiedFormat(count, hashes_v_unified); - - auto hashes = UnifiedVectorFormat::GetData(hashes_v_unified); - auto salts = FlatVector::GetData(state.salt_v); - - auto ht_offsets = FlatVector::GetData(state.ht_offsets_v); - auto ht_offsets_dense = FlatVector::GetData(state.ht_offsets_dense_v); - - idx_t non_empty_count = 0; - - // first, filter out the empty rows and calculate the offset - for (idx_t i = 0; i < count; i++) { - const auto row_index = sel.get_index(i); - auto uvf_index = hashes_v_unified.sel->get_index(row_index); - auto ht_offset = hashes[uvf_index] & ht.bitmask; - ht_offsets_dense[i] = ht_offset; - ht_offsets[row_index] = ht_offset; - } - - // have a dense loop to have as few instructions as possible while producing cache misses as this is the - // first location where we access the big entries array - for (idx_t i = 0; i < count; i++) { - idx_t ht_offset = ht_offsets_dense[i]; - auto &entry = entries[ht_offset]; - bool occupied = entry.IsOccupied(); - state.non_empty_sel.set_index(non_empty_count, i); - non_empty_count += occupied; - } - - for (idx_t i = 0; i < non_empty_count; i++) { - // transform the dense index to the actual index in the sel vector - idx_t dense_index = state.non_empty_sel.get_index(i); - const auto row_index = sel.get_index(dense_index); - state.non_empty_sel.set_index(i, row_index); - - if (USE_SALTS) { - auto uvf_index = hashes_v_unified.sel->get_index(row_index); - auto hash = hashes[uvf_index]; - hash_t row_salt = ht_entry_t::ExtractSalt(hash); - salts[row_index] = row_salt; - } - } - - auto pointers_result = FlatVector::GetData(pointers_result_v); - auto row_ptr_insert_to = FlatVector::GetData(state.rhs_row_locations); - - const SelectionVector *remaining_sel = &state.non_empty_sel; - idx_t remaining_count = non_empty_count; - - idx_t &match_count = count; - match_count = 0; - - while (remaining_count > 0) { - idx_t salt_match_count = 0; - idx_t key_no_match_count = 0; - - // for each entry, linear probing until - // a) an empty entry is found -> return nullptr (do nothing, as vector is zeroed) - // b) an entry is found where the salt matches -> need to compare the keys - for (idx_t i = 0; i < remaining_count; i++) { - const auto row_index = remaining_sel->get_index(i); - - idx_t &ht_offset = ht_offsets[row_index]; - bool occupied; - ht_entry_t entry; - - if (USE_SALTS) { - hash_t row_salt = salts[row_index]; - // increment the ht_offset of the entry as long as next entry is occupied and salt does not match - while (true) { - entry = entries[ht_offset]; - occupied = entry.IsOccupied(); - bool salt_match = entry.GetSalt() == row_salt; - - // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next - // entry - if (!occupied || salt_match) { - break; - } - - IncrementAndWrap(ht_offset, ht.bitmask); - } - } else { - entry = entries[ht_offset]; - occupied = entry.IsOccupied(); - } - - // the entries we need to process in the next iteration are the ones that are occupied and the row_salt - // does not match, the ones that are empty need no further processing - state.salt_match_sel.set_index(salt_match_count, row_index); - salt_match_count += occupied; - - // entry might be empty, so the pointer in the entry is nullptr, but this does not matter as the row - // will not be compared anyway as with an empty entry we are already done - row_ptr_insert_to[row_index] = entry.GetPointerOrNull(); - } - - if (salt_match_count != 0) { - // Perform row comparisons, after function call salt_match_sel will point to the keys that match - idx_t key_match_count = ht.row_matcher_build.Match(keys, key_state.vector_data, state.salt_match_sel, - salt_match_count, ht.layout, state.rhs_row_locations, - &state.key_no_match_sel, key_no_match_count); - - D_ASSERT(key_match_count + key_no_match_count == salt_match_count); - - // Set a pointer to the matching row - for (idx_t i = 0; i < key_match_count; i++) { - const auto row_index = state.salt_match_sel.get_index(i); - pointers_result[row_index] = row_ptr_insert_to[row_index]; - - match_sel.set_index(match_count, row_index); - match_count++; - } - - // Linear probing: each of the entries that do not match move to the next entry in the HT - for (idx_t i = 0; i < key_no_match_count; i++) { - const auto row_index = state.key_no_match_sel.get_index(i); - auto &ht_offset = ht_offsets[row_index]; - - IncrementAndWrap(ht_offset, ht.bitmask); - } - } - - remaining_sel = &state.key_no_match_sel; - remaining_count = key_no_match_count; - } -} - -inline bool JoinHashTable::UseSalt() const { - // only use salt for large hash tables - return this->capacity > USE_SALT_THRESHOLD; -} - -void JoinHashTable::GetRowPointers(DataChunk &keys, TupleDataChunkState &key_state, ProbeState &state, Vector &hashes_v, - const SelectionVector &sel, idx_t &count, Vector &pointers_result_v, - SelectionVector &match_sel) { - if (UseSalt()) { - GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, *this, entries, pointers_result_v, - match_sel); - } else { - GetRowPointersInternal(keys, key_state, state, hashes_v, sel, count, *this, entries, pointers_result_v, - match_sel); - } -} - -void JoinHashTable::Hash(DataChunk &keys, const SelectionVector &sel, idx_t count, Vector &hashes) { - if (count == keys.size()) { - // no null values are filtered: use regular hash functions - VectorOperations::Hash(keys.data[0], hashes, keys.size()); - for (idx_t i = 1; i < equality_types.size(); i++) { - VectorOperations::CombineHash(hashes, keys.data[i], keys.size()); - } - } else { - // null values were filtered: use selection vector - VectorOperations::Hash(keys.data[0], hashes, sel, count); - for (idx_t i = 1; i < equality_types.size(); i++) { - VectorOperations::CombineHash(hashes, keys.data[i], sel, count); - } - } -} - -static idx_t FilterNullValues(UnifiedVectorFormat &vdata, const SelectionVector &sel, idx_t count, - SelectionVector &result) { - idx_t result_count = 0; - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto key_idx = vdata.sel->get_index(idx); - if (vdata.validity.RowIsValid(key_idx)) { - result.set_index(result_count++, idx); - } - } - return result_count; -} - -void JoinHashTable::Build(PartitionedTupleDataAppendState &append_state, DataChunk &keys, DataChunk &payload) { - D_ASSERT(!finalized); - D_ASSERT(keys.size() == payload.size()); - if (keys.size() == 0) { - return; - } - // special case: correlated mark join - if (join_type == JoinType::MARK && !correlated_mark_join_info.correlated_types.empty()) { - auto &info = correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - // Correlated MARK join - // for the correlated mark join we need to keep track of COUNT(*) and COUNT(COLUMN) for each of the correlated - // columns push into the aggregate hash table - D_ASSERT(info.correlated_counts); - info.group_chunk.SetCardinality(keys); - for (idx_t i = 0; i < info.correlated_types.size(); i++) { - info.group_chunk.data[i].Reference(keys.data[i]); - } - if (info.correlated_payload.data.empty()) { - vector types; - types.push_back(keys.data[info.correlated_types.size()].GetType()); - info.correlated_payload.InitializeEmpty(types); - } - info.correlated_payload.SetCardinality(keys); - info.correlated_payload.data[0].Reference(keys.data[info.correlated_types.size()]); - info.correlated_counts->AddChunk(info.group_chunk, info.correlated_payload, AggregateType::NON_DISTINCT); - } - - // build a chunk to append to the data collection [keys, payload, (optional "found" boolean), hash] - DataChunk source_chunk; - source_chunk.InitializeEmpty(layout.GetTypes()); - for (idx_t i = 0; i < keys.ColumnCount(); i++) { - source_chunk.data[i].Reference(keys.data[i]); - } - idx_t col_offset = keys.ColumnCount(); - D_ASSERT(build_types.size() == payload.ColumnCount()); - for (idx_t i = 0; i < payload.ColumnCount(); i++) { - source_chunk.data[col_offset + i].Reference(payload.data[i]); - } - col_offset += payload.ColumnCount(); - if (PropagatesBuildSide(join_type)) { - // for FULL/RIGHT OUTER joins initialize the "found" boolean to false - source_chunk.data[col_offset].Reference(vfound); - col_offset++; - } - Vector hash_values(LogicalType::HASH); - source_chunk.data[col_offset].Reference(hash_values); - source_chunk.SetCardinality(keys); - - // ToUnifiedFormat the source chunk - TupleDataCollection::ToUnifiedFormat(append_state.chunk_state, source_chunk); - - // prepare the keys for processing - const SelectionVector *current_sel; - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t added_count = PrepareKeys(keys, append_state.chunk_state.vector_data, current_sel, sel, true); - if (added_count < keys.size()) { - has_null = true; - } - if (added_count == 0) { - return; - } - - // hash the keys and obtain an entry in the list - // note that we only hash the keys used in the equality comparison - Hash(keys, *current_sel, added_count, hash_values); - - // Re-reference and ToUnifiedFormat the hash column after computing it - source_chunk.data[col_offset].Reference(hash_values); - hash_values.ToUnifiedFormat(source_chunk.size(), append_state.chunk_state.vector_data.back().unified); - - // We already called TupleDataCollection::ToUnifiedFormat, so we can AppendUnified here - sink_collection->AppendUnified(append_state, source_chunk, *current_sel, added_count); -} - -idx_t JoinHashTable::PrepareKeys(DataChunk &keys, vector &vector_data, - const SelectionVector *¤t_sel, SelectionVector &sel, bool build_side) { - // figure out which keys are NULL, and create a selection vector out of them - current_sel = FlatVector::IncrementalSelectionVector(); - idx_t added_count = keys.size(); - if (build_side && PropagatesBuildSide(join_type)) { - // in case of a right or full outer join, we cannot remove NULL keys from the build side - return added_count; - } - - for (idx_t col_idx = 0; col_idx < keys.ColumnCount(); col_idx++) { - // see internal issue 3717. - if (join_type == JoinType::MARK && !correlated_mark_join_info.correlated_types.empty()) { - continue; - } - if (null_values_are_equal[col_idx]) { - continue; - } - auto &col_key_data = vector_data[col_idx].unified; - if (col_key_data.validity.AllValid()) { - continue; - } - added_count = FilterNullValues(col_key_data, *current_sel, added_count, sel); - // null values are NOT equal for this column, filter them out - current_sel = &sel; - } - return added_count; -} - -static void StorePointer(const const_data_ptr_t &pointer, const data_ptr_t &target) { - Store(cast_pointer_to_uint64(pointer), target); -} - -static data_ptr_t LoadPointer(const const_data_ptr_t &source) { - return cast_uint64_to_pointer(Load(source)); -} - -//! If we consider to insert into an entry we expct to be empty, if it was filled in the meantime the insert will not -//! happen and we need to return the pointer to the to row with which the new entry would have collided. In any other -//! case we return a nullptr -template -static inline data_ptr_t InsertRowToEntry(atomic &entry, const data_ptr_t &row_ptr_to_insert, - const hash_t &salt, const idx_t &pointer_offset) { - const ht_entry_t desired_entry(salt, row_ptr_to_insert); - if (PARALLEL) { - if (EXPECT_EMPTY) { - // Add nullptr to the end of the list to mark the end - StorePointer(nullptr, row_ptr_to_insert + pointer_offset); - - ht_entry_t expected_entry; - entry.compare_exchange_strong(expected_entry, desired_entry, std::memory_order_acquire, - std::memory_order_relaxed); - - // The expected entry is updated with the encountered entry by the compare exchange - // So, this returns a nullptr if it was empty, and a non-null if it was not (which cancels the insert) - return expected_entry.GetPointerOrNull(); - } else { - // At this point we know that the keys match, so we can try to insert until we succeed - ht_entry_t expected_entry = entry.load(std::memory_order_relaxed); - D_ASSERT(expected_entry.IsOccupied()); - do { - data_ptr_t current_row_pointer = expected_entry.GetPointer(); - StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset); - } while (!entry.compare_exchange_weak(expected_entry, desired_entry, std::memory_order_release, - std::memory_order_relaxed)); - - return nullptr; - } - } else { - // If we are not in parallel mode, we can just do the operation without any checks - data_ptr_t current_row_pointer = entry.load(std::memory_order_relaxed).GetPointerOrNull(); - StorePointer(current_row_pointer, row_ptr_to_insert + pointer_offset); - entry = desired_entry; - return nullptr; - } -} -static inline void PerformKeyComparison(JoinHashTable::InsertState &state, JoinHashTable &ht, - const TupleDataCollection &data_collection, Vector &row_locations, - const idx_t count, idx_t &key_match_count, idx_t &key_no_match_count) { - // Get the data for the rows that need to be compared - state.lhs_data.Reset(); - state.lhs_data.SetCardinality(count); // the right size - - // The target selection vector says where to write the results into the lhs_data, we just want to write - // sequentially as otherwise we trigger a bug in the Gather function - data_collection.ResetCachedCastVectors(state.chunk_state, ht.equality_predicate_columns); - data_collection.Gather(row_locations, state.salt_match_sel, count, ht.equality_predicate_columns, state.lhs_data, - *FlatVector::IncrementalSelectionVector(), state.chunk_state.cached_cast_vectors); - TupleDataCollection::ToUnifiedFormat(state.chunk_state, state.lhs_data); - - for (idx_t i = 0; i < count; i++) { - state.key_match_sel.set_index(i, i); - } - - // Perform row comparisons - key_match_count = - ht.row_matcher_build.Match(state.lhs_data, state.chunk_state.vector_data, state.key_match_sel, count, ht.layout, - state.rhs_row_locations, &state.key_no_match_sel, key_no_match_count); - - D_ASSERT(key_match_count + key_no_match_count == count); -} - -template -static inline void InsertMatchesAndIncrementMisses(atomic entries[], JoinHashTable::InsertState &state, - JoinHashTable &ht, const data_ptr_t lhs_row_locations[], - idx_t ht_offsets[], const hash_t hash_salts[], - const idx_t capacity_mask, const idx_t key_match_count, - const idx_t key_no_match_count) { - if (key_match_count != 0) { - ht.chains_longer_than_one = true; - } - - // Insert the rows that match - for (idx_t i = 0; i < key_match_count; i++) { - const auto need_compare_idx = state.key_match_sel.get_index(i); - const auto entry_index = state.salt_match_sel.get_index(need_compare_idx); - - const auto &ht_offset = ht_offsets[entry_index]; - auto &entry = entries[ht_offset]; - const auto row_ptr_to_insert = lhs_row_locations[entry_index]; - - const auto salt = hash_salts[entry_index]; - InsertRowToEntry(entry, row_ptr_to_insert, salt, ht.pointer_offset); - } - - // Linear probing: each of the entries that do not match move to the next entry in the HT - for (idx_t i = 0; i < key_no_match_count; i++) { - const auto need_compare_idx = state.key_no_match_sel.get_index(i); - const auto entry_index = state.salt_match_sel.get_index(need_compare_idx); - - auto &ht_offset = ht_offsets[entry_index]; - IncrementAndWrap(ht_offset, capacity_mask); - - state.remaining_sel.set_index(i, entry_index); - } -} - -template -static void InsertHashesLoop(atomic entries[], Vector &row_locations, Vector &hashes_v, const idx_t &count, - JoinHashTable::InsertState &state, const TupleDataCollection &data_collection, - JoinHashTable &ht) { - D_ASSERT(hashes_v.GetType().id() == LogicalType::HASH); - ApplyBitmaskAndGetSaltBuild(hashes_v, state.salt_v, count, ht.bitmask); - - // the salts offset for each row to insert - const auto ht_offsets = FlatVector::GetData(hashes_v); - const auto hash_salts = FlatVector::GetData(state.salt_v); - // the row locations of the rows that are already in the hash table - const auto rhs_row_locations = FlatVector::GetData(state.rhs_row_locations); - // the row locations of the rows that are to be inserted - const auto lhs_row_locations = FlatVector::GetData(row_locations); - - // we start off with the entire chunk - idx_t remaining_count = count; - const auto *remaining_sel = FlatVector::IncrementalSelectionVector(); - - if (PropagatesBuildSide(ht.join_type)) { - // if we propagate the build side, we may have added rows with NULL keys to the HT - // these may need to be filtered out depending on the comparison type (exactly like PrepareKeys does) - for (idx_t col_idx = 0; col_idx < ht.conditions.size(); col_idx++) { - // if null values are NOT equal for this column we filter them out - if (ht.NullValuesAreEqual(col_idx)) { - continue; - } - - idx_t entry_idx; - idx_t idx_in_entry; - ValidityBytes::GetEntryIndex(col_idx, entry_idx, idx_in_entry); - - idx_t new_remaining_count = 0; - for (idx_t i = 0; i < remaining_count; i++) { - const auto idx = remaining_sel->get_index(i); - if (ValidityBytes(lhs_row_locations[idx], count).RowIsValidUnsafe(col_idx)) { - state.remaining_sel.set_index(new_remaining_count++, idx); - } - } - remaining_count = new_remaining_count; - remaining_sel = &state.remaining_sel; - } - } - - // use the ht bitmask to make the modulo operation faster but keep the salt bits intact - idx_t capacity_mask = ht.bitmask | ht_entry_t::SALT_MASK; - while (remaining_count > 0) { - idx_t salt_match_count = 0; - - // iterate over each entry to find out whether it belongs to an existing list or will start a new list - for (idx_t i = 0; i < remaining_count; i++) { - const idx_t row_index = remaining_sel->get_index(i); - auto &ht_offset = ht_offsets[row_index]; - auto &salt = hash_salts[row_index]; - - // increment the ht_offset of the entry as long as next entry is occupied and salt does not match - ht_entry_t entry; - bool occupied; - while (true) { - atomic &atomic_entry = entries[ht_offset]; - entry = atomic_entry.load(std::memory_order_relaxed); - occupied = entry.IsOccupied(); - - // condition for incrementing the ht_offset: occupied and row_salt does not match -> move to next entry - if (!occupied) { - break; - } - if (entry.GetSalt() == salt) { - break; - } - - IncrementAndWrap(ht_offset, capacity_mask); - } - - if (!occupied) { // insert into free - auto &atomic_entry = entries[ht_offset]; - const auto row_ptr_to_insert = lhs_row_locations[row_index]; - const auto potential_collided_ptr = - InsertRowToEntry(atomic_entry, row_ptr_to_insert, salt, ht.pointer_offset); - - if (PARALLEL) { - // if the insertion was not successful, the entry was occupied in the meantime, so we have to - // compare the keys and insert the row to the next entry - if (potential_collided_ptr) { - // if the entry was occupied, we need to compare the keys and insert the row to the next entry - // we need to compare the keys and insert the row to the next entry - state.salt_match_sel.set_index(salt_match_count, row_index); - rhs_row_locations[salt_match_count] = potential_collided_ptr; - salt_match_count += 1; - } - } - - } else { // compare with full entry - state.salt_match_sel.set_index(salt_match_count, row_index); - rhs_row_locations[salt_match_count] = entry.GetPointer(); - salt_match_count += 1; - } - } - - // at this step, for all the rows to insert we stepped either until we found an empty entry or an entry with - // a matching salt, we now need to compare the keys for the ones that have a matching salt - idx_t key_no_match_count = 0; - if (salt_match_count != 0) { - idx_t key_match_count = 0; - PerformKeyComparison(state, ht, data_collection, row_locations, salt_match_count, key_match_count, - key_no_match_count); - InsertMatchesAndIncrementMisses(entries, state, ht, lhs_row_locations, ht_offsets, hash_salts, - capacity_mask, key_match_count, key_no_match_count); - } - - // update the overall selection vector to only point the entries that still need to be inserted - // as there was no match found for them yet - remaining_sel = &state.remaining_sel; - remaining_count = key_no_match_count; - } -} - -void JoinHashTable::InsertHashes(Vector &hashes_v, const idx_t count, TupleDataChunkState &chunk_state, - InsertState &insert_state, bool parallel) { - auto atomic_entries = reinterpret_cast *>(this->entries); - auto row_locations = chunk_state.row_locations; - if (parallel) { - InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); - } else { - InsertHashesLoop(atomic_entries, row_locations, hashes_v, count, insert_state, *data_collection, *this); - } -} - -void JoinHashTable::InitializePointerTable() { - capacity = PointerTableCapacity(Count()); - D_ASSERT(IsPowerOfTwo(capacity)); - - if (hash_map.get()) { - // There is already a hash map - auto current_capacity = hash_map.GetSize() / sizeof(ht_entry_t); - if (capacity > current_capacity) { - // Need more space - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); - } else { - // Just use the current hash map - capacity = current_capacity; - } - } else { - // Allocate a hash map - hash_map = buffer_manager.GetBufferAllocator().Allocate(capacity * sizeof(ht_entry_t)); - entries = reinterpret_cast(hash_map.get()); - } - D_ASSERT(hash_map.GetSize() == capacity * sizeof(ht_entry_t)); - - // initialize HT with all-zero entries - std::fill_n(entries, capacity, ht_entry_t()); - - bitmask = capacity - 1; -} - -void JoinHashTable::Finalize(idx_t chunk_idx_from, idx_t chunk_idx_to, bool parallel) { - // Pointer table should be allocated - D_ASSERT(hash_map.get()); - - Vector hashes(LogicalType::HASH); - auto hash_data = FlatVector::GetData(hashes); - - TupleDataChunkIterator iterator(*data_collection, TupleDataPinProperties::KEEP_EVERYTHING_PINNED, chunk_idx_from, - chunk_idx_to, false); - const auto row_locations = iterator.GetRowLocations(); - - InsertState insert_state(*this); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = 0; i < count; i++) { - hash_data[i] = Load(row_locations[i] + pointer_offset); - } - TupleDataChunkState &chunk_state = iterator.GetChunkState(); - - InsertHashes(hashes, count, chunk_state, insert_state, parallel); - } while (iterator.Next()); -} - -void JoinHashTable::InitializeScanStructure(ScanStructure &scan_structure, DataChunk &keys, - TupleDataChunkState &key_state, const SelectionVector *¤t_sel) { - D_ASSERT(Count() > 0); // should be handled before - D_ASSERT(finalized); - - // set up the scan structure - scan_structure.is_null = false; - scan_structure.finished = false; - if (join_type != JoinType::INNER) { - memset(scan_structure.found_match.get(), 0, sizeof(bool) * STANDARD_VECTOR_SIZE); - } - - // first prepare the keys for probing - TupleDataCollection::ToUnifiedFormat(key_state, keys); - scan_structure.count = PrepareKeys(keys, key_state.vector_data, current_sel, scan_structure.sel_vector, false); -} - -void JoinHashTable::Probe(ScanStructure &scan_structure, DataChunk &keys, TupleDataChunkState &key_state, - ProbeState &probe_state, optional_ptr precomputed_hashes) { - const SelectionVector *current_sel; - InitializeScanStructure(scan_structure, keys, key_state, current_sel); - if (scan_structure.count == 0) { - return; - } - - if (precomputed_hashes) { - GetRowPointers(keys, key_state, probe_state, *precomputed_hashes, *current_sel, scan_structure.count, - scan_structure.pointers, scan_structure.sel_vector); - } else { - Vector hashes(LogicalType::HASH); - // hash all the keys - Hash(keys, *current_sel, scan_structure.count, hashes); - - // now initialize the pointers of the scan structure based on the hashes - GetRowPointers(keys, key_state, probe_state, hashes, *current_sel, scan_structure.count, - scan_structure.pointers, scan_structure.sel_vector); - } -} - -ScanStructure::ScanStructure(JoinHashTable &ht_p, TupleDataChunkState &key_state_p) - : key_state(key_state_p), pointers(LogicalType::POINTER), count(0), sel_vector(STANDARD_VECTOR_SIZE), - chain_match_sel_vector(STANDARD_VECTOR_SIZE), chain_no_match_sel_vector(STANDARD_VECTOR_SIZE), - found_match(make_unsafe_uniq_array_uninitialized(STANDARD_VECTOR_SIZE)), ht(ht_p), finished(false), - is_null(true), rhs_pointers(LogicalType::POINTER), lhs_sel_vector(STANDARD_VECTOR_SIZE), last_match_count(0), - last_sel_vector(STANDARD_VECTOR_SIZE) { -} - -void ScanStructure::Next(DataChunk &keys, DataChunk &left, DataChunk &result) { - D_ASSERT(keys.size() == left.size()); - if (finished) { - return; - } - switch (ht.join_type) { - case JoinType::INNER: - case JoinType::RIGHT: - NextInnerJoin(keys, left, result); - break; - case JoinType::SEMI: - NextSemiJoin(keys, left, result); - break; - case JoinType::MARK: - NextMarkJoin(keys, left, result); - break; - case JoinType::ANTI: - NextAntiJoin(keys, left, result); - break; - case JoinType::RIGHT_ANTI: - case JoinType::RIGHT_SEMI: - NextRightSemiOrAntiJoin(keys); - break; - case JoinType::OUTER: - case JoinType::LEFT: - NextLeftJoin(keys, left, result); - break; - case JoinType::SINGLE: - NextSingleJoin(keys, left, result); - break; - default: - throw InternalException("Unhandled join type in JoinHashTable"); - } -} - -bool ScanStructure::PointersExhausted() const { - // AdvancePointers creates a "new_count" for every pointer advanced during the - // previous advance pointers call. If no pointers are advanced, new_count = 0. - // count is then set ot new_count. - return count == 0; -} - -idx_t ScanStructure::ResolvePredicates(DataChunk &keys, SelectionVector &match_sel, SelectionVector *no_match_sel) { - - // Initialize the found_match array to the current sel_vector - for (idx_t i = 0; i < this->count; ++i) { - match_sel.set_index(i, this->sel_vector.get_index(i)); - } - - // If there is a matcher for the probing side because of non-equality predicates, use it - if (ht.needs_chain_matcher) { - idx_t no_match_count = 0; - auto &matcher = no_match_sel ? ht.row_matcher_probe_no_match_sel : ht.row_matcher_probe; - D_ASSERT(matcher); - - // we need to only use the vectors with the indices of the columns that are used in the probe phase, namely - // the non-equality columns - return matcher->Match(keys, key_state.vector_data, match_sel, this->count, ht.layout, pointers, no_match_sel, - no_match_count, ht.non_equality_predicate_columns); - } else { - // no match sel is the opposite of match sel - return this->count; - } -} - -idx_t ScanStructure::ScanInnerJoin(DataChunk &keys, SelectionVector &result_vector) { - while (true) { - // resolve the equality_predicates for this set of keys - idx_t result_count = ResolvePredicates(keys, result_vector, nullptr); - - // after doing all the comparisons set the found_match vector - if (found_match) { - for (idx_t i = 0; i < result_count; i++) { - auto idx = result_vector.get_index(i); - found_match[idx] = true; - } - } - if (result_count > 0) { - return result_count; - } - // no matches found: check the next set of pointers - AdvancePointers(); - if (this->count == 0) { - return 0; - } - } -} - -void ScanStructure::AdvancePointers(const SelectionVector &sel, const idx_t sel_count) { - - if (!ht.chains_longer_than_one) { - this->count = 0; - return; - } - - // now for all the pointers, we move on to the next set of pointers - idx_t new_count = 0; - auto ptrs = FlatVector::GetData(this->pointers); - for (idx_t i = 0; i < sel_count; i++) { - auto idx = sel.get_index(i); - ptrs[idx] = LoadPointer(ptrs[idx] + ht.pointer_offset); - if (ptrs[idx]) { - this->sel_vector.set_index(new_count++, idx); - } - } - this->count = new_count; -} - -void ScanStructure::AdvancePointers() { - AdvancePointers(this->sel_vector, this->count); -} - -void ScanStructure::GatherResult(Vector &result, const SelectionVector &result_vector, - const SelectionVector &sel_vector, const idx_t count, const idx_t col_no) { - ht.data_collection->Gather(pointers, sel_vector, count, col_no, result, result_vector, nullptr); -} - -void ScanStructure::GatherResult(Vector &result, const SelectionVector &sel_vector, const idx_t count, - const idx_t col_idx) { - GatherResult(result, *FlatVector::IncrementalSelectionVector(), sel_vector, count, col_idx); -} - -void ScanStructure::GatherResult(Vector &result, const idx_t count, const idx_t col_idx) { - ht.data_collection->Gather(rhs_pointers, *FlatVector::IncrementalSelectionVector(), count, col_idx, result, - *FlatVector::IncrementalSelectionVector(), nullptr); -} - -void ScanStructure::UpdateCompactionBuffer(idx_t base_count, SelectionVector &result_vector, idx_t result_count) { - // matches were found - // record the result - // on the LHS, we store result vector - for (idx_t i = 0; i < result_count; i++) { - lhs_sel_vector.set_index(base_count + i, result_vector.get_index(i)); - } - - // on the RHS, we collect their pointers - VectorOperations::Copy(pointers, rhs_pointers, result_vector, result_count, 0, base_count); -} - -void ScanStructure::NextInnerJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - if (ht.join_type != JoinType::RIGHT_SEMI && ht.join_type != JoinType::RIGHT_ANTI) { - D_ASSERT(result.ColumnCount() == left.ColumnCount() + ht.output_columns.size()); - } - - idx_t base_count = 0; - idx_t result_count; - while (this->count > 0) { - // if we have saved the match result, we need not call ScanInnerJoin again - if (last_match_count == 0) { - result_count = ScanInnerJoin(keys, chain_match_sel_vector); - } else { - chain_match_sel_vector.Initialize(last_sel_vector); - result_count = last_match_count; - last_match_count = 0; - } - - if (result_count > 0) { - // the result chunk cannot contain more data, we record the match result for future use - if (base_count + result_count > STANDARD_VECTOR_SIZE) { - last_sel_vector.Initialize(chain_match_sel_vector); - last_match_count = result_count; - break; - } - - if (PropagatesBuildSide(ht.join_type)) { - // full/right outer join: mark join matches as FOUND in the HT - auto ptrs = FlatVector::GetData(pointers); - for (idx_t i = 0; i < result_count; i++) { - auto idx = chain_match_sel_vector.get_index(i); - // NOTE: threadsan reports this as a data race because this can be set concurrently by separate - // threads Technically it is, but it does not matter, since the only value that can be written is - // "true" - Store(true, ptrs[idx] + ht.tuple_size); - } - } - - if (ht.join_type != JoinType::RIGHT_SEMI && ht.join_type != JoinType::RIGHT_ANTI) { - // Fast Path: if there is NO more than one element in the chain, we construct the result chunk directly - if (!ht.chains_longer_than_one) { - // matches were found - // on the LHS, we create a slice using the result vector - result.Slice(left, chain_match_sel_vector, result_count); - - // on the RHS, we need to fetch the data from the hash table - for (idx_t i = 0; i < ht.output_columns.size(); i++) { - auto &vector = result.data[left.ColumnCount() + i]; - const auto output_col_idx = ht.output_columns[i]; - D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); - GatherResult(vector, chain_match_sel_vector, result_count, output_col_idx); - } - - AdvancePointers(); - return; - } - - // Common Path: use a buffer to store temporary data - UpdateCompactionBuffer(base_count, chain_match_sel_vector, result_count); - base_count += result_count; - } - } - AdvancePointers(); - } - - if (base_count > 0) { - // create result chunk, we have two steps: - // 1) slice LHS vectors - result.Slice(left, lhs_sel_vector, base_count); - - // 2) gather RHS vectors - for (idx_t i = 0; i < ht.output_columns.size(); i++) { - auto &vector = result.data[left.ColumnCount() + i]; - const auto output_col_idx = ht.output_columns[i]; - D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); - GatherResult(vector, base_count, output_col_idx); - } - } -} - -void ScanStructure::ScanKeyMatches(DataChunk &keys) { - // the semi-join, anti-join and mark-join we handle a differently from the inner join - // since there can be at most STANDARD_VECTOR_SIZE results - // we handle the entire chunk in one call to Next(). - // for every pointer, we keep chasing pointers and doing comparisons. - // this results in a boolean array indicating whether or not the tuple has a match - // Start with the scan selection - - while (this->count > 0) { - // resolve the equality_predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, &chain_no_match_sel_vector); - idx_t no_match_count = this->count - match_count; - - // mark each of the matches as found - for (idx_t i = 0; i < match_count; i++) { - found_match[chain_match_sel_vector.get_index(i)] = true; - } - // continue searching for the ones where we did not find a match yet - AdvancePointers(chain_no_match_sel_vector, no_match_count); - } -} - -template -void ScanStructure::NextSemiOrAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - D_ASSERT(left.ColumnCount() == result.ColumnCount()); - // create the selection vector from the matches that were found - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t result_count = 0; - for (idx_t i = 0; i < keys.size(); i++) { - if (found_match[i] == MATCH) { - // part of the result - sel.set_index(result_count++, i); - } - } - // construct the final result - if (result_count > 0) { - // we only return the columns on the left side - // reference the columns of the left side from the result - result.Slice(left, sel, result_count); - } else { - D_ASSERT(result.size() == 0); - } -} - -void ScanStructure::NextSemiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // first scan for key matches - ScanKeyMatches(keys); - // then construct the result from all tuples with a match - NextSemiOrAntiJoin(keys, left, result); - - finished = true; -} - -void ScanStructure::NextAntiJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // first scan for key matches - ScanKeyMatches(keys); - // then construct the result from all tuples that did not find a match - NextSemiOrAntiJoin(keys, left, result); - - finished = true; -} - -void ScanStructure::NextRightSemiOrAntiJoin(DataChunk &keys) { - const auto ptrs = FlatVector::GetData(pointers); - while (!PointersExhausted()) { - // resolve the equality_predicates for this set of keys - idx_t result_count = ResolvePredicates(keys, chain_match_sel_vector, nullptr); - - // for each match, fully follow the chain - for (idx_t i = 0; i < result_count; i++) { - const auto idx = chain_match_sel_vector.get_index(i); - auto &ptr = ptrs[idx]; - if (Load(ptr + ht.tuple_size)) { // Early out: chain has been fully marked as found before - ptr = ht.dead_end.get(); - continue; - } - - // Fully mark chain as found - while (true) { - // NOTE: threadsan reports this as a data race because this can be set concurrently by separate threads - // Technically it is, but it does not matter, since the only value that can be written is "true" - Store(true, ptr + ht.tuple_size); - auto next_ptr = LoadPointer(ptr + ht.pointer_offset); - if (!next_ptr) { - break; - } - ptr = next_ptr; - } - } - - // check the next set of pointers - AdvancePointers(); - } - - finished = true; -} - -void ScanStructure::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &child, DataChunk &result) { - // for the initial set of columns we just reference the left side - result.SetCardinality(child); - for (idx_t i = 0; i < child.ColumnCount(); i++) { - result.data[i].Reference(child.data[i]); - } - auto &mark_vector = result.data.back(); - mark_vector.SetVectorType(VectorType::FLAT_VECTOR); - // first we set the NULL values from the join keys - // if there is any NULL in the keys, the result is NULL - auto bool_result = FlatVector::GetData(mark_vector); - auto &mask = FlatVector::Validity(mark_vector); - for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { - if (ht.null_values_are_equal[col_idx]) { - continue; - } - UnifiedVectorFormat jdata; - join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); - if (!jdata.validity.AllValid()) { - for (idx_t i = 0; i < join_keys.size(); i++) { - auto jidx = jdata.sel->get_index(i); - mask.Set(i, jdata.validity.RowIsValidUnsafe(jidx)); - } - } - } - // now set the remaining entries to either true or false based on whether a match was found - if (found_match) { - for (idx_t i = 0; i < child.size(); i++) { - bool_result[i] = found_match[i]; - } - } else { - memset(bool_result, 0, sizeof(bool) * child.size()); - } - // if the right side contains NULL values, the result of any FALSE becomes NULL - if (ht.has_null) { - for (idx_t i = 0; i < child.size(); i++) { - if (!bool_result[i]) { - mask.SetInvalid(i); - } - } - } -} - -void ScanStructure::NextMarkJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - D_ASSERT(result.ColumnCount() == left.ColumnCount() + 1); - D_ASSERT(result.data.back().GetType() == LogicalType::BOOLEAN); - // this method should only be called for a non-empty HT - D_ASSERT(ht.Count() > 0); - - ScanKeyMatches(keys); - if (ht.correlated_mark_join_info.correlated_types.empty()) { - ConstructMarkJoinResult(keys, left, result); - } else { - auto &info = ht.correlated_mark_join_info; - lock_guard mj_lock(info.mj_lock); - - // there are correlated columns - // first we fetch the counts from the aggregate hashtable corresponding to these entries - D_ASSERT(keys.ColumnCount() == info.group_chunk.ColumnCount() + 1); - info.group_chunk.SetCardinality(keys); - for (idx_t i = 0; i < info.group_chunk.ColumnCount(); i++) { - info.group_chunk.data[i].Reference(keys.data[i]); - } - info.correlated_counts->FetchAggregates(info.group_chunk, info.result_chunk); - - // for the initial set of columns we just reference the left side - result.SetCardinality(left); - for (idx_t i = 0; i < left.ColumnCount(); i++) { - result.data[i].Reference(left.data[i]); - } - // create the result matching vector - auto &last_key = keys.data.back(); - auto &result_vector = result.data.back(); - // first set the nullmask based on whether or not there were NULL values in the join key - result_vector.SetVectorType(VectorType::FLAT_VECTOR); - auto bool_result = FlatVector::GetData(result_vector); - auto &mask = FlatVector::Validity(result_vector); - switch (last_key.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: - if (ConstantVector::IsNull(last_key)) { - mask.SetAllInvalid(left.size()); - } - break; - case VectorType::FLAT_VECTOR: - mask.Copy(FlatVector::Validity(last_key), left.size()); - break; - default: { - UnifiedVectorFormat kdata; - last_key.ToUnifiedFormat(keys.size(), kdata); - for (idx_t i = 0; i < left.size(); i++) { - auto kidx = kdata.sel->get_index(i); - mask.Set(i, kdata.validity.RowIsValid(kidx)); - } - break; - } - } - - auto count_star = FlatVector::GetData(info.result_chunk.data[0]); - auto count = FlatVector::GetData(info.result_chunk.data[1]); - // set the entries to either true or false based on whether a match was found - for (idx_t i = 0; i < left.size(); i++) { - D_ASSERT(count_star[i] >= count[i]); - bool_result[i] = found_match ? found_match[i] : false; - if (!bool_result[i] && count_star[i] > count[i]) { - // RHS has NULL value and result is false: set to null - mask.SetInvalid(i); - } - if (count_star[i] == 0) { - // count == 0, set nullmask to false (we know the result is false now) - mask.SetValid(i); - } - } - } - finished = true; -} - -void ScanStructure::NextLeftJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // a LEFT OUTER JOIN is identical to an INNER JOIN except all tuples that do - // not have a match must return at least one tuple (with the right side set - // to NULL in every column) - NextInnerJoin(keys, left, result); - if (result.size() == 0) { - // no entries left from the normal join - // fill in the result of the remaining left tuples - // together with NULL values on the right-hand side - idx_t remaining_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < left.size(); i++) { - if (!found_match[i]) { - sel.set_index(remaining_count++, i); - } - } - if (remaining_count > 0) { - // have remaining tuples - // slice the left side with tuples that did not find a match - result.Slice(left, sel, remaining_count); - - // now set the right side to NULL - for (idx_t i = left.ColumnCount(); i < result.ColumnCount(); i++) { - Vector &vec = result.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - } - finished = true; - } -} - -void ScanStructure::NextSingleJoin(DataChunk &keys, DataChunk &left, DataChunk &result) { - // single join - // this join is similar to the semi join except that - // (1) we actually return data from the RHS and - // (2) we return NULL for that data if there is no match - // (3) if single_join_error_on_multiple_rows is set, we need to keep looking for duplicates after fetching - idx_t result_count = 0; - SelectionVector result_sel(STANDARD_VECTOR_SIZE); - - while (this->count > 0) { - // resolve the equality_predicates for the current set of pointers - idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, &chain_no_match_sel_vector); - idx_t no_match_count = this->count - match_count; - - // mark each of the matches as found - for (idx_t i = 0; i < match_count; i++) { - // found a match for this index - auto index = chain_match_sel_vector.get_index(i); - found_match[index] = true; - result_sel.set_index(result_count++, index); - } - // continue searching for the ones where we did not find a match yet - AdvancePointers(chain_no_match_sel_vector, no_match_count); - } - // reference the columns of the left side from the result - D_ASSERT(left.ColumnCount() > 0); - for (idx_t i = 0; i < left.ColumnCount(); i++) { - result.data[i].Reference(left.data[i]); - } - // now fetch the data from the RHS - for (idx_t i = 0; i < ht.output_columns.size(); i++) { - auto &vector = result.data[left.ColumnCount() + i]; - // set NULL entries for every entry that was not found - for (idx_t j = 0; j < left.size(); j++) { - if (!found_match[j]) { - FlatVector::SetNull(vector, j, true); - } - } - const auto output_col_idx = ht.output_columns[i]; - D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); - GatherResult(vector, result_sel, result_sel, result_count, output_col_idx); - } - result.SetCardinality(left.size()); - - // like the SEMI, ANTI and MARK join types, the SINGLE join only ever does one pass over the HT per input chunk - finished = true; - - if (ht.single_join_error_on_multiple_rows && result_count > 0) { - // we need to throw an error if there are multiple rows per key - // advance pointers for those rows - AdvancePointers(result_sel, result_count); - - // now resolve the predicates - idx_t match_count = ResolvePredicates(keys, chain_match_sel_vector, nullptr); - if (match_count > 0) { - // we found at least one duplicate row - throw - throw InvalidInputException( - "More than one row returned by a subquery used as an expression - scalar subqueries can only " - "return a single row.\n\nUse \"SET scalar_subquery_error_on_multiple_rows=false\" to revert to " - "previous behavior of returning a random row."); - } - - this->count = 0; - } -} - -void JoinHashTable::ScanFullOuter(JoinHTScanState &state, Vector &addresses, DataChunk &result) const { - // scan the HT starting from the current position and check which rows from the build side did not find a match - auto key_locations = FlatVector::GetData(addresses); - idx_t found_entries = 0; - - auto &iterator = state.iterator; - if (iterator.Done()) { - return; - } - - // When scanning Full Outer for right semi joins, we only propagate matches that have true - // Right Semi Joins do not propagate values during the probe phase, since we do not want to - // duplicate RHS rows. - bool match_propagation_value = false; - if (join_type == JoinType::RIGHT_SEMI) { - match_propagation_value = true; - } - - const auto row_locations = iterator.GetRowLocations(); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = state.offset_in_chunk; i < count; i++) { - auto found_match = Load(row_locations[i] + tuple_size); - if (found_match == match_propagation_value) { - key_locations[found_entries++] = row_locations[i]; - if (found_entries == STANDARD_VECTOR_SIZE) { - state.offset_in_chunk = i + 1; - break; - } - } - } - if (found_entries == STANDARD_VECTOR_SIZE) { - break; - } - state.offset_in_chunk = 0; - } while (iterator.Next()); - - // now gather from the found rows - if (found_entries == 0) { - return; - } - result.SetCardinality(found_entries); - - idx_t left_column_count = result.ColumnCount() - output_columns.size(); - if (join_type == JoinType::RIGHT_SEMI || join_type == JoinType::RIGHT_ANTI) { - left_column_count = 0; - } - const auto &sel_vector = *FlatVector::IncrementalSelectionVector(); - // set the left side as a constant NULL - for (idx_t i = 0; i < left_column_count; i++) { - Vector &vec = result.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - - // gather the values from the RHS - for (idx_t i = 0; i < output_columns.size(); i++) { - auto &vector = result.data[left_column_count + i]; - const auto output_col_idx = output_columns[i]; - D_ASSERT(vector.GetType() == layout.GetTypes()[output_col_idx]); - data_collection->Gather(addresses, sel_vector, found_entries, output_col_idx, vector, sel_vector, nullptr); - } -} - -idx_t JoinHashTable::FillWithHTOffsets(JoinHTScanState &state, Vector &addresses) { - // iterate over HT - auto key_locations = FlatVector::GetData(addresses); - idx_t key_count = 0; - - auto &iterator = state.iterator; - const auto row_locations = iterator.GetRowLocations(); - do { - const auto count = iterator.GetCurrentChunkCount(); - for (idx_t i = 0; i < count; i++) { - key_locations[key_count + i] = row_locations[i]; - } - key_count += count; - } while (iterator.Next()); - - return key_count; -} - -idx_t JoinHashTable::GetTotalSize(const vector &partition_sizes, const vector &partition_counts, - idx_t &max_partition_size, idx_t &max_partition_count) const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - - idx_t total_size = 0; - idx_t total_count = 0; - idx_t max_partition_ht_size = 0; - max_partition_size = 0; - max_partition_count = 0; - for (idx_t i = 0; i < num_partitions; i++) { - total_size += partition_sizes[i]; - total_count += partition_counts[i]; - - auto partition_size = partition_sizes[i] + PointerTableSize(partition_counts[i]); - if (partition_size > max_partition_ht_size) { - max_partition_ht_size = partition_size; - max_partition_size = partition_sizes[i]; - max_partition_count = partition_counts[i]; - } - } - - if (total_count == 0) { - return 0; - } - - return total_size + PointerTableSize(total_count); -} - -idx_t JoinHashTable::GetTotalSize(const vector> &local_hts, idx_t &max_partition_size, - idx_t &max_partition_count) const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - vector partition_sizes(num_partitions, 0); - vector partition_counts(num_partitions, 0); - for (auto &ht : local_hts) { - ht->GetSinkCollection().GetSizesAndCounts(partition_sizes, partition_counts); - } - - return GetTotalSize(partition_sizes, partition_counts, max_partition_size, max_partition_count); -} - -idx_t JoinHashTable::GetRemainingSize() const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - auto &partitions = sink_collection->GetPartitions(); - - idx_t count = 0; - idx_t data_size = 0; - for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { - if (completed_partitions.RowIsValidUnsafe(partition_idx)) { - continue; - } - count += partitions[partition_idx]->Count(); - data_size += partitions[partition_idx]->SizeInBytes(); - } - - return data_size + PointerTableSize(count); -} - -void JoinHashTable::Unpartition() { - data_collection = sink_collection->GetUnpartitioned(); -} - -void JoinHashTable::SetRepartitionRadixBits(const idx_t max_ht_size, const idx_t max_partition_size, - const idx_t max_partition_count) { - D_ASSERT(max_partition_size + PointerTableSize(max_partition_count) > max_ht_size); - - const auto max_added_bits = RadixPartitioning::MAX_RADIX_BITS - radix_bits; - idx_t added_bits = 1; - for (; added_bits < max_added_bits; added_bits++) { - double partition_multiplier = static_cast(RadixPartitioning::NumberOfPartitions(added_bits)); - - auto new_estimated_size = static_cast(max_partition_size) / partition_multiplier; - auto new_estimated_count = static_cast(max_partition_count) / partition_multiplier; - auto new_estimated_ht_size = - new_estimated_size + static_cast(PointerTableSize(LossyNumericCast(new_estimated_count))); - - if (new_estimated_ht_size <= static_cast(max_ht_size) / 4) { - // Aim for an estimated partition size of max_ht_size / 4 - break; - } - } - radix_bits += added_bits; - sink_collection = - make_uniq(buffer_manager, layout, radix_bits, layout.ColumnCount() - 1); - - // Need to initialize again after changing the number of bits - InitializePartitionMasks(); -} - -void JoinHashTable::InitializePartitionMasks() { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - - current_partitions.Initialize(num_partitions); - current_partitions.SetAllInvalid(num_partitions); - - completed_partitions.Initialize(num_partitions); - completed_partitions.SetAllInvalid(num_partitions); -} - -idx_t JoinHashTable::CurrentPartitionCount() const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - D_ASSERT(current_partitions.Capacity() == num_partitions); - return current_partitions.CountValid(num_partitions); -} - -idx_t JoinHashTable::FinishedPartitionCount() const { - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - D_ASSERT(completed_partitions.Capacity() == num_partitions); - // We already marked the active partitions as done, so we have to subtract them here - return completed_partitions.CountValid(num_partitions) - CurrentPartitionCount(); -} - -void JoinHashTable::Repartition(JoinHashTable &global_ht) { - auto new_sink_collection = - make_uniq(buffer_manager, layout, global_ht.radix_bits, layout.ColumnCount() - 1); - sink_collection->Repartition(*new_sink_collection); - sink_collection = std::move(new_sink_collection); - global_ht.Merge(*this); -} - -void JoinHashTable::Reset() { - data_collection->Reset(); - hash_map.Reset(); - current_partitions.SetAllInvalid(RadixPartitioning::NumberOfPartitions(radix_bits)); - finalized = false; -} - -bool JoinHashTable::PrepareExternalFinalize(const idx_t max_ht_size) { - if (finalized) { - Reset(); - } - - const auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - D_ASSERT(current_partitions.Capacity() == num_partitions); - D_ASSERT(completed_partitions.Capacity() == num_partitions); - D_ASSERT(current_partitions.CheckAllInvalid(num_partitions)); - - if (completed_partitions.CheckAllValid(num_partitions)) { - return false; // All partitions are done - } - - // Create vector with unfinished partition indices - auto &partitions = sink_collection->GetPartitions(); - auto min_partition_size = NumericLimits::Maximum(); - vector partition_indices; - partition_indices.reserve(num_partitions); - for (idx_t partition_idx = 0; partition_idx < num_partitions; partition_idx++) { - if (completed_partitions.RowIsValidUnsafe(partition_idx)) { - continue; - } - partition_indices.push_back(partition_idx); - // Keep track of min partition size - const auto size = - partitions[partition_idx]->SizeInBytes() + PointerTableSize(partitions[partition_idx]->Count()); - min_partition_size = MinValue(min_partition_size, size); - } - - // Sort partitions by size, from small to large - std::stable_sort(partition_indices.begin(), partition_indices.end(), [&](const idx_t &lhs, const idx_t &rhs) { - const auto lhs_size = partitions[lhs]->SizeInBytes() + PointerTableSize(partitions[lhs]->Count()); - const auto rhs_size = partitions[rhs]->SizeInBytes() + PointerTableSize(partitions[rhs]->Count()); - // We divide by min_partition_size, effectively rouding everything down to a multiple of min_partition_size - // Makes it so minor differences in partition sizes don't mess up the original order - // Retaining as much of the original order as possible reduces I/O (partition idx determines eviction queue idx) - return lhs_size / min_partition_size < rhs_size / min_partition_size; - }); - - // Determine which partitions should go next - idx_t count = 0; - idx_t data_size = 0; - for (const auto &partition_idx : partition_indices) { - D_ASSERT(!completed_partitions.RowIsValidUnsafe(partition_idx)); - const auto incl_count = count + partitions[partition_idx]->Count(); - const auto incl_data_size = data_size + partitions[partition_idx]->SizeInBytes(); - const auto incl_ht_size = incl_data_size + PointerTableSize(incl_count); - if (count > 0 && incl_ht_size > max_ht_size) { - break; // Always add at least one partition - } - count = incl_count; - data_size = incl_data_size; - current_partitions.SetValidUnsafe(partition_idx); // Mark as currently active - data_collection->Combine(*partitions[partition_idx]); // Move partition to the main data collection - completed_partitions.SetValidUnsafe(partition_idx); // Also already mark as done - } - D_ASSERT(Count() == count); - - return true; -} - -void JoinHashTable::ProbeAndSpill(ScanStructure &scan_structure, DataChunk &probe_keys, TupleDataChunkState &key_state, - ProbeState &probe_state, DataChunk &probe_chunk, ProbeSpill &probe_spill, - ProbeSpillLocalAppendState &spill_state, DataChunk &spill_chunk) { - // hash all the keys - Vector hashes(LogicalType::HASH); - Hash(probe_keys, *FlatVector::IncrementalSelectionVector(), probe_keys.size(), hashes); - - // find out which keys we can match with the current pinned partitions - SelectionVector true_sel(STANDARD_VECTOR_SIZE); - SelectionVector false_sel(STANDARD_VECTOR_SIZE); - const auto true_count = - RadixPartitioning::Select(hashes, FlatVector::IncrementalSelectionVector(), probe_keys.size(), radix_bits, - current_partitions, &true_sel, &false_sel); - const auto false_count = probe_keys.size() - true_count; - - // can't probe these values right now, append to spill - spill_chunk.Reset(); - spill_chunk.Reference(probe_chunk); - spill_chunk.data.back().Reference(hashes); - spill_chunk.Slice(false_sel, false_count); - probe_spill.Append(spill_chunk, spill_state); - - // slice the stuff we CAN probe right now - hashes.Slice(true_sel, true_count); - probe_keys.Slice(true_sel, true_count); - probe_chunk.Slice(true_sel, true_count); - - const SelectionVector *current_sel; - InitializeScanStructure(scan_structure, probe_keys, key_state, current_sel); - if (scan_structure.count == 0) { - return; - } - - // now initialize the pointers of the scan structure based on the hashes - GetRowPointers(probe_keys, key_state, probe_state, hashes, *current_sel, scan_structure.count, - scan_structure.pointers, scan_structure.sel_vector); -} - -ProbeSpill::ProbeSpill(JoinHashTable &ht, ClientContext &context, const vector &probe_types) - : ht(ht), context(context), probe_types(probe_types) { - global_partitions = - make_uniq(context, probe_types, ht.radix_bits, probe_types.size() - 1); - column_ids.reserve(probe_types.size()); - for (column_t column_id = 0; column_id < probe_types.size(); column_id++) { - column_ids.emplace_back(column_id); - } -} - -ProbeSpillLocalState ProbeSpill::RegisterThread() { - ProbeSpillLocalAppendState result; - lock_guard guard(lock); - local_partitions.emplace_back(global_partitions->CreateShared()); - local_partition_append_states.emplace_back(make_uniq()); - local_partitions.back()->InitializeAppendState(*local_partition_append_states.back()); - - result.local_partition = local_partitions.back().get(); - result.local_partition_append_state = local_partition_append_states.back().get(); - return result; -} - -void ProbeSpill::Append(DataChunk &chunk, ProbeSpillLocalAppendState &local_state) { - local_state.local_partition->Append(*local_state.local_partition_append_state, chunk); -} - -void ProbeSpill::Finalize() { - D_ASSERT(local_partitions.size() == local_partition_append_states.size()); - for (idx_t i = 0; i < local_partition_append_states.size(); i++) { - local_partitions[i]->FlushAppendState(*local_partition_append_states[i]); - } - for (auto &local_partition : local_partitions) { - global_partitions->Combine(*local_partition); - } - local_partitions.clear(); - local_partition_append_states.clear(); -} - -void ProbeSpill::PrepareNextProbe() { - global_spill_collection.reset(); - auto &partitions = global_partitions->GetPartitions(); - if (partitions.empty() || ht.current_partitions.CheckAllInvalid(partitions.size())) { - // Can't probe, just make an empty one - global_spill_collection = - make_uniq(BufferManager::GetBufferManager(context), probe_types); - } else { - // Move current partitions to the global spill collection - for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { - if (!ht.current_partitions.RowIsValidUnsafe(partition_idx)) { - continue; - } - auto &partition = partitions[partition_idx]; - if (!global_spill_collection) { - global_spill_collection = std::move(partition); - } else if (partition->Count() != 0) { - global_spill_collection->Combine(*partition); - } - partition.reset(); - } - } - consumer = make_uniq(*global_spill_collection, column_ids); - consumer->InitializeScan(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp deleted file mode 100644 index 583c5946f..000000000 --- a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_inner.cpp +++ /dev/null @@ -1,193 +0,0 @@ -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/execution/nested_loop_join.hpp" - -namespace duckdb { - -struct InitialNestedLoopJoin { - template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { - using MATCH_OP = ComparisonOperationWrapper; - - // initialize phase of nested loop join - // fill lvector and rvector with matches from the base vectors - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(left_size, left_data); - right.ToUnifiedFormat(right_size, right_data); - - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - idx_t result_count = 0; - for (; rpos < right_size; rpos++) { - idx_t right_position = right_data.sel->get_index(rpos); - bool right_is_valid = right_data.validity.RowIsValid(right_position); - for (; lpos < left_size; lpos++) { - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - return result_count; - } - idx_t left_position = left_data.sel->get_index(lpos); - bool left_is_valid = left_data.validity.RowIsValid(left_position); - if (MATCH_OP::Operation(ldata[left_position], rdata[right_position], !left_is_valid, !right_is_valid)) { - // emit tuple - lvector.set_index(result_count, lpos); - rvector.set_index(result_count, rpos); - result_count++; - } - } - lpos = 0; - } - return result_count; - } -}; - -struct RefineNestedLoopJoin { - template - static idx_t Operation(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, idx_t &rpos, - SelectionVector &lvector, SelectionVector &rvector, idx_t current_match_count) { - using MATCH_OP = ComparisonOperationWrapper; - - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(left_size, left_data); - right.ToUnifiedFormat(right_size, right_data); - - // refine phase of the nested loop join - // refine lvector and rvector based on matches of subsequent conditions (in case there are multiple conditions - // in the join) - D_ASSERT(current_match_count > 0); - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - idx_t result_count = 0; - for (idx_t i = 0; i < current_match_count; i++) { - auto lidx = lvector.get_index(i); - auto ridx = rvector.get_index(i); - auto left_idx = left_data.sel->get_index(lidx); - auto right_idx = right_data.sel->get_index(ridx); - bool left_is_valid = left_data.validity.RowIsValid(left_idx); - bool right_is_valid = right_data.validity.RowIsValid(right_idx); - if (MATCH_OP::Operation(ldata[left_idx], rdata[right_idx], !left_is_valid, !right_is_valid)) { - lvector.set_index(result_count, lidx); - rvector.set_index(result_count, ridx); - result_count++; - } - } - return result_count; - } -}; - -template -static idx_t NestedLoopJoinTypeSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, - idx_t current_match_count) { - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT16: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT32: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INT64: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::UINT8: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::UINT16: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::UINT32: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::UINT64: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::INT128: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::UINT128: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::FLOAT: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::DOUBLE: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, rvector, - current_match_count); - case PhysicalType::INTERVAL: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case PhysicalType::VARCHAR: - return NLTYPE::template Operation(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - default: - throw InternalException("Unimplemented type for join!"); - } -} - -template -idx_t NestedLoopJoinComparisonSwitch(Vector &left, Vector &right, idx_t left_size, idx_t right_size, idx_t &lpos, - idx_t &rpos, SelectionVector &lvector, SelectionVector &rvector, - idx_t current_match_count, ExpressionType comparison_type) { - D_ASSERT(left.GetType() == right.GetType()); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_NOTEQUAL: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_LESSTHAN: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_GREATERTHAN: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, - lvector, rvector, current_match_count); - case ExpressionType::COMPARE_DISTINCT_FROM: - return NestedLoopJoinTypeSwitch(left, right, left_size, right_size, lpos, rpos, lvector, - rvector, current_match_count); - default: - throw NotImplementedException("Unimplemented comparison type for join!"); - } -} - -idx_t NestedLoopJoinInner::Perform(idx_t &lpos, idx_t &rpos, DataChunk &left_conditions, DataChunk &right_conditions, - SelectionVector &lvector, SelectionVector &rvector, - const vector &conditions) { - D_ASSERT(left_conditions.ColumnCount() == right_conditions.ColumnCount()); - if (lpos >= left_conditions.size() || rpos >= right_conditions.size()) { - return 0; - } - // for the first condition, lvector and rvector are not set yet - // we initialize them using the InitialNestedLoopJoin - idx_t match_count = NestedLoopJoinComparisonSwitch( - left_conditions.data[0], right_conditions.data[0], left_conditions.size(), right_conditions.size(), lpos, rpos, - lvector, rvector, 0, conditions[0].comparison); - // now resolve the rest of the conditions - for (idx_t i = 1; i < conditions.size(); i++) { - // check if we have run out of tuples to compare - if (match_count == 0) { - return 0; - } - // if not, get the vectors to compare - Vector &l = left_conditions.data[i]; - Vector &r = right_conditions.data[i]; - // then we refine the currently obtained results using the RefineNestedLoopJoin - match_count = NestedLoopJoinComparisonSwitch( - l, r, left_conditions.size(), right_conditions.size(), lpos, rpos, lvector, rvector, match_count, - conditions[i].comparison); - } - return match_count; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp b/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp deleted file mode 100644 index 4bb1c5f29..000000000 --- a/src/duckdb/src/execution/nested_loop_join/nested_loop_join_mark.cpp +++ /dev/null @@ -1,168 +0,0 @@ -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/nested_loop_join.hpp" - -namespace duckdb { - -template -static void TemplatedMarkJoin(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { - using MATCH_OP = ComparisonOperationWrapper; - - UnifiedVectorFormat left_data, right_data; - left.ToUnifiedFormat(lcount, left_data); - right.ToUnifiedFormat(rcount, right_data); - - auto ldata = UnifiedVectorFormat::GetData(left_data); - auto rdata = UnifiedVectorFormat::GetData(right_data); - for (idx_t i = 0; i < lcount; i++) { - if (found_match[i]) { - continue; - } - auto lidx = left_data.sel->get_index(i); - const auto left_null = !left_data.validity.RowIsValid(lidx); - if (!MATCH_OP::COMPARE_NULL && left_null) { - continue; - } - for (idx_t j = 0; j < rcount; j++) { - auto ridx = right_data.sel->get_index(j); - const auto right_null = !right_data.validity.RowIsValid(ridx); - if (!MATCH_OP::COMPARE_NULL && right_null) { - continue; - } - if (MATCH_OP::template Operation(ldata[lidx], rdata[ridx], left_null, right_null)) { - found_match[i] = true; - break; - } - } - } -} - -static void MarkJoinNested(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], - ExpressionType comparison_type) { - Vector left_reference(left.GetType()); - for (idx_t i = 0; i < lcount; i++) { - if (found_match[i]) { - continue; - } - ConstantVector::Reference(left_reference, left, i, rcount); - idx_t count; - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - count = VectorOperations::Equals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_NOTEQUAL: - count = VectorOperations::NotEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_LESSTHAN: - count = VectorOperations::LessThan(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_GREATERTHAN: - count = VectorOperations::GreaterThan(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - count = VectorOperations::LessThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - count = VectorOperations::GreaterThanEquals(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_DISTINCT_FROM: - count = VectorOperations::DistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - count = VectorOperations::NotDistinctFrom(left_reference, right, nullptr, rcount, nullptr, nullptr); - break; - default: - throw InternalException("Unsupported comparison type for MarkJoinNested"); - } - if (count > 0) { - found_match[i] = true; - } - } -} - -template -static void MarkJoinSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[]) { - switch (left.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT16: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT32: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT64: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::INT128: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT8: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT16: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT32: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT64: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::UINT128: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::FLOAT: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::DOUBLE: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - case PhysicalType::VARCHAR: - return TemplatedMarkJoin(left, right, lcount, rcount, found_match); - default: - throw NotImplementedException("Unimplemented type for mark join!"); - } -} - -static void MarkJoinComparisonSwitch(Vector &left, Vector &right, idx_t lcount, idx_t rcount, bool found_match[], - ExpressionType comparison_type) { - switch (left.GetType().InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - case PhysicalType::ARRAY: - return MarkJoinNested(left, right, lcount, rcount, found_match, comparison_type); - default: - break; - } - D_ASSERT(left.GetType() == right.GetType()); - switch (comparison_type) { - case ExpressionType::COMPARE_EQUAL: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_NOTEQUAL: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_LESSTHAN: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_GREATERTHAN: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - case ExpressionType::COMPARE_DISTINCT_FROM: - return MarkJoinSwitch(left, right, lcount, rcount, found_match); - default: - throw NotImplementedException("Unimplemented comparison type for join!"); - } -} - -void NestedLoopJoinMark::Perform(DataChunk &left, ColumnDataCollection &right, bool found_match[], - const vector &conditions) { - // initialize a new temporary selection vector for the left chunk - // loop over all chunks in the RHS - ColumnDataScanState scan_state; - right.InitializeScan(scan_state); - - DataChunk scan_chunk; - right.InitializeScanChunk(scan_chunk); - - while (right.Scan(scan_state, scan_chunk)) { - for (idx_t i = 0; i < conditions.size(); i++) { - MarkJoinComparisonSwitch(left.data[i], scan_chunk.data[i], left.size(), scan_chunk.size(), found_match, - conditions[i].comparison); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp b/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp deleted file mode 100644 index 53fe0368a..000000000 --- a/src/duckdb/src/execution/operator/aggregate/aggregate_object.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" - -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -AggregateObject::AggregateObject(AggregateFunction function, FunctionData *bind_data, idx_t child_count, - idx_t payload_size, AggregateType aggr_type, PhysicalType return_type, - Expression *filter) - : function(std::move(function)), - bind_data_wrapper(bind_data ? make_shared_ptr(bind_data->Copy()) : nullptr), - child_count(child_count), payload_size(payload_size), aggr_type(aggr_type), return_type(return_type), - filter(filter) { -} - -AggregateObject::AggregateObject(BoundAggregateExpression *aggr) - : AggregateObject(aggr->function, aggr->bind_info.get(), aggr->children.size(), - AlignValue(aggr->function.state_size(aggr->function)), aggr->aggr_type, - aggr->return_type.InternalType(), aggr->filter.get()) { -} - -AggregateObject::AggregateObject(const BoundWindowExpression &window) - : AggregateObject(*window.aggregate, window.bind_info.get(), window.children.size(), - AlignValue(window.aggregate->state_size(*window.aggregate)), - window.distinct ? AggregateType::DISTINCT : AggregateType::NON_DISTINCT, - window.return_type.InternalType(), window.filter_expr.get()) { -} - -vector AggregateObject::CreateAggregateObjects(const vector &bindings) { - vector aggregates; - aggregates.reserve(bindings.size()); - for (auto &binding : bindings) { - aggregates.emplace_back(binding); - } - return aggregates; -} - -AggregateFilterData::AggregateFilterData(ClientContext &context, Expression &filter_expr, - const vector &payload_types) - : filter_executor(context, &filter_expr), true_sel(STANDARD_VECTOR_SIZE) { - if (payload_types.empty()) { - return; - } - filtered_payload.Initialize(Allocator::Get(context), payload_types); -} - -idx_t AggregateFilterData::ApplyFilter(DataChunk &payload) { - filtered_payload.Reset(); - - auto count = filter_executor.SelectExpression(payload, true_sel); - filtered_payload.Slice(payload, true_sel, count); - return count; -} - -AggregateFilterDataSet::AggregateFilterDataSet() { -} - -void AggregateFilterDataSet::Initialize(ClientContext &context, const vector &aggregates, - const vector &payload_types) { - bool has_filters = false; - for (auto &aggregate : aggregates) { - if (aggregate.filter) { - has_filters = true; - break; - } - } - if (!has_filters) { - // no filters: nothing to do - return; - } - filter_data.resize(aggregates.size()); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggr = aggregates[aggr_idx]; - if (aggr.filter) { - filter_data[aggr_idx] = make_uniq(context, *aggr.filter, payload_types); - } - } -} - -AggregateFilterData &AggregateFilterDataSet::GetFilterData(idx_t aggr_idx) { - D_ASSERT(aggr_idx < filter_data.size()); - D_ASSERT(filter_data[aggr_idx]); - return *filter_data[aggr_idx]; -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp deleted file mode 100644 index d1d0d20d6..000000000 --- a/src/duckdb/src/execution/operator/aggregate/distinct_aggregate_data.cpp +++ /dev/null @@ -1,216 +0,0 @@ -#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -//! Shared information about a collection of distinct aggregates -DistinctAggregateCollectionInfo::DistinctAggregateCollectionInfo(const vector> &aggregates, - vector indices) - : indices(std::move(indices)), aggregates(aggregates) { - table_count = CreateTableIndexMap(); - - const idx_t aggregate_count = aggregates.size(); - - total_child_count = 0; - for (idx_t i = 0; i < aggregate_count; i++) { - auto &aggregate = aggregates[i]->Cast(); - - if (!aggregate.IsDistinct()) { - continue; - } - total_child_count += aggregate.children.size(); - } -} - -//! Stateful data for the distinct aggregates - -DistinctAggregateState::DistinctAggregateState(const DistinctAggregateData &data, ClientContext &client) - : child_executor(client) { - - radix_states.resize(data.info.table_count); - distinct_output_chunks.resize(data.info.table_count); - - idx_t aggregate_count = data.info.aggregates.size(); - for (idx_t i = 0; i < aggregate_count; i++) { - auto &aggregate = data.info.aggregates[i]->Cast(); - - // Initialize the child executor and get the payload types for every aggregate - for (auto &child : aggregate.children) { - child_executor.AddExpression(*child); - } - if (!aggregate.IsDistinct()) { - continue; - } - D_ASSERT(data.info.table_map.count(i)); - idx_t table_idx = data.info.table_map.at(i); - if (data.radix_tables[table_idx] == nullptr) { - //! This table is unused because the aggregate shares its data with another - continue; - } - - // Get the global sinkstate for the aggregate - auto &radix_table = *data.radix_tables[table_idx]; - radix_states[table_idx] = radix_table.GetGlobalSinkState(client); - - // Fill the chunk_types (group_by + children) - vector chunk_types; - for (auto &group_type : data.grouped_aggregate_data[table_idx]->group_types) { - chunk_types.push_back(group_type); - } - - // This is used in Finalize to get the data from the radix table - distinct_output_chunks[table_idx] = make_uniq(); - distinct_output_chunks[table_idx]->Initialize(client, chunk_types); - } -} - -//! Persistent + shared (read-only) data for the distinct aggregates -DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info) - : DistinctAggregateData(info, {}, nullptr) { -} - -DistinctAggregateData::DistinctAggregateData(const DistinctAggregateCollectionInfo &info, const GroupingSet &groups, - const vector> *group_expressions) - : info(info) { - grouped_aggregate_data.resize(info.table_count); - radix_tables.resize(info.table_count); - grouping_sets.resize(info.table_count); - - for (auto &i : info.indices) { - auto &aggregate = info.aggregates[i]->Cast(); - - D_ASSERT(info.table_map.count(i)); - idx_t table_idx = info.table_map.at(i); - if (radix_tables[table_idx] != nullptr) { - //! This aggregate shares a table with another aggregate, and the table is already initialized - continue; - } - // The grouping set contains the indices of the chunk that correspond to the data vector - // that will be used to figure out in which bucket the payload should be put - auto &grouping_set = grouping_sets[table_idx]; - //! Populate the group with the children of the aggregate - for (auto &group : groups) { - grouping_set.insert(group); - } - idx_t group_by_size = group_expressions ? group_expressions->size() : 0; - for (idx_t set_idx = 0; set_idx < aggregate.children.size(); set_idx++) { - grouping_set.insert(set_idx + group_by_size); - } - // Create the hashtable for the aggregate - grouped_aggregate_data[table_idx] = make_uniq(); - grouped_aggregate_data[table_idx]->InitializeDistinct(info.aggregates[i], group_expressions); - radix_tables[table_idx] = - make_uniq(grouping_set, *grouped_aggregate_data[table_idx]); - - // Fill the chunk_types (only contains the payload of the distinct aggregates) - vector chunk_types; - for (auto &child_p : aggregate.children) { - chunk_types.push_back(child_p->return_type); - } - } -} - -using aggr_ref_t = reference; - -struct FindMatchingAggregate { - explicit FindMatchingAggregate(const aggr_ref_t &aggr) : aggr_r(aggr) { - } - bool operator()(const aggr_ref_t other_r) { - auto &other = other_r.get(); - auto &aggr = aggr_r.get(); - if (other.children.size() != aggr.children.size()) { - return false; - } - if (!Expression::Equals(aggr.filter, other.filter)) { - return false; - } - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &other_child = other.children[i]->Cast(); - auto &aggr_child = aggr.children[i]->Cast(); - if (other_child.index != aggr_child.index) { - return false; - } - } - return true; - } - const aggr_ref_t aggr_r; -}; - -idx_t DistinctAggregateCollectionInfo::CreateTableIndexMap() { - vector table_inputs; - - D_ASSERT(table_map.empty()); - for (auto &agg_idx : indices) { - D_ASSERT(agg_idx < aggregates.size()); - auto &aggregate = aggregates[agg_idx]->Cast(); - - auto matching_inputs = - std::find_if(table_inputs.begin(), table_inputs.end(), FindMatchingAggregate(std::ref(aggregate))); - if (matching_inputs != table_inputs.end()) { - //! Assign the existing table to the aggregate - auto found_idx = NumericCast(std::distance(table_inputs.begin(), matching_inputs)); - table_map[agg_idx] = found_idx; - continue; - } - //! Create a new table and assign its index to the aggregate - table_map[agg_idx] = table_inputs.size(); - table_inputs.push_back(std::ref(aggregate)); - } - //! Every distinct aggregate needs to be assigned an index - D_ASSERT(table_map.size() == indices.size()); - //! There can not be more tables than there are distinct aggregates - D_ASSERT(table_inputs.size() <= indices.size()); - - return table_inputs.size(); -} - -bool DistinctAggregateCollectionInfo::AnyDistinct() const { - return !indices.empty(); -} - -const unsafe_vector &DistinctAggregateCollectionInfo::Indices() const { - return this->indices; -} - -static vector GetDistinctIndices(vector> &aggregates) { - vector distinct_indices; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - if (aggr.IsDistinct()) { - distinct_indices.push_back(i); - } - } - return distinct_indices; -} - -unique_ptr -DistinctAggregateCollectionInfo::Create(vector> &aggregates) { - vector indices = GetDistinctIndices(aggregates); - if (indices.empty()) { - return nullptr; - } - return make_uniq(aggregates, std::move(indices)); -} - -bool DistinctAggregateData::IsDistinct(idx_t index) const { - bool is_distinct = !radix_tables.empty() && info.table_map.count(index); -#ifdef DEBUG - //! Make sure that if it is distinct, it's also in the indices - //! And if it's not distinct, that it's also not in the indices - bool found = false; - for (auto &idx : info.indices) { - if (idx == index) { - found = true; - break; - } - } - D_ASSERT(found == is_distinct); -#endif - return is_distinct; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp b/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp deleted file mode 100644 index 088a0f59a..000000000 --- a/src/duckdb/src/execution/operator/aggregate/grouped_aggregate_data.cpp +++ /dev/null @@ -1,96 +0,0 @@ -#include "duckdb/execution/operator/aggregate/grouped_aggregate_data.hpp" - -namespace duckdb { - -idx_t GroupedAggregateData::GroupCount() const { - return groups.size(); -} - -const vector> &GroupedAggregateData::GetGroupingFunctions() const { - return grouping_functions; -} - -void GroupedAggregateData::InitializeGroupby(vector> groups, - vector> expressions, - vector> grouping_functions) { - InitializeGroupbyGroups(std::move(groups)); - vector payload_types_filters; - - SetGroupingFunctions(grouping_functions); - - filter_count = 0; - for (auto &expr : expressions) { - D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - D_ASSERT(expr->IsAggregate()); - auto &aggr = expr->Cast(); - bindings.push_back(&aggr); - - aggregate_return_types.push_back(aggr.return_type); - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - } - if (aggr.filter) { - filter_count++; - payload_types_filters.push_back(aggr.filter->return_type); - } - if (!aggr.function.combine) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); - } - aggregates.push_back(std::move(expr)); - } - for (const auto &pay_filters : payload_types_filters) { - payload_types.push_back(pay_filters); - } -} - -void GroupedAggregateData::InitializeDistinct(const unique_ptr &aggregate, - const vector> *groups_p) { - auto &aggr = aggregate->Cast(); - D_ASSERT(aggr.IsDistinct()); - - // Add the (empty in ungrouped case) groups of the aggregates - InitializeDistinctGroups(groups_p); - - // bindings.push_back(&aggr); - filter_count = 0; - aggregate_return_types.push_back(aggr.return_type); - for (idx_t i = 0; i < aggr.children.size(); i++) { - auto &child = aggr.children[i]; - group_types.push_back(child->return_type); - groups.push_back(child->Copy()); - payload_types.push_back(child->return_type); - if (aggr.filter) { - filter_count++; - } - } - if (!aggr.function.combine) { - throw InternalException("Aggregate function %s is missing a combine method", aggr.function.name); - } -} - -void GroupedAggregateData::InitializeDistinctGroups(const vector> *groups_p) { - if (!groups_p) { - return; - } - for (auto &expr : *groups_p) { - group_types.push_back(expr->return_type); - groups.push_back(expr->Copy()); - } -} - -void GroupedAggregateData::InitializeGroupbyGroups(vector> groups) { - // Add all the expressions of the group by clause - for (auto &expr : groups) { - group_types.push_back(expr->return_type); - } - this->groups = std::move(groups); -} - -void GroupedAggregateData::SetGroupingFunctions(vector> &functions) { - grouping_functions.reserve(functions.size()); - for (idx_t i = 0; i < functions.size(); i++) { - grouping_functions.push_back(std::move(functions[i])); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp deleted file mode 100644 index 6636389ea..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_hash_aggregate.cpp +++ /dev/null @@ -1,928 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" - -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/atomic.hpp" -#include "duckdb/common/optional_idx.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/interrupt.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/parallel/executor_task.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -namespace duckdb { - -HashAggregateGroupingData::HashAggregateGroupingData(GroupingSet &grouping_set_p, - const GroupedAggregateData &grouped_aggregate_data, - unique_ptr &info) - : table_data(grouping_set_p, grouped_aggregate_data) { - if (info) { - distinct_data = make_uniq(*info, grouping_set_p, &grouped_aggregate_data.groups); - } -} - -bool HashAggregateGroupingData::HasDistinct() const { - return distinct_data != nullptr; -} - -HashAggregateGroupingGlobalState::HashAggregateGroupingGlobalState(const HashAggregateGroupingData &data, - ClientContext &context) { - table_state = data.table_data.GetGlobalSinkState(context); - if (data.HasDistinct()) { - distinct_state = make_uniq(*data.distinct_data, context); - } -} - -HashAggregateGroupingLocalState::HashAggregateGroupingLocalState(const PhysicalHashAggregate &op, - const HashAggregateGroupingData &data, - ExecutionContext &context) { - table_state = data.table_data.GetLocalSinkState(context); - if (!data.HasDistinct()) { - return; - } - auto &distinct_data = *data.distinct_data; - - auto &distinct_indices = op.distinct_collection_info->Indices(); - D_ASSERT(!distinct_indices.empty()); - - distinct_states.resize(op.distinct_collection_info->aggregates.size()); - auto &table_map = op.distinct_collection_info->table_map; - - for (auto &idx : distinct_indices) { - idx_t table_idx = table_map[idx]; - auto &radix_table = distinct_data.radix_tables[table_idx]; - if (radix_table == nullptr) { - // This aggregate has identical input as another aggregate, so no table is created for it - continue; - } - // Initialize the states of the radix tables used for the distinct aggregates - distinct_states[table_idx] = radix_table->GetLocalSinkState(context); - } -} - -static vector CreateGroupChunkTypes(vector> &groups) { - set group_indices; - - if (groups.empty()) { - return {}; - } - - for (auto &group : groups) { - D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref = group->Cast(); - group_indices.insert(bound_ref.index); - } - idx_t highest_index = *group_indices.rbegin(); - vector types(highest_index + 1, LogicalType::SQLNULL); - for (auto &group : groups) { - auto &bound_ref = group->Cast(); - types[bound_ref.index] = bound_ref.return_type; - } - return types; -} - -bool PhysicalHashAggregate::CanSkipRegularSink() const { - if (!filter_indexes.empty()) { - // If we have filters, we can't skip the regular sink, because we might lose groups otherwise. - return false; - } - if (grouped_aggregate_data.aggregates.empty()) { - // When there are no aggregates, we have to add to the main ht right away - return false; - } - if (!non_distinct_filter.empty()) { - return false; - } - return true; -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, idx_t estimated_cardinality) - : PhysicalHashAggregate(context, std::move(types), std::move(expressions), {}, estimated_cardinality) { -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, - vector> groups_p, idx_t estimated_cardinality) - : PhysicalHashAggregate(context, std::move(types), std::move(expressions), std::move(groups_p), {}, {}, - estimated_cardinality) { -} - -PhysicalHashAggregate::PhysicalHashAggregate(ClientContext &context, vector types, - vector> expressions, - vector> groups_p, - vector grouping_sets_p, - vector> grouping_functions_p, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::HASH_GROUP_BY, std::move(types), estimated_cardinality), - grouping_sets(std::move(grouping_sets_p)) { - // get a list of all aggregates to be computed - const idx_t group_count = groups_p.size(); - if (grouping_sets.empty()) { - GroupingSet set; - for (idx_t i = 0; i < group_count; i++) { - set.insert(i); - } - grouping_sets.push_back(std::move(set)); - } - input_group_types = CreateGroupChunkTypes(groups_p); - - grouped_aggregate_data.InitializeGroupby(std::move(groups_p), std::move(expressions), - std::move(grouping_functions_p)); - - auto &aggregates = grouped_aggregate_data.aggregates; - // filter_indexes must be pre-built, not lazily instantiated in parallel... - // Because everything that lives in this class should be read-only at execution time - idx_t aggregate_input_idx = 0; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); - if (aggr.aggr_type == AggregateType::DISTINCT) { - distinct_filter.push_back(i); - } else if (aggr.aggr_type == AggregateType::NON_DISTINCT) { - non_distinct_filter.push_back(i); - } else { // LCOV_EXCL_START - throw NotImplementedException("AggregateType not implemented in PhysicalHashAggregate"); - } // LCOV_EXCL_STOP - } - - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]; - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - if (!filter_indexes.count(aggr.filter.get())) { - // Replace the bound reference expression's index with the corresponding index of the payload chunk - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx; - } - aggregate_input_idx++; - } - } - - distinct_collection_info = DistinctAggregateCollectionInfo::Create(grouped_aggregate_data.aggregates); - - for (idx_t i = 0; i < grouping_sets.size(); i++) { - groupings.emplace_back(grouping_sets[i], grouped_aggregate_data, distinct_collection_info); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class HashAggregateGlobalSinkState : public GlobalSinkState { -public: - HashAggregateGlobalSinkState(const PhysicalHashAggregate &op, ClientContext &context) { - grouping_states.reserve(op.groupings.size()); - for (idx_t i = 0; i < op.groupings.size(); i++) { - auto &grouping = op.groupings[i]; - grouping_states.emplace_back(grouping, context); - } - vector filter_types; - for (auto &aggr : op.grouped_aggregate_data.aggregates) { - auto &aggregate = aggr->Cast(); - for (auto &child : aggregate.children) { - payload_types.push_back(child->return_type); - } - if (aggregate.filter) { - filter_types.push_back(aggregate.filter->return_type); - } - } - payload_types.reserve(payload_types.size() + filter_types.size()); - payload_types.insert(payload_types.end(), filter_types.begin(), filter_types.end()); - } - - vector grouping_states; - vector payload_types; - //! Whether or not the aggregate is finished - bool finished = false; -}; - -class HashAggregateLocalSinkState : public LocalSinkState { -public: - HashAggregateLocalSinkState(const PhysicalHashAggregate &op, ExecutionContext &context) { - - auto &payload_types = op.grouped_aggregate_data.payload_types; - if (!payload_types.empty()) { - aggregate_input_chunk.InitializeEmpty(payload_types); - } - - grouping_states.reserve(op.groupings.size()); - for (auto &grouping : op.groupings) { - grouping_states.emplace_back(op, grouping, context); - } - // The filter set is only needed here for the distinct aggregates - // the filtering of data for the regular aggregates is done within the hashtable - vector aggregate_objects; - for (auto &aggregate : op.grouped_aggregate_data.aggregates) { - auto &aggr = aggregate->Cast(); - aggregate_objects.emplace_back(&aggr); - } - - filter_set.Initialize(context.client, aggregate_objects, payload_types); - } - - DataChunk aggregate_input_chunk; - vector grouping_states; - AggregateFilterDataSet filter_set; -}; - -void PhysicalHashAggregate::SetMultiScan(GlobalSinkState &state) { - auto &gstate = state.Cast(); - for (auto &grouping_state : gstate.grouping_states) { - RadixPartitionedHashTable::SetMultiScan(*grouping_state.table_state); - if (!grouping_state.distinct_state) { - continue; - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -unique_ptr PhysicalHashAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashAggregate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context); -} - -void PhysicalHashAggregate::SinkDistinctGrouping(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, - idx_t grouping_idx) const { - auto &sink = input.local_state.Cast(); - auto &global_sink = input.global_state.Cast(); - - auto &grouping_gstate = global_sink.grouping_states[grouping_idx]; - auto &grouping_lstate = sink.grouping_states[grouping_idx]; - auto &distinct_info = *distinct_collection_info; - - auto &distinct_state = grouping_gstate.distinct_state; - auto &distinct_data = groupings[grouping_idx].distinct_data; - - DataChunk empty_chunk; - - // Create an empty filter for Sink, since we don't need to update any aggregate states here - unsafe_vector empty_filter; - - for (idx_t &idx : distinct_info.indices) { - auto &aggregate = grouped_aggregate_data.aggregates[idx]->Cast(); - - D_ASSERT(distinct_info.table_map.count(idx)); - idx_t table_idx = distinct_info.table_map[idx]; - if (!distinct_data->radix_tables[table_idx]) { - continue; - } - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; - - InterruptState interrupt_state; - OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, interrupt_state}; - - if (aggregate.filter) { - DataChunk filter_chunk; - auto &filtered_data = sink.filter_set.GetFilterData(idx); - filter_chunk.InitializeEmpty(filtered_data.filtered_payload.GetTypes()); - - // Add the filter Vector (BOOL) - auto it = filter_indexes.find(aggregate.filter.get()); - D_ASSERT(it != filter_indexes.end()); - D_ASSERT(it->second < chunk.data.size()); - auto &filter_bound_ref = aggregate.filter->Cast(); - filter_chunk.data[filter_bound_ref.index].Reference(chunk.data[it->second]); - filter_chunk.SetCardinality(chunk.size()); - - // We cant use the AggregateFilterData::ApplyFilter method, because the chunk we need to - // apply the filter to also has the groups, and the filtered_data.filtered_payload does not have those. - SelectionVector sel_vec(STANDARD_VECTOR_SIZE); - idx_t count = filtered_data.filter_executor.SelectExpression(filter_chunk, sel_vec); - - if (count == 0) { - continue; - } - - // Because the 'input' chunk needs to be re-used after this, we need to create - // a duplicate of it, that we can apply the filter to - DataChunk filtered_input; - filtered_input.InitializeEmpty(chunk.GetTypes()); - - for (idx_t group_idx = 0; group_idx < grouped_aggregate_data.groups.size(); group_idx++) { - auto &group = grouped_aggregate_data.groups[group_idx]; - auto &bound_ref = group->Cast(); - auto &col = filtered_input.data[bound_ref.index]; - col.Reference(chunk.data[bound_ref.index]); - col.Slice(sel_vec, count); - } - for (idx_t child_idx = 0; child_idx < aggregate.children.size(); child_idx++) { - auto &child = aggregate.children[child_idx]; - auto &bound_ref = child->Cast(); - auto &col = filtered_input.data[bound_ref.index]; - col.Reference(chunk.data[bound_ref.index]); - col.Slice(sel_vec, count); - } - filtered_input.SetCardinality(count); - - radix_table.Sink(context, filtered_input, sink_input, empty_chunk, empty_filter); - } else { - radix_table.Sink(context, chunk, sink_input, empty_chunk, empty_filter); - } - } -} - -void PhysicalHashAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - for (idx_t i = 0; i < groupings.size(); i++) { - SinkDistinctGrouping(context, chunk, input, i); - } -} - -SinkResultType PhysicalHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &local_state = input.local_state.Cast(); - auto &global_state = input.global_state.Cast(); - - if (distinct_collection_info) { - SinkDistinct(context, chunk, input); - } - - if (CanSkipRegularSink()) { - return SinkResultType::NEED_MORE_INPUT; - } - - DataChunk &aggregate_input_chunk = local_state.aggregate_input_chunk; - auto &aggregates = grouped_aggregate_data.aggregates; - idx_t aggregate_input_idx = 0; - - // Populate the aggregate child vectors - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { - D_ASSERT(child_expr->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref_expr = child_expr->Cast(); - D_ASSERT(bound_ref_expr.index < chunk.data.size()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); - } - } - // Populate the filter vectors - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); - D_ASSERT(it != filter_indexes.end()); - D_ASSERT(it->second < chunk.data.size()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); - } - } - - aggregate_input_chunk.SetCardinality(chunk.size()); - aggregate_input_chunk.Verify(); - - // For every grouping set there is one radix_table - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_global_state = global_state.grouping_states[i]; - auto &grouping_local_state = local_state.grouping_states[i]; - InterruptState interrupt_state; - OperatorSinkInput sink_input {*grouping_global_state.table_state, *grouping_local_state.table_state, - interrupt_state}; - - auto &grouping = groupings[i]; - auto &table = grouping.table_data; - table.Sink(context, chunk, sink_input, aggregate_input_chunk, non_distinct_filter); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void PhysicalHashAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { - - auto &global_sink = input.global_state.Cast(); - auto &sink = input.local_state.Cast(); - - if (!distinct_collection_info) { - return; - } - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_gstate = global_sink.grouping_states[i]; - auto &grouping_lstate = sink.grouping_states[i]; - - auto &distinct_data = groupings[i].distinct_data; - auto &distinct_state = grouping_gstate.distinct_state; - - const auto table_count = distinct_data->radix_tables.size(); - for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { - if (!distinct_data->radix_tables[table_idx]) { - continue; - } - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *grouping_lstate.distinct_states[table_idx]; - - radix_table.Combine(context, radix_global_sink, radix_local_sink); - } - } -} - -SinkCombineResultType PhysicalHashAggregate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &llstate = input.local_state.Cast(); - - OperatorSinkCombineInput combine_distinct_input {gstate, llstate, input.interrupt_state}; - CombineDistinct(context, combine_distinct_input); - - if (CanSkipRegularSink()) { - return SinkCombineResultType::FINISHED; - } - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping_gstate = gstate.grouping_states[i]; - auto &grouping_lstate = llstate.grouping_states[i]; - - auto &grouping = groupings[i]; - auto &table = grouping.table_data; - table.Combine(context, *grouping_gstate.table_state, *grouping_lstate.table_state); - } - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -class HashAggregateFinalizeEvent : public BasePipelineEvent { -public: - //! "Regular" Finalize Event that is scheduled after combining the thread-local distinct HTs - HashAggregateFinalizeEvent(ClientContext &context, Pipeline *pipeline_p, const PhysicalHashAggregate &op_p, - HashAggregateGlobalSinkState &gstate_p) - : BasePipelineEvent(*pipeline_p), context(context), op(op_p), gstate(gstate_p) { - } - -public: - void Schedule() override; - -private: - ClientContext &context; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; -}; - -class HashAggregateFinalizeTask : public ExecutorTask { -public: - HashAggregateFinalizeTask(ClientContext &context, Pipeline &pipeline, shared_ptr event_p, - const PhysicalHashAggregate &op, HashAggregateGlobalSinkState &state_p) - : ExecutorTask(pipeline.executor, std::move(event_p)), context(context), pipeline(pipeline), op(op), - gstate(state_p) { - } - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - ClientContext &context; - Pipeline &pipeline; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; -}; - -void HashAggregateFinalizeEvent::Schedule() { - vector> tasks; - tasks.push_back(make_uniq(context, *pipeline, shared_from_this(), op, gstate)); - D_ASSERT(!tasks.empty()); - SetTasks(std::move(tasks)); -} - -TaskExecutionResult HashAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - op.FinalizeInternal(pipeline, *event, context, gstate, false); - D_ASSERT(!gstate.finished); - gstate.finished = true; - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -class HashAggregateDistinctFinalizeEvent : public BasePipelineEvent { -public: - //! Distinct Finalize Event that is scheduled if we have distinct aggregates - HashAggregateDistinctFinalizeEvent(ClientContext &context, Pipeline &pipeline_p, const PhysicalHashAggregate &op_p, - HashAggregateGlobalSinkState &gstate_p) - : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p) { - } - -public: - void Schedule() override; - void FinishEvent() override; - -private: - idx_t CreateGlobalSources(); - -private: - ClientContext &context; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; - -public: - //! The GlobalSourceStates for all the radix tables of the distinct aggregates - vector>> global_source_states; -}; - -class HashAggregateDistinctFinalizeTask : public ExecutorTask { -public: - HashAggregateDistinctFinalizeTask(Pipeline &pipeline, shared_ptr event_p, const PhysicalHashAggregate &op, - HashAggregateGlobalSinkState &state_p) - : ExecutorTask(pipeline.executor, std::move(event_p)), pipeline(pipeline), op(op), gstate(state_p) { - } - -public: - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - TaskExecutionResult AggregateDistinctGrouping(const idx_t grouping_idx); - -private: - Pipeline &pipeline; - - const PhysicalHashAggregate &op; - HashAggregateGlobalSinkState &gstate; - - unique_ptr local_sink_state; - idx_t grouping_idx = 0; - unique_ptr radix_table_lstate; - bool blocked = false; - idx_t aggregation_idx = 0; - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; -}; - -void HashAggregateDistinctFinalizeEvent::Schedule() { - auto n_tasks = CreateGlobalSources(); - n_tasks = MinValue(n_tasks, NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())); - vector> tasks; - for (idx_t i = 0; i < n_tasks; i++) { - tasks.push_back(make_uniq(*pipeline, shared_from_this(), op, gstate)); - } - SetTasks(std::move(tasks)); -} - -idx_t HashAggregateDistinctFinalizeEvent::CreateGlobalSources() { - auto &aggregates = op.grouped_aggregate_data.aggregates; - global_source_states.reserve(op.groupings.size()); - - idx_t n_tasks = 0; - for (idx_t grouping_idx = 0; grouping_idx < op.groupings.size(); grouping_idx++) { - auto &grouping = op.groupings[grouping_idx]; - auto &distinct_state = *gstate.grouping_states[grouping_idx].distinct_state; - auto &distinct_data = *grouping.distinct_data; - - vector> aggregate_sources; - aggregate_sources.reserve(aggregates.size()); - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]; - auto &aggr = aggregate->Cast(); - - if (!aggr.IsDistinct()) { - aggregate_sources.push_back(nullptr); - continue; - } - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - - auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table_p = distinct_data.radix_tables[table_idx]; - n_tasks += radix_table_p->MaxThreads(*distinct_state.radix_states[table_idx]); - aggregate_sources.push_back(radix_table_p->GetGlobalSourceState(context)); - } - global_source_states.push_back(std::move(aggregate_sources)); - } - - return MaxValue(n_tasks, 1); -} - -void HashAggregateDistinctFinalizeEvent::FinishEvent() { - // Now that everything is added to the main ht, we can actually finalize - auto new_event = make_shared_ptr(context, pipeline.get(), op, gstate); - this->InsertEvent(std::move(new_event)); -} - -TaskExecutionResult HashAggregateDistinctFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - for (; grouping_idx < op.groupings.size(); grouping_idx++) { - auto res = AggregateDistinctGrouping(grouping_idx); - if (res == TaskExecutionResult::TASK_BLOCKED) { - return res; - } - D_ASSERT(res == TaskExecutionResult::TASK_FINISHED); - aggregation_idx = 0; - payload_idx = 0; - next_payload_idx = 0; - local_sink_state = nullptr; - } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -TaskExecutionResult HashAggregateDistinctFinalizeTask::AggregateDistinctGrouping(const idx_t grouping_idx) { - D_ASSERT(op.distinct_collection_info); - auto &info = *op.distinct_collection_info; - - auto &grouping_data = op.groupings[grouping_idx]; - auto &grouping_state = gstate.grouping_states[grouping_idx]; - D_ASSERT(grouping_state.distinct_state); - auto &distinct_state = *grouping_state.distinct_state; - auto &distinct_data = *grouping_data.distinct_data; - - auto &aggregates = info.aggregates; - - // Thread-local contexts - ThreadContext thread_context(executor.context); - ExecutionContext execution_context(executor.context, thread_context, &pipeline); - - // Sink state to sink into global HTs - InterruptState interrupt_state(shared_from_this()); - auto &global_sink_state = *grouping_state.table_state; - if (!local_sink_state) { - local_sink_state = grouping_data.table_data.GetLocalSinkState(execution_context); - } - OperatorSinkInput sink_input {global_sink_state, *local_sink_state, interrupt_state}; - - // Create a chunk that mimics the 'input' chunk in Sink, for storing the group vectors - DataChunk group_chunk; - if (!op.input_group_types.empty()) { - group_chunk.Initialize(executor.context, op.input_group_types); - } - - const idx_t group_by_size = op.grouped_aggregate_data.groups.size(); - - DataChunk aggregate_input_chunk; - if (!gstate.payload_types.empty()) { - aggregate_input_chunk.Initialize(executor.context, gstate.payload_types); - } - - const auto &finalize_event = event->Cast(); - - auto &agg_idx = aggregation_idx; - for (; agg_idx < op.grouped_aggregate_data.aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - if (!blocked) { - // Forward the payload idx - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - } - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - const auto &table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table = distinct_data.radix_tables[table_idx]; - - auto &sink = *distinct_state.radix_states[table_idx]; - if (!blocked) { - radix_table_lstate = radix_table->GetLocalSourceState(execution_context); - } - auto &local_source = *radix_table_lstate; - OperatorSourceInput source_input {*finalize_event.global_source_states[grouping_idx][agg_idx], local_source, - interrupt_state}; - - // Create a duplicate of the output_chunk, because of multi-threading we cant alter the original - DataChunk output_chunk; - output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); - - // Fetch all the data from the aggregate ht, and Sink it into the main ht - while (true) { - output_chunk.Reset(); - group_chunk.Reset(); - aggregate_input_chunk.Reset(); - - auto res = radix_table->GetData(execution_context, output_chunk, sink, source_input); - if (res == SourceResultType::FINISHED) { - D_ASSERT(output_chunk.size() == 0); - break; - } else if (res == SourceResultType::BLOCKED) { - blocked = true; - return TaskExecutionResult::TASK_BLOCKED; - } - - auto &grouped_aggregate_data = *distinct_data.grouped_aggregate_data[table_idx]; - for (idx_t group_idx = 0; group_idx < group_by_size; group_idx++) { - auto &group = grouped_aggregate_data.groups[group_idx]; - auto &bound_ref_expr = group->Cast(); - group_chunk.data[bound_ref_expr.index].Reference(output_chunk.data[group_idx]); - } - group_chunk.SetCardinality(output_chunk); - - for (idx_t child_idx = 0; child_idx < grouped_aggregate_data.groups.size() - group_by_size; child_idx++) { - aggregate_input_chunk.data[payload_idx + child_idx].Reference( - output_chunk.data[group_by_size + child_idx]); - } - aggregate_input_chunk.SetCardinality(output_chunk); - - // Sink it into the main ht - grouping_data.table_data.Sink(execution_context, group_chunk, sink_input, aggregate_input_chunk, {agg_idx}); - } - blocked = false; - } - grouping_data.table_data.Combine(execution_context, global_sink_state, *local_sink_state); - return TaskExecutionResult::TASK_FINISHED; -} - -SinkFinalizeType PhysicalHashAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - D_ASSERT(distinct_collection_info); - - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping = groupings[i]; - auto &distinct_data = *grouping.distinct_data; - auto &distinct_state = *gstate.grouping_states[i].distinct_state; - - for (idx_t table_idx = 0; table_idx < distinct_data.radix_tables.size(); table_idx++) { - if (!distinct_data.radix_tables[table_idx]) { - continue; - } - auto &radix_table = distinct_data.radix_tables[table_idx]; - auto &radix_state = *distinct_state.radix_states[table_idx]; - radix_table->Finalize(context, radix_state); - } - } - auto new_event = make_shared_ptr(context, pipeline, *this, gstate); - event.InsertEvent(std::move(new_event)); - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalHashAggregate::FinalizeInternal(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p, bool check_distinct) const { - auto &gstate = gstate_p.Cast(); - - if (check_distinct && distinct_collection_info) { - // There are distinct aggregates - // If these are partitioned those need to be combined first - // Then we Finalize again, skipping this step - return FinalizeDistinct(pipeline, event, context, gstate_p); - } - - for (idx_t i = 0; i < groupings.size(); i++) { - auto &grouping = groupings[i]; - auto &grouping_gstate = gstate.grouping_states[i]; - grouping.table_data.Finalize(context, *grouping_gstate.table_state); - } - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalHashAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - return FinalizeInternal(pipeline, event, context, input.global_state, true); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class HashAggregateGlobalSourceState : public GlobalSourceState { -public: - HashAggregateGlobalSourceState(ClientContext &context, const PhysicalHashAggregate &op) : op(op), state_index(0) { - for (auto &grouping : op.groupings) { - auto &rt = grouping.table_data; - radix_states.push_back(rt.GetGlobalSourceState(context)); - } - } - - const PhysicalHashAggregate &op; - atomic state_index; - - vector> radix_states; - -public: - idx_t MaxThreads() override { - // If there are no tables, we only need one thread. - if (op.groupings.empty()) { - return 1; - } - - auto &ht_state = op.sink_state->Cast(); - idx_t threads = 0; - for (size_t sidx = 0; sidx < op.groupings.size(); ++sidx) { - auto &grouping = op.groupings[sidx]; - auto &grouping_gstate = ht_state.grouping_states[sidx]; - threads += grouping.table_data.MaxThreads(*grouping_gstate.table_state); - } - return MaxValue(1, threads); - } -}; - -unique_ptr PhysicalHashAggregate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -class HashAggregateLocalSourceState : public LocalSourceState { -public: - explicit HashAggregateLocalSourceState(ExecutionContext &context, const PhysicalHashAggregate &op) { - for (auto &grouping : op.groupings) { - auto &rt = grouping.table_data; - radix_states.push_back(rt.GetLocalSourceState(context)); - } - } - - optional_idx radix_idx; - vector> radix_states; -}; - -unique_ptr PhysicalHashAggregate::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink_gstate = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - while (true) { - if (!lstate.radix_idx.IsValid()) { - lstate.radix_idx = gstate.state_index.load(); - } - const auto radix_idx = lstate.radix_idx.GetIndex(); - if (radix_idx >= groupings.size()) { - break; - } - - auto &grouping = groupings[radix_idx]; - auto &radix_table = grouping.table_data; - auto &grouping_gstate = sink_gstate.grouping_states[radix_idx]; - - OperatorSourceInput source_input {*gstate.radix_states[radix_idx], *lstate.radix_states[radix_idx], - input.interrupt_state}; - auto res = radix_table.GetData(context, chunk, *grouping_gstate.table_state, source_input); - if (res == SourceResultType::BLOCKED) { - return res; - } - if (chunk.size() != 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } - - // move to the next table - auto guard = gstate.Lock(); - lstate.radix_idx = lstate.radix_idx.GetIndex() + 1; - if (lstate.radix_idx.GetIndex() > gstate.state_index) { - // we have not yet worked on the table - // move the global index forwards - gstate.state_index = lstate.radix_idx.GetIndex(); - } - lstate.radix_idx = gstate.state_index.load(); - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -ProgressData PhysicalHashAggregate::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &sink_gstate = sink_state->Cast(); - auto &gstate = gstate_p.Cast(); - ProgressData progress; - for (idx_t radix_idx = 0; radix_idx < groupings.size(); radix_idx++) { - progress.Add(groupings[radix_idx].table_data.GetProgress( - context, *sink_gstate.grouping_states[radix_idx].table_state, *gstate.radix_states[radix_idx])); - } - return progress; -} - -InsertionOrderPreservingMap PhysicalHashAggregate::ParamsToString() const { - InsertionOrderPreservingMap result; - auto &groups = grouped_aggregate_data.groups; - auto &aggregates = grouped_aggregate_data.aggregates; - string groups_info; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - groups_info += "\n"; - } - groups_info += groups[i]->GetName(); - } - result["Groups"] = groups_info; - - string aggregate_info; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]->Cast(); - if (i > 0) { - aggregate_info += "\n"; - } - aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); - } - } - result["Aggregates"] = aggregate_info; - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp deleted file mode 100644 index 32bf4ecc2..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_partitioned_aggregate.cpp +++ /dev/null @@ -1,226 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp" -#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" -#include "duckdb/common/types/value_map.hpp" - -namespace duckdb { - -PhysicalPartitionedAggregate::PhysicalPartitionedAggregate(ClientContext &context, vector types, - vector> aggregates_p, - vector> groups_p, - vector partitions_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::PARTITIONED_AGGREGATE, std::move(types), estimated_cardinality), - partitions(std::move(partitions_p)), groups(std::move(groups_p)), aggregates(std::move(aggregates_p)) { -} - -OperatorPartitionInfo PhysicalPartitionedAggregate::RequiredPartitionInfo() const { - return OperatorPartitionInfo::PartitionColumns(partitions); -} -//===--------------------------------------------------------------------===// -// Global State -//===--------------------------------------------------------------------===// -class PartitionedAggregateLocalSinkState : public LocalSinkState { -public: - PartitionedAggregateLocalSinkState(const PhysicalPartitionedAggregate &op, const vector &child_types, - ExecutionContext &context) - : execute_state(context.client, op.aggregates, child_types) { - } - - //! The current partition - Value current_partition; - //! The local aggregate state for the current partition - unique_ptr state; - //! The ungrouped aggregate execute state - UngroupedAggregateExecuteState execute_state; -}; - -class PartitionedAggregateGlobalSinkState : public GlobalSinkState { -public: - PartitionedAggregateGlobalSinkState(const PhysicalPartitionedAggregate &op, ClientContext &context) - : op(op), aggregate_result(BufferAllocator::Get(context), op.types) { - } - - mutex lock; - const PhysicalPartitionedAggregate &op; - //! The per-partition aggregate states - value_map_t> aggregate_states; - //! Final aggregate result - ColumnDataCollection aggregate_result; - - GlobalUngroupedAggregateState &GetOrCreatePartition(ClientContext &context, const Value &partition) { - lock_guard l(lock); - // find the state that corresponds to this partition and combine - auto entry = aggregate_states.find(partition); - if (entry != aggregate_states.end()) { - return *entry->second; - } - // no state yet for this partition - allocate a new one - auto new_global_state = make_uniq(BufferAllocator::Get(context), op.aggregates); - auto &result = *new_global_state; - aggregate_states.insert(make_pair(partition, std::move(new_global_state))); - return result; - } - - void Combine(ClientContext &context, PartitionedAggregateLocalSinkState &lstate) { - if (!lstate.state) { - // no aggregate state - return; - } - auto &global_state = GetOrCreatePartition(context, lstate.current_partition); - global_state.Combine(*lstate.state); - // clear the local aggregate state - lstate.state.reset(); - } -}; - -unique_ptr PhysicalPartitionedAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -//===--------------------------------------------------------------------===// -// Local State -//===--------------------------------------------------------------------===// - -unique_ptr PhysicalPartitionedAggregate::GetLocalSinkState(ExecutionContext &context) const { - D_ASSERT(sink_state); - return make_uniq(*this, children[0]->GetTypes(), context); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalPartitionedAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (!lstate.state) { - // the local state is not yet initialized for this partition - // initialize the partition - child_list_t partition_values; - for (idx_t partition_idx = 0; partition_idx < groups.size(); partition_idx++) { - auto column_name = to_string(partition_idx); - auto &partition = input.local_state.partition_info.partition_data[partition_idx]; - D_ASSERT(Value::NotDistinctFrom(partition.min_val, partition.max_val)); - partition_values.emplace_back(make_pair(std::move(column_name), partition.min_val)); - } - lstate.current_partition = Value::STRUCT(std::move(partition_values)); - - // initialize the state - auto &global_aggregate_state = gstate.GetOrCreatePartition(context.client, lstate.current_partition); - lstate.state = make_uniq(global_aggregate_state); - } - - // perform the aggregation - lstate.execute_state.Sink(*lstate.state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Next Batch -//===--------------------------------------------------------------------===// -SinkNextBatchType PhysicalPartitionedAggregate::NextBatch(ExecutionContext &context, - OperatorSinkNextBatchInput &input) const { - // flush the local state - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // finalize and reset the current state (if any) - gstate.Combine(context.client, lstate); - return SinkNextBatchType::READY; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalPartitionedAggregate::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.Combine(context.client, lstate); - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPartitionedAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - ColumnDataAppendState append_state; - gstate.aggregate_result.InitializeAppend(append_state); - // finalize each of the partitions and append to a ColumnDataCollection - DataChunk chunk; - chunk.Initialize(context, types); - for (auto &entry : gstate.aggregate_states) { - chunk.Reset(); - // reference the partitions - auto &partitions = StructValue::GetChildren(entry.first); - for (idx_t partition_idx = 0; partition_idx < partitions.size(); partition_idx++) { - chunk.data[partition_idx].Reference(partitions[partition_idx]); - } - // finalize the aggregates - entry.second->Finalize(chunk, partitions.size()); - - // append to the CDC - gstate.aggregate_result.Append(append_state, chunk); - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PartitionedAggregateGlobalSourceState : public GlobalSourceState { -public: - explicit PartitionedAggregateGlobalSourceState(PartitionedAggregateGlobalSinkState &gstate) { - gstate.aggregate_result.InitializeScan(scan_state); - } - - ColumnDataScanState scan_state; - - idx_t MaxThreads() override { - return 1; - } -}; - -unique_ptr PhysicalPartitionedAggregate::GetGlobalSourceState(ClientContext &context) const { - auto &gstate = sink_state->Cast(); - return make_uniq(gstate); -} - -SourceResultType PhysicalPartitionedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - auto &gsource = input.global_state.Cast(); - gstate.aggregate_result.Scan(gsource.scan_state, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// ParamsToString -//===--------------------------------------------------------------------===// -InsertionOrderPreservingMap PhysicalPartitionedAggregate::ParamsToString() const { - InsertionOrderPreservingMap result; - string groups_info; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - groups_info += "\n"; - } - groups_info += groups[i]->GetName(); - } - result["Groups"] = groups_info; - string aggregate_info; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]->Cast(); - if (i > 0) { - aggregate_info += "\n"; - } - aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); - } - } - result["Aggregates"] = aggregate_info; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp deleted file mode 100644 index d00b6fd9e..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_perfecthash_aggregate.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp" - -#include "duckdb/execution/perfect_aggregate_hashtable.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -PhysicalPerfectHashAggregate::PhysicalPerfectHashAggregate(ClientContext &context, vector types_p, - vector> aggregates_p, - vector> groups_p, - const vector> &group_stats, - vector required_bits_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::PERFECT_HASH_GROUP_BY, std::move(types_p), estimated_cardinality), - groups(std::move(groups_p)), aggregates(std::move(aggregates_p)), required_bits(std::move(required_bits_p)) { - D_ASSERT(groups.size() == group_stats.size()); - group_minima.reserve(group_stats.size()); - for (auto &stats : group_stats) { - D_ASSERT(stats); - auto &nstats = *stats; - D_ASSERT(NumericStats::HasMin(nstats)); - group_minima.push_back(NumericStats::Min(nstats)); - } - for (auto &expr : groups) { - group_types.push_back(expr->return_type); - } - - vector bindings; - vector payload_types_filters; - for (auto &expr : aggregates) { - D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - D_ASSERT(expr->IsAggregate()); - auto &aggr = expr->Cast(); - bindings.push_back(&aggr); - - D_ASSERT(!aggr.IsDistinct()); - D_ASSERT(aggr.function.combine); - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - } - if (aggr.filter) { - payload_types_filters.push_back(aggr.filter->return_type); - } - } - for (const auto &pay_filters : payload_types_filters) { - payload_types.push_back(pay_filters); - } - aggregate_objects = AggregateObject::CreateAggregateObjects(bindings); - - // filter_indexes must be pre-built, not lazily instantiated in parallel... - idx_t aggregate_input_idx = 0; - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - aggregate_input_idx += aggr.children.size(); - } - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto &bound_ref_expr = aggr.filter->Cast(); - auto it = filter_indexes.find(aggr.filter.get()); - if (it == filter_indexes.end()) { - filter_indexes[aggr.filter.get()] = bound_ref_expr.index; - bound_ref_expr.index = aggregate_input_idx++; - } else { - ++aggregate_input_idx; - } - } - } -} - -unique_ptr PhysicalPerfectHashAggregate::CreateHT(Allocator &allocator, - ClientContext &context) const { - return make_uniq(context, allocator, group_types, payload_types, aggregate_objects, - group_minima, required_bits); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class PerfectHashAggregateGlobalState : public GlobalSinkState { -public: - PerfectHashAggregateGlobalState(const PhysicalPerfectHashAggregate &op, ClientContext &context) - : ht(op.CreateHT(Allocator::Get(context), context)) { - } - - //! The lock for updating the global aggregate state - mutex lock; - //! The global aggregate hash table - unique_ptr ht; -}; - -class PerfectHashAggregateLocalState : public LocalSinkState { -public: - PerfectHashAggregateLocalState(const PhysicalPerfectHashAggregate &op, ExecutionContext &context) - : ht(op.CreateHT(Allocator::Get(context.client), context.client)) { - group_chunk.InitializeEmpty(op.group_types); - if (!op.payload_types.empty()) { - aggregate_input_chunk.InitializeEmpty(op.payload_types); - } - } - - //! The local aggregate hash table - unique_ptr ht; - DataChunk group_chunk; - DataChunk aggregate_input_chunk; -}; - -unique_ptr PhysicalPerfectHashAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalPerfectHashAggregate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*this, context); -} - -SinkResultType PhysicalPerfectHashAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - DataChunk &group_chunk = lstate.group_chunk; - DataChunk &aggregate_input_chunk = lstate.aggregate_input_chunk; - - for (idx_t group_idx = 0; group_idx < groups.size(); group_idx++) { - auto &group = groups[group_idx]; - D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref_expr = group->Cast(); - group_chunk.data[group_idx].Reference(chunk.data[bound_ref_expr.index]); - } - idx_t aggregate_input_idx = 0; - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - for (auto &child_expr : aggr.children) { - D_ASSERT(child_expr->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref_expr = child_expr->Cast(); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[bound_ref_expr.index]); - } - } - for (auto &aggregate : aggregates) { - auto &aggr = aggregate->Cast(); - if (aggr.filter) { - auto it = filter_indexes.find(aggr.filter.get()); - D_ASSERT(it != filter_indexes.end()); - aggregate_input_chunk.data[aggregate_input_idx++].Reference(chunk.data[it->second]); - } - } - - group_chunk.SetCardinality(chunk.size()); - - aggregate_input_chunk.SetCardinality(chunk.size()); - - group_chunk.Verify(); - aggregate_input_chunk.Verify(); - D_ASSERT(aggregate_input_chunk.ColumnCount() == 0 || group_chunk.size() == aggregate_input_chunk.size()); - - lstate.ht->AddChunk(group_chunk, aggregate_input_chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalPerfectHashAggregate::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - - lock_guard l(gstate.lock); - gstate.ht->Combine(*lstate.ht); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PerfectHashAggregateState : public GlobalSourceState { -public: - PerfectHashAggregateState() : ht_scan_position(0) { - } - - //! The current position to scan the HT for output tuples - idx_t ht_scan_position; -}; - -unique_ptr PhysicalPerfectHashAggregate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalPerfectHashAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &gstate = sink_state->Cast(); - - gstate.ht->Scan(state.ht_scan_position, chunk); - - if (chunk.size() > 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else { - return SourceResultType::FINISHED; - } -} - -InsertionOrderPreservingMap PhysicalPerfectHashAggregate::ParamsToString() const { - InsertionOrderPreservingMap result; - string groups_info; - for (idx_t i = 0; i < groups.size(); i++) { - if (i > 0) { - groups_info += "\n"; - } - groups_info += groups[i]->GetName(); - } - result["Groups"] = groups_info; - - string aggregate_info; - for (idx_t i = 0; i < aggregates.size(); i++) { - if (i > 0) { - aggregate_info += "\n"; - } - aggregate_info += aggregates[i]->GetName(); - auto &aggregate = aggregates[i]->Cast(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); - } - } - result["Aggregates"] = aggregate_info; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp deleted file mode 100644 index 8d78e6653..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_streaming_window.cpp +++ /dev/null @@ -1,647 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" - -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/aggregate_function.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -PhysicalStreamingWindow::PhysicalStreamingWindow(vector types, vector> select_list, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { -} - -class StreamingWindowGlobalState : public GlobalOperatorState { -public: - StreamingWindowGlobalState() : row_number(1) { - } - - //! The next row number. - std::atomic row_number; -}; - -class StreamingWindowState : public OperatorState { -public: - struct AggregateState { - AggregateState(ClientContext &client, BoundWindowExpression &wexpr, Allocator &allocator) - : wexpr(wexpr), arena_allocator(Allocator::DefaultAllocator()), executor(client), filter_executor(client), - statev(LogicalType::POINTER, data_ptr_cast(&state_ptr)), hashes(LogicalType::HASH), - addresses(LogicalType::POINTER) { - D_ASSERT(wexpr.GetExpressionType() == ExpressionType::WINDOW_AGGREGATE); - auto &aggregate = *wexpr.aggregate; - bind_data = wexpr.bind_info.get(); - dtor = aggregate.destructor; - state.resize(aggregate.state_size(aggregate)); - state_ptr = state.data(); - aggregate.initialize(aggregate, state.data()); - for (auto &child : wexpr.children) { - arg_types.push_back(child->return_type); - executor.AddExpression(*child); - } - if (!arg_types.empty()) { - arg_chunk.Initialize(allocator, arg_types); - arg_cursor.Initialize(allocator, arg_types); - } - if (wexpr.filter_expr) { - filter_executor.AddExpression(*wexpr.filter_expr); - filter_sel.Initialize(); - } - if (wexpr.distinct) { - distinct = make_uniq(client, allocator, arg_types); - distinct_args.Initialize(allocator, arg_types); - distinct_sel.Initialize(); - } - } - - ~AggregateState() { - if (dtor) { - AggregateInputData aggr_input_data(bind_data, arena_allocator); - state_ptr = state.data(); - dtor(statev, aggr_input_data, 1); - } - } - - void Execute(ExecutionContext &context, DataChunk &input, Vector &result); - - //! The aggregate expression - BoundWindowExpression &wexpr; - //! The allocator to use for aggregate data structures - ArenaAllocator arena_allocator; - //! Reusable executor for the children - ExpressionExecutor executor; - //! Shared executor for FILTER clauses - ExpressionExecutor filter_executor; - //! The single aggregate state we update row-by-row - vector state; - //! The pointer to the state stored in the state vector - data_ptr_t state_ptr = nullptr; - //! The state vector for the single state - Vector statev; - //! The aggregate binding data (if any) - FunctionData *bind_data = nullptr; - //! The aggregate state destructor (if any) - aggregate_destructor_t dtor = nullptr; - //! The inputs rows that pass the FILTER - SelectionVector filter_sel; - //! The number of unfiltered rows so far for COUNT(*) - int64_t unfiltered = 0; - //! Argument types - vector arg_types; - //! Argument value buffer - DataChunk arg_chunk; - //! Argument cursor (a one element slice of arg_chunk) - DataChunk arg_cursor; - - //! Hash table for accumulating the distinct values - unique_ptr distinct; - //! Filtered arguments for checking distinctness - DataChunk distinct_args; - //! Reusable hash vector - Vector hashes; - //! Rows that produced new distinct values - SelectionVector distinct_sel; - //! Pointers to groups in the hash table. - Vector addresses; - }; - - struct LeadLagState { - // Fixed size - static constexpr idx_t MAX_BUFFER = 2048U; - - static bool ComputeOffset(ClientContext &context, BoundWindowExpression &wexpr, int64_t &offset) { - offset = 1; - if (wexpr.offset_expr) { - if (wexpr.offset_expr->HasParameter() || !wexpr.offset_expr->IsFoldable()) { - return false; - } - auto offset_value = ExpressionExecutor::EvaluateScalar(context, *wexpr.offset_expr); - if (offset_value.IsNull()) { - return false; - } - Value bigint_value; - if (!offset_value.DefaultTryCastAs(LogicalType::BIGINT, bigint_value, nullptr, false)) { - return false; - } - offset = bigint_value.GetValue(); - } - - // We can only support LEAD and LAG values within one standard vector - if (wexpr.GetExpressionType() == ExpressionType::WINDOW_LEAD) { - offset = -offset; - } - return idx_t(std::abs(offset)) < MAX_BUFFER; - } - - static bool ComputeDefault(ClientContext &context, BoundWindowExpression &wexpr, Value &result) { - if (!wexpr.default_expr) { - result = Value(wexpr.return_type); - return true; - } - - if (wexpr.default_expr && (wexpr.default_expr->HasParameter() || !wexpr.default_expr->IsFoldable())) { - return false; - } - auto dflt_value = ExpressionExecutor::EvaluateScalar(context, *wexpr.default_expr); - return dflt_value.DefaultTryCastAs(wexpr.return_type, result, nullptr, false); - } - - LeadLagState(ClientContext &context, BoundWindowExpression &wexpr) - : wexpr(wexpr), executor(context, *wexpr.children[0]), prev(wexpr.return_type), temp(wexpr.return_type) { - ComputeOffset(context, wexpr, offset); - ComputeDefault(context, wexpr, dflt); - - curr_chunk.Initialize(context, {wexpr.return_type}); - - buffered = idx_t(std::abs(offset)); - prev.Reference(dflt); - prev.Flatten(buffered); - temp.Initialize(false, buffered); - } - - void Execute(ExecutionContext &context, DataChunk &input, DataChunk &delayed, Vector &result) { - if (offset >= 0) { - ExecuteLag(context, input, result); - } else { - ExecuteLead(context, input, delayed, result); - } - } - - void ExecuteLag(ExecutionContext &context, DataChunk &input, Vector &result) { - D_ASSERT(offset >= 0); - auto &curr = curr_chunk.data[0]; - curr_chunk.Reset(); - executor.Execute(input, curr_chunk); - const idx_t count = input.size(); - // Copy prev[0, buffered] => result[0, buffered] - idx_t source_count = MinValue(buffered, count); - VectorOperations::Copy(prev, result, source_count, 0, 0); - // Special case when we have buffered enough values for the output - if (count < buffered) { - // Shift down incomplete buffers - // Copy prev[buffered-count, buffered] => temp[0, count] - source_count = buffered - count; - FlatVector::Validity(temp).Reset(); - VectorOperations::Copy(prev, temp, buffered, source_count, 0); - - // Copy temp[0, count] => prev[0, count] - FlatVector::Validity(prev).Reset(); - VectorOperations::Copy(temp, prev, count, 0, 0); - // Copy curr[0, buffered-count] => prev[count, buffered] - VectorOperations::Copy(curr, prev, source_count, 0, count); - } else { - // Copy input values beyond what we have buffered - source_count = count - buffered; - // Copy curr[0, count-buffered] => result[buffered, count] - VectorOperations::Copy(curr, result, source_count, 0, buffered); - // Copy curr[count-buffered, count] => prev[0, buffered] - FlatVector::Validity(prev).Reset(); - VectorOperations::Copy(curr, prev, count, source_count, 0); - } - } - - void ExecuteLead(ExecutionContext &context, DataChunk &input, DataChunk &delayed, Vector &result) { - // We treat input || delayed as a logical unified buffer - D_ASSERT(offset < 0); - // Input has been set up with the number of rows we CAN produce. - const idx_t count = input.size(); - auto &curr = curr_chunk.data[0]; - // Copy unified[buffered:count] => result[pos:] - idx_t pos = 0; - idx_t unified_offset = buffered; - if (unified_offset < count) { - curr_chunk.Reset(); - executor.Execute(input, curr_chunk); - VectorOperations::Copy(curr, result, count, unified_offset, pos); - pos += count - unified_offset; - unified_offset = count; - } - // Copy unified[unified_offset:] => result[pos:] - idx_t unified_count = count + delayed.size(); - if (unified_offset < unified_count) { - curr_chunk.Reset(); - executor.Execute(delayed, curr_chunk); - idx_t delayed_offset = unified_offset - count; - // Only copy as many values as we need - idx_t delayed_count = MinValue(delayed.size(), delayed_offset + (count - pos)); - VectorOperations::Copy(curr, result, delayed_count, delayed_offset, pos); - pos += delayed_count - delayed_offset; - } - // Copy default[:count-pos] => result[pos:] - if (pos < count) { - const idx_t defaulted = count - pos; - VectorOperations::Copy(prev, result, defaulted, 0, pos); - } - } - - //! The aggregate expression - BoundWindowExpression &wexpr; - //! Cache the executor to cut down on memory allocation - ExpressionExecutor executor; - //! The constant offset - int64_t offset; - //! The number of rows we have buffered - idx_t buffered; - //! The constant default value - Value dflt; - //! The current set of values - DataChunk curr_chunk; - //! The previous set of values - Vector prev; - //! The copy buffer - Vector temp; - }; - - explicit StreamingWindowState(ClientContext &client) : initialized(false), allocator(Allocator::Get(client)) { - } - - ~StreamingWindowState() override { - } - - void Initialize(ClientContext &context, DataChunk &input, const vector> &expressions) { - const_vectors.resize(expressions.size()); - aggregate_states.resize(expressions.size()); - lead_lag_states.resize(expressions.size()); - - for (idx_t expr_idx = 0; expr_idx < expressions.size(); expr_idx++) { - auto &expr = *expressions[expr_idx]; - auto &wexpr = expr.Cast(); - switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: - aggregate_states[expr_idx] = make_uniq(context, wexpr, allocator); - break; - case ExpressionType::WINDOW_FIRST_VALUE: { - // Just execute the expression once - ExpressionExecutor executor(context); - executor.AddExpression(*wexpr.children[0]); - DataChunk result; - result.Initialize(Allocator::Get(context), {wexpr.children[0]->return_type}); - executor.Execute(input, result); - - const_vectors[expr_idx] = make_uniq(result.GetValue(0, 0)); - break; - } - case ExpressionType::WINDOW_PERCENT_RANK: { - const_vectors[expr_idx] = make_uniq(Value((double)0)); - break; - } - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: { - const_vectors[expr_idx] = make_uniq(Value((int64_t)1)); - break; - } - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_LEAD: { - lead_lag_states[expr_idx] = make_uniq(context, wexpr); - const auto offset = lead_lag_states[expr_idx]->offset; - if (offset < 0) { - lead_count = MaxValue(idx_t(-offset), lead_count); - } - break; - } - default: - break; - } - } - if (lead_count) { - delayed.Initialize(context, input.GetTypes(), lead_count + STANDARD_VECTOR_SIZE); - shifted.Initialize(context, input.GetTypes(), lead_count + STANDARD_VECTOR_SIZE); - } - initialized = true; - } - -public: - //! We can't initialise until we have an input chunk - bool initialized; - //! The values that are determined by the first row. - vector> const_vectors; - //! Aggregation states - vector> aggregate_states; - Allocator &allocator; - //! Lead/Lag states - vector> lead_lag_states; - //! The number of rows ahead to buffer for LEAD - idx_t lead_count = 0; - //! A buffer for delayed input - DataChunk delayed; - //! A buffer for shifting delayed input - DataChunk shifted; -}; - -bool PhysicalStreamingWindow::IsStreamingFunction(ClientContext &context, unique_ptr &expr) { - auto &wexpr = expr->Cast(); - if (!wexpr.partitions.empty() || !wexpr.orders.empty() || wexpr.ignore_nulls || !wexpr.arg_orders.empty() || - wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { - return false; - } - switch (wexpr.GetExpressionType()) { - // TODO: add more expression types here? - case ExpressionType::WINDOW_AGGREGATE: - // We can stream aggregates if they are "running totals" - return wexpr.start == WindowBoundary::UNBOUNDED_PRECEDING && wexpr.end == WindowBoundary::CURRENT_ROW_ROWS; - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: - case ExpressionType::WINDOW_ROW_NUMBER: - return true; - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_LEAD: { - // We can stream LEAD/LAG if the arguments are constant and the delta is less than a block behind - Value dflt; - int64_t offset; - return StreamingWindowState::LeadLagState::ComputeDefault(context, wexpr, dflt) && - StreamingWindowState::LeadLagState::ComputeOffset(context, wexpr, offset); - } - default: - return false; - } -} - -unique_ptr PhysicalStreamingWindow::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalStreamingWindow::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client); -} - -void StreamingWindowState::AggregateState::Execute(ExecutionContext &context, DataChunk &input, Vector &result) { - // Establish the aggregation environment - const idx_t count = input.size(); - auto &aggregate = *wexpr.aggregate; - auto &aggr_state = *this; - auto &statev = aggr_state.statev; - - // Compute the FILTER mask (if any) - ValidityMask filter_mask; - auto filtered = count; - auto &filter_sel = aggr_state.filter_sel; - if (wexpr.filter_expr) { - filtered = filter_executor.SelectExpression(input, filter_sel); - if (filtered < count) { - filter_mask.Initialize(count); - filter_mask.SetAllInvalid(count); - for (idx_t f = 0; f < filtered; ++f) { - filter_mask.SetValid(filter_sel.get_index(f)); - } - } - } - - // Check for COUNT(*) - if (wexpr.children.empty()) { - D_ASSERT(GetTypeIdSize(result.GetType().InternalType()) == sizeof(int64_t)); - auto data = FlatVector::GetData(result); - auto &unfiltered = aggr_state.unfiltered; - for (idx_t i = 0; i < count; ++i) { - unfiltered += int64_t(filter_mask.RowIsValid(i)); - data[i] = unfiltered; - } - return; - } - - // Compute the arguments - auto &arg_chunk = aggr_state.arg_chunk; - executor.Execute(input, arg_chunk); - arg_chunk.Flatten(); - - // Update the distinct hash table - ValidityMask distinct_mask; - if (aggr_state.distinct) { - auto &distinct_args = aggr_state.distinct_args; - distinct_args.Reference(arg_chunk); - if (wexpr.filter_expr) { - distinct_args.Slice(filter_sel, filtered); - } - idx_t distinct = 0; - auto &distinct_sel = aggr_state.distinct_sel; - if (filtered) { - // FindOrCreateGroups assumes non-empty input - auto &hashes = aggr_state.hashes; - distinct_args.Hash(hashes); - - auto &addresses = aggr_state.addresses; - distinct = aggr_state.distinct->FindOrCreateGroups(distinct_args, hashes, addresses, distinct_sel); - } - - // Translate the distinct selection from filtered row numbers - // back to input row numbers. We need to produce output for all input rows, - // so we filter out - if (distinct < filtered) { - distinct_mask.Initialize(count); - distinct_mask.SetAllInvalid(count); - for (idx_t d = 0; d < distinct; ++d) { - const auto f = distinct_sel.get_index(d); - distinct_mask.SetValid(filter_sel.get_index(f)); - } - } - } - - // Iterate through them using a single SV - sel_t s = 0; - SelectionVector sel(&s); - auto &arg_cursor = aggr_state.arg_cursor; - arg_cursor.Reset(); - arg_cursor.Slice(sel, 1); - // This doesn't work for STRUCTs because the SV - // is not copied to the children when you slice - vector structs; - for (column_t col_idx = 0; col_idx < arg_chunk.ColumnCount(); ++col_idx) { - auto &col_vec = arg_cursor.data[col_idx]; - DictionaryVector::Child(col_vec).Reference(arg_chunk.data[col_idx]); - if (col_vec.GetType().InternalType() == PhysicalType::STRUCT) { - structs.emplace_back(col_idx); - } - } - - // Update the state and finalize it one row at a time. - AggregateInputData aggr_input_data(wexpr.bind_info.get(), aggr_state.arena_allocator); - for (idx_t i = 0; i < count; ++i) { - sel.set_index(0, i); - for (const auto struct_idx : structs) { - arg_cursor.data[struct_idx].Slice(arg_chunk.data[struct_idx], sel, 1); - } - if (filter_mask.RowIsValid(i) && distinct_mask.RowIsValid(i)) { - aggregate.update(arg_cursor.data.data(), aggr_input_data, arg_cursor.ColumnCount(), statev, 1); - } - aggregate.finalize(statev, aggr_input_data, result, 1, i); - } -} - -void PhysicalStreamingWindow::ExecuteFunctions(ExecutionContext &context, DataChunk &chunk, DataChunk &delayed, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - - // Compute window functions - const idx_t count = chunk.size(); - const column_t input_width = children[0]->GetTypes().size(); - for (column_t expr_idx = 0; expr_idx < select_list.size(); expr_idx++) { - column_t col_idx = input_width + expr_idx; - auto &expr = *select_list[expr_idx]; - auto &result = chunk.data[col_idx]; - switch (expr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: - state.aggregate_states[expr_idx]->Execute(context, chunk, result); - break; - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_PERCENT_RANK: - case ExpressionType::WINDOW_RANK: - case ExpressionType::WINDOW_RANK_DENSE: { - // Reference constant vector - chunk.data[col_idx].Reference(*state.const_vectors[expr_idx]); - break; - } - case ExpressionType::WINDOW_ROW_NUMBER: { - // Set row numbers - int64_t start_row = gstate.row_number; - auto rdata = FlatVector::GetData(chunk.data[col_idx]); - for (idx_t i = 0; i < count; i++) { - rdata[i] = NumericCast(start_row + NumericCast(i)); - } - break; - } - case ExpressionType::WINDOW_LAG: - case ExpressionType::WINDOW_LEAD: - state.lead_lag_states[expr_idx]->Execute(context, chunk, delayed, result); - break; - default: - throw NotImplementedException("%s for StreamingWindow", ExpressionTypeToString(expr.GetExpressionType())); - } - } - gstate.row_number += NumericCast(count); -} - -void PhysicalStreamingWindow::ExecuteInput(ExecutionContext &context, DataChunk &delayed, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - - // Put payload columns in place - for (idx_t col_idx = 0; col_idx < input.data.size(); col_idx++) { - chunk.data[col_idx].Reference(input.data[col_idx]); - } - idx_t count = input.size(); - - // Handle LEAD - if (state.lead_count > 0) { - // Nothing delayed yet, so just truncate and copy the delayed values - count -= state.lead_count; - input.Copy(delayed, count); - } - chunk.SetCardinality(count); - - ExecuteFunctions(context, chunk, state.delayed, gstate_p, state_p); -} - -void PhysicalStreamingWindow::ExecuteShifted(ExecutionContext &context, DataChunk &delayed, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &shifted = state.shifted; - - idx_t i = input.size(); - idx_t d = delayed.size(); - shifted.Reset(); - // shifted = delayed - delayed.Copy(shifted); - delayed.Reset(); - for (idx_t col_idx = 0; col_idx < delayed.data.size(); ++col_idx) { - // chunk[0:i] = shifted[0:i] - chunk.data[col_idx].Reference(shifted.data[col_idx]); - // delayed[0:i] = chunk[i:d-i] - VectorOperations::Copy(shifted.data[col_idx], delayed.data[col_idx], d, i, 0); - // delayed[d-i:d] = input[0:i] - VectorOperations::Copy(input.data[col_idx], delayed.data[col_idx], i, 0, d - i); - } - chunk.SetCardinality(i); - delayed.SetCardinality(d); - - ExecuteFunctions(context, chunk, delayed, gstate_p, state_p); -} - -void PhysicalStreamingWindow::ExecuteDelayed(ExecutionContext &context, DataChunk &delayed, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - // Put payload columns in place - for (idx_t col_idx = 0; col_idx < delayed.data.size(); col_idx++) { - chunk.data[col_idx].Reference(delayed.data[col_idx]); - } - idx_t count = delayed.size(); - chunk.SetCardinality(count); - - ExecuteFunctions(context, chunk, input, gstate_p, state_p); -} - -OperatorResultType PhysicalStreamingWindow::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &state = state_p.Cast(); - if (!state.initialized) { - state.Initialize(context.client, input, select_list); - } - - auto &delayed = state.delayed; - // We can Reset delayed now that no one can be referencing it. - if (!delayed.size()) { - delayed.Reset(); - } - const idx_t available = delayed.size() + input.size(); - if (available <= state.lead_count) { - // If we don't have enough to produce a single row, - // then just delay more rows, return nothing - // and ask for more data. - delayed.Append(input); - chunk.SetCardinality(0); - return OperatorResultType::NEED_MORE_INPUT; - } else if (input.size() < delayed.size()) { - // If we can't consume all of the delayed values, - // we need to split them instead of referencing them all - ExecuteShifted(context, delayed, input, chunk, gstate_p, state_p); - // We delayed the unused input so ask for more - return OperatorResultType::NEED_MORE_INPUT; - } else if (delayed.size()) { - // We have enough delayed rows so flush them - ExecuteDelayed(context, delayed, input, chunk, gstate_p, state_p); - // Defer resetting delayed as it may be referenced. - delayed.SetCardinality(0); - // Come back to process the input - return OperatorResultType::HAVE_MORE_OUTPUT; - } else { - // No delayed rows, so emit what we can and delay the rest. - ExecuteInput(context, delayed, input, chunk, gstate_p, state_p); - return OperatorResultType::NEED_MORE_INPUT; - } -} - -OperatorFinalizeResultType PhysicalStreamingWindow::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - - if (state.initialized && state.lead_count) { - auto &delayed = state.delayed; - // There are no more input rows - auto &input = state.shifted; - input.Reset(); - ExecuteDelayed(context, delayed, input, chunk, gstate_p, state_p); - } - - return OperatorFinalizeResultType::FINISHED; -} - -InsertionOrderPreservingMap PhysicalStreamingWindow::ParamsToString() const { - InsertionOrderPreservingMap result; - string projections; - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - projections += "\n"; - } - projections += select_list[i]->GetName(); - } - result["Projections"] = projections; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp b/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp deleted file mode 100644 index 0ec4bd2a3..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_ungrouped_aggregate.cpp +++ /dev/null @@ -1,675 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" - -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/unordered_set.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/aggregate/aggregate_object.hpp" -#include "duckdb/execution/operator/aggregate/distinct_aggregate_data.hpp" -#include "duckdb/execution/radix_partitioned_hashtable.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/interrupt.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/parallel/executor_task.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" - -#include - -namespace duckdb { - -PhysicalUngroupedAggregate::PhysicalUngroupedAggregate(vector types, - vector> expressions, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::UNGROUPED_AGGREGATE, std::move(types), estimated_cardinality), - aggregates(std::move(expressions)) { - - distinct_collection_info = DistinctAggregateCollectionInfo::Create(aggregates); - if (!distinct_collection_info) { - return; - } - distinct_data = make_uniq(*distinct_collection_info); -} - -//===--------------------------------------------------------------------===// -// Ungrouped Aggregate State -//===--------------------------------------------------------------------===// -UngroupedAggregateState::UngroupedAggregateState(const vector> &aggregate_expressions) - : aggregate_expressions(aggregate_expressions) { - counts = make_uniq_array>(aggregate_expressions.size()); - for (idx_t i = 0; i < aggregate_expressions.size(); i++) { - auto &aggregate = aggregate_expressions[i]; - D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = aggregate->Cast(); - auto state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, state.get()); - aggregate_data.push_back(std::move(state)); - bind_data.push_back(aggr.bind_info.get()); - destructors.push_back(aggr.function.destructor); -#ifdef DEBUG - counts[i] = 0; -#endif - } -} -UngroupedAggregateState::~UngroupedAggregateState() { - D_ASSERT(destructors.size() == aggregate_data.size()); - for (idx_t i = 0; i < destructors.size(); i++) { - if (!destructors[i]) { - continue; - } - Vector state_vector(Value::POINTER(CastPointerToValue(aggregate_data[i].get()))); - state_vector.SetVectorType(VectorType::FLAT_VECTOR); - - ArenaAllocator allocator(Allocator::DefaultAllocator()); - AggregateInputData aggr_input_data(bind_data[i], allocator); - destructors[i](state_vector, aggr_input_data, 1); - } -} - -void UngroupedAggregateState::Move(UngroupedAggregateState &other) { - other.aggregate_data = std::move(aggregate_data); - other.destructors = std::move(destructors); -} - -//===--------------------------------------------------------------------===// -// Global State -//===--------------------------------------------------------------------===// -class UngroupedAggregateGlobalSinkState : public GlobalSinkState { -public: - UngroupedAggregateGlobalSinkState(const PhysicalUngroupedAggregate &op, ClientContext &client) - : state(BufferAllocator::Get(client), op.aggregates), finished(false) { - if (op.distinct_data) { - distinct_state = make_uniq(*op.distinct_data, client); - } - } - - //! The global aggregate state - GlobalUngroupedAggregateState state; - //! Whether or not the aggregate is finished - bool finished; - //! The data related to the distinct aggregates (if there are any) - unique_ptr distinct_state; -}; - -ArenaAllocator &GlobalUngroupedAggregateState::CreateAllocator() const { - lock_guard glock(lock); - stored_allocators.emplace_back(make_uniq(client_allocator)); - return *stored_allocators.back(); -} - -void GlobalUngroupedAggregateState::Combine(LocalUngroupedAggregateState &other) { - lock_guard glock(lock); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - - if (aggregate.IsDistinct()) { - continue; - } - - Vector source_state(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get()))); - Vector dest_state(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); - - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); - aggregate.function.combine(source_state, dest_state, aggr_input_data, 1); -#ifdef DEBUG - state.counts[aggr_idx] += other.state.counts[aggr_idx]; -#endif - } -} - -void GlobalUngroupedAggregateState::CombineDistinct(LocalUngroupedAggregateState &other, - DistinctAggregateData &distinct_data) { - lock_guard glock(lock); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { - if (!distinct_data.IsDistinct(aggr_idx)) { - continue; - } - - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator, - AggregateCombineType::ALLOW_DESTRUCTIVE); - - Vector state_vec(Value::POINTER(CastPointerToValue(other.state.aggregate_data[aggr_idx].get()))); - Vector combined_vec(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); - aggregate.function.combine(state_vec, combined_vec, aggr_input_data, 1); -#ifdef DEBUG - state.counts[aggr_idx] += other.state.counts[aggr_idx]; -#endif - } -} - -//===--------------------------------------------------------------------===// -// Ungrouped Aggregate Execute State -//===--------------------------------------------------------------------===// -UngroupedAggregateExecuteState::UngroupedAggregateExecuteState(ClientContext &context, - const vector> &aggregates, - const vector &child_types) - : aggregates(aggregates), child_executor(context), aggregate_input_chunk(), filter_set() { - vector payload_types; - vector aggregate_objects; - auto &allocator = BufferAllocator::Get(context); - for (auto &aggregate : aggregates) { - D_ASSERT(aggregate->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = aggregate->Cast(); - // initialize the payload chunk - for (auto &child : aggr.children) { - payload_types.push_back(child->return_type); - child_executor.AddExpression(*child); - } - aggregate_objects.emplace_back(&aggr); - } - if (!payload_types.empty()) { // for select count(*) from t; there is no payload at all - aggregate_input_chunk.Initialize(allocator, payload_types); - } - filter_set.Initialize(context, aggregate_objects, child_types); -} - -void UngroupedAggregateExecuteState::Reset() { - aggregate_input_chunk.Reset(); -} - -void UngroupedAggregateExecuteState::Sink(LocalUngroupedAggregateState &state, DataChunk &input) { - DataChunk &payload_chunk = aggregate_input_chunk; - - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; - - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]->Cast(); - - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - if (aggregate.IsDistinct()) { - continue; - } - - idx_t payload_cnt = 0; - // resolve the filter (if any) - if (aggregate.filter) { - auto &filtered_data = filter_set.GetFilterData(aggr_idx); - auto count = filtered_data.ApplyFilter(input); - - child_executor.SetChunk(filtered_data.filtered_payload); - payload_chunk.SetCardinality(count); - } else { - child_executor.SetChunk(input); - payload_chunk.SetCardinality(input); - } - - // resolve the child expressions of the aggregate (if any) - for (idx_t i = 0; i < aggregate.children.size(); ++i) { - child_executor.ExecuteExpression(payload_idx + payload_cnt, payload_chunk.data[payload_idx + payload_cnt]); - payload_cnt++; - } - - state.Sink(payload_chunk, payload_idx, aggr_idx); - } -} - -//===--------------------------------------------------------------------===// -// Local State -//===--------------------------------------------------------------------===// -LocalUngroupedAggregateState::LocalUngroupedAggregateState(GlobalUngroupedAggregateState &gstate) - : allocator(gstate.CreateAllocator()), state(gstate.state.aggregate_expressions) { -} - -class UngroupedAggregateLocalSinkState : public LocalSinkState { -public: - UngroupedAggregateLocalSinkState(const PhysicalUngroupedAggregate &op, const vector &child_types, - UngroupedAggregateGlobalSinkState &gstate_p, ExecutionContext &context) - : state(gstate_p.state), execute_state(context.client, op.aggregates, child_types) { - auto &gstate = gstate_p.Cast(); - InitializeDistinctAggregates(op, gstate, context); - } - - //! The local aggregate state - LocalUngroupedAggregateState state; - //! The ungrouped aggregate execute state - UngroupedAggregateExecuteState execute_state; - //! The local sink states of the distinct aggregates hash tables - vector> radix_states; - -public: - void InitializeDistinctAggregates(const PhysicalUngroupedAggregate &op, - const UngroupedAggregateGlobalSinkState &gstate, ExecutionContext &context) { - - if (!op.distinct_data) { - return; - } - auto &data = *op.distinct_data; - auto &state = *gstate.distinct_state; - D_ASSERT(!data.radix_tables.empty()); - - const idx_t aggregate_count = state.radix_states.size(); - radix_states.resize(aggregate_count); - - auto &distinct_info = *op.distinct_collection_info; - - for (auto &idx : distinct_info.indices) { - idx_t table_idx = distinct_info.table_map[idx]; - if (data.radix_tables[table_idx] == nullptr) { - // This aggregate has identical input as another aggregate, so no table is created for it - continue; - } - auto &radix_table = *data.radix_tables[table_idx]; - radix_states[table_idx] = radix_table.GetLocalSinkState(context); - } - } -}; - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -bool PhysicalUngroupedAggregate::SinkOrderDependent() const { - for (auto &expr : aggregates) { - auto &aggr = expr->Cast(); - if (aggr.function.order_dependent == AggregateOrderDependent::ORDER_DEPENDENT) { - return true; - } - } - return false; -} - -unique_ptr PhysicalUngroupedAggregate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalUngroupedAggregate::GetLocalSinkState(ExecutionContext &context) const { - D_ASSERT(sink_state); - auto &gstate = sink_state->Cast(); - return make_uniq(*this, children[0]->GetTypes(), gstate, context); -} - -void PhysicalUngroupedAggregate::SinkDistinct(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.local_state.Cast(); - auto &global_sink = input.global_state.Cast(); - D_ASSERT(distinct_data); - auto &distinct_state = *global_sink.distinct_state; - auto &distinct_info = *distinct_collection_info; - auto &distinct_indices = distinct_info.Indices(); - - DataChunk empty_chunk; - - auto &distinct_filter = distinct_info.Indices(); - - for (auto &idx : distinct_indices) { - auto &aggregate = aggregates[idx]->Cast(); - - idx_t table_idx = distinct_info.table_map[idx]; - if (!distinct_data->radix_tables[table_idx]) { - // This distinct aggregate shares its data with another - continue; - } - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state.radix_states[table_idx]; - auto &radix_local_sink = *sink.radix_states[table_idx]; - OperatorSinkInput sink_input {radix_global_sink, radix_local_sink, input.interrupt_state}; - - if (aggregate.filter) { - // The hashtable can apply a filter, but only on the payload - // And in our case, we need to filter the groups (the distinct aggr children) - - // Apply the filter before inserting into the hashtable - auto &filtered_data = sink.execute_state.filter_set.GetFilterData(idx); - idx_t count = filtered_data.ApplyFilter(chunk); - filtered_data.filtered_payload.SetCardinality(count); - - radix_table.Sink(context, filtered_data.filtered_payload, sink_input, empty_chunk, distinct_filter); - } else { - radix_table.Sink(context, chunk, sink_input, empty_chunk, distinct_filter); - } - } -} - -SinkResultType PhysicalUngroupedAggregate::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.local_state.Cast(); - - // perform the aggregation inside the local state - sink.execute_state.Reset(); - - if (distinct_data) { - SinkDistinct(context, chunk, input); - } - - sink.execute_state.Sink(sink.state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -void LocalUngroupedAggregateState::Sink(DataChunk &payload_chunk, idx_t payload_idx, idx_t aggr_idx) { -#ifdef DEBUG - state.counts[aggr_idx] += payload_chunk.size(); -#endif - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - idx_t payload_cnt = aggregate.children.size(); - D_ASSERT(payload_idx + payload_cnt <= payload_chunk.data.size()); - auto start_of_input = payload_cnt == 0 ? nullptr : &payload_chunk.data[payload_idx]; - AggregateInputData aggr_input_data(state.bind_data[aggr_idx], allocator); - aggregate.function.simple_update(start_of_input, aggr_input_data, payload_cnt, state.aggregate_data[aggr_idx].get(), - payload_chunk.size()); -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -void PhysicalUngroupedAggregate::CombineDistinct(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - if (!distinct_data) { - return; - } - auto &distinct_state = gstate.distinct_state; - auto table_count = distinct_data->radix_tables.size(); - for (idx_t table_idx = 0; table_idx < table_count; table_idx++) { - D_ASSERT(distinct_data->radix_tables[table_idx]); - auto &radix_table = *distinct_data->radix_tables[table_idx]; - auto &radix_global_sink = *distinct_state->radix_states[table_idx]; - auto &radix_local_sink = *lstate.radix_states[table_idx]; - - radix_table.Combine(context, radix_global_sink, radix_local_sink); - } -} - -SinkCombineResultType PhysicalUngroupedAggregate::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - D_ASSERT(!gstate.finished); - - // finalize: combine the local state into the global state - // all aggregates are combinable: we might be doing a parallel aggregate - // use the combine method to combine the partial aggregates - OperatorSinkCombineInput distinct_input {gstate, lstate, input.interrupt_state}; - CombineDistinct(context, distinct_input); - - gstate.state.Combine(lstate.state); - - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -class UngroupedDistinctAggregateFinalizeEvent : public BasePipelineEvent { -public: - UngroupedDistinctAggregateFinalizeEvent(ClientContext &context, const PhysicalUngroupedAggregate &op_p, - UngroupedAggregateGlobalSinkState &gstate_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), context(context), op(op_p), gstate(gstate_p), tasks_scheduled(0), - tasks_done(0) { - } - -public: - void Schedule() override; - void FinalizeTask() { - lock_guard finalize(lock); - D_ASSERT(!gstate.finished); - D_ASSERT(tasks_done < tasks_scheduled); - if (++tasks_done == tasks_scheduled) { - gstate.finished = true; - } - } - -private: - ClientContext &context; - - const PhysicalUngroupedAggregate &op; - UngroupedAggregateGlobalSinkState &gstate; - - mutex lock; - idx_t tasks_scheduled; - idx_t tasks_done; - -public: - vector> global_source_states; -}; - -class UngroupedDistinctAggregateFinalizeTask : public ExecutorTask { -public: - UngroupedDistinctAggregateFinalizeTask(Executor &executor, shared_ptr event_p, - const PhysicalUngroupedAggregate &op, - UngroupedAggregateGlobalSinkState &state_p) - : ExecutorTask(executor, std::move(event_p)), op(op), gstate(state_p), aggregate_state(gstate.state) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override; - -private: - TaskExecutionResult AggregateDistinct(); - -private: - const PhysicalUngroupedAggregate &op; - UngroupedAggregateGlobalSinkState &gstate; - - // Distinct aggregation state - LocalUngroupedAggregateState aggregate_state; - idx_t aggregation_idx = 0; - unique_ptr radix_table_lstate; - bool blocked = false; -}; - -void UngroupedDistinctAggregateFinalizeEvent::Schedule() { - D_ASSERT(gstate.distinct_state); - auto &aggregates = op.aggregates; - auto &distinct_data = *op.distinct_data; - - idx_t n_tasks = 0; - idx_t payload_idx = 0; - idx_t next_payload_idx = 0; - for (idx_t agg_idx = 0; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - // Forward the payload idx - payload_idx = next_payload_idx; - next_payload_idx = payload_idx + aggregate.children.size(); - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - global_source_states.push_back(nullptr); - continue; - } - D_ASSERT(distinct_data.info.table_map.count(agg_idx)); - - // Create global state for scanning - auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table_p = *distinct_data.radix_tables[table_idx]; - n_tasks += radix_table_p.MaxThreads(*gstate.distinct_state->radix_states[table_idx]); - global_source_states.push_back(radix_table_p.GetGlobalSourceState(context)); - } - n_tasks = MaxValue(n_tasks, 1); - n_tasks = MinValue(n_tasks, NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())); - - vector> tasks; - for (idx_t i = 0; i < n_tasks; i++) { - tasks.push_back( - make_uniq(pipeline->executor, shared_from_this(), op, gstate)); - tasks_scheduled++; - } - SetTasks(std::move(tasks)); -} - -TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::ExecuteTask(TaskExecutionMode mode) { - auto res = AggregateDistinct(); - if (res == TaskExecutionResult::TASK_BLOCKED) { - return res; - } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -TaskExecutionResult UngroupedDistinctAggregateFinalizeTask::AggregateDistinct() { - D_ASSERT(gstate.distinct_state); - auto &distinct_state = *gstate.distinct_state; - auto &distinct_data = *op.distinct_data; - - auto &aggregates = op.aggregates; - auto &state = aggregate_state; - - // Thread-local contexts - ThreadContext thread_context(executor.context); - ExecutionContext execution_context(executor.context, thread_context, nullptr); - - auto &finalize_event = event->Cast(); - - // Now loop through the distinct aggregates, scanning the distinct HTs - - // This needs to be preserved in case the radix_table.GetData blocks - auto &agg_idx = aggregation_idx; - - for (; agg_idx < aggregates.size(); agg_idx++) { - auto &aggregate = aggregates[agg_idx]->Cast(); - - // If aggregate is not distinct, skip it - if (!distinct_data.IsDistinct(agg_idx)) { - continue; - } - - const auto table_idx = distinct_data.info.table_map.at(agg_idx); - auto &radix_table = *distinct_data.radix_tables[table_idx]; - if (!blocked) { - // Because we can block, we need to make sure we preserve this state - radix_table_lstate = radix_table.GetLocalSourceState(execution_context); - } - auto &lstate = *radix_table_lstate; - - auto &sink = *distinct_state.radix_states[table_idx]; - InterruptState interrupt_state(shared_from_this()); - OperatorSourceInput source_input {*finalize_event.global_source_states[agg_idx], lstate, interrupt_state}; - - DataChunk output_chunk; - output_chunk.Initialize(executor.context, distinct_state.distinct_output_chunks[table_idx]->GetTypes()); - - DataChunk payload_chunk; - payload_chunk.InitializeEmpty(distinct_data.grouped_aggregate_data[table_idx]->group_types); - payload_chunk.SetCardinality(0); - - while (true) { - output_chunk.Reset(); - - auto res = radix_table.GetData(execution_context, output_chunk, sink, source_input); - if (res == SourceResultType::FINISHED) { - D_ASSERT(output_chunk.size() == 0); - break; - } else if (res == SourceResultType::BLOCKED) { - blocked = true; - return TaskExecutionResult::TASK_BLOCKED; - } - - // We dont need to resolve the filter, we already did this in Sink - idx_t payload_cnt = aggregate.children.size(); - for (idx_t i = 0; i < payload_cnt; i++) { - payload_chunk.data[i].Reference(output_chunk.data[i]); - } - payload_chunk.SetCardinality(output_chunk); - - // Update the aggregate state - state.Sink(payload_chunk, 0, agg_idx); - } - blocked = false; - } - - // After scanning the distinct HTs, we can combine the thread-local agg states with the thread-global - gstate.state.CombineDistinct(state, distinct_data); - finalize_event.FinalizeTask(); - return TaskExecutionResult::TASK_FINISHED; -} - -SinkFinalizeType PhysicalUngroupedAggregate::FinalizeDistinct(Pipeline &pipeline, Event &event, ClientContext &context, - GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - D_ASSERT(distinct_data); - auto &distinct_state = *gstate.distinct_state; - - for (idx_t table_idx = 0; table_idx < distinct_data->radix_tables.size(); table_idx++) { - auto &radix_table_p = distinct_data->radix_tables[table_idx]; - auto &radix_state = *distinct_state.radix_states[table_idx]; - radix_table_p->Finalize(context, radix_state); - } - auto new_event = make_shared_ptr(context, *this, gstate, pipeline); - event.InsertEvent(std::move(new_event)); - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalUngroupedAggregate::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - if (distinct_data) { - return FinalizeDistinct(pipeline, event, context, input.global_state); - } - - D_ASSERT(!gstate.finished); - gstate.finished = true; - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -void VerifyNullHandling(DataChunk &chunk, UngroupedAggregateState &state, - const vector> &aggregates) { -#ifdef DEBUG - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggr = aggregates[aggr_idx]->Cast(); - if (state.counts[aggr_idx] == 0 && aggr.function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { - // Default is when 0 values go in, NULL comes out - UnifiedVectorFormat vdata; - chunk.data[aggr_idx].ToUnifiedFormat(1, vdata); - D_ASSERT(!vdata.validity.RowIsValid(vdata.sel->get_index(0))); - } - } -#endif -} - -void GlobalUngroupedAggregateState::Finalize(DataChunk &result, idx_t column_offset) { - result.SetCardinality(1); - for (idx_t aggr_idx = 0; aggr_idx < state.aggregate_expressions.size(); aggr_idx++) { - auto &aggregate = state.aggregate_expressions[aggr_idx]->Cast(); - - Vector state_vector(Value::POINTER(CastPointerToValue(state.aggregate_data[aggr_idx].get()))); - AggregateInputData aggr_input_data(aggregate.bind_info.get(), allocator); - aggregate.function.finalize(state_vector, aggr_input_data, result.data[column_offset + aggr_idx], 1, 0); - } -} - -SourceResultType PhysicalUngroupedAggregate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - D_ASSERT(gstate.finished); - - // initialize the result chunk with the aggregate values - gstate.state.Finalize(chunk); - VerifyNullHandling(chunk, gstate.state.state, aggregates); - - return SourceResultType::FINISHED; -} - -InsertionOrderPreservingMap PhysicalUngroupedAggregate::ParamsToString() const { - InsertionOrderPreservingMap result; - string aggregate_info; - for (idx_t i = 0; i < aggregates.size(); i++) { - auto &aggregate = aggregates[i]->Cast(); - if (i > 0) { - aggregate_info += "\n"; - } - aggregate_info += aggregates[i]->GetName(); - if (aggregate.filter) { - aggregate_info += " Filter: " + aggregate.filter->GetName(); - } - } - result["Aggregates"] = aggregate_info; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp b/src/duckdb/src/execution/operator/aggregate/physical_window.cpp deleted file mode 100644 index 8b8b2a162..000000000 --- a/src/duckdb/src/execution/operator/aggregate/physical_window.cpp +++ /dev/null @@ -1,1014 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_window.hpp" - -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/function/window/window_aggregate_function.hpp" -#include "duckdb/function/window/window_executor.hpp" -#include "duckdb/function/window/window_rank_function.hpp" -#include "duckdb/function/window/window_rownumber_function.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/function/window/window_value_function.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" -// -#include - -namespace duckdb { - -// Global sink state -class WindowGlobalSinkState; - -enum WindowGroupStage : uint8_t { SINK, FINALIZE, GETDATA, DONE }; - -class WindowHashGroup { -public: - using HashGroupPtr = unique_ptr; - using OrderMasks = PartitionGlobalHashGroup::OrderMasks; - using ExecutorGlobalStatePtr = unique_ptr; - using ExecutorGlobalStates = vector; - using ExecutorLocalStatePtr = unique_ptr; - using ExecutorLocalStates = vector; - using ThreadLocalStates = vector; - - WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p); - - ExecutorGlobalStates &Initialize(WindowGlobalSinkState &gstate); - - // Scan all of the blocks during the build phase - unique_ptr GetBuildScanner(idx_t block_idx) const { - if (!rows) { - return nullptr; - } - return make_uniq(*rows, *heap, layout, external, block_idx, false); - } - - // Scan a single block during the evaluate phase - unique_ptr GetEvaluateScanner(idx_t block_idx) const { - // Second pass can flush - D_ASSERT(rows); - return make_uniq(*rows, *heap, layout, external, block_idx, true); - } - - // The processing stage for this group - WindowGroupStage GetStage() const { - return stage; - } - - bool TryPrepareNextStage() { - lock_guard prepare_guard(lock); - switch (stage.load()) { - case WindowGroupStage::SINK: - if (sunk == count) { - stage = WindowGroupStage::FINALIZE; - return true; - } - return false; - case WindowGroupStage::FINALIZE: - if (finalized == blocks) { - stage = WindowGroupStage::GETDATA; - return true; - } - return false; - default: - // never block in GETDATA - return true; - } - } - - //! The hash partition data - HashGroupPtr hash_group; - //! The size of the group - idx_t count = 0; - //! The number of blocks in the group - idx_t blocks = 0; - unique_ptr rows; - unique_ptr heap; - RowLayout layout; - //! The partition boundary mask - ValidityMask partition_mask; - //! The order boundary mask - OrderMasks order_masks; - //! The fully materialised data collection - unique_ptr collection; - //! External paging - bool external; - // The processing stage for this group - atomic stage; - //! The function global states for this hash group - ExecutorGlobalStates gestates; - //! Executor local states, one per thread - ThreadLocalStates thread_states; - - //! The bin number - idx_t hash_bin; - //! Single threading lock - mutex lock; - //! Count of sunk rows - std::atomic sunk; - //! Count of finalized blocks - std::atomic finalized; - //! The number of tasks left before we should be deleted - std::atomic tasks_remaining; - //! The output ordering batch index this hash group starts at - idx_t batch_base; - -private: - void MaterializeSortedData(); -}; - -class WindowPartitionGlobalSinkState; - -class WindowGlobalSinkState : public GlobalSinkState { -public: - using ExecutorPtr = unique_ptr; - using Executors = vector; - - WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context); - - //! Parent operator - const PhysicalWindow &op; - //! Execution context - ClientContext &context; - //! The partitioned sunk data - unique_ptr global_partition; - //! The execution functions - Executors executors; - //! The shared expressions library - WindowSharedExpressions shared; -}; - -class WindowPartitionGlobalSinkState : public PartitionGlobalSinkState { -public: - using WindowHashGroupPtr = unique_ptr; - - WindowPartitionGlobalSinkState(WindowGlobalSinkState &gsink, const BoundWindowExpression &wexpr) - : PartitionGlobalSinkState(gsink.context, wexpr.partitions, wexpr.orders, gsink.op.children[0]->types, - wexpr.partitions_stats, gsink.op.estimated_cardinality), - gsink(gsink) { - } - ~WindowPartitionGlobalSinkState() override = default; - - void OnBeginMerge() override { - PartitionGlobalSinkState::OnBeginMerge(); - window_hash_groups.resize(hash_groups.size()); - } - - void OnSortedPartition(const idx_t group_idx) override { - PartitionGlobalSinkState::OnSortedPartition(group_idx); - window_hash_groups[group_idx] = make_uniq(gsink, group_idx); - } - - //! Operator global sink state - WindowGlobalSinkState &gsink; - //! The sorted hash groups - vector window_hash_groups; -}; - -// Per-thread sink state -class WindowLocalSinkState : public LocalSinkState { -public: - WindowLocalSinkState(ClientContext &context, const WindowGlobalSinkState &gstate) - : local_partition(context, *gstate.global_partition) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); - } - - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; -}; - -// this implements a sorted window functions variant -PhysicalWindow::PhysicalWindow(vector types, vector> select_list_p, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list_p)), - order_idx(0), is_order_dependent(false) { - - idx_t max_orders = 0; - for (idx_t i = 0; i < select_list.size(); ++i) { - auto &expr = select_list[i]; - D_ASSERT(expr->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &bound_window = expr->Cast(); - if (bound_window.partitions.empty() && bound_window.orders.empty()) { - is_order_dependent = true; - } - - if (bound_window.orders.size() > max_orders) { - order_idx = i; - max_orders = bound_window.orders.size(); - } - } -} - -static unique_ptr WindowExecutorFactory(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared, WindowAggregationMode mode) { - switch (wexpr.GetExpressionType()) { - case ExpressionType::WINDOW_AGGREGATE: - return make_uniq(wexpr, context, shared, mode); - case ExpressionType::WINDOW_ROW_NUMBER: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_RANK_DENSE: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_RANK: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_PERCENT_RANK: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_CUME_DIST: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_NTILE: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_FIRST_VALUE: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_LAST_VALUE: - return make_uniq(wexpr, context, shared); - case ExpressionType::WINDOW_NTH_VALUE: - return make_uniq(wexpr, context, shared); - break; - default: - throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.GetExpressionType())); - } -} - -WindowGlobalSinkState::WindowGlobalSinkState(const PhysicalWindow &op, ClientContext &context) - : op(op), context(context) { - - D_ASSERT(op.select_list[op.order_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[op.order_idx]->Cast(); - - const auto mode = DBConfig::GetConfig(context).options.window_mode; - for (idx_t expr_idx = 0; expr_idx < op.select_list.size(); ++expr_idx) { - D_ASSERT(op.select_list[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.select_list[expr_idx]->Cast(); - auto wexec = WindowExecutorFactory(wexpr, context, shared, mode); - executors.emplace_back(std::move(wexec)); - } - - global_partition = make_uniq(*this, wexpr); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalWindow::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - lstate.Sink(chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalWindow::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalWindow::GetLocalSinkState(ExecutionContext &context) const { - auto &gstate = sink_state->Cast(); - return make_uniq(context.client, gstate); -} - -unique_ptr PhysicalWindow::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -SinkFinalizeType PhysicalWindow::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &state = input.global_state.Cast(); - - // Did we get any data? - if (!state.global_partition->count) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Do we have any sorting to schedule? - if (state.global_partition->rows) { - D_ASSERT(!state.global_partition->grouping_data); - return state.global_partition->rows->count ? SinkFinalizeType::READY : SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Find the first group to sort - if (!state.global_partition->HasMergeTasks()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(*state.global_partition, pipeline, *this); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class WindowGlobalSourceState : public GlobalSourceState { -public: - using ScannerPtr = unique_ptr; - - struct Task { - Task(WindowGroupStage stage, idx_t group_idx, idx_t max_idx) - : stage(stage), group_idx(group_idx), thread_idx(0), max_idx(max_idx) { - } - WindowGroupStage stage; - //! The hash group - idx_t group_idx; - //! The thread index (for local state) - idx_t thread_idx; - //! The total block index count - idx_t max_idx; - //! The first block index count - idx_t begin_idx = 0; - //! The end block index count - idx_t end_idx = 0; - }; - using TaskPtr = optional_ptr; - - WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p); - - //! Build task list - void CreateTaskList(); - - //! Are there any more tasks? - bool HasMoreTasks() const { - return !stopped && next_task < tasks.size(); - } - bool HasUnfinishedTasks() const { - return !stopped && finished < tasks.size(); - } - //! Try to advance the group stage - bool TryPrepareNextStage(); - //! Get the next task given the current state - bool TryNextTask(TaskPtr &task); - //! Finish a task - void FinishTask(TaskPtr task); - - //! Context for executing computations - ClientContext &context; - //! All the sunk data - WindowGlobalSinkState &gsink; - //! The total number of blocks to process; - idx_t total_blocks = 0; - //! The number of local states - atomic locals; - //! The list of tasks - vector tasks; - //! The the next task - atomic next_task; - //! The the number of finished tasks - atomic finished; - //! Stop producing tasks - atomic stopped; - //! The number of rows returned - atomic returned; - -public: - idx_t MaxThreads() override { - return total_blocks; - } -}; - -WindowGlobalSourceState::WindowGlobalSourceState(ClientContext &context_p, WindowGlobalSinkState &gsink_p) - : context(context_p), gsink(gsink_p), locals(0), next_task(0), finished(0), stopped(false), returned(0) { - auto &gpart = gsink.global_partition; - auto &window_hash_groups = gsink.global_partition->window_hash_groups; - - if (window_hash_groups.empty()) { - // OVER() - if (gpart->rows && !gpart->rows->blocks.empty()) { - // We need to construct the single WindowHashGroup here because the sort tasks will not be run. - window_hash_groups.emplace_back(make_uniq(gsink, idx_t(0))); - total_blocks = gpart->rows->blocks.size(); - } - } else { - idx_t batch_base = 0; - for (auto &window_hash_group : window_hash_groups) { - if (!window_hash_group) { - continue; - } - auto &rows = window_hash_group->rows; - if (!rows) { - continue; - } - - const auto block_count = window_hash_group->rows->blocks.size(); - window_hash_group->batch_base = batch_base; - batch_base += block_count; - } - total_blocks = batch_base; - } -} - -void WindowGlobalSourceState::CreateTaskList() { - // Check whether we have a task list outside the mutex. - if (next_task.load()) { - return; - } - - auto guard = Lock(); - - auto &window_hash_groups = gsink.global_partition->window_hash_groups; - if (!tasks.empty()) { - return; - } - - // Sort the groups from largest to smallest - if (window_hash_groups.empty()) { - return; - } - - using PartitionBlock = std::pair; - vector partition_blocks; - for (idx_t group_idx = 0; group_idx < window_hash_groups.size(); ++group_idx) { - auto &window_hash_group = window_hash_groups[group_idx]; - partition_blocks.emplace_back(window_hash_group->rows->blocks.size(), group_idx); - } - std::sort(partition_blocks.begin(), partition_blocks.end(), std::greater()); - - // Schedule the largest group on as many threads as possible - const auto threads = locals.load(); - const auto &max_block = partition_blocks.front(); - const auto per_thread = (max_block.first + threads - 1) / threads; - if (!per_thread) { - throw InternalException("No blocks per thread! %ld threads, %ld groups, %ld blocks, %ld hash group", threads, - partition_blocks.size(), max_block.first, max_block.second); - } - - // TODO: Generate dynamically instead of building a big list? - vector states {WindowGroupStage::SINK, WindowGroupStage::FINALIZE, WindowGroupStage::GETDATA}; - for (const auto &b : partition_blocks) { - auto &window_hash_group = *window_hash_groups[b.second]; - for (const auto &state : states) { - idx_t thread_count = 0; - for (Task task(state, b.second, b.first); task.begin_idx < task.max_idx; task.begin_idx += per_thread) { - task.end_idx = MinValue(task.begin_idx + per_thread, task.max_idx); - tasks.emplace_back(task); - window_hash_group.tasks_remaining++; - thread_count = ++task.thread_idx; - } - window_hash_group.thread_states.resize(thread_count); - } - } -} - -void WindowHashGroup::MaterializeSortedData() { - auto &global_sort_state = *hash_group->global_sort; - if (global_sort_state.sorted_blocks.empty()) { - return; - } - - // scan the sorted row data - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - auto &sb = *global_sort_state.sorted_blocks[0]; - - // Free up some memory before allocating more - sb.radix_sorting_data.clear(); - sb.blob_sorting_data = nullptr; - - // Move the sorting row blocks into our RDCs - auto &buffer_manager = global_sort_state.buffer_manager; - auto &sd = *sb.payload_data; - - // Data blocks are required - D_ASSERT(!sd.data_blocks.empty()); - auto &block = sd.data_blocks[0]; - rows = make_uniq(buffer_manager, block->capacity, block->entry_size); - rows->blocks = std::move(sd.data_blocks); - rows->count = std::accumulate(rows->blocks.begin(), rows->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); - - // Heap blocks are optional, but we want both for iteration. - if (!sd.heap_blocks.empty()) { - auto &block = sd.heap_blocks[0]; - heap = make_uniq(buffer_manager, block->capacity, block->entry_size); - heap->blocks = std::move(sd.heap_blocks); - hash_group.reset(); - } else { - heap = make_uniq(buffer_manager, buffer_manager.GetBlockSize(), 1U, true); - } - heap->count = std::accumulate(heap->blocks.begin(), heap->blocks.end(), idx_t(0), - [&](idx_t c, const unique_ptr &b) { return c + b->count; }); -} - -WindowHashGroup::WindowHashGroup(WindowGlobalSinkState &gstate, const idx_t hash_bin_p) - : count(0), blocks(0), stage(WindowGroupStage::SINK), hash_bin(hash_bin_p), sunk(0), finalized(0), - tasks_remaining(0), batch_base(0) { - // There are three types of partitions: - // 1. No partition (no sorting) - // 2. One partition (sorting, but no hashing) - // 3. Multiple partitions (sorting and hashing) - - // How big is the partition? - auto &gpart = *gstate.global_partition; - layout.Initialize(gpart.payload_types); - if (hash_bin < gpart.hash_groups.size() && gpart.hash_groups[hash_bin]) { - count = gpart.hash_groups[hash_bin]->count; - } else if (gpart.rows && !hash_bin) { - count = gpart.count; - } else { - return; - } - - // Initialise masks to false - partition_mask.Initialize(count); - partition_mask.SetAllInvalid(count); - - const auto &executors = gstate.executors; - for (auto &wexec : executors) { - auto &wexpr = wexec->wexpr; - auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; - if (order_mask.IsMaskSet()) { - continue; - } - order_mask.Initialize(count); - order_mask.SetAllInvalid(count); - } - - // Scan the sorted data into new Collections - external = gpart.external; - if (gpart.rows && !hash_bin) { - // Simple mask - partition_mask.SetValidUnsafe(0); - for (auto &order_mask : order_masks) { - order_mask.second.SetValidUnsafe(0); - } - // No partition - align the heap blocks with the row blocks - rows = gpart.rows->CloneEmpty(gpart.rows->keep_pinned); - heap = gpart.strings->CloneEmpty(gpart.strings->keep_pinned); - RowDataCollectionScanner::AlignHeapBlocks(*rows, *heap, *gpart.rows, *gpart.strings, layout); - external = true; - } else if (hash_bin < gpart.hash_groups.size()) { - // Overwrite the collections with the sorted data - D_ASSERT(gpart.hash_groups[hash_bin].get()); - hash_group = std::move(gpart.hash_groups[hash_bin]); - hash_group->ComputeMasks(partition_mask, order_masks); - external = hash_group->global_sort->external; - MaterializeSortedData(); - } - - if (rows) { - blocks = rows->blocks.size(); - } - - // Set up the collection for any fully materialised data - const auto &shared = WindowSharedExpressions::GetSortedExpressions(gstate.shared.coll_shared); - vector types; - for (auto &expr : shared) { - types.emplace_back(expr->return_type); - } - auto &buffer_manager = BufferManager::GetBufferManager(gstate.context); - collection = make_uniq(buffer_manager, count, types); -} - -// Per-thread scan state -class WindowLocalSourceState : public LocalSourceState { -public: - using Task = WindowGlobalSourceState::Task; - using TaskPtr = optional_ptr; - - explicit WindowLocalSourceState(WindowGlobalSourceState &gsource); - - //! Does the task have more work to do? - bool TaskFinished() const { - return !task || task->begin_idx == task->end_idx; - } - //! Assign the next task - bool TryAssignTask(); - //! Execute a step in the current task - void ExecuteTask(DataChunk &chunk); - - //! The shared source state - WindowGlobalSourceState &gsource; - //! The current batch index (for output reordering) - idx_t batch_index; - //! The task this thread is working on - TaskPtr task; - //! The current source being processed - optional_ptr window_hash_group; - //! The scan cursor - unique_ptr scanner; - //! Buffer for the inputs - DataChunk input_chunk; - //! Buffer for window results - DataChunk output_chunk; - -protected: - void Sink(); - void Finalize(); - void GetData(DataChunk &chunk); - - //! Storage and evaluation for the fully materialised data - unique_ptr builder; - ExpressionExecutor coll_exec; - DataChunk coll_chunk; - - //! Storage and evaluation for chunks used in the sink/build phase - ExpressionExecutor sink_exec; - DataChunk sink_chunk; - - //! Storage and evaluation for chunks used in the evaluate phase - ExpressionExecutor eval_exec; - DataChunk eval_chunk; -}; - -WindowHashGroup::ExecutorGlobalStates &WindowHashGroup::Initialize(WindowGlobalSinkState &gsink) { - // Single-threaded building as this is mostly memory allocation - lock_guard gestate_guard(lock); - const auto &executors = gsink.executors; - if (gestates.size() == executors.size()) { - return gestates; - } - - // These can be large so we defer building them until we are ready. - for (auto &wexec : executors) { - auto &wexpr = wexec->wexpr; - auto &order_mask = order_masks[wexpr.partitions.size() + wexpr.orders.size()]; - gestates.emplace_back(wexec->GetGlobalState(count, partition_mask, order_mask)); - } - - return gestates; -} - -void WindowLocalSourceState::Sink() { - D_ASSERT(task); - D_ASSERT(task->stage == WindowGroupStage::SINK); - - auto &gsink = gsource.gsink; - const auto &executors = gsink.executors; - - // Create the global state for each function - // These can be large so we defer building them until we are ready. - auto &gestates = window_hash_group->Initialize(gsink); - - // Set up the local states - auto &local_states = window_hash_group->thread_states.at(task->thread_idx); - if (local_states.empty()) { - for (idx_t w = 0; w < executors.size(); ++w) { - local_states.emplace_back(executors[w]->GetLocalState(*gestates[w])); - } - } - - // First pass over the input without flushing - for (; task->begin_idx < task->end_idx; ++task->begin_idx) { - scanner = window_hash_group->GetBuildScanner(task->begin_idx); - if (!scanner) { - break; - } - while (true) { - // TODO: Try to align on validity mask boundaries by starting ragged? - idx_t input_idx = scanner->Scanned(); - input_chunk.Reset(); - scanner->Scan(input_chunk); - if (input_chunk.size() == 0) { - break; - } - - // Compute fully materialised expressions - if (coll_chunk.data.empty()) { - coll_chunk.SetCardinality(input_chunk); - } else { - coll_chunk.Reset(); - coll_exec.Execute(input_chunk, coll_chunk); - auto collection = window_hash_group->collection.get(); - if (!builder || &builder->collection != collection) { - builder = make_uniq(*collection); - } - - builder->Sink(coll_chunk, input_idx); - } - - // Compute sink expressions - if (sink_chunk.data.empty()) { - sink_chunk.SetCardinality(input_chunk); - } else { - sink_chunk.Reset(); - sink_exec.Execute(input_chunk, sink_chunk); - } - - for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Sink(sink_chunk, coll_chunk, input_idx, *gestates[w], *local_states[w]); - } - - window_hash_group->sunk += input_chunk.size(); - } - - // External scanning assumes all blocks are swizzled. - scanner->SwizzleBlock(task->begin_idx); - scanner.reset(); - } -} - -void WindowLocalSourceState::Finalize() { - D_ASSERT(task); - D_ASSERT(task->stage == WindowGroupStage::FINALIZE); - - // First finalize the collection (so the executors can use it) - auto &gsink = gsource.gsink; - if (window_hash_group->collection) { - window_hash_group->collection->Combine(gsink.shared.coll_validity); - } - - // Finalize all the executors. - // Parallel finalisation is handled internally by the executor, - // and should not return until all threads have completed work. - const auto &executors = gsink.executors; - auto &gestates = window_hash_group->gestates; - auto &local_states = window_hash_group->thread_states.at(task->thread_idx); - for (idx_t w = 0; w < executors.size(); ++w) { - executors[w]->Finalize(*gestates[w], *local_states[w], window_hash_group->collection); - } - - // Mark this range as done - window_hash_group->finalized += (task->end_idx - task->begin_idx); - task->begin_idx = task->end_idx; -} - -WindowLocalSourceState::WindowLocalSourceState(WindowGlobalSourceState &gsource) - : gsource(gsource), batch_index(0), coll_exec(gsource.context), sink_exec(gsource.context), - eval_exec(gsource.context) { - auto &gsink = gsource.gsink; - auto &global_partition = *gsink.global_partition; - - input_chunk.Initialize(global_partition.allocator, global_partition.payload_types); - - vector output_types; - for (auto &wexec : gsink.executors) { - auto &wexpr = wexec->wexpr; - output_types.emplace_back(wexpr.return_type); - } - output_chunk.Initialize(global_partition.allocator, output_types); - - auto &shared = gsink.shared; - shared.PrepareCollection(coll_exec, coll_chunk); - shared.PrepareSink(sink_exec, sink_chunk); - shared.PrepareEvaluate(eval_exec, eval_chunk); - - ++gsource.locals; -} - -bool WindowGlobalSourceState::TryNextTask(TaskPtr &task) { - auto guard = Lock(); - if (next_task >= tasks.size() || stopped) { - task = nullptr; - return false; - } - - // If the next task matches the current state of its group, then we can use it - // Otherwise block. - task = &tasks[next_task]; - - auto &gpart = *gsink.global_partition; - auto &window_hash_group = gpart.window_hash_groups[task->group_idx]; - auto group_stage = window_hash_group->GetStage(); - - if (task->stage == group_stage) { - ++next_task; - return true; - } - - task = nullptr; - return false; -} - -void WindowGlobalSourceState::FinishTask(TaskPtr task) { - if (!task) { - return; - } - - auto &gpart = *gsink.global_partition; - auto &finished_hash_group = gpart.window_hash_groups[task->group_idx]; - D_ASSERT(finished_hash_group); - - if (!--finished_hash_group->tasks_remaining) { - finished_hash_group.reset(); - } -} - -bool WindowLocalSourceState::TryAssignTask() { - // Because downstream operators may be using our internal buffers, - // we can't "finish" a task until we are about to get the next one. - - // Scanner first, as it may be referencing sort blocks in the hash group - scanner.reset(); - gsource.FinishTask(task); - - return gsource.TryNextTask(task); -} - -bool WindowGlobalSourceState::TryPrepareNextStage() { - if (next_task >= tasks.size() || stopped) { - return true; - } - - auto task = &tasks[next_task]; - auto window_hash_group = gsink.global_partition->window_hash_groups[task->group_idx].get(); - return window_hash_group->TryPrepareNextStage(); -} - -void WindowLocalSourceState::ExecuteTask(DataChunk &result) { - auto &gsink = gsource.gsink; - - // Update the hash group - window_hash_group = gsink.global_partition->window_hash_groups[task->group_idx].get(); - - // Process the new state - switch (task->stage) { - case WindowGroupStage::SINK: - Sink(); - D_ASSERT(TaskFinished()); - break; - case WindowGroupStage::FINALIZE: - Finalize(); - D_ASSERT(TaskFinished()); - break; - case WindowGroupStage::GETDATA: - D_ASSERT(!TaskFinished()); - GetData(result); - break; - default: - throw InternalException("Invalid window source state."); - } - - // Count this task as finished. - if (TaskFinished()) { - ++gsource.finished; - } -} - -void WindowLocalSourceState::GetData(DataChunk &result) { - D_ASSERT(window_hash_group->GetStage() == WindowGroupStage::GETDATA); - - if (!scanner || !scanner->Remaining()) { - scanner = window_hash_group->GetEvaluateScanner(task->begin_idx); - batch_index = window_hash_group->batch_base + task->begin_idx; - } - - const auto position = scanner->Scanned(); - input_chunk.Reset(); - scanner->Scan(input_chunk); - - const auto &executors = gsource.gsink.executors; - auto &gestates = window_hash_group->gestates; - auto &local_states = window_hash_group->thread_states.at(task->thread_idx); - output_chunk.Reset(); - for (idx_t expr_idx = 0; expr_idx < executors.size(); ++expr_idx) { - auto &executor = *executors[expr_idx]; - auto &gstate = *gestates[expr_idx]; - auto &lstate = *local_states[expr_idx]; - auto &result = output_chunk.data[expr_idx]; - if (eval_chunk.data.empty()) { - eval_chunk.SetCardinality(input_chunk); - } else { - eval_chunk.Reset(); - eval_exec.Execute(input_chunk, eval_chunk); - } - executor.Evaluate(position, eval_chunk, result, lstate, gstate); - } - output_chunk.SetCardinality(input_chunk); - output_chunk.Verify(); - - idx_t out_idx = 0; - result.SetCardinality(input_chunk); - for (idx_t col_idx = 0; col_idx < input_chunk.ColumnCount(); col_idx++) { - result.data[out_idx++].Reference(input_chunk.data[col_idx]); - } - for (idx_t col_idx = 0; col_idx < output_chunk.ColumnCount(); col_idx++) { - result.data[out_idx++].Reference(output_chunk.data[col_idx]); - } - - // If we done with this block, move to the next one - if (!scanner->Remaining()) { - ++task->begin_idx; - } - - // If that was the last block, release out local state memory. - if (TaskFinished()) { - local_states.clear(); - } - result.Verify(); -} - -unique_ptr PhysicalWindow::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gsource_p) const { - auto &gsource = gsource_p.Cast(); - return make_uniq(gsource); -} - -unique_ptr PhysicalWindow::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(context, gsink); -} - -bool PhysicalWindow::SupportsPartitioning(const OperatorPartitionInfo &partition_info) const { - if (partition_info.RequiresPartitionColumns()) { - return false; - } - // We can only preserve order for single partitioning - // or work stealing causes out of order batch numbers - auto &wexpr = select_list[order_idx]->Cast(); - return wexpr.partitions.empty(); // NOLINT -} - -OrderPreservationType PhysicalWindow::SourceOrder() const { - auto &wexpr = select_list[order_idx]->Cast(); - if (!wexpr.partitions.empty()) { - // if we have partitions the window order is not defined - return OrderPreservationType::NO_ORDER; - } - // without partitions we can maintain order - if (wexpr.orders.empty()) { - // if we have no orders we maintain insertion order - return OrderPreservationType::INSERTION_ORDER; - } - // otherwise we can maintain the fixed order - return OrderPreservationType::FIXED_ORDER; -} - -ProgressData PhysicalWindow::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { - auto &gsource = gsource_p.Cast(); - const auto returned = gsource.returned.load(); - - auto &gsink = gsource.gsink; - const auto count = gsink.global_partition->count.load(); - ProgressData res; - if (count) { - res.done = double(returned); - res.total = double(count); - } else { - res.SetInvalid(); - } - return res; -} - -OperatorPartitionData PhysicalWindow::GetPartitionData(ExecutionContext &context, DataChunk &chunk, - GlobalSourceState &gstate_p, LocalSourceState &lstate_p, - const OperatorPartitionInfo &partition_info) const { - if (partition_info.RequiresPartitionColumns()) { - throw InternalException("PhysicalWindow::GetPartitionData: partition columns not supported"); - } - auto &lstate = lstate_p.Cast(); - return OperatorPartitionData(lstate.batch_index); -} - -SourceResultType PhysicalWindow::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - - gsource.CreateTaskList(); - - while (gsource.HasUnfinishedTasks() && chunk.size() == 0) { - if (!lsource.TaskFinished() || lsource.TryAssignTask()) { - try { - lsource.ExecuteTask(chunk); - } catch (...) { - gsource.stopped = true; - throw; - } - } else { - auto guard = gsource.Lock(); - if (!gsource.HasMoreTasks()) { - // no more tasks - exit - gsource.UnblockTasks(guard); - break; - } - if (gsource.TryPrepareNextStage()) { - // we successfully prepared the next stage - unblock tasks - gsource.UnblockTasks(guard); - } else { - // there are more tasks available, but we can't execute them yet - // block the source - return gsource.BlockSource(guard, input.interrupt_state); - } - } - } - - gsource.returned += chunk.size(); - - if (chunk.size() == 0) { - return SourceResultType::FINISHED; - } - return SourceResultType::HAVE_MORE_OUTPUT; -} - -InsertionOrderPreservingMap PhysicalWindow::ParamsToString() const { - InsertionOrderPreservingMap result; - string projections; - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - projections += "\n"; - } - projections += select_list[i]->GetName(); - } - result["Projections"] = projections; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp deleted file mode 100644 index c18c5f61e..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer.cpp +++ /dev/null @@ -1,94 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_buffer.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -CSVBuffer::CSVBuffer(ClientContext &context, idx_t buffer_size_p, CSVFileHandle &file_handle, - idx_t &global_csv_current_position, idx_t file_number_p) - : context(context), requested_size(buffer_size_p), file_number(file_number_p), can_seek(file_handle.CanSeek()), - is_pipe(file_handle.IsPipe()) { - AllocateBuffer(buffer_size_p); - auto buffer = Ptr(); - actual_buffer_size = file_handle.Read(buffer, buffer_size_p); - while (actual_buffer_size < buffer_size_p && !file_handle.FinishedReading()) { - // We keep reading until this block is full - actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size_p - actual_buffer_size); - } - global_csv_start = global_csv_current_position; - last_buffer = file_handle.FinishedReading(); -} - -CSVBuffer::CSVBuffer(CSVFileHandle &file_handle, ClientContext &context, idx_t buffer_size, - idx_t global_csv_current_position, idx_t file_number_p, idx_t buffer_idx_p) - : context(context), requested_size(buffer_size), global_csv_start(global_csv_current_position), - file_number(file_number_p), can_seek(file_handle.CanSeek()), is_pipe(file_handle.IsPipe()), - buffer_idx(buffer_idx_p) { - AllocateBuffer(buffer_size); - auto buffer = handle.Ptr(); - actual_buffer_size = file_handle.Read(handle.Ptr(), buffer_size); - while (actual_buffer_size < buffer_size && !file_handle.FinishedReading()) { - // We keep reading until this block is full - actual_buffer_size += file_handle.Read(&buffer[actual_buffer_size], buffer_size - actual_buffer_size); - } - last_buffer = file_handle.FinishedReading(); -} - -shared_ptr CSVBuffer::Next(CSVFileHandle &file_handle, idx_t buffer_size, idx_t file_number_p, - bool &has_seaked) { - if (has_seaked) { - // This means that at some point a reload was done, and we are currently on the incorrect position in our file - // handle - file_handle.Seek(global_csv_start + actual_buffer_size); - has_seaked = false; - } - auto next_csv_buffer = make_shared_ptr( - file_handle, context, buffer_size, global_csv_start + actual_buffer_size, file_number_p, buffer_idx + 1); - if (next_csv_buffer->GetBufferSize() == 0) { - // We are done reading - return nullptr; - } - return next_csv_buffer; -} - -void CSVBuffer::AllocateBuffer(idx_t buffer_size) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - bool can_destroy = !is_pipe; - handle = buffer_manager.Allocate(MemoryTag::CSV_READER, MaxValue(buffer_manager.GetBlockSize(), buffer_size), - can_destroy); - block = handle.GetBlockHandle(); -} - -idx_t CSVBuffer::GetBufferSize() const { - return actual_buffer_size; -} - -void CSVBuffer::Reload(CSVFileHandle &file_handle) { - AllocateBuffer(actual_buffer_size); - // If we can seek, we seek and return the correct pointers - file_handle.Seek(global_csv_start); - file_handle.Read(handle.Ptr(), actual_buffer_size); -} - -shared_ptr CSVBuffer::Pin(CSVFileHandle &file_handle, bool &has_seeked) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - if (!is_pipe && block->IsUnloaded()) { - // We have to reload it from disk - block = nullptr; - Reload(file_handle); - has_seeked = true; - } - return make_shared_ptr(buffer_manager.Pin(block), actual_buffer_size, requested_size, last_buffer, - file_number, buffer_idx); -} - -void CSVBuffer::Unpin() { - if (handle.IsValid()) { - handle.Destroy(); - } -} - -bool CSVBuffer::IsCSVFileLastBuffer() const { - return last_buffer; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp deleted file mode 100644 index 811689917..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_buffer_manager.cpp +++ /dev/null @@ -1,160 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_buffer_manager.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_buffer.hpp" -#include "duckdb/function/table/read_csv.hpp" -namespace duckdb { - -CSVBufferManager::CSVBufferManager(ClientContext &context_p, const CSVReaderOptions &options, const string &file_path_p, - const idx_t file_idx_p, bool per_file_single_threaded_p, - unique_ptr file_handle_p) - : context(context_p), per_file_single_threaded(per_file_single_threaded_p), file_idx(file_idx_p), - file_path(file_path_p), buffer_size(options.buffer_size_option.GetValue()) { - D_ASSERT(!file_path.empty()); - if (file_handle_p) { - file_handle = std::move(file_handle_p); - } else { - file_handle = ReadCSV::OpenCSV(file_path, options, context); - } - is_pipe = file_handle->IsPipe(); - skip_rows = options.dialect_options.skip_rows.GetValue(); - Initialize(); -} - -void CSVBufferManager::UnpinBuffer(const idx_t cache_idx) { - if (cache_idx < cached_buffers.size()) { - cached_buffers[cache_idx]->Unpin(); - } -} - -void CSVBufferManager::Initialize() { - if (cached_buffers.empty()) { - cached_buffers.emplace_back( - make_shared_ptr(context, buffer_size, *file_handle, global_csv_pos, file_idx)); - last_buffer = cached_buffers.front(); - } -} - -bool CSVBufferManager::ReadNextAndCacheIt() { - D_ASSERT(last_buffer); - for (idx_t i = 0; i < 2; i++) { - if (!last_buffer->IsCSVFileLastBuffer()) { - auto maybe_last_buffer = last_buffer->Next(*file_handle, buffer_size, file_idx, has_seeked); - if (!maybe_last_buffer) { - last_buffer->last_buffer = true; - return false; - } - last_buffer = std::move(maybe_last_buffer); - bytes_read += last_buffer->GetBufferSize(); - cached_buffers.emplace_back(last_buffer); - return true; - } - } - return false; -} - -shared_ptr CSVBufferManager::GetBuffer(const idx_t pos) { - lock_guard parallel_lock(main_mutex); - if (pos == 0 && done && cached_buffers.empty()) { - if (is_pipe) { - throw InvalidInputException("Recursive CTEs are not allowed when using piped csv files"); - } - // This is a recursive CTE, we have to reset out whole buffer - done = false; - file_handle->Reset(); - Initialize(); - } - while (pos >= cached_buffers.size()) { - if (done) { - return nullptr; - } - if (!ReadNextAndCacheIt()) { - done = true; - } - } - if (pos != 0 && (sniffing || file_handle->CanSeek() || per_file_single_threaded)) { - // We don't need to unpin the buffers here if we are not sniffing since we - // control it per-thread on the scan - if (cached_buffers[pos - 1]) { - cached_buffers[pos - 1]->Unpin(); - } - } - return cached_buffers[pos]->Pin(*file_handle, has_seeked); -} - -void CSVBufferManager::ResetBuffer(const idx_t buffer_idx) { - lock_guard parallel_lock(main_mutex); - if (buffer_idx >= cached_buffers.size()) { - // Nothing to reset - return; - } - D_ASSERT(cached_buffers[buffer_idx]); - if (buffer_idx == 0 && cached_buffers.size() > 1) { - cached_buffers[buffer_idx].reset(); - idx_t cur_buffer = buffer_idx + 1; - while (reset_when_possible.find(cur_buffer) != reset_when_possible.end()) { - cached_buffers[cur_buffer].reset(); - reset_when_possible.erase(cur_buffer); - cur_buffer++; - } - return; - } - // We only reset if previous one was also already reset - if (buffer_idx > 0 && !cached_buffers[buffer_idx - 1]) { - if (cached_buffers[buffer_idx]->last_buffer) { - // We clear the whole shebang - cached_buffers.clear(); - reset_when_possible.clear(); - return; - } - cached_buffers[buffer_idx].reset(); - idx_t cur_buffer = buffer_idx + 1; - while (reset_when_possible.find(cur_buffer) != reset_when_possible.end()) { - cached_buffers[cur_buffer].reset(); - reset_when_possible.erase(cur_buffer); - cur_buffer++; - } - } else { - reset_when_possible.insert(buffer_idx); - } -} - -idx_t CSVBufferManager::GetBufferSize() const { - return buffer_size; -} - -idx_t CSVBufferManager::BufferCount() const { - return cached_buffers.size(); -} - -bool CSVBufferManager::Done() const { - return done; -} - -void CSVBufferManager::ResetBufferManager() { - if (!file_handle->IsPipe()) { - // If this is not a pipe we reset the buffer manager and restart it when doing the actual scan - cached_buffers.clear(); - reset_when_possible.clear(); - file_handle->Reset(); - last_buffer = nullptr; - done = false; - global_csv_pos = 0; - Initialize(); - } -} - -string CSVBufferManager::GetFilePath() const { - return file_path; -} - -bool CSVBufferManager::IsBlockUnloaded(idx_t block_idx) { - if (block_idx < cached_buffers.size()) { - return cached_buffers[block_idx]->IsUnloaded(); - } - return false; -} - -idx_t CSVBufferManager::GetBytesRead() const { - return bytes_read; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp b/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp deleted file mode 100644 index 73fbb4ed0..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/buffer_manager/csv_file_handle.cpp +++ /dev/null @@ -1,126 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_file_handle.hpp" -#include "duckdb/common/exception/binder_exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/compressed_file_system.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" - -namespace duckdb { - -CSVFileHandle::CSVFileHandle(DBConfig &config, unique_ptr file_handle_p, const string &path_p, - const CSVReaderOptions &options) - : compression_type(options.compression), file_handle(std::move(file_handle_p)), - encoder(config, options.encoding, options.buffer_size_option.GetValue()), path(path_p) { - can_seek = file_handle->CanSeek(); - on_disk_file = file_handle->OnDiskFile(); - file_size = file_handle->GetFileSize(); - is_pipe = file_handle->IsPipe(); - compression_type = file_handle->GetFileCompressionType(); -} - -unique_ptr CSVFileHandle::OpenFileHandle(FileSystem &fs, Allocator &allocator, const string &path, - FileCompressionType compression) { - auto file_handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_READ | compression); - if (file_handle->CanSeek()) { - file_handle->Reset(); - } - return file_handle; -} - -unique_ptr CSVFileHandle::OpenFile(DBConfig &config, FileSystem &fs, Allocator &allocator, - const string &path, const CSVReaderOptions &options) { - auto file_handle = OpenFileHandle(fs, allocator, path, options.compression); - return make_uniq(config, std::move(file_handle), path, options); -} - -double CSVFileHandle::GetProgress() const { - return static_cast(file_handle->GetProgress()); -} - -bool CSVFileHandle::CanSeek() const { - return can_seek; -} - -void CSVFileHandle::Seek(const idx_t position) const { - if (!can_seek) { - if (is_pipe) { - throw InternalException("Trying to seek a piped CSV File."); - } - throw InternalException("Trying to seek a compressed CSV File."); - } - file_handle->Seek(position); -} - -bool CSVFileHandle::OnDiskFile() const { - return on_disk_file; -} - -void CSVFileHandle::Reset() { - file_handle->Reset(); - finished = false; - requested_bytes = 0; -} - -bool CSVFileHandle::IsPipe() const { - return is_pipe; -} - -idx_t CSVFileHandle::FileSize() const { - return file_size; -} - -bool CSVFileHandle::FinishedReading() const { - return finished; -} - -idx_t CSVFileHandle::Read(void *buffer, idx_t nr_bytes) { - requested_bytes += nr_bytes; - // if this is a plain file source OR we can seek we are not caching anything - idx_t bytes_read = 0; - if (encoder.encoding_name == "utf-8") { - bytes_read = static_cast(file_handle->Read(buffer, nr_bytes)); - } else { - bytes_read = encoder.Encode(*file_handle, static_cast(buffer), nr_bytes); - } - if (!finished) { - finished = bytes_read == 0; - } - uncompressed_bytes_read += static_cast(bytes_read); - return UnsafeNumericCast(bytes_read); -} - -string CSVFileHandle::ReadLine() { - bool carriage_return = false; - string result; - char buffer[1]; - while (true) { - idx_t bytes_read = Read(buffer, 1); - if (bytes_read == 0) { - return result; - } - if (carriage_return) { - if (buffer[0] != '\n') { - if (!file_handle->CanSeek()) { - throw BinderException( - "Carriage return newlines not supported when reading CSV files in which we cannot seek"); - } - file_handle->Seek(file_handle->SeekPosition() - 1); - return result; - } - } - if (buffer[0] == '\n') { - return result; - } - if (buffer[0] != '\r') { - result += buffer[0]; - } else { - carriage_return = true; - } - } -} - -string CSVFileHandle::GetFilePath() { - return path; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/encode/csv_encoder.cpp b/src/duckdb/src/execution/operator/csv_scanner/encode/csv_encoder.cpp deleted file mode 100644 index 89fc5df04..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/encode/csv_encoder.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/encode/csv_encoder.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/function/encoding_function.hpp" - -namespace duckdb { - -void CSVEncoderBuffer::Initialize(idx_t encoded_size) { - encoded_buffer_size = encoded_size; - encoded_buffer = std::unique_ptr(new char[encoded_size]); -} - -char *CSVEncoderBuffer::Ptr() const { - return encoded_buffer.get(); -} - -idx_t CSVEncoderBuffer::GetCapacity() const { - return encoded_buffer_size; -} - -idx_t CSVEncoderBuffer::GetSize() const { - return actual_encoded_buffer_size; -} - -void CSVEncoderBuffer::SetSize(const idx_t buffer_size) { - D_ASSERT(buffer_size <= encoded_buffer_size); - actual_encoded_buffer_size = buffer_size; -} - -bool CSVEncoderBuffer::HasDataToRead() const { - return cur_pos < actual_encoded_buffer_size; -} - -void CSVEncoderBuffer::Reset() { - cur_pos = 0; - actual_encoded_buffer_size = 0; -} - -CSVEncoder::CSVEncoder(DBConfig &config, const string &encoding_name_to_find, idx_t buffer_size) { - encoding_name = StringUtil::Lower(encoding_name_to_find); - auto function = config.GetEncodeFunction(encoding_name_to_find); - if (!function) { - auto loaded_encodings = config.GetLoadedEncodedFunctions(); - std::ostringstream error; - error << "The CSV Reader does not support the encoding: \"" << encoding_name_to_find << "\"\n"; - error << "The currently supported encodings are: " << '\n'; - for (auto &encoding_function : loaded_encodings) { - error << "* " << encoding_function.get().GetType() << '\n'; - } - throw InvalidInputException(error.str()); - } - // We ensure that the encoded buffer size is an even number to make the two byte lookup on utf-16 work - idx_t encoded_buffer_size = buffer_size % 2 != 0 ? buffer_size - 1 : buffer_size; - D_ASSERT(encoded_buffer_size > 0); - encoded_buffer.Initialize(encoded_buffer_size); - remaining_bytes_buffer.Initialize(function->GetBytesPerIteration()); - encoding_function = function; -} - -idx_t CSVEncoder::Encode(FileHandle &file_handle_input, char *output_buffer, const idx_t decoded_buffer_size) { - idx_t output_buffer_pos = 0; - // Check if we have some left-overs. These can either be - // 1. missing decoded bytes - if (remaining_bytes_buffer.HasDataToRead()) { - D_ASSERT(remaining_bytes_buffer.cur_pos == 0); - const auto remaining_bytes_buffer_ptr = remaining_bytes_buffer.Ptr(); - for (; remaining_bytes_buffer.cur_pos < remaining_bytes_buffer.GetSize(); remaining_bytes_buffer.cur_pos++) { - output_buffer[output_buffer_pos++] = remaining_bytes_buffer_ptr[remaining_bytes_buffer.cur_pos]; - } - remaining_bytes_buffer.Reset(); - } - // 2. remaining encoded buffer - if (encoded_buffer.HasDataToRead()) { - encoding_function->GetFunction()( - encoded_buffer.Ptr(), encoded_buffer.cur_pos, encoded_buffer.GetSize(), output_buffer, output_buffer_pos, - decoded_buffer_size, remaining_bytes_buffer.Ptr(), remaining_bytes_buffer.actual_encoded_buffer_size); - } - // 3. a new encoded buffer from the file - while (output_buffer_pos < decoded_buffer_size) { - idx_t current_decoded_buffer_start = output_buffer_pos; - encoded_buffer.Reset(); - auto actual_encoded_bytes = - static_cast(file_handle_input.Read(encoded_buffer.Ptr(), encoded_buffer.GetCapacity())); - encoded_buffer.SetSize(actual_encoded_bytes); - encoding_function->GetFunction()( - encoded_buffer.Ptr(), encoded_buffer.cur_pos, encoded_buffer.GetSize(), output_buffer, output_buffer_pos, - decoded_buffer_size, remaining_bytes_buffer.Ptr(), remaining_bytes_buffer.actual_encoded_buffer_size); - if (output_buffer_pos == current_decoded_buffer_start) { - return output_buffer_pos; - } - } - return output_buffer_pos; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp deleted file mode 100644 index 757598e14..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/base_scanner.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/base_scanner.hpp" - -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" - -namespace duckdb { - -ScannerResult::ScannerResult(CSVStates &states_p, CSVStateMachine &state_machine_p, idx_t result_size_p) - : result_size(result_size_p), state_machine(state_machine_p), states(states_p) { -} - -BaseScanner::BaseScanner(shared_ptr buffer_manager_p, shared_ptr state_machine_p, - shared_ptr error_handler_p, bool sniffing_p, - shared_ptr csv_file_scan_p, CSVIterator iterator_p) - : csv_file_scan(std::move(csv_file_scan_p)), sniffing(sniffing_p), error_handler(std::move(error_handler_p)), - state_machine(std::move(state_machine_p)), buffer_manager(std::move(buffer_manager_p)), iterator(iterator_p) { - D_ASSERT(buffer_manager); - D_ASSERT(state_machine); - // Initialize current buffer handle - cur_buffer_handle = buffer_manager->GetBuffer(iterator.GetBufferIdx()); - if (!cur_buffer_handle) { - buffer_handle_ptr = nullptr; - } else { - buffer_handle_ptr = cur_buffer_handle->Ptr(); - } -} - -bool BaseScanner::FinishedFile() const { - if (!cur_buffer_handle) { - return true; - } - // we have to scan to infinity, so we must check if we are done checking the whole file - if (!buffer_manager->Done()) { - return false; - } - // If yes, are we in the last buffer? - if (iterator.pos.buffer_idx != buffer_manager->BufferCount()) { - return false; - } - // If yes, are we in the last position? - return iterator.pos.buffer_pos + 1 == cur_buffer_handle->actual_size; -} - -CSVIterator BaseScanner::SkipCSVRows(shared_ptr buffer_manager, - const shared_ptr &state_machine, idx_t rows_to_skip) { - if (rows_to_skip == 0) { - return {}; - } - auto error_handler = make_shared_ptr(); - SkipScanner row_skipper(std::move(buffer_manager), state_machine, error_handler, rows_to_skip); - row_skipper.ParseChunk(); - return row_skipper.GetIterator(); -} - -CSVIterator &BaseScanner::GetIterator() { - return iterator; -} - -void BaseScanner::SetIterator(const CSVIterator &it) { - iterator = it; -} - -ScannerResult &BaseScanner::ParseChunk() { - throw InternalException("ParseChunk() from CSV Base Scanner is not implemented"); -} - -ScannerResult &BaseScanner::GetResult() { - throw InternalException("GetResult() from CSV Base Scanner is not implemented"); -} - -void BaseScanner::Initialize() { - throw InternalException("Initialize() from CSV Base Scanner is not implemented"); -} - -void BaseScanner::FinalizeChunkProcess() { - throw InternalException("FinalizeChunkProcess() from CSV Base Scanner is not implemented"); -} - -CSVStateMachine &BaseScanner::GetStateMachine() const { - return *state_machine; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp deleted file mode 100644 index 7c4b7bb9c..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/column_count_scanner.cpp +++ /dev/null @@ -1,163 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/column_count_scanner.hpp" - -namespace duckdb { - -ColumnCountResult::ColumnCountResult(CSVStates &states, CSVStateMachine &state_machine, idx_t result_size) - : ScannerResult(states, state_machine, result_size) { - column_counts.resize(result_size); -} - -void ColumnCountResult::AddValue(ColumnCountResult &result, idx_t buffer_pos) { - result.current_column_count++; -} - -inline void ColumnCountResult::InternalAddRow() { - const idx_t column_count = current_column_count + 1; - column_counts[result_position].number_of_columns = column_count; - rows_per_column_count[column_count]++; - current_column_count = 0; -} - -idx_t ColumnCountResult::GetMostFrequentColumnCount() const { - if (rows_per_column_count.empty()) { - return 1; - } - idx_t column_count = 0; - idx_t current_max = 0; - for (auto &rpc : rows_per_column_count) { - if (rpc.second > current_max) { - current_max = rpc.second; - column_count = rpc.first; - } else if (rpc.second == current_max) { - // We pick the largest to untie - if (rpc.first > column_count) { - column_count = rpc.first; - } - } - } - return column_count; -} - -bool ColumnCountResult::AddRow(ColumnCountResult &result, idx_t buffer_pos) { - result.InternalAddRow(); - if (!result.states.EmptyLastValue()) { - idx_t col_count_idx = result.result_position; - for (idx_t i = 0; i < result.result_position + 1; i++) { - if (!result.column_counts[col_count_idx].last_value_always_empty) { - break; - } - result.column_counts[col_count_idx--].last_value_always_empty = false; - } - } - result.result_position++; - if (result.result_position >= result.result_size) { - // We sniffed enough rows - return true; - } - return false; -} - -void ColumnCountResult::SetComment(ColumnCountResult &result, idx_t buffer_pos) { - if (!result.states.WasStandard()) { - result.cur_line_starts_as_comment = true; - } - result.comment = true; -} - -bool ColumnCountResult::UnsetComment(ColumnCountResult &result, idx_t buffer_pos) { - // If we are unsetting a comment, it means this row started with a comment char. - // We add the row but tag it as a comment - bool done = result.AddRow(result, buffer_pos); - if (result.cur_line_starts_as_comment) { - result.column_counts[result.result_position - 1].is_comment = true; - } else { - result.column_counts[result.result_position - 1].is_mid_comment = true; - } - result.comment = false; - result.cur_line_starts_as_comment = false; - return done; -} - -void ColumnCountResult::InvalidState(ColumnCountResult &result) { - result.result_position = 0; - result.error = true; -} - -bool ColumnCountResult::EmptyLine(ColumnCountResult &result, idx_t buffer_pos) { - // nop - return false; -} - -void ColumnCountResult::QuotedNewLine(ColumnCountResult &result) { - // nop -} - -ColumnCountScanner::ColumnCountScanner(shared_ptr buffer_manager, - const shared_ptr &state_machine, - shared_ptr error_handler, idx_t result_size_p, - CSVIterator iterator) - : BaseScanner(std::move(buffer_manager), state_machine, std::move(error_handler), true, nullptr, iterator), - result(states, *state_machine, result_size_p), column_count(1), result_size(result_size_p) { - sniffing = true; -} - -unique_ptr ColumnCountScanner::UpgradeToStringValueScanner() { - idx_t rows_to_skip = - std::max(state_machine->dialect_options.skip_rows.GetValue(), state_machine->dialect_options.rows_until_header); - auto iterator = SkipCSVRows(buffer_manager, state_machine, rows_to_skip); - if (iterator.done) { - CSVIterator it {}; - return make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true, it, - result_size); - } - return make_uniq(0U, buffer_manager, state_machine, error_handler, nullptr, true, iterator, - result_size); -} - -ColumnCountResult &ColumnCountScanner::ParseChunk() { - result.result_position = 0; - column_count = 1; - ParseChunkInternal(result); - return result; -} - -ColumnCountResult &ColumnCountScanner::GetResult() { - return result; -} - -void ColumnCountScanner::Initialize() { - states.Initialize(); -} - -void ColumnCountScanner::FinalizeChunkProcess() { - if (result.result_position == result.result_size || result.error) { - // We are done - return; - } - // We run until we have a full chunk, or we are done scanning - while (!FinishedFile() && result.result_position < result.result_size && !result.error) { - if (iterator.pos.buffer_pos == cur_buffer_handle->actual_size) { - // Move to next buffer - cur_buffer_handle = buffer_manager->GetBuffer(++iterator.pos.buffer_idx); - if (!cur_buffer_handle) { - buffer_handle_ptr = nullptr; - if (states.EmptyLine() || states.NewRow() || states.IsCurrentNewRow() || states.IsNotSet()) { - return; - } - // This means we reached the end of the file, we must add a last line if there is any to be added - if (result.comment) { - // If it's a comment we add the last line via unsetcomment - result.UnsetComment(result, NumericLimits::Maximum()); - } else { - // OW, we do a regular AddRow - result.AddRow(result, NumericLimits::Maximum()); - } - return; - } - iterator.pos.buffer_pos = 0; - buffer_handle_ptr = cur_buffer_handle->Ptr(); - } - Process(result); - } -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp deleted file mode 100644 index f99e46a9a..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/csv_schema.cpp +++ /dev/null @@ -1,217 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_schema.hpp" - -namespace duckdb { - -struct TypeIdxPair { - TypeIdxPair(LogicalType type_p, idx_t idx_p) : type(std::move(type_p)), idx(idx_p) { - } - TypeIdxPair() { - } - LogicalType type; - idx_t idx {}; -}; - -// We only really care about types that can be set in the sniffer_auto, or are sniffed by default -// If the user manually sets them, we should never get a cast issue from the sniffer! -bool CSVSchema::CanWeCastIt(LogicalTypeId source, LogicalTypeId destination) { - if (destination == LogicalTypeId::VARCHAR || source == destination) { - // We can always cast to varchar - // And obviously don't have to do anything if they are equal. - return true; - } - switch (source) { - case LogicalTypeId::SQLNULL: - return true; - case LogicalTypeId::TINYINT: - return destination == LogicalTypeId::SMALLINT || destination == LogicalTypeId::INTEGER || - destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || - destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; - case LogicalTypeId::SMALLINT: - return destination == LogicalTypeId::INTEGER || destination == LogicalTypeId::BIGINT || - destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || - destination == LogicalTypeId::DOUBLE; - case LogicalTypeId::INTEGER: - return destination == LogicalTypeId::BIGINT || destination == LogicalTypeId::DECIMAL || - destination == LogicalTypeId::FLOAT || destination == LogicalTypeId::DOUBLE; - case LogicalTypeId::BIGINT: - return destination == LogicalTypeId::DECIMAL || destination == LogicalTypeId::FLOAT || - destination == LogicalTypeId::DOUBLE; - case LogicalTypeId::FLOAT: - return destination == LogicalTypeId::DOUBLE; - default: - return false; - } -} - -void CSVSchema::MergeSchemas(CSVSchema &other, bool null_padding) { - // TODO: We could also merge names, maybe by giving preference to non-generated names? - const vector candidates_by_specificity = {LogicalType::BOOLEAN, LogicalType::BIGINT, - LogicalType::DOUBLE, LogicalType::VARCHAR}; - for (idx_t i = 0; i < columns.size() && i < other.columns.size(); i++) { - auto this_type = columns[i].type.id(); - auto other_type = other.columns[i].type.id(); - if (columns[i].type != other.columns[i].type) { - if (CanWeCastIt(this_type, other_type)) { - // If we can cast this to other, this becomes other - columns[i].type = other.columns[i].type; - } else if (!CanWeCastIt(other_type, this_type)) { - // If we can't cast this to other or other to this, we see which parent they can be both cast to - for (const auto &type : candidates_by_specificity) { - if (CanWeCastIt(this_type, type.id()) && CanWeCastIt(other_type, type.id())) { - columns[i].type = type; - break; - } - } - } - } - } - - if (null_padding && other.columns.size() > columns.size()) { - for (idx_t i = columns.size(); i < other.columns.size(); i++) { - auto name = other.columns[i].name; - auto type = other.columns[i].type; - columns.push_back({name, type}); - name_idx_map[name] = i; - } - } -} - -CSVSchema::CSVSchema(vector &names, vector &types, const string &file_path, idx_t rows_read_p, - const bool empty_p) - : rows_read(rows_read_p), empty(empty_p) { - Initialize(names, types, file_path); -} - -void CSVSchema::Initialize(const vector &names, const vector &types, const string &file_path_p) { - if (!columns.empty()) { - throw InternalException("CSV Schema is already populated, this should not happen."); - } - file_path = file_path_p; - D_ASSERT(names.size() == types.size() && !names.empty()); - for (idx_t i = 0; i < names.size(); i++) { - // Populate our little schema - auto name = names.at(i); - auto type = types.at(i); - columns.push_back({name, type}); - name_idx_map[names[i]] = i; - } -} - -vector CSVSchema::GetNames() const { - vector names; - for (auto &column : columns) { - names.push_back(column.name); - } - return names; -} - -vector CSVSchema::GetTypes() const { - vector types; - for (auto &column : columns) { - types.push_back(column.type); - } - return types; -} - -bool CSVSchema::Empty() const { - return columns.empty(); -} - -bool CSVSchema::MatchColumns(const CSVSchema &other) const { - return other.columns.size() == columns.size() || empty || other.empty; -} - -string CSVSchema::GetPath() const { - return file_path; -} - -idx_t CSVSchema::GetColumnCount() const { - return columns.size(); -} - -idx_t CSVSchema::GetRowsRead() const { - return rows_read; -} - -bool CSVSchema::SchemasMatch(string &error_message, SnifferResult &sniffer_result, const string &cur_file_path, - bool is_minimal_sniffer) const { - D_ASSERT(sniffer_result.names.size() == sniffer_result.return_types.size()); - bool match = true; - unordered_map current_schema; - - for (idx_t i = 0; i < sniffer_result.names.size(); i++) { - // Populate our little schema - current_schema[sniffer_result.names[i]] = {sniffer_result.return_types[i], i}; - } - if (is_minimal_sniffer) { - auto min_sniffer = static_cast(sniffer_result); - if (!min_sniffer.more_than_one_row) { - bool min_sniff_match = true; - // If we don't have more than one row, either the names must match or the types must match. - for (auto &column : columns) { - if (current_schema.find(column.name) == current_schema.end()) { - min_sniff_match = false; - break; - } - } - if (min_sniff_match) { - return true; - } - // Otherwise, the types must match. - min_sniff_match = true; - if (sniffer_result.return_types.size() == columns.size()) { - idx_t return_type_idx = 0; - for (auto &column : columns) { - if (column.type != sniffer_result.return_types[return_type_idx++]) { - min_sniff_match = false; - break; - } - } - } else { - min_sniff_match = false; - } - if (min_sniff_match) { - // If we got here, we have the right types but the wrong names, lets fix the names - idx_t sniff_name_idx = 0; - for (auto &column : columns) { - sniffer_result.names[sniff_name_idx++] = column.name; - } - return true; - } - } - // If we got to this point, the minimal sniffer doesn't match, we throw an error. - } - // Here we check if the schema of a given file matched our original schema - // We consider it's not a match if: - // 1. The file misses columns that were defined in the original schema. - // 2. They have a column match, but the types do not match. - std::ostringstream error; - error << "Schema mismatch between globbed files." - << "\n"; - error << "Main file schema: " << file_path << "\n"; - error << "Current file: " << cur_file_path << "\n"; - - for (auto &column : columns) { - if (current_schema.find(column.name) == current_schema.end()) { - error << "Column with name: \"" << column.name << "\" is missing" - << "\n"; - match = false; - } else { - if (!CanWeCastIt(current_schema[column.name].type.id(), column.type.id())) { - error << "Column with name: \"" << column.name - << "\" is expected to have type: " << column.type.ToString(); - error << " But has type: " << current_schema[column.name].type.ToString() << "\n"; - match = false; - } - } - } - - // Lets suggest some potential fixes - error << "Potential Fix: Since your schema has a mismatch, consider setting union_by_name=true."; - if (!match) { - error_message = error.str(); - } - return match; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp deleted file mode 100644 index b0511bec6..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/scanner_boundary.cpp +++ /dev/null @@ -1,137 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/scanner_boundary.hpp" - -namespace duckdb { - -CSVPosition::CSVPosition(idx_t buffer_idx_p, idx_t buffer_pos_p) : buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p) { -} -CSVPosition::CSVPosition() { -} - -CSVBoundary::CSVBoundary(idx_t buffer_idx_p, idx_t buffer_pos_p, idx_t boundary_idx_p, idx_t end_pos_p) - : buffer_idx(buffer_idx_p), buffer_pos(buffer_pos_p), boundary_idx(boundary_idx_p), end_pos(end_pos_p) { -} -CSVBoundary::CSVBoundary() : buffer_idx(0), buffer_pos(0), boundary_idx(0), end_pos(NumericLimits::Maximum()) { -} - -CSVIterator::CSVIterator() : is_set(false) { -} - -void CSVBoundary::Print() const { -#ifndef DUCKDB_DISABLE_PRINT - std::cout << "---Boundary: " << boundary_idx << " ---" << '\n'; - std::cout << "Buffer Index: " << buffer_idx << '\n'; - std::cout << "Buffer Pos: " << buffer_pos << '\n'; - std::cout << "End Pos: " << end_pos << '\n'; - std::cout << "------------" << end_pos << '\n'; -#endif -} - -void CSVIterator::Print() const { -#ifndef DUCKDB_DISABLE_PRINT - boundary.Print(); - std::cout << "Is set: " << is_set << '\n'; -#endif -} - -idx_t CSVIterator::BytesPerThread(const CSVReaderOptions &reader_options) { - const idx_t buffer_size = reader_options.buffer_size_option.GetValue(); - const idx_t max_row_size = reader_options.maximum_line_size.GetValue(); - const idx_t bytes_per_thread = buffer_size / CSVBuffer::ROWS_PER_BUFFER * ROWS_PER_THREAD; - if (bytes_per_thread < max_row_size) { - // If we are setting up the buffer size directly, we must make sure each thread will read the full buffer. - return max_row_size; - } - return bytes_per_thread; -} - -bool CSVIterator::Next(CSVBufferManager &buffer_manager, const CSVReaderOptions &reader_options) { - if (!is_set) { - return false; - } - const auto bytes_per_thread = BytesPerThread(reader_options); - - // If we are calling next this is not the first one anymore - first_one = false; - boundary.boundary_idx++; - // This is our start buffer - auto buffer = buffer_manager.GetBuffer(boundary.buffer_idx); - if (buffer->is_last_buffer && boundary.buffer_pos + bytes_per_thread > buffer->actual_size) { - // 1) We are done with the current file - return false; - } else if (boundary.buffer_pos + bytes_per_thread >= buffer->actual_size) { - // 2) We still have data to scan in this file, we set the iterator accordingly. - // We must move the buffer - boundary.buffer_idx++; - boundary.buffer_pos = 0; - // Verify this buffer really exists - auto next_buffer = buffer_manager.GetBuffer(boundary.buffer_idx); - if (!next_buffer) { - return false; - } - - } else { - // 3) We are not done with the current buffer, hence we just move where we start within the buffer - boundary.buffer_pos += bytes_per_thread; - } - boundary.end_pos = boundary.buffer_pos + bytes_per_thread; - SetCurrentPositionToBoundary(); - return true; -} - -bool CSVIterator::IsBoundarySet() const { - return is_set; -} -idx_t CSVIterator::GetEndPos() const { - return boundary.end_pos; -} - -idx_t CSVIterator::GetBufferIdx() const { - return pos.buffer_idx; -} - -idx_t CSVIterator::GetBoundaryIdx() const { - return boundary.boundary_idx; -} - -void CSVIterator::SetCurrentPositionToBoundary() { - pos.buffer_idx = boundary.buffer_idx; - pos.buffer_pos = boundary.buffer_pos; -} - -void CSVIterator::SetCurrentBoundaryToPosition(bool single_threaded, const CSVReaderOptions &reader_options) { - if (single_threaded) { - is_set = false; - return; - } - const auto bytes_per_thread = BytesPerThread(reader_options); - - boundary.buffer_idx = pos.buffer_idx; - if (pos.buffer_pos == 0) { - boundary.end_pos = bytes_per_thread; - } else { - boundary.end_pos = ((pos.buffer_pos + bytes_per_thread - 1) / bytes_per_thread) * bytes_per_thread; - } - - boundary.buffer_pos = boundary.end_pos - bytes_per_thread; - is_set = true; -} - -void CSVIterator::SetStart(idx_t start) { - boundary.buffer_pos = start; -} - -void CSVIterator::SetEnd(idx_t pos) { - boundary.end_pos = pos; -} - -void CSVIterator::CheckIfDone() { - if (IsBoundarySet() && (pos.buffer_idx > boundary.buffer_idx || pos.buffer_pos > boundary.buffer_pos)) { - done = true; - } -} - -idx_t CSVIterator::GetGlobalCurrentPos() const { - return pos.buffer_pos + buffer_size * pos.buffer_idx; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp deleted file mode 100644 index 3afe22d6b..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/skip_scanner.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" -#include "duckdb/execution/operator/csv_scanner/column_count_scanner.hpp" - -namespace duckdb { - -SkipResult::SkipResult(CSVStates &states, CSVStateMachine &state_machine, idx_t rows_to_skip_p) - : ScannerResult(states, state_machine, STANDARD_VECTOR_SIZE), rows_to_skip(rows_to_skip_p) { -} - -void SkipResult::AddValue(SkipResult &result, const idx_t buffer_pos) { - // nop -} - -inline void SkipResult::InternalAddRow() { - row_count++; -} - -void SkipResult::QuotedNewLine(SkipResult &result) { - // nop -} - -bool SkipResult::UnsetComment(SkipResult &result, idx_t buffer_pos) { - // If we are unsetting a comment, it means this row started with a comment char. - // We add the row but tag it as a comment - bool done = result.AddRow(result, buffer_pos); - result.comment = false; - return done; -} - -bool SkipResult::AddRow(SkipResult &result, const idx_t buffer_pos) { - result.InternalAddRow(); - if (result.row_count >= result.rows_to_skip) { - // We skipped enough rows - return true; - } - return false; -} - -void SkipResult::InvalidState(SkipResult &result) { - // nop -} - -bool SkipResult::EmptyLine(SkipResult &result, const idx_t buffer_pos) { - if (result.state_machine.dialect_options.num_cols == 1) { - return AddRow(result, buffer_pos); - } - return false; -} - -SkipScanner::SkipScanner(shared_ptr buffer_manager, const shared_ptr &state_machine, - shared_ptr error_handler, idx_t rows_to_skip) - : BaseScanner(std::move(buffer_manager), state_machine, std::move(error_handler)), - result(states, *state_machine, rows_to_skip) { -} - -SkipResult &SkipScanner::ParseChunk() { - ParseChunkInternal(result); - return result; -} - -SkipResult &SkipScanner::GetResult() { - return result; -} - -void SkipScanner::Initialize() { - states.Initialize(); -} - -void SkipScanner::FinalizeChunkProcess() { - // We continue skipping until we skipped enough rows, or we have nothing else to read. - while (!FinishedFile() && result.row_count < result.rows_to_skip) { - cur_buffer_handle = buffer_manager->GetBuffer(++iterator.pos.buffer_idx); - if (cur_buffer_handle) { - iterator.pos.buffer_pos = 0; - buffer_handle_ptr = cur_buffer_handle->Ptr(); - Process(result); - } - } - // Skip Carriage Return - if (state_machine->options.dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON && - states.states[1] == CSVState::CARRIAGE_RETURN) { - iterator.pos.buffer_pos++; - } - iterator.done = FinishedFile(); -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp deleted file mode 100644 index c1ab79dfe..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/scanner/string_value_scanner.cpp +++ /dev/null @@ -1,1839 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/string_value_scanner.hpp" - -#include "duckdb/common/operator/decimal_cast_operators.hpp" -#include "duckdb/common/operator/double_cast_operator.hpp" -#include "duckdb/common/operator/integer_cast_operator.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_casting.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp" -#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/main/client_data.hpp" -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -constexpr idx_t StringValueScanner::LINE_FINDER_ID; - -StringValueResult::StringValueResult(CSVStates &states, CSVStateMachine &state_machine, - const shared_ptr &buffer_handle, Allocator &buffer_allocator, - idx_t result_size_p, idx_t buffer_position, CSVErrorHandler &error_hander_p, - CSVIterator &iterator_p, bool store_line_size_p, - shared_ptr csv_file_scan_p, idx_t &lines_read_p, bool sniffing_p, - string path_p, idx_t scan_id) - : ScannerResult(states, state_machine, result_size_p), - number_of_columns(NumericCast(state_machine.dialect_options.num_cols)), - null_padding(state_machine.options.null_padding), ignore_errors(state_machine.options.ignore_errors.GetValue()), - extra_delimiter_bytes(state_machine.dialect_options.state_machine_options.delimiter.GetValue().size() - 1), - error_handler(error_hander_p), iterator(iterator_p), store_line_size(store_line_size_p), - csv_file_scan(std::move(csv_file_scan_p)), lines_read(lines_read_p), - current_errors(scan_id, state_machine.options.IgnoreErrors()), sniffing(sniffing_p), path(std::move(path_p)) { - // Vector information - D_ASSERT(number_of_columns > 0); - if (!buffer_handle) { - // It Was Over Before It Even Began - D_ASSERT(iterator.done); - return; - } - buffer_handles[buffer_handle->buffer_idx] = buffer_handle; - // Buffer Information - buffer_ptr = buffer_handle->Ptr(); - buffer_size = buffer_handle->actual_size; - last_position = {buffer_handle->buffer_idx, buffer_position, buffer_size}; - requested_size = buffer_handle->requested_size; - // Current Result information - current_line_position.begin = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, buffer_handle->actual_size}; - current_line_position.end = current_line_position.begin; - // Fill out Parse Types - vector logical_types; - parse_types = make_unsafe_uniq_array(number_of_columns); - LogicalType varchar_type = LogicalType::VARCHAR; - if (!csv_file_scan) { - for (idx_t i = 0; i < number_of_columns; i++) { - parse_types[i] = ParseTypeInfo(varchar_type, true); - logical_types.emplace_back(LogicalType::VARCHAR); - string name = "Column_" + to_string(i); - names.emplace_back(name); - } - } else { - if (csv_file_scan->file_types.size() > number_of_columns) { - throw InvalidInputException( - "Mismatch between the number of columns (%d) in the CSV file and what is expected in the scanner (%d).", - number_of_columns, csv_file_scan->file_types.size()); - } - bool icu_loaded = csv_file_scan->buffer_manager->context.db->ExtensionIsLoaded("icu"); - for (idx_t i = 0; i < csv_file_scan->file_types.size(); i++) { - auto &type = csv_file_scan->file_types[i]; - if (type.IsJSONType()) { - type = LogicalType::VARCHAR; - } - if (StringValueScanner::CanDirectlyCast(type, icu_loaded)) { - parse_types[i] = ParseTypeInfo(type, true); - logical_types.emplace_back(type); - } else { - parse_types[i] = ParseTypeInfo(varchar_type, type.id() == LogicalTypeId::VARCHAR || type.IsNested()); - logical_types.emplace_back(LogicalType::VARCHAR); - } - } - names = csv_file_scan->names; - if (!csv_file_scan->projected_columns.empty()) { - projecting_columns = false; - projected_columns = make_unsafe_uniq_array(number_of_columns); - for (idx_t col_idx = 0; col_idx < number_of_columns; col_idx++) { - if (csv_file_scan->projected_columns.find(col_idx) == csv_file_scan->projected_columns.end()) { - // Column is not projected - projecting_columns = true; - projected_columns[col_idx] = false; - } else { - projected_columns[col_idx] = true; - } - } - } - if (!projecting_columns) { - for (idx_t j = logical_types.size(); j < number_of_columns; j++) { - // This can happen if we have sneaky null columns at the end that we wish to ignore - parse_types[j] = ParseTypeInfo(varchar_type, true); - logical_types.emplace_back(LogicalType::VARCHAR); - } - } - } - - // Initialize Parse Chunk - parse_chunk.Initialize(buffer_allocator, logical_types, result_size); - for (auto &col : parse_chunk.data) { - vector_ptr.push_back(FlatVector::GetData(col)); - validity_mask.push_back(&FlatVector::Validity(col)); - } - - // Setup the NullStr information - null_str_count = state_machine.options.null_str.size(); - null_str_ptr = make_unsafe_uniq_array_uninitialized(null_str_count); - null_str_size = make_unsafe_uniq_array_uninitialized(null_str_count); - for (idx_t i = 0; i < null_str_count; i++) { - null_str_ptr[i] = state_machine.options.null_str[i].c_str(); - null_str_size[i] = state_machine.options.null_str[i].size(); - } - date_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::DATE).GetValue(); - timestamp_format = state_machine.options.dialect_options.date_format.at(LogicalTypeId::TIMESTAMP).GetValue(); - decimal_separator = state_machine.options.decimal_separator[0]; - - if (iterator.first_one) { - lines_read += - state_machine.dialect_options.skip_rows.GetValue() + state_machine.dialect_options.header.GetValue(); - if (lines_read == 0) { - SkipBOM(); - } - } -} - -StringValueResult::~StringValueResult() { - // We have to insert the lines read by this scanner - error_handler.Insert(iterator.GetBoundaryIdx(), lines_read); - if (!iterator.done) { - // Some operators, like Limit, might cause a future error to incorrectly report the wrong error line - // Better to print nothing to print something wrong - error_handler.DontPrintErrorLine(); - } -} - -inline bool IsValueNull(const char *null_str_ptr, const char *value_ptr, const idx_t size) { - for (idx_t i = 0; i < size; i++) { - if (null_str_ptr[i] != value_ptr[i]) { - return false; - } - } - return true; -} - -bool StringValueResult::HandleTooManyColumnsError(const char *value_ptr, const idx_t size) { - if (cur_col_id >= number_of_columns) { - bool error = true; - if (cur_col_id == number_of_columns && ((quoted && state_machine.options.allow_quoted_nulls) || !quoted)) { - // we make an exception if the first over-value is null - bool is_value_null = false; - for (idx_t i = 0; i < null_str_count; i++) { - is_value_null = is_value_null || IsValueNull(null_str_ptr[i], value_ptr, size); - } - error = !is_value_null; - } - if (error) { - // We error pointing to the current value error. - current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); - cur_col_id++; - } - // We had an error - return true; - } - return false; -} - -void StringValueResult::SetComment(StringValueResult &result, idx_t buffer_pos) { - if (!result.comment) { - result.position_before_comment = buffer_pos; - result.comment = true; - } -} - -bool StringValueResult::UnsetComment(StringValueResult &result, idx_t buffer_pos) { - bool done = false; - if (result.last_position.buffer_pos < result.position_before_comment) { - bool all_empty = true; - for (idx_t i = result.last_position.buffer_pos; i < result.position_before_comment; i++) { - if (result.buffer_ptr[i] != ' ') { - all_empty = false; - break; - } - } - if (!all_empty) { - done = AddRow(result, result.position_before_comment); - } - } else { - if (result.cur_col_id != 0) { - done = AddRow(result, result.position_before_comment); - } - } - if (result.number_of_rows == 0) { - result.first_line_is_comment = true; - } - result.comment = false; - if (result.state_machine.dialect_options.state_machine_options.new_line.GetValue() != NewLineIdentifier::CARRY_ON) { - result.last_position.buffer_pos = buffer_pos + 1; - } else { - result.last_position.buffer_pos = buffer_pos + 2; - } - result.cur_col_id = 0; - result.chunk_col_id = 0; - return done; -} - -static void SanitizeError(string &value) { - std::vector char_array(value.begin(), value.end()); - char_array.push_back('\0'); // Null-terminate the character array - Utf8Proc::MakeValid(&char_array[0], char_array.size()); - value = {char_array.begin(), char_array.end() - 1}; -} - -void StringValueResult::AddValueToVector(const char *value_ptr, const idx_t size, bool allocate) { - if (HandleTooManyColumnsError(value_ptr, size)) { - return; - } - if (cur_col_id >= number_of_columns) { - bool error = true; - if (cur_col_id == number_of_columns && ((quoted && state_machine.options.allow_quoted_nulls) || !quoted)) { - // we make an exception if the first over-value is null - bool is_value_null = false; - for (idx_t i = 0; i < null_str_count; i++) { - is_value_null = is_value_null || IsValueNull(null_str_ptr[i], value_ptr, size); - } - error = !is_value_null; - } - if (error) { - // We error pointing to the current value error. - current_errors.Insert(TOO_MANY_COLUMNS, cur_col_id, chunk_col_id, last_position); - cur_col_id++; - } - return; - } - - if (projecting_columns) { - if (!projected_columns[cur_col_id]) { - cur_col_id++; - return; - } - } - - if (((quoted && state_machine.options.allow_quoted_nulls) || !quoted)) { - // Check for the occurrence of escaped null string like \N only if RFC 4180 conformance is disabled - const bool check_unquoted_escaped_null = - state_machine.state_machine_options.rfc_4180.GetValue() == false && escaped && !quoted && size == 1; - for (idx_t i = 0; i < null_str_count; i++) { - bool is_null = false; - if (null_str_size[i] == 2 && null_str_ptr[i][0] == state_machine.state_machine_options.escape.GetValue()) { - is_null = check_unquoted_escaped_null && null_str_ptr[i][1] == value_ptr[0]; - } else if (size == null_str_size[i] && !check_unquoted_escaped_null) { - is_null = IsValueNull(null_str_ptr[i], value_ptr, size); - } - if (is_null) { - bool empty = false; - if (chunk_col_id < state_machine.options.force_not_null.size()) { - empty = state_machine.options.force_not_null[chunk_col_id]; - } - if (empty) { - if (parse_types[chunk_col_id].type_id != LogicalTypeId::VARCHAR) { - // If it is not a varchar, empty values are not accepted, we must error. - current_errors.Insert(CAST_ERROR, cur_col_id, chunk_col_id, last_position); - } else { - static_cast(vector_ptr[chunk_col_id])[number_of_rows] = string_t(); - } - } else { - if (chunk_col_id == number_of_columns) { - // We check for a weird case, where we ignore an extra value, if it is a null value - return; - } - validity_mask[chunk_col_id]->SetInvalid(static_cast(number_of_rows)); - } - cur_col_id++; - chunk_col_id++; - return; - } - } - } - bool success = true; - switch (parse_types[chunk_col_id].type_id) { - case LogicalTypeId::BOOLEAN: - success = - TryCastStringBool(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::TINYINT: - success = TrySimpleIntegerCast(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - false); - break; - case LogicalTypeId::SMALLINT: - success = TrySimpleIntegerCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::INTEGER: - success = TrySimpleIntegerCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::BIGINT: - success = TrySimpleIntegerCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::UTINYINT: - success = TrySimpleIntegerCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::USMALLINT: - success = TrySimpleIntegerCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::UINTEGER: - success = TrySimpleIntegerCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::UBIGINT: - success = TrySimpleIntegerCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - case LogicalTypeId::DOUBLE: - success = - TryDoubleCast(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - false, state_machine.options.decimal_separator[0]); - break; - case LogicalTypeId::FLOAT: - success = TryDoubleCast(value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - false, state_machine.options.decimal_separator[0]); - break; - case LogicalTypeId::DATE: { - if (!date_format.Empty()) { - success = date_format.TryParseDate(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows]); - } else { - idx_t pos; - bool special; - success = Date::TryConvertDate(value_ptr, size, pos, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], special, - false) == DateCastResult::SUCCESS; - } - break; - } - case LogicalTypeId::TIME: { - idx_t pos; - success = Time::TryConvertTime(value_ptr, size, pos, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], false); - break; - } - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: { - if (!timestamp_format.Empty()) { - success = timestamp_format.TryParseTimestamp( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]); - } else { - success = Timestamp::TryConvertTimestamp( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows]) == - TimestampCastResult::SUCCESS; - } - break; - } - case LogicalTypeId::DECIMAL: { - if (decimal_separator == ',') { - switch (parse_types[chunk_col_id].internal_type) { - case PhysicalType::INT16: - success = TryDecimalStringCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT32: - success = TryDecimalStringCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT64: - success = TryDecimalStringCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT128: - success = TryDecimalStringCast( - value_ptr, size, static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - default: - throw InternalException("Invalid Physical Type for Decimal Value. Physical Type: " + - TypeIdToString(parse_types[chunk_col_id].internal_type)); - } - - } else if (decimal_separator == '.') { - switch (parse_types[chunk_col_id].internal_type) { - case PhysicalType::INT16: - success = TryDecimalStringCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT32: - success = TryDecimalStringCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT64: - success = TryDecimalStringCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - case PhysicalType::INT128: - success = TryDecimalStringCast(value_ptr, size, - static_cast(vector_ptr[chunk_col_id])[number_of_rows], - parse_types[chunk_col_id].width, parse_types[chunk_col_id].scale); - break; - default: - throw InternalException("Invalid Physical Type for Decimal Value. Physical Type: " + - TypeIdToString(parse_types[chunk_col_id].internal_type)); - } - } else { - throw InvalidInputException("Decimals can only have ',' and '.' as decimal separators"); - } - break; - } - default: { - // By default, we add a string - // We only evaluate if a string is utf8 valid, if it's actually a varchar - if (parse_types[chunk_col_id].validate_utf8 && - !Utf8Proc::IsValid(value_ptr, UnsafeNumericCast(size))) { - bool force_error = !state_machine.options.ignore_errors.GetValue() && sniffing; - // Invalid unicode, we must error - if (force_error) { - HandleUnicodeError(cur_col_id, last_position); - } - // If we got here, we are ignoring errors, hence we must ignore this line. - current_errors.Insert(INVALID_UNICODE, cur_col_id, chunk_col_id, last_position); - break; - } - if (allocate) { - // If it's a value produced over multiple buffers, we must allocate - static_cast(vector_ptr[chunk_col_id])[number_of_rows] = StringVector::AddStringOrBlob( - parse_chunk.data[chunk_col_id], string_t(value_ptr, UnsafeNumericCast(size))); - } else { - static_cast(vector_ptr[chunk_col_id])[number_of_rows] = - string_t(value_ptr, UnsafeNumericCast(size)); - } - break; - } - } - if (!success) { - current_errors.Insert(CAST_ERROR, cur_col_id, chunk_col_id, last_position); - if (!state_machine.options.IgnoreErrors()) { - // We have to write the cast error message. - std::ostringstream error; - // Casting Error Message - error << "Could not convert string \"" << std::string(value_ptr, size) << "\" to \'" - << LogicalTypeIdToString(parse_types[chunk_col_id].type_id) << "\'"; - auto error_string = error.str(); - SanitizeError(error_string); - - current_errors.ModifyErrorMessageOfLastError(error_string); - } - } - cur_col_id++; - chunk_col_id++; -} - -DataChunk &StringValueResult::ToChunk() { - if (number_of_rows < 0) { - throw InternalException("CSVScanner: ToChunk() function. Has a negative number of rows, this indicates an " - "issue with the error handler."); - } - parse_chunk.SetCardinality(static_cast(number_of_rows)); - return parse_chunk; -} - -void StringValueResult::Reset() { - if (number_of_rows == 0) { - return; - } - number_of_rows = 0; - cur_col_id = 0; - chunk_col_id = 0; - for (auto &v : validity_mask) { - v->SetAllValid(result_size); - } - // We keep a reference to the buffer from our current iteration if it already exists - shared_ptr cur_buffer; - if (buffer_handles.find(iterator.GetBufferIdx()) != buffer_handles.end()) { - cur_buffer = buffer_handles[iterator.GetBufferIdx()]; - } - buffer_handles.clear(); - if (cur_buffer) { - buffer_handles[cur_buffer->buffer_idx] = cur_buffer; - } - current_errors.Reset(); - borked_rows.clear(); -} - -void StringValueResult::AddQuotedValue(StringValueResult &result, const idx_t buffer_pos) { - if (!result.unquoted) { - result.current_errors.Insert(UNTERMINATED_QUOTES, result.cur_col_id, result.chunk_col_id, result.last_position); - } - AddPossiblyEscapedValue(result, buffer_pos, result.buffer_ptr + result.quoted_position + 1, - buffer_pos - result.quoted_position - 2, buffer_pos < result.last_position.buffer_pos + 2); - result.quoted = false; -} - -void StringValueResult::AddPossiblyEscapedValue(StringValueResult &result, const idx_t buffer_pos, - const char *value_ptr, const idx_t length, const bool empty) { - if (result.escaped) { - if (result.projecting_columns) { - if (!result.projected_columns[result.cur_col_id]) { - result.cur_col_id++; - result.escaped = false; - return; - } - } - if (!result.HandleTooManyColumnsError(value_ptr, length)) { - // If it's an escaped value we have to remove all the escapes, this is not really great - // If we are going to escape, this vector must be a varchar vector - if (result.parse_chunk.data[result.chunk_col_id].GetType() != LogicalType::VARCHAR) { - result.current_errors.Insert(CAST_ERROR, result.cur_col_id, result.chunk_col_id, result.last_position); - if (!result.state_machine.options.IgnoreErrors()) { - // We have to write the cast error message. - std::ostringstream error; - // Casting Error Message - - error << "Could not convert string \"" << std::string(value_ptr, length) << "\" to \'" - << LogicalTypeIdToString(result.parse_types[result.chunk_col_id].type_id) << "\'"; - auto error_string = error.str(); - SanitizeError(error_string); - result.current_errors.ModifyErrorMessageOfLastError(error_string); - } - result.cur_col_id++; - result.chunk_col_id++; - } else { - auto value = StringValueScanner::RemoveEscape( - value_ptr, length, result.state_machine.dialect_options.state_machine_options.escape.GetValue(), - result.state_machine.dialect_options.state_machine_options.quote.GetValue(), - result.parse_chunk.data[result.chunk_col_id]); - result.AddValueToVector(value.GetData(), value.GetSize()); - } - } - } else { - if (empty) { - // empty value - auto value = string_t(); - result.AddValueToVector(value.GetData(), value.GetSize()); - } else { - result.AddValueToVector(value_ptr, length); - } - } - result.escaped = false; -} - -inline idx_t StringValueResult::HandleMultiDelimiter(const idx_t buffer_pos) const { - idx_t size = buffer_pos - last_position.buffer_pos - extra_delimiter_bytes; - if (buffer_pos < last_position.buffer_pos + extra_delimiter_bytes) { - // If this is a scenario where the value is null, that is fine (e.g., delim = '||' and line is: A||) - if (buffer_pos == last_position.buffer_pos) { - size = 0; - } else { - // Otherwise something went wrong. - throw InternalException( - "Value size is lower than the number of extra delimiter bytes in the HandleMultiDelimiter(). " - "buffer_pos = %d, last_position.buffer_pos = %d, extra_delimiter_bytes = %d", - buffer_pos, last_position.buffer_pos, extra_delimiter_bytes); - } - } - return size; -} - -void StringValueResult::AddValue(StringValueResult &result, const idx_t buffer_pos) { - if (result.last_position.buffer_pos > buffer_pos) { - return; - } - if (result.quoted) { - AddQuotedValue(result, buffer_pos - result.extra_delimiter_bytes); - } else if (result.escaped) { - AddPossiblyEscapedValue(result, buffer_pos, result.buffer_ptr + result.last_position.buffer_pos, - buffer_pos - result.last_position.buffer_pos, false); - } else { - result.AddValueToVector(result.buffer_ptr + result.last_position.buffer_pos, - result.HandleMultiDelimiter(buffer_pos)); - } - result.last_position.buffer_pos = buffer_pos + 1; -} - -void StringValueResult::HandleUnicodeError(idx_t col_idx, LinePosition &error_position) { - - bool first_nl; - auto borked_line = current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); - LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); - if (current_line_position.begin == error_position) { - auto csv_error = CSVError::InvalidUTF8(state_machine.options, col_idx, lines_per_batch, borked_line, - current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - error_position.GetGlobalPosition(requested_size, first_nl), path); - error_handler.Error(csv_error, true); - } else { - auto csv_error = CSVError::InvalidUTF8(state_machine.options, col_idx, lines_per_batch, borked_line, - current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - error_position.GetGlobalPosition(requested_size), path); - error_handler.Error(csv_error, true); - } -} - -bool LineError::HandleErrors(StringValueResult &result) { - bool skip_sniffing = false; - for (auto &cur_error : current_errors) { - if (cur_error.type == CSVErrorType::INVALID_UNICODE) { - skip_sniffing = true; - } - } - skip_sniffing = result.sniffing && skip_sniffing; - - if ((ignore_errors || skip_sniffing) && is_error_in_line && !result.figure_out_new_line) { - result.RemoveLastLine(); - Reset(); - return true; - } - // Reconstruct CSV Line - for (auto &cur_error : current_errors) { - LinesPerBoundary lines_per_batch(result.iterator.GetBoundaryIdx(), result.lines_read); - bool first_nl = false; - auto borked_line = result.current_line_position.ReconstructCurrentLine(first_nl, result.buffer_handles, - result.PrintErrorLine()); - CSVError csv_error; - auto col_idx = cur_error.col_idx; - auto &line_pos = cur_error.error_position; - - switch (cur_error.type) { - case TOO_MANY_COLUMNS: - case TOO_FEW_COLUMNS: - if (result.current_line_position.begin == line_pos) { - csv_error = CSVError::IncorrectColumnAmountError( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); - } else { - csv_error = CSVError::IncorrectColumnAmountError( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.path); - } - break; - case INVALID_UNICODE: { - if (result.current_line_position.begin == line_pos) { - csv_error = CSVError::InvalidUTF8( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); - } else { - csv_error = CSVError::InvalidUTF8( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.path); - } - break; - } - case UNTERMINATED_QUOTES: - if (result.current_line_position.begin == line_pos) { - csv_error = CSVError::UnterminatedQuotesError( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); - } else { - csv_error = CSVError::UnterminatedQuotesError( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.path); - } - break; - case CAST_ERROR: - if (result.current_line_position.begin == line_pos) { - csv_error = CSVError::CastError( - result.state_machine.options, result.names[cur_error.col_idx], cur_error.error_message, - cur_error.col_idx, borked_line, lines_per_batch, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl), - result.parse_types[cur_error.chunk_idx].type_id, result.path); - } else { - csv_error = CSVError::CastError( - result.state_machine.options, result.names[cur_error.col_idx], cur_error.error_message, - cur_error.col_idx, borked_line, lines_per_batch, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.parse_types[cur_error.chunk_idx].type_id, - result.path); - } - break; - case MAXIMUM_LINE_SIZE: - csv_error = CSVError::LineSizeError( - result.state_machine.options, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), result.path); - break; - case INVALID_STATE: - if (result.current_line_position.begin == line_pos) { - csv_error = CSVError::InvalidState( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size, first_nl), result.path); - } else { - csv_error = CSVError::InvalidState( - result.state_machine.options, col_idx, lines_per_batch, borked_line, - result.current_line_position.begin.GetGlobalPosition(result.requested_size, first_nl), - line_pos.GetGlobalPosition(result.requested_size), result.path); - } - break; - default: - throw InvalidInputException("CSV Error not allowed when inserting row"); - } - result.error_handler.Error(csv_error); - } - if (is_error_in_line && scan_id != StringValueScanner::LINE_FINDER_ID) { - if (result.sniffing) { - // If we are sniffing we just remove the line - result.RemoveLastLine(); - } else { - // Otherwise, we add it to the borked rows to remove it later and just cleanup the column variables. - result.borked_rows.insert(static_cast(result.number_of_rows)); - result.cur_col_id = 0; - result.chunk_col_id = 0; - } - Reset(); - return true; - } - return false; -} - -void StringValueResult::QuotedNewLine(StringValueResult &result) { - result.quoted_new_line = true; -} - -void StringValueResult::NullPaddingQuotedNewlineCheck() const { - // We do some checks for null_padding correctness - if (state_machine.options.null_padding && iterator.IsBoundarySet() && quoted_new_line) { - // If we have null_padding set, we found a quoted new line, we are scanning the file in parallel; We error. - LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); - auto csv_error = CSVError::NullPaddingFail(state_machine.options, lines_per_batch, path); - error_handler.Error(csv_error); - } -} - -//! Reconstructs the current line to be used in error messages -string FullLinePosition::ReconstructCurrentLine(bool &first_char_nl, - unordered_map> &buffer_handles, - bool reconstruct_line) const { - if (!reconstruct_line) { - return {}; - } - string result; - if (end.buffer_idx == begin.buffer_idx) { - if (buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { - throw InternalException("CSV Buffer is not available to reconstruct CSV Line, please open an issue with " - "your query and dataset."); - } - auto buffer = buffer_handles[begin.buffer_idx]->Ptr(); - first_char_nl = buffer[begin.buffer_pos] == '\n' || buffer[begin.buffer_pos] == '\r'; - for (idx_t i = begin.buffer_pos + first_char_nl; i < end.buffer_pos; i++) { - result += buffer[i]; - } - } else { - if (buffer_handles.find(begin.buffer_idx) == buffer_handles.end() || - buffer_handles.find(end.buffer_idx) == buffer_handles.end()) { - throw InternalException("CSV Buffer is not available to reconstruct CSV Line, please open an issue with " - "your query and dataset."); - } - auto first_buffer = buffer_handles[begin.buffer_idx]->Ptr(); - auto first_buffer_size = buffer_handles[begin.buffer_idx]->actual_size; - auto second_buffer = buffer_handles[end.buffer_idx]->Ptr(); - first_char_nl = first_buffer[begin.buffer_pos] == '\n' || first_buffer[begin.buffer_pos] == '\r'; - for (idx_t i = begin.buffer_pos + first_char_nl; i < first_buffer_size; i++) { - result += first_buffer[i]; - } - for (idx_t i = 0; i < end.buffer_pos; i++) { - result += second_buffer[i]; - } - } - // sanitize borked line - SanitizeError(result); - return result; -} - -bool StringValueResult::AddRowInternal() { - LinePosition current_line_start = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, buffer_size}; - idx_t current_line_size = current_line_start - current_line_position.end; - if (store_line_size) { - error_handler.NewMaxLineSize(current_line_size); - } - current_line_position.begin = current_line_position.end; - current_line_position.end = current_line_start; - if (current_line_size > state_machine.options.maximum_line_size.GetValue()) { - current_errors.Insert(MAXIMUM_LINE_SIZE, 1, chunk_col_id, last_position, current_line_size); - } - if (!state_machine.options.null_padding) { - for (idx_t col_idx = cur_col_id; col_idx < number_of_columns; col_idx++) { - current_errors.Insert(TOO_FEW_COLUMNS, col_idx - 1, chunk_col_id, last_position); - } - } - - if (current_errors.HandleErrors(*this)) { - line_positions_per_row[static_cast(number_of_rows)] = current_line_position; - number_of_rows++; - if (static_cast(number_of_rows) >= result_size) { - // We have a full chunk - return true; - } - return false; - } - NullPaddingQuotedNewlineCheck(); - quoted_new_line = false; - // We need to check if we are getting the correct number of columns here. - // If columns are correct, we add it, and that's it. - if (cur_col_id != number_of_columns) { - // We have too few columns: - if (null_padding) { - while (cur_col_id < number_of_columns) { - bool empty = false; - if (cur_col_id < state_machine.options.force_not_null.size()) { - empty = state_machine.options.force_not_null[cur_col_id]; - } - if (projecting_columns) { - if (!projected_columns[cur_col_id]) { - cur_col_id++; - continue; - } - } - if (empty) { - static_cast(vector_ptr[chunk_col_id])[number_of_rows] = string_t(); - } else { - validity_mask[chunk_col_id]->SetInvalid(static_cast(number_of_rows)); - } - cur_col_id++; - chunk_col_id++; - } - } else { - // If we are not null-padding this is an error - if (!state_machine.options.IgnoreErrors()) { - bool first_nl = false; - auto borked_line = - current_line_position.ReconstructCurrentLine(first_nl, buffer_handles, PrintErrorLine()); - LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), lines_read); - if (current_line_position.begin == last_position) { - auto csv_error = CSVError::IncorrectColumnAmountError( - state_machine.options, cur_col_id - 1, lines_per_batch, borked_line, - current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - last_position.GetGlobalPosition(requested_size, first_nl), path); - error_handler.Error(csv_error); - } else { - auto csv_error = CSVError::IncorrectColumnAmountError( - state_machine.options, cur_col_id - 1, lines_per_batch, borked_line, - current_line_position.begin.GetGlobalPosition(requested_size, first_nl), - last_position.GetGlobalPosition(requested_size), path); - error_handler.Error(csv_error); - } - } - // If we are here we ignore_errors, so we delete this line - RemoveLastLine(); - } - } - line_positions_per_row[static_cast(number_of_rows)] = current_line_position; - cur_col_id = 0; - chunk_col_id = 0; - number_of_rows++; - if (static_cast(number_of_rows) >= result_size) { - // We have a full chunk - return true; - } - return false; -} - -bool StringValueResult::AddRow(StringValueResult &result, const idx_t buffer_pos) { - if (result.last_position.buffer_pos <= buffer_pos) { - // We add the value - if (result.quoted) { - AddQuotedValue(result, buffer_pos); - } else { - char *value_ptr = result.buffer_ptr + result.last_position.buffer_pos; - idx_t size = buffer_pos - result.last_position.buffer_pos; - if (result.escaped) { - AddPossiblyEscapedValue(result, buffer_pos, value_ptr, size, size == 0); - } else { - result.AddValueToVector(value_ptr, size); - } - } - if (result.state_machine.dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON) { - if (result.states.states[1] == CSVState::RECORD_SEPARATOR) { - // Even though this is marked as a carry on, this is a hippie mixie - result.last_position.buffer_pos = buffer_pos + 1; - } else { - result.last_position.buffer_pos = buffer_pos + 2; - } - } else { - result.last_position.buffer_pos = buffer_pos + 1; - } - } - - // We add the value - return result.AddRowInternal(); -} - -void StringValueResult::InvalidState(StringValueResult &result) { - if (result.quoted) { - result.current_errors.Insert(UNTERMINATED_QUOTES, result.cur_col_id, result.chunk_col_id, result.last_position); - } else { - result.current_errors.Insert(INVALID_STATE, result.cur_col_id, result.chunk_col_id, result.last_position); - } -} - -bool StringValueResult::EmptyLine(StringValueResult &result, const idx_t buffer_pos) { - // We care about empty lines if this is a single column csv file - result.last_position = {result.iterator.pos.buffer_idx, result.iterator.pos.buffer_pos + 1, result.buffer_size}; - if (result.states.IsCarriageReturn() && - result.state_machine.dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON) { - result.last_position.buffer_pos++; - } - if (result.number_of_columns == 1) { - for (idx_t i = 0; i < result.null_str_count; i++) { - if (result.null_str_size[i] == 0) { - bool empty = false; - if (!result.state_machine.options.force_not_null.empty()) { - empty = result.state_machine.options.force_not_null[0]; - } - if (empty) { - static_cast(result.vector_ptr[0])[result.number_of_rows] = string_t(); - } else { - result.validity_mask[0]->SetInvalid(static_cast(result.number_of_rows)); - } - result.number_of_rows++; - } - } - if (static_cast(result.number_of_rows) >= result.result_size) { - // We have a full chunk - return true; - } - } - return false; -} - -StringValueScanner::StringValueScanner(idx_t scanner_idx_p, const shared_ptr &buffer_manager, - const shared_ptr &state_machine, - const shared_ptr &error_handler, - const shared_ptr &csv_file_scan, bool sniffing, - const CSVIterator &boundary, idx_t result_size) - : BaseScanner(buffer_manager, state_machine, error_handler, sniffing, csv_file_scan, boundary), - scanner_idx(scanner_idx_p), - result(states, *state_machine, cur_buffer_handle, BufferAllocator::Get(buffer_manager->context), result_size, - iterator.pos.buffer_pos, *error_handler, iterator, - buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, - buffer_manager->GetFilePath(), scanner_idx_p) { - iterator.buffer_size = state_machine->options.buffer_size_option.GetValue(); -} - -StringValueScanner::StringValueScanner(const shared_ptr &buffer_manager, - const shared_ptr &state_machine, - const shared_ptr &error_handler, idx_t result_size, - const CSVIterator &boundary) - : BaseScanner(buffer_manager, state_machine, error_handler, false, nullptr, boundary), scanner_idx(0), - result(states, *state_machine, cur_buffer_handle, Allocator::DefaultAllocator(), result_size, - iterator.pos.buffer_pos, *error_handler, iterator, - buffer_manager->context.client_data->debug_set_max_line_length, csv_file_scan, lines_read, sniffing, - buffer_manager->GetFilePath(), 0) { - iterator.buffer_size = state_machine->options.buffer_size_option.GetValue(); -} - -unique_ptr StringValueScanner::GetCSVScanner(ClientContext &context, CSVReaderOptions &options) { - auto state_machine = make_shared_ptr(options, options.dialect_options.state_machine_options, - CSVStateMachineCache::Get(context)); - - state_machine->dialect_options.num_cols = options.dialect_options.num_cols; - state_machine->dialect_options.header = options.dialect_options.header; - auto buffer_manager = make_shared_ptr(context, options, options.file_path, 0); - idx_t rows_to_skip = state_machine->options.GetSkipRows() + state_machine->options.GetHeader(); - rows_to_skip = std::max(rows_to_skip, state_machine->dialect_options.rows_until_header + - state_machine->dialect_options.header.GetValue()); - auto it = BaseScanner::SkipCSVRows(buffer_manager, state_machine, rows_to_skip); - auto scanner = make_uniq(buffer_manager, state_machine, make_shared_ptr(), - STANDARD_VECTOR_SIZE, it); - scanner->csv_file_scan = make_shared_ptr(context, options.file_path, options); - scanner->csv_file_scan->InitializeProjection(); - return scanner; -} - -bool StringValueScanner::FinishedIterator() const { - return iterator.done; -} - -StringValueResult &StringValueScanner::ParseChunk() { - result.Reset(); - ParseChunkInternal(result); - return result; -} - -void StringValueScanner::Flush(DataChunk &insert_chunk) { - auto &process_result = ParseChunk(); - // First Get Parsed Chunk - auto &parse_chunk = process_result.ToChunk(); - // We have to check if we got to error - error_handler->ErrorIfNeeded(); - if (parse_chunk.size() == 0) { - return; - } - // convert the columns in the parsed chunk to the types of the table - insert_chunk.SetCardinality(parse_chunk); - - // We keep track of the borked lines, in case we are ignoring errors - D_ASSERT(csv_file_scan); - - auto &reader_data = csv_file_scan->reader_data; - // Now Do the cast-aroo - for (idx_t c = 0; c < reader_data.column_ids.size(); c++) { - idx_t col_idx = c; - idx_t result_idx = reader_data.column_mapping[c]; - if (!csv_file_scan->projection_ids.empty()) { - result_idx = reader_data.column_mapping[csv_file_scan->projection_ids[c].second]; - } - if (col_idx >= parse_chunk.ColumnCount()) { - throw InvalidInputException("Mismatch between the schema of different files"); - } - auto &parse_vector = parse_chunk.data[col_idx]; - auto &result_vector = insert_chunk.data[result_idx]; - auto &type = result_vector.GetType(); - auto &parse_type = parse_vector.GetType(); - if (!type.IsJSONType() && - (type == LogicalType::VARCHAR || (type != LogicalType::VARCHAR && parse_type != LogicalType::VARCHAR))) { - // reinterpret rather than reference - result_vector.Reinterpret(parse_vector); - } else { - string error_message; - idx_t line_error = 0; - if (VectorOperations::TryCast(buffer_manager->context, parse_vector, result_vector, parse_chunk.size(), - &error_message, false, true)) { - continue; - } - // An error happened, to propagate it we need to figure out the exact line where the casting failed. - UnifiedVectorFormat inserted_column_data; - result_vector.ToUnifiedFormat(parse_chunk.size(), inserted_column_data); - UnifiedVectorFormat parse_column_data; - parse_vector.ToUnifiedFormat(parse_chunk.size(), parse_column_data); - - for (; line_error < parse_chunk.size(); line_error++) { - if (!inserted_column_data.validity.RowIsValid(line_error) && - parse_column_data.validity.RowIsValid(line_error)) { - break; - } - } - { - - if (state_machine->options.ignore_errors.GetValue()) { - vector row; - for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { - row.push_back(parse_chunk.GetValue(col, line_error)); - } - } - if (!state_machine->options.IgnoreErrors()) { - LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), - lines_read - parse_chunk.size() + line_error); - bool first_nl; - auto borked_line = result.line_positions_per_row[line_error].ReconstructCurrentLine( - first_nl, result.buffer_handles, result.PrintErrorLine()); - std::ostringstream error; - error << "Could not convert string \"" << parse_vector.GetValue(line_error) << "\" to \'" - << type.ToString() << "\'"; - string error_msg = error.str(); - SanitizeError(error_msg); - auto csv_error = CSVError::CastError( - state_machine->options, csv_file_scan->names[col_idx], error_msg, col_idx, borked_line, - lines_per_batch, - result.line_positions_per_row[line_error].begin.GetGlobalPosition(result.result_size, first_nl), - optional_idx::Invalid(), result_vector.GetType().id(), result.path); - error_handler->Error(csv_error); - } - } - result.borked_rows.insert(line_error++); - D_ASSERT(state_machine->options.ignore_errors.GetValue()); - // We are ignoring errors. We must continue but ignoring borked-rows - for (; line_error < parse_chunk.size(); line_error++) { - if (!inserted_column_data.validity.RowIsValid(line_error) && - parse_column_data.validity.RowIsValid(line_error)) { - result.borked_rows.insert(line_error); - vector row; - for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { - row.push_back(parse_chunk.GetValue(col, line_error)); - } - if (!state_machine->options.IgnoreErrors()) { - LinesPerBoundary lines_per_batch(iterator.GetBoundaryIdx(), - lines_read - parse_chunk.size() + line_error); - bool first_nl; - auto borked_line = result.line_positions_per_row[line_error].ReconstructCurrentLine( - first_nl, result.buffer_handles, result.PrintErrorLine()); - std::ostringstream error; - // Casting Error Message - error << "Could not convert string \"" << parse_vector.GetValue(line_error) << "\" to \'" - << LogicalTypeIdToString(type.id()) << "\'"; - string error_msg = error.str(); - SanitizeError(error_msg); - auto csv_error = - CSVError::CastError(state_machine->options, csv_file_scan->names[col_idx], error_msg, - col_idx, borked_line, lines_per_batch, - result.line_positions_per_row[line_error].begin.GetGlobalPosition( - result.result_size, first_nl), - optional_idx::Invalid(), result_vector.GetType().id(), result.path); - error_handler->Error(csv_error); - } - } - } - } - } - if (!result.borked_rows.empty()) { - // We must remove the borked lines from our chunk - SelectionVector successful_rows(parse_chunk.size()); - idx_t sel_idx = 0; - for (idx_t row_idx = 0; row_idx < parse_chunk.size(); row_idx++) { - if (result.borked_rows.find(row_idx) == result.borked_rows.end()) { - successful_rows.set_index(sel_idx++, row_idx); - } - } - // Now we slice the result - insert_chunk.Slice(successful_rows, sel_idx); - } -} - -void StringValueScanner::Initialize() { - states.Initialize(); - - if (result.result_size != 1 && !(sniffing && state_machine->options.null_padding && - !state_machine->options.dialect_options.skip_rows.IsSetByUser())) { - SetStart(); - } else { - start_pos = iterator.GetGlobalCurrentPos(); - } - - result.last_position = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, cur_buffer_handle->actual_size}; - result.current_line_position.begin = result.last_position; - result.current_line_position.end = result.current_line_position.begin; -} - -void StringValueScanner::ProcessExtraRow() { - result.NullPaddingQuotedNewlineCheck(); - idx_t to_pos = cur_buffer_handle->actual_size; - while (iterator.pos.buffer_pos < to_pos) { - state_machine->Transition(states, buffer_handle_ptr[iterator.pos.buffer_pos]); - switch (states.states[1]) { - case CSVState::INVALID: - result.InvalidState(result); - iterator.pos.buffer_pos++; - return; - case CSVState::RECORD_SEPARATOR: - if (states.states[0] == CSVState::RECORD_SEPARATOR) { - result.EmptyLine(result, iterator.pos.buffer_pos); - iterator.pos.buffer_pos++; - lines_read++; - return; - } else if (states.states[0] != CSVState::CARRIAGE_RETURN) { - if (result.IsCommentSet(result)) { - result.UnsetComment(result, iterator.pos.buffer_pos); - } else { - result.AddRow(result, iterator.pos.buffer_pos); - } - iterator.pos.buffer_pos++; - lines_read++; - return; - } - lines_read++; - iterator.pos.buffer_pos++; - break; - case CSVState::CARRIAGE_RETURN: - if (states.states[0] != CSVState::RECORD_SEPARATOR) { - if (result.IsCommentSet(result)) { - result.UnsetComment(result, iterator.pos.buffer_pos); - } else { - result.AddRow(result, iterator.pos.buffer_pos); - } - iterator.pos.buffer_pos++; - lines_read++; - return; - } else { - result.EmptyLine(result, iterator.pos.buffer_pos); - iterator.pos.buffer_pos++; - lines_read++; - return; - } - break; - case CSVState::DELIMITER: - result.AddValue(result, iterator.pos.buffer_pos); - iterator.pos.buffer_pos++; - break; - case CSVState::QUOTED: - if (states.states[0] == CSVState::UNQUOTED) { - result.SetEscaped(result); - } - result.SetQuoted(result, iterator.pos.buffer_pos); - iterator.pos.buffer_pos++; - while (state_machine->transition_array - .skip_quoted[static_cast(buffer_handle_ptr[iterator.pos.buffer_pos])] && - iterator.pos.buffer_pos < to_pos - 1) { - iterator.pos.buffer_pos++; - } - break; - case CSVState::ESCAPE: - case CSVState::UNQUOTED_ESCAPE: - case CSVState::ESCAPED_RETURN: - result.SetEscaped(result); - iterator.pos.buffer_pos++; - break; - case CSVState::STANDARD: - iterator.pos.buffer_pos++; - while (state_machine->transition_array - .skip_standard[static_cast(buffer_handle_ptr[iterator.pos.buffer_pos])] && - iterator.pos.buffer_pos < to_pos - 1) { - iterator.pos.buffer_pos++; - } - break; - case CSVState::UNQUOTED: { - result.SetUnquoted(result); - iterator.pos.buffer_pos++; - break; - } - case CSVState::COMMENT: - result.SetComment(result, iterator.pos.buffer_pos); - iterator.pos.buffer_pos++; - while (state_machine->transition_array - .skip_comment[static_cast(buffer_handle_ptr[iterator.pos.buffer_pos])] && - iterator.pos.buffer_pos < to_pos - 1) { - iterator.pos.buffer_pos++; - } - break; - case CSVState::QUOTED_NEW_LINE: - result.quoted_new_line = true; - result.NullPaddingQuotedNewlineCheck(); - iterator.pos.buffer_pos++; - break; - default: - iterator.pos.buffer_pos++; - break; - } - } -} - -string_t StringValueScanner::RemoveEscape(const char *str_ptr, idx_t end, char escape, char quote, Vector &vector) { - // Figure out the exact size - idx_t str_pos = 0; - bool just_escaped = false; - for (idx_t cur_pos = 0; cur_pos < end; cur_pos++) { - if (str_ptr[cur_pos] == escape && !just_escaped) { - just_escaped = true; - } else if (str_ptr[cur_pos] == quote) { - if (just_escaped) { - str_pos++; - } - just_escaped = false; - } else { - just_escaped = false; - str_pos++; - } - } - - auto removed_escapes = StringVector::EmptyString(vector, str_pos); - auto removed_escapes_ptr = removed_escapes.GetDataWriteable(); - // Allocate string and copy it - str_pos = 0; - just_escaped = false; - for (idx_t cur_pos = 0; cur_pos < end; cur_pos++) { - const char c = str_ptr[cur_pos]; - if (c == escape && !just_escaped) { - just_escaped = true; - } else if (str_ptr[cur_pos] == quote) { - if (just_escaped) { - removed_escapes_ptr[str_pos++] = c; - } - just_escaped = false; - } else { - just_escaped = false; - removed_escapes_ptr[str_pos++] = c; - } - } - removed_escapes.Finalize(); - return removed_escapes; -} - -void StringValueScanner::ProcessOverBufferValue() { - // Process first string - if (result.last_position.buffer_pos != previous_buffer_handle->actual_size) { - states.Initialize(); - } - - string over_buffer_string; - auto previous_buffer = previous_buffer_handle->Ptr(); - idx_t j = 0; - result.quoted = false; - for (idx_t i = result.last_position.buffer_pos; i < previous_buffer_handle->actual_size; i++) { - state_machine->Transition(states, previous_buffer[i]); - if (states.EmptyLine() || states.IsCurrentNewRow()) { - continue; - } - if (states.NewRow() || states.NewValue()) { - break; - } else { - if (!result.comment) { - over_buffer_string += previous_buffer[i]; - } - } - if (states.IsQuoted()) { - result.SetQuoted(result, j); - } - if (states.IsUnquoted()) { - result.SetUnquoted(result); - } - if (states.IsEscaped() && result.state_machine.dialect_options.state_machine_options.escape != '\0') { - result.escaped = true; - } - if (states.IsComment()) { - result.comment = true; - } - if (states.IsInvalid()) { - result.InvalidState(result); - } - j++; - } - if (over_buffer_string.empty() && - state_machine->dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON) { - if (buffer_handle_ptr[iterator.pos.buffer_pos] == '\n') { - iterator.pos.buffer_pos++; - } - } - // second buffer - for (; iterator.pos.buffer_pos < cur_buffer_handle->actual_size; iterator.pos.buffer_pos++) { - state_machine->Transition(states, buffer_handle_ptr[iterator.pos.buffer_pos]); - if (states.EmptyLine()) { - if (state_machine->dialect_options.num_cols == 1) { - break; - } else { - continue; - } - } - if (states.NewRow() || states.NewValue()) { - break; - } else { - if (!result.comment && !states.IsComment()) { - over_buffer_string += buffer_handle_ptr[iterator.pos.buffer_pos]; - } - } - if (states.IsQuoted()) { - result.SetQuoted(result, j); - } - if (states.IsComment()) { - result.comment = true; - } - if (states.IsEscaped() && result.state_machine.dialect_options.state_machine_options.escape != '\0') { - result.escaped = true; - } - if (states.IsInvalid()) { - result.InvalidState(result); - } - j++; - } - bool skip_value = false; - if (result.projecting_columns) { - if (!result.projected_columns[result.cur_col_id] && result.cur_col_id != result.number_of_columns) { - result.cur_col_id++; - skip_value = true; - } - } - if (!skip_value) { - string_t value; - if (result.quoted) { - value = string_t(over_buffer_string.c_str() + result.quoted_position, - UnsafeNumericCast(over_buffer_string.size() - 1 - result.quoted_position)); - if (result.escaped) { - if (!result.HandleTooManyColumnsError(over_buffer_string.c_str(), over_buffer_string.size())) { - const auto str_ptr = over_buffer_string.c_str() + result.quoted_position; - value = RemoveEscape(str_ptr, over_buffer_string.size() - 2, - state_machine->dialect_options.state_machine_options.escape.GetValue(), - state_machine->dialect_options.state_machine_options.quote.GetValue(), - result.parse_chunk.data[result.chunk_col_id]); - } - } - } else { - value = string_t(over_buffer_string.c_str(), UnsafeNumericCast(over_buffer_string.size())); - if (result.escaped) { - if (!result.HandleTooManyColumnsError(over_buffer_string.c_str(), over_buffer_string.size())) { - value = RemoveEscape(over_buffer_string.c_str(), over_buffer_string.size(), - state_machine->dialect_options.state_machine_options.escape.GetValue(), - state_machine->dialect_options.state_machine_options.quote.GetValue(), - result.parse_chunk.data[result.chunk_col_id]); - } - } - } - if (states.EmptyLine() && state_machine->dialect_options.num_cols == 1) { - result.EmptyLine(result, iterator.pos.buffer_pos); - } else if (!states.IsNotSet() && (!result.comment || !value.Empty())) { - idx_t value_size = value.GetSize(); - if (states.IsDelimiter()) { - idx_t extra_delimiter_bytes = - result.state_machine.dialect_options.state_machine_options.delimiter.GetValue().size() - 1; - if (extra_delimiter_bytes > value_size) { - throw InternalException( - "Value size is lower than the number of extra delimiter bytes in the ProcesOverBufferValue()"); - } - value_size -= extra_delimiter_bytes; - } - result.AddValueToVector(value.GetData(), value_size, true); - } - } else { - if (states.EmptyLine() && state_machine->dialect_options.num_cols == 1) { - result.EmptyLine(result, iterator.pos.buffer_pos); - } - } - - if (states.NewRow() && !states.IsNotSet()) { - if (result.IsCommentSet(result)) { - result.UnsetComment(result, iterator.pos.buffer_pos); - } else { - result.AddRowInternal(); - } - lines_read++; - } - - if (iterator.pos.buffer_pos >= cur_buffer_handle->actual_size && cur_buffer_handle->is_last_buffer) { - result.added_last_line = true; - } - if (states.IsCarriageReturn() && - state_machine->dialect_options.state_machine_options.new_line == NewLineIdentifier::CARRY_ON) { - result.last_position = {iterator.pos.buffer_idx, ++iterator.pos.buffer_pos + 1, result.buffer_size}; - } else { - result.last_position = {iterator.pos.buffer_idx, ++iterator.pos.buffer_pos, result.buffer_size}; - } - // Be sure to reset the quoted and escaped variables - result.quoted = false; - result.escaped = false; -} - -bool StringValueScanner::MoveToNextBuffer() { - if (iterator.pos.buffer_pos >= cur_buffer_handle->actual_size) { - previous_buffer_handle = cur_buffer_handle; - cur_buffer_handle = buffer_manager->GetBuffer(++iterator.pos.buffer_idx); - if (!cur_buffer_handle) { - iterator.pos.buffer_idx--; - buffer_handle_ptr = nullptr; - // We do not care if it's a quoted new line on the last row of our file. - result.quoted_new_line = false; - // This means we reached the end of the file, we must add a last line if there is any to be added - if (states.EmptyLine() || states.NewRow() || result.added_last_line || states.IsCurrentNewRow() || - states.IsNotSet()) { - if (result.cur_col_id == result.number_of_columns) { - result.number_of_rows++; - } - result.cur_col_id = 0; - result.chunk_col_id = 0; - return false; - } else if (states.NewValue()) { - // we add the value - result.AddValue(result, previous_buffer_handle->actual_size); - // And an extra empty value to represent what comes after the delimiter - if (result.IsCommentSet(result)) { - result.UnsetComment(result, iterator.pos.buffer_pos); - } else { - result.AddRow(result, previous_buffer_handle->actual_size); - } - lines_read++; - } else if (states.IsQuotedCurrent() && - state_machine->dialect_options.state_machine_options.rfc_4180.GetValue()) { - // Unterminated quote - LinePosition current_line_start = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, - result.buffer_size}; - result.current_line_position.begin = result.current_line_position.end; - result.current_line_position.end = current_line_start; - result.InvalidState(result); - } else { - if (result.IsCommentSet(result)) { - result.UnsetComment(result, iterator.pos.buffer_pos); - } else { - if (result.quoted && states.IsDelimiterBytes() && - state_machine->dialect_options.state_machine_options.rfc_4180.GetValue()) { - result.current_errors.Insert(UNTERMINATED_QUOTES, result.cur_col_id, result.chunk_col_id, - result.last_position); - } - result.AddRow(result, previous_buffer_handle->actual_size); - } - lines_read++; - } - return false; - } - result.buffer_handles[cur_buffer_handle->buffer_idx] = cur_buffer_handle; - - iterator.pos.buffer_pos = 0; - buffer_handle_ptr = cur_buffer_handle->Ptr(); - // Handle over-buffer value - ProcessOverBufferValue(); - result.buffer_ptr = buffer_handle_ptr; - result.buffer_size = cur_buffer_handle->actual_size; - return true; - } - return false; -} - -void StringValueResult::SkipBOM() const { - if (buffer_size >= 3 && buffer_ptr[0] == '\xEF' && buffer_ptr[1] == '\xBB' && buffer_ptr[2] == '\xBF' && - iterator.pos.buffer_pos == 0) { - iterator.pos.buffer_pos = 3; - } -} - -void StringValueResult::RemoveLastLine() { - // potentially de-nullify values - for (idx_t i = 0; i < chunk_col_id; i++) { - validity_mask[i]->SetValid(static_cast(number_of_rows)); - } - // reset column trackers - cur_col_id = 0; - chunk_col_id = 0; - // decrement row counter - number_of_rows--; -} -bool StringValueResult::PrintErrorLine() const { - // To print a lint, result size must be different, than one (i.e., this is a SetStart() trying to figure out new - // lines) And must either not be ignoring errors OR must be storing them in a rejects table. - return result_size != 1 && - (state_machine.options.store_rejects.GetValue() || !state_machine.options.ignore_errors.GetValue()); -} - -bool StringValueScanner::FirstValueEndsOnQuote(CSVIterator iterator) const { - CSVStates current_state; - current_state.Initialize(CSVState::STANDARD); - const idx_t to_pos = iterator.GetEndPos(); - while (iterator.pos.buffer_pos < to_pos) { - state_machine->Transition(current_state, buffer_handle_ptr[iterator.pos.buffer_pos++]); - if ((current_state.IsState(CSVState::DELIMITER) || current_state.IsState(CSVState::CARRIAGE_RETURN) || - current_state.IsState(CSVState::RECORD_SEPARATOR))) { - return buffer_handle_ptr[iterator.pos.buffer_pos - 2] == - state_machine->dialect_options.state_machine_options.quote.GetValue(); - } - } - return false; -} - -bool StringValueScanner::SkipUntilState(CSVState initial_state, CSVState until_state, CSVIterator ¤t_iterator, - bool "ed) const { - CSVStates current_state; - current_state.Initialize(initial_state); - bool first_column = true; - const idx_t to_pos = current_iterator.GetEndPos(); - while (current_iterator.pos.buffer_pos < to_pos) { - state_machine_strict->Transition(current_state, buffer_handle_ptr[current_iterator.pos.buffer_pos++]); - if (current_state.IsState(CSVState::STANDARD) || current_state.IsState(CSVState::STANDARD_NEWLINE)) { - while (current_iterator.pos.buffer_pos + 8 < to_pos) { - uint64_t value = Load( - reinterpret_cast(&buffer_handle_ptr[current_iterator.pos.buffer_pos])); - if (ContainsZeroByte((value ^ state_machine_strict->transition_array.delimiter) & - (value ^ state_machine_strict->transition_array.new_line) & - (value ^ state_machine_strict->transition_array.carriage_return) & - (value ^ state_machine_strict->transition_array.comment))) { - break; - } - current_iterator.pos.buffer_pos += 8; - } - while (state_machine_strict->transition_array - .skip_standard[static_cast(buffer_handle_ptr[current_iterator.pos.buffer_pos])] && - current_iterator.pos.buffer_pos < to_pos - 1) { - current_iterator.pos.buffer_pos++; - } - } - if (current_state.IsState(CSVState::QUOTED)) { - while (current_iterator.pos.buffer_pos + 8 < to_pos) { - uint64_t value = Load( - reinterpret_cast(&buffer_handle_ptr[current_iterator.pos.buffer_pos])); - if (ContainsZeroByte((value ^ state_machine_strict->transition_array.quote) & - (value ^ state_machine_strict->transition_array.escape))) { - break; - } - current_iterator.pos.buffer_pos += 8; - } - - while (state_machine_strict->transition_array - .skip_quoted[static_cast(buffer_handle_ptr[current_iterator.pos.buffer_pos])] && - current_iterator.pos.buffer_pos < to_pos - 1) { - current_iterator.pos.buffer_pos++; - } - } - if ((current_state.IsState(CSVState::DELIMITER) || current_state.IsState(CSVState::CARRIAGE_RETURN) || - current_state.IsState(CSVState::RECORD_SEPARATOR)) && - first_column) { - if (buffer_handle_ptr[current_iterator.pos.buffer_pos - 1] == - state_machine_strict->dialect_options.state_machine_options.quote.GetValue()) { - quoted = true; - } - } - if (current_state.WasState(CSVState::DELIMITER)) { - first_column = false; - } - if (current_state.IsState(until_state)) { - return true; - } - if (current_state.IsState(CSVState::INVALID)) { - return false; - } - } - return false; -} - -bool StringValueScanner::CanDirectlyCast(const LogicalType &type, bool icu_loaded) { - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DATE: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIME: - case LogicalTypeId::DECIMAL: - case LogicalType::VARCHAR: - case LogicalType::BOOLEAN: - return true; - case LogicalType::TIMESTAMP_TZ: - // We only try to do direct cast of timestamp tz if the ICU extension is not loaded, otherwise, it needs to go - // through string -> timestamp_tz casting - return !icu_loaded; - default: - return false; - } -} - -bool StringValueScanner::IsRowValid(CSVIterator ¤t_iterator) const { - if (iterator.pos.buffer_pos == cur_buffer_handle->actual_size) { - return false; - } - constexpr idx_t result_size = 1; - auto scan_finder = make_uniq(StringValueScanner::LINE_FINDER_ID, buffer_manager, - state_machine_strict, make_shared_ptr(), - csv_file_scan, false, current_iterator, result_size); - auto &tuples = scan_finder->ParseChunk(); - current_iterator.pos = scan_finder->GetIteratorPosition(); - bool has_error = false; - if (tuples.current_errors.HasError()) { - if (tuples.current_errors.Size() != 1 || !tuples.current_errors.HasErrorType(MAXIMUM_LINE_SIZE)) { - // We ignore maximum line size errors - has_error = true; - } - } - return (tuples.number_of_rows == 1 || tuples.first_line_is_comment) && !has_error && tuples.borked_rows.empty(); -} - -ValidRowInfo StringValueScanner::TryRow(CSVState state, idx_t start_pos, idx_t end_pos) const { - auto current_iterator = iterator; - current_iterator.SetStart(start_pos); - current_iterator.SetEnd(end_pos); - bool quoted = false; - if (SkipUntilState(state, CSVState::RECORD_SEPARATOR, current_iterator, quoted)) { - auto iterator_start = current_iterator; - idx_t current_pos = current_iterator.pos.buffer_pos; - current_iterator.SetEnd(iterator.GetEndPos()); - if (IsRowValid(current_iterator)) { - if (!quoted) { - quoted = FirstValueEndsOnQuote(iterator_start); - } - return {true, current_pos, current_iterator.pos.buffer_idx, current_iterator.pos.buffer_pos, quoted}; - } - } - return {false, current_iterator.pos.buffer_pos, current_iterator.pos.buffer_idx, current_iterator.pos.buffer_pos, - quoted}; -} - -void StringValueScanner::SetStart() { - start_pos = iterator.GetGlobalCurrentPos(); - if (iterator.first_one) { - if (result.store_line_size) { - result.error_handler.NewMaxLineSize(iterator.pos.buffer_pos); - } - return; - } - if (iterator.GetEndPos() > cur_buffer_handle->actual_size) { - iterator.SetEnd(cur_buffer_handle->actual_size); - } - if (!state_machine_strict) { - // We need to initialize our strict state machine - auto &state_machine_cache = CSVStateMachineCache::Get(buffer_manager->context); - auto state_options = state_machine->state_machine_options; - // To set the state machine to be strict we ensure that rfc_4180 is set to true - if (!state_options.rfc_4180.IsSetByUser()) { - state_options.rfc_4180 = true; - } - state_machine_strict = - make_shared_ptr(state_machine_cache.Get(state_options), state_machine->options); - } - // At this point we have 3 options: - // 1. We are at the start of a valid line - ValidRowInfo best_row = TryRow(CSVState::STANDARD_NEWLINE, iterator.pos.buffer_pos, iterator.GetEndPos()); - // 2. We are in the middle of a quoted value - if (state_machine->dialect_options.state_machine_options.quote.GetValue() != '\0') { - idx_t end_pos = iterator.GetEndPos(); - if (best_row.is_valid && best_row.end_buffer_idx == iterator.pos.buffer_idx) { - // If we got a valid row from the standard state, we limit our search up to that. - end_pos = best_row.end_pos; - } - auto quoted_row = TryRow(CSVState::QUOTED, iterator.pos.buffer_pos, end_pos); - if (quoted_row.is_valid && (!best_row.is_valid || best_row.last_state_quote)) { - best_row = quoted_row; - } - if (!best_row.is_valid && !quoted_row.is_valid && best_row.start_pos < quoted_row.start_pos) { - best_row = quoted_row; - } - } - // 3. We are in an escaped value - if (!best_row.is_valid && state_machine->dialect_options.state_machine_options.escape.GetValue() != '\0' && - state_machine->dialect_options.state_machine_options.quote.GetValue() != '\0') { - auto escape_row = TryRow(CSVState::ESCAPE, iterator.pos.buffer_pos, iterator.GetEndPos()); - if (escape_row.is_valid) { - best_row = escape_row; - } else { - if (best_row.start_pos < escape_row.start_pos) { - best_row = escape_row; - } - } - } - if (!best_row.is_valid) { - bool is_this_the_end = - best_row.start_pos >= cur_buffer_handle->actual_size && cur_buffer_handle->is_last_buffer; - if (is_this_the_end) { - iterator.pos.buffer_pos = best_row.start_pos; - iterator.done = true; - } else { - bool mock; - if (!SkipUntilState(CSVState::STANDARD_NEWLINE, CSVState::RECORD_SEPARATOR, iterator, mock)) { - iterator.CheckIfDone(); - } - } - } else { - iterator.pos.buffer_pos = best_row.start_pos; - bool is_this_the_end = - best_row.start_pos >= cur_buffer_handle->actual_size && cur_buffer_handle->is_last_buffer; - if (is_this_the_end) { - iterator.done = true; - } - } - - // 4. We have an error, if we have an error, we let life go on, the scanner will either ignore it - // or throw. - result.last_position = {iterator.pos.buffer_idx, iterator.pos.buffer_pos, result.buffer_size}; - start_pos = iterator.GetGlobalCurrentPos(); -} - -void StringValueScanner::FinalizeChunkProcess() { - if (static_cast(result.number_of_rows) >= result.result_size || iterator.done) { - // We are done - if (!sniffing) { - if (csv_file_scan) { - csv_file_scan->bytes_read += bytes_read; - bytes_read = 0; - } - } - return; - } - // If we are not done we have two options. - // 1) If a boundary is set. - if (iterator.IsBoundarySet()) { - bool found_error = false; - CSVErrorType type; - if (!result.current_errors.HasErrorType(UNTERMINATED_QUOTES) && - !result.current_errors.HasErrorType(INVALID_STATE)) { - iterator.done = true; - } else { - found_error = true; - if (result.current_errors.HasErrorType(UNTERMINATED_QUOTES)) { - type = UNTERMINATED_QUOTES; - } else { - type = INVALID_STATE; - } - } - // We read until the next line or until we have nothing else to read. - // Move to next buffer - if (!cur_buffer_handle) { - return; - } - bool moved = MoveToNextBuffer(); - if (cur_buffer_handle) { - if (moved && result.cur_col_id > 0) { - ProcessExtraRow(); - } else if (!moved) { - ProcessExtraRow(); - } - if (cur_buffer_handle->is_last_buffer && iterator.pos.buffer_pos >= cur_buffer_handle->actual_size) { - MoveToNextBuffer(); - } - } else { - if (result.current_errors.HasErrorType(UNTERMINATED_QUOTES)) { - found_error = true; - type = UNTERMINATED_QUOTES; - } else if (result.current_errors.HasErrorType(INVALID_STATE)) { - found_error = true; - type = INVALID_STATE; - } - if (result.current_errors.HandleErrors(result)) { - result.number_of_rows++; - } - } - if (states.IsQuotedCurrent() && !found_error && - state_machine->dialect_options.state_machine_options.rfc_4180.GetValue()) { - // If we finish the execution of a buffer, and we end in a quoted state, it means we have unterminated - // quotes - result.current_errors.Insert(type, result.cur_col_id, result.chunk_col_id, result.last_position); - if (result.current_errors.HandleErrors(result)) { - result.number_of_rows++; - } - } - if (!iterator.done) { - if (iterator.pos.buffer_pos >= iterator.GetEndPos() || iterator.pos.buffer_idx > iterator.GetBufferIdx() || - FinishedFile()) { - iterator.done = true; - } - } - } else { - // 2) If a boundary is not set - // We read until the chunk is complete, or we have nothing else to read. - while (!FinishedFile() && static_cast(result.number_of_rows) < result.result_size) { - MoveToNextBuffer(); - if (static_cast(result.number_of_rows) >= result.result_size) { - return; - } - if (cur_buffer_handle) { - Process(result); - } - } - iterator.done = FinishedFile(); - if (result.null_padding && result.number_of_rows < STANDARD_VECTOR_SIZE && result.chunk_col_id > 0) { - while (result.chunk_col_id < result.parse_chunk.ColumnCount()) { - result.validity_mask[result.chunk_col_id++]->SetInvalid(static_cast(result.number_of_rows)); - result.cur_col_id++; - } - result.number_of_rows++; - } - } -} - -ValidatorLine StringValueScanner::GetValidationLine() { - return {start_pos, result.iterator.GetGlobalCurrentPos()}; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp deleted file mode 100644 index c8740ebde..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/csv_sniffer.cpp +++ /dev/null @@ -1,272 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/common/types/value.hpp" - -namespace duckdb { - -CSVSniffer::CSVSniffer(CSVReaderOptions &options_p, shared_ptr buffer_manager_p, - CSVStateMachineCache &state_machine_cache_p, bool default_null_to_varchar_p) - : state_machine_cache(state_machine_cache_p), options(options_p), buffer_manager(std::move(buffer_manager_p)), - lines_sniffed(0), default_null_to_varchar(default_null_to_varchar_p) { - // Initialize Format Candidates - for (const auto &format_template : format_template_candidates) { - auto &logical_type = format_template.first; - best_format_candidates[logical_type].clear(); - } - // Initialize max columns found to either 0 or however many were set - max_columns_found = set_columns.Size(); - error_handler = make_shared_ptr(options.ignore_errors.GetValue()); - detection_error_handler = make_shared_ptr(true); - if (options.columns_set) { - set_columns = SetColumns(&options.sql_type_list, &options.name_list); - } -} - -bool SetColumns::IsSet() const { - if (!types) { - return false; - } - return !types->empty(); -} - -idx_t SetColumns::Size() const { - if (!types) { - return 0; - } - return types->size(); -} - -template -void MatchAndReplace(CSVOption &original, CSVOption &sniffed, const string &name, string &error) { - if (original.IsSetByUser()) { - // We verify that the user input matches the sniffed value - if (original != sniffed) { - error += "CSV Sniffer: Sniffer detected value different than the user input for the " + name; - error += " options \n Set: " + original.FormatValue() + ", Sniffed: " + sniffed.FormatValue() + "\n"; - } - } else { - // We replace the value of original with the sniffed value - original.Set(sniffed.GetValue(), false); - } -} - -void MatchAndReplaceUserSetVariables(DialectOptions &original, DialectOptions &sniffed, string &error, bool found_date, - bool found_timestamp) { - MatchAndReplace(original.header, sniffed.header, "Header", error); - if (sniffed.state_machine_options.new_line.GetValue() != NewLineIdentifier::NOT_SET) { - // Is sniffed line is not set (e.g., single-line file) , we don't try to replace and match. - MatchAndReplace(original.state_machine_options.new_line, sniffed.state_machine_options.new_line, "New Line", - error); - } - MatchAndReplace(original.skip_rows, sniffed.skip_rows, "Skip Rows", error); - MatchAndReplace(original.state_machine_options.delimiter, sniffed.state_machine_options.delimiter, "Delimiter", - error); - MatchAndReplace(original.state_machine_options.quote, sniffed.state_machine_options.quote, "Quote", error); - MatchAndReplace(original.state_machine_options.escape, sniffed.state_machine_options.escape, "Escape", error); - MatchAndReplace(original.state_machine_options.comment, sniffed.state_machine_options.comment, "Comment", error); - if (found_date) { - MatchAndReplace(original.date_format[LogicalTypeId::DATE], sniffed.date_format[LogicalTypeId::DATE], - "Date Format", error); - } - if (found_timestamp) { - MatchAndReplace(original.date_format[LogicalTypeId::TIMESTAMP], sniffed.date_format[LogicalTypeId::TIMESTAMP], - "Timestamp Format", error); - } -} -// Set the CSV Options in the reference -void CSVSniffer::SetResultOptions() const { - bool found_date = false; - bool found_timestamp = false; - for (auto &type : detected_types) { - if (type == LogicalType::DATE) { - found_date = true; - } else if (type == LogicalType::TIMESTAMP) { - found_timestamp = true; - } - } - MatchAndReplaceUserSetVariables(options.dialect_options, best_candidate->GetStateMachine().dialect_options, - options.sniffer_user_mismatch_error, found_date, found_timestamp); - options.dialect_options.num_cols = best_candidate->GetStateMachine().dialect_options.num_cols; - options.dialect_options.rows_until_header = best_candidate->GetStateMachine().dialect_options.rows_until_header; -} - -AdaptiveSnifferResult CSVSniffer::MinimalSniff() { - if (set_columns.IsSet()) { - // Nothing to see here - return AdaptiveSnifferResult(*set_columns.types, *set_columns.names, true); - } - // Return Types detected - vector return_types; - // Column Names detected - buffer_manager->sniffing = true; - constexpr idx_t result_size = STANDARD_VECTOR_SIZE; - - auto state_machine = - make_shared_ptr(options, options.dialect_options.state_machine_options, state_machine_cache); - ColumnCountScanner count_scanner(buffer_manager, state_machine, error_handler, result_size); - auto &sniffed_column_counts = count_scanner.ParseChunk(); - if (sniffed_column_counts.result_position == 0) { - // The file is an empty file, we just return - return {{}, {}, false}; - } - - state_machine->dialect_options.num_cols = sniffed_column_counts[0].number_of_columns; - options.dialect_options.num_cols = sniffed_column_counts[0].number_of_columns; - - // First figure out the number of columns on this configuration - auto scanner = count_scanner.UpgradeToStringValueScanner(); - scanner->error_handler->SetIgnoreErrors(true); - // Parse chunk and read csv with info candidate - auto &data_chunk = scanner->ParseChunk().ToChunk(); - idx_t start_row = 0; - if (sniffed_column_counts.result_position == 2) { - // If equal to two, we will only use the second row for type checking - start_row = 1; - } - - // Gather Types - for (idx_t i = 0; i < state_machine->dialect_options.num_cols; i++) { - best_sql_types_candidates_per_column_idx[i] = state_machine->options.auto_type_candidates; - } - SniffTypes(data_chunk, *state_machine, best_sql_types_candidates_per_column_idx, start_row); - - // Possibly Gather Header - vector potential_header; - for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { - auto &cur_vector = data_chunk.data[col_idx]; - const auto vector_data = FlatVector::GetData(cur_vector); - auto &validity = FlatVector::Validity(cur_vector); - HeaderValue val; - if (validity.RowIsValid(0)) { - val = HeaderValue(vector_data[0]); - } - potential_header.emplace_back(val); - } - - auto names = DetectHeaderInternal(buffer_manager->context, potential_header, *state_machine, set_columns, - best_sql_types_candidates_per_column_idx, options, *error_handler); - - for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { - LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); - if (best_sql_types_candidates_per_column_idx[column_idx].size() == options.auto_type_candidates.size()) { - d_type = LogicalType::VARCHAR; - } - detected_types.push_back(d_type); - } - return {detected_types, names, sniffed_column_counts.result_position > 1}; -} - -SnifferResult CSVSniffer::AdaptiveSniff(const CSVSchema &file_schema) { - auto min_sniff_res = MinimalSniff(); - bool run_full = error_handler->AnyErrors() || detection_error_handler->AnyErrors(); - // Check if we are happy with the result or if we need to do more sniffing - if (!error_handler->AnyErrors() && !detection_error_handler->AnyErrors()) { - // If we got no errors, we also run full if schemas do not match. - if (!set_columns.IsSet() && !options.file_options.AnySet()) { - string error; - run_full = !file_schema.SchemasMatch(error, min_sniff_res, options.file_path, true); - } - } - if (run_full) { - // We run full sniffer - auto full_sniffer = SniffCSV(); - if (!set_columns.IsSet() && !options.file_options.AnySet()) { - string error; - if (!file_schema.SchemasMatch(error, full_sniffer, options.file_path, false) && - !options.ignore_errors.GetValue()) { - throw InvalidInputException(error); - } - } - return full_sniffer; - } - return min_sniff_res.ToSnifferResult(); -} - -SnifferResult CSVSniffer::SniffCSV(const bool force_match) { - buffer_manager->sniffing = true; - // 1. Dialect Detection - DetectDialect(); - if (buffer_manager->file_handle->compression_type != FileCompressionType::UNCOMPRESSED && - buffer_manager->IsBlockUnloaded(0)) { - buffer_manager->ResetBufferManager(); - } - // 2. Type Detection - DetectTypes(); - // 3. Type Refinement - RefineTypes(); - // 4. Header Detection - DetectHeader(); - // 5. Type Replacement - ReplaceTypes(); - - // We reset the buffer for compressed files - // This is done because we can't easily seek on compressed files, if a buffer goes out of scope we must read from - // the start - if (buffer_manager->file_handle->compression_type != FileCompressionType::UNCOMPRESSED) { - buffer_manager->ResetBufferManager(); - } - buffer_manager->sniffing = false; - if (best_candidate->error_handler->AnyErrors() && !options.ignore_errors.GetValue()) { - best_candidate->error_handler->ErrorIfTypeExists(MAXIMUM_LINE_SIZE); - } - D_ASSERT(best_sql_types_candidates_per_column_idx.size() == names.size()); - // We are done, Set the CSV Options in the reference. Construct and return the result. - SetResultOptions(); - options.auto_detect = true; - // Check if everything matches - auto &error = options.sniffer_user_mismatch_error; - if (set_columns.IsSet()) { - bool match = true; - // Columns and their types were set, let's validate they match - if (options.dialect_options.header.GetValue()) { - // If the header exists it should match - string header_error = "The Column names set by the user do not match the ones found by the sniffer. \n"; - auto &set_names = *set_columns.names; - if (set_names.size() == names.size()) { - for (idx_t i = 0; i < set_columns.Size(); i++) { - if (set_names[i] != names[i]) { - header_error += "Column at position: " + to_string(i) + ", Set name: " + set_names[i] + - ", Sniffed Name: " + names[i] + "\n"; - match = false; - } - } - } - - if (!match) { - error += header_error; - } - } - match = true; - string type_error = "The Column types set by the user do not match the ones found by the sniffer. \n"; - auto &set_types = *set_columns.types; - if (detected_types.size() == set_columns.Size()) { - for (idx_t i = 0; i < set_columns.Size(); i++) { - if (set_types[i] != detected_types[i]) { - type_error += "Column at position: " + to_string(i) + " Set type: " + set_types[i].ToString() + - " Sniffed type: " + detected_types[i].ToString() + "\n"; - detected_types[i] = set_types[i]; - manually_set[i] = true; - match = false; - } - } - } - - if (!match) { - error += type_error; - } - - if (!error.empty() && force_match) { - throw InvalidInputException(error); - } - options.was_type_manually_set = manually_set; - } - if (!error.empty() && force_match) { - throw InvalidInputException(error); - } - options.was_type_manually_set = manually_set; - if (set_columns.IsSet()) { - return SnifferResult(*set_columns.types, *set_columns.names); - } - return SnifferResult(detected_types, names); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp deleted file mode 100644 index 5cced646d..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/dialect_detection.cpp +++ /dev/null @@ -1,597 +0,0 @@ -#include "duckdb/common/shared_ptr.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" - -namespace duckdb { - -constexpr idx_t CSVReaderOptions::sniff_size; - -bool IsQuoteDefault(char quote) { - if (quote == '\"' || quote == '\'' || quote == '\0') { - return true; - } - return false; -} - -vector DialectCandidates::GetDefaultDelimiter() { - return {",", "|", ";", "\t"}; -} - -vector> DialectCandidates::GetDefaultQuote() { - return {{'\0'}, {'\"', '\''}, {'\"'}}; -} - -vector DialectCandidates::GetDefaultQuoteRule() { - return {QuoteRule::NO_QUOTES, QuoteRule::QUOTES_OTHER, QuoteRule::QUOTES_RFC}; -} - -vector> DialectCandidates::GetDefaultEscape() { - return {{'\0'}, {'\\'}, {'\"', '\0', '\''}}; -} - -vector DialectCandidates::GetDefaultComment() { - return {'#', '\0'}; -} - -string DialectCandidates::Print() { - std::ostringstream search_space; - - search_space << "Delimiter Candidates: "; - for (idx_t i = 0; i < delim_candidates.size(); i++) { - search_space << "\'" << delim_candidates[i] << "\'"; - if (i < delim_candidates.size() - 1) { - search_space << ", "; - } - } - search_space << "\n"; - search_space << "Quote/Escape Candidates: "; - for (uint8_t i = 0; i < static_cast(quote_rule_candidates.size()); i++) { - auto quote_candidate = quote_candidates_map[i]; - auto escape_candidate = escape_candidates_map[i]; - for (idx_t j = 0; j < quote_candidate.size(); j++) { - for (idx_t k = 0; k < escape_candidate.size(); k++) { - search_space << "[\'"; - if (quote_candidate[j] == '\0') { - search_space << "(no quote)"; - } else { - search_space << quote_candidate[j]; - } - search_space << "\',\'"; - if (escape_candidate[k] == '\0') { - search_space << "(no escape)"; - } else { - search_space << escape_candidate[k]; - } - search_space << "\']"; - if (k < escape_candidate.size() - 1) { - search_space << ","; - } - } - if (j < quote_candidate.size() - 1) { - search_space << ","; - } - } - if (i < quote_rule_candidates.size() - 1) { - search_space << ","; - } - } - search_space << "\n"; - - search_space << "Comment Candidates: "; - for (idx_t i = 0; i < comment_candidates.size(); i++) { - search_space << "\'" << comment_candidates[i] << "\'"; - if (i < comment_candidates.size() - 1) { - search_space << ", "; - } - } - search_space << "\n"; - - return search_space.str(); -} - -DialectCandidates::DialectCandidates(const CSVStateMachineOptions &options) { - // assert that quotes escapes and rules have equal size - const auto default_quote = GetDefaultQuote(); - const auto default_escape = GetDefaultEscape(); - const auto default_quote_rule = GetDefaultQuoteRule(); - const auto default_delimiter = GetDefaultDelimiter(); - const auto default_comment = GetDefaultComment(); - - D_ASSERT(default_quote.size() == default_quote_rule.size() && default_quote_rule.size() == default_escape.size()); - // fill the escapes - for (idx_t i = 0; i < default_quote_rule.size(); i++) { - escape_candidates_map[static_cast(default_quote_rule[i])] = default_escape[i]; - } - - if (options.delimiter.IsSetByUser()) { - // user provided a delimiter: use that delimiter - delim_candidates = {options.delimiter.GetValue()}; - } else { - // no delimiter provided: try standard/common delimiters - delim_candidates = default_delimiter; - } - if (options.comment.IsSetByUser()) { - // user provided comment character: use that as a comment - comment_candidates = {options.comment.GetValue()}; - } else { - // no comment provided: try standard/common comments - comment_candidates = default_comment; - } - if (options.quote.IsSetByUser()) { - // user provided quote: use that quote rule - for (auto "e_rule : default_quote_rule) { - quote_candidates_map[static_cast(quote_rule)] = {options.quote.GetValue()}; - } - // also add it as an escape rule - if (!IsQuoteDefault(options.quote.GetValue())) { - escape_candidates_map[static_cast(QuoteRule::QUOTES_RFC)].emplace_back(options.quote.GetValue()); - } - } else { - // no quote rule provided: use standard/common quotes - for (idx_t i = 0; i < default_quote_rule.size(); i++) { - quote_candidates_map[static_cast(default_quote_rule[i])] = {default_quote[i]}; - } - } - if (options.escape.IsSetByUser()) { - // user provided escape: use that escape rule - if (options.escape == '\0') { - quote_rule_candidates = {QuoteRule::QUOTES_RFC}; - } else { - quote_rule_candidates = {QuoteRule::QUOTES_OTHER}; - } - escape_candidates_map[static_cast(quote_rule_candidates[0])] = {options.escape.GetValue()}; - } else { - // no escape provided: try standard/common escapes - quote_rule_candidates = default_quote_rule; - } -} - -void CSVSniffer::GenerateStateMachineSearchSpace(vector> &column_count_scanners, - const DialectCandidates &dialect_candidates) { - // Generate state machines for all option combinations - NewLineIdentifier new_line_id; - if (options.dialect_options.state_machine_options.new_line.IsSetByUser()) { - new_line_id = options.dialect_options.state_machine_options.new_line.GetValue(); - } else { - new_line_id = DetectNewLineDelimiter(*buffer_manager); - } - // We only sniff RFC 4180 rules, unless manually set by user. - bool rfc_4180 = true; - if (options.dialect_options.state_machine_options.rfc_4180.IsSetByUser()) { - rfc_4180 = options.dialect_options.state_machine_options.rfc_4180.GetValue(); - } - CSVIterator first_iterator; - bool iterator_set = false; - for (const auto quote_rule : dialect_candidates.quote_rule_candidates) { - const auto "e_candidates = dialect_candidates.quote_candidates_map.at(static_cast(quote_rule)); - for (const auto "e : quote_candidates) { - for (const auto &delimiter : dialect_candidates.delim_candidates) { - const auto &escape_candidates = - dialect_candidates.escape_candidates_map.at(static_cast(quote_rule)); - for (const auto &escape : escape_candidates) { - for (const auto &comment : dialect_candidates.comment_candidates) { - D_ASSERT(buffer_manager); - CSVStateMachineOptions state_machine_options(delimiter, quote, escape, comment, new_line_id, - rfc_4180); - auto sniffing_state_machine = - make_shared_ptr(options, state_machine_options, state_machine_cache); - if (options.dialect_options.skip_rows.IsSetByUser()) { - if (!iterator_set) { - first_iterator = BaseScanner::SkipCSVRows(buffer_manager, sniffing_state_machine, - options.dialect_options.skip_rows.GetValue()); - iterator_set = true; - } - column_count_scanners.emplace_back(make_uniq( - buffer_manager, std::move(sniffing_state_machine), detection_error_handler, - CSVReaderOptions::sniff_size, first_iterator)); - continue; - } - column_count_scanners.emplace_back( - make_uniq(buffer_manager, std::move(sniffing_state_machine), - detection_error_handler, CSVReaderOptions::sniff_size)); - } - } - } - } - } -} - -// Returns true if a comment is acceptable -bool AreCommentsAcceptable(const ColumnCountResult &result, idx_t num_cols, bool comment_set_by_user) { - if (comment_set_by_user) { - return true; - } - // For a comment to be acceptable, we want 3/5th's the majority of unmatched in the columns - constexpr double min_majority = 0.6; - // detected comments, are all lines that started with a comment character. - double detected_comments = 0; - // If at least one comment is a full line comment - bool has_full_line_comment = false; - // valid comments are all lines where the number of columns does not fit our expected number of columns. - double valid_comments = 0; - for (idx_t i = 0; i < result.result_position; i++) { - if (result.column_counts[i].is_comment || result.column_counts[i].is_mid_comment) { - detected_comments++; - if (result.column_counts[i].number_of_columns != num_cols && result.column_counts[i].is_comment) { - has_full_line_comment = true; - valid_comments++; - } - if (result.column_counts[i].number_of_columns == num_cols && result.column_counts[i].is_mid_comment) { - valid_comments++; - } - } - } - // If we do not encounter at least one full line comment, we do not consider this comment option. - if (valid_comments == 0 || !has_full_line_comment) { - // this is only valid if our comment character is \0 - if (result.state_machine.state_machine_options.comment.GetValue() == '\0') { - return true; - } - return false; - } - - return valid_comments / detected_comments >= min_majority; -} - -void CSVSniffer::AnalyzeDialectCandidate(unique_ptr scanner, idx_t &rows_read, - idx_t &best_consistent_rows, idx_t &prev_padding_count, - idx_t &min_ignored_rows) { - // The sniffed_column_counts variable keeps track of the number of columns found for each row - auto &sniffed_column_counts = scanner->ParseChunk(); - idx_t dirty_notes = 0; - idx_t dirty_notes_minus_comments = 0; - if (sniffed_column_counts.error) { - // This candidate has an error (i.e., over maximum line size or never unquoting quoted values) - return; - } - idx_t consistent_rows = 0; - idx_t num_cols = sniffed_column_counts.result_position == 0 ? 1 : sniffed_column_counts[0].number_of_columns; - const bool ignore_errors = options.ignore_errors.GetValue(); - // If we are ignoring errors and not null_padding , we pick the most frequent number of columns as the right one - const bool use_most_frequent_columns = ignore_errors && !options.null_padding; - if (use_most_frequent_columns) { - num_cols = sniffed_column_counts.GetMostFrequentColumnCount(); - } - idx_t padding_count = 0; - idx_t comment_rows = 0; - idx_t ignored_rows = 0; - const bool allow_padding = options.null_padding; - bool first_valid = false; - if (sniffed_column_counts.result_position > rows_read) { - rows_read = sniffed_column_counts.result_position; - } - if (set_columns.IsCandidateUnacceptable(num_cols, options.null_padding, ignore_errors, - sniffed_column_counts[0].last_value_always_empty)) { - // Not acceptable - return; - } - idx_t header_idx = 0; - for (idx_t row = 0; row < sniffed_column_counts.result_position; row++) { - if (set_columns.IsCandidateUnacceptable(sniffed_column_counts[row].number_of_columns, options.null_padding, - ignore_errors, sniffed_column_counts[row].last_value_always_empty)) { - // Not acceptable - return; - } - if (sniffed_column_counts[row].is_comment) { - comment_rows++; - } else if (sniffed_column_counts[row].last_value_always_empty && - sniffed_column_counts[row].number_of_columns == - sniffed_column_counts[header_idx].number_of_columns + 1) { - // we allow for the first row to miss one column IF last_value_always_empty is true - // This is so we can sniff files that have an extra delimiter on the data part. - // e.g., C1|C2\n1|2|\n3|4| - consistent_rows++; - } else if (num_cols < sniffed_column_counts[row].number_of_columns && - (!options.dialect_options.skip_rows.IsSetByUser() || comment_rows > 0) && - (!set_columns.IsSet() || options.null_padding) && (!first_valid || (!use_most_frequent_columns))) { - // all rows up to this point will need padding - if (!first_valid) { - first_valid = true; - sniffed_column_counts.state_machine.dialect_options.rows_until_header = row; - } - padding_count = 0; - // we use the maximum amount of num_cols that we find - num_cols = sniffed_column_counts[row].number_of_columns; - dirty_notes = row; - dirty_notes_minus_comments = dirty_notes - comment_rows; - header_idx = row; - consistent_rows = 1; - } else if (sniffed_column_counts[row].number_of_columns == num_cols || (use_most_frequent_columns)) { - if (!first_valid) { - first_valid = true; - sniffed_column_counts.state_machine.dialect_options.rows_until_header = row; - dirty_notes = row; - } - if (sniffed_column_counts[row].number_of_columns != num_cols) { - ignored_rows++; - } - consistent_rows++; - } else if (num_cols >= sniffed_column_counts[row].number_of_columns) { - // we are missing some columns, we can parse this as long as we add padding - padding_count++; - } - } - - if (sniffed_column_counts.state_machine.options.dialect_options.skip_rows.IsSetByUser()) { - sniffed_column_counts.state_machine.dialect_options.rows_until_header += - sniffed_column_counts.state_machine.options.dialect_options.skip_rows.GetValue(); - } - // Calculate the total number of consistent rows after adding padding. - consistent_rows += padding_count; - - // Whether there are more values (rows) available that are consistent, exceeding the current best. - const bool more_values = consistent_rows > best_consistent_rows && num_cols >= max_columns_found; - - const bool more_columns = consistent_rows == best_consistent_rows && num_cols > max_columns_found; - - // If additional padding is required when compared to the previous padding count. - const bool require_more_padding = padding_count > prev_padding_count; - - // If less padding is now required when compared to the previous padding count. - const bool require_less_padding = padding_count < prev_padding_count; - - // If there was only a single column before, and the new number of columns exceeds that. - const bool single_column_before = max_columns_found < 2 && num_cols > max_columns_found * candidates.size(); - - // If the number of rows is consistent with the calculated value after accounting for skipped rows and the - // start row. - const bool rows_consistent = - consistent_rows + (dirty_notes_minus_comments - options.dialect_options.skip_rows.GetValue()) + comment_rows == - sniffed_column_counts.result_position - options.dialect_options.skip_rows.GetValue(); - // If there are more than one consistent row. - const bool more_than_one_row = consistent_rows > 1; - - // If there are more than one column. - const bool more_than_one_column = num_cols > 1; - - // If the start position is valid. - const bool start_good = !candidates.empty() && - dirty_notes <= candidates.front()->GetStateMachine().dialect_options.skip_rows.GetValue(); - - // If padding happened but it is not allowed. - const bool invalid_padding = !allow_padding && padding_count > 0; - - const bool comments_are_acceptable = AreCommentsAcceptable( - sniffed_column_counts, num_cols, options.dialect_options.state_machine_options.comment.IsSetByUser()); - - const bool quoted = - scanner->ever_quoted && - sniffed_column_counts.state_machine.dialect_options.state_machine_options.quote.GetValue() != '\0'; - - // For our columns to match, we either don't have them manually set, or they match in value with the sniffed value - const bool columns_match_set = - num_cols == set_columns.Size() || - (num_cols == set_columns.Size() + 1 && sniffed_column_counts[0].last_value_always_empty) || - !set_columns.IsSet(); - - // If rows are consistent and no invalid padding happens, this is the best suitable candidate if one of the - // following is valid: - // - There's a single column before. - // - There are more values and no additional padding is required. - // - There's more than one column and less padding is required. - if (columns_match_set && (rows_consistent || (set_columns.IsSet() && ignore_errors)) && - (single_column_before || ((more_values || more_columns) && !require_more_padding) || - (more_than_one_column && require_less_padding) || quoted) && - !invalid_padding && comments_are_acceptable) { - if (!candidates.empty() && set_columns.IsSet() && max_columns_found == set_columns.Size() && - consistent_rows <= best_consistent_rows) { - // We have a candidate that fits our requirements better - if (candidates.front()->ever_quoted || !scanner->ever_quoted) { - return; - } - } - auto &sniffing_state_machine = scanner->GetStateMachine(); - - if (!candidates.empty() && candidates.front()->ever_quoted) { - // Give preference to quoted boys. - if (!scanner->ever_quoted) { - return; - } else { - // Give preference to one that got escaped - if (!scanner->ever_escaped && candidates.front()->ever_escaped) { - return; - } - if (best_consistent_rows == consistent_rows && num_cols >= max_columns_found) { - // If both have not been escaped, this might get solved later on. - sniffing_state_machine.dialect_options.num_cols = num_cols; - candidates.emplace_back(std::move(scanner)); - max_columns_found = num_cols; - return; - } - } - } - if (max_columns_found == num_cols && ignored_rows > min_ignored_rows) { - return; - } - if (quoted && num_cols < max_columns_found) { - for (auto &candidate : candidates) { - if (candidate->ever_quoted) { - return; - } - } - } - best_consistent_rows = consistent_rows; - max_columns_found = num_cols; - prev_padding_count = padding_count; - min_ignored_rows = ignored_rows; - - if (options.dialect_options.skip_rows.IsSetByUser()) { - // If skip rows is set by user, and we found dirty notes, we only accept it if either null_padding or - // ignore_errors is set we have comments - if (dirty_notes != 0 && !options.null_padding && !options.ignore_errors.GetValue() && comment_rows == 0) { - return; - } - sniffing_state_machine.dialect_options.skip_rows = options.dialect_options.skip_rows.GetValue(); - } else if (!options.null_padding) { - sniffing_state_machine.dialect_options.skip_rows = dirty_notes; - } - - candidates.clear(); - sniffing_state_machine.dialect_options.num_cols = num_cols; - lines_sniffed = sniffed_column_counts.result_position; - candidates.emplace_back(std::move(scanner)); - return; - } - // If there's more than one row and column, the start is good, rows are consistent, - // no additional padding is required, and there is no invalid padding, and there is not yet a candidate - // with the same quote, we add this state_machine as a suitable candidate. - if (columns_match_set && more_than_one_row && more_than_one_column && start_good && rows_consistent && - !require_more_padding && !invalid_padding && num_cols == max_columns_found && comments_are_acceptable) { - auto &sniffing_state_machine = scanner->GetStateMachine(); - - bool same_quote_is_candidate = false; - for (const auto &candidate : candidates) { - if (sniffing_state_machine.dialect_options.state_machine_options.quote == - candidate->GetStateMachine().dialect_options.state_machine_options.quote) { - same_quote_is_candidate = true; - } - } - if (!same_quote_is_candidate) { - if (options.dialect_options.skip_rows.IsSetByUser()) { - // If skip rows is set by user, and we found dirty notes, we only accept it if either null_padding or - // ignore_errors is set - if (dirty_notes != 0 && !options.null_padding && !options.ignore_errors.GetValue()) { - return; - } - sniffing_state_machine.dialect_options.skip_rows = options.dialect_options.skip_rows.GetValue(); - } else if (!options.null_padding) { - sniffing_state_machine.dialect_options.skip_rows = dirty_notes; - } - sniffing_state_machine.dialect_options.num_cols = num_cols; - lines_sniffed = sniffed_column_counts.result_position; - candidates.emplace_back(std::move(scanner)); - } - } -} - -bool CSVSniffer::RefineCandidateNextChunk(ColumnCountScanner &candidate) const { - auto &sniffed_column_counts = candidate.ParseChunk(); - for (idx_t i = 0; i < sniffed_column_counts.result_position; i++) { - if (set_columns.IsSet()) { - return !set_columns.IsCandidateUnacceptable(sniffed_column_counts[i].number_of_columns, - options.null_padding, options.ignore_errors.GetValue(), - sniffed_column_counts[i].last_value_always_empty); - } - if (max_columns_found != sniffed_column_counts[i].number_of_columns && - (!options.null_padding && !options.ignore_errors.GetValue() && !sniffed_column_counts[i].is_comment)) { - return false; - } - } - return true; -} - -void CSVSniffer::RefineCandidates() { - // It's very frequent that more than one dialect can parse a csv file, hence here we run one state machine - // fully on the whole sample dataset, when/if it fails we go to the next one. - if (candidates.empty()) { - // No candidates to refine - return; - } - if (candidates.size() == 1 || candidates[0]->FinishedFile()) { - // Only one candidate nothing to refine or all candidates already checked - return; - } - - for (idx_t i = 1; i <= options.sample_size_chunks; i++) { - vector> successful_candidates; - bool done = false; - for (auto &cur_candidate : candidates) { - const bool finished_file = cur_candidate->FinishedFile(); - if (successful_candidates.empty()) { - lines_sniffed += cur_candidate->GetResult().result_position; - } - if (finished_file || i == options.sample_size_chunks) { - // we finished the file or our chunk sample successfully - if (!cur_candidate->GetResult().error) { - successful_candidates.push_back(std::move(cur_candidate)); - } - done = true; - continue; - } - if (RefineCandidateNextChunk(*cur_candidate) && !cur_candidate->GetResult().error) { - successful_candidates.push_back(std::move(cur_candidate)); - } - } - candidates = std::move(successful_candidates); - if (done) { - break; - } - } - // If we have multiple candidates with quotes set, we will give the preference to ones - // that have actually quoted values, otherwise we will choose quotes = \0 - vector> successful_candidates = std::move(candidates); - if (!successful_candidates.empty()) { - for (idx_t i = 0; i < successful_candidates.size(); i++) { - unique_ptr cc_best_candidate = std::move(successful_candidates[i]); - if (cc_best_candidate->state_machine->state_machine_options.quote != '\0' && - cc_best_candidate->ever_quoted) { - candidates.clear(); - candidates.push_back(std::move(cc_best_candidate)); - return; - } - candidates.push_back(std::move(cc_best_candidate)); - } - } -} - -NewLineIdentifier CSVSniffer::DetectNewLineDelimiter(CSVBufferManager &buffer_manager) { - // Get first buffer - auto buffer = buffer_manager.GetBuffer(0); - auto buffer_ptr = buffer->Ptr(); - bool carriage_return = false; - bool n = false; - for (idx_t i = 0; i < buffer->actual_size; i++) { - if (buffer_ptr[i] == '\r') { - carriage_return = true; - } else if (buffer_ptr[i] == '\n') { - n = true; - break; - } else if (carriage_return) { - break; - } - } - if (carriage_return && n) { - return NewLineIdentifier::CARRY_ON; - } - if (carriage_return) { - return NewLineIdentifier::SINGLE_R; - } - return NewLineIdentifier::SINGLE_N; -} - -// Dialect Detection consists of five steps: -// 1. Generate a search space of all possible dialects -// 2. Generate a state machine for each dialect -// 3. Analyze the first chunk of the file and find the best dialect candidates -// 4. Analyze the remaining chunks of the file and find the best dialect candidate -void CSVSniffer::DetectDialect() { - // Variables for Dialect Detection - DialectCandidates dialect_candidates(options.dialect_options.state_machine_options); - // Number of rows read - idx_t rows_read = 0; - // Best Number of consistent rows (i.e., presenting all columns) - idx_t best_consistent_rows = 0; - // If padding was necessary (i.e., rows are missing some columns, how many) - idx_t prev_padding_count = 0; - // Min number of ignores rows - idx_t best_ignored_rows = 0; - // Vector of CSV State Machines - vector> csv_state_machines; - // Step 1: Generate state machines - GenerateStateMachineSearchSpace(csv_state_machines, dialect_candidates); - // Step 2: Analyze all candidates on the first chunk - for (auto &state_machine : csv_state_machines) { - AnalyzeDialectCandidate(std::move(state_machine), rows_read, best_consistent_rows, prev_padding_count, - best_ignored_rows); - } - // Step 3: Loop over candidates and find if they can still produce good results for the remaining chunks - RefineCandidates(); - - // if no dialect candidate was found, we throw an exception - if (candidates.empty()) { - auto error = CSVError::SniffingError(options, dialect_candidates.Print()); - error_handler->Error(error, true); - } -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp deleted file mode 100644 index 424468c55..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/header_detection.cpp +++ /dev/null @@ -1,343 +0,0 @@ -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" - -#include "utf8proc.hpp" - -namespace duckdb { -// Helper function to generate column names -static string GenerateColumnName(const idx_t total_cols, const idx_t col_number, const string &prefix = "column") { - auto max_digits = NumericHelper::UnsignedLength(total_cols - 1); - auto digits = NumericHelper::UnsignedLength(col_number); - string leading_zeros = string(NumericCast(max_digits - digits), '0'); - string value = to_string(col_number); - return string(prefix + leading_zeros + value); -} - -// Helper function for UTF-8 aware space trimming -static string TrimWhitespace(const string &col_name) { - utf8proc_int32_t codepoint; - const auto str = reinterpret_cast(col_name.c_str()); - const idx_t size = col_name.size(); - // Find the first character that is not left trimmed - idx_t begin = 0; - while (begin < size) { - auto bytes = utf8proc_iterate(str + begin, NumericCast(size - begin), &codepoint); - D_ASSERT(bytes > 0); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - break; - } - begin += NumericCast(bytes); - } - - // Find the last character that is not right trimmed - idx_t end = begin; - for (auto next = begin; next < col_name.size();) { - auto bytes = utf8proc_iterate(str + next, NumericCast(size - next), &codepoint); - D_ASSERT(bytes > 0); - next += NumericCast(bytes); - if (utf8proc_category(codepoint) != UTF8PROC_CATEGORY_ZS) { - end = next; - } - } - - // return the trimmed string - return col_name.substr(begin, end - begin); -} - -static string NormalizeColumnName(const string &col_name) { - // normalize UTF8 characters to NFKD - auto nfkd = utf8proc_NFKD(reinterpret_cast(col_name.c_str()), - NumericCast(col_name.size())); - const string col_name_nfkd = string(const_char_ptr_cast(nfkd), strlen(const_char_ptr_cast(nfkd))); - free(nfkd); - - // only keep ASCII characters 0-9 a-z A-Z and replace spaces with regular whitespace - string col_name_ascii = ""; - for (idx_t i = 0; i < col_name_nfkd.size(); i++) { - if (col_name_nfkd[i] == '_' || (col_name_nfkd[i] >= '0' && col_name_nfkd[i] <= '9') || - (col_name_nfkd[i] >= 'A' && col_name_nfkd[i] <= 'Z') || - (col_name_nfkd[i] >= 'a' && col_name_nfkd[i] <= 'z')) { - col_name_ascii += col_name_nfkd[i]; - } else if (StringUtil::CharacterIsSpace(col_name_nfkd[i])) { - col_name_ascii += " "; - } - } - - // trim whitespace and replace remaining whitespace by _ - string col_name_trimmed = TrimWhitespace(col_name_ascii); - string col_name_cleaned = ""; - bool in_whitespace = false; - for (idx_t i = 0; i < col_name_trimmed.size(); i++) { - if (col_name_trimmed[i] == ' ') { - if (!in_whitespace) { - col_name_cleaned += "_"; - in_whitespace = true; - } - } else { - col_name_cleaned += col_name_trimmed[i]; - in_whitespace = false; - } - } - - // don't leave string empty; if not empty, make lowercase - if (col_name_cleaned.empty()) { - col_name_cleaned = "_"; - } else { - col_name_cleaned = StringUtil::Lower(col_name_cleaned); - } - - // prepend _ if name starts with a digit or is a reserved keyword - auto keyword = KeywordHelper::KeywordCategoryType(col_name_cleaned); - if (keyword == KeywordCategory::KEYWORD_TYPE_FUNC || keyword == KeywordCategory::KEYWORD_RESERVED || - (col_name_cleaned[0] >= '0' && col_name_cleaned[0] <= '9')) { - col_name_cleaned = "_" + col_name_cleaned; - } - return col_name_cleaned; -} - -static void ReplaceNames(vector &detected_names, CSVStateMachine &state_machine, - unordered_map> &best_sql_types_candidates_per_column_idx, - CSVReaderOptions &options, const vector &best_header_row, - CSVErrorHandler &error_handler) { - auto &dialect_options = state_machine.dialect_options; - if (!options.columns_set) { - if (options.file_options.hive_partitioning || options.file_options.union_by_name || options.multi_file_reader) { - // Just do the replacement - for (idx_t i = 0; i < MinValue(detected_names.size(), options.name_list.size()); i++) { - detected_names[i] = options.name_list[i]; - } - return; - } - if (options.name_list.size() > dialect_options.num_cols) { - if (options.null_padding) { - // we increase our types - idx_t col = 0; - for (idx_t i = dialect_options.num_cols; i < options.name_list.size(); i++) { - detected_names.push_back(GenerateColumnName(options.name_list.size(), col++)); - best_sql_types_candidates_per_column_idx[i] = {LogicalType::VARCHAR}; - } - - dialect_options.num_cols = options.name_list.size(); - - } else { - // we throw an error - const auto error = CSVError::HeaderSniffingError( - options, best_header_row, options.name_list.size(), - state_machine.dialect_options.state_machine_options.delimiter.GetValue()); - error_handler.Error(error); - } - } - for (idx_t i = 0; i < options.name_list.size(); i++) { - detected_names[i] = options.name_list[i]; - } - } -} - -// If our columns were set by the user, we verify if their names match with the first row -bool CSVSniffer::DetectHeaderWithSetColumn(ClientContext &context, vector &best_header_row, - const SetColumns &set_columns, CSVReaderOptions &options) { - bool has_header = true; - - std::ostringstream error; - // User set the names, we must check if they match the first row - // We do a +1 to check for situations where the csv file has an extra all null column - if (set_columns.Size() != best_header_row.size() && set_columns.Size() + 1 != best_header_row.size()) { - return false; - } - - // Let's do a match-aroo - for (idx_t i = 0; i < set_columns.Size(); i++) { - if (best_header_row[i].IsNull()) { - return false; - } - if (best_header_row[i].value != (*set_columns.names)[i]) { - error << "Header mismatch at position: " << i << "\n"; - error << "Expected name: \"" << (*set_columns.names)[i] << "\", "; - error << "Actual name: \"" << best_header_row[i].value << "\"." - << "\n"; - has_header = false; - break; - } - } - - if (!has_header) { - bool all_varchar = true; - bool first_row_consistent = true; - // We verify if the types are consistent - for (idx_t col = 0; col < set_columns.Size(); col++) { - // try cast to sql_type of column - const auto &sql_type = (*set_columns.types)[col]; - if (sql_type != LogicalType::VARCHAR) { - all_varchar = false; - if (!CSVSniffer::CanYouCastIt(context, best_header_row[col].value, sql_type, options.dialect_options, - best_header_row[col].IsNull(), options.decimal_separator[0])) { - first_row_consistent = false; - } - } - } - if (!first_row_consistent) { - options.sniffer_user_mismatch_error += error.str(); - } - if (all_varchar) { - return true; - } - return !first_row_consistent; - } - return has_header; -} - -bool EmptyHeader(const string &col_name, bool is_null, bool normalize) { - if (col_name.empty() || is_null) { - return true; - } - if (normalize) { - // normalize has special logic to trim white spaces and generate names - return false; - } - // check if it's all white spaces - for (auto &c : col_name) { - if (!StringUtil::CharacterIsSpace(c)) { - return false; - } - } - // if we are not normalizing the name and is all white spaces, then we generate a name - return true; -} - -vector -CSVSniffer::DetectHeaderInternal(ClientContext &context, vector &best_header_row, - CSVStateMachine &state_machine, const SetColumns &set_columns, - unordered_map> &best_sql_types_candidates_per_column_idx, - CSVReaderOptions &options, CSVErrorHandler &error_handler) { - vector detected_names; - auto &dialect_options = state_machine.dialect_options; - dialect_options.num_cols = best_sql_types_candidates_per_column_idx.size(); - if (best_header_row.empty()) { - dialect_options.header = false; - for (idx_t col = 0; col < dialect_options.num_cols; col++) { - detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); - } - // If the user provided names, we must replace our header with the user provided names - ReplaceNames(detected_names, state_machine, best_sql_types_candidates_per_column_idx, options, best_header_row, - error_handler); - return detected_names; - } - // information for header detection - // check if header row is all null and/or consistent with detected column data types - // If null-padding is not allowed and there is a mismatch between our header candidate and the number of columns - // We can't detect the dialect/type options properly - if (!options.null_padding && best_sql_types_candidates_per_column_idx.size() != best_header_row.size()) { - if (options.ignore_errors.GetValue()) { - dialect_options.header = false; - for (idx_t col = 0; col < dialect_options.num_cols; col++) { - detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); - } - dialect_options.rows_until_header += 1; - ReplaceNames(detected_names, state_machine, best_sql_types_candidates_per_column_idx, options, - best_header_row, error_handler); - return detected_names; - } - auto error = - CSVError::HeaderSniffingError(options, best_header_row, best_sql_types_candidates_per_column_idx.size(), - state_machine.dialect_options.state_machine_options.delimiter.GetValue()); - error_handler.Error(error); - } - bool has_header; - - if (set_columns.IsSet()) { - has_header = DetectHeaderWithSetColumn(context, best_header_row, set_columns, options); - } else { - bool first_row_consistent = true; - bool all_varchar = true; - bool first_row_nulls = true; - for (idx_t col = 0; col < best_header_row.size(); col++) { - if (!best_header_row[col].IsNull()) { - first_row_nulls = false; - } - // try cast to sql_type of column - const auto &sql_type = best_sql_types_candidates_per_column_idx[col].back(); - if (sql_type != LogicalType::VARCHAR) { - all_varchar = false; - if (!CanYouCastIt(context, best_header_row[col].value, sql_type, dialect_options, - best_header_row[col].IsNull(), options.decimal_separator[0])) { - first_row_consistent = false; - } - } - } - // Our header is only false if types are not all varchar, and rows are consistent - if (all_varchar || first_row_nulls) { - has_header = true; - } else { - has_header = !first_row_consistent; - } - } - - if (options.dialect_options.header.IsSetByUser()) { - // Header is defined by user, use that. - has_header = options.dialect_options.header.GetValue(); - } - // update parser info, and read, generate & set col_names based on previous findings - if (has_header) { - dialect_options.header = true; - if (options.null_padding && !options.dialect_options.skip_rows.IsSetByUser()) { - if (dialect_options.skip_rows.GetValue() > 0) { - dialect_options.skip_rows = dialect_options.skip_rows.GetValue() - 1; - } - } - case_insensitive_map_t name_collision_count; - - // get header names from CSV - for (idx_t col = 0; col < best_header_row.size(); col++) { - string &col_name = best_header_row[col].value; - - // generate name if field is empty - if (EmptyHeader(col_name, best_header_row[col].is_null, options.normalize_names)) { - col_name = GenerateColumnName(dialect_options.num_cols, col); - } - - // normalize names or at least trim whitespace - if (options.normalize_names) { - col_name = NormalizeColumnName(col_name); - } else { - col_name = TrimWhitespace(col_name); - } - - // avoid duplicate header names - while (name_collision_count.find(col_name) != name_collision_count.end()) { - name_collision_count[col_name] += 1; - col_name = col_name + "_" + to_string(name_collision_count[col_name]); - } - detected_names.push_back(col_name); - name_collision_count[col_name] = 0; - } - if (best_header_row.size() < dialect_options.num_cols && options.null_padding) { - for (idx_t col = best_header_row.size(); col < dialect_options.num_cols; col++) { - detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); - } - } else if (best_header_row.size() < dialect_options.num_cols) { - throw InternalException("Detected header has number of columns inferior to dialect detection"); - } - - } else { - dialect_options.header = false; - for (idx_t col = 0; col < dialect_options.num_cols; col++) { - detected_names.push_back(GenerateColumnName(dialect_options.num_cols, col)); - } - } - - // If the user provided names, we must replace our header with the user provided names - ReplaceNames(detected_names, state_machine, best_sql_types_candidates_per_column_idx, options, best_header_row, - error_handler); - return detected_names; -} -void CSVSniffer::DetectHeader() { - auto &sniffer_state_machine = best_candidate->GetStateMachine(); - names = DetectHeaderInternal(buffer_manager->context, best_header_row, sniffer_state_machine, set_columns, - best_sql_types_candidates_per_column_idx, options, *error_handler); - for (idx_t i = max_columns_found; i < names.size(); i++) { - detected_types.push_back(LogicalType::VARCHAR); - } - max_columns_found = names.size(); -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp deleted file mode 100644 index d6b2b1a6b..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_detection.cpp +++ /dev/null @@ -1,493 +0,0 @@ -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/operator/decimal_cast_operators.hpp" -#include "duckdb/common/operator/double_cast_operator.hpp" -#include "duckdb/common/operator/integer_cast_operator.hpp" -#include "duckdb/common/string.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" - -namespace duckdb { -struct TryCastFloatingOperator { - template - static bool Operation(string_t input) { - T result; - string error_message; - CastParameters parameters(false, &error_message); - return OP::Operation(input, result, parameters); - } -}; - -static bool StartsWithNumericDate(string &separator, const string_t &value) { - auto begin = value.GetData(); - auto end = begin + value.GetSize(); - - // StrpTimeFormat::Parse will skip whitespace, so we can too - auto field1 = std::find_if_not(begin, end, StringUtil::CharacterIsSpace); - if (field1 == end) { - return false; - } - - // first numeric field must start immediately - if (!StringUtil::CharacterIsDigit(*field1)) { - return false; - } - auto literal1 = std::find_if_not(field1, end, StringUtil::CharacterIsDigit); - if (literal1 == end) { - return false; - } - - // second numeric field must exist - auto field2 = std::find_if(literal1, end, StringUtil::CharacterIsDigit); - if (field2 == end) { - return false; - } - auto literal2 = std::find_if_not(field2, end, StringUtil::CharacterIsDigit); - if (literal2 == end) { - return false; - } - - // third numeric field must exist - auto field3 = std::find_if(literal2, end, StringUtil::CharacterIsDigit); - if (field3 == end) { - return false; - } - - // second literal must match first - if (((field3 - literal2) != (field2 - literal1)) || - strncmp(literal1, literal2, NumericCast((field2 - literal1))) != 0) { - return false; - } - - // copy the literal as the separator, escaping percent signs - separator.clear(); - while (literal1 < field2) { - const auto literal_char = *literal1++; - if (literal_char == '%') { - separator.push_back(literal_char); - } - separator.push_back(literal_char); - } - - return true; -} - -string GenerateDateFormat(const string &separator, const char *format_template) { - string format_specifier = format_template; - auto amount_of_dashes = NumericCast(std::count(format_specifier.begin(), format_specifier.end(), '-')); - // All our date formats must have at least one - - D_ASSERT(amount_of_dashes); - string result; - result.reserve(format_specifier.size() - amount_of_dashes + (amount_of_dashes * separator.size())); - for (auto &character : format_specifier) { - if (character == '-') { - result += separator; - } else { - result += character; - } - } - return result; -} - -void CSVSniffer::SetDateFormat(CSVStateMachine &candidate, const string &format_specifier, - const LogicalTypeId &sql_type) { - StrpTimeFormat strpformat; - StrTimeFormat::ParseFormatSpecifier(format_specifier, strpformat); - candidate.dialect_options.date_format[sql_type].Set(strpformat, false); -} - -idx_t CSVSniffer::LinesSniffed() const { - return lines_sniffed; -} - -bool CSVSniffer::CanYouCastIt(ClientContext &context, const string_t value, const LogicalType &type, - const DialectOptions &dialect_options, const bool is_null, const char decimal_separator) { - if (is_null) { - return true; - } - auto value_ptr = value.GetData(); - auto value_size = value.GetSize(); - switch (type.id()) { - case LogicalTypeId::BOOLEAN: { - bool dummy_value; - return TryCastStringBool(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::TINYINT: { - int8_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, false); - } - case LogicalTypeId::SMALLINT: { - int16_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::INTEGER: { - int32_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::BIGINT: { - int64_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::UTINYINT: { - uint8_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::USMALLINT: { - uint16_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::UINTEGER: { - uint32_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::UBIGINT: { - uint64_t dummy_value; - return TrySimpleIntegerCast(value_ptr, value_size, dummy_value, true); - } - case LogicalTypeId::DOUBLE: { - double dummy_value; - return TryDoubleCast(value_ptr, value_size, dummy_value, true, decimal_separator); - } - case LogicalTypeId::FLOAT: { - float dummy_value; - return TryDoubleCast(value_ptr, value_size, dummy_value, true, decimal_separator); - } - case LogicalTypeId::DATE: { - if (!dialect_options.date_format.find(LogicalTypeId::DATE)->second.GetValue().Empty()) { - date_t result; - string error_message; - return dialect_options.date_format.find(LogicalTypeId::DATE) - ->second.GetValue() - .TryParseDate(value, result, error_message); - } - idx_t pos; - bool special; - date_t dummy_value; - return Date::TryConvertDate(value_ptr, value_size, pos, dummy_value, special, true) == DateCastResult::SUCCESS; - } - case LogicalTypeId::TIMESTAMP: { - timestamp_t dummy_value; - if (!dialect_options.date_format.find(LogicalTypeId::TIMESTAMP)->second.GetValue().Empty()) { - string error_message; - return dialect_options.date_format.find(LogicalTypeId::TIMESTAMP) - ->second.GetValue() - .TryParseTimestamp(value, dummy_value, error_message); - } - return Timestamp::TryConvertTimestamp(value_ptr, value_size, dummy_value) == TimestampCastResult::SUCCESS; - } - case LogicalTypeId::TIME: { - idx_t pos; - dtime_t dummy_value; - return Time::TryConvertTime(value_ptr, value_size, pos, dummy_value, true); - } - case LogicalTypeId::DECIMAL: { - uint8_t width, scale; - type.GetDecimalProperties(width, scale); - if (decimal_separator == ',') { - switch (type.InternalType()) { - case PhysicalType::INT16: { - int16_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT32: { - int32_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT64: { - int64_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT128: { - hugeint_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - default: - throw InternalException("Invalid Physical Type for Decimal Value. Physical Type: " + - TypeIdToString(type.InternalType())); - } - - } else if (decimal_separator == '.') { - switch (type.InternalType()) { - case PhysicalType::INT16: { - int16_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT32: { - int32_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT64: { - int64_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - case PhysicalType::INT128: { - hugeint_t dummy_value; - return TryDecimalStringCast(value_ptr, value_size, dummy_value, width, scale); - } - - default: - throw InternalException("Invalid Physical Type for Decimal Value. Physical Type: " + - TypeIdToString(type.InternalType())); - } - } - throw InvalidInputException("Decimals can only have ',' and '.' as decimal separators"); - } - case LogicalTypeId::VARCHAR: - return true; - default: { - // We do Value Try Cast for non-basic types. - Value new_value; - string error_message; - Value str_value(value); - return str_value.TryCastAs(context, type, new_value, &error_message, true); - } - } -} - -void CSVSniffer::InitializeDateAndTimeStampDetection(CSVStateMachine &candidate, const string &separator, - const LogicalType &sql_type) { - auto &format_candidate = format_candidates[sql_type.id()]; - if (!format_candidate.initialized) { - format_candidate.initialized = true; - // if user set a format, we add that as well - auto user_format = options.dialect_options.date_format.find(sql_type.id()); - if (user_format->second.IsSetByUser()) { - format_candidate.format.emplace_back(user_format->second.GetValue().format_specifier); - } else { - auto entry = format_template_candidates.find(sql_type.id()); - if (entry != format_template_candidates.end()) { - const auto &format_template_list = entry->second; - for (const auto &t : format_template_list) { - const auto format_string = GenerateDateFormat(separator, t); - // don't parse ISO 8601 - if (format_string.find("%Y-%m-%d") == string::npos) { - format_candidate.format.emplace_back(format_string); - } - } - } - } - // order by preference - original_format_candidates = format_candidates; - } - // initialise the first candidate - // all formats are constructed to be valid - SetDateFormat(candidate, format_candidate.format.back(), sql_type.id()); -} - -bool ValidSeparator(const string &separator) { - // We use https://en.wikipedia.org/wiki/List_of_date_formats_by_country as reference - return separator == "-" || separator == "." || separator == "/" || separator == " "; -} -void CSVSniffer::DetectDateAndTimeStampFormats(CSVStateMachine &candidate, const LogicalType &sql_type, - const string &separator, const string_t &dummy_val) { - if (!ValidSeparator(separator)) { - return; - } - // If it is the first time running date/timestamp detection we must initialize the format variables - InitializeDateAndTimeStampDetection(candidate, separator, sql_type); - // generate date format candidates the first time through - auto &type_format_candidates = format_candidates[sql_type.id()].format; - // check all formats and keep the first one that works - StrpTimeFormat::ParseResult result; - auto save_format_candidates = type_format_candidates; - const bool had_format_candidates = !save_format_candidates.empty(); - const bool initial_format_candidates = - save_format_candidates.size() == original_format_candidates.at(sql_type.id()).format.size(); - const bool is_set_by_user = options.dialect_options.date_format.find(sql_type.id())->second.IsSetByUser(); - while (!type_format_candidates.empty() && !is_set_by_user) { - // avoid using exceptions for flow control... - auto ¤t_format = candidate.dialect_options.date_format[sql_type.id()].GetValue(); - if (current_format.Parse(dummy_val, result, true)) { - format_candidates[sql_type.id()].had_match = true; - break; - } - // doesn't work - move to the next one - type_format_candidates.pop_back(); - if (!type_format_candidates.empty()) { - SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); - } - } - // if none match, then this is not a value of type sql_type, - if (type_format_candidates.empty()) { - // so restore the candidates that did work. - // or throw them out if they were generated by this value. - if (had_format_candidates) { - if (initial_format_candidates && !format_candidates[sql_type.id()].had_match) { - // we reset the whole thing because we tried to sniff the wrong type. - format_candidates[sql_type.id()].initialized = false; - format_candidates[sql_type.id()].format.clear(); - SetDateFormat(candidate, "", sql_type.id()); - return; - } - type_format_candidates.swap(save_format_candidates); - SetDateFormat(candidate, type_format_candidates.back(), sql_type.id()); - } - } -} - -void CSVSniffer::SniffTypes(DataChunk &data_chunk, CSVStateMachine &state_machine, - unordered_map> &info_sql_types_candidates, - idx_t start_idx_detection) { - const idx_t chunk_size = data_chunk.size(); - HasType has_type; - for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { - auto &cur_vector = data_chunk.data[col_idx]; - D_ASSERT(cur_vector.GetVectorType() == VectorType::FLAT_VECTOR); - D_ASSERT(cur_vector.GetType() == LogicalType::VARCHAR); - auto vector_data = FlatVector::GetData(cur_vector); - auto null_mask = FlatVector::Validity(cur_vector); - auto &col_type_candidates = info_sql_types_candidates[col_idx]; - for (idx_t row_idx = start_idx_detection; row_idx < chunk_size; row_idx++) { - // col_type_candidates can't be empty since anything in a CSV file should at least be a string - // and we validate utf-8 compatibility when creating the type - D_ASSERT(!col_type_candidates.empty()); - auto cur_top_candidate = col_type_candidates.back(); - // try cast from string to sql_type - while (col_type_candidates.size() > 1) { - const auto &sql_type = col_type_candidates.back(); - // try formatting for date types if the user did not specify one, and it starts with numeric - // values. - string separator; - // If Value is not Null, Has a numeric date format, and the current investigated candidate is - // either a timestamp or a date - if (null_mask.RowIsValid(row_idx) && StartsWithNumericDate(separator, vector_data[row_idx]) && - ((col_type_candidates.back().id() == LogicalTypeId::TIMESTAMP && !has_type.timestamp) || - (col_type_candidates.back().id() == LogicalTypeId::DATE && !has_type.date))) { - DetectDateAndTimeStampFormats(state_machine, sql_type, separator, vector_data[row_idx]); - } - // try cast from string to sql_type - if (sql_type == LogicalType::VARCHAR) { - // Nothing to convert it to - continue; - } - if (CanYouCastIt(buffer_manager->context, vector_data[row_idx], sql_type, state_machine.dialect_options, - !null_mask.RowIsValid(row_idx), state_machine.options.decimal_separator[0])) { - break; - } - - if (row_idx != start_idx_detection && - (cur_top_candidate == LogicalType::BOOLEAN || cur_top_candidate == LogicalType::DATE || - cur_top_candidate == LogicalType::TIME || cur_top_candidate == LogicalType::TIMESTAMP)) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; - } - col_type_candidates.pop_back(); - } - } - if (col_type_candidates.back().id() == LogicalTypeId::DATE) { - has_type.date = true; - } - if (col_type_candidates.back().id() == LogicalTypeId::TIMESTAMP) { - has_type.timestamp = true; - } - } -} - -// If we have a predefined date/timestamp format we set it -void CSVSniffer::SetUserDefinedDateTimeFormat(CSVStateMachine &candidate) const { - const vector data_time_formats {LogicalTypeId::DATE, LogicalTypeId::TIMESTAMP}; - for (auto &date_time_format : data_time_formats) { - auto &user_option = options.dialect_options.date_format.at(date_time_format); - if (user_option.IsSetByUser()) { - SetDateFormat(candidate, user_option.GetValue().format_specifier, date_time_format); - } - } -} -void CSVSniffer::DetectTypes() { - idx_t min_varchar_cols = max_columns_found + 1; - idx_t min_errors = NumericLimits::Maximum(); - vector return_types; - // check which info candidate leads to minimum amount of non-varchar columns... - for (auto &candidate_cc : candidates) { - auto &sniffing_state_machine = candidate_cc->GetStateMachine(); - unordered_map> info_sql_types_candidates; - for (idx_t i = 0; i < max_columns_found; i++) { - info_sql_types_candidates[i] = sniffing_state_machine.options.auto_type_candidates; - } - D_ASSERT(max_columns_found > 0); - - // Set all return_types to VARCHAR, so we can do datatype detection based on VARCHAR values - return_types.clear(); - return_types.assign(max_columns_found, LogicalType::VARCHAR); - - // Reset candidate for parsing - auto candidate = candidate_cc->UpgradeToStringValueScanner(); - SetUserDefinedDateTimeFormat(*candidate->state_machine); - // Parse chunk and read csv with info candidate - auto &data_chunk = candidate->ParseChunk().ToChunk(); - if (candidate->error_handler->AnyErrors() && !candidate->error_handler->HasError(MAXIMUM_LINE_SIZE) && - !candidate->state_machine->options.ignore_errors.GetValue()) { - continue; - } - idx_t start_idx_detection = 0; - idx_t chunk_size = data_chunk.size(); - if (chunk_size > 1 && - (!options.dialect_options.header.IsSetByUser() || - (options.dialect_options.header.IsSetByUser() && options.dialect_options.header.GetValue()))) { - // This means we have more than one row, hence we can use the first row to detect if we have a header - start_idx_detection = 1; - } - // First line where we start our type detection - SniffTypes(data_chunk, sniffing_state_machine, info_sql_types_candidates, start_idx_detection); - - // Count the number of varchar columns - idx_t varchar_cols = 0; - for (idx_t col = 0; col < info_sql_types_candidates.size(); col++) { - auto &col_type_candidates = info_sql_types_candidates[col]; - // check number of varchar columns - const auto &col_type = col_type_candidates.back(); - if (col_type == LogicalType::VARCHAR) { - varchar_cols++; - } - } - - // it's good if the dialect creates more non-varchar columns, but only if we sacrifice < 30% of - // best_num_cols. - const idx_t number_of_errors = candidate->error_handler->GetSize(); - if (!best_candidate || (varchar_cols(info_sql_types_candidates.size())>( - static_cast(max_columns_found) * 0.7) && - (!options.ignore_errors.GetValue() || number_of_errors < min_errors))) { - min_errors = number_of_errors; - best_header_row.clear(); - // we have a new best_options candidate - best_candidate = std::move(candidate); - min_varchar_cols = varchar_cols; - best_sql_types_candidates_per_column_idx = info_sql_types_candidates; - for (auto &format_candidate : format_candidates) { - best_format_candidates[format_candidate.first] = format_candidate.second.format; - } - if (chunk_size > 0) { - for (idx_t col_idx = 0; col_idx < data_chunk.ColumnCount(); col_idx++) { - auto &cur_vector = data_chunk.data[col_idx]; - auto vector_data = FlatVector::GetData(cur_vector); - auto null_mask = FlatVector::Validity(cur_vector); - if (null_mask.RowIsValid(0)) { - auto value = HeaderValue(vector_data[0]); - best_header_row.push_back(value); - } else { - best_header_row.push_back({}); - } - } - } - } - } - if (!best_candidate) { - DialectCandidates dialect_candidates(options.dialect_options.state_machine_options); - auto error = CSVError::SniffingError(options, dialect_candidates.Print()); - error_handler->Error(error, true); - } - // Assert that it's all good at this point. - D_ASSERT(best_candidate && !best_format_candidates.empty()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp deleted file mode 100644 index 8d3e26845..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_refinement.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_casting.hpp" - -namespace duckdb { -bool CSVSniffer::TryCastVector(Vector &parse_chunk_col, idx_t size, const LogicalType &sql_type) { - auto &sniffing_state_machine = best_candidate->GetStateMachine(); - // try vector-cast from string to sql_type - Vector dummy_result(sql_type, size); - if (!sniffing_state_machine.dialect_options.date_format[LogicalTypeId::DATE].GetValue().Empty() && - sql_type.id() == LogicalTypeId::DATE) { - // use the date format to cast the chunk - string error_message; - CastParameters parameters(false, &error_message); - idx_t line_error; - return CSVCast::TryCastDateVector(sniffing_state_machine.dialect_options.date_format, parse_chunk_col, - dummy_result, size, parameters, line_error); - } - if (!sniffing_state_machine.dialect_options.date_format[LogicalTypeId::TIMESTAMP].GetValue().Empty() && - sql_type.id() == LogicalTypeId::TIMESTAMP) { - // use the timestamp format to cast the chunk - string error_message; - CastParameters parameters(false, &error_message); - return CSVCast::TryCastTimestampVector(sniffing_state_machine.dialect_options.date_format, parse_chunk_col, - dummy_result, size, parameters); - } - if ((sql_type.id() == LogicalTypeId::DOUBLE || sql_type.id() == LogicalTypeId::FLOAT) && - options.decimal_separator == ",") { - string error_message; - CastParameters parameters(false, &error_message); - idx_t line_error; - return CSVCast::TryCastFloatingVectorCommaSeparated(options, parse_chunk_col, dummy_result, size, parameters, - sql_type, line_error); - } - if (sql_type.id() == LogicalTypeId::DECIMAL && options.decimal_separator == ",") { - string error_message; - CastParameters parameters(false, &error_message); - idx_t line_error; - return CSVCast::TryCastDecimalVectorCommaSeparated(options, parse_chunk_col, dummy_result, size, parameters, - sql_type, line_error); - } - // target type is not varchar: perform a cast - string error_message; - return VectorOperations::DefaultTryCast(parse_chunk_col, dummy_result, size, &error_message, true); -} - -void CSVSniffer::RefineTypes() { - auto &sniffing_state_machine = best_candidate->GetStateMachine(); - // if data types were provided, exit here if number of columns does not match - detected_types.assign(sniffing_state_machine.dialect_options.num_cols, LogicalType::VARCHAR); - if (sniffing_state_machine.options.all_varchar) { - // return all types varchar - return; - } - for (idx_t i = 1; i < sniffing_state_machine.options.sample_size_chunks; i++) { - bool finished_file = best_candidate->FinishedFile(); - if (finished_file) { - // we finished the file: stop - // set sql types - detected_types.clear(); - for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { - LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); - if (best_sql_types_candidates_per_column_idx[column_idx].size() == - sniffing_state_machine.options.auto_type_candidates.size()) { - d_type = LogicalType::VARCHAR; - } - detected_types.push_back(d_type); - } - return; - } - auto &parse_chunk = best_candidate->ParseChunk().ToChunk(); - - for (idx_t col = 0; col < parse_chunk.ColumnCount(); col++) { - vector &col_type_candidates = best_sql_types_candidates_per_column_idx[col]; - bool is_bool_type = col_type_candidates.back() == LogicalType::BOOLEAN; - while (col_type_candidates.size() > 1) { - const auto &sql_type = col_type_candidates.back(); - if (TryCastVector(parse_chunk.data[col], parse_chunk.size(), sql_type)) { - break; - } - if (col_type_candidates.back() == LogicalType::BOOLEAN && is_bool_type) { - // If we thought this was a boolean value (i.e., T,F, True, False) and it is not, we - // immediately pop to varchar. - while (col_type_candidates.back() != LogicalType::VARCHAR) { - col_type_candidates.pop_back(); - } - break; - } - col_type_candidates.pop_back(); - } - } - // reset parse chunk for the next iteration - parse_chunk.Reset(); - parse_chunk.SetCapacity(CSVReaderOptions::sniff_size); - } - detected_types.clear(); - // set sql types - for (idx_t column_idx = 0; column_idx < best_sql_types_candidates_per_column_idx.size(); column_idx++) { - LogicalType d_type = best_sql_types_candidates_per_column_idx[column_idx].back(); - if (best_sql_types_candidates_per_column_idx[column_idx].size() == - best_candidate->GetStateMachine().options.auto_type_candidates.size() && - default_null_to_varchar) { - d_type = LogicalType::VARCHAR; - } - detected_types.push_back(d_type); - } -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp b/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp deleted file mode 100644 index a693144de..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/sniffer/type_replacement.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" - -namespace duckdb { -void CSVSniffer::ReplaceTypes() { - auto &sniffing_state_machine = best_candidate->GetStateMachine(); - manually_set = vector(detected_types.size(), false); - if (sniffing_state_machine.options.sql_type_list.empty() || sniffing_state_machine.options.columns_set) { - return; - } - // user-defined types were supplied for certain columns - // override the types - if (!sniffing_state_machine.options.sql_types_per_column.empty()) { - // types supplied as name -> value map - idx_t found = 0; - for (idx_t i = 0; i < names.size(); i++) { - auto it = sniffing_state_machine.options.sql_types_per_column.find(names[i]); - if (it != sniffing_state_machine.options.sql_types_per_column.end()) { - best_sql_types_candidates_per_column_idx[i] = { - sniffing_state_machine.options.sql_type_list[it->second]}; - detected_types[i] = sniffing_state_machine.options.sql_type_list[it->second]; - manually_set[i] = true; - found++; - } - } - if (!sniffing_state_machine.options.file_options.union_by_name && - found < sniffing_state_machine.options.sql_types_per_column.size()) { - auto error_msg = CSVError::ColumnTypesError(options.sql_types_per_column, names); - error_handler->Error(error_msg); - } - return; - } - // types supplied as list - if (names.size() < sniffing_state_machine.options.sql_type_list.size()) { - throw BinderException("read_csv: %d types were provided, but CSV file only has %d columns", - sniffing_state_machine.options.sql_type_list.size(), names.size()); - } - for (idx_t i = 0; i < sniffing_state_machine.options.sql_type_list.size(); i++) { - detected_types[i] = sniffing_state_machine.options.sql_type_list[i]; - manually_set[i] = true; - } -} -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine.cpp b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine.cpp deleted file mode 100644 index eae140f7d..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_state_machine.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/main/error_manager.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp" - -namespace duckdb { - -CSVStateMachine::CSVStateMachine(CSVReaderOptions &options_p, const CSVStateMachineOptions &state_machine_options_p, - CSVStateMachineCache &csv_state_machine_cache) - : transition_array(csv_state_machine_cache.Get(state_machine_options_p)), - state_machine_options(state_machine_options_p), options(options_p) { - dialect_options.state_machine_options = state_machine_options; -} - -CSVStateMachine::CSVStateMachine(const StateMachine &transition_array_p, const CSVReaderOptions &options_p) - : transition_array(transition_array_p), state_machine_options(options_p.dialect_options.state_machine_options), - options(options_p), dialect_options(options.dialect_options) { - dialect_options.state_machine_options = state_machine_options; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp b/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp deleted file mode 100644 index 29fda8863..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/state_machine/csv_state_machine_cache.cpp +++ /dev/null @@ -1,444 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_state_machine.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_state_machine_cache.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" - -namespace duckdb { - -void InitializeTransitionArray(StateMachine &transition_array, const CSVState cur_state, const CSVState state) { - for (uint32_t i = 0; i < StateMachine::NUM_TRANSITIONS; i++) { - transition_array[i][static_cast(cur_state)] = state; - } -} - -// Shift and OR to replicate across all bytes -void ShiftAndReplicateBits(uint64_t &value) { - value |= value << 8; - value |= value << 16; - value |= value << 32; -} -void CSVStateMachineCache::Insert(const CSVStateMachineOptions &state_machine_options) { - D_ASSERT(state_machine_cache.find(state_machine_options) == state_machine_cache.end()); - // Initialize transition array with default values to the Standard option - auto &transition_array = state_machine_cache[state_machine_options]; - - for (uint32_t i = 0; i < StateMachine::NUM_STATES; i++) { - const auto cur_state = static_cast(i); - switch (cur_state) { - case CSVState::MAYBE_QUOTED: - case CSVState::QUOTED: - case CSVState::QUOTED_NEW_LINE: - case CSVState::ESCAPE: - InitializeTransitionArray(transition_array, cur_state, CSVState::QUOTED); - break; - case CSVState::UNQUOTED: - if (state_machine_options.rfc_4180.GetValue()) { - // If we have an unquoted state, following rfc 4180, our base state is invalid - InitializeTransitionArray(transition_array, cur_state, CSVState::INVALID); - } else { - // This will allow us to accept unescaped quotes - InitializeTransitionArray(transition_array, cur_state, CSVState::UNQUOTED); - } - break; - case CSVState::COMMENT: - InitializeTransitionArray(transition_array, cur_state, CSVState::COMMENT); - break; - default: - InitializeTransitionArray(transition_array, cur_state, CSVState::STANDARD); - break; - } - } - - const auto delimiter_value = state_machine_options.delimiter.GetValue(); - const auto delimiter_first_byte = static_cast(delimiter_value[0]); - const auto quote = static_cast(state_machine_options.quote.GetValue()); - const auto escape = static_cast(state_machine_options.escape.GetValue()); - const auto comment = static_cast(state_machine_options.comment.GetValue()); - - const auto new_line_id = state_machine_options.new_line.GetValue(); - - const bool multi_byte_delimiter = delimiter_value.size() != 1; - - const bool enable_unquoted_escape = state_machine_options.rfc_4180.GetValue() == false && - state_machine_options.quote != state_machine_options.escape && - state_machine_options.escape != '\0'; - // Now set values depending on configuration - // 1) Standard/Invalid State - const vector std_inv {static_cast(CSVState::STANDARD), static_cast(CSVState::INVALID), - static_cast(CSVState::STANDARD_NEWLINE)}; - for (const auto &state : std_inv) { - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][state] = CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][state] = CSVState::DELIMITER; - } - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][state] = CSVState::CARRIAGE_RETURN; - if (state == static_cast(CSVState::STANDARD_NEWLINE)) { - transition_array[static_cast('\n')][state] = CSVState::STANDARD; - } else if (!state_machine_options.rfc_4180.GetValue()) { - transition_array[static_cast('\n')][state] = CSVState::RECORD_SEPARATOR; - } else { - transition_array[static_cast('\n')][state] = CSVState::INVALID; - } - } else { - transition_array[static_cast('\r')][state] = CSVState::RECORD_SEPARATOR; - transition_array[static_cast('\n')][state] = CSVState::RECORD_SEPARATOR; - } - if (comment != '\0') { - transition_array[comment][state] = CSVState::COMMENT; - } - if (enable_unquoted_escape) { - transition_array[escape][state] = CSVState::UNQUOTED_ESCAPE; - } - } - // 2) Field Separator State - if (quote != '\0') { - transition_array[quote][static_cast(CSVState::DELIMITER)] = CSVState::QUOTED; - } - if (delimiter_first_byte != ' ') { - transition_array[' '][static_cast(CSVState::DELIMITER)] = CSVState::EMPTY_SPACE; - } - - const vector delimiter_states { - static_cast(CSVState::DELIMITER), static_cast(CSVState::DELIMITER_FIRST_BYTE), - static_cast(CSVState::DELIMITER_SECOND_BYTE), static_cast(CSVState::DELIMITER_THIRD_BYTE)}; - - // These are the same transitions for all delimiter states - for (auto &state : delimiter_states) { - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][state] = CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][state] = CSVState::DELIMITER; - } - transition_array[static_cast('\n')][state] = CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][state] = CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][state] = CSVState::RECORD_SEPARATOR; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::DELIMITER)] = CSVState::COMMENT; - } - } - // Deal other multi-byte delimiters - if (delimiter_value.size() == 2) { - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_FIRST_BYTE)] = CSVState::DELIMITER; - } else if (delimiter_value.size() == 3) { - if (delimiter_value[0] == delimiter_value[1]) { - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_SECOND_BYTE)] = CSVState::DELIMITER_SECOND_BYTE; - } - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_FIRST_BYTE)] = CSVState::DELIMITER_SECOND_BYTE; - transition_array[static_cast(delimiter_value[2])] - [static_cast(CSVState::DELIMITER_SECOND_BYTE)] = CSVState::DELIMITER; - } else if (delimiter_value.size() == 4) { - if (delimiter_value[0] == delimiter_value[2]) { - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_THIRD_BYTE)] = CSVState::DELIMITER_SECOND_BYTE; - } - if (delimiter_value[0] == delimiter_value[1] && delimiter_value[1] == delimiter_value[2]) { - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_THIRD_BYTE)] = CSVState::DELIMITER_THIRD_BYTE; - } - transition_array[static_cast(delimiter_value[1])] - [static_cast(CSVState::DELIMITER_FIRST_BYTE)] = CSVState::DELIMITER_SECOND_BYTE; - transition_array[static_cast(delimiter_value[2])] - [static_cast(CSVState::DELIMITER_SECOND_BYTE)] = CSVState::DELIMITER_THIRD_BYTE; - transition_array[static_cast(delimiter_value[3])] - [static_cast(CSVState::DELIMITER_THIRD_BYTE)] = CSVState::DELIMITER; - } - if (enable_unquoted_escape) { - transition_array[escape][static_cast(CSVState::DELIMITER)] = CSVState::UNQUOTED_ESCAPE; - } - - // 3) Record Separator State - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][static_cast(CSVState::RECORD_SEPARATOR)] = - CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::DELIMITER; - } - transition_array[static_cast('\n')][static_cast(CSVState::RECORD_SEPARATOR)] = - CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::RECORD_SEPARATOR)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::RECORD_SEPARATOR)] = - CSVState::RECORD_SEPARATOR; - } - if (quote != '\0') { - transition_array[quote][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::QUOTED; - } - if (delimiter_first_byte != ' ') { - transition_array[' '][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::EMPTY_SPACE; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::COMMENT; - } - if (enable_unquoted_escape) { - transition_array[escape][static_cast(CSVState::RECORD_SEPARATOR)] = CSVState::UNQUOTED_ESCAPE; - } - - // 4) Carriage Return State - transition_array[static_cast('\n')][static_cast(CSVState::CARRIAGE_RETURN)] = - CSVState::RECORD_SEPARATOR; - transition_array[static_cast('\r')][static_cast(CSVState::CARRIAGE_RETURN)] = - CSVState::CARRIAGE_RETURN; - if (quote != '\0') { - transition_array[quote][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::QUOTED; - } - if (delimiter_first_byte != ' ') { - transition_array[' '][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::EMPTY_SPACE; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::COMMENT; - } - if (enable_unquoted_escape) { - transition_array[escape][static_cast(CSVState::CARRIAGE_RETURN)] = CSVState::UNQUOTED_ESCAPE; - } - - // 5) Quoted State - transition_array[quote][static_cast(CSVState::QUOTED)] = CSVState::UNQUOTED; - transition_array['\n'][static_cast(CSVState::QUOTED)] = CSVState::QUOTED_NEW_LINE; - transition_array['\r'][static_cast(CSVState::QUOTED)] = CSVState::QUOTED_NEW_LINE; - - if (state_machine_options.quote != state_machine_options.escape && - state_machine_options.escape.GetValue() != '\0') { - transition_array[escape][static_cast(CSVState::QUOTED)] = CSVState::ESCAPE; - } - // 6) Unquoted State - transition_array[static_cast('\n')][static_cast(CSVState::UNQUOTED)] = CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::UNQUOTED)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::UNQUOTED)] = - CSVState::RECORD_SEPARATOR; - } - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][static_cast(CSVState::UNQUOTED)] = - CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][static_cast(CSVState::UNQUOTED)] = CSVState::DELIMITER; - } - if (state_machine_options.quote == state_machine_options.escape) { - transition_array[quote][static_cast(CSVState::UNQUOTED)] = CSVState::QUOTED; - } - if (state_machine_options.rfc_4180 == false) { - if (escape == '\0') { - // If escape is defined, it limits a bit how relaxed quotes can be in a reliable way. - transition_array[quote][static_cast(CSVState::UNQUOTED)] = CSVState::MAYBE_QUOTED; - } else { - transition_array[quote][static_cast(CSVState::UNQUOTED)] = CSVState::QUOTED; - } - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::UNQUOTED)] = CSVState::COMMENT; - } - if (delimiter_first_byte != ' ' && quote != ' ' && escape != ' ' && comment != ' ') { - // If space is not a special character, we can safely ignore it in an unquoted state - transition_array[' '][static_cast(CSVState::UNQUOTED)] = CSVState::UNQUOTED; - } - - // 8) Not Set - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][static_cast(CSVState::NOT_SET)] = - CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][static_cast(CSVState::NOT_SET)] = CSVState::DELIMITER; - } - transition_array[static_cast('\n')][static_cast(CSVState::NOT_SET)] = CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::NOT_SET)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::NOT_SET)] = - CSVState::RECORD_SEPARATOR; - } - if (quote != '\0') { - transition_array[quote][static_cast(CSVState::NOT_SET)] = CSVState::QUOTED; - } - if (delimiter_first_byte != ' ') { - transition_array[' '][static_cast(CSVState::NOT_SET)] = CSVState::EMPTY_SPACE; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::NOT_SET)] = CSVState::COMMENT; - } - if (enable_unquoted_escape) { - transition_array[escape][static_cast(CSVState::NOT_SET)] = CSVState::UNQUOTED_ESCAPE; - } - - // 9) Quoted NewLine - transition_array[quote][static_cast(CSVState::QUOTED_NEW_LINE)] = CSVState::UNQUOTED; - if (state_machine_options.quote != state_machine_options.escape && - state_machine_options.escape.GetValue() != '\0') { - transition_array[escape][static_cast(CSVState::QUOTED_NEW_LINE)] = CSVState::ESCAPE; - } - - // 10) Empty Value State (Not first value) - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][static_cast(CSVState::EMPTY_SPACE)] = - CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][static_cast(CSVState::EMPTY_SPACE)] = CSVState::DELIMITER; - } - transition_array[static_cast('\n')][static_cast(CSVState::EMPTY_SPACE)] = - CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::EMPTY_SPACE)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::EMPTY_SPACE)] = - CSVState::RECORD_SEPARATOR; - } - if (quote != '\0') { - transition_array[quote][static_cast(CSVState::EMPTY_SPACE)] = CSVState::QUOTED; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::EMPTY_SPACE)] = CSVState::COMMENT; - } - if (enable_unquoted_escape) { - transition_array[escape][static_cast(CSVState::EMPTY_SPACE)] = CSVState::UNQUOTED_ESCAPE; - } - - // 11) Comment State - transition_array[static_cast('\n')][static_cast(CSVState::COMMENT)] = CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::COMMENT)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::COMMENT)] = - CSVState::RECORD_SEPARATOR; - } - - // 12) Unquoted Escape State - if (enable_unquoted_escape) { - // Any character can be escaped, so default to STANDARD - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::UNQUOTED_ESCAPE)] = - CSVState::ESCAPED_RETURN; - } - } - - // 13) Escaped Return State - if (enable_unquoted_escape) { - // The new state is STANDARD for \r + \n and \r + ordinary character. - // Other special characters need to be handled. - transition_array[delimiter_first_byte][static_cast(CSVState::ESCAPED_RETURN)] = CSVState::DELIMITER; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::ESCAPED_RETURN)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::ESCAPED_RETURN)] = - CSVState::RECORD_SEPARATOR; - } - if (comment != '\0') { - transition_array[comment][static_cast(CSVState::ESCAPED_RETURN)] = CSVState::COMMENT; - } - transition_array[escape][static_cast(CSVState::ESCAPED_RETURN)] = CSVState::UNQUOTED_ESCAPE; - } - - // 14) Maybe quoted - transition_array[quote][static_cast(CSVState::MAYBE_QUOTED)] = CSVState::MAYBE_QUOTED; - - transition_array[static_cast('\n')][static_cast(CSVState::MAYBE_QUOTED)] = - CSVState::RECORD_SEPARATOR; - if (new_line_id == NewLineIdentifier::CARRY_ON) { - transition_array[static_cast('\r')][static_cast(CSVState::MAYBE_QUOTED)] = - CSVState::CARRIAGE_RETURN; - } else { - transition_array[static_cast('\r')][static_cast(CSVState::MAYBE_QUOTED)] = - CSVState::RECORD_SEPARATOR; - } - if (multi_byte_delimiter) { - transition_array[delimiter_first_byte][static_cast(CSVState::MAYBE_QUOTED)] = - CSVState::DELIMITER_FIRST_BYTE; - } else { - transition_array[delimiter_first_byte][static_cast(CSVState::MAYBE_QUOTED)] = CSVState::DELIMITER; - } - - // Initialize characters we can skip during processing, for Standard and Quoted states - for (idx_t i = 0; i < StateMachine::NUM_TRANSITIONS; i++) { - transition_array.skip_standard[i] = true; - transition_array.skip_quoted[i] = true; - transition_array.skip_comment[i] = true; - } - // For standard states we only care for delimiters \r and \n - transition_array.skip_standard[delimiter_first_byte] = false; - transition_array.skip_standard[static_cast('\n')] = false; - transition_array.skip_standard[static_cast('\r')] = false; - transition_array.skip_standard[comment] = false; - if (enable_unquoted_escape) { - transition_array.skip_standard[escape] = false; - } - - // For quoted we only care about quote, escape and for delimiters \r and \n - transition_array.skip_quoted[quote] = false; - transition_array.skip_quoted[escape] = false; - transition_array.skip_quoted[static_cast('\n')] = false; - transition_array.skip_quoted[static_cast('\r')] = false; - - transition_array.skip_comment[static_cast('\r')] = false; - transition_array.skip_comment[static_cast('\n')] = false; - - transition_array.delimiter = delimiter_first_byte; - transition_array.new_line = static_cast('\n'); - transition_array.carriage_return = static_cast('\r'); - transition_array.quote = quote; - transition_array.escape = escape; - - // Shift and OR to replicate across all bytes - ShiftAndReplicateBits(transition_array.delimiter); - ShiftAndReplicateBits(transition_array.new_line); - ShiftAndReplicateBits(transition_array.carriage_return); - ShiftAndReplicateBits(transition_array.quote); - ShiftAndReplicateBits(transition_array.escape); - ShiftAndReplicateBits(transition_array.comment); -} - -CSVStateMachineCache::CSVStateMachineCache() { - auto default_quote = DialectCandidates::GetDefaultQuote(); - auto default_escape = DialectCandidates::GetDefaultEscape(); - auto default_quote_rule = DialectCandidates::GetDefaultQuoteRule(); - auto default_delimiter = DialectCandidates::GetDefaultDelimiter(); - auto default_comment = DialectCandidates::GetDefaultComment(); - - for (auto quote_rule : default_quote_rule) { - const auto "e_candidates = default_quote[static_cast(quote_rule)]; - for (const auto "e : quote_candidates) { - for (const auto &delimiter : default_delimiter) { - const auto &escape_candidates = default_escape[static_cast(quote_rule)]; - for (const auto &escape : escape_candidates) { - for (const auto &comment : default_comment) { - for (const bool rfc_4180 : {true, false}) { - Insert({delimiter, quote, escape, comment, NewLineIdentifier::SINGLE_N, rfc_4180}); - Insert({delimiter, quote, escape, comment, NewLineIdentifier::SINGLE_R, rfc_4180}); - Insert({delimiter, quote, escape, comment, NewLineIdentifier::CARRY_ON, rfc_4180}); - } - } - } - } - } - } -} - -const StateMachine &CSVStateMachineCache::Get(const CSVStateMachineOptions &state_machine_options) { - // Custom State Machine, we need to create it and cache it first - lock_guard parallel_lock(main_mutex); - if (state_machine_cache.find(state_machine_options) == state_machine_cache.end()) { - Insert(state_machine_options); - } - const auto &transition_array = state_machine_cache[state_machine_options]; - return transition_array; -} - -CSVStateMachineCache &CSVStateMachineCache::Get(ClientContext &context) { - - auto &cache = ObjectCache::GetObjectCache(context); - return *cache.GetOrCreate(CSVStateMachineCache::ObjectType()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp deleted file mode 100644 index 093e698d0..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/csv_file_scanner.cpp +++ /dev/null @@ -1,241 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp" - -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" -#include "duckdb/function/table/read_csv.hpp" - -namespace duckdb { -CSVUnionData::~CSVUnionData() { -} - -CSVFileScan::CSVFileScan(ClientContext &context, shared_ptr buffer_manager_p, - shared_ptr state_machine_p, const CSVReaderOptions &options_p, - const ReadCSVData &bind_data, const vector &column_ids, CSVSchema &file_schema) - : file_path(options_p.file_path), file_idx(0), buffer_manager(std::move(buffer_manager_p)), - state_machine(std::move(state_machine_p)), file_size(buffer_manager->file_handle->FileSize()), - error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), - on_disk_file(buffer_manager->file_handle->OnDiskFile()), options(options_p) { - - auto multi_file_reader = MultiFileReader::CreateDefault("CSV Scan"); - if (bind_data.initial_reader.get()) { - auto &union_reader = *bind_data.initial_reader; - names = union_reader.GetNames(); - options = union_reader.options; - types = union_reader.GetTypes(); - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - InitializeFileNamesTypes(); - return; - } - if (!bind_data.column_info.empty()) { - // Serialized Union By name - names = bind_data.column_info[0].names; - types = bind_data.column_info[0].types; - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - InitializeFileNamesTypes(); - return; - } - names = bind_data.csv_names; - types = bind_data.csv_types; - file_schema.Initialize(names, types, file_path); - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - - InitializeFileNamesTypes(); - SetStart(); -} - -void CSVFileScan::SetStart() { - idx_t rows_to_skip = options.GetSkipRows() + state_machine->dialect_options.header.GetValue(); - rows_to_skip = std::max(rows_to_skip, state_machine->dialect_options.rows_until_header + - state_machine->dialect_options.header.GetValue()); - if (rows_to_skip == 0) { - start_iterator.first_one = true; - return; - } - SkipScanner skip_scanner(buffer_manager, state_machine, error_handler, rows_to_skip); - skip_scanner.ParseChunk(); - start_iterator = skip_scanner.GetIterator(); -} - -CSVFileScan::CSVFileScan(ClientContext &context, const string &file_path_p, const CSVReaderOptions &options_p, - idx_t file_idx_p, const ReadCSVData &bind_data, const vector &column_ids, - CSVSchema &file_schema, bool per_file_single_threaded) - : file_path(file_path_p), file_idx(file_idx_p), - error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { - auto multi_file_reader = MultiFileReader::CreateDefault("CSV Scan"); - if (file_idx == 0 && bind_data.initial_reader) { - auto &union_reader = *bind_data.initial_reader; - // Initialize Buffer Manager - buffer_manager = union_reader.buffer_manager; - // Initialize On Disk and Size of file - on_disk_file = union_reader.on_disk_file; - file_size = union_reader.file_size; - names = union_reader.GetNames(); - options = union_reader.options; - types = union_reader.GetTypes(); - state_machine = union_reader.state_machine; - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - - InitializeFileNamesTypes(); - SetStart(); - return; - } - - // Initialize Buffer Manager - buffer_manager = make_shared_ptr(context, options, file_path, file_idx, per_file_single_threaded); - // Initialize On Disk and Size of file - on_disk_file = buffer_manager->file_handle->OnDiskFile(); - file_size = buffer_manager->file_handle->FileSize(); - // Initialize State Machine - auto &state_machine_cache = CSVStateMachineCache::Get(context); - - if (file_idx < bind_data.column_info.size()) { - // (Serialized) Union By name - names = bind_data.column_info[file_idx].names; - types = bind_data.column_info[file_idx].types; - if (file_idx < bind_data.union_readers.size()) { - // union readers - use cached options - D_ASSERT(names == bind_data.union_readers[file_idx]->names); - D_ASSERT(types == bind_data.union_readers[file_idx]->types); - options = bind_data.union_readers[file_idx]->options; - } else { - // Serialized union by name - sniff again - options.dialect_options.num_cols = names.size(); - if (options.auto_detect) { - CSVSniffer sniffer(options, buffer_manager, state_machine_cache); - sniffer.SniffCSV(); - } - } - state_machine = make_shared_ptr( - state_machine_cache.Get(options.dialect_options.state_machine_options), options); - - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - InitializeFileNamesTypes(); - SetStart(); - return; - } - // Sniff it! - names = bind_data.csv_names; - types = bind_data.csv_types; - if (options.auto_detect && bind_data.files.size() > 1) { - if (file_schema.Empty()) { - CSVSniffer sniffer(options, buffer_manager, state_machine_cache); - auto result = sniffer.SniffCSV(); - file_schema.Initialize(bind_data.csv_names, bind_data.csv_types, options.file_path); - } else if (file_idx > 0 && buffer_manager->file_handle->FileSize() > 0) { - options.file_path = file_path; - CSVSniffer sniffer(options, buffer_manager, state_machine_cache, false); - auto result = sniffer.AdaptiveSniff(file_schema); - names = result.names; - types = result.return_types; - } - } - if (options.dialect_options.num_cols == 0) { - // We need to define the number of columns, if the sniffer is not running this must be in the sql_type_list - options.dialect_options.num_cols = options.sql_type_list.size(); - } - if (options.dialect_options.state_machine_options.new_line == NewLineIdentifier::NOT_SET) { - options.dialect_options.state_machine_options.new_line = CSVSniffer::DetectNewLineDelimiter(*buffer_manager); - } - state_machine = make_shared_ptr( - state_machine_cache.Get(options.dialect_options.state_machine_options), options); - multi_file_reader->InitializeReader(*this, options.file_options, bind_data.reader_bind, bind_data.return_types, - bind_data.return_names, column_ids, nullptr, file_path, context, nullptr); - InitializeFileNamesTypes(); - SetStart(); -} - -CSVFileScan::CSVFileScan(ClientContext &context, const string &file_name, const CSVReaderOptions &options_p) - : file_path(file_name), file_idx(0), - error_handler(make_shared_ptr(options_p.ignore_errors.GetValue())), options(options_p) { - buffer_manager = make_shared_ptr(context, options, file_path, file_idx); - // Initialize On Disk and Size of file - on_disk_file = buffer_manager->file_handle->OnDiskFile(); - file_size = buffer_manager->file_handle->FileSize(); - // Sniff it (We only really care about dialect detection, if types or number of columns are different this will - // error out during scanning) - auto &state_machine_cache = CSVStateMachineCache::Get(context); - // We sniff file if it has not been sniffed yet and either auto-detect is on, or union by name is on - if ((options.auto_detect || options.file_options.union_by_name) && options.dialect_options.num_cols == 0) { - CSVSniffer sniffer(options, buffer_manager, state_machine_cache); - auto sniffer_result = sniffer.SniffCSV(); - if (names.empty()) { - names = sniffer_result.names; - types = sniffer_result.return_types; - } - } - if (options.dialect_options.num_cols == 0) { - // We need to define the number of columns, if the sniffer is not running this must be in the sql_type_list - options.dialect_options.num_cols = options.sql_type_list.size(); - } - // Initialize State Machine - state_machine = make_shared_ptr( - state_machine_cache.Get(options.dialect_options.state_machine_options), options); - SetStart(); -} - -void CSVFileScan::InitializeFileNamesTypes() { - if (reader_data.empty_columns && reader_data.column_ids.empty()) { - // This means that the columns from this file are irrelevant. - // just read the first column - file_types.emplace_back(LogicalType::VARCHAR); - projected_columns.insert(0); - projection_ids.emplace_back(0, 0); - return; - } - - for (idx_t i = 0; i < reader_data.column_ids.size(); i++) { - idx_t result_idx = reader_data.column_ids[i]; - file_types.emplace_back(types[result_idx]); - projected_columns.insert(result_idx); - projection_ids.emplace_back(result_idx, i); - } - - if (reader_data.column_ids.empty()) { - file_types = types; - } - - // We need to be sure that our types are also following the cast_map - if (!reader_data.cast_map.empty()) { - for (idx_t i = 0; i < reader_data.column_ids.size(); i++) { - if (reader_data.cast_map.find(reader_data.column_ids[i]) != reader_data.cast_map.end()) { - file_types[i] = reader_data.cast_map[reader_data.column_ids[i]]; - } - } - } - - // We sort the types on the order of the parsed chunk - std::sort(projection_ids.begin(), projection_ids.end()); - vector sorted_types; - for (idx_t i = 0; i < projection_ids.size(); ++i) { - sorted_types.push_back(file_types[projection_ids[i].second]); - } - file_types = sorted_types; -} - -const string &CSVFileScan::GetFileName() const { - return file_path; -} -const vector &CSVFileScan::GetNames() { - return names; -} -const vector &CSVFileScan::GetTypes() { - return types; -} - -void CSVFileScan::InitializeProjection() { - for (idx_t i = 0; i < options.dialect_options.num_cols; i++) { - reader_data.column_ids.push_back(i); - reader_data.column_mapping.push_back(i); - } -} - -void CSVFileScan::Finish() { - buffer_manager.reset(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp b/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp deleted file mode 100644 index b1a4a6165..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/table_function/global_csv_state.cpp +++ /dev/null @@ -1,307 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/global_csv_state.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/scanner_boundary.hpp" -#include "duckdb/execution/operator/csv_scanner/skip_scanner.hpp" -#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" -#include "duckdb/main/appender.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -CSVGlobalState::CSVGlobalState(ClientContext &context_p, const shared_ptr &buffer_manager, - const CSVReaderOptions &options, idx_t system_threads_p, const vector &files, - vector column_ids_p, const ReadCSVData &bind_data_p) - : context(context_p), system_threads(system_threads_p), column_ids(std::move(column_ids_p)), - sniffer_mismatch_error(options.sniffer_user_mismatch_error), bind_data(bind_data_p) { - - if (buffer_manager && buffer_manager->GetFilePath() == files[0]) { - auto state_machine = make_shared_ptr( - CSVStateMachineCache::Get(context).Get(options.dialect_options.state_machine_options), options); - // If we already have a buffer manager, we don't need to reconstruct it to the first file - file_scans.emplace_back(make_uniq(context, buffer_manager, state_machine, options, bind_data, - column_ids, file_schema)); - } else { - // If not we need to construct it for the first file - file_scans.emplace_back( - make_uniq(context, files[0], options, 0U, bind_data, column_ids, file_schema, false)); - } - idx_t cur_file_idx = 0; - while (file_scans.back()->start_iterator.done && file_scans.size() < files.size()) { - cur_file_idx++; - file_scans.emplace_back(make_uniq(context, files[cur_file_idx], options, cur_file_idx, bind_data, - column_ids, file_schema, false)); - } - // There are situations where we only support single threaded scanning - bool many_csv_files = files.size() > 1 && files.size() > system_threads * 2; - single_threaded = many_csv_files || !options.parallel; - last_file_idx = 0; - scanner_idx = 0; - running_threads = CSVGlobalState::MaxThreads(); - current_boundary = file_scans.back()->start_iterator; - current_boundary.SetCurrentBoundaryToPosition(single_threaded, options); - if (current_boundary.done && context.client_data->debug_set_max_line_length) { - context.client_data->debug_max_line_length = current_boundary.pos.buffer_pos; - } - current_buffer_in_use = - make_shared_ptr(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); -} - -bool CSVGlobalState::IsDone() const { - lock_guard parallel_lock(main_mutex); - return current_boundary.done; -} - -double CSVGlobalState::GetProgress(const ReadCSVData &bind_data_p) const { - lock_guard parallel_lock(main_mutex); - idx_t total_files = bind_data.files.size(); - // get the progress WITHIN the current file - double percentage = 0; - if (file_scans.front()->file_size == 0) { - percentage = 1.0; - } else { - // for compressed files, read bytes may greater than files size. - for (auto &file : file_scans) { - double file_progress; - if (!file->buffer_manager) { - // We are done with this file, so it's 100% - file_progress = 1.0; - } else if (file->buffer_manager->file_handle->compression_type == FileCompressionType::GZIP || - file->buffer_manager->file_handle->compression_type == FileCompressionType::ZSTD) { - // This file is not done, and is a compressed file - file_progress = file->buffer_manager->file_handle->GetProgress(); - } else { - file_progress = static_cast(file->bytes_read); - } - // This file is an uncompressed file, so we use the more price bytes_read from the scanner - percentage += (static_cast(1) / static_cast(total_files)) * - std::min(1.0, file_progress / static_cast(file->file_size)); - } - } - return percentage * 100; -} - -unique_ptr CSVGlobalState::Next(optional_ptr previous_scanner) { - if (previous_scanner) { - // We have to insert information for validation - lock_guard parallel_lock(main_mutex); - validator.Insert(previous_scanner->csv_file_scan->file_idx, previous_scanner->scanner_idx, - previous_scanner->GetValidationLine()); - } - if (single_threaded) { - { - lock_guard parallel_lock(main_mutex); - if (previous_scanner) { - // Cleanup previous scanner. - previous_scanner->buffer_tracker.reset(); - current_buffer_in_use.reset(); - previous_scanner->csv_file_scan->Finish(); - } - } - idx_t cur_idx; - bool empty_file = false; - do { - { - lock_guard parallel_lock(main_mutex); - cur_idx = last_file_idx++; - if (cur_idx >= bind_data.files.size()) { - // No more files to scan - return nullptr; - } - if (cur_idx == 0) { - D_ASSERT(!previous_scanner); - auto current_file = file_scans.front(); - return make_uniq(scanner_idx++, current_file->buffer_manager, - current_file->state_machine, current_file->error_handler, - current_file, false, current_boundary); - } - } - auto file_scan = make_shared_ptr(context, bind_data.files[cur_idx], bind_data.options, cur_idx, - bind_data, column_ids, file_schema, true); - empty_file = file_scan->file_size == 0; - - if (!empty_file) { - lock_guard parallel_lock(main_mutex); - file_scans.emplace_back(std::move(file_scan)); - auto current_file = file_scans.back(); - current_boundary = current_file->start_iterator; - current_boundary.SetCurrentBoundaryToPosition(single_threaded, bind_data.options); - current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, - current_boundary.GetBufferIdx()); - - return make_uniq(scanner_idx++, current_file->buffer_manager, - current_file->state_machine, current_file->error_handler, - current_file, false, current_boundary); - } - } while (empty_file); - } - lock_guard parallel_lock(main_mutex); - if (finished) { - return nullptr; - } - if (current_buffer_in_use->buffer_idx != current_boundary.GetBufferIdx()) { - current_buffer_in_use = - make_shared_ptr(*file_scans.back()->buffer_manager, current_boundary.GetBufferIdx()); - } - // We first create the scanner for the current boundary - auto ¤t_file = *file_scans.back(); - auto csv_scanner = - make_uniq(scanner_idx++, current_file.buffer_manager, current_file.state_machine, - current_file.error_handler, file_scans.back(), false, current_boundary); - threads_per_file[csv_scanner->csv_file_scan->file_idx]++; - if (previous_scanner) { - threads_per_file[previous_scanner->csv_file_scan->file_idx]--; - if (threads_per_file[previous_scanner->csv_file_scan->file_idx] == 0) { - previous_scanner->buffer_tracker.reset(); - previous_scanner->csv_file_scan->Finish(); - } - } - csv_scanner->buffer_tracker = current_buffer_in_use; - - // We then produce the next boundary - if (!current_boundary.Next(*current_file.buffer_manager, bind_data.options)) { - // This means we are done scanning the current file - do { - auto current_file_idx = file_scans.back()->file_idx + 1; - if (current_file_idx < bind_data.files.size()) { - // If we have a next file we have to construct the file scan for that - file_scans.emplace_back(make_shared_ptr(context, bind_data.files[current_file_idx], - bind_data.options, current_file_idx, bind_data, - column_ids, file_schema, false)); - // And re-start the boundary-iterator - current_boundary = file_scans.back()->start_iterator; - current_boundary.SetCurrentBoundaryToPosition(single_threaded, bind_data.options); - current_buffer_in_use = make_shared_ptr(*file_scans.back()->buffer_manager, - current_boundary.GetBufferIdx()); - } else { - // If not we are done with this CSV Scanning - finished = true; - break; - } - } while (current_boundary.done); - } - // We initialize the scan - return csv_scanner; -} - -idx_t CSVGlobalState::MaxThreads() const { - // We initialize max one thread per our set bytes per thread limit - if (single_threaded || !file_scans.front()->on_disk_file) { - return system_threads; - } - const idx_t bytes_per_thread = CSVIterator::BytesPerThread(file_scans.front()->options); - const idx_t total_threads = file_scans.front()->file_size / bytes_per_thread + 1; - if (total_threads < system_threads) { - return total_threads; - } - return system_threads; -} - -void CSVGlobalState::DecrementThread() { - lock_guard parallel_lock(main_mutex); - D_ASSERT(running_threads > 0); - running_threads--; - if (running_threads == 0) { - const bool ignore_or_store_errors = - bind_data.options.ignore_errors.GetValue() || bind_data.options.store_rejects.GetValue(); - if (!single_threaded && !ignore_or_store_errors) { - // If we are running multithreaded and not ignoring errors, we must run the validator - validator.Verify(); - } - for (const auto &file : file_scans) { - file->error_handler->ErrorIfNeeded(); - } - FillRejectsTable(); - if (context.client_data->debug_set_max_line_length) { - context.client_data->debug_max_line_length = file_scans[0]->error_handler->GetMaxLineLength(); - } - } -} - -void FillScanErrorTable(InternalAppender &scan_appender, idx_t scan_idx, idx_t file_idx, CSVFileScan &file) { - CSVReaderOptions &options = file.options; - // Add the row to the rejects table - scan_appender.BeginRow(); - // 1. Scan Idx - scan_appender.Append(scan_idx); - // 2. File Idx - scan_appender.Append(file_idx); - // 3. File Path - scan_appender.Append(string_t(file.file_path)); - // 4. Delimiter - scan_appender.Append(string_t(options.dialect_options.state_machine_options.delimiter.FormatValue())); - // 5. Quote - scan_appender.Append(string_t(options.dialect_options.state_machine_options.quote.FormatValue())); - // 6. Escape - scan_appender.Append(string_t(options.dialect_options.state_machine_options.escape.FormatValue())); - // 7. NewLine Delimiter - scan_appender.Append(string_t(options.NewLineIdentifierToString())); - // 8. Skip Rows - scan_appender.Append(Value::UINTEGER(NumericCast(options.dialect_options.skip_rows.GetValue()))); - // 9. Has Header - scan_appender.Append(Value::BOOLEAN(options.dialect_options.header.GetValue())); - // 10. List> {'col1': 'INTEGER', 'col2': 'VARCHAR'} - std::ostringstream columns; - columns << "{"; - for (idx_t i = 0; i < file.types.size(); i++) { - columns << "'" << file.names[i] << "': '" << file.types[i].ToString() << "'"; - if (i != file.types.size() - 1) { - columns << ","; - } - } - columns << "}"; - scan_appender.Append(string_t(columns.str())); - // 11. Date Format - auto date_format = options.dialect_options.date_format[LogicalType::DATE].GetValue(); - if (!date_format.Empty()) { - scan_appender.Append(string_t(date_format.format_specifier)); - } else { - scan_appender.Append(Value()); - } - - // 12. Timestamp Format - auto timestamp_format = options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue(); - if (!timestamp_format.Empty()) { - scan_appender.Append(string_t(timestamp_format.format_specifier)); - } else { - scan_appender.Append(Value()); - } - - // 13. The Extra User Arguments - if (options.user_defined_parameters.empty()) { - scan_appender.Append(Value()); - } else { - scan_appender.Append(string_t(options.user_defined_parameters)); - } - // Finish the row to the rejects table - scan_appender.EndRow(); -} - -void CSVGlobalState::FillRejectsTable() const { - auto &options = bind_data.options; - - if (options.store_rejects.GetValue()) { - auto limit = options.rejects_limit; - auto rejects = CSVRejectsTable::GetOrCreate(context, options.rejects_scan_name.GetValue(), - options.rejects_table_name.GetValue()); - lock_guard lock(rejects->write_lock); - auto &errors_table = rejects->GetErrorsTable(context); - auto &scans_table = rejects->GetScansTable(context); - InternalAppender errors_appender(context, errors_table); - InternalAppender scans_appender(context, scans_table); - idx_t scan_idx = context.transaction.GetActiveQuery(); - for (auto &file : file_scans) { - const idx_t file_idx = rejects->GetCurrentFileIndex(scan_idx); - auto file_name = file->file_path; - file->error_handler->FillRejectsTable(errors_appender, file_idx, scan_idx, *file, *rejects, bind_data, - limit); - if (rejects->count != 0) { - rejects->count = 0; - FillScanErrorTable(scans_appender, scan_idx, file_idx, *file); - } - } - errors_appender.Close(); - scans_appender.Close(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp deleted file mode 100644 index f77a93ac0..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_error.cpp +++ /dev/null @@ -1,552 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_error.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/table/read_csv.hpp" -#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp" -#include "duckdb/main/appender.hpp" -#include - -namespace duckdb { - -LinesPerBoundary::LinesPerBoundary() { -} -LinesPerBoundary::LinesPerBoundary(idx_t boundary_idx_p, idx_t lines_in_batch_p) - : boundary_idx(boundary_idx_p), lines_in_batch(lines_in_batch_p) { -} - -CSVErrorHandler::CSVErrorHandler(bool ignore_errors_p) : ignore_errors(ignore_errors_p) { -} - -void CSVErrorHandler::ThrowError(const CSVError &csv_error) { - std::ostringstream error; - if (PrintLineNumber(csv_error)) { - error << "CSV Error on Line: " << GetLineInternal(csv_error.error_info) << '\n'; - if (!csv_error.csv_row.empty()) { - error << "Original Line: " << csv_error.csv_row << '\n'; - } - } - if (csv_error.full_error_message.empty()) { - error << csv_error.error_message; - } else { - error << csv_error.full_error_message; - } - - switch (csv_error.type) { - case CAST_ERROR: - throw ConversionException(error.str()); - case COLUMN_NAME_TYPE_MISMATCH: - throw BinderException(error.str()); - case NULLPADDED_QUOTED_NEW_VALUE: - throw ParameterNotAllowedException(error.str()); - default: - throw InvalidInputException(error.str()); - } -} - -void CSVErrorHandler::Error(const CSVError &csv_error, bool force_error) { - lock_guard parallel_lock(main_mutex); - if ((ignore_errors && !force_error) || (PrintLineNumber(csv_error) && !CanGetLine(csv_error.GetBoundaryIndex()))) { - // We store this error, we can't throw it now, or we are ignoring it - errors.push_back(csv_error); - return; - } - // Otherwise we can throw directly - ThrowError(csv_error); -} - -void CSVErrorHandler::ErrorIfNeeded() { - lock_guard parallel_lock(main_mutex); - if (ignore_errors || errors.empty()) { - // Nothing to error - return; - } - - if (CanGetLine(errors[0].error_info.boundary_idx)) { - ThrowError(errors[0]); - } -} - -void CSVErrorHandler::ErrorIfTypeExists(CSVErrorType error_type) { - lock_guard parallel_lock(main_mutex); - for (auto &error : errors) { - if (error.type == error_type) { - // If it's a maximum line size error, we can do it now. - ThrowError(error); - } - } -} - -void CSVErrorHandler::Insert(idx_t boundary_idx, idx_t rows) { - lock_guard parallel_lock(main_mutex); - if (lines_per_batch_map.find(boundary_idx) == lines_per_batch_map.end()) { - lines_per_batch_map[boundary_idx] = {boundary_idx, rows}; - } else { - lines_per_batch_map[boundary_idx].lines_in_batch += rows; - } -} - -void CSVErrorHandler::NewMaxLineSize(idx_t scan_line_size) { - lock_guard parallel_lock(main_mutex); - max_line_length = std::max(scan_line_size, max_line_length); -} - -bool CSVErrorHandler::AnyErrors() { - lock_guard parallel_lock(main_mutex); - return !errors.empty(); -} - -bool CSVErrorHandler::HasError(const CSVErrorType error_type) { - lock_guard parallel_lock(main_mutex); - for (const auto &er : errors) { - if (er.type == error_type) { - return true; - } - } - return false; -} - -idx_t CSVErrorHandler::GetSize() { - lock_guard parallel_lock(main_mutex); - return errors.size(); -} - -bool IsCSVErrorAcceptedReject(CSVErrorType type) { - switch (type) { - case CSVErrorType::INVALID_STATE: - case CSVErrorType::CAST_ERROR: - case CSVErrorType::TOO_MANY_COLUMNS: - case CSVErrorType::TOO_FEW_COLUMNS: - case CSVErrorType::MAXIMUM_LINE_SIZE: - case CSVErrorType::UNTERMINATED_QUOTES: - case CSVErrorType::INVALID_UNICODE: - return true; - default: - return false; - } -} -string CSVErrorTypeToEnum(CSVErrorType type) { - switch (type) { - case CSVErrorType::CAST_ERROR: - return "CAST"; - case CSVErrorType::TOO_FEW_COLUMNS: - return "MISSING COLUMNS"; - case CSVErrorType::TOO_MANY_COLUMNS: - return "TOO MANY COLUMNS"; - case CSVErrorType::MAXIMUM_LINE_SIZE: - return "LINE SIZE OVER MAXIMUM"; - case CSVErrorType::UNTERMINATED_QUOTES: - return "UNQUOTED VALUE"; - case CSVErrorType::INVALID_UNICODE: - return "INVALID UNICODE"; - case CSVErrorType::INVALID_STATE: - return "INVALID STATE"; - default: - throw InternalException("CSV Error is not valid to be stored in a Rejects Table"); - } -} - -void CSVErrorHandler::FillRejectsTable(InternalAppender &errors_appender, const idx_t file_idx, const idx_t scan_idx, - const CSVFileScan &file, CSVRejectsTable &rejects, const ReadCSVData &bind_data, - const idx_t limit) { - lock_guard parallel_lock(main_mutex); - // We first insert the file into the file scans table - for (auto &error : file.error_handler->errors) { - if (!IsCSVErrorAcceptedReject(error.type)) { - continue; - } - // short circuit if we already have too many rejects - if (limit == 0 || rejects.count < limit) { - if (limit != 0 && rejects.count >= limit) { - break; - } - rejects.count++; - const auto row_line = file.error_handler->GetLineInternal(error.error_info); - const auto col_idx = error.column_idx; - // Add the row to the rejects table - errors_appender.BeginRow(); - // 1. Scan ID - errors_appender.Append(scan_idx); - // 2. File ID - errors_appender.Append(file_idx); - // 3. Row Line - errors_appender.Append(row_line); - // 4. Byte Position of the row error - errors_appender.Append(error.row_byte_position + 1); - // 5. Byte Position where error occurred - if (!error.byte_position.IsValid()) { - // This means this error comes from a flush, and we don't support this yet, so we give it - // a null - errors_appender.Append(Value()); - } else { - errors_appender.Append(error.byte_position.GetIndex() + 1); - } - // 6. Column Index - if (error.type == CSVErrorType::MAXIMUM_LINE_SIZE) { - errors_appender.Append(Value()); - } else { - errors_appender.Append(col_idx + 1); - } - // 7. Column Name (If Applicable) - switch (error.type) { - case CSVErrorType::TOO_MANY_COLUMNS: - case CSVErrorType::MAXIMUM_LINE_SIZE: - errors_appender.Append(Value()); - break; - case CSVErrorType::TOO_FEW_COLUMNS: - D_ASSERT(bind_data.return_names.size() > col_idx + 1); - errors_appender.Append(string_t(bind_data.return_names[col_idx + 1])); - break; - default: - errors_appender.Append(string_t(bind_data.return_names[col_idx])); - } - // 8. Error Type - errors_appender.Append(string_t(CSVErrorTypeToEnum(error.type))); - // 9. Original CSV Line - errors_appender.Append(string_t(error.csv_row)); - // 10. Full Error Message - errors_appender.Append(string_t(error.error_message)); - errors_appender.EndRow(); - } - } -} - -idx_t CSVErrorHandler::GetMaxLineLength() { - lock_guard parallel_lock(main_mutex); - return max_line_length; -} - -void CSVErrorHandler::DontPrintErrorLine() { - lock_guard parallel_lock(main_mutex); - print_line = false; -} - -void CSVErrorHandler::SetIgnoreErrors(bool ignore_errors_p) { - lock_guard parallel_lock(main_mutex); - ignore_errors = ignore_errors_p; -} - -CSVError::CSVError(string error_message_p, CSVErrorType type_p, LinesPerBoundary error_info_p) - : error_message(std::move(error_message_p)), type(type_p), error_info(error_info_p) { -} - -CSVError::CSVError(string error_message_p, CSVErrorType type_p, idx_t column_idx_p, string csv_row_p, - LinesPerBoundary error_info_p, idx_t row_byte_position, optional_idx byte_position_p, - const CSVReaderOptions &reader_options, const string &fixes, const string ¤t_path) - : error_message(std::move(error_message_p)), type(type_p), column_idx(column_idx_p), csv_row(std::move(csv_row_p)), - error_info(error_info_p), row_byte_position(row_byte_position), byte_position(byte_position_p) { - // What were the options - std::ostringstream error; - if (reader_options.ignore_errors.GetValue()) { - RemoveNewLine(error_message); - } - error << error_message << '\n'; - error << fixes << '\n'; - error << reader_options.ToString(current_path); - error << '\n'; - full_error_message = error.str(); -} - -CSVError CSVError::ColumnTypesError(case_insensitive_map_t sql_types_per_column, const vector &names) { - for (idx_t i = 0; i < names.size(); i++) { - auto it = sql_types_per_column.find(names[i]); - if (it != sql_types_per_column.end()) { - sql_types_per_column.erase(names[i]); - } - } - if (sql_types_per_column.empty()) { - return CSVError("", COLUMN_NAME_TYPE_MISMATCH, {}); - } - string exception = "COLUMN_TYPES error: Columns with names: "; - for (auto &col : sql_types_per_column) { - exception += "\"" + col.first + "\","; - } - exception.pop_back(); - exception += " do not exist in the CSV File"; - return CSVError(exception, COLUMN_NAME_TYPE_MISMATCH, {}); -} - -void CSVError::RemoveNewLine(string &error) { - error = StringUtil::Split(error, "\n")[0]; -} - -CSVError CSVError::CastError(const CSVReaderOptions &options, string &column_name, string &cast_error, idx_t column_idx, - string &csv_row, LinesPerBoundary error_info, idx_t row_byte_position, - optional_idx byte_position, LogicalTypeId type, const string ¤t_path) { - std::ostringstream error; - // Which column - error << "Error when converting column \"" << column_name << "\". "; - // What was the cast error - error << cast_error << '\n'; - std::ostringstream how_to_fix_it; - how_to_fix_it << "Column " << column_name << " is being converted as type " << LogicalTypeIdToString(type) << '\n'; - if (!options.WasTypeManuallySet(column_idx)) { - how_to_fix_it << "This type was auto-detected from the CSV file." << '\n'; - how_to_fix_it << "Possible solutions:" << '\n'; - how_to_fix_it << "* Override the type for this column manually by setting the type explicitly, e.g. types={'" - << column_name << "': 'VARCHAR'}" << '\n'; - how_to_fix_it - << "* Set the sample size to a larger value to enable the auto-detection to scan more values, e.g. " - "sample_size=-1" - << '\n'; - how_to_fix_it << "* Use a COPY statement to automatically derive types from an existing table." << '\n'; - } else { - how_to_fix_it - << "This type was either manually set or derived from an existing table. Select a different type to " - "correctly parse this column." - << '\n'; - } - - return CSVError(error.str(), CAST_ERROR, column_idx, csv_row, error_info, row_byte_position, byte_position, options, - how_to_fix_it.str(), current_path); -} - -CSVError CSVError::LineSizeError(const CSVReaderOptions &options, LinesPerBoundary error_info, string &csv_row, - idx_t byte_position, const string ¤t_path) { - std::ostringstream error; - error << "Maximum line size of " << options.maximum_line_size.GetValue() << " bytes exceeded. "; - error << "Actual Size:" << csv_row.size() << " bytes." << '\n'; - - std::ostringstream how_to_fix_it; - how_to_fix_it << "Possible Solution: Change the maximum length size, e.g., max_line_size=" << csv_row.size() + 2 - << "\n"; - - return CSVError(error.str(), MAXIMUM_LINE_SIZE, 0, csv_row, error_info, byte_position, byte_position, options, - how_to_fix_it.str(), current_path); -} - -CSVError CSVError::InvalidState(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, - string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { - std::ostringstream error; - error << "The CSV Parser state machine reached an invalid state.\nThis can happen when is not possible to parse " - "your CSV File with the given options, or the CSV File is not RFC 4180 compliant "; - - std::ostringstream how_to_fix_it; - how_to_fix_it << "Possible fixes:" << '\n'; - how_to_fix_it << "* Enable scanning files that are not RFC 4180 compliant (rfc_4180=false)." << '\n'; - - return CSVError(error.str(), INVALID_STATE, current_column, csv_row, error_info, row_byte_position, byte_position, - options, how_to_fix_it.str(), current_path); -} -CSVError CSVError::HeaderSniffingError(const CSVReaderOptions &options, const vector &best_header_row, - const idx_t column_count, const string &delimiter) { - std::ostringstream error; - // 1. Which file - error << "Error when sniffing file \"" << options.file_path << "\"." << '\n'; - // 2. What's the error - error << "It was not possible to detect the CSV Header, due to the header having less columns than expected" - << '\n'; - // 2.1 What's the expected number of columns - error << "Number of expected columns: " << column_count << ". Actual number of columns " << best_header_row.size() - << '\n'; - // 2.2 What was the detected row - error << "Detected row as Header:" << '\n'; - for (idx_t i = 0; i < best_header_row.size(); i++) { - if (best_header_row[i].is_null) { - error << "NULL"; - } else { - error << best_header_row[i].value; - } - if (i < best_header_row.size() - 1) { - error << delimiter << " "; - } - } - error << "\n"; - - // 3. Suggest how to fix it! - error << "Possible fixes:" << '\n'; - // header - if (!options.dialect_options.header.IsSetByUser()) { - error << "* Set header (header = true) if your CSV has a header, or (header = false) if it doesn't" << '\n'; - } else { - error << "* Header is set to \'" << options.dialect_options.header.GetValue() << "\'. Consider unsetting it." - << '\n'; - } - // skip_rows - if (!options.dialect_options.skip_rows.IsSetByUser()) { - error << "* Set skip (skip=${n}) to skip ${n} lines at the top of the file" << '\n'; - } else { - error << "* Skip is set to \'" << options.dialect_options.skip_rows.GetValue() << "\'. Consider unsetting it." - << '\n'; - } - // ignore_errors - if (!options.ignore_errors.GetValue()) { - error << "* Enable ignore errors (ignore_errors=true) to ignore potential errors" << '\n'; - } - // null_padding - if (!options.null_padding) { - error << "* Enable null padding (null_padding=true) to pad missing columns with NULL values" << '\n'; - } - - return CSVError(error.str(), SNIFFING, {}); -} - -CSVError CSVError::SniffingError(const CSVReaderOptions &options, const string &search_space) { - std::ostringstream error; - // 1. Which file - error << "Error when sniffing file \"" << options.file_path << "\"." << '\n'; - // 2. What's the error - error << "It was not possible to automatically detect the CSV Parsing dialect/types" << '\n'; - - // 2. What was the search space? - error << "The search space used was:" << '\n'; - error << search_space; - // 3. Suggest how to fix it! - error << "Possible fixes:" << '\n'; - // 3.1 Inform the reader of the dialect - // delimiter - if (!options.dialect_options.state_machine_options.delimiter.IsSetByUser()) { - error << "* Set delimiter (e.g., delim=\',\')" << '\n'; - } else { - error << "* Delimiter is set to \'" << options.dialect_options.state_machine_options.delimiter.GetValue() - << "\'. Consider unsetting it." << '\n'; - } - // quote - if (!options.dialect_options.state_machine_options.quote.IsSetByUser()) { - error << "* Set quote (e.g., quote=\'\"\')" << '\n'; - } else { - error << "* Quote is set to \'" << options.dialect_options.state_machine_options.quote.GetValue() - << "\'. Consider unsetting it." << '\n'; - } - // escape - if (!options.dialect_options.state_machine_options.escape.IsSetByUser()) { - error << "* Set escape (e.g., escape=\'\"\')" << '\n'; - } else { - error << "* Escape is set to \'" << options.dialect_options.state_machine_options.escape.GetValue() - << "\'. Consider unsetting it." << '\n'; - } - // comment - if (!options.dialect_options.state_machine_options.comment.IsSetByUser()) { - error << "* Set comment (e.g., comment=\'#\')" << '\n'; - } else { - error << "* Comment is set to \'" << options.dialect_options.state_machine_options.comment.GetValue() - << "\'. Consider unsetting it." << '\n'; - } - // 3.2 skip_rows - if (!options.dialect_options.skip_rows.IsSetByUser()) { - error << "* Set skip (skip=${n}) to skip ${n} lines at the top of the file" << '\n'; - } - // 3.3 ignore_errors - if (!options.ignore_errors.GetValue()) { - error << "* Enable ignore errors (ignore_errors=true) to ignore potential errors" << '\n'; - } - // 3.4 null_padding - if (!options.null_padding) { - error << "* Enable null padding (null_padding=true) to pad missing columns with NULL values" << '\n'; - } - error << "* Check you are using the correct file compression, otherwise set it (e.g., compression = \'zstd\')" - << '\n'; - error << "* Be sure that the maximum line size is set to an appropriate value, otherwise set it (e.g., " - "max_line_size=10000000)" - << "\n"; - - if (options.dialect_options.state_machine_options.rfc_4180.GetValue() != false || - !options.dialect_options.state_machine_options.rfc_4180.IsSetByUser()) { - error << "* Enable scanning files that are not RFC 4180 compliant (rfc_4180=false). " << '\n'; - } - return CSVError(error.str(), SNIFFING, {}); -} - -CSVError CSVError::NullPaddingFail(const CSVReaderOptions &options, LinesPerBoundary error_info, - const string ¤t_path) { - std::ostringstream error; - error << " The parallel scanner does not support null_padding in conjunction with quoted new lines. Please " - "disable the parallel csv reader with parallel=false" - << '\n'; - // What were the options - error << options.ToString(current_path); - return CSVError(error.str(), NULLPADDED_QUOTED_NEW_VALUE, error_info); -} - -CSVError CSVError::UnterminatedQuotesError(const CSVReaderOptions &options, idx_t current_column, - LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { - std::ostringstream error; - error << "Value with unterminated quote found." << '\n'; - std::ostringstream how_to_fix_it; - how_to_fix_it << "Possible fixes:" << '\n'; - how_to_fix_it << "* Enable ignore errors (ignore_errors=true) to skip this row" << '\n'; - how_to_fix_it << "* Set quote to empty or to a different value (e.g., quote=\'\')" << '\n'; - return CSVError(error.str(), UNTERMINATED_QUOTES, current_column, csv_row, error_info, row_byte_position, - byte_position, options, how_to_fix_it.str(), current_path); -} - -CSVError CSVError::IncorrectColumnAmountError(const CSVReaderOptions &options, idx_t actual_columns, - LinesPerBoundary error_info, string &csv_row, idx_t row_byte_position, - optional_idx byte_position, const string ¤t_path) { - std::ostringstream error; - // We don't have a fix for this - std::ostringstream how_to_fix_it; - how_to_fix_it << "Possible fixes:" << '\n'; - if (!options.null_padding) { - how_to_fix_it << "* Enable null padding (null_padding=true) to replace missing values with NULL" << '\n'; - } - if (!options.ignore_errors.GetValue()) { - how_to_fix_it << "* Enable ignore errors (ignore_errors=true) to skip this row" << '\n'; - } - // How many columns were expected and how many were found - error << "Expected Number of Columns: " << options.dialect_options.num_cols << " Found: " << actual_columns + 1; - idx_t byte_pos = byte_position.GetIndex() == 0 ? 0 : byte_position.GetIndex() - 1; - if (actual_columns >= options.dialect_options.num_cols) { - return CSVError(error.str(), TOO_MANY_COLUMNS, actual_columns, csv_row, error_info, row_byte_position, byte_pos, - options, how_to_fix_it.str(), current_path); - } else { - return CSVError(error.str(), TOO_FEW_COLUMNS, actual_columns, csv_row, error_info, row_byte_position, byte_pos, - options, how_to_fix_it.str(), current_path); - } -} - -CSVError CSVError::InvalidUTF8(const CSVReaderOptions &options, idx_t current_column, LinesPerBoundary error_info, - string &csv_row, idx_t row_byte_position, optional_idx byte_position, - const string ¤t_path) { - std::ostringstream error; - // How many columns were expected and how many were found - error << "Invalid unicode (byte sequence mismatch) detected." << '\n'; - std::ostringstream how_to_fix_it; - how_to_fix_it << "Possible Solution: Enable ignore errors (ignore_errors=true) to skip this row" << '\n'; - return CSVError(error.str(), INVALID_UNICODE, current_column, csv_row, error_info, row_byte_position, byte_position, - options, how_to_fix_it.str(), current_path); -} - -bool CSVErrorHandler::PrintLineNumber(const CSVError &error) const { - if (!print_line) { - return false; - } - switch (error.type) { - case CAST_ERROR: - case UNTERMINATED_QUOTES: - case TOO_FEW_COLUMNS: - case TOO_MANY_COLUMNS: - case MAXIMUM_LINE_SIZE: - case NULLPADDED_QUOTED_NEW_VALUE: - case INVALID_UNICODE: - return true; - default: - return false; - } -} - -bool CSVErrorHandler::CanGetLine(idx_t boundary_index) { - for (idx_t i = 0; i < boundary_index; i++) { - if (lines_per_batch_map.find(i) == lines_per_batch_map.end()) { - return false; - } - } - return true; -} - -idx_t CSVErrorHandler::GetLine(const LinesPerBoundary &error_info) { - lock_guard parallel_lock(main_mutex); - return GetLineInternal(error_info); -} -idx_t CSVErrorHandler::GetLineInternal(const LinesPerBoundary &error_info) { - // We start from one, since the lines are 1-indexed - idx_t current_line = 1 + error_info.lines_in_batch; - for (idx_t boundary_idx = 0; boundary_idx < error_info.boundary_idx; boundary_idx++) { - current_line += lines_per_batch_map[boundary_idx].lines_in_batch; - } - return current_line; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp deleted file mode 100644 index ac18d42eb..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_reader_options.cpp +++ /dev/null @@ -1,765 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" -#include "duckdb/common/bind_helpers.hpp" -#include "duckdb/common/vector_size.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/set.hpp" - -namespace duckdb { - -CSVReaderOptions::CSVReaderOptions(const CSVOption single_byte_delimiter, - const CSVOption &multi_byte_delimiter) { - if (multi_byte_delimiter.GetValue().empty()) { - const char single_byte_value = single_byte_delimiter.GetValue(); - const string value(1, single_byte_value); - dialect_options.state_machine_options.delimiter = value; - } else { - dialect_options.state_machine_options.delimiter = multi_byte_delimiter; - } -} -static bool ParseBoolean(const Value &value, const string &loption); - -static bool ParseBoolean(const vector &set, const string &loption) { - if (set.empty()) { - // no option specified: default to true - return true; - } - if (set.size() > 1) { - throw BinderException("\"%s\" expects a single argument as a boolean value (e.g. TRUE or 1)", loption); - } - return ParseBoolean(set[0], loption); -} - -static bool ParseBoolean(const Value &value, const string &loption) { - if (value.IsNull()) { - throw BinderException("\"%s\" expects a non-null boolean value (e.g. TRUE or 1)", loption); - } - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - return ParseBoolean(children, loption); - } - if (value.type() == LogicalType::FLOAT || value.type() == LogicalType::DOUBLE || - value.type().id() == LogicalTypeId::DECIMAL) { - throw BinderException("\"%s\" expects a boolean value (e.g. TRUE or 1)", loption); - } - return BooleanValue::Get(value.DefaultCastAs(LogicalType::BOOLEAN)); -} - -static string ParseString(const Value &value, const string &loption) { - if (value.IsNull()) { - return string(); - } - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - if (children.size() != 1) { - throw BinderException("\"%s\" expects a single argument as a string value", loption); - } - return ParseString(children[0], loption); - } - if (value.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("\"%s\" expects a string argument!", loption); - } - return value.GetValue(); -} - -static int64_t ParseInteger(const Value &value, const string &loption) { - if (value.IsNull()) { - throw BinderException("\"%s\" expects a non-null integer value", loption); - } - if (value.type().id() == LogicalTypeId::LIST) { - auto &children = ListValue::GetChildren(value); - if (children.size() != 1) { - // no option specified or multiple options specified - throw BinderException("\"%s\" expects a single argument as an integer value", loption); - } - return ParseInteger(children[0], loption); - } - return value.GetValue(); -} - -bool CSVReaderOptions::GetHeader() const { - return this->dialect_options.header.GetValue(); -} - -void CSVReaderOptions::SetHeader(bool input) { - this->dialect_options.header.Set(input); -} - -void CSVReaderOptions::SetCompression(const string &compression_p) { - this->compression = FileCompressionTypeFromString(compression_p); -} - -string CSVReaderOptions::GetEscape() const { - return std::string(1, this->dialect_options.state_machine_options.escape.GetValue()); -} - -void CSVReaderOptions::SetEscape(const string &input) { - auto escape_str = input; - if (escape_str.size() > 1) { - throw InvalidInputException("The escape option cannot exceed a size of 1 byte."); - } - if (escape_str.empty()) { - escape_str = string("\0", 1); - } - this->dialect_options.state_machine_options.escape.Set(escape_str[0]); -} - -idx_t CSVReaderOptions::GetSkipRows() const { - return NumericCast(this->dialect_options.skip_rows.GetValue()); -} - -void CSVReaderOptions::SetSkipRows(int64_t skip_rows) { - if (skip_rows < 0) { - throw InvalidInputException("skip_rows option from read_csv scanner, must be equal or higher than 0"); - } - dialect_options.skip_rows.Set(NumericCast(skip_rows)); -} - -string CSVReaderOptions::GetDelimiter() const { - return this->dialect_options.state_machine_options.delimiter.GetValue(); -} - -void CSVReaderOptions::SetDelimiter(const string &input) { - auto delim_str = StringUtil::Replace(input, "\\t", "\t"); - if (delim_str.size() > 4) { - throw InvalidInputException("The delimiter option cannot exceed a size of 4 bytes."); - } - if (input.empty()) { - delim_str = string("\0", 1); - } - this->dialect_options.state_machine_options.delimiter.Set(delim_str); -} - -string CSVReaderOptions::GetQuote() const { - return std::string(1, this->dialect_options.state_machine_options.quote.GetValue()); -} - -void CSVReaderOptions::SetQuote(const string "e_p) { - auto quote_str = quote_p; - if (quote_str.size() > 1) { - throw InvalidInputException("The quote option cannot exceed a size of 1 byte."); - } - if (quote_str.empty()) { - quote_str = string("\0", 1); - } - this->dialect_options.state_machine_options.quote.Set(quote_str[0]); -} - -string CSVReaderOptions::GetComment() const { - return std::string(1, this->dialect_options.state_machine_options.comment.GetValue()); -} - -void CSVReaderOptions::SetComment(const string &comment_p) { - auto comment_str = comment_p; - if (comment_str.size() > 1) { - throw InvalidInputException("The comment option cannot exceed a size of 1 byte."); - } - if (comment_str.empty()) { - comment_str = string("\0", 1); - } - this->dialect_options.state_machine_options.comment.Set(comment_str[0]); -} - -string CSVReaderOptions::GetNewline() const { - switch (dialect_options.state_machine_options.new_line.GetValue()) { - case NewLineIdentifier::CARRY_ON: - return "\\r\\n"; - case NewLineIdentifier::SINGLE_R: - return "\\r"; - case NewLineIdentifier::SINGLE_N: - return "\\n"; - case NewLineIdentifier::NOT_SET: - return ""; - default: - throw NotImplementedException("New line type not supported"); - } -} - -void CSVReaderOptions::SetNewline(const string &input) { - if (input == "\\n") { - dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::SINGLE_N); - } else if (input == "\\r") { - dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::SINGLE_R); - } else if (input == "\\r\\n") { - dialect_options.state_machine_options.new_line.Set(NewLineIdentifier::CARRY_ON); - } else { - throw InvalidInputException("This is not accepted as a newline: " + input); - } -} - -bool CSVReaderOptions::GetRFC4180() const { - return this->dialect_options.state_machine_options.rfc_4180.GetValue(); -} - -void CSVReaderOptions::SetRFC4180(bool input) { - this->dialect_options.state_machine_options.rfc_4180.Set(input); -} - -bool CSVReaderOptions::IgnoreErrors() const { - return ignore_errors.GetValue() && !store_rejects.GetValue(); -} - -char CSVReaderOptions::GetSingleByteDelimiter() const { - return dialect_options.state_machine_options.delimiter.GetValue()[0]; -} - -string CSVReaderOptions::GetMultiByteDelimiter() const { - return dialect_options.state_machine_options.delimiter.GetValue(); -} - -void CSVReaderOptions::SetDateFormat(LogicalTypeId type, const string &format, bool read_format) { - string error; - if (read_format) { - StrpTimeFormat strpformat; - error = StrTimeFormat::ParseFormatSpecifier(format, strpformat); - dialect_options.date_format[type].Set(strpformat); - } else { - write_date_format[type] = Value(format); - } - if (!error.empty()) { - throw InvalidInputException("Could not parse DATEFORMAT: %s", error.c_str()); - } -} - -void CSVReaderOptions::SetReadOption(const string &loption, const Value &value, vector &expected_names) { - if (SetBaseOption(loption, value)) { - return; - } - if (loption == "auto_detect") { - auto_detect = ParseBoolean(value, loption); - } else if (loption == "sample_size") { - const auto sample_size_option = ParseInteger(value, loption); - if (sample_size_option < 1 && sample_size_option != -1) { - throw BinderException("Unsupported parameter for SAMPLE_SIZE: cannot be smaller than 1"); - } - if (sample_size_option == -1) { - // If -1, we basically read the whole thing - sample_size_chunks = NumericLimits().Maximum(); - } else { - sample_size_chunks = NumericCast(sample_size_option / STANDARD_VECTOR_SIZE); - if (sample_size_option % STANDARD_VECTOR_SIZE != 0) { - sample_size_chunks++; - } - } - - } else if (loption == "skip") { - SetSkipRows(ParseInteger(value, loption)); - } else if (loption == "max_line_size" || loption == "maximum_line_size") { - auto line_size = ParseInteger(value, loption); - if (line_size < 0) { - throw BinderException("Invalid value for MAX_LINE_SIZE parameter: it cannot be smaller than 0"); - } - maximum_line_size.Set(NumericCast(line_size)); - } else if (loption == "date_format" || loption == "dateformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::DATE, format, true); - } else if (loption == "timestamp_format" || loption == "timestampformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::TIMESTAMP, format, true); - } else if (loption == "ignore_errors") { - ignore_errors.Set(ParseBoolean(value, loption)); - } else if (loption == "buffer_size") { - buffer_size_option.Set(NumericCast(ParseInteger(value, loption))); - if (buffer_size_option == 0) { - throw InvalidInputException("Buffer Size option must be higher than 0"); - } - } else if (loption == "decimal_separator") { - decimal_separator = ParseString(value, loption); - if (decimal_separator != "." && decimal_separator != ",") { - throw BinderException("Unsupported parameter for DECIMAL_SEPARATOR: should be '.' or ','"); - } - } else if (loption == "null_padding") { - null_padding = ParseBoolean(value, loption); - } else if (loption == "parallel") { - parallel = ParseBoolean(value, loption); - } else if (loption == "allow_quoted_nulls") { - allow_quoted_nulls = ParseBoolean(value, loption); - } else if (loption == "store_rejects") { - store_rejects.Set(ParseBoolean(value, loption)); - } else if (loption == "force_not_null") { - if (!expected_names.empty()) { - force_not_null = ParseColumnList(value, expected_names, loption); - } else { - if (value.IsNull()) { - throw BinderException("Invalid value for 'force_not_null' paramenter"); - } - // Get the list of columns to use as a recovery key - auto &children = ListValue::GetChildren(value); - for (auto &child : children) { - auto col_name = child.GetValue(); - force_not_null_names.insert(col_name); - } - } - - } else if (loption == "rejects_table") { - // skip, handled in SetRejectsOptions - auto table_name = ParseString(value, loption); - if (table_name.empty()) { - throw BinderException("REJECTS_TABLE option cannot be empty"); - } - rejects_table_name.Set(table_name); - } else if (loption == "rejects_scan") { - // skip, handled in SetRejectsOptions - auto table_name = ParseString(value, loption); - if (table_name.empty()) { - throw BinderException("rejects_scan option cannot be empty"); - } - rejects_scan_name.Set(table_name); - } else if (loption == "rejects_limit") { - auto limit = ParseInteger(value, loption); - if (limit < 0) { - throw BinderException("Unsupported parameter for REJECTS_LIMIT: cannot be negative"); - } - rejects_limit = NumericCast(limit); - } else if (loption == "encoding") { - encoding = ParseString(value, loption); - } else { - throw BinderException("Unrecognized option for CSV reader \"%s\"", loption); - } -} - -void CSVReaderOptions::SetWriteOption(const string &loption, const Value &value) { - if (loption == "new_line") { - // Steal this from SetBaseOption so we can write different newlines (e.g., format JSON ARRAY) - write_newline = ParseString(value, loption); - return; - } - - if (SetBaseOption(loption, value, true)) { - return; - } - - if (loption == "force_quote") { - force_quote = ParseColumnList(value, name_list, loption); - } else if (loption == "date_format" || loption == "dateformat") { - string format = ParseString(value, loption); - SetDateFormat(LogicalTypeId::DATE, format, false); - } else if (loption == "timestamp_format" || loption == "timestampformat") { - string format = ParseString(value, loption); - if (StringUtil::Lower(format) == "iso") { - format = "%Y-%m-%dT%H:%M:%S.%fZ"; - } - SetDateFormat(LogicalTypeId::TIMESTAMP, format, false); - SetDateFormat(LogicalTypeId::TIMESTAMP_TZ, format, false); - } else if (loption == "prefix") { - prefix = ParseString(value, loption); - } else if (loption == "suffix") { - suffix = ParseString(value, loption); - } else { - throw BinderException("Unrecognized option CSV writer \"%s\"", loption); - } -} - -bool CSVReaderOptions::SetBaseOption(const string &loption, const Value &value, bool write_option) { - // Make sure this function was only called after the option was turned into lowercase - D_ASSERT(!std::any_of(loption.begin(), loption.end(), ::isupper)); - - if (StringUtil::StartsWith(loption, "delim") || StringUtil::StartsWith(loption, "sep")) { - SetDelimiter(ParseString(value, loption)); - } else if (loption == "quote") { - SetQuote(ParseString(value, loption)); - } else if (loption == "comment") { - SetComment(ParseString(value, loption)); - } else if (loption == "new_line") { - SetNewline(ParseString(value, loption)); - } else if (loption == "escape") { - SetEscape(ParseString(value, loption)); - } else if (loption == "header") { - SetHeader(ParseBoolean(value, loption)); - } else if (loption == "nullstr" || loption == "null") { - auto &child_type = value.type(); - null_str.clear(); - if (child_type.id() != LogicalTypeId::LIST && child_type.id() != LogicalTypeId::VARCHAR) { - throw BinderException("CSV Reader function option %s requires a string or a list as input", loption); - } - if (!null_str.empty()) { - throw BinderException("CSV Reader function option nullstr can only be supplied once"); - } - if (child_type.id() == LogicalTypeId::LIST) { - auto &list_child = ListType::GetChildType(child_type); - const vector *children = nullptr; - if (list_child.id() == LogicalTypeId::LIST) { - // This can happen if it comes from a copy FROM/TO - auto &list_grandchild = ListType::GetChildType(list_child); - auto &children_ref = ListValue::GetChildren(value); - if (list_grandchild.id() != LogicalTypeId::VARCHAR || children_ref.size() != 1) { - throw BinderException("CSV Reader function option %s requires a non-empty list of possible null " - "strings (varchar) as input", - loption); - } - children = &ListValue::GetChildren(children_ref.back()); - } else if (list_child.id() != LogicalTypeId::VARCHAR) { - throw BinderException("CSV Reader function option %s requires a non-empty list of possible null " - "strings (varchar) as input", - loption); - } - if (!children) { - children = &ListValue::GetChildren(value); - } - for (auto &child : *children) { - if (child.IsNull()) { - throw BinderException( - "CSV Reader function option %s does not accept NULL values as a valid nullstr option", loption); - } - null_str.push_back(StringValue::Get(child)); - } - } else { - null_str.push_back(StringValue::Get(ParseString(value, loption))); - } - if (null_str.size() > 1 && write_option) { - throw BinderException("CSV Writer function option %s only accepts one nullstr value.", loption); - } - - } else if (loption == "compression") { - SetCompression(ParseString(value, loption)); - } else if (loption == "rfc_4180") { - SetRFC4180(ParseBoolean(value, loption)); - } else { - // unrecognized option in base CSV - return false; - } - return true; -} - -template -string FormatOptionLine(const string &name, const CSVOption &option) { - return name + " = " + option.FormatValue() + " " + option.FormatSet() + "\n "; -} - -bool CSVReaderOptions::WasTypeManuallySet(idx_t i) const { - if (i >= was_type_manually_set.size()) { - return false; - } - return was_type_manually_set[i]; -} - -string CSVReaderOptions::ToString(const string ¤t_file_path) const { - auto &delimiter = dialect_options.state_machine_options.delimiter; - auto "e = dialect_options.state_machine_options.quote; - auto &escape = dialect_options.state_machine_options.escape; - auto &comment = dialect_options.state_machine_options.comment; - auto &new_line = dialect_options.state_machine_options.new_line; - auto &rfc_4180 = dialect_options.state_machine_options.rfc_4180; - auto &skip_rows = dialect_options.skip_rows; - - auto &header = dialect_options.header; - string error = " file = " + current_file_path + "\n "; - // Let's first print options that can either be set by the user or by the sniffer - // delimiter - error += FormatOptionLine("delimiter", delimiter); - // quote - error += FormatOptionLine("quote", quote); - // escape - error += FormatOptionLine("escape", escape); - // newline - error += FormatOptionLine("new_line", new_line); - // has_header - error += FormatOptionLine("header", header); - // skip_rows - error += FormatOptionLine("skip_rows", skip_rows); - // comment - error += FormatOptionLine("comment", comment); - // rfc_4180 - error += FormatOptionLine("rfc_4180", rfc_4180); - // date format - error += FormatOptionLine("date_format", dialect_options.date_format.at(LogicalType::DATE)); - // timestamp format - error += FormatOptionLine("timestamp_format", dialect_options.date_format.at(LogicalType::TIMESTAMP)); - - // Now we do options that can only be set by the user, that might hold some general significance - // null padding - error += "null_padding = " + std::to_string(null_padding) + "\n "; - // sample_size - error += "sample_size = " + std::to_string(sample_size_chunks * STANDARD_VECTOR_SIZE) + "\n "; - // ignore_errors - error += "ignore_errors = " + ignore_errors.FormatValue() + "\n "; - // all_varchar - error += "all_varchar = " + std::to_string(all_varchar) + "\n"; - - // Add information regarding sniffer mismatches (if any) - error += sniffer_user_mismatch_error; - return error; -} - -static Value StringVectorToValue(const vector &vec) { - vector content; - content.reserve(vec.size()); - for (auto &item : vec) { - content.push_back(Value(item)); - } - return Value::LIST(LogicalType::VARCHAR, std::move(content)); -} - -static uint8_t GetCandidateSpecificity(const LogicalType &candidate_type) { - //! Const ht with accepted auto_types and their weights in specificity - const duckdb::unordered_map auto_type_candidates_specificity { - {static_cast(LogicalTypeId::VARCHAR), 0}, {static_cast(LogicalTypeId::DOUBLE), 1}, - {static_cast(LogicalTypeId::FLOAT), 2}, {static_cast(LogicalTypeId::DECIMAL), 3}, - {static_cast(LogicalTypeId::BIGINT), 4}, {static_cast(LogicalTypeId::INTEGER), 5}, - {static_cast(LogicalTypeId::SMALLINT), 6}, {static_cast(LogicalTypeId::TINYINT), 7}, - {static_cast(LogicalTypeId::TIMESTAMP), 8}, {static_cast(LogicalTypeId::DATE), 9}, - {static_cast(LogicalTypeId::TIME), 10}, {static_cast(LogicalTypeId::BOOLEAN), 11}, - {static_cast(LogicalTypeId::SQLNULL), 12}}; - - auto id = static_cast(candidate_type.id()); - auto it = auto_type_candidates_specificity.find(id); - if (it == auto_type_candidates_specificity.end()) { - throw BinderException("Auto Type Candidate of type %s is not accepted as a valid input", - EnumUtil::ToString(candidate_type.id())); - } - return it->second; -} -bool StoreUserDefinedParameter(const string &option) { - if (option == "column_types" || option == "types" || option == "dtypes" || option == "auto_detect" || - option == "auto_type_candidates" || option == "columns" || option == "names") { - // We don't store options related to types, names and auto-detection since these are either irrelevant to our - // prompt or are covered by the columns option. - return false; - } - return true; -} - -void CSVReaderOptions::Verify() { - if (rejects_table_name.IsSetByUser() && !store_rejects.GetValue() && store_rejects.IsSetByUser()) { - throw BinderException("REJECTS_TABLE option is only supported when store_rejects is not manually set to false"); - } - if (rejects_scan_name.IsSetByUser() && !store_rejects.GetValue() && store_rejects.IsSetByUser()) { - throw BinderException("REJECTS_SCAN option is only supported when store_rejects is not manually set to false"); - } - if (rejects_scan_name.IsSetByUser() || rejects_table_name.IsSetByUser()) { - // Ensure we set store_rejects to true automagically - store_rejects.Set(true, false); - } - // Validate rejects_table options - if (store_rejects.GetValue()) { - if (!ignore_errors.GetValue() && ignore_errors.IsSetByUser()) { - throw BinderException( - "STORE_REJECTS option is only supported when IGNORE_ERRORS is not manually set to false"); - } - // Ensure we set ignore errors to true automagically - ignore_errors.Set(true, false); - if (file_options.union_by_name) { - throw BinderException("REJECTS_TABLE option is not supported when UNION_BY_NAME is set to true"); - } - } - if (rejects_limit != 0 && !store_rejects.GetValue()) { - throw BinderException("REJECTS_LIMIT option is only supported when REJECTS_TABLE is set to a table name"); - } - // Validate CSV Buffer and max_line_size do not conflict. - if (buffer_size_option.IsSetByUser() && maximum_line_size.IsSetByUser()) { - if (buffer_size_option.GetValue() < maximum_line_size.GetValue()) { - throw BinderException("BUFFER_SIZE option was set to %d, while MAX_LINE_SIZE was set to %d. BUFFER_SIZE " - "must have always be set to value bigger than MAX_LINE_SIZE", - buffer_size_option.GetValue(), maximum_line_size.GetValue()); - } - } else if (maximum_line_size.IsSetByUser() && maximum_line_size.GetValue() > max_line_size_default) { - // If the max line size is set by the user and bigger than we have by default, we make it part of our buffer - // size decision. - buffer_size_option.Set(CSVBuffer::ROWS_PER_BUFFER * maximum_line_size.GetValue(), false); - } -} - -bool GetBooleanValue(const pair &option) { - if (option.second.IsNull()) { - throw BinderException("read_csv %s cannot be NULL", option.first); - } - return BooleanValue::Get(option.second); -} - -void CSVReaderOptions::FromNamedParameters(const named_parameter_map_t &in, ClientContext &context) { - map ordered_user_defined_parameters; - for (auto &kv : in) { - if (MultiFileReader().ParseOption(kv.first, kv.second, file_options, context)) { - continue; - } - auto loption = StringUtil::Lower(kv.first); - // skip variables that are specific to auto-detection - if (StoreUserDefinedParameter(loption)) { - ordered_user_defined_parameters[loption] = kv.second.ToSQLString(); - } - if (loption == "columns") { - if (!name_list.empty()) { - throw BinderException("read_csv column_names/names can only be supplied once"); - } - columns_set = true; - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::STRUCT) { - throw BinderException("read_csv columns requires a struct as input"); - } - auto &struct_children = StructValue::GetChildren(kv.second); - D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - auto &name = StructType::GetChildName(child_type, i); - auto &val = struct_children[i]; - name_list.push_back(name); - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv requires a type specification as string"); - } - sql_types_per_column[name] = i; - sql_type_list.emplace_back(TransformStringToLogicalType(StringValue::Get(val), context)); - } - if (name_list.empty()) { - throw BinderException("read_csv requires at least a single column as input!"); - } - } else if (loption == "auto_type_candidates") { - auto_type_candidates.clear(); - map candidate_types; - // We always have the extremes of Null and Varchar, so we can default to varchar if the - // sniffer is not able to confidently detect that column type - candidate_types[GetCandidateSpecificity(LogicalType::VARCHAR)] = LogicalType::VARCHAR; - candidate_types[GetCandidateSpecificity(LogicalType::SQLNULL)] = LogicalType::SQLNULL; - - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::LIST) { - throw BinderException("read_csv auto_types requires a list as input"); - } - auto &list_children = ListValue::GetChildren(kv.second); - if (list_children.empty()) { - throw BinderException("auto_type_candidates requires at least one type"); - } - for (auto &child : list_children) { - if (child.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("auto_type_candidates requires a type specification as string"); - } - auto candidate_type = TransformStringToLogicalType(StringValue::Get(child), context); - candidate_types[GetCandidateSpecificity(candidate_type)] = candidate_type; - } - for (auto &candidate_type : candidate_types) { - auto_type_candidates.emplace_back(candidate_type.second); - } - } else if (loption == "column_names" || loption == "names") { - unordered_set column_names; - if (!name_list.empty()) { - throw BinderException("read_csv column_names/names can only be supplied once"); - } - if (kv.second.IsNull()) { - throw BinderException("read_csv %s cannot be NULL", kv.first); - } - auto &children = ListValue::GetChildren(kv.second); - for (auto &child : children) { - name_list.push_back(StringValue::Get(child)); - } - for (auto &name : name_list) { - bool empty = true; - for (auto &c : name) { - if (!StringUtil::CharacterIsSpace(c)) { - empty = false; - break; - } - } - if (empty) { - throw BinderException("read_csv %s cannot have empty (or all whitespace) value", kv.first); - } - if (column_names.find(name) != column_names.end()) { - throw BinderException("read_csv %s must have unique values. \"%s\" is repeated.", kv.first, name); - } - column_names.insert(name); - } - } else if (loption == "column_types" || loption == "types" || loption == "dtypes") { - auto &child_type = kv.second.type(); - if (child_type.id() != LogicalTypeId::STRUCT && child_type.id() != LogicalTypeId::LIST) { - throw BinderException("read_csv %s requires a struct or list as input", kv.first); - } - if (!sql_type_list.empty()) { - throw BinderException("read_csv column_types/types/dtypes can only be supplied once"); - } - vector sql_type_names; - if (child_type.id() == LogicalTypeId::STRUCT) { - auto &struct_children = StructValue::GetChildren(kv.second); - D_ASSERT(StructType::GetChildCount(child_type) == struct_children.size()); - for (idx_t i = 0; i < struct_children.size(); i++) { - auto &name = StructType::GetChildName(child_type, i); - auto &val = struct_children[i]; - if (val.type().id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv %s requires a type specification as string", kv.first); - } - sql_type_names.push_back(StringValue::Get(val)); - sql_types_per_column[name] = i; - } - } else { - auto &list_child = ListType::GetChildType(child_type); - if (list_child.id() != LogicalTypeId::VARCHAR) { - throw BinderException("read_csv %s requires a list of types (varchar) as input", kv.first); - } - auto &children = ListValue::GetChildren(kv.second); - for (auto &child : children) { - sql_type_names.push_back(StringValue::Get(child)); - } - } - sql_type_list.reserve(sql_type_names.size()); - for (auto &sql_type : sql_type_names) { - auto def_type = TransformStringToLogicalType(sql_type, context); - if (def_type.id() == LogicalTypeId::USER) { - throw BinderException("Unrecognized type \"%s\" for read_csv %s definition", sql_type, kv.first); - } - sql_type_list.push_back(std::move(def_type)); - } - } else if (loption == "all_varchar") { - all_varchar = GetBooleanValue(kv); - } else if (loption == "normalize_names") { - normalize_names = GetBooleanValue(kv); - } else { - SetReadOption(loption, kv.second, name_list); - } - } - for (auto &udf_parameter : ordered_user_defined_parameters) { - user_defined_parameters += udf_parameter.first + "=" + udf_parameter.second + ", "; - } - if (user_defined_parameters.size() >= 2) { - user_defined_parameters.erase(user_defined_parameters.size() - 2); - } -} -//! This function is used to remember options set by the sniffer, for use in ReadCSVRelation -void CSVReaderOptions::ToNamedParameters(named_parameter_map_t &named_params) const { - auto &delimiter = dialect_options.state_machine_options.delimiter; - auto "e = dialect_options.state_machine_options.quote; - auto &escape = dialect_options.state_machine_options.escape; - auto &comment = dialect_options.state_machine_options.comment; - auto &rfc_4180 = dialect_options.state_machine_options.rfc_4180; - auto &header = dialect_options.header; - if (delimiter.IsSetByUser()) { - named_params["delim"] = Value(GetDelimiter()); - } - if (dialect_options.state_machine_options.new_line.IsSetByUser()) { - named_params["new_line"] = Value(GetNewline()); - } - if (quote.IsSetByUser()) { - named_params["quote"] = Value(GetQuote()); - } - if (escape.IsSetByUser()) { - named_params["escape"] = Value(GetEscape()); - } - if (comment.IsSetByUser()) { - named_params["comment"] = Value(GetComment()); - } - if (header.IsSetByUser()) { - named_params["header"] = Value(GetHeader()); - } - if (rfc_4180.IsSetByUser()) { - named_params["rfc_4180"] = Value(GetRFC4180()); - } - named_params["max_line_size"] = Value::BIGINT(NumericCast(maximum_line_size.GetValue())); - if (dialect_options.skip_rows.IsSetByUser()) { - named_params["skip"] = Value::UBIGINT(GetSkipRows()); - } - named_params["null_padding"] = Value::BOOLEAN(null_padding); - named_params["parallel"] = Value::BOOLEAN(parallel); - if (!dialect_options.date_format.at(LogicalType::DATE).GetValue().format_specifier.empty()) { - named_params["dateformat"] = - Value(dialect_options.date_format.at(LogicalType::DATE).GetValue().format_specifier); - } - if (!dialect_options.date_format.at(LogicalType::TIMESTAMP).GetValue().format_specifier.empty()) { - named_params["timestampformat"] = - Value(dialect_options.date_format.at(LogicalType::TIMESTAMP).GetValue().format_specifier); - } - - named_params["normalize_names"] = Value::BOOLEAN(normalize_names); - if (!name_list.empty() && !named_params.count("columns") && !named_params.count("column_names") && - !named_params.count("names")) { - named_params["column_names"] = StringVectorToValue(name_list); - } - named_params["all_varchar"] = Value::BOOLEAN(all_varchar); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/csv_scanner/util/csv_validator.cpp b/src/duckdb/src/execution/operator/csv_scanner/util/csv_validator.cpp deleted file mode 100644 index 1b8d3f348..000000000 --- a/src/duckdb/src/execution/operator/csv_scanner/util/csv_validator.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include "duckdb/execution/operator/csv_scanner/csv_validator.hpp" -#include - -namespace duckdb { - -void ThreadLines::Insert(idx_t thread_idx, ValidatorLine line_info) { - D_ASSERT(thread_lines.find(thread_idx) == thread_lines.end()); - thread_lines.insert({thread_idx, line_info}); -} - -string ThreadLines::Print() const { - string result; - for (auto &line : thread_lines) { - result += "{start_pos: " + std::to_string(line.second.start_pos) + - ", end_pos: " + std::to_string(line.second.end_pos) + "}"; - } - return result; -} - -void ThreadLines::Verify() const { - bool initialized = false; - idx_t last_end_pos = 0; - for (auto &line_info : thread_lines) { - if (!initialized) { - // First run, we just set the initialized to true - initialized = true; - } else { - if (line_info.second.start_pos == line_info.second.end_pos) { - last_end_pos = line_info.second.end_pos; - continue; - } - if (last_end_pos + error_margin < line_info.second.start_pos || - line_info.second.start_pos < last_end_pos - error_margin) { - std::ostringstream error; - error << "The Parallel CSV Reader currently does not support a full read on this file." << '\n'; - error << "To correctly parse this file, please run with the single threaded error (i.e., parallel = " - "false)" - << '\n'; - throw NotImplementedException(error.str()); - } - } - last_end_pos = line_info.second.end_pos; - } -} - -void CSVValidator::Insert(idx_t file_idx, idx_t thread_idx, ValidatorLine line_info) { - if (per_file_thread_lines.size() <= file_idx) { - per_file_thread_lines.resize(file_idx + 1); - } - per_file_thread_lines[file_idx].Insert(thread_idx, line_info); -} - -void CSVValidator::Verify() const { - for (auto &file : per_file_thread_lines) { - file.Verify(); - } -} - -string CSVValidator::Print(idx_t file_idx) const { - return per_file_thread_lines[file_idx].Print(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/filter/physical_filter.cpp b/src/duckdb/src/execution/operator/filter/physical_filter.cpp deleted file mode 100644 index e2a25d9fd..000000000 --- a/src/duckdb/src/execution/operator/filter/physical_filter.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "duckdb/execution/operator/filter/physical_filter.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/parallel/thread_context.hpp" -namespace duckdb { - -PhysicalFilter::PhysicalFilter(vector types, vector> select_list, - idx_t estimated_cardinality) - : CachingPhysicalOperator(PhysicalOperatorType::FILTER, std::move(types), estimated_cardinality) { - D_ASSERT(select_list.size() > 0); - if (select_list.size() > 1) { - // create a big AND out of the expressions - auto conjunction = make_uniq(ExpressionType::CONJUNCTION_AND); - for (auto &expr : select_list) { - conjunction->children.push_back(std::move(expr)); - } - expression = std::move(conjunction); - } else { - expression = std::move(select_list[0]); - } -} - -class FilterState : public CachingOperatorState { -public: - explicit FilterState(ExecutionContext &context, Expression &expr) - : executor(context.client, expr), sel(STANDARD_VECTOR_SIZE) { - } - - ExpressionExecutor executor; - SelectionVector sel; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op); - } -}; - -unique_ptr PhysicalFilter::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context, *expression); -} - -OperatorResultType PhysicalFilter::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - idx_t result_count = state.executor.SelectExpression(input, state.sel); - if (result_count == input.size()) { - // nothing was filtered: skip adding any selection vectors - chunk.Reference(input); - } else { - chunk.Slice(input, state.sel, result_count); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -InsertionOrderPreservingMap PhysicalFilter::ParamsToString() const { - InsertionOrderPreservingMap result; - result["__expression__"] = expression->GetName(); - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp deleted file mode 100644 index c489124ee..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_batch_collector.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_batch_collector.hpp" - -#include "duckdb/common/types/batched_data_collection.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/materialized_query_result.hpp" - -namespace duckdb { - -PhysicalBatchCollector::PhysicalBatchCollector(PreparedStatementData &data) : PhysicalResultCollector(data) { -} - -SinkResultType PhysicalBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &state = input.local_state.Cast(); - state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBatchCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &state = input.local_state.Cast(); - - lock_guard lock(gstate.glock); - gstate.data.Merge(state.data); - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalBatchCollector::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto collection = gstate.data.FetchCollection(); - D_ASSERT(collection); - auto result = make_uniq(statement_type, properties, names, std::move(collection), - context.GetClientProperties()); - gstate.result = std::move(result); - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalBatchCollector::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -unique_ptr PhysicalBatchCollector::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalBatchCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - D_ASSERT(gstate.result); - return std::move(gstate.result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp deleted file mode 100644 index f881da2a9..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_batch_collector.cpp +++ /dev/null @@ -1,109 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp" - -#include "duckdb/common/types/batched_data_collection.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/materialized_query_result.hpp" -#include "duckdb/main/buffered_data/buffered_data.hpp" -#include "duckdb/main/buffered_data/batched_buffered_data.hpp" -#include "duckdb/main/stream_query_result.hpp" - -namespace duckdb { - -PhysicalBufferedBatchCollector::PhysicalBufferedBatchCollector(PreparedStatementData &data) - : PhysicalResultCollector(data) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BufferedBatchCollectorGlobalState : public GlobalSinkState { -public: - weak_ptr context; - shared_ptr buffered_data; -}; - -BufferedBatchCollectorLocalState::BufferedBatchCollectorLocalState() { -} - -SinkResultType PhysicalBufferedBatchCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - lstate.current_batch = lstate.partition_info.batch_index.GetIndex(); - auto batch = lstate.partition_info.batch_index.GetIndex(); - auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); - - auto &buffered_data = gstate.buffered_data->Cast(); - buffered_data.UpdateMinBatchIndex(min_batch_index); - - if (buffered_data.ShouldBlockBatch(batch)) { - auto callback_state = input.interrupt_state; - buffered_data.BlockSink(callback_state, batch); - return SinkResultType::BLOCKED; - } - - // FIXME: if we want to make this more accurate, we should grab a reservation on the buffer space - // while we're unlocked some other thread could also append, causing us to potentially cross our buffer size - - buffered_data.Append(chunk, batch); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkNextBatchType PhysicalBufferedBatchCollector::NextBatch(ExecutionContext &context, - OperatorSinkNextBatchInput &input) const { - - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto batch = lstate.current_batch; - auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); - auto new_index = lstate.partition_info.batch_index.GetIndex(); - - auto &buffered_data = gstate.buffered_data->Cast(); - buffered_data.CompleteBatch(batch); - lstate.current_batch = new_index; - // FIXME: this can move from the buffer to the read queue, increasing the 'read_queue_byte_count' - // We might want to block here if 'read_queue_byte_count' has already reached the ReadQueueCapacity() - // So we don't completely disregard the 'streaming_buffer_size' that was set - buffered_data.UpdateMinBatchIndex(min_batch_index); - return SinkNextBatchType::READY; -} - -SinkCombineResultType PhysicalBufferedBatchCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto min_batch_index = lstate.partition_info.min_batch_index.GetIndex(); - auto &buffered_data = gstate.buffered_data->Cast(); - - // FIXME: this can move from the buffer to the read queue, increasing the 'read_queue_byte_count' - // We might want to block here if 'read_queue_byte_count' has already reached the ReadQueueCapacity() - // So we don't completely disregard the 'streaming_buffer_size' that was set - buffered_data.UpdateMinBatchIndex(min_batch_index); - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalBufferedBatchCollector::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); - return std::move(state); -} - -unique_ptr PhysicalBufferedBatchCollector::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); - return std::move(state); -} - -unique_ptr PhysicalBufferedBatchCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - auto cc = gstate.context.lock(); - auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), - gstate.buffered_data); - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp deleted file mode 100644 index d52b7bc49..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_buffered_collector.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_buffered_collector.hpp" -#include "duckdb/main/stream_query_result.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -PhysicalBufferedCollector::PhysicalBufferedCollector(PreparedStatementData &data, bool parallel) - : PhysicalResultCollector(data), parallel(parallel) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BufferedCollectorGlobalState : public GlobalSinkState { -public: - mutex glock; - //! This is weak to avoid creating a cyclical reference - weak_ptr context; - shared_ptr buffered_data; -}; - -class BufferedCollectorLocalState : public LocalSinkState {}; - -SinkResultType PhysicalBufferedCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - (void)lstate; - - lock_guard l(gstate.glock); - auto &buffered_data = gstate.buffered_data->Cast(); - - if (buffered_data.BufferIsFull()) { - auto callback_state = input.interrupt_state; - buffered_data.BlockSink(callback_state); - return SinkResultType::BLOCKED; - } - buffered_data.Append(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBufferedCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalBufferedCollector::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - state->context = context.shared_from_this(); - state->buffered_data = make_shared_ptr(state->context); - return std::move(state); -} - -unique_ptr PhysicalBufferedCollector::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); - return std::move(state); -} - -unique_ptr PhysicalBufferedCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - lock_guard l(gstate.glock); - // FIXME: maybe we want to check if the execution was successful before creating the StreamQueryResult ? - auto cc = gstate.context.lock(); - auto result = make_uniq(statement_type, properties, types, names, cc->GetClientProperties(), - gstate.buffered_data); - return std::move(result); -} - -bool PhysicalBufferedCollector::ParallelSink() const { - return parallel; -} - -bool PhysicalBufferedCollector::SinkOrderDependent() const { - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp b/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp deleted file mode 100644 index 3a1dbe8b6..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_create_secret.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_create_secret.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/secret/secret_manager.hpp" - -namespace duckdb { - -SourceResultType PhysicalCreateSecret::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - auto &secret_manager = SecretManager::Get(client); - - secret_manager.CreateSecret(client, info); - - chunk.SetValue(0, 0, true); - chunk.SetCardinality(1); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_execute.cpp b/src/duckdb/src/execution/operator/helper/physical_execute.cpp deleted file mode 100644 index 22cab9be8..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_execute.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_execute.hpp" - -#include "duckdb/parallel/meta_pipeline.hpp" - -namespace duckdb { - -PhysicalExecute::PhysicalExecute(PhysicalOperator &plan) - : PhysicalOperator(PhysicalOperatorType::EXECUTE, plan.types, idx_t(-1)), plan(plan) { -} - -vector> PhysicalExecute::GetChildren() const { - return {plan}; -} - -void PhysicalExecute::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // EXECUTE statement: build pipeline on child - meta_pipeline.Build(plan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp b/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp deleted file mode 100644 index 8d2c12811..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_explain_analyze.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_explain_analyze.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/query_profiler.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class ExplainAnalyzeStateGlobalState : public GlobalSinkState { -public: - string analyzed_plan; -}; - -SinkResultType PhysicalExplainAnalyze::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - return SinkResultType::NEED_MORE_INPUT; -} - -SinkFinalizeType PhysicalExplainAnalyze::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &profiler = QueryProfiler::Get(context); - gstate.analyzed_plan = profiler.ToString(format); - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalExplainAnalyze::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalExplainAnalyze::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - - chunk.SetValue(0, 0, Value("analyzed_plan")); - chunk.SetValue(1, 0, Value(gstate.analyzed_plan)); - chunk.SetCardinality(1); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_limit.cpp deleted file mode 100644 index 936cd311c..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_limit.cpp +++ /dev/null @@ -1,244 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_limit.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/types/batched_data_collection.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -PhysicalLimit::PhysicalLimit(vector types, BoundLimitNode limit_val_p, BoundLimitNode offset_val_p, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::LIMIT, std::move(types), estimated_cardinality), - limit_val(std::move(limit_val_p)), offset_val(std::move(offset_val_p)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class LimitGlobalState : public GlobalSinkState { -public: - explicit LimitGlobalState(ClientContext &context, const PhysicalLimit &op) : data(context, op.types, true) { - limit = 0; - offset = 0; - } - - mutex glock; - idx_t limit; - idx_t offset; - BatchedDataCollection data; -}; - -class LimitLocalState : public LocalSinkState { -public: - explicit LimitLocalState(ClientContext &context, const PhysicalLimit &op) - : current_offset(0), data(context, op.types, true) { - PhysicalLimit::SetInitialLimits(op.limit_val, op.offset_val, limit, offset); - } - - idx_t current_offset; - optional_idx limit; - optional_idx offset; - BatchedDataCollection data; -}; - -unique_ptr PhysicalLimit::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalLimit::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void PhysicalLimit::SetInitialLimits(const BoundLimitNode &limit_val, const BoundLimitNode &offset_val, - optional_idx &limit, optional_idx &offset) { - switch (limit_val.Type()) { - case LimitNodeType::CONSTANT_VALUE: - limit = limit_val.GetConstantValue(); - break; - case LimitNodeType::UNSET: - limit = MAX_LIMIT_VALUE; - break; - default: - break; - } - switch (offset_val.Type()) { - case LimitNodeType::CONSTANT_VALUE: - offset = offset_val.GetConstantValue(); - break; - case LimitNodeType::UNSET: - offset = 0; - break; - default: - break; - } -} - -bool PhysicalLimit::ComputeOffset(ExecutionContext &context, DataChunk &input, optional_idx &limit, - optional_idx &offset, idx_t current_offset, idx_t &max_element, - const BoundLimitNode &limit_val, const BoundLimitNode &offset_val) { - if (!limit.IsValid()) { - Value val = GetDelimiter(context, input, limit_val.GetValueExpression()); - if (!val.IsNull()) { - limit = val.GetValue(); - } else { - limit = MAX_LIMIT_VALUE; - } - if (limit.GetIndex() > MAX_LIMIT_VALUE) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", limit.GetIndex(), MAX_LIMIT_VALUE); - } - } - if (!offset.IsValid()) { - Value val = GetDelimiter(context, input, offset_val.GetValueExpression()); - if (!val.IsNull()) { - offset = val.GetValue(); - } else { - offset = 0; - } - if (offset.GetIndex() > MAX_LIMIT_VALUE) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset.GetIndex(), MAX_LIMIT_VALUE); - } - } - max_element = limit.GetIndex() + offset.GetIndex(); - if (limit == 0 || current_offset >= max_element) { - return false; - } - return true; -} - -SinkResultType PhysicalLimit::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - - D_ASSERT(chunk.size() > 0); - auto &state = input.local_state.Cast(); - auto &limit = state.limit; - auto &offset = state.offset; - - idx_t max_element; - if (!ComputeOffset(context, chunk, limit, offset, state.current_offset, max_element, limit_val, offset_val)) { - return SinkResultType::FINISHED; - } - auto max_cardinality = max_element - state.current_offset; - if (max_cardinality < chunk.size()) { - chunk.SetCardinality(max_cardinality); - } - state.data.Append(chunk, state.partition_info.batch_index.GetIndex()); - state.current_offset += chunk.size(); - if (state.current_offset == max_element) { - return SinkResultType::FINISHED; - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalLimit::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &state = input.local_state.Cast(); - - lock_guard lock(gstate.glock); - if (state.limit.IsValid()) { - gstate.limit = state.limit.GetIndex(); - } - if (state.offset.IsValid()) { - gstate.offset = state.offset.GetIndex(); - } - gstate.data.Merge(state.data); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class LimitSourceState : public GlobalSourceState { -public: - LimitSourceState() { - initialized = false; - current_offset = 0; - } - - bool initialized; - idx_t current_offset; - BatchedChunkScanState scan_state; -}; - -unique_ptr PhysicalLimit::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalLimit::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - auto &state = input.global_state.Cast(); - while (state.current_offset < gstate.limit + gstate.offset) { - if (!state.initialized) { - gstate.data.InitializeScan(state.scan_state); - state.initialized = true; - } - gstate.data.Scan(state.scan_state, chunk); - if (chunk.size() == 0) { - return SourceResultType::FINISHED; - } - if (HandleOffset(chunk, state.current_offset, gstate.offset, gstate.limit)) { - break; - } - } - - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; -} - -bool PhysicalLimit::HandleOffset(DataChunk &input, idx_t ¤t_offset, idx_t offset, idx_t limit) { - idx_t max_element = limit + offset; - if (limit == DConstants::INVALID_INDEX) { - max_element = DConstants::INVALID_INDEX; - } - idx_t input_size = input.size(); - if (current_offset < offset) { - // we are not yet at the offset point - if (current_offset + input.size() > offset) { - // however we will reach it in this chunk - // we have to copy part of the chunk with an offset - idx_t start_position = offset - current_offset; - auto chunk_count = MinValue(limit, input.size() - start_position); - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < chunk_count; i++) { - sel.set_index(i, start_position + i); - } - // set up a slice of the input chunks - input.Slice(input, sel, chunk_count); - } else { - current_offset += input_size; - return false; - } - } else { - // have to copy either the entire chunk or part of it - idx_t chunk_count; - if (current_offset + input.size() >= max_element) { - // have to limit the count of the chunk - chunk_count = max_element - current_offset; - } else { - // we copy the entire chunk - chunk_count = input.size(); - } - // instead of copying we just change the pointer in the current chunk - input.Reference(input); - input.SetCardinality(chunk_count); - } - - current_offset += input_size; - return true; -} - -Value PhysicalLimit::GetDelimiter(ExecutionContext &context, DataChunk &input, const Expression &expr) { - DataChunk limit_chunk; - vector types {expr.return_type}; - auto &allocator = Allocator::Get(context.client); - limit_chunk.Initialize(allocator, types); - ExpressionExecutor limit_executor(context.client, &expr); - auto input_size = input.size(); - input.SetCardinality(1); - limit_executor.Execute(input, limit_chunk); - input.SetCardinality(input_size); - auto limit_value = limit_chunk.GetValue(0, 0); - return limit_value; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp b/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp deleted file mode 100644 index 69794be09..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_limit_percent.cpp +++ /dev/null @@ -1,166 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_limit_percent.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/helper/physical_limit.hpp" - -namespace duckdb { - -PhysicalLimitPercent::PhysicalLimitPercent(vector types, BoundLimitNode limit_val_p, - BoundLimitNode offset_val_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::LIMIT_PERCENT, std::move(types), estimated_cardinality), - limit_val(std::move(limit_val_p)), offset_val(std::move(offset_val_p)) { - D_ASSERT(limit_val.Type() == LimitNodeType::CONSTANT_PERCENTAGE || - limit_val.Type() == LimitNodeType::EXPRESSION_PERCENTAGE); -} -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class LimitPercentGlobalState : public GlobalSinkState { -public: - explicit LimitPercentGlobalState(ClientContext &context, const PhysicalLimitPercent &op) - : current_offset(0), data(context, op.GetTypes()) { - switch (op.limit_val.Type()) { - case LimitNodeType::CONSTANT_PERCENTAGE: - this->limit_percent = op.limit_val.GetConstantPercentage(); - this->is_limit_set = true; - break; - case LimitNodeType::EXPRESSION_PERCENTAGE: - this->is_limit_set = false; - break; - default: - throw InternalException("Unsupported type for limit value in PhysicalLimitPercent"); - } - switch (op.offset_val.Type()) { - case LimitNodeType::CONSTANT_VALUE: - this->offset = op.offset_val.GetConstantValue(); - break; - case LimitNodeType::UNSET: - this->offset = 0; - break; - case LimitNodeType::EXPRESSION_VALUE: - break; - default: - throw InternalException("Unsupported type for offset value in PhysicalLimitPercent"); - } - } - - idx_t current_offset; - double limit_percent; - optional_idx offset; - ColumnDataCollection data; - - bool is_limit_set = false; -}; - -unique_ptr PhysicalLimitPercent::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalLimitPercent::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - D_ASSERT(chunk.size() > 0); - auto &state = input.global_state.Cast(); - auto &limit_percent = state.limit_percent; - auto &offset = state.offset; - - // get the next chunk from the child - if (!state.is_limit_set) { - Value val = PhysicalLimit::GetDelimiter(context, chunk, limit_val.GetPercentageExpression()); - if (!val.IsNull()) { - limit_percent = val.GetValue(); - } else { - limit_percent = 100.0; - } - if (limit_percent < 0.0) { - throw BinderException("Percentage value(%f) can't be negative", limit_percent); - } - state.is_limit_set = true; - } - if (!state.offset.IsValid()) { - Value val = PhysicalLimit::GetDelimiter(context, chunk, offset_val.GetValueExpression()); - if (!val.IsNull()) { - offset = val.GetValue(); - } else { - offset = 0; - } - if (offset.GetIndex() > 1ULL << 62ULL) { - throw BinderException("Max value %lld for LIMIT/OFFSET is %lld", offset.GetIndex(), 1ULL << 62ULL); - } - } - - if (!PhysicalLimit::HandleOffset(chunk, state.current_offset, offset.GetIndex(), NumericLimits::Maximum())) { - return SinkResultType::NEED_MORE_INPUT; - } - - state.data.Append(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class LimitPercentOperatorState : public GlobalSourceState { -public: - explicit LimitPercentOperatorState(const PhysicalLimitPercent &op) : current_offset(0) { - D_ASSERT(op.sink_state); - auto &gstate = op.sink_state->Cast(); - gstate.data.InitializeScan(scan_state); - } - - ColumnDataScanState scan_state; - optional_idx limit; - idx_t current_offset; -}; - -unique_ptr PhysicalLimitPercent::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalLimitPercent::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - auto &state = input.global_state.Cast(); - auto &percent_limit = gstate.limit_percent; - auto &offset = gstate.offset; - auto &limit = state.limit; - auto ¤t_offset = state.current_offset; - - if (!limit.IsValid()) { - if (!gstate.is_limit_set) { - // no limit value and we have not set limit_percent - // we are running LIMIT % with a subquery over an empty table - D_ASSERT(gstate.data.Count() == 0); - return SourceResultType::FINISHED; - } - idx_t count = gstate.data.Count(); - if (count > 0) { - count += offset.GetIndex(); - } - if (Value::IsNan(percent_limit) || percent_limit < 0 || percent_limit > 100) { - throw OutOfRangeException("Limit percent out of range, should be between 0% and 100%"); - } - auto limit_percentage = idx_t(percent_limit / 100.0 * double(count)); - if (limit_percentage > count) { - limit = count; - } else { - limit = idx_t(limit_percentage); - } - if (limit == 0) { - return SourceResultType::FINISHED; - } - } - - if (current_offset >= limit.GetIndex()) { - return SourceResultType::FINISHED; - } - if (!gstate.data.Scan(state.scan_state, chunk)) { - return SourceResultType::FINISHED; - } - - PhysicalLimit::HandleOffset(chunk, current_offset, 0, limit.GetIndex()); - - return SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_load.cpp b/src/duckdb/src/execution/operator/helper/physical_load.cpp deleted file mode 100644 index 5f0e7a027..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_load.cpp +++ /dev/null @@ -1,47 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_load.hpp" -#include "duckdb/main/extension_helper.hpp" - -namespace duckdb { - -static void InstallFromRepository(ClientContext &context, const LoadInfo &info) { - ExtensionRepository repository; - if (!info.repository.empty() && info.repo_is_alias) { - auto repository_url = ExtensionRepository::TryGetRepositoryUrl(info.repository); - // This has been checked during bind, so it should not fail here - if (repository_url.empty()) { - throw InternalException("The repository alias failed to resolve"); - } - repository = ExtensionRepository(info.repository, repository_url); - } else if (!info.repository.empty()) { - repository = ExtensionRepository::GetRepositoryByUrl(info.repository); - } - - ExtensionInstallOptions options; - options.force_install = info.load_type == LoadType::FORCE_INSTALL; - options.throw_on_origin_mismatch = true; - options.version = info.version; - options.repository = repository; - - ExtensionHelper::InstallExtension(context, info.filename, options); -} - -SourceResultType PhysicalLoad::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - if (info->load_type == LoadType::INSTALL || info->load_type == LoadType::FORCE_INSTALL) { - if (info->repository.empty()) { - ExtensionInstallOptions options; - options.force_install = info->load_type == LoadType::FORCE_INSTALL; - options.throw_on_origin_mismatch = true; - options.version = info->version; - ExtensionHelper::InstallExtension(context.client, info->filename, options); - } else { - InstallFromRepository(context.client, *info); - } - - } else { - ExtensionHelper::LoadExternalExtension(context.client, info->filename); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp deleted file mode 100644 index 2492a3a89..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_materialized_collector.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_materialized_collector.hpp" - -#include "duckdb/main/materialized_query_result.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -PhysicalMaterializedCollector::PhysicalMaterializedCollector(PreparedStatementData &data, bool parallel) - : PhysicalResultCollector(data), parallel(parallel) { -} - -SinkResultType PhysicalMaterializedCollector::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.collection->Append(lstate.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalMaterializedCollector::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (lstate.collection->Count() == 0) { - return SinkCombineResultType::FINISHED; - } - - lock_guard l(gstate.glock); - if (!gstate.collection) { - gstate.collection = std::move(lstate.collection); - } else { - gstate.collection->Combine(*lstate.collection); - } - - return SinkCombineResultType::FINISHED; -} - -unique_ptr PhysicalMaterializedCollector::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - state->context = context.shared_from_this(); - return std::move(state); -} - -unique_ptr PhysicalMaterializedCollector::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); - state->collection = make_uniq(Allocator::DefaultAllocator(), types); - state->collection->InitializeAppend(state->append_state); - return std::move(state); -} - -unique_ptr PhysicalMaterializedCollector::GetResult(GlobalSinkState &state) { - auto &gstate = state.Cast(); - if (!gstate.collection) { - gstate.collection = make_uniq(Allocator::DefaultAllocator(), types); - } - auto result = make_uniq(statement_type, properties, names, std::move(gstate.collection), - gstate.context->GetClientProperties()); - return std::move(result); -} - -bool PhysicalMaterializedCollector::ParallelSink() const { - return parallel; -} - -bool PhysicalMaterializedCollector::SinkOrderDependent() const { - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_pragma.cpp b/src/duckdb/src/execution/operator/helper/physical_pragma.cpp deleted file mode 100644 index 34a5cb6fb..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_pragma.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_pragma.hpp" - -namespace duckdb { - -SourceResultType PhysicalPragma::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - FunctionParameters parameters {info->parameters, info->named_parameters}; - info->function.function(client, parameters); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp b/src/duckdb/src/execution/operator/helper/physical_prepare.cpp deleted file mode 100644 index 784d6ada4..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_prepare.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_prepare.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -SourceResultType PhysicalPrepare::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - - // store the prepared statement in the context - ClientData::Get(client).prepared_statements[name] = prepared; - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp deleted file mode 100644 index 32785a7ab..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_reservoir_sample.cpp +++ /dev/null @@ -1,103 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" -#include "duckdb/execution/reservoir_sample.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -class SampleGlobalSinkState : public GlobalSinkState { -public: - explicit SampleGlobalSinkState(Allocator &allocator, SampleOptions &options) { - if (options.is_percentage) { - auto percentage = options.sample_size.GetValue(); - if (percentage == 0) { - return; - } - sample = make_uniq(allocator, percentage, - static_cast(options.seed.GetIndex())); - } else { - auto size = NumericCast(options.sample_size.GetValue()); - if (size == 0) { - return; - } - sample = make_uniq(allocator, size, static_cast(options.seed.GetIndex())); - } - } - - //! The lock for updating the global aggoregate state - //! Also used to update the global sample when percentages are used - mutex lock; - //! The reservoir sample - unique_ptr sample; -}; - -unique_ptr PhysicalReservoirSample::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(Allocator::Get(context), *options); -} - -SinkResultType PhysicalReservoirSample::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &global_state = input.global_state.Cast(); - // Percentage only has a global sample. - lock_guard glock(global_state.lock); - if (!global_state.sample) { - // always gather full thread percentage - auto &allocator = Allocator::Get(context.client); - if (options->is_percentage) { - double percentage = options->sample_size.GetValue(); - if (percentage == 0) { - return SinkResultType::FINISHED; - } - global_state.sample = make_uniq(allocator, percentage, - static_cast(options->seed.GetIndex())); - } else { - idx_t num_samples = options->sample_size.GetValue(); - if (num_samples == 0) { - return SinkResultType::FINISHED; - } - global_state.sample = - make_uniq(allocator, num_samples, static_cast(options->seed.GetIndex())); - } - } - global_state.sample->AddToReservoir(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalReservoirSample::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalReservoirSample::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalReservoirSample::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink = this->sink_state->Cast(); - lock_guard glock(sink.lock); - if (!sink.sample) { - return SourceResultType::FINISHED; - } - auto sample_chunk = sink.sample->GetChunk(); - if (!sample_chunk) { - return SourceResultType::FINISHED; - } - chunk.Move(*sample_chunk); - - return SourceResultType::HAVE_MORE_OUTPUT; -} - -InsertionOrderPreservingMap PhysicalReservoirSample::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Sample Size"] = options->sample_size.ToString() + (options->is_percentage ? "%" : " rows"); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_reset.cpp b/src/duckdb/src/execution/operator/helper/physical_reset.cpp deleted file mode 100644 index b6a219a48..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_reset.cpp +++ /dev/null @@ -1,76 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_reset.hpp" - -#include "duckdb/common/string_util.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -void PhysicalReset::ResetExtensionVariable(ExecutionContext &context, DBConfig &config, - ExtensionOption &extension_option) const { - if (extension_option.set_function) { - extension_option.set_function(context.client, scope, extension_option.default_value); - } - if (scope == SetScope::GLOBAL) { - config.ResetOption(name); - } else { - auto &client_config = ClientConfig::GetConfig(context.client); - client_config.set_variables[name] = extension_option.default_value; - } -} - -SourceResultType PhysicalReset::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - if (scope == SetScope::VARIABLE) { - auto &client_config = ClientConfig::GetConfig(context.client); - client_config.ResetUserVariable(name); - return SourceResultType::FINISHED; - } - auto &config = DBConfig::GetConfig(context.client); - config.CheckLock(name); - auto option = DBConfig::GetOptionByName(name); - if (!option) { - // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); - if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); - } - ResetExtensionVariable(context, config, entry->second); - return SourceResultType::FINISHED; - } - - // Transform scope - SetScope variable_scope = scope; - if (variable_scope == SetScope::AUTOMATIC) { - if (option->set_local) { - variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); - variable_scope = SetScope::GLOBAL; - } - } - - switch (variable_scope) { - case SetScope::GLOBAL: { - if (!option->set_global) { - throw CatalogException("option \"%s\" cannot be reset globally", name); - } - auto &db = DatabaseInstance::GetDatabase(context.client); - config.ResetOption(&db, *option); - break; - } - case SetScope::SESSION: - if (!option->reset_local) { - throw CatalogException("option \"%s\" cannot be reset locally", name); - } - option->reset_local(context.client); - break; - default: - throw InternalException("Unsupported SetScope for variable"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp b/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp deleted file mode 100644 index 5dcb356e7..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_result_collector.cpp +++ /dev/null @@ -1,65 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_result_collector.hpp" - -#include "duckdb/execution/operator/helper/physical_batch_collector.hpp" -#include "duckdb/execution/operator/helper/physical_buffered_batch_collector.hpp" -#include "duckdb/execution/operator/helper/physical_materialized_collector.hpp" -#include "duckdb/execution/operator/helper/physical_buffered_collector.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/prepared_statement_data.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/main/query_result.hpp" -#include "duckdb/parallel/pipeline.hpp" - -namespace duckdb { - -PhysicalResultCollector::PhysicalResultCollector(PreparedStatementData &data) - : PhysicalOperator(PhysicalOperatorType::RESULT_COLLECTOR, {LogicalType::BOOLEAN}, 0), - statement_type(data.statement_type), properties(data.properties), plan(*data.plan), names(data.names) { - this->types = data.types; -} - -unique_ptr PhysicalResultCollector::GetResultCollector(ClientContext &context, - PreparedStatementData &data) { - if (!PhysicalPlanGenerator::PreserveInsertionOrder(context, *data.plan)) { - // the plan is not order preserving, so we just use the parallel materialized collector - if (data.is_streaming) { - return make_uniq_base(data, true); - } - return make_uniq_base(data, true); - } else if (!PhysicalPlanGenerator::UseBatchIndex(context, *data.plan)) { - // the plan is order preserving, but we cannot use the batch index: use a single-threaded result collector - if (data.is_streaming) { - return make_uniq_base(data, false); - } - return make_uniq_base(data, false); - } else { - // we care about maintaining insertion order and the sources all support batch indexes - // use a batch collector - if (data.is_streaming) { - return make_uniq_base(data); - } - return make_uniq_base(data); - } -} - -vector> PhysicalResultCollector::GetChildren() const { - return {plan}; -} - -void PhysicalResultCollector::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // operator is a sink, build a pipeline - sink_state.reset(); - - D_ASSERT(children.empty()); - - // single operator: the operator becomes the data source of the current pipeline - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - // we create a new pipeline starting from the child - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(plan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_set.cpp b/src/duckdb/src/execution/operator/helper/physical_set.cpp deleted file mode 100644 index 4a321a52e..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_set.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_set.hpp" - -#include "duckdb/common/string_util.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -void PhysicalSet::SetExtensionVariable(ClientContext &context, ExtensionOption &extension_option, const string &name, - SetScope scope, const Value &value) { - auto &config = DBConfig::GetConfig(context); - auto &target_type = extension_option.type; - Value target_value = value.CastAs(context, target_type); - if (extension_option.set_function) { - extension_option.set_function(context, scope, target_value); - } - if (scope == SetScope::GLOBAL) { - config.SetOption(name, std::move(target_value)); - } else { - auto &client_config = ClientConfig::GetConfig(context); - client_config.set_variables[name] = std::move(target_value); - } -} - -SourceResultType PhysicalSet::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &config = DBConfig::GetConfig(context.client); - // check if we are allowed to change the configuration option - config.CheckLock(name); - auto option = DBConfig::GetOptionByName(name); - if (!option) { - // check if this is an extra extension variable - auto entry = config.extension_parameters.find(name); - if (entry == config.extension_parameters.end()) { - Catalog::AutoloadExtensionByConfigName(context.client, name); - entry = config.extension_parameters.find(name); - D_ASSERT(entry != config.extension_parameters.end()); - } - SetExtensionVariable(context.client, entry->second, name, scope, value); - return SourceResultType::FINISHED; - } - SetScope variable_scope = scope; - if (variable_scope == SetScope::AUTOMATIC) { - if (option->set_local) { - variable_scope = SetScope::SESSION; - } else { - D_ASSERT(option->set_global); - variable_scope = SetScope::GLOBAL; - } - } - - Value input_val = value.CastAs(context.client, DBConfig::ParseLogicalType(option->parameter_type)); - switch (variable_scope) { - case SetScope::GLOBAL: { - if (!option->set_global) { - throw CatalogException("option \"%s\" cannot be set globally", name); - } - auto &db = DatabaseInstance::GetDatabase(context.client); - auto &config = DBConfig::GetConfig(context.client); - config.SetOption(&db, *option, input_val); - break; - } - case SetScope::SESSION: - if (!option->set_local) { - throw CatalogException("option \"%s\" cannot be set locally", name); - } - option->set_local(context.client, input_val); - break; - default: - throw InternalException("Unsupported SetScope for variable"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp b/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp deleted file mode 100644 index 37482d2c8..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_set_variable.cpp +++ /dev/null @@ -1,39 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_set_variable.hpp" -#include "duckdb/main/client_config.hpp" - -namespace duckdb { - -PhysicalSetVariable::PhysicalSetVariable(string name_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::SET_VARIABLE, {LogicalType::BOOLEAN}, estimated_cardinality), - name(std::move(name_p)) { -} - -SourceResultType PhysicalSetVariable::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - return SourceResultType::FINISHED; -} - -class SetVariableGlobalState : public GlobalSinkState { -public: - SetVariableGlobalState() { - } - - bool is_set = false; -}; - -unique_ptr PhysicalSetVariable::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -SinkResultType PhysicalSetVariable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - if (chunk.size() != 1 || gstate.is_set) { - throw InvalidInputException("PhysicalSetVariable can only handle a single value"); - } - auto &config = ClientConfig::GetConfig(context.client); - config.SetUserVariable(name, chunk.GetValue(0, 0)); - gstate.is_set = true; - return SinkResultType::NEED_MORE_INPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp deleted file mode 100644 index 881375b95..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_streaming_limit.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" -#include "duckdb/execution/operator/helper/physical_limit.hpp" - -namespace duckdb { - -PhysicalStreamingLimit::PhysicalStreamingLimit(vector types, BoundLimitNode limit_val_p, - BoundLimitNode offset_val_p, idx_t estimated_cardinality, bool parallel) - : PhysicalOperator(PhysicalOperatorType::STREAMING_LIMIT, std::move(types), estimated_cardinality), - limit_val(std::move(limit_val_p)), offset_val(std::move(offset_val_p)), parallel(parallel) { -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class StreamingLimitOperatorState : public OperatorState { -public: - explicit StreamingLimitOperatorState(const PhysicalStreamingLimit &op) { - PhysicalLimit::SetInitialLimits(op.limit_val, op.offset_val, limit, offset); - } - - optional_idx limit; - optional_idx offset; -}; - -class StreamingLimitGlobalState : public GlobalOperatorState { -public: - StreamingLimitGlobalState() : current_offset(0) { - } - - std::atomic current_offset; -}; - -unique_ptr PhysicalStreamingLimit::GetOperatorState(ExecutionContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalStreamingLimit::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalStreamingLimit::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - auto &limit = state.limit; - auto &offset = state.offset; - idx_t current_offset = gstate.current_offset.fetch_add(input.size()); - idx_t max_element; - if (!PhysicalLimit::ComputeOffset(context, input, limit, offset, current_offset, max_element, limit_val, - offset_val)) { - return OperatorResultType::FINISHED; - } - if (PhysicalLimit::HandleOffset(input, current_offset, offset.GetIndex(), limit.GetIndex())) { - chunk.Reference(input); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -OrderPreservationType PhysicalStreamingLimit::OperatorOrder() const { - return OrderPreservationType::FIXED_ORDER; -} - -bool PhysicalStreamingLimit::ParallelOperator() const { - return parallel; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp b/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp deleted file mode 100644 index 309256244..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_streaming_sample.cpp +++ /dev/null @@ -1,77 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" -#include "duckdb/common/random_engine.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -PhysicalStreamingSample::PhysicalStreamingSample(vector types, SampleMethod method, double percentage, - int64_t seed, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::STREAMING_SAMPLE, std::move(types), estimated_cardinality), method(method), - percentage(percentage / 100), seed(seed) { -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class StreamingSampleOperatorState : public OperatorState { -public: - explicit StreamingSampleOperatorState(int64_t seed) : random(seed) { - } - - RandomEngine random; -}; - -void PhysicalStreamingSample::SystemSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { - // system sampling: we throw one dice per chunk - auto &state = state_p.Cast(); - double rand = state.random.NextRandom(); - if (rand <= percentage) { - // rand is smaller than sample_size: output chunk - result.Reference(input); - } -} - -void PhysicalStreamingSample::BernoulliSample(DataChunk &input, DataChunk &result, OperatorState &state_p) const { - // bernoulli sampling: we throw one dice per tuple - // then slice the result chunk - auto &state = state_p.Cast(); - idx_t result_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < input.size(); i++) { - double rand = state.random.NextRandom(); - if (rand <= percentage) { - sel.set_index(result_count++, i); - } - } - if (result_count > 0) { - result.Slice(input, sel, result_count); - } -} - -unique_ptr PhysicalStreamingSample::GetOperatorState(ExecutionContext &context) const { - return make_uniq(seed); -} - -OperatorResultType PhysicalStreamingSample::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - switch (method) { - case SampleMethod::BERNOULLI_SAMPLE: - BernoulliSample(input, chunk, state); - break; - case SampleMethod::SYSTEM_SAMPLE: - SystemSample(input, chunk, state); - break; - default: - throw InternalException("Unsupported sample method for streaming sample"); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -InsertionOrderPreservingMap PhysicalStreamingSample::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Sample Method"] = EnumUtil::ToString(method) + ": " + to_string(100 * percentage) + "%"; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp b/src/duckdb/src/execution/operator/helper/physical_transaction.cpp deleted file mode 100644 index a8ad410db..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_transaction.cpp +++ /dev/null @@ -1,79 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_transaction.hpp" - -#include "duckdb/common/exception/transaction_exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/valid_checker.hpp" -#include "duckdb/transaction/meta_transaction.hpp" -#include "duckdb/transaction/transaction_manager.hpp" - -namespace duckdb { - -SourceResultType PhysicalTransaction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &client = context.client; - - auto type = info->type; - if (type == TransactionType::COMMIT && ValidChecker::IsInvalidated(client.ActiveTransaction())) { - // transaction is invalidated - turn COMMIT into ROLLBACK - type = TransactionType::ROLLBACK; - } - switch (type) { - case TransactionType::BEGIN_TRANSACTION: { - if (client.transaction.IsAutoCommit()) { - // start the active transaction - // if autocommit is active, we have already called - // BeginTransaction by setting autocommit to false we - // prevent it from being closed after this query, hence - // preserving the transaction context for the next query - client.transaction.SetAutoCommit(false); - auto &config = DBConfig::GetConfig(context.client); - if (info->modifier == TransactionModifierType::TRANSACTION_READ_ONLY) { - client.transaction.SetReadOnly(); - } - if (config.options.immediate_transaction_mode) { - // if immediate transaction mode is enabled then start all transactions immediately - auto databases = DatabaseManager::Get(client).GetDatabases(client); - for (auto db : databases) { - context.client.transaction.ActiveTransaction().GetTransaction(db.get()); - } - } - } else { - throw TransactionException("cannot start a transaction within a transaction"); - } - break; - } - case TransactionType::COMMIT: { - if (client.transaction.IsAutoCommit()) { - throw TransactionException("cannot commit - no transaction is active"); - } else { - // explicitly commit the current transaction - client.transaction.Commit(); - } - break; - } - case TransactionType::ROLLBACK: { - if (client.transaction.IsAutoCommit()) { - throw TransactionException("cannot rollback - no transaction is active"); - } else { - // Explicitly rollback the current transaction - // If it is because of an invalidated transaction, we need to rollback with an error - auto &valid_checker = ValidChecker::Get(client.transaction.ActiveTransaction()); - if (valid_checker.IsInvalidated()) { - ErrorData error(ExceptionType::TRANSACTION, valid_checker.InvalidatedMessage()); - client.transaction.Rollback(error); - } else { - client.transaction.Rollback(nullptr); - } - } - break; - } - default: - throw NotImplementedException("Unrecognized transaction type!"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp b/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp deleted file mode 100644 index d29632d12..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_update_extensions.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_update_extensions.hpp" -#include "duckdb/main/extension_helper.hpp" - -namespace duckdb { - -SourceResultType PhysicalUpdateExtensions::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &data = input.global_state.Cast(); - - if (data.offset >= data.update_result_entries.size()) { - // finished returning values - return SourceResultType::FINISHED; - } - - idx_t count = 0; - while (data.offset < data.update_result_entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.update_result_entries[data.offset]; - - // return values: - idx_t col = 0; - // extension_name LogicalType::VARCHAR - chunk.SetValue(col++, count, Value(entry.extension_name)); - // repository LogicalType::VARCHAR - chunk.SetValue(col++, count, Value(entry.repository)); - // update_result - chunk.SetValue(col++, count, Value(EnumUtil::ToString(entry.tag))); - // previous_version LogicalType::VARCHAR - chunk.SetValue(col++, count, Value(entry.prev_version)); - // current_version LogicalType::VARCHAR - chunk.SetValue(col++, count, Value(entry.installed_version)); - - data.offset++; - count++; - } - chunk.SetCardinality(count); - - return data.offset >= data.update_result_entries.size() ? SourceResultType::FINISHED - : SourceResultType::HAVE_MORE_OUTPUT; -} - -unique_ptr PhysicalUpdateExtensions::GetGlobalSourceState(ClientContext &context) const { - auto res = make_uniq(); - - if (info->extensions_to_update.empty()) { - // Update all - res->update_result_entries = ExtensionHelper::UpdateExtensions(context); - } else { - // Update extensions in extensions_to_update - for (const auto &ext : info->extensions_to_update) { - res->update_result_entries.emplace_back(ExtensionHelper::UpdateExtension(context, ext)); - } - } - - return std::move(res); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp b/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp deleted file mode 100644 index 345745df8..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_vacuum.cpp +++ /dev/null @@ -1,107 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_vacuum.hpp" - -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/statistics/distinct_statistics.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" - -namespace duckdb { - -PhysicalVacuum::PhysicalVacuum(unique_ptr info_p, optional_ptr table, - unordered_map column_id_map, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::VACUUM, {LogicalType::BOOLEAN}, estimated_cardinality), - info(std::move(info_p)), table(table), column_id_map(std::move(column_id_map)) { -} - -class VacuumLocalSinkState : public LocalSinkState { -public: - explicit VacuumLocalSinkState(VacuumInfo &info, optional_ptr table) : hashes(LogicalType::HASH) { - for (const auto &column_name : info.columns) { - auto &column = table->GetColumn(column_name); - if (DistinctStatistics::TypeIsSupported(column.GetType())) { - column_distinct_stats.push_back(make_uniq()); - } else { - column_distinct_stats.push_back(nullptr); - } - } - }; - - vector> column_distinct_stats; - Vector hashes; -}; - -unique_ptr PhysicalVacuum::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(*info, table); -} - -class VacuumGlobalSinkState : public GlobalSinkState { -public: - explicit VacuumGlobalSinkState(VacuumInfo &info, optional_ptr table) { - for (const auto &column_name : info.columns) { - auto &column = table->GetColumn(column_name); - if (DistinctStatistics::TypeIsSupported(column.GetType())) { - column_distinct_stats.push_back(make_uniq()); - } else { - column_distinct_stats.push_back(nullptr); - } - } - }; - - mutex stats_lock; - vector> column_distinct_stats; -}; - -unique_ptr PhysicalVacuum::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*info, table); -} - -SinkResultType PhysicalVacuum::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - D_ASSERT(lstate.column_distinct_stats.size() == column_id_map.size()); - - for (idx_t col_idx = 0; col_idx < chunk.data.size(); col_idx++) { - if (!DistinctStatistics::TypeIsSupported(chunk.data[col_idx].GetType())) { - continue; - } - lstate.column_distinct_stats[col_idx]->Update(chunk.data[col_idx], chunk.size(), lstate.hashes); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalVacuum::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &g_state = input.global_state.Cast(); - auto &l_state = input.local_state.Cast(); - - lock_guard lock(g_state.stats_lock); - D_ASSERT(g_state.column_distinct_stats.size() == l_state.column_distinct_stats.size()); - - for (idx_t col_idx = 0; col_idx < g_state.column_distinct_stats.size(); col_idx++) { - if (g_state.column_distinct_stats[col_idx]) { - D_ASSERT(l_state.column_distinct_stats[col_idx]); - g_state.column_distinct_stats[col_idx]->Merge(*l_state.column_distinct_stats[col_idx]); - } - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalVacuum::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &sink = input.global_state.Cast(); - - auto tbl = table; - for (idx_t col_idx = 0; col_idx < sink.column_distinct_stats.size(); col_idx++) { - tbl->GetStorage().SetDistinct(column_id_map.at(col_idx), std::move(sink.column_distinct_stats[col_idx])); - } - - return SinkFinalizeType::READY; -} - -SourceResultType PhysicalVacuum::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // NOP - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp b/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp deleted file mode 100644 index ea19d7ca9..000000000 --- a/src/duckdb/src/execution/operator/helper/physical_verify_vector.cpp +++ /dev/null @@ -1,234 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_verify_vector.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" - -namespace duckdb { - -PhysicalVerifyVector::PhysicalVerifyVector(unique_ptr child) - : PhysicalOperator(PhysicalOperatorType::VERIFY_VECTOR, child->types, child->estimated_cardinality) { - children.push_back(std::move(child)); -} - -class VerifyVectorState : public OperatorState { -public: - explicit VerifyVectorState() : const_idx(0) { - } - - idx_t const_idx; -}; - -OperatorResultType VerifyEmitConstantVectors(const DataChunk &input, DataChunk &chunk, OperatorState &state_p) { - auto &state = state_p.Cast(); - D_ASSERT(state.const_idx < input.size()); - - // Ensure that we don't alter the input data while another thread is still using it. - DataChunk copied_input; - copied_input.Initialize(Allocator::DefaultAllocator(), input.GetTypes()); - input.Copy(copied_input); - - // emit constant vectors at the current index - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - ConstantVector::Reference(chunk.data[c], copied_input.data[c], state.const_idx, 1); - } - chunk.SetCardinality(1); - state.const_idx++; - if (state.const_idx >= copied_input.size()) { - state.const_idx = 0; - return OperatorResultType::NEED_MORE_INPUT; - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType VerifyEmitDictionaryVectors(const DataChunk &input, DataChunk &chunk, OperatorState &state) { - input.Copy(chunk); - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - Vector::DebugTransformToDictionary(chunk.data[c], chunk.size()); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -struct ConstantOrSequenceInfo { - vector values; - bool is_constant = true; -}; - -OperatorResultType VerifyEmitSequenceVector(const DataChunk &input, DataChunk &chunk, OperatorState &state_p) { - auto &state = state_p.Cast(); - D_ASSERT(state.const_idx < input.size()); - - // find the longest length sequence or constant vector to emit - vector infos; - idx_t max_length = 0; - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - bool can_be_sequence = false; - switch (chunk.data[c].GetType().id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: { - can_be_sequence = true; - break; - } - default: { - break; - } - } - bool can_be_constant = true; - switch (chunk.data[c].GetType().id()) { - case LogicalTypeId::INTERVAL: - can_be_constant = false; - break; - default: - break; - } - ConstantOrSequenceInfo info; - info.is_constant = true; - for (idx_t k = state.const_idx; k < input.size(); k++) { - auto val = input.data[c].GetValue(k); - if (info.values.empty()) { - info.values.push_back(std::move(val)); - } else if (info.is_constant) { - if (!ValueOperations::DistinctFrom(val, info.values[0]) && can_be_constant) { - // found the same value! continue - info.values.push_back(std::move(val)); - continue; - } - // not the same value - can we convert this into a sequence vector? - if (!can_be_sequence) { - break; - } - // we can only convert to a sequence if we have only gathered one value - // otherwise we would have multiple identical values here already - if (info.values.size() > 1) { - break; - } - // cannot create a sequence with null values - if (val.IsNull() || info.values[0].IsNull()) { - break; - } - // check if the increment fits in the target type - // i.e. we cannot have a sequence vector with an increment of 200 in `int8_t` - auto increment = hugeint_t(val.GetValue()) - hugeint_t(info.values[0].GetValue()); - bool increment_fits = true; - switch (chunk.data[c].GetType().id()) { - case LogicalTypeId::TINYINT: { - int8_t result; - if (!Hugeint::TryCast(increment, result)) { - increment_fits = false; - } - break; - } - case LogicalTypeId::SMALLINT: { - int16_t result; - if (!Hugeint::TryCast(increment, result)) { - increment_fits = false; - } - break; - } - case LogicalTypeId::INTEGER: { - int32_t result; - if (!Hugeint::TryCast(increment, result)) { - increment_fits = false; - } - break; - } - case LogicalTypeId::BIGINT: { - int64_t result; - if (!Hugeint::TryCast(increment, result)) { - increment_fits = false; - } - break; - } - default: - throw InternalException("Unsupported sequence type"); - } - if (!increment_fits) { - break; - } - info.values.push_back(std::move(val)); - info.is_constant = false; - continue; - } else { - D_ASSERT(info.values.size() >= 2); - // sequence vector - check if this value is on the trajectory - if (val.IsNull()) { - // not on trajectory - this value is null - break; - } - int64_t start = info.values[0].GetValue(); - int64_t increment = info.values[1].GetValue() - start; - int64_t last_value = info.values.back().GetValue(); - if (hugeint_t(val.GetValue()) == hugeint_t(last_value) + hugeint_t(increment)) { - // value still fits in the sequence - info.values.push_back(std::move(val)); - continue; - } - // value no longer fits into the sequence - break - break; - } - } - if (info.values.size() > max_length) { - max_length = info.values.size(); - } - infos.push_back(std::move(info)); - } - // go over each of the columns again and construct either (1) a dictionary vector, or (2) a constant/sequence vector - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - auto &info = infos[c]; - if (info.values.size() != max_length) { - // dictionary vector - SelectionVector sel(max_length); - for (idx_t k = 0; k < max_length; k++) { - sel.set_index(k, state.const_idx + k); - } - chunk.data[c].Slice(input.data[c], sel, max_length); - } else if (info.is_constant) { - // constant vector - chunk.data[c].Reference(info.values[0]); - } else { - // sequence vector - int64_t start = info.values[0].GetValue(); - int64_t increment = info.values[1].GetValue() - start; - chunk.data[c].Sequence(start, increment, max_length); - } - } - chunk.SetCardinality(max_length); - state.const_idx += max_length; - if (state.const_idx >= input.size()) { - state.const_idx = 0; - return OperatorResultType::NEED_MORE_INPUT; - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType VerifyEmitNestedShuffleVector(const DataChunk &input, DataChunk &chunk, OperatorState &state) { - input.Copy(chunk); - for (idx_t c = 0; c < chunk.ColumnCount(); c++) { - Vector::DebugShuffleNestedVector(chunk.data[c], chunk.size()); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalVerifyVector::GetOperatorState(ExecutionContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalVerifyVector::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { -#ifdef DUCKDB_VERIFY_CONSTANT_OPERATOR - return VerifyEmitConstantVectors(input, chunk, state); -#endif -#ifdef DUCKDB_VERIFY_DICTIONARY_OPERATOR - return VerifyEmitDictionaryVectors(input, chunk, state); -#endif -#ifdef DUCKDB_VERIFY_SEQUENCE_OPERATOR - return VerifyEmitSequenceVector(input, chunk, state); -#endif -#ifdef DUCKDB_VERIFY_NESTED_SHUFFLE - return VerifyEmitNestedShuffleVector(input, chunk, state); -#endif - throw InternalException("PhysicalVerifyVector created but no verification code present"); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp b/src/duckdb/src/execution/operator/join/outer_join_marker.cpp deleted file mode 100644 index 645c702ad..000000000 --- a/src/duckdb/src/execution/operator/join/outer_join_marker.cpp +++ /dev/null @@ -1,108 +0,0 @@ -#include "duckdb/execution/operator/join/outer_join_marker.hpp" - -namespace duckdb { - -OuterJoinMarker::OuterJoinMarker(bool enabled_p) : enabled(enabled_p), count(0) { -} - -void OuterJoinMarker::Initialize(idx_t count_p) { - if (!enabled) { - return; - } - this->count = count_p; - found_match = make_unsafe_uniq_array_uninitialized(count); - Reset(); -} - -void OuterJoinMarker::Reset() { - if (!enabled) { - return; - } - memset(found_match.get(), 0, sizeof(bool) * count); -} - -void OuterJoinMarker::SetMatch(idx_t position) { - if (!enabled) { - return; - } - D_ASSERT(position < count); - found_match[position] = true; -} - -void OuterJoinMarker::SetMatches(const SelectionVector &sel, idx_t count, idx_t base_idx) { - if (!enabled) { - return; - } - for (idx_t i = 0; i < count; i++) { - auto idx = sel.get_index(i); - auto pos = base_idx + idx; - D_ASSERT(pos < this->count); - found_match[pos] = true; - } -} - -void OuterJoinMarker::ConstructLeftJoinResult(DataChunk &left, DataChunk &result) { - if (!enabled) { - return; - } - D_ASSERT(count == STANDARD_VECTOR_SIZE); - SelectionVector remaining_sel(STANDARD_VECTOR_SIZE); - idx_t remaining_count = 0; - for (idx_t i = 0; i < left.size(); i++) { - if (!found_match[i]) { - remaining_sel.set_index(remaining_count++, i); - } - } - if (remaining_count > 0) { - result.Slice(left, remaining_sel, remaining_count); - for (idx_t idx = left.ColumnCount(); idx < result.ColumnCount(); idx++) { - result.data[idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[idx], true); - } - } -} - -idx_t OuterJoinMarker::MaxThreads() const { - return count / (STANDARD_VECTOR_SIZE * 10ULL); -} - -void OuterJoinMarker::InitializeScan(ColumnDataCollection &data, OuterJoinGlobalScanState &gstate) { - gstate.data = &data; - data.InitializeScan(gstate.global_scan); -} - -void OuterJoinMarker::InitializeScan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate) { - D_ASSERT(gstate.data); - lstate.match_sel.Initialize(STANDARD_VECTOR_SIZE); - gstate.data->InitializeScanChunk(lstate.scan_chunk); -} - -void OuterJoinMarker::Scan(OuterJoinGlobalScanState &gstate, OuterJoinLocalScanState &lstate, DataChunk &result) { - D_ASSERT(gstate.data); - // fill in NULL values for the LHS - while (gstate.data->Scan(gstate.global_scan, lstate.local_scan, lstate.scan_chunk)) { - idx_t result_count = 0; - // figure out which tuples didn't find a match in the RHS - for (idx_t i = 0; i < lstate.scan_chunk.size(); i++) { - if (!found_match[lstate.local_scan.current_row_index + i]) { - lstate.match_sel.set_index(result_count++, i); - } - } - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - idx_t left_column_count = result.ColumnCount() - lstate.scan_chunk.ColumnCount(); - for (idx_t i = 0; i < left_column_count; i++) { - result.data[i].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[i], true); - } - for (idx_t col_idx = left_column_count; col_idx < result.ColumnCount(); col_idx++) { - result.data[col_idx].Slice(lstate.scan_chunk.data[col_idx - left_column_count], lstate.match_sel, - result_count); - } - result.SetCardinality(result_count); - return; - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp b/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp deleted file mode 100644 index 123d94161..000000000 --- a/src/duckdb/src/execution/operator/join/perfect_hash_join_executor.cpp +++ /dev/null @@ -1,401 +0,0 @@ -#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" - -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/execution/operator/join/physical_hash_join.hpp" - -namespace duckdb { - -PerfectHashJoinExecutor::PerfectHashJoinExecutor(const PhysicalHashJoin &join_p, JoinHashTable &ht_p) - : join(join_p), ht(ht_p) { -} - -//===--------------------------------------------------------------------===// -// Initialize -//===--------------------------------------------------------------------===// -bool ExtractNumericValue(Value val, hugeint_t &result) { - if (!val.type().IsIntegral()) { - switch (val.type().InternalType()) { - case PhysicalType::INT8: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::INT16: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::INT32: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::INT64: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::INT128: - result = val.GetValueUnsafe(); - break; - case PhysicalType::UINT8: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::UINT16: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::UINT32: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::UINT64: - result = Hugeint::Convert(val.GetValueUnsafe()); - break; - case PhysicalType::UINT128: { - const auto uhugeint_val = val.GetValueUnsafe(); - if (uhugeint_val > NumericCast(NumericLimits::Maximum())) { - return false; - } - result.lower = uhugeint_val.lower; - result.upper = NumericCast(uhugeint_val.upper); - break; - } - default: - return false; - } - } else { - if (!val.DefaultTryCastAs(LogicalType::HUGEINT)) { - return false; - } - result = val.GetValue(); - } - return true; -} - -bool PerfectHashJoinExecutor::CanDoPerfectHashJoin(const PhysicalHashJoin &op, const Value &min, const Value &max) { - if (perfect_join_statistics.is_build_small) { - return true; // Already true based on static statistics - } - - // We only do this optimization for inner joins with one integer equality condition - const auto key_type = op.conditions[0].left->return_type; - if (op.join_type != JoinType::INNER || op.conditions.size() != 1 || - op.conditions[0].comparison != ExpressionType::COMPARE_EQUAL || !TypeIsInteger(key_type.InternalType())) { - return false; - } - - // We bail out if there are nested types on the RHS - for (auto &type : op.children[1]->types) { - switch (type.InternalType()) { - case PhysicalType::STRUCT: - case PhysicalType::LIST: - case PhysicalType::ARRAY: - return false; - default: - break; - } - } - - // And when the build range is smaller than the threshold - perfect_join_statistics.build_min = min; - perfect_join_statistics.build_max = max; - hugeint_t min_value, max_value; - if (!ExtractNumericValue(perfect_join_statistics.build_min, min_value) || - !ExtractNumericValue(perfect_join_statistics.build_max, max_value)) { - return false; - } - if (max_value < min_value) { - return false; // Empty table - } - - hugeint_t build_range; - if (!TrySubtractOperator::Operation(max_value, min_value, build_range)) { - return false; - } - - // The max size our build must have to run the perfect HJ - static constexpr idx_t MAX_BUILD_SIZE = 1048576; - if (build_range > Hugeint::Convert(MAX_BUILD_SIZE)) { - return false; - } - perfect_join_statistics.build_range = NumericCast(build_range); - - // If count is larger than range (duplicates), we bail out - if (ht.Count() > perfect_join_statistics.build_range) { - return false; - } - - perfect_join_statistics.is_build_small = true; - return true; -} - -//===--------------------------------------------------------------------===// -// Build -//===--------------------------------------------------------------------===// -bool PerfectHashJoinExecutor::BuildPerfectHashTable(LogicalType &key_type) { - // First, allocate memory for each build column - auto build_size = perfect_join_statistics.build_range + 1; - for (const auto &type : join.rhs_output_columns.col_types) { - perfect_hash_table.emplace_back(type, build_size); - } - - // and for duplicate_checking - bitmap_build_idx = make_unsafe_uniq_array_uninitialized(build_size); - memset(bitmap_build_idx.get(), 0, sizeof(bool) * build_size); // set false - - // Now fill columns with build data - return FullScanHashTable(key_type); -} - -bool PerfectHashJoinExecutor::FullScanHashTable(LogicalType &key_type) { - auto &data_collection = ht.GetDataCollection(); - - // TODO: In a parallel finalize: One should exclusively lock and each thread should do one part of the code below. - Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples - - idx_t key_count = 0; - if (data_collection.ChunkCount() > 0) { - JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); - - // Go through all the blocks and fill the keys addresses - key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); - } - - // Scan the build keys in the hash table - Vector build_vector(key_type, key_count); - data_collection.Gather(tuples_addresses, *FlatVector::IncrementalSelectionVector(), key_count, 0, build_vector, - *FlatVector::IncrementalSelectionVector(), nullptr); - - // Now fill the selection vector using the build keys and create a sequential vector - // TODO: add check for fast pass when probe is part of build domain - SelectionVector sel_build(key_count + 1); - SelectionVector sel_tuples(key_count + 1); - bool success = FillSelectionVectorSwitchBuild(build_vector, sel_build, sel_tuples, key_count); - - // early out - if (!success) { - return false; - } - if (unique_keys == perfect_join_statistics.build_range + 1 && !ht.has_null) { - perfect_join_statistics.is_build_dense = true; - } - key_count = unique_keys; // do not consider keys out of the range - - // Full scan the remaining build columns and fill the perfect hash table - const auto build_size = perfect_join_statistics.build_range + 1; - for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { - auto &vector = perfect_hash_table[i]; - const auto output_col_idx = ht.output_columns[i]; - D_ASSERT(vector.GetType() == ht.layout.GetTypes()[output_col_idx]); - if (build_size > STANDARD_VECTOR_SIZE) { - auto &col_mask = FlatVector::Validity(vector); - col_mask.Initialize(build_size); - } - data_collection.Gather(tuples_addresses, sel_tuples, key_count, output_col_idx, vector, sel_build, nullptr); - } - - return true; -} - -bool PerfectHashJoinExecutor::FillSelectionVectorSwitchBuild(Vector &source, SelectionVector &sel_vec, - SelectionVector &seq_sel_vec, idx_t count) { - switch (source.GetType().InternalType()) { - case PhysicalType::INT8: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT16: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT32: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT64: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::INT128: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT8: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT16: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT32: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT64: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - case PhysicalType::UINT128: - return TemplatedFillSelectionVectorBuild(source, sel_vec, seq_sel_vec, count); - default: - throw NotImplementedException("Type not supported for perfect hash join"); - } -} - -template -bool PerfectHashJoinExecutor::TemplatedFillSelectionVectorBuild(Vector &source, SelectionVector &sel_vec, - SelectionVector &seq_sel_vec, idx_t count) { - if (perfect_join_statistics.build_min.IsNull() || perfect_join_statistics.build_max.IsNull()) { - return false; - } - auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); - auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); - UnifiedVectorFormat vector_data; - source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); - // generate the selection vector - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - auto data_idx = vector_data.sel->get_index(i); - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - sel_vec.set_index(sel_idx, idx); - if (bitmap_build_idx[idx]) { - return false; - } else { - bitmap_build_idx[idx] = true; - unique_keys++; - } - seq_sel_vec.set_index(sel_idx++, i); - } - } - return true; -} - -//===--------------------------------------------------------------------===// -// Probe -//===--------------------------------------------------------------------===// -class PerfectHashJoinState : public OperatorState { -public: - PerfectHashJoinState(ClientContext &context, const PhysicalHashJoin &join) : probe_executor(context) { - join_keys.Initialize(Allocator::Get(context), join.condition_types); - for (auto &cond : join.conditions) { - probe_executor.AddExpression(*cond.left); - } - build_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - probe_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - seq_sel_vec.Initialize(STANDARD_VECTOR_SIZE); - } - - DataChunk join_keys; - ExpressionExecutor probe_executor; - SelectionVector build_sel_vec; - SelectionVector probe_sel_vec; - SelectionVector seq_sel_vec; -}; - -unique_ptr PerfectHashJoinExecutor::GetOperatorState(ExecutionContext &context) { - auto state = make_uniq(context.client, join); - return std::move(state); -} - -OperatorResultType PerfectHashJoinExecutor::ProbePerfectHashTable(ExecutionContext &context, DataChunk &input, - DataChunk &lhs_output_columns, DataChunk &result, - OperatorState &state_p) { - auto &state = state_p.Cast(); - // keeps track of how many probe keys have a match - idx_t probe_sel_count = 0; - - // fetch the join keys from the chunk - state.join_keys.Reset(); - state.probe_executor.Execute(input, state.join_keys); - // select the keys that are in the min-max range - auto &keys_vec = state.join_keys.data[0]; - auto keys_count = state.join_keys.size(); - // todo: add check for fast pass when probe is part of build domain - FillSelectionVectorSwitchProbe(keys_vec, state.build_sel_vec, state.probe_sel_vec, keys_count, probe_sel_count); - - // If build is dense and probe is in build's domain, just reference probe - if (perfect_join_statistics.is_build_dense && keys_count == probe_sel_count) { - result.Reference(lhs_output_columns); - } else { - // otherwise, filter it out the values that do not match - result.Slice(lhs_output_columns, state.probe_sel_vec, probe_sel_count, 0); - } - // on the build side, we need to fetch the data and build dictionary vectors with the sel_vec - for (idx_t i = 0; i < join.rhs_output_columns.col_types.size(); i++) { - auto &result_vector = result.data[lhs_output_columns.ColumnCount() + i]; - D_ASSERT(result_vector.GetType() == ht.layout.GetTypes()[ht.output_columns[i]]); - auto &build_vec = perfect_hash_table[i]; - result_vector.Reference(build_vec); - result_vector.Slice(state.build_sel_vec, probe_sel_count); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -void PerfectHashJoinExecutor::FillSelectionVectorSwitchProbe(Vector &source, SelectionVector &build_sel_vec, - SelectionVector &probe_sel_vec, idx_t count, - idx_t &probe_sel_count) { - switch (source.GetType().InternalType()) { - case PhysicalType::INT8: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT16: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT32: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT64: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::INT128: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT8: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT16: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT32: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT64: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - case PhysicalType::UINT128: - TemplatedFillSelectionVectorProbe(source, build_sel_vec, probe_sel_vec, count, probe_sel_count); - break; - default: - throw NotImplementedException("Type not supported"); - } -} - -template -void PerfectHashJoinExecutor::TemplatedFillSelectionVectorProbe(Vector &source, SelectionVector &build_sel_vec, - SelectionVector &probe_sel_vec, idx_t count, - idx_t &probe_sel_count) { - auto min_value = perfect_join_statistics.build_min.GetValueUnsafe(); - auto max_value = perfect_join_statistics.build_max.GetValueUnsafe(); - - UnifiedVectorFormat vector_data; - source.ToUnifiedFormat(count, vector_data); - auto data = reinterpret_cast(vector_data.data); - auto validity_mask = &vector_data.validity; - // build selection vector for non-dense build - if (validity_mask->AllValid()) { - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - // retrieve value from vector - auto data_idx = vector_data.sel->get_index(i); - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { - build_sel_vec.set_index(sel_idx, idx); - probe_sel_vec.set_index(sel_idx++, i); - probe_sel_count++; - } - } - } - } else { - for (idx_t i = 0, sel_idx = 0; i < count; ++i) { - // retrieve value from vector - auto data_idx = vector_data.sel->get_index(i); - if (!validity_mask->RowIsValid(data_idx)) { - continue; - } - auto input_value = data[data_idx]; - // add index to selection vector if value in the range - if (min_value <= input_value && input_value <= max_value) { - auto idx = (idx_t)(input_value - min_value); // subtract min value to get the idx position - // check for matches in the build - if (bitmap_build_idx[idx]) { - build_sel_vec.set_index(sel_idx, idx); - probe_sel_vec.set_index(sel_idx++, i); - probe_sel_count++; - } - } - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp b/src/duckdb/src/execution/operator/join/physical_asof_join.cpp deleted file mode 100644 index 91dda01e6..000000000 --- a/src/duckdb/src/execution/operator/join/physical_asof_join.cpp +++ /dev/null @@ -1,876 +0,0 @@ -#include "duckdb/execution/operator/join/physical_asof_join.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/join/outer_join_marker.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/thread_context.hpp" - -#include - -namespace duckdb { - -PhysicalAsOfJoin::PhysicalAsOfJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right) - : PhysicalComparisonJoin(op, PhysicalOperatorType::ASOF_JOIN, std::move(op.conditions), op.join_type, - op.estimated_cardinality), - comparison_type(ExpressionType::INVALID) { - - // Convert the conditions partitions and sorts - for (auto &cond : conditions) { - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - switch (cond.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - null_sensitive.emplace_back(lhs_orders.size()); - lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - comparison_type = cond.comparison; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - // Always put NULLS LAST so they can be ignored. - null_sensitive.emplace_back(lhs_orders.size()); - lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - comparison_type = cond.comparison; - break; - case ExpressionType::COMPARE_EQUAL: - null_sensitive.emplace_back(lhs_orders.size()); - DUCKDB_EXPLICIT_FALLTHROUGH; - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - lhs_partitions.emplace_back(std::move(left)); - rhs_partitions.emplace_back(std::move(right)); - break; - default: - throw NotImplementedException("Unsupported join condition for ASOF join"); - } - } - D_ASSERT(!lhs_orders.empty()); - D_ASSERT(!rhs_orders.empty()); - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - // Fill out the right projection map. - right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = children[1]->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class AsOfGlobalSinkState : public GlobalSinkState { -public: - AsOfGlobalSinkState(ClientContext &context, const PhysicalAsOfJoin &op) - : rhs_sink(context, op.rhs_partitions, op.rhs_orders, op.children[1]->types, {}, op.estimated_cardinality), - is_outer(IsRightOuterJoin(op.join_type)), has_null(false) { - } - - idx_t Count() const { - return rhs_sink.count; - } - - PartitionLocalSinkState *RegisterBuffer(ClientContext &context) { - lock_guard guard(lock); - lhs_buffers.emplace_back(make_uniq(context, *lhs_sink)); - return lhs_buffers.back().get(); - } - - PartitionGlobalSinkState rhs_sink; - - // One per partition - const bool is_outer; - vector right_outers; - bool has_null; - - // Left side buffering - unique_ptr lhs_sink; - - mutex lock; - vector> lhs_buffers; -}; - -class AsOfLocalSinkState : public LocalSinkState { -public: - explicit AsOfLocalSinkState(ClientContext &context, PartitionGlobalSinkState &gstate_p) - : local_partition(context, gstate_p) { - } - - void Sink(DataChunk &input_chunk) { - local_partition.Sink(input_chunk); - } - - void Combine() { - local_partition.Combine(); - } - - PartitionLocalSinkState local_partition; -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalAsOfJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS - auto &gsink = sink_state->Cast(); - return make_uniq(context.client, gsink.rhs_sink); -} - -SinkResultType PhysicalAsOfJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - lstate.Sink(chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalAsOfJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.Combine(); - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalAsOfJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - - // The data is all in so we can initialise the left partitioning. - const vector> partitions_stats; - gstate.lhs_sink = make_uniq(context, lhs_partitions, lhs_orders, children[0]->types, - partitions_stats, 0U); - gstate.lhs_sink->SyncPartitioning(gstate.rhs_sink); - - // Find the first group to sort - if (!gstate.rhs_sink.HasMergeTasks() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Schedule all the sorts for maximum thread utilisation - auto new_event = make_shared_ptr(gstate.rhs_sink, pipeline, *this); - event.InsertEvent(std::move(new_event)); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class AsOfGlobalState : public GlobalOperatorState { -public: - explicit AsOfGlobalState(AsOfGlobalSinkState &gsink) { - // for FULL/RIGHT OUTER JOIN, initialize right_outers to false for every tuple - auto &rhs_partition = gsink.rhs_sink; - auto &right_outers = gsink.right_outers; - right_outers.reserve(rhs_partition.hash_groups.size()); - for (const auto &hash_group : rhs_partition.hash_groups) { - right_outers.emplace_back(OuterJoinMarker(gsink.is_outer)); - right_outers.back().Initialize(hash_group->count); - } - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalOperatorState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); -} - -class AsOfLocalState : public CachingOperatorState { -public: - AsOfLocalState(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), lhs_executor(context), - left_outer(IsLeftOuterJoin(op.join_type)), fetch_next_left(true) { - lhs_keys.Initialize(allocator, op.join_key_types); - for (const auto &cond : op.conditions) { - lhs_executor.AddExpression(*cond.left); - } - - lhs_payload.Initialize(allocator, op.children[0]->types); - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); - - auto &gsink = op.sink_state->Cast(); - lhs_partition_sink = gsink.RegisterBuffer(context); - } - - bool Sink(DataChunk &input); - OperatorResultType ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk); - - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; - - ExpressionExecutor lhs_executor; - DataChunk lhs_keys; - ValidityMask lhs_valid_mask; - SelectionVector lhs_sel; - DataChunk lhs_payload; - - OuterJoinMarker left_outer; - bool fetch_next_left; - - optional_ptr lhs_partition_sink; -}; - -bool AsOfLocalState::Sink(DataChunk &input) { - // Compute the join keys - lhs_keys.Reset(); - lhs_executor.Execute(input, lhs_keys); - lhs_keys.Flatten(); - - // Combine the NULLs - const auto count = input.size(); - lhs_valid_mask.Reset(); - for (auto col_idx : op.null_sensitive) { - auto &col = lhs_keys.data[col_idx]; - UnifiedVectorFormat unified; - col.ToUnifiedFormat(count, unified); - lhs_valid_mask.Combine(unified.validity, count); - } - - // Convert the mask to a selection vector - // and mark all the rows that cannot match for early return. - idx_t lhs_valid = 0; - const auto entry_count = lhs_valid_mask.EntryCount(count); - idx_t base_idx = 0; - left_outer.Reset(); - for (idx_t entry_idx = 0; entry_idx < entry_count;) { - const auto validity_entry = lhs_valid_mask.GetValidityEntry(entry_idx++); - const auto next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - for (; base_idx < next; ++base_idx) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - base_idx = next; - } else { - const auto start = base_idx; - for (; base_idx < next; ++base_idx) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - lhs_sel.set_index(lhs_valid++, base_idx); - left_outer.SetMatch(base_idx); - } - } - } - } - - // Slice the keys to the ones we can match - lhs_payload.Reset(); - if (lhs_valid == count) { - lhs_payload.Reference(input); - lhs_payload.SetCardinality(input); - } else { - lhs_payload.Slice(input, lhs_sel, lhs_valid); - lhs_payload.SetCardinality(lhs_valid); - - // Flush the ones that can't match - fetch_next_left = false; - } - - lhs_partition_sink->Sink(lhs_payload); - - return false; -} - -OperatorResultType AsOfLocalState::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk) { - input.Verify(); - Sink(input); - - // If there were any unmatchable rows, return them now so we can forget about them. - if (!fetch_next_left) { - fetch_next_left = true; - left_outer.ConstructLeftJoinResult(input, chunk); - left_outer.Reset(); - } - - // Just keep asking for data and buffering it - return OperatorResultType::NEED_MORE_INPUT; -} - -OperatorResultType PhysicalAsOfJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &lstate_p) const { - auto &gsink = sink_state->Cast(); - auto &lstate = lstate_p.Cast(); - - if (gsink.rhs_sink.count == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gsink.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - return lstate.ExecuteInternal(context, input, chunk); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class AsOfProbeBuffer { -public: - using Orders = vector; - - static bool IsExternal(ClientContext &context) { - return ClientConfig::GetConfig(context).force_external; - } - - AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op); - -public: - void ResolveJoin(bool *found_matches, idx_t *matches = nullptr); - bool Scanning() const { - return lhs_scanner.get(); - } - void BeginLeftScan(hash_t scan_bin); - bool NextLeft(); - void EndScan(); - - // resolve joins that output max N elements (SEMI, ANTI, MARK) - void ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk); - // resolve joins that can potentially output N*M elements (INNER, LEFT, FULL) - void ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk); - // Chunk may be empty - void GetData(ExecutionContext &context, DataChunk &chunk); - bool HasMoreData() const { - return !fetch_next_left || (lhs_scanner && lhs_scanner->Remaining()); - } - - ClientContext &context; - Allocator &allocator; - const PhysicalAsOfJoin &op; - BufferManager &buffer_manager; - const bool force_external; - const idx_t memory_per_thread; - Orders lhs_orders; - - // LHS scanning - SelectionVector lhs_sel; - optional_ptr left_hash; - OuterJoinMarker left_outer; - unique_ptr left_itr; - unique_ptr lhs_scanner; - DataChunk lhs_payload; - - // RHS scanning - optional_ptr right_hash; - optional_ptr right_outer; - unique_ptr right_itr; - unique_ptr rhs_scanner; - DataChunk rhs_payload; - - idx_t lhs_match_count; - bool fetch_next_left; -}; - -AsOfProbeBuffer::AsOfProbeBuffer(ClientContext &context, const PhysicalAsOfJoin &op) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(IsExternal(context)), - memory_per_thread(op.GetMaxThreadMemory(context)), left_outer(IsLeftOuterJoin(op.join_type)), - fetch_next_left(true) { - vector> partition_stats; - Orders partitions; // Not used. - PartitionGlobalSinkState::GenerateOrderings(partitions, lhs_orders, op.lhs_partitions, op.lhs_orders, - partition_stats); - - // We sort the row numbers of the incoming block, not the rows - lhs_payload.Initialize(allocator, op.children[0]->types); - rhs_payload.Initialize(allocator, op.children[1]->types); - - lhs_sel.Initialize(); - left_outer.Initialize(STANDARD_VECTOR_SIZE); -} - -void AsOfProbeBuffer::BeginLeftScan(hash_t scan_bin) { - auto &gsink = op.sink_state->Cast(); - auto &lhs_sink = *gsink.lhs_sink; - const auto left_group = lhs_sink.bin_groups[scan_bin]; - if (left_group >= lhs_sink.bin_groups.size()) { - return; - } - - auto iterator_comp = ExpressionType::INVALID; - switch (op.comparison_type) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_GREATERTHAN: - iterator_comp = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - iterator_comp = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHAN: - iterator_comp = ExpressionType::COMPARE_GREATERTHAN; - break; - default: - throw NotImplementedException("Unsupported comparison type for ASOF join"); - } - - left_hash = lhs_sink.hash_groups[left_group].get(); - auto &left_sort = *(left_hash->global_sort); - if (left_sort.sorted_blocks.empty()) { - return; - } - lhs_scanner = make_uniq(left_sort, false); - left_itr = make_uniq(left_sort, iterator_comp); - - // We are only probing the corresponding right side bin, which may be empty - // If they are empty, we leave the iterator as null so we can emit left matches - auto &rhs_sink = gsink.rhs_sink; - const auto right_group = rhs_sink.bin_groups[scan_bin]; - if (right_group < rhs_sink.bin_groups.size()) { - right_hash = rhs_sink.hash_groups[right_group].get(); - right_outer = gsink.right_outers.data() + right_group; - auto &right_sort = *(right_hash->global_sort); - right_itr = make_uniq(right_sort, iterator_comp); - rhs_scanner = make_uniq(right_sort, false); - } -} - -bool AsOfProbeBuffer::NextLeft() { - if (!HasMoreData()) { - return false; - } - - // Scan the next sorted chunk - lhs_payload.Reset(); - left_itr->SetIndex(lhs_scanner->Scanned()); - lhs_scanner->Scan(lhs_payload); - - return true; -} - -void AsOfProbeBuffer::EndScan() { - right_hash = nullptr; - right_itr.reset(); - rhs_scanner.reset(); - right_outer = nullptr; - - left_hash = nullptr; - left_itr.reset(); - lhs_scanner.reset(); -} - -void AsOfProbeBuffer::ResolveJoin(bool *found_match, idx_t *matches) { - // If there was no right partition, there are no matches - lhs_match_count = 0; - left_outer.Reset(); - if (!right_itr) { - return; - } - - const auto count = lhs_payload.size(); - const auto left_base = left_itr->GetIndex(); - // Searching for right <= left - for (idx_t i = 0; i < count; ++i) { - left_itr->SetIndex(left_base + i); - - // If right > left, then there is no match - if (!right_itr->Compare(*left_itr)) { - continue; - } - - // Exponential search forward for a non-matching value using radix iterators - // (We use exponential search to avoid thrashing the block manager on large probes) - idx_t bound = 1; - idx_t begin = right_itr->GetIndex(); - right_itr->SetIndex(begin + bound); - while (right_itr->GetIndex() < right_hash->count) { - if (right_itr->Compare(*left_itr)) { - // If right <= left, jump ahead - bound *= 2; - right_itr->SetIndex(begin + bound); - } else { - break; - } - } - - // Binary search for the first non-matching value using radix iterators - // The previous value (which we know exists) is the match - auto first = begin + bound / 2; - auto last = MinValue(begin + bound, right_hash->count); - while (first < last) { - const auto mid = first + (last - first) / 2; - right_itr->SetIndex(mid); - if (right_itr->Compare(*left_itr)) { - // If right <= left, new lower bound - first = mid + 1; - } else { - last = mid; - } - } - right_itr->SetIndex(--first); - - // Check partitions for strict equality - if (right_hash->ComparePartitions(*left_itr, *right_itr)) { - continue; - } - - // Emit match data - right_outer->SetMatch(first); - left_outer.SetMatch(i); - if (found_match) { - found_match[i] = true; - } - if (matches) { - matches[i] = first; - } - lhs_sel.set_index(lhs_match_count++, i); - } -} - -unique_ptr PhysicalAsOfJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void AsOfProbeBuffer::ResolveSimpleJoin(ExecutionContext &context, DataChunk &chunk) { - // perform the actual join - bool found_match[STANDARD_VECTOR_SIZE] = {false}; - ResolveJoin(found_match); - - // now construct the result based on the join result - switch (op.join_type) { - case JoinType::SEMI: - PhysicalJoin::ConstructSemiJoinResult(lhs_payload, chunk, found_match); - break; - case JoinType::ANTI: - PhysicalJoin::ConstructAntiJoinResult(lhs_payload, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented join type for AsOf join"); - } -} - -void AsOfProbeBuffer::ResolveComplexJoin(ExecutionContext &context, DataChunk &chunk) { - // perform the actual join - idx_t matches[STANDARD_VECTOR_SIZE]; - ResolveJoin(nullptr, matches); - - for (idx_t i = 0; i < lhs_match_count; ++i) { - const auto idx = lhs_sel[i]; - const auto match_pos = matches[idx]; - // Skip to the range containing the match - while (match_pos >= rhs_scanner->Scanned()) { - rhs_payload.Reset(); - rhs_scanner->Scan(rhs_payload); - } - // Append the individual values - // TODO: Batch the copies - const auto source_offset = match_pos - (rhs_scanner->Scanned() - rhs_payload.size()); - for (column_t col_idx = 0; col_idx < op.right_projection_map.size(); ++col_idx) { - const auto rhs_idx = op.right_projection_map[col_idx]; - auto &source = rhs_payload.data[rhs_idx]; - auto &target = chunk.data[lhs_payload.ColumnCount() + col_idx]; - VectorOperations::Copy(source, target, source_offset + 1, source_offset, i); - } - } - - // Slice the left payload into the result - for (column_t i = 0; i < lhs_payload.ColumnCount(); ++i) { - chunk.data[i].Slice(lhs_payload.data[i], lhs_sel, lhs_match_count); - } - chunk.SetCardinality(lhs_match_count); - - // If we are doing a left join, come back for the NULLs - fetch_next_left = !left_outer.Enabled(); -} - -void AsOfProbeBuffer::GetData(ExecutionContext &context, DataChunk &chunk) { - // Handle dangling left join results from current chunk - if (!fetch_next_left) { - fetch_next_left = true; - if (left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - left_outer.ConstructLeftJoinResult(lhs_payload, chunk); - left_outer.Reset(); - } - return; - } - - // Stop if there is no more data - if (!NextLeft()) { - return; - } - - switch (op.join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, chunk); - break; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::OUTER: - ResolveComplexJoin(context, chunk); - break; - default: - throw NotImplementedException("Unimplemented type for as-of join!"); - } -} - -class AsOfGlobalSourceState : public GlobalSourceState { -public: - explicit AsOfGlobalSourceState(AsOfGlobalSinkState &gsink_p) - : gsink(gsink_p), next_combine(0), combined(0), merged(0), mergers(0), next_left(0), flushed(0), next_right(0) { - } - - PartitionGlobalMergeStates &GetMergeStates() { - lock_guard guard(lock); - if (!merge_states) { - merge_states = make_uniq(*gsink.lhs_sink); - } - return *merge_states; - } - - AsOfGlobalSinkState &gsink; - //! The next buffer to combine - atomic next_combine; - //! The number of combined buffers - atomic combined; - //! The number of combined buffers - atomic merged; - //! The number of combined buffers - atomic mergers; - //! The next buffer to flush - atomic next_left; - //! The number of flushed buffers - atomic flushed; - //! The right outer output read position. - atomic next_right; - //! The merge handler - mutex lock; - unique_ptr merge_states; - -public: - idx_t MaxThreads() override { - return gsink.lhs_buffers.size(); - } -}; - -unique_ptr PhysicalAsOfJoin::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(gsink); -} - -class AsOfLocalSourceState : public LocalSourceState { -public: - using HashGroupPtr = unique_ptr; - - AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, ClientContext &client_p); - - // Return true if we were not interrupted (another thread died) - bool CombineLeftPartitions(); - bool MergeLeftPartitions(); - - idx_t BeginRightScan(const idx_t hash_bin); - - AsOfGlobalSourceState &gsource; - ClientContext &client; - - //! The left side partition being probed - AsOfProbeBuffer probe_buffer; - - //! The read partition - idx_t hash_bin; - HashGroupPtr hash_group; - //! The read cursor - unique_ptr scanner; - //! Pointer to the matches - const bool *found_match = {}; -}; - -AsOfLocalSourceState::AsOfLocalSourceState(AsOfGlobalSourceState &gsource, const PhysicalAsOfJoin &op, - ClientContext &client_p) - : gsource(gsource), client(client_p), probe_buffer(gsource.gsink.lhs_sink->context, op) { - gsource.mergers++; -} - -bool AsOfLocalSourceState::CombineLeftPartitions() { - const auto buffer_count = gsource.gsink.lhs_buffers.size(); - while (gsource.combined < buffer_count && !client.interrupted) { - const auto next_combine = gsource.next_combine++; - if (next_combine < buffer_count) { - gsource.gsink.lhs_buffers[next_combine]->Combine(); - ++gsource.combined; - } else { - TaskScheduler::GetScheduler(client).YieldThread(); - } - } - - return !client.interrupted; -} - -bool AsOfLocalSourceState::MergeLeftPartitions() { - PartitionGlobalMergeStates::Callback local_callback; - PartitionLocalMergeState local_merge(*gsource.gsink.lhs_sink); - gsource.GetMergeStates().ExecuteTask(local_merge, local_callback); - gsource.merged++; - while (gsource.merged < gsource.mergers && !client.interrupted) { - TaskScheduler::GetScheduler(client).YieldThread(); - } - return !client.interrupted; -} - -idx_t AsOfLocalSourceState::BeginRightScan(const idx_t hash_bin_p) { - hash_bin = hash_bin_p; - - hash_group = std::move(gsource.gsink.rhs_sink.hash_groups[hash_bin]); - if (hash_group->global_sort->sorted_blocks.empty()) { - return 0; - } - scanner = make_uniq(*hash_group->global_sort); - found_match = gsource.gsink.right_outers[hash_bin].GetMatches(); - - return scanner->Remaining(); -} - -unique_ptr PhysicalAsOfJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - auto &gsource = gstate.Cast(); - return make_uniq(gsource, *this, context.client); -} - -SourceResultType PhysicalAsOfJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gsource = input.global_state.Cast(); - auto &lsource = input.local_state.Cast(); - auto &rhs_sink = gsource.gsink.rhs_sink; - auto &client = context.client; - - // Step 1: Combine the partitions - if (!lsource.CombineLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 2: Sort on all threads - if (!lsource.MergeLeftPartitions()) { - return SourceResultType::FINISHED; - } - - // Step 3: Join the partitions - auto &lhs_sink = *gsource.gsink.lhs_sink; - const auto left_bins = lhs_sink.grouping_data ? lhs_sink.grouping_data->GetPartitions().size() : 1; - while (gsource.flushed < left_bins) { - // Make sure we have something to flush - if (!lsource.probe_buffer.Scanning()) { - const auto left_bin = gsource.next_left++; - if (left_bin < left_bins) { - // More to flush - lsource.probe_buffer.BeginLeftScan(left_bin); - } else if (!IsRightOuterJoin(join_type) || client.interrupted) { - return SourceResultType::FINISHED; - } else { - // Wait for all threads to finish - // TODO: How to implement a spin wait correctly? - // Returning BLOCKED seems to hang the system. - TaskScheduler::GetScheduler(client).YieldThread(); - continue; - } - } - - lsource.probe_buffer.GetData(context, chunk); - if (chunk.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else if (lsource.probe_buffer.HasMoreData()) { - // Join the next partition - continue; - } else { - lsource.probe_buffer.EndScan(); - gsource.flushed++; - } - } - - // Step 4: Emit right join matches - if (!IsRightOuterJoin(join_type)) { - return SourceResultType::FINISHED; - } - - auto &hash_groups = rhs_sink.hash_groups; - const auto right_groups = hash_groups.size(); - - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), rhs_sink.payload_types); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - - while (chunk.size() == 0) { - // Move to the next bin if we are done. - while (!lsource.scanner || !lsource.scanner->Remaining()) { - lsource.scanner.reset(); - lsource.hash_group.reset(); - auto hash_bin = gsource.next_right++; - if (hash_bin >= right_groups) { - return SourceResultType::FINISHED; - } - - for (; hash_bin < hash_groups.size(); hash_bin = gsource.next_right++) { - if (hash_groups[hash_bin]) { - break; - } - } - lsource.BeginRightScan(hash_bin); - } - const auto rhs_position = lsource.scanner->Scanned(); - lsource.scanner->Scan(rhs_chunk); - - const auto count = rhs_chunk.size(); - if (count == 0) { - return SourceResultType::FINISHED; - } - - // figure out which tuples didn't find a match in the RHS - auto found_match = lsource.found_match; - idx_t result_count = 0; - for (idx_t i = 0; i < count; i++) { - if (!found_match[rhs_position + i]) { - rsel.set_index(result_count++, i); - } - } - - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0]->types.size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - for (idx_t col_idx = 0; col_idx < right_projection_map.size(); ++col_idx) { - const auto rhs_idx = right_projection_map[col_idx]; - chunk.data[left_column_count + col_idx].Slice(rhs_chunk.data[rhs_idx], rsel, result_count); - } - chunk.SetCardinality(result_count); - break; - } - } - - return chunk.size() > 0 ? SourceResultType::HAVE_MORE_OUTPUT : SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp b/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp deleted file mode 100644 index deabb6f89..000000000 --- a/src/duckdb/src/execution/operator/join/physical_blockwise_nl_join.cpp +++ /dev/null @@ -1,277 +0,0 @@ -#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/join/outer_join_marker.hpp" -#include "duckdb/execution/operator/join/physical_comparison_join.hpp" -#include "duckdb/execution/operator/join/physical_cross_product.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -PhysicalBlockwiseNLJoin::PhysicalBlockwiseNLJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, unique_ptr condition, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalJoin(op, PhysicalOperatorType::BLOCKWISE_NL_JOIN, join_type, estimated_cardinality), - condition(std::move(condition)) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); - // MARK and SINGLE joins not handled - D_ASSERT(join_type != JoinType::MARK); - D_ASSERT(join_type != JoinType::SINGLE); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinLocalState : public LocalSinkState { -public: - BlockwiseNLJoinLocalState() { - } -}; - -class BlockwiseNLJoinGlobalState : public GlobalSinkState { -public: - explicit BlockwiseNLJoinGlobalState(ClientContext &context, const PhysicalBlockwiseNLJoin &op) - : right_chunks(context, op.children[1]->GetTypes()), right_outer(PropagatesBuildSide(op.join_type)) { - } - - mutex lock; - ColumnDataCollection right_chunks; - OuterJoinMarker right_outer; -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalBlockwiseNLJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(); -} - -SinkResultType PhysicalBlockwiseNLJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - lock_guard nl_lock(gstate.lock); - gstate.right_chunks.Append(chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalBlockwiseNLJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - gstate.right_outer.Initialize(gstate.right_chunks.Count()); - - if (gstate.right_chunks.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinState : public CachingOperatorState { -public: - explicit BlockwiseNLJoinState(ExecutionContext &context, ColumnDataCollection &rhs, - const PhysicalBlockwiseNLJoin &op) - : op(op), cross_product(rhs), left_outer(IsLeftOuterJoin(op.join_type)), match_sel(STANDARD_VECTOR_SIZE), - executor(context.client, *op.condition) { - left_outer.Initialize(STANDARD_VECTOR_SIZE); - ResetMatches(); - } - - const PhysicalBlockwiseNLJoin &op; - CrossProductExecutor cross_product; - OuterJoinMarker left_outer; - SelectionVector match_sel; - ExpressionExecutor executor; - DataChunk intermediate_chunk; - bool found_match[STANDARD_VECTOR_SIZE]; - - void ResetMatches() { - if (op.join_type != JoinType::SEMI && op.join_type != JoinType::ANTI) { - return; - } - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; i++) { - found_match[i] = false; - } - } -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetOperatorState(ExecutionContext &context) const { - auto &gstate = sink_state->Cast(); - auto result = make_uniq(context, gstate.right_chunks, *this); - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - vector intermediate_types; - for (auto &type : children[0]->types) { - intermediate_types.emplace_back(type); - } - for (auto &type : children[1]->types) { - intermediate_types.emplace_back(type); - } - result->intermediate_chunk.Initialize(Allocator::DefaultAllocator(), intermediate_types); - } - if (join_type == JoinType::RIGHT_ANTI || join_type == JoinType::RIGHT_SEMI) { - throw NotImplementedException("physical blockwise RIGHT_SEMI/RIGHT_ANTI join not yet implemented"); - } - return std::move(result); -} - -OperatorResultType PhysicalBlockwiseNLJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - D_ASSERT(input.size() > 0); - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - if (gstate.right_chunks.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - PhysicalComparisonJoin::ConstructEmptyJoinResult(join_type, false, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - DataChunk *intermediate_chunk = &chunk; - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - intermediate_chunk = &state.intermediate_chunk; - intermediate_chunk->Reset(); - } - - // now perform the actual join - // we perform a cross product, then execute the expression directly on the cross product result - idx_t result_count = 0; - - auto result = state.cross_product.Execute(input, *intermediate_chunk); - if (result == OperatorResultType::NEED_MORE_INPUT) { - // exhausted input, have to pull new LHS chunk - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(input, *intermediate_chunk); - state.left_outer.Reset(); - } - - if (join_type == JoinType::SEMI) { - PhysicalJoin::ConstructSemiJoinResult(input, chunk, state.found_match); - } - if (join_type == JoinType::ANTI) { - PhysicalJoin::ConstructAntiJoinResult(input, chunk, state.found_match); - } - state.ResetMatches(); - - return OperatorResultType::NEED_MORE_INPUT; - } - - // now perform the computation - result_count = state.executor.SelectExpression(*intermediate_chunk, state.match_sel); - - // handle anti and semi joins with different logic - if (result_count > 0) { - // found a match! - // handle anti semi join conditions first - if (join_type == JoinType::ANTI || join_type == JoinType::SEMI) { - if (state.cross_product.ScanLHS()) { - state.found_match[state.cross_product.PositionInChunk()] = true; - } else { - for (idx_t i = 0; i < result_count; i++) { - state.found_match[state.match_sel.get_index(i)] = true; - } - } - intermediate_chunk->Reset(); - // trick the loop to continue as semi and anti joins will never produce more output than - // the LHS cardinality - result_count = 0; - } else { - // check if the cross product is scanning the LHS or the RHS in its entirety - if (!state.cross_product.ScanLHS()) { - // set the match flags in the LHS - state.left_outer.SetMatches(state.match_sel, result_count); - // set the match flag in the RHS - gstate.right_outer.SetMatch(state.cross_product.ScanPosition() + state.cross_product.PositionInChunk()); - } else { - // set the match flag in the LHS - state.left_outer.SetMatch(state.cross_product.PositionInChunk()); - // set the match flags in the RHS - gstate.right_outer.SetMatches(state.match_sel, result_count, state.cross_product.ScanPosition()); - } - intermediate_chunk->Slice(state.match_sel, result_count); - } - } else { - // no result: reset the chunk - intermediate_chunk->Reset(); - } - - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -InsertionOrderPreservingMap PhysicalBlockwiseNLJoin::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Join Type"] = EnumUtil::ToString(join_type); - result["Condition"] = condition->GetName(); - return result; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class BlockwiseNLJoinGlobalScanState : public GlobalSourceState { -public: - explicit BlockwiseNLJoinGlobalScanState(const PhysicalBlockwiseNLJoin &op) : op(op) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(sink.right_chunks, scan_state); - } - - const PhysicalBlockwiseNLJoin &op; - OuterJoinGlobalScanState scan_state; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.right_outer.MaxThreads(); - } -}; - -class BlockwiseNLJoinLocalScanState : public LocalSourceState { -public: - explicit BlockwiseNLJoinLocalScanState(const PhysicalBlockwiseNLJoin &op, BlockwiseNLJoinGlobalScanState &gstate) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(gstate.scan_state, scan_state); - } - - OuterJoinLocalScanState scan_state; -}; - -unique_ptr PhysicalBlockwiseNLJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalBlockwiseNLJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, gstate.Cast()); -} - -SourceResultType PhysicalBlockwiseNLJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(PropagatesBuildSide(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output - sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp b/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp deleted file mode 100644 index 1091e2165..000000000 --- a/src/duckdb/src/execution/operator/join/physical_comparison_join.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "duckdb/execution/operator/join/physical_comparison_join.hpp" - -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -PhysicalComparisonJoin::PhysicalComparisonJoin(LogicalOperator &op, PhysicalOperatorType type, - vector conditions_p, JoinType join_type, - idx_t estimated_cardinality) - : PhysicalJoin(op, type, join_type, estimated_cardinality), conditions(std::move(conditions_p)) { - ReorderConditions(conditions); -} - -InsertionOrderPreservingMap PhysicalComparisonJoin::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Join Type"] = EnumUtil::ToString(join_type); - string condition_info; - for (idx_t i = 0; i < conditions.size(); i++) { - auto &join_condition = conditions[i]; - if (i > 0) { - condition_info += "\n"; - } - condition_info += - StringUtil::Format("%s %s %s", join_condition.left->GetName(), - ExpressionTypeToOperator(join_condition.comparison), join_condition.right->GetName()); - // string op = ExpressionTypeToOperator(it.comparison); - // extra_info += it.left->GetName() + " " + op + " " + it.right->GetName() + "\n"; - } - result["Conditions"] = condition_info; - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -void PhysicalComparisonJoin::ReorderConditions(vector &conditions) { - // we reorder conditions so the ones with COMPARE_EQUAL occur first - // check if this is already the case - bool is_ordered = true; - bool seen_non_equal = false; - for (auto &cond : conditions) { - if (cond.comparison == ExpressionType::COMPARE_EQUAL || - cond.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - if (seen_non_equal) { - is_ordered = false; - break; - } - } else { - seen_non_equal = true; - } - } - if (is_ordered) { - // no need to re-order - return; - } - // gather lists of equal/other conditions - vector equal_conditions; - vector other_conditions; - for (auto &cond : conditions) { - if (cond.comparison == ExpressionType::COMPARE_EQUAL || - cond.comparison == ExpressionType::COMPARE_NOT_DISTINCT_FROM) { - equal_conditions.push_back(std::move(cond)); - } else { - other_conditions.push_back(std::move(cond)); - } - } - conditions.clear(); - // reconstruct the sorted conditions - for (auto &cond : equal_conditions) { - conditions.push_back(std::move(cond)); - } - for (auto &cond : other_conditions) { - conditions.push_back(std::move(cond)); - } -} - -void PhysicalComparisonJoin::ConstructEmptyJoinResult(JoinType join_type, bool has_null, DataChunk &input, - DataChunk &result) { - // empty hash table, special case - if (join_type == JoinType::ANTI) { - // anti join with empty hash table, NOP join - // return the input - D_ASSERT(input.ColumnCount() == result.ColumnCount()); - result.Reference(input); - } else if (join_type == JoinType::MARK) { - // MARK join with empty hash table - D_ASSERT(join_type == JoinType::MARK); - D_ASSERT(result.ColumnCount() == input.ColumnCount() + 1); - auto &result_vector = result.data.back(); - D_ASSERT(result_vector.GetType() == LogicalType::BOOLEAN); - // for every data vector, we just reference the child chunk - result.SetCardinality(input); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // for the MARK vector: - // if the HT has no NULL values (i.e. empty result set), return a vector that has false for every input - // entry if the HT has NULL values (i.e. result set had values, but all were NULL), return a vector that - // has NULL for every input entry - if (!has_null) { - auto bool_result = FlatVector::GetData(result_vector); - for (idx_t i = 0; i < result.size(); i++) { - bool_result[i] = false; - } - } else { - FlatVector::Validity(result_vector).SetAllInvalid(result.size()); - } - } else if (join_type == JoinType::LEFT || join_type == JoinType::OUTER || join_type == JoinType::SINGLE) { - // LEFT/FULL OUTER/SINGLE join and build side is empty - // for the LHS we reference the data - result.SetCardinality(input.size()); - for (idx_t i = 0; i < input.ColumnCount(); i++) { - result.data[i].Reference(input.data[i]); - } - // for the RHS - for (idx_t k = input.ColumnCount(); k < result.ColumnCount(); k++) { - result.data[k].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[k], true); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp b/src/duckdb/src/execution/operator/join/physical_cross_product.cpp deleted file mode 100644 index a1175017e..000000000 --- a/src/duckdb/src/execution/operator/join/physical_cross_product.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "duckdb/execution/operator/join/physical_cross_product.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/operator/join/physical_join.hpp" - -namespace duckdb { - -PhysicalCrossProduct::PhysicalCrossProduct(vector types, unique_ptr left, - unique_ptr right, idx_t estimated_cardinality) - : CachingPhysicalOperator(PhysicalOperatorType::CROSS_PRODUCT, std::move(types), estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CrossProductGlobalState : public GlobalSinkState { -public: - explicit CrossProductGlobalState(ClientContext &context, const PhysicalCrossProduct &op) - : rhs_materialized(context, op.children[1]->GetTypes()) { - rhs_materialized.InitializeAppend(append_state); - } - - ColumnDataCollection rhs_materialized; - ColumnDataAppendState append_state; - mutex rhs_lock; -}; - -unique_ptr PhysicalCrossProduct::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalCrossProduct::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &sink = input.global_state.Cast(); - lock_guard client_guard(sink.rhs_lock); - sink.rhs_materialized.Append(sink.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -CrossProductExecutor::CrossProductExecutor(ColumnDataCollection &rhs) - : rhs(rhs), position_in_chunk(0), initialized(false), finished(false) { - rhs.InitializeScanChunk(scan_chunk); -} - -void CrossProductExecutor::Reset(DataChunk &input, DataChunk &output) { - initialized = true; - finished = false; - scan_input_chunk = false; - rhs.InitializeScan(scan_state); - position_in_chunk = 0; - scan_chunk.Reset(); -} - -bool CrossProductExecutor::NextValue(DataChunk &input, DataChunk &output) { - if (!initialized) { - // not initialized yet: initialize the scan - Reset(input, output); - } - position_in_chunk++; - idx_t chunk_size = scan_input_chunk ? input.size() : scan_chunk.size(); - if (position_in_chunk < chunk_size) { - return true; - } - // fetch the next chunk - rhs.Scan(scan_state, scan_chunk); - position_in_chunk = 0; - if (scan_chunk.size() == 0) { - return false; - } - // the way the cross product works is that we keep one chunk constantly referenced - // while iterating over the other chunk one value at a time - // the second one is the chunk we are "scanning" - - // for the engine, it is better if we emit larger chunks - // hence the chunk that we keep constantly referenced should be the larger of the two - scan_input_chunk = input.size() < scan_chunk.size(); - return true; -} - -OperatorResultType CrossProductExecutor::Execute(DataChunk &input, DataChunk &output) { - if (rhs.Count() == 0) { - // no RHS: empty result - return OperatorResultType::FINISHED; - } - if (!NextValue(input, output)) { - // ran out of entries on the RHS - // reset the RHS and move to the next chunk on the LHS - initialized = false; - return OperatorResultType::NEED_MORE_INPUT; - } - - // set up the constant chunk - auto &constant_chunk = scan_input_chunk ? scan_chunk : input; - auto col_count = constant_chunk.ColumnCount(); - auto col_offset = scan_input_chunk ? input.ColumnCount() : 0; - output.SetCardinality(constant_chunk.size()); - for (idx_t i = 0; i < col_count; i++) { - output.data[col_offset + i].Reference(constant_chunk.data[i]); - } - - // for the chunk that we are scanning, scan a single value from that chunk - auto &scan = scan_input_chunk ? input : scan_chunk; - col_count = scan.ColumnCount(); - col_offset = scan_input_chunk ? 0 : input.ColumnCount(); - for (idx_t i = 0; i < col_count; i++) { - ConstantVector::Reference(output.data[col_offset + i], scan.data[i], position_in_chunk, scan.size()); - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -class CrossProductOperatorState : public CachingOperatorState { -public: - explicit CrossProductOperatorState(ColumnDataCollection &rhs) : executor(rhs) { - } - - CrossProductExecutor executor; -}; - -unique_ptr PhysicalCrossProduct::GetOperatorState(ExecutionContext &context) const { - auto &sink = sink_state->Cast(); - return make_uniq(sink.rhs_materialized); -} - -OperatorResultType PhysicalCrossProduct::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - return state.executor.Execute(input, chunk); -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalCrossProduct::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalCrossProduct::GetSources() const { - return children[0]->GetSources(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_delim_join.cpp deleted file mode 100644 index 1da39904d..000000000 --- a/src/duckdb/src/execution/operator/join/physical_delim_join.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "duckdb/execution/operator/join/physical_delim_join.hpp" - -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" - -namespace duckdb { - -PhysicalDelimJoin::PhysicalDelimJoin(PhysicalOperatorType type, vector types, - unique_ptr original_join, - vector> delim_scans, idx_t estimated_cardinality, - optional_idx delim_idx) - : PhysicalOperator(type, std::move(types), estimated_cardinality), join(std::move(original_join)), - delim_scans(std::move(delim_scans)), delim_idx(delim_idx) { - D_ASSERT(type == PhysicalOperatorType::LEFT_DELIM_JOIN || type == PhysicalOperatorType::RIGHT_DELIM_JOIN); -} - -vector> PhysicalDelimJoin::GetChildren() const { - vector> result; - for (auto &child : children) { - result.push_back(*child); - } - result.push_back(*join); - result.push_back(*distinct); - return result; -} - -InsertionOrderPreservingMap PhysicalDelimJoin::ParamsToString() const { - auto result = join->ParamsToString(); - result["Delim Index"] = StringUtil::Format("%llu", delim_idx.GetIndex()); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp b/src/duckdb/src/execution/operator/join/physical_hash_join.cpp deleted file mode 100644 index 3c01a4acf..000000000 --- a/src/duckdb/src/execution/operator/join/physical_hash_join.cpp +++ /dev/null @@ -1,1391 +0,0 @@ -#include "duckdb/execution/operator/join/physical_hash_join.hpp" - -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/types/value_map.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/aggregate/ungrouped_aggregate_state.hpp" -#include "duckdb/function/aggregate/distributive_function_utils.hpp" -#include "duckdb/function/aggregate/distributive_functions.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/query_profiler.hpp" -#include "duckdb/optimizer/filter_combiner.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/executor_task.hpp" -#include "duckdb/parallel/interrupt.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" -#include "duckdb/planner/filter/constant_filter.hpp" -#include "duckdb/planner/filter/in_filter.hpp" -#include "duckdb/planner/filter/null_filter.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/table_filter.hpp" -#include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/storage/temporary_memory_manager.hpp" - -namespace duckdb { - -PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - const vector &left_projection_map, const vector &right_projection_map, - vector delim_types, idx_t estimated_cardinality, - unique_ptr pushdown_info_p) - : PhysicalComparisonJoin(op, PhysicalOperatorType::HASH_JOIN, std::move(cond), join_type, estimated_cardinality), - delim_types(std::move(delim_types)) { - - filter_pushdown = std::move(pushdown_info_p); - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - // Collect condition types, and which conditions are just references (so we won't duplicate them in the payload) - unordered_map build_columns_in_conditions; - for (idx_t cond_idx = 0; cond_idx < conditions.size(); cond_idx++) { - auto &condition = conditions[cond_idx]; - condition_types.push_back(condition.left->return_type); - if (condition.right->GetExpressionClass() == ExpressionClass::BOUND_REF) { - build_columns_in_conditions.emplace(condition.right->Cast().index, cond_idx); - } - } - - auto &lhs_input_types = children[0]->GetTypes(); - - // Create a projection map for the LHS (if it was empty), for convenience - lhs_output_columns.col_idxs = left_projection_map; - if (lhs_output_columns.col_idxs.empty()) { - lhs_output_columns.col_idxs.reserve(lhs_input_types.size()); - for (idx_t i = 0; i < lhs_input_types.size(); i++) { - lhs_output_columns.col_idxs.emplace_back(i); - } - } - - for (auto &lhs_col : lhs_output_columns.col_idxs) { - auto &lhs_col_type = lhs_input_types[lhs_col]; - lhs_output_columns.col_types.push_back(lhs_col_type); - } - - // For ANTI, SEMI and MARK join, we only need to store the keys, so for these the payload/RHS types are empty - if (join_type == JoinType::ANTI || join_type == JoinType::SEMI || join_type == JoinType::MARK) { - return; - } - - auto &rhs_input_types = children[1]->GetTypes(); - - // Create a projection map for the RHS (if it was empty), for convenience - auto right_projection_map_copy = right_projection_map; - if (right_projection_map_copy.empty()) { - right_projection_map_copy.reserve(rhs_input_types.size()); - for (idx_t i = 0; i < rhs_input_types.size(); i++) { - right_projection_map_copy.emplace_back(i); - } - } - - // Now fill payload expressions/types and RHS columns/types - for (auto &rhs_col : right_projection_map_copy) { - auto &rhs_col_type = rhs_input_types[rhs_col]; - - auto it = build_columns_in_conditions.find(rhs_col); - if (it == build_columns_in_conditions.end()) { - // This rhs column is not a join key - payload_columns.col_idxs.push_back(rhs_col); - payload_columns.col_types.push_back(rhs_col_type); - rhs_output_columns.col_idxs.push_back(condition_types.size() + payload_columns.col_types.size() - 1); - } else { - // This rhs column is a join key - rhs_output_columns.col_idxs.push_back(it->second); - } - rhs_output_columns.col_types.push_back(rhs_col_type); - } -} - -PhysicalHashJoin::PhysicalHashJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - idx_t estimated_cardinality) - : PhysicalHashJoin(op, std::move(left), std::move(right), std::move(cond), join_type, {}, {}, {}, - estimated_cardinality, nullptr) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -JoinFilterGlobalState::~JoinFilterGlobalState() { -} - -JoinFilterLocalState::~JoinFilterLocalState() { -} - -unique_ptr JoinFilterPushdownInfo::GetGlobalState(ClientContext &context, - const PhysicalOperator &op) const { - // clear any previously set filters - // we can have previous filters for this operator in case of e.g. recursive CTEs - for (auto &info : probe_info) { - info.dynamic_filters->ClearFilters(op); - } - auto result = make_uniq(); - result->global_aggregate_state = - make_uniq(BufferAllocator::Get(context), min_max_aggregates); - return result; -} - -class HashJoinGlobalSinkState : public GlobalSinkState { -public: - HashJoinGlobalSinkState(const PhysicalHashJoin &op_p, ClientContext &context_p) - : context(context_p), op(op_p), - num_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), - temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), finalized(false), - active_local_states(0), total_size(0), max_partition_size(0), max_partition_count(0), - probe_side_requirement(0), scanned_data(false) { - hash_table = op.InitializeHashTable(context); - - // For perfect hash join - perfect_join_executor = make_uniq(op, *hash_table); - bool use_perfect_hash = false; - if (op.conditions.size() == 1 && !op.join_stats.empty() && op.join_stats[1] && - TypeIsIntegral(op.join_stats[1]->GetType().InternalType()) && NumericStats::HasMinMax(*op.join_stats[1])) { - use_perfect_hash = perfect_join_executor->CanDoPerfectHashJoin(op, NumericStats::Min(*op.join_stats[1]), - NumericStats::Max(*op.join_stats[1])); - } - // For external hash join - external = ClientConfig::GetConfig(context).GetSetting(context); - // Set probe types - probe_types = op.children[0]->types; - probe_types.emplace_back(LogicalType::HASH); - - if (op.filter_pushdown) { - if (op.filter_pushdown->probe_info.empty() && use_perfect_hash) { - // Only computing min/max to check for perfect HJ, but we already can - skip_filter_pushdown = true; - } - global_filter_state = op.filter_pushdown->GetGlobalState(context, op); - } - } - - void ScheduleFinalize(Pipeline &pipeline, Event &event); - void InitializeProbeSpill(); - -public: - ClientContext &context; - const PhysicalHashJoin &op; - - const idx_t num_threads; - //! Temporary memory state for managing this operator's memory usage - unique_ptr temporary_memory_state; - - //! Global HT used by the join - unique_ptr hash_table; - //! The perfect hash join executor (if any) - unique_ptr perfect_join_executor; - //! Whether or not the hash table has been finalized - bool finalized; - //! The number of active local states - atomic active_local_states; - - //! Whether we are doing an external + some sizes - bool external; - idx_t total_size; - idx_t max_partition_size; - idx_t max_partition_count; - idx_t probe_side_requirement; - - //! Hash tables built by each thread - vector> local_hash_tables; - - //! Excess probe data gathered during Sink - vector probe_types; - unique_ptr probe_spill; - - //! Whether or not we have started scanning data using GetData - atomic scanned_data; - - bool skip_filter_pushdown = false; - unique_ptr global_filter_state; -}; - -unique_ptr JoinFilterPushdownInfo::GetLocalState(JoinFilterGlobalState &gstate) const { - auto result = make_uniq(); - result->local_aggregate_state = make_uniq(*gstate.global_aggregate_state); - return result; -} - -class HashJoinLocalSinkState : public LocalSinkState { -public: - HashJoinLocalSinkState(const PhysicalHashJoin &op, ClientContext &context, HashJoinGlobalSinkState &gstate) - : join_key_executor(context) { - auto &allocator = BufferAllocator::Get(context); - - for (auto &cond : op.conditions) { - join_key_executor.AddExpression(*cond.right); - } - join_keys.Initialize(allocator, op.condition_types); - - if (!op.payload_columns.col_types.empty()) { - payload_chunk.Initialize(allocator, op.payload_columns.col_types); - } - - hash_table = op.InitializeHashTable(context); - hash_table->GetSinkCollection().InitializeAppendState(append_state); - - gstate.active_local_states++; - - if (op.filter_pushdown) { - local_filter_state = op.filter_pushdown->GetLocalState(*gstate.global_filter_state); - } - } - -public: - PartitionedTupleDataAppendState append_state; - - ExpressionExecutor join_key_executor; - DataChunk join_keys; - - DataChunk payload_chunk; - - //! Thread-local HT - unique_ptr hash_table; - - unique_ptr local_filter_state; -}; - -unique_ptr PhysicalHashJoin::InitializeHashTable(ClientContext &context) const { - auto result = make_uniq(context, conditions, payload_columns.col_types, join_type, - rhs_output_columns.col_idxs); - if (!delim_types.empty() && join_type == JoinType::MARK) { - // correlated MARK join - if (delim_types.size() + 1 == conditions.size()) { - // the correlated MARK join has one more condition than the amount of correlated columns - // this is the case in a correlated ANY() expression - // in this case we need to keep track of additional entries, namely: - // - (1) the total amount of elements per group - // - (2) the amount of non-null elements per group - // we need these to correctly deal with the cases of either: - // - (1) the group being empty [in which case the result is always false, even if the comparison is NULL] - // - (2) the group containing a NULL value [in which case FALSE becomes NULL] - auto &info = result->correlated_mark_join_info; - - vector delim_payload_types; - vector correlated_aggregates; - unique_ptr aggr; - - // jury-rigging the GroupedAggregateHashTable - // we need a count_star and a count to get counts with and without NULLs - - FunctionBinder function_binder(context); - aggr = function_binder.BindAggregateFunction(CountStarFun::GetFunction(), {}, nullptr, - AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); - delim_payload_types.push_back(aggr->return_type); - info.correlated_aggregates.push_back(std::move(aggr)); - - auto count_fun = CountFunctionBase::GetFunction(); - vector> children; - // this is a dummy but we need it to make the hash table understand whats going on - children.push_back(make_uniq_base(count_fun.return_type, 0U)); - aggr = function_binder.BindAggregateFunction(count_fun, std::move(children), nullptr, - AggregateType::NON_DISTINCT); - correlated_aggregates.push_back(&*aggr); - delim_payload_types.push_back(aggr->return_type); - info.correlated_aggregates.push_back(std::move(aggr)); - - auto &allocator = BufferAllocator::Get(context); - info.correlated_counts = make_uniq(context, allocator, delim_types, - delim_payload_types, correlated_aggregates); - info.correlated_types = delim_types; - info.group_chunk.Initialize(allocator, delim_types); - info.result_chunk.Initialize(allocator, delim_payload_types); - } - } - return result; -} - -unique_ptr PhysicalHashJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashJoin::GetLocalSinkState(ExecutionContext &context) const { - auto &gstate = sink_state->Cast(); - return make_uniq(*this, context.client, gstate); -} - -void JoinFilterPushdownInfo::Sink(DataChunk &chunk, JoinFilterLocalState &lstate) const { - // if we are pushing any filters into a probe-side, compute the min/max over the columns that we are pushing - for (idx_t pushdown_idx = 0; pushdown_idx < join_condition.size(); pushdown_idx++) { - auto join_condition_idx = join_condition[pushdown_idx]; - for (idx_t i = 0; i < 2; i++) { - idx_t aggr_idx = pushdown_idx * 2 + i; - lstate.local_aggregate_state->Sink(chunk, join_condition_idx, aggr_idx); - } - } -} - -SinkResultType PhysicalHashJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // resolve the join keys for the right chunk - lstate.join_keys.Reset(); - lstate.join_key_executor.Execute(chunk, lstate.join_keys); - - if (filter_pushdown && !gstate.skip_filter_pushdown) { - filter_pushdown->Sink(lstate.join_keys, *lstate.local_filter_state); - } - - if (payload_columns.col_types.empty()) { // there are only keys: place an empty chunk in the payload - lstate.payload_chunk.SetCardinality(chunk.size()); - } else { // there are payload columns - lstate.payload_chunk.ReferenceColumns(chunk, payload_columns.col_idxs); - } - - // build the HT - lstate.hash_table->Build(lstate.append_state, lstate.join_keys, lstate.payload_chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -void JoinFilterPushdownInfo::Combine(JoinFilterGlobalState &gstate, JoinFilterLocalState &lstate) const { - gstate.global_aggregate_state->Combine(*lstate.local_aggregate_state); -} - -SinkCombineResultType PhysicalHashJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - lstate.hash_table->GetSinkCollection().FlushAppendState(lstate.append_state); - auto guard = gstate.Lock(); - gstate.local_hash_tables.push_back(std::move(lstate.hash_table)); - if (gstate.local_hash_tables.size() == gstate.active_local_states) { - // Set to 0 until PrepareFinalize - gstate.temporary_memory_state->SetZero(); - } - - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - if (filter_pushdown && !gstate.skip_filter_pushdown) { - filter_pushdown->Combine(*gstate.global_filter_state, *lstate.local_filter_state); - } - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -static idx_t GetTupleWidth(const vector &types, bool &all_constant) { - idx_t tuple_width = 0; - all_constant = true; - for (auto &type : types) { - tuple_width += GetTypeIdSize(type.InternalType()); - all_constant &= TypeIsConstantSize(type.InternalType()); - } - return tuple_width + AlignValue(types.size()) / 8 + GetTypeIdSize(PhysicalType::UINT64); -} - -static idx_t GetPartitioningSpaceRequirement(ClientContext &context, const vector &types, - const idx_t radix_bits, const idx_t num_threads) { - auto &buffer_manager = BufferManager::GetBufferManager(context); - bool all_constant; - idx_t tuple_width = GetTupleWidth(types, all_constant); - - auto tuples_per_block = buffer_manager.GetBlockSize() / tuple_width; - auto blocks_per_chunk = (STANDARD_VECTOR_SIZE + tuples_per_block) / tuples_per_block + 1; - if (!all_constant) { - blocks_per_chunk += 2; - } - auto size_per_partition = blocks_per_chunk * buffer_manager.GetBlockAllocSize(); - auto num_partitions = RadixPartitioning::NumberOfPartitions(radix_bits); - - return num_threads * num_partitions * size_per_partition; -} - -void PhysicalHashJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &global_state) const { - auto &gstate = global_state.Cast(); - const auto &ht = *gstate.hash_table; - - gstate.total_size = - ht.GetTotalSize(gstate.local_hash_tables, gstate.max_partition_size, gstate.max_partition_count); - gstate.probe_side_requirement = - GetPartitioningSpaceRequirement(context, children[0]->types, ht.GetRadixBits(), gstate.num_threads); - const auto max_partition_ht_size = - gstate.max_partition_size + JoinHashTable::PointerTableSize(gstate.max_partition_count); - gstate.temporary_memory_state->SetMinimumReservation(max_partition_ht_size + gstate.probe_side_requirement); - - bool all_constant; - gstate.temporary_memory_state->SetMaterializationPenalty(GetTupleWidth(children[0]->types, all_constant)); - gstate.temporary_memory_state->SetRemainingSize(gstate.total_size); -} - -class HashJoinFinalizeTask : public ExecutorTask { -public: - HashJoinFinalizeTask(shared_ptr event_p, ClientContext &context, HashJoinGlobalSinkState &sink_p, - idx_t chunk_idx_from_p, idx_t chunk_idx_to_p, bool parallel_p, const PhysicalOperator &op_p) - : ExecutorTask(context, std::move(event_p), op_p), sink(sink_p), chunk_idx_from(chunk_idx_from_p), - chunk_idx_to(chunk_idx_to_p), parallel(parallel_p) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - sink.hash_table->Finalize(chunk_idx_from, chunk_idx_to, parallel); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - HashJoinGlobalSinkState &sink; - idx_t chunk_idx_from; - idx_t chunk_idx_to; - bool parallel; -}; - -class HashJoinFinalizeEvent : public BasePipelineEvent { -public: - HashJoinFinalizeEvent(Pipeline &pipeline_p, HashJoinGlobalSinkState &sink) - : BasePipelineEvent(pipeline_p), sink(sink) { - } - - HashJoinGlobalSinkState &sink; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - vector> finalize_tasks; - auto &ht = *sink.hash_table; - const auto chunk_count = ht.GetDataCollection().ChunkCount(); - const auto num_threads = NumericCast(sink.num_threads); - - // If the data is very skewed (many of the exact same key), our finalize will become slow, - // due to completely slamming the same atomic using compare-and-swaps. - // We can detect this because we partition the data, and go for a single-threaded finalize instead. - const auto max_partition_ht_size = - sink.max_partition_size + JoinHashTable::PointerTableSize(sink.max_partition_count); - const auto skew = static_cast(max_partition_ht_size) / static_cast(sink.total_size); - - if (num_threads == 1 || (ht.Count() < PARALLEL_CONSTRUCT_THRESHOLD && skew > SKEW_SINGLE_THREADED_THRESHOLD && - !context.config.verify_parallelism)) { - // Single-threaded finalize - finalize_tasks.push_back( - make_uniq(shared_from_this(), context, sink, 0U, chunk_count, false, sink.op)); - } else { - // Parallel finalize - const idx_t chunks_per_task = context.config.verify_parallelism ? 1 : CHUNKS_PER_TASK; - for (idx_t chunk_idx = 0; chunk_idx < chunk_count; chunk_idx += chunks_per_task) { - auto chunk_idx_to = MinValue(chunk_idx + chunks_per_task, chunk_count); - finalize_tasks.push_back(make_uniq(shared_from_this(), context, sink, chunk_idx, - chunk_idx_to, true, sink.op)); - } - } - SetTasks(std::move(finalize_tasks)); - } - - void FinishEvent() override { - sink.hash_table->GetDataCollection().VerifyEverythingPinned(); - sink.hash_table->finalized = true; - } - - static constexpr idx_t PARALLEL_CONSTRUCT_THRESHOLD = 1048576; - static constexpr idx_t CHUNKS_PER_TASK = 64; - static constexpr double SKEW_SINGLE_THREADED_THRESHOLD = 0.33; -}; - -void HashJoinGlobalSinkState::ScheduleFinalize(Pipeline &pipeline, Event &event) { - if (hash_table->Count() == 0) { - hash_table->finalized = true; - return; - } - hash_table->InitializePointerTable(); - auto new_event = make_shared_ptr(pipeline, *this); - event.InsertEvent(std::move(new_event)); -} - -void HashJoinGlobalSinkState::InitializeProbeSpill() { - auto guard = Lock(); - if (!probe_spill) { - probe_spill = make_uniq(*hash_table, context, probe_types); - } -} - -class HashJoinRepartitionTask : public ExecutorTask { -public: - HashJoinRepartitionTask(shared_ptr event_p, ClientContext &context, JoinHashTable &global_ht, - JoinHashTable &local_ht, const PhysicalOperator &op_p) - : ExecutorTask(context, std::move(event_p), op_p), global_ht(global_ht), local_ht(local_ht) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - local_ht.Repartition(global_ht); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - JoinHashTable &global_ht; - JoinHashTable &local_ht; -}; - -class HashJoinRepartitionEvent : public BasePipelineEvent { -public: - HashJoinRepartitionEvent(Pipeline &pipeline_p, const PhysicalHashJoin &op_p, HashJoinGlobalSinkState &sink, - vector> &local_hts) - : BasePipelineEvent(pipeline_p), op(op_p), sink(sink), local_hts(local_hts) { - } - - const PhysicalHashJoin &op; - HashJoinGlobalSinkState &sink; - vector> &local_hts; - -public: - void Schedule() override { - D_ASSERT(sink.hash_table->GetRadixBits() > JoinHashTable::INITIAL_RADIX_BITS); - auto block_size = sink.hash_table->buffer_manager.GetBlockSize(); - - idx_t total_size = 0; - idx_t total_count = 0; - for (auto &local_ht : local_hts) { - auto &sink_collection = local_ht->GetSinkCollection(); - total_size += sink_collection.SizeInBytes(); - total_count += sink_collection.Count(); - } - auto total_blocks = (total_size + block_size - 1) / block_size; - auto count_per_block = total_count / total_blocks; - auto blocks_per_vector = MaxValue(STANDARD_VECTOR_SIZE / count_per_block, 2); - - // Assume 8 blocks per partition per thread (4 input, 4 output) - auto partition_multiplier = - RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits() - JoinHashTable::INITIAL_RADIX_BITS); - auto thread_memory = 2 * blocks_per_vector * partition_multiplier * block_size; - auto repartition_threads = MaxValue(sink.temporary_memory_state->GetReservation() / thread_memory, 1); - - if (repartition_threads < local_hts.size()) { - // Limit the number of threads working on repartitioning based on our memory reservation - for (idx_t thread_idx = repartition_threads; thread_idx < local_hts.size(); thread_idx++) { - local_hts[thread_idx % repartition_threads]->Merge(*local_hts[thread_idx]); - } - local_hts.resize(repartition_threads); - } - - auto &context = pipeline->GetClientContext(); - - vector> partition_tasks; - partition_tasks.reserve(local_hts.size()); - for (auto &local_ht : local_hts) { - partition_tasks.push_back( - make_uniq(shared_from_this(), context, *sink.hash_table, *local_ht, op)); - } - SetTasks(std::move(partition_tasks)); - } - - void FinishEvent() override { - local_hts.clear(); - - // Minimum reservation is now the new smallest partition size - const auto num_partitions = RadixPartitioning::NumberOfPartitions(sink.hash_table->GetRadixBits()); - vector partition_sizes(num_partitions, 0); - vector partition_counts(num_partitions, 0); - sink.total_size = sink.hash_table->GetTotalSize(partition_sizes, partition_counts, sink.max_partition_size, - sink.max_partition_count); - sink.probe_side_requirement = - GetPartitioningSpaceRequirement(sink.context, op.types, sink.hash_table->GetRadixBits(), sink.num_threads); - - sink.temporary_memory_state->SetMinimumReservation(sink.max_partition_size + - JoinHashTable::PointerTableSize(sink.max_partition_count) + - sink.probe_side_requirement); - sink.temporary_memory_state->UpdateReservation(executor.context); - - D_ASSERT(sink.temporary_memory_state->GetReservation() >= sink.probe_side_requirement); - sink.hash_table->PrepareExternalFinalize(sink.temporary_memory_state->GetReservation() - - sink.probe_side_requirement); - sink.ScheduleFinalize(*pipeline, *this); - } -}; - -void JoinFilterPushdownInfo::PushInFilter(const JoinFilterPushdownFilter &info, JoinHashTable &ht, - const PhysicalOperator &op, idx_t filter_idx, idx_t filter_col_idx) const { - // generate a "OR" filter (i.e. x=1 OR x=535 OR x=997) - // first scan the entire vector at the probe side - // FIXME: this code is duplicated from PerfectHashJoinExecutor::FullScanHashTable - auto build_idx = join_condition[filter_idx]; - auto &data_collection = ht.GetDataCollection(); - - Vector tuples_addresses(LogicalType::POINTER, ht.Count()); // allocate space for all the tuples - - JoinHTScanState join_ht_state(data_collection, 0, data_collection.ChunkCount(), - TupleDataPinProperties::KEEP_EVERYTHING_PINNED); - - // Go through all the blocks and fill the keys addresses - idx_t key_count = ht.FillWithHTOffsets(join_ht_state, tuples_addresses); - - // Scan the build keys in the hash table - Vector build_vector(ht.layout.GetTypes()[build_idx], key_count); - data_collection.Gather(tuples_addresses, *FlatVector::IncrementalSelectionVector(), key_count, build_idx, - build_vector, *FlatVector::IncrementalSelectionVector(), nullptr); - - // generate the OR-clause - note that we only need to consider unique values here (so we use a seT) - value_set_t unique_ht_values; - for (idx_t k = 0; k < key_count; k++) { - unique_ht_values.insert(build_vector.GetValue(k)); - } - vector in_list(unique_ht_values.begin(), unique_ht_values.end()); - - // generating the OR filter only makes sense if the range is - // not dense and that the range does not contain NULL - // i.e. if we have the values [0, 1, 2, 3, 4] - the min/max is fully equivalent to the OR filter - if (FilterCombiner::ContainsNull(in_list) || FilterCombiner::IsDenseRange(in_list)) { - return; - } - - // generate the OR filter - auto in_filter = make_uniq(std::move(in_list)); - in_filter->origin_is_hash_join = true; - - // we push the OR filter as an OptionalFilter so that we can use it for zonemap pruning only - // the IN-list is expensive to execute otherwise - auto filter = make_uniq(std::move(in_filter)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(filter)); - return; -} - -unique_ptr JoinFilterPushdownInfo::Finalize(ClientContext &context, JoinHashTable &ht, - JoinFilterGlobalState &gstate, - const PhysicalOperator &op) const { - // finalize the min/max aggregates - vector min_max_types; - for (auto &aggr_expr : min_max_aggregates) { - min_max_types.push_back(aggr_expr->return_type); - } - auto final_min_max = make_uniq(); - final_min_max->Initialize(Allocator::DefaultAllocator(), min_max_types); - - gstate.global_aggregate_state->Finalize(*final_min_max); - - if (probe_info.empty()) { - return final_min_max; // There are not table souces in which we can push down filters - } - - auto dynamic_or_filter_threshold = ClientConfig::GetSetting(context); - // create a filter for each of the aggregates - for (idx_t filter_idx = 0; filter_idx < join_condition.size(); filter_idx++) { - for (auto &info : probe_info) { - auto filter_col_idx = info.columns[filter_idx].probe_column_index.column_index; - auto min_idx = filter_idx * 2; - auto max_idx = min_idx + 1; - - auto min_val = final_min_max->data[min_idx].GetValue(0); - auto max_val = final_min_max->data[max_idx].GetValue(0); - if (min_val.IsNull() || max_val.IsNull()) { - // min/max is NULL - // this can happen in case all values in the RHS column are NULL, but they are still pushed into the - // hash table e.g. because they are part of a RIGHT join - continue; - } - // if the HT is small we can generate a complete "OR" filter - if (ht.Count() > 1 && ht.Count() <= dynamic_or_filter_threshold) { - PushInFilter(info, ht, op, filter_idx, filter_col_idx); - } - - if (Value::NotDistinctFrom(min_val, max_val)) { - // min = max - generate an equality filter - auto constant_filter = make_uniq(ExpressionType::COMPARE_EQUAL, std::move(min_val)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(constant_filter)); - } else { - // min != max - generate a range filter - auto greater_equals = - make_uniq(ExpressionType::COMPARE_GREATERTHANOREQUALTO, std::move(min_val)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(greater_equals)); - auto less_equals = - make_uniq(ExpressionType::COMPARE_LESSTHANOREQUALTO, std::move(max_val)); - info.dynamic_filters->PushFilter(op, filter_col_idx, std::move(less_equals)); - } - } - } - - return final_min_max; -} - -SinkFinalizeType PhysicalHashJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &sink = input.global_state.Cast(); - auto &ht = *sink.hash_table; - - sink.temporary_memory_state->UpdateReservation(context); - sink.external = sink.temporary_memory_state->GetReservation() < sink.total_size; - if (sink.external) { - // External Hash Join - sink.perfect_join_executor.reset(); - - const auto max_partition_ht_size = - sink.max_partition_size + JoinHashTable::PointerTableSize(sink.max_partition_count); - const auto very_very_skewed = // No point in repartitioning if it's this skewed - static_cast(max_partition_ht_size) >= 0.8 * static_cast(sink.total_size); - if (!very_very_skewed && - (max_partition_ht_size + sink.probe_side_requirement) > sink.temporary_memory_state->GetReservation()) { - // We have to repartition - ht.SetRepartitionRadixBits(sink.temporary_memory_state->GetReservation(), sink.max_partition_size, - sink.max_partition_count); - auto new_event = make_shared_ptr(pipeline, *this, sink, sink.local_hash_tables); - event.InsertEvent(std::move(new_event)); - } else { - // No repartitioning! - for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); - } - sink.local_hash_tables.clear(); - D_ASSERT(sink.temporary_memory_state->GetReservation() >= sink.probe_side_requirement); - sink.hash_table->PrepareExternalFinalize(sink.temporary_memory_state->GetReservation() - - sink.probe_side_requirement); - sink.ScheduleFinalize(pipeline, event); - } - sink.finalized = true; - return SinkFinalizeType::READY; - } - - // In-memory Hash Join - for (auto &local_ht : sink.local_hash_tables) { - ht.Merge(*local_ht); - } - sink.local_hash_tables.clear(); - ht.Unpartition(); - - Value min; - Value max; - if (filter_pushdown && !sink.skip_filter_pushdown && ht.Count() > 0) { - auto final_min_max = filter_pushdown->Finalize(context, ht, *sink.global_filter_state, *this); - min = final_min_max->data[0].GetValue(0); - max = final_min_max->data[1].GetValue(0); - } else if (TypeIsIntegral(conditions[0].right->return_type.InternalType())) { - min = Value::MinimumValue(conditions[0].right->return_type); - max = Value::MaximumValue(conditions[0].right->return_type); - } - - // check for possible perfect hash table - auto use_perfect_hash = sink.perfect_join_executor->CanDoPerfectHashJoin(*this, min, max); - if (use_perfect_hash) { - D_ASSERT(ht.equality_types.size() == 1); - auto key_type = ht.equality_types[0]; - use_perfect_hash = sink.perfect_join_executor->BuildPerfectHashTable(key_type); - } - // In case of a large build side or duplicates, use regular hash join - if (!use_perfect_hash) { - sink.perfect_join_executor.reset(); - sink.ScheduleFinalize(pipeline, event); - } - sink.finalized = true; - if (ht.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class HashJoinOperatorState : public CachingOperatorState { -public: - explicit HashJoinOperatorState(ClientContext &context, HashJoinGlobalSinkState &sink) - : probe_executor(context), scan_structure(*sink.hash_table, join_key_state) { - } - - DataChunk lhs_join_keys; - TupleDataChunkState join_key_state; - DataChunk lhs_output; - - ExpressionExecutor probe_executor; - JoinHashTable::ScanStructure scan_structure; - unique_ptr perfect_hash_join_state; - - JoinHashTable::ProbeSpillLocalAppendState spill_state; - JoinHashTable::ProbeState probe_state; - //! Chunk to sink data into for external join - DataChunk spill_chunk; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op); - } -}; - -unique_ptr PhysicalHashJoin::GetOperatorState(ExecutionContext &context) const { - auto &allocator = BufferAllocator::Get(context.client); - auto &sink = sink_state->Cast(); - auto state = make_uniq(context.client, sink); - state->lhs_join_keys.Initialize(allocator, condition_types); - if (!lhs_output_columns.col_types.empty()) { - state->lhs_output.Initialize(allocator, lhs_output_columns.col_types); - } - if (sink.perfect_join_executor) { - state->perfect_hash_join_state = sink.perfect_join_executor->GetOperatorState(context); - } else { - for (auto &cond : conditions) { - state->probe_executor.AddExpression(*cond.left); - } - TupleDataCollection::InitializeChunkState(state->join_key_state, condition_types); - } - if (sink.external) { - state->spill_chunk.Initialize(allocator, sink.probe_types); - sink.InitializeProbeSpill(); - } - - return std::move(state); -} - -OperatorResultType PhysicalHashJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &sink = sink_state->Cast(); - D_ASSERT(sink.finalized); - D_ASSERT(!sink.scanned_data); - - if (sink.hash_table->Count() == 0) { - if (EmptyResultIfRHSIsEmpty()) { - return OperatorResultType::FINISHED; - } - state.lhs_output.ReferenceColumns(input, lhs_output_columns.col_idxs); - ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, state.lhs_output, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } - - if (sink.perfect_join_executor) { - D_ASSERT(!sink.external); - state.lhs_output.ReferenceColumns(input, lhs_output_columns.col_idxs); - return sink.perfect_join_executor->ProbePerfectHashTable(context, input, state.lhs_output, chunk, - *state.perfect_hash_join_state); - } - - if (sink.external && !state.initialized) { - // some initialization for external hash join - if (!sink.probe_spill) { - sink.InitializeProbeSpill(); - } - state.spill_state = sink.probe_spill->RegisterThread(); - state.initialized = true; - } - - if (state.scan_structure.is_null) { - // probe the HT, start by resolving the join keys for the left chunk - state.lhs_join_keys.Reset(); - state.probe_executor.Execute(input, state.lhs_join_keys); - - // perform the actual probe - if (sink.external) { - sink.hash_table->ProbeAndSpill(state.scan_structure, state.lhs_join_keys, state.join_key_state, - state.probe_state, input, *sink.probe_spill, state.spill_state, - state.spill_chunk); - } else { - sink.hash_table->Probe(state.scan_structure, state.lhs_join_keys, state.join_key_state, state.probe_state); - } - } - - state.lhs_output.ReferenceColumns(input, lhs_output_columns.col_idxs); - state.scan_structure.Next(state.lhs_join_keys, state.lhs_output, chunk); - - if (state.scan_structure.PointersExhausted() && chunk.size() == 0) { - state.scan_structure.is_null = true; - return OperatorResultType::NEED_MORE_INPUT; - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -enum class HashJoinSourceStage : uint8_t { INIT, BUILD, PROBE, SCAN_HT, DONE }; - -class HashJoinLocalSourceState; - -class HashJoinGlobalSourceState : public GlobalSourceState { -public: - HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context); - - //! Initialize this source state using the info in the sink - void Initialize(HashJoinGlobalSinkState &sink); - //! Try to prepare the next stage - bool TryPrepareNextStage(HashJoinGlobalSinkState &sink); - //! Prepare the next build/probe/scan_ht stage for external hash join (must hold lock) - void PrepareBuild(HashJoinGlobalSinkState &sink); - void PrepareProbe(HashJoinGlobalSinkState &sink); - void PrepareScanHT(HashJoinGlobalSinkState &sink); - //! Assigns a task to a local source state - bool AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate); - - idx_t MaxThreads() override { - D_ASSERT(op.sink_state); - auto &gstate = op.sink_state->Cast(); - - idx_t count; - if (gstate.probe_spill) { - count = probe_count; - } else if (PropagatesBuildSide(op.join_type)) { - count = gstate.hash_table->Count(); - } else { - return 0; - } - return count / ((idx_t)STANDARD_VECTOR_SIZE * parallel_scan_chunk_count); - } - -public: - const PhysicalHashJoin &op; - - //! For synchronizing the external hash join - atomic global_stage; - - //! For HT build synchronization - idx_t build_chunk_idx = DConstants::INVALID_INDEX; - idx_t build_chunk_count; - idx_t build_chunk_done; - idx_t build_chunks_per_thread = DConstants::INVALID_INDEX; - - //! For probe synchronization - atomic probe_chunk_count; - idx_t probe_chunk_done; - - //! To determine the number of threads - idx_t probe_count; - idx_t parallel_scan_chunk_count; - - //! For full/outer synchronization - idx_t full_outer_chunk_idx = DConstants::INVALID_INDEX; - atomic full_outer_chunk_count; - atomic full_outer_chunk_done; - idx_t full_outer_chunks_per_thread = DConstants::INVALID_INDEX; - - vector blocked_tasks; -}; - -class HashJoinLocalSourceState : public LocalSourceState { -public: - HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, Allocator &allocator); - - //! Do the work this thread has been assigned - void ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - //! Whether this thread has finished the work it has been assigned - bool TaskFinished() const; - //! Build, probe and scan for external hash join - void ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate); - void ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - void ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, DataChunk &chunk); - -public: - //! The stage that this thread was assigned work for - HashJoinSourceStage local_stage; - //! Vector with pointers here so we don't have to re-initialize - Vector addresses; - - //! Chunks assigned to this thread for building the pointer table - idx_t build_chunk_idx_from = DConstants::INVALID_INDEX; - idx_t build_chunk_idx_to = DConstants::INVALID_INDEX; - - //! Local scan state for probe spill - ColumnDataConsumerScanState probe_local_scan; - //! Chunks for holding the scanned probe collection - DataChunk lhs_probe_chunk; - DataChunk lhs_join_keys; - DataChunk lhs_output; - TupleDataChunkState join_key_state; - ExpressionExecutor lhs_join_key_executor; - - //! Scan structure for the external probe - JoinHashTable::ScanStructure scan_structure; - JoinHashTable::ProbeState probe_state; - bool empty_ht_probe_in_progress = false; - - //! Chunks assigned to this thread for a full/outer scan - idx_t full_outer_chunk_idx_from = DConstants::INVALID_INDEX; - idx_t full_outer_chunk_idx_to = DConstants::INVALID_INDEX; - unique_ptr full_outer_scan_state; -}; - -unique_ptr PhysicalHashJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this, context); -} - -unique_ptr PhysicalHashJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, sink_state->Cast(), - BufferAllocator::Get(context.client)); -} - -HashJoinGlobalSourceState::HashJoinGlobalSourceState(const PhysicalHashJoin &op, const ClientContext &context) - : op(op), global_stage(HashJoinSourceStage::INIT), build_chunk_count(0), build_chunk_done(0), probe_chunk_count(0), - probe_chunk_done(0), probe_count(op.children[0]->estimated_cardinality), - parallel_scan_chunk_count(context.config.verify_parallelism ? 1 : 120) { -} - -void HashJoinGlobalSourceState::Initialize(HashJoinGlobalSinkState &sink) { - auto guard = Lock(); - if (global_stage != HashJoinSourceStage::INIT) { - // Another thread initialized - return; - } - - // Finalize the probe spill - if (sink.probe_spill) { - sink.probe_spill->Finalize(); - } - - global_stage = HashJoinSourceStage::PROBE; - TryPrepareNextStage(sink); -} - -bool HashJoinGlobalSourceState::TryPrepareNextStage(HashJoinGlobalSinkState &sink) { - switch (global_stage.load()) { - case HashJoinSourceStage::BUILD: - if (build_chunk_done == build_chunk_count) { - sink.hash_table->GetDataCollection().VerifyEverythingPinned(); - sink.hash_table->finalized = true; - PrepareProbe(sink); - return true; - } - break; - case HashJoinSourceStage::PROBE: - if (probe_chunk_done == probe_chunk_count) { - if (PropagatesBuildSide(op.join_type)) { - PrepareScanHT(sink); - } else { - PrepareBuild(sink); - } - return true; - } - break; - case HashJoinSourceStage::SCAN_HT: - if (full_outer_chunk_done == full_outer_chunk_count) { - PrepareBuild(sink); - return true; - } - break; - default: - break; - } - return false; -} - -void HashJoinGlobalSourceState::PrepareBuild(HashJoinGlobalSinkState &sink) { - D_ASSERT(global_stage != HashJoinSourceStage::BUILD); - auto &ht = *sink.hash_table; - - // Update remaining size - sink.temporary_memory_state->SetRemainingSizeAndUpdateReservation(sink.context, ht.GetRemainingSize() + - sink.probe_side_requirement); - - // Try to put the next partitions in the block collection of the HT - D_ASSERT(!sink.external || sink.temporary_memory_state->GetReservation() >= sink.probe_side_requirement); - if (!sink.external || - !ht.PrepareExternalFinalize(sink.temporary_memory_state->GetReservation() - sink.probe_side_requirement)) { - global_stage = HashJoinSourceStage::DONE; - sink.temporary_memory_state->SetZero(); - return; - } - - auto &data_collection = ht.GetDataCollection(); - if (data_collection.Count() == 0 && op.EmptyResultIfRHSIsEmpty()) { - PrepareBuild(sink); - return; - } - - build_chunk_idx = 0; - build_chunk_count = data_collection.ChunkCount(); - build_chunk_done = 0; - - if (sink.context.config.verify_parallelism) { - build_chunks_per_thread = 1; - } else { - const auto max_partition_ht_size = - sink.max_partition_size + JoinHashTable::PointerTableSize(sink.max_partition_count); - const auto skew = static_cast(max_partition_ht_size) / static_cast(sink.total_size); - - if (skew > HashJoinFinalizeEvent::SKEW_SINGLE_THREADED_THRESHOLD) { - build_chunks_per_thread = build_chunk_count; // This forces single-threaded building - } else { - build_chunks_per_thread = // Same task size as in HashJoinFinalizeEvent - MaxValue(MinValue(build_chunk_count, HashJoinFinalizeEvent::CHUNKS_PER_TASK), 1); - } - } - - ht.InitializePointerTable(); - - global_stage = HashJoinSourceStage::BUILD; -} - -void HashJoinGlobalSourceState::PrepareProbe(HashJoinGlobalSinkState &sink) { - sink.probe_spill->PrepareNextProbe(); - const auto &consumer = *sink.probe_spill->consumer; - - probe_chunk_count = consumer.Count() == 0 ? 0 : consumer.ChunkCount(); - probe_chunk_done = 0; - - global_stage = HashJoinSourceStage::PROBE; - if (probe_chunk_count == 0) { - TryPrepareNextStage(sink); - return; - } -} - -void HashJoinGlobalSourceState::PrepareScanHT(HashJoinGlobalSinkState &sink) { - D_ASSERT(global_stage != HashJoinSourceStage::SCAN_HT); - auto &ht = *sink.hash_table; - - auto &data_collection = ht.GetDataCollection(); - full_outer_chunk_idx = 0; - full_outer_chunk_count = data_collection.ChunkCount(); - full_outer_chunk_done = 0; - - full_outer_chunks_per_thread = - MaxValue((full_outer_chunk_count + sink.num_threads - 1) / sink.num_threads, 1); - - global_stage = HashJoinSourceStage::SCAN_HT; -} - -bool HashJoinGlobalSourceState::AssignTask(HashJoinGlobalSinkState &sink, HashJoinLocalSourceState &lstate) { - D_ASSERT(lstate.TaskFinished()); - - auto guard = Lock(); - switch (global_stage.load()) { - case HashJoinSourceStage::BUILD: - if (build_chunk_idx != build_chunk_count) { - lstate.local_stage = global_stage; - lstate.build_chunk_idx_from = build_chunk_idx; - build_chunk_idx = MinValue(build_chunk_count, build_chunk_idx + build_chunks_per_thread); - lstate.build_chunk_idx_to = build_chunk_idx; - return true; - } - break; - case HashJoinSourceStage::PROBE: - if (sink.probe_spill->consumer && sink.probe_spill->consumer->AssignChunk(lstate.probe_local_scan)) { - lstate.local_stage = global_stage; - lstate.empty_ht_probe_in_progress = false; - return true; - } - break; - case HashJoinSourceStage::SCAN_HT: - if (full_outer_chunk_idx != full_outer_chunk_count) { - lstate.local_stage = global_stage; - lstate.full_outer_chunk_idx_from = full_outer_chunk_idx; - full_outer_chunk_idx = - MinValue(full_outer_chunk_count, full_outer_chunk_idx + full_outer_chunks_per_thread); - lstate.full_outer_chunk_idx_to = full_outer_chunk_idx; - return true; - } - break; - case HashJoinSourceStage::DONE: - break; - default: - throw InternalException("Unexpected HashJoinSourceStage in AssignTask!"); - } - return false; -} - -HashJoinLocalSourceState::HashJoinLocalSourceState(const PhysicalHashJoin &op, const HashJoinGlobalSinkState &sink, - Allocator &allocator) - : local_stage(HashJoinSourceStage::INIT), addresses(LogicalType::POINTER), lhs_join_key_executor(sink.context), - scan_structure(*sink.hash_table, join_key_state) { - auto &chunk_state = probe_local_scan.current_chunk_state; - chunk_state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - - lhs_probe_chunk.Initialize(allocator, sink.probe_types); - lhs_join_keys.Initialize(allocator, op.condition_types); - lhs_output.Initialize(allocator, op.lhs_output_columns.col_types); - TupleDataCollection::InitializeChunkState(join_key_state, op.condition_types); - - for (auto &cond : op.conditions) { - lhs_join_key_executor.AddExpression(*cond.left); - } -} - -void HashJoinLocalSourceState::ExecuteTask(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - switch (local_stage) { - case HashJoinSourceStage::BUILD: - ExternalBuild(sink, gstate); - break; - case HashJoinSourceStage::PROBE: - ExternalProbe(sink, gstate, chunk); - break; - case HashJoinSourceStage::SCAN_HT: - ExternalScanHT(sink, gstate, chunk); - break; - default: - throw InternalException("Unexpected HashJoinSourceStage in ExecuteTask!"); - } -} - -bool HashJoinLocalSourceState::TaskFinished() const { - switch (local_stage) { - case HashJoinSourceStage::INIT: - case HashJoinSourceStage::BUILD: - return true; - case HashJoinSourceStage::PROBE: - return scan_structure.is_null && !empty_ht_probe_in_progress; - case HashJoinSourceStage::SCAN_HT: - return full_outer_scan_state == nullptr; - default: - throw InternalException("Unexpected HashJoinSourceStage in TaskFinished!"); - } -} - -void HashJoinLocalSourceState::ExternalBuild(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate) { - D_ASSERT(local_stage == HashJoinSourceStage::BUILD); - - auto &ht = *sink.hash_table; - ht.Finalize(build_chunk_idx_from, build_chunk_idx_to, true); - - auto guard = gstate.Lock(); - gstate.build_chunk_done += build_chunk_idx_to - build_chunk_idx_from; -} - -void HashJoinLocalSourceState::ExternalProbe(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - D_ASSERT(local_stage == HashJoinSourceStage::PROBE && sink.hash_table->finalized); - - if (!scan_structure.is_null) { - // Still have elements remaining (i.e. we got >STANDARD_VECTOR_SIZE elements in the previous probe) - scan_structure.Next(lhs_join_keys, lhs_output, chunk); - if (chunk.size() != 0 || !scan_structure.PointersExhausted()) { - return; - } - } - - if (!scan_structure.is_null || empty_ht_probe_in_progress) { - // Previous probe is done - scan_structure.is_null = true; - empty_ht_probe_in_progress = false; - sink.probe_spill->consumer->FinishChunk(probe_local_scan); - auto guard = gstate.Lock(); - gstate.probe_chunk_done++; - return; - } - - // Scan input chunk for next probe - sink.probe_spill->consumer->ScanChunk(probe_local_scan, lhs_probe_chunk); - - // Get the probe chunk columns/hashes - lhs_join_keys.Reset(); - lhs_join_key_executor.Execute(lhs_probe_chunk, lhs_join_keys); - lhs_output.ReferenceColumns(lhs_probe_chunk, sink.op.lhs_output_columns.col_idxs); - - if (sink.hash_table->Count() == 0 && !gstate.op.EmptyResultIfRHSIsEmpty()) { - gstate.op.ConstructEmptyJoinResult(sink.hash_table->join_type, sink.hash_table->has_null, lhs_output, chunk); - empty_ht_probe_in_progress = true; - return; - } - - // Perform the probe - auto precomputed_hashes = &lhs_probe_chunk.data.back(); - sink.hash_table->Probe(scan_structure, lhs_join_keys, join_key_state, probe_state, precomputed_hashes); - scan_structure.Next(lhs_join_keys, lhs_output, chunk); -} - -void HashJoinLocalSourceState::ExternalScanHT(HashJoinGlobalSinkState &sink, HashJoinGlobalSourceState &gstate, - DataChunk &chunk) { - D_ASSERT(local_stage == HashJoinSourceStage::SCAN_HT); - - if (!full_outer_scan_state) { - full_outer_scan_state = make_uniq(sink.hash_table->GetDataCollection(), - full_outer_chunk_idx_from, full_outer_chunk_idx_to); - } - sink.hash_table->ScanFullOuter(*full_outer_scan_state, addresses, chunk); - - if (chunk.size() == 0) { - full_outer_scan_state = nullptr; - auto guard = gstate.Lock(); - gstate.full_outer_chunk_done += full_outer_chunk_idx_to - full_outer_chunk_idx_from; - } -} - -SourceResultType PhysicalHashJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - sink.scanned_data = true; - - if (!sink.external && !PropagatesBuildSide(join_type)) { - auto guard = gstate.Lock(); - if (gstate.global_stage != HashJoinSourceStage::DONE) { - gstate.global_stage = HashJoinSourceStage::DONE; - sink.hash_table->Reset(); - sink.temporary_memory_state->SetZero(); - } - return SourceResultType::FINISHED; - } - - if (gstate.global_stage == HashJoinSourceStage::INIT) { - gstate.Initialize(sink); - } - - // Any call to GetData must produce tuples, otherwise the pipeline executor thinks that we're done - // Therefore, we loop until we've produced tuples, or until the operator is actually done - while (gstate.global_stage != HashJoinSourceStage::DONE && chunk.size() == 0) { - if (!lstate.TaskFinished() || gstate.AssignTask(sink, lstate)) { - lstate.ExecuteTask(sink, gstate, chunk); - } else { - auto guard = gstate.Lock(); - if (gstate.TryPrepareNextStage(sink) || gstate.global_stage == HashJoinSourceStage::DONE) { - gstate.UnblockTasks(guard); - } else { - return gstate.BlockSource(guard, input.interrupt_state); - } - } - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -ProgressData PhysicalHashJoin::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &sink = sink_state->Cast(); - auto &gstate = gstate_p.Cast(); - - ProgressData res; - - if (!sink.external) { - if (PropagatesBuildSide(join_type)) { - res.done = static_cast(gstate.full_outer_chunk_done); - res.total = static_cast(gstate.full_outer_chunk_count); - return res; - } - res.done = 0.0; - res.total = 1.0; - return res; - } - - const auto &ht = *sink.hash_table; - const auto num_partitions = static_cast(RadixPartitioning::NumberOfPartitions(ht.GetRadixBits())); - - res.done = static_cast(ht.FinishedPartitionCount()); - res.total = num_partitions; - - const auto probe_chunk_done = static_cast(gstate.probe_chunk_done); - const auto probe_chunk_count = static_cast(gstate.probe_chunk_count); - if (probe_chunk_count != 0) { - // Progress of the current round of probing - auto probe_progress = probe_chunk_done / probe_chunk_count; - // Weighed by the number of partitions - probe_progress *= static_cast(ht.CurrentPartitionCount()); - // Add it to the progress - res.done += probe_progress; - } - - return res; -} - -InsertionOrderPreservingMap PhysicalHashJoin::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Join Type"] = EnumUtil::ToString(join_type); - - string condition_info; - for (idx_t i = 0; i < conditions.size(); i++) { - auto &join_condition = conditions[i]; - if (i > 0) { - condition_info += "\n"; - } - condition_info += - StringUtil::Format("%s %s %s", join_condition.left->GetName(), - ExpressionTypeToOperator(join_condition.comparison), join_condition.right->GetName()); - } - result["Conditions"] = condition_info; - - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp b/src/duckdb/src/execution/operator/join/physical_iejoin.cpp deleted file mode 100644 index c5d5810fa..000000000 --- a/src/duckdb/src/execution/operator/join/physical_iejoin.cpp +++ /dev/null @@ -1,1057 +0,0 @@ -#include "duckdb/execution/operator/join/physical_iejoin.hpp" - -#include "duckdb/common/atomic.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/sort/sorted_block.hpp" -#include "duckdb/common/thread.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -namespace duckdb { - -PhysicalIEJoin::PhysicalIEJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right, vector cond, JoinType join_type, - idx_t estimated_cardinality) - : PhysicalRangeJoin(op, PhysicalOperatorType::IE_JOIN, std::move(left), std::move(right), std::move(cond), - join_type, estimated_cardinality) { - - // 1. let L1 (resp. L2) be the array of column X (resp. Y) - D_ASSERT(conditions.size() >= 2); - for (idx_t i = 0; i < 2; ++i) { - auto &cond = conditions[i]; - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - // Convert the conditions to sort orders - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - auto sense = OrderType::INVALID; - - // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order - // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order - // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order - // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order - switch (cond.comparison) { - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - sense = i ? OrderType::ASCENDING : OrderType::DESCENDING; - break; - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - sense = i ? OrderType::DESCENDING : OrderType::ASCENDING; - break; - default: - throw NotImplementedException("Unimplemented join type for IEJoin"); - } - lhs_orders.emplace_back(sense, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(sense, OrderByNullType::NULLS_LAST, std::move(right)); - } - - for (idx_t i = 2; i < conditions.size(); ++i) { - auto &cond = conditions[i]; - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class IEJoinLocalState : public LocalSinkState { -public: - using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - - IEJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) - : table(context, op, child) { - } - - //! The local sort state - LocalSortedTable table; -}; - -class IEJoinGlobalState : public GlobalSinkState { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - IEJoinGlobalState(ClientContext &context, const PhysicalIEJoin &op) : child(0) { - tables.resize(2); - RowLayout lhs_layout; - lhs_layout.Initialize(op.children[0]->types); - vector lhs_order; - lhs_order.emplace_back(op.lhs_orders[0].Copy()); - tables[0] = make_uniq(context, lhs_order, lhs_layout, op); - - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1]->types); - vector rhs_order; - rhs_order.emplace_back(op.rhs_orders[0].Copy()); - tables[1] = make_uniq(context, rhs_order, rhs_layout, op); - } - - IEJoinGlobalState(IEJoinGlobalState &prev) : tables(std::move(prev.tables)), child(prev.child + 1) { - state = prev.state; - } - - void Sink(DataChunk &input, IEJoinLocalState &lstate) { - auto &table = *tables[child]; - auto &global_sort_state = table.global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - - // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - } - - vector> tables; - size_t child; -}; - -unique_ptr PhysicalIEJoin::GetGlobalSinkState(ClientContext &context) const { - D_ASSERT(!sink_state); - return make_uniq(context, *this); -} - -unique_ptr PhysicalIEJoin::GetLocalSinkState(ExecutionContext &context) const { - idx_t sink_child = 0; - if (sink_state) { - const auto &ie_sink = sink_state->Cast(); - sink_child = ie_sink.child; - } - return make_uniq(context.client, *this, sink_child); -} - -SinkResultType PhysicalIEJoin::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - gstate.Sink(chunk, lstate); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalIEJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.tables[gstate.child]->Combine(lstate.table); - auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalIEJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &table = *gstate.tables[gstate.child]; - auto &global_sort_state = table.global_sort_state; - - if ((gstate.child == 1 && PropagatesBuildSide(join_type)) || (gstate.child == 0 && IsLeftOuterJoin(join_type))) { - // for FULL/LEFT/RIGHT OUTER JOIN, initialize found_match to false for every tuple - table.IntializeMatches(); - } - if (gstate.child == 1 && global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Sort the current input child - table.Finalize(pipeline, event); - - // Move to the next input child - ++gstate.child; - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -OperatorResultType PhysicalIEJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - return OperatorResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -struct IEJoinUnion { - using SortedTable = PhysicalRangeJoin::GlobalSortedTable; - - static idx_t AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx); - - static void Sort(SortedTable &table) { - auto &global_sort_state = table.global_sort_state; - global_sort_state.PrepareMergePhase(); - while (global_sort_state.sorted_blocks.size() > 1) { - global_sort_state.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort_state, global_sort_state.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort_state.CompleteMergeRound(true); - } - } - - template - static vector ExtractColumn(SortedTable &table, idx_t col_idx) { - vector result; - result.reserve(table.count); - - auto &gstate = table.global_sort_state; - auto &blocks = *gstate.sorted_blocks[0]->payload_data; - PayloadScanner scanner(blocks, gstate, false); - - DataChunk payload; - payload.Initialize(Allocator::DefaultAllocator(), gstate.payload_layout.GetTypes()); - for (;;) { - payload.Reset(); - scanner.Scan(payload); - const auto count = payload.size(); - if (!count) { - break; - } - - const auto data_ptr = FlatVector::GetData(payload.data[col_idx]); - result.insert(result.end(), data_ptr, data_ptr + count); - } - - return result; - } - - IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, SortedTable &t2, - const idx_t b2); - - idx_t SearchL1(idx_t pos); - bool NextRow(); - - //! Inverted loop - idx_t JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel); - - //! L1 - unique_ptr l1; - //! L2 - unique_ptr l2; - - //! Li - vector li; - //! P - vector p; - - //! B - vector bit_array; - ValidityMask bit_mask; - - //! Bloom Filter - static constexpr idx_t BLOOM_CHUNK_BITS = 1024; - idx_t bloom_count; - vector bloom_array; - ValidityMask bloom_filter; - - //! Iteration state - idx_t n; - idx_t i; - idx_t j; - unique_ptr op1; - unique_ptr off1; - unique_ptr op2; - unique_ptr off2; - int64_t lrid; -}; - -idx_t IEJoinUnion::AppendKey(SortedTable &table, ExpressionExecutor &executor, SortedTable &marked, int64_t increment, - int64_t base, const idx_t block_idx) { - LocalSortState local_sort_state; - local_sort_state.Initialize(marked.global_sort_state, marked.global_sort_state.buffer_manager); - - // Reading - const auto valid = table.count - table.has_null; - auto &gstate = table.global_sort_state; - PayloadScanner scanner(gstate, block_idx); - auto table_idx = block_idx * gstate.block_capacity; - - DataChunk scanned; - scanned.Initialize(Allocator::DefaultAllocator(), scanner.GetPayloadTypes()); - - // Writing - auto types = local_sort_state.sort_layout->logical_types; - const idx_t payload_idx = types.size(); - - const auto &payload_types = local_sort_state.payload_layout->GetTypes(); - types.insert(types.end(), payload_types.begin(), payload_types.end()); - const idx_t rid_idx = types.size() - 1; - - DataChunk keys; - DataChunk payload; - keys.Initialize(Allocator::DefaultAllocator(), types); - - idx_t inserted = 0; - for (auto rid = base; table_idx < valid;) { - scanned.Reset(); - scanner.Scan(scanned); - - // NULLs are at the end, so stop when we reach them - auto scan_count = scanned.size(); - if (table_idx + scan_count > valid) { - scan_count = valid - table_idx; - scanned.SetCardinality(scan_count); - } - if (scan_count == 0) { - break; - } - table_idx += scan_count; - - // Compute the input columns from the payload - keys.Reset(); - keys.Split(payload, rid_idx); - executor.Execute(scanned, keys); - - // Mark the rid column - payload.data[0].Sequence(rid, increment, scan_count); - payload.SetCardinality(scan_count); - keys.Fuse(payload); - rid += increment * UnsafeNumericCast(scan_count); - - // Sort on the sort columns (which will no longer be needed) - keys.Split(payload, payload_idx); - local_sort_state.SinkChunk(keys, payload); - inserted += scan_count; - keys.Fuse(payload); - - // Flush when we have enough data - if (local_sort_state.SizeInBytes() >= marked.memory_per_thread) { - local_sort_state.Sort(marked.global_sort_state, true); - } - } - marked.global_sort_state.AddLocalState(local_sort_state); - marked.count += inserted; - - return inserted; -} - -IEJoinUnion::IEJoinUnion(ClientContext &context, const PhysicalIEJoin &op, SortedTable &t1, const idx_t b1, - SortedTable &t2, const idx_t b2) - : n(0), i(0) { - // input : query Q with 2 join predicates t1.X op1 t2.X' and t1.Y op2 t2.Y', tables T, T' of sizes m and n resp. - // output: a list of tuple pairs (ti , tj) - // Note that T/T' are already sorted on X/X' and contain the payload data - // We only join the two block numbers and use the sizes of the blocks as the counts - - // 0. Filter out tables with no overlap - if (!t1.BlockSize(b1) || !t2.BlockSize(b2)) { - return; - } - - const auto &cmp1 = op.conditions[0].comparison; - SBIterator bounds1(t1.global_sort_state, cmp1); - SBIterator bounds2(t2.global_sort_state, cmp1); - - // t1.X[0] op1 t2.X'[-1] - bounds1.SetIndex(bounds1.block_capacity * b1); - bounds2.SetIndex(bounds2.block_capacity * b2 + t2.BlockSize(b2) - 1); - if (!bounds1.Compare(bounds2)) { - return; - } - - // 1. let L1 (resp. L2) be the array of column X (resp. Y ) - const auto &order1 = op.lhs_orders[0]; - const auto &order2 = op.lhs_orders[1]; - - // 2. if (op1 ∈ {>, ≥}) sort L1 in descending order - // 3. else if (op1 ∈ {<, ≤}) sort L1 in ascending order - - // For the union algorithm, we make a unified table with the keys and the rids as the payload: - // X/X', Y/Y', R/R'/Li - // The first position is the sort key. - vector types; - types.emplace_back(order2.expression->return_type); - types.emplace_back(LogicalType::BIGINT); - RowLayout payload_layout; - payload_layout.Initialize(types); - - // Sort on the first expression - auto ref = make_uniq(order1.expression->return_type, 0U); - vector orders; - orders.emplace_back(order1.type, order1.null_order, std::move(ref)); - // The goal is to make i (from the left table) < j (from the right table), - // if value[i] and value[j] match the condition 1. - // Add a column from_left to solve the problem when there exist multiple equal values in l1. - // If the operator is loose inequality, make t1.from_left (== true) sort BEFORE t2.from_left (== false). - // Otherwise, make t1.from_left sort (== true) sort AFTER t2.from_left (== false). - // For example, if t1.time <= t2.time - // | value | 1 | 1 | 1 | 1 | - // | --------- | ----- | ----- | ----- | ----- | - // | from_left | T(l2) | T(l2) | F(r1) | F(r2) | - // if t1.time < t2.time - // | value | 1 | 1 | 1 | 1 | - // | --------- | ----- | ----- | ----- | ----- | - // | from_left | F(r2) | F(r1) | T(l2) | T(l1) | - // Using this OrderType, if i < j then value[i] (from left table) and value[j] (from right table) match - // the condition (t1.time <= t2.time or t1.time < t2.time), then from_left will force them into the correct order. - auto from_left = make_uniq(Value::BOOLEAN(true)); - orders.emplace_back(SBIterator::ComparisonValue(cmp1) == 0 ? OrderType::DESCENDING : OrderType::ASCENDING, - OrderByNullType::ORDER_DEFAULT, std::move(from_left)); - - l1 = make_uniq(context, orders, payload_layout, op); - - // LHS has positive rids - ExpressionExecutor l_executor(context); - l_executor.AddExpression(*order1.expression); - // add const column true - auto left_const = make_uniq(Value::BOOLEAN(true)); - l_executor.AddExpression(*left_const); - l_executor.AddExpression(*order2.expression); - AppendKey(t1, l_executor, *l1, 1, 1, b1); - - // RHS has negative rids - ExpressionExecutor r_executor(context); - r_executor.AddExpression(*op.rhs_orders[0].expression); - // add const column flase - auto right_const = make_uniq(Value::BOOLEAN(false)); - r_executor.AddExpression(*right_const); - r_executor.AddExpression(*op.rhs_orders[1].expression); - AppendKey(t2, r_executor, *l1, -1, -1, b2); - - if (l1->global_sort_state.sorted_blocks.empty()) { - return; - } - - Sort(*l1); - - op1 = make_uniq(l1->global_sort_state, cmp1); - off1 = make_uniq(l1->global_sort_state, cmp1); - - // We don't actually need the L1 column, just its sort key, which is in the sort blocks - li = ExtractColumn(*l1, types.size() - 1); - - // 4. if (op2 ∈ {>, ≥}) sort L2 in ascending order - // 5. else if (op2 ∈ {<, ≤}) sort L2 in descending order - - // We sort on Y/Y' to obtain the sort keys and the permutation array. - // For this we just need a two-column table of Y, P - types.clear(); - types.emplace_back(LogicalType::BIGINT); - payload_layout.Initialize(types); - - // Sort on the first expression - orders.clear(); - ref = make_uniq(order2.expression->return_type, 0U); - orders.emplace_back(order2.type, order2.null_order, std::move(ref)); - - ExpressionExecutor executor(context); - executor.AddExpression(*orders[0].expression); - - l2 = make_uniq(context, orders, payload_layout, op); - for (idx_t base = 0, block_idx = 0; block_idx < l1->BlockCount(); ++block_idx) { - base += AppendKey(*l1, executor, *l2, 1, NumericCast(base), block_idx); - } - - Sort(*l2); - - // We don't actually need the L2 column, just its sort key, which is in the sort blocks - - // 6. compute the permutation array P of L2 w.r.t. L1 - p = ExtractColumn(*l2, types.size() - 1); - - // 7. initialize bit-array B (|B| = n), and set all bits to 0 - n = l2->count.load(); - bit_array.resize(ValidityMask::EntryCount(n), 0); - bit_mask.Initialize(bit_array.data(), n); - - // Bloom filter - bloom_count = (n + (BLOOM_CHUNK_BITS - 1)) / BLOOM_CHUNK_BITS; - bloom_array.resize(ValidityMask::EntryCount(bloom_count), 0); - bloom_filter.Initialize(bloom_array.data(), bloom_count); - - // 11. for(i←1 to n) do - const auto &cmp2 = op.conditions[1].comparison; - op2 = make_uniq(l2->global_sort_state, cmp2); - off2 = make_uniq(l2->global_sort_state, cmp2); - i = 0; - j = 0; - (void)NextRow(); -} - -bool IEJoinUnion::NextRow() { - for (; i < n; ++i) { - // 12. pos ← P[i] - auto pos = p[i]; - lrid = li[pos]; - if (lrid < 0) { - continue; - } - - // 16. B[pos] ← 1 - op2->SetIndex(i); - for (; off2->GetIndex() < n; ++(*off2)) { - if (!off2->Compare(*op2)) { - break; - } - const auto p2 = p[off2->GetIndex()]; - if (li[p2] < 0) { - // Only mark rhs matches. - bit_mask.SetValid(p2); - bloom_filter.SetValid(p2 / BLOOM_CHUNK_BITS); - } - } - - // 9. if (op1 ∈ {≤,≥} and op2 ∈ {≤,≥}) eqOff = 0 - // 10. else eqOff = 1 - // No, because there could be more than one equal value. - // Find the leftmost off1 where L1[pos] op1 L1[off1..n] - // These are the rows that satisfy the op1 condition - // and that is where we should start scanning B from - j = pos; - - return true; - } - return false; -} - -static idx_t NextValid(const ValidityMask &bits, idx_t j, const idx_t n) { - if (j >= n) { - return n; - } - - // We can do a first approximation by checking entries one at a time - // which gives 64:1. - idx_t entry_idx, idx_in_entry; - bits.GetEntryIndex(j, entry_idx, idx_in_entry); - auto entry = bits.GetValidityEntry(entry_idx++); - - // Trim the bits before the start position - entry &= (ValidityMask::ValidityBuffer::MAX_ENTRY << idx_in_entry); - - // Check the non-ragged entries - for (const auto entry_count = bits.EntryCount(n); entry_idx < entry_count; ++entry_idx) { - if (entry) { - for (; idx_in_entry < bits.BITS_PER_VALUE; ++idx_in_entry, ++j) { - if (bits.RowIsValid(entry, idx_in_entry)) { - return j; - } - } - } else { - j += bits.BITS_PER_VALUE - idx_in_entry; - } - - entry = bits.GetValidityEntry(entry_idx); - idx_in_entry = 0; - } - - // Check the final entry - for (; j < n; ++idx_in_entry, ++j) { - if (bits.RowIsValid(entry, idx_in_entry)) { - return j; - } - } - - return j; -} - -idx_t IEJoinUnion::JoinComplexBlocks(SelectionVector &lsel, SelectionVector &rsel) { - // 8. initialize join result as an empty list for tuple pairs - idx_t result_count = 0; - - // 11. for(i←1 to n) do - while (i < n) { - // 13. for (j ← pos+eqOff to n) do - for (;;) { - // 14. if B[j] = 1 then - - // Use the Bloom filter to find candidate blocks - while (j < n) { - auto bloom_begin = NextValid(bloom_filter, j / BLOOM_CHUNK_BITS, bloom_count) * BLOOM_CHUNK_BITS; - auto bloom_end = MinValue(n, bloom_begin + BLOOM_CHUNK_BITS); - - j = MaxValue(j, bloom_begin); - j = NextValid(bit_mask, j, bloom_end); - if (j < bloom_end) { - break; - } - } - - if (j >= n) { - break; - } - - // Filter out tuples with the same sign (they come from the same table) - const auto rrid = li[j]; - ++j; - - D_ASSERT(lrid > 0 && rrid < 0); - // 15. add tuples w.r.t. (L1[j], L1[i]) to join result - lsel.set_index(result_count, sel_t(+lrid - 1)); - rsel.set_index(result_count, sel_t(-rrid - 1)); - ++result_count; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - return result_count; - } - } - ++i; - - if (!NextRow()) { - break; - } - } - - return result_count; -} - -class IEJoinLocalSourceState : public LocalSourceState { -public: - explicit IEJoinLocalSourceState(ClientContext &context, const PhysicalIEJoin &op) - : op(op), true_sel(STANDARD_VECTOR_SIZE), left_executor(context), right_executor(context), - left_matches(nullptr), right_matches(nullptr) { - auto &allocator = Allocator::Get(context); - unprojected.Initialize(allocator, op.unprojected_types); - - if (op.conditions.size() < 3) { - return; - } - - vector left_types; - vector right_types; - for (idx_t i = 2; i < op.conditions.size(); ++i) { - const auto &cond = op.conditions[i]; - - left_types.push_back(cond.left->return_type); - left_executor.AddExpression(*cond.left); - - right_types.push_back(cond.left->return_type); - right_executor.AddExpression(*cond.right); - } - - left_keys.Initialize(allocator, left_types); - right_keys.Initialize(allocator, right_types); - } - - idx_t SelectOuterRows(bool *matches) { - idx_t count = 0; - for (; outer_idx < outer_count; ++outer_idx) { - if (!matches[outer_idx]) { - true_sel.set_index(count++, outer_idx); - if (count >= STANDARD_VECTOR_SIZE) { - outer_idx++; - break; - } - } - } - - return count; - } - - const PhysicalIEJoin &op; - - // Joining - unique_ptr joiner; - - idx_t left_base; - idx_t left_block_index; - - idx_t right_base; - idx_t right_block_index; - - // Trailing predicates - SelectionVector true_sel; - - ExpressionExecutor left_executor; - DataChunk left_keys; - - ExpressionExecutor right_executor; - DataChunk right_keys; - - DataChunk unprojected; - - // Outer joins - idx_t outer_idx; - idx_t outer_count; - bool *left_matches; - bool *right_matches; -}; - -void PhysicalIEJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &result, LocalSourceState &state_p) const { - auto &state = state_p.Cast(); - auto &ie_sink = sink_state->Cast(); - auto &left_table = *ie_sink.tables[0]; - auto &right_table = *ie_sink.tables[1]; - - const auto left_cols = children[0]->GetTypes().size(); - auto &chunk = state.unprojected; - do { - SelectionVector lsel(STANDARD_VECTOR_SIZE); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - auto result_count = state.joiner->JoinComplexBlocks(lsel, rsel); - if (result_count == 0) { - // exhausted this pair - return; - } - - // found matches: extract them - - chunk.Reset(); - SliceSortedPayload(chunk, left_table.global_sort_state, state.left_block_index, lsel, result_count, 0); - SliceSortedPayload(chunk, right_table.global_sort_state, state.right_block_index, rsel, result_count, - left_cols); - chunk.SetCardinality(result_count); - - auto sel = FlatVector::IncrementalSelectionVector(); - if (conditions.size() > 2) { - // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. - const auto tail_cols = conditions.size() - 2; - - DataChunk right_chunk; - chunk.Split(right_chunk, left_cols); - state.left_executor.SetChunk(chunk); - state.right_executor.SetChunk(right_chunk); - - auto tail_count = result_count; - auto true_sel = &state.true_sel; - for (size_t cmp_idx = 0; cmp_idx < tail_cols; ++cmp_idx) { - auto &left = state.left_keys.data[cmp_idx]; - state.left_executor.ExecuteExpression(cmp_idx, left); - - auto &right = state.right_keys.data[cmp_idx]; - state.right_executor.ExecuteExpression(cmp_idx, right); - - if (tail_count < result_count) { - left.Slice(*sel, tail_count); - right.Slice(*sel, tail_count); - } - tail_count = SelectJoinTail(conditions[cmp_idx + 2].comparison, left, right, sel, tail_count, true_sel); - sel = true_sel; - } - chunk.Fuse(right_chunk); - - if (tail_count < result_count) { - result_count = tail_count; - chunk.Slice(*sel, result_count); - } - } - - // We need all of the data to compute other predicates, - // but we only return what is in the projection map - ProjectResult(chunk, result); - - // found matches: mark the found matches if required - if (left_table.found_match) { - for (idx_t i = 0; i < result_count; i++) { - left_table.found_match[state.left_base + lsel[sel->get_index(i)]] = true; - } - } - if (right_table.found_match) { - for (idx_t i = 0; i < result_count; i++) { - right_table.found_match[state.right_base + rsel[sel->get_index(i)]] = true; - } - } - result.Verify(); - } while (result.size() == 0); -} - -class IEJoinGlobalSourceState : public GlobalSourceState { -public: - explicit IEJoinGlobalSourceState(const PhysicalIEJoin &op, IEJoinGlobalState &gsink) - : op(op), gsink(gsink), initialized(false), next_pair(0), completed(0), left_outers(0), next_left(0), - right_outers(0), next_right(0) { - } - - void Initialize() { - auto guard = Lock(); - if (initialized) { - return; - } - - // Compute the starting row for reach block - // (In theory these are all the same size, but you never know...) - auto &left_table = *gsink.tables[0]; - const auto left_blocks = left_table.BlockCount(); - idx_t left_base = 0; - - for (size_t lhs = 0; lhs < left_blocks; ++lhs) { - left_bases.emplace_back(left_base); - left_base += left_table.BlockSize(lhs); - } - - auto &right_table = *gsink.tables[1]; - const auto right_blocks = right_table.BlockCount(); - idx_t right_base = 0; - for (size_t rhs = 0; rhs < right_blocks; ++rhs) { - right_bases.emplace_back(right_base); - right_base += right_table.BlockSize(rhs); - } - - // Outer join block counts - if (left_table.found_match) { - left_outers = left_blocks; - } - - if (right_table.found_match) { - right_outers = right_blocks; - } - - // Ready for action - initialized = true; - } - -public: - idx_t MaxThreads() override { - // We can't leverage any more threads than block pairs. - const auto &sink_state = (op.sink_state->Cast()); - return sink_state.tables[0]->BlockCount() * sink_state.tables[1]->BlockCount(); - } - - void GetNextPair(ClientContext &client, IEJoinLocalSourceState &lstate) { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; - - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; - - // Regular block - const auto i = next_pair++; - if (i < pair_count) { - const auto b1 = i / right_blocks; - const auto b2 = i % right_blocks; - - lstate.left_block_index = b1; - lstate.left_base = left_bases[b1]; - - lstate.right_block_index = b2; - lstate.right_base = right_bases[b2]; - - lstate.joiner = make_uniq(client, op, left_table, b1, right_table, b2); - return; - } - - // Outer joins - if (!left_outers && !right_outers) { - return; - } - - // Spin wait for regular blocks to finish(!) - while (completed < pair_count) { - std::this_thread::yield(); - } - - // Left outer blocks - const auto l = next_left++; - if (l < left_outers) { - lstate.joiner = nullptr; - lstate.left_block_index = l; - lstate.left_base = left_bases[l]; - - lstate.left_matches = left_table.found_match.get() + lstate.left_base; - lstate.outer_idx = 0; - lstate.outer_count = left_table.BlockSize(l); - return; - } else { - lstate.left_matches = nullptr; - } - - // Right outer block - const auto r = next_right++; - if (r < right_outers) { - lstate.joiner = nullptr; - lstate.right_block_index = r; - lstate.right_base = right_bases[r]; - - lstate.right_matches = right_table.found_match.get() + lstate.right_base; - lstate.outer_idx = 0; - lstate.outer_count = right_table.BlockSize(r); - return; - } else { - lstate.right_matches = nullptr; - } - } - - void PairCompleted(ClientContext &client, IEJoinLocalSourceState &lstate) { - lstate.joiner.reset(); - ++completed; - GetNextPair(client, lstate); - } - - ProgressData GetProgress() const { - auto &left_table = *gsink.tables[0]; - auto &right_table = *gsink.tables[1]; - - const auto left_blocks = left_table.BlockCount(); - const auto right_blocks = right_table.BlockCount(); - const auto pair_count = left_blocks * right_blocks; - - const auto count = pair_count + left_outers + right_outers; - - const auto l = MinValue(next_left.load(), left_outers.load()); - const auto r = MinValue(next_right.load(), right_outers.load()); - const auto returned = completed.load() + l + r; - - ProgressData res; - if (count) { - res.done = double(returned); - res.total = double(count); - } else { - res.SetInvalid(); - } - return res; - } - - const PhysicalIEJoin &op; - IEJoinGlobalState &gsink; - - bool initialized; - - // Join queue state - atomic next_pair; - atomic completed; - - // Block base row number - vector left_bases; - vector right_bases; - - // Outer joins - atomic left_outers; - atomic next_left; - - atomic right_outers; - atomic next_right; -}; - -unique_ptr PhysicalIEJoin::GetGlobalSourceState(ClientContext &context) const { - auto &gsink = sink_state->Cast(); - return make_uniq(*this, gsink); -} - -unique_ptr PhysicalIEJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context.client, *this); -} - -ProgressData PhysicalIEJoin::GetProgress(ClientContext &context, GlobalSourceState &gsource_p) const { - auto &gsource = gsource_p.Cast(); - return gsource.GetProgress(); -} - -SourceResultType PhysicalIEJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - auto &ie_sink = sink_state->Cast(); - auto &ie_gstate = input.global_state.Cast(); - auto &ie_lstate = input.local_state.Cast(); - - ie_gstate.Initialize(); - - if (!ie_lstate.joiner && !ie_lstate.left_matches && !ie_lstate.right_matches) { - ie_gstate.GetNextPair(context.client, ie_lstate); - } - - // Process INNER results - while (ie_lstate.joiner) { - ResolveComplexJoin(context, result, ie_lstate); - - if (result.size()) { - return SourceResultType::HAVE_MORE_OUTPUT; - } - - ie_gstate.PairCompleted(context.client, ie_lstate); - } - - // Process LEFT OUTER results - const auto left_cols = children[0]->GetTypes().size(); - while (ie_lstate.left_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.left_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[0]->global_sort_state, ie_lstate.left_block_index, ie_lstate.true_sel, - count); - - // Fill in NULLs to the right - for (auto col_idx = left_cols; col_idx < chunk.ColumnCount(); ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; - } - - // Process RIGHT OUTER results - while (ie_lstate.right_matches) { - const idx_t count = ie_lstate.SelectOuterRows(ie_lstate.right_matches); - if (!count) { - ie_gstate.GetNextPair(context.client, ie_lstate); - continue; - } - - auto &chunk = ie_lstate.unprojected; - chunk.Reset(); - SliceSortedPayload(chunk, ie_sink.tables[1]->global_sort_state, ie_lstate.right_block_index, ie_lstate.true_sel, - count, left_cols); - - // Fill in NULLs to the left - for (idx_t col_idx = 0; col_idx < left_cols; ++col_idx) { - chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[col_idx], true); - } - - ProjectResult(chunk, result); - result.SetCardinality(count); - result.Verify(); - - break; - } - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalIEJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - D_ASSERT(children.size() == 2); - if (meta_pipeline.HasRecursiveCTE()) { - throw NotImplementedException("IEJoins are not supported in recursive CTEs yet"); - } - - // becomes a source after both children fully sink their data - meta_pipeline.GetState().SetPipelineSource(current, *this); - - // Create one child meta pipeline that will hold the LHS and RHS pipelines - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - - // Build out LHS - auto lhs_pipeline = child_meta_pipeline.GetBasePipeline(); - children[0]->BuildPipelines(*lhs_pipeline, child_meta_pipeline); - - // Build out RHS - auto &rhs_pipeline = child_meta_pipeline.CreatePipeline(); - children[1]->BuildPipelines(rhs_pipeline, child_meta_pipeline); - - // Despite having the same sink, RHS and everything created after it need their own (same) PipelineFinishEvent - child_meta_pipeline.AddFinishEvent(rhs_pipeline); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_join.cpp b/src/duckdb/src/execution/operator/join/physical_join.cpp deleted file mode 100644 index bb7011a3a..000000000 --- a/src/duckdb/src/execution/operator/join/physical_join.cpp +++ /dev/null @@ -1,97 +0,0 @@ -#include "duckdb/execution/operator/join/physical_join.hpp" - -#include "duckdb/execution/operator/join/physical_hash_join.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" - -namespace duckdb { - -PhysicalJoin::PhysicalJoin(LogicalOperator &op, PhysicalOperatorType type, JoinType join_type, - idx_t estimated_cardinality) - : CachingPhysicalOperator(type, op.types, estimated_cardinality), join_type(join_type) { -} - -bool PhysicalJoin::EmptyResultIfRHSIsEmpty() const { - // empty RHS with INNER, RIGHT or SEMI join means empty result set - switch (join_type) { - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::SEMI: - case JoinType::RIGHT_SEMI: - case JoinType::RIGHT_ANTI: - return true; - default: - return false; - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalJoin::BuildJoinPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline, PhysicalOperator &op, - bool build_rhs) { - op.op_state.reset(); - op.sink_state.reset(); - - // 'current' is the probe pipeline: add this operator - auto &state = meta_pipeline.GetState(); - state.AddPipelineOperator(current, op); - - // save the last added pipeline to set up dependencies later (in case we need to add a child pipeline) - vector> pipelines_so_far; - meta_pipeline.GetPipelines(pipelines_so_far, false); - auto &last_pipeline = *pipelines_so_far.back(); - - vector> dependencies; - optional_ptr last_child_ptr; - if (build_rhs) { - // on the RHS (build side), we construct a child MetaPipeline with this operator as its sink - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, op, MetaPipelineType::JOIN_BUILD); - child_meta_pipeline.Build(*op.children[1]); - if (op.children[1]->CanSaturateThreads(current.GetClientContext())) { - // if the build side can saturate all available threads, - // we don't just make the LHS pipeline depend on the RHS, but recursively all LHS children too. - // this prevents breadth-first plan evaluation - child_meta_pipeline.GetPipelines(dependencies, false); - last_child_ptr = meta_pipeline.GetLastChild(); - } - } - - // continue building the current pipeline on the LHS (probe side) - op.children[0]->BuildPipelines(current, meta_pipeline); - - if (last_child_ptr) { - // the pointer was set, set up the dependencies - meta_pipeline.AddRecursiveDependencies(dependencies, *last_child_ptr); - } - - switch (op.type) { - case PhysicalOperatorType::POSITIONAL_JOIN: - // Positional joins are always outer - meta_pipeline.CreateChildPipeline(current, op, last_pipeline); - return; - case PhysicalOperatorType::CROSS_PRODUCT: - return; - default: - break; - } - - // Join can become a source operator if it's RIGHT/OUTER, or if the hash join goes out-of-core - if (op.Cast().IsSource()) { - meta_pipeline.CreateChildPipeline(current, op, last_pipeline); - } -} - -void PhysicalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalJoin::GetSources() const { - auto result = children[0]->GetSources(); - if (IsSource()) { - result.push_back(*this); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp deleted file mode 100644 index 49f259abc..000000000 --- a/src/duckdb/src/execution/operator/join/physical_left_delim_join.cpp +++ /dev/null @@ -1,143 +0,0 @@ -#include "duckdb/execution/operator/join/physical_left_delim_join.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_join.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" - -namespace duckdb { - -PhysicalLeftDelimJoin::PhysicalLeftDelimJoin(vector types, unique_ptr original_join, - vector> delim_scans, - idx_t estimated_cardinality, optional_idx delim_idx) - : PhysicalDelimJoin(PhysicalOperatorType::LEFT_DELIM_JOIN, std::move(types), std::move(original_join), - std::move(delim_scans), estimated_cardinality, delim_idx) { - D_ASSERT(join->children.size() == 2); - // now for the original join - // we take its left child, this is the side that we will duplicate eliminate - children.push_back(std::move(join->children[0])); - - // we replace it with a PhysicalColumnDataScan, that scans the ColumnDataCollection that we keep cached - // the actual chunk collection to scan will be created in the LeftDelimJoinGlobalState - auto cached_chunk_scan = make_uniq( - children[0]->GetTypes(), PhysicalOperatorType::COLUMN_DATA_SCAN, estimated_cardinality, nullptr); - if (delim_idx.IsValid()) { - cached_chunk_scan->cte_index = delim_idx.GetIndex(); - } - join->children[0] = std::move(cached_chunk_scan); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class LeftDelimJoinGlobalState : public GlobalSinkState { -public: - explicit LeftDelimJoinGlobalState(ClientContext &context, const PhysicalLeftDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0]->GetTypes()) { - D_ASSERT(!delim_join.delim_scans.empty()); - // set up the delim join chunk to scan in the original join - auto &cached_chunk_scan = delim_join.join->children[0]->Cast(); - cached_chunk_scan.collection = &lhs_data; - } - - ColumnDataCollection lhs_data; - mutex lhs_lock; - - void Merge(ColumnDataCollection &input) { - lock_guard guard(lhs_lock); - lhs_data.Combine(input); - } -}; - -class LeftDelimJoinLocalState : public LocalSinkState { -public: - explicit LeftDelimJoinLocalState(ClientContext &context, const PhysicalLeftDelimJoin &delim_join) - : lhs_data(context, delim_join.children[0]->GetTypes()) { - lhs_data.InitializeAppend(append_state); - } - - unique_ptr distinct_state; - ColumnDataCollection lhs_data; - ColumnDataAppendState append_state; - - void Append(DataChunk &input) { - lhs_data.Append(input); - } -}; - -unique_ptr PhysicalLeftDelimJoin::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(context, *this); - distinct->sink_state = distinct->GetGlobalSinkState(context); - if (delim_scans.size() > 1) { - PhysicalHashAggregate::SetMultiScan(*distinct->sink_state); - } - return std::move(state); -} - -unique_ptr PhysicalLeftDelimJoin::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(context.client, *this); - state->distinct_state = distinct->GetLocalSinkState(context); - return std::move(state); -} - -SinkResultType PhysicalLeftDelimJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.lhs_data.Append(lstate.append_state, chunk); - OperatorSinkInput distinct_sink_input {*distinct->sink_state, *lstate.distinct_state, input.interrupt_state}; - distinct->Sink(context, chunk, distinct_sink_input); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalLeftDelimJoin::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - gstate.Merge(lstate.lhs_data); - - OperatorSinkCombineInput distinct_combine_input {*distinct->sink_state, *lstate.distinct_state, - input.interrupt_state}; - distinct->Combine(context, distinct_combine_input); - - return SinkCombineResultType::FINISHED; -} - -void PhysicalLeftDelimJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { - distinct->PrepareFinalize(context, *distinct->sink_state); -} - -SinkFinalizeType PhysicalLeftDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, - OperatorSinkFinalizeInput &input) const { - // finalize the distinct HT - D_ASSERT(distinct); - - OperatorSinkFinalizeInput finalize_input {*distinct->sink_state, input.interrupt_state}; - distinct->Finalize(pipeline, event, client, finalize_input); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalLeftDelimJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - - D_ASSERT(type == PhysicalOperatorType::LEFT_DELIM_JOIN); - // recurse into the actual join - // any pipelines in there depend on the main pipeline - // any scan of the duplicate eliminated data on the RHS depends on this pipeline - // we add an entry to the mapping of (PhysicalOperator*) -> (Pipeline*) - auto &state = meta_pipeline.GetState(); - for (auto &delim_scan : delim_scans) { - state.delim_join_dependencies.insert( - make_pair(delim_scan, reference(*child_meta_pipeline.GetBasePipeline()))); - } - join->BuildPipelines(current, meta_pipeline); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp b/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp deleted file mode 100644 index 022337e5a..000000000 --- a/src/duckdb/src/execution/operator/join/physical_nested_loop_join.cpp +++ /dev/null @@ -1,470 +0,0 @@ -#include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/nested_loop_join.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/execution/operator/join/outer_join_marker.hpp" - -namespace duckdb { - -PhysicalNestedLoopJoin::PhysicalNestedLoopJoin(LogicalOperator &op, unique_ptr left, - unique_ptr right, vector cond, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalComparisonJoin(op, PhysicalOperatorType::NESTED_LOOP_JOIN, std::move(cond), join_type, - estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -bool PhysicalJoin::HasNullValues(DataChunk &chunk) { - for (idx_t col_idx = 0; col_idx < chunk.ColumnCount(); col_idx++) { - UnifiedVectorFormat vdata; - chunk.data[col_idx].ToUnifiedFormat(chunk.size(), vdata); - - if (vdata.validity.AllValid()) { - continue; - } - for (idx_t i = 0; i < chunk.size(); i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - return true; - } - } - } - return false; -} - -template -static void ConstructSemiOrAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - D_ASSERT(left.ColumnCount() == result.ColumnCount()); - // create the selection vector from the matches that were found - idx_t result_count = 0; - SelectionVector sel(STANDARD_VECTOR_SIZE); - for (idx_t i = 0; i < left.size(); i++) { - if (found_match[i] == MATCH) { - sel.set_index(result_count++, i); - } - } - // construct the final result - if (result_count > 0) { - // we only return the columns on the left side - // project them using the result selection vector - // reference the columns of the left side from the result - result.Slice(left, sel, result_count); - } else { - result.SetCardinality(0); - } -} - -void PhysicalJoin::ConstructSemiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - ConstructSemiOrAntiJoinResult(left, result, found_match); -} - -void PhysicalJoin::ConstructAntiJoinResult(DataChunk &left, DataChunk &result, bool found_match[]) { - ConstructSemiOrAntiJoinResult(left, result, found_match); -} - -void PhysicalJoin::ConstructMarkJoinResult(DataChunk &join_keys, DataChunk &left, DataChunk &result, bool found_match[], - bool has_null) { - // for the initial set of columns we just reference the left side - result.SetCardinality(left); - for (idx_t i = 0; i < left.ColumnCount(); i++) { - result.data[i].Reference(left.data[i]); - } - auto &mark_vector = result.data.back(); - mark_vector.SetVectorType(VectorType::FLAT_VECTOR); - // first we set the NULL values from the join keys - // if there is any NULL in the keys, the result is NULL - auto bool_result = FlatVector::GetData(mark_vector); - auto &mask = FlatVector::Validity(mark_vector); - for (idx_t col_idx = 0; col_idx < join_keys.ColumnCount(); col_idx++) { - UnifiedVectorFormat jdata; - join_keys.data[col_idx].ToUnifiedFormat(join_keys.size(), jdata); - if (!jdata.validity.AllValid()) { - for (idx_t i = 0; i < join_keys.size(); i++) { - auto jidx = jdata.sel->get_index(i); - mask.Set(i, jdata.validity.RowIsValid(jidx)); - } - } - } - // now set the remaining entries to either true or false based on whether a match was found - if (found_match) { - for (idx_t i = 0; i < left.size(); i++) { - bool_result[i] = found_match[i]; - } - } else { - memset(bool_result, 0, sizeof(bool) * left.size()); - } - // if the right side contains NULL values, the result of any FALSE becomes NULL - if (has_null) { - for (idx_t i = 0; i < left.size(); i++) { - if (!bool_result[i]) { - mask.SetInvalid(i); - } - } - } -} - -bool PhysicalNestedLoopJoin::IsSupported(const vector &conditions, JoinType join_type) { - if (join_type == JoinType::MARK) { - return true; - } - for (auto &cond : conditions) { - if (cond.left->return_type.InternalType() == PhysicalType::STRUCT || - cond.left->return_type.InternalType() == PhysicalType::LIST || - cond.left->return_type.InternalType() == PhysicalType::ARRAY) { - return false; - } - } - // To avoid situations like https://github.com/duckdb/duckdb/issues/10046 - // If there is an equality in the conditions, a hash join is planned - // with one condition, we can use mark join logic, otherwise we should use physical blockwise nl join - if (join_type == JoinType::SEMI || join_type == JoinType::ANTI) { - return conditions.size() == 1; - } - return true; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class NestedLoopJoinLocalState : public LocalSinkState { -public: - explicit NestedLoopJoinLocalState(ClientContext &context, const vector &conditions) - : rhs_executor(context) { - vector condition_types; - for (auto &cond : conditions) { - rhs_executor.AddExpression(*cond.right); - condition_types.push_back(cond.right->return_type); - } - right_condition.Initialize(Allocator::Get(context), condition_types); - } - - //! The chunk holding the right condition - DataChunk right_condition; - //! The executor of the RHS condition - ExpressionExecutor rhs_executor; -}; - -class NestedLoopJoinGlobalState : public GlobalSinkState { -public: - explicit NestedLoopJoinGlobalState(ClientContext &context, const PhysicalNestedLoopJoin &op) - : right_payload_data(context, op.children[1]->types), right_condition_data(context, op.GetJoinTypes()), - has_null(false), right_outer(PropagatesBuildSide(op.join_type)) { - } - - mutex nj_lock; - //! Materialized data of the RHS - ColumnDataCollection right_payload_data; - //! Materialized join condition of the RHS - ColumnDataCollection right_condition_data; - //! Whether or not the RHS of the nested loop join has NULL values - atomic has_null; - //! A bool indicating for each tuple in the RHS if they found a match (only used in FULL OUTER JOIN) - OuterJoinMarker right_outer; -}; - -vector PhysicalNestedLoopJoin::GetJoinTypes() const { - vector result; - for (auto &op : conditions) { - result.push_back(op.right->return_type); - } - return result; -} - -SinkResultType PhysicalNestedLoopJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &nlj_state = input.local_state.Cast(); - - // resolve the join expression of the right side - nlj_state.right_condition.Reset(); - nlj_state.rhs_executor.Execute(chunk, nlj_state.right_condition); - - // if we have not seen any NULL values yet, and we are performing a MARK join, check if there are NULL values in - // this chunk - if (join_type == JoinType::MARK && !gstate.has_null) { - if (HasNullValues(nlj_state.right_condition)) { - gstate.has_null = true; - } - } - - // append the payload data and the conditions - lock_guard nj_guard(gstate.nj_lock); - gstate.right_payload_data.Append(chunk); - gstate.right_condition_data.Append(nlj_state.right_condition); - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalNestedLoopJoin::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalNestedLoopJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - gstate.right_outer.Initialize(gstate.right_payload_data.Count()); - if (gstate.right_payload_data.Count() == 0 && EmptyResultIfRHSIsEmpty()) { - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - return SinkFinalizeType::READY; -} - -unique_ptr PhysicalNestedLoopJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalNestedLoopJoin::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, conditions); -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class PhysicalNestedLoopJoinState : public CachingOperatorState { -public: - PhysicalNestedLoopJoinState(ClientContext &context, const PhysicalNestedLoopJoin &op, - const vector &conditions) - : fetch_next_left(true), fetch_next_right(false), lhs_executor(context), left_tuple(0), right_tuple(0), - left_outer(IsLeftOuterJoin(op.join_type)) { - vector condition_types; - for (auto &cond : conditions) { - lhs_executor.AddExpression(*cond.left); - condition_types.push_back(cond.left->return_type); - } - auto &allocator = Allocator::Get(context); - left_condition.Initialize(allocator, condition_types); - right_condition.Initialize(allocator, condition_types); - right_payload.Initialize(allocator, op.children[1]->GetTypes()); - left_outer.Initialize(STANDARD_VECTOR_SIZE); - } - - bool fetch_next_left; - bool fetch_next_right; - DataChunk left_condition; - //! The executor of the LHS condition - ExpressionExecutor lhs_executor; - - ColumnDataScanState condition_scan_state; - ColumnDataScanState payload_scan_state; - DataChunk right_condition; - DataChunk right_payload; - - idx_t left_tuple; - idx_t right_tuple; - - OuterJoinMarker left_outer; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op); - } -}; - -unique_ptr PhysicalNestedLoopJoin::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context.client, *this, conditions); -} - -OperatorResultType PhysicalNestedLoopJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &gstate = sink_state->Cast(); - - if (gstate.right_payload_data.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gstate.has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - switch (join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, input, chunk, state_p); - return OperatorResultType::NEED_MORE_INPUT; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::OUTER: - case JoinType::RIGHT: - return ResolveComplexJoin(context, input, chunk, state_p); - default: - throw NotImplementedException("Unimplemented type " + JoinTypeToString(join_type) + " for nested loop join!"); - } -} - -void PhysicalNestedLoopJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - // resolve the left join condition for the current chunk - state.left_condition.Reset(); - state.lhs_executor.Execute(input, state.left_condition); - - bool found_match[STANDARD_VECTOR_SIZE] = {false}; - NestedLoopJoinMark::Perform(state.left_condition, gstate.right_condition_data, found_match, conditions); - switch (join_type) { - case JoinType::MARK: - // now construct the mark join result from the found matches - PhysicalJoin::ConstructMarkJoinResult(state.left_condition, input, chunk, found_match, gstate.has_null); - break; - case JoinType::SEMI: - // construct the semi join result from the found matches - PhysicalJoin::ConstructSemiJoinResult(input, chunk, found_match); - break; - case JoinType::ANTI: - // construct the anti join result from the found matches - PhysicalJoin::ConstructAntiJoinResult(input, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented type for simple nested loop join!"); - } -} - -OperatorResultType PhysicalNestedLoopJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - idx_t match_count; - do { - if (state.fetch_next_right) { - // we exhausted the chunk on the right: move to the next chunk on the right - state.left_tuple = 0; - state.right_tuple = 0; - state.fetch_next_right = false; - // check if we exhausted all chunks on the RHS - if (gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition)) { - if (!gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload)) { - throw InternalException("Nested loop join: payload and conditions are unaligned!?"); - } - if (state.right_condition.size() != state.right_payload.size()) { - throw InternalException("Nested loop join: payload and conditions are unaligned!?"); - } - } else { - // we exhausted all chunks on the right: move to the next chunk on the left - state.fetch_next_left = true; - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(input, chunk); - state.left_outer.Reset(); - } - return OperatorResultType::NEED_MORE_INPUT; - } - } - if (state.fetch_next_left) { - // resolve the left join condition for the current chunk - state.left_condition.Reset(); - state.lhs_executor.Execute(input, state.left_condition); - - state.left_tuple = 0; - state.right_tuple = 0; - gstate.right_condition_data.InitializeScan(state.condition_scan_state); - gstate.right_condition_data.Scan(state.condition_scan_state, state.right_condition); - - gstate.right_payload_data.InitializeScan(state.payload_scan_state); - gstate.right_payload_data.Scan(state.payload_scan_state, state.right_payload); - state.fetch_next_left = false; - } - // now we have a left and a right chunk that we can join together - // note that we only get here in the case of a LEFT, INNER or FULL join - auto &left_chunk = input; - auto &right_condition = state.right_condition; - auto &right_payload = state.right_payload; - - // sanity check - left_chunk.Verify(); - right_condition.Verify(); - right_payload.Verify(); - - // now perform the join - SelectionVector lvector(STANDARD_VECTOR_SIZE), rvector(STANDARD_VECTOR_SIZE); - match_count = NestedLoopJoinInner::Perform(state.left_tuple, state.right_tuple, state.left_condition, - right_condition, lvector, rvector, conditions); - // we have finished resolving the join conditions - if (match_count > 0) { - // we have matching tuples! - // construct the result - state.left_outer.SetMatches(lvector, match_count); - gstate.right_outer.SetMatches(rvector, match_count, state.condition_scan_state.current_row_index); - - chunk.Slice(input, lvector, match_count); - chunk.Slice(right_payload, rvector, match_count, input.ColumnCount()); - } - - // check if we exhausted the RHS, if we did we need to move to the next right chunk in the next iteration - if (state.right_tuple >= right_condition.size()) { - state.fetch_next_right = true; - } - } while (match_count == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class NestedLoopJoinGlobalScanState : public GlobalSourceState { -public: - explicit NestedLoopJoinGlobalScanState(const PhysicalNestedLoopJoin &op) : op(op) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(sink.right_payload_data, scan_state); - } - - const PhysicalNestedLoopJoin &op; - OuterJoinGlobalScanState scan_state; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.right_outer.MaxThreads(); - } -}; - -class NestedLoopJoinLocalScanState : public LocalSourceState { -public: - explicit NestedLoopJoinLocalScanState(const PhysicalNestedLoopJoin &op, NestedLoopJoinGlobalScanState &gstate) { - D_ASSERT(op.sink_state); - auto &sink = op.sink_state->Cast(); - sink.right_outer.InitializeScan(gstate.scan_state, scan_state); - } - - OuterJoinLocalScanState scan_state; -}; - -unique_ptr PhysicalNestedLoopJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -unique_ptr PhysicalNestedLoopJoin::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(*this, gstate.Cast()); -} - -SourceResultType PhysicalNestedLoopJoin::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(PropagatesBuildSide(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan chunks we still need to output - sink.right_outer.Scan(gstate.scan_state, lstate.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp b/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp deleted file mode 100644 index 8216d91ab..000000000 --- a/src/duckdb/src/execution/operator/join/physical_piecewise_merge_join.cpp +++ /dev/null @@ -1,768 +0,0 @@ -#include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/operator/join/outer_join_marker.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/thread_context.hpp" - -namespace duckdb { - -PhysicalPiecewiseMergeJoin::PhysicalPiecewiseMergeJoin(LogicalComparisonJoin &op, unique_ptr left, - unique_ptr right, vector cond, - JoinType join_type, idx_t estimated_cardinality) - : PhysicalRangeJoin(op, PhysicalOperatorType::PIECEWISE_MERGE_JOIN, std::move(left), std::move(right), - std::move(cond), join_type, estimated_cardinality) { - - for (auto &cond : conditions) { - D_ASSERT(cond.left->return_type == cond.right->return_type); - join_key_types.push_back(cond.left->return_type); - - // Convert the conditions to sort orders - auto left = cond.left->Copy(); - auto right = cond.right->Copy(); - switch (cond.comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - lhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - break; - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - lhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_LAST, std::move(right)); - break; - case ExpressionType::COMPARE_NOTEQUAL: - case ExpressionType::COMPARE_DISTINCT_FROM: - // Allowed in multi-predicate joins, but can't be first/sort. - D_ASSERT(!lhs_orders.empty()); - lhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(left)); - rhs_orders.emplace_back(OrderType::INVALID, OrderByNullType::NULLS_LAST, std::move(right)); - break; - - default: - // COMPARE EQUAL not supported with merge join - throw NotImplementedException("Unimplemented join type for merge join"); - } - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class MergeJoinLocalState : public LocalSinkState { -public: - explicit MergeJoinLocalState(ClientContext &context, const PhysicalRangeJoin &op, const idx_t child) - : table(context, op, child) { - } - - //! The local sort state - PhysicalRangeJoin::LocalSortedTable table; -}; - -class MergeJoinGlobalState : public GlobalSinkState { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - MergeJoinGlobalState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op) { - RowLayout rhs_layout; - rhs_layout.Initialize(op.children[1]->types); - vector rhs_order; - rhs_order.emplace_back(op.rhs_orders[0].Copy()); - table = make_uniq(context, rhs_order, rhs_layout, op); - } - - inline idx_t Count() const { - return table->count; - } - - void Sink(DataChunk &input, MergeJoinLocalState &lstate) { - auto &global_sort_state = table->global_sort_state; - auto &local_sort_state = lstate.table.local_sort_state; - - // Sink the data into the local sort state - lstate.table.Sink(input, global_sort_state); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= table->memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - } - - unique_ptr table; -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr PhysicalPiecewiseMergeJoin::GetLocalSinkState(ExecutionContext &context) const { - // We only sink the RHS - return make_uniq(context.client, *this, 1U); -} - -SinkResultType PhysicalPiecewiseMergeJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - gstate.Sink(chunk, lstate); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalPiecewiseMergeJoin::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.table->Combine(lstate.table); - auto &client_profiler = QueryProfiler::Get(context.client); - - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalPiecewiseMergeJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &global_sort_state = gstate.table->global_sort_state; - - if (PropagatesBuildSide(join_type)) { - // for FULL/RIGHT OUTER JOIN, initialize found_match to false for every tuple - gstate.table->IntializeMatches(); - } - if (global_sort_state.sorted_blocks.empty() && EmptyResultIfRHSIsEmpty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Sort the current input child - gstate.table->Finalize(pipeline, event); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -class PiecewiseMergeJoinState : public CachingOperatorState { -public: - using LocalSortedTable = PhysicalRangeJoin::LocalSortedTable; - - PiecewiseMergeJoinState(ClientContext &context, const PhysicalPiecewiseMergeJoin &op, bool force_external) - : context(context), allocator(Allocator::Get(context)), op(op), - buffer_manager(BufferManager::GetBufferManager(context)), force_external(force_external), - left_outer(IsLeftOuterJoin(op.join_type)), left_position(0), first_fetch(true), finished(true), - right_position(0), right_chunk_index(0), rhs_executor(context) { - vector condition_types; - for (auto &order : op.lhs_orders) { - condition_types.push_back(order.expression->return_type); - } - left_outer.Initialize(STANDARD_VECTOR_SIZE); - lhs_layout.Initialize(op.children[0]->types); - lhs_payload.Initialize(allocator, op.children[0]->types); - - lhs_order.emplace_back(op.lhs_orders[0].Copy()); - - // Set up shared data for multiple predicates - sel.Initialize(STANDARD_VECTOR_SIZE); - condition_types.clear(); - for (auto &order : op.rhs_orders) { - rhs_executor.AddExpression(*order.expression); - condition_types.push_back(order.expression->return_type); - } - rhs_keys.Initialize(allocator, condition_types); - } - - ClientContext &context; - Allocator &allocator; - const PhysicalPiecewiseMergeJoin &op; - BufferManager &buffer_manager; - bool force_external; - - // Block sorting - DataChunk lhs_payload; - OuterJoinMarker left_outer; - vector lhs_order; - RowLayout lhs_layout; - unique_ptr lhs_local_table; - unique_ptr lhs_global_state; - unique_ptr scanner; - - // Simple scans - idx_t left_position; - - // Complex scans - bool first_fetch; - bool finished; - idx_t right_position; - idx_t right_chunk_index; - idx_t right_base; - idx_t prev_left_index; - - // Secondary predicate shared data - SelectionVector sel; - DataChunk rhs_keys; - DataChunk rhs_input; - ExpressionExecutor rhs_executor; - vector payload_heap_handles; - -public: - void ResolveJoinKeys(DataChunk &input) { - // sort by join key - lhs_global_state = make_uniq(buffer_manager, lhs_order, lhs_layout); - lhs_local_table = make_uniq(context, op, 0U); - lhs_local_table->Sink(input, *lhs_global_state); - - // Set external (can be forced with the PRAGMA) - lhs_global_state->external = force_external; - lhs_global_state->AddLocalState(lhs_local_table->local_sort_state); - lhs_global_state->PrepareMergePhase(); - while (lhs_global_state->sorted_blocks.size() > 1) { - MergeSorter merge_sorter(*lhs_global_state, buffer_manager); - merge_sorter.PerformInMergeRound(); - lhs_global_state->CompleteMergeRound(); - } - - // Scan the sorted payload - D_ASSERT(lhs_global_state->sorted_blocks.size() == 1); - - scanner = make_uniq(*lhs_global_state->sorted_blocks[0]->payload_data, *lhs_global_state); - lhs_payload.Reset(); - scanner->Scan(lhs_payload); - - // Recompute the sorted keys from the sorted input - lhs_local_table->keys.Reset(); - lhs_local_table->executor.Execute(lhs_payload, lhs_local_table->keys); - } - - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - if (lhs_local_table) { - context.thread.profiler.Flush(op); - } - } -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetOperatorState(ExecutionContext &context) const { - auto &config = ClientConfig::GetConfig(context.client); - return make_uniq(context.client, *this, config.force_external); -} - -static inline idx_t SortedBlockNotNull(const idx_t base, const idx_t count, const idx_t not_null) { - return MinValue(base + count, MaxValue(base, not_null)) - base; -} - -static int MergeJoinComparisonValue(ExpressionType comparison) { - switch (comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_GREATERTHAN: - return -1; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return 0; - default: - throw InternalException("Unimplemented comparison type for merge join!"); - } -} - -struct BlockMergeInfo { - GlobalSortState &state; - //! The block being scanned - const idx_t block_idx; - //! The number of not-NULL values in the block (they are at the end) - const idx_t not_null; - //! The current offset in the block - idx_t &entry_idx; - SelectionVector result; - - BlockMergeInfo(GlobalSortState &state, idx_t block_idx, idx_t &entry_idx, idx_t not_null) - : state(state), block_idx(block_idx), not_null(not_null), entry_idx(entry_idx), result(STANDARD_VECTOR_SIZE) { - } -}; - -static void MergeJoinPinSortingBlock(SBScanState &scan, const idx_t block_idx) { - scan.SetIndices(block_idx, 0); - scan.PinRadix(block_idx); - - auto &sd = *scan.sb->blob_sorting_data; - if (block_idx < sd.data_blocks.size()) { - scan.PinData(sd); - } -} - -static data_ptr_t MergeJoinRadixPtr(SBScanState &scan, const idx_t entry_idx) { - scan.entry_idx = entry_idx; - return scan.RadixPtr(); -} - -static idx_t MergeJoinSimpleBlocks(PiecewiseMergeJoinState &lstate, MergeJoinGlobalState &rstate, bool *found_match, - const ExpressionType comparison) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - auto &lsort = *lstate.lhs_global_state; - auto &rsort = rstate.table->global_sort_state; - D_ASSERT(lsort.sort_layout.all_constant == rsort.sort_layout.all_constant); - const auto all_constant = lsort.sort_layout.all_constant; - D_ASSERT(lsort.external == rsort.external); - const auto external = lsort.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(lsort.sorted_blocks.size() == 1); - SBScanState lread(lsort.buffer_manager, lsort); - lread.sb = lsort.sorted_blocks[0].get(); - - const idx_t l_block_idx = 0; - idx_t l_entry_idx = 0; - const auto lhs_not_null = lstate.lhs_local_table->count - lstate.lhs_local_table->has_null; - MergeJoinPinSortingBlock(lread, l_block_idx); - auto l_ptr = MergeJoinRadixPtr(lread, l_entry_idx); - - D_ASSERT(rsort.sorted_blocks.size() == 1); - SBScanState rread(rsort.buffer_manager, rsort); - rread.sb = rsort.sorted_blocks[0].get(); - - const auto cmp_size = lsort.sort_layout.comparison_size; - const auto entry_size = lsort.sort_layout.entry_size; - - idx_t right_base = 0; - for (idx_t r_block_idx = 0; r_block_idx < rread.sb->radix_sorting_data.size(); r_block_idx++) { - // we only care about the BIGGEST value in each of the RHS data blocks - // because we want to figure out if the LHS values are less than [or equal] to ANY value - // get the biggest value from the RHS chunk - MergeJoinPinSortingBlock(rread, r_block_idx); - - auto &rblock = *rread.sb->radix_sorting_data[r_block_idx]; - const auto r_not_null = - SortedBlockNotNull(right_base, rblock.count, rstate.table->count - rstate.table->has_null); - if (r_not_null == 0) { - break; - } - const auto r_entry_idx = r_not_null - 1; - right_base += rblock.count; - - auto r_ptr = MergeJoinRadixPtr(rread, r_entry_idx); - - // now we start from the current lpos value and check if we found a new value that is [<= OR <] the max RHS - // value - while (true) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l_entry_idx; - rread.entry_idx = r_entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, lsort.sort_layout, external); - } - - if (comp_res <= cmp) { - // found a match for lpos, set it in the found_match vector - found_match[l_entry_idx] = true; - l_entry_idx++; - l_ptr += entry_size; - if (l_entry_idx >= lhs_not_null) { - // early out: we exhausted the entire LHS and they all match - return 0; - } - } else { - // we found no match: any subsequent value from the LHS we scan now will be bigger and thus also not - // match move to the next RHS chunk - break; - } - } - } - return 0; -} - -void PhysicalPiecewiseMergeJoin::ResolveSimpleJoin(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - - state.ResolveJoinKeys(input); - auto &lhs_table = *state.lhs_local_table; - - // perform the actual join - bool found_match[STANDARD_VECTOR_SIZE]; - memset(found_match, 0, sizeof(found_match)); - MergeJoinSimpleBlocks(state, gstate, found_match, conditions[0].comparison); - - // use the sorted payload - const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - auto &payload = state.lhs_payload; - - // now construct the result based on the join result - switch (join_type) { - case JoinType::MARK: { - // The only part of the join keys that is actually used is the validity mask. - // Since the payload is sorted, we can just set the tail end of the validity masks to invalid. - for (auto &key : lhs_table.keys.data) { - key.Flatten(lhs_table.keys.size()); - auto &mask = FlatVector::Validity(key); - if (mask.AllValid()) { - continue; - } - mask.SetAllValid(lhs_not_null); - for (idx_t i = lhs_not_null; i < lhs_table.count; ++i) { - mask.SetInvalid(i); - } - } - // So we make a set of keys that have the validity mask set for the - PhysicalJoin::ConstructMarkJoinResult(lhs_table.keys, payload, chunk, found_match, gstate.table->has_null); - break; - } - case JoinType::SEMI: - PhysicalJoin::ConstructSemiJoinResult(payload, chunk, found_match); - break; - case JoinType::ANTI: - PhysicalJoin::ConstructAntiJoinResult(payload, chunk, found_match); - break; - default: - throw NotImplementedException("Unimplemented join type for merge join"); - } -} - -static idx_t MergeJoinComplexBlocks(BlockMergeInfo &l, BlockMergeInfo &r, const ExpressionType comparison, - idx_t &prev_left_index) { - const auto cmp = MergeJoinComparisonValue(comparison); - - // The sort parameters should all be the same - D_ASSERT(l.state.sort_layout.all_constant == r.state.sort_layout.all_constant); - const auto all_constant = r.state.sort_layout.all_constant; - D_ASSERT(l.state.external == r.state.external); - const auto external = l.state.external; - - // There should only be one sorted block if they have been sorted - D_ASSERT(l.state.sorted_blocks.size() == 1); - SBScanState lread(l.state.buffer_manager, l.state); - lread.sb = l.state.sorted_blocks[0].get(); - D_ASSERT(lread.sb->radix_sorting_data.size() == 1); - MergeJoinPinSortingBlock(lread, l.block_idx); - auto l_start = MergeJoinRadixPtr(lread, 0); - auto l_ptr = MergeJoinRadixPtr(lread, l.entry_idx); - - D_ASSERT(r.state.sorted_blocks.size() == 1); - SBScanState rread(r.state.buffer_manager, r.state); - rread.sb = r.state.sorted_blocks[0].get(); - - if (r.entry_idx >= r.not_null) { - return 0; - } - - MergeJoinPinSortingBlock(rread, r.block_idx); - auto r_ptr = MergeJoinRadixPtr(rread, r.entry_idx); - - const auto cmp_size = l.state.sort_layout.comparison_size; - const auto entry_size = l.state.sort_layout.entry_size; - - idx_t result_count = 0; - while (true) { - if (l.entry_idx < prev_left_index) { - // left side smaller: found match - l.result.set_index(result_count, sel_t(l.entry_idx)); - r.result.set_index(result_count, sel_t(r.entry_idx)); - result_count++; - // move left side forward - l.entry_idx++; - l_ptr += entry_size; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - break; - } - continue; - } - if (l.entry_idx < l.not_null) { - int comp_res; - if (all_constant) { - comp_res = FastMemcmp(l_ptr, r_ptr, cmp_size); - } else { - lread.entry_idx = l.entry_idx; - rread.entry_idx = r.entry_idx; - comp_res = Comparators::CompareTuple(lread, rread, l_ptr, r_ptr, l.state.sort_layout, external); - } - if (comp_res <= cmp) { - // left side smaller: found match - l.result.set_index(result_count, sel_t(l.entry_idx)); - r.result.set_index(result_count, sel_t(r.entry_idx)); - result_count++; - // move left side forward - l.entry_idx++; - l_ptr += entry_size; - if (result_count == STANDARD_VECTOR_SIZE) { - // out of space! - break; - } - continue; - } - } - - prev_left_index = l.entry_idx; - // right side smaller or equal, or left side exhausted: move - // right pointer forward reset left side to start - r.entry_idx++; - if (r.entry_idx >= r.not_null) { - break; - } - r_ptr += entry_size; - - l_ptr = l_start; - l.entry_idx = 0; - } - - return result_count; -} - -OperatorResultType PhysicalPiecewiseMergeJoin::ResolveComplexJoin(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, OperatorState &state_p) const { - auto &state = state_p.Cast(); - auto &gstate = sink_state->Cast(); - auto &rsorted = *gstate.table->global_sort_state.sorted_blocks[0]; - const auto left_cols = input.ColumnCount(); - const auto tail_cols = conditions.size() - 1; - - state.payload_heap_handles.clear(); - do { - if (state.first_fetch) { - state.ResolveJoinKeys(input); - - state.right_chunk_index = 0; - state.right_base = 0; - state.left_position = 0; - state.prev_left_index = 0; - state.right_position = 0; - state.first_fetch = false; - state.finished = false; - } - if (state.finished) { - if (state.left_outer.Enabled()) { - // left join: before we move to the next chunk, see if we need to output any vectors that didn't - // have a match found - state.left_outer.ConstructLeftJoinResult(state.lhs_payload, chunk); - state.left_outer.Reset(); - } - state.first_fetch = true; - state.finished = false; - return OperatorResultType::NEED_MORE_INPUT; - } - - auto &lhs_table = *state.lhs_local_table; - const auto lhs_not_null = lhs_table.count - lhs_table.has_null; - BlockMergeInfo left_info(*state.lhs_global_state, 0, state.left_position, lhs_not_null); - - const auto &rblock = *rsorted.radix_sorting_data[state.right_chunk_index]; - const auto rhs_not_null = - SortedBlockNotNull(state.right_base, rblock.count, gstate.table->count - gstate.table->has_null); - BlockMergeInfo right_info(gstate.table->global_sort_state, state.right_chunk_index, state.right_position, - rhs_not_null); - - idx_t result_count = - MergeJoinComplexBlocks(left_info, right_info, conditions[0].comparison, state.prev_left_index); - if (result_count == 0) { - // exhausted this chunk on the right side - // move to the next right chunk - state.left_position = 0; - state.right_position = 0; - state.right_base += rsorted.radix_sorting_data[state.right_chunk_index]->count; - state.right_chunk_index++; - if (state.right_chunk_index >= rsorted.radix_sorting_data.size()) { - state.finished = true; - } - } else { - // found matches: extract them - chunk.Reset(); - for (idx_t c = 0; c < state.lhs_payload.ColumnCount(); ++c) { - chunk.data[c].Slice(state.lhs_payload.data[c], left_info.result, result_count); - } - state.payload_heap_handles.push_back(SliceSortedPayload(chunk, right_info.state, right_info.block_idx, - right_info.result, result_count, left_cols)); - chunk.SetCardinality(result_count); - - auto sel = FlatVector::IncrementalSelectionVector(); - if (tail_cols) { - // If there are more expressions to compute, - // split the result chunk into the left and right halves - // so we can compute the values for comparison. - chunk.Split(state.rhs_input, left_cols); - state.rhs_executor.SetChunk(state.rhs_input); - state.rhs_keys.Reset(); - - auto tail_count = result_count; - for (size_t cmp_idx = 1; cmp_idx < conditions.size(); ++cmp_idx) { - Vector left(lhs_table.keys.data[cmp_idx]); - left.Slice(left_info.result, result_count); - - auto &right = state.rhs_keys.data[cmp_idx]; - state.rhs_executor.ExecuteExpression(cmp_idx, right); - - if (tail_count < result_count) { - left.Slice(*sel, tail_count); - right.Slice(*sel, tail_count); - } - tail_count = - SelectJoinTail(conditions[cmp_idx].comparison, left, right, sel, tail_count, &state.sel); - sel = &state.sel; - } - chunk.Fuse(state.rhs_input); - - if (tail_count < result_count) { - result_count = tail_count; - if (result_count == 0) { - // Need to reset here otherwise we may use the non-flat chunk when constructing LEFT/OUTER - chunk.Reset(); - } else { - chunk.Slice(*sel, result_count); - } - } - } - - // found matches: mark the found matches if required - if (state.left_outer.Enabled()) { - for (idx_t i = 0; i < result_count; i++) { - state.left_outer.SetMatch(left_info.result[sel->get_index(i)]); - } - } - if (gstate.table->found_match) { - // Absolute position of the block + start position inside that block - for (idx_t i = 0; i < result_count; i++) { - gstate.table->found_match[state.right_base + right_info.result[sel->get_index(i)]] = true; - } - } - chunk.SetCardinality(result_count); - chunk.Verify(); - } - } while (chunk.size() == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType PhysicalPiecewiseMergeJoin::ExecuteInternal(ExecutionContext &context, DataChunk &input, - DataChunk &chunk, GlobalOperatorState &gstate_p, - OperatorState &state) const { - auto &gstate = sink_state->Cast(); - - if (gstate.Count() == 0) { - // empty RHS - if (!EmptyResultIfRHSIsEmpty()) { - ConstructEmptyJoinResult(join_type, gstate.table->has_null, input, chunk); - return OperatorResultType::NEED_MORE_INPUT; - } else { - return OperatorResultType::FINISHED; - } - } - - input.Verify(); - switch (join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::MARK: - // simple joins can have max STANDARD_VECTOR_SIZE matches per chunk - ResolveSimpleJoin(context, input, chunk, state); - return OperatorResultType::NEED_MORE_INPUT; - case JoinType::LEFT: - case JoinType::INNER: - case JoinType::RIGHT: - case JoinType::OUTER: - return ResolveComplexJoin(context, input, chunk, state); - default: - throw NotImplementedException("Unimplemented type for piecewise merge loop join!"); - } -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PiecewiseJoinScanState : public GlobalSourceState { -public: - explicit PiecewiseJoinScanState(const PhysicalPiecewiseMergeJoin &op) : op(op), right_outer_position(0) { - } - - mutex lock; - const PhysicalPiecewiseMergeJoin &op; - unique_ptr scanner; - idx_t right_outer_position; - -public: - idx_t MaxThreads() override { - auto &sink = op.sink_state->Cast(); - return sink.Count() / (STANDARD_VECTOR_SIZE * idx_t(10)); - } -}; - -unique_ptr PhysicalPiecewiseMergeJoin::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalPiecewiseMergeJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - D_ASSERT(PropagatesBuildSide(join_type)); - // check if we need to scan any unmatched tuples from the RHS for the full/right outer join - auto &sink = sink_state->Cast(); - auto &state = input.global_state.Cast(); - - lock_guard l(state.lock); - if (!state.scanner) { - // Initialize scanner (if not yet initialized) - auto &sort_state = sink.table->global_sort_state; - if (sort_state.sorted_blocks.empty()) { - return SourceResultType::FINISHED; - } - state.scanner = make_uniq(*sort_state.sorted_blocks[0]->payload_data, sort_state); - } - - // if the LHS is exhausted in a FULL/RIGHT OUTER JOIN, we scan the found_match for any chunks we - // still need to output - const auto found_match = sink.table->found_match.get(); - - DataChunk rhs_chunk; - rhs_chunk.Initialize(Allocator::Get(context.client), sink.table->global_sort_state.payload_layout.GetTypes()); - SelectionVector rsel(STANDARD_VECTOR_SIZE); - for (;;) { - // Read the next sorted chunk - state.scanner->Scan(rhs_chunk); - - const auto count = rhs_chunk.size(); - if (count == 0) { - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; - } - - idx_t result_count = 0; - // figure out which tuples didn't find a match in the RHS - for (idx_t i = 0; i < count; i++) { - if (!found_match[state.right_outer_position + i]) { - rsel.set_index(result_count++, i); - } - } - state.right_outer_position += count; - - if (result_count > 0) { - // if there were any tuples that didn't find a match, output them - const idx_t left_column_count = children[0]->types.size(); - for (idx_t col_idx = 0; col_idx < left_column_count; ++col_idx) { - result.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result.data[col_idx], true); - } - const idx_t right_column_count = children[1]->types.size(); - ; - for (idx_t col_idx = 0; col_idx < right_column_count; ++col_idx) { - result.data[left_column_count + col_idx].Slice(rhs_chunk.data[col_idx], rsel, result_count); - } - result.SetCardinality(result_count); - break; - } - } - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp b/src/duckdb/src/execution/operator/join/physical_positional_join.cpp deleted file mode 100644 index bcf4b498b..000000000 --- a/src/duckdb/src/execution/operator/join/physical_positional_join.cpp +++ /dev/null @@ -1,196 +0,0 @@ -#include "duckdb/execution/operator/join/physical_positional_join.hpp" - -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/operator/join/physical_join.hpp" - -namespace duckdb { - -PhysicalPositionalJoin::PhysicalPositionalJoin(vector types, unique_ptr left, - unique_ptr right, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::POSITIONAL_JOIN, std::move(types), estimated_cardinality) { - children.push_back(std::move(left)); - children.push_back(std::move(right)); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class PositionalJoinGlobalState : public GlobalSinkState { -public: - explicit PositionalJoinGlobalState(ClientContext &context, const PhysicalPositionalJoin &op) - : rhs(context, op.children[1]->GetTypes()), initialized(false), source_offset(0), exhausted(false) { - rhs.InitializeAppend(append_state); - } - - ColumnDataCollection rhs; - ColumnDataAppendState append_state; - mutex rhs_lock; - - bool initialized; - ColumnDataScanState scan_state; - DataChunk source; - idx_t source_offset; - bool exhausted; - - void InitializeScan(); - idx_t Refill(); - idx_t CopyData(DataChunk &output, const idx_t count, const idx_t col_offset); - void Execute(DataChunk &input, DataChunk &output); - void GetData(DataChunk &output); -}; - -unique_ptr PhysicalPositionalJoin::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SinkResultType PhysicalPositionalJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &sink = input.global_state.Cast(); - lock_guard client_guard(sink.rhs_lock); - sink.rhs.Append(sink.append_state, chunk); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -void PositionalJoinGlobalState::InitializeScan() { - if (!initialized) { - // not initialized yet: initialize the scan - initialized = true; - rhs.InitializeScanChunk(source); - rhs.InitializeScan(scan_state); - } -} - -idx_t PositionalJoinGlobalState::Refill() { - if (source_offset >= source.size()) { - if (!exhausted) { - source.Reset(); - rhs.Scan(scan_state, source); - } - source_offset = 0; - } - - const auto available = source.size() - source_offset; - if (!available) { - if (!exhausted) { - source.Reset(); - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - auto &vec = source.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - exhausted = true; - } - } - - return available; -} - -idx_t PositionalJoinGlobalState::CopyData(DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { - // Fast track: aligned and has enough data - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - output.data[col_offset + i].Reference(source.data[i]); - } - source_offset += count; - } else { - // Copy data - for (idx_t target_offset = 0; target_offset < count;) { - const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, - target_offset); - } - target_offset += copy_size; - source_offset += copy_size; - Refill(); - } - } - - return source.ColumnCount(); -} - -void PositionalJoinGlobalState::Execute(DataChunk &input, DataChunk &output) { - lock_guard client_guard(rhs_lock); - - // Reference the input and assume it will be full - const auto col_offset = input.ColumnCount(); - for (idx_t i = 0; i < col_offset; ++i) { - output.data[i].Reference(input.data[i]); - } - - // Copy or reference the RHS columns - const auto count = input.size(); - InitializeScan(); - Refill(); - CopyData(output, count, col_offset); - - output.SetCardinality(count); -} - -OperatorResultType PhysicalPositionalJoin::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &sink = sink_state->Cast(); - sink.Execute(input, chunk); - return OperatorResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -void PositionalJoinGlobalState::GetData(DataChunk &output) { - lock_guard client_guard(rhs_lock); - - InitializeScan(); - Refill(); - - // LHS exhausted - if (exhausted) { - // RHS exhausted too, so we are done - output.SetCardinality(0); - return; - } - - // LHS is all NULL - const auto col_offset = output.ColumnCount() - source.ColumnCount(); - for (idx_t i = 0; i < col_offset; ++i) { - auto &vec = output.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - - // RHS still has data, so copy it - const auto count = MinValue(STANDARD_VECTOR_SIZE, source.size() - source_offset); - CopyData(output, count, col_offset); - output.SetCardinality(count); -} - -SourceResultType PhysicalPositionalJoin::GetData(ExecutionContext &context, DataChunk &result, - OperatorSourceInput &input) const { - auto &sink = sink_state->Cast(); - sink.GetData(result); - - return result.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalPositionalJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *this); -} - -vector> PhysicalPositionalJoin::GetSources() const { - auto result = children[0]->GetSources(); - if (IsSource()) { - result.push_back(*this); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_range_join.cpp b/src/duckdb/src/execution/operator/join/physical_range_join.cpp deleted file mode 100644 index f7701f845..000000000 --- a/src/duckdb/src/execution/operator/join/physical_range_join.cpp +++ /dev/null @@ -1,402 +0,0 @@ -#include "duckdb/execution/operator/join/physical_range_join.hpp" - -#include "duckdb/common/fast_mem.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/sort/comparators.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/validity_mask.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/parallel/executor_task.hpp" - -#include - -namespace duckdb { - -PhysicalRangeJoin::LocalSortedTable::LocalSortedTable(ClientContext &context, const PhysicalRangeJoin &op, - const idx_t child) - : op(op), executor(context), has_null(0), count(0) { - // Initialize order clause expression executor and key DataChunk - vector types; - for (const auto &cond : op.conditions) { - const auto &expr = child ? cond.right : cond.left; - executor.AddExpression(*expr); - - types.push_back(expr->return_type); - } - auto &allocator = Allocator::Get(context); - keys.Initialize(allocator, types); -} - -void PhysicalRangeJoin::LocalSortedTable::Sink(DataChunk &input, GlobalSortState &global_sort_state) { - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, global_sort_state.buffer_manager); - } - - // Obtain sorting columns - keys.Reset(); - executor.Execute(input, keys); - - // Do not operate on primary key directly to avoid modifying the input chunk - Vector primary = keys.data[0]; - // Count the NULLs so we can exclude them later - has_null += MergeNulls(primary, op.conditions); - count += keys.size(); - - // Only sort the primary key - DataChunk join_head; - join_head.data.emplace_back(primary); - join_head.SetCardinality(keys.size()); - - // Sink the data into the local sort state - local_sort_state.SinkChunk(join_head, input); -} - -PhysicalRangeJoin::GlobalSortedTable::GlobalSortedTable(ClientContext &context, const vector &orders, - RowLayout &payload_layout, const PhysicalOperator &op_p) - : op(op_p), global_sort_state(BufferManager::GetBufferManager(context), orders, payload_layout), has_null(0), - count(0), memory_per_thread(0) { - - // Set external (can be forced with the PRAGMA) - auto &config = ClientConfig::GetConfig(context); - global_sort_state.external = config.force_external; - memory_per_thread = PhysicalRangeJoin::GetMaxThreadMemory(context); -} - -void PhysicalRangeJoin::GlobalSortedTable::Combine(LocalSortedTable <able) { - global_sort_state.AddLocalState(ltable.local_sort_state); - has_null += ltable.has_null; - count += ltable.count; -} - -void PhysicalRangeJoin::GlobalSortedTable::IntializeMatches() { - found_match = make_unsafe_uniq_array_uninitialized(Count()); - memset(found_match.get(), 0, sizeof(bool) * Count()); -} - -void PhysicalRangeJoin::GlobalSortedTable::Print() { - global_sort_state.Print(); -} - -class RangeJoinMergeTask : public ExecutorTask { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - RangeJoinMergeTask(shared_ptr event_p, ClientContext &context, GlobalSortedTable &table) - : ExecutorTask(context, std::move(event_p), table.op), context(context), table(table) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize iejoin sorted and iterate until done - auto &global_sort_state = table.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); - - return TaskExecutionResult::TASK_FINISHED; - } - -private: - ClientContext &context; - GlobalSortedTable &table; -}; - -class RangeJoinMergeEvent : public BasePipelineEvent { -public: - using GlobalSortedTable = PhysicalRangeJoin::GlobalSortedTable; - -public: - RangeJoinMergeEvent(GlobalSortedTable &table_p, Pipeline &pipeline_p) - : BasePipelineEvent(pipeline_p), table(table_p) { - } - - GlobalSortedTable &table; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> iejoin_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - iejoin_tasks.push_back(make_uniq(shared_from_this(), context, table)); - } - SetTasks(std::move(iejoin_tasks)); - } - - void FinishEvent() override { - auto &global_sort_state = table.global_sort_state; - - global_sort_state.CompleteMergeRound(true); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - table.ScheduleMergeTasks(*pipeline, *this); - } - } -}; - -void PhysicalRangeJoin::GlobalSortedTable::ScheduleMergeTasks(Pipeline &pipeline, Event &event) { - // Initialize global sort state for a round of merging - global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(*this, pipeline); - event.InsertEvent(std::move(new_event)); -} - -void PhysicalRangeJoin::GlobalSortedTable::Finalize(Pipeline &pipeline, Event &event) { - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - ScheduleMergeTasks(pipeline, event); - } -} - -PhysicalRangeJoin::PhysicalRangeJoin(LogicalComparisonJoin &op, PhysicalOperatorType type, - unique_ptr left, unique_ptr right, - vector cond, JoinType join_type, idx_t estimated_cardinality) - : PhysicalComparisonJoin(op, type, std::move(cond), join_type, estimated_cardinality) { - // Reorder the conditions so that ranges are at the front. - // TODO: use stats to improve the choice? - // TODO: Prefer fixed length types? - if (conditions.size() > 1) { - vector conditions_p(conditions.size()); - std::swap(conditions_p, conditions); - idx_t range_position = 0; - idx_t other_position = conditions_p.size(); - for (idx_t i = 0; i < conditions_p.size(); ++i) { - switch (conditions_p[i].comparison) { - case ExpressionType::COMPARE_LESSTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - conditions[range_position++] = std::move(conditions_p[i]); - break; - default: - conditions[--other_position] = std::move(conditions_p[i]); - break; - } - } - } - - children.push_back(std::move(left)); - children.push_back(std::move(right)); - - // Fill out the left projection map. - left_projection_map = op.left_projection_map; - if (left_projection_map.empty()) { - const auto left_count = children[0]->types.size(); - left_projection_map.reserve(left_count); - for (column_t i = 0; i < left_count; ++i) { - left_projection_map.emplace_back(i); - } - } - // Fill out the right projection map. - right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = children[1]->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } - - // Construct the unprojected type layout from the children's types - unprojected_types = children[0]->GetTypes(); - auto &types = children[1]->GetTypes(); - unprojected_types.insert(unprojected_types.end(), types.begin(), types.end()); -} - -idx_t PhysicalRangeJoin::LocalSortedTable::MergeNulls(Vector &primary, const vector &conditions) { - // Merge the validity masks of the comparison keys into the primary - // Return the number of NULLs in the resulting chunk - D_ASSERT(keys.ColumnCount() > 0); - const auto count = keys.size(); - - size_t all_constant = 0; - for (auto &v : keys.data) { - if (v.GetVectorType() == VectorType::CONSTANT_VECTOR) { - ++all_constant; - } - } - - if (all_constant == keys.data.size()) { - // Either all NULL or no NULLs - if (ConstantVector::IsNull(primary)) { - // Primary is already NULL - return count; - } - for (size_t c = 1; c < keys.data.size(); ++c) { - // Skip comparisons that accept NULLs - if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { - continue; - } - auto &v = keys.data[c]; - if (ConstantVector::IsNull(v)) { - // Create a new validity mask to avoid modifying original mask - auto &pvalidity = ConstantVector::Validity(primary); - ValidityMask pvalidity_copy = ConstantVector::Validity(primary); - pvalidity.Copy(pvalidity_copy, count); - ConstantVector::SetNull(primary, true); - return count; - } - } - return 0; - } else if (keys.ColumnCount() > 1) { - // Flatten the primary, as it will need to merge arbitrary validity masks - primary.Flatten(count); - auto &pvalidity = FlatVector::Validity(primary); - // Make a copy of validity to avoid modifying original mask - ValidityMask pvalidity_copy = FlatVector::Validity(primary); - pvalidity.Copy(pvalidity_copy, count); - - D_ASSERT(keys.ColumnCount() == conditions.size()); - for (size_t c = 1; c < keys.data.size(); ++c) { - // Skip comparisons that accept NULLs - if (conditions[c].comparison == ExpressionType::COMPARE_DISTINCT_FROM) { - continue; - } - // ToUnifiedFormat the rest, as the sort code will do this anyway. - auto &v = keys.data[c]; - UnifiedVectorFormat vdata; - v.ToUnifiedFormat(count, vdata); - auto &vvalidity = vdata.validity; - if (vvalidity.AllValid()) { - continue; - } - pvalidity.EnsureWritable(); - switch (v.GetVectorType()) { - case VectorType::FLAT_VECTOR: { - // Merge entire entries - auto pmask = pvalidity.GetData(); - const auto entry_count = pvalidity.EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; ++entry_idx) { - pmask[entry_idx] &= vvalidity.GetValidityEntry(entry_idx); - } - break; - } - case VectorType::CONSTANT_VECTOR: - // All or nothing - if (ConstantVector::IsNull(v)) { - pvalidity.SetAllInvalid(count); - return count; - } - break; - default: - // One by one - for (idx_t i = 0; i < count; ++i) { - const auto idx = vdata.sel->get_index(i); - if (!vvalidity.RowIsValidUnsafe(idx)) { - pvalidity.SetInvalidUnsafe(i); - } - } - break; - } - } - return count - pvalidity.CountValid(count); - } else { - return count - VectorOperations::CountNotNull(primary, count); - } -} - -void PhysicalRangeJoin::ProjectResult(DataChunk &chunk, DataChunk &result) const { - const auto left_projected = left_projection_map.size(); - for (idx_t i = 0; i < left_projected; ++i) { - result.data[i].Reference(chunk.data[left_projection_map[i]]); - } - const auto left_width = children[0]->types.size(); - for (idx_t i = 0; i < right_projection_map.size(); ++i) { - result.data[left_projected + i].Reference(chunk.data[left_width + right_projection_map[i]]); - } - result.SetCardinality(chunk); -} - -BufferHandle PhysicalRangeJoin::SliceSortedPayload(DataChunk &payload, GlobalSortState &state, const idx_t block_idx, - const SelectionVector &result, const idx_t result_count, - const idx_t left_cols) { - // There should only be one sorted block if they have been sorted - D_ASSERT(state.sorted_blocks.size() == 1); - SBScanState read_state(state.buffer_manager, state); - read_state.sb = state.sorted_blocks[0].get(); - auto &sorted_data = *read_state.sb->payload_data; - - read_state.SetIndices(block_idx, 0); - read_state.PinData(sorted_data); - const auto data_ptr = read_state.DataPtr(sorted_data); - data_ptr_t heap_ptr = nullptr; - - // Set up a batch of pointers to scan data from - Vector addresses(LogicalType::POINTER, result_count); - auto data_pointers = FlatVector::GetData(addresses); - - // Set up the data pointers for the values that are actually referenced - const idx_t &row_width = sorted_data.layout.GetRowWidth(); - - auto prev_idx = result.get_index(0); - SelectionVector gsel(result_count); - idx_t addr_count = 0; - gsel.set_index(0, addr_count); - data_pointers[addr_count] = data_ptr + prev_idx * row_width; - for (idx_t i = 1; i < result_count; ++i) { - const auto row_idx = result.get_index(i); - if (row_idx != prev_idx) { - data_pointers[++addr_count] = data_ptr + row_idx * row_width; - prev_idx = row_idx; - } - gsel.set_index(i, addr_count); - } - ++addr_count; - - // Unswizzle the offsets back to pointers (if needed) - if (!sorted_data.layout.AllConstant() && state.external) { - heap_ptr = read_state.payload_heap_handle.Ptr(); - } - - // Deserialize the payload data - auto sel = FlatVector::IncrementalSelectionVector(); - for (idx_t col_no = 0; col_no < sorted_data.layout.ColumnCount(); col_no++) { - auto &col = payload.data[left_cols + col_no]; - RowOperations::Gather(addresses, *sel, col, *sel, addr_count, sorted_data.layout, col_no, 0, heap_ptr); - col.Slice(gsel, result_count); - } - - return std::move(read_state.payload_heap_handle); -} - -idx_t PhysicalRangeJoin::SelectJoinTail(const ExpressionType &condition, Vector &left, Vector &right, - const SelectionVector *sel, idx_t count, SelectionVector *true_sel) { - switch (condition) { - case ExpressionType::COMPARE_NOTEQUAL: - return VectorOperations::NotEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_LESSTHAN: - return VectorOperations::LessThan(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_GREATERTHAN: - return VectorOperations::GreaterThan(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - return VectorOperations::LessThanEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - return VectorOperations::GreaterThanEquals(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_DISTINCT_FROM: - return VectorOperations::DistinctFrom(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - return VectorOperations::NotDistinctFrom(left, right, sel, count, true_sel, nullptr); - case ExpressionType::COMPARE_EQUAL: - return VectorOperations::Equals(left, right, sel, count, true_sel, nullptr); - default: - throw InternalException("Unsupported comparison type for PhysicalRangeJoin"); - } - - return count; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp b/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp deleted file mode 100644 index 60aaeaca8..000000000 --- a/src/duckdb/src/execution/operator/join/physical_right_delim_join.cpp +++ /dev/null @@ -1,126 +0,0 @@ -#include "duckdb/execution/operator/join/physical_right_delim_join.hpp" - -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_join.hpp" -#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" - -namespace duckdb { - -PhysicalRightDelimJoin::PhysicalRightDelimJoin(vector types, unique_ptr original_join, - vector> delim_scans, - idx_t estimated_cardinality, optional_idx delim_idx) - : PhysicalDelimJoin(PhysicalOperatorType::RIGHT_DELIM_JOIN, std::move(types), std::move(original_join), - std::move(delim_scans), estimated_cardinality, delim_idx) { - D_ASSERT(join->children.size() == 2); - // now for the original join - // we take its right child, this is the side that we will duplicate eliminate - children.push_back(std::move(join->children[1])); - - // we replace it with a PhysicalDummyScan, which contains no data, just the types, it won't be scanned anyway - join->children[1] = make_uniq(children[0]->GetTypes(), estimated_cardinality); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class RightDelimJoinGlobalState : public GlobalSinkState {}; - -class RightDelimJoinLocalState : public LocalSinkState { -public: - unique_ptr join_state; - unique_ptr distinct_state; -}; - -unique_ptr PhysicalRightDelimJoin::GetGlobalSinkState(ClientContext &context) const { - auto state = make_uniq(); - join->sink_state = join->GetGlobalSinkState(context); - distinct->sink_state = distinct->GetGlobalSinkState(context); - if (delim_scans.size() > 1) { - PhysicalHashAggregate::SetMultiScan(*distinct->sink_state); - } - return std::move(state); -} - -unique_ptr PhysicalRightDelimJoin::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(); - state->join_state = join->GetLocalSinkState(context); - state->distinct_state = distinct->GetLocalSinkState(context); - return std::move(state); -} - -SinkResultType PhysicalRightDelimJoin::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - - OperatorSinkInput join_sink_input {*join->sink_state, *lstate.join_state, input.interrupt_state}; - join->Sink(context, chunk, join_sink_input); - - OperatorSinkInput distinct_sink_input {*distinct->sink_state, *lstate.distinct_state, input.interrupt_state}; - distinct->Sink(context, chunk, distinct_sink_input); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalRightDelimJoin::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - - OperatorSinkCombineInput join_combine_input {*join->sink_state, *lstate.join_state, input.interrupt_state}; - join->Combine(context, join_combine_input); - - OperatorSinkCombineInput distinct_combine_input {*distinct->sink_state, *lstate.distinct_state, - input.interrupt_state}; - distinct->Combine(context, distinct_combine_input); - - return SinkCombineResultType::FINISHED; -} - -void PhysicalRightDelimJoin::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { - join->PrepareFinalize(context, *join->sink_state); - distinct->PrepareFinalize(context, *distinct->sink_state); -} - -SinkFinalizeType PhysicalRightDelimJoin::Finalize(Pipeline &pipeline, Event &event, ClientContext &client, - OperatorSinkFinalizeInput &input) const { - D_ASSERT(join); - D_ASSERT(distinct); - - OperatorSinkFinalizeInput join_finalize_input {*join->sink_state, input.interrupt_state}; - join->Finalize(pipeline, event, client, join_finalize_input); - - OperatorSinkFinalizeInput distinct_finalize_input {*distinct->sink_state, input.interrupt_state}; - distinct->Finalize(pipeline, event, client, distinct_finalize_input); - - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalRightDelimJoin::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - - D_ASSERT(type == PhysicalOperatorType::RIGHT_DELIM_JOIN); - // recurse into the actual join - // any pipelines in there depend on the main pipeline - // any scan of the duplicate eliminated data on the LHS depends on this pipeline - // we add an entry to the mapping of (PhysicalOperator*) -> (Pipeline*) - auto &state = meta_pipeline.GetState(); - for (auto &delim_scan : delim_scans) { - state.delim_join_dependencies.insert( - make_pair(delim_scan, reference(*child_meta_pipeline.GetBasePipeline()))); - } - - // Build join pipelines without building the RHS (already built in the Sink of this op) - PhysicalJoin::BuildJoinPipelines(current, meta_pipeline, *join, false); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/order/physical_order.cpp b/src/duckdb/src/execution/operator/order/physical_order.cpp deleted file mode 100644 index 71294e9b7..000000000 --- a/src/duckdb/src/execution/operator/order/physical_order.cpp +++ /dev/null @@ -1,291 +0,0 @@ -#include "duckdb/execution/operator/order/physical_order.hpp" - -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/executor_task.hpp" -#include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/common/shared_ptr.hpp" - -namespace duckdb { - -PhysicalOrder::PhysicalOrder(vector types, vector orders, vector projections, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::ORDER_BY, std::move(types), estimated_cardinality), - orders(std::move(orders)), projections(std::move(projections)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class OrderGlobalSinkState : public GlobalSinkState { -public: - OrderGlobalSinkState(BufferManager &buffer_manager, const PhysicalOrder &order, RowLayout &payload_layout) - : order(order), global_sort_state(buffer_manager, order.orders, payload_layout) { - } - - const PhysicalOrder ℴ - //! Global sort state - GlobalSortState global_sort_state; - //! Memory usage per thread - idx_t memory_per_thread; -}; - -class OrderLocalSinkState : public LocalSinkState { -public: - OrderLocalSinkState(ClientContext &context, const PhysicalOrder &op) : key_executor(context) { - // Initialize order clause expression executor and DataChunk - vector key_types; - for (auto &order : op.orders) { - key_types.push_back(order.expression->return_type); - key_executor.AddExpression(*order.expression); - } - auto &allocator = Allocator::Get(context); - keys.Initialize(allocator, key_types); - payload.Initialize(allocator, op.types); - } - -public: - //! The local sort state - LocalSortState local_sort_state; - //! Key expression executor, and chunk to hold the vectors - ExpressionExecutor key_executor; - DataChunk keys; - //! Payload chunk to hold the vectors - DataChunk payload; -}; - -unique_ptr PhysicalOrder::GetGlobalSinkState(ClientContext &context) const { - // Get the payload layout from the return types - RowLayout payload_layout; - payload_layout.Initialize(types); - auto state = make_uniq(BufferManager::GetBufferManager(context), *this, payload_layout); - // Set external (can be force with the PRAGMA) - state->global_sort_state.external = ClientConfig::GetConfig(context).force_external; - state->memory_per_thread = GetMaxThreadMemory(context); - return std::move(state); -} - -unique_ptr PhysicalOrder::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -SinkResultType PhysicalOrder::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto &global_sort_state = gstate.global_sort_state; - auto &local_sort_state = lstate.local_sort_state; - - // Initialize local state (if necessary) - if (!local_sort_state.initialized) { - local_sort_state.Initialize(global_sort_state, BufferManager::GetBufferManager(context.client)); - } - - // Obtain sorting columns - auto &keys = lstate.keys; - keys.Reset(); - lstate.key_executor.Execute(chunk, keys); - - auto &payload = lstate.payload; - payload.ReferenceColumns(chunk, projections); - - // Sink the data into the local sort state - keys.Verify(); - chunk.Verify(); - local_sort_state.SinkChunk(keys, payload); - - // When sorting data reaches a certain size, we sort it - if (local_sort_state.SizeInBytes() >= gstate.memory_per_thread) { - local_sort_state.Sort(global_sort_state, true); - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalOrder::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - gstate.global_sort_state.AddLocalState(lstate.local_sort_state); - - return SinkCombineResultType::FINISHED; -} - -class PhysicalOrderMergeTask : public ExecutorTask { -public: - PhysicalOrderMergeTask(shared_ptr event_p, ClientContext &context, OrderGlobalSinkState &state, - const PhysicalOperator &op_p) - : ExecutorTask(context, std::move(event_p), op_p), context(context), state(state) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - // Initialize merge sorted and iterate until done - auto &global_sort_state = state.global_sort_state; - MergeSorter merge_sorter(global_sort_state, BufferManager::GetBufferManager(context)); - merge_sorter.PerformInMergeRound(); - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - ClientContext &context; - OrderGlobalSinkState &state; -}; - -class OrderMergeEvent : public BasePipelineEvent { -public: - OrderMergeEvent(OrderGlobalSinkState &gstate_p, Pipeline &pipeline_p, const PhysicalOperator &op_p) - : BasePipelineEvent(pipeline_p), gstate(gstate_p), op(op_p) { - } - - OrderGlobalSinkState &gstate; - const PhysicalOperator &op; - -public: - void Schedule() override { - auto &context = pipeline->GetClientContext(); - - // Schedule tasks equal to the number of threads, which will each merge multiple partitions - auto &ts = TaskScheduler::GetScheduler(context); - auto num_threads = NumericCast(ts.NumberOfThreads()); - - vector> merge_tasks; - for (idx_t tnum = 0; tnum < num_threads; tnum++) { - merge_tasks.push_back(make_uniq(shared_from_this(), context, gstate, op)); - } - SetTasks(std::move(merge_tasks)); - } - - void FinishEvent() override { - auto &global_sort_state = gstate.global_sort_state; - - global_sort_state.CompleteMergeRound(); - if (global_sort_state.sorted_blocks.size() > 1) { - // Multiple blocks remaining: Schedule the next round - PhysicalOrder::ScheduleMergeTasks(*pipeline, *this, gstate); - } - } -}; - -SinkFinalizeType PhysicalOrder::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &state = input.global_state.Cast(); - auto &global_sort_state = state.global_sort_state; - - if (global_sort_state.sorted_blocks.empty()) { - // Empty input! - return SinkFinalizeType::NO_OUTPUT_POSSIBLE; - } - - // Prepare for merge sort phase - global_sort_state.PrepareMergePhase(); - - // Start the merge phase or finish if a merge is not necessary - if (global_sort_state.sorted_blocks.size() > 1) { - PhysicalOrder::ScheduleMergeTasks(pipeline, event, state); - } - return SinkFinalizeType::READY; -} - -void PhysicalOrder::ScheduleMergeTasks(Pipeline &pipeline, Event &event, OrderGlobalSinkState &state) { - // Initialize global sort state for a round of merging - state.global_sort_state.InitializeMergeRound(); - auto new_event = make_shared_ptr(state, pipeline, state.order); - event.InsertEvent(std::move(new_event)); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class PhysicalOrderGlobalSourceState : public GlobalSourceState { -public: - explicit PhysicalOrderGlobalSourceState(OrderGlobalSinkState &sink) : next_batch_index(0) { - auto &global_sort_state = sink.global_sort_state; - if (global_sort_state.sorted_blocks.empty()) { - total_batches = 0; - } else { - D_ASSERT(global_sort_state.sorted_blocks.size() == 1); - total_batches = global_sort_state.sorted_blocks[0]->payload_data->data_blocks.size(); - } - } - - idx_t MaxThreads() override { - return total_batches; - } - -public: - atomic next_batch_index; - idx_t total_batches; -}; - -unique_ptr PhysicalOrder::GetGlobalSourceState(ClientContext &context) const { - auto &sink = this->sink_state->Cast(); - return make_uniq(sink); -} - -class PhysicalOrderLocalSourceState : public LocalSourceState { -public: - explicit PhysicalOrderLocalSourceState(PhysicalOrderGlobalSourceState &gstate) - : batch_index(gstate.next_batch_index++) { - } - -public: - idx_t batch_index; - unique_ptr scanner; -}; - -unique_ptr PhysicalOrder::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - return make_uniq(gstate); -} - -SourceResultType PhysicalOrder::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - if (lstate.scanner && lstate.scanner->Remaining() == 0) { - lstate.batch_index = gstate.next_batch_index++; - lstate.scanner = nullptr; - } - - if (lstate.batch_index >= gstate.total_batches) { - return SourceResultType::FINISHED; - } - - if (!lstate.scanner) { - auto &sink = this->sink_state->Cast(); - auto &global_sort_state = sink.global_sort_state; - lstate.scanner = make_uniq(global_sort_state, lstate.batch_index, true); - } - - lstate.scanner->Scan(chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -OperatorPartitionData PhysicalOrder::GetPartitionData(ExecutionContext &context, DataChunk &chunk, - GlobalSourceState &gstate_p, LocalSourceState &lstate_p, - const OperatorPartitionInfo &partition_info) const { - if (partition_info.RequiresPartitionColumns()) { - throw InternalException("PhysicalOrder::GetPartitionData: partition columns not supported"); - } - auto &lstate = lstate_p.Cast(); - return OperatorPartitionData(lstate.batch_index); -} - -InsertionOrderPreservingMap PhysicalOrder::ParamsToString() const { - InsertionOrderPreservingMap result; - string orders_info; - for (idx_t i = 0; i < orders.size(); i++) { - if (i > 0) { - orders_info += "\n"; - } - orders_info += orders[i].expression->ToString() + " "; - orders_info += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; - } - result["__order_by__"] = orders_info; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/order/physical_top_n.cpp b/src/duckdb/src/execution/operator/order/physical_top_n.cpp deleted file mode 100644 index 669c25a5e..000000000 --- a/src/duckdb/src/execution/operator/order/physical_top_n.cpp +++ /dev/null @@ -1,585 +0,0 @@ -#include "duckdb/execution/operator/order/physical_top_n.hpp" - -#include "duckdb/common/assert.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/planner/filter/dynamic_filter.hpp" - -namespace duckdb { - -PhysicalTopN::PhysicalTopN(vector types, vector orders, idx_t limit, idx_t offset, - shared_ptr dynamic_filter_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::TOP_N, std::move(types), estimated_cardinality), orders(std::move(orders)), - limit(limit), offset(offset), dynamic_filter(std::move(dynamic_filter_p)) { -} - -PhysicalTopN::~PhysicalTopN() { -} - -//===--------------------------------------------------------------------===// -// Heaps -//===--------------------------------------------------------------------===// -class TopNHeap; - -struct TopNEntry { - string_t sort_key; - idx_t index; - - bool operator<(const TopNEntry &other) const { - return sort_key < other.sort_key; - } -}; - -struct TopNScanState { - TopNScanState() : pos(0), sel(STANDARD_VECTOR_SIZE) { - } - - idx_t pos; - vector scan_order; - SelectionVector sel; -}; - -struct TopNBoundaryValue { - explicit TopNBoundaryValue(const PhysicalTopN &op) - : op(op), boundary_vector(op.orders[0].expression->return_type), - boundary_modifiers(op.orders[0].type, op.orders[0].null_order) { - } - - const PhysicalTopN &op; - mutex lock; - string boundary_value; - bool is_set = false; - Vector boundary_vector; - OrderModifiers boundary_modifiers; - - string GetBoundaryValue() { - lock_guard l(lock); - return boundary_value; - } - - void UpdateValue(string_t boundary_val) { - unique_lock l(lock); - if (!is_set || boundary_val < string_t(boundary_value)) { - boundary_value = boundary_val.GetString(); - is_set = true; - if (op.dynamic_filter) { - CreateSortKeyHelpers::DecodeSortKey(boundary_val, boundary_vector, 0, boundary_modifiers); - auto new_dynamic_value = boundary_vector.GetValue(0); - l.unlock(); - op.dynamic_filter->SetValue(std::move(new_dynamic_value)); - } - } - } -}; - -class TopNHeap { -public: - TopNHeap(ClientContext &context, const vector &payload_types, const vector &orders, - idx_t limit, idx_t offset); - TopNHeap(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset); - TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset); - - Allocator &allocator; - BufferManager &buffer_manager; - unsafe_vector heap; - const vector &payload_types; - const vector &orders; - vector modifiers; - idx_t limit; - idx_t offset; - idx_t heap_size; - ExpressionExecutor executor; - DataChunk sort_chunk; - DataChunk heap_data; - DataChunk payload_chunk; - DataChunk sort_keys; - StringHeap sort_key_heap; - - SelectionVector matching_sel; - - DataChunk compare_chunk; - //! Cached global boundary value as a set of constant vectors - DataChunk boundary_values; - //! Cached global boundary value in sort-key format - string boundary_val; - SelectionVector final_sel; - SelectionVector true_sel; - SelectionVector false_sel; - SelectionVector new_remaining_sel; - -public: - void Sink(DataChunk &input, optional_ptr boundary_value = nullptr); - void Combine(TopNHeap &other); - void Reduce(); - void Finalize(); - - void InitializeScan(TopNScanState &state, bool exclude_offset); - void Scan(TopNScanState &state, DataChunk &chunk); - - bool CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload, TopNBoundaryValue &boundary_val); - void AddSmallHeap(DataChunk &input, Vector &sort_keys_vec); - void AddLargeHeap(DataChunk &input, Vector &sort_keys_vec); - -public: - idx_t ReduceThreshold() const { - return MaxValue(STANDARD_VECTOR_SIZE * 5ULL, 2ULL * heap_size); - } - - idx_t InitialHeapAllocSize() const { - return MinValue(STANDARD_VECTOR_SIZE * 100ULL, ReduceThreshold()) + STANDARD_VECTOR_SIZE; - } - -private: - inline bool EntryShouldBeAdded(const string_t &sort_key) { - if (heap.size() < heap_size) { - // heap is full - check the latest entry - return true; - } - if (sort_key < heap.front().sort_key) { - // sort key is smaller than current max value - return true; - } - // heap is full and there is no room for the entry - return false; - } - - inline void AddEntryToHeap(const TopNEntry &entry) { - if (heap.size() >= heap_size) { - std::pop_heap(heap.begin(), heap.end()); - heap.pop_back(); - } - heap.push_back(entry); - std::push_heap(heap.begin(), heap.end()); - } -}; - -//===--------------------------------------------------------------------===// -// TopNHeap -//===--------------------------------------------------------------------===// -TopNHeap::TopNHeap(ClientContext &context, Allocator &allocator, const vector &payload_types_p, - const vector &orders_p, idx_t limit, idx_t offset) - : allocator(allocator), buffer_manager(BufferManager::GetBufferManager(context)), payload_types(payload_types_p), - orders(orders_p), limit(limit), offset(offset), heap_size(limit + offset), executor(context), - matching_sel(STANDARD_VECTOR_SIZE), final_sel(STANDARD_VECTOR_SIZE), true_sel(STANDARD_VECTOR_SIZE), - false_sel(STANDARD_VECTOR_SIZE), new_remaining_sel(STANDARD_VECTOR_SIZE) { - // initialize the executor and the sort_chunk - vector sort_types; - for (auto &order : orders) { - auto &expr = order.expression; - sort_types.push_back(expr->return_type); - executor.AddExpression(*expr); - modifiers.emplace_back(order.type, order.null_order); - } - heap.reserve(InitialHeapAllocSize()); - vector sort_keys_type {LogicalType::BLOB}; - sort_keys.Initialize(allocator, sort_keys_type); - heap_data.Initialize(allocator, payload_types, InitialHeapAllocSize()); - payload_chunk.Initialize(allocator, payload_types); - sort_chunk.Initialize(allocator, sort_types); - compare_chunk.Initialize(allocator, sort_types); - boundary_values.Initialize(allocator, sort_types); -} - -TopNHeap::TopNHeap(ClientContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : TopNHeap(context, BufferAllocator::Get(context), payload_types, orders, limit, offset) { -} - -TopNHeap::TopNHeap(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : TopNHeap(context.client, Allocator::Get(context.client), payload_types, orders, limit, offset) { -} - -void TopNHeap::AddSmallHeap(DataChunk &input, Vector &sort_keys_vec) { - // insert the sort keys into the priority queue - constexpr idx_t BASE_INDEX = NumericLimits::Maximum(); - - bool any_added = false; - auto sort_key_values = FlatVector::GetData(sort_keys_vec); - for (idx_t r = 0; r < input.size(); r++) { - auto &sort_key = sort_key_values[r]; - if (!EntryShouldBeAdded(sort_key)) { - continue; - } - // replace the previous top entry with the new entry - TopNEntry entry; - entry.sort_key = sort_key; - entry.index = BASE_INDEX + r; - AddEntryToHeap(entry); - any_added = true; - } - if (!any_added) { - // early-out: no matches - return; - } - - // for all matching entries we need to copy over the corresponding payload values - idx_t match_count = 0; - for (auto &entry : heap) { - if (entry.index < BASE_INDEX) { - continue; - } - // this entry was added in this chunk - // if not inlined - copy over the string to the string heap - if (!entry.sort_key.IsInlined()) { - entry.sort_key = sort_key_heap.AddBlob(entry.sort_key); - } - // to finalize the addition of this entry we need to move over the payload data - matching_sel.set_index(match_count, entry.index - BASE_INDEX); - entry.index = heap_data.size() + match_count; - match_count++; - } - - // copy over the input rows to the payload chunk - heap_data.Append(input, true, &matching_sel, match_count); -} - -void TopNHeap::AddLargeHeap(DataChunk &input, Vector &sort_keys_vec) { - auto sort_key_values = FlatVector::GetData(sort_keys_vec); - idx_t base_index = heap_data.size(); - idx_t match_count = 0; - for (idx_t r = 0; r < input.size(); r++) { - auto &sort_key = sort_key_values[r]; - if (!EntryShouldBeAdded(sort_key)) { - continue; - } - // replace the previous top entry with the new entry - TopNEntry entry; - entry.sort_key = sort_key.IsInlined() ? sort_key : sort_key_heap.AddBlob(sort_key); - entry.index = base_index + match_count; - AddEntryToHeap(entry); - matching_sel.set_index(match_count++, r); - } - if (match_count == 0) { - // early-out: no matches - return; - } - - // copy over the input rows to the payload chunk - heap_data.Append(input, true, &matching_sel, match_count); -} - -bool TopNHeap::CheckBoundaryValues(DataChunk &sort_chunk, DataChunk &payload, TopNBoundaryValue &global_boundary) { - // get the global boundary value - auto current_boundary_val = global_boundary.GetBoundaryValue(); - if (current_boundary_val.empty()) { - // no boundary value (yet) - don't do anything - return true; - } - if (current_boundary_val != boundary_val) { - // new boundary value - decode - boundary_val = std::move(current_boundary_val); - boundary_values.Reset(); - CreateSortKeyHelpers::DecodeSortKey(string_t(boundary_val), boundary_values, 0, modifiers); - for (auto &col : boundary_values.data) { - col.SetVectorType(VectorType::CONSTANT_VECTOR); - } - } - boundary_values.SetCardinality(sort_chunk.size()); - - // we have boundary values - // from these boundary values, determine which values we should insert (if any) - idx_t final_count = 0; - - SelectionVector remaining_sel(nullptr); - idx_t remaining_count = sort_chunk.size(); - for (idx_t i = 0; i < orders.size(); i++) { - if (remaining_sel.data()) { - compare_chunk.data[i].Slice(sort_chunk.data[i], remaining_sel, remaining_count); - } else { - compare_chunk.data[i].Reference(sort_chunk.data[i]); - } - bool is_last = i + 1 == orders.size(); - idx_t true_count; - if (orders[i].null_order == OrderByNullType::NULLS_LAST) { - if (orders[i].type == OrderType::ASCENDING) { - true_count = VectorOperations::DistinctLessThan(compare_chunk.data[i], boundary_values.data[i], - &remaining_sel, remaining_count, &true_sel, &false_sel); - } else { - true_count = VectorOperations::DistinctGreaterThanNullsFirst(compare_chunk.data[i], - boundary_values.data[i], &remaining_sel, - remaining_count, &true_sel, &false_sel); - } - } else { - D_ASSERT(orders[i].null_order == OrderByNullType::NULLS_FIRST); - if (orders[i].type == OrderType::ASCENDING) { - true_count = VectorOperations::DistinctLessThanNullsFirst(compare_chunk.data[i], - boundary_values.data[i], &remaining_sel, - remaining_count, &true_sel, &false_sel); - } else { - true_count = - VectorOperations::DistinctGreaterThan(compare_chunk.data[i], boundary_values.data[i], - &remaining_sel, remaining_count, &true_sel, &false_sel); - } - } - - if (true_count > 0) { - memcpy(final_sel.data() + final_count, true_sel.data(), true_count * sizeof(sel_t)); - final_count += true_count; - } - idx_t false_count = remaining_count - true_count; - if (!is_last && false_count > 0) { - // check what we should continue to check - compare_chunk.data[i].Slice(sort_chunk.data[i], false_sel, false_count); - remaining_count = VectorOperations::NotDistinctFrom(compare_chunk.data[i], boundary_values.data[i], - &false_sel, false_count, &new_remaining_sel, nullptr); - remaining_sel.Initialize(new_remaining_sel); - } else { - break; - } - } - if (final_count == 0) { - return false; - } - if (final_count < sort_chunk.size()) { - sort_chunk.Slice(final_sel, final_count); - payload.Slice(final_sel, final_count); - } - return true; -} - -void TopNHeap::Sink(DataChunk &input, optional_ptr global_boundary) { - static constexpr idx_t SMALL_HEAP_THRESHOLD = 100; - - // compute the ordering values for the new chunk - sort_chunk.Reset(); - executor.Execute(input, sort_chunk); - - if (global_boundary) { - // if we have a global boundary value check which rows pass before doing anything - if (!CheckBoundaryValues(sort_chunk, input, *global_boundary)) { - // nothing in this chunk can be in the final result - return; - } - } - - // construct the sort key from the sort chunk - sort_keys.Reset(); - auto &sort_keys_vec = sort_keys.data[0]; - CreateSortKeyHelpers::CreateSortKey(sort_chunk, modifiers, sort_keys_vec); - - if (heap_size <= SMALL_HEAP_THRESHOLD) { - AddSmallHeap(input, sort_keys_vec); - } else { - AddLargeHeap(input, sort_keys_vec); - } - - // if we modified the heap we might be able to update the global boundary - // note that the global boundary only applies to FULL heaps - if (heap.size() >= heap_size && global_boundary) { - global_boundary->UpdateValue(heap.front().sort_key); - } -} - -void TopNHeap::Combine(TopNHeap &other) { - other.Finalize(); - - idx_t match_count = 0; - // merge the heap of other into this - for (idx_t i = 0; i < other.heap.size(); i++) { - // heap is full - check the latest entry - auto &other_entry = other.heap[i]; - auto &sort_key = other_entry.sort_key; - if (!EntryShouldBeAdded(sort_key)) { - continue; - } - // add this entry - TopNEntry new_entry; - new_entry.sort_key = sort_key.IsInlined() ? sort_key : sort_key_heap.AddBlob(sort_key); - new_entry.index = heap_data.size() + match_count; - AddEntryToHeap(new_entry); - - matching_sel.set_index(match_count++, other_entry.index); - if (match_count >= STANDARD_VECTOR_SIZE) { - // flush - heap_data.Append(other.heap_data, true, &matching_sel, match_count); - match_count = 0; - } - } - if (match_count > 0) { - // flush - heap_data.Append(other.heap_data, true, &matching_sel, match_count); - match_count = 0; - } - Reduce(); -} - -void TopNHeap::Finalize() { -} - -void TopNHeap::Reduce() { - if (heap_data.size() < ReduceThreshold()) { - // only reduce when we pass the reduce threshold - return; - } - // we have too many values in the heap - reduce them - StringHeap new_sort_heap; - DataChunk new_heap_data; - new_heap_data.Initialize(allocator, payload_types, heap.size()); - - SelectionVector new_payload_sel(heap.size()); - for (idx_t i = 0; i < heap.size(); i++) { - auto &entry = heap[i]; - // the entry is not inlined - move the sort key to the new sort heap - if (!entry.sort_key.IsInlined()) { - entry.sort_key = new_sort_heap.AddBlob(entry.sort_key); - } - // move this heap entry to position X in the payload chunk - new_payload_sel.set_index(i, entry.index); - entry.index = i; - } - - // copy over the data from the current payload chunk to the new payload chunk - new_heap_data.Slice(heap_data, new_payload_sel, heap.size()); - new_heap_data.Flatten(); - - sort_key_heap.Destroy(); - sort_key_heap.Move(new_sort_heap); - heap_data.Reference(new_heap_data); -} - -void TopNHeap::InitializeScan(TopNScanState &state, bool exclude_offset) { - auto heap_copy = heap; - // traverse the rest of the heap - state.scan_order.resize(heap_copy.size()); - while (!heap_copy.empty()) { - std::pop_heap(heap_copy.begin(), heap_copy.end()); - state.scan_order[heap_copy.size() - 1] = UnsafeNumericCast(heap_copy.back().index); - heap_copy.pop_back(); - } - state.pos = exclude_offset ? offset : 0; -} - -void TopNHeap::Scan(TopNScanState &state, DataChunk &chunk) { - if (state.pos >= state.scan_order.size()) { - return; - } - SelectionVector sel(state.scan_order.data() + state.pos); - idx_t count = MinValue(STANDARD_VECTOR_SIZE, state.scan_order.size() - state.pos); - state.pos += STANDARD_VECTOR_SIZE; - - chunk.Reset(); - chunk.Slice(heap_data, sel, count); -} - -class TopNGlobalState : public GlobalSinkState { -public: - TopNGlobalState(ClientContext &context, const PhysicalTopN &op) - : heap(context, op.types, op.orders, op.limit, op.offset), boundary_value(op) { - } - - mutex lock; - TopNHeap heap; - TopNBoundaryValue boundary_value; -}; - -class TopNLocalState : public LocalSinkState { -public: - TopNLocalState(ExecutionContext &context, const vector &payload_types, - const vector &orders, idx_t limit, idx_t offset) - : heap(context, payload_types, orders, limit, offset) { - } - - TopNHeap heap; -}; - -unique_ptr PhysicalTopN::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context, types, orders, limit, offset); -} - -unique_ptr PhysicalTopN::GetGlobalSinkState(ClientContext &context) const { - if (dynamic_filter) { - dynamic_filter->Reset(); - } - return make_uniq(context, *this); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalTopN::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - // append to the local sink state - auto &gstate = input.global_state.Cast(); - auto &sink = input.local_state.Cast(); - sink.heap.Sink(chunk, &gstate.boundary_value); - sink.heap.Reduce(); - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalTopN::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - // scan the local top N and append it to the global heap - lock_guard glock(gstate.lock); - gstate.heap.Combine(lstate.heap); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalTopN::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - // global finalize: compute the final top N - gstate.heap.Finalize(); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class TopNOperatorState : public GlobalSourceState { -public: - TopNScanState state; - bool initialized = false; -}; - -unique_ptr PhysicalTopN::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -SourceResultType PhysicalTopN::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - if (limit == 0) { - return SourceResultType::FINISHED; - } - auto &state = input.global_state.Cast(); - auto &gstate = sink_state->Cast(); - - if (!state.initialized) { - gstate.heap.InitializeScan(state.state, true); - state.initialized = true; - } - gstate.heap.Scan(state.state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -InsertionOrderPreservingMap PhysicalTopN::ParamsToString() const { - InsertionOrderPreservingMap result; - result["Top"] = to_string(limit); - if (offset > 0) { - result["Offset"] = to_string(offset); - } - - string orders_info; - for (idx_t i = 0; i < orders.size(); i++) { - if (i > 0) { - orders_info += "\n"; - } - orders_info += orders[i].expression->ToString() + " "; - orders_info += orders[i].type == OrderType::DESCENDING ? "DESC" : "ASC"; - } - result["Order By"] = orders_info; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp b/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp deleted file mode 100644 index 0674181f2..000000000 --- a/src/duckdb/src/execution/operator/persistent/csv_rejects_table.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "duckdb/main/appender.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/function/table/read_csv.hpp" -#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" - -namespace duckdb { - -TableCatalogEntry &CSVRejectsTable::GetErrorsTable(ClientContext &context) { - auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, errors_table); - return table_entry; -} - -TableCatalogEntry &CSVRejectsTable::GetScansTable(ClientContext &context) { - auto &temp_catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto &table_entry = temp_catalog.GetEntry(context, TEMP_CATALOG, DEFAULT_SCHEMA, scan_table); - return table_entry; -} - -idx_t CSVRejectsTable::GetCurrentFileIndex(idx_t query_id) { - if (current_query_id != query_id) { - current_query_id = query_id; - current_file_idx = 0; - } - return current_file_idx++; -} - -shared_ptr CSVRejectsTable::GetOrCreate(ClientContext &context, const string &rejects_scan, - const string &rejects_error) { - // Check that these names can't be the same - if (rejects_scan == rejects_error) { - throw BinderException("The names of the rejects scan and rejects error tables can't be the same. Use different " - "names for these tables."); - } - auto key = - "CSV_REJECTS_TABLE_CACHE_ENTRY_" + StringUtil::Upper(rejects_scan) + "_" + StringUtil::Upper(rejects_error); - auto &cache = ObjectCache::GetObjectCache(context); - auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - auto rejects_scan_exist = catalog.GetEntry(context, CatalogType::TABLE_ENTRY, DEFAULT_SCHEMA, rejects_scan, - OnEntryNotFound::RETURN_NULL) != nullptr; - auto rejects_error_exist = catalog.GetEntry(context, CatalogType::TABLE_ENTRY, DEFAULT_SCHEMA, rejects_error, - OnEntryNotFound::RETURN_NULL) != nullptr; - if ((rejects_scan_exist || rejects_error_exist) && !cache.Get(key)) { - std::ostringstream error; - if (rejects_scan_exist) { - error << "Reject Scan Table name \"" << rejects_scan << "\" is already in use. "; - } - if (rejects_error_exist) { - error << "Reject Error Table name \"" << rejects_error << "\" is already in use. "; - } - error << "Either drop the used name(s), or give other name options in the CSV Reader function.\n"; - throw BinderException(error.str()); - } - - return cache.GetOrCreate(key, rejects_scan, rejects_error); -} - -void CSVRejectsTable::InitializeTable(ClientContext &context, const ReadCSVData &data) { - // (Re)Create the temporary rejects table - auto &catalog = Catalog::GetCatalog(context, TEMP_CATALOG); - - // Create CSV_ERROR_TYPE ENUM - string enum_name = "CSV_ERROR_TYPE"; - constexpr uint8_t number_of_accepted_errors = 7; - Vector order_errors(LogicalType::VARCHAR, number_of_accepted_errors); - order_errors.SetValue(0, "CAST"); - order_errors.SetValue(1, "MISSING COLUMNS"); - order_errors.SetValue(2, "TOO MANY COLUMNS"); - order_errors.SetValue(3, "UNQUOTED VALUE"); - order_errors.SetValue(4, "LINE SIZE OVER MAXIMUM"); - order_errors.SetValue(5, "INVALID UNICODE"); - order_errors.SetValue(6, "INVALID STATE"); - - LogicalType enum_type = LogicalType::ENUM(enum_name, order_errors, number_of_accepted_errors); - auto type_info = make_uniq(enum_name, enum_type); - type_info->temporary = true; - type_info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - catalog.CreateType(context, *type_info); - - // Create Rejects Scans Table - { - auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, scan_table); - info->temporary = true; - info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - // 0. Scan ID - info->columns.AddColumn(ColumnDefinition("scan_id", LogicalType::UBIGINT)); - // 1. File ID (within the scan) - info->columns.AddColumn(ColumnDefinition("file_id", LogicalType::UBIGINT)); - // 2. File Path - info->columns.AddColumn(ColumnDefinition("file_path", LogicalType::VARCHAR)); - // 3. Delimiter - info->columns.AddColumn(ColumnDefinition("delimiter", LogicalType::VARCHAR)); - // 4. Quote - info->columns.AddColumn(ColumnDefinition("quote", LogicalType::VARCHAR)); - // 5. Escape - info->columns.AddColumn(ColumnDefinition("escape", LogicalType::VARCHAR)); - // 6. NewLine Delimiter - info->columns.AddColumn(ColumnDefinition("newline_delimiter", LogicalType::VARCHAR)); - // 7. Skip Rows - info->columns.AddColumn(ColumnDefinition("skip_rows", LogicalType::UINTEGER)); - // 8. Has Header - info->columns.AddColumn(ColumnDefinition("has_header", LogicalType::BOOLEAN)); - // 9. List> - info->columns.AddColumn(ColumnDefinition("columns", LogicalType::VARCHAR)); - // 10. Date Format - info->columns.AddColumn(ColumnDefinition("date_format", LogicalType::VARCHAR)); - // 11. Timestamp Format - info->columns.AddColumn(ColumnDefinition("timestamp_format", LogicalType::VARCHAR)); - // 12. CSV read function with all the options used - info->columns.AddColumn(ColumnDefinition("user_arguments", LogicalType::VARCHAR)); - catalog.CreateTable(context, std::move(info)); - } - { - // Create Rejects Error Table - auto info = make_uniq(TEMP_CATALOG, DEFAULT_SCHEMA, errors_table); - info->temporary = true; - info->on_conflict = OnCreateConflict::IGNORE_ON_CONFLICT; - // 0. Scan ID - info->columns.AddColumn(ColumnDefinition("scan_id", LogicalType::UBIGINT)); - // 1. File ID (within the scan) - info->columns.AddColumn(ColumnDefinition("file_id", LogicalType::UBIGINT)); - // 2. Row Line - info->columns.AddColumn(ColumnDefinition("line", LogicalType::UBIGINT)); - // 3. Byte Position of the start of the line - info->columns.AddColumn(ColumnDefinition("line_byte_position", LogicalType::UBIGINT)); - // 4. Byte Position where error occurred - info->columns.AddColumn(ColumnDefinition("byte_position", LogicalType::UBIGINT)); - // 5. Column Index (If Applicable) - info->columns.AddColumn(ColumnDefinition("column_idx", LogicalType::UBIGINT)); - // 6. Column Name (If Applicable) - info->columns.AddColumn(ColumnDefinition("column_name", LogicalType::VARCHAR)); - // 7. Error Type - info->columns.AddColumn(ColumnDefinition("error_type", enum_type)); - // 8. Original CSV Line - info->columns.AddColumn(ColumnDefinition("csv_line", LogicalType::VARCHAR)); - // 9. Full Error Message - info->columns.AddColumn(ColumnDefinition("error_message", LogicalType::VARCHAR)); - catalog.CreateTable(context, std::move(info)); - } - - count = 0; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp deleted file mode 100644 index 4effccaff..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_copy_to_file.cpp +++ /dev/null @@ -1,629 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" - -#include "duckdb/common/allocator.hpp" -#include "duckdb/common/queue.hpp" -#include "duckdb/common/types/batched_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/operator/persistent/batch_memory_manager.hpp" -#include "duckdb/execution/operator/persistent/batch_task_manager.hpp" -#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" -#include "duckdb/parallel/base_pipeline_event.hpp" -#include "duckdb/parallel/executor_task.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -#include - -namespace duckdb { - -struct ActiveFlushGuard { - explicit ActiveFlushGuard(atomic &bool_value_p) : bool_value(bool_value_p) { - bool_value = true; - } - ~ActiveFlushGuard() { - bool_value = false; - } - - atomic &bool_value; -}; - -PhysicalBatchCopyToFile::PhysicalBatchCopyToFile(vector types, CopyFunction function_p, - unique_ptr bind_data_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_COPY_TO_FILE, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)) { - if (!function.flush_batch || !function.prepare_batch) { - throw InternalException("PhysicalFixedBatchCopy created for copy function that does not have " - "prepare_batch/flush_batch defined"); - } -} - -//===--------------------------------------------------------------------===// -// States -//===--------------------------------------------------------------------===// -class BatchCopyTask { -public: - virtual ~BatchCopyTask() { - } - - virtual void Execute(const PhysicalBatchCopyToFile &op, ClientContext &context, GlobalSinkState &gstate_p) = 0; -}; - -struct FixedRawBatchData { - FixedRawBatchData(idx_t memory_usage_p, unique_ptr collection_p) - : memory_usage(memory_usage_p), collection(std::move(collection_p)) { - } - - idx_t memory_usage; - unique_ptr collection; -}; - -struct FixedPreparedBatchData { - idx_t memory_usage; - unique_ptr prepared_data; -}; - -class FixedBatchCopyGlobalState : public GlobalSinkState { -public: - // heuristic - we need at least 4MB of cache space per column per thread we launch - static constexpr const idx_t MINIMUM_MEMORY_PER_COLUMN_PER_THREAD = 4ULL * 1024ULL * 1024ULL; - -public: - explicit FixedBatchCopyGlobalState(ClientContext &context_p, unique_ptr global_state, - idx_t minimum_memory_per_thread) - : memory_manager(context_p, minimum_memory_per_thread), rows_copied(0), global_state(std::move(global_state)), - batch_size(0), scheduled_batch_index(0), flushed_batch_index(0), any_flushing(false), any_finished(false), - minimum_memory_per_thread(minimum_memory_per_thread) { - } - - BatchMemoryManager memory_manager; - BatchTaskManager task_manager; - mutex lock; - mutex flush_lock; - //! The total number of rows copied to the file - atomic rows_copied; - //! Global copy state - unique_ptr global_state; - //! The desired batch size (if any) - idx_t batch_size; - //! Unpartitioned batches - map> raw_batches; - //! The prepared batch data by batch index - ready to flush - map> batch_data; - //! The index of the latest batch index that has been scheduled - atomic scheduled_batch_index; - //! The index of the latest batch index that has been flushed - atomic flushed_batch_index; - //! Whether or not any thread is flushing - atomic any_flushing; - //! Whether or not any threads are finished - atomic any_finished; - //! Minimum memory per thread - idx_t minimum_memory_per_thread; - - void AddBatchData(idx_t batch_index, unique_ptr new_batch, idx_t memory_usage) { - // move the batch data to the set of prepared batch data - lock_guard l(lock); - auto prepared_data = make_uniq(); - prepared_data->prepared_data = std::move(new_batch); - prepared_data->memory_usage = memory_usage; - auto entry = batch_data.insert(make_pair(batch_index, std::move(prepared_data))); - if (!entry.second) { - throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); - } - } - - idx_t MaxThreads(idx_t source_max_threads) override { - // try to request 4MB per column per thread - memory_manager.SetMemorySize(source_max_threads * minimum_memory_per_thread); - // cap the concurrent threads working on this task based on the amount of available memory - return MinValue(source_max_threads, memory_manager.AvailableMemory() / minimum_memory_per_thread + 1); - } -}; - -enum class FixedBatchCopyState : uint8_t { SINKING_DATA = 1, PROCESSING_TASKS = 2 }; - -class FixedBatchCopyLocalState : public LocalSinkState { -public: - explicit FixedBatchCopyLocalState(unique_ptr local_state_p) - : local_state(std::move(local_state_p)), rows_copied(0), local_memory_usage(0) { - } - - //! Local copy state - unique_ptr local_state; - //! The current collection we are appending to - unique_ptr collection; - //! The append state of the collection - ColumnDataAppendState append_state; - //! How many rows have been copied in total - idx_t rows_copied; - //! Memory usage of the thread-local collection - idx_t local_memory_usage; - //! The current batch index - optional_idx batch_index; - //! Current task - FixedBatchCopyState current_task = FixedBatchCopyState::SINKING_DATA; - - void InitializeCollection(ClientContext &context, const PhysicalOperator &op) { - collection = make_uniq(BufferAllocator::Get(context), op.children[0]->types); - collection->InitializeAppend(append_state); - local_memory_usage = 0; - } -}; - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalBatchCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - auto &state = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - auto &memory_manager = gstate.memory_manager; - auto batch_index = state.partition_info.batch_index.GetIndex(); - if (state.current_task == FixedBatchCopyState::PROCESSING_TASKS) { - ExecuteTasks(context.client, gstate); - FlushBatchData(context.client, gstate); - - if (!memory_manager.IsMinimumBatchIndex(batch_index) && memory_manager.OutOfMemory(batch_index)) { - auto guard = memory_manager.Lock(); - if (!memory_manager.IsMinimumBatchIndex(batch_index)) { - // no tasks to process, we are not the minimum batch index and we have no memory available to buffer - // block the task for now - return memory_manager.BlockSink(guard, input.interrupt_state); - } - } - state.current_task = FixedBatchCopyState::SINKING_DATA; - } - if (!memory_manager.IsMinimumBatchIndex(batch_index)) { - memory_manager.UpdateMinBatchIndex(state.partition_info.min_batch_index.GetIndex()); - - // we are not processing the current min batch index - // check if we have exceeded the maximum number of unflushed rows - if (memory_manager.OutOfMemory(batch_index)) { - // out-of-memory - stop sinking chunks and instead assist in processing tasks for the minimum batch index - state.current_task = FixedBatchCopyState::PROCESSING_TASKS; - return Sink(context, chunk, input); - } - } - if (!state.collection) { - state.InitializeCollection(context.client, *this); - state.batch_index = batch_index; - } - state.rows_copied += chunk.size(); - state.collection->Append(state.append_state, chunk); - auto new_memory_usage = state.collection->AllocationSize(); - if (new_memory_usage > state.local_memory_usage) { - // memory usage increased - add to global state - memory_manager.IncreaseUnflushedMemory(new_memory_usage - state.local_memory_usage); - state.local_memory_usage = new_memory_usage; - } else if (new_memory_usage < state.local_memory_usage) { - throw InternalException("PhysicalFixedBatchCopy - memory usage decreased somehow?"); - } - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalBatchCopyToFile::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - auto &state = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - auto &memory_manager = gstate.memory_manager; - gstate.rows_copied += state.rows_copied; - - // add any final remaining local batches - AddLocalBatch(context.client, gstate, state); - - if (!gstate.any_finished) { - // signal that this thread is finished processing batches and that we should move on to Finalize - lock_guard l(gstate.lock); - gstate.any_finished = true; - } - memory_manager.UpdateMinBatchIndex(state.partition_info.min_batch_index.GetIndex()); - ExecuteTasks(context.client, gstate); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// ProcessRemainingBatchesEvent -//===--------------------------------------------------------------------===// -class ProcessRemainingBatchesTask : public ExecutorTask { -public: - ProcessRemainingBatchesTask(Executor &executor, shared_ptr event_p, FixedBatchCopyGlobalState &state_p, - ClientContext &context, const PhysicalBatchCopyToFile &op) - : ExecutorTask(executor, std::move(event_p)), op(op), gstate(state_p), context(context) { - } - - TaskExecutionResult ExecuteTask(TaskExecutionMode mode) override { - while (op.ExecuteTask(context, gstate)) { - op.FlushBatchData(context, gstate); - } - event->FinishTask(); - return TaskExecutionResult::TASK_FINISHED; - } - -private: - const PhysicalBatchCopyToFile &op; - FixedBatchCopyGlobalState &gstate; - ClientContext &context; -}; - -class ProcessRemainingBatchesEvent : public BasePipelineEvent { -public: - ProcessRemainingBatchesEvent(const PhysicalBatchCopyToFile &op_p, FixedBatchCopyGlobalState &gstate_p, - Pipeline &pipeline_p, ClientContext &context) - : BasePipelineEvent(pipeline_p), op(op_p), gstate(gstate_p), context(context) { - } - const PhysicalBatchCopyToFile &op; - FixedBatchCopyGlobalState &gstate; - ClientContext &context; - -public: - void Schedule() override { - vector> tasks; - for (idx_t i = 0; i < idx_t(TaskScheduler::GetScheduler(context).NumberOfThreads()); i++) { - auto process_task = - make_uniq(pipeline->executor, shared_from_this(), gstate, context, op); - tasks.push_back(std::move(process_task)); - } - D_ASSERT(!tasks.empty()); - SetTasks(std::move(tasks)); - } - - void FinishEvent() override { - //! Now that all batches are processed we finish flushing the file to disk - op.FinalFlush(context, gstate); - } -}; -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalBatchCopyToFile::FinalFlush(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - if (gstate.task_manager.TaskCount() != 0) { - throw InternalException("Unexecuted tasks are remaining in PhysicalFixedBatchCopy::FinalFlush!?"); - } - - FlushBatchData(context, gstate_p); - if (gstate.scheduled_batch_index != gstate.flushed_batch_index) { - throw InternalException("Not all batches were flushed to disk - incomplete file?"); - } - if (function.copy_to_finalize) { - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - - if (use_tmp_file) { - PhysicalCopyToFile::MoveTmpFile(context, file_path); - } - } - gstate.memory_manager.FinalCheck(); - return SinkFinalizeType::READY; -} - -SinkFinalizeType PhysicalBatchCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto min_batch_index = idx_t(NumericLimits::Maximum()); - // repartition any remaining batches - RepartitionBatches(context, input.global_state, min_batch_index, true); - // check if we have multiple tasks to execute - if (gstate.task_manager.TaskCount() <= 1) { - // we don't - just execute the remaining task and finish flushing to disk - ExecuteTasks(context, input.global_state); - FinalFlush(context, input.global_state); - } else { - // we have multiple tasks remaining - launch an event to execute the tasks in parallel - auto new_event = make_shared_ptr(*this, gstate, pipeline, context); - event.InsertEvent(std::move(new_event)); - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Tasks -//===--------------------------------------------------------------------===// -class RepartitionedFlushTask : public BatchCopyTask { -public: - RepartitionedFlushTask() { - } - - void Execute(const PhysicalBatchCopyToFile &op, ClientContext &context, GlobalSinkState &gstate_p) override { - op.FlushBatchData(context, gstate_p); - } -}; - -class PrepareBatchTask : public BatchCopyTask { -public: - PrepareBatchTask(idx_t batch_index, unique_ptr batch_data_p) - : batch_index(batch_index), batch_data(std::move(batch_data_p)) { - } - - idx_t batch_index; - unique_ptr batch_data; - - void Execute(const PhysicalBatchCopyToFile &op, ClientContext &context, GlobalSinkState &gstate_p) override { - auto &gstate = gstate_p.Cast(); - auto memory_usage = batch_data->memory_usage; - auto prepared_batch = - op.function.prepare_batch(context, *op.bind_data, *gstate.global_state, std::move(batch_data->collection)); - gstate.AddBatchData(batch_index, std::move(prepared_batch), memory_usage); - if (batch_index == gstate.flushed_batch_index) { - gstate.task_manager.AddTask(make_uniq()); - } - } -}; - -//===--------------------------------------------------------------------===// -// Batch Data Handling -//===--------------------------------------------------------------------===// -void PhysicalBatchCopyToFile::AddRawBatchData(ClientContext &context, GlobalSinkState &gstate_p, idx_t batch_index, - unique_ptr raw_batch) const { - auto &gstate = gstate_p.Cast(); - - // add the batch index to the set of raw batches - lock_guard l(gstate.lock); - auto entry = gstate.raw_batches.insert(make_pair(batch_index, std::move(raw_batch))); - if (!entry.second) { - throw InternalException("Duplicate batch index %llu encountered in PhysicalFixedBatchCopy", batch_index); - } -} - -static bool CorrectSizeForBatch(idx_t collection_size, idx_t desired_size) { - if (desired_size == 0) { - // a batch size of 0 indicates we are happy with any batch size - return true; - } - return idx_t(AbsValue(int64_t(collection_size) - int64_t(desired_size))) < STANDARD_VECTOR_SIZE; -} - -void PhysicalBatchCopyToFile::RepartitionBatches(ClientContext &context, GlobalSinkState &gstate_p, idx_t min_index, - bool final) const { - auto &gstate = gstate_p.Cast(); - auto &task_manager = gstate.task_manager; - - // repartition batches until the min index is reached - lock_guard l(gstate.lock); - if (gstate.raw_batches.empty()) { - return; - } - if (!final) { - if (gstate.any_finished) { - // we only repartition in ::NextBatch if all threads are still busy processing batches - // otherwise we might end up repartitioning a lot of data with only a few threads remaining - // which causes erratic performance - return; - } - // if this is not the final flush we first check if we have enough data to merge past the batch threshold - idx_t candidate_rows = 0; - for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end(); entry++) { - if (entry->first >= min_index) { - // we have exceeded the minimum batch - break; - } - candidate_rows += entry->second->collection->Count(); - } - if (candidate_rows < gstate.batch_size) { - // not enough rows - cancel! - return; - } - } - // gather all collections we can repartition - idx_t max_batch_index = 0; - vector> raw_batches; - for (auto entry = gstate.raw_batches.begin(); entry != gstate.raw_batches.end();) { - if (entry->first >= min_index) { - break; - } - max_batch_index = entry->first; - raw_batches.push_back(std::move(entry->second)); - entry = gstate.raw_batches.erase(entry); - } - unique_ptr append_batch; - ColumnDataAppendState append_state; - // now perform the actual repartitioning - for (auto ¤t_batch : raw_batches) { - if (!append_batch) { - auto current_count = current_batch->collection->Count(); - if (CorrectSizeForBatch(current_count, gstate.batch_size)) { - // the collection is ~approximately equal to the batch size (off by at most one vector) - // use it directly - task_manager.AddTask( - make_uniq(gstate.scheduled_batch_index++, std::move(current_batch))); - current_batch.reset(); - } else if (current_count < gstate.batch_size) { - // the collection is smaller than the batch size - use it as a starting point - append_batch = std::move(current_batch); - current_batch.reset(); - } else { - // the collection is too large for a batch - we need to repartition - // create an empty collection - auto new_collection = - make_uniq(BufferAllocator::Get(context), children[0]->types); - append_batch = make_uniq(0U, std::move(new_collection)); - } - if (append_batch) { - append_batch->collection->InitializeAppend(append_state); - } - } - if (!current_batch) { - // we have consumed the collection already - no need to append - continue; - } - auto ¤t_collection = *current_batch->collection; - append_batch->memory_usage += current_batch->memory_usage; - // iterate the collection while appending - for (auto &chunk : current_collection.Chunks()) { - // append the chunk to the collection - append_batch->collection->Append(append_state, chunk); - if (append_batch->collection->Count() < gstate.batch_size) { - // the collection is still under the batch size - continue - continue; - } - // the collection is full - move it to the result and create a new one - task_manager.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(append_batch))); - - auto new_collection = make_uniq(BufferAllocator::Get(context), children[0]->types); - append_batch = make_uniq(0U, std::move(new_collection)); - append_batch->collection->InitializeAppend(append_state); - } - } - if (append_batch && append_batch->collection->Count() > 0) { - // if there are any remaining batches that are not filled up to the batch size - // AND this is not the final collection - // re-add it to the set of raw (to-be-merged) batches - if (final || CorrectSizeForBatch(append_batch->collection->Count(), gstate.batch_size)) { - task_manager.AddTask(make_uniq(gstate.scheduled_batch_index++, std::move(append_batch))); - } else { - gstate.raw_batches[max_batch_index] = std::move(append_batch); - } - } -} - -void PhysicalBatchCopyToFile::FlushBatchData(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - auto &memory_manager = gstate.memory_manager; - - // flush batch data to disk (if there are any to flush) - // grab the flush lock - we can only call flush_batch with this lock - // otherwise the data might end up in the wrong order - { - lock_guard l(gstate.flush_lock); - if (gstate.any_flushing) { - return; - } - gstate.any_flushing = true; - } - ActiveFlushGuard active_flush(gstate.any_flushing); - while (true) { - unique_ptr batch_data; - { - lock_guard l(gstate.lock); - if (gstate.batch_data.empty()) { - // no batch data left to flush - break; - } - auto entry = gstate.batch_data.begin(); - if (entry->first != gstate.flushed_batch_index) { - // this entry is not yet ready to be flushed - break; - } - if (entry->first < gstate.flushed_batch_index) { - throw InternalException("Batch index was out of order!?"); - } - batch_data = std::move(entry->second); - gstate.batch_data.erase(entry); - } - function.flush_batch(context, *bind_data, *gstate.global_state, *batch_data->prepared_data); - batch_data->prepared_data.reset(); - memory_manager.ReduceUnflushedMemory(batch_data->memory_usage); - gstate.flushed_batch_index++; - } -} - -//===--------------------------------------------------------------------===// -// Tasks -//===--------------------------------------------------------------------===// -bool PhysicalBatchCopyToFile::ExecuteTask(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - auto task = gstate.task_manager.GetTask(); - if (!task) { - return false; - } - task->Execute(*this, context, gstate_p); - return true; -} - -void PhysicalBatchCopyToFile::ExecuteTasks(ClientContext &context, GlobalSinkState &gstate_p) const { - while (ExecuteTask(context, gstate_p)) { - } -} - -//===--------------------------------------------------------------------===// -// Next Batch -//===--------------------------------------------------------------------===// -void PhysicalBatchCopyToFile::AddLocalBatch(ClientContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate) const { - auto &state = lstate.Cast(); - auto &gstate = gstate_p.Cast(); - auto &memory_manager = gstate.memory_manager; - if (!state.collection || state.collection->Count() == 0) { - return; - } - // we finished processing this batch - // start flushing data - auto min_batch_index = state.partition_info.min_batch_index.GetIndex(); - // push the raw batch data into the set of unprocessed batches - auto raw_batch = make_uniq(state.local_memory_usage, std::move(state.collection)); - AddRawBatchData(context, gstate, state.batch_index.GetIndex(), std::move(raw_batch)); - // attempt to repartition to our desired batch size - RepartitionBatches(context, gstate, min_batch_index); - // unblock tasks so they can help process batches (if any are blocked) - bool any_unblocked; - { - auto guard = memory_manager.Lock(); - any_unblocked = memory_manager.UnblockTasks(guard); - } - // if any threads were unblocked they can pick up execution of the tasks - // otherwise we will execute a task and flush here - if (!any_unblocked) { - //! Execute a single repartition task - ExecuteTask(context, gstate); - //! Flush batch data to disk (if any is ready) - FlushBatchData(context, gstate); - } -} - -SinkNextBatchType PhysicalBatchCopyToFile::NextBatch(ExecutionContext &context, - OperatorSinkNextBatchInput &input) const { - auto &lstate = input.local_state; - auto &state = lstate.Cast(); - auto &gstate = input.global_state.Cast(); - auto &memory_manager = gstate.memory_manager; - - // add the previously finished batch (if any) to the state - AddLocalBatch(context.client, gstate, state); - - // update the minimum batch index - memory_manager.UpdateMinBatchIndex(state.partition_info.min_batch_index.GetIndex()); - state.batch_index = lstate.partition_info.batch_index.GetIndex(); - - state.InitializeCollection(context.client, *this); - return SinkNextBatchType::READY; -} - -unique_ptr PhysicalBatchCopyToFile::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(function.copy_to_initialize_local(context, *bind_data)); -} - -unique_ptr PhysicalBatchCopyToFile::GetGlobalSinkState(ClientContext &context) const { - // request memory based on the minimum amount of memory per column - auto minimum_memory_per_thread = - FixedBatchCopyGlobalState::MINIMUM_MEMORY_PER_COLUMN_PER_THREAD * children[0]->types.size(); - auto result = make_uniq( - context, function.copy_to_initialize_global(context, *bind_data, file_path), minimum_memory_per_thread); - result->batch_size = function.desired_batch_size ? function.desired_batch_size(context, *bind_data) : 0; - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalBatchCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &g = sink_state->Cast(); - - chunk.SetCardinality(1); - switch (return_type) { - case CopyFunctionReturnType::CHANGED_ROWS: - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); - break; - case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: { - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); - auto fp = use_tmp_file ? PhysicalCopyToFile::GetNonTmpFile(context.client, file_path) : file_path; - chunk.SetValue(1, 0, Value::LIST(LogicalType::VARCHAR, {fp})); - break; - } - default: - throw NotImplementedException("Unknown CopyFunctionReturnType"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp deleted file mode 100644 index 2e546c477..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_batch_insert.cpp +++ /dev/null @@ -1,656 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" - -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/execution/operator/persistent/batch_memory_manager.hpp" -#include "duckdb/execution/operator/persistent/batch_task_manager.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table/append_state.hpp" -#include "duckdb/storage/table/row_group_collection.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/storage/table_io_manager.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/transaction/local_storage.hpp" - -namespace duckdb { - -PhysicalBatchInsert::PhysicalBatchInsert(vector types_p, TableCatalogEntry &table, - physical_index_vector_t column_index_map_p, - vector> bound_defaults_p, - vector> bound_constraints_p, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_INSERT, std::move(types_p), estimated_cardinality), - column_index_map(std::move(column_index_map_p)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults_p)), bound_constraints(std::move(bound_constraints_p)) { -} - -PhysicalBatchInsert::PhysicalBatchInsert(LogicalOperator &op, SchemaCatalogEntry &schema, - unique_ptr info_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::BATCH_CREATE_TABLE_AS, op.types, estimated_cardinality), - insert_table(nullptr), schema(&schema), info(std::move(info_p)) { - PhysicalInsert::GetInsertInfo(*info, insert_types, bound_defaults); -} - -//===--------------------------------------------------------------------===// -// CollectionMerger -//===--------------------------------------------------------------------===// -enum class RowGroupBatchType : uint8_t { FLUSHED, NOT_FLUSHED }; - -class CollectionMerger { -public: - explicit CollectionMerger(ClientContext &context) : context(context) { - } - - ClientContext &context; - vector> current_collections; - RowGroupBatchType batch_type = RowGroupBatchType::NOT_FLUSHED; - -public: - void AddCollection(unique_ptr collection, RowGroupBatchType type) { - current_collections.push_back(std::move(collection)); - if (type == RowGroupBatchType::FLUSHED) { - batch_type = RowGroupBatchType::FLUSHED; - if (current_collections.size() > 1) { - throw InternalException("Cannot merge flushed collections"); - } - } - } - - bool Empty() { - return current_collections.empty(); - } - - unique_ptr Flush(OptimisticDataWriter &writer) { - if (Empty()) { - return nullptr; - } - unique_ptr new_collection = std::move(current_collections[0]); - if (current_collections.size() > 1) { - // we have gathered multiple collections: create one big collection and merge that - auto &types = new_collection->GetTypes(); - TableAppendState append_state; - new_collection->InitializeAppend(append_state); - - DataChunk scan_chunk; - scan_chunk.Initialize(context, types); - - vector column_ids; - for (idx_t i = 0; i < types.size(); i++) { - column_ids.emplace_back(i); - } - for (auto &collection : current_collections) { - if (!collection) { - continue; - } - TableScanState scan_state; - scan_state.Initialize(column_ids); - collection->InitializeScan(scan_state.local_state, column_ids, nullptr); - - while (true) { - scan_chunk.Reset(); - scan_state.local_state.ScanCommitted(scan_chunk, TableScanType::TABLE_SCAN_COMMITTED_ROWS); - if (scan_chunk.size() == 0) { - break; - } - auto new_row_group = new_collection->Append(scan_chunk, append_state); - if (new_row_group) { - writer.WriteNewRowGroup(*new_collection); - } - } - } - new_collection->FinalizeAppend(TransactionData(0, 0), append_state); - writer.WriteLastRowGroup(*new_collection); - } else if (batch_type == RowGroupBatchType::NOT_FLUSHED) { - writer.WriteLastRowGroup(*new_collection); - } - current_collections.clear(); - return new_collection; - } -}; - -struct RowGroupBatchEntry { - RowGroupBatchEntry(idx_t batch_idx, unique_ptr collection_p, RowGroupBatchType type) - : batch_idx(batch_idx), total_rows(collection_p->GetTotalRows()), unflushed_memory(0), - collection(std::move(collection_p)), type(type) { - if (type == RowGroupBatchType::NOT_FLUSHED) { - unflushed_memory = collection->GetAllocationSize(); - } - } - - idx_t batch_idx; - idx_t total_rows; - idx_t unflushed_memory; - unique_ptr collection; - RowGroupBatchType type; -}; - -//===--------------------------------------------------------------------===// -// States -//===--------------------------------------------------------------------===// -class BatchInsertTask { -public: - virtual ~BatchInsertTask() { - } - - virtual void Execute(const PhysicalBatchInsert &op, ClientContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) = 0; -}; - -class BatchInsertGlobalState : public GlobalSinkState { -public: - explicit BatchInsertGlobalState(ClientContext &context, DuckTableEntry &table, idx_t minimum_memory_per_thread) - : memory_manager(context, minimum_memory_per_thread), table(table), insert_count(0), - optimistically_written(false), minimum_memory_per_thread(minimum_memory_per_thread) { - row_group_size = table.GetStorage().GetRowGroupSize(); - } - - BatchMemoryManager memory_manager; - BatchTaskManager task_manager; - mutex lock; - DuckTableEntry &table; - idx_t row_group_size; - idx_t insert_count; - vector collections; - idx_t next_start = 0; - atomic optimistically_written; - idx_t minimum_memory_per_thread; - - bool ReadyToMerge(idx_t count) const; - void ScheduleMergeTasks(idx_t min_batch_index); - unique_ptr MergeCollections(ClientContext &context, - vector merge_collections, - OptimisticDataWriter &writer); - void AddCollection(ClientContext &context, idx_t batch_index, idx_t min_batch_index, - unique_ptr current_collection, - optional_ptr writer = nullptr); - - idx_t MaxThreads(idx_t source_max_threads) override { - // try to request 4MB per column per thread - memory_manager.SetMemorySize(source_max_threads * minimum_memory_per_thread); - // cap the concurrent threads working on this task based on the amount of available memory - return MinValue(source_max_threads, memory_manager.AvailableMemory() / minimum_memory_per_thread + 1); - } -}; - -class BatchInsertLocalState : public LocalSinkState { -public: - BatchInsertLocalState(ClientContext &context, const vector &types, - const vector> &bound_defaults) - : default_executor(context, bound_defaults) { - insert_chunk.Initialize(Allocator::Get(context), types); - } - - DataChunk insert_chunk; - ExpressionExecutor default_executor; - idx_t current_index; - TableAppendState current_append_state; - unique_ptr current_collection; - optional_ptr writer; - unique_ptr constraint_state; - - void CreateNewCollection(DuckTableEntry &table, const vector &insert_types) { - auto table_info = table.GetStorage().GetDataTableInfo(); - auto &io_manager = TableIOManager::Get(table.GetStorage()); - current_collection = make_uniq(std::move(table_info), io_manager, insert_types, - NumericCast(MAX_ROW_ID)); - current_collection->InitializeEmpty(); - current_collection->InitializeAppend(current_append_state); - } -}; - -//===--------------------------------------------------------------------===// -// Merge Task -//===--------------------------------------------------------------------===// -class MergeCollectionTask : public BatchInsertTask { -public: - MergeCollectionTask(vector merge_collections_p, idx_t merged_batch_index) - : merge_collections(std::move(merge_collections_p)), merged_batch_index(merged_batch_index) { - } - - vector merge_collections; - idx_t merged_batch_index; - - void Execute(const PhysicalBatchInsert &op, ClientContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) override { - auto &gstate = gstate_p.Cast(); - auto &lstate = lstate_p.Cast(); - // merge together the collections - D_ASSERT(lstate.writer); - auto final_collection = gstate.MergeCollections(context, std::move(merge_collections), *lstate.writer); - // add the merged-together collection to the set of batch indexes - lock_guard l(gstate.lock); - RowGroupBatchEntry new_entry(merged_batch_index, std::move(final_collection), RowGroupBatchType::FLUSHED); - auto it = std::lower_bound( - gstate.collections.begin(), gstate.collections.end(), new_entry, - [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { return a.batch_idx < b.batch_idx; }); - if (it->batch_idx != merged_batch_index) { - throw InternalException("Merged batch index was no longer present in collection"); - } - it->collection = std::move(new_entry.collection); - } -}; - -struct BatchMergeTask { - explicit BatchMergeTask(idx_t start_index) : start_index(start_index), end_index(0), total_count(0) { - } - - idx_t start_index; - idx_t end_index; - idx_t total_count; -}; - -bool BatchInsertGlobalState::ReadyToMerge(idx_t count) const { - // we try to merge so the count fits nicely into row groups - if (count >= row_group_size / 10 * 9 && count <= row_group_size) { - // 90%-100% of row group size - return true; - } - if (count >= row_group_size / 10 * 18 && count <= row_group_size * 2) { - // 180%-200% of row group size - return true; - } - if (count >= row_group_size / 10 * 27 && count <= row_group_size * 3) { - // 270%-300% of row group size - return true; - } - if (count >= row_group_size / 10 * 36) { - // >360% of row group size - return true; - } - return false; -} - -void BatchInsertGlobalState::ScheduleMergeTasks(idx_t min_batch_index) { - idx_t current_idx; - - vector to_be_scheduled_tasks; - - BatchMergeTask current_task(next_start); - for (current_idx = current_task.start_index; current_idx < collections.size(); current_idx++) { - auto &entry = collections[current_idx]; - if (entry.batch_idx > min_batch_index) { - // this entry is AFTER the min_batch_index - // finished - if (ReadyToMerge(current_task.total_count)) { - current_task.end_index = current_idx; - to_be_scheduled_tasks.push_back(current_task); - } - break; - } - if (entry.type == RowGroupBatchType::FLUSHED) { - // already flushed: cannot flush anything here - if (current_task.total_count > 0) { - current_task.end_index = current_idx; - to_be_scheduled_tasks.push_back(current_task); - } - current_task.start_index = current_idx + 1; - if (current_task.start_index > next_start) { - // avoid checking this segment again in the future - next_start = current_task.start_index; - } - current_task.total_count = 0; - continue; - } - // not flushed - add to set of indexes to flush - current_task.total_count += entry.total_rows; - if (ReadyToMerge(current_task.total_count)) { - // create a task to merge these collections - current_task.end_index = current_idx + 1; - to_be_scheduled_tasks.push_back(current_task); - current_task.start_index = current_idx + 1; - current_task.total_count = 0; - } - } - - if (to_be_scheduled_tasks.empty()) { - return; - } - for (auto &scheduled_task : to_be_scheduled_tasks) { - D_ASSERT(scheduled_task.total_count > 0); - D_ASSERT(current_idx > scheduled_task.start_index); - idx_t merged_batch_index = collections[scheduled_task.start_index].batch_idx; - vector merge_collections; - for (idx_t idx = scheduled_task.start_index; idx < scheduled_task.end_index; idx++) { - auto &entry = collections[idx]; - if (!entry.collection || entry.type == RowGroupBatchType::FLUSHED) { - throw InternalException("Adding a row group collection that should not be flushed"); - } - RowGroupBatchEntry added_entry(collections[scheduled_task.start_index].batch_idx, - std::move(entry.collection), RowGroupBatchType::FLUSHED); - added_entry.unflushed_memory = entry.unflushed_memory; - merge_collections.push_back(std::move(added_entry)); - entry.total_rows = scheduled_task.total_count; - entry.type = RowGroupBatchType::FLUSHED; - } - task_manager.AddTask(make_uniq(std::move(merge_collections), merged_batch_index)); - } - // erase in reverse order - for (idx_t i = to_be_scheduled_tasks.size(); i > 0; i--) { - auto &scheduled_task = to_be_scheduled_tasks[i - 1]; - if (scheduled_task.start_index + 1 < scheduled_task.end_index) { - // erase all entries except the first one - collections.erase(collections.begin() + NumericCast(scheduled_task.start_index) + 1, - collections.begin() + NumericCast(scheduled_task.end_index)); - } - } -} - -unique_ptr BatchInsertGlobalState::MergeCollections(ClientContext &context, - vector merge_collections, - OptimisticDataWriter &writer) { - D_ASSERT(!merge_collections.empty()); - CollectionMerger merger(context); - idx_t written_data = 0; - for (auto &entry : merge_collections) { - merger.AddCollection(std::move(entry.collection), RowGroupBatchType::NOT_FLUSHED); - written_data += entry.unflushed_memory; - } - optimistically_written = true; - memory_manager.ReduceUnflushedMemory(written_data); - return merger.Flush(writer); -} - -void BatchInsertGlobalState::AddCollection(ClientContext &context, idx_t batch_index, idx_t min_batch_index, - unique_ptr current_collection, - optional_ptr writer) { - if (batch_index < min_batch_index) { - throw InternalException("Batch index of the added collection (%llu) is smaller than the min batch index (%llu)", - batch_index, min_batch_index); - } - auto new_count = current_collection->GetTotalRows(); - auto batch_type = new_count < row_group_size ? RowGroupBatchType::NOT_FLUSHED : RowGroupBatchType::FLUSHED; - if (batch_type == RowGroupBatchType::FLUSHED && writer) { - writer->WriteLastRowGroup(*current_collection); - } - lock_guard l(lock); - insert_count += new_count; - // add the collection to the batch index - RowGroupBatchEntry new_entry(batch_index, std::move(current_collection), batch_type); - if (batch_type == RowGroupBatchType::NOT_FLUSHED) { - memory_manager.IncreaseUnflushedMemory(new_entry.unflushed_memory); - } - - auto it = std::lower_bound( - collections.begin(), collections.end(), new_entry, - [&](const RowGroupBatchEntry &a, const RowGroupBatchEntry &b) { return a.batch_idx < b.batch_idx; }); - if (it != collections.end() && it->batch_idx == new_entry.batch_idx) { - throw InternalException("PhysicalBatchInsert::AddCollection error: batch index %d is present in multiple " - "collections. This occurs when " - "batch indexes are not uniquely distributed over threads", - batch_index); - } - collections.insert(it, std::move(new_entry)); - if (writer) { - ScheduleMergeTasks(min_batch_index); - } -} - -//===--------------------------------------------------------------------===// -// States -//===--------------------------------------------------------------------===// -unique_ptr PhysicalBatchInsert::GetGlobalSinkState(ClientContext &context) const { - optional_ptr table; - if (info) { - // CREATE TABLE AS - D_ASSERT(!insert_table); - auto &catalog = schema->catalog; - auto created_table = catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info); - table = &created_table->Cast(); - } else { - D_ASSERT(insert_table); - D_ASSERT(insert_table->IsDuckTable()); - table = insert_table.get_mutable(); - } - // heuristic - we start off by allocating 4MB of cache space per column - static constexpr const idx_t MINIMUM_MEMORY_PER_COLUMN = 4ULL * 1024ULL * 1024ULL; - auto minimum_memory_per_thread = table->GetColumns().PhysicalColumnCount() * MINIMUM_MEMORY_PER_COLUMN; - auto result = make_uniq(context, table->Cast(), minimum_memory_per_thread); - return std::move(result); -} - -unique_ptr PhysicalBatchInsert::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, insert_types, bound_defaults); -} - -//===--------------------------------------------------------------------===// -// Tasks -//===--------------------------------------------------------------------===// -bool PhysicalBatchInsert::ExecuteTask(ClientContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) const { - auto &gstate = gstate_p.Cast(); - auto task = gstate.task_manager.GetTask(); - if (!task) { - return false; - } - task->Execute(*this, context, gstate_p, lstate_p); - return true; -} - -void PhysicalBatchInsert::ExecuteTasks(ClientContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) const { - while (ExecuteTask(context, gstate_p, lstate_p)) { - } -} - -//===--------------------------------------------------------------------===// -// NextBatch -//===--------------------------------------------------------------------===// -SinkNextBatchType PhysicalBatchInsert::NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &memory_manager = gstate.memory_manager; - - auto batch_index = lstate.partition_info.batch_index.GetIndex(); - if (lstate.current_collection) { - if (lstate.current_index == batch_index) { - throw InternalException("NextBatch called with the same batch index?"); - } - // batch index has changed: move the old collection to the global state and create a new collection - TransactionData tdata(0, 0); - lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); - gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), - std::move(lstate.current_collection), lstate.writer); - - bool any_unblocked; - { - auto guard = memory_manager.Lock(); - any_unblocked = memory_manager.UnblockTasks(guard); - } - if (!any_unblocked) { - ExecuteTasks(context.client, gstate, lstate); - } - lstate.current_collection.reset(); - } - lstate.current_index = batch_index; - - // unblock any blocked tasks - auto guard = memory_manager.Lock(); - memory_manager.UnblockTasks(guard); - - return SinkNextBatchType::READY; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalBatchInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &memory_manager = gstate.memory_manager; - - auto &table = gstate.table; - PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); - - auto batch_index = lstate.partition_info.batch_index.GetIndex(); - // check if we should process this batch - if (!memory_manager.IsMinimumBatchIndex(batch_index)) { - memory_manager.UpdateMinBatchIndex(lstate.partition_info.min_batch_index.GetIndex()); - - // we are not processing the current min batch index - // check if we have exceeded the maximum number of unflushed rows - if (memory_manager.OutOfMemory(batch_index)) { - // out-of-memory - // execute tasks while we wait (if any are available) - ExecuteTasks(context.client, gstate, lstate); - - auto guard = memory_manager.Lock(); - if (!memory_manager.IsMinimumBatchIndex(batch_index)) { - // we are not the minimum batch index and we have no memory available to buffer - block the task for - // now - return memory_manager.BlockSink(guard, input.interrupt_state); - } - } - } - if (!lstate.current_collection) { - lock_guard l(gstate.lock); - // no collection yet: create a new one - lstate.CreateNewCollection(table, insert_types); - if (!lstate.writer) { - lstate.writer = &table.GetStorage().CreateOptimisticWriter(context.client); - } - } - - if (lstate.current_index != batch_index) { - throw InternalException("Current batch differs from batch - but NextBatch was not called!?"); - } - - if (!lstate.constraint_state) { - lstate.constraint_state = table.GetStorage().InitializeConstraintState(table, bound_constraints); - } - auto &storage = table.GetStorage(); - storage.VerifyAppendConstraints(*lstate.constraint_state, context.client, lstate.insert_chunk, nullptr, nullptr); - - auto new_row_group = lstate.current_collection->Append(lstate.insert_chunk, lstate.current_append_state); - if (new_row_group) { - // we have already written to disk - flush the next row group as well - lstate.writer->WriteNewRowGroup(*lstate.current_collection); - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -SinkCombineResultType PhysicalBatchInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &memory_manager = gstate.memory_manager; - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - - memory_manager.UpdateMinBatchIndex(lstate.partition_info.min_batch_index.GetIndex()); - - if (lstate.current_collection) { - TransactionData tdata(0, 0); - lstate.current_collection->FinalizeAppend(tdata, lstate.current_append_state); - if (lstate.current_collection->GetTotalRows() > 0) { - gstate.AddCollection(context.client, lstate.current_index, lstate.partition_info.min_batch_index.GetIndex(), - std::move(lstate.current_collection)); - } - } - if (lstate.writer) { - lock_guard l(gstate.lock); - gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); - } - - // unblock any blocked tasks - auto guard = memory_manager.Lock(); - memory_manager.UnblockTasks(guard); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -SinkFinalizeType PhysicalBatchInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &memory_manager = gstate.memory_manager; - - if (gstate.optimistically_written || gstate.insert_count >= gstate.row_group_size) { - // we have written data to disk optimistically or are inserting a large amount of data - // perform a final pass over all of the row groups and merge them together - vector> mergers; - unique_ptr current_merger; - - auto &storage = gstate.table.GetStorage(); - for (auto &entry : gstate.collections) { - if (entry.type == RowGroupBatchType::NOT_FLUSHED) { - // this collection has not been flushed: add it to the merge set - if (!current_merger) { - current_merger = make_uniq(context); - } - current_merger->AddCollection(std::move(entry.collection), entry.type); - memory_manager.ReduceUnflushedMemory(entry.unflushed_memory); - } else { - // this collection has been flushed: it does not need to be merged - // create a separate collection merger only for this entry - if (current_merger) { - // we have small collections remaining: flush them - mergers.push_back(std::move(current_merger)); - current_merger.reset(); - } - auto larger_merger = make_uniq(context); - larger_merger->AddCollection(std::move(entry.collection), entry.type); - mergers.push_back(std::move(larger_merger)); - } - } - if (current_merger) { - mergers.push_back(std::move(current_merger)); - } - - // now that we have created all of the mergers, perform the actual merging - vector> final_collections; - final_collections.reserve(mergers.size()); - auto &writer = storage.CreateOptimisticWriter(context); - for (auto &merger : mergers) { - final_collections.push_back(merger->Flush(writer)); - } - - // finally, merge the row groups into the local storage - for (auto &collection : final_collections) { - storage.LocalMerge(context, *collection); - } - storage.FinalizeOptimisticWriter(context, writer); - } else { - // we are writing a small amount of data to disk - // append directly to transaction local storage - auto &table = gstate.table; - auto &storage = table.GetStorage(); - LocalAppendState append_state; - storage.InitializeLocalAppend(append_state, table, context, bound_constraints); - auto &transaction = DuckTransaction::Get(context, table.catalog); - for (auto &entry : gstate.collections) { - if (entry.type != RowGroupBatchType::NOT_FLUSHED) { - throw InternalException("Encountered a flushed batch"); - } - - memory_manager.ReduceUnflushedMemory(entry.unflushed_memory); - entry.collection->Scan(transaction, [&](DataChunk &insert_chunk) { - storage.LocalAppend(append_state, context, insert_chunk, false); - return true; - }); - } - storage.FinalizeLocalAppend(append_state); - } - memory_manager.FinalCheck(); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalBatchInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &insert_gstate = sink_state->Cast(); - - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(insert_gstate.insert_count))); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp deleted file mode 100644 index 6ad354ea2..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_database.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_copy_database.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/parsed_data/bound_create_table_info.hpp" -#include "duckdb/parser/parsed_data/create_schema_info.hpp" -#include "duckdb/parser/parsed_data/create_macro_info.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/parsed_data/create_index_info.hpp" - -namespace duckdb { - -PhysicalCopyDatabase::PhysicalCopyDatabase(vector types, idx_t estimated_cardinality, - unique_ptr info_p) - : PhysicalOperator(PhysicalOperatorType::COPY_DATABASE, std::move(types), estimated_cardinality), - info(std::move(info_p)) { -} - -PhysicalCopyDatabase::~PhysicalCopyDatabase() { -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCopyDatabase::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->target_database); - for (auto &create_info : info->entries) { - switch (create_info->type) { - case CatalogType::SCHEMA_ENTRY: - catalog.CreateSchema(context.client, create_info->Cast()); - break; - case CatalogType::VIEW_ENTRY: - catalog.CreateView(context.client, create_info->Cast()); - break; - case CatalogType::SEQUENCE_ENTRY: - catalog.CreateSequence(context.client, create_info->Cast()); - break; - case CatalogType::TYPE_ENTRY: - catalog.CreateType(context.client, create_info->Cast()); - break; - case CatalogType::MACRO_ENTRY: - case CatalogType::TABLE_MACRO_ENTRY: - catalog.CreateFunction(context.client, create_info->Cast()); - break; - case CatalogType::TABLE_ENTRY: { - auto binder = Binder::CreateBinder(context.client); - auto bound_info = binder->BindCreateTableInfo(std::move(create_info)); - catalog.CreateTable(context.client, *bound_info); - break; - } - case CatalogType::INDEX_ENTRY: { - catalog.CreateIndex(context.client, create_info->Cast()); - break; - } - default: - throw NotImplementedException("Entry type %s not supported in PhysicalCopyDatabase", - CatalogTypeToString(create_info->type)); - } - } - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp b/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp deleted file mode 100644 index fa85d670b..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_copy_to_file.cpp +++ /dev/null @@ -1,559 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" - -#include "duckdb/common/file_opener.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/hive_partitioning.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/uuid.hpp" -#include "duckdb/common/value_operations/value_operations.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/planner/operator/logical_copy_to_file.hpp" - -#include - -namespace duckdb { - -struct PartitionWriteInfo { - unique_ptr global_state; - idx_t active_writes = 0; -}; - -struct VectorOfValuesHashFunction { - uint64_t operator()(const vector &values) const { - hash_t result = 0; - for (auto &val : values) { - result ^= val.Hash(); - } - return result; - } -}; - -struct VectorOfValuesEquality { - bool operator()(const vector &a, const vector &b) const { - if (a.size() != b.size()) { - return false; - } - for (idx_t i = 0; i < a.size(); i++) { - if (ValueOperations::DistinctFrom(a[i], b[i])) { - return false; - } - } - return true; - } -}; - -template -using vector_of_value_map_t = unordered_map, T, VectorOfValuesHashFunction, VectorOfValuesEquality>; - -class CopyToFunctionGlobalState : public GlobalSinkState { -public: - explicit CopyToFunctionGlobalState(ClientContext &context, unique_ptr global_state) - : rows_copied(0), last_file_offset(0), global_state(std::move(global_state)) { - max_open_files = ClientConfig::GetConfig(context).partitioned_write_max_open_files; - } - StorageLock lock; - atomic rows_copied; - atomic last_file_offset; - unique_ptr global_state; - //! Created directories - unordered_set created_directories; - //! shared state for HivePartitionedColumnData - shared_ptr partition_state; - //! File names - vector file_names; - //! Max open files - idx_t max_open_files; - - void CreateDir(const string &dir_path, FileSystem &fs) { - if (created_directories.find(dir_path) != created_directories.end()) { - // already attempted to create this directory - return; - } - if (!fs.DirectoryExists(dir_path)) { - fs.CreateDirectory(dir_path); - } - created_directories.insert(dir_path); - } - - string GetOrCreateDirectory(const vector &cols, const vector &names, const vector &values, - string path, FileSystem &fs) { - CreateDir(path, fs); - for (idx_t i = 0; i < cols.size(); i++) { - const auto &partition_col_name = names[cols[i]]; - const auto &partition_value = values[i]; - string p_dir; - p_dir += HivePartitioning::Escape(partition_col_name); - p_dir += "="; - p_dir += HivePartitioning::Escape(partition_value.ToString()); - path = fs.JoinPath(path, p_dir); - CreateDir(path, fs); - } - return path; - } - - void AddFileName(const StorageLockKey &l, const string &file_name) { - D_ASSERT(l.GetType() == StorageLockType::EXCLUSIVE); - file_names.emplace_back(file_name); - } - - void FinalizePartition(ClientContext &context, const PhysicalCopyToFile &op, PartitionWriteInfo &info) { - if (!info.global_state) { - // already finalized - return; - } - // finalize the partition - op.function.copy_to_finalize(context, *op.bind_data, *info.global_state); - info.global_state.reset(); - } - - void FinalizePartitions(ClientContext &context, const PhysicalCopyToFile &op) { - // finalize any remaining partitions - for (auto &entry : active_partitioned_writes) { - FinalizePartition(context, op, *entry.second); - } - } - - PartitionWriteInfo &GetPartitionWriteInfo(ExecutionContext &context, const PhysicalCopyToFile &op, - const vector &values) { - auto global_lock = lock.GetExclusiveLock(); - // check if we have already started writing this partition - auto active_write_entry = active_partitioned_writes.find(values); - if (active_write_entry != active_partitioned_writes.end()) { - // we have - continue writing in this partition - active_write_entry->second->active_writes++; - return *active_write_entry->second; - } - // check if we need to close any writers before we can continue - if (active_partitioned_writes.size() >= max_open_files) { - // we need to! try to close writers - for (auto &entry : active_partitioned_writes) { - if (entry.second->active_writes == 0) { - // we can evict this entry - evict the partition - FinalizePartition(context.client, op, *entry.second); - ++previous_partitions[entry.first]; - active_partitioned_writes.erase(entry.first); - break; - } - } - } - idx_t offset = 0; - auto prev_offset = previous_partitions.find(values); - if (prev_offset != previous_partitions.end()) { - offset = prev_offset->second; - } - auto &fs = FileSystem::GetFileSystem(context.client); - // Create a writer for the current file - auto trimmed_path = op.GetTrimmedPath(context.client); - string hive_path = GetOrCreateDirectory(op.partition_columns, op.names, values, trimmed_path, fs); - string full_path(op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset)); - if (op.overwrite_mode == CopyOverwriteMode::COPY_APPEND) { - // when appending, we first check if the file exists - while (fs.FileExists(full_path)) { - // file already exists - re-generate name - if (!op.filename_pattern.HasUUID()) { - throw InternalException("CopyOverwriteMode::COPY_APPEND without {uuid} - and file exists"); - } - full_path = op.filename_pattern.CreateFilename(fs, hive_path, op.file_extension, offset); - } - } - if (op.return_type == CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST) { - AddFileName(*global_lock, full_path); - } - // initialize writes - auto info = make_uniq(); - info->global_state = op.function.copy_to_initialize_global(context.client, *op.bind_data, full_path); - auto &result = *info; - info->active_writes = 1; - // store in active write map - active_partitioned_writes.insert(make_pair(values, std::move(info))); - return result; - } - - void FinishPartitionWrite(PartitionWriteInfo &info) { - auto global_lock = lock.GetExclusiveLock(); - info.active_writes--; - } - -private: - //! The active writes per partition (for partitioned write) - vector_of_value_map_t> active_partitioned_writes; - vector_of_value_map_t previous_partitions; -}; - -string PhysicalCopyToFile::GetTrimmedPath(ClientContext &context) const { - auto &fs = FileSystem::GetFileSystem(context); - string trimmed_path = file_path; - StringUtil::RTrim(trimmed_path, fs.PathSeparator(trimmed_path)); - return trimmed_path; -} - -class CopyToFunctionLocalState : public LocalSinkState { -public: - explicit CopyToFunctionLocalState(unique_ptr local_state) : local_state(std::move(local_state)) { - } - unique_ptr global_state; - unique_ptr local_state; - - //! Buffers the tuples in partitions before writing - unique_ptr part_buffer; - unique_ptr part_buffer_append_state; - - idx_t append_count = 0; - - void InitializeAppendState(ClientContext &context, const PhysicalCopyToFile &op, - CopyToFunctionGlobalState &gstate) { - part_buffer = make_uniq(context, op.expected_types, op.partition_columns, - gstate.partition_state); - part_buffer_append_state = make_uniq(); - part_buffer->InitializeAppendState(*part_buffer_append_state); - append_count = 0; - } - - void AppendToPartition(ExecutionContext &context, const PhysicalCopyToFile &op, CopyToFunctionGlobalState &g, - DataChunk &chunk) { - if (!part_buffer) { - // re-initialize the append - InitializeAppendState(context.client, op, g); - } - part_buffer->Append(*part_buffer_append_state, chunk); - append_count += chunk.size(); - if (append_count >= ClientConfig::GetConfig(context.client).partitioned_write_flush_threshold) { - // flush all cached partitions - FlushPartitions(context, op, g); - } - } - - void ResetAppendState() { - part_buffer_append_state.reset(); - part_buffer.reset(); - append_count = 0; - } - - void SetDataWithoutPartitions(DataChunk &chunk, const DataChunk &source, const vector &col_types, - const vector &part_cols) { - D_ASSERT(source.ColumnCount() == col_types.size()); - auto types = LogicalCopyToFile::GetTypesWithoutPartitions(col_types, part_cols, false); - chunk.InitializeEmpty(types); - set part_col_set(part_cols.begin(), part_cols.end()); - idx_t new_col_id = 0; - for (idx_t col_idx = 0; col_idx < source.ColumnCount(); col_idx++) { - if (part_col_set.find(col_idx) == part_col_set.end()) { - chunk.data[new_col_id].Reference(source.data[col_idx]); - new_col_id++; - } - } - chunk.SetCardinality(source.size()); - } - - void FlushPartitions(ExecutionContext &context, const PhysicalCopyToFile &op, CopyToFunctionGlobalState &g) { - if (!part_buffer) { - return; - } - part_buffer->FlushAppendState(*part_buffer_append_state); - auto &partitions = part_buffer->GetPartitions(); - auto partition_key_map = part_buffer->GetReverseMap(); - - for (idx_t i = 0; i < partitions.size(); i++) { - auto entry = partition_key_map.find(i); - if (entry == partition_key_map.end()) { - continue; - } - // get the partition write info for this buffer - auto &info = g.GetPartitionWriteInfo(context, op, entry->second->values); - - auto local_copy_state = op.function.copy_to_initialize_local(context, *op.bind_data); - // push the chunks into the write state - for (auto &chunk : partitions[i]->Chunks()) { - if (op.write_partition_columns) { - op.function.copy_to_sink(context, *op.bind_data, *info.global_state, *local_copy_state, chunk); - } else { - DataChunk filtered_chunk; - SetDataWithoutPartitions(filtered_chunk, chunk, op.expected_types, op.partition_columns); - op.function.copy_to_sink(context, *op.bind_data, *info.global_state, *local_copy_state, - filtered_chunk); - } - } - op.function.copy_to_combine(context, *op.bind_data, *info.global_state, *local_copy_state); - local_copy_state.reset(); - partitions[i].reset(); - g.FinishPartitionWrite(info); - } - ResetAppendState(); - } -}; - -unique_ptr PhysicalCopyToFile::CreateFileState(ClientContext &context, GlobalSinkState &sink, - StorageLockKey &global_lock) const { - auto &g = sink.Cast(); - idx_t this_file_offset = g.last_file_offset++; - auto &fs = FileSystem::GetFileSystem(context); - string output_path(filename_pattern.CreateFilename(fs, file_path, file_extension, this_file_offset)); - if (return_type == CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST) { - g.AddFileName(global_lock, output_path); - } - return function.copy_to_initialize_global(context, *bind_data, output_path); -} - -unique_ptr PhysicalCopyToFile::GetLocalSinkState(ExecutionContext &context) const { - if (partition_output) { - auto &g = sink_state->Cast(); - - auto state = make_uniq(nullptr); - state->InitializeAppendState(context.client, *this, g); - return std::move(state); - } - auto res = make_uniq(function.copy_to_initialize_local(context, *bind_data)); - return std::move(res); -} - -void CheckDirectory(FileSystem &fs, const string &file_path, CopyOverwriteMode overwrite_mode) { - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE_OR_IGNORE || - overwrite_mode == CopyOverwriteMode::COPY_APPEND) { - // with overwrite or ignore we fully ignore the presence of any files instead of erasing them - return; - } - if (fs.IsRemoteFile(file_path) && overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { - // we can only remove files for local file systems currently - // as remote file systems (e.g. S3) do not support RemoveFile - throw NotImplementedException("OVERWRITE is not supported for remote file systems"); - } - vector file_list; - vector directory_list; - directory_list.push_back(file_path); - for (idx_t dir_idx = 0; dir_idx < directory_list.size(); dir_idx++) { - auto directory = directory_list[dir_idx]; - fs.ListFiles(directory, [&](const string &path, bool is_directory) { - auto full_path = fs.JoinPath(directory, path); - if (is_directory) { - directory_list.emplace_back(std::move(full_path)); - } else { - file_list.emplace_back(std::move(full_path)); - } - }); - } - if (file_list.empty()) { - return; - } - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { - for (auto &file : file_list) { - fs.RemoveFile(file); - } - } else { - throw IOException("Directory \"%s\" is not empty! Enable OVERWRITE option to overwrite files", file_path); - } -} - -unique_ptr PhysicalCopyToFile::GetGlobalSinkState(ClientContext &context) const { - if (partition_output || per_thread_output || rotate) { - auto &fs = FileSystem::GetFileSystem(context); - if (fs.FileExists(file_path)) { - // the target file exists AND is a file (not a directory) - if (fs.IsRemoteFile(file_path)) { - // for remote files we cannot do anything - as we cannot delete the file - throw IOException("Cannot write to \"%s\" - it exists and is a file, not a directory!", file_path); - } else { - // for local files we can remove the file if OVERWRITE_OR_IGNORE is enabled - if (overwrite_mode == CopyOverwriteMode::COPY_OVERWRITE) { - fs.RemoveFile(file_path); - } else { - throw IOException("Cannot write to \"%s\" - it exists and is a file, not a directory! Enable " - "OVERWRITE option to overwrite the file", - file_path); - } - } - } - // what if the target exists and is a directory - if (!fs.DirectoryExists(file_path)) { - fs.CreateDirectory(file_path); - } else { - CheckDirectory(fs, file_path, overwrite_mode); - } - - auto state = make_uniq(context, nullptr); - if (!per_thread_output && rotate) { - auto global_lock = state->lock.GetExclusiveLock(); - state->global_state = CreateFileState(context, *state, *global_lock); - } - - if (partition_output) { - state->partition_state = make_shared_ptr(); - } - - return std::move(state); - } - - auto state = make_uniq( - context, function.copy_to_initialize_global(context, *bind_data, file_path)); - if (use_tmp_file) { - auto global_lock = state->lock.GetExclusiveLock(); - state->AddFileName(*global_lock, file_path); - } else { - state->file_names.emplace_back(file_path); - } - return std::move(state); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -void PhysicalCopyToFile::MoveTmpFile(ClientContext &context, const string &tmp_file_path) { - auto &fs = FileSystem::GetFileSystem(context); - auto file_path = GetNonTmpFile(context, tmp_file_path); - if (fs.FileExists(file_path)) { - fs.RemoveFile(file_path); - } - fs.MoveFile(tmp_file_path, file_path); -} - -string PhysicalCopyToFile::GetNonTmpFile(ClientContext &context, const string &tmp_file_path) { - auto &fs = FileSystem::GetFileSystem(context); - - auto path = StringUtil::GetFilePath(tmp_file_path); - auto base = StringUtil::GetFileName(tmp_file_path); - - auto prefix = base.find("tmp_"); - if (prefix == 0) { - base = base.substr(4); - } - - return fs.JoinPath(path, base); -} - -PhysicalCopyToFile::PhysicalCopyToFile(vector types, CopyFunction function_p, - unique_ptr bind_data, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::COPY_TO_FILE, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data)), parallel(false) { -} - -SinkResultType PhysicalCopyToFile::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &g = input.global_state.Cast(); - auto &l = input.local_state.Cast(); - - g.rows_copied += chunk.size(); - - if (partition_output) { - l.AppendToPartition(context, *this, g, chunk); - return SinkResultType::NEED_MORE_INPUT; - } - - if (per_thread_output) { - auto &gstate = l.global_state; - if (!gstate) { - // Lazily create file state here to prevent creating empty files - auto global_lock = g.lock.GetExclusiveLock(); - gstate = CreateFileState(context.client, *sink_state, *global_lock); - } else if (rotate && function.rotate_next_file(*gstate, *bind_data, file_size_bytes)) { - function.copy_to_finalize(context.client, *bind_data, *gstate); - auto global_lock = g.lock.GetExclusiveLock(); - gstate = CreateFileState(context.client, *sink_state, *global_lock); - } - function.copy_to_sink(context, *bind_data, *gstate, *l.local_state, chunk); - return SinkResultType::NEED_MORE_INPUT; - } - - if (!file_size_bytes.IsValid() && !rotate) { - function.copy_to_sink(context, *bind_data, *g.global_state, *l.local_state, chunk); - return SinkResultType::NEED_MORE_INPUT; - } - - // FILE_SIZE_BYTES/rotate is set, but threads write to the same file, synchronize using lock - auto &gstate = g.global_state; - auto global_lock = g.lock.GetExclusiveLock(); - if (rotate && function.rotate_next_file(*gstate, *bind_data, file_size_bytes)) { - auto owned_gstate = std::move(gstate); - gstate = CreateFileState(context.client, *sink_state, *global_lock); - global_lock.reset(); - function.copy_to_finalize(context.client, *bind_data, *owned_gstate); - } else { - global_lock.reset(); - } - - global_lock = g.lock.GetSharedLock(); - function.copy_to_sink(context, *bind_data, *gstate, *l.local_state, chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalCopyToFile::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &g = input.global_state.Cast(); - auto &l = input.local_state.Cast(); - - if (partition_output) { - // flush all remaining partitions - l.FlushPartitions(context, *this, g); - } else if (function.copy_to_combine) { - if (per_thread_output) { - // For PER_THREAD_OUTPUT, we can combine/finalize immediately (if there is a gstate) - if (l.global_state) { - function.copy_to_combine(context, *bind_data, *l.global_state, *l.local_state); - function.copy_to_finalize(context.client, *bind_data, *l.global_state); - } - } else if (rotate) { - // File in global state may change with FILE_SIZE_BYTES/rotate, need to grab lock - auto lock = g.lock.GetSharedLock(); - function.copy_to_combine(context, *bind_data, *g.global_state, *l.local_state); - } else { - function.copy_to_combine(context, *bind_data, *g.global_state, *l.local_state); - } - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalCopyToFile::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - if (partition_output) { - // finalize any outstanding partitions - gstate.FinalizePartitions(context, *this); - return SinkFinalizeType::READY; - } - if (per_thread_output) { - // already happened in combine - if (NumericCast(gstate.rows_copied.load()) == 0 && sink_state != nullptr) { - // no rows from source, write schema to file - auto global_lock = gstate.lock.GetExclusiveLock(); - gstate.global_state = CreateFileState(context, *sink_state, *global_lock); - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - } - return SinkFinalizeType::READY; - } - if (function.copy_to_finalize) { - function.copy_to_finalize(context, *bind_data, *gstate.global_state); - - if (use_tmp_file) { - D_ASSERT(!per_thread_output); - D_ASSERT(!partition_output); - D_ASSERT(!file_size_bytes.IsValid()); - D_ASSERT(!rotate); - MoveTmpFile(context, file_path); - } - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalCopyToFile::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &g = sink_state->Cast(); - - chunk.SetCardinality(1); - switch (return_type) { - case CopyFunctionReturnType::CHANGED_ROWS: - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); - break; - case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.rows_copied.load()))); - chunk.SetValue(1, 0, Value::LIST(LogicalType::VARCHAR, g.file_names)); - break; - default: - throw NotImplementedException("Unknown CopyFunctionReturnType"); - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp b/src/duckdb/src/execution/operator/persistent/physical_delete.cpp deleted file mode 100644 index e92e2ec66..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_delete.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_delete.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/atomic.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/execution/index/bound_index.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table/delete_state.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/transaction/duck_transaction.hpp" - -namespace duckdb { - -PhysicalDelete::PhysicalDelete(vector types, TableCatalogEntry &tableref, DataTable &table, - vector> bound_constraints, idx_t row_id_index, - idx_t estimated_cardinality, bool return_chunk) - : PhysicalOperator(PhysicalOperatorType::DELETE_OPERATOR, std::move(types), estimated_cardinality), - tableref(tableref), table(table), bound_constraints(std::move(bound_constraints)), row_id_index(row_id_index), - return_chunk(return_chunk) { -} -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class DeleteGlobalState : public GlobalSinkState { -public: - explicit DeleteGlobalState(ClientContext &context, const vector &return_types, - TableCatalogEntry &table, const vector> &bound_constraints) - : deleted_count(0), return_collection(context, return_types), has_unique_indexes(false) { - - // We need to append deletes to the local delete-ART. - auto &storage = table.GetStorage(); - if (storage.HasUniqueIndexes()) { - storage.InitializeLocalStorage(delete_index_append_state, table, context, bound_constraints); - has_unique_indexes = true; - } - } - - mutex delete_lock; - idx_t deleted_count; - ColumnDataCollection return_collection; - LocalAppendState delete_index_append_state; - bool has_unique_indexes; -}; - -class DeleteLocalState : public LocalSinkState { -public: - DeleteLocalState(ClientContext &context, TableCatalogEntry &table, - const vector> &bound_constraints) { - delete_chunk.Initialize(Allocator::Get(context), table.GetTypes()); - auto &storage = table.GetStorage(); - delete_state = storage.InitializeDelete(table, context, bound_constraints); - } - -public: - DataChunk delete_chunk; - unique_ptr delete_state; -}; - -SinkResultType PhysicalDelete::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &g_state = input.global_state.Cast(); - auto &l_state = input.local_state.Cast(); - - auto &transaction = DuckTransaction::Get(context.client, table.db); - auto &row_ids = chunk.data[row_id_index]; - - vector column_ids; - for (idx_t i = 0; i < table.ColumnCount(); i++) { - column_ids.emplace_back(i); - }; - auto fetch_state = ColumnFetchState(); - - lock_guard delete_guard(g_state.delete_lock); - if (!return_chunk && !g_state.has_unique_indexes) { - g_state.deleted_count += table.Delete(*l_state.delete_state, context.client, row_ids, chunk.size()); - return SinkResultType::NEED_MORE_INPUT; - } - - // Fetch the to-be-deleted chunk. - l_state.delete_chunk.Reset(); - row_ids.Flatten(chunk.size()); - table.Fetch(transaction, l_state.delete_chunk, column_ids, row_ids, chunk.size(), fetch_state); - - // Append the deleted row IDs to the delete indexes. - // If we only delete local row IDs, then the delete_chunk is empty. - if (g_state.has_unique_indexes && l_state.delete_chunk.size() != 0) { - auto &local_storage = LocalStorage::Get(context.client, table.db); - auto storage = local_storage.GetStorage(table); - storage->delete_indexes.Scan([&](Index &index) { - if (!index.IsBound() || !index.IsUnique()) { - return false; - } - auto &bound_index = index.Cast(); - auto error = bound_index.Append(l_state.delete_chunk, row_ids); - if (error.HasError()) { - throw InternalException("failed to update delete ART in physical delete: ", error.Message()); - } - return false; - }); - } - - // Append the return_chunk to the return collection. - if (return_chunk) { - g_state.return_collection.Append(l_state.delete_chunk); - } - - g_state.deleted_count += table.Delete(*l_state.delete_state, context.client, row_ids, chunk.size()); - return SinkResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalDelete::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, GetTypes(), tableref, bound_constraints); -} - -unique_ptr PhysicalDelete::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, tableref, bound_constraints); -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class DeleteSourceState : public GlobalSourceState { -public: - explicit DeleteSourceState(const PhysicalDelete &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalDelete::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalDelete::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &g = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.deleted_count))); - return SourceResultType::FINISHED; - } - - g.return_collection.Scan(state.scan_state, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_export.cpp b/src/duckdb/src/execution/operator/persistent/physical_export.cpp deleted file mode 100644 index 66afb4c97..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_export.cpp +++ /dev/null @@ -1,285 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_export.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parser/keyword_helper.hpp" -#include "duckdb/transaction/transaction.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/dependency_manager.hpp" - -#include -#include - -namespace duckdb { - -void ReorderTableEntries(catalog_entry_vector_t &tables); - -using std::stringstream; - -PhysicalExport::PhysicalExport(vector types, CopyFunction function, unique_ptr info, - idx_t estimated_cardinality, unique_ptr exported_tables) - : PhysicalOperator(PhysicalOperatorType::EXPORT, std::move(types), estimated_cardinality), - function(std::move(function)), info(std::move(info)), exported_tables(std::move(exported_tables)) { -} - -static void WriteCatalogEntries(stringstream &ss, catalog_entry_vector_t &entries) { - for (auto &entry : entries) { - if (entry.get().internal) { - continue; - } - auto create_info = entry.get().GetInfo(); - try { - // Strip the catalog from the info - create_info->catalog.clear(); - auto to_string = create_info->ToString(); - ss << to_string; - } catch (const NotImplementedException &) { - ss << entry.get().ToSQL(); - } - ss << '\n'; - } - ss << '\n'; -} - -static void WriteStringStreamToFile(FileSystem &fs, stringstream &ss, const string &path) { - auto ss_string = ss.str(); - auto handle = fs.OpenFile(path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW | - FileLockType::WRITE_LOCK); - fs.Write(*handle, (void *)ss_string.c_str(), NumericCast(ss_string.size())); - handle.reset(); -} - -static void WriteCopyStatement(FileSystem &fs, stringstream &ss, CopyInfo &info, ExportedTableData &exported_table, - CopyFunction const &function) { - ss << "COPY "; - - //! NOTE: The catalog is explicitly not set here - if (exported_table.schema_name != DEFAULT_SCHEMA && !exported_table.schema_name.empty()) { - ss << KeywordHelper::WriteOptionallyQuoted(exported_table.schema_name) << "."; - } - - auto file_path = StringUtil::Replace(exported_table.file_path, "\\", "/"); - ss << StringUtil::Format("%s FROM %s (", SQLIdentifier(exported_table.table_name), SQLString(file_path)); - // write the copy options - ss << "FORMAT '" << info.format << "'"; - if (info.format == "csv") { - // insert default csv options, if not specified - if (info.options.find("header") == info.options.end()) { - info.options["header"].push_back(Value::INTEGER(1)); - } - if (info.options.find("delimiter") == info.options.end() && info.options.find("sep") == info.options.end() && - info.options.find("delim") == info.options.end()) { - info.options["delimiter"].push_back(Value(",")); - } - if (info.options.find("quote") == info.options.end()) { - info.options["quote"].push_back(Value("\"")); - } - info.options.erase("force_not_null"); - for (auto ¬_null_column : exported_table.not_null_columns) { - info.options["force_not_null"].push_back(not_null_column); - } - } - for (auto ©_option : info.options) { - if (copy_option.first == "force_quote") { - continue; - } - if (copy_option.second.empty()) { - // empty options are interpreted as TRUE - copy_option.second.push_back(true); - } - ss << ", " << copy_option.first << " "; - if (copy_option.second.size() == 1) { - ss << copy_option.second[0].ToSQLString(); - } else { - // For Lists - ss << "("; - for (idx_t i = 0; i < copy_option.second.size(); i++) { - ss << copy_option.second[i].ToSQLString(); - if (i != copy_option.second.size() - 1) { - ss << ", "; - } - } - ss << ")"; - } - } - ss << ");" << '\n'; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class ExportSourceState : public GlobalSourceState { -public: - ExportSourceState() : finished(false) { - } - - bool finished; -}; - -unique_ptr PhysicalExport::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -void PhysicalExport::ExtractEntries(ClientContext &context, vector> &schema_list, - ExportEntries &result) { - for (auto &schema_p : schema_list) { - auto &schema = schema_p.get(); - auto &catalog = schema.ParentCatalog(); - if (catalog.IsSystemCatalog() || catalog.IsTemporaryCatalog()) { - continue; - } - if (!schema.internal) { - result.schemas.push_back(schema); - } - schema.Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - if (entry.type != CatalogType::TABLE_ENTRY) { - result.views.push_back(entry); - } - if (entry.type == CatalogType::TABLE_ENTRY) { - result.tables.push_back(entry); - } - }); - schema.Scan(context, CatalogType::SEQUENCE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - result.sequences.push_back(entry); - }); - schema.Scan(context, CatalogType::TYPE_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - result.custom_types.push_back(entry); - }); - schema.Scan(context, CatalogType::INDEX_ENTRY, [&](CatalogEntry &entry) { - if (entry.internal) { - return; - } - result.indexes.push_back(entry); - }); - schema.Scan(context, CatalogType::MACRO_ENTRY, [&](CatalogEntry &entry) { - if (!entry.internal && entry.type == CatalogType::MACRO_ENTRY) { - result.macros.push_back(entry); - } - }); - schema.Scan(context, CatalogType::TABLE_MACRO_ENTRY, [&](CatalogEntry &entry) { - if (!entry.internal && entry.type == CatalogType::TABLE_MACRO_ENTRY) { - result.macros.push_back(entry); - } - }); - } -} - -static void AddEntries(catalog_entry_vector_t &all_entries, catalog_entry_vector_t &to_add) { - for (auto &entry : to_add) { - all_entries.push_back(entry); - } - to_add.clear(); -} - -catalog_entry_vector_t PhysicalExport::GetNaiveExportOrder(ClientContext &context, Catalog &catalog) { - // gather all catalog types to export - ExportEntries entries; - auto schema_list = catalog.GetSchemas(context); - PhysicalExport::ExtractEntries(context, schema_list, entries); - - ReorderTableEntries(entries.tables); - - // order macro's by timestamp so nested macro's are imported nicely - sort(entries.macros.begin(), entries.macros.end(), - [](const reference &lhs, const reference &rhs) { - return lhs.get().oid < rhs.get().oid; - }); - - catalog_entry_vector_t catalog_entries; - idx_t size = 0; - size += entries.schemas.size(); - size += entries.custom_types.size(); - size += entries.sequences.size(); - size += entries.tables.size(); - size += entries.views.size(); - size += entries.indexes.size(); - size += entries.macros.size(); - catalog_entries.reserve(size); - AddEntries(catalog_entries, entries.schemas); - AddEntries(catalog_entries, entries.sequences); - AddEntries(catalog_entries, entries.custom_types); - AddEntries(catalog_entries, entries.tables); - AddEntries(catalog_entries, entries.macros); - AddEntries(catalog_entries, entries.views); - AddEntries(catalog_entries, entries.indexes); - return catalog_entries; -} - -SourceResultType PhysicalExport::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - if (state.finished) { - return SourceResultType::FINISHED; - } - - auto &ccontext = context.client; - auto &fs = FileSystem::GetFileSystem(ccontext); - - auto &catalog = Catalog::GetCatalog(ccontext, info->catalog); - - catalog_entry_vector_t catalog_entries; - catalog_entries = GetNaiveExportOrder(context.client, catalog); - auto dependency_manager = catalog.GetDependencyManager(); - if (dependency_manager) { - dependency_manager->ReorderEntries(catalog_entries, ccontext); - } - - // write the schema.sql file - stringstream ss; - WriteCatalogEntries(ss, catalog_entries); - WriteStringStreamToFile(fs, ss, fs.JoinPath(info->file_path, "schema.sql")); - - // write the load.sql file - // for every table, we write COPY INTO statement with the specified options - stringstream load_ss; - for (idx_t i = 0; i < exported_tables->data.size(); i++) { - auto exported_table_info = exported_tables->data[i].table_data; - WriteCopyStatement(fs, load_ss, *info, exported_table_info, function); - } - WriteStringStreamToFile(fs, load_ss, fs.JoinPath(info->file_path, "load.sql")); - state.finished = true; - - return SourceResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -SinkResultType PhysicalExport::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - // nop - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalExport::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // EXPORT has an optional child - // we only need to schedule child pipelines if there is a child - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - if (children.empty()) { - return; - } - PhysicalOperator::BuildPipelines(current, meta_pipeline); -} - -vector> PhysicalExport::GetSources() const { - return {*this}; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp b/src/duckdb/src/execution/operator/persistent/physical_insert.cpp deleted file mode 100644 index dac3491fe..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_insert.cpp +++ /dev/null @@ -1,768 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_insert.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/storage/table_io_manager.hpp" -#include "duckdb/transaction/local_storage.hpp" -#include "duckdb/parser/statement/insert_statement.hpp" -#include "duckdb/parser/statement/update_statement.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/common/types/conflict_manager.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/storage/table/append_state.hpp" -#include "duckdb/storage/table/update_state.hpp" -#include "duckdb/function/create_sort_key.hpp" - -namespace duckdb { - -PhysicalInsert::PhysicalInsert( - vector types_p, TableCatalogEntry &table, physical_index_vector_t column_index_map, - vector> bound_defaults, vector> bound_constraints_p, - vector> set_expressions, vector set_columns, vector set_types, - idx_t estimated_cardinality, bool return_chunk, bool parallel, OnConflictAction action_type, - unique_ptr on_conflict_condition_p, unique_ptr do_update_condition_p, - unordered_set conflict_target_p, vector columns_to_fetch_p, bool update_is_del_and_insert) - : PhysicalOperator(PhysicalOperatorType::INSERT, std::move(types_p), estimated_cardinality), - column_index_map(std::move(column_index_map)), insert_table(&table), insert_types(table.GetTypes()), - bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints_p)), - return_chunk(return_chunk), parallel(parallel), action_type(action_type), - set_expressions(std::move(set_expressions)), set_columns(std::move(set_columns)), set_types(std::move(set_types)), - on_conflict_condition(std::move(on_conflict_condition_p)), do_update_condition(std::move(do_update_condition_p)), - conflict_target(std::move(conflict_target_p)), update_is_del_and_insert(update_is_del_and_insert) { - - if (action_type == OnConflictAction::THROW) { - return; - } - - D_ASSERT(this->set_expressions.size() == this->set_columns.size()); - - // One or more columns are referenced from the existing table, - // we use the 'insert_types' to figure out which types these columns have - types_to_fetch = vector(columns_to_fetch_p.size(), LogicalType::SQLNULL); - for (idx_t i = 0; i < columns_to_fetch_p.size(); i++) { - auto &id = columns_to_fetch_p[i]; - D_ASSERT(id < insert_types.size()); - types_to_fetch[i] = insert_types[id]; - columns_to_fetch.emplace_back(id); - } -} - -PhysicalInsert::PhysicalInsert(LogicalOperator &op, SchemaCatalogEntry &schema, unique_ptr info_p, - idx_t estimated_cardinality, bool parallel) - : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE_AS, op.types, estimated_cardinality), insert_table(nullptr), - return_chunk(false), schema(&schema), info(std::move(info_p)), parallel(parallel), - action_type(OnConflictAction::THROW), update_is_del_and_insert(false) { - GetInsertInfo(*info, insert_types, bound_defaults); -} - -void PhysicalInsert::GetInsertInfo(const BoundCreateTableInfo &info, vector &insert_types, - vector> &bound_defaults) { - auto &create_info = info.base->Cast(); - for (auto &col : create_info.columns.Physical()) { - insert_types.push_back(col.GetType()); - bound_defaults.push_back(make_uniq(Value(col.GetType()))); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -InsertGlobalState::InsertGlobalState(ClientContext &context, const vector &return_types, - DuckTableEntry &table) - : table(table), insert_count(0), initialized(false), return_collection(context, return_types) { -} - -InsertLocalState::InsertLocalState(ClientContext &context, const vector &types, - const vector> &bound_defaults, - const vector> &bound_constraints) - : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { - - auto &allocator = Allocator::Get(context); - insert_chunk.Initialize(allocator, types); - update_chunk.Initialize(allocator, types); - append_chunk.Initialize(allocator, types); -} - -ConstraintState &InsertLocalState::GetConstraintState(DataTable &table, TableCatalogEntry &table_ref) { - if (!constraint_state) { - constraint_state = table.InitializeConstraintState(table_ref, bound_constraints); - } - return *constraint_state; -} - -TableDeleteState &InsertLocalState::GetDeleteState(DataTable &table, TableCatalogEntry &table_ref, - ClientContext &context) { - if (!delete_state) { - delete_state = table.InitializeDelete(table_ref, context, bound_constraints); - } - return *delete_state; -} - -unique_ptr PhysicalInsert::GetGlobalSinkState(ClientContext &context) const { - optional_ptr table; - if (info) { - // CREATE TABLE AS - D_ASSERT(!insert_table); - auto &catalog = schema->catalog; - table = &catalog.CreateTable(catalog.GetCatalogTransaction(context), *schema.get_mutable(), *info) - ->Cast(); - } else { - D_ASSERT(insert_table); - D_ASSERT(insert_table->IsDuckTable()); - table = insert_table.get_mutable(); - } - auto result = make_uniq(context, GetTypes(), table->Cast()); - return std::move(result); -} - -unique_ptr PhysicalInsert::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, insert_types, bound_defaults, bound_constraints); -} - -void PhysicalInsert::ResolveDefaults(const TableCatalogEntry &table, DataChunk &chunk, - const physical_index_vector_t &column_index_map, - ExpressionExecutor &default_executor, DataChunk &result) { - chunk.Flatten(); - default_executor.SetChunk(chunk); - - result.Reset(); - result.SetCardinality(chunk); - - if (!column_index_map.empty()) { - // columns specified by the user, use column_index_map - for (auto &col : table.GetColumns().Physical()) { - auto storage_idx = col.StorageOid(); - auto mapped_index = column_index_map[col.Physical()]; - if (mapped_index == DConstants::INVALID_INDEX) { - // insert default value - default_executor.ExecuteExpression(storage_idx, result.data[storage_idx]); - } else { - // get value from child chunk - D_ASSERT((idx_t)mapped_index < chunk.ColumnCount()); - D_ASSERT(result.data[storage_idx].GetType() == chunk.data[mapped_index].GetType()); - result.data[storage_idx].Reference(chunk.data[mapped_index]); - } - } - } else { - // no columns specified, just append directly - for (idx_t i = 0; i < result.ColumnCount(); i++) { - D_ASSERT(result.data[i].GetType() == chunk.data[i].GetType()); - result.data[i].Reference(chunk.data[i]); - } - } -} - -bool AllConflictsMeetCondition(DataChunk &result) { - result.Flatten(); - auto data = FlatVector::GetData(result.data[0]); - for (idx_t i = 0; i < result.size(); i++) { - if (!data[i]) { - return false; - } - } - return true; -} - -void CheckOnConflictCondition(ExecutionContext &context, DataChunk &conflicts, const unique_ptr &condition, - DataChunk &result) { - ExpressionExecutor executor(context.client, *condition); - result.Initialize(context.client, {LogicalType::BOOLEAN}); - executor.Execute(conflicts, result); - result.SetCardinality(conflicts.size()); -} - -static void CombineExistingAndInsertTuples(DataChunk &result, DataChunk &scan_chunk, DataChunk &input_chunk, - ClientContext &client, const PhysicalInsert &op) { - auto &types_to_fetch = op.types_to_fetch; - auto &insert_types = op.insert_types; - - if (types_to_fetch.empty()) { - // We have not scanned the initial table, so we can just duplicate the initial chunk - result.Initialize(client, input_chunk.GetTypes()); - result.Reference(input_chunk); - result.SetCardinality(input_chunk); - return; - } - vector combined_types; - combined_types.reserve(insert_types.size() + types_to_fetch.size()); - combined_types.insert(combined_types.end(), insert_types.begin(), insert_types.end()); - combined_types.insert(combined_types.end(), types_to_fetch.begin(), types_to_fetch.end()); - - result.Initialize(client, combined_types); - result.Reset(); - // Add the VALUES list - for (idx_t i = 0; i < insert_types.size(); i++) { - idx_t col_idx = i; - auto &other_col = input_chunk.data[i]; - auto &this_col = result.data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - // Add the columns from the original conflicting tuples - for (idx_t i = 0; i < types_to_fetch.size(); i++) { - idx_t col_idx = i + insert_types.size(); - auto &other_col = scan_chunk.data[i]; - auto &this_col = result.data[col_idx]; - D_ASSERT(other_col.GetType() == this_col.GetType()); - this_col.Reference(other_col); - } - // This is guaranteed by the requirement of a conflict target to have a condition or set expressions - // Only when we have any sort of condition or SET expression that references the existing table is this possible - // to not be true. - // We can have a SET expression without a conflict target ONLY if there is only 1 Index on the table - // In which case this also can't cause a discrepancy between existing tuple count and insert tuple count - D_ASSERT(input_chunk.size() == scan_chunk.size()); - result.SetCardinality(input_chunk.size()); -} - -static void CreateUpdateChunk(ExecutionContext &context, DataChunk &chunk, TableCatalogEntry &table, Vector &row_ids, - DataChunk &update_chunk, const PhysicalInsert &op) { - - auto &do_update_condition = op.do_update_condition; - auto &set_types = op.set_types; - auto &set_expressions = op.set_expressions; - // Check the optional condition for the DO UPDATE clause, to filter which rows will be updated - if (do_update_condition) { - DataChunk do_update_filter_result; - do_update_filter_result.Initialize(context.client, {LogicalType::BOOLEAN}); - ExpressionExecutor where_executor(context.client, *do_update_condition); - where_executor.Execute(chunk, do_update_filter_result); - do_update_filter_result.SetCardinality(chunk.size()); - do_update_filter_result.Flatten(); - - ManagedSelection selection(chunk.size()); - - auto where_data = FlatVector::GetData(do_update_filter_result.data[0]); - for (idx_t i = 0; i < chunk.size(); i++) { - if (where_data[i]) { - selection.Append(i); - } - } - if (selection.Count() != selection.Size()) { - // Not all conflicts met the condition, need to filter out the ones that don't - chunk.Slice(selection.Selection(), selection.Count()); - chunk.SetCardinality(selection.Count()); - // Also apply this Slice to the to-update row_ids - row_ids.Slice(selection.Selection(), selection.Count()); - } - } - - // Execute the SET expressions - update_chunk.Initialize(context.client, set_types); - ExpressionExecutor executor(context.client, set_expressions); - executor.Execute(chunk, update_chunk); - update_chunk.SetCardinality(chunk); -} - -template -static idx_t PerformOnConflictAction(InsertLocalState &lstate, ExecutionContext &context, DataChunk &chunk, - TableCatalogEntry &table, Vector &row_ids, const PhysicalInsert &op) { - // Early-out, if we do nothing on conflicting rows. - if (op.action_type == OnConflictAction::NOTHING) { - return 0; - } - - auto &set_columns = op.set_columns; - DataChunk update_chunk; - CreateUpdateChunk(context, chunk, table, row_ids, update_chunk, op); - auto &data_table = table.GetStorage(); - - // Perform the UPDATE on the (global) storage. - if (!op.update_is_del_and_insert) { - if (GLOBAL) { - auto update_state = data_table.InitializeUpdate(table, context.client, op.bound_constraints); - data_table.Update(*update_state, context.client, row_ids, set_columns, update_chunk); - return update_chunk.size(); - } - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - local_storage.Update(data_table, row_ids, set_columns, update_chunk); - return update_chunk.size(); - } - - // Arrange the columns in the standard table order. - DataChunk &append_chunk = lstate.append_chunk; - append_chunk.SetCardinality(update_chunk); - for (idx_t i = 0; i < append_chunk.ColumnCount(); i++) { - append_chunk.data[i].Reference(chunk.data[i]); - } - for (idx_t i = 0; i < set_columns.size(); i++) { - append_chunk.data[set_columns[i].index].Reference(update_chunk.data[i]); - } - - if (GLOBAL) { - auto &delete_state = lstate.GetDeleteState(data_table, table, context.client); - data_table.Delete(delete_state, context.client, row_ids, update_chunk.size()); - } else { - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - local_storage.Delete(data_table, row_ids, update_chunk.size()); - } - - data_table.LocalAppend(table, context.client, append_chunk, op.bound_constraints, row_ids, append_chunk); - return update_chunk.size(); -} - -// TODO: should we use a hash table to keep track of this instead? -static void RegisterUpdatedRows(InsertLocalState &lstate, const Vector &row_ids, idx_t count) { - // Insert all rows, if any of the rows has already been updated before, we throw an error - auto data = FlatVector::GetData(row_ids); - - auto &updated_rows = lstate.updated_rows; - for (idx_t i = 0; i < count; i++) { - auto result = updated_rows.insert(data[i]); - if (result.second == false) { - // This is following postgres behavior: - throw InvalidInputException( - "ON CONFLICT DO UPDATE can not update the same row twice in the same command. Ensure that no rows " - "proposed for insertion within the same command have duplicate constrained values"); - } - } -} - -static void CheckDistinctnessInternal(ValidityMask &valid, vector> &sort_keys, idx_t count, - map> &result) { - for (idx_t i = 0; i < count; i++) { - bool has_conflicts = false; - for (idx_t j = i + 1; j < count; j++) { - if (!valid.RowIsValid(j)) { - // Already a conflict - continue; - } - bool matches = true; - for (auto &sort_key : sort_keys) { - auto &this_row = FlatVector::GetData(sort_key.get())[i]; - auto &other_row = FlatVector::GetData(sort_key.get())[j]; - if (this_row != other_row) { - matches = false; - break; - } - } - if (matches) { - auto &row_ids = result[i]; - has_conflicts = true; - row_ids.push_back(j); - valid.SetInvalid(j); - } - } - if (has_conflicts) { - valid.SetInvalid(i); - } - } -} - -void PrepareSortKeys(DataChunk &input, unordered_map> &sort_keys, - const unordered_set &column_ids) { - OrderModifiers order_modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - for (auto &it : column_ids) { - auto &sort_key = sort_keys[it]; - if (sort_key != nullptr) { - continue; - } - auto &column = input.data[it]; - sort_key = make_uniq(LogicalType::BLOB); - CreateSortKeyHelpers::CreateSortKey(column, input.size(), order_modifiers, *sort_key); - } -} - -static map> CheckDistinctness(DataChunk &input, ConflictInfo &info, - unordered_set &matched_indexes) { - map> conflicts; - unordered_map> sort_keys; - //! Register which rows have already caused a conflict - ValidityMask valid(input.size()); - - auto &column_ids = info.column_ids; - if (column_ids.empty()) { - for (auto index : matched_indexes) { - auto &index_column_ids = index->GetColumnIdSet(); - PrepareSortKeys(input, sort_keys, index_column_ids); - vector> columns; - for (auto &idx : index_column_ids) { - columns.push_back(*sort_keys[idx]); - } - CheckDistinctnessInternal(valid, columns, input.size(), conflicts); - } - } else { - PrepareSortKeys(input, sort_keys, column_ids); - vector> columns; - for (auto &idx : column_ids) { - columns.push_back(*sort_keys[idx]); - } - CheckDistinctnessInternal(valid, columns, input.size(), conflicts); - } - return conflicts; -} - -template -static void VerifyOnConflictCondition(ExecutionContext &context, DataChunk &combined_chunk, - const unique_ptr &on_conflict_condition, - ConstraintState &constraint_state, DataChunk &tuples, DataTable &data_table, - LocalStorage &local_storage) { - if (!on_conflict_condition) { - return; - } - DataChunk conflict_condition_result; - CheckOnConflictCondition(context, combined_chunk, on_conflict_condition, conflict_condition_result); - bool conditions_met = AllConflictsMeetCondition(conflict_condition_result); - if (conditions_met) { - return; - } - - // We need to throw. Filter all tuples that passed, and verify again with those that violate the constraint. - ManagedSelection sel(combined_chunk.size()); - auto data = FlatVector::GetData(conflict_condition_result.data[0]); - for (idx_t i = 0; i < combined_chunk.size(); i++) { - if (!data[i]) { - // This tuple did not meet the condition. - sel.Append(i); - } - } - combined_chunk.Slice(sel.Selection(), sel.Count()); - - // Verify and throw. - if (GLOBAL) { - data_table.VerifyAppendConstraints(constraint_state, context.client, combined_chunk, nullptr, nullptr); - throw InternalException("VerifyAppendConstraints was expected to throw but didn't"); - } - - auto &indexes = local_storage.GetIndexes(data_table); - auto storage = local_storage.GetStorage(data_table); - DataTable::VerifyUniqueIndexes(indexes, storage, tuples, nullptr); - throw InternalException("VerifyUniqueIndexes was expected to throw but didn't"); -} - -template -static idx_t HandleInsertConflicts(TableCatalogEntry &table, ExecutionContext &context, InsertLocalState &lstate, - DataChunk &tuples, const PhysicalInsert &op) { - auto &types_to_fetch = op.types_to_fetch; - auto &on_conflict_condition = op.on_conflict_condition; - auto &conflict_target = op.conflict_target; - auto &columns_to_fetch = op.columns_to_fetch; - auto &data_table = table.GetStorage(); - - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - - ConflictInfo conflict_info(conflict_target); - ConflictManager conflict_manager(VerifyExistenceType::APPEND, tuples.size(), &conflict_info); - if (GLOBAL) { - auto &constraint_state = lstate.GetConstraintState(data_table, table); - auto storage = local_storage.GetStorage(data_table); - data_table.VerifyAppendConstraints(constraint_state, context.client, tuples, storage, &conflict_manager); - } else { - auto &indexes = local_storage.GetIndexes(data_table); - auto storage = local_storage.GetStorage(data_table); - DataTable::VerifyUniqueIndexes(indexes, storage, tuples, &conflict_manager); - } - - conflict_manager.Finalize(); - if (conflict_manager.ConflictCount() == 0) { - // No conflicts found, 0 updates performed - return 0; - } - idx_t affected_tuples = 0; - - auto &conflicts = conflict_manager.Conflicts(); - auto &row_ids = conflict_manager.RowIds(); - - DataChunk conflict_chunk; // contains only the conflicting values - DataChunk scan_chunk; // contains the original values, that caused the conflict - DataChunk combined_chunk; // contains conflict_chunk + scan_chunk (wide) - - // Filter out everything but the conflicting rows - conflict_chunk.Initialize(context.client, tuples.GetTypes()); - conflict_chunk.Reference(tuples); - conflict_chunk.Slice(conflicts.Selection(), conflicts.Count()); - conflict_chunk.SetCardinality(conflicts.Count()); - - // Holds the pins for the fetched rows - unique_ptr fetch_state; - if (!types_to_fetch.empty()) { - D_ASSERT(scan_chunk.size() == 0); - // When these values are required for the conditions or the SET expressions, - // then we scan the existing table for the conflicting tuples, using the rowids - scan_chunk.Initialize(context.client, types_to_fetch); - fetch_state = make_uniq(); - if (GLOBAL) { - auto &transaction = DuckTransaction::Get(context.client, table.catalog); - data_table.Fetch(transaction, scan_chunk, columns_to_fetch, row_ids, conflicts.Count(), *fetch_state); - } else { - local_storage.FetchChunk(data_table, row_ids, conflicts.Count(), columns_to_fetch, scan_chunk, - *fetch_state); - } - } - - // Splice the Input chunk and the fetched chunk together - CombineExistingAndInsertTuples(combined_chunk, scan_chunk, conflict_chunk, context.client, op); - - auto &constraint_state = lstate.GetConstraintState(data_table, table); - VerifyOnConflictCondition(context, combined_chunk, on_conflict_condition, constraint_state, tuples, - data_table, local_storage); - - if (&tuples == &lstate.update_chunk) { - // Allow updating duplicate rows for the 'update_chunk' - RegisterUpdatedRows(lstate, row_ids, combined_chunk.size()); - } - - affected_tuples += PerformOnConflictAction(lstate, context, combined_chunk, table, row_ids, op); - - // Remove the conflicting tuples from the insert chunk - SelectionVector sel_vec(tuples.size()); - idx_t new_size = SelectionVector::Inverted(conflicts.Selection(), sel_vec, conflicts.Count(), tuples.size()); - tuples.Slice(sel_vec, new_size); - tuples.SetCardinality(new_size); - return affected_tuples; -} - -idx_t PhysicalInsert::OnConflictHandling(TableCatalogEntry &table, ExecutionContext &context, InsertGlobalState &gstate, - InsertLocalState &lstate) const { - auto &data_table = table.GetStorage(); - auto &local_storage = LocalStorage::Get(context.client, data_table.db); - - if (action_type == OnConflictAction::THROW) { - auto &constraint_state = lstate.GetConstraintState(data_table, table); - auto storage = local_storage.GetStorage(data_table); - data_table.VerifyAppendConstraints(constraint_state, context.client, lstate.insert_chunk, storage, nullptr); - return 0; - } - - ConflictInfo conflict_info(conflict_target); - - auto &global_indexes = data_table.GetDataTableInfo()->GetIndexes(); - auto &local_indexes = local_storage.GetIndexes(data_table); - - unordered_set matched_indexes; - if (conflict_info.column_ids.empty()) { - // We care about every index that applies to the table if no ON CONFLICT (...) target is given - global_indexes.Scan([&](Index &index) { - if (!index.IsUnique()) { - return false; - } - if (conflict_info.ConflictTargetMatches(index)) { - D_ASSERT(index.IsBound()); - auto &bound_index = index.Cast(); - matched_indexes.insert(&bound_index); - } - return false; - }); - local_indexes.Scan([&](Index &index) { - if (!index.IsUnique()) { - return false; - } - if (conflict_info.ConflictTargetMatches(index)) { - D_ASSERT(index.IsBound()); - auto &bound_index = index.Cast(); - matched_indexes.insert(&bound_index); - } - return false; - }); - } - - auto inner_conflicts = CheckDistinctness(lstate.insert_chunk, conflict_info, matched_indexes); - idx_t count = lstate.insert_chunk.size(); - if (!inner_conflicts.empty()) { - // We have at least one inner conflict, filter it out - ManagedSelection sel_vec(count); - ValidityMask not_a_conflict(count); - set last_occurrences_of_conflict; - for (idx_t i = 0; i < count; i++) { - auto it = inner_conflicts.find(i); - if (it != inner_conflicts.end()) { - auto &conflicts = it->second; - auto conflict_it = conflicts.begin(); - for (; conflict_it != conflicts.end();) { - auto &idx = *conflict_it; - not_a_conflict.SetInvalid(idx); - conflict_it++; - if (conflict_it == conflicts.end()) { - last_occurrences_of_conflict.insert(idx); - } - } - } - if (not_a_conflict.RowIsValid(i)) { - sel_vec.Append(i); - } - } - if (action_type == OnConflictAction::UPDATE) { - ManagedSelection last_occurrences(last_occurrences_of_conflict.size()); - for (auto &idx : last_occurrences_of_conflict) { - last_occurrences.Append(idx); - } - - lstate.update_chunk.Reference(lstate.insert_chunk); - lstate.update_chunk.Slice(last_occurrences.Selection(), last_occurrences.Count()); - lstate.update_chunk.SetCardinality(last_occurrences.Count()); - } - - lstate.insert_chunk.Slice(sel_vec.Selection(), sel_vec.Count()); - lstate.insert_chunk.SetCardinality(sel_vec.Count()); - } - - // Check whether any conflicts arise, and if they all meet the conflict_target + condition - // If that's not the case - We throw the first error - idx_t updated_tuples = 0; - updated_tuples += HandleInsertConflicts(table, context, lstate, lstate.insert_chunk, *this); - // Also check the transaction-local storage+ART so we can detect conflicts within this transaction - updated_tuples += HandleInsertConflicts(table, context, lstate, lstate.insert_chunk, *this); - - return updated_tuples; -} - -SinkResultType PhysicalInsert::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - - auto &table = gstate.table; - auto &storage = table.GetStorage(); - PhysicalInsert::ResolveDefaults(table, chunk, column_index_map, lstate.default_executor, lstate.insert_chunk); - - if (!parallel) { - if (!gstate.initialized) { - storage.InitializeLocalAppend(gstate.append_state, table, context.client, bound_constraints); - gstate.initialized = true; - } - - if (action_type != OnConflictAction::NOTHING && return_chunk) { - // If the action is UPDATE or REPLACE, we will always create either an APPEND or an INSERT - // for NOTHING we don't create either an APPEND or an INSERT for the tuple - // so it should not be added to the RETURNING chunk - gstate.return_collection.Append(lstate.insert_chunk); - } - idx_t updated_tuples = OnConflictHandling(table, context, gstate, lstate); - if (action_type == OnConflictAction::NOTHING && return_chunk) { - // Because we didn't add to the RETURNING chunk yet - // we add the tuples that did not get filtered out now - gstate.return_collection.Append(lstate.insert_chunk); - } - gstate.insert_count += lstate.insert_chunk.size(); - gstate.insert_count += updated_tuples; - storage.LocalAppend(gstate.append_state, context.client, lstate.insert_chunk, true); - if (action_type == OnConflictAction::UPDATE && lstate.update_chunk.size() != 0) { - // Flush the append so we can target the data we just appended with the update - storage.FinalizeLocalAppend(gstate.append_state); - gstate.initialized = false; - (void)HandleInsertConflicts(table, context, lstate, lstate.update_chunk, *this); - (void)HandleInsertConflicts(table, context, lstate, lstate.update_chunk, *this); - // All of the tuples should have been turned into an update, leaving the chunk empty afterwards - D_ASSERT(lstate.update_chunk.size() == 0); - } - } else { - D_ASSERT(!return_chunk); - // parallel append - if (!lstate.local_collection) { - lock_guard l(gstate.lock); - auto table_info = storage.GetDataTableInfo(); - auto &io_manager = TableIOManager::Get(table.GetStorage()); - lstate.local_collection = make_uniq(std::move(table_info), io_manager, insert_types, - NumericCast(MAX_ROW_ID)); - lstate.local_collection->InitializeEmpty(); - lstate.local_collection->InitializeAppend(lstate.local_append_state); - lstate.writer = &gstate.table.GetStorage().CreateOptimisticWriter(context.client); - } - OnConflictHandling(table, context, gstate, lstate); - D_ASSERT(action_type != OnConflictAction::UPDATE); - - auto new_row_group = lstate.local_collection->Append(lstate.insert_chunk, lstate.local_append_state); - if (new_row_group) { - lstate.writer->WriteNewRowGroup(*lstate.local_collection); - } - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalInsert::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - - if (!parallel || !lstate.local_collection) { - return SinkCombineResultType::FINISHED; - } - - auto &table = gstate.table; - auto &storage = table.GetStorage(); - const idx_t row_group_size = storage.GetRowGroupSize(); - - // parallel append: finalize the append - TransactionData tdata(0, 0); - lstate.local_collection->FinalizeAppend(tdata, lstate.local_append_state); - - auto append_count = lstate.local_collection->GetTotalRows(); - - lock_guard lock(gstate.lock); - gstate.insert_count += append_count; - if (append_count < row_group_size) { - // we have few rows - append to the local storage directly - storage.InitializeLocalAppend(gstate.append_state, table, context.client, bound_constraints); - auto &transaction = DuckTransaction::Get(context.client, table.catalog); - lstate.local_collection->Scan(transaction, [&](DataChunk &insert_chunk) { - storage.LocalAppend(gstate.append_state, context.client, insert_chunk, false); - return true; - }); - storage.FinalizeLocalAppend(gstate.append_state); - } else { - // we have written rows to disk optimistically - merge directly into the transaction-local storage - lstate.writer->WriteLastRowGroup(*lstate.local_collection); - lstate.writer->FinalFlush(); - gstate.table.GetStorage().LocalMerge(context.client, *lstate.local_collection); - gstate.table.GetStorage().FinalizeOptimisticWriter(context.client, *lstate.writer); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalInsert::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - auto &gstate = input.global_state.Cast(); - if (!parallel && gstate.initialized) { - auto &table = gstate.table; - auto &storage = table.GetStorage(); - storage.FinalizeLocalAppend(gstate.append_state); - } - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class InsertSourceState : public GlobalSourceState { -public: - explicit InsertSourceState(const PhysicalInsert &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalInsert::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalInsert::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &insert_gstate = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(insert_gstate.insert_count))); - return SourceResultType::FINISHED; - } - - insert_gstate.return_collection.Scan(state.scan_state, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/persistent/physical_update.cpp b/src/duckdb/src/execution/operator/persistent/physical_update.cpp deleted file mode 100644 index 88d30c5e6..000000000 --- a/src/duckdb/src/execution/operator/persistent/physical_update.cpp +++ /dev/null @@ -1,260 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_update.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table/delete_state.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/storage/table/update_state.hpp" -#include "duckdb/transaction/duck_transaction.hpp" - -namespace duckdb { - -PhysicalUpdate::PhysicalUpdate(vector types, TableCatalogEntry &tableref, DataTable &table, - vector columns, vector> expressions, - vector> bound_defaults, - vector> bound_constraints, idx_t estimated_cardinality, - bool return_chunk) - : PhysicalOperator(PhysicalOperatorType::UPDATE, std::move(types), estimated_cardinality), tableref(tableref), - table(table), columns(std::move(columns)), expressions(std::move(expressions)), - bound_defaults(std::move(bound_defaults)), bound_constraints(std::move(bound_constraints)), - return_chunk(return_chunk), index_update(false) { - - auto &indexes = table.GetDataTableInfo().get()->GetIndexes(); - auto index_columns = indexes.GetRequiredColumns(); - - unordered_set update_columns; - for (const auto col : this->columns) { - update_columns.insert(col.index); - } - - for (const auto &col : table.Columns()) { - if (index_columns.find(col.Logical().index) == index_columns.end()) { - continue; - } - if (update_columns.find(col.Physical().index) == update_columns.end()) { - continue; - } - index_update = true; - break; - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class UpdateGlobalState : public GlobalSinkState { -public: - explicit UpdateGlobalState(ClientContext &context, const vector &return_types) - : updated_count(0), return_collection(context, return_types) { - } - - mutex lock; - idx_t updated_count; - unordered_set updated_rows; - ColumnDataCollection return_collection; -}; - -class UpdateLocalState : public LocalSinkState { -public: - UpdateLocalState(ClientContext &context, const vector> &expressions, - const vector &table_types, const vector> &bound_defaults, - const vector> &bound_constraints) - : default_executor(context, bound_defaults), bound_constraints(bound_constraints) { - - // Initialize the update chunk. - auto &allocator = Allocator::Get(context); - vector update_types; - update_types.reserve(expressions.size()); - for (auto &expr : expressions) { - update_types.push_back(expr->return_type); - } - update_chunk.Initialize(allocator, update_types); - - // Initialize the mock and delete chunk. - mock_chunk.Initialize(allocator, table_types); - delete_chunk.Initialize(allocator, table_types); - } - - DataChunk update_chunk; - DataChunk mock_chunk; - DataChunk delete_chunk; - ExpressionExecutor default_executor; - unique_ptr delete_state; - unique_ptr update_state; - const vector> &bound_constraints; - - TableDeleteState &GetDeleteState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { - if (!delete_state) { - delete_state = table.InitializeDelete(tableref, context, bound_constraints); - } - return *delete_state; - } - - TableUpdateState &GetUpdateState(DataTable &table, TableCatalogEntry &tableref, ClientContext &context) { - if (!update_state) { - update_state = table.InitializeUpdate(tableref, context, bound_constraints); - } - return *update_state; - } -}; - -SinkResultType PhysicalUpdate::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &g_state = input.global_state.Cast(); - auto &l_state = input.local_state.Cast(); - - chunk.Flatten(); - l_state.default_executor.SetChunk(chunk); - - DataChunk &update_chunk = l_state.update_chunk; - update_chunk.Reset(); - update_chunk.SetCardinality(chunk); - - for (idx_t i = 0; i < expressions.size(); i++) { - // Default expression, set to the default value of the column. - if (expressions[i]->GetExpressionType() == ExpressionType::VALUE_DEFAULT) { - l_state.default_executor.ExecuteExpression(columns[i].index, update_chunk.data[i]); - continue; - } - - D_ASSERT(expressions[i]->GetExpressionType() == ExpressionType::BOUND_REF); - auto &binding = expressions[i]->Cast(); - update_chunk.data[i].Reference(chunk.data[binding.index]); - } - - lock_guard glock(g_state.lock); - auto &row_ids = chunk.data[chunk.ColumnCount() - 1]; - DataChunk &mock_chunk = l_state.mock_chunk; - - // Regular in-place update. - if (!update_is_del_and_insert) { - if (return_chunk) { - mock_chunk.SetCardinality(update_chunk); - for (idx_t i = 0; i < columns.size(); i++) { - mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); - } - } - auto &update_state = l_state.GetUpdateState(table, tableref, context.client); - table.Update(update_state, context.client, row_ids, columns, update_chunk); - - if (return_chunk) { - g_state.return_collection.Append(mock_chunk); - } - g_state.updated_count += chunk.size(); - return SinkResultType::NEED_MORE_INPUT; - } - - // We update an index or a complex type, so we need to split the UPDATE into DELETE + INSERT. - - // Keep track of the rows that have not yet been deleted in this UPDATE. - // This is required since we might see the same row_id multiple times, e.g., - // during an UPDATE containing joins. - SelectionVector sel(update_chunk.size()); - idx_t update_count = 0; - auto row_id_data = FlatVector::GetData(row_ids); - - for (idx_t i = 0; i < update_chunk.size(); i++) { - auto row_id = row_id_data[i]; - if (g_state.updated_rows.find(row_id) == g_state.updated_rows.end()) { - g_state.updated_rows.insert(row_id); - sel.set_index(update_count++, i); - } - } - - // The update chunk now contains exactly those rows that we are deleting. - Vector del_row_ids(row_ids); - if (update_count != update_chunk.size()) { - update_chunk.Slice(sel, update_count); - del_row_ids.Slice(row_ids, sel, update_count); - } - - auto &delete_chunk = index_update ? l_state.delete_chunk : l_state.mock_chunk; - delete_chunk.Reset(); - delete_chunk.SetCardinality(update_count); - - if (index_update) { - auto &transaction = DuckTransaction::Get(context.client, table.db); - vector column_ids; - for (idx_t i = 0; i < table.ColumnCount(); i++) { - column_ids.emplace_back(i); - }; - // We need to fetch the previous index keys to add them to the delete index. - auto fetch_state = ColumnFetchState(); - table.Fetch(transaction, delete_chunk, column_ids, row_ids, update_count, fetch_state); - } - - auto &delete_state = l_state.GetDeleteState(table, tableref, context.client); - table.Delete(delete_state, context.client, del_row_ids, update_count); - - // Arrange the columns in the standard table order. - mock_chunk.SetCardinality(update_count); - for (idx_t i = 0; i < columns.size(); i++) { - mock_chunk.data[columns[i].index].Reference(update_chunk.data[i]); - } - - table.LocalAppend(tableref, context.client, mock_chunk, bound_constraints, del_row_ids, delete_chunk); - if (return_chunk) { - g_state.return_collection.Append(mock_chunk); - } - - g_state.updated_count += chunk.size(); - return SinkResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalUpdate::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, GetTypes()); -} - -unique_ptr PhysicalUpdate::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, expressions, table.GetTypes(), bound_defaults, - bound_constraints); -} - -SinkCombineResultType PhysicalUpdate::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &client_profiler = QueryProfiler::Get(context.client); - context.thread.profiler.Flush(*this); - client_profiler.Flush(context.thread.profiler); - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -class UpdateSourceState : public GlobalSourceState { -public: - explicit UpdateSourceState(const PhysicalUpdate &op) { - if (op.return_chunk) { - D_ASSERT(op.sink_state); - auto &g = op.sink_state->Cast(); - g.return_collection.InitializeScan(scan_state); - } - } - - ColumnDataScanState scan_state; -}; - -unique_ptr PhysicalUpdate::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*this); -} - -SourceResultType PhysicalUpdate::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &state = input.global_state.Cast(); - auto &g = sink_state->Cast(); - if (!return_chunk) { - chunk.SetCardinality(1); - chunk.SetValue(0, 0, Value::BIGINT(NumericCast(g.updated_count))); - return SourceResultType::FINISHED; - } - - g.return_collection.Scan(state.scan_state, chunk); - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp b/src/duckdb/src/execution/operator/projection/physical_pivot.cpp deleted file mode 100644 index d3def4c0c..000000000 --- a/src/duckdb/src/execution/operator/projection/physical_pivot.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_pivot.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -PhysicalPivot::PhysicalPivot(vector types_p, unique_ptr child, - BoundPivotInfo bound_pivot_p) - : PhysicalOperator(PhysicalOperatorType::PIVOT, std::move(types_p), child->estimated_cardinality), - bound_pivot(std::move(bound_pivot_p)) { - children.push_back(std::move(child)); - for (idx_t p = 0; p < bound_pivot.pivot_values.size(); p++) { - auto entry = pivot_map.find(bound_pivot.pivot_values[p]); - if (entry != pivot_map.end()) { - continue; - } - pivot_map[bound_pivot.pivot_values[p]] = bound_pivot.group_count + p; - } - // extract the empty aggregate expressions - ArenaAllocator allocator(Allocator::DefaultAllocator()); - for (auto &aggr_expr : bound_pivot.aggregates) { - auto &aggr = aggr_expr->Cast(); - // for each aggregate, initialize an empty aggregate state and finalize it immediately - auto state = make_unsafe_uniq_array(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, state.get()); - Vector state_vector(Value::POINTER(CastPointerToValue(state.get()))); - Vector result_vector(aggr_expr->return_type); - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - aggr.function.finalize(state_vector, aggr_input_data, result_vector, 1, 0); - empty_aggregates.push_back(result_vector.GetValue(0)); - } -} - -OperatorResultType PhysicalPivot::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - // copy the groups as-is - input.Flatten(); - for (idx_t i = 0; i < bound_pivot.group_count; i++) { - chunk.data[i].Reference(input.data[i]); - } - auto pivot_column_lists = FlatVector::GetData(input.data.back()); - auto &pivot_column_values = ListVector::GetEntry(input.data.back()); - auto pivot_columns = FlatVector::GetData(pivot_column_values); - - // initialize all aggregate columns with the empty aggregate value - // if there are multiple aggregates the columns are in order of [AGGR1][AGGR2][AGGR1][AGGR2] - // so we need to alternate the empty_aggregate that we use - idx_t aggregate = 0; - for (idx_t c = bound_pivot.group_count; c < chunk.ColumnCount(); c++) { - chunk.data[c].Reference(empty_aggregates[aggregate]); - chunk.data[c].Flatten(input.size()); - aggregate++; - if (aggregate >= empty_aggregates.size()) { - aggregate = 0; - } - } - - // move the pivots to the given columns - for (idx_t r = 0; r < input.size(); r++) { - auto list = pivot_column_lists[r]; - for (idx_t l = 0; l < list.length; l++) { - // figure out the column value number of this list - auto &column_name = pivot_columns[list.offset + l]; - auto entry = pivot_map.find(column_name); - if (entry == pivot_map.end()) { - // column entry not found in map - that means this element is explicitly excluded from the pivot list - continue; - } - auto column_idx = entry->second; - for (idx_t aggr = 0; aggr < empty_aggregates.size(); aggr++) { - auto pivot_value_lists = FlatVector::GetData(input.data[bound_pivot.group_count + aggr]); - auto &pivot_value_child = ListVector::GetEntry(input.data[bound_pivot.group_count + aggr]); - if (list.length != pivot_value_lists[r].length) { - throw InternalException("Pivot - unaligned lists between values and columns!?"); - } - chunk.data[column_idx + aggr].SetValue(r, pivot_value_child.GetValue(pivot_value_lists[r].offset + l)); - } - } - } - chunk.SetCardinality(input.size()); - return OperatorResultType::NEED_MORE_INPUT; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_projection.cpp b/src/duckdb/src/execution/operator/projection/physical_projection.cpp deleted file mode 100644 index 5d6dcb13f..000000000 --- a/src/duckdb/src/execution/operator/projection/physical_projection.cpp +++ /dev/null @@ -1,87 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -namespace duckdb { - -class ProjectionState : public OperatorState { -public: - explicit ProjectionState(ExecutionContext &context, const vector> &expressions) - : executor(context.client, expressions) { - } - - ExpressionExecutor executor; - -public: - void Finalize(const PhysicalOperator &op, ExecutionContext &context) override { - context.thread.profiler.Flush(op); - } -}; - -PhysicalProjection::PhysicalProjection(vector types, vector> select_list, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::PROJECTION, std::move(types), estimated_cardinality), - select_list(std::move(select_list)) { -} - -OperatorResultType PhysicalProjection::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - state.executor.Execute(input, chunk); - return OperatorResultType::NEED_MORE_INPUT; -} - -unique_ptr PhysicalProjection::GetOperatorState(ExecutionContext &context) const { - return make_uniq(context, select_list); -} - -unique_ptr -PhysicalProjection::CreateJoinProjection(vector proj_types, const vector &lhs_types, - const vector &rhs_types, const vector &left_projection_map, - const vector &right_projection_map, const idx_t estimated_cardinality) { - - vector> proj_selects; - proj_selects.reserve(proj_types.size()); - - if (left_projection_map.empty()) { - for (storage_t i = 0; i < lhs_types.size(); ++i) { - proj_selects.emplace_back(make_uniq(lhs_types[i], i)); - } - } else { - for (auto i : left_projection_map) { - proj_selects.emplace_back(make_uniq(lhs_types[i], i)); - } - } - const auto left_cols = lhs_types.size(); - - if (right_projection_map.empty()) { - for (storage_t i = 0; i < rhs_types.size(); ++i) { - proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); - } - - } else { - for (auto i : right_projection_map) { - proj_selects.emplace_back(make_uniq(rhs_types[i], left_cols + i)); - } - } - - return make_uniq(std::move(proj_types), std::move(proj_selects), estimated_cardinality); -} - -InsertionOrderPreservingMap PhysicalProjection::ParamsToString() const { - InsertionOrderPreservingMap result; - string projections; - for (idx_t i = 0; i < select_list.size(); i++) { - if (i > 0) { - projections += "\n"; - } - auto &expr = select_list[i]; - projections += expr->GetName(); - } - result["__projections__"] = projections; - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp b/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp deleted file mode 100644 index fa150693e..000000000 --- a/src/duckdb/src/execution/operator/projection/physical_tableinout_function.cpp +++ /dev/null @@ -1,138 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_tableinout_function.hpp" - -namespace duckdb { - -class TableInOutLocalState : public OperatorState { -public: - TableInOutLocalState() : row_index(0), new_row(true) { - } - - unique_ptr local_state; - idx_t row_index; - bool new_row; - DataChunk input_chunk; -}; - -class TableInOutGlobalState : public GlobalOperatorState { -public: - TableInOutGlobalState() { - } - - unique_ptr global_state; -}; - -PhysicalTableInOutFunction::PhysicalTableInOutFunction(vector types, TableFunction function_p, - unique_ptr bind_data_p, - vector column_ids_p, idx_t estimated_cardinality, - vector project_input_p) - : PhysicalOperator(PhysicalOperatorType::INOUT_FUNCTION, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)), column_ids(std::move(column_ids_p)), - projected_input(std::move(project_input_p)) { -} - -unique_ptr PhysicalTableInOutFunction::GetOperatorState(ExecutionContext &context) const { - auto &gstate = op_state->Cast(); - auto result = make_uniq(); - if (function.init_local) { - TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); - result->local_state = function.init_local(context, input, gstate.global_state.get()); - } - if (!projected_input.empty()) { - vector input_types; - auto &child_types = children[0]->types; - idx_t input_length = child_types.size() - projected_input.size(); - for (idx_t k = 0; k < input_length; k++) { - input_types.push_back(child_types[k]); - } - for (idx_t k = 0; k < projected_input.size(); k++) { - D_ASSERT(projected_input[k] >= input_length); - } - result->input_chunk.Initialize(context.client, input_types); - } - return std::move(result); -} - -unique_ptr PhysicalTableInOutFunction::GetGlobalOperatorState(ClientContext &context) const { - auto result = make_uniq(); - if (function.init_global) { - TableFunctionInitInput input(bind_data.get(), column_ids, vector(), nullptr); - result->global_state = function.init_global(context, input); - } - return std::move(result); -} - -OperatorResultType PhysicalTableInOutFunction::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate_p, OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - if (projected_input.empty()) { - // straightforward case - no need to project input - return function.in_out_function(context, data, input, chunk); - } - // when project_input is set we execute the input function row-by-row - if (state.new_row) { - if (state.row_index >= input.size()) { - // finished processing this chunk - state.new_row = true; - state.row_index = 0; - return OperatorResultType::NEED_MORE_INPUT; - } - // we are processing a new row: fetch the data for the current row - state.input_chunk.Reset(); - // set up the input data to the table in-out function - for (idx_t col_idx = 0; col_idx < state.input_chunk.ColumnCount(); col_idx++) { - ConstantVector::Reference(state.input_chunk.data[col_idx], input.data[col_idx], state.row_index, 1); - } - state.input_chunk.SetCardinality(1); - state.row_index++; - state.new_row = false; - } - // set up the output data in "chunk" - D_ASSERT(chunk.ColumnCount() > projected_input.size()); - D_ASSERT(state.row_index > 0); - idx_t base_idx = chunk.ColumnCount() - projected_input.size(); - for (idx_t project_idx = 0; project_idx < projected_input.size(); project_idx++) { - auto source_idx = projected_input[project_idx]; - auto target_idx = base_idx + project_idx; - ConstantVector::Reference(chunk.data[target_idx], input.data[source_idx], state.row_index - 1, 1); - } - auto result = function.in_out_function(context, data, state.input_chunk, chunk); - if (result == OperatorResultType::FINISHED) { - return result; - } - if (result == OperatorResultType::NEED_MORE_INPUT) { - // we finished processing this row: move to the next row - state.new_row = true; - } - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -InsertionOrderPreservingMap PhysicalTableInOutFunction::ParamsToString() const { - InsertionOrderPreservingMap result; - if (function.to_string) { - TableFunctionToStringInput input(function, bind_data.get()); - auto to_string_result = function.to_string(input); - for (const auto &it : to_string_result) { - result[it.first] = it.second; - } - } else { - result["Name"] = function.name; - } - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -OperatorFinalizeResultType PhysicalTableInOutFunction::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate_p, - OperatorState &state_p) const { - auto &gstate = gstate_p.Cast(); - auto &state = state_p.Cast(); - if (!projected_input.empty()) { - throw InternalException("FinalExecute not supported for project_input"); - } - TableFunctionInput data(bind_data.get(), state.local_state.get(), gstate.global_state.get()); - return function.in_out_function_final(context, data, chunk); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp b/src/duckdb/src/execution/operator/projection/physical_unnest.cpp deleted file mode 100644 index 4f32fd291..000000000 --- a/src/duckdb/src/execution/operator/projection/physical_unnest.cpp +++ /dev/null @@ -1,387 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_unnest.hpp" - -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_unnest_expression.hpp" - -namespace duckdb { - -class UnnestOperatorState : public OperatorState { -public: - UnnestOperatorState(ClientContext &context, const vector> &select_list) - : current_row(0), list_position(0), longest_list_length(DConstants::INVALID_INDEX), first_fetch(true), - executor(context) { - - // for each UNNEST in the select_list, we add the child expression to the expression executor - // and set the return type in the list_data chunk, which will contain the evaluated expression results - vector list_data_types; - for (auto &exp : select_list) { - D_ASSERT(exp->GetExpressionType() == ExpressionType::BOUND_UNNEST); - auto &bue = exp->Cast(); - list_data_types.push_back(bue.child->return_type); - executor.AddExpression(*bue.child.get()); - } - - auto &allocator = Allocator::Get(context); - list_data.Initialize(allocator, list_data_types); - - list_vector_data.resize(list_data.ColumnCount()); - list_child_data.resize(list_data.ColumnCount()); - } - - idx_t current_row; - idx_t list_position; - idx_t longest_list_length; - bool first_fetch; - - ExpressionExecutor executor; - DataChunk list_data; - vector list_vector_data; - vector list_child_data; - -public: - //! Reset the fields of the unnest operator state - void Reset(); - //! Set the longest list's length for the current row - void SetLongestListLength(); -}; - -void UnnestOperatorState::Reset() { - current_row = 0; - list_position = 0; - longest_list_length = DConstants::INVALID_INDEX; - first_fetch = true; -} - -void UnnestOperatorState::SetLongestListLength() { - longest_list_length = 0; - for (idx_t col_idx = 0; col_idx < list_data.ColumnCount(); col_idx++) { - - auto &vector_data = list_vector_data[col_idx]; - auto current_idx = vector_data.sel->get_index(current_row); - - if (vector_data.validity.RowIsValid(current_idx)) { - - // check if this list is longer - auto list_data_entries = UnifiedVectorFormat::GetData(vector_data); - auto list_entry = list_data_entries[current_idx]; - if (list_entry.length > longest_list_length) { - longest_list_length = list_entry.length; - } - } - } -} - -PhysicalUnnest::PhysicalUnnest(vector types, vector> select_list, - idx_t estimated_cardinality, PhysicalOperatorType type) - : PhysicalOperator(type, std::move(types), estimated_cardinality), select_list(std::move(select_list)) { - D_ASSERT(!this->select_list.empty()); -} - -static void UnnestNull(idx_t start, idx_t end, Vector &result) { - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto &validity = FlatVector::Validity(result); - for (idx_t i = start; i < end; i++) { - validity.SetInvalid(i); - } - - const auto &logical_type = result.GetType(); - if (logical_type.InternalType() == PhysicalType::STRUCT) { - const auto &struct_children = StructVector::GetEntries(result); - for (auto &child : struct_children) { - UnnestNull(start, end, *child); - } - } else if (logical_type.InternalType() == PhysicalType::ARRAY) { - auto &array_child = ArrayVector::GetEntry(result); - auto array_size = ArrayType::GetSize(logical_type); - UnnestNull(start * array_size, end * array_size, array_child); - } -} - -template -static void TemplatedUnnest(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { - - auto source_data = UnifiedVectorFormat::GetData(vector_data); - auto &source_mask = vector_data.validity; - - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = start; i < end; i++) { - auto source_idx = vector_data.sel->get_index(i); - auto target_idx = i - start; - if (source_mask.RowIsValid(source_idx)) { - result_data[target_idx] = source_data[source_idx]; - result_mask.SetValid(target_idx); - } else { - result_mask.SetInvalid(target_idx); - } - } -} - -static void UnnestValidity(UnifiedVectorFormat &vector_data, idx_t start, idx_t end, Vector &result) { - - auto &source_mask = vector_data.validity; - D_ASSERT(result.GetVectorType() == VectorType::FLAT_VECTOR); - auto &result_mask = FlatVector::Validity(result); - - for (idx_t i = start; i < end; i++) { - auto source_idx = vector_data.sel->get_index(i); - auto target_idx = i - start; - result_mask.Set(target_idx, source_mask.RowIsValid(source_idx)); - } -} - -static void UnnestVector(UnifiedVectorFormat &child_vector_data, Vector &child_vector, idx_t list_size, idx_t start, - idx_t end, Vector &result) { - - D_ASSERT(child_vector.GetType() == result.GetType()); - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT16: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT32: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT64: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INT128: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT8: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT16: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT32: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT64: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::UINT128: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::FLOAT: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::DOUBLE: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::INTERVAL: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::VARCHAR: - TemplatedUnnest(child_vector_data, start, end, result); - break; - case PhysicalType::LIST: { - // the child vector of result now references the child vector source - // FIXME: only reference relevant children (start - end) instead of all - auto &target = ListVector::GetEntry(result); - target.Reference(ListVector::GetEntry(child_vector)); - ListVector::SetListSize(result, ListVector::GetListSize(child_vector)); - // unnest - TemplatedUnnest(child_vector_data, start, end, result); - break; - } - case PhysicalType::STRUCT: { - auto &child_vector_entries = StructVector::GetEntries(child_vector); - auto &result_entries = StructVector::GetEntries(result); - - // set the validity mask for the 'outer' struct vector before unnesting its children - UnnestValidity(child_vector_data, start, end, result); - - for (idx_t i = 0; i < child_vector_entries.size(); i++) { - UnifiedVectorFormat child_vector_entries_data; - child_vector_entries[i]->ToUnifiedFormat(list_size, child_vector_entries_data); - UnnestVector(child_vector_entries_data, *child_vector_entries[i], list_size, start, end, - *result_entries[i]); - } - break; - } - case PhysicalType::ARRAY: { - auto array_size = ArrayType::GetSize(child_vector.GetType()); - auto &source_array = ArrayVector::GetEntry(child_vector); - auto &target_array = ArrayVector::GetEntry(result); - - UnnestValidity(child_vector_data, start, end, result); - - UnifiedVectorFormat child_array_data; - source_array.ToUnifiedFormat(list_size * array_size, child_array_data); - UnnestVector(child_array_data, source_array, list_size * array_size, start * array_size, end * array_size, - target_array); - break; - } - default: - throw InternalException("Unimplemented type for UNNEST."); - } -} - -static void PrepareInput(UnnestOperatorState &state, DataChunk &input, - const vector> &select_list) { - - state.list_data.Reset(); - // execute the expressions inside each UNNEST in the select_list to get the list data - // execution results (lists) are kept in state.list_data chunk - state.executor.Execute(input, state.list_data); - - // verify incoming lists - state.list_data.Verify(); - D_ASSERT(input.size() == state.list_data.size()); - D_ASSERT(state.list_data.ColumnCount() == select_list.size()); - D_ASSERT(state.list_vector_data.size() == state.list_data.ColumnCount()); - D_ASSERT(state.list_child_data.size() == state.list_data.ColumnCount()); - - // get the UnifiedVectorFormat of each list_data vector (LIST vectors for the different UNNESTs) - // both for the vector itself and its child vector - for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { - - auto &list_vector = state.list_data.data[col_idx]; - list_vector.ToUnifiedFormat(state.list_data.size(), state.list_vector_data[col_idx]); - - if (list_vector.GetType() == LogicalType::SQLNULL) { - // UNNEST(NULL): SQLNULL vectors don't have child vectors, but we need to point to the child vector of - // each vector, so we just get the UnifiedVectorFormat of the vector itself - auto &child_vector = list_vector; - child_vector.ToUnifiedFormat(0, state.list_child_data[col_idx]); - } else { - auto list_size = ListVector::GetListSize(list_vector); - auto &child_vector = ListVector::GetEntry(list_vector); - child_vector.ToUnifiedFormat(list_size, state.list_child_data[col_idx]); - } - } - - state.first_fetch = false; -} - -unique_ptr PhysicalUnnest::GetOperatorState(ExecutionContext &context) const { - return PhysicalUnnest::GetState(context, select_list); -} - -unique_ptr PhysicalUnnest::GetState(ExecutionContext &context, - const vector> &select_list) { - return make_uniq(context.client, select_list); -} - -OperatorResultType PhysicalUnnest::ExecuteInternal(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - OperatorState &state_p, - const vector> &select_list, - bool include_input) { - - auto &state = state_p.Cast(); - - do { - // reset validities, if previous loop iteration contained UNNEST(NULL) - if (include_input) { - chunk.Reset(); - } - - // prepare the input data by executing any expressions and getting the - // UnifiedVectorFormat of each LIST vector (list_vector_data) and its child vector (list_child_data) - if (state.first_fetch) { - PrepareInput(state, input, select_list); - } - - // finished with all rows of this input chunk, reset - if (state.current_row >= input.size()) { - state.Reset(); - return OperatorResultType::NEED_MORE_INPUT; - } - - // each UNNEST in the select_list contains a list (or NULL) for this row, find the longest list - // because this length determines how many times we need to repeat for the current row - if (state.longest_list_length == DConstants::INVALID_INDEX) { - state.SetLongestListLength(); - } - D_ASSERT(state.longest_list_length != DConstants::INVALID_INDEX); - - // we emit chunks of either STANDARD_VECTOR_SIZE or smaller - auto this_chunk_len = MinValue(STANDARD_VECTOR_SIZE, state.longest_list_length - state.list_position); - chunk.SetCardinality(this_chunk_len); - - // if we include other projection input columns, e.g. SELECT 1, UNNEST([1, 2]);, then - // we need to add them as a constant vector to the resulting chunk - // FIXME: emit multiple unnested rows. Currently, we never emit a chunk containing multiple unnested input rows, - // so setting a constant vector for the value at state.current_row is fine - idx_t col_offset = 0; - if (include_input) { - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - ConstantVector::Reference(chunk.data[col_idx], input.data[col_idx], state.current_row, input.size()); - } - col_offset = input.ColumnCount(); - } - - // unnest the lists - for (idx_t col_idx = 0; col_idx < state.list_data.ColumnCount(); col_idx++) { - - auto &result_vector = chunk.data[col_idx + col_offset]; - - if (state.list_data.data[col_idx].GetType() == LogicalType::SQLNULL) { - // UNNEST(NULL) - chunk.SetCardinality(0); - break; - } - - auto &vector_data = state.list_vector_data[col_idx]; - auto current_idx = vector_data.sel->get_index(state.current_row); - - if (!vector_data.validity.RowIsValid(current_idx)) { - UnnestNull(0, this_chunk_len, result_vector); - continue; - } - - auto list_data = UnifiedVectorFormat::GetData(vector_data); - auto list_entry = list_data[current_idx]; - - idx_t list_count = 0; - if (state.list_position < list_entry.length) { - // there are still list_count elements to unnest - list_count = MinValue(this_chunk_len, list_entry.length - state.list_position); - - auto &list_vector = state.list_data.data[col_idx]; - auto &child_vector = ListVector::GetEntry(list_vector); - auto list_size = ListVector::GetListSize(list_vector); - auto &child_vector_data = state.list_child_data[col_idx]; - - auto base_offset = list_entry.offset + state.list_position; - UnnestVector(child_vector_data, child_vector, list_size, base_offset, base_offset + list_count, - result_vector); - } - - // fill the rest with NULLs - if (list_count != this_chunk_len) { - UnnestNull(list_count, this_chunk_len, result_vector); - } - } - - chunk.Verify(); - - state.list_position += this_chunk_len; - if (state.list_position == state.longest_list_length) { - state.current_row++; - state.longest_list_length = DConstants::INVALID_INDEX; - state.list_position = 0; - } - - // we only emit one unnested row (that contains data) at a time - } while (chunk.size() == 0); - return OperatorResultType::HAVE_MORE_OUTPUT; -} - -OperatorResultType PhysicalUnnest::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &, OperatorState &state) const { - return ExecuteInternal(context, input, chunk, state, select_list); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp deleted file mode 100644 index e864f3a2e..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_column_data_scan.cpp +++ /dev/null @@ -1,130 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_delim_join.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" - -namespace duckdb { - -PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, - idx_t estimated_cardinality, - optionally_owned_ptr collection_p) - : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(std::move(collection_p)), - cte_index(DConstants::INVALID_INDEX) { -} - -PhysicalColumnDataScan::PhysicalColumnDataScan(vector types, PhysicalOperatorType op_type, - idx_t estimated_cardinality, idx_t cte_index) - : PhysicalOperator(op_type, std::move(types), estimated_cardinality), collection(nullptr), cte_index(cte_index) { -} - -class PhysicalColumnDataGlobalScanState : public GlobalSourceState { -public: - explicit PhysicalColumnDataGlobalScanState(const ColumnDataCollection &collection) - : max_threads(MaxValue(collection.ChunkCount(), 1)) { - collection.InitializeScan(global_scan_state); - } - - idx_t MaxThreads() override { - return max_threads; - } - -public: - ColumnDataParallelScanState global_scan_state; - - const idx_t max_threads; -}; - -class PhysicalColumnDataLocalScanState : public LocalSourceState { -public: - ColumnDataLocalScanState local_scan_state; -}; - -unique_ptr PhysicalColumnDataScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(*collection); -} - -unique_ptr PhysicalColumnDataScan::GetLocalSourceState(ExecutionContext &, - GlobalSourceState &) const { - return make_uniq(); -} - -SourceResultType PhysicalColumnDataScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - collection->Scan(gstate.global_scan_state, lstate.local_scan_state, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalColumnDataScan::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - // check if there is any additional action we need to do depending on the type - auto &state = meta_pipeline.GetState(); - switch (type) { - case PhysicalOperatorType::DELIM_SCAN: { - auto entry = state.delim_join_dependencies.find(*this); - D_ASSERT(entry != state.delim_join_dependencies.end()); - // this chunk scan introduces a dependency to the current pipeline - // namely a dependency on the duplicate elimination pipeline to finish - auto delim_dependency = entry->second.get().shared_from_this(); - auto delim_sink = state.GetPipelineSink(*delim_dependency); - D_ASSERT(delim_sink); - D_ASSERT(delim_sink->type == PhysicalOperatorType::LEFT_DELIM_JOIN || - delim_sink->type == PhysicalOperatorType::RIGHT_DELIM_JOIN); - auto &delim_join = delim_sink->Cast(); - current.AddDependency(delim_dependency); - state.SetPipelineSource(current, delim_join.distinct->Cast()); - return; - } - case PhysicalOperatorType::CTE_SCAN: { - auto entry = state.cte_dependencies.find(*this); - D_ASSERT(entry != state.cte_dependencies.end()); - // this chunk scan introduces a dependency to the current pipeline - // namely a dependency on the CTE pipeline to finish - auto cte_dependency = entry->second.get().shared_from_this(); - auto cte_sink = state.GetPipelineSink(*cte_dependency); - (void)cte_sink; - D_ASSERT(cte_sink); - D_ASSERT(cte_sink->type == PhysicalOperatorType::CTE); - current.AddDependency(cte_dependency); - state.SetPipelineSource(current, *this); - return; - } - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: - if (!meta_pipeline.HasRecursiveCTE()) { - throw InternalException("Recursive CTE scan found without recursive CTE node"); - } - break; - default: - break; - } - D_ASSERT(children.empty()); - state.SetPipelineSource(current, *this); -} - -InsertionOrderPreservingMap PhysicalColumnDataScan::ParamsToString() const { - InsertionOrderPreservingMap result; - switch (type) { - case PhysicalOperatorType::DELIM_SCAN: - if (delim_index.IsValid()) { - result["Delim Index"] = StringUtil::Format("%llu", delim_index.GetIndex()); - } - break; - case PhysicalOperatorType::CTE_SCAN: - case PhysicalOperatorType::RECURSIVE_CTE_SCAN: { - result["CTE Index"] = StringUtil::Format("%llu", cte_index); - break; - } - default: - break; - } - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp deleted file mode 100644 index 1a620803b..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_dummy_scan.cpp +++ /dev/null @@ -1,13 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" - -namespace duckdb { - -SourceResultType PhysicalDummyScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // return a single row on the first call to the dummy scan - chunk.SetCardinality(1); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp b/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp deleted file mode 100644 index 2e7d006bf..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_empty_result.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_empty_result.hpp" - -namespace duckdb { - -SourceResultType PhysicalEmptyResult::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp deleted file mode 100644 index c0e91ae2b..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_expression_scan.cpp +++ /dev/null @@ -1,78 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" - -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/parallel/thread_context.hpp" - -namespace duckdb { - -class ExpressionScanState : public OperatorState { -public: - explicit ExpressionScanState(Allocator &allocator, const PhysicalExpressionScan &op) : expression_index(0) { - temp_chunk.Initialize(allocator, op.GetTypes()); - } - - //! The current position in the scan - idx_t expression_index; - //! Temporary chunk for evaluating expressions - DataChunk temp_chunk; -}; - -unique_ptr PhysicalExpressionScan::GetOperatorState(ExecutionContext &context) const { - return make_uniq(Allocator::Get(context.client), *this); -} - -OperatorResultType PhysicalExpressionScan::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - - for (; chunk.size() + input.size() <= STANDARD_VECTOR_SIZE && state.expression_index < expressions.size(); - state.expression_index++) { - state.temp_chunk.Reset(); - EvaluateExpression(context.client, state.expression_index, &input, chunk, &state.temp_chunk); - } - if (state.expression_index < expressions.size()) { - return OperatorResultType::HAVE_MORE_OUTPUT; - } else { - state.expression_index = 0; - return OperatorResultType::NEED_MORE_INPUT; - } -} - -void PhysicalExpressionScan::EvaluateExpression(ClientContext &context, idx_t expression_idx, - optional_ptr child_chunk, DataChunk &result, - optional_ptr temp_chunk_ptr) const { - if (temp_chunk_ptr) { - EvaluateExpressionInternal(context, expression_idx, child_chunk, result, *temp_chunk_ptr); - } else { - DataChunk temp_chunk; - temp_chunk.Initialize(Allocator::Get(context), GetTypes()); - EvaluateExpressionInternal(context, expression_idx, child_chunk, result, temp_chunk); - } -} - -void PhysicalExpressionScan::EvaluateExpressionInternal(ClientContext &context, idx_t expression_idx, - optional_ptr child_chunk, DataChunk &result, - DataChunk &temp_chunk) const { - ExpressionExecutor executor(context, expressions[expression_idx]); - if (child_chunk) { - child_chunk->Verify(); - executor.Execute(*child_chunk, temp_chunk); - } else { - executor.Execute(temp_chunk); - } - // Need to append because "executor" might be holding state (e.g., strings), which go out of scope here - result.Append(temp_chunk); -} - -bool PhysicalExpressionScan::IsFoldable() const { - for (auto &expr_list : expressions) { - for (auto &expr : expr_list) { - if (!expr->IsFoldable()) { - return false; - } - } - } - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp deleted file mode 100644 index c1e2707b2..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_positional_scan.cpp +++ /dev/null @@ -1,220 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parallel/interrupt.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/transaction/transaction.hpp" - -#include - -namespace duckdb { - -PhysicalPositionalScan::PhysicalPositionalScan(vector types, unique_ptr left, - unique_ptr right) - : PhysicalOperator(PhysicalOperatorType::POSITIONAL_SCAN, std::move(types), - MaxValue(left->estimated_cardinality, right->estimated_cardinality)) { - - // Manage the children ourselves - if (left->type == PhysicalOperatorType::TABLE_SCAN) { - child_tables.emplace_back(std::move(left)); - } else if (left->type == PhysicalOperatorType::POSITIONAL_SCAN) { - auto &left_scan = left->Cast(); - child_tables = std::move(left_scan.child_tables); - } else { - throw InternalException("Invalid left input for PhysicalPositionalScan"); - } - - if (right->type == PhysicalOperatorType::TABLE_SCAN) { - child_tables.emplace_back(std::move(right)); - } else if (right->type == PhysicalOperatorType::POSITIONAL_SCAN) { - auto &right_scan = right->Cast(); - auto &right_tables = right_scan.child_tables; - child_tables.reserve(child_tables.size() + right_tables.size()); - std::move(right_tables.begin(), right_tables.end(), std::back_inserter(child_tables)); - } else { - throw InternalException("Invalid right input for PhysicalPositionalScan"); - } -} - -class PositionalScanGlobalSourceState : public GlobalSourceState { -public: - PositionalScanGlobalSourceState(ClientContext &context, const PhysicalPositionalScan &op) { - for (const auto &table : op.child_tables) { - global_states.emplace_back(table->GetGlobalSourceState(context)); - } - } - - vector> global_states; - - idx_t MaxThreads() override { - return 1; - } -}; - -class PositionalTableScanner { -public: - PositionalTableScanner(ExecutionContext &context, PhysicalOperator &table_p, GlobalSourceState &gstate_p) - : table(table_p), global_state(gstate_p), source_offset(0), exhausted(false) { - local_state = table.GetLocalSourceState(context, gstate_p); - source.Initialize(Allocator::Get(context.client), table.types); - } - - idx_t Refill(ExecutionContext &context) { - if (source_offset >= source.size()) { - if (!exhausted) { - source.Reset(); - - InterruptState interrupt_state; - OperatorSourceInput source_input {global_state, *local_state, interrupt_state}; - auto source_result = table.GetData(context, source, source_input); - if (source_result == SourceResultType::BLOCKED) { - throw NotImplementedException( - "Unexpected interrupt from table Source in PositionalTableScanner refill"); - } - } - source_offset = 0; - } - - const auto available = source.size() - source_offset; - if (!available) { - if (!exhausted) { - source.Reset(); - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - auto &vec = source.data[i]; - vec.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(vec, true); - } - exhausted = true; - } - } - - return available; - } - - idx_t CopyData(ExecutionContext &context, DataChunk &output, const idx_t count, const idx_t col_offset) { - if (!source_offset && (source.size() >= count || exhausted)) { - // Fast track: aligned and has enough data - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - output.data[col_offset + i].Reference(source.data[i]); - } - source_offset += count; - } else { - // Copy data - for (idx_t target_offset = 0; target_offset < count;) { - const auto needed = count - target_offset; - const auto available = exhausted ? needed : (source.size() - source_offset); - const auto copy_size = MinValue(needed, available); - const auto source_count = source_offset + copy_size; - for (idx_t i = 0; i < source.ColumnCount(); ++i) { - VectorOperations::Copy(source.data[i], output.data[col_offset + i], source_count, source_offset, - target_offset); - } - target_offset += copy_size; - source_offset += copy_size; - Refill(context); - } - } - - return source.ColumnCount(); - } - - ProgressData GetProgress(ClientContext &context) { - return table.GetProgress(context, global_state); - } - - PhysicalOperator &table; - GlobalSourceState &global_state; - unique_ptr local_state; - DataChunk source; - idx_t source_offset; - bool exhausted; -}; - -class PositionalScanLocalSourceState : public LocalSourceState { -public: - PositionalScanLocalSourceState(ExecutionContext &context, PositionalScanGlobalSourceState &gstate, - const PhysicalPositionalScan &op) { - for (size_t i = 0; i < op.child_tables.size(); ++i) { - auto &child = *op.child_tables[i]; - auto &global_state = *gstate.global_states[i]; - scanners.emplace_back(make_uniq(context, child, global_state)); - } - } - - vector> scanners; -}; - -unique_ptr PhysicalPositionalScan::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, gstate.Cast(), *this); -} - -unique_ptr PhysicalPositionalScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalPositionalScan::GetData(ExecutionContext &context, DataChunk &output, - OperatorSourceInput &input) const { - auto &lstate = input.local_state.Cast(); - - // Find the longest source block - idx_t count = 0; - for (auto &scanner : lstate.scanners) { - count = MaxValue(count, scanner->Refill(context)); - } - - // All done? - if (!count) { - return SourceResultType::FINISHED; - } - - // Copy or reference the source columns - idx_t col_offset = 0; - for (auto &scanner : lstate.scanners) { - col_offset += scanner->CopyData(context, output, count, col_offset); - } - - output.SetCardinality(count); - return SourceResultType::HAVE_MORE_OUTPUT; -} - -ProgressData PhysicalPositionalScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - - ProgressData res; - - for (size_t t = 0; t < child_tables.size(); ++t) { - res.Add(child_tables[t]->GetProgress(context, *gstate.global_states[t])); - } - - return res; -} - -bool PhysicalPositionalScan::Equals(const PhysicalOperator &other_p) const { - if (type != other_p.type) { - return false; - } - - auto &other = other_p.Cast(); - if (child_tables.size() != other.child_tables.size()) { - return false; - } - for (size_t i = 0; i < child_tables.size(); ++i) { - if (!child_tables[i]->Equals(*other.child_tables[i])) { - return false; - } - } - - return true; -} - -vector> PhysicalPositionalScan::GetChildren() const { - auto result = PhysicalOperator::GetChildren(); - for (auto &entry : child_tables) { - result.push_back(*entry); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp b/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp deleted file mode 100644 index 0ea996d34..000000000 --- a/src/duckdb/src/execution/operator/scan/physical_table_scan.cpp +++ /dev/null @@ -1,267 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" - -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/transaction/transaction.hpp" - -#include - -namespace duckdb { - -PhysicalTableScan::PhysicalTableScan(vector types, TableFunction function_p, - unique_ptr bind_data_p, vector returned_types_p, - vector column_ids_p, vector projection_ids_p, - vector names_p, unique_ptr table_filters_p, - idx_t estimated_cardinality, ExtraOperatorInfo extra_info, - vector parameters_p) - : PhysicalOperator(PhysicalOperatorType::TABLE_SCAN, std::move(types), estimated_cardinality), - function(std::move(function_p)), bind_data(std::move(bind_data_p)), returned_types(std::move(returned_types_p)), - column_ids(std::move(column_ids_p)), projection_ids(std::move(projection_ids_p)), names(std::move(names_p)), - table_filters(std::move(table_filters_p)), extra_info(extra_info), parameters(std::move(parameters_p)) { -} - -class TableScanGlobalSourceState : public GlobalSourceState { -public: - TableScanGlobalSourceState(ClientContext &context, const PhysicalTableScan &op) { - if (op.dynamic_filters && op.dynamic_filters->HasFilters()) { - table_filters = op.dynamic_filters->GetFinalTableFilters(op, op.table_filters.get()); - } - - if (op.function.init_global) { - auto filters = table_filters ? *table_filters : GetTableFilters(op); - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, filters, - op.extra_info.sample_options); - - global_state = op.function.init_global(context, input); - if (global_state) { - max_threads = global_state->MaxThreads(); - } - } else { - max_threads = 1; - } - if (op.function.in_out_function) { - // this is an in-out function, we need to setup the input chunk - vector input_types; - for (auto ¶m : op.parameters) { - input_types.push_back(param.type()); - } - input_chunk.Initialize(context, input_types); - for (idx_t c = 0; c < op.parameters.size(); c++) { - input_chunk.data[c].SetValue(0, op.parameters[c]); - } - input_chunk.SetCardinality(1); - } - } - - idx_t max_threads = 0; - unique_ptr global_state; - bool in_out_final = false; - DataChunk input_chunk; - //! Combined table filters, if we have dynamic filters - unique_ptr table_filters; - - optional_ptr GetTableFilters(const PhysicalTableScan &op) const { - return table_filters ? table_filters.get() : op.table_filters.get(); - } - idx_t MaxThreads() override { - return max_threads; - } -}; - -class TableScanLocalSourceState : public LocalSourceState { -public: - TableScanLocalSourceState(ExecutionContext &context, TableScanGlobalSourceState &gstate, - const PhysicalTableScan &op) { - if (op.function.init_local) { - TableFunctionInitInput input(op.bind_data.get(), op.column_ids, op.projection_ids, - gstate.GetTableFilters(op), op.extra_info.sample_options); - local_state = op.function.init_local(context, input, gstate.global_state.get()); - } - } - - unique_ptr local_state; -}; - -unique_ptr PhysicalTableScan::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(context, gstate.Cast(), *this); -} - -unique_ptr PhysicalTableScan::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -SourceResultType PhysicalTableScan::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - D_ASSERT(!column_ids.empty()); - auto &g_state = input.global_state.Cast(); - auto &l_state = input.local_state.Cast(); - - TableFunctionInput data(bind_data.get(), l_state.local_state.get(), g_state.global_state.get()); - - if (function.function) { - function.function(context.client, data, chunk); - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; - } - - if (g_state.in_out_final) { - function.in_out_function_final(context, data, chunk); - } - function.in_out_function(context, data, g_state.input_chunk, chunk); - if (chunk.size() == 0 && function.in_out_function_final) { - function.in_out_function_final(context, data, chunk); - g_state.in_out_final = true; - } - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -ProgressData PhysicalTableScan::GetProgress(ClientContext &context, GlobalSourceState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - ProgressData res; - if (function.table_scan_progress) { - double table_progress = function.table_scan_progress(context, bind_data.get(), gstate.global_state.get()); - if (table_progress < 0.0) { - res.SetInvalid(); - } else { - res.done = table_progress; - res.total = 100.0; - // Assume cardinality is always 1e3 - res.Normalize(1e3); - } - } else { - // if table_scan_progress is not implemented we don't support this function yet in the progress bar - res.SetInvalid(); - } - return res; -} - -bool PhysicalTableScan::SupportsPartitioning(const OperatorPartitionInfo &partition_info) const { - if (!function.get_partition_data) { - return false; - } - // FIXME: actually check if partition info is supported - return true; -} - -OperatorPartitionData PhysicalTableScan::GetPartitionData(ExecutionContext &context, DataChunk &chunk, - GlobalSourceState &gstate_p, LocalSourceState &lstate, - const OperatorPartitionInfo &partition_info) const { - D_ASSERT(SupportsPartitioning(partition_info)); - D_ASSERT(function.get_partition_data); - auto &gstate = gstate_p.Cast(); - auto &state = lstate.Cast(); - TableFunctionGetPartitionInput input(bind_data.get(), state.local_state.get(), gstate.global_state.get(), - partition_info); - return function.get_partition_data(context.client, input); -} - -string PhysicalTableScan::GetName() const { - return StringUtil::Upper(function.name + " " + function.extra_info); -} - -void AddProjectionNames(const ColumnIndex &index, const string &name, const LogicalType &type, string &result) { - if (!index.HasChildren()) { - // base case - no children projected out - if (!result.empty()) { - result += "\n"; - } - result += name; - return; - } - auto &child_types = StructType::GetChildTypes(type); - for (auto &child_index : index.GetChildIndexes()) { - auto &ele = child_types[child_index.GetPrimaryIndex()]; - AddProjectionNames(child_index, name + "." + ele.first, ele.second, result); - } -} - -InsertionOrderPreservingMap PhysicalTableScan::ParamsToString() const { - InsertionOrderPreservingMap result; - if (function.to_string) { - TableFunctionToStringInput input(function, bind_data.get()); - auto to_string_result = function.to_string(input); - for (const auto &it : to_string_result) { - result[it.first] = it.second; - } - } else { - result["Function"] = StringUtil::Upper(function.name); - } - if (function.projection_pushdown) { - string projections; - idx_t projected_column_count = function.filter_prune ? projection_ids.size() : column_ids.size(); - for (idx_t i = 0; i < projected_column_count; i++) { - auto base_index = function.filter_prune ? projection_ids[i] : i; - auto &column_index = column_ids[base_index]; - auto column_id = column_index.GetPrimaryIndex(); - if (column_id >= names.size()) { - continue; - } - AddProjectionNames(column_index, names[column_id], returned_types[column_id], projections); - } - result["Projections"] = projections; - } - if (function.filter_pushdown && table_filters) { - string filters_info; - bool first_item = true; - for (auto &f : table_filters->filters) { - auto &column_index = f.first; - auto &filter = f.second; - if (column_index < names.size()) { - if (!first_item) { - filters_info += "\n"; - } - first_item = false; - - const auto col_id = column_ids[column_index].GetPrimaryIndex(); - if (col_id == COLUMN_IDENTIFIER_ROW_ID) { - filters_info += filter->ToString("rowid"); - } else { - filters_info += filter->ToString(names[col_id]); - } - } - } - result["Filters"] = filters_info; - } - if (extra_info.sample_options) { - result["Sample Method"] = "System: " + extra_info.sample_options->sample_size.ToString() + "%"; - } - if (!extra_info.file_filters.empty()) { - result["File Filters"] = extra_info.file_filters; - if (extra_info.filtered_files.IsValid() && extra_info.total_files.IsValid()) { - result["Scanning Files"] = StringUtil::Format("%llu/%llu", extra_info.filtered_files.GetIndex(), - extra_info.total_files.GetIndex()); - } - } - - SetEstimatedCardinality(result, estimated_cardinality); - return result; -} - -bool PhysicalTableScan::Equals(const PhysicalOperator &other_p) const { - if (type != other_p.type) { - return false; - } - auto &other = other_p.Cast(); - if (function.function != other.function.function) { - return false; - } - if (column_ids != other.column_ids) { - return false; - } - if (!FunctionData::Equals(bind_data.get(), other.bind_data.get())) { - return false; - } - return true; -} - -bool PhysicalTableScan::ParallelSource() const { - if (!function.function) { - // table in-out functions cannot be executed in parallel as part of a PhysicalTableScan - // since they have only a single input row - return false; - } - return true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_alter.cpp b/src/duckdb/src/execution/operator/schema/physical_alter.cpp deleted file mode 100644 index 0d463e5ad..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_alter.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_alter.hpp" -#include "duckdb/parser/parsed_data/alter_table_info.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalAlter::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.Alter(context.client, *info); - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_attach.cpp b/src/duckdb/src/execution/operator/schema/physical_attach.cpp deleted file mode 100644 index 523f0d57e..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_attach.cpp +++ /dev/null @@ -1,82 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_attach.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/parser/parsed_data/attach_info.hpp" -#include "duckdb/storage/storage_extension.hpp" -#include "duckdb/main/database_path_and_type.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalAttach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - // parse the options - auto &config = DBConfig::GetConfig(context.client); - AttachOptions options(info, config.options.access_mode); - - // get the name and path of the database - auto &name = info->name; - auto &path = info->path; - if (options.db_type.empty()) { - DBPathAndType::ExtractExtensionPrefix(path, options.db_type); - } - if (name.empty()) { - auto &fs = FileSystem::GetFileSystem(context.client); - name = AttachedDatabase::ExtractDatabaseName(path, fs); - } - - // check ATTACH IF NOT EXISTS - auto &db_manager = DatabaseManager::Get(context.client); - if (info->on_conflict == OnCreateConflict::IGNORE_ON_CONFLICT) { - // constant-time lookup in the catalog for the db name - auto existing_db = db_manager.GetDatabase(context.client, name); - if (existing_db) { - if ((existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_WRITE) || - (!existing_db->IsReadOnly() && options.access_mode == AccessMode::READ_ONLY)) { - - auto existing_mode = existing_db->IsReadOnly() ? AccessMode::READ_ONLY : AccessMode::READ_WRITE; - auto existing_mode_str = EnumUtil::ToString(existing_mode); - auto attached_mode = EnumUtil::ToString(options.access_mode); - throw BinderException("Database \"%s\" is already attached in %s mode, cannot re-attach in %s mode", - name, existing_mode_str, attached_mode); - } - if (!options.default_table.name.empty()) { - existing_db->GetCatalog().SetDefaultTable(options.default_table.schema, options.default_table.name); - } - return SourceResultType::FINISHED; - } - } - - string extension = ""; - if (FileSystem::IsRemoteFile(path, extension)) { - if (!ExtensionHelper::TryAutoLoadExtension(context.client, extension)) { - throw MissingExtensionException("Attaching path '%s' requires extension '%s' to be loaded", path, - extension); - } - if (options.access_mode == AccessMode::AUTOMATIC) { - // Attaching of remote files gets bumped to READ_ONLY - // This is due to the fact that on most (all?) remote files writes to DB are not available - // and having this raised later is not super helpful - options.access_mode = AccessMode::READ_ONLY; - } - } - - // Get the database type and attach the database. - db_manager.GetDatabaseType(context.client, *info, config, options); - auto attached_db = db_manager.AttachDatabase(context.client, *info, options); - - //! Initialize the database. - const auto storage_options = info->GetStorageOptions(); - attached_db->Initialize(storage_options); - if (!options.default_table.name.empty()) { - attached_db->GetCatalog().SetDefaultTable(options.default_table.schema, options.default_table.name); - } - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp b/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp deleted file mode 100644 index e34e49de4..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_art_index.cpp +++ /dev/null @@ -1,229 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_art_index.hpp" - -#include "duckdb/catalog/catalog_entry/duck_index_entry.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/execution/index/art/art_key.hpp" -#include "duckdb/execution/index/bound_index.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/storage/table/append_state.hpp" -#include "duckdb/common/exception/transaction_exception.hpp" - -namespace duckdb { - -PhysicalCreateARTIndex::PhysicalCreateARTIndex(LogicalOperator &op, TableCatalogEntry &table_p, - const vector &column_ids, unique_ptr info, - vector> unbound_expressions, - idx_t estimated_cardinality, const bool sorted, - unique_ptr alter_table_info) - : PhysicalOperator(PhysicalOperatorType::CREATE_INDEX, op.types, estimated_cardinality), - table(table_p.Cast()), info(std::move(info)), unbound_expressions(std::move(unbound_expressions)), - sorted(sorted), alter_table_info(std::move(alter_table_info)) { - - // Convert the logical column ids to physical column ids. - for (auto &column_id : column_ids) { - storage_ids.push_back(table.GetColumns().LogicalToPhysical(LogicalIndex(column_id)).index); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// - -class CreateARTIndexGlobalSinkState : public GlobalSinkState { -public: - unique_ptr global_index; -}; - -class CreateARTIndexLocalSinkState : public LocalSinkState { -public: - explicit CreateARTIndexLocalSinkState(ClientContext &context) : arena_allocator(Allocator::Get(context)) {}; - - unique_ptr local_index; - ArenaAllocator arena_allocator; - - DataChunk key_chunk; - unsafe_vector keys; - vector key_column_ids; - - DataChunk row_id_chunk; - unsafe_vector row_ids; -}; - -unique_ptr PhysicalCreateARTIndex::GetGlobalSinkState(ClientContext &context) const { - // Create the global sink state and add the global index. - auto state = make_uniq(); - auto &storage = table.GetStorage(); - state->global_index = make_uniq(info->index_name, info->constraint_type, storage_ids, - TableIOManager::Get(storage), unbound_expressions, storage.db); - return (std::move(state)); -} - -unique_ptr PhysicalCreateARTIndex::GetLocalSinkState(ExecutionContext &context) const { - // Create the local sink state and add the local index. - auto state = make_uniq(context.client); - auto &storage = table.GetStorage(); - state->local_index = make_uniq(info->index_name, info->constraint_type, storage_ids, - TableIOManager::Get(storage), unbound_expressions, storage.db); - - // Initialize the local sink state. - state->keys.resize(STANDARD_VECTOR_SIZE); - state->row_ids.resize(STANDARD_VECTOR_SIZE); - state->key_chunk.Initialize(Allocator::Get(context.client), state->local_index->logical_types); - state->row_id_chunk.Initialize(Allocator::Get(context.client), vector {LogicalType::ROW_TYPE}); - for (idx_t i = 0; i < state->key_chunk.ColumnCount(); i++) { - state->key_column_ids.push_back(i); - } - return std::move(state); -} - -SinkResultType PhysicalCreateARTIndex::SinkUnsorted(OperatorSinkInput &input) const { - - auto &l_state = input.local_state.Cast(); - auto row_count = l_state.key_chunk.size(); - auto &art = l_state.local_index->Cast(); - - // Insert each key and its corresponding row ID. - for (idx_t i = 0; i < row_count; i++) { - auto status = art.tree.GetGateStatus(); - auto conflict_type = art.Insert(art.tree, l_state.keys[i], 0, l_state.row_ids[i], status, nullptr); - D_ASSERT(conflict_type != ARTConflictType::TRANSACTION); - if (conflict_type == ARTConflictType::CONSTRAINT) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkResultType PhysicalCreateARTIndex::SinkSorted(OperatorSinkInput &input) const { - - auto &l_state = input.local_state.Cast(); - auto &storage = table.GetStorage(); - auto &l_index = l_state.local_index; - - // Construct an ART for this chunk. - auto art = make_uniq(info->index_name, l_index->GetConstraintType(), l_index->GetColumnIds(), - l_index->table_io_manager, l_index->unbound_expressions, storage.db, - l_index->Cast().allocators); - if (!art->Construct(l_state.keys, l_state.row_ids, l_state.key_chunk.size())) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - // Merge the ART into the local ART. - if (!l_index->MergeIndexes(*art)) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkResultType PhysicalCreateARTIndex::Sink(ExecutionContext &context, DataChunk &chunk, - OperatorSinkInput &input) const { - - D_ASSERT(chunk.ColumnCount() >= 2); - auto &l_state = input.local_state.Cast(); - l_state.arena_allocator.Reset(); - l_state.key_chunk.ReferenceColumns(chunk, l_state.key_column_ids); - - // Check for NULLs, if we are creating a PRIMARY KEY. - // FIXME: Later, we want to ensure that we skip the NULL check for any non-PK alter. - if (alter_table_info) { - auto row_count = l_state.key_chunk.size(); - for (idx_t i = 0; i < l_state.key_chunk.ColumnCount(); i++) { - if (VectorOperations::HasNull(l_state.key_chunk.data[i], row_count)) { - throw ConstraintException("NOT NULL constraint failed: %s", info->index_name); - } - } - } - - ART::GenerateKeyVectors(l_state.arena_allocator, l_state.key_chunk, chunk.data[chunk.ColumnCount() - 1], - l_state.keys, l_state.row_ids); - - if (sorted) { - return SinkSorted(input); - } - return SinkUnsorted(input); -} - -SinkCombineResultType PhysicalCreateARTIndex::Combine(ExecutionContext &context, - OperatorSinkCombineInput &input) const { - - auto &g_state = input.global_state.Cast(); - auto &l_state = input.local_state.Cast(); - - // Merge the local index into the global index. - if (!g_state.global_index->MergeIndexes(*l_state.local_index)) { - throw ConstraintException("Data contains duplicates on indexed column(s)"); - } - - return SinkCombineResultType::FINISHED; -} - -SinkFinalizeType PhysicalCreateARTIndex::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - - // Here, we set the resulting global index as the newly created index of the table. - auto &state = input.global_state.Cast(); - - // Vacuum excess memory and verify. - state.global_index->Vacuum(); - D_ASSERT(!state.global_index->VerifyAndToString(true).empty()); - state.global_index->VerifyAllocations(); - - auto &storage = table.GetStorage(); - if (!storage.IsRoot()) { - throw TransactionException("cannot add an index to a table that has been altered"); - } - - auto &schema = table.schema; - info->column_ids = storage_ids; - - // FIXME: We should check for catalog exceptions prior to index creation, and later double-check. - if (!alter_table_info) { - // Ensure that the index does not yet exist in the catalog. - auto entry = schema.GetEntry(schema.GetCatalogTransaction(context), CatalogType::INDEX_ENTRY, info->index_name); - if (entry) { - if (info->on_conflict != OnCreateConflict::IGNORE_ON_CONFLICT) { - throw CatalogException("Index with name \"%s\" already exists!", info->index_name); - } - // IF NOT EXISTS on existing index. We are done. - return SinkFinalizeType::READY; - } - - auto index_entry = schema.CreateIndex(schema.GetCatalogTransaction(context), *info, table).get(); - D_ASSERT(index_entry); - auto &index = index_entry->Cast(); - index.initial_index_size = state.global_index->GetInMemorySize(); - - } else { - // Ensure that there are no other indexes with that name on this table. - auto &indexes = storage.GetDataTableInfo()->GetIndexes(); - indexes.Scan([&](Index &index) { - if (index.GetIndexName() == info->index_name) { - throw CatalogException("an index with that name already exists for this table: %s", info->index_name); - } - return false; - }); - - auto &catalog = Catalog::GetCatalog(context, info->catalog); - catalog.Alter(context, *alter_table_info); - } - - // Add the index to the storage. - storage.AddIndex(std::move(state.global_index)); - return SinkFinalizeType::READY; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// - -SourceResultType PhysicalCreateARTIndex::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_function.cpp b/src/duckdb/src/execution/operator/schema/physical_create_function.cpp deleted file mode 100644 index 2521b2082..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_function.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_function.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateFunction::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateFunction(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp b/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp deleted file mode 100644 index b0b031390..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_schema.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_schema.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/exception/binder_exception.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSchema::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - if (catalog.IsSystemCatalog()) { - throw BinderException("Cannot create schema in system catalog"); - } - catalog.CreateSchema(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp b/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp deleted file mode 100644 index 80c4a2ffa..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_sequence.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_sequence.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateSequence::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateSequence(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_table.cpp b/src/duckdb/src/execution/operator/schema/physical_create_table.cpp deleted file mode 100644 index 3220c4356..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_table.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_table.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/storage/data_table.hpp" - -namespace duckdb { - -PhysicalCreateTable::PhysicalCreateTable(LogicalOperator &op, SchemaCatalogEntry &schema, - unique_ptr info, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CREATE_TABLE, op.types, estimated_cardinality), schema(schema), - info(std::move(info)) { -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateTable::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = schema.catalog; - catalog.CreateTable(catalog.GetCatalogTransaction(context.client), schema, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_type.cpp b/src/duckdb/src/execution/operator/schema/physical_create_type.cpp deleted file mode 100644 index 68bc258b3..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_type.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_type.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/common/string_map_set.hpp" - -namespace duckdb { - -PhysicalCreateType::PhysicalCreateType(unique_ptr info_p, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CREATE_TYPE, {LogicalType::BIGINT}, estimated_cardinality), - info(std::move(info_p)) { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CreateTypeGlobalState : public GlobalSinkState { -public: - explicit CreateTypeGlobalState(ClientContext &context) : result(LogicalType::VARCHAR) { - } - Vector result; - idx_t size = 0; - idx_t capacity = STANDARD_VECTOR_SIZE; - string_set_t found_strings; -}; - -unique_ptr PhysicalCreateType::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context); -} - -SinkResultType PhysicalCreateType::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - idx_t total_row_count = gstate.size + chunk.size(); - if (total_row_count > NumericLimits::Maximum()) { - throw InvalidInputException("Attempted to create ENUM of size %llu, which exceeds the maximum size of %llu", - total_row_count, NumericLimits::Maximum()); - } - UnifiedVectorFormat sdata; - chunk.data[0].ToUnifiedFormat(chunk.size(), sdata); - - if (total_row_count > gstate.capacity) { - // We must resize our result vector - gstate.result.Resize(gstate.capacity, gstate.capacity * 2); - gstate.capacity *= 2; - } - - auto src_ptr = UnifiedVectorFormat::GetData(sdata); - auto result_ptr = FlatVector::GetData(gstate.result); - // Input vector has NULL value, we just throw an exception - for (idx_t i = 0; i < chunk.size(); i++) { - idx_t idx = sdata.sel->get_index(i); - if (!sdata.validity.RowIsValid(idx)) { - throw InvalidInputException("Attempted to create ENUM type with NULL value!"); - } - auto str = src_ptr[idx]; - auto entry = gstate.found_strings.find(src_ptr[idx]); - if (entry != gstate.found_strings.end()) { - // entry was already found - skip - continue; - } - auto owned_string = StringVector::AddStringOrBlob(gstate.result, str.GetData(), str.GetSize()); - gstate.found_strings.insert(owned_string); - result_ptr[gstate.size++] = owned_string; - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateType::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - if (IsSink()) { - D_ASSERT(info->type == LogicalType::INVALID); - auto &g_sink_state = sink_state->Cast(); - info->type = LogicalType::ENUM(g_sink_state.result, g_sink_state.size); - } - - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateType(context.client, *info); - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_create_view.cpp b/src/duckdb/src/execution/operator/schema/physical_create_view.cpp deleted file mode 100644 index 948adad14..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_create_view.cpp +++ /dev/null @@ -1,17 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_view.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalCreateView::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.CreateView(context.client, *info); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_detach.cpp b/src/duckdb/src/execution/operator/schema/physical_detach.cpp deleted file mode 100644 index 480890c3a..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_detach.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_detach.hpp" -#include "duckdb/parser/parsed_data/detach_info.hpp" -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/storage/storage_extension.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalDetach::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &db_manager = DatabaseManager::Get(context.client); - db_manager.DetachDatabase(context.client, info->name, info->if_not_found); - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/schema/physical_drop.cpp b/src/duckdb/src/execution/operator/schema/physical_drop.cpp deleted file mode 100644 index cfb6841a8..000000000 --- a/src/duckdb/src/execution/operator/schema/physical_drop.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_drop.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/secret/secret_manager.hpp" -#include "duckdb/catalog/catalog_search_path.hpp" -#include "duckdb/main/settings.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalDrop::GetData(ExecutionContext &context, DataChunk &chunk, OperatorSourceInput &input) const { - switch (info->type) { - case CatalogType::PREPARED_STATEMENT: { - // DEALLOCATE silently ignores errors - auto &statements = ClientData::Get(context.client).prepared_statements; - if (statements.find(info->name) != statements.end()) { - statements.erase(info->name); - } - break; - } - case CatalogType::SCHEMA_ENTRY: { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.DropEntry(context.client, *info); - - // Check if the dropped schema was set as the current schema - auto &client_data = ClientData::Get(context.client); - auto &default_entry = client_data.catalog_search_path->GetDefault(); - auto ¤t_catalog = default_entry.catalog; - auto ¤t_schema = default_entry.schema; - D_ASSERT(info->name != DEFAULT_SCHEMA); - - if (info->catalog == current_catalog && current_schema == info->name) { - // Reset the schema to default - SchemaSetting::SetLocal(context.client, DEFAULT_SCHEMA); - } - break; - } - case CatalogType::SECRET_ENTRY: { - // Note: the schema param is used to optionally pass the storage to drop from - D_ASSERT(info->extra_drop_info); - auto &extra_info = info->extra_drop_info->Cast(); - SecretManager::Get(context.client) - .DropSecretByName(context.client, info->name, info->if_not_found, extra_info.persist_mode, - extra_info.secret_storage); - break; - } - default: { - auto &catalog = Catalog::GetCatalog(context.client, info->catalog); - catalog.DropEntry(context.client, *info); - break; - } - } - - return SourceResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_cte.cpp b/src/duckdb/src/execution/operator/set/physical_cte.cpp deleted file mode 100644 index fad76bbdd..000000000 --- a/src/duckdb/src/execution/operator/set/physical_cte.cpp +++ /dev/null @@ -1,112 +0,0 @@ -#include "duckdb/execution/operator/set/physical_cte.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" - -namespace duckdb { - -PhysicalCTE::PhysicalCTE(string ctename, idx_t table_index, vector types, unique_ptr top, - unique_ptr bottom, idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::CTE, std::move(types), estimated_cardinality), table_index(table_index), - ctename(std::move(ctename)) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -PhysicalCTE::~PhysicalCTE() { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class CTEGlobalState : public GlobalSinkState { -public: - explicit CTEGlobalState(ClientContext &context, const PhysicalCTE &op) : working_table_ref(op.working_table.get()) { - } - optional_ptr working_table_ref; - - mutex lhs_lock; - - void MergeIT(ColumnDataCollection &input) { - lock_guard guard(lhs_lock); - working_table_ref->Combine(input); - } -}; - -class CTELocalState : public LocalSinkState { -public: - explicit CTELocalState(ClientContext &context, const PhysicalCTE &op) - : lhs_data(context, op.working_table->Types()) { - lhs_data.InitializeAppend(append_state); - } - - unique_ptr distinct_state; - ColumnDataCollection lhs_data; - ColumnDataAppendState append_state; - - void Append(DataChunk &input) { - lhs_data.Append(input); - } -}; - -unique_ptr PhysicalCTE::GetGlobalSinkState(ClientContext &context) const { - working_table->Reset(); - return make_uniq(context, *this); -} - -unique_ptr PhysicalCTE::GetLocalSinkState(ExecutionContext &context) const { - auto state = make_uniq(context.client, *this); - return std::move(state); -} - -SinkResultType PhysicalCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &lstate = input.local_state.Cast(); - lstate.lhs_data.Append(lstate.append_state, chunk); - - return SinkResultType::NEED_MORE_INPUT; -} - -SinkCombineResultType PhysicalCTE::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - auto &lstate = input.local_state.Cast(); - auto &gstate = input.global_state.Cast(); - gstate.MergeIT(lstate.lhs_data); - - return SinkCombineResultType::FINISHED; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - D_ASSERT(children.size() == 2); - op_state.reset(); - sink_state.reset(); - - auto &state = meta_pipeline.GetState(); - - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - - for (auto &cte_scan : cte_scans) { - state.cte_dependencies.insert(make_pair(cte_scan, reference(*child_meta_pipeline.GetBasePipeline()))); - } - - children[1]->BuildPipelines(current, meta_pipeline); -} - -vector> PhysicalCTE::GetSources() const { - return children[1]->GetSources(); -} - -InsertionOrderPreservingMap PhysicalCTE::ParamsToString() const { - InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; - result["Table Index"] = StringUtil::Format("%llu", table_index); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp b/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp deleted file mode 100644 index 328e0822b..000000000 --- a/src/duckdb/src/execution/operator/set/physical_recursive_cte.cpp +++ /dev/null @@ -1,233 +0,0 @@ -#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" - -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/execution/executor.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/parallel/event.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -PhysicalRecursiveCTE::PhysicalRecursiveCTE(string ctename, idx_t table_index, vector types, bool union_all, - unique_ptr top, unique_ptr bottom, - idx_t estimated_cardinality) - : PhysicalOperator(PhysicalOperatorType::RECURSIVE_CTE, std::move(types), estimated_cardinality), - ctename(std::move(ctename)), table_index(table_index), union_all(union_all) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -PhysicalRecursiveCTE::~PhysicalRecursiveCTE() { -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -class RecursiveCTEState : public GlobalSinkState { -public: - explicit RecursiveCTEState(ClientContext &context, const PhysicalRecursiveCTE &op) - : intermediate_table(context, op.GetTypes()), new_groups(STANDARD_VECTOR_SIZE) { - ht = make_uniq(context, BufferAllocator::Get(context), op.types, - vector(), vector()); - } - - unique_ptr ht; - - bool intermediate_empty = true; - mutex intermediate_table_lock; - ColumnDataCollection intermediate_table; - ColumnDataScanState scan_state; - bool initialized = false; - bool finished_scan = false; - SelectionVector new_groups; -}; - -unique_ptr PhysicalRecursiveCTE::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -idx_t PhysicalRecursiveCTE::ProbeHT(DataChunk &chunk, RecursiveCTEState &state) const { - Vector dummy_addresses(LogicalType::POINTER); - - // Use the HT to eliminate duplicate rows - idx_t new_group_count = state.ht->FindOrCreateGroups(chunk, dummy_addresses, state.new_groups); - - // we only return entries we have not seen before (i.e. new groups) - chunk.Slice(state.new_groups, new_group_count); - - return new_group_count; -} - -SinkResultType PhysicalRecursiveCTE::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - auto &gstate = input.global_state.Cast(); - - lock_guard guard(gstate.intermediate_table_lock); - if (!union_all) { - idx_t match_count = ProbeHT(chunk, gstate); - if (match_count > 0) { - gstate.intermediate_table.Append(chunk); - } - } else { - gstate.intermediate_table.Append(chunk); - } - return SinkResultType::NEED_MORE_INPUT; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -SourceResultType PhysicalRecursiveCTE::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - auto &gstate = sink_state->Cast(); - if (!gstate.initialized) { - gstate.intermediate_table.InitializeScan(gstate.scan_state); - gstate.finished_scan = false; - gstate.initialized = true; - } - while (chunk.size() == 0) { - if (!gstate.finished_scan) { - // scan any chunks we have collected so far - gstate.intermediate_table.Scan(gstate.scan_state, chunk); - if (chunk.size() == 0) { - gstate.finished_scan = true; - } else { - break; - } - } else { - // we have run out of chunks - // now we need to recurse - // we set up the working table as the data we gathered in this iteration of the recursion - working_table->Reset(); - working_table->Combine(gstate.intermediate_table); - // and we clear the intermediate table - gstate.finished_scan = false; - gstate.intermediate_table.Reset(); - // now we need to re-execute all of the pipelines that depend on the recursion - ExecuteRecursivePipelines(context); - - // check if we obtained any results - // if not, we are done - if (gstate.intermediate_table.Count() == 0) { - gstate.finished_scan = true; - break; - } - // set up the scan again - gstate.intermediate_table.InitializeScan(gstate.scan_state); - } - } - - return chunk.size() == 0 ? SourceResultType::FINISHED : SourceResultType::HAVE_MORE_OUTPUT; -} - -void PhysicalRecursiveCTE::ExecuteRecursivePipelines(ExecutionContext &context) const { - if (!recursive_meta_pipeline) { - throw InternalException("Missing meta pipeline for recursive CTE"); - } - D_ASSERT(recursive_meta_pipeline->HasRecursiveCTE()); - - // get and reset pipelines - vector> pipelines; - recursive_meta_pipeline->GetPipelines(pipelines, true); - for (auto &pipeline : pipelines) { - auto sink = pipeline->GetSink(); - if (sink.get() != this) { - sink->sink_state.reset(); - } - for (auto &op_ref : pipeline->GetOperators()) { - auto &op = op_ref.get(); - op.op_state.reset(); - } - pipeline->ClearSource(); - } - - // get the MetaPipelines in the recursive_meta_pipeline and reschedule them - vector> meta_pipelines; - recursive_meta_pipeline->GetMetaPipelines(meta_pipelines, true, false); - auto &executor = recursive_meta_pipeline->GetExecutor(); - vector> events; - executor.ReschedulePipelines(meta_pipelines, events); - - while (true) { - executor.WorkOnTasks(); - if (executor.HasError()) { - executor.ThrowException(); - } - bool finished = true; - for (auto &event : events) { - if (!event->IsFinished()) { - finished = false; - break; - } - } - if (finished) { - // all pipelines finished: done! - break; - } - } -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// - -static void GatherColumnDataScans(const PhysicalOperator &op, vector> &delim_scans) { - if (op.type == PhysicalOperatorType::DELIM_SCAN || op.type == PhysicalOperatorType::CTE_SCAN) { - delim_scans.push_back(op); - } - for (auto &child : op.children) { - GatherColumnDataScans(*child, delim_scans); - } -} - -void PhysicalRecursiveCTE::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - recursive_meta_pipeline.reset(); - - auto &state = meta_pipeline.GetState(); - state.SetPipelineSource(current, *this); - - auto &executor = meta_pipeline.GetExecutor(); - executor.AddRecursiveCTE(*this); - - // the LHS of the recursive CTE is our initial state - auto &initial_state_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - initial_state_pipeline.Build(*children[0]); - - // the RHS is the recursive pipeline - recursive_meta_pipeline = make_shared_ptr(executor, state, this); - recursive_meta_pipeline->SetRecursiveCTE(); - recursive_meta_pipeline->Build(*children[1]); - - vector> ops; - GatherColumnDataScans(*children[1], ops); - - for (auto op : ops) { - auto entry = state.cte_dependencies.find(op); - if (entry == state.cte_dependencies.end()) { - continue; - } - // this chunk scan introduces a dependency to the current pipeline - // namely a dependency on the CTE pipeline to finish - auto cte_dependency = entry->second.get().shared_from_this(); - current.AddDependency(cte_dependency); - } -} - -vector> PhysicalRecursiveCTE::GetSources() const { - return {*this}; -} - -InsertionOrderPreservingMap PhysicalRecursiveCTE::ParamsToString() const { - InsertionOrderPreservingMap result; - result["CTE Name"] = ctename; - result["Table Index"] = StringUtil::Format("%llu", table_index); - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/operator/set/physical_union.cpp b/src/duckdb/src/execution/operator/set/physical_union.cpp deleted file mode 100644 index 1194b9539..000000000 --- a/src/duckdb/src/execution/operator/set/physical_union.cpp +++ /dev/null @@ -1,89 +0,0 @@ -#include "duckdb/execution/operator/set/physical_union.hpp" - -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" - -namespace duckdb { - -PhysicalUnion::PhysicalUnion(vector types, unique_ptr top, - unique_ptr bottom, idx_t estimated_cardinality, bool allow_out_of_order) - : PhysicalOperator(PhysicalOperatorType::UNION, std::move(types), estimated_cardinality), - allow_out_of_order(allow_out_of_order) { - children.push_back(std::move(top)); - children.push_back(std::move(bottom)); -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalUnion::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - sink_state.reset(); - - // order matters if any of the downstream operators are order dependent, - // or if the sink preserves order, but does not support batch indices to do so - auto sink = meta_pipeline.GetSink(); - bool order_matters = false; - if (!allow_out_of_order) { - order_matters = true; - } - if (current.IsOrderDependent()) { - order_matters = true; - } - if (sink) { - if (sink->SinkOrderDependent()) { - order_matters = true; - } - auto partition_info = sink->RequiredPartitionInfo(); - if (partition_info.batch_index) { - order_matters = true; - } - if (!sink->ParallelSink()) { - order_matters = true; - } - } - - // create a union pipeline that has identical dependencies to 'current' - auto &union_pipeline = meta_pipeline.CreateUnionPipeline(current, order_matters); - - // continue with the current pipeline - children[0]->BuildPipelines(current, meta_pipeline); - - vector> dependencies; - optional_ptr last_child_ptr; - const auto can_saturate_threads = children[0]->CanSaturateThreads(current.GetClientContext()); - if (order_matters || can_saturate_threads) { - // we add dependencies if order matters: union_pipeline comes after all pipelines created by building current - dependencies = meta_pipeline.AddDependenciesFrom(union_pipeline, union_pipeline, false); - // we also add dependencies if the LHS child can saturate all available threads - // in that case, we recursively make all RHS children depend on the LHS. - // This prevents breadth-first plan evaluation - if (can_saturate_threads) { - last_child_ptr = meta_pipeline.GetLastChild(); - } - } - - // build the union pipeline - children[1]->BuildPipelines(union_pipeline, meta_pipeline); - - if (last_child_ptr) { - // the pointer was set, set up the dependencies - meta_pipeline.AddRecursiveDependencies(dependencies, *last_child_ptr); - } - - // Assign proper batch index to the union pipeline - // This needs to happen after the pipelines have been built because unions can be nested - meta_pipeline.AssignNextBatchIndex(union_pipeline); -} - -vector> PhysicalUnion::GetSources() const { - vector> result; - for (auto &child : children) { - auto child_sources = child->GetSources(); - result.insert(result.end(), child_sources.begin(), child_sources.end()); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp b/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp deleted file mode 100644 index c378e61ed..000000000 --- a/src/duckdb/src/execution/perfect_aggregate_hashtable.cpp +++ /dev/null @@ -1,318 +0,0 @@ -#include "duckdb/execution/perfect_aggregate_hashtable.hpp" - -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -PerfectAggregateHashTable::PerfectAggregateHashTable(ClientContext &context, Allocator &allocator, - const vector &group_types_p, - vector payload_types_p, - vector aggregate_objects_p, - vector group_minima_p, vector required_bits_p) - : BaseAggregateHashTable(context, allocator, aggregate_objects_p, std::move(payload_types_p)), - addresses(LogicalType::POINTER), required_bits(std::move(required_bits_p)), total_required_bits(0), - group_minima(std::move(group_minima_p)), sel(STANDARD_VECTOR_SIZE), - aggregate_allocator(make_uniq(allocator)) { - for (auto &group_bits : required_bits) { - total_required_bits += group_bits; - } - // the total amount of groups we allocate space for is 2^required_bits - total_groups = (uint64_t)1 << total_required_bits; - // we don't need to store the groups in a perfect hash table, since the group keys can be deduced by their location - grouping_columns = group_types_p.size(); - layout.Initialize(std::move(aggregate_objects_p)); - tuple_size = layout.GetRowWidth(); - - // allocate and null initialize the data - owned_data = make_unsafe_uniq_array_uninitialized(tuple_size * total_groups); - data = owned_data.get(); - - // set up the empty payloads for every tuple, and initialize the "occupied" flag to false - group_is_set = make_unsafe_uniq_array_uninitialized(total_groups); - memset(group_is_set.get(), 0, total_groups * sizeof(bool)); - - // initialize the hash table for each entry - auto address_data = FlatVector::GetData(addresses); - idx_t init_count = 0; - for (idx_t i = 0; i < total_groups; i++) { - address_data[init_count] = uintptr_t(data) + (tuple_size * i); - init_count++; - if (init_count == STANDARD_VECTOR_SIZE) { - RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); - init_count = 0; - } - } - RowOperations::InitializeStates(layout, addresses, *FlatVector::IncrementalSelectionVector(), init_count); -} - -PerfectAggregateHashTable::~PerfectAggregateHashTable() { - Destroy(); -} - -template -static void ComputeGroupLocationTemplated(UnifiedVectorFormat &group_data, Value &min, uintptr_t *address_data, - idx_t current_shift, idx_t count) { - auto data = UnifiedVectorFormat::GetData(group_data); - auto min_val = min.GetValueUnsafe(); - if (!group_data.validity.AllValid()) { - for (idx_t i = 0; i < count; i++) { - auto index = group_data.sel->get_index(i); - // check if the value is NULL - // NULL groups are considered as "0" in the hash table - // that is to say, they have no effect on the position of the element (because 0 << shift is 0) - // we only need to handle non-null values here - if (group_data.validity.RowIsValid(index)) { - D_ASSERT(data[index] >= min_val); - auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); - address_data[i] += adjusted_value << current_shift; - } - } - } else { - // no null values: we can directly compute the addresses - for (idx_t i = 0; i < count; i++) { - auto index = group_data.sel->get_index(i); - auto adjusted_value = UnsafeNumericCast((data[index] - min_val) + 1); - address_data[i] += adjusted_value << current_shift; - } - } -} - -static void ComputeGroupLocation(Vector &group, Value &min, uintptr_t *address_data, idx_t current_shift, idx_t count) { - UnifiedVectorFormat vdata; - group.ToUnifiedFormat(count, vdata); - - switch (group.GetType().InternalType()) { - case PhysicalType::INT8: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT16: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT32: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::INT64: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT8: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT16: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT32: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - case PhysicalType::UINT64: - ComputeGroupLocationTemplated(vdata, min, address_data, current_shift, count); - break; - default: - throw InternalException("Unsupported group type for perfect aggregate hash table"); - } -} - -void PerfectAggregateHashTable::AddChunk(DataChunk &groups, DataChunk &payload) { - // first we need to find the location in the HT of each of the groups - auto address_data = FlatVector::GetData(addresses); - // zero-initialize the address data - memset(address_data, 0, groups.size() * sizeof(uintptr_t)); - D_ASSERT(groups.ColumnCount() == group_minima.size()); - - // then compute the actual group location by iterating over each of the groups - idx_t current_shift = total_required_bits; - for (idx_t i = 0; i < groups.ColumnCount(); i++) { - current_shift -= required_bits[i]; - ComputeGroupLocation(groups.data[i], group_minima[i], address_data, current_shift, groups.size()); - } - // now we have the HT entry number for every tuple - // compute the actual pointer to the data by adding it to the base HT pointer and multiplying by the tuple size - for (idx_t i = 0; i < groups.size(); i++) { - const auto group = address_data[i]; - if (group >= total_groups) { - throw InvalidInputException("Perfect hash aggregate: aggregate group %llu exceeded total groups %llu. This " - "likely means that the statistics in your data source are corrupt.\n* PRAGMA " - "disable_optimizer to disable optimizations that rely on correct statistics", - group, total_groups); - } - group_is_set[group] = true; - address_data[i] = uintptr_t(data) + group * tuple_size; - } - - // after finding the group location we update the aggregates - idx_t payload_idx = 0; - auto &aggregates = layout.GetAggregates(); - RowOperationsState row_state(*aggregate_allocator); - for (idx_t aggr_idx = 0; aggr_idx < aggregates.size(); aggr_idx++) { - auto &aggregate = aggregates[aggr_idx]; - auto input_count = (idx_t)aggregate.child_count; - if (aggregate.filter) { - RowOperations::UpdateFilteredStates(row_state, filter_set.GetFilterData(aggr_idx), aggregate, addresses, - payload, payload_idx); - } else { - RowOperations::UpdateStates(row_state, aggregate, addresses, payload, payload_idx, payload.size()); - } - // move to the next aggregate - payload_idx += input_count; - VectorOperations::AddInPlace(addresses, UnsafeNumericCast(aggregate.payload_size), payload.size()); - } -} - -void PerfectAggregateHashTable::Combine(PerfectAggregateHashTable &other) { - D_ASSERT(total_groups == other.total_groups); - D_ASSERT(tuple_size == other.tuple_size); - - Vector source_addresses(LogicalType::POINTER); - Vector target_addresses(LogicalType::POINTER); - auto source_addresses_ptr = FlatVector::GetData(source_addresses); - auto target_addresses_ptr = FlatVector::GetData(target_addresses); - - // iterate over all entries of both hash tables and call combine for all entries that can be combined - data_ptr_t source_ptr = other.data; - data_ptr_t target_ptr = data; - idx_t combine_count = 0; - RowOperationsState row_state(*aggregate_allocator); - for (idx_t i = 0; i < total_groups; i++) { - auto has_entry_source = other.group_is_set[i]; - // we only have any work to do if the source has an entry for this group - if (has_entry_source) { - group_is_set[i] = true; - source_addresses_ptr[combine_count] = source_ptr; - target_addresses_ptr[combine_count] = target_ptr; - combine_count++; - if (combine_count == STANDARD_VECTOR_SIZE) { - RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); - combine_count = 0; - } - } - source_ptr += tuple_size; - target_ptr += tuple_size; - } - RowOperations::CombineStates(row_state, layout, source_addresses, target_addresses, combine_count); - - // FIXME: after moving the arena allocator, we currently have to ensure that the pointer is not nullptr, because the - // FIXME: Destroy()-function of the hash table expects an allocator in some cases (e.g., for sorted aggregates) - stored_allocators.push_back(std::move(other.aggregate_allocator)); - other.aggregate_allocator = make_uniq(allocator); -} - -template -static void ReconstructGroupVectorTemplated(uint32_t group_values[], Value &min, idx_t mask, idx_t shift, - idx_t entry_count, Vector &result) { - auto data = FlatVector::GetData(result); - auto &validity_mask = FlatVector::Validity(result); - auto min_data = min.GetValueUnsafe(); - for (idx_t i = 0; i < entry_count; i++) { - // extract the value of this group from the total group index - auto group_index = UnsafeNumericCast((group_values[i] >> shift) & mask); - if (group_index == 0) { - // if it is 0, the value is NULL - validity_mask.SetInvalid(i); - } else { - // otherwise we add the value (minus 1) to the min value - data[i] = UnsafeNumericCast(UnsafeNumericCast(min_data) + - UnsafeNumericCast(group_index) - 1); - } - } -} - -static void ReconstructGroupVector(uint32_t group_values[], Value &min, idx_t required_bits, idx_t shift, - idx_t entry_count, Vector &result) { - // construct the mask for this entry - idx_t mask = ((uint64_t)1 << required_bits) - 1; - switch (result.GetType().InternalType()) { - case PhysicalType::INT8: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT16: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT32: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::INT64: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT8: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT16: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT32: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - case PhysicalType::UINT64: - ReconstructGroupVectorTemplated(group_values, min, mask, shift, entry_count, result); - break; - default: - throw InternalException("Invalid type for perfect aggregate HT group"); - } -} - -void PerfectAggregateHashTable::Scan(idx_t &scan_position, DataChunk &result) { - auto data_pointers = FlatVector::GetData(addresses); - uint32_t group_values[STANDARD_VECTOR_SIZE]; - - // iterate over the HT until we either have exhausted the entire HT, or - idx_t entry_count = 0; - for (; scan_position < total_groups; scan_position++) { - if (group_is_set[scan_position]) { - // this group is set: add it to the set of groups to extract - data_pointers[entry_count] = data + tuple_size * scan_position; - group_values[entry_count] = NumericCast(scan_position); - entry_count++; - if (entry_count == STANDARD_VECTOR_SIZE) { - scan_position++; - break; - } - } - } - if (entry_count == 0) { - // no entries found - return; - } - // first reconstruct the groups from the group index - idx_t shift = total_required_bits; - for (idx_t i = 0; i < grouping_columns; i++) { - shift -= required_bits[i]; - ReconstructGroupVector(group_values, group_minima[i], required_bits[i], shift, entry_count, result.data[i]); - } - // then construct the payloads - result.SetCardinality(entry_count); - RowOperationsState row_state(*aggregate_allocator); - RowOperations::FinalizeStates(row_state, layout, addresses, result, grouping_columns); -} - -void PerfectAggregateHashTable::Destroy() { - // check if there is any destructor to call - bool has_destructor = false; - for (auto &aggr : layout.GetAggregates()) { - if (aggr.function.destructor) { - has_destructor = true; - } - } - if (!has_destructor) { - return; - } - // there are aggregates with destructors: loop over the hash table - // and call the destructor method for each of the aggregates - auto data_pointers = FlatVector::GetData(addresses); - idx_t count = 0; - - // iterate over all initialised slots of the hash table - RowOperationsState row_state(*aggregate_allocator); - data_ptr_t payload_ptr = data; - for (idx_t i = 0; i < total_groups; i++) { - data_pointers[count++] = payload_ptr; - if (count == STANDARD_VECTOR_SIZE) { - RowOperations::DestroyStates(row_state, layout, addresses, count); - count = 0; - } - payload_ptr += tuple_size; - } - RowOperations::DestroyStates(row_state, layout, addresses, count); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_operator.cpp b/src/duckdb/src/execution/physical_operator.cpp deleted file mode 100644 index c5119620b..000000000 --- a/src/duckdb/src/execution/physical_operator.cpp +++ /dev/null @@ -1,355 +0,0 @@ -#include "duckdb/execution/physical_operator.hpp" - -#include "duckdb/common/printer.hpp" -#include "duckdb/common/render_tree.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/tree_renderer.hpp" -#include "duckdb/execution/execution_context.hpp" -#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parallel/meta_pipeline.hpp" -#include "duckdb/parallel/pipeline.hpp" -#include "duckdb/parallel/thread_context.hpp" -#include "duckdb/storage/buffer/buffer_pool.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -string PhysicalOperator::GetName() const { - return PhysicalOperatorToString(type); -} - -string PhysicalOperator::ToString(ExplainFormat format) const { - auto renderer = TreeRenderer::CreateRenderer(format); - stringstream ss; - auto tree = RenderTree::CreateRenderTree(*this); - renderer->ToStream(*tree, ss); - return ss.str(); -} - -// LCOV_EXCL_START -void PhysicalOperator::Print() const { - Printer::Print(ToString()); -} -// LCOV_EXCL_STOP - -vector> PhysicalOperator::GetChildren() const { - vector> result; - for (auto &child : children) { - result.push_back(*child); - } - return result; -} - -void PhysicalOperator::SetEstimatedCardinality(InsertionOrderPreservingMap &result, - idx_t estimated_cardinality) { - result[RenderTreeNode::ESTIMATED_CARDINALITY] = StringUtil::Format("%llu", estimated_cardinality); -} - -idx_t PhysicalOperator::EstimatedThreadCount() const { - idx_t result = 0; - if (children.empty()) { - // Terminal operator, e.g., base table, these decide the degree of parallelism of pipelines - result = MaxValue(estimated_cardinality / (DEFAULT_ROW_GROUP_SIZE * 2), 1); - } else if (type == PhysicalOperatorType::UNION) { - // We can run union pipelines in parallel, so we sum up the thread count of the children - for (auto &child : children) { - result += child->EstimatedThreadCount(); - } - } else { - // For other operators we take the maximum of the children - for (auto &child : children) { - result = MaxValue(child->EstimatedThreadCount(), result); - } - } - return result; -} - -bool PhysicalOperator::CanSaturateThreads(ClientContext &context) const { -#ifdef DEBUG - // In debug mode we always return true here so that the code that depends on it is well-tested - return true; -#else - const auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); - return EstimatedThreadCount() >= num_threads; -#endif -} - -//===--------------------------------------------------------------------===// -// Operator -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -unique_ptr PhysicalOperator::GetOperatorState(ExecutionContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalOperatorState(ClientContext &context) const { - return make_uniq(); -} - -OperatorResultType PhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - throw InternalException("Calling Execute on a node that is not an operator!"); -} - -OperatorFinalizeResultType PhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state) const { - throw InternalException("Calling FinalExecute on a node that is not an operator!"); -} -// LCOV_EXCL_STOP - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -unique_ptr PhysicalOperator::GetLocalSourceState(ExecutionContext &context, - GlobalSourceState &gstate) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(); -} - -// LCOV_EXCL_START -SourceResultType PhysicalOperator::GetData(ExecutionContext &context, DataChunk &chunk, - OperatorSourceInput &input) const { - throw InternalException("Calling GetData on a node that is not a source!"); -} - -OperatorPartitionData PhysicalOperator::GetPartitionData(ExecutionContext &context, DataChunk &chunk, - GlobalSourceState &gstate, LocalSourceState &lstate, - const OperatorPartitionInfo &partition_info) const { - throw InternalException("Calling GetPartitionData on a node that does not support it"); -} - -ProgressData PhysicalOperator::GetProgress(ClientContext &context, GlobalSourceState &gstate) const { - ProgressData res; - res.SetInvalid(); - return res; -} -// LCOV_EXCL_STOP - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -// LCOV_EXCL_START -SinkResultType PhysicalOperator::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input) const { - throw InternalException("Calling Sink on a node that is not a sink!"); -} - -// LCOV_EXCL_STOP - -SinkCombineResultType PhysicalOperator::Combine(ExecutionContext &context, OperatorSinkCombineInput &input) const { - return SinkCombineResultType::FINISHED; -} - -void PhysicalOperator::PrepareFinalize(ClientContext &context, GlobalSinkState &sink_state) const { -} - -SinkFinalizeType PhysicalOperator::Finalize(Pipeline &pipeline, Event &event, ClientContext &context, - OperatorSinkFinalizeInput &input) const { - return SinkFinalizeType::READY; -} - -SinkNextBatchType PhysicalOperator::NextBatch(ExecutionContext &context, OperatorSinkNextBatchInput &input) const { - return SinkNextBatchType::READY; -} - -unique_ptr PhysicalOperator::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(); -} - -unique_ptr PhysicalOperator::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(); -} - -idx_t PhysicalOperator::GetMaxThreadMemory(ClientContext &context) { - // Memory usage per thread should scale with max mem / num threads - // We take 1/4th of this, to be conservative - auto max_memory = BufferManager::GetBufferManager(context).GetQueryMaxMemory(); - auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); - return (max_memory / num_threads) / 4; -} - -bool PhysicalOperator::OperatorCachingAllowed(ExecutionContext &context) { - if (!context.client.config.enable_caching_operators) { - return false; - } else if (!context.pipeline) { - return false; - } else if (!context.pipeline->GetSink()) { - return false; - } else if (context.pipeline->IsOrderDependent()) { - return false; - } else { - auto partition_info = context.pipeline->GetSink()->RequiredPartitionInfo(); - if (partition_info.AnyRequired()) { - return false; - } - } - - return true; -} - -//===--------------------------------------------------------------------===// -// Pipeline Construction -//===--------------------------------------------------------------------===// -void PhysicalOperator::BuildPipelines(Pipeline ¤t, MetaPipeline &meta_pipeline) { - op_state.reset(); - - auto &state = meta_pipeline.GetState(); - if (IsSink()) { - // operator is a sink, build a pipeline - sink_state.reset(); - D_ASSERT(children.size() == 1); - - // single operator: the operator becomes the data source of the current pipeline - state.SetPipelineSource(current, *this); - - // we create a new pipeline starting from the child - auto &child_meta_pipeline = meta_pipeline.CreateChildMetaPipeline(current, *this); - child_meta_pipeline.Build(*children[0]); - } else { - // operator is not a sink! recurse in children - if (children.empty()) { - // source - state.SetPipelineSource(current, *this); - } else { - if (children.size() != 1) { - throw InternalException("Operator not supported in BuildPipelines"); - } - state.AddPipelineOperator(current, *this); - children[0]->BuildPipelines(current, meta_pipeline); - } - } -} - -vector> PhysicalOperator::GetSources() const { - vector> result; - if (IsSink()) { - D_ASSERT(children.size() == 1); - result.push_back(*this); - return result; - } else { - if (children.empty()) { - // source - result.push_back(*this); - return result; - } else { - if (children.size() != 1) { - throw InternalException("Operator not supported in GetSource"); - } - return children[0]->GetSources(); - } - } -} - -bool PhysicalOperator::AllSourcesSupportBatchIndex() const { - auto sources = GetSources(); - for (auto &source : sources) { - if (!source.get().SupportsPartitioning(OperatorPartitionInfo::BatchIndex())) { - return false; - } - } - return true; -} - -void PhysicalOperator::Verify() { -#ifdef DEBUG - auto sources = GetSources(); - D_ASSERT(!sources.empty()); - for (auto &child : children) { - child->Verify(); - } -#endif -} - -bool CachingPhysicalOperator::CanCacheType(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::ARRAY: - return false; - case LogicalTypeId::STRUCT: { - auto &entries = StructType::GetChildTypes(type); - for (auto &entry : entries) { - if (!CanCacheType(entry.second)) { - return false; - } - } - return true; - } - default: - return true; - } -} - -CachingPhysicalOperator::CachingPhysicalOperator(PhysicalOperatorType type, vector types_p, - idx_t estimated_cardinality) - : PhysicalOperator(type, std::move(types_p), estimated_cardinality) { - - caching_supported = true; - for (auto &col_type : types) { - if (!CanCacheType(col_type)) { - caching_supported = false; - break; - } - } -} - -OperatorResultType CachingPhysicalOperator::Execute(ExecutionContext &context, DataChunk &input, DataChunk &chunk, - GlobalOperatorState &gstate, OperatorState &state_p) const { - auto &state = state_p.Cast(); - - // Execute child operator - auto child_result = ExecuteInternal(context, input, chunk, gstate, state); - -#if STANDARD_VECTOR_SIZE >= 128 - if (!state.initialized) { - state.initialized = true; - state.can_cache_chunk = caching_supported && PhysicalOperator::OperatorCachingAllowed(context); - } - if (!state.can_cache_chunk) { - return child_result; - } - // TODO chunk size of 0 should not result in a cache being created! - if (chunk.size() < CACHE_THRESHOLD) { - // we have filtered out a significant amount of tuples - // add this chunk to the cache and continue - - if (!state.cached_chunk) { - state.cached_chunk = make_uniq(); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); - } - - state.cached_chunk->Append(chunk); - - if (state.cached_chunk->size() >= (STANDARD_VECTOR_SIZE - CACHE_THRESHOLD) || - child_result == OperatorResultType::FINISHED) { - // chunk cache full: return it - chunk.Move(*state.cached_chunk); - state.cached_chunk->Initialize(Allocator::Get(context.client), chunk.GetTypes()); - return child_result; - } else { - // chunk cache not full return empty result - chunk.Reset(); - } - } -#endif - - return child_result; -} - -OperatorFinalizeResultType CachingPhysicalOperator::FinalExecute(ExecutionContext &context, DataChunk &chunk, - GlobalOperatorState &gstate, - OperatorState &state_p) const { - auto &state = state_p.Cast(); - if (state.cached_chunk) { - chunk.Move(*state.cached_chunk); - state.cached_chunk.reset(); - } else { - chunk.SetCardinality(0); - } - return OperatorFinalizeResultType::FINISHED; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp b/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp deleted file mode 100644 index c61781a33..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_aggregate.cpp +++ /dev/null @@ -1,332 +0,0 @@ -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/aggregate/physical_perfecthash_aggregate.hpp" -#include "duckdb/execution/operator/aggregate/physical_ungrouped_aggregate.hpp" -#include "duckdb/execution/operator/aggregate/physical_partitioned_aggregate.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parser/expression/comparison_expression.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_aggregate.hpp" - -namespace duckdb { - -static uint32_t RequiredBitsForValue(uint32_t n) { - idx_t required_bits = 0; - while (n > 0) { - n >>= 1; - required_bits++; - } - return UnsafeNumericCast(required_bits); -} - -template -hugeint_t GetRangeHugeint(const BaseStatistics &nstats) { - return Hugeint::Convert(NumericStats::GetMax(nstats)) - Hugeint::Convert(NumericStats::GetMin(nstats)); -} - -static bool CanUsePartitionedAggregate(ClientContext &context, LogicalAggregate &op, PhysicalOperator &child, - vector &partition_columns) { - if (op.grouping_sets.size() > 1 || !op.grouping_functions.empty()) { - return false; - } - for (auto &expression : op.expressions) { - auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct()) { - // distinct aggregates are not supported in partitioned hash aggregates - return false; - } - } - // check if the source is partitioned by the aggregate columns - // figure out the columns we are grouping by - for (auto &group_expr : op.groups) { - // only support bound reference here - if (group_expr->GetExpressionType() != ExpressionType::BOUND_REF) { - return false; - } - auto &ref = group_expr->Cast(); - partition_columns.push_back(ref.index); - } - // traverse the children of the aggregate to find the source operator - reference child_ref(child); - while (child_ref.get().type != PhysicalOperatorType::TABLE_SCAN) { - auto &child_op = child_ref.get(); - switch (child_op.type) { - case PhysicalOperatorType::PROJECTION: { - // recompute partition columns - auto &projection = child_op.Cast(); - vector new_columns; - for (auto &partition_col : partition_columns) { - // we only support bound reference here - auto &expr = projection.select_list[partition_col]; - if (expr->GetExpressionType() != ExpressionType::BOUND_REF) { - return false; - } - auto &ref = expr->Cast(); - new_columns.push_back(ref.index); - } - // continue into child node with new columns - partition_columns = std::move(new_columns); - child_ref = *child_op.children[0]; - break; - } - case PhysicalOperatorType::FILTER: - // continue into child operators - child_ref = *child_op.children[0]; - break; - default: - // unsupported operator for partition pass-through - return false; - } - } - auto &table_scan = child_ref.get().Cast(); - if (!table_scan.function.get_partition_info) { - // this source does not expose partition information - skip - return false; - } - // get the base columns by projecting over the projection_ids/column_ids - if (!table_scan.projection_ids.empty()) { - for (auto &partition_col : partition_columns) { - partition_col = table_scan.projection_ids[partition_col]; - } - } - vector base_columns; - for (const auto &partition_idx : partition_columns) { - auto col_idx = partition_idx; - col_idx = table_scan.column_ids[col_idx].GetPrimaryIndex(); - base_columns.push_back(col_idx); - } - // check if the source operator is partitioned by the grouping columns - TableFunctionPartitionInput input(table_scan.bind_data.get(), base_columns); - auto partition_info = table_scan.function.get_partition_info(context, input); - if (partition_info != TablePartitionInfo::SINGLE_VALUE_PARTITIONS) { - // we only support single-value partitions currently - return false; - } - // we have single value partitions! - return true; -} - -static bool CanUsePerfectHashAggregate(ClientContext &context, LogicalAggregate &op, vector &bits_per_group) { - if (op.grouping_sets.size() > 1 || !op.grouping_functions.empty()) { - return false; - } - idx_t perfect_hash_bits = 0; - if (op.group_stats.empty()) { - op.group_stats.resize(op.groups.size()); - } - for (idx_t group_idx = 0; group_idx < op.groups.size(); group_idx++) { - auto &group = op.groups[group_idx]; - auto &stats = op.group_stats[group_idx]; - - switch (group->return_type.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - case PhysicalType::UINT32: - case PhysicalType::UINT64: - break; - default: - // we only support simple integer types for perfect hashing - return false; - } - // check if the group has stats available - auto &group_type = group->return_type; - if (!stats) { - // no stats, but we might still be able to use perfect hashing if the type is small enough - // for small types we can just set the stats to [type_min, type_max] - switch (group_type.InternalType()) { - case PhysicalType::INT8: - case PhysicalType::INT16: - case PhysicalType::UINT8: - case PhysicalType::UINT16: - break; - default: - // type is too large and there are no stats: skip perfect hashing - return false; - } - // construct stats with the min and max value of the type - stats = NumericStats::CreateUnknown(group_type).ToUnique(); - NumericStats::SetMin(*stats, Value::MinimumValue(group_type)); - NumericStats::SetMax(*stats, Value::MaximumValue(group_type)); - } - auto &nstats = *stats; - - if (!NumericStats::HasMinMax(nstats)) { - return false; - } - - if (NumericStats::Max(*stats) < NumericStats::Min(*stats)) { - // May result in underflow - return false; - } - - // we have a min and a max value for the stats: use that to figure out how many bits we have - // we add two here, one for the NULL value, and one to make the computation one-indexed - // (e.g. if min and max are the same, we still need one entry in total) - hugeint_t range_h; - switch (group_type.InternalType()) { - case PhysicalType::INT8: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT16: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT32: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::INT64: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT8: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT16: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT32: - range_h = GetRangeHugeint(nstats); - break; - case PhysicalType::UINT64: - range_h = GetRangeHugeint(nstats); - break; - default: - throw InternalException("Unsupported type for perfect hash (should be caught before)"); - } - - uint64_t range; - if (!Hugeint::TryCast(range_h, range)) { - return false; - } - - // bail out on any range bigger than 2^32 - if (range >= NumericLimits::Maximum()) { - return false; - } - - range += 2; - // figure out how many bits we need - idx_t required_bits = RequiredBitsForValue(UnsafeNumericCast(range)); - bits_per_group.push_back(required_bits); - perfect_hash_bits += required_bits; - // check if we have exceeded the bits for the hash - if (perfect_hash_bits > ClientConfig::GetConfig(context).perfect_ht_threshold) { - // too many bits for perfect hash - return false; - } - } - for (auto &expression : op.expressions) { - auto &aggregate = expression->Cast(); - if (aggregate.IsDistinct() || !aggregate.function.combine) { - // distinct aggregates are not supported in perfect hash aggregates - return false; - } - } - return true; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAggregate &op) { - unique_ptr groupby; - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - plan = ExtractAggregateExpressions(std::move(plan), op.expressions, op.groups); - - bool can_use_simple_aggregation = true; - for (auto &expression : op.expressions) { - auto &aggregate = expression->Cast(); - if (!aggregate.function.simple_update) { - // unsupported aggregate for simple aggregation: use hash aggregation - can_use_simple_aggregation = false; - break; - } - } - if (op.groups.empty() && op.grouping_sets.size() <= 1) { - // no groups, check if we can use a simple aggregation - // special case: aggregate entire columns together - if (can_use_simple_aggregation) { - groupby = make_uniq_base(op.types, std::move(op.expressions), - op.estimated_cardinality); - } else { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), op.estimated_cardinality); - } - } else { - // groups! create a GROUP BY aggregator - // use a partitioned or perfect hash aggregate if possible - vector partition_columns; - vector required_bits; - if (can_use_simple_aggregation && CanUsePartitionedAggregate(context, op, *plan, partition_columns)) { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), std::move(op.groups), std::move(partition_columns), - op.estimated_cardinality); - } else if (CanUsePerfectHashAggregate(context, op, required_bits)) { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.group_stats), - std::move(required_bits), op.estimated_cardinality); - } else { - groupby = make_uniq_base( - context, op.types, std::move(op.expressions), std::move(op.groups), std::move(op.grouping_sets), - std::move(op.grouping_functions), op.estimated_cardinality); - } - } - groupby->children.push_back(std::move(plan)); - return groupby; -} - -unique_ptr -PhysicalPlanGenerator::ExtractAggregateExpressions(unique_ptr child, - vector> &aggregates, - vector> &groups) { - vector> expressions; - vector types; - - // bind sorted aggregates - for (auto &aggr : aggregates) { - auto &bound_aggr = aggr->Cast(); - if (bound_aggr.order_bys) { - // sorted aggregate! - FunctionBinder::BindSortedAggregate(context, bound_aggr, groups); - } - } - for (auto &group : groups) { - auto ref = make_uniq(group->return_type, expressions.size()); - types.push_back(group->return_type); - expressions.push_back(std::move(group)); - group = std::move(ref); - } - for (auto &aggr : aggregates) { - auto &bound_aggr = aggr->Cast(); - for (auto &child : bound_aggr.children) { - auto ref = make_uniq(child->return_type, expressions.size()); - types.push_back(child->return_type); - expressions.push_back(std::move(child)); - child = std::move(ref); - } - if (bound_aggr.filter) { - auto &filter = bound_aggr.filter; - auto ref = make_uniq(filter->return_type, expressions.size()); - types.push_back(filter->return_type); - expressions.push_back(std::move(filter)); - bound_aggr.filter = std::move(ref); - } - } - if (expressions.empty()) { - return child; - } - auto projection = - make_uniq(std::move(types), std::move(expressions), child->estimated_cardinality); - projection->children.push_back(std::move(child)); - return std::move(projection); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_any_join.cpp b/src/duckdb/src/execution/physical_plan/plan_any_join.cpp deleted file mode 100644 index 5e8ee6223..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_any_join.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_any_join.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalAnyJoin &op) { - // first visit the child nodes - D_ASSERT(op.children.size() == 2); - D_ASSERT(op.condition); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - // create the blockwise NL join - return make_uniq(op, std::move(left), std::move(right), std::move(op.condition), - op.join_type, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp b/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp deleted file mode 100644 index 927defa4f..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_asof_join.cpp +++ /dev/null @@ -1,125 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_window.hpp" -#include "duckdb/execution/operator/join/physical_asof_join.hpp" -#include "duckdb/execution/operator/join/physical_iejoin.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::PlanAsOfJoin(LogicalComparisonJoin &op) { - // now visit the children - D_ASSERT(op.children.size() == 2); - idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); - idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - D_ASSERT(left && right); - - // Validate - vector equi_indexes; - auto asof_idx = op.conditions.size(); - for (size_t c = 0; c < op.conditions.size(); ++c) { - auto &cond = op.conditions[c]; - switch (cond.comparison) { - case ExpressionType::COMPARE_EQUAL: - case ExpressionType::COMPARE_NOT_DISTINCT_FROM: - equi_indexes.emplace_back(c); - break; - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - D_ASSERT(asof_idx == op.conditions.size()); - asof_idx = c; - break; - default: - throw InternalException("Invalid ASOF JOIN comparison"); - } - } - D_ASSERT(asof_idx < op.conditions.size()); - - if (!ClientConfig::GetConfig(context).force_asof_iejoin) { - return make_uniq(op, std::move(left), std::move(right)); - } - - // Strip extra column from rhs projections - auto &right_projection_map = op.right_projection_map; - if (right_projection_map.empty()) { - const auto right_count = right->types.size(); - right_projection_map.reserve(right_count); - for (column_t i = 0; i < right_count; ++i) { - right_projection_map.emplace_back(i); - } - } - - // Debug implementation: IEJoin of Window - // LEAD(asof_column, 1, infinity) OVER (PARTITION BY equi_column... ORDER BY asof_column) AS asof_end - auto &asof_comp = op.conditions[asof_idx]; - auto &asof_column = asof_comp.right; - auto asof_type = asof_column->return_type; - auto asof_end = make_uniq(ExpressionType::WINDOW_LEAD, asof_type, nullptr, nullptr); - asof_end->children.emplace_back(asof_column->Copy()); - // TODO: If infinities are not supported for a type, fake them by looking at LHS statistics? - asof_end->offset_expr = make_uniq(Value::BIGINT(1)); - for (auto equi_idx : equi_indexes) { - asof_end->partitions.emplace_back(op.conditions[equi_idx].right->Copy()); - } - switch (asof_comp.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - case ExpressionType::COMPARE_GREATERTHAN: - asof_end->orders.emplace_back(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); - asof_end->default_expr = make_uniq(Value::Infinity(asof_type)); - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - case ExpressionType::COMPARE_LESSTHAN: - asof_end->orders.emplace_back(OrderType::DESCENDING, OrderByNullType::NULLS_FIRST, asof_column->Copy()); - asof_end->default_expr = make_uniq(Value::NegativeInfinity(asof_type)); - break; - default: - throw InternalException("Invalid ASOF JOIN ordering for WINDOW"); - } - - asof_end->start = WindowBoundary::UNBOUNDED_PRECEDING; - asof_end->end = WindowBoundary::CURRENT_ROW_ROWS; - - vector> window_select; - window_select.emplace_back(std::move(asof_end)); - - auto &window_types = op.children[1]->types; - window_types.emplace_back(asof_type); - - auto window = make_uniq(window_types, std::move(window_select), rhs_cardinality); - window->children.emplace_back(std::move(right)); - - // IEJoin(left, window, conditions || asof_comp ~op asof_end) - JoinCondition asof_upper; - asof_upper.left = asof_comp.left->Copy(); - asof_upper.right = make_uniq(asof_type, window_types.size() - 1); - switch (asof_comp.comparison) { - case ExpressionType::COMPARE_GREATERTHANOREQUALTO: - asof_upper.comparison = ExpressionType::COMPARE_LESSTHAN; - break; - case ExpressionType::COMPARE_GREATERTHAN: - asof_upper.comparison = ExpressionType::COMPARE_LESSTHANOREQUALTO; - break; - case ExpressionType::COMPARE_LESSTHANOREQUALTO: - asof_upper.comparison = ExpressionType::COMPARE_GREATERTHAN; - break; - case ExpressionType::COMPARE_LESSTHAN: - asof_upper.comparison = ExpressionType::COMPARE_GREATERTHANOREQUALTO; - break; - default: - throw InternalException("Invalid ASOF JOIN comparison for IEJoin"); - } - - op.conditions.emplace_back(std::move(asof_upper)); - - return make_uniq(op, std::move(left), std::move(window), std::move(op.conditions), op.join_type, - lhs_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp b/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp deleted file mode 100644 index 46305675b..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_column_data_get.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_column_data_get.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalColumnDataGet &op) { - D_ASSERT(op.children.size() == 0); - D_ASSERT(op.collection); - - return make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, op.estimated_cardinality, - std::move(op.collection)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp b/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp deleted file mode 100644 index e8d310519..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_comparison_join.cpp +++ /dev/null @@ -1,122 +0,0 @@ -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/execution/operator/join/perfect_hash_join_executor.hpp" -#include "duckdb/execution/operator/join/physical_blockwise_nl_join.hpp" -#include "duckdb/execution/operator/join/physical_cross_product.hpp" -#include "duckdb/execution/operator/join/physical_hash_join.hpp" -#include "duckdb/execution/operator/join/physical_iejoin.hpp" -#include "duckdb/execution/operator/join/physical_nested_loop_join.hpp" -#include "duckdb/execution/operator/join/physical_piecewise_merge_join.hpp" -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/function/table/table_scan.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/operator/logical_comparison_join.hpp" -#include "duckdb/transaction/duck_transaction.hpp" - -namespace duckdb { - -static void RewriteJoinCondition(Expression &expr, idx_t offset) { - if (expr.GetExpressionType() == ExpressionType::BOUND_REF) { - auto &ref = expr.Cast(); - ref.index += offset; - } - ExpressionIterator::EnumerateChildren(expr, [&](Expression &child) { RewriteJoinCondition(child, offset); }); -} - -unique_ptr PhysicalPlanGenerator::PlanComparisonJoin(LogicalComparisonJoin &op) { - // now visit the children - D_ASSERT(op.children.size() == 2); - idx_t lhs_cardinality = op.children[0]->EstimateCardinality(context); - idx_t rhs_cardinality = op.children[1]->EstimateCardinality(context); - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - left->estimated_cardinality = lhs_cardinality; - right->estimated_cardinality = rhs_cardinality; - D_ASSERT(left && right); - - if (op.conditions.empty()) { - // no conditions: insert a cross product - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); - } - - idx_t has_range = 0; - bool has_equality = op.HasEquality(has_range); - bool can_merge = has_range > 0; - bool can_iejoin = has_range >= 2 && recursive_cte_tables.empty(); - switch (op.join_type) { - case JoinType::SEMI: - case JoinType::ANTI: - case JoinType::RIGHT_ANTI: - case JoinType::RIGHT_SEMI: - case JoinType::MARK: - can_merge = can_merge && op.conditions.size() == 1; - can_iejoin = false; - break; - default: - break; - } - auto &client_config = ClientConfig::GetConfig(context); - - // TODO: Extend PWMJ to handle all comparisons and projection maps - const auto prefer_range_joins = client_config.prefer_range_joins && can_iejoin; - - unique_ptr plan; - if (has_equality && !prefer_range_joins) { - // Equality join with small number of keys : possible perfect join optimization - plan = make_uniq( - op, std::move(left), std::move(right), std::move(op.conditions), op.join_type, op.left_projection_map, - op.right_projection_map, std::move(op.mark_types), op.estimated_cardinality, std::move(op.filter_pushdown)); - plan->Cast().join_stats = std::move(op.join_stats); - } else { - D_ASSERT(op.left_projection_map.empty()); - if (left->estimated_cardinality <= client_config.nested_loop_join_threshold || - right->estimated_cardinality <= client_config.nested_loop_join_threshold) { - can_iejoin = false; - can_merge = false; - } - if (can_merge && can_iejoin) { - if (left->estimated_cardinality <= client_config.merge_join_threshold || - right->estimated_cardinality <= client_config.merge_join_threshold) { - can_iejoin = false; - } - } - if (can_iejoin) { - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else if (can_merge) { - // range join: use piecewise merge join - plan = - make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else if (PhysicalNestedLoopJoin::IsSupported(op.conditions, op.join_type)) { - // inequality join: use nested loop - plan = make_uniq(op, std::move(left), std::move(right), std::move(op.conditions), - op.join_type, op.estimated_cardinality); - } else { - for (auto &cond : op.conditions) { - RewriteJoinCondition(*cond.right, left->types.size()); - } - auto condition = JoinCondition::CreateExpression(std::move(op.conditions)); - plan = make_uniq(op, std::move(left), std::move(right), std::move(condition), - op.join_type, op.estimated_cardinality); - } - } - return plan; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalComparisonJoin &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - return PlanAsOfJoin(op); - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - return PlanComparisonJoin(op); - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - return PlanDelimJoin(op); - default: - throw InternalException("Unrecognized operator type for LogicalComparisonJoin"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_copy_database.cpp b/src/duckdb/src/execution/physical_plan/plan_copy_database.cpp deleted file mode 100644 index 68eaf5619..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_copy_database.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_copy_database.hpp" -#include "duckdb/execution/operator/persistent/physical_copy_database.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyDatabase &op) { - auto node = make_uniq(op.types, op.estimated_cardinality, std::move(op.info)); - return std::move(node); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp b/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp deleted file mode 100644 index a2981f61f..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_copy_to_file.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "duckdb/execution/operator/persistent/physical_batch_copy_to_file.hpp" -#include "duckdb/execution/operator/persistent/physical_copy_to_file.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_copy_to_file.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCopyToFile &op) { - auto plan = CreatePlan(*op.children[0]); - bool preserve_insertion_order = PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool supports_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto &fs = FileSystem::GetFileSystem(context); - op.file_path = fs.ExpandPath(op.file_path); - if (op.use_tmp_file) { - auto path = StringUtil::GetFilePath(op.file_path); - auto base = StringUtil::GetFileName(op.file_path); - op.file_path = fs.JoinPath(path, "tmp_" + base); - } - if (op.per_thread_output || op.file_size_bytes.IsValid() || op.rotate || op.partition_output || - !op.partition_columns.empty() || op.overwrite_mode != CopyOverwriteMode::COPY_ERROR_ON_CONFLICT) { - // hive-partitioning/per-thread output does not care about insertion order, and does not support batch indexes - preserve_insertion_order = false; - supports_batch_index = false; - } - auto mode = CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; - if (op.function.execution_mode) { - mode = op.function.execution_mode(preserve_insertion_order, supports_batch_index); - } - if (mode == CopyFunctionExecutionMode::BATCH_COPY_TO_FILE) { - if (!supports_batch_index) { - throw InternalException("BATCH_COPY_TO_FILE can only be used if batch indexes are supported"); - } - // batched copy to file - auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), - op.estimated_cardinality); - copy->file_path = op.file_path; - copy->use_tmp_file = op.use_tmp_file; - copy->children.push_back(std::move(plan)); - copy->return_type = op.return_type; - return std::move(copy); - } - - // COPY from select statement to file - auto copy = make_uniq(op.types, op.function, std::move(op.bind_data), op.estimated_cardinality); - copy->file_path = op.file_path; - copy->use_tmp_file = op.use_tmp_file; - copy->overwrite_mode = op.overwrite_mode; - copy->filename_pattern = op.filename_pattern; - copy->file_extension = op.file_extension; - copy->per_thread_output = op.per_thread_output; - if (op.file_size_bytes.IsValid()) { - copy->file_size_bytes = op.file_size_bytes; - } - copy->rotate = op.rotate; - copy->return_type = op.return_type; - copy->partition_output = op.partition_output; - copy->partition_columns = op.partition_columns; - copy->write_partition_columns = op.write_partition_columns; - copy->names = op.names; - copy->expected_types = op.expected_types; - copy->parallel = mode == CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; - - copy->children.push_back(std::move(plan)); - return std::move(copy); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create.cpp b/src/duckdb/src/execution/physical_plan/plan_create.cpp deleted file mode 100644 index 42a1652e6..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_create.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "duckdb/execution/operator/schema/physical_create_function.hpp" -#include "duckdb/execution/operator/schema/physical_create_schema.hpp" -#include "duckdb/execution/operator/schema/physical_create_sequence.hpp" -#include "duckdb/execution/operator/schema/physical_create_type.hpp" -#include "duckdb/execution/operator/schema/physical_create_view.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/parser/parsed_data/create_type_info.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/operator/logical_create.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreate &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_CREATE_TYPE: { - unique_ptr create = make_uniq( - unique_ptr_cast(std::move(op.info)), op.estimated_cardinality); - if (!op.children.empty()) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - create->children.push_back(std::move(plan)); - } - return create; - } - default: - throw NotImplementedException("Unimplemented type for logical simple create"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create_index.cpp b/src/duckdb/src/execution/physical_plan/plan_create_index.cpp deleted file mode 100644 index e2017c233..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_create_index.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/planner/expression/bound_operator_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_create_index.hpp" -#include "duckdb/planner/operator/logical_get.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateIndex &op) { - // Ensure that all expressions contain valid scalar functions. - // E.g., get_current_timestamp(), random(), and sequence values cannot be index keys. - for (idx_t i = 0; i < op.unbound_expressions.size(); i++) { - auto &expr = op.unbound_expressions[i]; - if (!expr->IsConsistent()) { - throw BinderException("Index keys cannot contain expressions with side effects."); - } - } - - // If we get here and the index type is not valid index type, we throw an exception. - const auto index_type = context.db->config.GetIndexTypes().FindByName(op.info->index_type); - if (!index_type) { - throw BinderException("Unknown index type: " + op.info->index_type); - } - if (!index_type->create_plan) { - throw InternalException("Index type '%s' is missing a create_plan function", op.info->index_type); - } - - // Add a dependency for the entire table on which we create the index. - dependencies.AddDependency(op.table); - D_ASSERT(op.info->scan_types.size() - 1 <= op.info->names.size()); - D_ASSERT(op.info->scan_types.size() - 1 <= op.info->column_ids.size()); - - // Generate a physical plan for the parallel index creation. - // TABLE SCAN - PROJECTION - (optional) NOT NULL FILTER - (optional) ORDER BY - CREATE INDEX - D_ASSERT(op.children.size() == 1); - auto table_scan = CreatePlan(*op.children[0]); - - PlanIndexInput input(context, op, table_scan); - return index_type->create_plan(input); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create_secret.cpp b/src/duckdb/src/execution/physical_plan/plan_create_secret.cpp deleted file mode 100644 index b4818c809..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_create_secret.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_create_secret.hpp" -#include "duckdb/execution/operator/helper/physical_create_secret.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateSecret &op) { - return make_uniq(op.info, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_create_table.cpp b/src/duckdb/src/execution/physical_plan/plan_create_table.cpp deleted file mode 100644 index 06b728ee4..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_create_table.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" -#include "duckdb/execution/operator/persistent/physical_insert.hpp" -#include "duckdb/execution/operator/schema/physical_create_table.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/parser/parsed_data/create_table_info.hpp" -#include "duckdb/planner/constraints/bound_check_constraint.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/operator/logical_create_table.hpp" - -namespace duckdb { - -unique_ptr DuckCatalog::PlanCreateTableAs(ClientContext &context, LogicalCreateTable &op, - unique_ptr plan) { - bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - unique_ptr create; - if (!parallel_streaming_insert && use_batch_index) { - create = make_uniq(op, op.schema, std::move(op.info), 0U); - - } else { - create = make_uniq(op, op.schema, std::move(op.info), 0U, - parallel_streaming_insert && num_threads > 1); - } - - D_ASSERT(op.children.size() == 1); - create->children.push_back(std::move(plan)); - return create; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCreateTable &op) { - const auto &create_info = op.info->base->Cast(); - auto &catalog = op.info->schema.catalog; - auto existing_entry = catalog.GetEntry(context, create_info.schema, create_info.table, - OnEntryNotFound::RETURN_NULL); - bool replace = op.info->Base().on_conflict == OnCreateConflict::REPLACE_ON_CONFLICT; - if ((!existing_entry || replace) && !op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - return op.schema.catalog.PlanCreateTableAs(context, op, std::move(plan)); - } else { - return make_uniq(op, op.schema, std::move(op.info), op.estimated_cardinality); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp b/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp deleted file mode 100644 index dac220706..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_cross_product.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/execution/operator/join/physical_cross_product.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_cross_product.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCrossProduct &op) { - D_ASSERT(op.children.size() == 2); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_cte.cpp b/src/duckdb/src/execution/physical_plan/plan_cte.cpp deleted file mode 100644 index 9c6596279..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_cte.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/operator/set/physical_cte.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_cteref.hpp" -#include "duckdb/planner/operator/logical_materialized_cte.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalMaterializedCTE &op) { - D_ASSERT(op.children.size() == 2); - - // Create the working_table that the PhysicalCTE will use for evaluation. - auto working_table = make_shared_ptr(context, op.children[0]->types); - - // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator - recursive_cte_tables[op.table_index] = working_table; - materialized_ctes[op.table_index] = vector>(); - - // Create the plan for the left side. This is the materialization. - auto left = CreatePlan(*op.children[0]); - // Initialize an empty vector to collect the scan operators. - auto right = CreatePlan(*op.children[1]); - - unique_ptr cte; - cte = make_uniq(op.ctename, op.table_index, right->types, std::move(left), std::move(right), - op.estimated_cardinality); - cte->working_table = working_table; - cte->cte_scans = materialized_ctes[op.table_index]; - - return std::move(cte); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delete.cpp b/src/duckdb/src/execution/physical_plan/plan_delete.cpp deleted file mode 100644 index d6afee808..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_delete.cpp +++ /dev/null @@ -1,32 +0,0 @@ -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/execution/operator/persistent/physical_delete.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_delete.hpp" -#include "duckdb/catalog/duck_catalog.hpp" - -namespace duckdb { - -unique_ptr DuckCatalog::PlanDelete(ClientContext &context, LogicalDelete &op, - unique_ptr plan) { - // get the index of the row_id column - auto &bound_ref = op.expressions[0]->Cast(); - - auto del = make_uniq(op.types, op.table, op.table.GetStorage(), std::move(op.bound_constraints), - bound_ref.index, op.estimated_cardinality, op.return_chunk); - del->children.push_back(std::move(plan)); - return std::move(del); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelete &op) { - D_ASSERT(op.children.size() == 1); - D_ASSERT(op.expressions.size() == 1); - D_ASSERT(op.expressions[0]->GetExpressionType() == ExpressionType::BOUND_REF); - - auto plan = CreatePlan(*op.children[0]); - - dependencies.AddDependency(op.table); - return op.table.catalog.PlanDelete(context, op, std::move(plan)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp deleted file mode 100644 index 1b45efe21..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_delim_get.cpp +++ /dev/null @@ -1,16 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_delim_get.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDelimGet &op) { - D_ASSERT(op.children.empty()); - - // create a PhysicalChunkScan without an owned_collection, the collection will be added later - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::DELIM_SCAN, - op.estimated_cardinality, nullptr); - return std::move(chunk_scan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp b/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp deleted file mode 100644 index 9755f8330..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_delim_join.cpp +++ /dev/null @@ -1,66 +0,0 @@ -#include "duckdb/common/enum_util.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/join/physical_hash_join.hpp" -#include "duckdb/execution/operator/join/physical_left_delim_join.hpp" -#include "duckdb/execution/operator/join/physical_right_delim_join.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" - -namespace duckdb { - -static void GatherDelimScans(PhysicalOperator &op, vector> &delim_scans, - idx_t delim_index) { - if (op.type == PhysicalOperatorType::DELIM_SCAN) { - auto &scan = op.Cast(); - scan.delim_index = optional_idx(delim_index); - delim_scans.push_back(op); - } - for (auto &child : op.children) { - GatherDelimScans(*child, delim_scans, delim_index); - } -} - -unique_ptr PhysicalPlanGenerator::PlanDelimJoin(LogicalComparisonJoin &op) { - // first create the underlying join - auto plan = PlanComparisonJoin(op); - // this should create a join, not a cross product - D_ASSERT(plan && plan->type != PhysicalOperatorType::CROSS_PRODUCT); - // duplicate eliminated join - // first gather the scans on the duplicate eliminated data set from the delim side - const idx_t delim_idx = op.delim_flipped ? 0 : 1; - vector> delim_scans; - GatherDelimScans(*plan->children[delim_idx], delim_scans, ++this->delim_index); - if (delim_scans.empty()) { - // no duplicate eliminated scans in the delim side! - // in this case we don't need to create a delim join - // just push the normal join - return plan; - } - vector delim_types; - vector> distinct_groups, distinct_expressions; - for (auto &delim_expr : op.duplicate_eliminated_columns) { - D_ASSERT(delim_expr->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref = delim_expr->Cast(); - delim_types.push_back(bound_ref.return_type); - distinct_groups.push_back(make_uniq(bound_ref.return_type, bound_ref.index)); - } - // now create the duplicate eliminated join - unique_ptr delim_join; - if (op.delim_flipped) { - delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality, - optional_idx(this->delim_index)); - } else { - delim_join = make_uniq(op.types, std::move(plan), delim_scans, op.estimated_cardinality, - optional_idx(this->delim_index)); - } - // we still have to create the DISTINCT clause that is used to generate the duplicate eliminated chunk - delim_join->distinct = make_uniq(context, delim_types, std::move(distinct_expressions), - std::move(distinct_groups), op.estimated_cardinality); - - return std::move(delim_join); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp b/src/duckdb/src/execution/physical_plan/plan_distinct.cpp deleted file mode 100644 index 60fdb6584..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_distinct.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/function/aggregate/distributive_function_utils.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_distinct.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/optimizer/rule/ordered_aggregate_optimizer.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDistinct &op) { - D_ASSERT(op.children.size() == 1); - auto child = CreatePlan(*op.children[0]); - auto &distinct_targets = op.distinct_targets; - D_ASSERT(child); - D_ASSERT(!distinct_targets.empty()); - - auto &types = child->GetTypes(); - vector> groups, aggregates, projections; - idx_t group_count = distinct_targets.size(); - unordered_map group_by_references; - vector aggregate_types; - // creates one group per distinct_target - for (idx_t i = 0; i < distinct_targets.size(); i++) { - auto &target = distinct_targets[i]; - if (target->GetExpressionType() == ExpressionType::BOUND_REF) { - auto &bound_ref = target->Cast(); - group_by_references[bound_ref.index] = i; - } - aggregate_types.push_back(target->return_type); - groups.push_back(std::move(target)); - } - bool requires_projection = false; - if (types.size() != group_count) { - requires_projection = true; - } - // we need to create one aggregate per column in the select_list - for (idx_t i = 0; i < types.size(); ++i) { - auto logical_type = types[i]; - // check if we can directly refer to a group, or if we need to push an aggregate with FIRST - auto entry = group_by_references.find(i); - if (entry != group_by_references.end()) { - auto group_index = entry->second; - // entry is found: can directly refer to a group - projections.push_back(make_uniq(logical_type, group_index)); - if (group_index != i) { - // we require a projection only if this group element is out of order - requires_projection = true; - } - } else { - if (op.distinct_type == DistinctType::DISTINCT && op.order_by) { - throw InternalException("Entry that is not a group, but not a DISTINCT ON aggregate"); - } - // entry is not one of the groups: need to push a FIRST aggregate - auto bound = make_uniq(logical_type, i); - vector> first_children; - first_children.push_back(std::move(bound)); - - FunctionBinder function_binder(context); - auto first_aggregate = - function_binder.BindAggregateFunction(FirstFunctionGetter::GetFunction(logical_type), - std::move(first_children), nullptr, AggregateType::NON_DISTINCT); - first_aggregate->order_bys = op.order_by ? op.order_by->Copy() : nullptr; - - if (ClientConfig::GetConfig(context).enable_optimizer) { - bool changes_made = false; - auto new_expr = OrderedAggregateOptimizer::Apply(context, *first_aggregate, groups, changes_made); - if (new_expr) { - D_ASSERT(new_expr->return_type == first_aggregate->return_type); - D_ASSERT(new_expr->GetExpressionType() == ExpressionType::BOUND_AGGREGATE); - first_aggregate = unique_ptr_cast(std::move(new_expr)); - } - } - // add the projection - projections.push_back(make_uniq(logical_type, group_count + aggregates.size())); - // push it to the list of aggregates - aggregate_types.push_back(logical_type); - aggregates.push_back(std::move(first_aggregate)); - requires_projection = true; - } - } - - child = ExtractAggregateExpressions(std::move(child), aggregates, groups); - - // we add a physical hash aggregation in the plan to select the distinct groups - auto groupby = make_uniq(context, aggregate_types, std::move(aggregates), std::move(groups), - child->estimated_cardinality); - groupby->children.push_back(std::move(child)); - if (!requires_projection) { - return std::move(groupby); - } - - // we add a physical projection on top of the aggregation to project all members in the select list - auto aggr_projection = make_uniq(types, std::move(projections), groupby->estimated_cardinality); - aggr_projection->children.push_back(std::move(groupby)); - return std::move(aggr_projection); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp b/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp deleted file mode 100644 index ed561cf4e..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_dummy_scan.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_dummy_scan.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalDummyScan &op) { - D_ASSERT(op.children.size() == 0); - return make_uniq(op.types, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp b/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp deleted file mode 100644 index 7313f034a..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_empty_result.cpp +++ /dev/null @@ -1,12 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_empty_result.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_empty_result.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalEmptyResult &op) { - D_ASSERT(op.children.size() == 0); - return make_uniq(op.types, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_execute.cpp b/src/duckdb/src/execution/physical_plan/plan_execute.cpp deleted file mode 100644 index 735cf9fc1..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_execute.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_execute.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_execute.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExecute &op) { - if (!op.prepared->plan) { - D_ASSERT(op.children.size() == 1); - auto owned_plan = CreatePlan(*op.children[0]); - auto execute = make_uniq(*owned_plan); - execute->owned_plan = std::move(owned_plan); - execute->prepared = std::move(op.prepared); - return std::move(execute); - } else { - D_ASSERT(op.children.size() == 0); - return make_uniq(*op.prepared->plan); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_explain.cpp b/src/duckdb/src/execution/physical_plan/plan_explain.cpp deleted file mode 100644 index 3f9d2a4b4..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_explain.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include "duckdb/common/tree_renderer.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/helper/physical_explain_analyze.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/operator/logical_explain.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExplain &op) { - D_ASSERT(op.children.size() == 1); - auto logical_plan_opt = op.children[0]->ToString(op.explain_format); - auto plan = CreatePlan(*op.children[0]); - if (op.explain_type == ExplainType::EXPLAIN_ANALYZE) { - auto result = make_uniq(op.types, op.explain_format); - result->children.push_back(std::move(plan)); - return std::move(result); - } - - op.physical_plan = plan->ToString(op.explain_format); - // the output of the explain - vector keys, values; - switch (ClientConfig::GetConfig(context).explain_output_type) { - case ExplainOutputType::OPTIMIZED_ONLY: - keys = {"logical_opt"}; - values = {logical_plan_opt}; - break; - case ExplainOutputType::PHYSICAL_ONLY: - keys = {"physical_plan"}; - values = {op.physical_plan}; - break; - default: - keys = {"logical_plan", "logical_opt", "physical_plan"}; - values = {op.logical_plan_unopt, logical_plan_opt, op.physical_plan}; - } - - // create a ColumnDataCollection from the output - auto &allocator = Allocator::Get(context); - vector plan_types {LogicalType::VARCHAR, LogicalType::VARCHAR}; - auto collection = - make_uniq(context, plan_types, ColumnDataAllocatorType::IN_MEMORY_ALLOCATOR); - - DataChunk chunk; - chunk.Initialize(allocator, op.types); - for (idx_t i = 0; i < keys.size(); i++) { - chunk.SetValue(0, chunk.size(), Value(keys[i])); - chunk.SetValue(1, chunk.size(), Value(values[i])); - chunk.SetCardinality(chunk.size() + 1); - if (chunk.size() == STANDARD_VECTOR_SIZE) { - collection->Append(chunk); - chunk.Reset(); - } - } - collection->Append(chunk); - - // create a chunk scan to output the result - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - op.estimated_cardinality, std::move(collection)); - return std::move(chunk_scan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_export.cpp b/src/duckdb/src/execution/physical_plan/plan_export.cpp deleted file mode 100644 index ff04115e3..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_export.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/execution/operator/persistent/physical_export.hpp" -#include "duckdb/planner/operator/logical_export.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExport &op) { - auto export_node = make_uniq(op.types, op.function, std::move(op.copy_info), - op.estimated_cardinality, std::move(op.exported_tables)); - // plan the underlying copy statements, if any - if (!op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - export_node->children.push_back(std::move(plan)); - } - return std::move(export_node); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp b/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp deleted file mode 100644 index b63726742..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_expression_get.cpp +++ /dev/null @@ -1,38 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_expression_get.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalExpressionGet &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - - auto expr_scan = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - expr_scan->children.push_back(std::move(plan)); - if (!expr_scan->IsFoldable()) { - return std::move(expr_scan); - } - auto &allocator = Allocator::Get(context); - // simple expression scan (i.e. no subqueries to evaluate and no prepared statement parameters) - // we can evaluate all the expressions right now and turn this into a chunk collection scan - auto chunk_scan = make_uniq(op.types, PhysicalOperatorType::COLUMN_DATA_SCAN, - expr_scan->expressions.size(), - make_uniq(context, op.types)); - - DataChunk chunk; - chunk.Initialize(allocator, op.types); - - ColumnDataAppendState append_state; - chunk_scan->collection->InitializeAppend(append_state); - for (idx_t expression_idx = 0; expression_idx < expr_scan->expressions.size(); expression_idx++) { - chunk.Reset(); - expr_scan->EvaluateExpression(context, expression_idx, nullptr, chunk); - chunk_scan->collection->Append(append_state, chunk); - } - return std::move(chunk_scan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_filter.cpp b/src/duckdb/src/execution/physical_plan/plan_filter.cpp deleted file mode 100644 index 50c1253d0..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_filter.cpp +++ /dev/null @@ -1,36 +0,0 @@ -#include "duckdb/execution/operator/filter/physical_filter.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_filter.hpp" -#include "duckdb/planner/operator/logical_get.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalFilter &op) { - D_ASSERT(op.children.size() == 1); - unique_ptr plan = CreatePlan(*op.children[0]); - if (!op.expressions.empty()) { - D_ASSERT(plan->types.size() > 0); - // create a filter if there is anything to filter - auto filter = make_uniq(plan->types, std::move(op.expressions), op.estimated_cardinality); - filter->children.push_back(std::move(plan)); - plan = std::move(filter); - } - if (op.HasProjectionMap()) { - // there is a projection map, generate a physical projection - vector> select_list; - for (idx_t i = 0; i < op.projection_map.size(); i++) { - select_list.push_back(make_uniq(op.types[i], op.projection_map[i])); - } - auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); - proj->children.push_back(std::move(plan)); - plan = std::move(proj); - } - return plan; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_get.cpp b/src/duckdb/src/execution/physical_plan/plan_get.cpp deleted file mode 100644 index 3b5d940eb..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_get.cpp +++ /dev/null @@ -1,195 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_expression_scan.hpp" -#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/operator/projection/physical_tableinout_function.hpp" -#include "duckdb/execution/operator/scan/physical_table_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/function/table/table_scan.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/planner/expression/bound_conjunction_expression.hpp" -#include "duckdb/execution/operator/filter/physical_filter.hpp" - -namespace duckdb { - -unique_ptr CreateTableFilterSet(TableFilterSet &table_filters, const vector &column_ids) { - // create the table filter map - auto table_filter_set = make_uniq(); - for (auto &table_filter : table_filters.filters) { - // find the relative column index from the absolute column index into the table - optional_idx column_index; - for (idx_t i = 0; i < column_ids.size(); i++) { - if (table_filter.first == column_ids[i].GetPrimaryIndex()) { - column_index = i; - break; - } - } - if (!column_index.IsValid()) { - throw InternalException("Could not find column index for table filter"); - } - table_filter_set->filters[column_index.GetIndex()] = std::move(table_filter.second); - } - return table_filter_set; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalGet &op) { - auto column_ids = op.GetColumnIds(); - if (!op.children.empty()) { - auto child_node = CreatePlan(std::move(op.children[0])); - // this is for table producing functions that consume subquery results - // push a projection node with casts if required - if (child_node->types.size() < op.input_table_types.size()) { - throw InternalException( - "Mismatch between input table types and child node types - expected %llu but got %llu", - op.input_table_types.size(), child_node->types.size()); - } - vector return_types; - vector> expressions; - bool any_cast_required = false; - for (idx_t proj_idx = 0; proj_idx < child_node->types.size(); proj_idx++) { - auto ref = make_uniq(child_node->types[proj_idx], proj_idx); - auto &target_type = - proj_idx < op.input_table_types.size() ? op.input_table_types[proj_idx] : child_node->types[proj_idx]; - if (child_node->types[proj_idx] != target_type) { - // cast is required - push a cast - any_cast_required = true; - auto cast = BoundCastExpression::AddCastToType(context, std::move(ref), target_type); - expressions.push_back(std::move(cast)); - } else { - expressions.push_back(std::move(ref)); - } - return_types.push_back(target_type); - } - if (any_cast_required) { - auto proj = make_uniq(std::move(return_types), std::move(expressions), - child_node->estimated_cardinality); - proj->children.push_back(std::move(child_node)); - child_node = std::move(proj); - } - - auto node = make_uniq(op.types, op.function, std::move(op.bind_data), column_ids, - op.estimated_cardinality, std::move(op.projected_input)); - node->children.push_back(std::move(child_node)); - return std::move(node); - } - if (!op.projected_input.empty()) { - throw InternalException("LogicalGet::project_input can only be set for table-in-out functions"); - } - - unique_ptr table_filters; - if (!op.table_filters.filters.empty()) { - table_filters = CreateTableFilterSet(op.table_filters, column_ids); - } - - if (op.function.dependency) { - op.function.dependency(dependencies, op.bind_data.get()); - } - unique_ptr filter; - - auto &projection_ids = op.projection_ids; - - if (table_filters && op.function.supports_pushdown_type) { - vector> select_list; - unique_ptr unsupported_filter; - unordered_set to_remove; - for (auto &entry : table_filters->filters) { - auto column_id = column_ids[entry.first].GetPrimaryIndex(); - auto &type = op.returned_types[column_id]; - if (!op.function.supports_pushdown_type(type)) { - idx_t column_id_filter = entry.first; - bool found_projection = false; - for (idx_t i = 0; i < projection_ids.size(); i++) { - if (column_ids[projection_ids[i]] == column_ids[entry.first]) { - column_id_filter = i; - found_projection = true; - break; - } - } - if (!found_projection) { - projection_ids.push_back(entry.first); - column_id_filter = projection_ids.size() - 1; - } - auto column = make_uniq(type, column_id_filter); - select_list.push_back(entry.second->ToExpression(*column)); - to_remove.insert(entry.first); - } - } - for (auto &col : to_remove) { - table_filters->filters.erase(col); - } - - if (!select_list.empty()) { - vector filter_types; - for (auto &c : projection_ids) { - auto column_id = column_ids[c].GetPrimaryIndex(); - filter_types.push_back(op.returned_types[column_id]); - } - filter = make_uniq(filter_types, std::move(select_list), op.estimated_cardinality); - } - } - op.ResolveOperatorTypes(); - // create the table scan node - if (!op.function.projection_pushdown) { - // function does not support projection pushdown - auto node = make_uniq( - op.returned_types, op.function, std::move(op.bind_data), op.returned_types, column_ids, vector(), - op.names, std::move(table_filters), op.estimated_cardinality, op.extra_info, std::move(op.parameters)); - // first check if an additional projection is necessary - if (column_ids.size() == op.returned_types.size()) { - bool projection_necessary = false; - for (idx_t i = 0; i < column_ids.size(); i++) { - if (column_ids[i].GetPrimaryIndex() != i) { - projection_necessary = true; - break; - } - } - if (!projection_necessary) { - // a projection is not necessary if all columns have been requested in-order - // in that case we just return the node - if (filter) { - filter->children.push_back(std::move(node)); - return std::move(filter); - } - return std::move(node); - } - } - // push a projection on top that does the projection - vector types; - vector> expressions; - for (auto &column_id : column_ids) { - if (column_id.IsRowIdColumn()) { - types.emplace_back(op.GetRowIdType()); - // Now how to make that a constant expression. - expressions.push_back(make_uniq(Value(op.GetRowIdType()))); - } else { - auto col_id = column_id.GetPrimaryIndex(); - auto type = op.returned_types[col_id]; - types.push_back(type); - expressions.push_back(make_uniq(type, col_id)); - } - } - unique_ptr projection = - make_uniq(std::move(types), std::move(expressions), op.estimated_cardinality); - if (filter) { - filter->children.push_back(std::move(node)); - projection->children.push_back(std::move(filter)); - } else { - projection->children.push_back(std::move(node)); - } - return std::move(projection); - } else { - auto node = make_uniq(op.types, op.function, std::move(op.bind_data), op.returned_types, - column_ids, op.projection_ids, op.names, std::move(table_filters), - op.estimated_cardinality, op.extra_info, std::move(op.parameters)); - node->dynamic_filters = op.dynamic_filters; - if (filter) { - filter->children.push_back(std::move(node)); - return std::move(filter); - } - return std::move(node); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_insert.cpp b/src/duckdb/src/execution/physical_plan/plan_insert.cpp deleted file mode 100644 index 2342a1ef0..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_insert.cpp +++ /dev/null @@ -1,121 +0,0 @@ -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/execution/operator/persistent/physical_insert.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_insert.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/execution/operator/persistent/physical_batch_insert.hpp" -#include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/catalog/duck_catalog.hpp" - -namespace duckdb { - -static OrderPreservationType OrderPreservationRecursive(PhysicalOperator &op) { - if (op.IsSource()) { - return op.SourceOrder(); - } - - idx_t child_idx = 0; - for (auto &child : op.children) { - // Do not take the materialization phase of physical CTEs into account - if (op.type == PhysicalOperatorType::CTE && child_idx == 0) { - continue; - } - auto child_preservation = OrderPreservationRecursive(*child); - if (child_preservation != OrderPreservationType::INSERTION_ORDER) { - return child_preservation; - } - child_idx++; - } - return OrderPreservationType::INSERTION_ORDER; -} - -bool PhysicalPlanGenerator::PreserveInsertionOrder(ClientContext &context, PhysicalOperator &plan) { - auto &config = DBConfig::GetConfig(context); - - auto preservation_type = OrderPreservationRecursive(plan); - if (preservation_type == OrderPreservationType::FIXED_ORDER) { - // always need to maintain preservation order - return true; - } - if (preservation_type == OrderPreservationType::NO_ORDER) { - // never need to preserve order - return false; - } - // preserve insertion order - check flags - if (!config.options.preserve_insertion_order) { - // preserving insertion order is disabled by config - return false; - } - return true; -} - -bool PhysicalPlanGenerator::PreserveInsertionOrder(PhysicalOperator &plan) { - return PreserveInsertionOrder(context, plan); -} - -bool PhysicalPlanGenerator::UseBatchIndex(ClientContext &context, PhysicalOperator &plan) { - // TODO: always preserve order if query contains ORDER BY - auto &scheduler = TaskScheduler::GetScheduler(context); - if (scheduler.NumberOfThreads() == 1) { - // batch index usage only makes sense if we are using multiple threads - return false; - } - if (!plan.AllSourcesSupportBatchIndex()) { - // batch index is not supported - return false; - } - return true; -} - -bool PhysicalPlanGenerator::UseBatchIndex(PhysicalOperator &plan) { - return UseBatchIndex(context, plan); -} - -unique_ptr DuckCatalog::PlanInsert(ClientContext &context, LogicalInsert &op, - unique_ptr plan) { - bool parallel_streaming_insert = !PhysicalPlanGenerator::PreserveInsertionOrder(context, *plan); - bool use_batch_index = PhysicalPlanGenerator::UseBatchIndex(context, *plan); - auto num_threads = TaskScheduler::GetScheduler(context).NumberOfThreads(); - if (op.return_chunk) { - // not supported for RETURNING (yet?) - parallel_streaming_insert = false; - use_batch_index = false; - } - if (op.action_type != OnConflictAction::THROW) { - // We don't support ON CONFLICT clause in batch insertion operation currently - use_batch_index = false; - } - if (op.action_type == OnConflictAction::UPDATE) { - // When we potentially need to perform updates, we have to check that row is not updated twice - // that currently needs to be done for every chunk, which would add a huge bottleneck to parallelized insertion - parallel_streaming_insert = false; - } - unique_ptr insert; - if (use_batch_index && !parallel_streaming_insert) { - insert = make_uniq(op.types, op.table, op.column_index_map, std::move(op.bound_defaults), - std::move(op.bound_constraints), op.estimated_cardinality); - } else { - insert = make_uniq( - op.types, op.table, op.column_index_map, std::move(op.bound_defaults), std::move(op.bound_constraints), - std::move(op.expressions), std::move(op.set_columns), std::move(op.set_types), op.estimated_cardinality, - op.return_chunk, parallel_streaming_insert && num_threads > 1, op.action_type, - std::move(op.on_conflict_condition), std::move(op.do_update_condition), std::move(op.on_conflict_filter), - std::move(op.columns_to_fetch), op.update_is_del_and_insert); - } - D_ASSERT(plan); - insert->children.push_back(std::move(plan)); - return insert; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalInsert &op) { - unique_ptr plan; - if (!op.children.empty()) { - D_ASSERT(op.children.size() == 1); - plan = CreatePlan(*op.children[0]); - } - dependencies.AddDependency(op.table); - return op.table.catalog.PlanInsert(context, op, std::move(plan)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_limit.cpp b/src/duckdb/src/execution/physical_plan/plan_limit.cpp deleted file mode 100644 index 508f1c889..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_limit.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_limit.hpp" -#include "duckdb/execution/operator/helper/physical_streaming_limit.hpp" -#include "duckdb/execution/operator/helper/physical_limit_percent.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/planner/operator/logical_limit.hpp" - -namespace duckdb { - -bool UseBatchLimit(PhysicalOperator &child_node, BoundLimitNode &limit_val, BoundLimitNode &offset_val) { -#ifdef DUCKDB_ALTERNATIVE_VERIFY - return true; -#else - // we only want to use the batch limit when we are executing a complex query (e.g. involving a filter or join) - // if we are doing a limit over a table scan we are otherwise scanning a lot of rows just to throw them away - reference current_ref(child_node); - bool finished = false; - while (!finished) { - auto ¤t_op = current_ref.get(); - switch (current_op.type) { - case PhysicalOperatorType::TABLE_SCAN: - return false; - case PhysicalOperatorType::PROJECTION: - current_ref = *current_op.children[0]; - break; - default: - finished = true; - break; - } - } - // we only use batch limit when we are computing a small amount of values - // as the batch limit materializes this many rows PER thread - static constexpr const idx_t BATCH_LIMIT_THRESHOLD = 10000; - - if (limit_val.Type() != LimitNodeType::CONSTANT_VALUE) { - return false; - } - if (offset_val.Type() == LimitNodeType::EXPRESSION_VALUE) { - return false; - } - idx_t total_offset = limit_val.GetConstantValue(); - if (offset_val.Type() == LimitNodeType::CONSTANT_VALUE) { - total_offset += offset_val.GetConstantValue(); - } - return total_offset <= BATCH_LIMIT_THRESHOLD; -#endif -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalLimit &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - unique_ptr limit; - switch (op.limit_val.Type()) { - case LimitNodeType::EXPRESSION_PERCENTAGE: - case LimitNodeType::CONSTANT_PERCENTAGE: - limit = make_uniq(op.types, std::move(op.limit_val), std::move(op.offset_val), - op.estimated_cardinality); - break; - default: - if (!PreserveInsertionOrder(*plan)) { - // use parallel streaming limit if insertion order is not important - limit = make_uniq(op.types, std::move(op.limit_val), std::move(op.offset_val), - op.estimated_cardinality, true); - } else { - // maintaining insertion order is important - if (UseBatchIndex(*plan) && UseBatchLimit(*plan, op.limit_val, op.offset_val)) { - // source supports batch index: use parallel batch limit - limit = make_uniq(op.types, std::move(op.limit_val), std::move(op.offset_val), - op.estimated_cardinality); - } else { - // source does not support batch index: use a non-parallel streaming limit - limit = make_uniq(op.types, std::move(op.limit_val), std::move(op.offset_val), - op.estimated_cardinality, false); - } - } - break; - } - - limit->children.push_back(std::move(plan)); - return limit; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_order.cpp b/src/duckdb/src/execution/physical_plan/plan_order.cpp deleted file mode 100644 index 449e7ee81..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_order.cpp +++ /dev/null @@ -1,28 +0,0 @@ -#include "duckdb/execution/operator/order/physical_order.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_order.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOrder &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - if (!op.orders.empty()) { - vector projection_map; - if (op.HasProjectionMap()) { - projection_map = std::move(op.projection_map); - } else { - for (idx_t i = 0; i < plan->types.size(); i++) { - projection_map.push_back(i); - } - } - auto order = make_uniq(op.types, std::move(op.orders), std::move(projection_map), - op.estimated_cardinality); - order->children.push_back(std::move(plan)); - plan = std::move(order); - } - return plan; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_pivot.cpp b/src/duckdb/src/execution/physical_plan/plan_pivot.cpp deleted file mode 100644 index bca3bc909..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_pivot.cpp +++ /dev/null @@ -1,14 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_pivot.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_pivot.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPivot &op) { - D_ASSERT(op.children.size() == 1); - auto child_plan = CreatePlan(*op.children[0]); - auto pivot = make_uniq(std::move(op.types), std::move(child_plan), std::move(op.bound_pivot)); - return std::move(pivot); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp b/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp deleted file mode 100644 index 6fd20ed4e..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_positional_join.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include "duckdb/execution/operator/join/physical_positional_join.hpp" -#include "duckdb/execution/operator/scan/physical_positional_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_positional_join.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPositionalJoin &op) { - D_ASSERT(op.children.size() == 2); - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - switch (left->type) { - case PhysicalOperatorType::TABLE_SCAN: - case PhysicalOperatorType::POSITIONAL_SCAN: - switch (right->type) { - case PhysicalOperatorType::TABLE_SCAN: - case PhysicalOperatorType::POSITIONAL_SCAN: - return make_uniq(op.types, std::move(left), std::move(right)); - default: - break; - } - break; - default: - break; - } - - return make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_pragma.cpp b/src/duckdb/src/execution/physical_plan/plan_pragma.cpp deleted file mode 100644 index c0303965f..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_pragma.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_pragma.hpp" - -#include "duckdb/execution/operator/helper/physical_pragma.hpp" -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPragma &op) { - return make_uniq(std::move(op.info), op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_prepare.cpp b/src/duckdb/src/execution/physical_plan/plan_prepare.cpp deleted file mode 100644 index 191377b1f..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_prepare.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "duckdb/execution/operator/scan/physical_dummy_scan.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_prepare.hpp" -#include "duckdb/execution/operator/helper/physical_prepare.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalPrepare &op) { - D_ASSERT(op.children.size() <= 1); - - // generate physical plan only when all parameters are bound (otherwise the physical plan won't be used anyway) - if (op.prepared->properties.bound_all_parameters && !op.children.empty()) { - auto plan = CreatePlan(*op.children[0]); - op.prepared->types = plan->types; - op.prepared->plan = std::move(plan); - } - - return make_uniq(op.name, std::move(op.prepared), op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_projection.cpp b/src/duckdb/src/execution/physical_plan/plan_projection.cpp deleted file mode 100644 index f5f262f55..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_projection.cpp +++ /dev/null @@ -1,44 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_projection.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalProjection &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - -#ifdef DEBUG - for (auto &expr : op.expressions) { - D_ASSERT(!expr->IsWindow()); - D_ASSERT(!expr->IsAggregate()); - } -#endif - if (plan->types.size() == op.types.size()) { - // check if this projection can be omitted entirely - // this happens if a projection simply emits the columns in the same order - // e.g. PROJECTION(#0, #1, #2, #3, ...) - bool omit_projection = true; - for (idx_t i = 0; i < op.types.size(); i++) { - if (op.expressions[i]->GetExpressionType() == ExpressionType::BOUND_REF) { - auto &bound_ref = op.expressions[i]->Cast(); - if (bound_ref.index == i) { - continue; - } - } - omit_projection = false; - break; - } - if (omit_projection) { - // the projection only directly projects the child' columns: omit it entirely - return plan; - } - } - - auto projection = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - projection->children.push_back(std::move(plan)); - return std::move(projection); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp b/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp deleted file mode 100644 index 7933dd166..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_recursive_cte.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/operator/scan/physical_column_data_scan.hpp" -#include "duckdb/execution/operator/set/physical_recursive_cte.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/operator/logical_cteref.hpp" -#include "duckdb/planner/operator/logical_recursive_cte.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalRecursiveCTE &op) { - D_ASSERT(op.children.size() == 2); - - // Create the working_table that the PhysicalRecursiveCTE will use for evaluation. - auto working_table = make_shared_ptr(context, op.types); - - // Add the ColumnDataCollection to the context of this PhysicalPlanGenerator - recursive_cte_tables[op.table_index] = working_table; - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - auto cte = make_uniq(op.ctename, op.table_index, op.types, op.union_all, std::move(left), - std::move(right), op.estimated_cardinality); - cte->working_table = working_table; - - return std::move(cte); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalCTERef &op) { - D_ASSERT(op.children.empty()); - - // Check if this LogicalCTERef is supposed to scan a materialized CTE. - if (op.materialized_cte == CTEMaterialize::CTE_MATERIALIZE_ALWAYS) { - // Lookup if there is a materialized CTE for the cte_index. - auto materialized_cte = materialized_ctes.find(op.cte_index); - - // If this check fails, this is a reference to a materialized recursive CTE. - if (materialized_cte != materialized_ctes.end()) { - auto chunk_scan = make_uniq(op.chunk_types, PhysicalOperatorType::CTE_SCAN, - op.estimated_cardinality, op.cte_index); - - auto cte = recursive_cte_tables.find(op.cte_index); - if (cte == recursive_cte_tables.end()) { - throw InvalidInputException("Referenced materialized CTE does not exist."); - } - chunk_scan->collection = cte->second.get(); - materialized_cte->second.push_back(*chunk_scan.get()); - - return std::move(chunk_scan); - } - } - - // CreatePlan of a LogicalRecursiveCTE must have happened before. - auto cte = recursive_cte_tables.find(op.cte_index); - if (cte == recursive_cte_tables.end()) { - throw InvalidInputException("Referenced recursive CTE does not exist."); - } - - auto chunk_scan = make_uniq( - cte->second.get()->Types(), PhysicalOperatorType::RECURSIVE_CTE_SCAN, op.estimated_cardinality, op.cte_index); - - chunk_scan->collection = cte->second.get(); - - return std::move(chunk_scan); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_reset.cpp b/src/duckdb/src/execution/physical_plan/plan_reset.cpp deleted file mode 100644 index 2f8aa3698..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_reset.cpp +++ /dev/null @@ -1,11 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_reset.hpp" -#include "duckdb/execution/operator/helper/physical_reset.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalReset &op) { - return make_uniq(op.name, op.scope, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_sample.cpp b/src/duckdb/src/execution/physical_plan/plan_sample.cpp deleted file mode 100644 index be5578477..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_sample.cpp +++ /dev/null @@ -1,42 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_reservoir_sample.hpp" -#include "duckdb/execution/operator/helper/physical_streaming_sample.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_sample.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/random_engine.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSample &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - unique_ptr sample; - if (!op.sample_options->seed.IsValid()) { - auto &random_engine = RandomEngine::Get(context); - op.sample_options->SetSeed(random_engine.NextRandomInteger()); - } - switch (op.sample_options->method) { - case SampleMethod::RESERVOIR_SAMPLE: - sample = make_uniq(op.types, std::move(op.sample_options), op.estimated_cardinality); - break; - case SampleMethod::SYSTEM_SAMPLE: - case SampleMethod::BERNOULLI_SAMPLE: - if (!op.sample_options->is_percentage) { - throw ParserException("Sample method %s cannot be used with a discrete sample count, either switch to " - "reservoir sampling or use a sample_size", - EnumUtil::ToString(op.sample_options->method)); - } - sample = make_uniq( - op.types, op.sample_options->method, op.sample_options->sample_size.GetValue(), - static_cast(op.sample_options->seed.GetIndex()), op.estimated_cardinality); - break; - default: - throw InternalException("Unimplemented sample method"); - } - sample->children.push_back(std::move(plan)); - return sample; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_set.cpp b/src/duckdb/src/execution/physical_plan/plan_set.cpp deleted file mode 100644 index fe09cb317..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_set.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_set.hpp" -#include "duckdb/execution/operator/helper/physical_set.hpp" -#include "duckdb/execution/operator/helper/physical_set_variable.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSet &op) { - if (!op.children.empty()) { - // set variable - auto child = CreatePlan(*op.children[0]); - auto set_variable = make_uniq(std::move(op.name), op.estimated_cardinality); - set_variable->children.push_back(std::move(child)); - return std::move(set_variable); - } - // set config setting - return make_uniq(op.name, op.value, op.scope, op.estimated_cardinality); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp b/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp deleted file mode 100644 index dada24fcf..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_set_operation.cpp +++ /dev/null @@ -1,123 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/execution/operator/aggregate/physical_window.hpp" -#include "duckdb/execution/operator/join/physical_hash_join.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/operator/set/physical_union.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" -#include "duckdb/planner/operator/logical_set_operation.hpp" - -namespace duckdb { - -static vector> CreatePartitionedRowNumExpression(const vector &types) { - vector> res; - auto expr = - make_uniq(ExpressionType::WINDOW_ROW_NUMBER, LogicalType::BIGINT, nullptr, nullptr); - expr->start = WindowBoundary::UNBOUNDED_PRECEDING; - expr->end = WindowBoundary::UNBOUNDED_FOLLOWING; - for (idx_t i = 0; i < types.size(); i++) { - expr->partitions.push_back(make_uniq(types[i], i)); - } - res.push_back(std::move(expr)); - return res; -} - -static JoinCondition CreateNotDistinctComparison(const LogicalType &type, idx_t i) { - JoinCondition cond; - cond.left = make_uniq(type, i); - cond.right = make_uniq(type, i); - cond.comparison = ExpressionType::COMPARE_NOT_DISTINCT_FROM; - return cond; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSetOperation &op) { - D_ASSERT(op.children.size() == 2); - - unique_ptr result; - - auto left = CreatePlan(*op.children[0]); - auto right = CreatePlan(*op.children[1]); - - if (left->GetTypes() != right->GetTypes()) { - throw InvalidInputException("Type mismatch for SET OPERATION"); - } - - switch (op.type) { - case LogicalOperatorType::LOGICAL_UNION: - // UNION - result = make_uniq(op.types, std::move(left), std::move(right), op.estimated_cardinality, - op.allow_out_of_order); - break; - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: { - auto &types = left->GetTypes(); - vector conditions; - // create equality condition for all columns - for (idx_t i = 0; i < types.size(); i++) { - conditions.push_back(CreateNotDistinctComparison(types[i], i)); - } - // For EXCEPT ALL / INTERSECT ALL we push a window operator with a ROW_NUMBER into the scans and join to get bag - // semantics. - if (op.setop_all) { - vector window_types = types; - window_types.push_back(LogicalType::BIGINT); - - auto window_left = make_uniq(window_types, CreatePartitionedRowNumExpression(types), - left->estimated_cardinality); - window_left->children.push_back(std::move(left)); - left = std::move(window_left); - - auto window_right = make_uniq(window_types, CreatePartitionedRowNumExpression(types), - right->estimated_cardinality); - window_right->children.push_back(std::move(right)); - right = std::move(window_right); - - // add window expression result to join condition - conditions.push_back(CreateNotDistinctComparison(LogicalType::BIGINT, types.size())); - // join (created below) now includes the row number result column - op.types.push_back(LogicalType::BIGINT); - } - - // EXCEPT is ANTI join - // INTERSECT is SEMI join - - JoinType join_type = op.type == LogicalOperatorType::LOGICAL_EXCEPT ? JoinType::ANTI : JoinType::SEMI; - result = make_uniq(op, std::move(left), std::move(right), std::move(conditions), join_type, - op.estimated_cardinality); - - // For EXCEPT ALL / INTERSECT ALL we need to remove the row number column again - if (op.setop_all) { - vector> projection_select_list; - for (idx_t i = 0; i < types.size(); i++) { - projection_select_list.push_back(make_uniq(types[i], i)); - } - auto projection = - make_uniq(types, std::move(projection_select_list), op.estimated_cardinality); - projection->children.push_back(std::move(result)); - result = std::move(projection); - } - break; - } - default: - throw InternalException("Unexpected operator type for set operation"); - } - - // if the ALL specifier is not given, we have to ensure distinct results. Hence, push a GROUP BY ALL - if (!op.setop_all) { // no ALL, use distinct semantics - auto &types = result->GetTypes(); - vector> groups, aggregates /* left empty */; - for (idx_t i = 0; i < types.size(); i++) { - groups.push_back(make_uniq(types[i], i)); - } - auto groupby = make_uniq(context, op.types, std::move(aggregates), std::move(groups), - result->estimated_cardinality); - groupby->children.push_back(std::move(result)); - result = std::move(groupby); - } - - D_ASSERT(result); - return (result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_simple.cpp b/src/duckdb/src/execution/physical_plan/plan_simple.cpp deleted file mode 100644 index 7c03ff4f0..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_simple.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_load.hpp" -#include "duckdb/execution/operator/helper/physical_transaction.hpp" -#include "duckdb/execution/operator/helper/physical_update_extensions.hpp" -#include "duckdb/execution/operator/helper/physical_vacuum.hpp" -#include "duckdb/execution/operator/schema/physical_alter.hpp" -#include "duckdb/execution/operator/schema/physical_attach.hpp" -#include "duckdb/execution/operator/schema/physical_create_schema.hpp" -#include "duckdb/execution/operator/schema/physical_create_view.hpp" -#include "duckdb/execution/operator/schema/physical_detach.hpp" -#include "duckdb/execution/operator/schema/physical_drop.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/operator/logical_simple.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalSimple &op) { - switch (op.type) { - case LogicalOperatorType::LOGICAL_ALTER: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_DROP: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_TRANSACTION: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_LOAD: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_ATTACH: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_DETACH: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: - return make_uniq(unique_ptr_cast(std::move(op.info)), - op.estimated_cardinality); - default: - throw NotImplementedException("Unimplemented type for logical simple operator"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_top_n.cpp b/src/duckdb/src/execution/physical_plan/plan_top_n.cpp deleted file mode 100644 index 700021876..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_top_n.cpp +++ /dev/null @@ -1,19 +0,0 @@ -#include "duckdb/execution/operator/order/physical_top_n.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_top_n.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalTopN &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - auto top_n = - make_uniq(op.types, std::move(op.orders), NumericCast(op.limit), - NumericCast(op.offset), std::move(op.dynamic_filter), op.estimated_cardinality); - top_n->children.push_back(std::move(plan)); - return std::move(top_n); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_unnest.cpp b/src/duckdb/src/execution/physical_plan/plan_unnest.cpp deleted file mode 100644 index 992da4306..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_unnest.cpp +++ /dev/null @@ -1,15 +0,0 @@ -#include "duckdb/execution/operator/projection/physical_unnest.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_unnest.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUnnest &op) { - D_ASSERT(op.children.size() == 1); - auto plan = CreatePlan(*op.children[0]); - auto unnest = make_uniq(op.types, std::move(op.expressions), op.estimated_cardinality); - unnest->children.push_back(std::move(plan)); - return std::move(unnest); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_update.cpp b/src/duckdb/src/execution/physical_plan/plan_update.cpp deleted file mode 100644 index b35914233..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_update.cpp +++ /dev/null @@ -1,29 +0,0 @@ -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/execution/operator/persistent/physical_update.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/operator/logical_update.hpp" -#include "duckdb/catalog/duck_catalog.hpp" - -namespace duckdb { - -unique_ptr DuckCatalog::PlanUpdate(ClientContext &context, LogicalUpdate &op, - unique_ptr plan) { - auto update = make_uniq(op.types, op.table, op.table.GetStorage(), op.columns, - std::move(op.expressions), std::move(op.bound_defaults), - std::move(op.bound_constraints), op.estimated_cardinality, op.return_chunk); - - update->update_is_del_and_insert = op.update_is_del_and_insert; - update->children.push_back(std::move(plan)); - return std::move(update); -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalUpdate &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); - - dependencies.AddDependency(op.table); - return op.table.catalog.PlanUpdate(context, op, std::move(plan)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_vacuum.cpp b/src/duckdb/src/execution/physical_plan/plan_vacuum.cpp deleted file mode 100644 index 3f1f88b3d..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_vacuum.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/execution/operator/helper/physical_vacuum.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/planner/logical_operator.hpp" -#include "duckdb/planner/operator/logical_vacuum.hpp" - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalVacuum &op) { - auto result = make_uniq(unique_ptr_cast(std::move(op.info)), op.table, - std::move(op.column_id_map), op.estimated_cardinality); - if (!op.children.empty()) { - auto child = CreatePlan(*op.children[0]); - result->children.push_back(std::move(child)); - } - return std::move(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan/plan_window.cpp b/src/duckdb/src/execution/physical_plan/plan_window.cpp deleted file mode 100644 index f294d9324..000000000 --- a/src/duckdb/src/execution/physical_plan/plan_window.cpp +++ /dev/null @@ -1,153 +0,0 @@ -#include "duckdb/execution/operator/aggregate/physical_streaming_window.hpp" -#include "duckdb/execution/operator/aggregate/physical_window.hpp" -#include "duckdb/execution/operator/projection/physical_projection.hpp" -#include "duckdb/execution/physical_plan_generator.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" -#include "duckdb/planner/operator/logical_window.hpp" - -#include - -namespace duckdb { - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalWindow &op) { - D_ASSERT(op.children.size() == 1); - - auto plan = CreatePlan(*op.children[0]); -#ifdef DEBUG - for (auto &expr : op.expressions) { - D_ASSERT(expr->IsWindow()); - } -#endif - - op.estimated_cardinality = op.EstimateCardinality(context); - - // Slice types - auto types = op.types; - const auto input_width = types.size() - op.expressions.size(); - types.resize(input_width); - - // Identify streaming windows - const bool enable_optimizer = ClientConfig::GetConfig(context).enable_optimizer; - vector blocking_windows; - vector streaming_windows; - for (idx_t expr_idx = 0; expr_idx < op.expressions.size(); expr_idx++) { - if (enable_optimizer && PhysicalStreamingWindow::IsStreamingFunction(context, op.expressions[expr_idx])) { - streaming_windows.push_back(expr_idx); - } else { - blocking_windows.push_back(expr_idx); - } - } - - // Process the window functions by sharing the partition/order definitions - unordered_map projection_map; - vector> window_expressions; - idx_t blocking_count = 0; - auto output_pos = input_width; - while (!blocking_windows.empty() || !streaming_windows.empty()) { - const bool process_streaming = blocking_windows.empty(); - auto &remaining = process_streaming ? streaming_windows : blocking_windows; - blocking_count += process_streaming ? 0 : 1; - - // Find all functions that share the partitioning of the first remaining expression - auto over_idx = remaining[0]; - - vector matching; - vector unprocessed; - for (const auto &expr_idx : remaining) { - D_ASSERT(op.expressions[expr_idx]->GetExpressionClass() == ExpressionClass::BOUND_WINDOW); - auto &wexpr = op.expressions[expr_idx]->Cast(); - - // Just record the first one (it defines the partition) - if (over_idx == expr_idx) { - matching.emplace_back(expr_idx); - continue; - } - - // If it is in a different partition, skip it - const auto &over_expr = op.expressions[over_idx]->Cast(); - if (!over_expr.PartitionsAreEquivalent(wexpr)) { - unprocessed.emplace_back(expr_idx); - continue; - } - - // CSE Elimination: Search for a previous match - bool cse = false; - for (idx_t i = 0; i < matching.size(); ++i) { - const auto match_idx = matching[i]; - auto &match_expr = op.expressions[match_idx]->Cast(); - if (wexpr.Equals(match_expr)) { - projection_map[input_width + expr_idx] = output_pos + i; - cse = true; - break; - } - } - if (cse) { - continue; - } - - // Is there a common sort prefix? - const auto prefix = over_expr.GetSharedOrders(wexpr); - if (prefix != MinValue(over_expr.orders.size(), wexpr.orders.size())) { - unprocessed.emplace_back(expr_idx); - continue; - } - matching.emplace_back(expr_idx); - - // Switch to the longer prefix - if (prefix < wexpr.orders.size()) { - over_idx = expr_idx; - } - } - remaining.swap(unprocessed); - - // Remember the projection order - for (const auto &expr_idx : matching) { - projection_map[input_width + expr_idx] = output_pos++; - } - - window_expressions.emplace_back(std::move(matching)); - } - - // Build the window operators - for (idx_t i = 0; i < window_expressions.size(); ++i) { - // Extract the matching expressions - const auto &matching = window_expressions[i]; - vector> select_list; - for (const auto &expr_idx : matching) { - select_list.emplace_back(std::move(op.expressions[expr_idx])); - types.emplace_back(op.types[input_width + expr_idx]); - } - - // Chain the new window operator on top of the plan - unique_ptr window; - if (i < blocking_count) { - window = make_uniq(types, std::move(select_list), op.estimated_cardinality); - } else { - window = make_uniq(types, std::move(select_list), op.estimated_cardinality); - } - window->children.push_back(std::move(plan)); - plan = std::move(window); - } - - // Put everything back into place if it moved - if (!projection_map.empty()) { - vector> select_list(op.types.size()); - // The inputs don't move - for (idx_t i = 0; i < input_width; ++i) { - select_list[i] = make_uniq(op.types[i], i); - } - // The outputs have been rearranged - for (const auto &p : projection_map) { - select_list[p.first] = make_uniq(op.types[p.first], p.second); - } - auto proj = make_uniq(op.types, std::move(select_list), op.estimated_cardinality); - proj->children.push_back(std::move(plan)); - plan = std::move(proj); - } - - return plan; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/physical_plan_generator.cpp b/src/duckdb/src/execution/physical_plan_generator.cpp deleted file mode 100644 index 1ac6f61b2..000000000 --- a/src/duckdb/src/execution/physical_plan_generator.cpp +++ /dev/null @@ -1,221 +0,0 @@ -#include "duckdb/execution/physical_plan_generator.hpp" - -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/execution/column_binding_resolver.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/query_profiler.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/operator/logical_extension_operator.hpp" -#include "duckdb/planner/operator/list.hpp" -#include "duckdb/execution/operator/helper/physical_verify_vector.hpp" - -namespace duckdb { - -PhysicalPlanGenerator::PhysicalPlanGenerator(ClientContext &context) : context(context) { -} - -PhysicalPlanGenerator::~PhysicalPlanGenerator() { -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(unique_ptr op) { - auto &profiler = QueryProfiler::Get(context); - - // first resolve column references - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER_COLUMN_BINDING); - ColumnBindingResolver resolver; - resolver.VisitOperator(*op); - profiler.EndPhase(); - - // now resolve types of all the operators - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER_RESOLVE_TYPES); - op->ResolveOperatorTypes(); - profiler.EndPhase(); - - // then create the main physical plan - profiler.StartPhase(MetricsType::PHYSICAL_PLANNER_CREATE_PLAN); - auto plan = CreatePlan(*op); - profiler.EndPhase(); - - plan->Verify(); - return plan; -} - -unique_ptr PhysicalPlanGenerator::CreatePlan(LogicalOperator &op) { - op.estimated_cardinality = op.EstimateCardinality(context); - unique_ptr plan = nullptr; - - switch (op.type) { - case LogicalOperatorType::LOGICAL_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PROJECTION: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EMPTY_RESULT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_FILTER: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_AGGREGATE_AND_GROUP_BY: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_WINDOW: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UNNEST: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_LIMIT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_SAMPLE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ORDER_BY: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_TOP_N: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_COPY_TO_FILE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DUMMY_SCAN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ANY_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_ASOF_JOIN: - case LogicalOperatorType::LOGICAL_DELIM_JOIN: - case LogicalOperatorType::LOGICAL_COMPARISON_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CROSS_PRODUCT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_POSITIONAL_JOIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UNION: - case LogicalOperatorType::LOGICAL_EXCEPT: - case LogicalOperatorType::LOGICAL_INTERSECT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_INSERT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DELETE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CHUNK_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DELIM_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPRESSION_GET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UPDATE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_TABLE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_INDEX: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_SECRET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPLAIN: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_DISTINCT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PREPARE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXECUTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CREATE_VIEW: - case LogicalOperatorType::LOGICAL_CREATE_SEQUENCE: - case LogicalOperatorType::LOGICAL_CREATE_SCHEMA: - case LogicalOperatorType::LOGICAL_CREATE_MACRO: - case LogicalOperatorType::LOGICAL_CREATE_TYPE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PRAGMA: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_VACUUM: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_TRANSACTION: - case LogicalOperatorType::LOGICAL_ALTER: - case LogicalOperatorType::LOGICAL_DROP: - case LogicalOperatorType::LOGICAL_LOAD: - case LogicalOperatorType::LOGICAL_ATTACH: - case LogicalOperatorType::LOGICAL_DETACH: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_RECURSIVE_CTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_MATERIALIZED_CTE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_CTE_REF: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXPORT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_SET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_RESET: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_PIVOT: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_COPY_DATABASE: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_UPDATE_EXTENSIONS: - plan = CreatePlan(op.Cast()); - break; - case LogicalOperatorType::LOGICAL_EXTENSION_OPERATOR: - plan = op.Cast().CreatePlan(context, *this); - - if (!plan) { - throw InternalException("Missing PhysicalOperator for Extension Operator"); - } - break; - case LogicalOperatorType::LOGICAL_JOIN: - case LogicalOperatorType::LOGICAL_DEPENDENT_JOIN: - case LogicalOperatorType::LOGICAL_INVALID: { - throw NotImplementedException("Unimplemented logical operator type!"); - } - } - if (!plan) { - throw InternalException("Physical plan generator - no plan generated"); - } - - plan->estimated_cardinality = op.estimated_cardinality; -#ifdef DUCKDB_VERIFY_VECTOR_OPERATOR - auto verify = make_uniq(std::move(plan)); - plan = std::move(verify); -#endif - - return plan; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp b/src/duckdb/src/execution/radix_partitioned_hashtable.cpp deleted file mode 100644 index cf6038fa4..000000000 --- a/src/duckdb/src/execution/radix_partitioned_hashtable.cpp +++ /dev/null @@ -1,947 +0,0 @@ -#include "duckdb/execution/radix_partitioned_hashtable.hpp" - -#include "duckdb/common/radix_partitioning.hpp" -#include "duckdb/common/row_operations/row_operations.hpp" -#include "duckdb/common/types/row/tuple_data_collection.hpp" -#include "duckdb/common/types/row/tuple_data_iterator.hpp" -#include "duckdb/execution/aggregate_hashtable.hpp" -#include "duckdb/execution/executor.hpp" -#include "duckdb/execution/ht_entry.hpp" -#include "duckdb/execution/operator/aggregate/physical_hash_aggregate.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/storage/temporary_memory_manager.hpp" - -namespace duckdb { - -RadixPartitionedHashTable::RadixPartitionedHashTable(GroupingSet &grouping_set_p, const GroupedAggregateData &op_p) - : grouping_set(grouping_set_p), op(op_p) { - auto groups_count = op.GroupCount(); - for (idx_t i = 0; i < groups_count; i++) { - if (grouping_set.find(i) == grouping_set.end()) { - null_groups.push_back(i); - } - } - if (grouping_set.empty()) { - // Fake a single group with a constant value for aggregation without groups - group_types.emplace_back(LogicalType::TINYINT); - } - for (auto &entry : grouping_set) { - D_ASSERT(entry < op.group_types.size()); - group_types.push_back(op.group_types[entry]); - } - SetGroupingValues(); - - auto group_types_copy = group_types; - group_types_copy.emplace_back(LogicalType::HASH); - layout.Initialize(std::move(group_types_copy), AggregateObject::CreateAggregateObjects(op.bindings)); -} - -void RadixPartitionedHashTable::SetGroupingValues() { - // Compute the GROUPING values: - // For each parameter to the GROUPING clause, we check if the hash table groups on this particular group - // If it does, we return 0, otherwise we return 1 - // We then use bitshifts to combine these values - auto &grouping_functions = op.GetGroupingFunctions(); - for (auto &grouping : grouping_functions) { - int64_t grouping_value = 0; - D_ASSERT(grouping.size() < sizeof(int64_t) * 8); - for (idx_t i = 0; i < grouping.size(); i++) { - if (grouping_set.find(grouping[i]) == grouping_set.end()) { - // We don't group on this value! - grouping_value += 1LL << (grouping.size() - (i + 1)); - } - } - grouping_values.push_back(Value::BIGINT(grouping_value)); - } -} - -const TupleDataLayout &RadixPartitionedHashTable::GetLayout() const { - return layout; -} - -unique_ptr RadixPartitionedHashTable::CreateHT(ClientContext &context, const idx_t capacity, - const idx_t radix_bits) const { - return make_uniq(context, BufferAllocator::Get(context), group_types, op.payload_types, - op.bindings, capacity, radix_bits); -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -enum class AggregatePartitionState : uint8_t { - //! Can be finalized - READY_TO_FINALIZE = 0, - //! Finalize is in progress - FINALIZE_IN_PROGRESS = 1, - //! Finalized, ready to scan - READY_TO_SCAN = 2 -}; - -struct AggregatePartition : StateWithBlockableTasks { - explicit AggregatePartition(unique_ptr data_p) - : state(AggregatePartitionState::READY_TO_FINALIZE), data(std::move(data_p)), progress(0) { - } - - AggregatePartitionState state; - - unique_ptr data; - atomic progress; -}; - -class RadixHTGlobalSinkState; - -struct RadixHTConfig { -public: - explicit RadixHTConfig(RadixHTGlobalSinkState &sink); - - void SetRadixBits(const idx_t &radix_bits_p); - bool SetRadixBitsToExternal(); - idx_t GetRadixBits() const; - -private: - void SetRadixBitsInternal(idx_t radix_bits_p, bool external); - idx_t InitialSinkRadixBits() const; - idx_t MaximumSinkRadixBits() const; - idx_t SinkCapacity() const; - -private: - //! The global sink state - RadixHTGlobalSinkState &sink; - -public: - //! Number of threads (from TaskScheduler) - const idx_t number_of_threads; - //! Width of tuples - const idx_t row_width; - //! Capacity of HTs during the Sink - const idx_t sink_capacity; - -private: - //! Assume (1 << 15) = 32KB L1 cache per core, divided by two because hyperthreading - static constexpr idx_t L1_CACHE_SIZE = 32768 / 2; - //! Assume (1 << 20) = 1MB L2 cache per core, divided by two because hyperthreading - static constexpr idx_t L2_CACHE_SIZE = 1048576 / 2; - //! Assume (1 << 20) + (1 << 19) = 1.5MB L3 cache per core (shared), divided by two because hyperthreading - static constexpr idx_t L3_CACHE_SIZE = 1572864 / 2; - - //! Sink radix bits to initialize with - static constexpr idx_t MAXIMUM_INITIAL_SINK_RADIX_BITS = 4; - //! Maximum Sink radix bits (independent of threads) - static constexpr idx_t MAXIMUM_FINAL_SINK_RADIX_BITS = 8; - - //! Current thread-global sink radix bits - atomic sink_radix_bits; - //! Maximum Sink radix bits (set based on number of threads) - const idx_t maximum_sink_radix_bits; - - //! Thresholds at which we reduce the sink radix bits - //! This needed to reduce cache misses when we have very wide rows - static constexpr idx_t ROW_WIDTH_THRESHOLD_ONE = 32; - static constexpr idx_t ROW_WIDTH_THRESHOLD_TWO = 64; - -public: - //! If we have this many or less threads, we grow the HT, otherwise we abandon - static constexpr idx_t GROW_STRATEGY_THREAD_THRESHOLD = 2; - //! If we fill this many blocks per partition, we trigger a repartition - static constexpr double BLOCK_FILL_FACTOR = 1.8; - //! By how many bits to repartition if a repartition is triggered - static constexpr idx_t REPARTITION_RADIX_BITS = 2; -}; - -class RadixHTGlobalSinkState : public GlobalSinkState { -public: - RadixHTGlobalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - - //! Destroys aggregate states (if multi-scan) - ~RadixHTGlobalSinkState() override; - void Destroy(); - -public: - ClientContext &context; - //! Temporary memory state for managing this hash table's memory usage - unique_ptr temporary_memory_state; - idx_t minimum_reservation; - - //! Whether we've called Finalize - bool finalized; - //! Whether we are doing an external aggregation - atomic external; - //! Threads that have called Sink - atomic active_threads; - //! Number of threads (from TaskScheduler) - const idx_t number_of_threads; - //! If any thread has called combine - atomic any_combined; - - //! The radix HT - const RadixPartitionedHashTable &radix_ht; - //! Config for partitioning - RadixHTConfig config; - - //! Uncombined partitioned data that will be put into the AggregatePartitions - unique_ptr uncombined_data; - //! Allocators used during the Sink/Finalize - vector> stored_allocators; - idx_t stored_allocators_size; - - //! Partitions that are finalized during GetData - vector> partitions; - //! For keeping track of progress - atomic finalize_done; - - //! Pin properties when scanning - TupleDataPinProperties scan_pin_properties; - //! Total count before combining - idx_t count_before_combining; - //! Maximum partition size if all unique - idx_t max_partition_size; -}; - -RadixHTGlobalSinkState::RadixHTGlobalSinkState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht_p) - : context(context_p), temporary_memory_state(TemporaryMemoryManager::Get(context).Register(context)), - finalized(false), external(false), active_threads(0), - number_of_threads(NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads())), - any_combined(false), radix_ht(radix_ht_p), config(*this), stored_allocators_size(0), finalize_done(0), - scan_pin_properties(TupleDataPinProperties::DESTROY_AFTER_DONE), count_before_combining(0), - max_partition_size(0) { - - // Compute minimum reservation - auto block_alloc_size = BufferManager::GetBufferManager(context).GetBlockAllocSize(); - auto tuples_per_block = block_alloc_size / radix_ht.GetLayout().GetRowWidth(); - idx_t ht_count = - LossyNumericCast(static_cast(config.sink_capacity) / GroupedAggregateHashTable::LOAD_FACTOR); - auto num_partitions = RadixPartitioning::NumberOfPartitions(config.GetRadixBits()); - auto count_per_partition = ht_count / num_partitions; - auto blocks_per_partition = (count_per_partition + tuples_per_block) / tuples_per_block + 1; - if (!radix_ht.GetLayout().AllConstant()) { - blocks_per_partition += 2; - } - auto ht_size = blocks_per_partition * block_alloc_size + config.sink_capacity * sizeof(ht_entry_t); - - // This really is the minimum reservation that we can do - auto num_threads = NumericCast(TaskScheduler::GetScheduler(context).NumberOfThreads()); - minimum_reservation = num_threads * ht_size; - - temporary_memory_state->SetMinimumReservation(minimum_reservation); - temporary_memory_state->SetRemainingSizeAndUpdateReservation(context, minimum_reservation); -} - -RadixHTGlobalSinkState::~RadixHTGlobalSinkState() { - Destroy(); -} - -// LCOV_EXCL_START -void RadixHTGlobalSinkState::Destroy() { - if (scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE || count_before_combining == 0 || - partitions.empty()) { - // Already destroyed / empty - return; - } - - TupleDataLayout layout = partitions[0]->data->GetLayout().Copy(); - if (!layout.HasDestructor()) { - return; // No destructors, exit - } - - // There are aggregates with destructors: Call the destructor for each of the aggregates - auto guard = Lock(); - RowOperationsState row_state(*stored_allocators.back()); - for (auto &partition : partitions) { - auto &data_collection = *partition->data; - if (data_collection.Count() == 0) { - continue; - } - TupleDataChunkIterator iterator(data_collection, TupleDataPinProperties::DESTROY_AFTER_DONE, false); - auto &row_locations = iterator.GetChunkState().row_locations; - do { - RowOperations::DestroyStates(row_state, layout, row_locations, iterator.GetCurrentChunkCount()); - } while (iterator.Next()); - data_collection.Reset(); - } -} -// LCOV_EXCL_STOP - -RadixHTConfig::RadixHTConfig(RadixHTGlobalSinkState &sink_p) - : sink(sink_p), number_of_threads(sink.number_of_threads), row_width(sink.radix_ht.GetLayout().GetRowWidth()), - sink_capacity(SinkCapacity()), sink_radix_bits(InitialSinkRadixBits()), - maximum_sink_radix_bits(MaximumSinkRadixBits()) { -} - -void RadixHTConfig::SetRadixBits(const idx_t &radix_bits_p) { - SetRadixBitsInternal(MinValue(radix_bits_p, maximum_sink_radix_bits), false); -} - -bool RadixHTConfig::SetRadixBitsToExternal() { - SetRadixBitsInternal(MAXIMUM_FINAL_SINK_RADIX_BITS, true); - return sink.external; -} - -idx_t RadixHTConfig::GetRadixBits() const { - return sink_radix_bits; -} - -void RadixHTConfig::SetRadixBitsInternal(const idx_t radix_bits_p, bool external) { - if (sink_radix_bits >= radix_bits_p || sink.any_combined) { - return; - } - - auto guard = sink.Lock(); - if (sink_radix_bits >= radix_bits_p || sink.any_combined) { - return; - } - - if (external) { - sink.external = true; - } - sink_radix_bits = radix_bits_p; -} - -idx_t RadixHTConfig::InitialSinkRadixBits() const { - return MinValue(RadixPartitioning::RadixBitsOfPowerOfTwo(NextPowerOfTwo(number_of_threads)), - MAXIMUM_INITIAL_SINK_RADIX_BITS); -} - -idx_t RadixHTConfig::MaximumSinkRadixBits() const { - if (number_of_threads <= GROW_STRATEGY_THREAD_THRESHOLD) { - return InitialSinkRadixBits(); // Don't repartition unless we go external - } - // If rows are very wide we have to reduce the number of partitions, otherwise cache misses get out of hand - if (row_width >= ROW_WIDTH_THRESHOLD_TWO) { - return MAXIMUM_FINAL_SINK_RADIX_BITS - 2; - } - if (row_width >= ROW_WIDTH_THRESHOLD_ONE) { - return MAXIMUM_FINAL_SINK_RADIX_BITS - 1; - } - return MAXIMUM_FINAL_SINK_RADIX_BITS; -} - -idx_t RadixHTConfig::SinkCapacity() const { - // Compute cache size per active thread (assuming cache is shared) - const auto total_shared_cache_size = number_of_threads * L3_CACHE_SIZE; - const auto cache_per_active_thread = L1_CACHE_SIZE + L2_CACHE_SIZE + total_shared_cache_size / number_of_threads; - - // Divide cache per active thread by entry size, round up to next power of two, to get capacity - const auto size_per_entry = LossyNumericCast(sizeof(ht_entry_t) * GroupedAggregateHashTable::LOAD_FACTOR) + - MinValue(row_width, ROW_WIDTH_THRESHOLD_TWO); - const auto capacity = NextPowerOfTwo(cache_per_active_thread / size_per_entry); - - // Capacity must be at least the minimum capacity - return MaxValue(capacity, GroupedAggregateHashTable::InitialCapacity()); -} - -class RadixHTLocalSinkState : public LocalSinkState { -public: - RadixHTLocalSinkState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - -public: - //! Thread-local HT that is re-used after abandoning - unique_ptr ht; - //! Chunk with group columns - DataChunk group_chunk; - - //! Data that is abandoned ends up here (only if we're doing external aggregation) - unique_ptr abandoned_data; -}; - -RadixHTLocalSinkState::RadixHTLocalSinkState(ClientContext &, const RadixPartitionedHashTable &radix_ht) { - // If there are no groups we create a fake group so everything has the same group - group_chunk.InitializeEmpty(radix_ht.group_types); - if (radix_ht.grouping_set.empty()) { - group_chunk.data[0].Reference(Value::TINYINT(42)); - } -} - -unique_ptr RadixPartitionedHashTable::GetGlobalSinkState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr RadixPartitionedHashTable::GetLocalSinkState(ExecutionContext &context) const { - return make_uniq(context.client, *this); -} - -void RadixPartitionedHashTable::PopulateGroupChunk(DataChunk &group_chunk, DataChunk &input_chunk) const { - idx_t chunk_index = 0; - // Populate the group_chunk - for (auto &group_idx : grouping_set) { - // Retrieve the expression containing the index in the input chunk - auto &group = op.groups[group_idx]; - D_ASSERT(group->GetExpressionType() == ExpressionType::BOUND_REF); - auto &bound_ref_expr = group->Cast(); - // Reference from input_chunk[group.index] -> group_chunk[chunk_index] - group_chunk.data[chunk_index++].Reference(input_chunk.data[bound_ref_expr.index]); - } - group_chunk.SetCardinality(input_chunk.size()); - group_chunk.Verify(); -} - -void MaybeRepartition(ClientContext &context, RadixHTGlobalSinkState &gstate, RadixHTLocalSinkState &lstate) { - auto &config = gstate.config; - auto &ht = *lstate.ht; - - // Check if we're approaching the memory limit - auto &temporary_memory_state = *gstate.temporary_memory_state; - const auto aggregate_allocator_size = ht.GetAggregateAllocator()->AllocationSize(); - const auto total_size = - aggregate_allocator_size + ht.GetPartitionedData().SizeInBytes() + ht.Capacity() * sizeof(ht_entry_t); - idx_t thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; - if (total_size > thread_limit) { - // We're over the thread memory limit - if (!gstate.external) { - // We haven't yet triggered out-of-core behavior, but maybe we don't have to, grab the lock and check again - auto guard = gstate.Lock(); - thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; - if (total_size > thread_limit) { - // Out-of-core would be triggered below, update minimum reservation and try to increase the reservation - temporary_memory_state.SetMinimumReservation(aggregate_allocator_size * gstate.number_of_threads + - gstate.minimum_reservation); - auto remaining_size = - MaxValue(gstate.number_of_threads * total_size, temporary_memory_state.GetRemainingSize()); - temporary_memory_state.SetRemainingSizeAndUpdateReservation(context, 2 * remaining_size); - thread_limit = temporary_memory_state.GetReservation() / gstate.number_of_threads; - } - } - } - - if (total_size > thread_limit) { - if (gstate.config.SetRadixBitsToExternal()) { - // We're approaching the memory limit, unpin the data - if (!lstate.abandoned_data) { - lstate.abandoned_data = make_uniq( - BufferManager::GetBufferManager(context), gstate.radix_ht.GetLayout(), config.GetRadixBits(), - gstate.radix_ht.GetLayout().ColumnCount() - 1); - } - ht.SetRadixBits(gstate.config.GetRadixBits()); - ht.AcquirePartitionedData()->Repartition(*lstate.abandoned_data); - } - } - - // We can go external when there are few threads, but we shouldn't repartition here - if (gstate.number_of_threads <= RadixHTConfig::GROW_STRATEGY_THREAD_THRESHOLD) { - return; - } - - const auto partition_count = ht.GetPartitionedData().PartitionCount(); - const auto current_radix_bits = RadixPartitioning::RadixBitsOfPowerOfTwo(partition_count); - D_ASSERT(current_radix_bits <= config.GetRadixBits()); - - const auto block_size = BufferManager::GetBufferManager(context).GetBlockSize(); - const auto row_size_per_partition = - ht.GetPartitionedData().Count() * ht.GetPartitionedData().GetLayout().GetRowWidth() / partition_count; - if (row_size_per_partition > LossyNumericCast(config.BLOCK_FILL_FACTOR * static_cast(block_size))) { - // We crossed our block filling threshold, try to increment radix bits - config.SetRadixBits(current_radix_bits + config.REPARTITION_RADIX_BITS); - } - - const auto global_radix_bits = config.GetRadixBits(); - if (current_radix_bits == global_radix_bits) { - return; // We're already on the right number of radix bits - } - - // We're out-of-sync with the global radix bits, repartition - ht.SetRadixBits(global_radix_bits); - ht.Repartition(); -} - -void RadixPartitionedHashTable::Sink(ExecutionContext &context, DataChunk &chunk, OperatorSinkInput &input, - DataChunk &payload_input, const unsafe_vector &filter) const { - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - if (!lstate.ht) { - lstate.ht = CreateHT(context.client, gstate.config.sink_capacity, gstate.config.GetRadixBits()); - gstate.active_threads++; - } - - auto &group_chunk = lstate.group_chunk; - PopulateGroupChunk(group_chunk, chunk); - - auto &ht = *lstate.ht; - ht.AddChunk(group_chunk, payload_input, filter); - - if (ht.Count() + STANDARD_VECTOR_SIZE < GroupedAggregateHashTable::ResizeThreshold(gstate.config.sink_capacity)) { - return; // We can fit another chunk - } - - if (gstate.number_of_threads > RadixHTConfig::GROW_STRATEGY_THREAD_THRESHOLD || gstate.external) { - // 'Reset' the HT without taking its data, we can just keep appending to the same collection - // This only works because we never resize the HT - // We don't do this when running with 1 or 2 threads, it only makes sense when there's many threads - ht.Abandon(); - - // Once we've inserted more than SKIP_LOOKUP_THRESHOLD tuples, - // and more than UNIQUE_PERCENTAGE_THRESHOLD were unique, - // we set the HT to skip doing lookups, which makes it blindly append data to the HT. - // This speeds up adding data, at the cost of no longer de-duplicating. - // The data will be de-duplicated later anyway - static constexpr idx_t SKIP_LOOKUP_THRESHOLD = 262144; - static constexpr double UNIQUE_PERCENTAGE_THRESHOLD = 0.95; - const auto unique_percentage = - static_cast(ht.GetPartitionedData().Count()) / static_cast(ht.GetSinkCount()); - if (ht.GetSinkCount() > SKIP_LOOKUP_THRESHOLD && unique_percentage > UNIQUE_PERCENTAGE_THRESHOLD) { - ht.SkipLookups(); - } - } - - // Check if we need to repartition - const auto radix_bits_before = ht.GetRadixBits(); - MaybeRepartition(context.client, gstate, lstate); - const auto repartitioned = radix_bits_before != ht.GetRadixBits(); - - if (repartitioned && ht.Count() != 0) { - // We repartitioned, but we didn't clear the pointer table / reset the count because we're on 1 or 2 threads - ht.Abandon(); - if (gstate.external) { - ht.Resize(gstate.config.sink_capacity); - } - } - - // TODO: combine early and often -} - -void RadixPartitionedHashTable::Combine(ExecutionContext &context, GlobalSinkState &gstate_p, - LocalSinkState &lstate_p) const { - auto &gstate = gstate_p.Cast(); - auto &lstate = lstate_p.Cast(); - if (!lstate.ht) { - return; - } - - // Set any_combined, then check one last time whether we need to repartition - gstate.any_combined = true; - MaybeRepartition(context.client, gstate, lstate); - - auto &ht = *lstate.ht; - auto lstate_data = ht.AcquirePartitionedData(); - if (lstate.abandoned_data) { - D_ASSERT(gstate.external); - D_ASSERT(lstate.abandoned_data->PartitionCount() == lstate.ht->GetPartitionedData().PartitionCount()); - D_ASSERT(lstate.abandoned_data->PartitionCount() == - RadixPartitioning::NumberOfPartitions(gstate.config.GetRadixBits())); - lstate.abandoned_data->Combine(*lstate_data); - } else { - lstate.abandoned_data = std::move(lstate_data); - } - - auto guard = gstate.Lock(); - if (gstate.uncombined_data) { - gstate.uncombined_data->Combine(*lstate.abandoned_data); - } else { - gstate.uncombined_data = std::move(lstate.abandoned_data); - } - gstate.stored_allocators.emplace_back(ht.GetAggregateAllocator()); - gstate.stored_allocators_size += gstate.stored_allocators.back()->AllocationSize(); -} - -void RadixPartitionedHashTable::Finalize(ClientContext &context, GlobalSinkState &gstate_p) const { - auto &gstate = gstate_p.Cast(); - - if (gstate.uncombined_data) { - auto &uncombined_data = *gstate.uncombined_data; - gstate.count_before_combining = uncombined_data.Count(); - - // If true there is no need to combine, it was all done by a single thread in a single HT - const auto single_ht = !gstate.external && gstate.active_threads == 1 && gstate.number_of_threads == 1; - - auto &uncombined_partition_data = uncombined_data.GetPartitions(); - const auto n_partitions = uncombined_partition_data.size(); - gstate.partitions.reserve(n_partitions); - for (idx_t i = 0; i < n_partitions; i++) { - auto &partition = uncombined_partition_data[i]; - auto partition_size = - partition->SizeInBytes() + - GroupedAggregateHashTable::GetCapacityForCount(partition->Count()) * sizeof(ht_entry_t); - gstate.max_partition_size = MaxValue(gstate.max_partition_size, partition_size); - - gstate.partitions.emplace_back(make_uniq(std::move(partition))); - if (single_ht) { - gstate.finalize_done++; - gstate.partitions.back()->progress = 1; - gstate.partitions.back()->state = AggregatePartitionState::READY_TO_SCAN; - } - } - } else { - gstate.count_before_combining = 0; - } - - // Minimum of combining one partition at a time - gstate.temporary_memory_state->SetMinimumReservation(gstate.stored_allocators_size + gstate.max_partition_size); - // Set size to 0 until the scan actually starts - gstate.temporary_memory_state->SetZero(); - gstate.finalized = true; -} - -//===--------------------------------------------------------------------===// -// Source -//===--------------------------------------------------------------------===// -idx_t RadixPartitionedHashTable::MaxThreads(GlobalSinkState &sink_p) const { - auto &sink = sink_p.Cast(); - if (sink.partitions.empty()) { - return 0; - } - - const auto max_threads = MinValue( - NumericCast(TaskScheduler::GetScheduler(sink.context).NumberOfThreads()), sink.partitions.size()); - sink.temporary_memory_state->SetRemainingSizeAndUpdateReservation( - sink.context, sink.stored_allocators_size + max_threads * sink.max_partition_size); - - // we cannot spill aggregate state memory - const auto usable_memory = sink.temporary_memory_state->GetReservation() > sink.stored_allocators_size - ? sink.temporary_memory_state->GetReservation() - sink.max_partition_size - : 0; - // This many partitions will fit given our reservation (at least 1)) - const auto partitions_fit = MaxValue(usable_memory / sink.max_partition_size, 1); - - // Mininum of the two - return MinValue(partitions_fit, max_threads); -} - -void RadixPartitionedHashTable::SetMultiScan(GlobalSinkState &sink_p) { - auto &sink = sink_p.Cast(); - sink.scan_pin_properties = TupleDataPinProperties::UNPIN_AFTER_DONE; -} - -enum class RadixHTSourceTaskType : uint8_t { NO_TASK, FINALIZE, SCAN }; - -class RadixHTLocalSourceState; - -class RadixHTGlobalSourceState : public GlobalSourceState { -public: - RadixHTGlobalSourceState(ClientContext &context, const RadixPartitionedHashTable &radix_ht); - - //! Assigns a task to a local source state - SourceResultType AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate, - InterruptState &interrupt_state); - -public: - //! The client context - ClientContext &context; - //! For synchronizing the source phase - atomic finished; - - //! Column ids for scanning - vector column_ids; - - //! For synchronizing tasks - idx_t task_idx; - atomic task_done; -}; - -enum class RadixHTScanStatus : uint8_t { INIT, IN_PROGRESS, DONE }; - -class RadixHTLocalSourceState : public LocalSourceState { -public: - explicit RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht); - -public: - //! Do the work this thread has been assigned - void ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); - //! Whether this thread has finished the work it has been assigned - bool TaskFinished(); - -private: - //! Execute the finalize or scan task - void Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate); - void Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk); - -public: - //! Current task and index - RadixHTSourceTaskType task; - idx_t task_idx; - - //! Thread-local HT that is re-used to Finalize - unique_ptr ht; - //! Current status of a Scan - RadixHTScanStatus scan_status; - -private: - //! Allocator and layout for finalizing state - TupleDataLayout layout; - ArenaAllocator aggregate_allocator; - - //! State and chunk for scanning - TupleDataScanState scan_state; - DataChunk scan_chunk; -}; - -unique_ptr RadixPartitionedHashTable::GetGlobalSourceState(ClientContext &context) const { - return make_uniq(context, *this); -} - -unique_ptr RadixPartitionedHashTable::GetLocalSourceState(ExecutionContext &context) const { - return make_uniq(context, *this); -} - -RadixHTGlobalSourceState::RadixHTGlobalSourceState(ClientContext &context_p, const RadixPartitionedHashTable &radix_ht) - : context(context_p), finished(false), task_idx(0), task_done(0) { - for (column_t column_id = 0; column_id < radix_ht.group_types.size(); column_id++) { - column_ids.push_back(column_id); - } -} - -SourceResultType RadixHTGlobalSourceState::AssignTask(RadixHTGlobalSinkState &sink, RadixHTLocalSourceState &lstate, - InterruptState &interrupt_state) { - // First, try to get a partition index - auto guard = sink.Lock(); - if (finished || task_idx == sink.partitions.size()) { - lstate.ht.reset(); - return SourceResultType::FINISHED; - } - lstate.task_idx = task_idx++; - - // We got a partition index - auto &partition = *sink.partitions[lstate.task_idx]; - auto partition_guard = partition.Lock(); - switch (partition.state) { - case AggregatePartitionState::READY_TO_FINALIZE: - partition.state = AggregatePartitionState::FINALIZE_IN_PROGRESS; - lstate.task = RadixHTSourceTaskType::FINALIZE; - return SourceResultType::HAVE_MORE_OUTPUT; - case AggregatePartitionState::FINALIZE_IN_PROGRESS: - lstate.task = RadixHTSourceTaskType::SCAN; - lstate.scan_status = RadixHTScanStatus::INIT; - return partition.BlockSource(partition_guard, interrupt_state); - case AggregatePartitionState::READY_TO_SCAN: - lstate.task = RadixHTSourceTaskType::SCAN; - lstate.scan_status = RadixHTScanStatus::INIT; - return SourceResultType::HAVE_MORE_OUTPUT; - default: - throw InternalException("Unexpected AggregatePartitionState in RadixHTLocalSourceState::Finalize!"); - } -} - -RadixHTLocalSourceState::RadixHTLocalSourceState(ExecutionContext &context, const RadixPartitionedHashTable &radix_ht) - : task(RadixHTSourceTaskType::NO_TASK), task_idx(DConstants::INVALID_INDEX), scan_status(RadixHTScanStatus::DONE), - layout(radix_ht.GetLayout().Copy()), aggregate_allocator(BufferAllocator::Get(context.client)) { - auto &allocator = BufferAllocator::Get(context.client); - auto scan_chunk_types = radix_ht.group_types; - for (auto &aggr_type : radix_ht.op.aggregate_return_types) { - scan_chunk_types.push_back(aggr_type); - } - scan_chunk.Initialize(allocator, scan_chunk_types); -} - -void RadixHTLocalSourceState::ExecuteTask(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, - DataChunk &chunk) { - D_ASSERT(task != RadixHTSourceTaskType::NO_TASK); - switch (task) { - case RadixHTSourceTaskType::FINALIZE: - Finalize(sink, gstate); - break; - case RadixHTSourceTaskType::SCAN: - Scan(sink, gstate, chunk); - break; - default: - throw InternalException("Unexpected RadixHTSourceTaskType in ExecuteTask!"); - } -} - -void RadixHTLocalSourceState::Finalize(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate) { - D_ASSERT(task == RadixHTSourceTaskType::FINALIZE); - D_ASSERT(scan_status != RadixHTScanStatus::IN_PROGRESS); - auto &partition = *sink.partitions[task_idx]; - - if (!ht) { - // This capacity would always be sufficient for all data - const auto capacity = GroupedAggregateHashTable::GetCapacityForCount(partition.data->Count()); - - // However, we will limit the initial capacity so we don't do a huge over-allocation - const auto n_threads = NumericCast(TaskScheduler::GetScheduler(gstate.context).NumberOfThreads()); - const auto memory_limit = BufferManager::GetBufferManager(gstate.context).GetMaxMemory(); - const idx_t thread_limit = LossyNumericCast(0.6 * double(memory_limit) / double(n_threads)); - - const idx_t size_per_entry = partition.data->SizeInBytes() / MaxValue(partition.data->Count(), 1) + - idx_t(GroupedAggregateHashTable::LOAD_FACTOR * sizeof(ht_entry_t)); - // but not lower than the initial capacity - const auto capacity_limit = - MaxValue(NextPowerOfTwo(thread_limit / size_per_entry), GroupedAggregateHashTable::InitialCapacity()); - - ht = sink.radix_ht.CreateHT(gstate.context, MinValue(capacity, capacity_limit), 0); - } else { - ht->Abandon(); - } - - // Now combine the uncombined data using this thread's HT - ht->Combine(*partition.data, &partition.progress); - partition.progress = 1; - - // Move the combined data back to the partition - partition.data = - make_uniq(BufferManager::GetBufferManager(gstate.context), sink.radix_ht.GetLayout()); - partition.data->Combine(*ht->AcquirePartitionedData()->GetPartitions()[0]); - - // Update thread-global state - auto guard = sink.Lock(); - sink.stored_allocators.emplace_back(ht->GetAggregateAllocator()); - if (task_idx == sink.partitions.size()) { - ht.reset(); - } - const auto finalizes_done = ++sink.finalize_done; - D_ASSERT(finalizes_done <= sink.partitions.size()); - if (finalizes_done == sink.partitions.size()) { - // All finalizes are done, set remaining size to 0 - sink.temporary_memory_state->SetZero(); - } - - // Update partition state - auto partition_guard = partition.Lock(); - partition.state = AggregatePartitionState::READY_TO_SCAN; - partition.UnblockTasks(partition_guard); - - // This thread will scan the partition - task = RadixHTSourceTaskType::SCAN; - scan_status = RadixHTScanStatus::INIT; -} - -void RadixHTLocalSourceState::Scan(RadixHTGlobalSinkState &sink, RadixHTGlobalSourceState &gstate, DataChunk &chunk) { - D_ASSERT(task == RadixHTSourceTaskType::SCAN); - D_ASSERT(scan_status != RadixHTScanStatus::DONE); - - auto &partition = *sink.partitions[task_idx]; - D_ASSERT(partition.state == AggregatePartitionState::READY_TO_SCAN); - auto &data_collection = *partition.data; - - if (scan_status == RadixHTScanStatus::INIT) { - data_collection.InitializeScan(scan_state, gstate.column_ids, sink.scan_pin_properties); - scan_status = RadixHTScanStatus::IN_PROGRESS; - } - - if (!data_collection.Scan(scan_state, scan_chunk)) { - if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE) { - data_collection.Reset(); - } - scan_status = RadixHTScanStatus::DONE; - auto guard = sink.Lock(); - if (++gstate.task_done == sink.partitions.size()) { - gstate.finished = true; - } - return; - } - - RowOperationsState row_state(aggregate_allocator); - const auto group_cols = layout.ColumnCount() - 1; - RowOperations::FinalizeStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk, group_cols); - - if (sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE && layout.HasDestructor()) { - RowOperations::DestroyStates(row_state, layout, scan_state.chunk_state.row_locations, scan_chunk.size()); - } - - auto &radix_ht = sink.radix_ht; - idx_t chunk_index = 0; - for (auto &entry : radix_ht.grouping_set) { - chunk.data[entry].Reference(scan_chunk.data[chunk_index++]); - } - for (auto null_group : radix_ht.null_groups) { - chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[null_group], true); - } - D_ASSERT(radix_ht.grouping_set.size() + radix_ht.null_groups.size() == radix_ht.op.GroupCount()); - for (idx_t col_idx = 0; col_idx < radix_ht.op.aggregates.size(); col_idx++) { - chunk.data[radix_ht.op.GroupCount() + col_idx].Reference( - scan_chunk.data[radix_ht.group_types.size() + col_idx]); - } - D_ASSERT(radix_ht.op.grouping_functions.size() == radix_ht.grouping_values.size()); - for (idx_t i = 0; i < radix_ht.op.grouping_functions.size(); i++) { - chunk.data[radix_ht.op.GroupCount() + radix_ht.op.aggregates.size() + i].Reference(radix_ht.grouping_values[i]); - } - chunk.SetCardinality(scan_chunk); - D_ASSERT(chunk.size() != 0); -} - -bool RadixHTLocalSourceState::TaskFinished() { - switch (task) { - case RadixHTSourceTaskType::FINALIZE: - return true; - case RadixHTSourceTaskType::SCAN: - return scan_status == RadixHTScanStatus::DONE; - default: - D_ASSERT(task == RadixHTSourceTaskType::NO_TASK); - return true; - } -} - -SourceResultType RadixPartitionedHashTable::GetData(ExecutionContext &context, DataChunk &chunk, - GlobalSinkState &sink_p, OperatorSourceInput &input) const { - auto &sink = sink_p.Cast(); - D_ASSERT(sink.finalized); - - auto &gstate = input.global_state.Cast(); - auto &lstate = input.local_state.Cast(); - D_ASSERT(sink.scan_pin_properties == TupleDataPinProperties::UNPIN_AFTER_DONE || - sink.scan_pin_properties == TupleDataPinProperties::DESTROY_AFTER_DONE); - - if (gstate.finished) { - return SourceResultType::FINISHED; - } - - if (sink.count_before_combining == 0) { - if (grouping_set.empty()) { - // Special case hack to sort out aggregating from empty intermediates for aggregations without groups - D_ASSERT(chunk.ColumnCount() == null_groups.size() + op.aggregates.size() + op.grouping_functions.size()); - // For each column in the aggregates, set to initial state - chunk.SetCardinality(1); - for (auto null_group : null_groups) { - chunk.data[null_group].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(chunk.data[null_group], true); - } - ArenaAllocator allocator(BufferAllocator::Get(context.client)); - for (idx_t i = 0; i < op.aggregates.size(); i++) { - D_ASSERT(op.aggregates[i]->GetExpressionClass() == ExpressionClass::BOUND_AGGREGATE); - auto &aggr = op.aggregates[i]->Cast(); - auto aggr_state = make_unsafe_uniq_array_uninitialized(aggr.function.state_size(aggr.function)); - aggr.function.initialize(aggr.function, aggr_state.get()); - - AggregateInputData aggr_input_data(aggr.bind_info.get(), allocator); - Vector state_vector(Value::POINTER(CastPointerToValue(aggr_state.get()))); - aggr.function.finalize(state_vector, aggr_input_data, chunk.data[null_groups.size() + i], 1, 0); - if (aggr.function.destructor) { - aggr.function.destructor(state_vector, aggr_input_data, 1); - } - } - // Place the grouping values (all the groups of the grouping_set condensed into a single value) - // Behind the null groups + aggregates - for (idx_t i = 0; i < op.grouping_functions.size(); i++) { - chunk.data[null_groups.size() + op.aggregates.size() + i].Reference(grouping_values[i]); - } - } - gstate.finished = true; - return SourceResultType::FINISHED; - } - - while (!gstate.finished && chunk.size() == 0) { - if (lstate.TaskFinished()) { - const auto res = gstate.AssignTask(sink, lstate, input.interrupt_state); - if (res != SourceResultType::HAVE_MORE_OUTPUT) { - D_ASSERT(res == SourceResultType::FINISHED || res == SourceResultType::BLOCKED); - return res; - } - } - lstate.ExecuteTask(sink, gstate, chunk); - } - - if (chunk.size() != 0) { - return SourceResultType::HAVE_MORE_OUTPUT; - } else { - return SourceResultType::FINISHED; - } -} - -ProgressData RadixPartitionedHashTable::GetProgress(ClientContext &, GlobalSinkState &sink_p, - GlobalSourceState &gstate_p) const { - auto &sink = sink_p.Cast(); - auto &gstate = gstate_p.Cast(); - - // Get partition combine progress, weigh it 2x - ProgressData progress; - for (auto &partition : sink.partitions) { - progress.done += 2.0 * partition->progress; - } - - // Get scan progress, weigh it 1x - progress.done += 1.0 * double(gstate.task_done); - - // Divide by 3x for the weights, and the number of partitions to get a value between 0 and 1 again - progress.total += 3.0 * double(sink.partitions.size()); - - return progress; -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp b/src/duckdb/src/execution/sample/base_reservoir_sample.cpp deleted file mode 100644 index 0f0fcdf7a..000000000 --- a/src/duckdb/src/execution/sample/base_reservoir_sample.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#include "duckdb/execution/reservoir_sample.hpp" -#include - -namespace duckdb { - -double BaseReservoirSampling::GetMinWeightFromTuplesSeen(idx_t rows_seen_total) { - // this function was obtained using https://mycurvefit.com. Inputting multiple x, y values into - // The - switch (rows_seen_total) { - case 0: - return 0; - case 1: - return 0.000161; - case 2: - return 0.530136; - case 3: - return 0.693454; - default: { - return (0.99 - 0.355 * std::exp(-0.07 * static_cast(rows_seen_total))); - } - } -} - -BaseReservoirSampling::BaseReservoirSampling(int64_t seed) : random(seed) { - next_index_to_sample = 0; - min_weight_threshold = 0; - min_weighted_entry_index = 0; - num_entries_to_skip_b4_next_sample = 0; - num_entries_seen_total = 0; -} - -BaseReservoirSampling::BaseReservoirSampling() : BaseReservoirSampling(1) { -} - -unique_ptr BaseReservoirSampling::Copy() { - auto ret = make_uniq(1); - ret->reservoir_weights = reservoir_weights; - ret->next_index_to_sample = next_index_to_sample; - ret->min_weight_threshold = min_weight_threshold; - ret->min_weighted_entry_index = min_weighted_entry_index; - ret->num_entries_to_skip_b4_next_sample = num_entries_to_skip_b4_next_sample; - ret->num_entries_seen_total = num_entries_seen_total; - return ret; -} - -void BaseReservoirSampling::InitializeReservoirWeights(idx_t cur_size, idx_t sample_size) { - //! 1: The first m items of V are inserted into R - //! first we need to check if the reservoir already has "m" elements - //! 2. For each item vi ∈ R: Calculate a key ki = random(0, 1) - //! we then define the threshold to enter the reservoir T_w as the minimum key of R - //! we use a priority queue to extract the minimum key in O(1) time - if (cur_size == sample_size) { - //! 2. For each item vi ∈ R: Calculate a key ki = random(0, 1) - //! we then define the threshold to enter the reservoir T_w as the minimum key of R - //! we use a priority queue to extract the minimum key in O(1) time - for (idx_t i = 0; i < sample_size; i++) { - double k_i = random.NextRandom(); - reservoir_weights.emplace(-k_i, i); - } - SetNextEntry(); - } -} - -void BaseReservoirSampling::SetNextEntry() { - D_ASSERT(!reservoir_weights.empty()); - //! 4. Let r = random(0, 1) and Xw = log(r) / log(T_w) - auto &min_key = reservoir_weights.top(); - double t_w = -min_key.first; - double r = random.NextRandom32(); - double x_w = log(r) / log(t_w); - //! 5. From the current item vc skip items until item vi , such that: - //! 6. wc +wc+1 +···+wi−1 < Xw <= wc +wc+1 +···+wi−1 +wi - //! since all our weights are 1 (uniform sampling), we can just determine the amount of elements to skip - min_weight_threshold = t_w; - min_weighted_entry_index = min_key.second; - next_index_to_sample = MaxValue(1, idx_t(round(x_w))); - num_entries_to_skip_b4_next_sample = 0; -} - -void BaseReservoirSampling::ReplaceElementWithIndex(idx_t entry_index, double with_weight, bool pop) { - - if (pop) { - reservoir_weights.pop(); - } - double r2 = with_weight; - //! now we insert the new weight into the reservoir - reservoir_weights.emplace(-r2, entry_index); - //! we update the min entry with the new min entry in the reservoir - SetNextEntry(); -} - -void BaseReservoirSampling::ReplaceElement(double with_weight) { - //! replace the entry in the reservoir - //! pop the minimum entry - reservoir_weights.pop(); - //! now update the reservoir - //! 8. Let tw = Tw i , r2 = random(tw,1) and vi’s key: ki = (r2)1/wi - //! 9. The new threshold Tw is the new minimum key of R - //! we generate a random number between (min_weight_threshold, 1) - double r2 = random.NextRandom(min_weight_threshold, 1); - - //! if we are merging two reservoir samples use the weight passed - if (with_weight >= 0) { - r2 = with_weight; - } - //! now we insert the new weight into the reservoir - reservoir_weights.emplace(-r2, min_weighted_entry_index); - //! we update the min entry with the new min entry in the reservoir - SetNextEntry(); -} - -void BaseReservoirSampling::UpdateMinWeightThreshold() { - if (!reservoir_weights.empty()) { - min_weight_threshold = -reservoir_weights.top().first; - min_weighted_entry_index = reservoir_weights.top().second; - return; - } - min_weight_threshold = 1; -} - -void BaseReservoirSampling::FillWeights(SelectionVector &sel, idx_t &sel_size) { - if (!reservoir_weights.empty()) { - return; - } - D_ASSERT(reservoir_weights.empty()); - auto num_entries_seen_normalized = num_entries_seen_total / FIXED_SAMPLE_SIZE; - auto min_weight = GetMinWeightFromTuplesSeen(num_entries_seen_normalized); - for (idx_t i = 0; i < sel_size; i++) { - auto weight = random.NextRandom(min_weight, 1); - reservoir_weights.emplace(-weight, i); - } - D_ASSERT(reservoir_weights.size() <= sel_size); - SetNextEntry(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/execution/sample/reservoir_sample.cpp b/src/duckdb/src/execution/sample/reservoir_sample.cpp deleted file mode 100644 index ba777b609..000000000 --- a/src/duckdb/src/execution/sample/reservoir_sample.cpp +++ /dev/null @@ -1,930 +0,0 @@ -#include "duckdb/execution/reservoir_sample.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include - -namespace duckdb { - -std::pair BlockingSample::PopFromWeightQueue() { - D_ASSERT(base_reservoir_sample && !base_reservoir_sample->reservoir_weights.empty()); - auto ret = base_reservoir_sample->reservoir_weights.top(); - base_reservoir_sample->reservoir_weights.pop(); - - base_reservoir_sample->UpdateMinWeightThreshold(); - D_ASSERT(base_reservoir_sample->min_weight_threshold > 0); - return ret; -} - -double BlockingSample::GetMinWeightThreshold() { - return base_reservoir_sample->min_weight_threshold; -} - -idx_t BlockingSample::GetPriorityQueueSize() { - return base_reservoir_sample->reservoir_weights.size(); -} - -void BlockingSample::Destroy() { - destroyed = true; -} - -void ReservoirChunk::Serialize(Serializer &serializer) const { - chunk.Serialize(serializer); -} - -unique_ptr ReservoirChunk::Deserialize(Deserializer &deserializer) { - auto result = make_uniq(); - result->chunk.Deserialize(deserializer); - return result; -} - -unique_ptr ReservoirChunk::Copy() const { - auto copy = make_uniq(); - copy->chunk.Initialize(Allocator::DefaultAllocator(), chunk.GetTypes()); - - chunk.Copy(copy->chunk); - return copy; -} - -ReservoirSample::ReservoirSample(idx_t sample_count, unique_ptr reservoir_chunk) - : ReservoirSample(Allocator::DefaultAllocator(), sample_count, 1) { - if (reservoir_chunk) { - this->reservoir_chunk = std::move(reservoir_chunk); - sel_size = this->reservoir_chunk->chunk.size(); - sel = SelectionVector(0, sel_size); - ExpandSerializedSample(); - } - stats_sample = true; -} - -ReservoirSample::ReservoirSample(Allocator &allocator, idx_t sample_count, int64_t seed) - : BlockingSample(seed), sample_count(sample_count), allocator(allocator) { - base_reservoir_sample = make_uniq(seed); - type = SampleType::RESERVOIR_SAMPLE; - reservoir_chunk = nullptr; - stats_sample = false; - sel = SelectionVector(sample_count); - sel_size = 0; -} - -idx_t ReservoirSample::GetSampleCount() { - return sample_count; -} - -idx_t ReservoirSample::NumSamplesCollected() const { - if (!reservoir_chunk) { - return 0; - } - return reservoir_chunk->chunk.size(); -} - -SamplingState ReservoirSample::GetSamplingState() const { - if (base_reservoir_sample->reservoir_weights.empty()) { - return SamplingState::RANDOM; - } - return SamplingState::RESERVOIR; -} - -idx_t ReservoirSample::GetActiveSampleCount() const { - switch (GetSamplingState()) { - case SamplingState::RANDOM: - return sel_size; - case SamplingState::RESERVOIR: - return base_reservoir_sample->reservoir_weights.size(); - default: - throw InternalException("Sampling State is INVALID"); - } -} - -idx_t ReservoirSample::GetTuplesSeen() const { - return base_reservoir_sample->num_entries_seen_total; -} - -DataChunk &ReservoirSample::Chunk() { - D_ASSERT(reservoir_chunk); - return reservoir_chunk->chunk; -} - -unique_ptr ReservoirSample::GetChunk() { - if (destroyed || !reservoir_chunk || Chunk().size() == 0) { - return nullptr; - } - // cannot destory internal samples. - auto ret = make_uniq(); - - SelectionVector ret_sel(STANDARD_VECTOR_SIZE); - idx_t collected_samples = GetActiveSampleCount(); - - if (collected_samples == 0) { - return nullptr; - } - - idx_t samples_remaining; - idx_t return_chunk_size; - if (collected_samples > STANDARD_VECTOR_SIZE) { - samples_remaining = collected_samples - STANDARD_VECTOR_SIZE; - return_chunk_size = STANDARD_VECTOR_SIZE; - } else { - samples_remaining = 0; - return_chunk_size = collected_samples; - } - - for (idx_t i = samples_remaining; i < collected_samples; i++) { - // pop samples and reduce size of selection vector. - if (GetSamplingState() == SamplingState::RESERVOIR) { - auto top = PopFromWeightQueue(); - ret_sel.set_index(i - samples_remaining, sel.get_index(top.second)); - } else { - ret_sel.set_index(i - samples_remaining, sel.get_index(i)); - } - sel_size -= 1; - } - - auto reservoir_types = Chunk().GetTypes(); - - ret->Initialize(allocator, reservoir_types, STANDARD_VECTOR_SIZE); - ret->Slice(Chunk(), ret_sel, return_chunk_size); - ret->SetCardinality(return_chunk_size); - return ret; -} - -unique_ptr ReservoirSample::CreateNewSampleChunk(vector &types, idx_t size) const { - auto new_sample_chunk = make_uniq(); - new_sample_chunk->chunk.Initialize(Allocator::DefaultAllocator(), types, size); - - // set the NULL columns correctly - for (idx_t col_idx = 0; col_idx < types.size(); col_idx++) { - if (!ValidSampleType(types[col_idx]) && stats_sample) { - new_sample_chunk->chunk.data[col_idx].SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(new_sample_chunk->chunk.data[col_idx], true); - } - } - return new_sample_chunk; -} - -void ReservoirSample::Vacuum() { - Verify(); - if (NumSamplesCollected() <= FIXED_SAMPLE_SIZE || !reservoir_chunk || destroyed) { - // sample is destroyed or too small to shrink - return; - } - - auto ret = Copy(); - auto ret_reservoir = duckdb::unique_ptr_cast(std::move(ret)); - reservoir_chunk = std::move(ret_reservoir->reservoir_chunk); - sel = std::move(ret_reservoir->sel); - sel_size = ret_reservoir->sel_size; - - Verify(); - // We should only have one sample chunk now. - D_ASSERT(Chunk().size() > 0 && Chunk().size() <= sample_count); -} - -unique_ptr ReservoirSample::Copy() const { - - auto ret = make_uniq(sample_count); - ret->stats_sample = stats_sample; - - ret->base_reservoir_sample = base_reservoir_sample->Copy(); - ret->destroyed = destroyed; - - if (!reservoir_chunk || destroyed) { - return unique_ptr_cast(std::move(ret)); - } - - D_ASSERT(reservoir_chunk); - - // create a new sample chunk to store new samples - auto types = reservoir_chunk->chunk.GetTypes(); - // how many values should be copied - idx_t values_to_copy = MinValue(GetActiveSampleCount(), sample_count); - - auto new_sample_chunk = CreateNewSampleChunk(types, GetReservoirChunkCapacity()); - - SelectionVector sel_copy(sel); - - ret->reservoir_chunk = std::move(new_sample_chunk); - ret->UpdateSampleAppend(ret->reservoir_chunk->chunk, reservoir_chunk->chunk, sel_copy, values_to_copy); - ret->sel = SelectionVector(values_to_copy); - for (idx_t i = 0; i < values_to_copy; i++) { - ret->sel.set_index(i, i); - } - ret->sel_size = sel_size; - D_ASSERT(ret->reservoir_chunk->chunk.size() <= sample_count); - ret->Verify(); - return unique_ptr_cast(std::move(ret)); -} - -void ReservoirSample::ConvertToReservoirSample() { - D_ASSERT(sel_size <= sample_count); - base_reservoir_sample->FillWeights(sel, sel_size); -} - -vector ReservoirSample::GetRandomizedVector(uint32_t range, uint32_t size) const { - vector ret; - ret.reserve(range); - for (uint32_t i = 0; i < range; i++) { - ret.push_back(i); - } - if (size == FIXED_SAMPLE_SIZE) { - std::shuffle(ret.begin(), ret.end(), base_reservoir_sample->random); - return ret; - } - for (uint32_t i = 0; i < size; i++) { - uint32_t random_shuffle = base_reservoir_sample->random.NextRandomInteger32(i, range); - if (random_shuffle == i) { - // leave the value where it is - continue; - } - uint32_t tmp = ret[random_shuffle]; - // basically replacing the tuple that was at index actual_sample_indexes[random_shuffle] - ret[random_shuffle] = ret[i]; - ret[i] = tmp; - } - return ret; -} - -void ReservoirSample::SimpleMerge(ReservoirSample &other) { - D_ASSERT(GetPriorityQueueSize() == 0); - D_ASSERT(other.GetPriorityQueueSize() == 0); - D_ASSERT(GetSamplingState() == SamplingState::RANDOM); - D_ASSERT(other.GetSamplingState() == SamplingState::RANDOM); - - if (other.GetActiveSampleCount() == 0 && other.GetTuplesSeen() == 0) { - return; - } - - if (GetActiveSampleCount() == 0 && GetTuplesSeen() == 0) { - sel = SelectionVector(other.sel); - sel_size = other.sel_size; - base_reservoir_sample->num_entries_seen_total = other.GetTuplesSeen(); - return; - } - - idx_t total_seen = GetTuplesSeen() + other.GetTuplesSeen(); - - auto weight_tuples_this = static_cast(GetTuplesSeen()) / static_cast(total_seen); - auto weight_tuples_other = static_cast(other.GetTuplesSeen()) / static_cast(total_seen); - - // If weights don't add up to 1, most likely a simple merge occured and no new samples were added. - // if that is the case, add the missing weight to the lower weighted sample to adjust. - // this is to avoid cases where if you have a 20k row table and add another 20k rows row by row - // then eventually the missing weights will add up, and get you a more even distribution - if (weight_tuples_this + weight_tuples_other < 1) { - weight_tuples_other += 1 - (weight_tuples_other + weight_tuples_this); - } - - idx_t keep_from_this = 0; - idx_t keep_from_other = 0; - D_ASSERT(stats_sample); - D_ASSERT(sample_count == FIXED_SAMPLE_SIZE); - D_ASSERT(sample_count == other.sample_count); - auto sample_count_double = static_cast(sample_count); - - if (weight_tuples_this > weight_tuples_other) { - keep_from_this = MinValue(static_cast(round(sample_count_double * weight_tuples_this)), - GetActiveSampleCount()); - keep_from_other = MinValue(sample_count - keep_from_this, other.GetActiveSampleCount()); - } else { - keep_from_other = MinValue(static_cast(round(sample_count_double * weight_tuples_other)), - other.GetActiveSampleCount()); - keep_from_this = MinValue(sample_count - keep_from_other, GetActiveSampleCount()); - } - - D_ASSERT(keep_from_this <= GetActiveSampleCount()); - D_ASSERT(keep_from_other <= other.GetActiveSampleCount()); - D_ASSERT(keep_from_other + keep_from_this <= FIXED_SAMPLE_SIZE); - idx_t size_after_merge = MinValue(keep_from_other + keep_from_this, FIXED_SAMPLE_SIZE); - - // Check if appending the other samples to this will go over the sample chunk size - if (reservoir_chunk->chunk.size() + keep_from_other > GetReservoirChunkCapacity()) { - Vacuum(); - } - - D_ASSERT(size_after_merge <= other.GetActiveSampleCount() + GetActiveSampleCount()); - SelectionVector chunk_sel(keep_from_other); - auto offset = reservoir_chunk->chunk.size(); - for (idx_t i = keep_from_this; i < size_after_merge; i++) { - if (i >= GetActiveSampleCount()) { - sel.set_index(GetActiveSampleCount(), offset); - sel_size += 1; - } else { - sel.set_index(i, offset); - } - chunk_sel.set_index(i - keep_from_this, other.sel.get_index(i - keep_from_this)); - offset += 1; - } - - D_ASSERT(GetActiveSampleCount() == size_after_merge); - - // Copy the rows that make it to the sample from other and put them into this. - UpdateSampleAppend(reservoir_chunk->chunk, other.reservoir_chunk->chunk, chunk_sel, keep_from_other); - base_reservoir_sample->num_entries_seen_total += other.GetTuplesSeen(); - - // if THIS has too many samples now, we conver it to a slower sample. - if (GetTuplesSeen() >= FIXED_SAMPLE_SIZE * FAST_TO_SLOW_THRESHOLD) { - ConvertToReservoirSample(); - } - Verify(); -} - -void ReservoirSample::WeightedMerge(ReservoirSample &other_sample) { - D_ASSERT(GetSamplingState() == SamplingState::RESERVOIR); - D_ASSERT(other_sample.GetSamplingState() == SamplingState::RESERVOIR); - - // Find out how many samples we want to keep. - idx_t total_samples = GetActiveSampleCount() + other_sample.GetActiveSampleCount(); - idx_t total_samples_seen = - base_reservoir_sample->num_entries_seen_total + other_sample.base_reservoir_sample->num_entries_seen_total; - idx_t num_samples_to_keep = MinValue(total_samples, MinValue(sample_count, total_samples_seen)); - - D_ASSERT(GetActiveSampleCount() <= num_samples_to_keep); - D_ASSERT(total_samples <= FIXED_SAMPLE_SIZE * 2); - - // pop from base base_reservoir weights until there are num_samples_to_keep left. - vector this_indexes_to_replace; - for (idx_t i = num_samples_to_keep; i < total_samples; i++) { - auto min_weight_this = base_reservoir_sample->min_weight_threshold; - auto min_weight_other = other_sample.base_reservoir_sample->min_weight_threshold; - // min weight threshol is always positive - if (min_weight_this > min_weight_other) { - // pop from other - other_sample.base_reservoir_sample->reservoir_weights.pop(); - other_sample.base_reservoir_sample->UpdateMinWeightThreshold(); - } else { - auto top_this = PopFromWeightQueue(); - this_indexes_to_replace.push_back(top_this.second); - base_reservoir_sample->UpdateMinWeightThreshold(); - } - } - - D_ASSERT(other_sample.GetPriorityQueueSize() + GetPriorityQueueSize() <= FIXED_SAMPLE_SIZE); - D_ASSERT(other_sample.GetPriorityQueueSize() + GetPriorityQueueSize() == num_samples_to_keep); - D_ASSERT(other_sample.reservoir_chunk->chunk.GetTypes() == reservoir_chunk->chunk.GetTypes()); - - // Prepare a selection vector to copy data from the other sample chunk to this sample chunk - SelectionVector sel_other(other_sample.GetPriorityQueueSize()); - D_ASSERT(GetPriorityQueueSize() <= num_samples_to_keep); - D_ASSERT(other_sample.GetPriorityQueueSize() >= this_indexes_to_replace.size()); - idx_t chunk_offset = 0; - - // Now push weights from other.base_reservoir_sample to this - // Depending on how many sample values "this" has, we either need to add to the selection vector - // Or replace values in "this'" selection vector - idx_t i = 0; - while (other_sample.GetPriorityQueueSize() > 0) { - auto other_top = other_sample.PopFromWeightQueue(); - idx_t index_for_new_pair = chunk_offset + reservoir_chunk->chunk.size(); - - // update the sel used to copy values from other to this - sel_other.set_index(chunk_offset, other_top.second); - if (i < this_indexes_to_replace.size()) { - auto replacement_index = this_indexes_to_replace[i]; - sel.set_index(replacement_index, index_for_new_pair); - other_top.second = replacement_index; - } else { - sel.set_index(sel_size, index_for_new_pair); - other_top.second = sel_size; - sel_size += 1; - } - - // make sure that the sample indexes are (this.sample_chunk.size() + chunk_offfset) - base_reservoir_sample->reservoir_weights.push(other_top); - chunk_offset += 1; - i += 1; - } - - D_ASSERT(GetPriorityQueueSize() == num_samples_to_keep); - - base_reservoir_sample->UpdateMinWeightThreshold(); - D_ASSERT(base_reservoir_sample->min_weight_threshold > 0); - base_reservoir_sample->num_entries_seen_total = GetTuplesSeen() + other_sample.GetTuplesSeen(); - - UpdateSampleAppend(reservoir_chunk->chunk, other_sample.reservoir_chunk->chunk, sel_other, chunk_offset); - if (reservoir_chunk->chunk.size() > FIXED_SAMPLE_SIZE * (FIXED_SAMPLE_SIZE_MULTIPLIER - 3)) { - Vacuum(); - } - - Verify(); -} - -void ReservoirSample::Merge(unique_ptr other) { - if (destroyed || other->destroyed) { - Destroy(); - return; - } - - D_ASSERT(other->type == SampleType::RESERVOIR_SAMPLE); - auto &other_sample = other->Cast(); - - // if the other sample has not collected anything yet return - if (!other_sample.reservoir_chunk || other_sample.reservoir_chunk->chunk.size() == 0) { - return; - } - - // this has not collected samples, take over the other - if (!reservoir_chunk || reservoir_chunk->chunk.size() == 0) { - base_reservoir_sample = std::move(other->base_reservoir_sample); - reservoir_chunk = std::move(other_sample.reservoir_chunk); - sel = SelectionVector(other_sample.sel); - sel_size = other_sample.sel_size; - Verify(); - return; - } - //! Both samples are still in "fast sampling" method - if (GetSamplingState() == SamplingState::RANDOM && other_sample.GetSamplingState() == SamplingState::RANDOM) { - SimpleMerge(other_sample); - return; - } - - // One or none of the samples are in "Fast Sampling" method. - // When this is the case, switch both to slow sampling - ConvertToReservoirSample(); - other_sample.ConvertToReservoirSample(); - WeightedMerge(other_sample); -} - -void ReservoirSample::ShuffleSel(SelectionVector &sel, idx_t range, idx_t size) const { - auto randomized = GetRandomizedVector(static_cast(range), static_cast(size)); - SelectionVector original_sel(range); - for (idx_t i = 0; i < range; i++) { - original_sel.set_index(i, sel.get_index(i)); - } - for (idx_t i = 0; i < size; i++) { - sel.set_index(i, original_sel.get_index(randomized[i])); - } -} - -void ReservoirSample::NormalizeWeights() { - vector> tmp_weights; - while (!base_reservoir_sample->reservoir_weights.empty()) { - auto top = base_reservoir_sample->reservoir_weights.top(); - tmp_weights.push_back(std::move(top)); - base_reservoir_sample->reservoir_weights.pop(); - } - std::sort(tmp_weights.begin(), tmp_weights.end(), - [&](std::pair a, std::pair b) { return a.second < b.second; }); - for (idx_t i = 0; i < tmp_weights.size(); i++) { - base_reservoir_sample->reservoir_weights.emplace(tmp_weights.at(i).first, i); - } - base_reservoir_sample->SetNextEntry(); -} - -void ReservoirSample::EvictOverBudgetSamples() { - Verify(); - if (!reservoir_chunk || destroyed) { - return; - } - - // since this is for serialization, we really need to make sure keep a - // minimum of 1% of the rows or 2048 rows - idx_t num_samples_to_keep = - MinValue(FIXED_SAMPLE_SIZE, static_cast(SAVE_PERCENTAGE * static_cast(GetTuplesSeen()))); - - if (num_samples_to_keep <= 0) { - reservoir_chunk->chunk.SetCardinality(0); - return; - } - - if (num_samples_to_keep == sample_count) { - return; - } - - // if we over sampled, make sure we only keep the highest percentage samples - std::unordered_set selections_to_delete; - - while (num_samples_to_keep < GetPriorityQueueSize()) { - auto top = PopFromWeightQueue(); - D_ASSERT(top.second < sel_size); - selections_to_delete.emplace(top.second); - } - - // set up reservoir chunk for the reservoir sample - D_ASSERT(reservoir_chunk->chunk.size() <= sample_count); - // create a new sample chunk to store new samples - auto types = reservoir_chunk->chunk.GetTypes(); - D_ASSERT(num_samples_to_keep <= sample_count); - D_ASSERT(stats_sample); - D_ASSERT(sample_count == FIXED_SAMPLE_SIZE); - auto new_reservoir_chunk = CreateNewSampleChunk(types, sample_count); - - // The current selection vector can potentially have 2048 valid mappings. - // If we need to save a sample with less rows than that, we need to do the following - // 1. Create a new selection vector that doesn't point to the rows we are evicting - SelectionVector new_sel(num_samples_to_keep); - idx_t offset = 0; - for (idx_t i = 0; i < num_samples_to_keep + selections_to_delete.size(); i++) { - if (selections_to_delete.find(i) == selections_to_delete.end()) { - D_ASSERT(i - offset < num_samples_to_keep); - new_sel.set_index(i - offset, sel.get_index(i)); - } else { - offset += 1; - } - } - // 2. Update row_ids in our weights so that they don't store rows ids to - // indexes in the selection vector that have been evicted. - if (!selections_to_delete.empty()) { - NormalizeWeights(); - } - - D_ASSERT(reservoir_chunk->chunk.GetTypes() == new_reservoir_chunk->chunk.GetTypes()); - - UpdateSampleAppend(new_reservoir_chunk->chunk, reservoir_chunk->chunk, new_sel, num_samples_to_keep); - // set the cardinality - new_reservoir_chunk->chunk.SetCardinality(num_samples_to_keep); - reservoir_chunk = std::move(new_reservoir_chunk); - sel_size = num_samples_to_keep; - base_reservoir_sample->UpdateMinWeightThreshold(); -} - -void ReservoirSample::ExpandSerializedSample() { - if (!reservoir_chunk) { - return; - } - - auto types = reservoir_chunk->chunk.GetTypes(); - auto new_res_chunk = CreateNewSampleChunk(types, GetReservoirChunkCapacity()); - auto copy_count = reservoir_chunk->chunk.size(); - SelectionVector tmp_sel = SelectionVector(0, copy_count); - UpdateSampleAppend(new_res_chunk->chunk, reservoir_chunk->chunk, tmp_sel, copy_count); - new_res_chunk->chunk.SetCardinality(copy_count); - std::swap(reservoir_chunk, new_res_chunk); -} - -idx_t ReservoirSample::GetReservoirChunkCapacity() const { - return sample_count + (FIXED_SAMPLE_SIZE_MULTIPLIER * FIXED_SAMPLE_SIZE); -} - -idx_t ReservoirSample::FillReservoir(DataChunk &chunk) { - - idx_t ingested_count = 0; - if (!reservoir_chunk) { - if (chunk.size() > FIXED_SAMPLE_SIZE) { - throw InternalException("Creating sample with DataChunk that is larger than the fixed sample size"); - } - auto types = chunk.GetTypes(); - // create a new sample chunk to store new samples - reservoir_chunk = CreateNewSampleChunk(types, GetReservoirChunkCapacity()); - } - - idx_t actual_sample_index_start = GetActiveSampleCount(); - D_ASSERT(reservoir_chunk->chunk.ColumnCount() == chunk.ColumnCount()); - - if (reservoir_chunk->chunk.size() < sample_count) { - ingested_count = MinValue(sample_count - reservoir_chunk->chunk.size(), chunk.size()); - auto random_other_sel = - GetRandomizedVector(static_cast(ingested_count), static_cast(ingested_count)); - SelectionVector sel_for_input_chunk(ingested_count); - for (idx_t i = 0; i < ingested_count; i++) { - sel.set_index(actual_sample_index_start + i, actual_sample_index_start + i); - sel_for_input_chunk.set_index(i, random_other_sel[i]); - } - UpdateSampleAppend(reservoir_chunk->chunk, chunk, sel_for_input_chunk, ingested_count); - sel_size += ingested_count; - } - D_ASSERT(GetActiveSampleCount() <= sample_count); - D_ASSERT(GetActiveSampleCount() >= ingested_count); - // always return how many tuples were ingested - return ingested_count; -} - -void ReservoirSample::Destroy() { - destroyed = true; -} - -SelectionVectorHelper ReservoirSample::GetReplacementIndexes(idx_t sample_chunk_offset, - idx_t theoretical_chunk_length) { - if (GetSamplingState() == SamplingState::RANDOM) { - return GetReplacementIndexesFast(sample_chunk_offset, theoretical_chunk_length); - } - return GetReplacementIndexesSlow(sample_chunk_offset, theoretical_chunk_length); -} - -SelectionVectorHelper ReservoirSample::GetReplacementIndexesFast(idx_t sample_chunk_offset, idx_t chunk_length) { - - // how much weight to the other tuples have compared to the ones in this chunk? - auto weight_tuples_other = static_cast(chunk_length) / static_cast(GetTuplesSeen() + chunk_length); - auto num_to_pop = static_cast(round(weight_tuples_other * static_cast(sample_count))); - D_ASSERT(num_to_pop <= sample_count); - D_ASSERT(num_to_pop <= sel_size); - SelectionVectorHelper ret; - - if (num_to_pop == 0) { - ret.sel = SelectionVector(num_to_pop); - ret.size = 0; - return ret; - } - std::unordered_map replacement_indexes; - SelectionVector chunk_sel(num_to_pop); - - auto random_indexes_chunk = GetRandomizedVector(static_cast(chunk_length), num_to_pop); - auto random_sel_indexes = GetRandomizedVector(static_cast(sel_size), num_to_pop); - for (idx_t i = 0; i < num_to_pop; i++) { - // update the selection vector for the reservoir sample - chunk_sel.set_index(i, random_indexes_chunk[i]); - // sel is not guaratneed to be random, so we update the indexes according to our - // random sel indexes. - sel.set_index(random_sel_indexes[i], sample_chunk_offset + i); - } - - D_ASSERT(sel_size == sample_count); - - ret.sel = SelectionVector(chunk_sel); - ret.size = num_to_pop; - return ret; -} - -SelectionVectorHelper ReservoirSample::GetReplacementIndexesSlow(const idx_t sample_chunk_offset, - const idx_t chunk_length) { - idx_t remaining = chunk_length; - std::unordered_map ret_map; - idx_t sample_chunk_index = 0; - - idx_t base_offset = 0; - - while (true) { - idx_t offset = - base_reservoir_sample->next_index_to_sample - base_reservoir_sample->num_entries_to_skip_b4_next_sample; - if (offset >= remaining) { - // not in this chunk! increment current count and go to the next chunk - base_reservoir_sample->num_entries_to_skip_b4_next_sample += remaining; - break; - } - // in this chunk! replace the element - // ret[index_in_new_chunk] = index_in_sample_chunk (the sample chunk offset will be applied later) - // D_ASSERT(sample_chunk_index == ret.size()); - ret_map[base_offset + offset] = sample_chunk_index; - double r2 = base_reservoir_sample->random.NextRandom32(base_reservoir_sample->min_weight_threshold, 1); - // replace element in our max_heap - // first get the top most pair - const auto top = PopFromWeightQueue(); - const auto index = top.second; - const auto index_in_sample_chunk = sample_chunk_offset + sample_chunk_index; - sel.set_index(index, index_in_sample_chunk); - base_reservoir_sample->ReplaceElementWithIndex(index, r2, false); - - sample_chunk_index += 1; - // shift the chunk forward - remaining -= offset; - base_offset += offset; - } - - // create selection vector to return - SelectionVector ret_sel(ret_map.size()); - D_ASSERT(sel_size == sample_count); - for (auto &kv : ret_map) { - ret_sel.set_index(kv.second, kv.first); - } - SelectionVectorHelper ret; - ret.sel = SelectionVector(ret_sel); - ret.size = static_cast(ret_map.size()); - return ret; -} - -void ReservoirSample::Finalize() { -} - -bool ReservoirSample::ValidSampleType(const LogicalType &type) { - return type.IsNumeric(); -} - -void ReservoirSample::UpdateSampleAppend(DataChunk &this_, DataChunk &other, SelectionVector &other_sel, - idx_t append_count) const { - idx_t new_size = this_.size() + append_count; - if (other.size() == 0) { - return; - } - D_ASSERT(this_.GetTypes() == other.GetTypes()); - - // UpdateSampleAppend(this_, other, other_sel, append_count); - D_ASSERT(this_.GetTypes() == other.GetTypes()); - auto types = reservoir_chunk->chunk.GetTypes(); - - for (idx_t i = 0; i < reservoir_chunk->chunk.ColumnCount(); i++) { - auto col_type = types[i]; - if (ValidSampleType(col_type) || !stats_sample) { - D_ASSERT(this_.data[i].GetVectorType() == VectorType::FLAT_VECTOR); - VectorOperations::Copy(other.data[i], this_.data[i], other_sel, append_count, 0, this_.size()); - } - } - this_.SetCardinality(new_size); -} - -void ReservoirSample::AddToReservoir(DataChunk &chunk) { - if (destroyed || chunk.size() == 0) { - return; - } - - idx_t tuples_consumed = FillReservoir(chunk); - base_reservoir_sample->num_entries_seen_total += tuples_consumed; - D_ASSERT(sample_count == 0 || reservoir_chunk->chunk.size() >= 1); - - if (tuples_consumed == chunk.size()) { - return; - } - - // the chunk filled the first FIXED_SAMPLE_SIZE chunk but still has tuples remaining - // slice the chunk and call AddToReservoir again. - if (tuples_consumed != chunk.size() && tuples_consumed != 0) { - // Fill reservoir consumed some of the chunk to reach FIXED_SAMPLE_SIZE - // now we need to - // So we slice it and call AddToReservoir - auto slice = make_uniq(); - auto samples_remaining = chunk.size() - tuples_consumed; - auto types = chunk.GetTypes(); - SelectionVector input_sel(samples_remaining); - for (idx_t i = 0; i < samples_remaining; i++) { - input_sel.set_index(i, tuples_consumed + i); - } - slice->Initialize(Allocator::DefaultAllocator(), types, samples_remaining); - slice->Slice(chunk, input_sel, samples_remaining); - slice->SetCardinality(samples_remaining); - AddToReservoir(*slice); - return; - } - - // at this point we should have collected at least sample count samples - D_ASSERT(GetActiveSampleCount() >= sample_count); - - auto chunk_sel = GetReplacementIndexes(reservoir_chunk->chunk.size(), chunk.size()); - - if (chunk_sel.size == 0) { - // not adding any samples - return; - } - idx_t size = chunk_sel.size; - D_ASSERT(size <= chunk.size()); - - UpdateSampleAppend(reservoir_chunk->chunk, chunk, chunk_sel.sel, size); - - base_reservoir_sample->num_entries_seen_total += chunk.size(); - D_ASSERT(base_reservoir_sample->reservoir_weights.size() == 0 || - base_reservoir_sample->reservoir_weights.size() == sample_count); - - Verify(); - - // if we are over the threshold, we ned to swith to slow sampling. - if (GetSamplingState() == SamplingState::RANDOM && GetTuplesSeen() >= FIXED_SAMPLE_SIZE * FAST_TO_SLOW_THRESHOLD) { - ConvertToReservoirSample(); - } - if (reservoir_chunk->chunk.size() >= (GetReservoirChunkCapacity() - (static_cast(FIXED_SAMPLE_SIZE) * 3))) { - Vacuum(); - } -} - -void ReservoirSample::Verify() { -#ifdef DEBUG - if (destroyed) { - return; - } - if (GetPriorityQueueSize() == 0) { - D_ASSERT(GetActiveSampleCount() <= sample_count); - D_ASSERT(GetTuplesSeen() >= GetActiveSampleCount()); - return; - } - if (NumSamplesCollected() > sample_count) { - D_ASSERT(GetPriorityQueueSize() == sample_count); - } else if (NumSamplesCollected() <= sample_count && GetPriorityQueueSize() > 0) { - // it's possible to collect more samples than your priority queue size. - // see sample_converts_to_reservoir_sample.test - D_ASSERT(NumSamplesCollected() >= GetPriorityQueueSize()); - } - auto base_reservoir_copy = base_reservoir_sample->Copy(); - std::unordered_map index_count; - while (!base_reservoir_copy->reservoir_weights.empty()) { - auto &pair = base_reservoir_copy->reservoir_weights.top(); - if (index_count.find(pair.second) == index_count.end()) { - index_count[pair.second] = 1; - base_reservoir_copy->reservoir_weights.pop(); - } else { - index_count[pair.second] += 1; - base_reservoir_copy->reservoir_weights.pop(); - throw InternalException("Duplicate selection index in reservoir weights"); - } - } - // TODO: Verify the Sel as well. No duplicate indices. - - if (reservoir_chunk) { - reservoir_chunk->chunk.Verify(); - } -#endif -} - -ReservoirSamplePercentage::ReservoirSamplePercentage(double percentage, int64_t seed, idx_t reservoir_sample_size) - : BlockingSample(seed), allocator(Allocator::DefaultAllocator()), sample_percentage(percentage / 100.0), - reservoir_sample_size(reservoir_sample_size), current_count(0), is_finalized(false) { - current_sample = make_uniq(allocator, reservoir_sample_size, base_reservoir_sample->random()); - type = SampleType::RESERVOIR_PERCENTAGE_SAMPLE; -} - -ReservoirSamplePercentage::ReservoirSamplePercentage(Allocator &allocator, double percentage, int64_t seed) - : BlockingSample(seed), allocator(allocator), sample_percentage(percentage / 100.0), current_count(0), - is_finalized(false) { - reservoir_sample_size = (idx_t)(sample_percentage * RESERVOIR_THRESHOLD); - current_sample = make_uniq(allocator, reservoir_sample_size, base_reservoir_sample->random()); - type = SampleType::RESERVOIR_PERCENTAGE_SAMPLE; -} - -ReservoirSamplePercentage::ReservoirSamplePercentage(double percentage, int64_t seed) - : ReservoirSamplePercentage(Allocator::DefaultAllocator(), percentage, seed) { -} - -void ReservoirSamplePercentage::AddToReservoir(DataChunk &input) { - base_reservoir_sample->num_entries_seen_total += input.size(); - if (current_count + input.size() > RESERVOIR_THRESHOLD) { - // we don't have enough space in our current reservoir - // first check what we still need to append to the current sample - idx_t append_to_current_sample_count = RESERVOIR_THRESHOLD - current_count; - idx_t append_to_next_sample = input.size() - append_to_current_sample_count; - if (append_to_current_sample_count > 0) { - // we have elements remaining, first add them to the current sample - if (append_to_next_sample > 0) { - // we need to also add to the next sample - DataChunk new_chunk; - new_chunk.InitializeEmpty(input.GetTypes()); - new_chunk.Slice(input, *FlatVector::IncrementalSelectionVector(), append_to_current_sample_count); - new_chunk.Flatten(); - current_sample->AddToReservoir(new_chunk); - } else { - input.Flatten(); - input.SetCardinality(append_to_current_sample_count); - current_sample->AddToReservoir(input); - } - } - if (append_to_next_sample > 0) { - // slice the input for the remainder - SelectionVector sel(append_to_next_sample); - for (idx_t i = append_to_current_sample_count; i < append_to_next_sample + append_to_current_sample_count; - i++) { - sel.set_index(i - append_to_current_sample_count, i); - } - input.Slice(sel, append_to_next_sample); - } - // now our first sample is filled: append it to the set of finished samples - finished_samples.push_back(std::move(current_sample)); - - // allocate a new sample, and potentially add the remainder of the current input to that sample - current_sample = make_uniq(allocator, reservoir_sample_size, base_reservoir_sample->random()); - if (append_to_next_sample > 0) { - current_sample->AddToReservoir(input); - } - current_count = append_to_next_sample; - } else { - // we can just append to the current sample - current_count += input.size(); - current_sample->AddToReservoir(input); - } -} - -unique_ptr ReservoirSamplePercentage::GetChunk() { - // reservoir sample percentage should never stay - if (!is_finalized) { - Finalize(); - } - while (!finished_samples.empty()) { - auto &front = finished_samples.front(); - auto chunk = front->GetChunk(); - if (chunk && chunk->size() > 0) { - return chunk; - } - // move to the next sample - finished_samples.erase(finished_samples.begin()); - } - return nullptr; -} - -unique_ptr ReservoirSamplePercentage::Copy() const { - throw InternalException("Cannot call Copy on ReservoirSample Percentage"); -} - -void ReservoirSamplePercentage::Finalize() { - // need to finalize the current sample, if any - // we are finializing, so we are starting to return chunks. Our last chunk has - // sample_percentage * RESERVOIR_THRESHOLD entries that hold samples. - // if our current count is less than the sample_percentage * RESERVOIR_THRESHOLD - // then we have sampled too much for the current_sample and we need to redo the sample - // otherwise we can just push the current sample back - // Imagine sampling 70% of 100 rows (so 70 rows). We allocate sample_percentage * RESERVOIR_THRESHOLD - // ----------------------------------------- - auto sampled_more_than_required = - static_cast(current_count) > sample_percentage * RESERVOIR_THRESHOLD || finished_samples.empty(); - if (current_count > 0 && sampled_more_than_required) { - // create a new sample - auto new_sample_size = static_cast(round(sample_percentage * static_cast(current_count))); - auto new_sample = make_uniq(allocator, new_sample_size, base_reservoir_sample->random()); - while (true) { - auto chunk = current_sample->GetChunk(); - if (!chunk || chunk->size() == 0) { - break; - } - new_sample->AddToReservoir(*chunk); - } - finished_samples.push_back(std::move(new_sample)); - } else { - finished_samples.push_back(std::move(current_sample)); - } - // when finalizing, current_sample is null. All samples are now in finished samples. - current_sample = nullptr; - is_finalized = true; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/count.cpp b/src/duckdb/src/function/aggregate/distributive/count.cpp deleted file mode 100644 index 30824b7a2..000000000 --- a/src/duckdb/src/function/aggregate/distributive/count.cpp +++ /dev/null @@ -1,255 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/aggregate/distributive_functions.hpp" -#include "duckdb/function/aggregate/distributive_function_utils.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" - -namespace duckdb { - -struct BaseCountFunction { - template - static void Initialize(STATE &state) { - state = 0; - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - target += source; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - target = state; - } -}; - -struct CountStarFunction : public BaseCountFunction { - template - static void Operation(STATE &state, AggregateInputData &, idx_t idx) { - state += 1; - } - - template - static void ConstantOperation(STATE &state, AggregateInputData &, idx_t count) { - state += count; - } - - template - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, const_data_ptr_t, - data_ptr_t l_state, const SubFrames &frames, Vector &result, idx_t rid) { - D_ASSERT(partition.column_ids.empty()); - - auto data = FlatVector::GetData(result); - RESULT_TYPE total = 0; - for (const auto &frame : frames) { - const auto begin = frame.start; - const auto end = frame.end; - - // Slice to any filtered rows - if (partition.filter_mask.AllValid()) { - total += end - begin; - continue; - } - for (auto i = begin; i < end; ++i) { - total += partition.filter_mask.RowIsValid(i); - } - } - data[rid] = total; - } -}; - -struct CountFunction : public BaseCountFunction { - using STATE = int64_t; - - static void Operation(STATE &state) { - state += 1; - } - - static void ConstantOperation(STATE &state, idx_t count) { - state += UnsafeNumericCast(count); - } - - static bool IgnoreNull() { - return true; - } - - static inline void CountFlatLoop(STATE **__restrict states, ValidityMask &mask, idx_t count) { - if (!mask.AllValid()) { - idx_t base_idx = 0; - auto entry_count = ValidityMask::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - auto validity_entry = mask.GetValidityEntry(entry_idx); - idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - // all valid: perform operation - for (; base_idx < next; base_idx++) { - CountFunction::Operation(*states[base_idx]); - } - } else if (ValidityMask::NoneValid(validity_entry)) { - // nothing valid: skip all - base_idx = next; - continue; - } else { - // partially valid: need to check individual elements for validity - idx_t start = base_idx; - for (; base_idx < next; base_idx++) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - CountFunction::Operation(*states[base_idx]); - } - } - } - } - } else { - for (idx_t i = 0; i < count; i++) { - CountFunction::Operation(*states[i]); - } - } - } - - static inline void CountScatterLoop(STATE **__restrict states, const SelectionVector &isel, - const SelectionVector &ssel, ValidityMask &mask, idx_t count) { - if (!mask.AllValid()) { - // potential NULL values - for (idx_t i = 0; i < count; i++) { - auto idx = isel.get_index(i); - auto sidx = ssel.get_index(i); - if (mask.RowIsValid(idx)) { - CountFunction::Operation(*states[sidx]); - } - } - } else { - // quick path: no NULL values - for (idx_t i = 0; i < count; i++) { - auto sidx = ssel.get_index(i); - CountFunction::Operation(*states[sidx]); - } - } - } - - static void CountScatter(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, - idx_t count) { - auto &input = inputs[0]; - if (input.GetVectorType() == VectorType::FLAT_VECTOR && states.GetVectorType() == VectorType::FLAT_VECTOR) { - auto sdata = FlatVector::GetData(states); - CountFlatLoop(sdata, FlatVector::Validity(input), count); - } else { - UnifiedVectorFormat idata, sdata; - input.ToUnifiedFormat(count, idata); - states.ToUnifiedFormat(count, sdata); - CountScatterLoop(reinterpret_cast(sdata.data), *idata.sel, *sdata.sel, idata.validity, count); - } - } - - static inline void CountFlatUpdateLoop(STATE &result, ValidityMask &mask, idx_t count) { - idx_t base_idx = 0; - auto entry_count = ValidityMask::EntryCount(count); - for (idx_t entry_idx = 0; entry_idx < entry_count; entry_idx++) { - auto validity_entry = mask.GetValidityEntry(entry_idx); - idx_t next = MinValue(base_idx + ValidityMask::BITS_PER_VALUE, count); - if (ValidityMask::AllValid(validity_entry)) { - // all valid - result += UnsafeNumericCast(next - base_idx); - base_idx = next; - } else if (ValidityMask::NoneValid(validity_entry)) { - // nothing valid: skip all - base_idx = next; - continue; - } else { - // partially valid: need to check individual elements for validity - idx_t start = base_idx; - for (; base_idx < next; base_idx++) { - if (ValidityMask::RowIsValid(validity_entry, base_idx - start)) { - result++; - } - } - } - } - } - - static inline void CountUpdateLoop(STATE &result, ValidityMask &mask, idx_t count, - const SelectionVector &sel_vector) { - if (mask.AllValid()) { - // no NULL values - result += UnsafeNumericCast(count); - return; - } - for (idx_t i = 0; i < count; i++) { - auto idx = sel_vector.get_index(i); - if (mask.RowIsValid(idx)) { - result++; - } - } - } - - static void CountUpdate(Vector inputs[], AggregateInputData &, idx_t input_count, data_ptr_t state_p, idx_t count) { - auto &input = inputs[0]; - auto &result = *reinterpret_cast(state_p); - switch (input.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (!ConstantVector::IsNull(input)) { - // if the constant is not null increment the state - result += UnsafeNumericCast(count); - } - break; - } - case VectorType::FLAT_VECTOR: { - CountFlatUpdateLoop(result, FlatVector::Validity(input), count); - break; - } - case VectorType::SEQUENCE_VECTOR: { - // sequence vectors cannot have NULL values - result += UnsafeNumericCast(count); - break; - } - default: { - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - CountUpdateLoop(result, idata.validity, count, *idata.sel); - break; - } - } - } -}; - -AggregateFunction CountFunctionBase::GetFunction() { - AggregateFunction fun({LogicalType(LogicalTypeId::ANY)}, LogicalType::BIGINT, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, CountFunction::CountScatter, - AggregateFunction::StateCombine, - AggregateFunction::StateFinalize, - FunctionNullHandling::SPECIAL_HANDLING, CountFunction::CountUpdate); - fun.name = "count"; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - return fun; -} - -AggregateFunction CountStarFun::GetFunction() { - auto fun = AggregateFunction::NullaryAggregate(LogicalType::BIGINT); - fun.name = "count_star"; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - fun.window = CountStarFunction::Window; - return fun; -} - -unique_ptr CountPropagateStats(ClientContext &context, BoundAggregateExpression &expr, - AggregateStatisticsInput &input) { - if (!expr.IsDistinct() && !input.child_stats[0].CanHaveNull()) { - // count on a column without null values: use count star - expr.function = CountStarFun::GetFunction(); - expr.function.name = "count_star"; - expr.children.clear(); - } - return nullptr; -} - -AggregateFunctionSet CountFun::GetFunctions() { - AggregateFunction count_function = CountFunctionBase::GetFunction(); - count_function.statistics = CountPropagateStats; - AggregateFunctionSet count("count"); - count.AddFunction(count_function); - // the count function can also be called without arguments - count.AddFunction(CountStarFun::GetFunction()); - return count; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp b/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp deleted file mode 100644 index 4e7455efa..000000000 --- a/src/duckdb/src/function/aggregate/distributive/first_last_any.cpp +++ /dev/null @@ -1,378 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/aggregate/distributive_functions.hpp" -#include "duckdb/function/aggregate/distributive_function_utils.hpp" -#include "duckdb/function/create_sort_key.hpp" -#include "duckdb/planner/expression.hpp" - -namespace duckdb { - -template -struct FirstState { - T value; - bool is_set; - bool is_null; -}; - -struct FirstFunctionBase { - template - static void Initialize(STATE &state) { - state.is_set = false; - state.is_null = false; - } - - static bool IgnoreNull() { - return false; - } -}; - -template -struct FirstFunction : public FirstFunctionBase { - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (LAST || !state.is_set) { - if (!unary_input.RowIsValid()) { - if (!SKIP_NULLS) { - state.is_set = true; - } - state.is_null = true; - } else { - state.is_set = true; - state.is_null = false; - state.value = input; - } - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - Operation(state, input, unary_input); - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!target.is_set) { - target = source; - } - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -template -struct FirstFunctionStringBase : public FirstFunctionBase { - template - static void SetValue(STATE &state, AggregateInputData &input_data, string_t value, bool is_null) { - if (LAST && state.is_set) { - Destroy(state, input_data); - } - if (is_null) { - if (!SKIP_NULLS) { - state.is_set = true; - state.is_null = true; - } - } else { - state.is_set = true; - state.is_null = false; - if ((COMBINE && !LAST) || value.IsInlined()) { - // We use the aggregate allocator for 'first', so the allocation is already done when combining - // Of course, if the value is inlined, we also don't need to allocate - state.value = value; - } else { - // non-inlined string, need to allocate space for it - auto len = value.GetSize(); - auto ptr = LAST ? new char[len] : char_ptr_cast(input_data.allocator.Allocate(len)); - memcpy(ptr, value.GetData(), len); - - state.value = string_t(ptr, UnsafeNumericCast(len)); - } - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (source.is_set && (LAST || !target.is_set)) { - SetValue(target, input_data, source.value, source.is_null); - } - } - - template - static void Destroy(STATE &state, AggregateInputData &) { - if (state.is_set && !state.is_null && !state.value.IsInlined()) { - delete[] state.value.GetData(); - } - } -}; - -template -struct FirstFunctionString : FirstFunctionStringBase { - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (LAST || !state.is_set) { - FirstFunctionStringBase::template SetValue(state, unary_input.input, input, - !unary_input.RowIsValid()); - } - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - Operation(state, input, unary_input); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } -}; - -template -struct FirstVectorFunction : FirstFunctionStringBase { - using STATE = FirstState; - - static void Update(Vector inputs[], AggregateInputData &input_data, idx_t, Vector &state_vector, idx_t count) { - auto &input = inputs[0]; - UnifiedVectorFormat idata; - input.ToUnifiedFormat(count, idata); - - UnifiedVectorFormat sdata; - state_vector.ToUnifiedFormat(count, sdata); - - sel_t assign_sel[STANDARD_VECTOR_SIZE]; - idx_t assign_count = 0; - - auto states = UnifiedVectorFormat::GetData(sdata); - for (idx_t i = 0; i < count; i++) { - const auto idx = idata.sel->get_index(i); - bool is_null = !idata.validity.RowIsValid(idx); - if (SKIP_NULLS && is_null) { - continue; - } - auto &state = *states[sdata.sel->get_index(i)]; - if (!LAST && state.is_set) { - continue; - } - assign_sel[assign_count++] = NumericCast(i); - } - if (assign_count == 0) { - // fast path - nothing to set - return; - } - - Vector sort_key(LogicalType::BLOB); - OrderModifiers modifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST); - // slice with a selection vector and generate sort keys - if (assign_count == count) { - CreateSortKeyHelpers::CreateSortKey(input, count, modifiers, sort_key); - } else { - SelectionVector sel(assign_sel); - Vector sliced_input(input, sel, assign_count); - CreateSortKeyHelpers::CreateSortKey(sliced_input, assign_count, modifiers, sort_key); - } - auto sort_key_data = FlatVector::GetData(sort_key); - - // now assign sort keys - for (idx_t i = 0; i < assign_count; i++) { - const auto state_idx = sdata.sel->get_index(assign_sel[i]); - auto &state = *states[state_idx]; - if (!LAST && state.is_set) { - continue; - } - - const auto idx = idata.sel->get_index(assign_sel[i]); - bool is_null = !idata.validity.RowIsValid(idx); - FirstFunctionStringBase::template SetValue(state, input_data, sort_key_data[i], - is_null); - } - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.is_set || state.is_null) { - finalize_data.ReturnNull(); - } else { - CreateSortKeyHelpers::DecodeSortKey(state.value, finalize_data.result, finalize_data.result_idx, - OrderModifiers(OrderType::ASCENDING, OrderByNullType::NULLS_LAST)); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -template -static void FirstFunctionSimpleUpdate(Vector inputs[], AggregateInputData &aggregate_input_data, idx_t input_count, - data_ptr_t state, idx_t count) { - auto agg_state = reinterpret_cast *>(state); - if (LAST || !agg_state->is_set) { - // For FIRST, this skips looping over the input once the aggregate state has been set - // FIXME: for LAST we could loop from the back of the Vector instead - AggregateFunction::UnaryUpdate, T, FirstFunction>(inputs, aggregate_input_data, - input_count, state, count); - } -} - -template -static AggregateFunction GetFirstAggregateTemplated(LogicalType type) { - auto result = AggregateFunction::UnaryAggregate, T, T, FirstFunction>(type, type); - result.simple_update = FirstFunctionSimpleUpdate; - return result; -} - -template -static AggregateFunction GetFirstFunction(const LogicalType &type); - -template -AggregateFunction GetDecimalFirstFunction(const LogicalType &type) { - D_ASSERT(type.id() == LogicalTypeId::DECIMAL); - switch (type.InternalType()) { - case PhysicalType::INT16: - return GetFirstFunction(LogicalType::SMALLINT); - case PhysicalType::INT32: - return GetFirstFunction(LogicalType::INTEGER); - case PhysicalType::INT64: - return GetFirstFunction(LogicalType::BIGINT); - default: - return GetFirstFunction(LogicalType::HUGEINT); - } -} -template -static AggregateFunction GetFirstFunction(const LogicalType &type) { - if (type.id() == LogicalTypeId::DECIMAL) { - type.Verify(); - AggregateFunction function = GetDecimalFirstFunction(type); - function.arguments[0] = type; - function.return_type = type; - return function; - } - switch (type.InternalType()) { - case PhysicalType::BOOL: - case PhysicalType::INT8: - return GetFirstAggregateTemplated(type); - case PhysicalType::INT16: - return GetFirstAggregateTemplated(type); - case PhysicalType::INT32: - return GetFirstAggregateTemplated(type); - case PhysicalType::INT64: - return GetFirstAggregateTemplated(type); - case PhysicalType::UINT8: - return GetFirstAggregateTemplated(type); - case PhysicalType::UINT16: - return GetFirstAggregateTemplated(type); - case PhysicalType::UINT32: - return GetFirstAggregateTemplated(type); - case PhysicalType::UINT64: - return GetFirstAggregateTemplated(type); - case PhysicalType::INT128: - return GetFirstAggregateTemplated(type); - case PhysicalType::UINT128: - return GetFirstAggregateTemplated(type); - case PhysicalType::FLOAT: - return GetFirstAggregateTemplated(type); - case PhysicalType::DOUBLE: - return GetFirstAggregateTemplated(type); - case PhysicalType::INTERVAL: - return GetFirstAggregateTemplated(type); - case PhysicalType::VARCHAR: - if (LAST) { - return AggregateFunction::UnaryAggregateDestructor, string_t, string_t, - FirstFunctionString>(type, type); - } else { - return AggregateFunction::UnaryAggregate, string_t, string_t, - FirstFunctionString>(type, type); - } - default: { - using OP = FirstVectorFunction; - using STATE = FirstState; - return AggregateFunction( - {type}, type, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - OP::Update, AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, - nullptr, OP::Bind, LAST ? AggregateFunction::StateDestroy : nullptr, nullptr, nullptr); - } - } -} - -AggregateFunction FirstFunctionGetter::GetFunction(const LogicalType &type) { - auto fun = GetFirstFunction(type); - fun.name = "first"; - return fun; -} - -template -unique_ptr BindDecimalFirst(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto decimal_type = arguments[0]->return_type; - auto name = std::move(function.name); - function = GetFirstFunction(decimal_type); - function.name = std::move(name); - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - function.return_type = decimal_type; - return nullptr; -} - -template -static AggregateFunction GetFirstOperator(const LogicalType &type) { - if (type.id() == LogicalTypeId::DECIMAL) { - throw InternalException("FIXME: this shouldn't happen..."); - } - return GetFirstFunction(type); -} - -template -unique_ptr BindFirst(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - auto input_type = arguments[0]->return_type; - auto name = std::move(function.name); - function = GetFirstOperator(input_type); - function.name = std::move(name); - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - if (function.bind) { - return function.bind(context, function, arguments); - } else { - return nullptr; - } -} - -template -static void AddFirstOperator(AggregateFunctionSet &set) { - set.AddFunction(AggregateFunction({LogicalTypeId::DECIMAL}, LogicalTypeId::DECIMAL, nullptr, nullptr, nullptr, - nullptr, nullptr, nullptr, BindDecimalFirst)); - set.AddFunction(AggregateFunction({LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, nullptr, - nullptr, BindFirst)); -} - -AggregateFunctionSet FirstFun::GetFunctions() { - AggregateFunctionSet first("first"); - AddFirstOperator(first); - return first; -} - -AggregateFunctionSet LastFun::GetFunctions() { - AggregateFunctionSet last("last"); - AddFirstOperator(last); - return last; -} - -AggregateFunctionSet AnyValueFun::GetFunctions() { - AggregateFunctionSet any_value("any_value"); - AddFirstOperator(any_value); - return any_value; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/distributive/minmax.cpp b/src/duckdb/src/function/aggregate/distributive/minmax.cpp deleted file mode 100644 index b862bf6d9..000000000 --- a/src/duckdb/src/function/aggregate/distributive/minmax.cpp +++ /dev/null @@ -1,556 +0,0 @@ -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/aggregate/distributive_functions.hpp" -#include "duckdb/function/aggregate/distributive_function_utils.hpp" -#include "duckdb/function/aggregate/minmax_n_helpers.hpp" -#include "duckdb/function/aggregate/sort_key_helpers.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/planner/expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -template -struct MinMaxState { - T value; - bool isset; -}; - -template -static AggregateFunction GetUnaryAggregate(LogicalType type) { - switch (type.InternalType()) { - case PhysicalType::BOOL: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case PhysicalType::INT8: - return AggregateFunction::UnaryAggregate, int8_t, int8_t, OP>(type, type); - case PhysicalType::INT16: - return AggregateFunction::UnaryAggregate, int16_t, int16_t, OP>(type, type); - case PhysicalType::INT32: - return AggregateFunction::UnaryAggregate, int32_t, int32_t, OP>(type, type); - case PhysicalType::INT64: - return AggregateFunction::UnaryAggregate, int64_t, int64_t, OP>(type, type); - case PhysicalType::UINT8: - return AggregateFunction::UnaryAggregate, uint8_t, uint8_t, OP>(type, type); - case PhysicalType::UINT16: - return AggregateFunction::UnaryAggregate, uint16_t, uint16_t, OP>(type, type); - case PhysicalType::UINT32: - return AggregateFunction::UnaryAggregate, uint32_t, uint32_t, OP>(type, type); - case PhysicalType::UINT64: - return AggregateFunction::UnaryAggregate, uint64_t, uint64_t, OP>(type, type); - case PhysicalType::INT128: - return AggregateFunction::UnaryAggregate, hugeint_t, hugeint_t, OP>(type, type); - case PhysicalType::UINT128: - return AggregateFunction::UnaryAggregate, uhugeint_t, uhugeint_t, OP>(type, type); - case PhysicalType::FLOAT: - return AggregateFunction::UnaryAggregate, float, float, OP>(type, type); - case PhysicalType::DOUBLE: - return AggregateFunction::UnaryAggregate, double, double, OP>(type, type); - case PhysicalType::INTERVAL: - return AggregateFunction::UnaryAggregate, interval_t, interval_t, OP>(type, type); - default: - throw InternalException("Unimplemented type for min/max aggregate"); - } -} - -struct MinMaxBase { - template - static void Initialize(STATE &state) { - state.isset = false; - } - - template - static void ConstantOperation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input, - idx_t count) { - if (!state.isset) { - OP::template Assign(state, input, unary_input.input); - state.isset = true; - } else { - OP::template Execute(state, input, unary_input.input); - } - } - - template - static void Operation(STATE &state, const INPUT_TYPE &input, AggregateUnaryInput &unary_input) { - if (!state.isset) { - OP::template Assign(state, input, unary_input.input); - state.isset = true; - } else { - OP::template Execute(state, input, unary_input.input); - } - } - - static bool IgnoreNull() { - return true; - } -}; - -struct NumericMinMaxBase : public MinMaxBase { - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &) { - state.value = input; - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = state.value; - } - } -}; - -struct MinOperation : public NumericMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (LessThan::Operation(input, state.value)) { - state.value = input; - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - target = source; - } else if (GreaterThan::Operation(target.value, source.value)) { - target.value = source.value; - } - } -}; - -struct MaxOperation : public NumericMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &) { - if (GreaterThan::Operation(input, state.value)) { - state.value = input; - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - target = source; - } else if (LessThan::Operation(target.value, source.value)) { - target.value = source.value; - } - } -}; - -struct MinMaxStringState : MinMaxState { - void Destroy() { - if (isset && !value.IsInlined()) { - delete[] value.GetData(); - } - } - - void Assign(string_t input) { - if (input.IsInlined()) { - // inlined string - we can directly store it into the string_t without having to allocate anything - Destroy(); - value = input; - } else { - // non-inlined string, need to allocate space for it somehow - auto len = input.GetSize(); - char *ptr; - if (!isset || value.GetSize() < len) { - // we cannot fit this into the current slot - destroy it and re-allocate - Destroy(); - ptr = new char[len]; - } else { - // this fits into the current slot - take over the pointer - ptr = value.GetDataWriteable(); - } - memcpy(ptr, input.GetData(), len); - - value = string_t(ptr, UnsafeNumericCast(len)); - } - } -}; - -struct StringMinMaxBase : public MinMaxBase { - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.Destroy(); - } - - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - state.Assign(input); - } - - template - static void Finalize(STATE &state, T &target, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - target = StringVector::AddStringOrBlob(finalize_data.result, state.value); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - if (!target.isset) { - // target is NULL, use source value directly - Assign(target, source.value, input_data); - target.isset = true; - } else { - OP::template Execute(target, source.value, input_data); - } - } -}; - -struct MinOperationString : public StringMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (LessThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } -}; - -struct MaxOperationString : public StringMinMaxBase { - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (GreaterThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } -}; - -template -struct VectorMinMaxBase { - static constexpr OrderType ORDER_TYPE = ORDER_TYPE_TEMPLATED; - - static bool IgnoreNull() { - return true; - } - - template - static void Initialize(STATE &state) { - state.isset = false; - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.Destroy(); - } - - template - static void Assign(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - state.Assign(input); - } - - template - static void Execute(STATE &state, INPUT_TYPE input, AggregateInputData &input_data) { - if (!state.isset) { - Assign(state, input, input_data); - state.isset = true; - return; - } - if (LessThan::Operation(input, state.value)) { - Assign(state, input, input_data); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &input_data) { - if (!source.isset) { - // source is NULL, nothing to do - return; - } - OP::template Execute(target, source.value, input_data); - } - - template - static void Finalize(STATE &state, AggregateFinalizeData &finalize_data) { - if (!state.isset) { - finalize_data.ReturnNull(); - } else { - CreateSortKeyHelpers::DecodeSortKey(state.value, finalize_data.result, finalize_data.result_idx, - OrderModifiers(ORDER_TYPE, OrderByNullType::NULLS_LAST)); - } - } - - static unique_ptr Bind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } -}; - -struct MinOperationVector : VectorMinMaxBase {}; - -struct MaxOperationVector : VectorMinMaxBase {}; - -template -static AggregateFunction GetMinMaxFunction(const LogicalType &type) { - return AggregateFunction( - {type}, LogicalType::BLOB, AggregateFunction::StateSize, AggregateFunction::StateInitialize, - AggregateSortKeyHelpers::UnaryUpdate, - AggregateFunction::StateCombine, AggregateFunction::StateVoidFinalize, nullptr, OP::Bind, - AggregateFunction::StateDestroy); -} - -template -static AggregateFunction GetMinMaxOperator(const LogicalType &type) { - auto internal_type = type.InternalType(); - switch (internal_type) { - case PhysicalType::VARCHAR: - return AggregateFunction::UnaryAggregateDestructor(type, - type); - case PhysicalType::LIST: - case PhysicalType::STRUCT: - case PhysicalType::ARRAY: - return GetMinMaxFunction(type); - default: - return GetUnaryAggregate(type); - } -} - -template -unique_ptr BindMinMax(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - if (arguments[0]->return_type.id() == LogicalTypeId::VARCHAR) { - auto str_collation = StringType::GetCollation(arguments[0]->return_type); - if (!str_collation.empty() || !DBConfig::GetConfig(context).options.collation.empty()) { - // If aggr function is min/max and uses collations, replace bound_function with arg_min/arg_max - // to make sure the result's correctness. - string function_name = function.name == "min" ? "arg_min" : "arg_max"; - QueryErrorContext error_context; - auto func = Catalog::GetEntry(context, CatalogType::AGGREGATE_FUNCTION_ENTRY, "", "", function_name, - OnEntryNotFound::RETURN_NULL, error_context); - if (!func) { - throw NotImplementedException( - "Failure while binding function \"%s\" using collations - arg_min/arg_max do not exist in the " - "catalog - load the core_functions module to fix this issue", - function.name); - } - - auto &func_entry = func->Cast(); - - FunctionBinder function_binder(context); - vector types {arguments[0]->return_type, arguments[0]->return_type}; - ErrorData error; - auto best_function = function_binder.BindFunction(func_entry.name, func_entry.functions, types, error); - if (!best_function.IsValid()) { - throw BinderException(string("Fail to find corresponding function for collation min/max: ") + - error.Message()); - } - function = func_entry.functions.GetFunctionByOffset(best_function.GetIndex()); - - // Create a copied child and PushCollation for it. - arguments.push_back(arguments[0]->Copy()); - ExpressionBinder::PushCollation(context, arguments[1], arguments[0]->return_type); - - // Bind function like arg_min/arg_max. - function.arguments[0] = arguments[0]->return_type; - function.return_type = arguments[0]->return_type; - return nullptr; - } - } - - auto input_type = arguments[0]->return_type; - if (input_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - auto name = std::move(function.name); - function = GetMinMaxOperator(input_type); - function.name = std::move(name); - function.order_dependent = AggregateOrderDependent::NOT_ORDER_DEPENDENT; - function.distinct_dependent = AggregateDistinctDependent::NOT_DISTINCT_DEPENDENT; - if (function.bind) { - return function.bind(context, function, arguments); - } else { - return nullptr; - } -} - -template -static AggregateFunction GetMinMaxOperator(string name) { - return AggregateFunction(std::move(name), {LogicalType::ANY}, LogicalType::ANY, nullptr, nullptr, nullptr, nullptr, - nullptr, nullptr, BindMinMax); -} - -AggregateFunction MinFunction::GetFunction() { - return GetMinMaxOperator("min"); -} - -AggregateFunction MaxFunction::GetFunction() { - return GetMinMaxOperator("max"); -} - -//--------------------------------------------------- -// MinMaxN -//--------------------------------------------------- - -template -class MinMaxNState { -public: - using VAL_TYPE = A; - using T = typename VAL_TYPE::TYPE; - - UnaryAggregateHeap heap; - bool is_initialized = false; - - void Initialize(idx_t nval) { - heap.Initialize(nval); - is_initialized = true; - } - - static const T &GetValue(const T &val) { - return val; - } -}; - -template -static void MinMaxNUpdate(Vector inputs[], AggregateInputData &aggr_input, idx_t input_count, Vector &state_vector, - idx_t count) { - - auto &val_vector = inputs[0]; - auto &n_vector = inputs[1]; - - UnifiedVectorFormat val_format; - UnifiedVectorFormat n_format; - UnifiedVectorFormat state_format; - ; - auto val_extra_state = STATE::VAL_TYPE::CreateExtraState(val_vector, count); - - STATE::VAL_TYPE::PrepareData(val_vector, count, val_extra_state, val_format); - - n_vector.ToUnifiedFormat(count, n_format); - state_vector.ToUnifiedFormat(count, state_format); - - auto states = UnifiedVectorFormat::GetData(state_format); - - for (idx_t i = 0; i < count; i++) { - const auto val_idx = val_format.sel->get_index(i); - if (!val_format.validity.RowIsValid(val_idx)) { - continue; - } - const auto state_idx = state_format.sel->get_index(i); - auto &state = *states[state_idx]; - - // Initialize the heap if necessary and add the input to the heap - if (!state.is_initialized) { - static constexpr int64_t MAX_N = 1000000; - const auto nidx = n_format.sel->get_index(i); - if (!n_format.validity.RowIsValid(nidx)) { - throw InvalidInputException("Invalid input for MIN/MAX: n value cannot be NULL"); - } - const auto nval = UnifiedVectorFormat::GetData(n_format)[nidx]; - if (nval <= 0) { - throw InvalidInputException("Invalid input for MIN/MAX: n value must be > 0"); - } - if (nval >= MAX_N) { - throw InvalidInputException("Invalid input for MIN/MAX: n value must be < %d", MAX_N); - } - state.Initialize(UnsafeNumericCast(nval)); - } - - // Now add the input to the heap - auto val_val = STATE::VAL_TYPE::Create(val_format, val_idx); - state.heap.Insert(aggr_input.allocator, val_val); - } -} - -template -static void SpecializeMinMaxNFunction(AggregateFunction &function) { - using STATE = MinMaxNState; - using OP = MinMaxNOperation; - - function.state_size = AggregateFunction::StateSize; - function.initialize = AggregateFunction::StateInitialize; - function.combine = AggregateFunction::StateCombine; - function.destructor = AggregateFunction::StateDestroy; - - function.finalize = MinMaxNOperation::Finalize; - function.update = MinMaxNUpdate; -} - -template -static void SpecializeMinMaxNFunction(PhysicalType arg_type, AggregateFunction &function) { - switch (arg_type) { - case PhysicalType::VARCHAR: - SpecializeMinMaxNFunction(function); - break; - case PhysicalType::INT32: - SpecializeMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::INT64: - SpecializeMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::FLOAT: - SpecializeMinMaxNFunction, COMPARATOR>(function); - break; - case PhysicalType::DOUBLE: - SpecializeMinMaxNFunction, COMPARATOR>(function); - break; - default: - SpecializeMinMaxNFunction(function); - break; - } -} - -template -unique_ptr MinMaxNBind(ClientContext &context, AggregateFunction &function, - vector> &arguments) { - - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - } - - const auto val_type = arguments[0]->return_type.InternalType(); - - // Specialize the function based on the input types - SpecializeMinMaxNFunction(val_type, function); - - function.return_type = LogicalType::LIST(arguments[0]->return_type); - return nullptr; -} - -template -static AggregateFunction GetMinMaxNFunction() { - return AggregateFunction({LogicalTypeId::ANY, LogicalType::BIGINT}, LogicalType::LIST(LogicalType::ANY), nullptr, - nullptr, nullptr, nullptr, nullptr, nullptr, MinMaxNBind, nullptr); -} - -//--------------------------------------------------- -// Function Registration -//---------------------------------------------------s -AggregateFunctionSet MinFun::GetFunctions() { - AggregateFunctionSet min("min"); - min.AddFunction(MinFunction::GetFunction()); - min.AddFunction(GetMinMaxNFunction()); - return min; -} - -AggregateFunctionSet MaxFun::GetFunctions() { - AggregateFunctionSet max("max"); - max.AddFunction(MaxFunction::GetFunction()); - max.AddFunction(GetMinMaxNFunction()); - return max; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp b/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp deleted file mode 100644 index aa49db798..000000000 --- a/src/duckdb/src/function/aggregate/sorted_aggregate_function.cpp +++ /dev/null @@ -1,818 +0,0 @@ -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/types/list_segment.hpp" -#include "duckdb/function/aggregate_function.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" -#include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/parser/expression_map.hpp" -#include "duckdb/function/aggregate/distributive_functions.hpp" - -namespace duckdb { - -struct SortedAggregateBindData : public FunctionData { - using Expressions = vector>; - using BindInfoPtr = unique_ptr; - using OrderBys = vector; - - SortedAggregateBindData(ClientContext &context, Expressions &children, AggregateFunction &aggregate, - BindInfoPtr &bind_info, OrderBys &order_bys) - : buffer_manager(BufferManager::GetBufferManager(context)), function(aggregate), - bind_info(std::move(bind_info)), threshold(ClientConfig::GetConfig(context).ordered_aggregate_threshold), - external(ClientConfig::GetConfig(context).force_external) { - arg_types.reserve(children.size()); - arg_funcs.reserve(children.size()); - for (const auto &child : children) { - arg_types.emplace_back(child->return_type); - ListSegmentFunctions funcs; - GetSegmentDataFunctions(funcs, arg_types.back()); - arg_funcs.emplace_back(std::move(funcs)); - } - sort_types.reserve(order_bys.size()); - sort_funcs.reserve(order_bys.size()); - for (auto &order : order_bys) { - orders.emplace_back(order.Copy()); - sort_types.emplace_back(order.expression->return_type); - ListSegmentFunctions funcs; - GetSegmentDataFunctions(funcs, sort_types.back()); - sort_funcs.emplace_back(std::move(funcs)); - } - sorted_on_args = (children.size() == order_bys.size()); - for (size_t i = 0; sorted_on_args && i < children.size(); ++i) { - sorted_on_args = children[i]->Equals(*order_bys[i].expression); - } - } - - SortedAggregateBindData(ClientContext &context, BoundAggregateExpression &expr) - : SortedAggregateBindData(context, expr.children, expr.function, expr.bind_info, expr.order_bys->orders) { - } - - SortedAggregateBindData(ClientContext &context, BoundWindowExpression &expr) - : SortedAggregateBindData(context, expr.children, *expr.aggregate, expr.bind_info, expr.arg_orders) { - } - - SortedAggregateBindData(const SortedAggregateBindData &other) - : buffer_manager(other.buffer_manager), function(other.function), arg_types(other.arg_types), - arg_funcs(other.arg_funcs), sort_types(other.sort_types), sort_funcs(other.sort_funcs), - sorted_on_args(other.sorted_on_args), threshold(other.threshold), external(other.external) { - if (other.bind_info) { - bind_info = other.bind_info->Copy(); - } - for (auto &order : other.orders) { - orders.emplace_back(order.Copy()); - } - } - - unique_ptr Copy() const override { - return make_uniq(*this); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - if (bind_info && other.bind_info) { - if (!bind_info->Equals(*other.bind_info)) { - return false; - } - } else if (bind_info || other.bind_info) { - return false; - } - if (function != other.function) { - return false; - } - if (orders.size() != other.orders.size()) { - return false; - } - for (size_t i = 0; i < orders.size(); ++i) { - if (!orders[i].Equals(other.orders[i])) { - return false; - } - } - return true; - } - - BufferManager &buffer_manager; - AggregateFunction function; - vector arg_types; - unique_ptr bind_info; - vector arg_funcs; - - vector orders; - vector sort_types; - vector sort_funcs; - bool sorted_on_args; - - //! The sort flush threshold - const idx_t threshold; - const bool external; -}; - -struct SortedAggregateState { - // Linked list equivalent of DataChunk - using LinkedLists = vector; - using LinkedChunkFunctions = vector; - - //! Capacities of the various levels of buffering - static const idx_t CHUNK_CAPACITY = STANDARD_VECTOR_SIZE; - static const idx_t LIST_CAPACITY = MinValue(16, CHUNK_CAPACITY); - - SortedAggregateState() : count(0), nsel(0), offset(0) { - } - - static inline void InitializeLinkedList(LinkedLists &linked, const vector &types) { - if (linked.empty() && !types.empty()) { - linked.resize(types.size(), LinkedList()); - } - } - - inline void InitializeLinkedLists(const SortedAggregateBindData &order_bind) { - InitializeLinkedList(sort_linked, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeLinkedList(arg_linked, order_bind.arg_types); - } - } - - static inline void InitializeChunk(unique_ptr &chunk, const vector &types) { - if (!chunk && !types.empty()) { - chunk = make_uniq(); - chunk->Initialize(Allocator::DefaultAllocator(), types); - } - } - - void InitializeChunks(const SortedAggregateBindData &order_bind) { - // Lazy instantiation of the buffer chunks - InitializeChunk(sort_chunk, order_bind.sort_types); - if (!order_bind.sorted_on_args) { - InitializeChunk(arg_chunk, order_bind.arg_types); - } - } - - static inline void FlushLinkedList(const LinkedChunkFunctions &funcs, LinkedLists &linked, DataChunk &chunk) { - idx_t total_count = 0; - for (column_t i = 0; i < linked.size(); ++i) { - funcs[i].BuildListVector(linked[i], chunk.data[i], total_count); - chunk.SetCardinality(linked[i].total_capacity); - } - } - - void FlushLinkedLists(const SortedAggregateBindData &order_bind) { - InitializeChunks(order_bind); - FlushLinkedList(order_bind.sort_funcs, sort_linked, *sort_chunk); - if (arg_chunk) { - FlushLinkedList(order_bind.arg_funcs, arg_linked, *arg_chunk); - } - } - - void InitializeCollections(const SortedAggregateBindData &order_bind) { - ordering = make_uniq(order_bind.buffer_manager, order_bind.sort_types); - ordering_append = make_uniq(); - ordering->InitializeAppend(*ordering_append); - - if (!order_bind.sorted_on_args) { - arguments = make_uniq(order_bind.buffer_manager, order_bind.arg_types); - arguments_append = make_uniq(); - arguments->InitializeAppend(*arguments_append); - } - } - - void FlushChunks(const SortedAggregateBindData &order_bind) { - D_ASSERT(sort_chunk); - ordering->Append(*ordering_append, *sort_chunk); - sort_chunk->Reset(); - - if (arguments) { - D_ASSERT(arg_chunk); - arguments->Append(*arguments_append, *arg_chunk); - arg_chunk->Reset(); - } - } - - void Resize(const SortedAggregateBindData &order_bind, idx_t n) { - count = n; - - // Establish the current buffering - if (count <= LIST_CAPACITY) { - InitializeLinkedLists(order_bind); - } - - if (count > LIST_CAPACITY && !sort_chunk && !ordering) { - FlushLinkedLists(order_bind); - } - - if (count > CHUNK_CAPACITY && !ordering) { - InitializeCollections(order_bind); - FlushChunks(order_bind); - } - } - - static void LinkedAppend(const LinkedChunkFunctions &functions, ArenaAllocator &allocator, DataChunk &input, - LinkedLists &linked, SelectionVector &sel, idx_t nsel) { - const auto count = input.size(); - for (column_t c = 0; c < input.ColumnCount(); ++c) { - auto &func = functions[c]; - auto &linked_list = linked[c]; - RecursiveUnifiedVectorFormat input_data; - Vector::RecursiveToUnifiedFormat(input.data[c], count, input_data); - for (idx_t i = 0; i < nsel; ++i) { - idx_t sidx = sel.get_index(i); - func.AppendRow(allocator, linked_list, input_data, sidx); - } - } - } - - static void LinkedAbsorb(LinkedLists &source, LinkedLists &target) { - D_ASSERT(source.size() == target.size()); - for (column_t i = 0; i < source.size(); ++i) { - auto &src = source[i]; - if (!src.total_capacity) { - break; - } - - auto &tgt = target[i]; - if (!tgt.total_capacity) { - tgt = src; - } else { - // append the linked list - tgt.last_segment->next = src.first_segment; - tgt.last_segment = src.last_segment; - tgt.total_capacity += src.total_capacity; - } - } - } - - void Update(const AggregateInputData &aggr_input_data, DataChunk &sort_input, DataChunk &arg_input) { - const auto &order_bind = aggr_input_data.bind_data->Cast(); - Resize(order_bind, count + sort_input.size()); - - sel.Initialize(nullptr); - nsel = sort_input.size(); - - if (ordering) { - // Using collections - ordering->Append(*ordering_append, sort_input); - if (arguments) { - arguments->Append(*arguments_append, arg_input); - } - } else if (sort_chunk) { - // Still using data chunks - sort_chunk->Append(sort_input); - if (arg_chunk) { - arg_chunk->Append(arg_input); - } - } else { - // Still using linked lists - LinkedAppend(order_bind.sort_funcs, aggr_input_data.allocator, sort_input, sort_linked, sel, nsel); - if (!arg_linked.empty()) { - LinkedAppend(order_bind.arg_funcs, aggr_input_data.allocator, arg_input, arg_linked, sel, nsel); - } - } - - nsel = 0; - offset = 0; - } - - void UpdateSlice(const AggregateInputData &aggr_input_data, DataChunk &sort_input, DataChunk &arg_input) { - const auto &order_bind = aggr_input_data.bind_data->Cast(); - Resize(order_bind, count + nsel); - - if (ordering) { - // Using collections - D_ASSERT(sort_chunk); - sort_chunk->Slice(sort_input, sel, nsel); - if (arg_chunk) { - arg_chunk->Slice(arg_input, sel, nsel); - } - FlushChunks(order_bind); - } else if (sort_chunk) { - // Still using data chunks - sort_chunk->Append(sort_input, true, &sel, nsel); - if (arg_chunk) { - arg_chunk->Append(arg_input, true, &sel, nsel); - } - } else { - // Still using linked lists - LinkedAppend(order_bind.sort_funcs, aggr_input_data.allocator, sort_input, sort_linked, sel, nsel); - if (!arg_linked.empty()) { - LinkedAppend(order_bind.arg_funcs, aggr_input_data.allocator, arg_input, arg_linked, sel, nsel); - } - } - - nsel = 0; - offset = 0; - } - - void Swap(SortedAggregateState &other) { - std::swap(count, other.count); - - std::swap(arguments, other.arguments); - std::swap(arguments_append, other.arguments_append); - std::swap(ordering, other.ordering); - std::swap(ordering_append, other.ordering_append); - - std::swap(sort_chunk, other.sort_chunk); - std::swap(arg_chunk, other.arg_chunk); - - std::swap(sort_linked, other.sort_linked); - std::swap(arg_linked, other.arg_linked); - } - - void Absorb(const SortedAggregateBindData &order_bind, SortedAggregateState &other) { - if (!other.count) { - return; - } else if (!count) { - Swap(other); - return; - } - - // Change to a state large enough for all the data - Resize(order_bind, count + other.count); - - // 3x3 matrix. - // We can simplify the logic a bit because the target is already set for the final capacity - if (!sort_chunk) { - // If the combined count is still linked lists, - // then just move the pointers. - // Note that this assumes ArenaAllocator is shared and the memory will not vanish under us. - LinkedAbsorb(other.sort_linked, sort_linked); - if (!arg_linked.empty()) { - LinkedAbsorb(other.arg_linked, arg_linked); - } - - other.Reset(); - return; - } - - if (!other.sort_chunk) { - other.FlushLinkedLists(order_bind); - } - - if (!ordering) { - // Still using chunks, which means the source is using chunks or lists - D_ASSERT(sort_chunk); - D_ASSERT(other.sort_chunk); - sort_chunk->Append(*other.sort_chunk); - if (arg_chunk) { - D_ASSERT(other.arg_chunk); - arg_chunk->Append(*other.arg_chunk); - } - } else { - // Using collections, so source could be using anything. - if (other.ordering) { - ordering->Combine(*other.ordering); - if (arguments) { - D_ASSERT(other.arguments); - arguments->Combine(*other.arguments); - } - } else { - ordering->Append(*other.sort_chunk); - if (arguments) { - D_ASSERT(other.arg_chunk); - arguments->Append(*other.arg_chunk); - } - } - } - - // Free all memory as we have absorbed it. - other.Reset(); - } - - void PrefixSortBuffer(DataChunk &prefixed) { - for (column_t col_idx = 0; col_idx < sort_chunk->ColumnCount(); ++col_idx) { - prefixed.data[col_idx + 1].Reference(sort_chunk->data[col_idx]); - } - prefixed.SetCardinality(*sort_chunk); - } - - void Finalize(const SortedAggregateBindData &order_bind, DataChunk &prefixed, LocalSortState &local_sort) { - if (arguments) { - ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - ColumnDataScanState arg_state; - arguments->InitializeScan(arg_state); - for (sort_chunk->Reset(); ordering->Scan(sort_state, *sort_chunk); sort_chunk->Reset()) { - PrefixSortBuffer(prefixed); - arg_chunk->Reset(); - arguments->Scan(arg_state, *arg_chunk); - local_sort.SinkChunk(prefixed, *arg_chunk); - } - } else if (ordering) { - ColumnDataScanState sort_state; - ordering->InitializeScan(sort_state); - for (sort_chunk->Reset(); ordering->Scan(sort_state, *sort_chunk); sort_chunk->Reset()) { - PrefixSortBuffer(prefixed); - local_sort.SinkChunk(prefixed, *sort_chunk); - } - } else { - // Force chunks so we can sort - if (!sort_chunk) { - FlushLinkedLists(order_bind); - } - - PrefixSortBuffer(prefixed); - if (arg_chunk) { - local_sort.SinkChunk(prefixed, *arg_chunk); - } else { - local_sort.SinkChunk(prefixed, *sort_chunk); - } - } - - Reset(); - } - - void Reset() { - // Release all memory - ordering.reset(); - arguments.reset(); - - sort_chunk.reset(); - arg_chunk.reset(); - - sort_linked.clear(); - arg_linked.clear(); - - count = 0; - } - - idx_t count; - - unique_ptr arguments; - unique_ptr arguments_append; - unique_ptr ordering; - unique_ptr ordering_append; - - unique_ptr sort_chunk; - unique_ptr arg_chunk; - - LinkedLists sort_linked; - LinkedLists arg_linked; - - // Selection for scattering - SelectionVector sel; - idx_t nsel; - idx_t offset; -}; - -struct SortedAggregateFunction { - template - static void Initialize(STATE &state) { - new (&state) STATE(); - } - - template - static void Destroy(STATE &state, AggregateInputData &aggr_input_data) { - state.~STATE(); - } - - static void ProjectInputs(Vector inputs[], const SortedAggregateBindData &order_bind, idx_t input_count, - idx_t count, DataChunk &arg_input, DataChunk &sort_input) { - idx_t col = 0; - - if (!order_bind.sorted_on_args) { - arg_input.InitializeEmpty(order_bind.arg_types); - for (auto &dst : arg_input.data) { - dst.Reference(inputs[col++]); - } - arg_input.SetCardinality(count); - } - - sort_input.InitializeEmpty(order_bind.sort_types); - for (auto &dst : sort_input.data) { - dst.Reference(inputs[col++]); - } - sort_input.SetCardinality(count); - } - - static void SimpleUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, data_ptr_t state, - idx_t count) { - const auto order_bind = aggr_input_data.bind_data->Cast(); - DataChunk arg_input; - DataChunk sort_input; - ProjectInputs(inputs, order_bind, input_count, count, arg_input, sort_input); - - const auto order_state = reinterpret_cast(state); - order_state->Update(aggr_input_data, sort_input, arg_input); - } - - static void ScatterUpdate(Vector inputs[], AggregateInputData &aggr_input_data, idx_t input_count, Vector &states, - idx_t count) { - if (!count) { - return; - } - - // Append the arguments to the two sub-collections - const auto &order_bind = aggr_input_data.bind_data->Cast(); - DataChunk arg_inputs; - DataChunk sort_inputs; - ProjectInputs(inputs, order_bind, input_count, count, arg_inputs, sort_inputs); - - // We have to scatter the chunks one at a time - // so build a selection vector for each one. - UnifiedVectorFormat svdata; - states.ToUnifiedFormat(count, svdata); - - // Size the selection vector for each state. - auto sdata = UnifiedVectorFormat::GetDataNoConst(svdata); - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - order_state->nsel++; - } - - // Build the selection vector for each state. - vector sel_data(count); - idx_t start = 0; - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - if (!order_state->offset) { - // First one - order_state->offset = start; - order_state->sel.Initialize(sel_data.data() + order_state->offset); - start += order_state->nsel; - } - sel_data[order_state->offset++] = UnsafeNumericCast(sidx); - } - - // Append nonempty slices to the arguments - for (idx_t i = 0; i < count; ++i) { - auto sidx = svdata.sel->get_index(i); - auto order_state = sdata[sidx]; - if (!order_state->nsel) { - continue; - } - - order_state->UpdateSlice(aggr_input_data, sort_inputs, arg_inputs); - } - } - - template - static void Combine(const STATE &source, STATE &target, AggregateInputData &aggr_input_data) { - auto &order_bind = aggr_input_data.bind_data->Cast(); - auto &other = const_cast(source); // NOLINT: absorb explicitly allows destruction - target.Absorb(order_bind, other); - } - - static void Window(AggregateInputData &aggr_input_data, const WindowPartitionInput &partition, - const_data_ptr_t g_state, data_ptr_t l_state, const SubFrames &subframes, Vector &result, - idx_t rid) { - throw InternalException("Sorted aggregates should not be generated for window clauses"); - } - - static void Finalize(Vector &states, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - const idx_t offset) { - auto &order_bind = aggr_input_data.bind_data->Cast(); - auto &buffer_manager = order_bind.buffer_manager; - RowLayout payload_layout; - payload_layout.Initialize(order_bind.arg_types); - DataChunk chunk; - chunk.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); - DataChunk sliced; - sliced.Initialize(Allocator::DefaultAllocator(), order_bind.arg_types); - - // Reusable inner state - auto &aggr = order_bind.function; - vector agg_state(aggr.state_size(aggr)); - Vector agg_state_vec(Value::POINTER(CastPointerToValue(agg_state.data()))); - - // State variables - auto bind_info = order_bind.bind_info.get(); - AggregateInputData aggr_bind_info(bind_info, aggr_input_data.allocator); - - // Inner aggregate APIs - auto initialize = aggr.initialize; - auto destructor = aggr.destructor; - auto simple_update = aggr.simple_update; - auto update = aggr.update; - auto finalize = aggr.finalize; - - auto sdata = FlatVector::GetData(states); - - vector state_unprocessed(count, 0); - for (idx_t i = 0; i < count; ++i) { - state_unprocessed[i] = sdata[i]->count; - } - - // Sort the input payloads on (state_idx ASC, orders) - vector orders; - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, - make_uniq(Value::USMALLINT(0)))); - for (const auto &order : order_bind.orders) { - orders.emplace_back(order.Copy()); - } - - auto global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = order_bind.external; - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - - DataChunk prefixed; - prefixed.Initialize(Allocator::DefaultAllocator(), global_sort->sort_layout.logical_types); - - // Go through the states accumulating values to sort until we hit the sort threshold - idx_t unsorted_count = 0; - idx_t sorted = 0; - for (idx_t finalized = 0; finalized < count;) { - if (unsorted_count < order_bind.threshold) { - auto state = sdata[finalized]; - prefixed.Reset(); - prefixed.data[0].Reference(Value::USMALLINT(UnsafeNumericCast(finalized))); - state->Finalize(order_bind, prefixed, *local_sort); - unsorted_count += state_unprocessed[finalized]; - - // Go to the next aggregate unless this is the last one - if (++finalized < count) { - continue; - } - } - - // If they were all empty (filtering) flush them - // (This can only happen on the last range) - if (!unsorted_count) { - break; - } - - // Sort all the data - global_sort->AddLocalState(*local_sort); - global_sort->PrepareMergePhase(); - while (global_sort->sorted_blocks.size() > 1) { - global_sort->InitializeMergeRound(); - MergeSorter merge_sorter(*global_sort, global_sort->buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort->CompleteMergeRound(false); - } - - auto scanner = make_uniq(*global_sort); - initialize(aggr, agg_state.data()); - while (scanner->Remaining()) { - chunk.Reset(); - scanner->Scan(chunk); - idx_t consumed = 0; - - // Distribute the scanned chunk to the aggregates - while (consumed < chunk.size()) { - // Find the next aggregate that needs data - for (; !state_unprocessed[sorted]; ++sorted) { - // Finalize a single value at the next offset - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - - initialize(aggr, agg_state.data()); - } - const auto input_count = MinValue(state_unprocessed[sorted], chunk.size() - consumed); - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - sliced.data[col_idx].Slice(chunk.data[col_idx], consumed, consumed + input_count); - } - sliced.SetCardinality(input_count); - - // These are all simple updates, so use it if available - if (simple_update) { - simple_update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state.data(), - sliced.size()); - } else { - // We are only updating a constant state - agg_state_vec.SetVectorType(VectorType::CONSTANT_VECTOR); - update(sliced.data.data(), aggr_bind_info, sliced.data.size(), agg_state_vec, sliced.size()); - } - - consumed += input_count; - state_unprocessed[sorted] -= input_count; - } - } - - // Finalize the last state for this sort - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - ++sorted; - - // Stop if we are done - if (finalized >= count) { - break; - } - - // Create a new sort - scanner.reset(); - global_sort = make_uniq(buffer_manager, orders, payload_layout); - global_sort->external = order_bind.external; - local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - unsorted_count = 0; - } - - for (; sorted < count; ++sorted) { - initialize(aggr, agg_state.data()); - - // Finalize a single value at the next offset - agg_state_vec.SetVectorType(states.GetVectorType()); - finalize(agg_state_vec, aggr_bind_info, result, 1, sorted + offset); - - if (destructor) { - destructor(agg_state_vec, aggr_bind_info, 1); - } - } - - result.Verify(count); - } -}; - -void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundAggregateExpression &expr, - const vector> &groups) { - if (!expr.order_bys || expr.order_bys->orders.empty() || expr.children.empty()) { - // not a sorted aggregate: return - return; - } - // Remove unnecessary ORDER BY clauses and return if nothing remains - if (context.config.enable_optimizer) { - if (expr.order_bys->Simplify(groups)) { - expr.order_bys.reset(); - return; - } - } - auto &bound_function = expr.function; - auto &children = expr.children; - auto &order_bys = *expr.order_bys; - auto sorted_bind = make_uniq(context, expr); - - if (!sorted_bind->sorted_on_args) { - // The arguments are the children plus the sort columns. - for (auto &order : order_bys.orders) { - children.emplace_back(std::move(order.expression)); - } - } - - vector arguments; - arguments.reserve(children.size()); - for (const auto &child : children) { - arguments.emplace_back(child->return_type); - } - - // Replace the aggregate with the wrapper - AggregateFunction ordered_aggregate( - bound_function.name, arguments, bound_function.return_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - SortedAggregateFunction::ScatterUpdate, - AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, bound_function.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, - AggregateFunction::StateDestroy, nullptr, - SortedAggregateFunction::Window); - - expr.function = std::move(ordered_aggregate); - expr.bind_info = std::move(sorted_bind); - expr.order_bys.reset(); -} - -void FunctionBinder::BindSortedAggregate(ClientContext &context, BoundWindowExpression &expr) { - if (expr.arg_orders.empty() || expr.children.empty()) { - // not a sorted aggregate: return - return; - } - // Remove unnecessary ORDER BY clauses and return if nothing remains - if (context.config.enable_optimizer) { - if (BoundOrderModifier::Simplify(expr.arg_orders, expr.partitions)) { - expr.arg_orders.clear(); - return; - } - } - auto &aggregate = *expr.aggregate; - auto &children = expr.children; - auto &arg_orders = expr.arg_orders; - auto sorted_bind = make_uniq(context, expr); - - if (!sorted_bind->sorted_on_args) { - // The arguments are the children plus the sort columns. - for (auto &order : arg_orders) { - children.emplace_back(std::move(order.expression)); - } - } - - vector arguments; - arguments.reserve(children.size()); - for (const auto &child : children) { - arguments.emplace_back(child->return_type); - } - - // Replace the aggregate with the wrapper - AggregateFunction ordered_aggregate( - aggregate.name, arguments, aggregate.return_type, AggregateFunction::StateSize, - AggregateFunction::StateInitialize, - SortedAggregateFunction::ScatterUpdate, - AggregateFunction::StateCombine, - SortedAggregateFunction::Finalize, aggregate.null_handling, SortedAggregateFunction::SimpleUpdate, nullptr, - AggregateFunction::StateDestroy, nullptr, - SortedAggregateFunction::Window); - - aggregate = std::move(ordered_aggregate); - expr.bind_info = std::move(sorted_bind); - expr.arg_orders.clear(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/aggregate_function.cpp b/src/duckdb/src/function/aggregate_function.cpp deleted file mode 100644 index dd3bc0018..000000000 --- a/src/duckdb/src/function/aggregate_function.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include "duckdb/function/aggregate_function.hpp" - -namespace duckdb { - -AggregateFunctionInfo::~AggregateFunctionInfo() { -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/built_in_functions.cpp b/src/duckdb/src/function/built_in_functions.cpp deleted file mode 100644 index 00e7eac64..000000000 --- a/src/duckdb/src/function/built_in_functions.cpp +++ /dev/null @@ -1,171 +0,0 @@ -#include "duckdb/function/built_in_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/main/extension_entries.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" -#include "duckdb/parser/parsed_data/create_collation_info.hpp" -#include "duckdb/parser/parsed_data/create_copy_function_info.hpp" -#include "duckdb/parser/parsed_data/create_pragma_function_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -BuiltinFunctions::BuiltinFunctions(CatalogTransaction transaction, Catalog &catalog) - : transaction(transaction), catalog(catalog) { -} - -BuiltinFunctions::~BuiltinFunctions() { -} - -void BuiltinFunctions::AddCollation(string name, ScalarFunction function, bool combinable, - bool not_required_for_equality) { - CreateCollationInfo info(std::move(name), std::move(function), combinable, not_required_for_equality); - info.internal = true; - catalog.CreateCollation(transaction, info); -} - -void BuiltinFunctions::AddFunction(AggregateFunctionSet set) { - CreateAggregateFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(AggregateFunction function) { - CreateAggregateFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(PragmaFunction function) { - CreatePragmaFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreatePragmaFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(const string &name, PragmaFunctionSet functions) { - CreatePragmaFunctionInfo info(name, std::move(functions)); - info.internal = true; - catalog.CreatePragmaFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(ScalarFunction function) { - CreateScalarFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(const vector &names, ScalarFunction function) { // NOLINT: false positive - for (auto &name : names) { - function.name = name; - AddFunction(function); - } -} - -void BuiltinFunctions::AddFunction(ScalarFunctionSet set) { - CreateScalarFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(TableFunction function) { - CreateTableFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateTableFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(TableFunctionSet set) { - CreateTableFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateTableFunction(transaction, info); -} - -void BuiltinFunctions::AddFunction(CopyFunction function) { - CreateCopyFunctionInfo info(std::move(function)); - info.internal = true; - catalog.CreateCopyFunction(transaction, info); -} - -struct ExtensionFunctionInfo : public ScalarFunctionInfo { - explicit ExtensionFunctionInfo(string extension_p) : extension(std::move(extension_p)) { - } - - string extension; -}; - -unique_ptr BindExtensionFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // if this is triggered we are trying to call a method that is present in an extension - // but the extension is not loaded - // try to autoload the extension - // first figure out which extension we need to auto-load - auto &function_info = bound_function.function_info->Cast(); - auto &extension_name = function_info.extension; - auto &db = *context.db; - - if (!ExtensionHelper::CanAutoloadExtension(extension_name)) { - throw BinderException("Trying to call function \"%s\" which is present in extension \"%s\" - but the extension " - "is not loaded and could not be auto-loaded", - bound_function.name, extension_name); - } - // auto-load the extension - ExtensionHelper::AutoLoadExtension(db, extension_name); - - // now find the function in the catalog - auto &catalog = Catalog::GetSystemCatalog(db); - auto &function_entry = catalog.GetEntry(context, DEFAULT_SCHEMA, bound_function.name); - // override the function with the extension function - bound_function = function_entry.functions.GetFunctionByArguments(context, bound_function.arguments); - // call the original bind (if any) - if (!bound_function.bind) { - return nullptr; - } - return bound_function.bind(context, bound_function, arguments); -} - -void BuiltinFunctions::AddExtensionFunction(ScalarFunctionSet set) { - CreateScalarFunctionInfo info(std::move(set)); - info.internal = true; - catalog.CreateFunction(transaction, info); -} - -void BuiltinFunctions::RegisterExtensionOverloads() { -#ifdef GENERATE_EXTENSION_ENTRIES - // do not insert auto loading placeholders when generating extension entries - return; -#endif - ScalarFunctionSet current_set; - for (auto &entry : EXTENSION_FUNCTION_OVERLOADS) { - vector arguments; - auto splits = StringUtil::Split(entry.signature, ">"); - auto return_type = DBConfig::ParseLogicalType(splits[1]); - auto argument_splits = StringUtil::Split(splits[0], ","); - for (auto ¶m : argument_splits) { - arguments.push_back(DBConfig::ParseLogicalType(param)); - } - if (entry.type != CatalogType::SCALAR_FUNCTION_ENTRY) { - throw InternalException( - "Extension function overloads only supported for scalar functions currently - %s has a different type", - entry.name); - } - - ScalarFunction function(entry.name, std::move(arguments), std::move(return_type), nullptr, - BindExtensionFunction); - function.function_info = make_shared_ptr(entry.extension); - if (current_set.name != entry.name) { - if (!current_set.name.empty()) { - // create set of functions - AddExtensionFunction(current_set); - } - current_set = ScalarFunctionSet(entry.name); - } - // add this function to the set of function overloads - current_set.AddFunction(std::move(function)); - } - AddExtensionFunction(std::move(current_set)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/array_casts.cpp b/src/duckdb/src/function/cast/array_casts.cpp deleted file mode 100644 index 2357a2c2c..000000000 --- a/src/duckdb/src/function/cast/array_casts.cpp +++ /dev/null @@ -1,224 +0,0 @@ -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -namespace duckdb { - -unique_ptr ArrayBoundCastData::BindArrayToArrayCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_child_type = ArrayType::GetChildType(source); - auto &result_child_type = ArrayType::GetChildType(target); - auto child_cast = input.GetCastFunction(source_child_type, result_child_type); - return make_uniq(std::move(child_cast)); -} - -static unique_ptr BindArrayToListCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::ARRAY); - D_ASSERT(target.id() == LogicalTypeId::LIST); - - vector child_cast_info; - auto &source_child_type = ArrayType::GetChildType(source); - auto &result_child_type = ListType::GetChildType(target); - auto child_cast = input.GetCastFunction(source_child_type, result_child_type); - return make_uniq(std::move(child_cast)); -} - -unique_ptr ArrayBoundCastData::InitArrayLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - if (!cast_data.child_cast_info.init_local_state) { - return nullptr; - } - CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); - return cast_data.child_cast_info.init_local_state(child_parameters); -} - -//------------------------------------------------------------------------------ -// ARRAY -> ARRAY -//------------------------------------------------------------------------------ -static bool ArrayToArrayCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - - auto source_array_size = ArrayType::GetSize(source.GetType()); - auto target_array_size = ArrayType::GetSize(result.GetType()); - if (source_array_size != target_array_size) { - // Cant cast between arrays of different sizes - auto msg = StringUtil::Format("Cannot cast array of size %u to array of size %u", source_array_size, - target_array_size); - HandleCastError::AssignError(msg, parameters); - if (!parameters.strict) { - // if this was a TRY_CAST, we know every row will fail, so just return null - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return false; - } - } - - auto &cast_data = parameters.cast_data->Cast(); - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - if (ConstantVector::IsNull(source)) { - ConstantVector::SetNull(result, true); - } - - auto &source_cc = ArrayVector::GetEntry(source); - auto &result_cc = ArrayVector::GetEntry(result); - - // If the array vector is constant, the child vector must be flat (or constant if array size is 1) - D_ASSERT(source_cc.GetVectorType() == VectorType::FLAT_VECTOR || source_array_size == 1); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_ok = cast_data.child_cast_info.function(source_cc, result_cc, source_array_size, child_parameters); - return all_ok; - } else { - // Flatten if not constant - source.Flatten(count); - result.SetVectorType(VectorType::FLAT_VECTOR); - - FlatVector::SetValidity(result, FlatVector::Validity(source)); - auto &source_cc = ArrayVector::GetEntry(source); - auto &result_cc = ArrayVector::GetEntry(result); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_ok = - cast_data.child_cast_info.function(source_cc, result_cc, count * source_array_size, child_parameters); - return all_ok; - } -} - -//------------------------------------------------------------------------------ -// ARRAY -> VARCHAR -//------------------------------------------------------------------------------ -static bool ArrayToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto is_constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - - auto size = ArrayType::GetSize(source.GetType()); - Vector varchar_list(LogicalType::ARRAY(LogicalType::VARCHAR, size), count); - ArrayToArrayCast(source, varchar_list, count, parameters); - - varchar_list.Flatten(count); - auto &validity = FlatVector::Validity(varchar_list); - auto &child = ArrayVector::GetEntry(varchar_list); - - child.Flatten(count); - auto &child_validity = FlatVector::Validity(child); - - auto in_data = FlatVector::GetData(child); - auto out_data = FlatVector::GetData(result); - - static constexpr const idx_t SEP_LENGTH = 2; - static constexpr const idx_t NULL_LENGTH = 4; - - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - - // First pass, compute the length - idx_t array_varchar_length = 2; - for (idx_t j = 0; j < size; j++) { - auto elem_idx = (i * size) + j; - auto elem = in_data[elem_idx]; - if (j > 0) { - array_varchar_length += SEP_LENGTH; - } - array_varchar_length += child_validity.RowIsValid(elem_idx) ? elem.GetSize() : NULL_LENGTH; - } - - out_data[i] = StringVector::EmptyString(result, array_varchar_length); - auto dataptr = out_data[i].GetDataWriteable(); - idx_t offset = 0; - dataptr[offset++] = '['; - - // Second pass, write the actual data - for (idx_t j = 0; j < size; j++) { - auto elem_idx = (i * size) + j; - auto elem = in_data[elem_idx]; - if (j > 0) { - memcpy(dataptr + offset, ", ", SEP_LENGTH); - offset += SEP_LENGTH; - } - if (child_validity.RowIsValid(elem_idx)) { - auto len = elem.GetSize(); - memcpy(dataptr + offset, elem.GetData(), len); - offset += len; - } else { - memcpy(dataptr + offset, "NULL", NULL_LENGTH); - offset += NULL_LENGTH; - } - } - dataptr[offset++] = ']'; - out_data[i].Finalize(); - } - - if (is_constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - return true; -} - -//------------------------------------------------------------------------------ -// ARRAY -> LIST -//------------------------------------------------------------------------------ -static bool ArrayToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - - // FIXME: dont flatten - source.Flatten(count); - - auto array_size = ArrayType::GetSize(source.GetType()); - auto child_count = count * array_size; - - ListVector::Reserve(result, child_count); - ListVector::SetListSize(result, child_count); - - auto &source_child = ArrayVector::GetEntry(source); - auto &result_child = ListVector::GetEntry(result); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_ok = cast_data.child_cast_info.function(source_child, result_child, child_count, child_parameters); - - auto list_data = ListVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - if (FlatVector::IsNull(source, i)) { - FlatVector::SetNull(result, i, true); - continue; - } - - list_data[i].offset = i * array_size; - list_data[i].length = array_size; - } - - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - return all_ok; -} - -BoundCastInfo DefaultCasts::ArrayCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::VARCHAR: { - auto size = ArrayType::GetSize(source); - return BoundCastInfo( - ArrayToVarcharCast, - ArrayBoundCastData::BindArrayToArrayCast(input, source, LogicalType::ARRAY(LogicalType::VARCHAR, size)), - ArrayBoundCastData::InitArrayLocalState); - } - case LogicalTypeId::ARRAY: - return BoundCastInfo(ArrayToArrayCast, ArrayBoundCastData::BindArrayToArrayCast(input, source, target), - ArrayBoundCastData::InitArrayLocalState); - case LogicalTypeId::LIST: - return BoundCastInfo(ArrayToListCast, BindArrayToListCast(input, source, target), - ArrayBoundCastData::InitArrayLocalState); - default: - return DefaultCasts::TryVectorNullCast; - }; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/bit_cast.cpp b/src/duckdb/src/function/cast/bit_cast.cpp deleted file mode 100644 index e542a1427..000000000 --- a/src/duckdb/src/function/cast/bit_cast.cpp +++ /dev/null @@ -1,51 +0,0 @@ -#include "duckdb/common/hugeint.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -BoundCastInfo DefaultCasts::BitCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - // Numerics - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::UHUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - - case LogicalTypeId::BLOB: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/blob_cast.cpp b/src/duckdb/src/function/cast/blob_cast.cpp deleted file mode 100644 index 170a733d2..000000000 --- a/src/duckdb/src/function/cast/blob_cast.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -BoundCastInfo DefaultCasts::BlobCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // blob to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::AGGREGATE_STATE: - return DefaultCasts::ReinterpretCast; - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/cast_function_set.cpp b/src/duckdb/src/function/cast/cast_function_set.cpp deleted file mode 100644 index 48b8bef77..000000000 --- a/src/duckdb/src/function/cast/cast_function_set.cpp +++ /dev/null @@ -1,215 +0,0 @@ -#include "duckdb/function/cast/cast_function_set.hpp" - -#include "duckdb/common/pair.hpp" -#include "duckdb/common/types/type_map.hpp" -#include "duckdb/function/cast_rules.hpp" -#include "duckdb/planner/collation_binding.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -BindCastInput::BindCastInput(CastFunctionSet &function_set, optional_ptr info, - optional_ptr context) - : function_set(function_set), info(info), context(context) { -} - -BoundCastInfo BindCastInput::GetCastFunction(const LogicalType &source, const LogicalType &target) { - GetCastFunctionInput input(context); - input.query_location = query_location; - return function_set.GetCastFunction(source, target, input); -} - -BindCastFunction::BindCastFunction(bind_cast_function_t function_p, unique_ptr info_p) - : function(function_p), info(std::move(info_p)) { -} - -CastFunctionSet::CastFunctionSet() : map_info(nullptr) { - bind_functions.emplace_back(DefaultCasts::GetDefaultCastFunction); -} - -CastFunctionSet::CastFunctionSet(DBConfig &config_p) : CastFunctionSet() { - this->config = &config_p; -} - -CastFunctionSet &CastFunctionSet::Get(ClientContext &context) { - return DBConfig::GetConfig(context).GetCastFunctions(); -} - -CollationBinding &CollationBinding::Get(ClientContext &context) { - return DBConfig::GetConfig(context).GetCollationBinding(); -} - -CastFunctionSet &CastFunctionSet::Get(DatabaseInstance &db) { - return DBConfig::GetConfig(db).GetCastFunctions(); -} - -CollationBinding &CollationBinding::Get(DatabaseInstance &db) { - return DBConfig::GetConfig(db).GetCollationBinding(); -} - -BoundCastInfo CastFunctionSet::GetCastFunction(const LogicalType &source, const LogicalType &target, - GetCastFunctionInput &get_input) { - if (source == target) { - return DefaultCasts::NopCast; - } - // the first function is the default - // we iterate the set of bind functions backwards - for (idx_t i = bind_functions.size(); i > 0; i--) { - auto &bind_function = bind_functions[i - 1]; - BindCastInput input(*this, bind_function.info.get(), get_input.context); - input.query_location = get_input.query_location; - auto result = bind_function.function(input, source, target); - if (result.function) { - // found a cast function! return it - return result; - } - } - // no cast found: return the default null cast - return DefaultCasts::TryVectorNullCast; -} - -struct MapCastNode { - MapCastNode(BoundCastInfo info, int64_t implicit_cast_cost) - : cast_info(std::move(info)), bind_function(nullptr), implicit_cast_cost(implicit_cast_cost) { - } - MapCastNode(bind_cast_function_t func, int64_t implicit_cast_cost) - : cast_info(nullptr), bind_function(func), implicit_cast_cost(implicit_cast_cost) { - } - - BoundCastInfo cast_info; - bind_cast_function_t bind_function; - int64_t implicit_cast_cost; -}; - -template -static auto RelaxedTypeMatch(type_map_t &map, const LogicalType &type) -> decltype(map.find(type)) { - D_ASSERT(map.find(type) == map.end()); // we shouldn't be here - switch (type.id()) { - case LogicalTypeId::LIST: - return map.find(LogicalType::LIST(LogicalType::ANY)); - case LogicalTypeId::STRUCT: - return map.find(LogicalType::STRUCT({{"any", LogicalType::ANY}})); - case LogicalTypeId::MAP: - for (auto it = map.begin(); it != map.end(); it++) { - const auto &entry_type = it->first; - if (entry_type.id() != LogicalTypeId::MAP) { - continue; - } - auto &entry_key_type = MapType::KeyType(entry_type); - auto &entry_val_type = MapType::ValueType(entry_type); - if ((entry_key_type == LogicalType::ANY || entry_key_type == MapType::KeyType(type)) && - (entry_val_type == LogicalType::ANY || entry_val_type == MapType::ValueType(type))) { - return it; - } - } - return map.end(); - case LogicalTypeId::UNION: - return map.find(LogicalType::UNION({{"any", LogicalType::ANY}})); - case LogicalTypeId::ARRAY: - return map.find(LogicalType::ARRAY(LogicalType::ANY, optional_idx())); - default: - return map.find(LogicalType::ANY); - } -} - -struct MapCastInfo : public BindCastInfo { -public: - const optional_ptr GetEntry(const LogicalType &source, const LogicalType &target) { - auto source_type_id_entry = casts.find(source.id()); - if (source_type_id_entry == casts.end()) { - source_type_id_entry = casts.find(LogicalTypeId::ANY); - if (source_type_id_entry == casts.end()) { - return nullptr; - } - } - - auto &source_type_entries = source_type_id_entry->second; - auto source_type_entry = source_type_entries.find(source); - if (source_type_entry == source_type_entries.end()) { - source_type_entry = RelaxedTypeMatch(source_type_entries, source); - if (source_type_entry == source_type_entries.end()) { - return nullptr; - } - } - - auto &target_type_id_entries = source_type_entry->second; - auto target_type_id_entry = target_type_id_entries.find(target.id()); - if (target_type_id_entry == target_type_id_entries.end()) { - target_type_id_entry = target_type_id_entries.find(LogicalTypeId::ANY); - if (target_type_id_entry == target_type_id_entries.end()) { - return nullptr; - } - } - - auto &target_type_entries = target_type_id_entry->second; - auto target_type_entry = target_type_entries.find(target); - if (target_type_entry == target_type_entries.end()) { - target_type_entry = RelaxedTypeMatch(target_type_entries, target); - if (target_type_entry == target_type_entries.end()) { - return nullptr; - } - } - - return &target_type_entry->second; - } - - void AddEntry(const LogicalType &source, const LogicalType &target, MapCastNode node) { - casts[source.id()][source][target.id()].insert(make_pair(target, std::move(node))); - } - -private: - type_id_map_t>>> casts; -}; - -int64_t CastFunctionSet::ImplicitCastCost(const LogicalType &source, const LogicalType &target) { - // check if a cast has been registered - if (map_info) { - auto entry = map_info->GetEntry(source, target); - if (entry) { - return entry->implicit_cast_cost; - } - } - // if not, fallback to the default implicit cast rules - auto score = CastRules::ImplicitCast(source, target); - if (score < 0 && config && config->options.old_implicit_casting) { - if (source.id() != LogicalTypeId::BLOB && target.id() == LogicalTypeId::VARCHAR) { - score = 149; - } - } - return score; -} - -BoundCastInfo MapCastFunction(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(input.info); - auto &map_info = input.info->Cast(); - auto entry = map_info.GetEntry(source, target); - if (entry) { - if (entry->bind_function) { - return entry->bind_function(input, source, target); - } - return entry->cast_info.Copy(); - } - return nullptr; -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, BoundCastInfo function, - int64_t implicit_cast_cost) { - RegisterCastFunction(source, target, MapCastNode(std::move(function), implicit_cast_cost)); -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, - bind_cast_function_t bind_function, int64_t implicit_cast_cost) { - RegisterCastFunction(source, target, MapCastNode(bind_function, implicit_cast_cost)); -} - -void CastFunctionSet::RegisterCastFunction(const LogicalType &source, const LogicalType &target, MapCastNode node) { - if (!map_info) { - // create the cast map and the cast map function - auto info = make_uniq(); - map_info = info.get(); - bind_functions.emplace_back(MapCastFunction, std::move(info)); - } - map_info->AddEntry(source, target, std::move(node)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/decimal_cast.cpp b/src/duckdb/src/function/cast/decimal_cast.cpp deleted file mode 100644 index 21f80bf04..000000000 --- a/src/duckdb/src/function/cast/decimal_cast.cpp +++ /dev/null @@ -1,324 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -#include "duckdb/common/vector_operations/general_cast.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/common/types/cast_helpers.hpp" - -namespace duckdb { - -template -static bool FromDecimalCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &source_type = source.GetType(); - auto width = DecimalType::GetWidth(source_type); - auto scale = DecimalType::GetScale(source_type); - switch (source_type.InternalType()) { - case PhysicalType::INT16: - return VectorCastHelpers::TemplatedDecimalCast(source, result, count, - parameters, width, scale); - case PhysicalType::INT32: - return VectorCastHelpers::TemplatedDecimalCast(source, result, count, - parameters, width, scale); - case PhysicalType::INT64: - return VectorCastHelpers::TemplatedDecimalCast(source, result, count, - parameters, width, scale); - case PhysicalType::INT128: - return VectorCastHelpers::TemplatedDecimalCast(source, result, count, - parameters, width, scale); - default: - throw InternalException("Unimplemented internal type for decimal"); - } -} - -template -struct DecimalScaleInput { - DecimalScaleInput(Vector &result_p, FACTOR_TYPE factor_p, CastParameters ¶meters) - : result(result_p), vector_cast_data(result, parameters), factor(factor_p) { - } - DecimalScaleInput(Vector &result_p, LIMIT_TYPE limit_p, FACTOR_TYPE factor_p, CastParameters ¶meters, - uint8_t source_width_p, uint8_t source_scale_p) - : result(result_p), vector_cast_data(result, parameters), limit(limit_p), factor(factor_p), - source_width(source_width_p), source_scale(source_scale_p) { - } - - Vector &result; - VectorTryCastData vector_cast_data; - LIMIT_TYPE limit; - FACTOR_TYPE factor; - uint8_t source_width; - uint8_t source_scale; -}; - -struct DecimalScaleUpOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - return Cast::Operation(input) * data->factor; - } -}; - -struct DecimalScaleUpCheckOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = (DecimalScaleInput *)dataptr; - if (input >= data->limit || input <= -data->limit) { - auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", - Decimal::ToString(input, data->source_width, data->source_scale), - data->result.GetType().ToString()); - return HandleVectorCastError::Operation(std::move(error), mask, idx, data->vector_cast_data); - } - return Cast::Operation(input) * data->factor; - } -}; - -template -bool TemplatedDecimalScaleUp(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto source_width = DecimalType::GetWidth(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - auto result_width = DecimalType::GetWidth(result.GetType()); - D_ASSERT(result_scale >= source_scale); - idx_t scale_difference = result_scale - source_scale; - DEST multiply_factor = UnsafeNumericCast(POWERS_DEST::POWERS_OF_TEN[scale_difference]); - idx_t target_width = result_width - scale_difference; - if (source_width < target_width) { - DecimalScaleInput input(result, multiply_factor, parameters); - // type will always fit: no need to check limit - UnaryExecutor::GenericExecute(source, result, count, &input); - return true; - } else { - // type might not fit: check limit - auto limit = UnsafeNumericCast(POWERS_SOURCE::POWERS_OF_TEN[target_width]); - DecimalScaleInput input(result, limit, multiply_factor, parameters, source_width, source_scale); - UnaryExecutor::GenericExecute(source, result, count, &input, - parameters.error_message); - return input.vector_cast_data.all_converted; - } -} - -struct DecimalScaleDownOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - // We need to round here, not truncate. - auto data = (DecimalScaleInput *)dataptr; - // Scale first so we don't overflow when rounding. - const auto scaling = data->factor / 2; - input /= scaling; - if (input < 0) { - input -= 1; - } else { - input += 1; - } - return Cast::Operation(input / 2); - } -}; - -// This function detects if we can scale a decimal down to another. -template -bool CanScaleDownDecimal(INPUT_TYPE input, DecimalScaleInput &data) { - int64_t divisor = UnsafeNumericCast(NumericHelper::POWERS_OF_TEN[data.source_scale]); - auto value = input % divisor; - auto rounded_input = input; - if (rounded_input < 0) { - rounded_input *= -1; - value *= -1; - } - if (value >= divisor / 2) { - rounded_input += divisor; - } - return rounded_input < data.limit && rounded_input > -data.limit; -} - -template <> -bool CanScaleDownDecimal(hugeint_t input, DecimalScaleInput &data) { - auto divisor = UnsafeNumericCast(Hugeint::POWERS_OF_TEN[data.source_scale]); - hugeint_t value = input % divisor; - hugeint_t rounded_input = input; - if (rounded_input < 0) { - rounded_input *= -1; - value *= -1; - } - if (value >= divisor / 2) { - rounded_input += divisor; - } - return rounded_input < data.limit && rounded_input > -data.limit; -} - -struct DecimalScaleDownCheckOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = static_cast *>(dataptr); - if (!CanScaleDownDecimal(input, *data)) { - auto error = StringUtil::Format("Casting value \"%s\" to type %s failed: value is out of range!", - Decimal::ToString(input, data->source_width, data->source_scale), - data->result.GetType().ToString()); - return HandleVectorCastError::Operation(std::move(error), mask, idx, data->vector_cast_data); - } - return DecimalScaleDownOperator::Operation(input, mask, idx, dataptr); - } -}; - -template -bool TemplatedDecimalScaleDown(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto source_width = DecimalType::GetWidth(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - auto result_width = DecimalType::GetWidth(result.GetType()); - D_ASSERT(result_scale < source_scale); - idx_t scale_difference = source_scale - result_scale; - idx_t target_width = result_width + scale_difference; - auto divide_factor = UnsafeNumericCast(POWERS_SOURCE::POWERS_OF_TEN[scale_difference]); - if (source_width < target_width) { - DecimalScaleInput input(result, divide_factor, parameters); - // type will always fit: no need to check limit - UnaryExecutor::GenericExecute(source, result, count, &input); - return true; - } else { - // type might not fit: check limit - auto limit = UnsafeNumericCast(POWERS_SOURCE::POWERS_OF_TEN[target_width]); - DecimalScaleInput input(result, limit, divide_factor, parameters, source_width, source_scale); - UnaryExecutor::GenericExecute(source, result, count, &input, - parameters.error_message); - return input.vector_cast_data.all_converted; - } -} - -template -static bool DecimalDecimalCastSwitch(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto source_scale = DecimalType::GetScale(source.GetType()); - auto result_scale = DecimalType::GetScale(result.GetType()); - source.GetType().Verify(); - result.GetType().Verify(); - - // we need to either multiply or divide by the difference in scales - if (result_scale >= source_scale) { - // multiply - switch (result.GetType().InternalType()) { - case PhysicalType::INT16: - return TemplatedDecimalScaleUp(source, result, count, - parameters); - case PhysicalType::INT32: - return TemplatedDecimalScaleUp(source, result, count, - parameters); - case PhysicalType::INT64: - return TemplatedDecimalScaleUp(source, result, count, - parameters); - case PhysicalType::INT128: - return TemplatedDecimalScaleUp(source, result, count, - parameters); - default: - throw NotImplementedException("Unimplemented internal type for decimal"); - } - } else { - // divide - switch (result.GetType().InternalType()) { - case PhysicalType::INT16: - return TemplatedDecimalScaleDown(source, result, count, parameters); - case PhysicalType::INT32: - return TemplatedDecimalScaleDown(source, result, count, parameters); - case PhysicalType::INT64: - return TemplatedDecimalScaleDown(source, result, count, parameters); - case PhysicalType::INT128: - return TemplatedDecimalScaleDown(source, result, count, parameters); - default: - throw NotImplementedException("Unimplemented internal type for decimal"); - } - } -} - -struct DecimalCastInput { - DecimalCastInput(Vector &result_p, uint8_t width_p, uint8_t scale_p) - : result(result_p), width(width_p), scale(scale_p) { - } - - Vector &result; - uint8_t width; - uint8_t scale; -}; - -struct StringCastFromDecimalOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, ValidityMask &mask, idx_t idx, void *dataptr) { - auto data = reinterpret_cast(dataptr); - return StringCastFromDecimal::Operation(input, data->width, data->scale, data->result); - } -}; - -template -static bool DecimalToStringCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &source_type = source.GetType(); - auto width = DecimalType::GetWidth(source_type); - auto scale = DecimalType::GetScale(source_type); - DecimalCastInput input(result, width, scale); - - UnaryExecutor::GenericExecute(source, result, count, (void *)&input); - return true; -} - -BoundCastInfo DefaultCasts::DecimalCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::BOOLEAN: - return FromDecimalCast; - case LogicalTypeId::TINYINT: - return FromDecimalCast; - case LogicalTypeId::SMALLINT: - return FromDecimalCast; - case LogicalTypeId::INTEGER: - return FromDecimalCast; - case LogicalTypeId::BIGINT: - return FromDecimalCast; - case LogicalTypeId::UTINYINT: - return FromDecimalCast; - case LogicalTypeId::USMALLINT: - return FromDecimalCast; - case LogicalTypeId::UINTEGER: - return FromDecimalCast; - case LogicalTypeId::UBIGINT: - return FromDecimalCast; - case LogicalTypeId::HUGEINT: - return FromDecimalCast; - case LogicalTypeId::UHUGEINT: - return FromDecimalCast; - case LogicalTypeId::DECIMAL: { - // decimal to decimal cast - // first we need to figure out the source and target internal types - switch (source.InternalType()) { - case PhysicalType::INT16: - return DecimalDecimalCastSwitch; - case PhysicalType::INT32: - return DecimalDecimalCastSwitch; - case PhysicalType::INT64: - return DecimalDecimalCastSwitch; - case PhysicalType::INT128: - return DecimalDecimalCastSwitch; - default: - throw NotImplementedException("Unimplemented internal type for decimal in decimal_decimal cast"); - } - } - case LogicalTypeId::FLOAT: - return FromDecimalCast; - case LogicalTypeId::DOUBLE: - return FromDecimalCast; - case LogicalTypeId::VARCHAR: { - switch (source.InternalType()) { - case PhysicalType::INT16: - return DecimalToStringCast; - case PhysicalType::INT32: - return DecimalToStringCast; - case PhysicalType::INT64: - return DecimalToStringCast; - case PhysicalType::INT128: - return DecimalToStringCast; - default: - throw InternalException("Unimplemented internal decimal type"); - } - } - default: - return DefaultCasts::TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/default_casts.cpp b/src/duckdb/src/function/cast/default_casts.cpp deleted file mode 100644 index d0da44327..000000000 --- a/src/duckdb/src/function/cast/default_casts.cpp +++ /dev/null @@ -1,155 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" - -#include "duckdb/common/likely.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/cast_helpers.hpp" - -#include "duckdb/common/types/null_value.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -BindCastInfo::~BindCastInfo() { -} - -BoundCastData::~BoundCastData() { -} - -BoundCastInfo::BoundCastInfo(cast_function_t function_p, unique_ptr cast_data_p, - init_cast_local_state_t init_local_state_p) - : function(function_p), init_local_state(init_local_state_p), cast_data(std::move(cast_data_p)) { -} - -BoundCastInfo BoundCastInfo::Copy() const { - return BoundCastInfo(function, cast_data ? cast_data->Copy() : nullptr, init_local_state); -} - -bool DefaultCasts::NopCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - result.Reference(source); - return true; -} - -void HandleCastError::AssignError(const string &error_message, CastParameters ¶meters) { - AssignError(error_message, parameters.error_message, parameters.query_location); -} - -static string UnimplementedCastMessage(const LogicalType &source_type, const LogicalType &target_type) { - return StringUtil::Format("Unimplemented type for cast (%s -> %s)", source_type.ToString(), target_type.ToString()); -} - -// NULL cast only works if all values in source are NULL, otherwise an unimplemented cast exception is thrown -bool DefaultCasts::TryVectorNullCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - bool success = true; - if (VectorOperations::HasNotNull(source, count)) { - HandleCastError::AssignError(UnimplementedCastMessage(source.GetType(), result.GetType()), parameters); - success = false; - } - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return success; -} - -bool DefaultCasts::ReinterpretCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - result.Reinterpret(source); - return true; -} - -static bool AggregateStateToBlobCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - if (result.GetType().id() != LogicalTypeId::BLOB) { - throw TypeMismatchException(source.GetType(), result.GetType(), - "Cannot cast AGGREGATE_STATE to anything but BLOB"); - } - result.Reinterpret(source); - return true; -} - -static bool NullTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - // cast a NULL to another type, just copy the properties and change the type - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return true; -} - -BoundCastInfo DefaultCasts::GetDefaultCastFunction(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source != target); - - // first check if were casting to a union - if (source.id() != LogicalTypeId::UNION && source.id() != LogicalTypeId::SQLNULL && - target.id() == LogicalTypeId::UNION) { - return ImplicitToUnionCast(input, source, target); - } - - // else, switch on source type - switch (source.id()) { - case LogicalTypeId::BOOLEAN: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return NumericCastSwitch(input, source, target); - case LogicalTypeId::POINTER: - return PointerCastSwitch(input, source, target); - case LogicalTypeId::UUID: - return UUIDCastSwitch(input, source, target); - case LogicalTypeId::DECIMAL: - return DecimalCastSwitch(input, source, target); - case LogicalTypeId::DATE: - return DateCastSwitch(input, source, target); - case LogicalTypeId::TIME: - return TimeCastSwitch(input, source, target); - case LogicalTypeId::TIME_TZ: - return TimeTzCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP: - return TimestampCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_TZ: - return TimestampTzCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_NS: - return TimestampNsCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_MS: - return TimestampMsCastSwitch(input, source, target); - case LogicalTypeId::TIMESTAMP_SEC: - return TimestampSecCastSwitch(input, source, target); - case LogicalTypeId::INTERVAL: - return IntervalCastSwitch(input, source, target); - case LogicalTypeId::VARCHAR: - return StringCastSwitch(input, source, target); - case LogicalTypeId::BLOB: - return BlobCastSwitch(input, source, target); - case LogicalTypeId::BIT: - return BitCastSwitch(input, source, target); - case LogicalTypeId::SQLNULL: - return NullTypeCast; - case LogicalTypeId::MAP: - return MapCastSwitch(input, source, target); - case LogicalTypeId::STRUCT: - return StructCastSwitch(input, source, target); - case LogicalTypeId::LIST: - return ListCastSwitch(input, source, target); - case LogicalTypeId::UNION: - return UnionCastSwitch(input, source, target); - case LogicalTypeId::ENUM: - return EnumCastSwitch(input, source, target); - case LogicalTypeId::ARRAY: - return ArrayCastSwitch(input, source, target); - case LogicalTypeId::VARINT: - return VarintCastSwitch(input, source, target); - case LogicalTypeId::AGGREGATE_STATE: - return AggregateStateToBlobCast; - default: - return nullptr; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/enum_casts.cpp b/src/duckdb/src/function/cast/enum_casts.cpp deleted file mode 100644 index 7e22cb932..000000000 --- a/src/duckdb/src/function/cast/enum_casts.cpp +++ /dev/null @@ -1,147 +0,0 @@ -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" - -namespace duckdb { - -template -bool EnumEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &enum_dictionary = EnumType::GetValuesInsertOrder(source.GetType()); - auto dictionary_data = FlatVector::GetData(enum_dictionary); - auto res_enum_type = result.GetType(); - - VectorTryCastData vector_cast_data(result, parameters); - UnaryExecutor::ExecuteWithNulls( - source, result, count, [&](SRC_TYPE value, ValidityMask &mask, idx_t row_idx) { - auto key = EnumType::GetPos(res_enum_type, dictionary_data[value]); - if (key == -1) { - if (!parameters.error_message) { - return HandleVectorCastError::Operation(CastExceptionText(value), - mask, row_idx, vector_cast_data); - } else { - mask.SetInvalid(row_idx); - } - return RES_TYPE(); - } else { - return UnsafeNumericCast(key); - } - }); - return vector_cast_data.all_converted; -} - -template -BoundCastInfo EnumEnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.InternalType()) { - case PhysicalType::UINT8: - return EnumEnumCast; - case PhysicalType::UINT16: - return EnumEnumCast; - case PhysicalType::UINT32: - return EnumEnumCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } -} - -template -static bool EnumToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &enum_dictionary = EnumType::GetValuesInsertOrder(source.GetType()); - auto dictionary_data = FlatVector::GetData(enum_dictionary); - - UnaryExecutor::Execute(source, result, count, - [&](SRC enum_idx) { return dictionary_data[enum_idx]; }); - return true; -} - -struct EnumBoundCastData : public BoundCastData { - EnumBoundCastData(BoundCastInfo to_varchar_cast, BoundCastInfo from_varchar_cast) - : to_varchar_cast(std::move(to_varchar_cast)), from_varchar_cast(std::move(from_varchar_cast)) { - } - - BoundCastInfo to_varchar_cast; - BoundCastInfo from_varchar_cast; - -public: - unique_ptr Copy() const override { - return make_uniq(to_varchar_cast.Copy(), from_varchar_cast.Copy()); - } -}; - -unique_ptr BindEnumCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto to_varchar_cast = input.GetCastFunction(source, LogicalType::VARCHAR); - auto from_varchar_cast = input.GetCastFunction(LogicalType::VARCHAR, target); - return make_uniq(std::move(to_varchar_cast), std::move(from_varchar_cast)); -} - -struct EnumCastLocalState : public FunctionLocalState { -public: - unique_ptr to_varchar_local; - unique_ptr from_varchar_local; -}; - -static unique_ptr InitEnumCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - if (cast_data.from_varchar_cast.init_local_state) { - CastLocalStateParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data); - result->from_varchar_local = cast_data.from_varchar_cast.init_local_state(from_varchar_params); - } - if (cast_data.to_varchar_cast.init_local_state) { - CastLocalStateParameters from_varchar_params(parameters, cast_data.to_varchar_cast.cast_data); - result->from_varchar_local = cast_data.to_varchar_cast.init_local_state(from_varchar_params); - } - return std::move(result); -} - -static bool EnumToAnyCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - Vector varchar_cast(LogicalType::VARCHAR, count); - - // cast to varchar - CastParameters to_varchar_params(parameters, cast_data.to_varchar_cast.cast_data, lstate.to_varchar_local); - cast_data.to_varchar_cast.function(source, varchar_cast, count, to_varchar_params); - - // cast from varchar to the target - CastParameters from_varchar_params(parameters, cast_data.from_varchar_cast.cast_data, lstate.from_varchar_local); - cast_data.from_varchar_cast.function(varchar_cast, result, count, from_varchar_params); - return true; -} - -BoundCastInfo DefaultCasts::EnumCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto enum_physical_type = source.InternalType(); - switch (target.id()) { - case LogicalTypeId::ENUM: { - // This means they are both ENUMs, but of different types. - switch (enum_physical_type) { - case PhysicalType::UINT8: - return EnumEnumCastSwitch(input, source, target); - case PhysicalType::UINT16: - return EnumEnumCastSwitch(input, source, target); - case PhysicalType::UINT32: - return EnumEnumCastSwitch(input, source, target); - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - } - case LogicalTypeId::VARCHAR: - switch (enum_physical_type) { - case PhysicalType::UINT8: - return EnumToVarcharCast; - case PhysicalType::UINT16: - return EnumToVarcharCast; - case PhysicalType::UINT32: - return EnumToVarcharCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - default: { - return BoundCastInfo(EnumToAnyCast, BindEnumCast(input, source, target), InitEnumCastLocalState); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/list_casts.cpp b/src/duckdb/src/function/cast/list_casts.cpp deleted file mode 100644 index a2326f860..000000000 --- a/src/duckdb/src/function/cast/list_casts.cpp +++ /dev/null @@ -1,294 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" -#include "duckdb/common/operator/cast_operators.hpp" - -namespace duckdb { - -unique_ptr ListBoundCastData::BindListToListCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_child_type = ListType::GetChildType(source); - auto &result_child_type = ListType::GetChildType(target); - auto child_cast = input.GetCastFunction(source_child_type, result_child_type); - return make_uniq(std::move(child_cast)); -} - -static unique_ptr BindListToArrayCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_child_type = ListType::GetChildType(source); - auto &result_child_type = ArrayType::GetChildType(target); - auto child_cast = input.GetCastFunction(source_child_type, result_child_type); - return make_uniq(std::move(child_cast)); -} - -unique_ptr ListBoundCastData::InitListLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - if (!cast_data.child_cast_info.init_local_state) { - return nullptr; - } - CastLocalStateParameters child_parameters(parameters, cast_data.child_cast_info.cast_data); - return cast_data.child_cast_info.init_local_state(child_parameters); -} - -bool ListCast::ListToListCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - - // only handle constant and flat vectors here for now - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(source.GetVectorType()); - const bool is_null = ConstantVector::IsNull(source); - ConstantVector::SetNull(result, is_null); - - if (!is_null) { - auto ldata = ConstantVector::GetData(source); - auto tdata = ConstantVector::GetData(result); - *tdata = *ldata; - } - } else { - source.Flatten(count); - result.SetVectorType(VectorType::FLAT_VECTOR); - FlatVector::SetValidity(result, FlatVector::Validity(source)); - - auto ldata = FlatVector::GetData(source); - auto tdata = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - tdata[i] = ldata[i]; - } - } - auto &source_cc = ListVector::GetEntry(source); - auto source_size = ListVector::GetListSize(source); - - ListVector::Reserve(result, source_size); - auto &append_vector = ListVector::GetEntry(result); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_succeeded = cast_data.child_cast_info.function(source_cc, append_vector, source_size, child_parameters); - ListVector::SetListSize(result, source_size); - D_ASSERT(ListVector::GetListSize(result) == source_size); - return all_succeeded; -} - -static bool ListToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast the child vector to varchar - Vector varchar_list(LogicalType::LIST(LogicalType::VARCHAR), count); - ListCast::ListToListCast(source, varchar_list, count, parameters); - - // now construct the actual varchar vector - varchar_list.Flatten(count); - auto &child = ListVector::GetEntry(varchar_list); - auto list_data = FlatVector::GetData(varchar_list); - auto &validity = FlatVector::Validity(varchar_list); - - child.Flatten(ListVector::GetListSize(varchar_list)); - auto child_data = FlatVector::GetData(child); - auto &child_validity = FlatVector::Validity(child); - - auto result_data = FlatVector::GetData(result); - static constexpr const idx_t SEP_LENGTH = 2; - static constexpr const idx_t NULL_LENGTH = 4; - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - auto list = list_data[i]; - // figure out how long the result needs to be - idx_t list_length = 2; // "[" and "]" - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (list_idx > 0) { - list_length += SEP_LENGTH; // ", " - } - // string length, or "NULL" - list_length += child_validity.RowIsValid(idx) ? child_data[idx].GetSize() : NULL_LENGTH; - } - result_data[i] = StringVector::EmptyString(result, list_length); - auto dataptr = result_data[i].GetDataWriteable(); - idx_t offset = 0; - dataptr[offset++] = '['; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (list_idx > 0) { - memcpy(dataptr + offset, ", ", SEP_LENGTH); - offset += SEP_LENGTH; - } - if (child_validity.RowIsValid(idx)) { - auto len = child_data[idx].GetSize(); - memcpy(dataptr + offset, child_data[idx].GetData(), len); - offset += len; - } else { - memcpy(dataptr + offset, "NULL", NULL_LENGTH); - offset += NULL_LENGTH; - } - } - dataptr[offset] = ']'; - result_data[i].Finalize(); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -static bool ListToArrayCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto array_size = ArrayType::GetSize(result.GetType()); - - // only handle constant and flat vectors here for now - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(source.GetVectorType()); - if (ConstantVector::IsNull(source)) { - ConstantVector::SetNull(result, true); - return true; - } - - auto ldata = ConstantVector::GetData(source)[0]; - if (!ConstantVector::IsNull(source) && ldata.length != array_size) { - // Cant cast to array, list size mismatch - auto msg = StringUtil::Format("Cannot cast list with length %llu to array with length %u", ldata.length, - array_size); - HandleCastError::AssignError(msg, parameters); - ConstantVector::SetNull(result, true); - return false; - } - - auto &source_cc = ListVector::GetEntry(source); - auto &result_cc = ArrayVector::GetEntry(result); - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - - if (ldata.offset == 0) { - // Fast path: offset is zero, we can just cast `array_size` elements of the child vectors directly - // Since the list was constant, there can only be one sequence of data in the child vector - return cast_data.child_cast_info.function(source_cc, result_cc, array_size, child_parameters); - } - - // Else, we need to copy the range we want to cast to a new vector and cast that - // In theory we could slice the source child to create a dictionary, but we would then have to flatten the - // result child which is going to allocate a temp vector and perform a copy anyway. Since we just want to copy a - // single contiguous range with a single offset, this is simpler. - - Vector payload_vector(source_cc.GetType(), array_size); - VectorOperations::Copy(source_cc, payload_vector, ldata.offset + array_size, ldata.offset, 0); - return cast_data.child_cast_info.function(payload_vector, result_cc, array_size, child_parameters); - - } else { - source.Flatten(count); - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto child_type = ArrayType::GetChildType(result.GetType()); - auto &source_cc = ListVector::GetEntry(source); - auto &result_cc = ArrayVector::GetEntry(result); - auto ldata = FlatVector::GetData(source); - - auto child_count = array_size * count; - SelectionVector child_sel(child_count); - - bool all_ok = true; - - for (idx_t i = 0; i < count; i++) { - if (FlatVector::IsNull(source, i)) { - FlatVector::SetNull(result, i, true); - for (idx_t array_elem = 0; array_elem < array_size; array_elem++) { - FlatVector::SetNull(result_cc, i * array_size + array_elem, true); - child_sel.set_index(i * array_size + array_elem, 0); - } - } else if (ldata[i].length != array_size) { - if (all_ok) { - all_ok = false; - auto msg = StringUtil::Format("Cannot cast list with length %llu to array with length %u", - ldata[i].length, array_size); - HandleCastError::AssignError(msg, parameters); - } - FlatVector::SetNull(result, i, true); - for (idx_t array_elem = 0; array_elem < array_size; array_elem++) { - FlatVector::SetNull(result_cc, i * array_size + array_elem, true); - child_sel.set_index(i * array_size + array_elem, 0); - } - } else { - for (idx_t array_elem = 0; array_elem < array_size; array_elem++) { - child_sel.set_index(i * array_size + array_elem, ldata[i].offset + array_elem); - } - } - } - - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - - // Fast path: No lists are null - // We can just cast the child vector directly - // Note: Its worth doing a CheckAllValid here, the slow path is significantly more expensive - if (FlatVector::Validity(result).CheckAllValid(count)) { - Vector payload_vector(result_cc.GetType(), child_count); - - bool ok = cast_data.child_cast_info.function(source_cc, payload_vector, child_count, child_parameters); - if (all_ok && !ok) { - all_ok = false; - HandleCastError::AssignError(*child_parameters.error_message, parameters); - } - // Now do the actual copy onto the result vector, making sure to slice properly in case the lists are out of - // order - VectorOperations::Copy(payload_vector, result_cc, child_sel, child_count, 0, 0); - return all_ok; - } - - // Slow path: Some lists are null, so we need to copy the data list by list to the right place - auto list_data = FlatVector::GetData(source); - DataChunk cast_chunk; - cast_chunk.Initialize(Allocator::DefaultAllocator(), {source_cc.GetType(), result_cc.GetType()}, array_size); - - for (idx_t i = 0; i < count; i++) { - if (FlatVector::IsNull(result, i)) { - // We've already failed to cast this list above (e.g. length mismatch), so theres nothing to do here. - continue; - } else { - auto &list_cast_input = cast_chunk.data[0]; - auto &list_cast_output = cast_chunk.data[1]; - auto list_entry = list_data[i]; - - VectorOperations::Copy(source_cc, list_cast_input, list_entry.offset + array_size, list_entry.offset, - 0); - - bool ok = - cast_data.child_cast_info.function(list_cast_input, list_cast_output, array_size, child_parameters); - if (all_ok && !ok) { - all_ok = false; - HandleCastError::AssignError(*child_parameters.error_message, parameters); - } - VectorOperations::Copy(list_cast_output, result_cc, array_size, 0, i * array_size); - - // Reset the cast_chunk - cast_chunk.Reset(); - } - } - - return all_ok; - } -} - -BoundCastInfo DefaultCasts::ListCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::LIST: - return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::VARCHAR: - return BoundCastInfo( - ListToVarcharCast, - ListBoundCastData::BindListToListCast(input, source, LogicalType::LIST(LogicalType::VARCHAR)), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::ARRAY: - return BoundCastInfo(ListToArrayCast, BindListToArrayCast(input, source, target), - ListBoundCastData::InitListLocalState); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -/* - - */ - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/map_cast.cpp b/src/duckdb/src/function/cast/map_cast.cpp deleted file mode 100644 index 20c14efa0..000000000 --- a/src/duckdb/src/function/cast/map_cast.cpp +++ /dev/null @@ -1,94 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" - -namespace duckdb { - -unique_ptr MapBoundCastData::BindMapToMapCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto source_key = MapType::KeyType(source); - auto target_key = MapType::KeyType(target); - auto source_val = MapType::ValueType(source); - auto target_val = MapType::ValueType(target); - auto key_cast = input.GetCastFunction(source_key, target_key); - auto value_cast = input.GetCastFunction(source_val, target_val); - return make_uniq(std::move(key_cast), std::move(value_cast)); -} - -static bool MapToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - Vector varchar_map(varchar_type, count); - - // since map's physical type is a list, the ListCast can be utilized - ListCast::ListToListCast(source, varchar_map, count, parameters); - - varchar_map.Flatten(count); - auto &validity = FlatVector::Validity(varchar_map); - auto &key_str = MapVector::GetKeys(varchar_map); - auto &val_str = MapVector::GetValues(varchar_map); - - key_str.Flatten(ListVector::GetListSize(source)); - val_str.Flatten(ListVector::GetListSize(source)); - - auto list_data = ListVector::GetData(varchar_map); - auto key_data = FlatVector::GetData(key_str); - auto val_data = FlatVector::GetData(val_str); - auto &key_validity = FlatVector::Validity(key_str); - auto &val_validity = FlatVector::Validity(val_str); - auto &struct_validity = FlatVector::Validity(ListVector::GetEntry(varchar_map)); - - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - auto list = list_data[i]; - string ret = "{"; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - if (list_idx > 0) { - ret += ", "; - } - auto idx = list.offset + list_idx; - - if (!struct_validity.RowIsValid(idx)) { - ret += "NULL"; - continue; - } - if (!key_validity.RowIsValid(idx)) { - // throw InternalException("Error in map: key validity invalid?!"); - ret += "invalid"; - continue; - } - ret += key_data[idx].GetString(); - ret += "="; - ret += val_validity.RowIsValid(idx) ? val_data[idx].GetString() : "NULL"; - } - ret += "}"; - result_data[i] = StringVector::AddString(result, ret); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -BoundCastInfo DefaultCasts::MapCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::MAP: - return BoundCastInfo(ListCast::ListToListCast, ListBoundCastData::BindListToListCast(input, source, target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::VARCHAR: { - auto varchar_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - return BoundCastInfo(MapToVarcharCast, ListBoundCastData::BindListToListCast(input, source, varchar_type), - ListBoundCastData::InitListLocalState); - } - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/numeric_casts.cpp b/src/duckdb/src/function/cast/numeric_casts.cpp deleted file mode 100644 index bdb999ffc..000000000 --- a/src/duckdb/src/function/cast/numeric_casts.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/common/operator/string_cast.hpp" -#include "duckdb/common/operator/numeric_cast.hpp" -#include "duckdb/common/types/varint.hpp" - -namespace duckdb { - -template -static BoundCastInfo InternalNumericCastSwitch(const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::UHUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::DECIMAL: - return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); - case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::VARINT: - return Varint::NumericToVarintCastSwitch(source); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::NumericCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (source.id()) { - case LogicalTypeId::BOOLEAN: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::TINYINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::SMALLINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::INTEGER: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::BIGINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UTINYINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::USMALLINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UINTEGER: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UBIGINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::HUGEINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::UHUGEINT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::FLOAT: - return InternalNumericCastSwitch(source, target); - case LogicalTypeId::DOUBLE: - return InternalNumericCastSwitch(source, target); - default: - throw InternalException("NumericCastSwitch called with non-numeric argument"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/pointer_cast.cpp b/src/duckdb/src/function/cast/pointer_cast.cpp deleted file mode 100644 index 33eead299..000000000 --- a/src/duckdb/src/function/cast/pointer_cast.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -BoundCastInfo DefaultCasts::PointerCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // pointer to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return nullptr; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/string_cast.cpp b/src/duckdb/src/function/cast/string_cast.cpp deleted file mode 100644 index f3a19c427..000000000 --- a/src/duckdb/src/function/cast/string_cast.cpp +++ /dev/null @@ -1,526 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/vector.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" -#include "duckdb/common/types/varint.hpp" - -namespace duckdb { - -template -bool StringEnumCastLoop(const string_t *source_data, ValidityMask &source_mask, const LogicalType &source_type, - T *result_data, ValidityMask &result_mask, const LogicalType &result_type, idx_t count, - VectorTryCastData &vector_cast_data, const SelectionVector *sel) { - for (idx_t i = 0; i < count; i++) { - idx_t source_idx = i; - if (sel) { - source_idx = sel->get_index(i); - } - if (source_mask.RowIsValid(source_idx)) { - auto pos = EnumType::GetPos(result_type, source_data[source_idx]); - if (pos == -1) { - result_data[i] = HandleVectorCastError::Operation( - CastExceptionText(source_data[source_idx]), result_mask, i, vector_cast_data); - } else { - result_data[i] = UnsafeNumericCast(pos); - } - } else { - result_mask.SetInvalid(i); - } - } - return vector_cast_data.all_converted; -} - -template -bool StringEnumCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); - switch (source.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - auto source_data = ConstantVector::GetData(source); - auto source_mask = ConstantVector::Validity(source); - auto result_data = ConstantVector::GetData(result); - auto &result_mask = ConstantVector::Validity(result); - - VectorTryCastData vector_cast_data(result, parameters); - return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, - result.GetType(), 1, vector_cast_data, nullptr); - } - default: { - UnifiedVectorFormat vdata; - source.ToUnifiedFormat(count, vdata); - - result.SetVectorType(VectorType::FLAT_VECTOR); - - auto source_data = UnifiedVectorFormat::GetData(vdata); - auto source_sel = vdata.sel; - auto source_mask = vdata.validity; - auto result_data = FlatVector::GetData(result); - auto &result_mask = FlatVector::Validity(result); - - VectorTryCastData vector_cast_data(result, parameters); - return StringEnumCastLoop(source_data, source_mask, source.GetType(), result_data, result_mask, - result.GetType(), count, vector_cast_data, source_sel); - } - } -} - -static BoundCastInfo VectorStringCastNumericSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::ENUM: { - switch (target.InternalType()) { - case PhysicalType::UINT8: - return StringEnumCast; - case PhysicalType::UINT16: - return StringEnumCast; - case PhysicalType::UINT32: - return StringEnumCast; - default: - throw InternalException("ENUM can only have unsigned integers (except UINT64) as physical types"); - } - } - case LogicalTypeId::BOOLEAN: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::UHUGEINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::INTERVAL: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::DECIMAL: - return BoundCastInfo(&VectorCastHelpers::ToDecimalCast); - default: - return DefaultCasts::TryVectorNullCast; - } -} - -//===--------------------------------------------------------------------===// -// string -> list casting -//===--------------------------------------------------------------------===// -bool VectorStringToList::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - idx_t total_list_size = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - continue; - } - total_list_size += VectorStringToList::CountPartsList(source_data[idx]); - } - - Vector varchar_vector(LogicalType::VARCHAR, total_list_size); - - ListVector::Reserve(result, total_list_size); - ListVector::SetListSize(result, total_list_size); - - auto list_data = ListVector::GetData(result); - auto child_data = FlatVector::GetData(varchar_vector); - - VectorTryCastData vector_cast_data(result, parameters); - idx_t total = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(i); - continue; - } - - list_data[i].offset = total; - if (!VectorStringToList::SplitStringList(source_data[idx], child_data, total, varchar_vector)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type LIST"; - HandleVectorCastError::Operation(text, result_mask, i, vector_cast_data); - } - list_data[i].length = total - list_data[i].offset; // length is the amount of parts coming from this string - } - D_ASSERT(total_list_size == total); - - auto &result_child = ListVector::GetEntry(result); - auto &cast_data = parameters.cast_data->Cast(); - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool all_converted = - cast_data.child_cast_info.function(varchar_vector, result_child, total_list_size, child_parameters) && - vector_cast_data.all_converted; - if (!all_converted && parameters.nullify_parent) { - UnifiedVectorFormat inserted_column_data; - result_child.ToUnifiedFormat(total_list_size, inserted_column_data); - UnifiedVectorFormat parse_column_data; - varchar_vector.ToUnifiedFormat(total_list_size, parse_column_data); - // Something went wrong in the conversion, we need to nullify the parent - for (idx_t i = 0; i < count; i++) { - for (idx_t j = list_data[i].offset; j < list_data[i].offset + list_data[i].length; j++) { - if (!inserted_column_data.validity.RowIsValid(j) && parse_column_data.validity.RowIsValid(j)) { - result_mask.SetInvalid(i); - break; - } - } - } - } - return all_converted; -} - -static LogicalType InitVarcharStructType(const LogicalType &target) { - child_list_t child_types; - for (auto &child : StructType::GetChildTypes(target)) { - child_types.push_back(make_pair(child.first, LogicalType::VARCHAR)); - } - - return LogicalType::STRUCT(child_types); -} - -//===--------------------------------------------------------------------===// -// string -> struct casting -//===--------------------------------------------------------------------===// -bool VectorStringToStruct::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - auto varchar_struct_type = InitVarcharStructType(result.GetType()); - Vector varchar_vector(varchar_struct_type, count); - auto &child_vectors = StructVector::GetEntries(varchar_vector); - auto &result_children = StructVector::GetEntries(result); - auto is_unnamed = StructType::IsUnnamed(result.GetType()); - - string_map_t child_names; - vector> child_masks; - for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { - if (!is_unnamed) { - child_names.insert({StructType::GetChildName(result.GetType(), child_idx), child_idx}); - } - child_masks.emplace_back(FlatVector::Validity(*child_vectors[child_idx])); - child_masks[child_idx].get().SetAllInvalid(count); - } - - VectorTryCastData vector_cast_data(result, parameters); - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(i); - continue; - } - if (is_unnamed) { - throw ConversionException("Casting strings to unnamed structs is unsupported"); - } - if (!VectorStringToStruct::SplitStruct(source_data[idx], child_vectors, i, child_names, child_masks)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type STRUCT"; - for (auto &child_mask : child_masks) { - child_mask.get().SetInvalid(i); // some values may have already been found and set valid - } - HandleVectorCastError::Operation(text, result_mask, i, vector_cast_data); - } - } - - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - D_ASSERT(cast_data.child_cast_info.size() == result_children.size()); - - for (idx_t child_idx = 0; child_idx < result_children.size(); child_idx++) { - auto &child_varchar_vector = *child_vectors[child_idx]; - auto &result_child_vector = *result_children[child_idx]; - auto &child_cast_info = cast_data.child_cast_info[child_idx]; - CastParameters child_parameters(parameters, child_cast_info.cast_data, lstate.local_states[child_idx]); - if (!child_cast_info.function(child_varchar_vector, result_child_vector, count, child_parameters)) { - vector_cast_data.all_converted = false; - } - } - return vector_cast_data.all_converted; -} - -//===--------------------------------------------------------------------===// -// string -> map casting -//===--------------------------------------------------------------------===// -unique_ptr InitMapCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - if (cast_data.key_cast.init_local_state) { - CastLocalStateParameters child_params(parameters, cast_data.key_cast.cast_data); - result->key_state = cast_data.key_cast.init_local_state(child_params); - } - if (cast_data.value_cast.init_local_state) { - CastLocalStateParameters child_params(parameters, cast_data.value_cast.cast_data); - result->value_state = cast_data.value_cast.init_local_state(child_params); - } - return std::move(result); -} - -bool VectorStringToMap::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - idx_t total_elements = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - continue; - } - total_elements += (VectorStringToMap::CountPartsMap(source_data[idx]) + 1) / 2; - } - - Vector varchar_key_vector(LogicalType::VARCHAR, total_elements); - Vector varchar_val_vector(LogicalType::VARCHAR, total_elements); - auto child_key_data = FlatVector::GetData(varchar_key_vector); - auto child_val_data = FlatVector::GetData(varchar_val_vector); - - ListVector::Reserve(result, total_elements); - ListVector::SetListSize(result, total_elements); - auto list_data = ListVector::GetData(result); - - VectorTryCastData vector_cast_data(result, parameters); - idx_t total = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - result_mask.SetInvalid(i); - continue; - } - - list_data[i].offset = total; - if (!VectorStringToMap::SplitStringMap(source_data[idx], child_key_data, child_val_data, total, - varchar_key_vector, varchar_val_vector)) { - string text = "Type VARCHAR with value '" + source_data[idx].GetString() + - "' can't be cast to the destination type MAP"; - FlatVector::SetNull(result, i, true); - HandleVectorCastError::Operation(text, result_mask, i, vector_cast_data); - } - list_data[i].length = total - list_data[i].offset; - } - D_ASSERT(total_elements == total); - - auto &result_key_child = MapVector::GetKeys(result); - auto &result_val_child = MapVector::GetValues(result); - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - CastParameters key_params(parameters, cast_data.key_cast.cast_data, lstate.key_state); - if (!cast_data.key_cast.function(varchar_key_vector, result_key_child, total_elements, key_params)) { - vector_cast_data.all_converted = false; - } - CastParameters val_params(parameters, cast_data.value_cast.cast_data, lstate.value_state); - if (!cast_data.value_cast.function(varchar_val_vector, result_val_child, total_elements, val_params)) { - vector_cast_data.all_converted = false; - } - - if (!vector_cast_data.all_converted) { - auto &key_validity = FlatVector::Validity(result_key_child); - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - if (!result_mask.RowIsValid(row_idx)) { - continue; - } - auto list = list_data[row_idx]; - for (idx_t list_idx = 0; list_idx < list.length; list_idx++) { - auto idx = list.offset + list_idx; - if (!key_validity.RowIsValid(idx)) { - result_mask.SetInvalid(row_idx); - } - } - } - } - MapVector::MapConversionVerify(result, count); - return vector_cast_data.all_converted; -} - -//===--------------------------------------------------------------------===// -// string -> array casting -//===--------------------------------------------------------------------===// -bool VectorStringToArray::StringToNestedTypeCastLoop(const string_t *source_data, ValidityMask &source_mask, - Vector &result, ValidityMask &result_mask, idx_t count, - CastParameters ¶meters, const SelectionVector *sel) { - idx_t array_size = ArrayType::GetSize(result.GetType()); - bool all_lengths_match = true; - - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - if (!source_mask.RowIsValid(idx)) { - continue; - } - auto str_array_size = VectorStringToList::CountPartsList(source_data[idx]); - if (array_size != str_array_size) { - if (all_lengths_match) { - all_lengths_match = false; - auto msg = - StringUtil::Format("Type VARCHAR with value '%s' can't be cast to the destination type ARRAY[%u]" - ", the size of the array must match the destination type", - source_data[idx].GetString(), array_size); - if (parameters.strict) { - throw ConversionException(msg); - } - HandleCastError::AssignError(msg, parameters); - } - result_mask.SetInvalid(i); - } - } - - auto child_count = array_size * count; - Vector varchar_vector(LogicalType::VARCHAR, child_count); - auto child_data = FlatVector::GetData(varchar_vector); - - VectorTryCastData vector_cast_data(result, parameters); - idx_t total = 0; - for (idx_t i = 0; i < count; i++) { - idx_t idx = i; - if (sel) { - idx = sel->get_index(i); - } - - if (!source_mask.RowIsValid(idx) || !result_mask.RowIsValid(i)) { - // The source is null, or there was a size-mismatch above, so dont try to split the string - result_mask.SetInvalid(i); - - // Null the entire array - for (idx_t j = 0; j < array_size; j++) { - FlatVector::SetNull(varchar_vector, i * array_size + j, true); - } - - total += array_size; - continue; - } - - if (!VectorStringToList::SplitStringList(source_data[idx], child_data, total, varchar_vector)) { - auto text = StringUtil::Format("Type VARCHAR with value '%s' can't be cast to the destination type ARRAY", - source_data[idx].GetString()); - HandleVectorCastError::Operation(text, result_mask, i, vector_cast_data); - } - } - D_ASSERT(total == child_count); - - auto &result_child = ArrayVector::GetEntry(result); - auto &cast_data = parameters.cast_data->Cast(); - CastParameters child_parameters(parameters, cast_data.child_cast_info.cast_data, parameters.local_state); - bool cast_result = cast_data.child_cast_info.function(varchar_vector, result_child, child_count, child_parameters); - - return all_lengths_match && cast_result && vector_cast_data.all_converted; -} - -template -bool StringToNestedTypeCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(source.GetType().id() == LogicalTypeId::VARCHAR); - - switch (source.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - auto source_data = ConstantVector::GetData(source); - auto &source_mask = ConstantVector::Validity(source); - auto &result_mask = FlatVector::Validity(result); - auto ret = T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, 1, parameters, nullptr); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - return ret; - } - default: { - UnifiedVectorFormat unified_source; - - source.ToUnifiedFormat(count, unified_source); - auto source_sel = unified_source.sel; - auto source_data = UnifiedVectorFormat::GetData(unified_source); - auto &source_mask = unified_source.validity; - auto &result_mask = FlatVector::Validity(result); - - return T::StringToNestedTypeCastLoop(source_data, source_mask, result, result_mask, count, parameters, - source_sel); - } - } -} - -BoundCastInfo DefaultCasts::StringCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::DATE: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIME: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIME_TZ: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - return BoundCastInfo(&VectorCastHelpers::TryCastErrorLoop); - case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo( - &VectorCastHelpers::TryCastStrictLoop); - case LogicalTypeId::BLOB: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::BIT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::UUID: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::SQLNULL: - return &DefaultCasts::TryVectorNullCast; - case LogicalTypeId::VARCHAR: - return &DefaultCasts::ReinterpretCast; - case LogicalTypeId::LIST: - // the second argument allows for a secondary casting function to be passed in the CastParameters - return BoundCastInfo( - &StringToNestedTypeCast, - ListBoundCastData::BindListToListCast(input, LogicalType::LIST(LogicalType::VARCHAR), target), - ListBoundCastData::InitListLocalState); - case LogicalTypeId::ARRAY: - // the second argument allows for a secondary casting function to be passed in the CastParameters - return BoundCastInfo(&StringToNestedTypeCast, - ArrayBoundCastData::BindArrayToArrayCast( - input, LogicalType::ARRAY(LogicalType::VARCHAR, optional_idx()), target), - ArrayBoundCastData::InitArrayLocalState); - case LogicalTypeId::STRUCT: - return BoundCastInfo(&StringToNestedTypeCast, - StructBoundCastData::BindStructToStructCast(input, InitVarcharStructType(target), target), - StructBoundCastData::InitStructCastLocalState); - case LogicalTypeId::MAP: - return BoundCastInfo(&StringToNestedTypeCast, - MapBoundCastData::BindMapToMapCast( - input, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR), target), - InitMapCastLocalState); - case LogicalTypeId::VARINT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - default: - return VectorStringCastNumericSwitch(input, source, target); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/struct_cast.cpp b/src/duckdb/src/function/cast/struct_cast.cpp deleted file mode 100644 index 051581c54..000000000 --- a/src/duckdb/src/function/cast/struct_cast.cpp +++ /dev/null @@ -1,235 +0,0 @@ -#include "duckdb/common/exception/binder_exception.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" - -namespace duckdb { - -unique_ptr StructBoundCastData::BindStructToStructCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - auto &source_children = StructType::GetChildTypes(source); - auto &target_children = StructType::GetChildTypes(target); - - auto target_is_unnamed = StructType::IsUnnamed(target); - auto source_is_unnamed = StructType::IsUnnamed(source); - - auto is_unnamed = target_is_unnamed || source_is_unnamed; - if (is_unnamed && source_children.size() != target_children.size()) { - throw TypeMismatchException(input.query_location, source, target, "Cannot cast STRUCTs of different size"); - } - - case_insensitive_map_t target_children_map; - if (!is_unnamed) { - for (idx_t i = 0; i < target_children.size(); i++) { - auto &name = target_children[i].first; - if (target_children_map.find(name) != target_children_map.end()) { - throw NotImplementedException("Error while casting - duplicate name \"%s\" in struct", name); - } - target_children_map[name] = i; - } - } - - vector source_indexes; - vector target_indexes; - vector target_null_indexes; - bool has_any_match = is_unnamed; - - for (idx_t i = 0; i < source_children.size(); i++) { - auto &source_child = source_children[i]; - auto target_idx = i; - - // Map to the correct index for names structs. - if (!is_unnamed) { - auto target_child = target_children_map.find(source_child.first); - if (target_child == target_children_map.end()) { - // Skip any children that have no target. - continue; - } - target_idx = target_child->second; - target_children_map.erase(target_child); - has_any_match = true; - } - - source_indexes.push_back(i); - target_indexes.push_back(target_idx); - auto child_cast = input.GetCastFunction(source_child.second, target_children[target_idx].second); - child_cast_info.push_back(std::move(child_cast)); - } - - if (!has_any_match) { - throw BinderException("STRUCT to STRUCT cast must have at least one matching member"); - } - - // The remaining target children have no match in the source struct. - // Thus, they become NULL. - for (const auto &target_child : target_children_map) { - target_null_indexes.push_back(target_child.second); - } - - return make_uniq(std::move(child_cast_info), target, std::move(source_indexes), - std::move(target_indexes), std::move(target_null_indexes)); -} - -unique_ptr StructBoundCastData::InitStructCastLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - for (auto &entry : cast_data.child_cast_info) { - unique_ptr child_state; - if (entry.init_local_state) { - CastLocalStateParameters child_params(parameters, entry.cast_data); - child_state = entry.init_local_state(child_params); - } - result->local_states.push_back(std::move(child_state)); - } - return std::move(result); -} - -static bool StructToStructCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &l_state = parameters.local_state->Cast(); - - auto &source_vectors = StructVector::GetEntries(source); - auto &target_children = StructVector::GetEntries(result); - - bool all_converted = true; - for (idx_t i = 0; i < cast_data.source_indexes.size(); i++) { - auto source_idx = cast_data.source_indexes[i]; - auto target_idx = cast_data.target_indexes[i]; - - auto &source_vector = *source_vectors[source_idx]; - auto &target_vector = *target_children[target_idx]; - - auto &child_cast_info = cast_data.child_cast_info[i]; - CastParameters child_parameters(parameters, child_cast_info.cast_data, l_state.local_states[i]); - auto success = child_cast_info.function(source_vector, target_vector, count, child_parameters); - if (!success) { - all_converted = false; - } - } - - if (!cast_data.target_null_indexes.empty()) { - for (idx_t i = 0; i < cast_data.target_null_indexes.size(); i++) { - auto target_idx = cast_data.target_null_indexes[i]; - auto &target_vector = *target_children[target_idx]; - - target_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(target_vector, true); - } - } - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - return all_converted; - } - - source.Flatten(count); - auto &result_validity = FlatVector::Validity(result); - result_validity = FlatVector::Validity(source); - result.Verify(count); - return all_converted; -} - -static bool StructToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast all child elements to varchar - auto &cast_data = parameters.cast_data->Cast(); - Vector varchar_struct(cast_data.target, count); - StructToStructCast(source, varchar_struct, count, parameters); - - // now construct the actual varchar vector - varchar_struct.Flatten(count); - bool is_unnamed = StructType::IsUnnamed(source.GetType()); - auto &child_types = StructType::GetChildTypes(source.GetType()); - auto &children = StructVector::GetEntries(varchar_struct); - auto &validity = FlatVector::Validity(varchar_struct); - auto result_data = FlatVector::GetData(result); - static constexpr const idx_t SEP_LENGTH = 2; - static constexpr const idx_t NAME_SEP_LENGTH = 4; - static constexpr const idx_t NULL_LENGTH = 4; - for (idx_t i = 0; i < count; i++) { - if (!validity.RowIsValid(i)) { - FlatVector::SetNull(result, i, true); - continue; - } - idx_t string_length = 2; // {} - for (idx_t c = 0; c < children.size(); c++) { - if (c > 0) { - string_length += SEP_LENGTH; - } - children[c]->Flatten(count); - auto &child_validity = FlatVector::Validity(*children[c]); - auto data = FlatVector::GetData(*children[c]); - auto &name = child_types[c].first; - if (!is_unnamed) { - string_length += name.size() + NAME_SEP_LENGTH; // "'{name}': " - } - string_length += child_validity.RowIsValid(i) ? data[i].GetSize() : NULL_LENGTH; - } - result_data[i] = StringVector::EmptyString(result, string_length); - auto dataptr = result_data[i].GetDataWriteable(); - idx_t offset = 0; - dataptr[offset++] = is_unnamed ? '(' : '{'; - for (idx_t c = 0; c < children.size(); c++) { - if (c > 0) { - memcpy(dataptr + offset, ", ", SEP_LENGTH); - offset += SEP_LENGTH; - } - auto &child_validity = FlatVector::Validity(*children[c]); - auto data = FlatVector::GetData(*children[c]); - if (!is_unnamed) { - auto &name = child_types[c].first; - // "'{name}': " - dataptr[offset++] = '\''; - memcpy(dataptr + offset, name.c_str(), name.size()); - offset += name.size(); - dataptr[offset++] = '\''; - dataptr[offset++] = ':'; - dataptr[offset++] = ' '; - } - // value - if (child_validity.RowIsValid(i)) { - auto len = data[i].GetSize(); - memcpy(dataptr + offset, data[i].GetData(), len); - offset += len; - } else { - memcpy(dataptr + offset, "NULL", NULL_LENGTH); - offset += NULL_LENGTH; - } - } - dataptr[offset++] = is_unnamed ? ')' : '}'; - result_data[i].Finalize(); - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - return true; -} - -BoundCastInfo DefaultCasts::StructCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - switch (target.id()) { - case LogicalTypeId::STRUCT: - return BoundCastInfo(StructToStructCast, StructBoundCastData::BindStructToStructCast(input, source, target), - StructBoundCastData::InitStructCastLocalState); - case LogicalTypeId::VARCHAR: { - // bind a cast in which we convert all child entries to VARCHAR entries - auto &struct_children = StructType::GetChildTypes(source); - child_list_t varchar_children; - for (auto &child_entry : struct_children) { - varchar_children.push_back(make_pair(child_entry.first, LogicalType::VARCHAR)); - } - auto varchar_type = LogicalType::STRUCT(varchar_children); - return BoundCastInfo(StructToVarcharCast, - StructBoundCastData::BindStructToStructCast(input, source, varchar_type), - StructBoundCastData::InitStructCastLocalState); - } - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/time_casts.cpp b/src/duckdb/src/function/cast/time_casts.cpp deleted file mode 100644 index b9587ad1c..000000000 --- a/src/duckdb/src/function/cast/time_casts.cpp +++ /dev/null @@ -1,215 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/common/operator/string_cast.hpp" -namespace duckdb { - -BoundCastInfo DefaultCasts::DateCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // date to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - // date to timestamp - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_SEC: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - case LogicalTypeId::TIMESTAMP_MS: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimeCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME_TZ: - // time to time with time zone - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimeTzCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME: - // time with time zone to time - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DATE: - // timestamp to date - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME: - // timestamp to time - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME_TZ: - // timestamp to time_tz - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_TZ: - // timestamp (us) to timestamp with time zone - return ReinterpretCast; - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (us) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_MS: - // timestamp (us) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_SEC: - // timestamp (us) to timestamp (s) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampTzCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp with time zone to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::TIME_TZ: - // timestamp with time zone to time with time zone. - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp with time zone to timestamp (us) - return ReinterpretCast; - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampNsCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (ns) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DATE: - // timestamp (ns) to date - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME: - // timestamp (ns) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp (ns) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_TZ: - // timestamp (ns) to timestamp with time zone (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampMsCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (ms) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DATE: - // timestamp (ms) to date - return BoundCastInfo(&VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME: - // timestamp (ms) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp (ms) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (ms) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_TZ: - // timestamp (ms) to timestamp with timezone (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::TimestampSecCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // timestamp (s) to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DATE: - // timestamp (s) to date - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIME: - // timestamp (s) to time - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_MS: - // timestamp (s) to timestamp (ms) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP: - // timestamp (s) to timestamp (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_TZ: - // timestamp (s) to timestamp with timezone (us) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - case LogicalTypeId::TIMESTAMP_NS: - // timestamp (s) to timestamp (ns) - return BoundCastInfo( - &VectorCastHelpers::TemplatedCastLoop); - default: - return TryVectorNullCast; - } -} -BoundCastInfo DefaultCasts::IntervalCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // time to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/union/from_struct.cpp b/src/duckdb/src/function/cast/union/from_struct.cpp deleted file mode 100644 index c6fcd2a03..000000000 --- a/src/duckdb/src/function/cast/union/from_struct.cpp +++ /dev/null @@ -1,134 +0,0 @@ -#include "duckdb/function/cast/bound_cast_data.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" - -namespace duckdb { - -bool StructToUnionCast::AllowImplicitCastFromStruct(const LogicalType &source, const LogicalType &target) { - if (source.id() != LogicalTypeId::STRUCT) { - return false; - } - auto target_fields = StructType::GetChildTypes(target); - auto fields = StructType::GetChildTypes(source); - if (target_fields.size() != fields.size()) { - // Struct should have the same amount of fields as the union - return false; - } - for (idx_t i = 0; i < target_fields.size(); i++) { - auto &target_field = target_fields[i].second; - auto &target_field_name = target_fields[i].first; - auto &field = fields[i].second; - auto &field_name = fields[i].first; - if (i == 0) { - // For the tag field we don't accept a type substitute as varchar - if (target_field != field) { - return false; - } - continue; - } - if (!StringUtil::CIEquals(target_field_name, field_name)) { - return false; - } - if (target_field != field && field != LogicalType::VARCHAR) { - // We allow the field to be VARCHAR, since unsupported types get cast to VARCHAR by EXPORT DATABASE (format - // PARQUET) i.e UNION(a BIT) becomes STRUCT(a VARCHAR) - return false; - } - } - return true; -} - -// Physical Cast execution - -bool StructToUnionCast::Cast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - D_ASSERT(source.GetType().id() == LogicalTypeId::STRUCT); - D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); - D_ASSERT(cast_data.target.id() == LogicalTypeId::UNION); - - auto &source_children = StructVector::GetEntries(source); - auto &target_children = StructVector::GetEntries(result); - - for (idx_t i = 0; i < source_children.size(); i++) { - auto &result_child_vector = *target_children[i]; - auto &source_child_vector = *source_children[i]; - CastParameters child_parameters(parameters, cast_data.child_cast_info[i].cast_data, lstate.local_states[i]); - auto converted = - cast_data.child_cast_info[i].function(source_child_vector, result_child_vector, count, child_parameters); - (void)converted; - D_ASSERT(converted); - // we flatten the child because we use FlatVector::SetNull below and we may get non-flat from source/cast - result_child_vector.Flatten(count); - } - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, ConstantVector::IsNull(source)); - - // if the tag is NULL, the union should be NULL - auto &tag_vec = *target_children[0]; - ConstantVector::SetNull(result, ConstantVector::IsNull(tag_vec)); - } else { - // if the tag is NULL, the union should be NULL - auto &tag_vec = *target_children[0]; - UnifiedVectorFormat source_data, tag_data; - source.ToUnifiedFormat(count, source_data); - tag_vec.ToUnifiedFormat(count, tag_data); - - for (idx_t i = 0; i < count; i++) { - if (!source_data.validity.RowIsValid(source_data.sel->get_index(i)) || - !tag_data.validity.RowIsValid(tag_data.sel->get_index(i))) { - FlatVector::SetNull(result, i, true); - } - } - } - - auto check_tags = UnionVector::CheckUnionValidity(result, count); - switch (check_tags) { - case UnionInvalidReason::TAG_OUT_OF_RANGE: - throw ConversionException("One or more of the tags do not point to a valid union member"); - case UnionInvalidReason::VALIDITY_OVERLAP: - throw ConversionException("One or more rows in the produced UNION have validity set for more than 1 member"); - case UnionInvalidReason::TAG_MISMATCH: - throw ConversionException( - "One or more rows in the produced UNION have tags that don't point to the valid member"); - case UnionInvalidReason::NULL_TAG: - throw ConversionException("One or more rows in the produced UNION have a NULL tag"); - case UnionInvalidReason::VALID: - break; - default: - throw InternalException("Struct to union cast failed for unknown reason"); - } - - result.Verify(count); - return true; -} - -// Bind cast - -unique_ptr StructToUnionCast::BindData(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - vector child_cast_info; - D_ASSERT(source.id() == LogicalTypeId::STRUCT); - D_ASSERT(target.id() == LogicalTypeId::UNION); - - auto result_child_count = StructType::GetChildCount(target); - D_ASSERT(result_child_count == StructType::GetChildCount(source)); - - for (idx_t i = 0; i < result_child_count; i++) { - auto &source_child = StructType::GetChildType(source, i); - auto &target_child = StructType::GetChildType(target, i); - - auto child_cast = input.GetCastFunction(source_child, target_child); - child_cast_info.push_back(std::move(child_cast)); - } - return make_uniq(std::move(child_cast_info), target); -} - -BoundCastInfo StructToUnionCast::Bind(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - auto cast_data = StructToUnionCast::BindData(input, source, target); - return BoundCastInfo(&StructToUnionCast::Cast, std::move(cast_data), StructBoundCastData::InitStructCastLocalState); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/union_casts.cpp b/src/duckdb/src/function/cast/union_casts.cpp deleted file mode 100644 index 72b171745..000000000 --- a/src/duckdb/src/function/cast/union_casts.cpp +++ /dev/null @@ -1,365 +0,0 @@ -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/function/cast/cast_function_set.hpp" -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/function/cast/bound_cast_data.hpp" - -#include // for std::sort - -namespace duckdb { - -//-------------------------------------------------------------------------------------------------- -// ??? -> UNION -//-------------------------------------------------------------------------------------------------- -// if the source can be implicitly cast to a member of the target union, the cast is valid - -unique_ptr BindToUnionCast(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - D_ASSERT(target.id() == LogicalTypeId::UNION); - - vector candidates; - - for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(target); member_idx++) { - auto member_type = UnionType::GetMemberType(target, member_idx); - auto member_name = UnionType::GetMemberName(target, member_idx); - auto member_cast_cost = input.function_set.ImplicitCastCost(source, member_type); - if (member_cast_cost != -1) { - auto member_cast_info = input.GetCastFunction(source, member_type); - candidates.emplace_back(member_idx, member_name, member_type, member_cast_cost, - std::move(member_cast_info)); - } - }; - - // no possible casts found! - if (candidates.empty()) { - auto message = StringUtil::Format( - "Type %s can't be cast as %s. %s can't be implicitly cast to any of the union member types: ", - source.ToString(), target.ToString(), source.ToString()); - - auto member_count = UnionType::GetMemberCount(target); - for (idx_t member_idx = 0; member_idx < member_count; member_idx++) { - auto member_type = UnionType::GetMemberType(target, member_idx); - message += member_type.ToString(); - if (member_idx < member_count - 1) { - message += ", "; - } - } - throw ConversionException(message); - } - - // sort the candidate casts by cost - std::sort(candidates.begin(), candidates.end(), UnionBoundCastData::SortByCostAscending); - - // select the lowest possible cost cast - auto &selected_cast = candidates[0]; - auto selected_cost = candidates[0].cost; - - // check if the cast is ambiguous (2 or more casts have the same cost) - if (candidates.size() > 1 && candidates[1].cost == selected_cost) { - - // collect all the ambiguous types - auto message = StringUtil::Format( - "Type %s can't be cast as %s. The cast is ambiguous, multiple possible members in target: ", source, - target); - for (size_t i = 0; i < candidates.size(); i++) { - if (candidates[i].cost == selected_cost) { - message += StringUtil::Format("'%s (%s)'", candidates[i].name, candidates[i].type.ToString()); - if (i < candidates.size() - 1) { - message += ", "; - } - } - } - message += ". Disambiguate the target type by using the 'union_value( := )' function to promote the " - "source value to a single member union before casting."; - throw ConversionException(message); - } - - // otherwise, return the selected cast - return make_uniq(std::move(selected_cast)); -} - -unique_ptr InitToUnionLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - if (!cast_data.member_cast_info.init_local_state) { - return nullptr; - } - CastLocalStateParameters child_parameters(parameters, cast_data.member_cast_info.cast_data); - return cast_data.member_cast_info.init_local_state(child_parameters); -} - -static bool ToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - D_ASSERT(result.GetType().id() == LogicalTypeId::UNION); - auto &cast_data = parameters.cast_data->Cast(); - auto &selected_member_vector = UnionVector::GetMember(result, cast_data.tag); - - CastParameters child_parameters(parameters, cast_data.member_cast_info.cast_data, parameters.local_state); - if (!cast_data.member_cast_info.function(source, selected_member_vector, count, child_parameters)) { - return false; - } - - // cast succeeded, create union vector - UnionVector::SetToMember(result, cast_data.tag, selected_member_vector, count, true); - - result.Verify(count); - - return true; -} - -BoundCastInfo DefaultCasts::ImplicitToUnionCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - - D_ASSERT(target.id() == LogicalTypeId::UNION); - if (StructToUnionCast::AllowImplicitCastFromStruct(source, target)) { - return StructToUnionCast::Bind(input, source, target); - } - auto cast_data = BindToUnionCast(input, source, target); - return BoundCastInfo(&ToUnionCast, std::move(cast_data), InitToUnionLocalState); -} - -//-------------------------------------------------------------------------------------------------- -// UNION -> UNION -//-------------------------------------------------------------------------------------------------- -// if the source member tags is a subset of the target member tags, and all the source members can be -// implicitly cast to the corresponding target members, the cast is valid. -// -// VALID: UNION(A, B) -> UNION(A, B, C) -// VALID: UNION(A, B) -> UNION(A, C) if B can be implicitly cast to C -// -// INVALID: UNION(A, B, C) -> UNION(A, B) -// INVALID: UNION(A, B) -> UNION(A, C) if B can't be implicitly cast to C -// INVALID: UNION(A, B, D) -> UNION(A, B, C) - -struct UnionUnionBoundCastData : public BoundCastData { - - // mapping from source member index to target member index - // these are always the same size as the source member count - // (since all source members must be present in the target) - vector tag_map; - vector member_casts; - - LogicalType target_type; - - UnionUnionBoundCastData(vector tag_map, vector member_casts, LogicalType target_type) - : tag_map(std::move(tag_map)), member_casts(std::move(member_casts)), target_type(std::move(target_type)) { - } - -public: - unique_ptr Copy() const override { - vector member_casts_copy; - for (auto &member_cast : member_casts) { - member_casts_copy.push_back(member_cast.Copy()); - } - return make_uniq(tag_map, std::move(member_casts_copy), target_type); - } -}; - -unique_ptr BindUnionToUnionCast(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::UNION); - D_ASSERT(target.id() == LogicalTypeId::UNION); - - auto source_member_count = UnionType::GetMemberCount(source); - - auto tag_map = vector(source_member_count); - auto member_casts = vector(); - - for (idx_t source_idx = 0; source_idx < source_member_count; source_idx++) { - auto &source_member_type = UnionType::GetMemberType(source, source_idx); - auto &source_member_name = UnionType::GetMemberName(source, source_idx); - - bool found = false; - for (idx_t target_idx = 0; target_idx < UnionType::GetMemberCount(target); target_idx++) { - auto &target_member_name = UnionType::GetMemberName(target, target_idx); - - // found a matching member - if (StringUtil::CIEquals(source_member_name, target_member_name)) { - auto &target_member_type = UnionType::GetMemberType(target, target_idx); - tag_map[source_idx] = target_idx; - member_casts.push_back(input.GetCastFunction(source_member_type, target_member_type)); - found = true; - break; - } - } - if (!found) { - // no matching member tag found in the target set - auto message = - StringUtil::Format("Type %s can't be cast as %s. The member '%s' is not present in target union", - source.ToString(), target.ToString(), source_member_name); - throw ConversionException(message); - } - } - - return make_uniq(tag_map, std::move(member_casts), target); -} - -unique_ptr InitUnionToUnionLocalState(CastLocalStateParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto result = make_uniq(); - - for (auto &entry : cast_data.member_casts) { - unique_ptr child_state; - if (entry.init_local_state) { - CastLocalStateParameters child_params(parameters, entry.cast_data); - child_state = entry.init_local_state(child_params); - } - result->local_states.push_back(std::move(child_state)); - } - return std::move(result); -} - -static bool UnionToUnionCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto &cast_data = parameters.cast_data->Cast(); - auto &lstate = parameters.local_state->Cast(); - - auto source_member_count = UnionType::GetMemberCount(source.GetType()); - auto target_member_count = UnionType::GetMemberCount(result.GetType()); - - auto target_member_is_mapped = vector(target_member_count); - - // Perform the casts from source to target members - for (idx_t member_idx = 0; member_idx < source_member_count; member_idx++) { - auto target_member_idx = cast_data.tag_map[member_idx]; - - auto &source_member_vector = UnionVector::GetMember(source, member_idx); - auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); - auto &member_cast = cast_data.member_casts[member_idx]; - - CastParameters child_parameters(parameters, member_cast.cast_data, lstate.local_states[member_idx]); - if (!member_cast.function(source_member_vector, target_member_vector, count, child_parameters)) { - return false; - } - - target_member_is_mapped[target_member_idx] = true; - } - - // All member casts succeeded! - - // Set the unmapped target members to constant NULL. - // If we cast UNION(A, B) -> UNION(A, B, C) we need to invalidate C so that - // the invariants of the result union hold. (only member columns "selected" - // by the rowwise corresponding tag in the tag vector should be valid) - for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { - if (!target_member_is_mapped[target_member_idx]) { - auto &target_member_vector = UnionVector::GetMember(result, target_member_idx); - target_member_vector.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(target_member_vector, true); - } - } - - // Update the tags in the result vector - auto &source_tag_vector = UnionVector::GetTags(source); - auto &result_tag_vector = UnionVector::GetTags(result); - - if (source.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // Constant vector case optimization - result.SetVectorType(VectorType::CONSTANT_VECTOR); - if (ConstantVector::IsNull(source)) { - ConstantVector::SetNull(result, true); - } else { - // map the tag - auto source_tag = ConstantVector::GetData(source_tag_vector)[0]; - auto mapped_tag = cast_data.tag_map[source_tag]; - ConstantVector::GetData(result_tag_vector)[0] = UnsafeNumericCast(mapped_tag); - } - } else { - // Otherwise, use the unified vector format to access the source vector. - - // Ensure that all the result members are flat vectors - // This is not always the case, e.g. when a member is cast using the default TryNullCast function - // the resulting member vector will be a constant null vector. - for (idx_t target_member_idx = 0; target_member_idx < target_member_count; target_member_idx++) { - UnionVector::GetMember(result, target_member_idx).Flatten(count); - } - - // We assume that a union tag vector validity matches the union vector validity. - UnifiedVectorFormat source_tag_format; - source_tag_vector.ToUnifiedFormat(count, source_tag_format); - - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto source_row_idx = source_tag_format.sel->get_index(row_idx); - if (source_tag_format.validity.RowIsValid(source_row_idx)) { - // map the tag - auto source_tag = (UnifiedVectorFormat::GetData(source_tag_format))[source_row_idx]; - auto target_tag = cast_data.tag_map[source_tag]; - FlatVector::GetData(result_tag_vector)[row_idx] = - UnsafeNumericCast(target_tag); - } else { - - // Issue: The members of the result is not always flatvectors - // In the case of TryNullCast, the result member is constant. - FlatVector::SetNull(result, row_idx, true); - } - } - } - - result.Verify(count); - - return true; -} - -static bool UnionToVarcharCast(Vector &source, Vector &result, idx_t count, CastParameters ¶meters) { - auto constant = source.GetVectorType() == VectorType::CONSTANT_VECTOR; - // first cast all union members to varchar - auto &cast_data = parameters.cast_data->Cast(); - Vector varchar_union(cast_data.target_type, count); - - UnionToUnionCast(source, varchar_union, count, parameters); - - // now construct the actual varchar vector - // varchar_union.Flatten(count); - auto &tag_vector = UnionVector::GetTags(varchar_union); - UnifiedVectorFormat tag_format; - tag_vector.ToUnifiedFormat(count, tag_format); - - auto result_data = FlatVector::GetData(result); - - for (idx_t i = 0; i < count; i++) { - auto tag_idx = tag_format.sel->get_index(i); - if (!tag_format.validity.RowIsValid(tag_idx)) { - FlatVector::SetNull(result, i, true); - continue; - } - - auto tag = UnifiedVectorFormat::GetData(tag_format)[tag_idx]; - auto &member = UnionVector::GetMember(varchar_union, tag); - UnifiedVectorFormat member_vdata; - member.ToUnifiedFormat(count, member_vdata); - auto mapped_idx = member_vdata.sel->get_index(i); - auto member_valid = member_vdata.validity.RowIsValid(mapped_idx); - if (member_valid) { - auto member_str = (UnifiedVectorFormat::GetData(member_vdata))[mapped_idx]; - result_data[i] = StringVector::AddString(result, member_str); - } else { - result_data[i] = StringVector::AddString(result, "NULL"); - } - } - - if (constant) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(count); - return true; -} - -BoundCastInfo DefaultCasts::UnionCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::UNION); - switch (target.id()) { - case LogicalTypeId::VARCHAR: { - // bind a cast in which we convert all members to VARCHAR first - child_list_t varchar_members; - for (idx_t member_idx = 0; member_idx < UnionType::GetMemberCount(source); member_idx++) { - varchar_members.push_back(make_pair(UnionType::GetMemberName(source, member_idx), LogicalType::VARCHAR)); - } - auto varchar_type = LogicalType::UNION(std::move(varchar_members)); - return BoundCastInfo(UnionToVarcharCast, BindUnionToUnionCast(input, source, varchar_type), - InitUnionToUnionLocalState); - } - case LogicalTypeId::UNION: - return BoundCastInfo(UnionToUnionCast, BindUnionToUnionCast(input, source, target), InitUnionToUnionLocalState); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/uuid_casts.cpp b/src/duckdb/src/function/cast/uuid_casts.cpp deleted file mode 100644 index c2267b51e..000000000 --- a/src/duckdb/src/function/cast/uuid_casts.cpp +++ /dev/null @@ -1,18 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -BoundCastInfo DefaultCasts::UUIDCastSwitch(BindCastInput &input, const LogicalType &source, const LogicalType &target) { - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - // uuid to varchar - return BoundCastInfo(&VectorCastHelpers::StringCast); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/varint_casts.cpp b/src/duckdb/src/function/cast/varint_casts.cpp deleted file mode 100644 index 0f4346883..000000000 --- a/src/duckdb/src/function/cast/varint_casts.cpp +++ /dev/null @@ -1,283 +0,0 @@ -#include "duckdb/function/cast/default_casts.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/function/cast/vector_cast_helpers.hpp" -#include "duckdb/common/types/varint.hpp" -#include - -namespace duckdb { - -template -string_t IntToVarInt(Vector &result, T int_value) { - // Determine if the number is negative - bool is_negative = int_value < 0; - // Determine the number of data bytes - uint64_t abs_value; - if (is_negative) { - if (int_value == std::numeric_limits::min()) { - abs_value = static_cast(std::numeric_limits::max()) + 1; - } else { - abs_value = static_cast(std::abs(static_cast(int_value))); - } - } else { - abs_value = static_cast(int_value); - } - uint32_t data_byte_size; - if (abs_value != NumericLimits::Maximum()) { - data_byte_size = (abs_value == 0) ? 1 : static_cast(std::ceil(std::log2(abs_value + 1) / 8.0)); - } else { - data_byte_size = static_cast(std::ceil(std::log2(abs_value) / 8.0)); - } - - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; - auto blob = StringVector::EmptyString(result, blob_size); - auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); - - // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; - for (int i = static_cast(data_byte_size) - 1; i >= 0; --i) { - if (is_negative) { - writable_blob[wb_idx++] = static_cast(~(abs_value >> i * 8 & 0xFF)); - } else { - writable_blob[wb_idx++] = static_cast(abs_value >> i * 8 & 0xFF); - } - } - blob.Finalize(); - return blob; -} - -template <> -string_t HugeintCastToVarInt::Operation(uhugeint_t int_value, Vector &result) { - uint32_t data_byte_size; - if (int_value.upper != NumericLimits::Maximum()) { - data_byte_size = - (int_value.upper == 0) ? 0 : static_cast(std::ceil(std::log2(int_value.upper + 1) / 8.0)); - } else { - data_byte_size = static_cast(std::ceil(std::log2(int_value.upper) / 8.0)); - } - - uint32_t upper_byte_size = data_byte_size; - if (data_byte_size > 0) { - // If we have at least one byte on the upper side, the bottom side is complete - data_byte_size += 8; - } else { - if (int_value.lower != NumericLimits::Maximum()) { - data_byte_size += static_cast(std::ceil(std::log2(int_value.lower + 1) / 8.0)); - } else { - data_byte_size += static_cast(std::ceil(std::log2(int_value.lower) / 8.0)); - } - } - if (data_byte_size == 0) { - data_byte_size++; - } - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; - auto blob = StringVector::EmptyString(result, blob_size); - auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, false); - - // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; - for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { - writable_blob[wb_idx++] = static_cast(int_value.upper >> i * 8 & 0xFF); - } - for (int i = static_cast(data_byte_size - upper_byte_size) - 1; i >= 0; --i) { - writable_blob[wb_idx++] = static_cast(int_value.lower >> i * 8 & 0xFF); - } - blob.Finalize(); - return blob; -} - -template <> -string_t HugeintCastToVarInt::Operation(hugeint_t int_value, Vector &result) { - // Determine if the number is negative - bool is_negative = int_value.upper >> 63 & 1; - if (is_negative) { - // We must check if it's -170141183460469231731687303715884105728, since it's not possible to negate it - // without overflowing - if (int_value == NumericLimits::Minimum()) { - uhugeint_t u_int_value {0x8000000000000000, 0}; - auto cast_value = Operation(u_int_value, result); - // We have to do all the bit flipping. - auto writable_value_ptr = cast_value.GetDataWriteable(); - Varint::SetHeader(writable_value_ptr, cast_value.GetSize() - Varint::VARINT_HEADER_SIZE, is_negative); - for (idx_t i = Varint::VARINT_HEADER_SIZE; i < cast_value.GetSize(); i++) { - writable_value_ptr[i] = static_cast(~writable_value_ptr[i]); - } - cast_value.Finalize(); - return cast_value; - } - int_value = -int_value; - } - // Determine the number of data bytes - uint64_t abs_value_upper = static_cast(int_value.upper); - - uint32_t data_byte_size; - if (abs_value_upper != NumericLimits::Maximum()) { - data_byte_size = - (abs_value_upper == 0) ? 0 : static_cast(std::ceil(std::log2(abs_value_upper + 1) / 8.0)); - } else { - data_byte_size = static_cast(std::ceil(std::log2(abs_value_upper) / 8.0)); - } - - uint32_t upper_byte_size = data_byte_size; - if (data_byte_size > 0) { - // If we have at least one byte on the upper side, the bottom side is complete - data_byte_size += 8; - } else { - if (int_value.lower != NumericLimits::Maximum()) { - data_byte_size += static_cast(std::ceil(std::log2(int_value.lower + 1) / 8.0)); - } else { - data_byte_size += static_cast(std::ceil(std::log2(int_value.lower) / 8.0)); - } - } - - if (data_byte_size == 0) { - data_byte_size++; - } - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; - auto blob = StringVector::EmptyString(result, blob_size); - auto writable_blob = blob.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); - - // Add data bytes to the blob, starting off after header bytes - idx_t wb_idx = Varint::VARINT_HEADER_SIZE; - for (int i = static_cast(upper_byte_size) - 1; i >= 0; --i) { - if (is_negative) { - writable_blob[wb_idx++] = static_cast(~(abs_value_upper >> i * 8 & 0xFF)); - } else { - writable_blob[wb_idx++] = static_cast(abs_value_upper >> i * 8 & 0xFF); - } - } - for (int i = static_cast(data_byte_size - upper_byte_size) - 1; i >= 0; --i) { - if (is_negative) { - writable_blob[wb_idx++] = static_cast(~(int_value.lower >> i * 8 & 0xFF)); - } else { - writable_blob[wb_idx++] = static_cast(int_value.lower >> i * 8 & 0xFF); - } - } - blob.Finalize(); - return blob; -} - -// Varchar to Varint -// TODO: This is a slow quadratic algorithm, we can still optimize it further. -template <> -bool TryCastToVarInt::Operation(string_t input_value, string_t &result_value, Vector &result, - CastParameters ¶meters) { - auto blob_string = Varint::VarcharToVarInt(input_value); - - uint32_t blob_size = static_cast(blob_string.size()); - result_value = StringVector::EmptyString(result, blob_size); - auto writable_blob = result_value.GetDataWriteable(); - - // Write string_blob into blob - for (idx_t i = 0; i < blob_string.size(); i++) { - writable_blob[i] = blob_string[i]; - } - result_value.Finalize(); - return true; -} - -template -bool DoubleToVarInt(T double_value, string_t &result_value, Vector &result) { - // Check if we can cast it - if (!std::isfinite(double_value)) { - // We can't cast inf -inf nan - return false; - } - // Determine if the number is negative - bool is_negative = double_value < 0; - // Determine the number of data bytes - double abs_value = std::abs(double_value); - - if (abs_value == 0) { - // Return Value 0 - result_value = Varint::InitializeVarintZero(result); - return true; - } - vector value; - while (abs_value > 0) { - double quotient = abs_value / 256; - double truncated = floor(quotient); - uint8_t byte = static_cast(abs_value - truncated * 256); - abs_value = truncated; - if (is_negative) { - value.push_back(static_cast(~byte)); - } else { - value.push_back(static_cast(byte)); - } - } - uint32_t data_byte_size = static_cast(value.size()); - uint32_t blob_size = data_byte_size + Varint::VARINT_HEADER_SIZE; - result_value = StringVector::EmptyString(result, blob_size); - auto writable_blob = result_value.GetDataWriteable(); - Varint::SetHeader(writable_blob, data_byte_size, is_negative); - // Add data bytes to the blob, starting off after header bytes - idx_t blob_string_idx = value.size() - 1; - for (idx_t i = Varint::VARINT_HEADER_SIZE; i < blob_size; i++) { - writable_blob[i] = value[blob_string_idx--]; - } - result_value.Finalize(); - return true; -} - -template <> -bool TryCastToVarInt::Operation(double double_value, string_t &result_value, Vector &result, - CastParameters ¶meters) { - return DoubleToVarInt(double_value, result_value, result); -} - -template <> -bool TryCastToVarInt::Operation(float double_value, string_t &result_value, Vector &result, - CastParameters ¶meters) { - return DoubleToVarInt(double_value, result_value, result); -} - -BoundCastInfo Varint::NumericToVarintCastSwitch(const LogicalType &source) { - // now switch on the result type - switch (source.id()) { - case LogicalTypeId::TINYINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::UTINYINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::SMALLINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::USMALLINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::INTEGER: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::UINTEGER: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::BIGINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::UBIGINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::UHUGEINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::FLOAT: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastStringLoop); - case LogicalTypeId::HUGEINT: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DECIMAL: - default: - return DefaultCasts::TryVectorNullCast; - } -} - -BoundCastInfo DefaultCasts::VarintCastSwitch(BindCastInput &input, const LogicalType &source, - const LogicalType &target) { - D_ASSERT(source.id() == LogicalTypeId::VARINT); - // now switch on the result type - switch (target.id()) { - case LogicalTypeId::VARCHAR: - return BoundCastInfo(&VectorCastHelpers::StringCast); - case LogicalTypeId::DOUBLE: - return BoundCastInfo(&VectorCastHelpers::TryCastLoop); - default: - return TryVectorNullCast; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast/vector_cast_helpers.cpp b/src/duckdb/src/function/cast/vector_cast_helpers.cpp deleted file mode 100644 index c7aa523ea..000000000 --- a/src/duckdb/src/function/cast/vector_cast_helpers.cpp +++ /dev/null @@ -1,360 +0,0 @@ -#include "duckdb/function/cast/vector_cast_helpers.hpp" - -namespace duckdb { - -// ------- Helper functions for splitting string nested types ------- -static bool IsNull(const char *buf, idx_t start_pos, Vector &child, idx_t row_idx) { - if ((buf[start_pos] == 'N' || buf[start_pos] == 'n') && (buf[start_pos + 1] == 'U' || buf[start_pos + 1] == 'u') && - (buf[start_pos + 2] == 'L' || buf[start_pos + 2] == 'l') && - (buf[start_pos + 3] == 'L' || buf[start_pos + 3] == 'l')) { - FlatVector::SetNull(child, row_idx, true); - return true; - } - return false; -} - -inline static void SkipWhitespace(const char *buf, idx_t &pos, idx_t len) { - while (pos < len && StringUtil::CharacterIsSpace(buf[pos])) { - pos++; - } -} - -static bool SkipToCloseQuotes(idx_t &pos, const char *buf, idx_t &len) { - char quote = buf[pos]; - pos++; - bool escaped = false; - - while (pos < len) { - if (buf[pos] == '\\') { - escaped = !escaped; - } else { - if (buf[pos] == quote && !escaped) { - return true; - } - escaped = false; - } - pos++; - } - return false; -} - -static bool SkipToClose(idx_t &idx, const char *buf, idx_t &len, idx_t &lvl, char close_bracket) { - idx++; - - vector brackets; - brackets.push_back(close_bracket); - while (idx < len) { - if (buf[idx] == '"' || buf[idx] == '\'') { - if (!SkipToCloseQuotes(idx, buf, len)) { - return false; - } - } else if (buf[idx] == '{') { - brackets.push_back('}'); - } else if (buf[idx] == '[') { - brackets.push_back(']'); - lvl++; - } else if (buf[idx] == brackets.back()) { - if (buf[idx] == ']') { - lvl--; - } - brackets.pop_back(); - if (brackets.empty()) { - return true; - } - } - idx++; - } - return false; -} - -static idx_t StringTrim(const char *buf, idx_t &start_pos, idx_t pos) { - idx_t trailing_whitespace = 0; - while (pos > start_pos && StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { - trailing_whitespace++; - } - if ((buf[start_pos] == '"' && buf[pos - trailing_whitespace - 1] == '"') || - (buf[start_pos] == '\'' && buf[pos - trailing_whitespace - 1] == '\'')) { - start_pos++; - trailing_whitespace++; - } - return (pos - trailing_whitespace); -} - -struct CountPartOperation { - idx_t count = 0; - - bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { - count++; - return true; - } - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - count++; - } -}; - -// ------- LIST SPLIT ------- -struct SplitStringListOperation { - SplitStringListOperation(string_t *child_data, idx_t &child_start, Vector &child) - : child_data(child_data), child_start(child_start), child(child) { - } - - string_t *child_data; - idx_t &child_start; - Vector &child; - - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, child, child_start)) { - child_start++; - return; - } - if (start_pos > pos) { - pos = start_pos; - } - child_data[child_start] = StringVector::AddString(child, buf + start_pos, pos - start_pos); - child_start++; - } -}; - -template -static bool SplitStringListInternal(const string_t &input, OP &state) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t lvl = 1; - idx_t pos = 0; - bool seen_value = false; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '[') { - return false; - } - - SkipWhitespace(buf, ++pos, len); - idx_t start_pos = pos; - while (pos < len) { - if (buf[pos] == '[') { - if (!SkipToClose(pos, buf, len, ++lvl, ']')) { - return false; - } - } else if ((buf[pos] == '"' || buf[pos] == '\'') && pos == start_pos) { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - idx_t struct_lvl = 0; - SkipToClose(pos, buf, len, struct_lvl, '}'); - } else if (buf[pos] == ',' || buf[pos] == ']') { - idx_t trailing_whitespace = 0; - while (StringUtil::CharacterIsSpace(buf[pos - trailing_whitespace - 1])) { - trailing_whitespace++; - } - if (buf[pos] != ']' || start_pos != pos || seen_value) { - state.HandleValue(buf, start_pos, pos - trailing_whitespace); - seen_value = true; - } - if (buf[pos] == ']') { - lvl--; - break; - } - SkipWhitespace(buf, ++pos, len); - start_pos = pos; - continue; - } - pos++; - } - SkipWhitespace(buf, ++pos, len); - return (pos == len && lvl == 0); -} - -bool VectorStringToList::SplitStringList(const string_t &input, string_t *child_data, idx_t &child_start, - Vector &child) { - SplitStringListOperation state(child_data, child_start, child); - return SplitStringListInternal(input, state); -} - -idx_t VectorStringToList::CountPartsList(const string_t &input) { - CountPartOperation state; - SplitStringListInternal(input, state); - return state.count; -} - -// ------- MAP SPLIT ------- -struct SplitStringMapOperation { - SplitStringMapOperation(string_t *child_key_data, string_t *child_val_data, idx_t &child_start, Vector &varchar_key, - Vector &varchar_val) - : child_key_data(child_key_data), child_val_data(child_val_data), child_start(child_start), - varchar_key(varchar_key), varchar_val(varchar_val) { - } - - string_t *child_key_data; - string_t *child_val_data; - idx_t &child_start; - Vector &varchar_key; - Vector &varchar_val; - - bool HandleKey(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_key, child_start)) { - FlatVector::SetNull(varchar_val, child_start, true); - child_start++; - return false; - } - child_key_data[child_start] = StringVector::AddString(varchar_key, buf + start_pos, pos - start_pos); - return true; - } - - void HandleValue(const char *buf, idx_t start_pos, idx_t pos) { - if ((pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_val, child_start)) { - child_start++; - return; - } - child_val_data[child_start] = StringVector::AddString(varchar_val, buf + start_pos, pos - start_pos); - child_start++; - } -}; - -template -static bool FindKeyOrValueMap(const char *buf, idx_t len, idx_t &pos, OP &state, bool key) { - auto start_pos = pos; - idx_t lvl = 0; - while (pos < len) { - if (buf[pos] == '"' || buf[pos] == '\'') { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - SkipToClose(pos, buf, len, lvl, '}'); - } else if (buf[pos] == '[') { - SkipToClose(pos, buf, len, lvl, ']'); - } else if (key && buf[pos] == '=') { - idx_t end_pos = StringTrim(buf, start_pos, pos); - return state.HandleKey(buf, start_pos, end_pos); // put string in KEY_child_vector - } else if (!key && (buf[pos] == ',' || buf[pos] == '}')) { - idx_t end_pos = StringTrim(buf, start_pos, pos); - state.HandleValue(buf, start_pos, end_pos); // put string in VALUE_child_vector - return true; - } - pos++; - } - return false; -} - -template -static bool SplitStringMapInternal(const string_t &input, OP &state) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t pos = 0; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '{') { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (pos == len) { - return false; - } - if (buf[pos] == '}') { - SkipWhitespace(buf, ++pos, len); - return (pos == len); - } - while (pos < len) { - if (!FindKeyOrValueMap(buf, len, pos, state, true)) { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (!FindKeyOrValueMap(buf, len, pos, state, false)) { - return false; - } - SkipWhitespace(buf, ++pos, len); - } - return true; -} - -bool VectorStringToMap::SplitStringMap(const string_t &input, string_t *child_key_data, string_t *child_val_data, - idx_t &child_start, Vector &varchar_key, Vector &varchar_val) { - SplitStringMapOperation state(child_key_data, child_val_data, child_start, varchar_key, varchar_val); - return SplitStringMapInternal(input, state); -} - -idx_t VectorStringToMap::CountPartsMap(const string_t &input) { - CountPartOperation state; - SplitStringMapInternal(input, state); - return state.count; -} - -// ------- STRUCT SPLIT ------- -static bool FindKeyStruct(const char *buf, idx_t len, idx_t &pos) { - while (pos < len) { - if (buf[pos] == ':') { - return true; - } - pos++; - } - return false; -} - -static bool FindValueStruct(const char *buf, idx_t len, idx_t &pos, Vector &varchar_child, idx_t &row_idx, - ValidityMask &child_mask) { - auto start_pos = pos; - idx_t lvl = 0; - while (pos < len) { - if (buf[pos] == '"' || buf[pos] == '\'') { - SkipToCloseQuotes(pos, buf, len); - } else if (buf[pos] == '{') { - SkipToClose(pos, buf, len, lvl, '}'); - } else if (buf[pos] == '[') { - SkipToClose(pos, buf, len, lvl, ']'); - } else if (buf[pos] == ',' || buf[pos] == '}') { - idx_t end_pos = StringTrim(buf, start_pos, pos); - if ((end_pos - start_pos) == 4 && IsNull(buf, start_pos, varchar_child, row_idx)) { - return true; - } - FlatVector::GetData(varchar_child)[row_idx] = - StringVector::AddString(varchar_child, buf + start_pos, end_pos - start_pos); - child_mask.SetValid(row_idx); // any child not set to valid will remain invalid - return true; - } - pos++; - } - return false; -} - -bool VectorStringToStruct::SplitStruct(const string_t &input, vector> &varchar_vectors, - idx_t &row_idx, string_map_t &child_names, - vector> &child_masks) { - const char *buf = input.GetData(); - idx_t len = input.GetSize(); - idx_t pos = 0; - idx_t child_idx; - - SkipWhitespace(buf, pos, len); - if (pos == len || buf[pos] != '{') { - return false; - } - SkipWhitespace(buf, ++pos, len); - if (buf[pos] == '}') { - pos++; - } else { - while (pos < len) { - auto key_start = pos; - if (!FindKeyStruct(buf, len, pos)) { - return false; - } - auto key_end = StringTrim(buf, key_start, pos); - if (key_start >= key_end) { - // empty key name unsupported - return false; - } - string_t found_key(buf + key_start, UnsafeNumericCast(key_end - key_start)); - - auto it = child_names.find(found_key); - if (it == child_names.end()) { - return false; // false key - } - child_idx = it->second; - SkipWhitespace(buf, ++pos, len); - if (!FindValueStruct(buf, len, pos, *varchar_vectors[child_idx], row_idx, child_masks[child_idx].get())) { - return false; - } - SkipWhitespace(buf, ++pos, len); - } - } - SkipWhitespace(buf, pos, len); - return (pos == len); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/cast_rules.cpp b/src/duckdb/src/function/cast_rules.cpp deleted file mode 100644 index 951ecc935..000000000 --- a/src/duckdb/src/function/cast_rules.cpp +++ /dev/null @@ -1,585 +0,0 @@ -#include "duckdb/function/cast_rules.hpp" -#include "duckdb/common/helper.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/case_insensitive_map.hpp" - -namespace duckdb { - -//! The target type determines the preferred implicit casts -static int64_t TargetTypeCost(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BIGINT: - return 101; - case LogicalTypeId::INTEGER: - return 102; - case LogicalTypeId::HUGEINT: - return 103; - case LogicalTypeId::DOUBLE: - return 104; - case LogicalTypeId::DECIMAL: - return 105; - case LogicalTypeId::TIMESTAMP_NS: - return 119; - case LogicalTypeId::TIMESTAMP: - return 120; - case LogicalTypeId::TIMESTAMP_MS: - return 121; - case LogicalTypeId::TIMESTAMP_SEC: - return 122; - case LogicalTypeId::TIMESTAMP_TZ: - return 123; - case LogicalTypeId::VARCHAR: - return 149; - case LogicalTypeId::STRUCT: - case LogicalTypeId::MAP: - case LogicalTypeId::LIST: - case LogicalTypeId::UNION: - case LogicalTypeId::ARRAY: - return 160; - case LogicalTypeId::ANY: - return int64_t(AnyType::GetCastScore(type)); - default: - return 110; - } -} - -static int64_t ImplicitCastTinyint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastSmallint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastInteger(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastBigint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUTinyint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUSmallint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUInteger(const LogicalType &to) { - switch (to.id()) { - - case LogicalTypeId::UBIGINT: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUBigint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastFloat(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::DOUBLE: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastDouble(const LogicalType &to) { - switch (to.id()) { - default: - return -1; - } -} - -static int64_t ImplicitCastDecimal(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastHugeint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastUhugeint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::DECIMAL: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastDate(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastEnum(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::VARCHAR: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastTimestampSec(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastTimestampMS(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastTimestampNS(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP: - // we allow casting ALL timestamps, including nanosecond ones, to TimestampNS - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastTimestamp(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::TIMESTAMP_NS: - return TargetTypeCost(to); - case LogicalTypeId::TIMESTAMP_TZ: - return TargetTypeCost(to); - default: - return -1; - } -} - -static int64_t ImplicitCastVarint(const LogicalType &to) { - switch (to.id()) { - case LogicalTypeId::DOUBLE: - return TargetTypeCost(to); - default: - return -1; - } -} - -bool LogicalTypeIsValid(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::STRUCT: - case LogicalTypeId::UNION: - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::ARRAY: - case LogicalTypeId::DECIMAL: - // these types are only valid with auxiliary info - if (!type.AuxInfo()) { - return false; - } - break; - default: - break; - } - switch (type.id()) { - case LogicalTypeId::ANY: - case LogicalTypeId::INVALID: - case LogicalTypeId::UNKNOWN: - return false; - case LogicalTypeId::STRUCT: { - auto child_count = StructType::GetChildCount(type); - for (idx_t i = 0; i < child_count; i++) { - if (!LogicalTypeIsValid(StructType::GetChildType(type, i))) { - return false; - } - } - return true; - } - case LogicalTypeId::UNION: { - auto member_count = UnionType::GetMemberCount(type); - for (idx_t i = 0; i < member_count; i++) { - if (!LogicalTypeIsValid(UnionType::GetMemberType(type, i))) { - return false; - } - } - return true; - } - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - return LogicalTypeIsValid(ListType::GetChildType(type)); - case LogicalTypeId::ARRAY: - return LogicalTypeIsValid(ArrayType::GetChildType(type)); - default: - return true; - } -} - -int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to) { - if (from.id() == LogicalTypeId::SQLNULL || to.id() == LogicalTypeId::ANY) { - // NULL expression can be cast to anything - return TargetTypeCost(to); - } - if (from.id() == LogicalTypeId::UNKNOWN) { - // parameter expression can be cast to anything for no cost - return 0; - } - if (from.id() == LogicalTypeId::STRING_LITERAL) { - // string literals can be cast to any type for low cost as long as the type is valid - // i.e. we cannot cast to LIST(ANY) as we don't know what "ANY" should be - // we cannot cast to DECIMAL without precision/width specified - if (!LogicalTypeIsValid(to)) { - return -1; - } - if (to.id() == LogicalTypeId::VARCHAR && to.GetAlias().empty()) { - return 1; - } - return 20; - } - if (from.id() == LogicalTypeId::INTEGER_LITERAL) { - // the integer literal has an underlying type - this type always matches - if (IntegerLiteral::GetType(from).id() == to.id()) { - return 0; - } - // integer literals can be cast to any other integer type for a low cost, but only if the literal fits - if (IntegerLiteral::FitsInType(from, to)) { - // to avoid ties we prefer BIGINT, INT, ... - auto target_cost = TargetTypeCost(to); - if (target_cost < 100) { - throw InternalException("Integer literal implicit cast - TargetTypeCost should be >= 100"); - } - return target_cost - 90; - } - // in any other case we use the casting rules of the preferred type of the literal - return CastRules::ImplicitCast(IntegerLiteral::GetType(from), to); - } - if (from.GetAlias() != to.GetAlias()) { - // if aliases are different, an implicit cast is not possible - return -1; - } - if (from.id() == LogicalTypeId::LIST && to.id() == LogicalTypeId::LIST) { - // Lists can be cast if their child types can be cast - auto child_cost = ImplicitCast(ListType::GetChildType(from), ListType::GetChildType(to)); - if (child_cost >= 1) { - // subtract one from the cost because we prefer LIST[X] -> LIST[VARCHAR] over LIST[X] -> VARCHAR - child_cost--; - } - return child_cost; - } - if (from.id() == LogicalTypeId::ARRAY && to.id() == LogicalTypeId::ARRAY) { - // Arrays can be cast if their child types can be cast and the source and target has the same size - // or the target type has a unknown (any) size. - auto from_size = ArrayType::GetSize(from); - auto to_size = ArrayType::GetSize(to); - auto to_is_any_size = ArrayType::IsAnySize(to); - if (from_size == to_size || to_is_any_size) { - auto child_cost = ImplicitCast(ArrayType::GetChildType(from), ArrayType::GetChildType(to)); - if (child_cost >= 100) { - // subtract one from the cost because we prefer ARRAY[X] -> ARRAY[VARCHAR] over ARRAY[X] -> VARCHAR - child_cost--; - } - return child_cost; - } - return -1; // Not possible if the sizes are different - } - if (from.id() == LogicalTypeId::ARRAY && to.id() == LogicalTypeId::LIST) { - // Arrays can be cast to lists for the cost of casting the child type - auto child_cost = ImplicitCast(ArrayType::GetChildType(from), ListType::GetChildType(to)); - if (child_cost < 0) { - return -1; - } - // add 1 because we prefer ARRAY->ARRAY casts over ARRAY->LIST casts - return child_cost + 1; - } - if (from.id() == LogicalTypeId::LIST && (to.id() == LogicalTypeId::ARRAY && !ArrayType::IsAnySize(to))) { - // Lists can be cast to arrays for the cost of casting the child type, if the target size is known - // there is no way for us to resolve the size at bind-time without inspecting the list values. - // TODO: if we can access the expression we could resolve the size if the list is constant. - return ImplicitCast(ListType::GetChildType(from), ArrayType::GetChildType(to)); - } - if (from.id() == LogicalTypeId::UNION && to.id() == LogicalTypeId::UNION) { - // Check that the target union type is fully resolved. - if (to.AuxInfo() == nullptr) { - // If not, try anyway and let the actual cast logic handle it. - // This is to allow passing unions into functions that take a generic union type (without specifying member - // types) as an argument. - return 0; - } - // Unions can be cast if the source tags are a subset of the target tags - // in which case the most expensive cost is used - int64_t cost = -1; - for (idx_t from_member_idx = 0; from_member_idx < UnionType::GetMemberCount(from); from_member_idx++) { - auto &from_member_name = UnionType::GetMemberName(from, from_member_idx); - - bool found = false; - for (idx_t to_member_idx = 0; to_member_idx < UnionType::GetMemberCount(to); to_member_idx++) { - auto &to_member_name = UnionType::GetMemberName(to, to_member_idx); - - if (StringUtil::CIEquals(from_member_name, to_member_name)) { - auto &from_member_type = UnionType::GetMemberType(from, from_member_idx); - auto &to_member_type = UnionType::GetMemberType(to, to_member_idx); - - auto child_cost = ImplicitCast(from_member_type, to_member_type); - cost = MaxValue(cost, child_cost); - found = true; - break; - } - } - if (!found) { - return -1; - } - } - return cost; - } - if (from.id() == LogicalTypeId::STRUCT && to.id() == LogicalTypeId::STRUCT) { - if (to.AuxInfo() == nullptr) { - // If this struct is not fully resolved, we'll leave it to the actual cast logic to handle it. - return 0; - } - - auto &source_children = StructType::GetChildTypes(from); - auto &target_children = StructType::GetChildTypes(to); - - if (source_children.size() != target_children.size()) { - // different number of children: not possible - return -1; - } - - auto target_is_unnamed = StructType::IsUnnamed(to); - auto source_is_unnamed = StructType::IsUnnamed(from); - auto named_struct_cast = !source_is_unnamed && !target_is_unnamed; - - int64_t cost = -1; - if (named_struct_cast) { - - // Collect the target members in a map for easy lookup - case_insensitive_map_t target_members; - for (idx_t target_idx = 0; target_idx < target_children.size(); target_idx++) { - auto &target_name = target_children[target_idx].first; - if (target_members.find(target_name) != target_members.end()) { - // duplicate name in target struct - return -1; - } - target_members[target_name] = target_idx; - } - // Match the source members to the target members by name - for (idx_t source_idx = 0; source_idx < source_children.size(); source_idx++) { - auto &source_child = source_children[source_idx]; - auto entry = target_members.find(source_child.first); - if (entry == target_members.end()) { - // element in source struct was not found in target struct - return -1; - } - auto target_idx = entry->second; - target_members.erase(entry); - auto child_cost = ImplicitCast(source_child.second, target_children[target_idx].second); - if (child_cost == -1) { - return -1; - } - cost = MaxValue(cost, child_cost); - } - } else { - // Match the source members to the target members by position - for (idx_t i = 0; i < source_children.size(); i++) { - auto &source_child = source_children[i]; - auto &target_child = target_children[i]; - auto child_cost = ImplicitCast(source_child.second, target_child.second); - if (child_cost == -1) { - return -1; - } - cost = MaxValue(cost, child_cost); - } - } - return cost; - } - - if (from.id() == to.id()) { - // arguments match: do nothing - return 0; - } - - // Special case: Anything can be cast to a union if the source type is a member of the union - if (to.id() == LogicalTypeId::UNION) { - // check that the union type is fully resolved. - if (to.AuxInfo() == nullptr) { - return -1; - } - // check if the union contains something castable from the source type - // in which case the least expensive (most specific) cast should be used - bool found = false; - auto cost = NumericLimits::Maximum(); - for (idx_t i = 0; i < UnionType::GetMemberCount(to); i++) { - auto target_member = UnionType::GetMemberType(to, i); - auto target_cost = ImplicitCast(from, target_member); - if (target_cost != -1) { - found = true; - cost = MinValue(cost, target_cost); - } - } - return found ? cost : -1; - } - - switch (from.id()) { - case LogicalTypeId::TINYINT: - return ImplicitCastTinyint(to); - case LogicalTypeId::SMALLINT: - return ImplicitCastSmallint(to); - case LogicalTypeId::INTEGER: - return ImplicitCastInteger(to); - case LogicalTypeId::BIGINT: - return ImplicitCastBigint(to); - case LogicalTypeId::UTINYINT: - return ImplicitCastUTinyint(to); - case LogicalTypeId::USMALLINT: - return ImplicitCastUSmallint(to); - case LogicalTypeId::UINTEGER: - return ImplicitCastUInteger(to); - case LogicalTypeId::UBIGINT: - return ImplicitCastUBigint(to); - case LogicalTypeId::HUGEINT: - return ImplicitCastHugeint(to); - case LogicalTypeId::UHUGEINT: - return ImplicitCastUhugeint(to); - case LogicalTypeId::FLOAT: - return ImplicitCastFloat(to); - case LogicalTypeId::DOUBLE: - return ImplicitCastDouble(to); - case LogicalTypeId::DATE: - return ImplicitCastDate(to); - case LogicalTypeId::DECIMAL: - return ImplicitCastDecimal(to); - case LogicalTypeId::ENUM: - return ImplicitCastEnum(to); - case LogicalTypeId::TIMESTAMP_SEC: - return ImplicitCastTimestampSec(to); - case LogicalTypeId::TIMESTAMP_MS: - return ImplicitCastTimestampMS(to); - case LogicalTypeId::TIMESTAMP_NS: - return ImplicitCastTimestampNS(to); - case LogicalTypeId::TIMESTAMP: - return ImplicitCastTimestamp(to); - case LogicalTypeId::VARINT: - return ImplicitCastVarint(to); - default: - return -1; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/compression_config.cpp b/src/duckdb/src/function/compression_config.cpp deleted file mode 100644 index deece1a85..000000000 --- a/src/duckdb/src/function/compression_config.cpp +++ /dev/null @@ -1,106 +0,0 @@ -#include "duckdb/common/pair.hpp" -#include "duckdb/function/compression/compression.hpp" -#include "duckdb/function/compression_function.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -typedef CompressionFunction (*get_compression_function_t)(PhysicalType type); -typedef bool (*compression_supports_type_t)(const PhysicalType physical_type); - -struct DefaultCompressionMethod { - CompressionType type; - get_compression_function_t get_function; - compression_supports_type_t supports_type; -}; - -static const DefaultCompressionMethod internal_compression_methods[] = { - {CompressionType::COMPRESSION_CONSTANT, ConstantFun::GetFunction, ConstantFun::TypeIsSupported}, - {CompressionType::COMPRESSION_UNCOMPRESSED, UncompressedFun::GetFunction, UncompressedFun::TypeIsSupported}, - {CompressionType::COMPRESSION_RLE, RLEFun::GetFunction, RLEFun::TypeIsSupported}, - {CompressionType::COMPRESSION_BITPACKING, BitpackingFun::GetFunction, BitpackingFun::TypeIsSupported}, - {CompressionType::COMPRESSION_DICTIONARY, DictionaryCompressionFun::GetFunction, - DictionaryCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_CHIMP, ChimpCompressionFun::GetFunction, ChimpCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_PATAS, PatasCompressionFun::GetFunction, PatasCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_ALP, AlpCompressionFun::GetFunction, AlpCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_ALPRD, AlpRDCompressionFun::GetFunction, AlpRDCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_FSST, FSSTFun::GetFunction, FSSTFun::TypeIsSupported}, - {CompressionType::COMPRESSION_ZSTD, ZSTDFun::GetFunction, ZSTDFun::TypeIsSupported}, - {CompressionType::COMPRESSION_ROARING, RoaringCompressionFun::GetFunction, RoaringCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_EMPTY, EmptyValidityCompressionFun::GetFunction, - EmptyValidityCompressionFun::TypeIsSupported}, - {CompressionType::COMPRESSION_AUTO, nullptr, nullptr}}; - -static optional_ptr FindCompressionFunction(CompressionFunctionSet &set, CompressionType type, - const PhysicalType physical_type) { - auto &functions = set.functions; - auto comp_entry = functions.find(type); - if (comp_entry != functions.end()) { - auto &type_functions = comp_entry->second; - auto type_entry = type_functions.find(physical_type); - if (type_entry != type_functions.end()) { - return &type_entry->second; - } - } - return nullptr; -} - -static optional_ptr LoadCompressionFunction(CompressionFunctionSet &set, CompressionType type, - const PhysicalType physical_type) { - for (idx_t i = 0; internal_compression_methods[i].get_function; i++) { - const auto &method = internal_compression_methods[i]; - if (method.type == type) { - if (!method.supports_type(physical_type)) { - return nullptr; - } - // The type is supported. We create the function and insert it into the set of available functions. - auto function = method.get_function(physical_type); - set.functions[type].insert(make_pair(physical_type, function)); - return FindCompressionFunction(set, type, physical_type); - } - } - throw InternalException("Unsupported compression function type"); -} - -static void TryLoadCompression(DBConfig &config, vector> &result, CompressionType type, - const PhysicalType physical_type) { - auto function = config.GetCompressionFunction(type, physical_type); - if (!function) { - return; - } - result.push_back(*function); -} - -vector> DBConfig::GetCompressionFunctions(const PhysicalType physical_type) { - vector> result; - TryLoadCompression(*this, result, CompressionType::COMPRESSION_UNCOMPRESSED, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_RLE, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_BITPACKING, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_DICTIONARY, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_CHIMP, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_PATAS, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALP, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ALPRD, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_FSST, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ZSTD, physical_type); - TryLoadCompression(*this, result, CompressionType::COMPRESSION_ROARING, physical_type); - return result; -} - -optional_ptr DBConfig::GetCompressionFunction(CompressionType type, - const PhysicalType physical_type) { - lock_guard l(compression_functions->lock); - - // Check if the function is already loaded into the global compression functions. - auto function = FindCompressionFunction(*compression_functions, type, physical_type); - if (function) { - return function; - } - - // We could not find the function in the global compression functions, - // so we attempt loading it. - return LoadCompressionFunction(*compression_functions, type, physical_type); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/copy_function.cpp b/src/duckdb/src/function/copy_function.cpp deleted file mode 100644 index ac2bc7545..000000000 --- a/src/duckdb/src/function/copy_function.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/function/copy_function.hpp" - -namespace duckdb { - -vector GetCopyFunctionReturnNames(CopyFunctionReturnType return_type) { - switch (return_type) { - case CopyFunctionReturnType::CHANGED_ROWS: - return {"Count"}; - case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: - return {"Count", "Files"}; - default: - throw NotImplementedException("Unknown CopyFunctionReturnType"); - } -} - -vector GetCopyFunctionReturnLogicalTypes(CopyFunctionReturnType return_type) { - switch (return_type) { - case CopyFunctionReturnType::CHANGED_ROWS: - return {LogicalType::BIGINT}; - case CopyFunctionReturnType::CHANGED_ROWS_AND_FILE_LIST: - return {LogicalType::BIGINT, LogicalType::LIST(LogicalType::VARCHAR)}; - default: - throw NotImplementedException("Unknown CopyFunctionReturnType"); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/encoding_function.cpp b/src/duckdb/src/function/encoding_function.cpp deleted file mode 100644 index 644652c46..000000000 --- a/src/duckdb/src/function/encoding_function.cpp +++ /dev/null @@ -1,134 +0,0 @@ -#include "duckdb/function/encoding_function.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -struct DefaultEncodeMethod { - string name; - encode_t encode_function; - idx_t ratio; - idx_t bytes_per_iteration; -}; - -void DecodeUTF16ToUTF8(const char *source_buffer, idx_t &source_buffer_current_position, const idx_t source_buffer_size, - char *target_buffer, idx_t &target_buffer_current_position, const idx_t target_buffer_size, - char *remaining_bytes_buffer, idx_t &remaining_bytes_size) { - - for (; source_buffer_current_position < source_buffer_size; source_buffer_current_position += 2) { - if (target_buffer_current_position == target_buffer_size) { - // We are done - return; - } - const uint16_t ch = - static_cast(static_cast(source_buffer[source_buffer_current_position]) | - (static_cast(source_buffer[source_buffer_current_position + 1]) << 8)); - if (ch >= 0xD800 && ch <= 0xDFFF) { - throw InvalidInputException("File is not utf-16 encoded"); - } - if (ch <= 0x007F) { - // 1-byte UTF-8 for ASCII characters - target_buffer[target_buffer_current_position++] = static_cast(ch & 0x7F); - } else if (ch <= 0x07FF) { - // 2-byte UTF-8 - target_buffer[target_buffer_current_position++] = static_cast(0xC0 | (ch >> 6)); - if (target_buffer_current_position == target_buffer_size) { - // We are done, but we have to store one byte for the next chunk! - source_buffer_current_position += 2; - remaining_bytes_buffer[0] = static_cast(0x80 | (ch & 0x3F)); - remaining_bytes_size = 1; - return; - } - target_buffer[target_buffer_current_position++] = static_cast(0x80 | (ch & 0x3F)); - } else { - // 3-byte UTF-8 - target_buffer[target_buffer_current_position++] = static_cast(0xE0 | (ch >> 12)); - if (target_buffer_current_position == target_buffer_size) { - // We are done, but we have to store two bytes for the next chunk! - source_buffer_current_position += 2; - remaining_bytes_buffer[0] = static_cast(0x80 | ((ch >> 6) & 0x3F)); - remaining_bytes_buffer[1] = static_cast(0x80 | (ch & 0x3F)); - remaining_bytes_size = 2; - return; - } - target_buffer[target_buffer_current_position++] = static_cast(0x80 | ((ch >> 6) & 0x3F)); - if (target_buffer_current_position == target_buffer_size) { - // We are done, but we have to store one byte for the next chunk! - source_buffer_current_position += 2; - remaining_bytes_buffer[0] = static_cast(0x80 | (ch & 0x3F)); - remaining_bytes_size = 1; - return; - } - target_buffer[target_buffer_current_position++] = static_cast(0x80 | (ch & 0x3F)); - } - } -} - -void DecodeLatin1ToUTF8(const char *source_buffer, idx_t &source_buffer_current_position, - const idx_t source_buffer_size, char *target_buffer, idx_t &target_buffer_current_position, - const idx_t target_buffer_size, char *remaining_bytes_buffer, idx_t &remaining_bytes_size) { - for (; source_buffer_current_position < source_buffer_size; source_buffer_current_position++) { - if (target_buffer_current_position == target_buffer_size) { - // We are done - return; - } - const unsigned char ch = static_cast(source_buffer[source_buffer_current_position]); - if (ch > 0x7F && ch <= 0x9F) { - throw InvalidInputException("File is not latin-1 encoded"); - } - if (ch <= 0x7F) { - // ASCII: 1 byte in UTF-8 - target_buffer[target_buffer_current_position++] = static_cast(ch); - } else { - // Non-ASCII: 2 bytes in UTF-8 - target_buffer[target_buffer_current_position++] = static_cast(0xc2 + (ch > 0xbf)); - if (target_buffer_current_position == target_buffer_size) { - // We are done, but we have to store one byte for the next chunk! - source_buffer_current_position++; - remaining_bytes_buffer[0] = static_cast((ch & 0x3f) + 0x80); - remaining_bytes_size = 1; - return; - } - target_buffer[target_buffer_current_position++] = static_cast((ch & 0x3f) + 0x80); - } - } -} - -void DecodeUTF8(const char *source_buffer, idx_t &source_buffer_current_position, const idx_t source_buffer_size, - char *target_buffer, idx_t &target_buffer_current_position, const idx_t target_buffer_size, - char *remaining_bytes_buffer, idx_t &remaining_bytes_size) { - throw InternalException("Decode UTF8 is not a valid function, and should be verified one level up."); -} - -void EncodingFunctionSet::Initialize(DBConfig &config) { - config.RegisterEncodeFunction({"utf-8", DecodeUTF8, 1, 1}); - config.RegisterEncodeFunction({"latin-1", DecodeLatin1ToUTF8, 2, 1}); - config.RegisterEncodeFunction({"utf-16", DecodeUTF16ToUTF8, 2, 2}); -} - -void DBConfig::RegisterEncodeFunction(const EncodingFunction &function) const { - lock_guard l(encoding_functions->lock); - const auto decode_type = function.GetType(); - if (encoding_functions->functions.find(decode_type) != encoding_functions->functions.end()) { - throw InvalidInputException("Decoding function with name %s already registered", decode_type); - } - encoding_functions->functions[decode_type] = function; -} - -optional_ptr DBConfig::GetEncodeFunction(const string &name) const { - lock_guard l(encoding_functions->lock); - // Check if the function is already loaded into the global compression functions. - if (encoding_functions->functions.find(name) != encoding_functions->functions.end()) { - return &encoding_functions->functions[name]; - } - return nullptr; -} - -vector> DBConfig::GetLoadedEncodedFunctions() const { - lock_guard l(encoding_functions->lock); - vector> result; - for (auto &function : encoding_functions->functions) { - result.push_back(function.second); - } - return result; -} -} // namespace duckdb diff --git a/src/duckdb/src/function/function.cpp b/src/duckdb/src/function/function.cpp deleted file mode 100644 index 3307d3c54..000000000 --- a/src/duckdb/src/function/function.cpp +++ /dev/null @@ -1,160 +0,0 @@ -#include "duckdb/function/function.hpp" - -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/hash.hpp" -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/parser/parsed_data/pragma_info.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/main/extension_entries.hpp" - -namespace duckdb { - -FunctionData::~FunctionData() { -} - -bool FunctionData::Equals(const FunctionData *left, const FunctionData *right) { - if (left == right) { - return true; - } - if (!left || !right) { - return false; - } - return left->Equals(*right); -} - -TableFunctionData::~TableFunctionData() { -} - -unique_ptr TableFunctionData::Copy() const { - throw InternalException("Copy not supported for TableFunctionData"); -} - -bool TableFunctionData::Equals(const FunctionData &other) const { - return false; -} - -Function::Function(string name_p) : name(std::move(name_p)) { -} -Function::~Function() { -} - -SimpleFunction::SimpleFunction(string name_p, vector arguments_p, LogicalType varargs_p) - : Function(std::move(name_p)), arguments(std::move(arguments_p)), varargs(std::move(varargs_p)) { -} - -SimpleFunction::~SimpleFunction() { -} - -string SimpleFunction::ToString() const { - return Function::CallToString(name, arguments, varargs); -} - -bool SimpleFunction::HasVarArgs() const { - return varargs.id() != LogicalTypeId::INVALID; -} - -SimpleNamedParameterFunction::SimpleNamedParameterFunction(string name_p, vector arguments_p, - LogicalType varargs_p) - : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)) { -} - -SimpleNamedParameterFunction::~SimpleNamedParameterFunction() { -} - -string SimpleNamedParameterFunction::ToString() const { - return Function::CallToString(name, arguments, named_parameters); -} - -bool SimpleNamedParameterFunction::HasNamedParameters() const { - return !named_parameters.empty(); -} - -BaseScalarFunction::BaseScalarFunction(string name_p, vector arguments_p, LogicalType return_type_p, - FunctionStability stability, LogicalType varargs_p, - FunctionNullHandling null_handling, FunctionErrors errors) - : SimpleFunction(std::move(name_p), std::move(arguments_p), std::move(varargs_p)), - return_type(std::move(return_type_p)), stability(stability), null_handling(null_handling), errors(errors), - collation_handling(FunctionCollationHandling::PROPAGATE_COLLATIONS) { -} - -BaseScalarFunction::~BaseScalarFunction() { -} - -string BaseScalarFunction::ToString() const { - return Function::CallToString(name, arguments, varargs, return_type); -} - -// add your initializer for new functions here -void BuiltinFunctions::Initialize() { - RegisterTableScanFunctions(); - RegisterSQLiteFunctions(); - RegisterReadFunctions(); - RegisterTableFunctions(); - RegisterArrowFunctions(); - - RegisterPragmaFunctions(); - - // initialize collations - AddCollation("nocase", LowerFun::GetFunction(), true); - AddCollation("noaccent", StripAccentsFun::GetFunction(), true); - AddCollation("nfc", NFCNormalizeFun::GetFunction()); - - RegisterExtensionOverloads(); -} - -hash_t BaseScalarFunction::Hash() const { - hash_t hash = return_type.Hash(); - for (auto &arg : arguments) { - hash = duckdb::CombineHash(hash, arg.Hash()); - } - return hash; -} - -string Function::CallToString(const string &name, const vector &arguments, const LogicalType &varargs) { - string result = name + "("; - vector string_arguments; - for (auto &arg : arguments) { - string_arguments.push_back(arg.ToString()); - } - if (varargs.IsValid()) { - string_arguments.push_back("[" + varargs.ToString() + "...]"); - } - result += StringUtil::Join(string_arguments, ", "); - return result + ")"; -} - -string Function::CallToString(const string &name, const vector &arguments, const LogicalType &varargs, - const LogicalType &return_type) { - string result = CallToString(name, arguments, varargs); - result += " -> " + return_type.ToString(); - return result; -} - -string Function::CallToString(const string &name, const vector &arguments, - const named_parameter_type_map_t &named_parameters) { - vector input_arguments; - input_arguments.reserve(arguments.size() + named_parameters.size()); - for (auto &arg : arguments) { - input_arguments.push_back(arg.ToString()); - } - for (auto &kv : named_parameters) { - input_arguments.push_back(StringUtil::Format("%s : %s", kv.first, kv.second.ToString())); - } - return StringUtil::Format("%s(%s)", name, StringUtil::Join(input_arguments, ", ")); -} - -void Function::EraseArgument(SimpleFunction &bound_function, vector> &arguments, - idx_t argument_index) { - if (bound_function.original_arguments.empty()) { - bound_function.original_arguments = bound_function.arguments; - } - D_ASSERT(arguments.size() == bound_function.arguments.size()); - D_ASSERT(argument_index < arguments.size()); - arguments.erase_at(argument_index); - bound_function.arguments.erase_at(argument_index); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/function_binder.cpp b/src/duckdb/src/function/function_binder.cpp deleted file mode 100644 index 671441825..000000000 --- a/src/duckdb/src/function/function_binder.cpp +++ /dev/null @@ -1,487 +0,0 @@ -#include "duckdb/function/function_binder.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/aggregate_function.hpp" -#include "duckdb/function/cast_rules.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/parser/parsed_data/create_secret_info.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -FunctionBinder::FunctionBinder(ClientContext &context_p) : binder(nullptr), context(context_p) { -} -FunctionBinder::FunctionBinder(Binder &binder_p) : binder(&binder_p), context(binder_p.context) { -} - -optional_idx FunctionBinder::BindVarArgsFunctionCost(const SimpleFunction &func, const vector &arguments) { - if (arguments.size() < func.arguments.size()) { - // not enough arguments to fulfill the non-vararg part of the function - return optional_idx(); - } - idx_t cost = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - LogicalType arg_type = i < func.arguments.size() ? func.arguments[i] : func.varargs; - if (arguments[i] == arg_type) { - // arguments match: do nothing - continue; - } - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], arg_type); - if (cast_cost >= 0) { - // we can implicitly cast, add the cost to the total cost - cost += idx_t(cast_cost); - } else { - // we can't implicitly cast: throw an error - return optional_idx(); - } - } - return cost; -} - -optional_idx FunctionBinder::BindFunctionCost(const SimpleFunction &func, const vector &arguments) { - if (func.HasVarArgs()) { - // special case varargs function - return BindVarArgsFunctionCost(func, arguments); - } - if (func.arguments.size() != arguments.size()) { - // invalid argument count: check the next function - return optional_idx(); - } - idx_t cost = 0; - bool has_parameter = false; - for (idx_t i = 0; i < arguments.size(); i++) { - if (arguments[i].id() == LogicalTypeId::UNKNOWN) { - has_parameter = true; - continue; - } - int64_t cast_cost = CastFunctionSet::Get(context).ImplicitCastCost(arguments[i], func.arguments[i]); - if (cast_cost >= 0) { - // we can implicitly cast, add the cost to the total cost - cost += idx_t(cast_cost); - } else { - // we can't implicitly cast: throw an error - return optional_idx(); - } - } - if (has_parameter) { - // all arguments are implicitly castable and there is a parameter - return 0 as cost - return 0; - } - return cost; -} - -template -vector FunctionBinder::BindFunctionsFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error) { - optional_idx best_function; - idx_t lowest_cost = NumericLimits::Maximum(); - vector candidate_functions; - for (idx_t f_idx = 0; f_idx < functions.functions.size(); f_idx++) { - auto &func = functions.functions[f_idx]; - // check the arguments of the function - auto bind_cost = BindFunctionCost(func, arguments); - if (!bind_cost.IsValid()) { - // auto casting was not possible - continue; - } - auto cost = bind_cost.GetIndex(); - if (cost == lowest_cost) { - candidate_functions.push_back(f_idx); - continue; - } - if (cost > lowest_cost) { - continue; - } - candidate_functions.clear(); - lowest_cost = cost; - best_function = f_idx; - } - if (!best_function.IsValid()) { - // no matching function was found, throw an error - vector candidates; - for (auto &f : functions.functions) { - candidates.push_back(f.ToString()); - } - error = ErrorData(BinderException::NoMatchingFunction(name, arguments, candidates)); - return candidate_functions; - } - candidate_functions.push_back(best_function.GetIndex()); - return candidate_functions; -} - -template -optional_idx FunctionBinder::MultipleCandidateException(const string &name, FunctionSet &functions, - vector &candidate_functions, - const vector &arguments, ErrorData &error) { - D_ASSERT(functions.functions.size() > 1); - // there are multiple possible function definitions - // throw an exception explaining which overloads are there - string call_str = Function::CallToString(name, arguments); - string candidate_str; - for (auto &conf : candidate_functions) { - T f = functions.GetFunctionByOffset(conf); - candidate_str += "\t" + f.ToString() + "\n"; - } - error = ErrorData( - ExceptionType::BINDER, - StringUtil::Format("Could not choose a best candidate function for the function call \"%s\". In order to " - "select one, please add explicit type casts.\n\tCandidate functions:\n%s", - call_str, candidate_str)); - return optional_idx(); -} - -template -optional_idx FunctionBinder::BindFunctionFromArguments(const string &name, FunctionSet &functions, - const vector &arguments, ErrorData &error) { - auto candidate_functions = BindFunctionsFromArguments(name, functions, arguments, error); - if (candidate_functions.empty()) { - // no candidates - return optional_idx(); - } - if (candidate_functions.size() > 1) { - // multiple candidates, check if there are any unknown arguments - bool has_parameters = false; - for (auto &arg_type : arguments) { - if (arg_type.id() == LogicalTypeId::UNKNOWN) { - //! there are! we could not resolve parameters in this case - throw ParameterNotResolvedException(); - } - } - if (!has_parameters) { - return MultipleCandidateException(name, functions, candidate_functions, arguments, error); - } - } - return candidate_functions[0]; -} - -optional_idx FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -optional_idx FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -optional_idx FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - const vector &arguments, ErrorData &error) { - return BindFunctionFromArguments(name, functions, arguments, error); -} - -optional_idx FunctionBinder::BindFunction(const string &name, PragmaFunctionSet &functions, vector ¶meters, - ErrorData &error) { - vector types; - for (auto &value : parameters) { - types.push_back(value.type()); - } - auto entry = BindFunctionFromArguments(name, functions, types, error); - if (!entry.IsValid()) { - error.Throw(); - } - auto candidate_function = functions.GetFunctionByOffset(entry.GetIndex()); - // cast the input parameters - for (idx_t i = 0; i < parameters.size(); i++) { - auto target_type = - i < candidate_function.arguments.size() ? candidate_function.arguments[i] : candidate_function.varargs; - parameters[i] = parameters[i].CastAs(context, target_type); - } - return entry; -} - -vector FunctionBinder::GetLogicalTypesFromExpressions(vector> &arguments) { - vector types; - types.reserve(arguments.size()); - for (auto &argument : arguments) { - types.push_back(ExpressionBinder::GetExpressionReturnType(*argument)); - } - return types; -} - -optional_idx FunctionBinder::BindFunction(const string &name, ScalarFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -optional_idx FunctionBinder::BindFunction(const string &name, AggregateFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -optional_idx FunctionBinder::BindFunction(const string &name, TableFunctionSet &functions, - vector> &arguments, ErrorData &error) { - auto types = GetLogicalTypesFromExpressions(arguments); - return BindFunction(name, functions, types, error); -} - -enum class LogicalTypeComparisonResult : uint8_t { IDENTICAL_TYPE, TARGET_IS_ANY, DIFFERENT_TYPES }; - -LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const LogicalType &target_type) { - if (target_type.id() == LogicalTypeId::ANY) { - return LogicalTypeComparisonResult::TARGET_IS_ANY; - } - if (source_type == target_type) { - return LogicalTypeComparisonResult::IDENTICAL_TYPE; - } - if (source_type.id() == LogicalTypeId::LIST && target_type.id() == LogicalTypeId::LIST) { - return RequiresCast(ListType::GetChildType(source_type), ListType::GetChildType(target_type)); - } - if (source_type.id() == LogicalTypeId::ARRAY && target_type.id() == LogicalTypeId::ARRAY) { - return RequiresCast(ArrayType::GetChildType(source_type), ArrayType::GetChildType(target_type)); - } - return LogicalTypeComparisonResult::DIFFERENT_TYPES; -} - -bool TypeRequiresPrepare(const LogicalType &type) { - if (type.id() == LogicalTypeId::ANY) { - return true; - } - if (type.id() == LogicalTypeId::LIST) { - return TypeRequiresPrepare(ListType::GetChildType(type)); - } - return false; -} - -LogicalType PrepareTypeForCastRecursive(const LogicalType &type) { - if (type.id() == LogicalTypeId::ANY) { - return AnyType::GetTargetType(type); - } - if (type.id() == LogicalTypeId::LIST) { - return LogicalType::LIST(PrepareTypeForCastRecursive(ListType::GetChildType(type))); - } - return type; -} - -void PrepareTypeForCast(LogicalType &type) { - if (!TypeRequiresPrepare(type)) { - return; - } - type = PrepareTypeForCastRecursive(type); -} - -void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector> &children) { - for (auto &arg : function.arguments) { - PrepareTypeForCast(arg); - } - PrepareTypeForCast(function.varargs); - - for (idx_t i = 0; i < children.size(); i++) { - auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs; - if (target_type.id() == LogicalTypeId::STRING_LITERAL || target_type.id() == LogicalTypeId::INTEGER_LITERAL) { - throw InternalException( - "Function %s returned a STRING_LITERAL or INTEGER_LITERAL type - return an explicit type instead", - function.name); - } - target_type.Verify(); - // don't cast lambda children, they get removed before execution - if (children[i]->return_type.id() == LogicalTypeId::LAMBDA) { - continue; - } - // check if the type of child matches the type of function argument - // if not we need to add a cast - auto cast_result = RequiresCast(children[i]->return_type, target_type); - // except for one special case: if the function accepts ANY argument - // in that case we don't add a cast - if (cast_result == LogicalTypeComparisonResult::DIFFERENT_TYPES) { - children[i] = BoundCastExpression::AddCastToType(context, std::move(children[i]), target_type); - } - } -} - -unique_ptr FunctionBinder::BindScalarFunction(const string &schema, const string &name, - vector> children, ErrorData &error, - bool is_operator, optional_ptr binder) { - // bind the function - auto &function = - Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, schema, name); - D_ASSERT(function.type == CatalogType::SCALAR_FUNCTION_ENTRY); - return BindScalarFunction(function.Cast(), std::move(children), error, is_operator, - binder); -} - -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunctionCatalogEntry &func, - vector> children, ErrorData &error, - bool is_operator, optional_ptr binder) { - // bind the function - auto best_function = BindFunction(func.name, func.functions, children, error); - if (!best_function.IsValid()) { - return nullptr; - } - - // found a matching function! - auto bound_function = func.functions.GetFunctionByOffset(best_function.GetIndex()); - - // If any of the parameters are NULL, the function will just be replaced with a NULL constant. - // We try to give the NULL constant the correct type, but we have to do this without binding the function, - // because functions with DEFAULT_NULL_HANDLING should not have to deal with NULL inputs in their bind code. - // Some functions may have an invalid default return type, as they must be bound to infer the return type. - // In those cases, we default to SQLNULL. - const auto return_type_if_null = - bound_function.return_type.IsComplete() ? bound_function.return_type : LogicalType::SQLNULL; - if (bound_function.null_handling == FunctionNullHandling::DEFAULT_NULL_HANDLING) { - for (auto &child : children) { - if (child->return_type == LogicalTypeId::SQLNULL) { - return make_uniq(Value(return_type_if_null)); - } - if (!child->IsFoldable()) { - continue; - } - Value result; - if (!ExpressionExecutor::TryEvaluateScalar(context, *child, result)) { - continue; - } - if (result.IsNull()) { - return make_uniq(Value(return_type_if_null)); - } - } - } - return BindScalarFunction(bound_function, std::move(children), is_operator, binder); -} - -bool RequiresCollationPropagation(const LogicalType &type) { - return type.id() == LogicalTypeId::VARCHAR && !type.HasAlias(); -} - -string ExtractCollation(const vector> &children) { - string collation; - for (auto &arg : children) { - if (!RequiresCollationPropagation(arg->return_type)) { - // not a varchar column - continue; - } - auto child_collation = StringType::GetCollation(arg->return_type); - if (collation.empty()) { - collation = child_collation; - } else if (!child_collation.empty() && collation != child_collation) { - throw BinderException("Cannot combine types with different collation!"); - } - } - return collation; -} - -void PropagateCollations(ClientContext &, ScalarFunction &bound_function, vector> &children) { - if (!RequiresCollationPropagation(bound_function.return_type)) { - // we only need to propagate if the function returns a varchar - return; - } - auto collation = ExtractCollation(children); - if (collation.empty()) { - // no collation to propagate - return; - } - // propagate the collation to the return type - auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - bound_function.return_type = std::move(collation_type); -} - -void PushCollations(ClientContext &context, ScalarFunction &bound_function, vector> &children, - CollationType type) { - auto collation = ExtractCollation(children); - if (collation.empty()) { - // no collation to push - return; - } - // push collation into the return type if required - auto collation_type = LogicalType::VARCHAR_COLLATION(std::move(collation)); - if (RequiresCollationPropagation(bound_function.return_type)) { - bound_function.return_type = collation_type; - } - // push collations to the children - for (auto &arg : children) { - if (RequiresCollationPropagation(arg->return_type)) { - // if this is a varchar type - propagate the collation - arg->return_type = collation_type; - } - // now push the actual collation handling - ExpressionBinder::PushCollation(context, arg, arg->return_type, type); - } -} - -void HandleCollations(ClientContext &context, ScalarFunction &bound_function, - vector> &children) { - switch (bound_function.collation_handling) { - case FunctionCollationHandling::IGNORE_COLLATIONS: - // explicitly ignoring collation handling - break; - case FunctionCollationHandling::PROPAGATE_COLLATIONS: - PropagateCollations(context, bound_function, children); - break; - case FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS: - // first propagate, then push collations to the children - PushCollations(context, bound_function, children, CollationType::COMBINABLE_COLLATIONS); - break; - default: - throw InternalException("Unrecognized collation handling"); - } -} - -unique_ptr FunctionBinder::BindScalarFunction(ScalarFunction bound_function, - vector> children, bool is_operator, - optional_ptr binder) { - unique_ptr bind_info; - - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); - } else if (bound_function.bind_extended) { - if (!binder) { - throw InternalException("Function '%s' has a 'bind_extended' but the FunctionBinder was created without " - "a reference to a Binder", - bound_function.name); - } - ScalarFunctionBindInput bind_input(*binder); - bind_info = bound_function.bind_extended(bind_input, bound_function, children); - } - - if (bound_function.get_modified_databases && binder) { - auto &properties = binder->GetStatementProperties(); - FunctionModifiedDatabasesInput input(bind_info, properties); - bound_function.get_modified_databases(context, input); - } - HandleCollations(context, bound_function, children); - - // check if we need to add casts to the children - CastToFunctionArguments(bound_function, children); - - auto return_type = bound_function.return_type; - unique_ptr result; - auto result_func = make_uniq(std::move(return_type), std::move(bound_function), - std::move(children), std::move(bind_info), is_operator); - if (result_func->function.bind_expression) { - // if a bind_expression callback is registered - call it and emit the resulting expression - FunctionBindExpressionInput input(context, result_func->bind_info.get(), *result_func); - result = result_func->function.bind_expression(input); - } - if (!result) { - result = std::move(result_func); - } - return result; -} - -unique_ptr FunctionBinder::BindAggregateFunction(AggregateFunction bound_function, - vector> children, - unique_ptr filter, - AggregateType aggr_type) { - unique_ptr bind_info; - if (bound_function.bind) { - bind_info = bound_function.bind(context, bound_function, children); - // we may have lost some arguments in the bind - children.resize(MinValue(bound_function.arguments.size(), children.size())); - } - - // check if we need to add casts to the children - CastToFunctionArguments(bound_function, children); - - return make_uniq(std::move(bound_function), std::move(children), std::move(filter), - std::move(bind_info), aggr_type); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/function_list.cpp b/src/duckdb/src/function/function_list.cpp deleted file mode 100644 index 70d869703..000000000 --- a/src/duckdb/src/function/function_list.cpp +++ /dev/null @@ -1,177 +0,0 @@ -#include "duckdb/function/function_list.hpp" - -#include "duckdb/function/aggregate/distributive_functions.hpp" -#include "duckdb/function/scalar/compressed_materialization_functions.hpp" -#include "duckdb/function/scalar/date_functions.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/function/scalar/map_functions.hpp" -#include "duckdb/function/scalar/operator_functions.hpp" -#include "duckdb/function/scalar/sequence_functions.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar/struct_functions.hpp" -#include "duckdb/function/scalar/system_functions.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" - -namespace duckdb { - -// Scalar Function -#define DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, _PARAM::GetFunction, nullptr, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Scalar Function Set -#define DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, _PARAM::GetFunctions, nullptr, nullptr } -#define DUCKDB_SCALAR_FUNCTION_SET(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_SCALAR_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_SCALAR_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function -#define DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, _PARAM::GetFunction, nullptr } -#define DUCKDB_AGGREGATE_FUNCTION(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_BASE(_PARAM::ALIAS, _PARAM::Name) -// Aggregate Function Set -#define DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _NAME) \ - { _NAME, _PARAM::Parameters, _PARAM::Description, _PARAM::Example, nullptr, nullptr, nullptr, _PARAM::GetFunctions } -#define DUCKDB_AGGREGATE_FUNCTION_SET(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM, _PARAM::Name) -#define DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(_PARAM) DUCKDB_AGGREGATE_FUNCTION_SET_BASE(_PARAM::ALIAS, _PARAM::Name) -#define FINAL_FUNCTION \ - { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr } - -// this list is generated by scripts/generate_functions.py -static const StaticFunctionDefinition function[] = { - DUCKDB_SCALAR_FUNCTION(NotLikeFun), - DUCKDB_SCALAR_FUNCTION(NotILikeFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorModuloFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorMultiplyFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorAddFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorSubtractFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorFloatDivideFun), - DUCKDB_SCALAR_FUNCTION_SET(OperatorIntegerDivideFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUbigintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUintegerFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUsmallintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalCompressIntegralUtinyintFun), - DUCKDB_SCALAR_FUNCTION(InternalCompressStringHugeintFun), - DUCKDB_SCALAR_FUNCTION(InternalCompressStringUbigintFun), - DUCKDB_SCALAR_FUNCTION(InternalCompressStringUintegerFun), - DUCKDB_SCALAR_FUNCTION(InternalCompressStringUsmallintFun), - DUCKDB_SCALAR_FUNCTION(InternalCompressStringUtinyintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralBigintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralHugeintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralIntegerFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralSmallintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUbigintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUhugeintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUintegerFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressIntegralUsmallintFun), - DUCKDB_SCALAR_FUNCTION_SET(InternalDecompressStringFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(AddFun), - DUCKDB_AGGREGATE_FUNCTION_SET(AnyValueFun), - DUCKDB_AGGREGATE_FUNCTION_SET_ALIAS(ArbitraryFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayCatFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayConcatFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayContainsFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayExtractFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayHasFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayIndexofFun), - DUCKDB_SCALAR_FUNCTION_SET(ArrayLengthFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayPositionFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ArrayResizeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArraySelectFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayWhereFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ArrayZipFun), - DUCKDB_SCALAR_FUNCTION_SET(BitLengthFun), - DUCKDB_SCALAR_FUNCTION(CombineFun), - DUCKDB_SCALAR_FUNCTION(ConcatFun), - DUCKDB_SCALAR_FUNCTION(ConcatWsFun), - DUCKDB_SCALAR_FUNCTION(ConstantOrNullFun), - DUCKDB_SCALAR_FUNCTION_SET(ContainsFun), - DUCKDB_AGGREGATE_FUNCTION_SET(CountFun), - DUCKDB_AGGREGATE_FUNCTION(CountStarFun), - DUCKDB_SCALAR_FUNCTION(CreateSortKeyFun), - DUCKDB_SCALAR_FUNCTION(CurrvalFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(DivideFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(EndsWithFun), - DUCKDB_SCALAR_FUNCTION(ErrorFun), - DUCKDB_SCALAR_FUNCTION(FinalizeFun), - DUCKDB_AGGREGATE_FUNCTION_SET(FirstFun), - DUCKDB_SCALAR_FUNCTION(GetVariableFun), - DUCKDB_SCALAR_FUNCTION(IlikeEscapeFun), - DUCKDB_AGGREGATE_FUNCTION_SET(LastFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(LcaseFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(LenFun), - DUCKDB_SCALAR_FUNCTION_SET(LengthFun), - DUCKDB_SCALAR_FUNCTION_SET(LengthGraphemeFun), - DUCKDB_SCALAR_FUNCTION(LikeEscapeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListCatFun), - DUCKDB_SCALAR_FUNCTION(ListConcatFun), - DUCKDB_SCALAR_FUNCTION(ListContainsFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ListElementFun), - DUCKDB_SCALAR_FUNCTION_SET(ListExtractFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListHasFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(ListIndexofFun), - DUCKDB_SCALAR_FUNCTION(ListPositionFun), - DUCKDB_SCALAR_FUNCTION_SET(ListResizeFun), - DUCKDB_SCALAR_FUNCTION(ListSelectFun), - DUCKDB_SCALAR_FUNCTION(ListWhereFun), - DUCKDB_SCALAR_FUNCTION(ListZipFun), - DUCKDB_SCALAR_FUNCTION(LowerFun), - DUCKDB_SCALAR_FUNCTION(MapContainsFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MaxFun), - DUCKDB_SCALAR_FUNCTION_SET(MD5Fun), - DUCKDB_SCALAR_FUNCTION_SET(MD5NumberFun), - DUCKDB_AGGREGATE_FUNCTION_SET(MinFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(ModFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(MultiplyFun), - DUCKDB_SCALAR_FUNCTION(NextvalFun), - DUCKDB_SCALAR_FUNCTION(NFCNormalizeFun), - DUCKDB_SCALAR_FUNCTION(NotIlikeEscapeFun), - DUCKDB_SCALAR_FUNCTION(NotLikeEscapeFun), - DUCKDB_SCALAR_FUNCTION_SET(OctetLengthFun), - DUCKDB_SCALAR_FUNCTION(PrefixFun), - DUCKDB_SCALAR_FUNCTION(RegexpEscapeFun), - DUCKDB_SCALAR_FUNCTION_SET(RegexpExtractFun), - DUCKDB_SCALAR_FUNCTION_SET(RegexpExtractAllFun), - DUCKDB_SCALAR_FUNCTION_SET(RegexpFun), - DUCKDB_SCALAR_FUNCTION_SET(RegexpMatchesFun), - DUCKDB_SCALAR_FUNCTION_SET(RegexpReplaceFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(RegexpSplitToArrayFun), - DUCKDB_SCALAR_FUNCTION(RowFun), - DUCKDB_SCALAR_FUNCTION_SET(SHA1Fun), - DUCKDB_SCALAR_FUNCTION_SET(SHA256Fun), - DUCKDB_SCALAR_FUNCTION_ALIAS(SplitFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StrSplitFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(StrSplitRegexFun), - DUCKDB_SCALAR_FUNCTION_SET(StrfTimeFun), - DUCKDB_SCALAR_FUNCTION(StringSplitFun), - DUCKDB_SCALAR_FUNCTION_SET(StringSplitRegexFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(StringToArrayFun), - DUCKDB_SCALAR_FUNCTION(StripAccentsFun), - DUCKDB_SCALAR_FUNCTION(StrlenFun), - DUCKDB_SCALAR_FUNCTION_SET(StrpTimeFun), - DUCKDB_SCALAR_FUNCTION(StructConcatFun), - DUCKDB_SCALAR_FUNCTION_SET(StructExtractFun), - DUCKDB_SCALAR_FUNCTION_SET(StructExtractAtFun), - DUCKDB_SCALAR_FUNCTION(StructPackFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(SubstrFun), - DUCKDB_SCALAR_FUNCTION_SET(SubstringFun), - DUCKDB_SCALAR_FUNCTION_SET(SubstringGraphemeFun), - DUCKDB_SCALAR_FUNCTION_SET_ALIAS(SubtractFun), - DUCKDB_SCALAR_FUNCTION(SuffixFun), - DUCKDB_SCALAR_FUNCTION_SET(TryStrpTimeFun), - DUCKDB_SCALAR_FUNCTION_ALIAS(UcaseFun), - DUCKDB_SCALAR_FUNCTION(UpperFun), - DUCKDB_SCALAR_FUNCTION(ConcatOperatorFun), - DUCKDB_SCALAR_FUNCTION(LikeFun), - DUCKDB_SCALAR_FUNCTION(ILikeFun), - DUCKDB_SCALAR_FUNCTION(GlobPatternFun), - FINAL_FUNCTION -}; - -const StaticFunctionDefinition *FunctionList::GetInternalFunctionList() { - return function; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/function_set.cpp b/src/duckdb/src/function/function_set.cpp deleted file mode 100644 index cf48c14e4..000000000 --- a/src/duckdb/src/function/function_set.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include "duckdb/function/function_set.hpp" -#include "duckdb/function/function_binder.hpp" - -namespace duckdb { - -ScalarFunctionSet::ScalarFunctionSet() : FunctionSet("") { -} - -ScalarFunctionSet::ScalarFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -ScalarFunctionSet::ScalarFunctionSet(ScalarFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -ScalarFunction ScalarFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { - ErrorData error; - FunctionBinder binder(context); - auto index = binder.BindFunction(name, *this, arguments, error); - if (!index.IsValid()) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); - } - return GetFunctionByOffset(index.GetIndex()); -} - -AggregateFunctionSet::AggregateFunctionSet() : FunctionSet("") { -} - -AggregateFunctionSet::AggregateFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -AggregateFunctionSet::AggregateFunctionSet(AggregateFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -AggregateFunction AggregateFunctionSet::GetFunctionByArguments(ClientContext &context, - const vector &arguments) { - ErrorData error; - FunctionBinder binder(context); - auto index = binder.BindFunction(name, *this, arguments, error); - if (!index.IsValid()) { - // check if the arguments are a prefix of any of the arguments - // this is used for functions such as quantile or string_agg that delete part of their arguments during bind - // FIXME: we should come up with a better solution here - for (auto &func : functions) { - if (arguments.size() >= func.arguments.size()) { - continue; - } - bool is_prefix = true; - for (idx_t k = 0; k < arguments.size(); k++) { - if (arguments[k].id() != func.arguments[k].id()) { - is_prefix = false; - break; - } - } - if (is_prefix) { - return func; - } - } - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); - } - return GetFunctionByOffset(index.GetIndex()); -} - -TableFunctionSet::TableFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -TableFunctionSet::TableFunctionSet(TableFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -TableFunction TableFunctionSet::GetFunctionByArguments(ClientContext &context, const vector &arguments) { - ErrorData error; - FunctionBinder binder(context); - auto index = binder.BindFunction(name, *this, arguments, error); - if (!index.IsValid()) { - throw InternalException("Failed to find function %s(%s)\n%s", name, StringUtil::ToString(arguments, ","), - error.Message()); - } - return GetFunctionByOffset(index.GetIndex()); -} - -PragmaFunctionSet::PragmaFunctionSet(string name) : FunctionSet(std::move(name)) { -} - -PragmaFunctionSet::PragmaFunctionSet(PragmaFunction fun) : FunctionSet(std::move(fun.name)) { - functions.push_back(std::move(fun)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/macro_function.cpp b/src/duckdb/src/function/macro_function.cpp deleted file mode 100644 index d70150c88..000000000 --- a/src/duckdb/src/function/macro_function.cpp +++ /dev/null @@ -1,138 +0,0 @@ - -#include "duckdb/function/macro_function.hpp" - -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/scalar_macro_function.hpp" -#include "duckdb/function/table_macro_function.hpp" -#include "duckdb/parser/expression/columnref_expression.hpp" -#include "duckdb/parser/expression/comparison_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" - -namespace duckdb { - -MacroFunction::MacroFunction(MacroType type) : type(type) { -} - -string FormatMacroFunction(MacroFunction &function, const string &name) { - string result; - result = name + "("; - string parameters; - for (auto ¶m : function.parameters) { - if (!parameters.empty()) { - parameters += ", "; - } - parameters += param->Cast().GetColumnName(); - } - for (auto &named_param : function.default_parameters) { - if (!parameters.empty()) { - parameters += ", "; - } - parameters += named_param.first; - parameters += " := "; - parameters += named_param.second->ToString(); - } - result += parameters + ")"; - return result; -} - -MacroBindResult MacroFunction::BindMacroFunction(const vector> &functions, const string &name, - FunctionExpression &function_expr, - vector> &positionals, - unordered_map> &defaults) { - // separate positional and default arguments - for (auto &arg : function_expr.children) { - if (!arg->GetAlias().empty()) { - // default argument - if (defaults.count(arg->GetAlias())) { - return MacroBindResult(StringUtil::Format("Duplicate default parameters %s!", arg->GetAlias())); - } - defaults[arg->GetAlias()] = std::move(arg); - } else if (!defaults.empty()) { - return MacroBindResult("Positional parameters cannot come after parameters with a default value!"); - } else { - // positional argument - positionals.push_back(std::move(arg)); - } - } - - // check for each macro function if it matches the number of positional arguments - optional_idx result_idx; - for (idx_t function_idx = 0; function_idx < functions.size(); function_idx++) { - if (functions[function_idx]->parameters.size() == positionals.size()) { - // found a matching function - result_idx = function_idx; - break; - } - } - if (!result_idx.IsValid()) { - // no matching function found - string error; - if (functions.size() == 1) { - // we only have one function - print the old more detailed error message - auto ¯o_def = *functions[0]; - auto ¶meters = macro_def.parameters; - error = StringUtil::Format("Macro function %s requires ", FormatMacroFunction(macro_def, name)); - error += parameters.size() == 1 ? "a single positional argument" - : StringUtil::Format("%i positional arguments", parameters.size()); - error += ", but "; - error += positionals.size() == 1 ? "a single positional argument was" - : StringUtil::Format("%i positional arguments were", positionals.size()); - error += " provided."; - } else { - // we have multiple functions - list all candidates - error += StringUtil::Format("Macro \"%s\" does not support %llu parameters.\n", name, positionals.size()); - error += "Candidate macros:"; - for (auto &function : functions) { - error += "\n\t" + FormatMacroFunction(*function, name); - } - } - return MacroBindResult(error); - } - // found a matching index - check if the default values exist within the macro - auto macro_idx = result_idx.GetIndex(); - auto ¯o_def = *functions[macro_idx]; - for (auto &default_val : defaults) { - auto entry = macro_def.default_parameters.find(default_val.first); - if (entry == macro_def.default_parameters.end()) { - string error = - StringUtil::Format("Macro \"%s\" does not have a named parameter \"%s\"\n", name, default_val.first); - error += "\nMacro definition: " + FormatMacroFunction(macro_def, name); - return MacroBindResult(error); - } - } - // Add the default values for parameters that have defaults, that were not explicitly assigned to - for (auto it = macro_def.default_parameters.begin(); it != macro_def.default_parameters.end(); it++) { - auto ¶meter_name = it->first; - auto ¶meter_default = it->second; - if (!defaults.count(parameter_name)) { - // This parameter was not set yet, set it with the default value - defaults[parameter_name] = parameter_default->Copy(); - } - } - return MacroBindResult(macro_idx); -} - -void MacroFunction::CopyProperties(MacroFunction &other) const { - other.type = type; - for (auto ¶m : parameters) { - other.parameters.push_back(param->Copy()); - } - for (auto &kv : default_parameters) { - other.default_parameters[kv.first] = kv.second->Copy(); - } -} - -string MacroFunction::ToSQL() const { - vector param_strings; - for (auto ¶m : parameters) { - param_strings.push_back(param->ToString()); - } - for (auto &named_param : default_parameters) { - param_strings.push_back(StringUtil::Format("%s := %s", named_param.first, named_param.second->ToString())); - } - return StringUtil::Format("(%s) AS ", StringUtil::Join(param_strings, ", ")); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/pragma/pragma_functions.cpp b/src/duckdb/src/function/pragma/pragma_functions.cpp deleted file mode 100644 index 635828066..000000000 --- a/src/duckdb/src/function/pragma/pragma_functions.cpp +++ /dev/null @@ -1,163 +0,0 @@ -#include "duckdb/function/pragma/pragma_functions.hpp" - -#include "duckdb/common/enums/output_type.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/logging/http_logger.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/query_profiler.hpp" -#include "duckdb/main/secret/secret_manager.hpp" -#include "duckdb/parallel/task_scheduler.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/storage/buffer_manager.hpp" -#include "duckdb/storage/storage_manager.hpp" - -#include - -namespace duckdb { - -static void PragmaEnableProfilingStatement(ClientContext &context, const FunctionParameters ¶meters) { - auto &config = ClientConfig::GetConfig(context); - config.enable_profiler = true; - config.emit_profiler_output = true; -} - -void RegisterEnableProfiling(BuiltinFunctions &set) { - PragmaFunctionSet functions(""); - functions.AddFunction(PragmaFunction::PragmaStatement(string(), PragmaEnableProfilingStatement)); - - set.AddFunction("enable_profile", functions); - set.AddFunction("enable_profiling", functions); -} - -static void PragmaDisableProfiling(ClientContext &context, const FunctionParameters ¶meters) { - auto &config = ClientConfig::GetConfig(context); - config.enable_profiler = false; -} - -static void PragmaEnableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_progress_bar = true; -} - -static void PragmaDisableProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_progress_bar = false; -} - -static void PragmaEnablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).print_progress_bar = true; -} - -static void PragmaDisablePrintProgressBar(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).print_progress_bar = false; -} - -static void PragmaEnableVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).query_verification_enabled = true; - ClientConfig::GetConfig(context).verify_serializer = true; -} - -static void PragmaDisableVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).query_verification_enabled = false; - ClientConfig::GetConfig(context).verify_serializer = false; -} - -static void PragmaVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_serializer = true; -} - -static void PragmaDisableVerifySerializer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_serializer = false; -} - -static void PragmaEnableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_external = true; -} - -static void PragmaDisableExternalVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_external = false; -} - -static void PragmaEnableFetchRowVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_fetch_row = true; -} - -static void PragmaDisableFetchRowVerification(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_fetch_row = false; -} - -static void PragmaEnableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_parallelism = true; -} - -static void PragmaForceCheckpoint(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.force_checkpoint = true; -} - -static void PragmaDisableForceParallelism(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).verify_parallelism = false; -} - -static void PragmaEnableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { -} - -static void PragmaDisableObjectCache(ClientContext &context, const FunctionParameters ¶meters) { -} - -static void PragmaEnableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.checkpoint_on_shutdown = true; -} - -static void PragmaDisableCheckpointOnShutdown(ClientContext &context, const FunctionParameters ¶meters) { - DBConfig::GetConfig(context).options.checkpoint_on_shutdown = false; -} - -static void PragmaEnableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = true; -} - -static void PragmaDisableOptimizer(ClientContext &context, const FunctionParameters ¶meters) { - ClientConfig::GetConfig(context).enable_optimizer = false; -} - -void PragmaFunctions::RegisterFunction(BuiltinFunctions &set) { - RegisterEnableProfiling(set); - - set.AddFunction(PragmaFunction::PragmaStatement("disable_profile", PragmaDisableProfiling)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_profiling", PragmaDisableProfiling)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_verification", PragmaEnableVerification)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verification", PragmaDisableVerification)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_external", PragmaEnableExternalVerification)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_external", PragmaDisableExternalVerification)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_fetch_row", PragmaEnableFetchRowVerification)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_fetch_row", PragmaDisableFetchRowVerification)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_serializer", PragmaVerifySerializer)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_serializer", PragmaDisableVerifySerializer)); - - set.AddFunction(PragmaFunction::PragmaStatement("verify_parallelism", PragmaEnableForceParallelism)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_verify_parallelism", PragmaDisableForceParallelism)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_object_cache", PragmaEnableObjectCache)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_object_cache", PragmaDisableObjectCache)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_optimizer", PragmaEnableOptimizer)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_optimizer", PragmaDisableOptimizer)); - - set.AddFunction(PragmaFunction::PragmaStatement("force_checkpoint", PragmaForceCheckpoint)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_progress_bar", PragmaEnableProgressBar)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_progress_bar", PragmaDisableProgressBar)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_print_progress_bar", PragmaEnablePrintProgressBar)); - set.AddFunction(PragmaFunction::PragmaStatement("disable_print_progress_bar", PragmaDisablePrintProgressBar)); - - set.AddFunction(PragmaFunction::PragmaStatement("enable_checkpoint_on_shutdown", PragmaEnableCheckpointOnShutdown)); - set.AddFunction( - PragmaFunction::PragmaStatement("disable_checkpoint_on_shutdown", PragmaDisableCheckpointOnShutdown)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/pragma/pragma_queries.cpp b/src/duckdb/src/function/pragma/pragma_queries.cpp deleted file mode 100644 index 66dd6ace7..000000000 --- a/src/duckdb/src/function/pragma/pragma_queries.cpp +++ /dev/null @@ -1,222 +0,0 @@ -#include "duckdb/catalog/catalog_search_path.hpp" -#include "duckdb/common/constants.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/pragma/pragma_functions.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/parser/parser.hpp" -#include "duckdb/parser/qualified_name.hpp" -#include "duckdb/parser/statement/copy_statement.hpp" -#include "duckdb/parser/statement/export_statement.hpp" - -namespace duckdb { - -string PragmaTableInfo(ClientContext &context, const FunctionParameters ¶meters) { - return StringUtil::Format("SELECT * FROM pragma_table_info(%s);", - KeywordHelper::WriteQuoted(parameters.values[0].ToString(), '\'')); -} - -string PragmaShowTables() { - // clang-format off - return R"EOF( - with "tables" as - ( - SELECT table_name as "name" - FROM duckdb_tables - where in_search_path(database_name, schema_name) - ), "views" as - ( - SELECT view_name as "name" - FROM duckdb_views - where in_search_path(database_name, schema_name) - ), db_objects as - ( - SELECT "name" FROM "tables" - UNION ALL - SELECT "name" FROM "views" - ) - SELECT "name" - FROM db_objects - ORDER BY "name";)EOF"; - // clang-format on -} - -string PragmaShowTables(ClientContext &context, const FunctionParameters ¶meters) { - return PragmaShowTables(); -} - -string PragmaShowTablesExpanded() { - return R"( - SELECT - t.database_name AS database, - t.schema_name AS schema, - t.table_name AS name, - LIST(c.column_name order by c.column_index) AS column_names, - LIST(c.data_type order by c.column_index) AS column_types, - FIRST(t.temporary) AS temporary, - FROM duckdb_tables t - JOIN duckdb_columns c - USING (table_oid) - GROUP BY database, schema, name - - UNION ALL - - SELECT - v.database_name AS database, - v.schema_name AS schema, - v.view_name AS name, - LIST(c.column_name order by c.column_index) AS column_names, - LIST(c.data_type order by c.column_index) AS column_types, - FIRST(v.temporary) AS temporary, - FROM duckdb_views v - JOIN duckdb_columns c - ON (v.view_oid=c.table_oid) - GROUP BY database, schema, name - - ORDER BY database, schema, name - )"; -} - -string PragmaShowTablesExpanded(ClientContext &context, const FunctionParameters ¶meters) { - return PragmaShowTablesExpanded(); -} - -string PragmaShowDatabases() { - return "SELECT database_name FROM duckdb_databases() WHERE NOT internal ORDER BY database_name;"; -} - -string PragmaShowDatabases(ClientContext &context, const FunctionParameters ¶meters) { - return PragmaShowDatabases(); -} - -string PragmaShowVariables() { - return "SELECT * FROM duckdb_variables() ORDER BY name"; -} -string PragmaAllProfiling(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_last_profiling_output() JOIN pragma_detailed_profiling_output() ON " - "(pragma_last_profiling_output.operator_id);"; -} - -string PragmaDatabaseList(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_database_list;"; -} - -string PragmaCollations(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_collations() ORDER BY 1;"; -} - -string PragmaFunctionsQuery(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT function_name AS name, upper(function_type) AS type, parameter_types AS parameters, varargs, " - "return_type, has_side_effects AS side_effects" - " FROM duckdb_functions()" - " WHERE function_type IN ('scalar', 'aggregate')" - " ORDER BY 1;"; -} - -string PragmaShow(const string &table_name) { - return StringUtil::Format("SELECT * FROM pragma_show(%s);", KeywordHelper::WriteQuoted(table_name, '\'')); -} - -string PragmaShow(ClientContext &context, const FunctionParameters ¶meters) { - return PragmaShow(parameters.values[0].ToString()); -} - -string PragmaVersion(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_version();"; -} - -string PragmaExtensionVersions(ClientContext &context, const FunctionParameters ¶meters) { - return "select extension_name, extension_version, install_mode, installed_from from duckdb_extensions() where " - "installed"; -} - -string PragmaPlatform(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_platform();"; -} - -string PragmaImportDatabase(ClientContext &context, const FunctionParameters ¶meters) { - auto &fs = FileSystem::GetFileSystem(context); - - string final_query; - // read the "schema.sql" and "load.sql" files - vector files = {"schema.sql", "load.sql"}; - for (auto &file : files) { - auto file_path = fs.JoinPath(parameters.values[0].ToString(), file); - auto handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_READ); - auto fsize = fs.GetFileSize(*handle); - auto buffer = make_unsafe_uniq_array(UnsafeNumericCast(fsize)); - fs.Read(*handle, buffer.get(), fsize); - auto query = string(buffer.get(), UnsafeNumericCast(fsize)); - // Replace the placeholder with the path provided to IMPORT - if (file == "load.sql") { - Parser parser; - parser.ParseQuery(query); - auto copy_statements = std::move(parser.statements); - query.clear(); - for (auto &statement_p : copy_statements) { - D_ASSERT(statement_p->type == StatementType::COPY_STATEMENT); - auto &statement = statement_p->Cast(); - auto &info = *statement.info; - auto file_name = fs.ExtractName(info.file_path); - info.file_path = fs.JoinPath(parameters.values[0].ToString(), file_name); - query += statement.ToString() + ";"; - } - } - final_query += query; - } - return final_query; -} - -string PragmaCopyDatabase(ClientContext &context, const FunctionParameters ¶meters) { - string copy_stmt = "COPY FROM DATABASE "; - copy_stmt += KeywordHelper::WriteOptionallyQuoted(parameters.values[0].ToString()); - copy_stmt += " TO "; - copy_stmt += KeywordHelper::WriteOptionallyQuoted(parameters.values[1].ToString()); - string final_query; - final_query += copy_stmt + " (SCHEMA);\n"; - final_query += copy_stmt + " (DATA);"; - return final_query; -} - -string PragmaDatabaseSize(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_database_size();"; -} - -string PragmaStorageInfo(ClientContext &context, const FunctionParameters ¶meters) { - return StringUtil::Format("SELECT * FROM pragma_storage_info('%s');", parameters.values[0].ToString()); -} - -string PragmaMetadataInfo(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_metadata_info();"; -} - -string PragmaUserAgent(ClientContext &context, const FunctionParameters ¶meters) { - return "SELECT * FROM pragma_user_agent()"; -} - -void PragmaQueries::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(PragmaFunction::PragmaCall("table_info", PragmaTableInfo, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaCall("storage_info", PragmaStorageInfo, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaCall("metadata_info", PragmaMetadataInfo, {})); - set.AddFunction(PragmaFunction::PragmaStatement("show_tables", PragmaShowTables)); - set.AddFunction(PragmaFunction::PragmaStatement("show_tables_expanded", PragmaShowTablesExpanded)); - set.AddFunction(PragmaFunction::PragmaStatement("show_databases", PragmaShowDatabases)); - set.AddFunction(PragmaFunction::PragmaStatement("database_list", PragmaDatabaseList)); - set.AddFunction(PragmaFunction::PragmaStatement("collations", PragmaCollations)); - set.AddFunction(PragmaFunction::PragmaCall("show", PragmaShow, {LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaStatement("version", PragmaVersion)); - set.AddFunction(PragmaFunction::PragmaStatement("extension_versions", PragmaExtensionVersions)); - set.AddFunction(PragmaFunction::PragmaStatement("platform", PragmaPlatform)); - set.AddFunction(PragmaFunction::PragmaStatement("database_size", PragmaDatabaseSize)); - set.AddFunction(PragmaFunction::PragmaStatement("functions", PragmaFunctionsQuery)); - set.AddFunction(PragmaFunction::PragmaCall("import_database", PragmaImportDatabase, {LogicalType::VARCHAR})); - set.AddFunction( - PragmaFunction::PragmaCall("copy_database", PragmaCopyDatabase, {LogicalType::VARCHAR, LogicalType::VARCHAR})); - set.AddFunction(PragmaFunction::PragmaStatement("all_profiling_output", PragmaAllProfiling)); - set.AddFunction(PragmaFunction::PragmaStatement("user_agent", PragmaUserAgent)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/pragma_function.cpp b/src/duckdb/src/function/pragma_function.cpp deleted file mode 100644 index 531cdfb43..000000000 --- a/src/duckdb/src/function/pragma_function.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "duckdb/function/pragma_function.hpp" -#include "duckdb/common/string_util.hpp" - -namespace duckdb { - -PragmaFunction::PragmaFunction(string name, PragmaType pragma_type, pragma_query_t query, pragma_function_t function, - vector arguments, LogicalType varargs) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments), std::move(varargs)), type(pragma_type), - query(query), function(function) { -} - -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_query_t query, vector arguments, - LogicalType varargs) { - return PragmaFunction(name, PragmaType::PRAGMA_CALL, query, nullptr, std::move(arguments), std::move(varargs)); -} - -PragmaFunction PragmaFunction::PragmaCall(const string &name, pragma_function_t function, vector arguments, - LogicalType varargs) { - return PragmaFunction(name, PragmaType::PRAGMA_CALL, nullptr, function, std::move(arguments), std::move(varargs)); -} - -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_query_t query) { - vector types; - return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, query, nullptr, std::move(types), LogicalType::INVALID); -} - -PragmaFunction PragmaFunction::PragmaStatement(const string &name, pragma_function_t function) { - vector types; - return PragmaFunction(name, PragmaType::PRAGMA_STATEMENT, nullptr, function, std::move(types), - LogicalType::INVALID); -} - -string PragmaFunction::ToString() const { - switch (type) { - case PragmaType::PRAGMA_STATEMENT: - return StringUtil::Format("PRAGMA %s", name); - case PragmaType::PRAGMA_CALL: { - return StringUtil::Format("PRAGMA %s", SimpleNamedParameterFunction::ToString()); - } - default: - return "UNKNOWN"; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/register_function_list.cpp b/src/duckdb/src/function/register_function_list.cpp deleted file mode 100644 index 1aeebbbbb..000000000 --- a/src/duckdb/src/function/register_function_list.cpp +++ /dev/null @@ -1,54 +0,0 @@ -#include "duckdb/catalog/default/default_types.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/function_list.hpp" -#include "duckdb/function/register_function_list_helper.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" - -namespace duckdb { - -template -static void FillExtraInfo(const StaticFunctionDefinition &function, T &info) { - info.internal = true; - FillFunctionDescriptions(function, info); -} - -static void RegisterFunctionList(Catalog &catalog, CatalogTransaction transaction, - const StaticFunctionDefinition *functions) { - for (idx_t i = 0; functions[i].name; i++) { - auto &function = functions[i]; - if (function.get_function || function.get_function_set) { - // scalar function - ScalarFunctionSet result; - if (function.get_function) { - result.AddFunction(function.get_function()); - } else { - result = function.get_function_set(); - } - result.name = function.name; - CreateScalarFunctionInfo info(result); - FillExtraInfo(function, info); - catalog.CreateFunction(transaction, info); - } else if (function.get_aggregate_function || function.get_aggregate_function_set) { - // aggregate function - AggregateFunctionSet result; - if (function.get_aggregate_function) { - result.AddFunction(function.get_aggregate_function()); - } else { - result = function.get_aggregate_function_set(); - } - result.name = function.name; - CreateAggregateFunctionInfo info(result); - FillExtraInfo(function, info); - catalog.CreateFunction(transaction, info); - } else { - throw InternalException("Do not know how to register function of this type"); - } - } -} - -void FunctionList::RegisterFunctions(Catalog &catalog, CatalogTransaction transaction) { - RegisterFunctionList(catalog, transaction, FunctionList::GetInternalFunctionList()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp deleted file mode 100644 index 78410a59f..000000000 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_integral.cpp +++ /dev/null @@ -1,282 +0,0 @@ -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/function/scalar/compressed_materialization_functions.hpp" -#include "duckdb/function/scalar/compressed_materialization_utils.hpp" - -namespace duckdb { - -static string IntegralCompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_compress_integral_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -struct TemplatedIntegralCompress { - static inline RESULT_TYPE Operation(const INPUT_TYPE &input, const INPUT_TYPE &min_val) { - D_ASSERT(min_val <= input); - return UnsafeNumericCast(input - min_val); - } -}; - -template -struct TemplatedIntegralCompress { - static inline RESULT_TYPE Operation(const hugeint_t &input, const hugeint_t &min_val) { - D_ASSERT(min_val <= input); - return UnsafeNumericCast((input - min_val).lower); - } -}; - -template -struct TemplatedIntegralCompress { - static inline RESULT_TYPE Operation(const uhugeint_t &input, const uhugeint_t &min_val) { - D_ASSERT(min_val <= input); - return UnsafeNumericCast((input - min_val).lower); - } -}; - -template -static void IntegralCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); - const auto min_val = ConstantVector::GetData(args.data[1])[0]; - UnaryExecutor::Execute( - args.data[0], result, args.size(), - [&](const INPUT_TYPE &input) { - return TemplatedIntegralCompress::Operation(input, min_val); - }, - FunctionErrors::CANNOT_ERROR); -} - -template -static scalar_function_t GetIntegralCompressFunction(const LogicalType &input_type, const LogicalType &result_type) { - return IntegralCompressFunction; -} - -template -static scalar_function_t GetIntegralCompressFunctionResultSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::UTINYINT: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralCompressFunction(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralCompressFunction(input_type, result_type); - default: - throw InternalException("Unexpected result type in GetIntegralCompressFunctionResultSwitch"); - } -} - -static scalar_function_t GetIntegralCompressFunctionInputSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (input_type.id()) { - case LogicalTypeId::SMALLINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::INTEGER: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::BIGINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::HUGEINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UHUGEINT: - return GetIntegralCompressFunctionResultSwitch(input_type, result_type); - default: - throw InternalException("Unexpected input type in GetIntegralCompressFunctionInputSwitch"); - } -} - -static string IntegralDecompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_decompress_integral_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -struct TemplatedIntegralDecompress { - static inline RESULT_TYPE Operation(const INPUT_TYPE &input, const RESULT_TYPE &min_val) { - return min_val + UnsafeNumericCast(input); - } -}; - -template -struct TemplatedIntegralDecompress { - static inline hugeint_t Operation(const INPUT_TYPE &input, const hugeint_t &min_val) { - return min_val + hugeint_t(0, input); - } -}; - -template -struct TemplatedIntegralDecompress { - static inline uhugeint_t Operation(const INPUT_TYPE &input, const uhugeint_t &min_val) { - return min_val + uhugeint_t(0, input); - } -}; - -template -static void IntegralDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - D_ASSERT(args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR); - D_ASSERT(args.data[1].GetType() == result.GetType()); - const auto min_val = ConstantVector::GetData(args.data[1])[0]; - UnaryExecutor::Execute( - args.data[0], result, args.size(), - [&](const INPUT_TYPE &input) { - return TemplatedIntegralDecompress::Operation(input, min_val); - }, - FunctionErrors::CANNOT_ERROR); -} - -template -static scalar_function_t GetIntegralDecompressFunction(const LogicalType &input_type, const LogicalType &result_type) { - return IntegralDecompressFunction; -} - -template -static scalar_function_t GetIntegralDecompressFunctionResultSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::SMALLINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::INTEGER: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::BIGINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::HUGEINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralDecompressFunction(input_type, result_type); - case LogicalTypeId::UHUGEINT: - return GetIntegralDecompressFunction(input_type, result_type); - default: - throw InternalException("Unexpected input type in GetIntegralDecompressFunctionSetSwitch"); - } -} - -static scalar_function_t GetIntegralDecompressFunctionInputSwitch(const LogicalType &input_type, - const LogicalType &result_type) { - switch (input_type.id()) { - case LogicalTypeId::UTINYINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::USMALLINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UINTEGER: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - case LogicalTypeId::UBIGINT: - return GetIntegralDecompressFunctionResultSwitch(input_type, result_type); - default: - throw InternalException("Unexpected result type in GetIntegralDecompressFunctionInputSwitch"); - } -} - -static void CMIntegralSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); -} - -template -unique_ptr CMIntegralDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GET_FUNCTION(function.arguments[0], return_type); - return nullptr; -} - -ScalarFunction CMIntegralCompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralCompressFunctionName(result_type), {input_type, input_type}, result_type, - GetIntegralCompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; - return result; -} - -static ScalarFunctionSet GetIntegralCompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralCompressFunctionName(result_type)); - for (const auto &input_type : LogicalType::Integral()) { - if (GetTypeIdSize(result_type.InternalType()) < GetTypeIdSize(input_type.InternalType())) { - set.AddFunction(CMIntegralCompressFun::GetFunction(input_type, result_type)); - } - } - return set; -} - -ScalarFunction CMIntegralDecompressFun::GetFunction(const LogicalType &input_type, const LogicalType &result_type) { - ScalarFunction result(IntegralDecompressFunctionName(result_type), {input_type, result_type}, result_type, - GetIntegralDecompressFunctionInputSwitch(input_type, result_type), CMUtils::Bind); - result.serialize = CMIntegralSerialize; - result.deserialize = CMIntegralDeserialize; - return result; -} - -static ScalarFunctionSet GetIntegralDecompressFunctionSet(const LogicalType &result_type) { - ScalarFunctionSet set(IntegralDecompressFunctionName(result_type)); - for (const auto &input_type : CMUtils::IntegralTypes()) { - if (GetTypeIdSize(result_type.InternalType()) > GetTypeIdSize(input_type.InternalType())) { - set.AddFunction(CMIntegralDecompressFun::GetFunction(input_type, result_type)); - } - } - return set; -} - -ScalarFunctionSet InternalCompressIntegralUtinyintFun::GetFunctions() { - return GetIntegralCompressFunctionSet(LogicalType(LogicalTypeId::UTINYINT)); -} - -ScalarFunctionSet InternalCompressIntegralUsmallintFun::GetFunctions() { - return GetIntegralCompressFunctionSet(LogicalType(LogicalTypeId::USMALLINT)); -} - -ScalarFunctionSet InternalCompressIntegralUintegerFun::GetFunctions() { - return GetIntegralCompressFunctionSet(LogicalType(LogicalTypeId::UINTEGER)); -} - -ScalarFunctionSet InternalCompressIntegralUbigintFun::GetFunctions() { - return GetIntegralCompressFunctionSet(LogicalType(LogicalTypeId::UBIGINT)); -} - -ScalarFunctionSet InternalDecompressIntegralSmallintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::SMALLINT)); -} - -ScalarFunctionSet InternalDecompressIntegralIntegerFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::INTEGER)); -} - -ScalarFunctionSet InternalDecompressIntegralBigintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::BIGINT)); -} - -ScalarFunctionSet InternalDecompressIntegralHugeintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::HUGEINT)); -} - -ScalarFunctionSet InternalDecompressIntegralUsmallintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::USMALLINT)); -} - -ScalarFunctionSet InternalDecompressIntegralUintegerFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::UINTEGER)); -} - -ScalarFunctionSet InternalDecompressIntegralUbigintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::UBIGINT)); -} - -ScalarFunctionSet InternalDecompressIntegralUhugeintFun::GetFunctions() { - return GetIntegralDecompressFunctionSet(LogicalType(LogicalTypeId::UHUGEINT)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp b/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp deleted file mode 100644 index d1d0734d1..000000000 --- a/src/duckdb/src/function/scalar/compressed_materialization/compress_string.cpp +++ /dev/null @@ -1,268 +0,0 @@ -#include "duckdb/common/bswap.hpp" -#include "duckdb/function/scalar/compressed_materialization_functions.hpp" -#include "duckdb/function/scalar/compressed_materialization_utils.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/serializer/deserializer.hpp" - -namespace duckdb { - -static string StringCompressFunctionName(const LogicalType &result_type) { - return StringUtil::Format("__internal_compress_string_%s", - StringUtil::Lower(LogicalTypeIdToString(result_type.id()))); -} - -template -static inline void TemplatedReverseMemCpy(const data_ptr_t &__restrict dest, const const_data_ptr_t &__restrict src) { - for (idx_t i = 0; i < LENGTH; i++) { - dest[i] = src[LENGTH - 1 - i]; - } -} - -static inline void ReverseMemCpy(const data_ptr_t &__restrict dest, const const_data_ptr_t &__restrict src, - const idx_t &length) { - for (idx_t i = 0; i < length; i++) { - dest[i] = src[length - 1 - i]; - } -} - -template -static inline RESULT_TYPE StringCompressInternal(const string_t &input) { - RESULT_TYPE result; - const auto result_ptr = data_ptr_cast(&result); - if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { - TemplatedReverseMemCpy(result_ptr, const_data_ptr_cast(input.GetPrefix())); - } else if (input.IsInlined()) { - static constexpr auto REMAINDER = sizeof(RESULT_TYPE) - string_t::INLINE_LENGTH; - TemplatedReverseMemCpy(result_ptr + REMAINDER, const_data_ptr_cast(input.GetPrefix())); - memset(result_ptr, '\0', REMAINDER); - } else { - const auto remainder = sizeof(RESULT_TYPE) - input.GetSize(); - ReverseMemCpy(result_ptr + remainder, data_ptr_cast(input.GetPointer()), input.GetSize()); - memset(result_ptr, '\0', remainder); - } - result_ptr[0] = UnsafeNumericCast(input.GetSize()); - return result; -} - -template -static inline RESULT_TYPE StringCompress(const string_t &input) { - D_ASSERT(input.GetSize() < sizeof(RESULT_TYPE)); - return StringCompressInternal(input); -} - -template -static inline RESULT_TYPE MiniStringCompress(const string_t &input) { - if (sizeof(RESULT_TYPE) <= string_t::INLINE_LENGTH) { - return UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPrefix())); - } else if (input.GetSize() == 0) { - return 0; - } else { - return UnsafeNumericCast(input.GetSize() + *const_data_ptr_cast(input.GetPointer())); - } -} - -template <> -inline uint8_t StringCompress(const string_t &input) { - D_ASSERT(input.GetSize() <= sizeof(uint8_t)); - return MiniStringCompress(input); -} - -template -static void StringCompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::Execute(args.data[0], result, args.size(), StringCompress, - FunctionErrors::CANNOT_ERROR); -} - -template -static scalar_function_t GetStringCompressFunction(const LogicalType &result_type) { - return StringCompressFunction; -} - -static scalar_function_t GetStringCompressFunctionSwitch(const LogicalType &result_type) { - switch (result_type.id()) { - case LogicalTypeId::UTINYINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::USMALLINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::UINTEGER: - return GetStringCompressFunction(result_type); - case LogicalTypeId::UBIGINT: - return GetStringCompressFunction(result_type); - case LogicalTypeId::HUGEINT: - return GetStringCompressFunction(result_type); - default: - throw InternalException("Unexpected type in GetStringCompressFunctionSwitch"); - } -} - -static string StringDecompressFunctionName() { - return "__internal_decompress_string"; -} - -struct StringDecompressLocalState : public FunctionLocalState { -public: - explicit StringDecompressLocalState(ClientContext &context) : allocator(Allocator::Get(context)) { - } - - static unique_ptr Init(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - return make_uniq(state.GetContext()); - } - -public: - ArenaAllocator allocator; -}; - -template -static inline string_t StringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - const auto input_ptr = const_data_ptr_cast(&input); - string_t result(input_ptr[0]); - if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { - const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr); - memset(result_ptr + sizeof(INPUT_TYPE) - 1, '\0', string_t::INLINE_LENGTH - sizeof(INPUT_TYPE) + 1); - } else if (result.GetSize() <= string_t::INLINE_LENGTH) { - static constexpr auto REMAINDER = sizeof(INPUT_TYPE) - string_t::INLINE_LENGTH; - const auto result_ptr = data_ptr_cast(result.GetPrefixWriteable()); - TemplatedReverseMemCpy(result_ptr, input_ptr + REMAINDER); - } else { - result.SetPointer(char_ptr_cast(allocator.Allocate(sizeof(INPUT_TYPE)))); - TemplatedReverseMemCpy(data_ptr_cast(result.GetPointer()), input_ptr); - memcpy(result.GetPrefixWriteable(), result.GetPointer(), string_t::PREFIX_LENGTH); - } - return result; -} - -template -static inline string_t MiniStringDecompress(const INPUT_TYPE &input, ArenaAllocator &allocator) { - if (input == 0) { - string_t result(uint32_t(0)); - memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); - return result; - } - - string_t result(1); - if (sizeof(INPUT_TYPE) <= string_t::INLINE_LENGTH) { - memset(result.GetPrefixWriteable(), '\0', string_t::INLINE_BYTES); - *data_ptr_cast(result.GetPrefixWriteable()) = input - 1; - } else { - result.SetPointer(char_ptr_cast(allocator.Allocate(1))); - *data_ptr_cast(result.GetPointer()) = input - 1; - memset(result.GetPrefixWriteable(), '\0', string_t::PREFIX_LENGTH); - *result.GetPrefixWriteable() = *result.GetPointer(); - } - return result; -} - -template <> -inline string_t StringDecompress(const uint8_t &input, ArenaAllocator &allocator) { - return MiniStringDecompress(input, allocator); -} - -template -static void StringDecompressFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &allocator = ExecuteFunctionState::GetFunctionState(state)->Cast().allocator; - allocator.Reset(); - UnaryExecutor::Execute( - args.data[0], result, args.size(), - [&](const INPUT_TYPE &input) { return StringDecompress(input, allocator); }, - FunctionErrors::CANNOT_ERROR); -} - -template -static scalar_function_t GetStringDecompressFunction(const LogicalType &input_type) { - return StringDecompressFunction; -} - -static scalar_function_t GetStringDecompressFunctionSwitch(const LogicalType &input_type) { - switch (input_type.id()) { - case LogicalTypeId::UTINYINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::USMALLINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::UINTEGER: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::UBIGINT: - return GetStringDecompressFunction(input_type); - case LogicalTypeId::HUGEINT: - return GetStringDecompressFunction(input_type); - default: - throw InternalException("Unexpected type in GetStringDecompressFunctionSwitch"); - } -} - -static void CMStringCompressSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); - serializer.WriteProperty(101, "return_type", function.return_type); -} - -unique_ptr CMStringCompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - function.function = GetStringCompressFunctionSwitch(return_type); - return nullptr; -} - -ScalarFunction CMStringCompressFun::GetFunction(const LogicalType &result_type) { - ScalarFunction result(StringCompressFunctionName(result_type), {LogicalType::VARCHAR}, result_type, - GetStringCompressFunctionSwitch(result_type), CMUtils::Bind); - result.serialize = CMStringCompressSerialize; - result.deserialize = CMStringCompressDeserialize; - return result; -} - -static void CMStringDecompressSerialize(Serializer &serializer, const optional_ptr bind_data, - const ScalarFunction &function) { - serializer.WriteProperty(100, "arguments", function.arguments); -} - -unique_ptr CMStringDecompressDeserialize(Deserializer &deserializer, ScalarFunction &function) { - function.arguments = deserializer.ReadProperty>(100, "arguments"); - function.function = GetStringDecompressFunctionSwitch(function.arguments[0]); - function.return_type = deserializer.Get(); - return nullptr; -} - -ScalarFunction CMStringDecompressFun::GetFunction(const LogicalType &input_type) { - ScalarFunction result(StringDecompressFunctionName(), {input_type}, LogicalType::VARCHAR, - GetStringDecompressFunctionSwitch(input_type), CMUtils::Bind, nullptr, nullptr, - StringDecompressLocalState::Init); - result.serialize = CMStringDecompressSerialize; - result.deserialize = CMStringDecompressDeserialize; - return result; -} - -static ScalarFunctionSet GetStringDecompressFunctionSet() { - ScalarFunctionSet set(StringDecompressFunctionName()); - for (const auto &input_type : CMUtils::StringTypes()) { - set.AddFunction(CMStringDecompressFun::GetFunction(input_type)); - } - return set; -} - -ScalarFunction InternalCompressStringUtinyintFun::GetFunction() { - return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::UTINYINT)); -} - -ScalarFunction InternalCompressStringUsmallintFun::GetFunction() { - return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::USMALLINT)); -} - -ScalarFunction InternalCompressStringUintegerFun::GetFunction() { - return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::UINTEGER)); -} - -ScalarFunction InternalCompressStringUbigintFun::GetFunction() { - return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::UBIGINT)); -} - -ScalarFunction InternalCompressStringHugeintFun::GetFunction() { - return CMStringCompressFun::GetFunction(LogicalType(LogicalTypeId::HUGEINT)); -} - -ScalarFunctionSet InternalDecompressStringFun::GetFunctions() { - return GetStringDecompressFunctionSet(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/compressed_materialization_utils.cpp b/src/duckdb/src/function/scalar/compressed_materialization_utils.cpp deleted file mode 100644 index 2d09a7e7f..000000000 --- a/src/duckdb/src/function/scalar/compressed_materialization_utils.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "duckdb/function/scalar/compressed_materialization_utils.hpp" - -namespace duckdb { - -const vector CMUtils::IntegralTypes() { - return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT}; -} - -const vector CMUtils::StringTypes() { - return {LogicalType::UTINYINT, LogicalType::USMALLINT, LogicalType::UINTEGER, LogicalType::UBIGINT, - LogicalType::HUGEINT}; -} - -// LCOV_EXCL_START -unique_ptr CMUtils::Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - throw BinderException("Compressed materialization functions are for internal use only!"); -} -// LCOV_EXCL_STOP - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/create_sort_key.cpp b/src/duckdb/src/function/scalar/create_sort_key.cpp deleted file mode 100644 index 1cb2acdde..000000000 --- a/src/duckdb/src/function/scalar/create_sort_key.cpp +++ /dev/null @@ -1,1037 +0,0 @@ -#include "duckdb/function/create_sort_key.hpp" - -#include "duckdb/common/enums/order_type.hpp" -#include "duckdb/common/radix.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" - -namespace duckdb { - -struct CreateSortKeyBindData : public FunctionData { - vector modifiers; - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return modifiers == other.modifiers; - } - unique_ptr Copy() const override { - auto result = make_uniq(); - result->modifiers = modifiers; - return std::move(result); - } -}; - -unique_ptr CreateSortKeyBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments.size() % 2 != 0) { - throw BinderException( - "Arguments to create_sort_key must be [key1, sort_specifier1, key2, sort_specifier2, ...]"); - } - auto result = make_uniq(); - for (idx_t i = 1; i < arguments.size(); i += 2) { - if (!arguments[i]->IsFoldable()) { - throw BinderException("sort_specifier must be a constant value - but got %s", arguments[i]->ToString()); - } - - // Rebind to return a date if we are truncating that far - Value sort_specifier = ExpressionExecutor::EvaluateScalar(context, *arguments[i]); - if (sort_specifier.IsNull()) { - throw BinderException("sort_specifier cannot be NULL"); - } - auto sort_specifier_str = sort_specifier.ToString(); - result->modifiers.push_back(OrderModifiers::Parse(sort_specifier_str)); - } - // push collations - for (idx_t i = 0; i < arguments.size(); i += 2) { - ExpressionBinder::PushCollation(context, arguments[i], arguments[i]->return_type); - } - // check if all types are constant - bool all_constant = true; - idx_t constant_size = 0; - for (idx_t i = 0; i < arguments.size(); i += 2) { - auto physical_type = arguments[i]->return_type.InternalType(); - if (!TypeIsConstantSize(physical_type)) { - all_constant = false; - } else { - // we always add one byte for the validity - constant_size += GetTypeIdSize(physical_type) + 1; - } - } - if (all_constant) { - if (constant_size <= sizeof(int64_t)) { - bound_function.return_type = LogicalType::BIGINT; - } - } - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Operators -//===--------------------------------------------------------------------===// -struct SortKeyVectorData { - static constexpr data_t NULL_FIRST_BYTE = 1; - static constexpr data_t NULL_LAST_BYTE = 2; - static constexpr data_t STRING_DELIMITER = 0; - static constexpr data_t LIST_DELIMITER = 0; - static constexpr data_t BLOB_ESCAPE_CHARACTER = 1; - - SortKeyVectorData(Vector &input, idx_t size, OrderModifiers modifiers) : vec(input) { - if (size != 0) { - input.ToUnifiedFormat(size, format); - } - this->size = size; - - null_byte = NULL_FIRST_BYTE; - valid_byte = NULL_LAST_BYTE; - if (modifiers.null_type == OrderByNullType::NULLS_LAST) { - std::swap(null_byte, valid_byte); - } - - // NULLS FIRST/NULLS LAST passed in by the user are only respected at the top level - // within nested types NULLS LAST/NULLS FIRST is dependent on ASC/DESC order instead - // don't blame me this is what Postgres does - auto child_null_type = - modifiers.order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; - OrderModifiers child_modifiers(modifiers.order_type, child_null_type); - switch (input.GetType().InternalType()) { - case PhysicalType::STRUCT: { - auto &children = StructVector::GetEntries(input); - for (auto &child : children) { - child_data.push_back(make_uniq(*child, size, child_modifiers)); - } - break; - } - case PhysicalType::ARRAY: { - auto &child_entry = ArrayVector::GetEntry(input); - auto array_size = ArrayType::GetSize(input.GetType()); - child_data.push_back(make_uniq(child_entry, size * array_size, child_modifiers)); - break; - } - case PhysicalType::LIST: { - auto &child_entry = ListVector::GetEntry(input); - auto child_size = size == 0 ? 0 : ListVector::GetListSize(input); - child_data.push_back(make_uniq(child_entry, child_size, child_modifiers)); - break; - } - default: - break; - } - } - // disable copy constructors - SortKeyVectorData(const SortKeyVectorData &other) = delete; - SortKeyVectorData &operator=(const SortKeyVectorData &) = delete; - - void Initialize() { - } - - PhysicalType GetPhysicalType() { - return vec.GetType().InternalType(); - } - - Vector &vec; - idx_t size; - UnifiedVectorFormat format; - vector> child_data; - data_t null_byte; - data_t valid_byte; -}; - -template -struct SortKeyConstantOperator { - using TYPE = T; - - static idx_t GetEncodeLength(TYPE input) { - return sizeof(T); - } - - static idx_t Encode(data_ptr_t result, TYPE input) { - Radix::EncodeData(result, input); - return sizeof(T); - } - - static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { - auto result_data = FlatVector::GetData(result); - if (flip_bytes) { - // descending order - so flip bytes - data_t flipped_bytes[sizeof(T)]; - for (idx_t b = 0; b < sizeof(T); b++) { - flipped_bytes[b] = ~input[b]; - } - result_data[result_idx] = Radix::DecodeData(flipped_bytes); - } else { - result_data[result_idx] = Radix::DecodeData(input); - } - return sizeof(T); - } -}; - -struct SortKeyVarcharOperator { - using TYPE = string_t; - - static idx_t GetEncodeLength(TYPE input) { - return input.GetSize() + 1; - } - - static idx_t Encode(data_ptr_t result, TYPE input) { - auto input_data = const_data_ptr_cast(input.GetDataUnsafe()); - auto input_size = input.GetSize(); - for (idx_t r = 0; r < input_size; r++) { - result[r] = input_data[r] + 1; - } - result[input_size] = SortKeyVectorData::STRING_DELIMITER; // null-byte delimiter - return input_size + 1; - } - - static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { - auto result_data = FlatVector::GetData(result); - // iterate until we encounter the string delimiter to figure out the string length - data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; - if (flip_bytes) { - string_delimiter = ~string_delimiter; - } - idx_t pos; - for (pos = 0; input[pos] != string_delimiter; pos++) { - } - idx_t str_len = pos; - // now allocate the string data and fill it with the decoded data - result_data[result_idx] = StringVector::EmptyString(result, str_len); - auto str_data = data_ptr_cast(result_data[result_idx].GetDataWriteable()); - for (pos = 0; pos < str_len; pos++) { - if (flip_bytes) { - str_data[pos] = (~input[pos]) - 1; - } else { - str_data[pos] = input[pos] - 1; - } - } - result_data[result_idx].Finalize(); - return pos + 1; - } -}; - -struct SortKeyBlobOperator { - using TYPE = string_t; - - static idx_t GetEncodeLength(TYPE input) { - auto input_data = data_ptr_t(input.GetDataUnsafe()); - auto input_size = input.GetSize(); - idx_t escaped_characters = 0; - for (idx_t r = 0; r < input_size; r++) { - if (input_data[r] <= 1) { - // we escape both \x00 and \x01 - escaped_characters++; - } - } - return input.GetSize() + escaped_characters + 1; - } - - static idx_t Encode(data_ptr_t result, TYPE input) { - auto input_data = data_ptr_t(input.GetDataUnsafe()); - auto input_size = input.GetSize(); - idx_t result_offset = 0; - for (idx_t r = 0; r < input_size; r++) { - if (input_data[r] <= 1) { - // we escape both \x00 and \x01 with \x01 - result[result_offset++] = SortKeyVectorData::BLOB_ESCAPE_CHARACTER; - result[result_offset++] = input_data[r]; - } else { - result[result_offset++] = input_data[r]; - } - } - result[result_offset++] = SortKeyVectorData::STRING_DELIMITER; // null-byte delimiter - return result_offset; - } - - static idx_t Decode(const_data_ptr_t input, Vector &result, idx_t result_idx, bool flip_bytes) { - auto result_data = FlatVector::GetData(result); - // scan until we find the delimiter, keeping in mind escapes - data_t string_delimiter = SortKeyVectorData::STRING_DELIMITER; - data_t escape_character = SortKeyVectorData::BLOB_ESCAPE_CHARACTER; - if (flip_bytes) { - string_delimiter = ~string_delimiter; - escape_character = ~escape_character; - } - idx_t blob_len = 0; - idx_t pos; - for (pos = 0; input[pos] != string_delimiter; pos++) { - blob_len++; - if (input[pos] == escape_character) { - // escape character - skip the next byte - pos++; - } - } - // now allocate the blob data and fill it with the decoded data - result_data[result_idx] = StringVector::EmptyString(result, blob_len); - auto str_data = data_ptr_cast(result_data[result_idx].GetDataWriteable()); - for (idx_t input_pos = 0, result_pos = 0; input_pos < pos; input_pos++) { - if (input[input_pos] == escape_character) { - // if we encounter an escape character - copy the NEXT byte - input_pos++; - } - if (flip_bytes) { - str_data[result_pos++] = ~input[input_pos]; - } else { - str_data[result_pos++] = input[input_pos]; - } - } - result_data[result_idx].Finalize(); - return pos + 1; - } -}; - -struct SortKeyListEntry { - static bool IsArray() { - return false; - } - - static list_entry_t GetListEntry(SortKeyVectorData &vector_data, idx_t idx) { - auto data = UnifiedVectorFormat::GetData(vector_data.format); - return data[idx]; - } -}; - -struct SortKeyArrayEntry { - static bool IsArray() { - return true; - } - - static list_entry_t GetListEntry(SortKeyVectorData &vector_data, idx_t idx) { - auto array_size = ArrayType::GetSize(vector_data.vec.GetType()); - return list_entry_t(array_size * idx, array_size); - } -}; - -struct SortKeyChunk { - SortKeyChunk(idx_t start, idx_t end) : start(start), end(end), has_result_index(false) { - } - SortKeyChunk(idx_t start, idx_t end, idx_t result_index) - : start(start), end(end), result_index(result_index), has_result_index(true) { - } - - idx_t start; - idx_t end; - idx_t result_index; - bool has_result_index; - - inline idx_t GetResultIndex(idx_t r) { - return has_result_index ? result_index : r; - } -}; - -//===--------------------------------------------------------------------===// -// Get Sort Key Length -//===--------------------------------------------------------------------===// -struct SortKeyLengthInfo { - explicit SortKeyLengthInfo(idx_t size) : constant_length(0) { - variable_lengths.resize(size, 0); - } - - idx_t constant_length; - unsafe_vector variable_lengths; -}; - -static void GetSortKeyLengthRecursive(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyLengthInfo &result); - -template -void TemplatedGetSortKeyLength(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyLengthInfo &result) { - auto &format = vector_data.format; - auto data = UnifiedVectorFormat::GetData(vector_data.format); - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto idx = format.sel->get_index(r); - auto result_index = chunk.GetResultIndex(r); - result.variable_lengths[result_index]++; // every value is prefixed by a validity byte - - if (!format.validity.RowIsValid(idx)) { - continue; - } - result.variable_lengths[result_index] += OP::GetEncodeLength(data[idx]); - } -} - -void GetSortKeyLengthStruct(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyLengthInfo &result) { - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto result_index = chunk.GetResultIndex(r); - result.variable_lengths[result_index]++; // every struct is prefixed by a validity byte - } - // now recursively call GetSortKeyLength on the child elements - for (auto &child_data : vector_data.child_data) { - GetSortKeyLengthRecursive(*child_data, chunk, result); - } -} - -template -void GetSortKeyLengthList(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyLengthInfo &result) { - auto &child_data = vector_data.child_data[0]; - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto idx = vector_data.format.sel->get_index(r); - auto result_index = chunk.GetResultIndex(r); - result.variable_lengths[result_index]++; // every list is prefixed by a validity byte - - if (!vector_data.format.validity.RowIsValid(idx)) { - if (!OP::IsArray()) { - // for arrays we need to fill in the child vector for all elements, even if the top-level array is NULL - continue; - } - } - auto list_entry = OP::GetListEntry(vector_data, idx); - // for each non-null list we have an "end of list" delimiter - result.variable_lengths[result_index]++; - if (list_entry.length > 0) { - // recursively call GetSortKeyLength for the children of this list - SortKeyChunk child_chunk(list_entry.offset, list_entry.offset + list_entry.length, result_index); - GetSortKeyLengthRecursive(*child_data, child_chunk, result); - } - } -} - -static void GetSortKeyLengthRecursive(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyLengthInfo &result) { - auto physical_type = vector_data.GetPhysicalType(); - // handle variable lengths - switch (physical_type) { - case PhysicalType::BOOL: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::UINT8: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INT8: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::UINT16: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INT16: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::UINT32: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INT32: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::UINT64: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INT64: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::FLOAT: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::DOUBLE: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INTERVAL: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::UINT128: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::INT128: - TemplatedGetSortKeyLength>(vector_data, chunk, result); - break; - case PhysicalType::VARCHAR: - if (vector_data.vec.GetType().id() == LogicalTypeId::VARCHAR) { - TemplatedGetSortKeyLength(vector_data, chunk, result); - } else { - TemplatedGetSortKeyLength(vector_data, chunk, result); - } - break; - case PhysicalType::STRUCT: - GetSortKeyLengthStruct(vector_data, chunk, result); - break; - case PhysicalType::LIST: - GetSortKeyLengthList(vector_data, chunk, result); - break; - case PhysicalType::ARRAY: - GetSortKeyLengthList(vector_data, chunk, result); - break; - default: - throw NotImplementedException("Unsupported physical type %s in GetSortKeyLength", physical_type); - } -} - -static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo &result, SortKeyChunk chunk) { - // top-level method - auto physical_type = vector_data.GetPhysicalType(); - if (TypeIsConstantSize(physical_type)) { - // every row is prefixed by a validity byte - result.constant_length += 1; - result.constant_length += GetTypeIdSize(physical_type); - return; - } - GetSortKeyLengthRecursive(vector_data, chunk, result); -} - -static void GetSortKeyLength(SortKeyVectorData &vector_data, SortKeyLengthInfo &result) { - GetSortKeyLength(vector_data, result, SortKeyChunk(0, vector_data.size)); -} - -//===--------------------------------------------------------------------===// -// Construct Sort Key -//===--------------------------------------------------------------------===// -struct SortKeyConstructInfo { - SortKeyConstructInfo(OrderModifiers modifiers_p, unsafe_vector &offsets, data_ptr_t *result_data) - : modifiers(modifiers_p), offsets(offsets), result_data(result_data) { - flip_bytes = modifiers.order_type == OrderType::DESCENDING; - } - - OrderModifiers modifiers; - unsafe_vector &offsets; - data_ptr_t *result_data; - bool flip_bytes; -}; - -static void ConstructSortKeyRecursive(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyConstructInfo &info); - -template -void TemplatedConstructSortKey(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyConstructInfo &info) { - auto data = UnifiedVectorFormat::GetData(vector_data.format); - auto &offsets = info.offsets; - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto result_index = chunk.GetResultIndex(r); - auto idx = vector_data.format.sel->get_index(r); - auto &offset = offsets[result_index]; - auto result_ptr = info.result_data[result_index]; - if (!vector_data.format.validity.RowIsValid(idx)) { - // NULL value - write the null byte and skip - result_ptr[offset++] = vector_data.null_byte; - continue; - } - // valid value - write the validity byte - result_ptr[offset++] = vector_data.valid_byte; - idx_t encode_len = OP::Encode(result_ptr + offset, data[idx]); - if (info.flip_bytes) { - // descending order - so flip bytes - for (idx_t b = offset; b < offset + encode_len; b++) { - result_ptr[b] = ~result_ptr[b]; - } - } - offset += encode_len; - } -} - -void ConstructSortKeyStruct(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyConstructInfo &info) { - bool list_of_structs = chunk.has_result_index; - // write the validity data of the struct - auto &offsets = info.offsets; - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto result_index = chunk.GetResultIndex(r); - auto idx = vector_data.format.sel->get_index(r); - auto &offset = offsets[result_index]; - auto result_ptr = info.result_data[result_index]; - if (!vector_data.format.validity.RowIsValid(idx)) { - // NULL value - write the null byte and skip - result_ptr[offset++] = vector_data.null_byte; - } else { - // valid value - write the validity byte - result_ptr[offset++] = vector_data.valid_byte; - } - if (list_of_structs) { - // for a list of structs we need to write the child data for every iteration - // since the final layout needs to be - // [struct1][struct2][...] - for (auto &child : vector_data.child_data) { - SortKeyChunk child_chunk(r, r + 1, result_index); - ConstructSortKeyRecursive(*child, child_chunk, info); - } - } - } - if (!list_of_structs) { - for (auto &child : vector_data.child_data) { - ConstructSortKeyRecursive(*child, chunk, info); - } - } -} - -template -void ConstructSortKeyList(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyConstructInfo &info) { - auto &offsets = info.offsets; - for (idx_t r = chunk.start; r < chunk.end; r++) { - auto result_index = chunk.GetResultIndex(r); - auto idx = vector_data.format.sel->get_index(r); - auto &offset = offsets[result_index]; - auto result_ptr = info.result_data[result_index]; - if (!vector_data.format.validity.RowIsValid(idx)) { - // NULL value - write the null byte and skip - result_ptr[offset++] = vector_data.null_byte; - if (!OP::IsArray()) { - // for arrays we always write the child elements - also if the top-level array is NULL - continue; - } - } else { - // valid value - write the validity byte - result_ptr[offset++] = vector_data.valid_byte; - } - - auto list_entry = OP::GetListEntry(vector_data, idx); - // recurse and write the list elements - if (list_entry.length > 0) { - SortKeyChunk child_chunk(list_entry.offset, list_entry.offset + list_entry.length, result_index); - ConstructSortKeyRecursive(*vector_data.child_data[0], child_chunk, info); - } - - // write the end-of-list delimiter - result_ptr[offset++] = static_cast(info.flip_bytes ? ~SortKeyVectorData::LIST_DELIMITER - : SortKeyVectorData::LIST_DELIMITER); - } -} - -static void ConstructSortKeyRecursive(SortKeyVectorData &vector_data, SortKeyChunk chunk, SortKeyConstructInfo &info) { - switch (vector_data.GetPhysicalType()) { - case PhysicalType::BOOL: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::UINT8: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INT8: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::UINT16: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INT16: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::UINT32: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INT32: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::UINT64: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INT64: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::FLOAT: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::DOUBLE: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INTERVAL: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::UINT128: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::INT128: - TemplatedConstructSortKey>(vector_data, chunk, info); - break; - case PhysicalType::VARCHAR: - if (vector_data.vec.GetType().id() == LogicalTypeId::VARCHAR) { - TemplatedConstructSortKey(vector_data, chunk, info); - } else { - TemplatedConstructSortKey(vector_data, chunk, info); - } - break; - case PhysicalType::STRUCT: - ConstructSortKeyStruct(vector_data, chunk, info); - break; - case PhysicalType::LIST: - ConstructSortKeyList(vector_data, chunk, info); - break; - case PhysicalType::ARRAY: - ConstructSortKeyList(vector_data, chunk, info); - break; - default: - throw NotImplementedException("Unsupported type %s in ConstructSortKey", vector_data.vec.GetType()); - } -} - -static void ConstructSortKey(SortKeyVectorData &vector_data, SortKeyConstructInfo &info) { - ConstructSortKeyRecursive(vector_data, SortKeyChunk(0, vector_data.size), info); -} - -static void PrepareSortData(Vector &result, idx_t size, SortKeyLengthInfo &key_lengths, data_ptr_t *data_pointers) { - switch (result.GetType().id()) { - case LogicalTypeId::BLOB: { - auto result_data = FlatVector::GetData(result); - for (idx_t r = 0; r < size; r++) { - auto blob_size = key_lengths.variable_lengths[r] + key_lengths.constant_length; - result_data[r] = StringVector::EmptyString(result, blob_size); - data_pointers[r] = data_ptr_cast(result_data[r].GetDataWriteable()); -#ifdef DEBUG - memset(data_pointers[r], 0xFF, blob_size); -#endif - } - break; - } - case LogicalTypeId::BIGINT: { - auto result_data = FlatVector::GetData(result); - for (idx_t r = 0; r < size; r++) { - result_data[r] = 0; - data_pointers[r] = data_ptr_cast(&result_data[r]); - } - break; - } - default: - throw InternalException("Unsupported key type for CreateSortKey"); - } -} - -static void FinalizeSortData(Vector &result, idx_t size) { - switch (result.GetType().id()) { - case LogicalTypeId::BLOB: { - auto result_data = FlatVector::GetData(result); - // call Finalize on the result - for (idx_t r = 0; r < size; r++) { - result_data[r].Finalize(); - } - break; - } - case LogicalTypeId::BIGINT: { - auto result_data = FlatVector::GetData(result); - for (idx_t r = 0; r < size; r++) { - result_data[r] = BSwap(result_data[r]); - } - break; - } - default: - throw InternalException("Unsupported key type for CreateSortKey"); - } -} - -static void CreateSortKeyInternal(vector> &sort_key_data, - const vector &modifiers, Vector &result, idx_t row_count) { - // two phases - // a) get the length of the final sorted key - // b) allocate the sorted key and construct - // we do all of this in a vectorized manner - SortKeyLengthInfo key_lengths(row_count); - for (auto &vector_data : sort_key_data) { - GetSortKeyLength(*vector_data, key_lengths); - } - // allocate the empty sort keys - auto data_pointers = unique_ptr(new data_ptr_t[row_count]); - PrepareSortData(result, row_count, key_lengths, data_pointers.get()); - - unsafe_vector offsets; - offsets.resize(row_count, 0); - // now construct the sort keys - for (idx_t c = 0; c < sort_key_data.size(); c++) { - SortKeyConstructInfo info(modifiers[c], offsets, data_pointers.get()); - ConstructSortKey(*sort_key_data[c], info); - } - FinalizeSortData(result, row_count); -} - -void CreateSortKeyHelpers::CreateSortKey(Vector &input, idx_t input_count, OrderModifiers order_modifier, - Vector &result) { - // prepare the sort key data - vector modifiers {order_modifier}; - vector> sort_key_data; - sort_key_data.push_back(make_uniq(input, input_count, order_modifier)); - - CreateSortKeyInternal(sort_key_data, modifiers, result, input_count); -} - -void CreateSortKeyHelpers::CreateSortKey(DataChunk &input, const vector &modifiers, Vector &result) { - vector> sort_key_data; - D_ASSERT(modifiers.size() == input.ColumnCount()); - for (idx_t r = 0; r < modifiers.size(); r++) { - sort_key_data.push_back(make_uniq(input.data[r], input.size(), modifiers[r])); - } - CreateSortKeyInternal(sort_key_data, modifiers, result, input.size()); -} - -void CreateSortKeyHelpers::CreateSortKeyWithValidity(Vector &input, Vector &result, const OrderModifiers &modifiers, - const idx_t count) { - CreateSortKey(input, count, modifiers, result); - UnifiedVectorFormat format; - input.ToUnifiedFormat(count, format); - auto &validity = FlatVector::Validity(result); - - for (idx_t i = 0; i < count; i++) { - auto idx = format.sel->get_index(i); - if (!format.validity.RowIsValid(idx)) { - validity.SetInvalid(i); - } - } -} - -static void CreateSortKeyFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &bind_data = state.expr.Cast().bind_info->Cast(); - - // prepare the sort key data - vector> sort_key_data; - for (idx_t c = 0; c < args.ColumnCount(); c += 2) { - sort_key_data.push_back(make_uniq(args.data[c], args.size(), bind_data.modifiers[c / 2])); - } - CreateSortKeyInternal(sort_key_data, bind_data.modifiers, result, args.size()); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -//===--------------------------------------------------------------------===// -// Decode Sort Key -//===--------------------------------------------------------------------===// -struct DecodeSortKeyVectorData { - DecodeSortKeyVectorData(const LogicalType &type, OrderModifiers modifiers) - : flip_bytes(modifiers.order_type == OrderType::DESCENDING) { - null_byte = SortKeyVectorData::NULL_FIRST_BYTE; - valid_byte = SortKeyVectorData::NULL_LAST_BYTE; - if (modifiers.null_type == OrderByNullType::NULLS_LAST) { - std::swap(null_byte, valid_byte); - } - - // NULLS FIRST/NULLS LAST passed in by the user are only respected at the top level - // within nested types NULLS LAST/NULLS FIRST is dependent on ASC/DESC order instead - // don't blame me this is what Postgres does - auto child_null_type = - modifiers.order_type == OrderType::ASCENDING ? OrderByNullType::NULLS_LAST : OrderByNullType::NULLS_FIRST; - OrderModifiers child_modifiers(modifiers.order_type, child_null_type); - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - auto &children = StructType::GetChildTypes(type); - for (auto &child_type : children) { - child_data.emplace_back(child_type.second, child_modifiers); - } - break; - } - case PhysicalType::ARRAY: { - auto &child_type = ArrayType::GetChildType(type); - child_data.emplace_back(child_type, child_modifiers); - break; - } - case PhysicalType::LIST: { - auto &child_type = ListType::GetChildType(type); - child_data.emplace_back(child_type, child_modifiers); - break; - } - default: - break; - } - } - - data_t null_byte; - data_t valid_byte; - vector child_data; - bool flip_bytes; -}; - -struct DecodeSortKeyData { - explicit DecodeSortKeyData(string_t &sort_key) - : data(const_data_ptr_cast(sort_key.GetData())), size(sort_key.GetSize()), position(0) { - } - - const_data_ptr_t data; - idx_t size; - idx_t position; -}; - -void DecodeSortKeyRecursive(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx); - -template -void TemplatedDecodeSortKey(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx) { - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; - if (validity_byte == vector_data.null_byte) { - // NULL value - FlatVector::Validity(result).SetInvalid(result_idx); - return; - } - idx_t increment = OP::Decode(decode_data.data + decode_data.position, result, result_idx, vector_data.flip_bytes); - decode_data.position += increment; -} - -void DecodeSortKeyStruct(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx) { - // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; - if (validity_byte == vector_data.null_byte) { - // entire struct is NULL - // note that we still deserialize the children - FlatVector::Validity(result).SetInvalid(result_idx); - } - // recurse into children - auto &child_entries = StructVector::GetEntries(result); - for (idx_t c = 0; c < child_entries.size(); c++) { - auto &child_entry = child_entries[c]; - DecodeSortKeyRecursive(decode_data, vector_data.child_data[c], *child_entry, result_idx); - } -} - -void DecodeSortKeyList(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx) { - // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; - if (validity_byte == vector_data.null_byte) { - // entire list is NULL - FlatVector::Validity(result).SetInvalid(result_idx); - return; - } - // list is valid - decode child elements - // we don't know how many there will be - // decode child elements until we encounter the list delimiter - auto list_delimiter = SortKeyVectorData::LIST_DELIMITER; - if (vector_data.flip_bytes) { - list_delimiter = ~list_delimiter; - } - auto list_data = FlatVector::GetData(result); - auto &child_vector = ListVector::GetEntry(result); - // get the current list size - auto start_list_size = ListVector::GetListSize(result); - auto new_list_size = start_list_size; - // loop until we find the list delimiter - while (decode_data.data[decode_data.position] != list_delimiter) { - // found a valid entry here - decode it - // first reserve space for it - new_list_size++; - ListVector::Reserve(result, new_list_size); - - // now decode the entry - DecodeSortKeyRecursive(decode_data, vector_data.child_data[0], child_vector, new_list_size - 1); - } - // skip the list delimiter - decode_data.position++; - // set the list_entry_t information and update the list size - list_data[result_idx].length = new_list_size - start_list_size; - list_data[result_idx].offset = start_list_size; - ListVector::SetListSize(result, new_list_size); -} - -void DecodeSortKeyArray(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx) { - // check if the top-level is valid or not - auto validity_byte = decode_data.data[decode_data.position]; - decode_data.position++; - if (validity_byte == vector_data.null_byte) { - // entire array is NULL - // note that we still read the child elements - FlatVector::Validity(result).SetInvalid(result_idx); - } - // array is valid - decode child elements - // arrays need to encode exactly array_size child elements - // however the decoded data still contains a list delimiter - // we use this delimiter to verify we successfully decoded the entire array - auto list_delimiter = SortKeyVectorData::LIST_DELIMITER; - if (vector_data.flip_bytes) { - list_delimiter = ~list_delimiter; - } - auto &child_vector = ArrayVector::GetEntry(result); - auto array_size = ArrayType::GetSize(result.GetType()); - - idx_t found_elements = 0; - auto child_start = array_size * result_idx; - // loop until we find the list delimiter - while (decode_data.data[decode_data.position] != list_delimiter) { - found_elements++; - if (found_elements > array_size) { - // error - found too many elements - break; - } - // now decode the entry - DecodeSortKeyRecursive(decode_data, vector_data.child_data[0], child_vector, child_start + found_elements - 1); - } - // skip the list delimiter - decode_data.position++; - if (found_elements != array_size) { - throw InvalidInputException("Failed to decode array - found %d elements but expected %d", found_elements, - array_size); - } -} - -void DecodeSortKeyRecursive(DecodeSortKeyData &decode_data, DecodeSortKeyVectorData &vector_data, Vector &result, - idx_t result_idx) { - switch (result.GetType().InternalType()) { - case PhysicalType::BOOL: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::UINT8: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INT8: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::UINT16: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INT16: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::UINT32: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INT32: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::UINT64: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INT64: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::FLOAT: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::DOUBLE: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INTERVAL: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::UINT128: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::INT128: - TemplatedDecodeSortKey>(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::VARCHAR: - if (result.GetType().id() == LogicalTypeId::VARCHAR) { - TemplatedDecodeSortKey(decode_data, vector_data, result, result_idx); - } else { - TemplatedDecodeSortKey(decode_data, vector_data, result, result_idx); - } - break; - case PhysicalType::STRUCT: - DecodeSortKeyStruct(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::LIST: - DecodeSortKeyList(decode_data, vector_data, result, result_idx); - break; - case PhysicalType::ARRAY: - DecodeSortKeyArray(decode_data, vector_data, result, result_idx); - break; - default: - throw NotImplementedException("Unsupported type %s in DecodeSortKey", result.GetType()); - } -} - -void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, Vector &result, idx_t result_idx, - OrderModifiers modifiers) { - DecodeSortKeyVectorData sort_key_data(result.GetType(), modifiers); - DecodeSortKeyData decode_data(sort_key); - DecodeSortKeyRecursive(decode_data, sort_key_data, result, result_idx); -} - -void CreateSortKeyHelpers::DecodeSortKey(string_t sort_key, DataChunk &result, idx_t result_idx, - const vector &modifiers) { - DecodeSortKeyData decode_data(sort_key); - D_ASSERT(modifiers.size() == result.ColumnCount()); - for (idx_t c = 0; c < result.ColumnCount(); c++) { - auto &vec = result.data[c]; - DecodeSortKeyVectorData vector_data(vec.GetType(), modifiers[c]); - DecodeSortKeyRecursive(decode_data, vector_data, vec, result_idx); - } -} - -ScalarFunction CreateSortKeyFun::GetFunction() { - ScalarFunction sort_key_function("create_sort_key", {LogicalType::ANY}, LogicalType::BLOB, CreateSortKeyFunction, - CreateSortKeyBind); - sort_key_function.varargs = LogicalType::ANY; - sort_key_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return sort_key_function; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/date/strftime.cpp b/src/duckdb/src/function/scalar/date/strftime.cpp deleted file mode 100644 index 32222f191..000000000 --- a/src/duckdb/src/function/scalar/date/strftime.cpp +++ /dev/null @@ -1,333 +0,0 @@ -#include "duckdb/function/scalar/strftime_format.hpp" - -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/function/scalar/date_functions.hpp" - -#include -#include - -namespace duckdb { - -struct StrfTimeBindData : public FunctionData { - explicit StrfTimeBindData(StrfTimeFormat format_p, string format_string_p, bool is_null) - : format(std::move(format_p)), format_string(std::move(format_string_p)), is_null(is_null) { - } - - StrfTimeFormat format; - string format_string; - bool is_null; - - unique_ptr Copy() const override { - return make_uniq(format, format_string, is_null); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return format_string == other.format_string; - } -}; - -template -static unique_ptr StrfTimeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto format_idx = REVERSED ? 0U : 1U; - auto &format_arg = arguments[format_idx]; - if (format_arg->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!format_arg->IsFoldable()) { - throw InvalidInputException(*format_arg, "strftime format must be a constant"); - } - Value options_str = ExpressionExecutor::EvaluateScalar(context, *format_arg); - auto format_string = options_str.GetValue(); - StrfTimeFormat format; - bool is_null = options_str.IsNull(); - if (!is_null) { - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException(*format_arg, "Failed to parse format specifier %s: %s", format_string, error); - } - } - return make_uniq(format, format_string, is_null); -} - -template -static void StrfTimeFunctionDate(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.is_null) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - info.format.ConvertDateVector(args.data[REVERSED ? 1 : 0], result, args.size()); -} - -template -static void StrfTimeFunctionTimestamp(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.is_null) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - info.format.ConvertTimestampVector(args.data[REVERSED ? 1 : 0], result, args.size()); -} - -template -static void StrfTimeFunctionTimestampNS(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.is_null) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - info.format.ConvertTimestampNSVector(args.data[REVERSED ? 1 : 0], result, args.size()); -} - -ScalarFunctionSet StrfTimeFun::GetFunctions() { - ScalarFunctionSet strftime("strftime"); - - strftime.AddFunction(ScalarFunction({LogicalType::DATE, LogicalType::VARCHAR}, LogicalType::VARCHAR, - StrfTimeFunctionDate, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP, LogicalType::VARCHAR}, LogicalType::VARCHAR, - StrfTimeFunctionTimestamp, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::TIMESTAMP_NS, LogicalType::VARCHAR}, LogicalType::VARCHAR, - StrfTimeFunctionTimestampNS, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::DATE}, LogicalType::VARCHAR, - StrfTimeFunctionDate, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP}, LogicalType::VARCHAR, - StrfTimeFunctionTimestamp, StrfTimeBindFunction)); - strftime.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::TIMESTAMP_NS}, LogicalType::VARCHAR, - StrfTimeFunctionTimestampNS, StrfTimeBindFunction)); - return strftime; -} - -StrpTimeFormat::StrpTimeFormat() { -} - -StrpTimeFormat::StrpTimeFormat(const string &format_string) { - if (format_string.empty()) { - return; - } - StrTimeFormat::ParseFormatSpecifier(format_string, *this); -} - -struct StrpTimeBindData : public FunctionData { - StrpTimeBindData(const StrpTimeFormat &format, const string &format_string) - : formats(1, format), format_strings(1, format_string) { - } - - StrpTimeBindData(vector formats_p, vector format_strings_p) - : formats(std::move(formats_p)), format_strings(std::move(format_strings_p)) { - } - - vector formats; - vector format_strings; - - unique_ptr Copy() const override { - return make_uniq(formats, format_strings); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return format_strings == other.format_strings; - } -}; - -template -inline T StrpTimeResult(StrpTimeFormat::ParseResult &parsed) { - return parsed.ToTimestamp(); -} - -template <> -inline timestamp_ns_t StrpTimeResult(StrpTimeFormat::ParseResult &parsed) { - return parsed.ToTimestampNS(); -} - -template -inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, T &result, string &error) { - return format.TryParseTimestamp(input, result, error); -} - -template <> -inline bool StrpTimeTryResult(StrpTimeFormat &format, string_t &input, timestamp_ns_t &result, string &error) { - return format.TryParseTimestampNS(input, result, error); -} - -struct StrpTimeFunction { - - template - static void Parse(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // There is a bizarre situation where the format column is foldable but not constant - // (i.e., the statistics tell us it has only one value) - // We have to check whether that value is NULL - const auto count = args.size(); - UnifiedVectorFormat format_unified; - args.data[1].ToUnifiedFormat(count, format_unified); - - if (!format_unified.validity.RowIsValid(0)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - UnaryExecutor::Execute(args.data[0], result, args.size(), [&](string_t input) { - StrpTimeFormat::ParseResult result; - for (auto &format : info.formats) { - if (format.Parse(input, result)) { - return StrpTimeResult(result); - } - } - throw InvalidInputException(result.FormatError(input, info.formats[0].format_specifier)); - }); - } - - template - static void TryParse(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (args.data[1].GetVectorType() == VectorType::CONSTANT_VECTOR && ConstantVector::IsNull(args.data[1])) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - UnaryExecutor::ExecuteWithNulls(args.data[0], result, args.size(), - [&](string_t input, ValidityMask &mask, idx_t idx) { - T result; - string error; - for (auto &format : info.formats) { - if (StrpTimeTryResult(format, input, result, error)) { - return result; - } - } - - mask.SetInvalid(idx); - return T(); - }); - } - - static unique_ptr Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[1]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[1]->IsFoldable()) { - throw InvalidInputException(*arguments[0], "strptime format must be a constant"); - } - Value format_value = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - string format_string; - StrpTimeFormat format; - if (format_value.IsNull()) { - return make_uniq(format, format_string); - } else if (format_value.type().id() == LogicalTypeId::VARCHAR) { - format_string = format_value.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, - error); - } - if (format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET)) { - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } else if (format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED)) { - bound_function.return_type = LogicalType::TIMESTAMP_NS; - if (bound_function.name == "strptime") { - bound_function.function = Parse; - } else { - bound_function.function = TryParse; - } - } - return make_uniq(format, format_string); - } else if (format_value.type() == LogicalType::LIST(LogicalType::VARCHAR)) { - const auto &children = ListValue::GetChildren(format_value); - if (children.empty()) { - throw InvalidInputException(*arguments[0], "strptime format list must not be empty"); - } - vector format_strings; - vector formats; - bool has_offset = false; - bool has_nanos = false; - - for (const auto &child : children) { - format_string = child.ToString(); - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException(*arguments[0], "Failed to parse format specifier %s: %s", format_string, - error); - } - has_offset = has_offset || format.HasFormatSpecifier(StrTimeSpecifier::UTC_OFFSET); - has_nanos = has_nanos || format.HasFormatSpecifier(StrTimeSpecifier::NANOSECOND_PADDED); - format_strings.emplace_back(format_string); - formats.emplace_back(format); - } - - if (has_offset) { - // If any format has UTC offsets, then we have to produce TSTZ - bound_function.return_type = LogicalType::TIMESTAMP_TZ; - } else if (has_nanos) { - // If any format has nanoseconds, then we have to produce TSNS - // unless there is an offset, in which case we produce - bound_function.return_type = LogicalType::TIMESTAMP_NS; - if (bound_function.name == "strptime") { - bound_function.function = Parse; - } else { - bound_function.function = TryParse; - } - } - return make_uniq(formats, format_strings); - } else { - throw InvalidInputException(*arguments[0], "strptime format must be a string"); - } - } -}; - -ScalarFunctionSet StrpTimeFun::GetFunctions() { - ScalarFunctionSet strptime("strptime"); - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::Parse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(fun); - strptime.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, - StrpTimeFunction::Parse, StrpTimeFunction::Bind); - BaseScalarFunction::SetReturnsError(fun); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - strptime.AddFunction(fun); - return strptime; -} - -ScalarFunctionSet TryStrpTimeFun::GetFunctions() { - ScalarFunctionSet try_strptime("try_strptime"); - - const auto list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto fun = ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::TIMESTAMP, - StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - try_strptime.AddFunction(fun); - - fun = ScalarFunction({LogicalType::VARCHAR, list_type}, LogicalType::TIMESTAMP, - StrpTimeFunction::TryParse, StrpTimeFunction::Bind); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - try_strptime.AddFunction(fun); - - return try_strptime; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp b/src/duckdb/src/function/scalar/generic/constant_or_null.cpp deleted file mode 100644 index 32b9f855c..000000000 --- a/src/duckdb/src/function/scalar/generic/constant_or_null.cpp +++ /dev/null @@ -1,105 +0,0 @@ -#include "duckdb/function/scalar/generic_common.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/common/operator/comparison_operators.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -struct ConstantOrNullBindData : public FunctionData { - explicit ConstantOrNullBindData(Value val) : value(std::move(val)) { - } - - Value value; - -public: - unique_ptr Copy() const override { - return make_uniq(value); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return value == other.value; - } -}; - -static void ConstantOrNullFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - result.Reference(info.value); - for (idx_t idx = 1; idx < args.ColumnCount(); idx++) { - switch (args.data[idx].GetVectorType()) { - case VectorType::FLAT_VECTOR: { - auto &input_mask = FlatVector::Validity(args.data[idx]); - if (!input_mask.AllValid()) { - // there are null values: need to merge them into the result - result.Flatten(args.size()); - auto &result_mask = FlatVector::Validity(result); - result_mask.Combine(input_mask, args.size()); - } - break; - } - case VectorType::CONSTANT_VECTOR: { - if (ConstantVector::IsNull(args.data[idx])) { - // input is constant null, return constant null - result.Reference(info.value); - ConstantVector::SetNull(result, true); - return; - } - break; - } - default: { - UnifiedVectorFormat vdata; - args.data[idx].ToUnifiedFormat(args.size(), vdata); - if (!vdata.validity.AllValid()) { - result.Flatten(args.size()); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < args.size(); i++) { - if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - result_mask.SetInvalid(i); - } - } - } - break; - } - } - } -} - -unique_ptr ConstantOrNull::Bind(Value value) { - return make_uniq(std::move(value)); -} - -bool ConstantOrNull::IsConstantOrNull(BoundFunctionExpression &expr, const Value &val) { - if (expr.function.name != "constant_or_null") { - return false; - } - D_ASSERT(expr.bind_info); - auto &bind_data = expr.bind_info->Cast(); - D_ASSERT(bind_data.value.type() == val.type()); - return bind_data.value == val; -} - -unique_ptr ConstantOrNullBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[0]->IsFoldable()) { - throw BinderException("ConstantOrNull requires a constant input"); - } - D_ASSERT(arguments.size() >= 2); - auto value = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - bound_function.return_type = arguments[0]->return_type; - return make_uniq(std::move(value)); -} - -ScalarFunction ConstantOrNullFun::GetFunction() { - auto fun = ScalarFunction("constant_or_null", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, - ConstantOrNullFunction); - fun.bind = ConstantOrNullBind; - fun.varargs = LogicalType::ANY; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic/error.cpp b/src/duckdb/src/function/scalar/generic/error.cpp deleted file mode 100644 index 2d42cfced..000000000 --- a/src/duckdb/src/function/scalar/generic/error.cpp +++ /dev/null @@ -1,23 +0,0 @@ -#include "duckdb/function/scalar/generic_functions.hpp" - -#include - -namespace duckdb { - -struct ErrorOperator { - template - static inline TR Operation(const TA &input) { - throw InvalidInputException(input.GetString()); - } -}; - -ScalarFunction ErrorFun::GetFunction() { - auto fun = ScalarFunction("error", {LogicalType::VARCHAR}, LogicalType::SQLNULL, - ScalarFunction::UnaryFunction); - // Set the function with side effects to avoid the optimization. - fun.stability = FunctionStability::VOLATILE; - BaseScalarFunction::SetReturnsError(fun); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/generic/getvariable.cpp b/src/duckdb/src/function/scalar/generic/getvariable.cpp deleted file mode 100644 index 14d32954d..000000000 --- a/src/duckdb/src/function/scalar/generic/getvariable.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/transaction/meta_transaction.hpp" - -namespace duckdb { - -struct GetVariableBindData : FunctionData { - explicit GetVariableBindData(Value value_p) : value(std::move(value_p)) { - } - - Value value; - - bool Equals(const FunctionData &other_p) const override { - const auto &other = other_p.Cast(); - return Value::NotDistinctFrom(value, other.value); - } - - unique_ptr Copy() const override { - return make_uniq(value); - } -}; - -static unique_ptr GetVariableBind(ClientContext &context, ScalarFunction &function, - vector> &arguments) { - if (!arguments[0]->IsFoldable()) { - throw NotImplementedException("getvariable requires a constant input"); - } - if (arguments[0]->HasParameter()) { - throw ParameterNotResolvedException(); - } - Value value; - auto variable_name = ExpressionExecutor::EvaluateScalar(context, *arguments[0]); - if (!variable_name.IsNull()) { - ClientConfig::GetConfig(context).GetUserVariable(variable_name.ToString(), value); - } - function.return_type = value.type(); - return make_uniq(std::move(value)); -} - -unique_ptr BindGetVariableExpression(FunctionBindExpressionInput &input) { - if (!input.bind_data) { - // unknown type - throw InternalException("input.bind_data should be set"); - } - auto &bind_data = input.bind_data->Cast(); - // emit a constant expression - return make_uniq(bind_data.value); -} - -ScalarFunction GetVariableFun::GetFunction() { - ScalarFunction getvar("getvariable", {LogicalType::VARCHAR}, LogicalType::ANY, nullptr, GetVariableBind, nullptr); - getvar.bind_expression = BindGetVariableExpression; - return getvar; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/contains_or_position.cpp b/src/duckdb/src/function/scalar/list/contains_or_position.cpp deleted file mode 100644 index 309d78c39..000000000 --- a/src/duckdb/src/function/scalar/list/contains_or_position.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/function/scalar/list/contains_or_position.hpp" - -namespace duckdb { - -template -static void ListSearchFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto target_count = input.size(); - auto &list_vec = input.data[0]; - auto &source_vec = ListVector::GetEntry(list_vec); - auto &target_vec = input.data[1]; - - ListSearchOp(list_vec, source_vec, target_vec, result, target_count); - - if (target_count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListSearchBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - // If the first argument is an array, cast it to a list - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - const auto &list = arguments[0]->return_type; - const auto &value = arguments[1]->return_type; - - const auto list_is_param = list.id() == LogicalTypeId::UNKNOWN; - const auto value_is_param = value.id() == LogicalTypeId::UNKNOWN; - - if (list_is_param) { - if (!value_is_param) { - // only list is a parameter, cast it to a list of value type - bound_function.arguments[0] = LogicalType::LIST(value); - bound_function.arguments[1] = value; - } - } else if (value_is_param) { - // only value is a parameter: we expect the child type of list - bound_function.arguments[0] = list; - bound_function.arguments[1] = ListType::GetChildType(list); - } else { - LogicalType max_child_type; - if (!LogicalType::TryGetMaxLogicalType(context, ListType::GetChildType(list), value, max_child_type)) { - throw BinderException( - "%s: Cannot match element of type '%s' in a list of type '%s' - an explicit cast is required", - bound_function.name, value.ToString(), list.ToString()); - } - - bound_function.arguments[0] = LogicalType::LIST(max_child_type); - bound_function.arguments[1] = max_child_type; - } - return make_uniq(bound_function.return_type); -} - -ScalarFunction ListContainsFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::BOOLEAN, - ListSearchFunction, ListSearchBind); -} - -ScalarFunction ListPositionFun::GetFunction() { - return ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::ANY}, LogicalType::INTEGER, - ListSearchFunction, ListSearchBind); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_extract.cpp b/src/duckdb/src/function/scalar/list/list_extract.cpp deleted file mode 100644 index 058ce5f42..000000000 --- a/src/duckdb/src/function/scalar/list/list_extract.cpp +++ /dev/null @@ -1,190 +0,0 @@ -#include "duckdb/common/pair.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/uhugeint.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/storage/statistics/list_stats.hpp" - -namespace duckdb { - -static optional_idx TryGetChildOffset(const list_entry_t &list_entry, const int64_t offset) { - // 1-based indexing - if (offset == 0) { - return optional_idx::Invalid(); - } - - const auto index_offset = (offset > 0) ? offset - 1 : offset; - if (index_offset < 0) { - const auto signed_list_length = UnsafeNumericCast(list_entry.length); - if (signed_list_length + index_offset < 0) { - return optional_idx::Invalid(); - } - return optional_idx(list_entry.offset + UnsafeNumericCast(signed_list_length + index_offset)); - } - - const auto unsigned_offset = UnsafeNumericCast(index_offset); - - // Check that the offset is within the list - if (unsigned_offset >= list_entry.length) { - return optional_idx::Invalid(); - } - - return optional_idx(list_entry.offset + unsigned_offset); -} - -static void ExecuteListExtract(Vector &result, Vector &list, Vector &offsets, const idx_t count) { - D_ASSERT(list.GetType().id() == LogicalTypeId::LIST); - UnifiedVectorFormat list_data; - UnifiedVectorFormat offsets_data; - - list.ToUnifiedFormat(count, list_data); - offsets.ToUnifiedFormat(count, offsets_data); - - const auto list_ptr = UnifiedVectorFormat::GetData(list_data); - const auto offsets_ptr = UnifiedVectorFormat::GetData(offsets_data); - - UnifiedVectorFormat child_data; - auto &child_vector = ListVector::GetEntry(list); - auto child_count = ListVector::GetListSize(list); - child_vector.ToUnifiedFormat(child_count, child_data); - - SelectionVector sel(count); - vector invalid_offsets; - - optional_idx first_valid_child_idx; - for (idx_t i = 0; i < count; i++) { - const auto list_index = list_data.sel->get_index(i); - const auto offsets_index = offsets_data.sel->get_index(i); - - if (!list_data.validity.RowIsValid(list_index) || !offsets_data.validity.RowIsValid(offsets_index)) { - invalid_offsets.push_back(i); - continue; - } - - const auto child_offset = TryGetChildOffset(list_ptr[list_index], offsets_ptr[offsets_index]); - - if (!child_offset.IsValid()) { - invalid_offsets.push_back(i); - continue; - } - - const auto child_idx = child_data.sel->get_index(child_offset.GetIndex()); - sel.set_index(i, child_idx); - - if (!first_valid_child_idx.IsValid()) { - // Save the first valid child as a dummy index to copy in VectorOperations::Copy later - first_valid_child_idx = child_idx; - } - } - - if (first_valid_child_idx.IsValid()) { - // Only copy if we found at least one valid child - for (const auto &invalid_offset : invalid_offsets) { - sel.set_index(invalid_offset, first_valid_child_idx.GetIndex()); - } - VectorOperations::Copy(child_vector, result, sel, count, 0, 0); - } - - // Copy:ing the vectors also copies the validity mask, so we set the rows with invalid offsets (0) to false here. - for (const auto &invalid_idx : invalid_offsets) { - FlatVector::SetNull(result, invalid_idx, true); - } - - if (count == 1 || (list.GetVectorType() == VectorType::CONSTANT_VECTOR && - offsets.GetVectorType() == VectorType::CONSTANT_VECTOR)) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(count); -} - -static void ExecuteStringExtract(Vector &result, Vector &input_vector, Vector &subscript_vector, const idx_t count) { - BinaryExecutor::Execute( - input_vector, subscript_vector, result, count, - [&](string_t input_string, int64_t subscript) { return SubstringUnicode(result, input_string, subscript, 1); }); -} - -static void ListExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 2); - auto count = args.size(); - - Vector &base = args.data[0]; - Vector &subscript = args.data[1]; - - switch (base.GetType().id()) { - case LogicalTypeId::LIST: - ExecuteListExtract(result, base, subscript, count); - break; - case LogicalTypeId::VARCHAR: - ExecuteStringExtract(result, base, subscript, count); - break; - case LogicalTypeId::SQLNULL: - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - break; - default: - throw NotImplementedException("Specifier type not implemented"); - } -} - -static unique_ptr ListExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id()); - // list extract returns the child type of the list as return type - auto child_type = ListType::GetChildType(arguments[0]->return_type); - - bound_function.return_type = child_type; - bound_function.arguments[0] = LogicalType::LIST(child_type); - return make_uniq(bound_function.return_type); -} - -static unique_ptr ListExtractStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &list_child_stats = ListStats::GetChildStats(child_stats[0]); - auto child_copy = list_child_stats.Copy(); - // list_extract always pushes a NULL, since if the offset is out of range for a list it inserts a null - child_copy.Set(StatsInfo::CAN_HAVE_NULL_VALUES); - return child_copy.ToUnique(); -} - -ScalarFunctionSet ListExtractFun::GetFunctions() { - ScalarFunctionSet list_extract_set("list_extract"); - - // the arguments and return types are actually set in the binder function - ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, - ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); - - ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); - BaseScalarFunction::SetReturnsError(lfun); - BaseScalarFunction::SetReturnsError(sfun); - list_extract_set.AddFunction(lfun); - list_extract_set.AddFunction(sfun); - return list_extract_set; -} - -ScalarFunctionSet ArrayExtractFun::GetFunctions() { - ScalarFunctionSet array_extract_set("array_extract"); - - // the arguments and return types are actually set in the binder function - ScalarFunction lfun({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, LogicalType::ANY, - ListExtractFunction, ListExtractBind, nullptr, ListExtractStats); - - ScalarFunction sfun({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, ListExtractFunction); - - array_extract_set.AddFunction(lfun); - array_extract_set.AddFunction(sfun); - array_extract_set.AddFunction(GetKeyExtractFunction()); - array_extract_set.AddFunction(GetIndexExtractFunction()); - return array_extract_set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_resize.cpp b/src/duckdb/src/function/scalar/list/list_resize.cpp deleted file mode 100644 index 019dfcc72..000000000 --- a/src/duckdb/src/function/scalar/list/list_resize.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" - -namespace duckdb { - -void ListResizeFunction(DataChunk &args, ExpressionState &, Vector &result) { - - // Early-out, if the return value is a constant NULL. - if (result.GetType().id() == LogicalTypeId::SQLNULL) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - - auto &lists = args.data[0]; - auto &new_sizes = args.data[1]; - auto row_count = args.size(); - - UnifiedVectorFormat lists_data; - lists.ToUnifiedFormat(row_count, lists_data); - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto list_entries = UnifiedVectorFormat::GetData(lists_data); - - auto &child_vector = ListVector::GetEntry(lists); - UnifiedVectorFormat child_data; - child_vector.ToUnifiedFormat(row_count, child_data); - - UnifiedVectorFormat new_sizes_data; - new_sizes.ToUnifiedFormat(row_count, new_sizes_data); - D_ASSERT(new_sizes.GetType().id() == LogicalTypeId::UBIGINT); - auto new_size_entries = UnifiedVectorFormat::GetData(new_sizes_data); - - // Get the new size of the result child vector. - // We skip rows with NULL values in the input lists. - idx_t child_vector_size = 0; - for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - auto list_idx = lists_data.sel->get_index(row_idx); - auto new_size_idx = new_sizes_data.sel->get_index(row_idx); - - if (lists_data.validity.RowIsValid(list_idx) && new_sizes_data.validity.RowIsValid(new_size_idx)) { - child_vector_size += new_size_entries[new_size_idx]; - } - } - ListVector::Reserve(result, child_vector_size); - ListVector::SetListSize(result, child_vector_size); - - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_entries = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - auto &result_child_vector = ListVector::GetEntry(result); - - // Get the default values, if provided. - UnifiedVectorFormat default_data; - optional_ptr default_vector; - if (args.ColumnCount() == 3) { - default_vector = &args.data[2]; - default_vector->ToUnifiedFormat(row_count, default_data); - } - - idx_t offset = 0; - for (idx_t row_idx = 0; row_idx < row_count; row_idx++) { - - auto list_idx = lists_data.sel->get_index(row_idx); - auto new_size_idx = new_sizes_data.sel->get_index(row_idx); - - // Set to NULL, if the list is NULL. - if (!lists_data.validity.RowIsValid(list_idx)) { - result_validity.SetInvalid(row_idx); - continue; - } - - idx_t new_size = 0; - if (new_sizes_data.validity.RowIsValid(new_size_idx)) { - new_size = new_size_entries[new_size_idx]; - } - - // If new_size >= length, then we copy [0, length) values. - // If new_size < length, then we copy [0, new_size) values. - auto copy_count = MinValue(list_entries[list_idx].length, new_size); - - // Set the result entry. - result_entries[row_idx].offset = offset; - result_entries[row_idx].length = new_size; - - // Copy the child vector's values. - // The number of elements to copy is later determined like so: source_count - source_offset. - idx_t source_offset = list_entries[list_idx].offset; - idx_t source_count = source_offset + copy_count; - VectorOperations::Copy(child_vector, result_child_vector, source_count, source_offset, offset); - offset += copy_count; - - // Fill the remaining space with the default values. - if (copy_count < new_size) { - idx_t remaining_count = new_size - copy_count; - - if (default_vector) { - auto default_idx = default_data.sel->get_index(row_idx); - if (default_data.validity.RowIsValid(default_idx)) { - SelectionVector sel(remaining_count); - for (idx_t j = 0; j < remaining_count; j++) { - sel.set_index(j, row_idx); - } - VectorOperations::Copy(*default_vector, result_child_vector, sel, remaining_count, 0, offset); - offset += remaining_count; - continue; - } - } - - // Fill the remaining space with NULL. - for (idx_t j = copy_count; j < new_size; j++) { - FlatVector::SetNull(result_child_vector, offset, true); - offset++; - } - } - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ListResizeBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2 || arguments.size() == 3); - bound_function.arguments[1] = LogicalType::UBIGINT; - - // If the first argument is an array, cast it to a list. - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - // Early-out, if the first argument is a constant NULL. - if (arguments[0]->return_type == LogicalType::SQLNULL) { - bound_function.arguments[0] = LogicalType::SQLNULL; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); - } - - // Early-out, if the first argument is a prepared statement. - if (arguments[0]->return_type == LogicalType::UNKNOWN) { - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); - } - - // Attempt implicit casting, if the default type does not match list the list child type. - if (bound_function.arguments.size() == 3 && - ListType::GetChildType(arguments[0]->return_type) != arguments[2]->return_type && - arguments[2]->return_type != LogicalTypeId::SQLNULL) { - bound_function.arguments[2] = ListType::GetChildType(arguments[0]->return_type); - } - - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); -} - -ScalarFunctionSet ListResizeFun::GetFunctions() { - ScalarFunction simple_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - simple_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(simple_fun); - ScalarFunction default_value_fun({LogicalType::LIST(LogicalTypeId::ANY), LogicalTypeId::ANY, LogicalTypeId::ANY}, - LogicalType::LIST(LogicalTypeId::ANY), ListResizeFunction, ListResizeBind); - default_value_fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - BaseScalarFunction::SetReturnsError(default_value_fun); - ScalarFunctionSet list_resize_set("list_resize"); - list_resize_set.AddFunction(simple_fun); - list_resize_set.AddFunction(default_value_fun); - return list_resize_set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_select.cpp b/src/duckdb/src/function/scalar/list/list_select.cpp deleted file mode 100644 index 55c6a9f1d..000000000 --- a/src/duckdb/src/function/scalar/list/list_select.cpp +++ /dev/null @@ -1,183 +0,0 @@ -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/function/scalar/list_functions.hpp" - -namespace duckdb { - -struct SetSelectionVectorSelect { - static void SetSelectionVector(SelectionVector &selection_vector, ValidityMask &validity_mask, - ValidityMask &input_validity, Vector &selection_entry, idx_t child_idx, - idx_t &target_offset, idx_t selection_offset, idx_t input_offset, - idx_t target_length) { - auto sel_idx = selection_entry.GetValue(selection_offset + child_idx).GetValue() - 1; - if (sel_idx >= 0 && sel_idx < UnsafeNumericCast(target_length)) { - auto sel_idx_unsigned = UnsafeNumericCast(sel_idx); - selection_vector.set_index(target_offset, input_offset + sel_idx_unsigned); - if (!input_validity.RowIsValid(input_offset + sel_idx_unsigned)) { - validity_mask.SetInvalid(target_offset); - } - } else { - selection_vector.set_index(target_offset, 0); - validity_mask.SetInvalid(target_offset); - } - target_offset++; - } - - static void GetResultLength(DataChunk &args, idx_t &result_length, const list_entry_t *selection_data, - Vector selection_entry, idx_t selection_idx) { - result_length += selection_data[selection_idx].length; - } -}; - -struct SetSelectionVectorWhere { - static void SetSelectionVector(SelectionVector &selection_vector, ValidityMask &validity_mask, - ValidityMask &input_validity, Vector &selection_entry, idx_t child_idx, - idx_t &target_offset, idx_t selection_offset, idx_t input_offset, - idx_t target_length) { - if (!selection_entry.GetValue(selection_offset + child_idx).GetValue()) { - return; - } - - selection_vector.set_index(target_offset, input_offset + child_idx); - if (!input_validity.RowIsValid(input_offset + child_idx)) { - validity_mask.SetInvalid(target_offset); - } - - if (child_idx >= target_length) { - selection_vector.set_index(target_offset, 0); - validity_mask.SetInvalid(target_offset); - } - - target_offset++; - } - - static void GetResultLength(DataChunk &args, idx_t &result_length, const list_entry_t *selection_data, - Vector selection_entry, idx_t selection_idx) { - for (idx_t child_idx = 0; child_idx < selection_data[selection_idx].length; child_idx++) { - if (selection_entry.GetValue(selection_data[selection_idx].offset + child_idx).IsNull()) { - throw InvalidInputException("NULLs are not allowed as list elements in the second input parameter."); - } - if (selection_entry.GetValue(selection_data[selection_idx].offset + child_idx).GetValue()) { - result_length++; - } - } - } -}; - -template -static void ListSelectFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.data.size() == 2); - Vector &list = args.data[0]; - Vector &selection_list = args.data[1]; - idx_t count = args.size(); - - list_entry_t *result_data; - result_data = FlatVector::GetData(result); - auto &result_entry = ListVector::GetEntry(result); - - UnifiedVectorFormat selection_lists; - selection_list.ToUnifiedFormat(count, selection_lists); - auto selection_lists_data = UnifiedVectorFormat::GetData(selection_lists); - auto &selection_entry = ListVector::GetEntry(selection_list); - - UnifiedVectorFormat input_list; - list.ToUnifiedFormat(count, input_list); - auto input_lists_data = UnifiedVectorFormat::GetData(input_list); - auto &input_entry = ListVector::GetEntry(list); - auto &input_validity = FlatVector::Validity(input_entry); - - idx_t result_length = 0; - for (idx_t i = 0; i < count; i++) { - idx_t input_idx = input_list.sel->get_index(i); - idx_t selection_idx = selection_lists.sel->get_index(i); - if (input_list.validity.RowIsValid(input_idx) && selection_lists.validity.RowIsValid(selection_idx)) { - OP::GetResultLength(args, result_length, selection_lists_data, selection_entry, selection_idx); - } - } - - ListVector::Reserve(result, result_length); - SelectionVector result_selection_vec = SelectionVector(result_length); - ValidityMask entry_validity_mask = ValidityMask(result_length); - ValidityMask &result_validity_mask = FlatVector::Validity(result); - - idx_t offset = 0; - for (idx_t j = 0; j < count; j++) { - // Get length and offset of selection list for current output row - auto selection_list_idx = selection_lists.sel->get_index(j); - idx_t selection_len = 0; - idx_t selection_offset = 0; - if (selection_lists.validity.RowIsValid(selection_list_idx)) { - selection_len = selection_lists_data[selection_list_idx].length; - selection_offset = selection_lists_data[selection_list_idx].offset; - } else { - result_validity_mask.SetInvalid(j); - continue; - } - // Get length and offset of input list for current output row - auto input_list_idx = input_list.sel->get_index(j); - idx_t input_length = 0; - idx_t input_offset = 0; - if (input_list.validity.RowIsValid(input_list_idx)) { - input_length = input_lists_data[input_list_idx].length; - input_offset = input_lists_data[input_list_idx].offset; - } else { - result_validity_mask.SetInvalid(j); - continue; - } - result_data[j].offset = offset; - // Set all selected values in the result - for (idx_t child_idx = 0; child_idx < selection_len; child_idx++) { - if (selection_entry.GetValue(selection_offset + child_idx).IsNull()) { - throw InvalidInputException("NULLs are not allowed as list elements in the second input parameter."); - } - OP::SetSelectionVector(result_selection_vec, entry_validity_mask, input_validity, selection_entry, - child_idx, offset, selection_offset, input_offset, input_length); - } - result_data[j].length = offset - result_data[j].offset; - } - result_entry.Slice(input_entry, result_selection_vec, count); - result_entry.Flatten(offset); - ListVector::SetListSize(result, offset); - FlatVector::SetValidity(result_entry, entry_validity_mask); - result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); -} - -static unique_ptr ListSelectBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - // If the first argument is an array, cast it to a list - arguments[0] = BoundCastExpression::AddArrayCastToList(context, std::move(arguments[0])); - - LogicalType child_type; - if (arguments[0]->return_type == LogicalTypeId::UNKNOWN || arguments[1]->return_type == LogicalTypeId::UNKNOWN) { - bound_function.arguments[0] = LogicalTypeId::UNKNOWN; - bound_function.return_type = LogicalType::SQLNULL; - return make_uniq(bound_function.return_type); - } - - D_ASSERT(LogicalTypeId::LIST == arguments[0]->return_type.id() || - LogicalTypeId::SQLNULL == arguments[0]->return_type.id()); - - bound_function.return_type = arguments[0]->return_type; - return make_uniq(bound_function.return_type); -} -ScalarFunction ListWhereFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::LIST(LogicalTypeId::ANY), LogicalType::LIST(LogicalType::BOOLEAN)}, - LogicalType::LIST(LogicalTypeId::ANY), ListSelectFunction, - ListSelectBind); - return fun; -} - -ScalarFunction ListSelectFun::GetFunction() { - auto fun = ScalarFunction({LogicalType::LIST(LogicalTypeId::ANY), LogicalType::LIST(LogicalType::BIGINT)}, - LogicalType::LIST(LogicalTypeId::ANY), ListSelectFunction, - ListSelectBind); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/list/list_zip.cpp b/src/duckdb/src/function/scalar/list/list_zip.cpp deleted file mode 100644 index ef39a989d..000000000 --- a/src/duckdb/src/function/scalar/list/list_zip.cpp +++ /dev/null @@ -1,170 +0,0 @@ -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/common/to_string.hpp" - -namespace duckdb { - -static void ListZipFunction(DataChunk &args, ExpressionState &state, Vector &result) { - idx_t count = args.size(); - idx_t args_size = args.ColumnCount(); - auto *result_data = FlatVector::GetData(result); - auto &result_struct = ListVector::GetEntry(result); - auto &struct_entries = StructVector::GetEntries(result_struct); - bool truncate_flags_set = false; - - // Check flag - if (args.data.back().GetType().id() == LogicalTypeId::BOOLEAN) { - truncate_flags_set = true; - args_size--; - } - - vector input_lists; - input_lists.resize(args.ColumnCount()); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - args.data[i].ToUnifiedFormat(count, input_lists[i]); - } - - // Handling output row for each input row - idx_t result_size = 0; - vector lengths; - for (idx_t j = 0; j < count; j++) { - // Is flag for current row set - bool truncate_to_shortest = false; - if (truncate_flags_set) { - auto &flag_vec = input_lists.back(); - idx_t flag_idx = flag_vec.sel->get_index(j); - if (flag_vec.validity.RowIsValid(flag_idx)) { - truncate_to_shortest = UnifiedVectorFormat::GetData(flag_vec)[flag_idx]; - } - } - - // Calculation of the outgoing list size - idx_t len = truncate_to_shortest ? NumericLimits::Maximum() : 0; - for (idx_t i = 0; i < args_size; i++) { - idx_t curr_size; - if (args.data[i].GetType() == LogicalType::SQLNULL || ListVector::GetListSize(args.data[i]) == 0) { - curr_size = 0; - } else { - idx_t sel_idx = input_lists[i].sel->get_index(j); - auto curr_data = UnifiedVectorFormat::GetData(input_lists[i]); - curr_size = input_lists[i].validity.RowIsValid(sel_idx) ? curr_data[sel_idx].length : 0; - } - - // Dependent on flag using gt or lt - if (truncate_to_shortest) { - len = len > curr_size ? curr_size : len; - } else { - len = len < curr_size ? curr_size : len; - } - } - lengths.push_back(len); - result_size += len; - } - - ListVector::SetListSize(result, result_size); - ListVector::Reserve(result, result_size); - vector selections; - vector masks; - for (idx_t i = 0; i < args_size; i++) { - selections.push_back(SelectionVector(result_size)); - masks.push_back(ValidityMask(result_size)); - } - - idx_t offset = 0; - for (idx_t j = 0; j < count; j++) { - idx_t len = lengths[j]; - for (idx_t i = 0; i < args_size; i++) { - auto &curr = input_lists[i]; - idx_t sel_idx = curr.sel->get_index(j); - idx_t curr_off = 0; - idx_t curr_len = 0; - - // Copying values from the given lists - if (curr.validity.RowIsValid(sel_idx)) { - auto input_lists_data = UnifiedVectorFormat::GetData(curr); - curr_off = input_lists_data[sel_idx].offset; - curr_len = input_lists_data[sel_idx].length; - auto copy_len = len < curr_len ? len : curr_len; - idx_t entry = offset; - for (idx_t k = 0; k < copy_len; k++) { - if (!FlatVector::Validity(ListVector::GetEntry(args.data[i])).RowIsValid(curr_off + k)) { - masks[i].SetInvalid(entry + k); - } - selections[i].set_index(entry + k, curr_off + k); - } - } - - // Set NULL values for list that are shorter than the output list - if (len > curr_len) { - for (idx_t d = curr_len; d < len; d++) { - masks[i].SetInvalid(d + offset); - selections[i].set_index(d + offset, 0); - } - } - } - result_data[j].length = len; - result_data[j].offset = offset; - offset += len; - } - for (idx_t child_idx = 0; child_idx < args_size; child_idx++) { - if (args.data[child_idx].GetType() != LogicalType::SQLNULL) { - struct_entries[child_idx]->Slice(ListVector::GetEntry(args.data[child_idx]), selections[child_idx], - result_size); - } - struct_entries[child_idx]->Flatten(result_size); - FlatVector::SetValidity((*struct_entries[child_idx]), masks[child_idx]); - } - result.SetVectorType(args.AllConstant() ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); -} - -static unique_ptr ListZipBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - child_list_t struct_children; - - // The last argument could be a flag to be set if we want a minimal list or a maximal list - idx_t size = arguments.size(); - if (size == 0) { - throw BinderException("Provide at least one argument to " + bound_function.name); - } - if (arguments[size - 1]->return_type.id() == LogicalTypeId::BOOLEAN) { - if (--size == 0) { - throw BinderException("Provide at least one list argument to " + bound_function.name); - } - } - - case_insensitive_set_t struct_names; - for (idx_t i = 0; i < size; i++) { - auto &child = arguments[i]; - switch (child->return_type.id()) { - case LogicalTypeId::LIST: - case LogicalTypeId::ARRAY: - child = BoundCastExpression::AddArrayCastToList(context, std::move(child)); - struct_children.push_back(make_pair(string(), ListType::GetChildType(child->return_type))); - break; - case LogicalTypeId::SQLNULL: - struct_children.push_back(make_pair(string(), LogicalTypeId::SQLNULL)); - break; - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - default: - throw BinderException("Parameter type needs to be List"); - } - } - bound_function.return_type = LogicalType::LIST(LogicalType::STRUCT(struct_children)); - return make_uniq(bound_function.return_type); -} - -ScalarFunction ListZipFun::GetFunction() { - - auto fun = ScalarFunction({}, LogicalType::LIST(LogicalTypeId::STRUCT), ListZipFunction, ListZipBind); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/map/map_contains.cpp b/src/duckdb/src/function/scalar/map/map_contains.cpp deleted file mode 100644 index 068e67bc7..000000000 --- a/src/duckdb/src/function/scalar/map/map_contains.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "duckdb/function/scalar/list/contains_or_position.hpp" -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/function/scalar/map_functions.hpp" - -namespace duckdb { - -static void MapContainsFunction(DataChunk &input, ExpressionState &state, Vector &result) { - const auto count = input.size(); - - auto &map_vec = input.data[0]; - auto &key_vec = MapVector::GetKeys(map_vec); - auto &arg_vec = input.data[1]; - - ListSearchOp(map_vec, key_vec, arg_vec, result, count); - - if (count == 1) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr MapContainsBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - - const auto &map = arguments[0]->return_type; - const auto &key = arguments[1]->return_type; - - if (map.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - if (key.id() == LogicalTypeId::UNKNOWN) { - // Infer the argument type from the map type - bound_function.arguments[0] = map; - bound_function.arguments[1] = MapType::KeyType(map); - } else { - LogicalType max_child_type; - if (!LogicalType::TryGetMaxLogicalType(context, MapType::KeyType(map), key, max_child_type)) { - throw BinderException( - "%s: Cannot match element of type '%s' in a map of type '%s' - an explicit cast is required", - bound_function.name, key.ToString(), map.ToString()); - } - - bound_function.arguments[0] = LogicalType::MAP(max_child_type, MapType::ValueType(map)); - bound_function.arguments[1] = max_child_type; - } - return nullptr; -} - -ScalarFunction MapContainsFun::GetFunction() { - ScalarFunction fun("map_contains", {LogicalType::MAP(LogicalType::ANY, LogicalType::ANY), LogicalType::ANY}, - LogicalType::BOOLEAN, MapContainsFunction, MapContainsBind); - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/nested_functions.cpp b/src/duckdb/src/function/scalar/nested_functions.cpp deleted file mode 100644 index 2d5359c4e..000000000 --- a/src/duckdb/src/function/scalar/nested_functions.cpp +++ /dev/null @@ -1,40 +0,0 @@ -#include "duckdb/function/scalar/nested_functions.hpp" - -namespace duckdb { - -void MapUtil::ReinterpretMap(Vector &result, Vector &input, idx_t count) { - UnifiedVectorFormat input_data; - input.ToUnifiedFormat(count, input_data); - // Copy the list validity - FlatVector::SetValidity(result, input_data.validity); - - // Copy the struct validity - UnifiedVectorFormat input_struct_data; - ListVector::GetEntry(input).ToUnifiedFormat(count, input_struct_data); - auto &result_struct = ListVector::GetEntry(result); - FlatVector::SetValidity(result_struct, input_struct_data.validity); - - // Copy the list size - auto list_size = ListVector::GetListSize(input); - ListVector::SetListSize(result, list_size); - - // Copy the list buffer (the list_entry_t data) - result.CopyBuffer(input); - - auto &input_keys = MapVector::GetKeys(input); - auto &result_keys = MapVector::GetKeys(result); - result_keys.Reference(input_keys); - - auto &input_values = MapVector::GetValues(input); - auto &result_values = MapVector::GetValues(result); - result_values.Reference(input_values); - - if (input.GetVectorType() == VectorType::DICTIONARY_VECTOR) { - result.Slice(*input_data.sel, count); - } - - // Set the right vector type - result.SetVectorType(input.GetVectorType()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operator/add.cpp b/src/duckdb/src/function/scalar/operator/add.cpp deleted file mode 100644 index fa2f1a9e7..000000000 --- a/src/duckdb/src/function/scalar/operator/add.cpp +++ /dev/null @@ -1,295 +0,0 @@ -#include "duckdb/common/operator/add.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/common/types/value.hpp" - -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// + [add] -//===--------------------------------------------------------------------===// -template <> -float AddOperator::Operation(float left, float right) { - auto result = left + right; - return result; -} - -template <> -double AddOperator::Operation(double left, double right) { - auto result = left + right; - return result; -} - -template <> -interval_t AddOperator::Operation(interval_t left, interval_t right) { - left.months = AddOperatorOverflowCheck::Operation(left.months, right.months); - left.days = AddOperatorOverflowCheck::Operation(left.days, right.days); - left.micros = AddOperatorOverflowCheck::Operation(left.micros, right.micros); - return left; -} - -template <> -date_t AddOperator::Operation(date_t left, int32_t right) { - date_t result; - if (!TryAddOperator::Operation(left, right, result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -template <> -date_t AddOperator::Operation(int32_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(date_t left, dtime_t right) { - if (left == date_t::infinity()) { - return timestamp_t::infinity(); - } else if (left == date_t::ninfinity()) { - return timestamp_t::ninfinity(); - } - timestamp_t result; - if (!Timestamp::TryFromDatetime(left, right, result)) { - throw OutOfRangeException("Timestamp out of range"); - } - return result; -} - -template <> -timestamp_t AddOperator::Operation(date_t left, dtime_tz_t right) { - if (left == date_t::infinity()) { - return timestamp_t::infinity(); - } else if (left == date_t::ninfinity()) { - return timestamp_t::ninfinity(); - } - timestamp_t result; - if (!Timestamp::TryFromDatetime(left, right, result)) { - throw OutOfRangeException("Timestamp with time zone out of range"); - } - return result; -} - -template <> -timestamp_t AddOperator::Operation(dtime_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(dtime_tz_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(date_t left, interval_t right) { - if (left == date_t::infinity()) { - return timestamp_t::infinity(); - } - if (left == date_t::ninfinity()) { - return timestamp_t::ninfinity(); - } - return Interval::Add(Timestamp::FromDatetime(left, dtime_t(0)), right); -} - -template <> -timestamp_t AddOperator::Operation(interval_t left, date_t right) { - return AddOperator::Operation(right, left); -} - -template <> -timestamp_t AddOperator::Operation(timestamp_t left, interval_t right) { - return Interval::Add(left, right); -} - -template <> -timestamp_t AddOperator::Operation(interval_t left, timestamp_t right) { - return AddOperator::Operation(right, left); -} - -//===--------------------------------------------------------------------===// -// + [add] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedAddition { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = AddOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TryAddOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} -template <> -bool TryAddOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} -template <> -bool TryAddOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (NumericLimits::Maximum() - left < right) { - return false; - } - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(date_t left, int32_t right, date_t &result) { - if (left == date_t::infinity() || left == date_t::ninfinity()) { - result = date_t(left); - return true; - } - int32_t days; - if (!TryAddOperator::Operation(left.days, right, days)) { - return false; - } - result.days = days; - if (!Value::IsFinite(result)) { - return false; - } - return true; -} - -template <> -bool TryAddOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedAddition::Operation(left, right, result); -} - -template <> -bool TryAddOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_add_overflow(left, right, &result)) { - return false; - } -#else - // https://blog.regehr.org/archives/1139 - result = int64_t((uint64_t)left + (uint64_t)right); - if ((left < 0 && right < 0 && result >= 0) || (left >= 0 && right >= 0 && result < 0)) { - return false; - } -#endif - return true; -} - -template <> -bool TryAddOperator::Operation(uhugeint_t left, uhugeint_t right, uhugeint_t &result) { - if (!Uhugeint::TryAddInPlace(left, right)) { - return false; - } - result = left; - return true; -} - -template <> -bool TryAddOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - if (!Hugeint::TryAddInPlace(left, right)) { - return false; - } - result = left; - return true; -} - -//===--------------------------------------------------------------------===// -// add decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalAddTemplated(T left, T right, T &result) { - if (right < 0) { - if (min - right > left) { - return false; - } - } else { - if (max - right < left) { - return false; - } - } - result = left + right; - return true; -} - -template <> -bool TryDecimalAdd::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalAddTemplated(left, right, result); -} - -template <> -bool TryDecimalAdd::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - if (!TryAddOperator::Operation(left, right, result)) { - return false; - } - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalAddOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalAdd::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in addition of DECIMAL(38) (%s + %s);", left.ToString(), right.ToString()); - } - return result; -} - -//===--------------------------------------------------------------------===// -// add time operator -//===--------------------------------------------------------------------===// -template <> -dtime_t AddTimeOperator::Operation(dtime_t left, interval_t right) { - date_t date(0); - return Interval::Add(left, right, date); -} - -template <> -dtime_t AddTimeOperator::Operation(interval_t left, dtime_t right) { - return AddTimeOperator::Operation(right, left); -} - -template <> -dtime_tz_t AddTimeOperator::Operation(dtime_tz_t left, interval_t right) { - date_t date(0); - return Interval::Add(left, right, date); -} - -template <> -dtime_tz_t AddTimeOperator::Operation(interval_t left, dtime_tz_t right) { - return AddTimeOperator::Operation(right, left); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operator/arithmetic.cpp b/src/duckdb/src/function/scalar/operator/arithmetic.cpp deleted file mode 100644 index a3c11ed63..000000000 --- a/src/duckdb/src/function/scalar/operator/arithmetic.cpp +++ /dev/null @@ -1,1119 +0,0 @@ -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/operator/numeric_binary_operators.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/decimal.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/operators.hpp" -#include "duckdb/function/scalar/operator_functions.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -#include - -namespace duckdb { - -template -static scalar_function_t GetScalarIntegerFunction(PhysicalType type) { - scalar_function_t function; - switch (type) { - case PhysicalType::INT8: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT16: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT32: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT64: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::INT128: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT8: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT16: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT32: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT64: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::UINT128: - function = &ScalarFunction::BinaryFunction; - break; - default: - throw NotImplementedException("Unimplemented type for GetScalarBinaryFunction: %s", TypeIdToString(type)); - } - return function; -} - -template -static scalar_function_t GetScalarBinaryFunction(PhysicalType type) { - scalar_function_t function; - switch (type) { - case PhysicalType::FLOAT: - function = &ScalarFunction::BinaryFunction; - break; - case PhysicalType::DOUBLE: - function = &ScalarFunction::BinaryFunction; - break; - default: - function = GetScalarIntegerFunction(type); - break; - } - return function; -} - -//===--------------------------------------------------------------------===// -// + [add] -//===--------------------------------------------------------------------===// -struct AddPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - T min, max; - // new min is min+min - if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMin(rstats), min)) { - return true; - } - // new max is max+max - if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMax(rstats), max)) { - return true; - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -struct SubtractPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - T min, max; - if (!OP::Operation(NumericStats::GetMin(lstats), NumericStats::GetMax(rstats), min)) { - return true; - } - if (!OP::Operation(NumericStats::GetMax(lstats), NumericStats::GetMin(rstats), max)) { - return true; - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -struct DecimalArithmeticBindData : public FunctionData { - DecimalArithmeticBindData() : check_overflow(false) { - } - - unique_ptr Copy() const override { - auto res = make_uniq(); - res->check_overflow = check_overflow; - return std::move(res); - } - - bool Equals(const FunctionData &other_p) const override { - auto other = other_p.Cast(); - return other.check_overflow == check_overflow; - } - - bool check_overflow; -}; - -template -static unique_ptr PropagateNumericStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 2); - // can only propagate stats if the children have stats - auto &lstats = child_stats[0]; - auto &rstats = child_stats[1]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(lstats) && NumericStats::HasMinMax(rstats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT16: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT32: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - case PhysicalType::INT64: - potential_overflow = - PROPAGATE::template Operation(expr.return_type, lstats, rstats, new_min, new_max); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } else { - // no potential overflow: replace with non-overflowing operator - if (input.bind_data) { - auto &bind_data = input.bind_data->Cast(); - bind_data.check_overflow = false; - } - expr.function.function = GetScalarIntegerFunction(expr.return_type.InternalType()); - } - auto result = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(result, new_min); - NumericStats::SetMax(result, new_max); - result.CombineValidity(lstats, rstats); - return result.ToUnique(); -} - -template -unique_ptr BindDecimalArithmetic(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto bind_data = make_uniq(); - - // get the max width and scale of the input arguments - uint8_t max_width = 0, max_scale = 0, max_width_over_scale = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { - continue; - } - uint8_t width, scale; - auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); - if (!can_convert) { - throw InternalException("Could not convert type %s to a decimal.", arguments[i]->return_type.ToString()); - } - max_width = MaxValue(width, max_width); - max_scale = MaxValue(scale, max_scale); - max_width_over_scale = MaxValue(width - scale, max_width_over_scale); - } - D_ASSERT(max_width > 0); - uint8_t required_width = MaxValue(max_scale + max_width_over_scale, max_width); - if (!IS_MODULO) { - // for addition/subtraction, we add 1 to the width to ensure we don't overflow - required_width = NumericCast(required_width + 1); - if (required_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64) { - // we don't automatically promote past the hugeint boundary to avoid the large hugeint performance penalty - bind_data->check_overflow = true; - required_width = Decimal::MAX_WIDTH_INT64; - } - } - if (required_width > Decimal::MAX_WIDTH_DECIMAL) { - // target width does not fit in decimal at all: truncate the scale and perform overflow detection - bind_data->check_overflow = true; - required_width = Decimal::MAX_WIDTH_DECIMAL; - } - // arithmetic between two decimal arguments: check the types of the input arguments - LogicalType result_type = LogicalType::DECIMAL(required_width, max_scale); - // we cast all input types to the specified type - for (idx_t i = 0; i < arguments.size(); i++) { - // first check if the cast is necessary - // if the argument has a matching scale and internal type as the output type, no casting is necessary - auto &argument_type = arguments[i]->return_type; - uint8_t width, scale; - argument_type.GetDecimalProperties(width, scale); - if (scale == DecimalType::GetScale(result_type) && argument_type.InternalType() == result_type.InternalType()) { - bound_function.arguments[i] = argument_type; - } else { - bound_function.arguments[i] = result_type; - } - } - bound_function.return_type = result_type; - return bind_data; -} - -template -unique_ptr BindDecimalAddSubtract(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto bind_data = BindDecimalArithmetic(context, bound_function, arguments); - - // now select the physical function to execute - auto &result_type = bound_function.return_type; - if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } - if (result_type.InternalType() != PhysicalType::INT128 && result_type.InternalType() != PhysicalType::UINT128) { - if (IS_SUBTRACT) { - bound_function.statistics = - PropagateNumericStats; - } else { - bound_function.statistics = PropagateNumericStats; - } - } - return std::move(bind_data); -} - -static void SerializeDecimalArithmetic(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "check_overflow", bind_data.check_overflow); - serializer.WriteProperty(101, "return_type", function.return_type); - serializer.WriteProperty(102, "arguments", function.arguments); -} - -// TODO this is partially duplicated from the bind -template -unique_ptr DeserializeDecimalArithmetic(Deserializer &deserializer, ScalarFunction &bound_function) { - - // // re-change the function pointers - auto check_overflow = deserializer.ReadProperty(100, "check_overflow"); - auto return_type = deserializer.ReadProperty(101, "return_type"); - auto arguments = deserializer.ReadProperty>(102, "arguments"); - if (check_overflow) { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(return_type.InternalType()); - } - bound_function.statistics = nullptr; // TODO we likely dont want to do stats prop again - bound_function.return_type = return_type; - bound_function.arguments = arguments; - - auto bind_data = make_uniq(); - bind_data->check_overflow = check_overflow; - return std::move(bind_data); -} - -unique_ptr NopDecimalBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - bound_function.return_type = arguments[0]->return_type; - bound_function.arguments[0] = arguments[0]->return_type; - return nullptr; -} - -ScalarFunction AddFunction::GetFunction(const LogicalType &type) { - D_ASSERT(type.IsNumeric()); - if (type.id() == LogicalTypeId::DECIMAL) { - return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction, NopDecimalBind); - } else { - return ScalarFunction("+", {type}, type, ScalarFunction::NopFunction); - } -} - -ScalarFunction AddFunction::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { - if (left_type.IsNumeric() && left_type.id() == right_type.id()) { - if (left_type.id() == LogicalTypeId::DECIMAL) { - auto function = ScalarFunction("+", {left_type, right_type}, left_type, nullptr, - BindDecimalAddSubtract); - BaseScalarFunction::SetReturnsError(function); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - return function; - } else if (left_type.IsIntegral()) { - ScalarFunction function("+", {left_type, right_type}, left_type, - GetScalarIntegerFunction(left_type.InternalType()), - nullptr, nullptr, - PropagateNumericStats); - BaseScalarFunction::SetReturnsError(function); - return function; - } else { - ScalarFunction function("+", {left_type, right_type}, left_type, - GetScalarBinaryFunction(left_type.InternalType())); - BaseScalarFunction::SetReturnsError(function); - return function; - } - } - - switch (left_type.id()) { - case LogicalTypeId::DATE: - if (right_type.id() == LogicalTypeId::INTEGER) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::TIME) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::TIME_TZ) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::INTEGER: - if (right_type.id() == LogicalTypeId::DATE) { - ScalarFunction function("+", {left_type, right_type}, right_type, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::INTERVAL: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::DATE) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - BaseScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::TIME) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::TIME_TZ) { - ScalarFunction function( - "+", {left_type, right_type}, LogicalType::TIME_TZ, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::TIMESTAMP) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIME: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::DATE) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIME_TZ: - if (right_type.id() == LogicalTypeId::DATE) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP_TZ, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function( - "+", {left_type, right_type}, LogicalType::TIME_TZ, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIMESTAMP: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("+", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - default: - break; - } - // LCOV_EXCL_START - throw NotImplementedException("AddFunction for types %s, %s", EnumUtil::ToString(left_type.id()), - EnumUtil::ToString(right_type.id())); - // LCOV_EXCL_STOP -} - -ScalarFunctionSet OperatorAddFun::GetFunctions() { - ScalarFunctionSet add("+"); - for (auto &type : LogicalType::Numeric()) { - // unary add function is a nop, but only exists for numeric types - add.AddFunction(AddFunction::GetFunction(type)); - // binary add function adds two numbers together - add.AddFunction(AddFunction::GetFunction(type, type)); - } - // we can add integers to dates - add.AddFunction(AddFunction::GetFunction(LogicalType::DATE, LogicalType::INTEGER)); - add.AddFunction(AddFunction::GetFunction(LogicalType::INTEGER, LogicalType::DATE)); - // we can add intervals together - add.AddFunction(AddFunction::GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); - // we can add intervals to dates/times/timestamps - add.AddFunction(AddFunction::GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); - add.AddFunction(AddFunction::GetFunction(LogicalType::INTERVAL, LogicalType::DATE)); - - add.AddFunction(AddFunction::GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); - add.AddFunction(AddFunction::GetFunction(LogicalType::INTERVAL, LogicalType::TIME)); - - add.AddFunction(AddFunction::GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); - add.AddFunction(AddFunction::GetFunction(LogicalType::INTERVAL, LogicalType::TIMESTAMP)); - - add.AddFunction(AddFunction::GetFunction(LogicalType::TIME_TZ, LogicalType::INTERVAL)); - add.AddFunction(AddFunction::GetFunction(LogicalType::INTERVAL, LogicalType::TIME_TZ)); - - // we can add times to dates - add.AddFunction(AddFunction::GetFunction(LogicalType::TIME, LogicalType::DATE)); - add.AddFunction(AddFunction::GetFunction(LogicalType::DATE, LogicalType::TIME)); - - // we can add times with time zones (offsets) to dates - add.AddFunction(AddFunction::GetFunction(LogicalType::TIME_TZ, LogicalType::DATE)); - add.AddFunction(AddFunction::GetFunction(LogicalType::DATE, LogicalType::TIME_TZ)); - - // we can add lists together - add.AddFunction(ListConcatFun::GetFunction()); - - return add; -} - -//===--------------------------------------------------------------------===// -// - [subtract] -//===--------------------------------------------------------------------===// -struct NegateOperator { - template - static bool CanNegate(T input) { - using Limits = NumericLimits; - return !(Limits::IsSigned() && Limits::Minimum() == input); - } - - template - static inline TR Operation(TA input) { - auto cast = (TR)input; - if (!CanNegate(cast)) { - throw OutOfRangeException("Overflow in negation of integer!"); - } - return -cast; - } -}; - -template <> -bool NegateOperator::CanNegate(float input) { - return true; -} - -template <> -bool NegateOperator::CanNegate(double input) { - return true; -} - -template <> -interval_t NegateOperator::Operation(interval_t input) { - interval_t result; - result.months = NegateOperator::Operation(input.months); - result.days = NegateOperator::Operation(input.days); - result.micros = NegateOperator::Operation(input.micros); - return result; -} - -struct DecimalNegateBindData : public FunctionData { - DecimalNegateBindData() : bound_type(LogicalTypeId::INVALID) { - } - - unique_ptr Copy() const override { - auto res = make_uniq(); - res->bound_type = bound_type; - return std::move(res); - } - - bool Equals(const FunctionData &other_p) const override { - auto other = other_p.Cast(); - return other.bound_type == bound_type; - } - - LogicalTypeId bound_type; -}; - -unique_ptr DecimalNegateBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto bind_data = make_uniq(); - - auto &decimal_type = arguments[0]->return_type; - auto width = DecimalType::GetWidth(decimal_type); - if (width <= Decimal::MAX_WIDTH_INT16) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::SMALLINT); - } else if (width <= Decimal::MAX_WIDTH_INT32) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::INTEGER); - } else if (width <= Decimal::MAX_WIDTH_INT64) { - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::BIGINT); - } else { - D_ASSERT(width <= Decimal::MAX_WIDTH_INT128); - bound_function.function = ScalarFunction::GetScalarUnaryFunction(LogicalTypeId::HUGEINT); - } - decimal_type.Verify(); - bound_function.arguments[0] = decimal_type; - bound_function.return_type = decimal_type; - return nullptr; -} - -struct NegatePropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &istats, Value &new_min, Value &new_max) { - auto max_value = NumericStats::GetMax(istats); - auto min_value = NumericStats::GetMin(istats); - if (!NegateOperator::CanNegate(min_value) || !NegateOperator::CanNegate(max_value)) { - return true; - } - // new min is -max - new_min = Value::Numeric(type, NegateOperator::Operation(max_value)); - // new max is -min - new_max = Value::Numeric(type, NegateOperator::Operation(min_value)); - return false; - } -}; - -static unique_ptr NegateBindStatistics(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - auto &istats = child_stats[0]; - Value new_min, new_max; - bool potential_overflow = true; - if (NumericStats::HasMinMax(istats)) { - switch (expr.return_type.InternalType()) { - case PhysicalType::INT8: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT16: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT32: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - case PhysicalType::INT64: - potential_overflow = - NegatePropagateStatistics::Operation(expr.return_type, istats, new_min, new_max); - break; - default: - return nullptr; - } - } - if (potential_overflow) { - new_min = Value(expr.return_type); - new_max = Value(expr.return_type); - } - auto stats = NumericStats::CreateEmpty(expr.return_type); - NumericStats::SetMin(stats, new_min); - NumericStats::SetMax(stats, new_max); - stats.CopyValidity(istats); - return stats.ToUnique(); -} - -ScalarFunction SubtractFunction::GetFunction(const LogicalType &type) { - if (type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction func("-", {type}, type, ScalarFunction::UnaryFunction); - ScalarFunction::SetReturnsError(func); - return func; - } else if (type.id() == LogicalTypeId::DECIMAL) { - ScalarFunction func("-", {type}, type, nullptr, DecimalNegateBind, nullptr, NegateBindStatistics); - return func; - } else { - D_ASSERT(type.IsNumeric()); - ScalarFunction func("-", {type}, type, ScalarFunction::GetScalarUnaryFunction(type), nullptr, - nullptr, NegateBindStatistics); - ScalarFunction::SetReturnsError(func); - return func; - } -} - -ScalarFunction SubtractFunction::GetFunction(const LogicalType &left_type, const LogicalType &right_type) { - if (left_type.IsNumeric() && left_type.id() == right_type.id()) { - if (left_type.id() == LogicalTypeId::DECIMAL) { - ScalarFunction function("-", {left_type, right_type}, left_type, nullptr, - BindDecimalAddSubtract); - ScalarFunction::SetReturnsError(function); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - return function; - } else if (left_type.IsIntegral()) { - ScalarFunction function( - "-", {left_type, right_type}, left_type, - GetScalarIntegerFunction(left_type.InternalType()), nullptr, nullptr, - PropagateNumericStats); - ScalarFunction::SetReturnsError(function); - return function; - - } else { - ScalarFunction function("-", {left_type, right_type}, left_type, - GetScalarBinaryFunction(left_type.InternalType())); - ScalarFunction::SetReturnsError(function); - return function; - } - } - - switch (left_type.id()) { - case LogicalTypeId::DATE: - if (right_type.id() == LogicalTypeId::DATE) { - ScalarFunction function("-", {left_type, right_type}, LogicalType::BIGINT, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - - } else if (right_type.id() == LogicalTypeId::INTEGER) { - ScalarFunction function("-", {left_type, right_type}, LogicalType::DATE, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("-", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIMESTAMP: - if (right_type.id() == LogicalTypeId::TIMESTAMP) { - ScalarFunction function( - "-", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } else if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function( - "-", {left_type, right_type}, LogicalType::TIMESTAMP, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::INTERVAL: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function( - "-", {left_type, right_type}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIME: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function("-", {left_type, right_type}, LogicalType::TIME, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - case LogicalTypeId::TIME_TZ: - if (right_type.id() == LogicalTypeId::INTERVAL) { - ScalarFunction function( - "-", {left_type, right_type}, LogicalType::TIME_TZ, - ScalarFunction::BinaryFunction); - ScalarFunction::SetReturnsError(function); - return function; - } - break; - default: - break; - } - // LCOV_EXCL_START - throw NotImplementedException("SubtractFun for types %s, %s", EnumUtil::ToString(left_type.id()), - EnumUtil::ToString(right_type.id())); - // LCOV_EXCL_STOP -} - -ScalarFunctionSet OperatorSubtractFun::GetFunctions() { - ScalarFunctionSet subtract("-"); - for (auto &type : LogicalType::Numeric()) { - // unary subtract function, negates the input (i.e. multiplies by -1) - subtract.AddFunction(SubtractFunction::GetFunction(type)); - // binary subtract function "a - b", subtracts b from a - subtract.AddFunction(SubtractFunction::GetFunction(type, type)); - } - // we can subtract dates from each other - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::DATE, LogicalType::DATE)); - // we can subtract integers from dates - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::DATE, LogicalType::INTEGER)); - // we can subtract timestamps from each other - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::TIMESTAMP, LogicalType::TIMESTAMP)); - // we can subtract intervals from each other - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::INTERVAL, LogicalType::INTERVAL)); - // we can subtract intervals from dates/times/timestamps, but not the other way around - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::DATE, LogicalType::INTERVAL)); - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::TIME, LogicalType::INTERVAL)); - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::TIMESTAMP, LogicalType::INTERVAL)); - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::TIME_TZ, LogicalType::INTERVAL)); - // we can negate intervals - subtract.AddFunction(SubtractFunction::GetFunction(LogicalType::INTERVAL)); - - return subtract; -} - -//===--------------------------------------------------------------------===// -// * [multiply] -//===--------------------------------------------------------------------===// -struct MultiplyPropagateStatistics { - template - static bool Operation(LogicalType type, BaseStatistics &lstats, BaseStatistics &rstats, Value &new_min, - Value &new_max) { - // statistics propagation on the multiplication is slightly less straightforward because of negative numbers - // the new min/max depend on the signs of the input types - // if both are positive the result is [lmin * rmin][lmax * rmax] - // if lmin/lmax are negative the result is [lmin * rmax][lmax * rmin] - // etc - // rather than doing all this switcheroo we just multiply all combinations of lmin/lmax with rmin/rmax - // and check what the minimum/maximum value is - T lvals[] {NumericStats::GetMin(lstats), NumericStats::GetMax(lstats)}; - T rvals[] {NumericStats::GetMin(rstats), NumericStats::GetMax(rstats)}; - T min = NumericLimits::Maximum(); - T max = NumericLimits::Minimum(); - // multiplications - for (idx_t l = 0; l < 2; l++) { - for (idx_t r = 0; r < 2; r++) { - T result; - if (!OP::Operation(lvals[l], rvals[r], result)) { - // potential overflow - return true; - } - if (result < min) { - min = result; - } - if (result > max) { - max = result; - } - } - } - new_min = Value::Numeric(type, min); - new_max = Value::Numeric(type, max); - return false; - } -}; - -unique_ptr BindDecimalMultiply(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - auto bind_data = make_uniq(); - - uint8_t result_width = 0, result_scale = 0; - uint8_t max_width = 0; - for (idx_t i = 0; i < arguments.size(); i++) { - if (arguments[i]->return_type.id() == LogicalTypeId::UNKNOWN) { - continue; - } - uint8_t width, scale; - auto can_convert = arguments[i]->return_type.GetDecimalProperties(width, scale); - if (!can_convert) { - throw InternalException("Could not convert type %s to a decimal?", arguments[i]->return_type.ToString()); - } - if (width > max_width) { - max_width = width; - } - result_width += width; - result_scale += scale; - } - D_ASSERT(max_width > 0); - if (result_scale > Decimal::MAX_WIDTH_DECIMAL) { - throw OutOfRangeException( - "Needed scale %d to accurately represent the multiplication result, but this is out of range of the " - "DECIMAL type. Max scale is %d; could not perform an accurate multiplication. Either add a cast to DOUBLE, " - "or add an explicit cast to a decimal with a lower scale.", - result_scale, Decimal::MAX_WIDTH_DECIMAL); - } - if (result_width > Decimal::MAX_WIDTH_INT64 && max_width <= Decimal::MAX_WIDTH_INT64 && - result_scale < Decimal::MAX_WIDTH_INT64) { - bind_data->check_overflow = true; - result_width = Decimal::MAX_WIDTH_INT64; - } - if (result_width > Decimal::MAX_WIDTH_DECIMAL) { - bind_data->check_overflow = true; - result_width = Decimal::MAX_WIDTH_DECIMAL; - } - LogicalType result_type = LogicalType::DECIMAL(result_width, result_scale); - // since our scale is the summation of our input scales, we do not need to cast to the result scale - // however, we might need to cast to the correct internal type - for (idx_t i = 0; i < arguments.size(); i++) { - auto &argument_type = arguments[i]->return_type; - if (argument_type.InternalType() == result_type.InternalType()) { - bound_function.arguments[i] = argument_type; - } else { - uint8_t width, scale; - if (!argument_type.GetDecimalProperties(width, scale)) { - scale = 0; - } - - bound_function.arguments[i] = LogicalType::DECIMAL(result_width, scale); - } - } - result_type.Verify(); - bound_function.return_type = result_type; - // now select the physical function to execute - if (bind_data->check_overflow) { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } else { - bound_function.function = GetScalarBinaryFunction(result_type.InternalType()); - } - if (result_type.InternalType() != PhysicalType::INT128) { - bound_function.statistics = - PropagateNumericStats; - } - return std::move(bind_data); -} - -ScalarFunctionSet OperatorMultiplyFun::GetFunctions() { - ScalarFunctionSet multiply("*"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - ScalarFunction function({type, type}, type, nullptr, BindDecimalMultiply); - function.serialize = SerializeDecimalArithmetic; - function.deserialize = DeserializeDecimalArithmetic; - multiply.AddFunction(function); - } else if (TypeIsIntegral(type.InternalType())) { - multiply.AddFunction(ScalarFunction( - {type, type}, type, GetScalarIntegerFunction(type.InternalType()), - nullptr, nullptr, - PropagateNumericStats)); - } else { - multiply.AddFunction( - ScalarFunction({type, type}, type, GetScalarBinaryFunction(type.InternalType()))); - } - } - multiply.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction)); - multiply.AddFunction( - ScalarFunction({LogicalType::BIGINT, LogicalType::INTERVAL}, LogicalType::INTERVAL, - ScalarFunction::BinaryFunction)); - for (auto &func : multiply.functions) { - ScalarFunction::SetReturnsError(func); - } - - return multiply; -} - -//===--------------------------------------------------------------------===// -// / [divide] -//===--------------------------------------------------------------------===// -template <> -float DivideOperator::Operation(float left, float right) { - auto result = left / right; - return result; -} - -template <> -double DivideOperator::Operation(double left, double right) { - auto result = left / right; - return result; -} - -template <> -hugeint_t DivideOperator::Operation(hugeint_t left, hugeint_t right) { - if (right.lower == 0 && right.upper == 0) { - throw InternalException("Hugeint division by zero!"); - } - return left / right; -} - -template <> -interval_t DivideOperator::Operation(interval_t left, int64_t right) { - left.days = UnsafeNumericCast(left.days / right); - left.months = UnsafeNumericCast(left.months / right); - left.micros /= right; - return left; -} - -struct BinaryNumericDivideWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (left == NumericLimits::Minimum() && right == -1) { - throw OutOfRangeException("Overflow in division of %d / %d", left, right); - } else if (right == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -struct BinaryZeroIsNullWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (right == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -struct BinaryNumericDivideHugeintWrapper { - template - static inline RESULT_TYPE Operation(FUNC fun, LEFT_TYPE left, RIGHT_TYPE right, ValidityMask &mask, idx_t idx) { - if (left == NumericLimits::Minimum() && right == -1) { - throw OutOfRangeException("Overflow in division of %s / %s", left.ToString(), right.ToString()); - } else if (right == 0) { - mask.SetInvalid(idx); - return left; - } else { - return OP::template Operation(left, right); - } - } - - static bool AddsNulls() { - return true; - } -}; - -template -static void BinaryScalarFunctionIgnoreZero(DataChunk &input, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute(input.data[0], input.data[1], result, input.size()); -} - -template -static scalar_function_t GetBinaryFunctionIgnoreZero(PhysicalType type) { - switch (type) { - case PhysicalType::INT8: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::INT16: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::INT32: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::INT64: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::UINT8: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::UINT16: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::UINT32: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::UINT64: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::INT128: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::UINT128: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::FLOAT: - return BinaryScalarFunctionIgnoreZero; - case PhysicalType::DOUBLE: - return BinaryScalarFunctionIgnoreZero; - default: - throw NotImplementedException("Unimplemented type for GetScalarUnaryFunction"); - } -} - -template -unique_ptr BindBinaryFloatingPoint(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto &config = ClientConfig::GetConfig(context); - if (config.ieee_floating_point_ops) { - bound_function.function = GetScalarBinaryFunction(bound_function.return_type.InternalType()); - } else { - bound_function.function = GetBinaryFunctionIgnoreZero(bound_function.return_type.InternalType()); - } - return nullptr; -} - -ScalarFunctionSet OperatorFloatDivideFun::GetFunctions() { - ScalarFunctionSet fp_divide("/"); - fp_divide.AddFunction(ScalarFunction({LogicalType::FLOAT, LogicalType::FLOAT}, LogicalType::FLOAT, nullptr, - BindBinaryFloatingPoint)); - fp_divide.AddFunction(ScalarFunction({LogicalType::DOUBLE, LogicalType::DOUBLE}, LogicalType::DOUBLE, nullptr, - BindBinaryFloatingPoint)); - fp_divide.AddFunction( - ScalarFunction({LogicalType::INTERVAL, LogicalType::BIGINT}, LogicalType::INTERVAL, - BinaryScalarFunctionIgnoreZero)); - for (auto &func : fp_divide.functions) { - ScalarFunction::SetReturnsError(func); - } - return fp_divide; -} - -ScalarFunctionSet OperatorIntegerDivideFun::GetFunctions() { - ScalarFunctionSet full_divide("//"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::DECIMAL) { - continue; - } else { - full_divide.AddFunction( - ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type.InternalType()))); - } - } - for (auto &func : full_divide.functions) { - ScalarFunction::SetReturnsError(func); - } - return full_divide; -} - -//===--------------------------------------------------------------------===// -// % [modulo] -//===--------------------------------------------------------------------===// -template -unique_ptr BindDecimalModulo(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto bind_data = BindDecimalArithmetic(context, bound_function, arguments); - // now select the physical function to execute - if (bind_data->check_overflow) { - // fallback to DOUBLE if the decimal type is not guaranteed to fit within the max decimal width - for (auto &arg : bound_function.arguments) { - arg = LogicalType::DOUBLE; - } - bound_function.return_type = LogicalType::DOUBLE; - } - auto &result_type = bound_function.return_type; - bound_function.function = GetBinaryFunctionIgnoreZero(result_type.InternalType()); - return std::move(bind_data); -} - -template <> -float ModuloOperator::Operation(float left, float right) { - auto result = std::fmod(left, right); - return result; -} - -template <> -double ModuloOperator::Operation(double left, double right) { - auto result = std::fmod(left, right); - return result; -} - -template <> -hugeint_t ModuloOperator::Operation(hugeint_t left, hugeint_t right) { - if (right.lower == 0 && right.upper == 0) { - throw InternalException("Hugeint division by zero!"); - } - return left % right; -} - -ScalarFunctionSet OperatorModuloFun::GetFunctions() { - ScalarFunctionSet modulo("%"); - for (auto &type : LogicalType::Numeric()) { - if (type.id() == LogicalTypeId::FLOAT || type.id() == LogicalTypeId::DOUBLE) { - modulo.AddFunction(ScalarFunction({type, type}, type, nullptr, BindBinaryFloatingPoint)); - } else if (type.id() == LogicalTypeId::DECIMAL) { - modulo.AddFunction(ScalarFunction({type, type}, type, nullptr, BindDecimalModulo)); - } else { - modulo.AddFunction( - ScalarFunction({type, type}, type, GetBinaryFunctionIgnoreZero(type.InternalType()))); - } - } - for (auto &func : modulo.functions) { - ScalarFunction::SetReturnsError(func); - } - - return modulo; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operator/multiply.cpp b/src/duckdb/src/function/scalar/operator/multiply.cpp deleted file mode 100644 index c2ae20767..000000000 --- a/src/duckdb/src/function/scalar/operator/multiply.cpp +++ /dev/null @@ -1,242 +0,0 @@ -#include "duckdb/common/operator/multiply.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/value.hpp" -#include "duckdb/common/windows_undefs.hpp" - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// * [multiply] -//===--------------------------------------------------------------------===// -template <> -float MultiplyOperator::Operation(float left, float right) { - auto result = left * right; - return result; -} - -template <> -double MultiplyOperator::Operation(double left, double right) { - auto result = left * right; - return result; -} - -template <> -interval_t MultiplyOperator::Operation(interval_t left, int64_t right) { - const auto right32 = Cast::Operation(right); - left.months = MultiplyOperatorOverflowCheck::Operation(left.months, right32); - left.days = MultiplyOperatorOverflowCheck::Operation(left.days, right32); - left.micros = MultiplyOperatorOverflowCheck::Operation(left.micros, right); - return left; -} - -template <> -interval_t MultiplyOperator::Operation(int64_t left, interval_t right) { - return MultiplyOperator::Operation(right, left); -} - -//===--------------------------------------------------------------------===// -// * [multiply] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedMultiply { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = MultiplyOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TryMultiplyOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} -template <> -bool TryMultiplyOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (left > right) { - std::swap(left, right); - } - if (left > NumericLimits::Maximum()) { - return false; - } - uint32_t c = right >> 32; - uint32_t d = NumericLimits::Maximum() & right; - uint64_t r = left * c; - uint64_t s = left * d; - if (r > NumericLimits::Maximum()) { - return false; - } - r <<= 32; - if (NumericLimits::Maximum() - s < r) { - return false; - } - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedMultiply::Operation(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_mul_overflow(left, right, &result)) { - return false; - } -#else - if (left == std::numeric_limits::min()) { - if (right == 0) { - result = 0; - return true; - } - if (right == 1) { - result = left; - return true; - } - return false; - } - if (right == std::numeric_limits::min()) { - if (left == 0) { - result = 0; - return true; - } - if (left == 1) { - result = right; - return true; - } - return false; - } - uint64_t left_non_negative = uint64_t(std::abs(left)); - uint64_t right_non_negative = uint64_t(std::abs(right)); - // split values into 2 32-bit parts - uint64_t left_high_bits = left_non_negative >> 32; - uint64_t left_low_bits = left_non_negative & 0xffffffff; - uint64_t right_high_bits = right_non_negative >> 32; - uint64_t right_low_bits = right_non_negative & 0xffffffff; - - // check the high bits of both - // the high bits define the overflow - if (left_high_bits == 0) { - if (right_high_bits != 0) { - // only the right has high bits set - // multiply the high bits of right with the low bits of left - // multiply the low bits, and carry any overflow to the high bits - // then check for any overflow - auto low_low = left_low_bits * right_low_bits; - auto low_high = left_low_bits * right_high_bits; - auto high_bits = low_high + (low_low >> 32); - if (high_bits & 0xffffff80000000) { - // there is! abort - return false; - } - } - } else if (right_high_bits == 0) { - // only the left has high bits set - // multiply the high bits of left with the low bits of right - // multiply the low bits, and carry any overflow to the high bits - // then check for any overflow - auto low_low = left_low_bits * right_low_bits; - auto high_low = left_high_bits * right_low_bits; - auto high_bits = high_low + (low_low >> 32); - if (high_bits & 0xffffff80000000) { - // there is! abort - return false; - } - } else { - // both left and right have high bits set: guaranteed overflow - // abort! - return false; - } - // now we know that there is no overflow, we can just perform the multiplication - result = left * right; -#endif - return true; -} - -template <> -bool TryMultiplyOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - return Hugeint::TryMultiply(left, right, result); -} - -template <> -bool TryMultiplyOperator::Operation(uhugeint_t left, uhugeint_t right, uhugeint_t &result) { - return Uhugeint::TryMultiply(left, right, result); -} - -//===--------------------------------------------------------------------===// -// multiply decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalMultiplyTemplated(T left, T right, T &result) { - if (!TryMultiplyOperator::Operation(left, right, result) || result < min || result > max) { - return false; - } - return true; -} - -template <> -bool TryDecimalMultiply::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalMultiplyTemplated(left, right, result); -} - -template <> -bool TryDecimalMultiply::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - if (!TryMultiplyOperator::Operation(left, right, result)) { - return false; - } - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalMultiplyOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalMultiply::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in multiplication of DECIMAL(38) (%s * %s). You might want to add an " - "explicit cast to a decimal with a smaller scale.", - left.ToString(), right.ToString()); - } - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/operator/subtract.cpp b/src/duckdb/src/function/scalar/operator/subtract.cpp deleted file mode 100644 index 3a102b669..000000000 --- a/src/duckdb/src/function/scalar/operator/subtract.cpp +++ /dev/null @@ -1,243 +0,0 @@ -#include "duckdb/common/operator/subtract.hpp" - -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/common/types/uhugeint.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/interval.hpp" -#include "duckdb/common/types/value.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// - [subtract] -//===--------------------------------------------------------------------===// -template <> -float SubtractOperator::Operation(float left, float right) { - auto result = left - right; - return result; -} - -template <> -double SubtractOperator::Operation(double left, double right) { - auto result = left - right; - return result; -} - -template <> -int64_t SubtractOperator::Operation(date_t left, date_t right) { - return int64_t(left.days) - int64_t(right.days); -} - -template <> -date_t SubtractOperator::Operation(date_t left, int32_t right) { - if (!Date::IsFinite(left)) { - return left; - } - int32_t days; - if (!TrySubtractOperator::Operation(left.days, right, days)) { - throw OutOfRangeException("Date out of range"); - } - - date_t result(days); - if (!Date::IsFinite(result)) { - throw OutOfRangeException("Date out of range"); - } - return result; -} - -template <> -interval_t SubtractOperator::Operation(interval_t left, interval_t right) { - interval_t result; - if (!TrySubtractOperator::Operation(left.months, right.months, result.months)) { - throw OutOfRangeException("Interval months subtraction out of range"); - } - if (!TrySubtractOperator::Operation(left.days, right.days, result.days)) { - throw OutOfRangeException("Interval days subtraction out of range"); - } - if (!TrySubtractOperator::Operation(left.micros, right.micros, result.micros)) { - throw OutOfRangeException("Interval micros subtraction out of range"); - } - return result; -} - -template <> -timestamp_t SubtractOperator::Operation(date_t left, interval_t right) { - return AddOperator::Operation(left, Interval::Invert(right)); -} - -template <> -timestamp_t SubtractOperator::Operation(timestamp_t left, interval_t right) { - return AddOperator::Operation(left, Interval::Invert(right)); -} - -template <> -interval_t SubtractOperator::Operation(timestamp_t left, timestamp_t right) { - return Interval::GetDifference(left, right); -} - -//===--------------------------------------------------------------------===// -// - [subtract] with overflow check -//===--------------------------------------------------------------------===// -struct OverflowCheckedSubtract { - template - static inline bool Operation(SRCTYPE left, SRCTYPE right, SRCTYPE &result) { - UTYPE uresult = SubtractOperator::Operation(UTYPE(left), UTYPE(right)); - if (uresult < NumericLimits::Minimum() || uresult > NumericLimits::Maximum()) { - return false; - } - result = SRCTYPE(uresult); - return true; - } -}; - -template <> -bool TrySubtractOperator::Operation(uint8_t left, uint8_t right, uint8_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint16_t left, uint16_t right, uint16_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint32_t left, uint32_t right, uint32_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(uint64_t left, uint64_t right, uint64_t &result) { - if (right > left) { - return false; - } - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int8_t left, int8_t right, int8_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int16_t left, int16_t right, int16_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int32_t left, int32_t right, int32_t &result) { - return OverflowCheckedSubtract::Operation(left, right, result); -} - -template <> -bool TrySubtractOperator::Operation(int64_t left, int64_t right, int64_t &result) { -#if (__GNUC__ >= 5) || defined(__clang__) - if (__builtin_sub_overflow(left, right, &result)) { - return false; - } -#else - if (right < 0) { - if (NumericLimits::Maximum() + right < left) { - return false; - } - } else { - if (NumericLimits::Minimum() + right > left) { - return false; - } - } - result = left - right; -#endif - return true; -} - -template <> -bool TrySubtractOperator::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - result = left; - return Hugeint::TrySubtractInPlace(result, right); -} - -template <> -bool TrySubtractOperator::Operation(uhugeint_t left, uhugeint_t right, uhugeint_t &result) { - result = left; - return Uhugeint::TrySubtractInPlace(result, right); -} - -//===--------------------------------------------------------------------===// -// subtract decimal with overflow check -//===--------------------------------------------------------------------===// -template -bool TryDecimalSubtractTemplated(T left, T right, T &result) { - if (right < 0) { - if (max + right < left) { - return false; - } - } else { - if (min + right > left) { - return false; - } - } - result = left - right; - return true; -} - -template <> -bool TryDecimalSubtract::Operation(int16_t left, int16_t right, int16_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(int32_t left, int32_t right, int32_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(int64_t left, int64_t right, int64_t &result) { - return TryDecimalSubtractTemplated(left, right, result); -} - -template <> -bool TryDecimalSubtract::Operation(hugeint_t left, hugeint_t right, hugeint_t &result) { - if (!TrySubtractOperator::Operation(left, right, result)) { - return false; - } - if (result <= -Hugeint::POWERS_OF_TEN[38] || result >= Hugeint::POWERS_OF_TEN[38]) { - return false; - } - return true; -} - -template <> -hugeint_t DecimalSubtractOverflowCheck::Operation(hugeint_t left, hugeint_t right) { - hugeint_t result; - if (!TryDecimalSubtract::Operation(left, right, result)) { - throw OutOfRangeException("Overflow in subtract of DECIMAL(38) (%s - %s);", left.ToString(), right.ToString()); - } - return result; -} - -//===--------------------------------------------------------------------===// -// subtract time operator -//===--------------------------------------------------------------------===// -template <> -dtime_t SubtractTimeOperator::Operation(dtime_t left, interval_t right) { - right.micros = -right.micros; - return AddTimeOperator::Operation(left, right); -} - -template <> -dtime_tz_t SubtractTimeOperator::Operation(dtime_tz_t left, interval_t right) { - right.micros = -right.micros; - return AddTimeOperator::Operation(left, right); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/pragma_functions.cpp b/src/duckdb/src/function/scalar/pragma_functions.cpp deleted file mode 100644 index dfce174c4..000000000 --- a/src/duckdb/src/function/scalar/pragma_functions.cpp +++ /dev/null @@ -1,10 +0,0 @@ -#include "duckdb/function/pragma/pragma_functions.hpp" - -namespace duckdb { - -void BuiltinFunctions::RegisterPragmaFunctions() { - Register(); - Register(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/sequence/nextval.cpp b/src/duckdb/src/function/scalar/sequence/nextval.cpp deleted file mode 100644 index 6738383dc..000000000 --- a/src/duckdb/src/function/scalar/sequence/nextval.cpp +++ /dev/null @@ -1,162 +0,0 @@ -#include "duckdb/function/scalar/sequence_functions.hpp" -#include "duckdb/function/scalar/sequence_utils.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/transaction/meta_transaction.hpp" -#include "duckdb/planner/binder.hpp" - -namespace duckdb { - -struct CurrentSequenceValueOperator { - static int64_t Operation(DuckTransaction &, SequenceCatalogEntry &seq) { - return seq.CurrentValue(); - } -}; - -struct NextSequenceValueOperator { - static int64_t Operation(DuckTransaction &transaction, SequenceCatalogEntry &seq) { - return seq.NextValue(transaction); - } -}; - -SequenceCatalogEntry &BindSequence(Binder &binder, string &catalog, string &schema, const string &name) { - // fetch the sequence from the catalog - Binder::BindSchemaOrCatalog(binder.context, catalog, schema); - return binder.EntryRetriever() - .GetEntry(CatalogType::SEQUENCE_ENTRY, catalog, schema, name) - ->Cast(); -} - -SequenceCatalogEntry &BindSequenceFromContext(ClientContext &context, string &catalog, string &schema, - const string &name) { - Binder::BindSchemaOrCatalog(context, catalog, schema); - return Catalog::GetEntry(context, catalog, schema, name); -} - -SequenceCatalogEntry &BindSequence(Binder &binder, const string &name) { - auto qname = QualifiedName::Parse(name); - return BindSequence(binder, qname.catalog, qname.schema, qname.name); -} - -struct NextValLocalState : public FunctionLocalState { - explicit NextValLocalState(DuckTransaction &transaction, SequenceCatalogEntry &sequence) - : transaction(transaction), sequence(sequence) { - } - - DuckTransaction &transaction; - SequenceCatalogEntry &sequence; -}; - -unique_ptr NextValLocalFunction(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - if (!bind_data) { - return nullptr; - } - auto &context = state.GetContext(); - auto &info = bind_data->Cast(); - auto &sequence = info.sequence; - auto &transaction = DuckTransaction::Get(context, sequence.catalog); - return make_uniq(transaction, sequence); -} - -template -static void NextValFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - if (!func_expr.bind_info) { - // no bind info - return null - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - // sequence to use is hard coded - // increment the sequence - result.SetVectorType(VectorType::FLAT_VECTOR); - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - // get the next value from the sequence - result_data[i] = OP::Operation(lstate.transaction, lstate.sequence); - } -} - -static unique_ptr NextValBind(ScalarFunctionBindInput &bind_input, ScalarFunction &, - vector> &arguments) { - if (arguments[0]->HasParameter() || arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (!arguments[0]->IsFoldable()) { - throw NotImplementedException( - "currval/nextval requires a constant sequence - non-constant sequences are no longer supported"); - } - auto &binder = bind_input.binder; - // parameter to nextval function is a foldable constant - // evaluate the constant and perform the catalog lookup already - auto seqname = ExpressionExecutor::EvaluateScalar(binder.context, *arguments[0]); - if (seqname.IsNull()) { - return nullptr; - } - auto &seq = BindSequence(binder, seqname.ToString()); - return make_uniq(seq); -} - -void Serialize(Serializer &serializer, const optional_ptr bind_data, const ScalarFunction &) { - auto &next_val_bind_data = bind_data->Cast(); - serializer.WritePropertyWithDefault(100, "sequence_create_info", next_val_bind_data.create_info); -} - -unique_ptr Deserialize(Deserializer &deserializer, ScalarFunction &) { - auto create_info = deserializer.ReadPropertyWithExplicitDefault>(100, "sequence_create_info", - unique_ptr()); - if (!create_info) { - return nullptr; - } - auto &seq_info = create_info->Cast(); - auto &context = deserializer.Get(); - auto &sequence = BindSequenceFromContext(context, seq_info.catalog, seq_info.schema, seq_info.name); - return make_uniq(sequence); -} - -void NextValModifiedDatabases(ClientContext &context, FunctionModifiedDatabasesInput &input) { - if (!input.bind_data) { - return; - } - auto &seq = input.bind_data->Cast(); - input.properties.RegisterDBModify(seq.sequence.ParentCatalog(), context); -} - -ScalarFunction NextvalFun::GetFunction() { - ScalarFunction next_val("nextval", {LogicalType::VARCHAR}, LogicalType::BIGINT, - NextValFunction, nullptr, nullptr); - next_val.bind_extended = NextValBind; - next_val.stability = FunctionStability::VOLATILE; - next_val.serialize = Serialize; - next_val.deserialize = Deserialize; - next_val.get_modified_databases = NextValModifiedDatabases; - next_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(next_val); - return next_val; -} - -ScalarFunction CurrvalFun::GetFunction() { - ScalarFunction curr_val("currval", {LogicalType::VARCHAR}, LogicalType::BIGINT, - NextValFunction, nullptr, nullptr); - curr_val.bind_extended = NextValBind; - curr_val.stability = FunctionStability::VOLATILE; - curr_val.serialize = Serialize; - curr_val.deserialize = Deserialize; - curr_val.init_local_state = NextValLocalFunction; - BaseScalarFunction::SetReturnsError(curr_val); - return curr_val; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/strftime_format.cpp b/src/duckdb/src/function/scalar/strftime_format.cpp deleted file mode 100644 index 8ab46ace7..000000000 --- a/src/duckdb/src/function/scalar/strftime_format.cpp +++ /dev/null @@ -1,1594 +0,0 @@ -#include "duckdb/function/scalar/strftime_format.hpp" -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/types/cast_helpers.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/time.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/multiply.hpp" - -#include - -namespace duckdb { - -idx_t StrfTimepecifierSize(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: - return 3; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - case StrTimeSpecifier::WEEKDAY_ISO: - return 1; - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::AM_PM: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - return 2; - case StrTimeSpecifier::NANOSECOND_PADDED: - return 9; - case StrTimeSpecifier::MICROSECOND_PADDED: - return 6; - case StrTimeSpecifier::MILLISECOND_PADDED: - return 3; - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - return 3; - case StrTimeSpecifier::YEAR_ISO: - return 4; - default: - return 0; - } -} - -void StrTimeFormat::AddLiteral(string literal) { - constant_size += literal.size(); - literals.push_back(std::move(literal)); -} - -void StrTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - AddLiteral(std::move(preceding_literal)); - specifiers.push_back(specifier); -} - -void StrfTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - is_date_specifier.push_back(IsDateSpecifier(specifier)); - idx_t specifier_size = StrfTimepecifierSize(specifier); - if (specifier_size == 0) { - // variable length specifier - var_length_specifiers.push_back(specifier); - } else { - // constant size specifier - constant_size += specifier_size; - } - StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); -} - -idx_t StrfTimeFormat::GetSpecifierLength(StrTimeSpecifier specifier, date_t date, int32_t data[8], - const char *tz_name) { - switch (specifier) { - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - return Date::DAY_NAMES[Date::ExtractISODayOfTheWeek(date) % 7].GetSize(); - case StrTimeSpecifier::FULL_MONTH_NAME: - return Date::MONTH_NAMES[data[1] - 1].GetSize(); - case StrTimeSpecifier::YEAR_DECIMAL: { - auto year = data[0]; - // Be consistent with WriteStandardSpecifier - if (0 <= year && year <= 9999) { - return 4; - } else { - return UnsafeNumericCast(NumericHelper::SignedLength(year)); - } - } - case StrTimeSpecifier::MONTH_DECIMAL: { - idx_t len = 1; - auto month = data[1]; - len += month >= 10; - return len; - } - case StrTimeSpecifier::UTC_OFFSET: - // ±HH or ±HH:MM - return (data[7] % 60) ? 6 : 3; - case StrTimeSpecifier::TZ_NAME: - if (tz_name) { - return strlen(tz_name); - } - // empty for now - return 0; - case StrTimeSpecifier::HOUR_24_DECIMAL: - case StrTimeSpecifier::HOUR_12_DECIMAL: - case StrTimeSpecifier::MINUTE_DECIMAL: - case StrTimeSpecifier::SECOND_DECIMAL: { - // time specifiers - idx_t len = 1; - int32_t hour = data[3], min = data[4], sec = data[5]; - switch (specifier) { - case StrTimeSpecifier::HOUR_24_DECIMAL: - len += hour >= 10; - break; - case StrTimeSpecifier::HOUR_12_DECIMAL: - hour = hour % 12; - if (hour == 0) { - hour = 12; - } - len += hour >= 10; - break; - case StrTimeSpecifier::MINUTE_DECIMAL: - len += min >= 10; - break; - case StrTimeSpecifier::SECOND_DECIMAL: - len += sec >= 10; - break; - default: - throw InternalException("Time specifier mismatch"); - } - return len; - } - case StrTimeSpecifier::DAY_OF_MONTH: - return UnsafeNumericCast(NumericHelper::UnsignedLength(UnsafeNumericCast(data[2]))); - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return UnsafeNumericCast( - NumericHelper::UnsignedLength(UnsafeNumericCast(Date::ExtractDayOfTheYear(date)))); - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - return UnsafeNumericCast( - NumericHelper::UnsignedLength(UnsafeNumericCast(AbsValue(data[0]) % 100))); - default: - throw InternalException("Unimplemented specifier for GetSpecifierLength"); - } -} - -//! Returns the total length of the date formatted by this format specifier -idx_t StrfTimeFormat::GetLength(date_t date, int32_t data[8], const char *tz_name) const { - idx_t size = constant_size; - if (!var_length_specifiers.empty()) { - for (auto &specifier : var_length_specifiers) { - size += GetSpecifierLength(specifier, date, data, tz_name); - } - } - return size; -} - -idx_t StrfTimeFormat::GetLength(date_t date, dtime_t time, int32_t utc_offset, const char *tz_name) { - if (!var_length_specifiers.empty()) { - int32_t data[8]; - Date::Convert(date, data[0], data[1], data[2]); - Time::Convert(time, data[3], data[4], data[5], data[6]); - data[6] *= Interval::NANOS_PER_MICRO; - data[7] = utc_offset; - return GetLength(date, data, tz_name); - } - return constant_size; -} - -char *StrfTimeFormat::WriteString(char *target, const string_t &str) const { - idx_t size = str.GetSize(); - memcpy(target, str.GetData(), size); - return target + size; -} - -// write a value in the range of 0..99 unpadded (e.g. "1", "2", ... "98", "99") -char *StrfTimeFormat::Write2(char *target, uint8_t value) const { - D_ASSERT(value < 100); - if (value >= 10) { - return WritePadded2(target, value); - } else { - *target = char(uint8_t('0') + value); - return target + 1; - } -} - -// write a value in the range of 0..99 padded to 2 digits -char *StrfTimeFormat::WritePadded2(char *target, uint32_t value) const { - D_ASSERT(value < 100); - auto index = static_cast(value * 2); - *target++ = duckdb_fmt::internal::data::digits[index]; - *target++ = duckdb_fmt::internal::data::digits[index + 1]; - return target; -} - -// write a value in the range of 0..999 padded -char *StrfTimeFormat::WritePadded3(char *target, uint32_t value) const { - D_ASSERT(value < 1000); - if (value >= 100) { - WritePadded2(target + 1, value % 100); - *target = char(uint8_t('0') + value / 100); - return target + 3; - } else { - *target = '0'; - target++; - return WritePadded2(target, value); - } -} - -// write a value in the range of 0..999999... padded to the given number of digits -char *StrfTimeFormat::WritePadded(char *target, uint32_t value, size_t padding) const { - D_ASSERT(padding > 1); - if (padding % 2) { - uint32_t decimals = value % 1000u; - WritePadded3(target + padding - 3, decimals); - value /= 1000; - padding -= 3; - } - for (size_t i = 0; i < padding / 2; i++) { - uint32_t decimals = value % 100u; - WritePadded2(target + padding - 2 * (i + 1), decimals); - value /= 100; - } - return target + padding; -} - -bool StrfTimeFormat::IsDateSpecifier(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - case StrTimeSpecifier::WEEKDAY_DECIMAL: - case StrTimeSpecifier::WEEKDAY_ISO: - case StrTimeSpecifier::YEAR_ISO: - return true; - default: - return false; - } -} - -char *StrfTimeFormat::WriteDateSpecifier(StrTimeSpecifier specifier, date_t date, char *target) const { - switch (specifier) { - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: { - auto dow = Date::ExtractISODayOfTheWeek(date); - target = WriteString(target, Date::DAY_NAMES_ABBREVIATED[dow % 7]); - break; - } - case StrTimeSpecifier::FULL_WEEKDAY_NAME: { - auto dow = Date::ExtractISODayOfTheWeek(date); - target = WriteString(target, Date::DAY_NAMES[dow % 7]); - break; - } - case StrTimeSpecifier::WEEKDAY_DECIMAL: { - auto dow = Date::ExtractISODayOfTheWeek(date); - *target = char('0' + uint8_t(dow % 7)); - target++; - break; - } - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: { - int32_t doy = Date::ExtractDayOfTheYear(date); - target = WritePadded3(target, UnsafeNumericCast(doy)); - break; - } - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, true))); - break; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractWeekNumberRegular(date, false))); - break; - case StrTimeSpecifier::WEEK_NUMBER_ISO: - target = WritePadded2(target, UnsafeNumericCast(Date::ExtractISOWeekNumber(date))); - break; - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - auto doy = UnsafeNumericCast(Date::ExtractDayOfTheYear(date)); - target += NumericHelper::UnsignedLength(doy); - NumericHelper::FormatUnsigned(doy, target); - break; - } - case StrTimeSpecifier::YEAR_ISO: - target = WritePadded(target, UnsafeNumericCast(Date::ExtractISOYearNumber(date)), 4); - break; - case StrTimeSpecifier::WEEKDAY_ISO: - *target = char('0' + uint8_t(Date::ExtractISODayOfTheWeek(date))); - target++; - break; - default: - throw InternalException("Unimplemented date specifier for strftime"); - } - return target; -} - -char *StrfTimeFormat::WriteStandardSpecifier(StrTimeSpecifier specifier, int32_t data[], const char *tz_name, - size_t tz_len, char *target) const { - // data contains [0] year, [1] month, [2] day, [3] hour, [4] minute, [5] second, [6] ns, [7] utc - switch (specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[2])); - break; - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { - auto &month_name = Date::MONTH_NAMES_ABBREVIATED[data[1] - 1]; - return WriteString(target, month_name); - } - case StrTimeSpecifier::FULL_MONTH_NAME: { - auto &month_name = Date::MONTH_NAMES[data[1] - 1]; - return WriteString(target, month_name); - } - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[1])); - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - target = WritePadded2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); - break; - case StrTimeSpecifier::YEAR_DECIMAL: - if (data[0] >= 0 && data[0] <= 9999) { - target = WritePadded(target, UnsafeNumericCast(data[0]), 4); - } else { - int32_t year = data[0]; - if (data[0] < 0) { - *target = '-'; - year = -year; - target++; - } - auto len = NumericHelper::UnsignedLength(UnsafeNumericCast(year)); - NumericHelper::FormatUnsigned(year, target + len); - target += len; - } - break; - case StrTimeSpecifier::HOUR_24_PADDED: { - target = WritePadded2(target, UnsafeNumericCast(data[3])); - break; - } - case StrTimeSpecifier::HOUR_12_PADDED: { - int hour = data[3] % 12; - if (hour == 0) { - hour = 12; - } - target = WritePadded2(target, UnsafeNumericCast(hour)); - break; - } - case StrTimeSpecifier::AM_PM: - *target++ = data[3] >= 12 ? 'P' : 'A'; - *target++ = 'M'; - break; - case StrTimeSpecifier::MINUTE_PADDED: { - target = WritePadded2(target, UnsafeNumericCast(data[4])); - break; - } - case StrTimeSpecifier::SECOND_PADDED: - target = WritePadded2(target, UnsafeNumericCast(data[5])); - break; - case StrTimeSpecifier::NANOSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6]), 9); - break; - case StrTimeSpecifier::MICROSECOND_PADDED: - target = WritePadded(target, UnsafeNumericCast(data[6] / Interval::NANOS_PER_MICRO), 6); - break; - case StrTimeSpecifier::MILLISECOND_PADDED: - target = WritePadded3(target, UnsafeNumericCast(data[6] / Interval::NANOS_PER_MSEC)); - break; - case StrTimeSpecifier::UTC_OFFSET: { - *target++ = (data[7] < 0) ? '-' : '+'; - - auto offset = abs(data[7]); - auto offset_hours = offset / Interval::MINS_PER_HOUR; - auto offset_minutes = offset % Interval::MINS_PER_HOUR; - target = WritePadded2(target, UnsafeNumericCast(offset_hours)); - if (offset_minutes) { - *target++ = ':'; - target = WritePadded2(target, UnsafeNumericCast(offset_minutes)); - } - break; - } - case StrTimeSpecifier::TZ_NAME: - if (tz_name) { - memcpy(target, tz_name, tz_len); - target += strlen(tz_name); - } - break; - case StrTimeSpecifier::DAY_OF_MONTH: { - target = Write2(target, UnsafeNumericCast(data[2] % 100)); - break; - } - case StrTimeSpecifier::MONTH_DECIMAL: { - target = Write2(target, UnsafeNumericCast(data[1])); - break; - } - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: { - target = Write2(target, UnsafeNumericCast(AbsValue(data[0]) % 100)); - break; - } - case StrTimeSpecifier::HOUR_24_DECIMAL: { - target = Write2(target, UnsafeNumericCast(data[3])); - break; - } - case StrTimeSpecifier::HOUR_12_DECIMAL: { - int hour = data[3] % 12; - if (hour == 0) { - hour = 12; - } - target = Write2(target, UnsafeNumericCast(hour)); - break; - } - case StrTimeSpecifier::MINUTE_DECIMAL: { - target = Write2(target, UnsafeNumericCast(data[4])); - break; - } - case StrTimeSpecifier::SECOND_DECIMAL: { - target = Write2(target, UnsafeNumericCast(data[5])); - break; - } - default: - throw InternalException("Unimplemented specifier for WriteStandardSpecifier in strftime"); - } - return target; -} - -void StrfTimeFormat::FormatStringNS(date_t date, int32_t data[8], const char *tz_name, char *target) const { - D_ASSERT(specifiers.size() + 1 == literals.size()); - idx_t i; - for (i = 0; i < specifiers.size(); i++) { - // first copy the current literal - memcpy(target, literals[i].c_str(), literals[i].size()); - target += literals[i].size(); - // now copy the specifier - if (is_date_specifier[i]) { - target = WriteDateSpecifier(specifiers[i], date, target); - } else { - auto tz_len = tz_name ? strlen(tz_name) : 0; - target = WriteStandardSpecifier(specifiers[i], data, tz_name, tz_len, target); - } - } - // copy the final literal into the target - memcpy(target, literals[i].c_str(), literals[i].size()); -} - -void StrfTimeFormat::FormatString(date_t date, int32_t data[8], const char *tz_name, char *target) { - data[6] *= Interval::NANOS_PER_MICRO; - FormatStringNS(date, data, tz_name, target); - data[6] /= Interval::NANOS_PER_MICRO; -} - -void StrfTimeFormat::FormatString(date_t date, dtime_t time, char *target) { - int32_t data[8]; // year, month, day, hour, min, sec, µs, offset - Date::Convert(date, data[0], data[1], data[2]); - Time::Convert(time, data[3], data[4], data[5], data[6]); - data[7] = 0; - - FormatString(date, data, nullptr, target); -} - -string StrfTimeFormat::Format(timestamp_t timestamp, const string &format_str) { - StrfTimeFormat format; - format.ParseFormatSpecifier(format_str, format); - - auto date = Timestamp::GetDate(timestamp); - auto time = Timestamp::GetTime(timestamp); - - auto len = format.GetLength(date, time, 0, nullptr); - auto result = make_unsafe_uniq_array_uninitialized(len); - format.FormatString(date, time, result.get()); - return string(result.get(), len); -} - -string StrTimeFormat::ParseFormatSpecifier(const string &format_string, StrTimeFormat &format) { - if (format_string.empty()) { - return "Empty format string"; - } - format.format_specifier = format_string; - format.specifiers.clear(); - format.literals.clear(); - format.numeric_width.clear(); - format.constant_size = 0; - idx_t pos = 0; - string current_literal; - for (idx_t i = 0; i < format_string.size(); i++) { - if (format_string[i] == '%') { - if (i + 1 == format_string.size()) { - return "Trailing format character %"; - } - if (i > pos) { - // push the previous string to the current literal - current_literal += format_string.substr(pos, i - pos); - } - char format_char = format_string[++i]; - if (format_char == '%') { - // special case: %% - // set the pos for the next literal and continue - pos = i; - continue; - } - StrTimeSpecifier specifier; - if (format_char == '-' && i + 1 < format_string.size()) { - format_char = format_string[++i]; - switch (format_char) { - case 'd': - specifier = StrTimeSpecifier::DAY_OF_MONTH; - break; - case 'm': - specifier = StrTimeSpecifier::MONTH_DECIMAL; - break; - case 'y': - specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY; - break; - case 'H': - specifier = StrTimeSpecifier::HOUR_24_DECIMAL; - break; - case 'I': - specifier = StrTimeSpecifier::HOUR_12_DECIMAL; - break; - case 'M': - specifier = StrTimeSpecifier::MINUTE_DECIMAL; - break; - case 'S': - specifier = StrTimeSpecifier::SECOND_DECIMAL; - break; - case 'j': - specifier = StrTimeSpecifier::DAY_OF_YEAR_DECIMAL; - break; - default: - return "Unrecognized format for strftime/strptime: %-" + string(1, format_char); - } - } else { - switch (format_char) { - case 'a': - specifier = StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME; - break; - case 'A': - specifier = StrTimeSpecifier::FULL_WEEKDAY_NAME; - break; - case 'w': - specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; - break; - case 'u': - specifier = StrTimeSpecifier::WEEKDAY_ISO; - break; - case 'd': - specifier = StrTimeSpecifier::DAY_OF_MONTH_PADDED; - break; - case 'h': - case 'b': - specifier = StrTimeSpecifier::ABBREVIATED_MONTH_NAME; - break; - case 'B': - specifier = StrTimeSpecifier::FULL_MONTH_NAME; - break; - case 'm': - specifier = StrTimeSpecifier::MONTH_DECIMAL_PADDED; - break; - case 'y': - specifier = StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED; - break; - case 'Y': - specifier = StrTimeSpecifier::YEAR_DECIMAL; - break; - case 'G': - specifier = StrTimeSpecifier::YEAR_ISO; - break; - case 'H': - specifier = StrTimeSpecifier::HOUR_24_PADDED; - break; - case 'I': - specifier = StrTimeSpecifier::HOUR_12_PADDED; - break; - case 'p': - specifier = StrTimeSpecifier::AM_PM; - break; - case 'M': - specifier = StrTimeSpecifier::MINUTE_PADDED; - break; - case 'S': - specifier = StrTimeSpecifier::SECOND_PADDED; - break; - case 'n': - specifier = StrTimeSpecifier::NANOSECOND_PADDED; - break; - case 'f': - specifier = StrTimeSpecifier::MICROSECOND_PADDED; - break; - case 'g': - specifier = StrTimeSpecifier::MILLISECOND_PADDED; - break; - case 'z': - specifier = StrTimeSpecifier::UTC_OFFSET; - break; - case 'Z': - specifier = StrTimeSpecifier::TZ_NAME; - break; - case 'j': - specifier = StrTimeSpecifier::DAY_OF_YEAR_PADDED; - break; - case 'U': - specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST; - break; - case 'W': - specifier = StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST; - break; - case 'V': - specifier = StrTimeSpecifier::WEEK_NUMBER_ISO; - break; - case 'c': - case 'x': - case 'X': - case 'T': { - string subformat; - if (format_char == 'c') { - // %c: Locale’s appropriate date and time representation. - // we push the ISO timestamp representation here - subformat = "%Y-%m-%d %H:%M:%S"; - } else if (format_char == 'x') { - // %x - Locale’s appropriate date representation. - // we push the ISO date format here - subformat = "%Y-%m-%d"; - } else if (format_char == 'X' || format_char == 'T') { - // %X - Locale’s appropriate time representation. - // we push the ISO time format here - subformat = "%H:%M:%S"; - } - // parse the subformat in a separate format specifier - StrfTimeFormat locale_format; - string error = StrTimeFormat::ParseFormatSpecifier(subformat, locale_format); - if (!error.empty()) { - throw InternalException("Failed to bind sub-format specifier \"%s\": %s", subformat, error); - } - // add the previous literal to the first literal of the subformat - locale_format.literals[0] = std::move(current_literal) + locale_format.literals[0]; - current_literal = ""; - // now push the subformat into the current format specifier - for (idx_t i = 0; i < locale_format.specifiers.size(); i++) { - format.AddFormatSpecifier(std::move(locale_format.literals[i]), locale_format.specifiers[i]); - } - pos = i + 1; - continue; - } - default: - return "Unrecognized format for strftime/strptime: %" + string(1, format_char); - } - } - format.AddFormatSpecifier(std::move(current_literal), specifier); - current_literal = ""; - pos = i + 1; - } - } - // add the final literal - if (pos < format_string.size()) { - current_literal += format_string.substr(pos, format_string.size() - pos); - } - format.AddLiteral(std::move(current_literal)); - return string(); -} - -void StrfTimeFormat::ConvertDateVector(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::DATE); - D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); - UnaryExecutor::ExecuteWithNulls( - input, result, count, [&](date_t input, ValidityMask &mask, idx_t idx) { - if (Date::IsFinite(input)) { - dtime_t time(0); - idx_t len = GetLength(input, time, 0, nullptr); - string_t target = StringVector::EmptyString(result, len); - FormatString(input, time, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - return StringVector::AddString(result, Date::ToString(input)); - } - }); -} - -string_t StrfTimeFormat::ConvertTimestampValue(const timestamp_t &input, Vector &result) const { - if (Timestamp::IsFinite(input)) { - date_t date; - dtime_t time; - Timestamp::Convert(input, date, time); - - int32_t data[8]; // year, month, day, hour, min, sec, ns, offset - Date::Convert(date, data[0], data[1], data[2]); - Time::Convert(time, data[3], data[4], data[5], data[6]); - data[6] *= Interval::NANOS_PER_MICRO; - data[7] = 0; - const char *tz_name = nullptr; - - idx_t len = GetLength(date, data, tz_name); - string_t target = StringVector::EmptyString(result, len); - FormatStringNS(date, data, tz_name, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - return StringVector::AddString(result, Timestamp::ToString(input)); - } -} - -string_t StrfTimeFormat::ConvertTimestampValue(const timestamp_ns_t &input, Vector &result) const { - if (Timestamp::IsFinite(input)) { - date_t date; - dtime_t time; - int32_t nanos; - Timestamp::Convert(input, date, time, nanos); - - int32_t data[8]; // year, month, day, hour, min, sec, ns, offset - Date::Convert(date, data[0], data[1], data[2]); - Time::Convert(time, data[3], data[4], data[5], data[6]); - data[6] *= Interval::NANOS_PER_MICRO; - data[6] += nanos; - data[7] = 0; - const char *tz_name = nullptr; - - idx_t len = GetLength(date, data, tz_name); - string_t target = StringVector::EmptyString(result, len); - FormatStringNS(date, data, tz_name, target.GetDataWriteable()); - target.Finalize(); - return target; - } else { - return StringVector::AddString(result, Timestamp::ToString(input)); - } -} - -void StrfTimeFormat::ConvertTimestampVector(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP || input.GetType().id() == LogicalTypeId::TIMESTAMP_TZ); - D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); - UnaryExecutor::ExecuteWithNulls( - input, result, count, - [&](timestamp_t input, ValidityMask &mask, idx_t idx) { return ConvertTimestampValue(input, result); }); -} - -void StrfTimeFormat::ConvertTimestampNSVector(Vector &input, Vector &result, idx_t count) { - D_ASSERT(input.GetType().id() == LogicalTypeId::TIMESTAMP_NS); - D_ASSERT(result.GetType().id() == LogicalTypeId::VARCHAR); - UnaryExecutor::ExecuteWithNulls( - input, result, count, - [&](timestamp_ns_t input, ValidityMask &mask, idx_t idx) { return ConvertTimestampValue(input, result); }); -} - -void StrpTimeFormat::AddFormatSpecifier(string preceding_literal, StrTimeSpecifier specifier) { - numeric_width.push_back(NumericSpecifierWidth(specifier)); - StrTimeFormat::AddFormatSpecifier(std::move(preceding_literal), specifier); -} - -int StrpTimeFormat::NumericSpecifierWidth(StrTimeSpecifier specifier) { - switch (specifier) { - case StrTimeSpecifier::WEEKDAY_DECIMAL: - case StrTimeSpecifier::WEEKDAY_ISO: - return 1; - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_24_DECIMAL: - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::HOUR_12_DECIMAL: - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::MINUTE_DECIMAL: - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::SECOND_DECIMAL: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - return 2; - case StrTimeSpecifier::MILLISECOND_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - return 3; - case StrTimeSpecifier::YEAR_DECIMAL: - case StrTimeSpecifier::YEAR_ISO: - return 4; - case StrTimeSpecifier::MICROSECOND_PADDED: - return 6; - case StrTimeSpecifier::NANOSECOND_PADDED: - return 9; - default: - return -1; - } -} - -enum class TimeSpecifierAMOrPM : uint8_t { TIME_SPECIFIER_NONE = 0, TIME_SPECIFIER_AM = 1, TIME_SPECIFIER_PM = 2 }; - -int32_t StrpTimeFormat::TryParseCollection(const char *data, idx_t &pos, idx_t size, const string_t collection[], - idx_t collection_count) const { - for (idx_t c = 0; c < collection_count; c++) { - auto &entry = collection[c]; - auto entry_data = entry.GetData(); - auto entry_size = entry.GetSize(); - // check if this entry matches - if (pos + entry_size > size) { - // too big: can't match - continue; - } - // compare the characters - idx_t i; - for (i = 0; i < entry_size; i++) { - if (std::tolower(entry_data[i]) != std::tolower(data[pos + i])) { - break; - } - } - if (i == entry_size) { - // full match - pos += entry_size; - return UnsafeNumericCast(c); - } - } - return -1; -} - -bool StrpTimeFormat::Parse(const char *data, size_t size, ParseResult &result, bool strict) const { - auto &result_data = result.data; - auto &error_message = result.error_message; - auto &error_position = result.error_position; - - // initialize the result - result_data[0] = 1900; - result_data[1] = 1; - result_data[2] = 1; - result_data[3] = 0; - result_data[4] = 0; - result_data[5] = 0; - result_data[6] = 0; - result_data[7] = 0; - // skip leading spaces - while (StringUtil::CharacterIsSpace(*data)) { - data++; - size--; - } - // Check for specials - // Precheck for alphas for performance. - idx_t pos = 0; - result.is_special = false; - if (size > 4) { - if (StringUtil::CharacterIsAlpha(*data)) { - if (Date::TryConvertDateSpecial(data, size, pos, Date::PINF)) { - result.is_special = true; - result.special = date_t::infinity(); - } else if (Date::TryConvertDateSpecial(data, size, pos, Date::EPOCH)) { - result.is_special = true; - result.special = date_t::epoch(); - } - } else if (*data == '-' && Date::TryConvertDateSpecial(data, size, pos, Date::NINF)) { - result.is_special = true; - result.special = date_t::ninfinity(); - } - } - if (result.is_special) { - // skip trailing spaces - while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { - pos++; - } - if (pos != size) { - error_message = "Special timestamp did not match: trailing characters"; - error_position = pos; - return false; - } - return true; - } - - TimeSpecifierAMOrPM ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE; - - // Year offset state (Year+W/j) - auto offset_specifier = StrTimeSpecifier::WEEKDAY_DECIMAL; - uint64_t weekno = 0; - uint64_t weekday = 0; - uint64_t yearday = 0; - bool has_weekday = false; - - // ISO state (%G/%V/%u) - // Out of range values to detect multiple specifications - uint64_t iso_year = 10000; - uint64_t iso_week = 54; - uint64_t iso_weekday = 8; - - for (idx_t i = 0;; i++) { - D_ASSERT(i < literals.size()); - // first compare the literal - const auto &literal = literals[i]; - for (size_t l = 0; l < literal.size();) { - // Match runs of spaces to runs of spaces. - if (StringUtil::CharacterIsSpace(literal[l])) { - if (!StringUtil::CharacterIsSpace(data[pos])) { - error_message = "Space does not match, expected " + literals[i]; - error_position = pos; - return false; - } - for (++pos; pos < size && StringUtil::CharacterIsSpace(data[pos]); ++pos) { - continue; - } - for (++l; l < literal.size() && StringUtil::CharacterIsSpace(literal[l]); ++l) { - continue; - } - continue; - } - // literal does not match - if (data[pos++] != literal[l++]) { - error_message = "Literal does not match, expected " + literal; - error_position = pos; - return false; - } - } - if (i == specifiers.size()) { - break; - } - // now parse the specifier - if (numeric_width[i] > 0) { - // numeric specifier: parse a number - uint64_t number = 0; - size_t start_pos = pos; - size_t end_pos = start_pos + UnsafeNumericCast(numeric_width[i]); - while (pos < size && pos < end_pos && StringUtil::CharacterIsDigit(data[pos])) { - number = number * 10 + UnsafeNumericCast(data[pos]) - '0'; - pos++; - } - if (pos == start_pos) { - // expected a number here - error_message = "Expected a number"; - error_position = start_pos; - return false; - } - switch (specifiers[i]) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - if (number < 1 || number > 31) { - error_message = "Day out of range, expected a value between 1 and 31"; - error_position = start_pos; - return false; - } - // day of the month - result_data[2] = UnsafeNumericCast(number); - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - if (number < 1 || number > 12) { - error_message = "Month out of range, expected a value between 1 and 12"; - error_position = start_pos; - return false; - } - // month number - result_data[1] = UnsafeNumericCast(number); - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - switch (offset_specifier) { - case StrTimeSpecifier::YEAR_ISO: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - // Override - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - break; - } - // year without century.. - // Python uses 69 as a crossover point (i.e. >= 69 is 19.., < 69 is 20..) - if (pos - start_pos < 2 && strict) { - return false; - } - if (number >= 100) { - // %y only supports numbers between [0..99] - error_message = "Year without century out of range, expected a value between 0 and 99"; - error_position = start_pos; - return false; - } - if (number >= 69) { - result_data[0] = int32_t(1900 + number); - } else { - result_data[0] = int32_t(2000 + number); - } - break; - case StrTimeSpecifier::YEAR_DECIMAL: - switch (offset_specifier) { - case StrTimeSpecifier::YEAR_ISO: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - // Override - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - break; - } - if (pos - start_pos < 2 && strict) { - return false; - } - // year as full number - result_data[0] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::YEAR_ISO: - switch (offset_specifier) { - // y/m/d overrides G/V/u but does not conflict - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::YEAR_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::YEAR_ISO: - case StrTimeSpecifier::WEEK_NUMBER_ISO: - // Already parsing ISO - if (iso_year <= 9999) { - error_message = "Multiple ISO year offsets specified"; - error_position = start_pos; - return false; - } - break; - default: - error_message = "Incompatible ISO year offset specified"; - error_position = start_pos; - return false; - break; - } - if (number > 9999) { - // %G only supports numbers between [0..9999] - error_message = "ISO Year out of range, expected a value between 0000 and 9999"; - error_position = start_pos; - return false; - } - iso_year = number; - break; - case StrTimeSpecifier::HOUR_24_PADDED: - case StrTimeSpecifier::HOUR_24_DECIMAL: - if (number >= 24) { - error_message = "Hour out of range, expected a value between 0 and 23"; - error_position = start_pos; - return false; - } - // hour as full number - result_data[3] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::HOUR_12_PADDED: - case StrTimeSpecifier::HOUR_12_DECIMAL: - if (number < 1 || number > 12) { - error_message = "Hour12 out of range, expected a value between 1 and 12"; - error_position = start_pos; - return false; - } - // 12-hour number: start off by just storing the number - result_data[3] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::MINUTE_PADDED: - case StrTimeSpecifier::MINUTE_DECIMAL: - if (number >= 60) { - error_message = "Minutes out of range, expected a value between 0 and 59"; - error_position = start_pos; - return false; - } - // minutes - result_data[4] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::SECOND_PADDED: - case StrTimeSpecifier::SECOND_DECIMAL: - if (number >= 60) { - error_message = "Seconds out of range, expected a value between 0 and 59"; - error_position = start_pos; - return false; - } - // seconds - result_data[5] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::NANOSECOND_PADDED: - D_ASSERT(number < Interval::NANOS_PER_SEC); // enforced by the length of the number - // nanoseconds - result_data[6] = UnsafeNumericCast(number); - break; - case StrTimeSpecifier::MICROSECOND_PADDED: - D_ASSERT(number < Interval::MICROS_PER_SEC); // enforced by the length of the number - // nanoseconds - result_data[6] = UnsafeNumericCast(number * Interval::NANOS_PER_MICRO); - break; - case StrTimeSpecifier::MILLISECOND_PADDED: - D_ASSERT(number < Interval::MSECS_PER_SEC); // enforced by the length of the number - // nanoseconds - result_data[6] = UnsafeNumericCast(number * Interval::NANOS_PER_MSEC); - break; - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: - // m/d overrides WU/w but does not conflict - switch (offset_specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::YEAR_DECIMAL: - // Switch to offset parsing - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - error_message = "Multiple week offsets specified"; - error_position = start_pos; - return false; - } - if (number > 53) { - error_message = "Week out of range, expected a value between 0 and 53"; - error_position = start_pos; - return false; - } - weekno = number; - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - if (number > 6) { - error_message = "Weekday out of range, expected a value between 0 and 6"; - error_position = start_pos; - return false; - } - has_weekday = true; - weekday = number; - break; - case StrTimeSpecifier::WEEK_NUMBER_ISO: - // y/m/d overrides G/V/u but does not conflict - switch (offset_specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::YEAR_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - case StrTimeSpecifier::WEEK_NUMBER_ISO: - case StrTimeSpecifier::YEAR_ISO: - // Already parsing ISO - if (iso_week <= 53) { - error_message = "Multiple ISO week offsets specified"; - error_position = start_pos; - return false; - } - break; - default: - error_message = "Incompatible ISO week offset specified"; - error_position = start_pos; - return false; - } - if (number < 1 || number > 53) { - error_message = "ISO week offset out of range, expected a value between 1 and 53"; - error_position = start_pos; - return false; - } - iso_week = number; - break; - case StrTimeSpecifier::WEEKDAY_ISO: - if (iso_weekday <= 7) { - error_message = "Multiple ISO weekday offsets specified"; - error_position = start_pos; - return false; - } - if (number < 1 || number > 7) { - error_message = "ISO weekday offset out of range, expected a value between 1 and 7"; - error_position = start_pos; - return false; - } - iso_weekday = number; - break; - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: - // m/d overrides j but does not conflict - switch (offset_specifier) { - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - // Just validate, don't use - break; - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::YEAR_DECIMAL: - // Switch to offset parsing - case StrTimeSpecifier::WEEKDAY_DECIMAL: - // First offset specifier - offset_specifier = specifiers[i]; - break; - default: - error_message = "Multiple year offsets specified"; - error_position = start_pos; - return false; - } - if (number < 1 || number > 366) { - error_message = "Year day out of range, expected a value between 1 and 366"; - error_position = start_pos; - return false; - } - yearday = number; - break; - default: - throw NotImplementedException("Unsupported specifier for strptime"); - } - } else { - switch (specifiers[i]) { - case StrTimeSpecifier::AM_PM: { - // parse the next 2 characters - if (pos + 2 > size) { - // no characters left to parse - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - char pa_char = char(std::tolower(data[pos])); - char m_char = char(std::tolower(data[pos + 1])); - if (m_char != 'm') { - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - if (pa_char == 'p') { - ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_PM; - } else if (pa_char == 'a') { - ampm = TimeSpecifierAMOrPM::TIME_SPECIFIER_AM; - } else { - error_message = "Expected AM/PM"; - error_position = pos; - return false; - } - pos += 2; - break; - } - // we parse weekday names, but we don't use them as information - case StrTimeSpecifier::ABBREVIATED_WEEKDAY_NAME: - if (TryParseCollection(data, pos, size, Date::DAY_NAMES_ABBREVIATED, 7) < 0) { - error_message = "Expected an abbreviated day name (Mon, Tue, Wed, Thu, Fri, Sat, Sun)"; - error_position = pos; - return false; - } - break; - case StrTimeSpecifier::FULL_WEEKDAY_NAME: - if (TryParseCollection(data, pos, size, Date::DAY_NAMES, 7) < 0) { - error_message = "Expected a full day name (Monday, Tuesday, etc...)"; - error_position = pos; - return false; - } - break; - case StrTimeSpecifier::ABBREVIATED_MONTH_NAME: { - int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES_ABBREVIATED, 12); - if (month < 0) { - error_message = "Expected an abbreviated month name (Jan, Feb, Mar, etc..)"; - error_position = pos; - return false; - } - result_data[1] = month + 1; - break; - } - case StrTimeSpecifier::FULL_MONTH_NAME: { - int32_t month = TryParseCollection(data, pos, size, Date::MONTH_NAMES, 12); - if (month < 0) { - error_message = "Expected a full month name (January, February, etc...)"; - error_position = pos; - return false; - } - result_data[1] = month + 1; - break; - } - case StrTimeSpecifier::UTC_OFFSET: { - int hour_offset, minute_offset; - if (!Timestamp::TryParseUTCOffset(data, pos, size, hour_offset, minute_offset)) { - error_message = "Expected +HH[MM] or -HH[MM]"; - error_position = pos; - return false; - } - result_data[7] = hour_offset * Interval::MINS_PER_HOUR + minute_offset; - break; - } - case StrTimeSpecifier::TZ_NAME: { - // skip leading spaces - while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { - pos++; - } - const auto tz_begin = data + pos; - // stop when we encounter a non-tz character - while (pos < size && Timestamp::CharacterIsTimeZone(data[pos])) { - pos++; - } - const auto tz_end = data + pos; - // Can't fully validate without a list - caller's responsibility. - // But tz must not be empty. - if (tz_end == tz_begin) { - error_message = "Empty Time Zone name"; - error_position = UnsafeNumericCast(tz_begin - data); - return false; - } - result.tz.assign(tz_begin, tz_end); - break; - } - default: - throw NotImplementedException("Unsupported specifier for strptime"); - } - } - } - // skip trailing spaces - while (pos < size && StringUtil::CharacterIsSpace(data[pos])) { - pos++; - } - if (pos != size) { - error_message = "Full specifier did not match: trailing characters"; - error_position = pos; - return false; - } - if (ampm != TimeSpecifierAMOrPM::TIME_SPECIFIER_NONE) { - if (result_data[3] > 12) { - error_message = - "Invalid hour: " + to_string(result_data[3]) + " AM/PM, expected an hour within the range [0..12]"; - return false; - } - // adjust the hours based on the AM or PM specifier - if (ampm == TimeSpecifierAMOrPM::TIME_SPECIFIER_AM) { - // AM: 12AM=0, 1AM=1, 2AM=2, ..., 11AM=11 - if (result_data[3] == 12) { - result_data[3] = 0; - } - } else { - // PM: 12PM=12, 1PM=13, 2PM=14, ..., 11PM=23 - if (result_data[3] != 12) { - result_data[3] += 12; - } - } - } - switch (offset_specifier) { - case StrTimeSpecifier::YEAR_ISO: - case StrTimeSpecifier::WEEK_NUMBER_ISO: { - // Default to 1900-01-01 - iso_year = (iso_year > 9999) ? 1900 : iso_year; - iso_week = (iso_week > 53) ? 1 : iso_week; - iso_weekday = (iso_weekday > 7) ? 1 : iso_weekday; - // Gregorian and ISO agree on the year of January 4 - auto jan4 = Date::FromDate(UnsafeNumericCast(iso_year), 1, 4); - // ISO Week 1 starts on the previous Monday - auto week1 = Date::GetMondayOfCurrentWeek(jan4); - // ISO Week N starts N-1 weeks later - auto iso_date = week1 + UnsafeNumericCast((iso_week - 1) * 7 + (iso_weekday - 1)); - Date::Convert(iso_date, result_data[0], result_data[1], result_data[2]); - break; - } - case StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST: - case StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST: { - // Adjust weekday to be 0-based for the week type - if (has_weekday) { - weekday = (weekday + 7 - - static_cast(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_MON_FIRST)) % - 7; - } - // Get the start of week 1, move back 7 days and then weekno * 7 + weekday gives the date - const auto jan1 = Date::FromDate(result_data[0], 1, 1); - auto yeardate = Date::GetMondayOfCurrentWeek(jan1); - yeardate -= int(offset_specifier == StrTimeSpecifier::WEEK_NUMBER_PADDED_SUN_FIRST); - // Is there a week 0? - yeardate -= 7 * int(yeardate >= jan1); - yeardate += UnsafeNumericCast(weekno * 7 + weekday); - Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); - break; - } - case StrTimeSpecifier::DAY_OF_YEAR_PADDED: - case StrTimeSpecifier::DAY_OF_YEAR_DECIMAL: { - auto yeardate = Date::FromDate(result_data[0], 1, 1); - yeardate += UnsafeNumericCast(yearday - 1); - Date::Convert(yeardate, result_data[0], result_data[1], result_data[2]); - break; - } - case StrTimeSpecifier::DAY_OF_MONTH_PADDED: - case StrTimeSpecifier::DAY_OF_MONTH: - case StrTimeSpecifier::MONTH_DECIMAL_PADDED: - case StrTimeSpecifier::MONTH_DECIMAL: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY_PADDED: - case StrTimeSpecifier::YEAR_WITHOUT_CENTURY: - case StrTimeSpecifier::YEAR_DECIMAL: - // m/d overrides UWVwu/j - break; - default: - D_ASSERT(offset_specifier == StrTimeSpecifier::WEEKDAY_DECIMAL); - break; - } - - return true; -} - -//! Parses a timestamp using the given specifier -bool StrpTimeFormat::Parse(string_t str, ParseResult &result, bool strict) const { - auto data = str.GetData(); - idx_t size = str.GetSize(); - return Parse(data, size, result, strict); -} - -StrpTimeFormat::ParseResult StrpTimeFormat::Parse(const string &format_string, const string &text) { - StrpTimeFormat format; - format.format_specifier = format_string; - string error = StrTimeFormat::ParseFormatSpecifier(format_string, format); - if (!error.empty()) { - throw InvalidInputException("Failed to parse format specifier %s: %s", format_string, error); - } - StrpTimeFormat::ParseResult result; - if (!format.Parse(text, result)) { - throw InvalidInputException("Failed to parse string \"%s\" with format specifier \"%s\"", text, format_string); - } - return result; -} - -bool StrTimeFormat::Empty() const { - return format_specifier.empty(); -} - -string StrpTimeFormat::FormatStrpTimeError(const string &input, optional_idx position) { - if (!position.IsValid()) { - return string(); - } - return input + "\n" + string(position.GetIndex(), ' ') + "^"; -} - -date_t StrpTimeFormat::ParseResult::ToDate() { - if (is_special) { - return special; - } - return Date::FromDate(data[0], data[1], data[2]); -} - -bool StrpTimeFormat::ParseResult::TryToDate(date_t &result) { - return Date::TryFromDate(data[0], data[1], data[2], result); -} - -int32_t StrpTimeFormat::ParseResult::GetMicros() const { - return UnsafeNumericCast((data[6] + Interval::NANOS_PER_MICRO / 2) / Interval::NANOS_PER_MICRO); -} - -dtime_t StrpTimeFormat::ParseResult::ToTime() { - const auto hour_offset = data[7] / Interval::MINS_PER_HOUR; - const auto mins_offset = data[7] % Interval::MINS_PER_HOUR; - return Time::FromTime(data[3] - hour_offset, data[4] - mins_offset, data[5], GetMicros()); -} - -int64_t StrpTimeFormat::ParseResult::ToTimeNS() { - const int32_t hour_offset = data[7] / Interval::MINS_PER_HOUR; - const int32_t mins_offset = data[7] % Interval::MINS_PER_HOUR; - return Time::ToNanoTime(data[3] - hour_offset, data[4] - mins_offset, data[5], data[6]); -} - -bool StrpTimeFormat::ParseResult::TryToTime(dtime_t &result) { - if (data[7]) { - return false; - } - result = Time::FromTime(data[3], data[4], data[5], GetMicros()); - return true; -} - -timestamp_t StrpTimeFormat::ParseResult::ToTimestamp() { - if (is_special) { - if (special == date_t::infinity()) { - return timestamp_t::infinity(); - } else if (special == date_t::ninfinity()) { - return timestamp_t::ninfinity(); - } - return Timestamp::FromDatetime(special, dtime_t(0)); - } - - date_t date = ToDate(); - dtime_t time = ToTime(); - return Timestamp::FromDatetime(date, time); -} - -bool StrpTimeFormat::ParseResult::TryToTimestamp(timestamp_t &result) { - date_t date; - if (!TryToDate(date)) { - return false; - } - dtime_t time = ToTime(); - return Timestamp::TryFromDatetime(date, time, result); -} - -timestamp_ns_t StrpTimeFormat::ParseResult::ToTimestampNS() { - timestamp_ns_t result; - if (is_special) { - if (special == date_t::infinity()) { - result.value = timestamp_t::infinity().value; - } else if (special == date_t::ninfinity()) { - result.value = timestamp_t::ninfinity().value; - } else { - result.value = special.days * Interval::NANOS_PER_DAY; - } - } else { - // Don't use rounded µs - const auto date = ToDate(); - const auto time = ToTimeNS(); - if (!TryMultiplyOperator::Operation(date.days, Interval::NANOS_PER_DAY, - result.value)) { - throw ConversionException("Date out of nanosecond range: %d-%d-%d", data[0], data[1], data[2]); - } - if (!TryAddOperator::Operation(result.value, time, result.value)) { - throw ConversionException("Overflow exception in date/time -> timestamp_ns conversion"); - } - } - - return result; -} - -bool StrpTimeFormat::ParseResult::TryToTimestampNS(timestamp_ns_t &result) { - date_t date; - if (!TryToDate(date)) { - return false; - } - - // Don't use rounded µs - const auto time = ToTimeNS(); - if (!TryMultiplyOperator::Operation(date.days, Interval::NANOS_PER_DAY, result.value)) { - return false; - } - if (!TryAddOperator::Operation(result.value, time, result.value)) { - return false; - } - return Timestamp::IsFinite(result); -} - -string StrpTimeFormat::ParseResult::FormatError(string_t input, const string &format_specifier) { - return StringUtil::Format("Could not parse string \"%s\" according to format specifier \"%s\"\n%s\nError: %s", - input.GetString(), format_specifier, - FormatStrpTimeError(input.GetString(), error_position), error_message); -} - -bool StrpTimeFormat::TryParseDate(string_t input, date_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToDate(result); -} - -bool StrpTimeFormat::TryParseDate(const char *data, size_t size, date_t &result) const { - ParseResult parse_result; - if (!Parse(data, size, parse_result)) { - return false; - } - return parse_result.TryToDate(result); -} - -bool StrpTimeFormat::TryParseTime(string_t input, dtime_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToTime(result); -} - -bool StrpTimeFormat::TryParseTimestamp(string_t input, timestamp_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToTimestamp(result); -} - -bool StrpTimeFormat::TryParseTimestamp(const char *data, size_t size, timestamp_t &result) const { - ParseResult parse_result; - if (!Parse(data, size, parse_result)) { - return false; - } - return parse_result.TryToTimestamp(result); -} - -bool StrpTimeFormat::TryParseTimestampNS(string_t input, timestamp_ns_t &result, string &error_message) const { - ParseResult parse_result; - if (!Parse(input, parse_result)) { - error_message = parse_result.FormatError(input, format_specifier); - return false; - } - return parse_result.TryToTimestampNS(result); -} - -bool StrpTimeFormat::TryParseTimestampNS(const char *data, size_t size, timestamp_ns_t &result) const { - ParseResult parse_result; - if (!Parse(data, size, parse_result)) { - return false; - } - return parse_result.TryToTimestampNS(result); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/caseconvert.cpp b/src/duckdb/src/function/scalar/string/caseconvert.cpp deleted file mode 100644 index d0d850f5a..000000000 --- a/src/duckdb/src/function/scalar/string/caseconvert.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -#include "utf8proc_wrapper.hpp" - -#include - -namespace duckdb { - -template -static string_t ASCIICaseConvert(Vector &result, const char *input_data, idx_t input_length) { - idx_t output_length = input_length; - auto result_str = StringVector::EmptyString(result, output_length); - auto result_data = result_str.GetDataWriteable(); - for (idx_t i = 0; i < input_length; i++) { - result_data[i] = UnsafeNumericCast(IS_UPPER ? StringUtil::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] - : StringUtil::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]); - } - result_str.Finalize(); - return result_str; -} - -template -static idx_t GetResultLength(const char *input_data, idx_t input_length) { - idx_t output_length = 0; - for (idx_t i = 0; i < input_length;) { - if (input_data[i] & 0x80) { - // unicode - int sz = 0; - auto codepoint = Utf8Proc::UTF8ToCodepoint(input_data + i, sz); - auto converted_codepoint = - IS_UPPER ? Utf8Proc::CodepointToUpper(codepoint) : Utf8Proc::CodepointToLower(codepoint); - auto new_sz = Utf8Proc::CodepointLength(converted_codepoint); - D_ASSERT(new_sz >= 0); - output_length += UnsafeNumericCast(new_sz); - i += UnsafeNumericCast(sz); - } else { - // ascii - output_length++; - i++; - } - } - return output_length; -} - -template -static void CaseConvert(const char *input_data, idx_t input_length, char *result_data) { - for (idx_t i = 0; i < input_length;) { - if (input_data[i] & 0x80) { - // non-ascii character - int sz = 0, new_sz = 0; - auto codepoint = Utf8Proc::UTF8ToCodepoint(input_data + i, sz); - auto converted_codepoint = - IS_UPPER ? Utf8Proc::CodepointToUpper(codepoint) : Utf8Proc::CodepointToLower(codepoint); - auto success = Utf8Proc::CodepointToUtf8(converted_codepoint, new_sz, result_data); - D_ASSERT(success); - (void)success; - result_data += new_sz; - i += UnsafeNumericCast(sz); - } else { - // ascii - *result_data = UnsafeNumericCast(IS_UPPER ? StringUtil::ASCII_TO_UPPER_MAP[uint8_t(input_data[i])] - : StringUtil::ASCII_TO_LOWER_MAP[uint8_t(input_data[i])]); - result_data++; - i++; - } - } -} - -idx_t LowerLength(const char *input_data, idx_t input_length) { - return GetResultLength(input_data, input_length); -} - -void LowerCase(const char *input_data, idx_t input_length, char *result_data) { - CaseConvert(input_data, input_length, result_data); -} - -template -static string_t UnicodeCaseConvert(Vector &result, const char *input_data, idx_t input_length) { - // first figure out the output length - idx_t output_length = GetResultLength(input_data, input_length); - auto result_str = StringVector::EmptyString(result, output_length); - auto result_data = result_str.GetDataWriteable(); - - CaseConvert(input_data, input_length, result_data); - result_str.Finalize(); - return result_str; -} - -template -struct CaseConvertOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - return UnicodeCaseConvert(result, input_data, input_length); - } -}; - -template -static void CaseConvertFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, args.size()); -} - -template -struct CaseConvertOperatorASCII { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - return ASCIICaseConvert(result, input_data, input_length); - } -}; - -template -static void CaseConvertFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString>(args.data[0], result, - args.size()); -} - -template -static unique_ptr CaseConvertPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = CaseConvertFunctionASCII; - } - return nullptr; -} - -ScalarFunction LowerFun::GetFunction() { - return ScalarFunction("lower", {LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, - nullptr, CaseConvertPropagateStats); -} - -ScalarFunction UpperFun::GetFunction() { - return ScalarFunction("upper", {LogicalType::VARCHAR}, LogicalType::VARCHAR, CaseConvertFunction, nullptr, - nullptr, CaseConvertPropagateStats); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/concat.cpp b/src/duckdb/src/function/scalar/string/concat.cpp deleted file mode 100644 index a6a495a95..000000000 --- a/src/duckdb/src/function/scalar/string/concat.cpp +++ /dev/null @@ -1,367 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/string_functions.hpp" - -#include "duckdb/planner/expression/bound_cast_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -#include - -namespace duckdb { - -struct ConcatFunctionData : public FunctionData { - ConcatFunctionData(const LogicalType &return_type_p, bool is_operator_p) - : return_type(return_type_p), is_operator(is_operator_p) { - } - ~ConcatFunctionData() override; - - LogicalType return_type; - - bool is_operator = false; - -public: - bool Equals(const FunctionData &other_p) const override; - unique_ptr Copy() const override; -}; - -ConcatFunctionData::~ConcatFunctionData() { -} - -bool ConcatFunctionData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return return_type == other.return_type && is_operator == other.is_operator; -} - -unique_ptr ConcatFunctionData::Copy() const { - return make_uniq(return_type, is_operator); -} - -static void StringConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - // iterate over the vectors to count how large the final string will be - idx_t constant_lengths = 0; - vector result_lengths(args.size(), 0); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; - D_ASSERT(input.GetType().InternalType() == PhysicalType::VARCHAR); - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - if (ConstantVector::IsNull(input)) { - // constant null, skip - continue; - } - auto input_data = ConstantVector::GetData(input); - constant_lengths += input_data->GetSize(); - } else { - // non-constant vector: set the result type to a flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - // now get the lengths of each of the input elements - UnifiedVectorFormat vdata; - input.ToUnifiedFormat(args.size(), vdata); - - auto input_data = UnifiedVectorFormat::GetData(vdata); - // now add the length of each vector to the result length - for (idx_t i = 0; i < args.size(); i++) { - auto idx = vdata.sel->get_index(i); - if (!vdata.validity.RowIsValid(idx)) { - continue; - } - result_lengths[i] += input_data[idx].GetSize(); - } - } - } - - // first we allocate the empty strings for each of the values - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < args.size(); i++) { - // allocate an empty string of the required size - idx_t str_length = constant_lengths + result_lengths[i]; - result_data[i] = StringVector::EmptyString(result, str_length); - // we reuse the result_lengths vector to store the currently appended size - result_lengths[i] = 0; - } - - // now that the empty space for the strings has been allocated, perform the concatenation - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - auto &input = args.data[col_idx]; - - // loop over the vector and concat to all results - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - // constant vector - if (ConstantVector::IsNull(input)) { - // constant null, skip - continue; - } - // append the constant vector to each of the strings - auto input_data = ConstantVector::GetData(input); - auto input_ptr = input_data->GetData(); - auto input_len = input_data->GetSize(); - for (idx_t i = 0; i < args.size(); i++) { - memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); - result_lengths[i] += input_len; - } - } else { - // standard vector - UnifiedVectorFormat idata; - input.ToUnifiedFormat(args.size(), idata); - - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < args.size(); i++) { - auto idx = idata.sel->get_index(i); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - auto input_ptr = input_data[idx].GetData(); - auto input_len = input_data[idx].GetSize(); - memcpy(result_data[i].GetDataWriteable() + result_lengths[i], input_ptr, input_len); - result_lengths[i] += input_len; - } - } - } - for (idx_t i = 0; i < args.size(); i++) { - result_data[i].Finalize(); - } -} - -static void ConcatOperator(DataChunk &args, ExpressionState &state, Vector &result) { - BinaryExecutor::Execute( - args.data[0], args.data[1], result, args.size(), [&](string_t a, string_t b) { - auto a_data = a.GetData(); - auto b_data = b.GetData(); - auto a_length = a.GetSize(); - auto b_length = b.GetSize(); - - auto target_length = a_length + b_length; - auto target = StringVector::EmptyString(result, target_length); - auto target_data = target.GetDataWriteable(); - - memcpy(target_data, a_data, a_length); - memcpy(target_data + a_length, b_data, b_length); - target.Finalize(); - return target; - }); -} - -struct ListConcatInputData { - ListConcatInputData(Vector &input, Vector &child_vec) : input(input), child_vec(child_vec) { - } - - UnifiedVectorFormat vdata; - Vector &input; - Vector &child_vec; - UnifiedVectorFormat child_vdata; - const list_entry_t *input_entries = nullptr; -}; - -static void ListConcatFunction(DataChunk &args, ExpressionState &state, Vector &result, bool is_operator) { - auto count = args.size(); - - auto result_entries = FlatVector::GetData(result); - vector input_data; - for (auto &input : args.data) { - if (!is_operator && input.GetType().id() == LogicalTypeId::SQLNULL) { - // LIST_CONCAT ignores NULL values - continue; - } - - auto &child_vec = ListVector::GetEntry(input); - ListConcatInputData data(input, child_vec); - input.ToUnifiedFormat(count, data.vdata); - - data.input_entries = UnifiedVectorFormat::GetData(data.vdata); - auto list_size = ListVector::GetListSize(input); - - child_vec.ToUnifiedFormat(list_size, data.child_vdata); - - input_data.push_back(std::move(data)); - } - - auto &result_validity = FlatVector::Validity(result); - idx_t offset = 0; - for (idx_t i = 0; i < count; i++) { - auto &result_entry = result_entries[i]; - result_entry.offset = offset; - result_entry.length = 0; - for (auto &data : input_data) { - auto list_index = data.vdata.sel->get_index(i); - if (!data.vdata.validity.RowIsValid(list_index)) { - // LIST_CONCAT ignores NULL values, but || does not - if (is_operator) { - result_validity.SetInvalid(i); - } - continue; - } - const auto &list_entry = data.input_entries[list_index]; - result_entry.length += list_entry.length; - ListVector::Append(result, data.child_vec, *data.child_vdata.sel, list_entry.offset + list_entry.length, - list_entry.offset); - } - offset += result_entry.length; - } - ListVector::SetListSize(result, offset); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void ConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - if (info.return_type.id() == LogicalTypeId::LIST) { - return ListConcatFunction(args, state, result, info.is_operator); - } else if (info.is_operator) { - return ConcatOperator(args, state, result); - } - return StringConcatFunction(args, state, result); -} - -static void SetArgumentType(ScalarFunction &bound_function, const LogicalType &type, bool is_operator) { - if (is_operator) { - bound_function.arguments[0] = type; - bound_function.arguments[1] = type; - bound_function.return_type = type; - return; - } - - for (auto &arg : bound_function.arguments) { - arg = type; - } - bound_function.varargs = type; - bound_function.return_type = type; -} - -static unique_ptr BindListConcat(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, bool is_operator) { - LogicalType child_type = LogicalType::SQLNULL; - bool all_null = true; - for (auto &arg : arguments) { - auto &return_type = arg->return_type; - if (return_type == LogicalTypeId::SQLNULL) { - // we mimic postgres behaviour: list_concat(NULL, my_list) = my_list - continue; - } - all_null = false; - LogicalType next_type = LogicalTypeId::INVALID; - switch (return_type.id()) { - case LogicalTypeId::UNKNOWN: - throw ParameterNotResolvedException(); - case LogicalTypeId::LIST: - next_type = ListType::GetChildType(return_type); - break; - case LogicalTypeId::ARRAY: - next_type = ArrayType::GetChildType(return_type); - break; - default: { - string type_list; - for (idx_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) { - if (!type_list.empty()) { - if (arg_idx + 1 == arguments.size()) { - // last argument - type_list += " and "; - } else { - type_list += ", "; - } - } - type_list += arguments[arg_idx]->return_type.ToString(); - } - throw BinderException(*arg, "Cannot concatenate types %s - an explicit cast is required", type_list); - } - } - if (!LogicalType::TryGetMaxLogicalType(context, child_type, next_type, child_type)) { - throw BinderException(*arg, - "Cannot concatenate lists of types %s[] and %s[] - an explicit cast is required", - child_type.ToString(), next_type.ToString()); - } - } - if (all_null) { - // all arguments are NULL - SetArgumentType(bound_function, LogicalTypeId::SQLNULL, is_operator); - return make_uniq(bound_function.return_type, is_operator); - } - auto list_type = LogicalType::LIST(child_type); - - SetArgumentType(bound_function, list_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); -} - -static unique_ptr BindConcatFunctionInternal(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - bool is_operator) { - bool list_concat = false; - // blob concat is only supported for the concat operator - regular concat converts to varchar - bool all_blob = is_operator ? true : false; - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - if (arg->return_type.id() == LogicalTypeId::LIST || arg->return_type.id() == LogicalTypeId::ARRAY) { - list_concat = true; - } - if (arg->return_type.id() != LogicalTypeId::BLOB) { - all_blob = false; - } - } - if (list_concat) { - return BindListConcat(context, bound_function, arguments, is_operator); - } - auto return_type = all_blob ? LogicalType::BLOB : LogicalType::VARCHAR; - - // we can now assume that the input is a string or castable to a string - SetArgumentType(bound_function, return_type, is_operator); - return make_uniq(bound_function.return_type, is_operator); -} - -static unique_ptr BindConcatFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return BindConcatFunctionInternal(context, bound_function, arguments, false); -} - -static unique_ptr BindConcatOperator(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return BindConcatFunctionInternal(context, bound_function, arguments, true); -} - -static unique_ptr ListConcatStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto stats = child_stats[0].ToUnique(); - for (idx_t i = 1; i < child_stats.size(); i++) { - stats->Merge(child_stats[i]); - } - return stats; -} - -ScalarFunction ListConcatFun::GetFunction() { - // The arguments and return types are set in the binder function. - auto fun = ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::LIST(LogicalType::ANY)}, - LogicalType::LIST(LogicalType::ANY), ConcatFunction, BindConcatFunction, nullptr, - ListConcatStats); - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -// the concat operator and concat function have different behavior regarding NULLs -// this is strange but seems consistent with postgresql and mysql -// (sqlite does not support the concat function, only the concat operator) - -// the concat operator behaves as one would expect: any NULL value present results in a NULL -// i.e. NULL || 'hello' = NULL -// the concat function, however, treats NULL values as an empty string -// i.e. concat(NULL, 'hello') = 'hello' -ScalarFunction ConcatFun::GetFunction() { - ScalarFunction concat = - ScalarFunction("concat", {LogicalType::ANY}, LogicalType::ANY, ConcatFunction, BindConcatFunction); - concat.varargs = LogicalType::ANY; - concat.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return concat; -} - -ScalarFunction ConcatOperatorFun::GetFunction() { - ScalarFunction concat_op = ScalarFunction("||", {LogicalType::ANY, LogicalType::ANY}, LogicalType::ANY, - ConcatFunction, BindConcatOperator); - return concat_op; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/concat_ws.cpp b/src/duckdb/src/function/scalar/string/concat_ws.cpp deleted file mode 100644 index ebc1e8b3a..000000000 --- a/src/duckdb/src/function/scalar/string/concat_ws.cpp +++ /dev/null @@ -1,149 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" - -#include - -namespace duckdb { - -static void TemplatedConcatWS(DataChunk &args, const string_t *sep_data, const SelectionVector &sep_sel, - const SelectionVector &rsel, idx_t count, Vector &result) { - vector result_lengths(args.size(), 0); - vector has_results(args.size(), false); - - // we overallocate here, but this is important for static analysis - auto orrified_data = make_unsafe_uniq_array_uninitialized(args.ColumnCount()); - - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - args.data[col_idx].ToUnifiedFormat(args.size(), orrified_data[col_idx - 1]); - } - - // first figure out the lengths - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; - - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - result_lengths[ridx] += sep_data[sep_idx].GetSize(); - } - result_lengths[ridx] += input_data[idx].GetSize(); - has_results[ridx] = true; - } - } - - // first we allocate the empty strings for each of the values - auto result_data = FlatVector::GetData(result); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - // allocate an empty string of the required size - result_data[ridx] = StringVector::EmptyString(result, result_lengths[ridx]); - // we reuse the result_lengths vector to store the currently appended size - result_lengths[ridx] = 0; - has_results[ridx] = false; - } - - // now that the empty space for the strings has been allocated, perform the concatenation - for (idx_t col_idx = 1; col_idx < args.ColumnCount(); col_idx++) { - auto &idata = orrified_data[col_idx - 1]; - auto input_data = UnifiedVectorFormat::GetData(idata); - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - auto sep_idx = sep_sel.get_index(ridx); - auto idx = idata.sel->get_index(ridx); - if (!idata.validity.RowIsValid(idx)) { - continue; - } - if (has_results[ridx]) { - auto sep_size = sep_data[sep_idx].GetSize(); - auto sep_ptr = sep_data[sep_idx].GetData(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], sep_ptr, sep_size); - result_lengths[ridx] += sep_size; - } - auto input_ptr = input_data[idx].GetData(); - auto input_len = input_data[idx].GetSize(); - memcpy(result_data[ridx].GetDataWriteable() + result_lengths[ridx], input_ptr, input_len); - result_lengths[ridx] += input_len; - has_results[ridx] = true; - } - } - for (idx_t i = 0; i < count; i++) { - auto ridx = rsel.get_index(i); - result_data[ridx].Finalize(); - } -} - -static void ConcatWSFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &separator = args.data[0]; - UnifiedVectorFormat vdata; - separator.ToUnifiedFormat(args.size(), vdata); - - result.SetVectorType(VectorType::CONSTANT_VECTOR); - for (idx_t col_idx = 0; col_idx < args.ColumnCount(); col_idx++) { - if (args.data[col_idx].GetVectorType() != VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::FLAT_VECTOR); - break; - } - } - switch (separator.GetVectorType()) { - case VectorType::CONSTANT_VECTOR: { - if (ConstantVector::IsNull(separator)) { - // constant NULL as separator: return constant NULL vector - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::SetNull(result, true); - return; - } - // no null values - auto sel = FlatVector::IncrementalSelectionVector(); - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, *sel, args.size(), result); - return; - } - default: { - // default case: loop over nullmask and create a non-null selection vector - idx_t not_null_count = 0; - SelectionVector not_null_vector(STANDARD_VECTOR_SIZE); - auto &result_mask = FlatVector::Validity(result); - for (idx_t i = 0; i < args.size(); i++) { - if (!vdata.validity.RowIsValid(vdata.sel->get_index(i))) { - result_mask.SetInvalid(i); - } else { - not_null_vector.set_index(not_null_count++, i); - } - } - TemplatedConcatWS(args, UnifiedVectorFormat::GetData(vdata), *vdata.sel, not_null_vector, - not_null_count, result); - return; - } - } -} - -static unique_ptr BindConcatWSFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - for (auto &arg : bound_function.arguments) { - arg = LogicalType::VARCHAR; - } - bound_function.varargs = LogicalType::VARCHAR; - return nullptr; -} - -ScalarFunction ConcatWsFun::GetFunction() { - // concat_ws functions similarly to the concat function, except the result is NULL if the separator is NULL - // if the separator is not NULL, however, NULL values are counted as empty string - // there is one separate rule: there are no separators added between NULL values, - // so the NULL value and empty string are different! - // e.g.: - // concat_ws(',', NULL, NULL) = "" - // concat_ws(',', '', '') = "," - - ScalarFunction concat_ws = ScalarFunction("concat_ws", {LogicalType::VARCHAR, LogicalType::ANY}, - LogicalType::VARCHAR, ConcatWSFunction, BindConcatWSFunction); - concat_ws.varargs = LogicalType::ANY; - concat_ws.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return ScalarFunction(concat_ws); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/contains.cpp b/src/duckdb/src/function/scalar/string/contains.cpp deleted file mode 100644 index b34d62c78..000000000 --- a/src/duckdb/src/function/scalar/string/contains.cpp +++ /dev/null @@ -1,173 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/list_functions.hpp" -#include "duckdb/function/scalar/map_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -template -static idx_t ContainsUnaligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t base_offset) { - if (NEEDLE_SIZE > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // contains for a small unaligned needle (3/5/6/7 bytes) - // we perform unsigned integer comparisons to check for equality of the entire needle in a single comparison - // this implementation is inspired by the memmem implementation of freebsd - - // first we set up the needle and the first NEEDLE_SIZE characters of the haystack as UNSIGNED integers - UNSIGNED needle_entry = 0; - UNSIGNED haystack_entry = 0; - const UNSIGNED start = (sizeof(UNSIGNED) * 8) - 8; - const UNSIGNED shift = (sizeof(UNSIGNED) - NEEDLE_SIZE) * 8; - for (idx_t i = 0; i < NEEDLE_SIZE; i++) { - needle_entry |= UNSIGNED(needle[i]) << UNSIGNED(start - i * 8); - haystack_entry |= UNSIGNED(haystack[i]) << UNSIGNED(start - i * 8); - } - // now we perform the actual search - for (idx_t offset = NEEDLE_SIZE; offset < haystack_size; offset++) { - // for this position we first compare the haystack with the needle - if (haystack_entry == needle_entry) { - return base_offset + offset - NEEDLE_SIZE; - } - // now we adjust the haystack entry by - // (1) removing the left-most character (shift by 8) - // (2) adding the next character (bitwise or, with potential shift) - // this shift is only necessary if the needle size is not aligned with the unsigned integer size - // (e.g. needle size 3, unsigned integer size 4, we need to shift by 1) - haystack_entry = (haystack_entry << 8) | ((UNSIGNED(haystack[offset])) << shift); - } - if (haystack_entry == needle_entry) { - return base_offset + haystack_size - NEEDLE_SIZE; - } - return DConstants::INVALID_INDEX; -} - -template -static idx_t ContainsAligned(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t base_offset) { - if (sizeof(UNSIGNED) > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // contains for a small needle aligned with unsigned integer (2/4/8) - // similar to ContainsUnaligned, but simpler because we only need to do a reinterpret cast - auto needle_entry = Load(needle); - for (idx_t offset = 0; offset <= haystack_size - sizeof(UNSIGNED); offset++) { - // for this position we first compare the haystack with the needle - auto haystack_entry = Load(haystack + offset); - if (needle_entry == haystack_entry) { - return base_offset + offset; - } - } - return DConstants::INVALID_INDEX; -} - -idx_t ContainsGeneric(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, - idx_t needle_size, idx_t base_offset) { - if (needle_size > haystack_size) { - // needle is bigger than haystack: haystack cannot contain needle - return DConstants::INVALID_INDEX; - } - // this implementation is inspired by Raphael Javaux's faststrstr (https://github.com/RaphaelJ/fast_strstr) - // generic contains; note that we can't use strstr because we don't have null-terminated strings anymore - // we keep track of a shifting window sum of all characters with window size equal to needle_size - // this shifting sum is used to avoid calling into memcmp; - // we only need to call into memcmp when the window sum is equal to the needle sum - // when that happens, the characters are potentially the same and we call into memcmp to check if they are - uint32_t sums_diff = 0; - for (idx_t i = 0; i < needle_size; i++) { - sums_diff += haystack[i]; - sums_diff -= needle[i]; - } - idx_t offset = 0; - while (true) { - if (sums_diff == 0 && haystack[offset] == needle[0]) { - if (memcmp(haystack + offset, needle, needle_size) == 0) { - return base_offset + offset; - } - } - if (offset >= haystack_size - needle_size) { - return DConstants::INVALID_INDEX; - } - sums_diff -= haystack[offset]; - sums_diff += haystack[offset + needle_size]; - offset++; - } -} - -idx_t FindStrInStr(const unsigned char *haystack, idx_t haystack_size, const unsigned char *needle, idx_t needle_size) { - D_ASSERT(needle_size > 0); - // start off by performing a memchr to find the first character of the - auto location = memchr(haystack, needle[0], haystack_size); - if (location == nullptr) { - return DConstants::INVALID_INDEX; - } - idx_t base_offset = UnsafeNumericCast(const_uchar_ptr_cast(location) - haystack); - haystack_size -= base_offset; - haystack = const_uchar_ptr_cast(location); - // switch algorithm depending on needle size - switch (needle_size) { - case 1: - return base_offset; - case 2: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - case 3: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 4: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - case 5: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 6: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 7: - return ContainsUnaligned(haystack, haystack_size, needle, base_offset); - case 8: - return ContainsAligned(haystack, haystack_size, needle, base_offset); - default: - return ContainsGeneric(haystack, haystack_size, needle, needle_size, base_offset); - } -} - -idx_t FindStrInStr(const string_t &haystack_s, const string_t &needle_s) { - auto haystack = const_uchar_ptr_cast(haystack_s.GetData()); - auto haystack_size = haystack_s.GetSize(); - auto needle = const_uchar_ptr_cast(needle_s.GetData()); - auto needle_size = needle_s.GetSize(); - if (needle_size == 0) { - // empty needle: always true - return 0; - } - return FindStrInStr(haystack, haystack_size, needle, needle_size); -} - -struct ContainsOperator { - template - static inline TR Operation(TA left, TB right) { - return FindStrInStr(left, right) != DConstants::INVALID_INDEX; - } -}; - -ScalarFunction GetStringContains() { - ScalarFunction string_fun("contains", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction); - string_fun.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return string_fun; -} - -ScalarFunctionSet ContainsFun::GetFunctions() { - auto string_fun = GetStringContains(); - auto list_fun = ListContainsFun::GetFunction(); - auto map_fun = MapContainsFun::GetFunction(); - ScalarFunctionSet set("contains"); - set.AddFunction(string_fun); - set.AddFunction(list_fun); - set.AddFunction(map_fun); - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/length.cpp b/src/duckdb/src/function/scalar/string/length.cpp deleted file mode 100644 index 538646419..000000000 --- a/src/duckdb/src/function/scalar/string/length.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "utf8proc.hpp" - -namespace duckdb { - -// length returns the number of unicode codepoints -struct StringLengthOperator { - template - static inline TR Operation(TA input) { - return Length(input); - } -}; - -struct GraphemeCountOperator { - template - static inline TR Operation(TA input) { - return GraphemeCount(input); - } -}; - -// strlen returns the size in bytes -struct StrLenOperator { - template - static inline TR Operation(TA input) { - return UnsafeNumericCast(input.GetSize()); - } -}; - -struct OctetLenOperator { - template - static inline TR Operation(TA input) { - return UnsafeNumericCast(Bit::OctetLength(input)); - } -}; - -// bitlen returns the size in bits -struct BitLenOperator { - template - static inline TR Operation(TA input) { - return UnsafeNumericCast(8 * input.GetSize()); - } -}; - -// bitstringlen returns the amount of bits in a bitstring -struct BitStringLenOperator { - template - static inline TR Operation(TA input) { - return UnsafeNumericCast(Bit::BitLength(input)); - } -}; - -static unique_ptr LengthPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() == 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::UnaryFunction; - } - return nullptr; -} - -//------------------------------------------------------------------ -// ARRAY / LIST LENGTH -//------------------------------------------------------------------ -static void ListLengthFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - D_ASSERT(input.GetType().id() == LogicalTypeId::LIST); - UnaryExecutor::Execute( - input, result, args.size(), [](list_entry_t input) { return UnsafeNumericCast(input.length); }); - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static void ArrayLengthFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnifiedVectorFormat format; - args.data[0].ToUnifiedFormat(args.size(), format); - - // for arrays the length is constant - result.SetVectorType(VectorType::CONSTANT_VECTOR); - ConstantVector::GetData(result)[0] = static_cast(ArrayType::GetSize(input.GetType())); - - // but we do need to take null values into account - if (format.validity.AllValid()) { - // if there are no null values we can just return the constant - return; - } - // otherwise we flatten and inherit the null values of the parent - result.Flatten(args.size()); - auto &result_validity = FlatVector::Validity(result); - for (idx_t r = 0; r < args.size(); r++) { - auto idx = format.sel->get_index(r); - if (!format.validity.RowIsValid(idx)) { - result_validity.SetInvalid(r); - } - } - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ArrayOrListLengthBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->HasParameter() || arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - const auto &arg_type = arguments[0]->return_type.id(); - if (arg_type == LogicalTypeId::ARRAY) { - bound_function.function = ArrayLengthFunction; - } else if (arg_type == LogicalTypeId::LIST) { - bound_function.function = ListLengthFunction; - } else { - // Unreachable - throw BinderException("length can only be used on arrays or lists"); - } - bound_function.arguments[0] = arguments[0]->return_type; - return nullptr; -} - -//------------------------------------------------------------------ -// ARRAY / LIST WITH DIMENSION -//------------------------------------------------------------------ -static void ListLengthBinaryFunction(DataChunk &args, ExpressionState &, Vector &result) { - auto type = args.data[0].GetType(); - auto &input = args.data[0]; - auto &dimension = args.data[1]; - BinaryExecutor::Execute( - input, dimension, result, args.size(), [](list_entry_t input, int64_t dimension) { - if (dimension != 1) { - throw NotImplementedException("array_length for lists with dimensions other than 1 not implemented"); - } - return UnsafeNumericCast(input.length); - }); - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -struct ArrayLengthBinaryFunctionData : public FunctionData { - vector dimensions; - - unique_ptr Copy() const override { - auto copy = make_uniq(); - copy->dimensions = dimensions; - return std::move(copy); - } - - bool Equals(const FunctionData &other) const override { - auto &other_data = other.Cast(); - return dimensions == other_data.dimensions; - } -}; - -static void ArrayLengthBinaryFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto type = args.data[0].GetType(); - auto &dimension = args.data[1]; - - auto &expr = state.expr.Cast(); - auto &data = expr.bind_info->Cast(); - auto &dimensions = data.dimensions; - auto max_dimension = static_cast(dimensions.size()); - - UnaryExecutor::Execute(dimension, result, args.size(), [&](int64_t dimension) { - if (dimension < 1 || dimension > max_dimension) { - throw OutOfRangeException(StringUtil::Format( - "array_length dimension '%lld' out of range (min: '1', max: '%lld')", dimension, max_dimension)); - } - return dimensions[UnsafeNumericCast(dimension - 1)]; - }); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -static unique_ptr ArrayOrListLengthBinaryBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - if (arguments[0]->HasParameter() || arguments[0]->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - auto type = arguments[0]->return_type; - if (type.id() == LogicalTypeId::ARRAY) { - bound_function.arguments[0] = type; - bound_function.function = ArrayLengthBinaryFunction; - - // If the input is an array, the dimensions are constant, so we can calculate them at bind time - vector dimensions; - while (true) { - if (type.id() == LogicalTypeId::ARRAY) { - dimensions.push_back(UnsafeNumericCast(ArrayType::GetSize(type))); - type = ArrayType::GetChildType(type); - } else { - break; - } - } - auto data = make_uniq(); - data->dimensions = dimensions; - return std::move(data); - - } else if (type.id() == LogicalTypeId::LIST) { - bound_function.function = ListLengthBinaryFunction; - bound_function.arguments[0] = type; - return nullptr; - } else { - // Unreachable - throw BinderException("array_length can only be used on arrays or lists"); - } -} - -ScalarFunctionSet LengthFun::GetFunctions() { - ScalarFunctionSet length("length"); - length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction, nullptr, - nullptr, LengthPropagateStats)); - length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - length.AddFunction( - ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::BIGINT, nullptr, ArrayOrListLengthBind)); - return (length); -} - -ScalarFunctionSet LengthGraphemeFun::GetFunctions() { - ScalarFunctionSet length_grapheme("length_grapheme"); - length_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction, - nullptr, nullptr, LengthPropagateStats)); - return (length_grapheme); -} - -ScalarFunctionSet ArrayLengthFun::GetFunctions() { - ScalarFunctionSet array_length("array_length"); - array_length.AddFunction( - ScalarFunction({LogicalType::LIST(LogicalType::ANY)}, LogicalType::BIGINT, nullptr, ArrayOrListLengthBind)); - array_length.AddFunction(ScalarFunction({LogicalType::LIST(LogicalType::ANY), LogicalType::BIGINT}, - LogicalType::BIGINT, nullptr, ArrayOrListLengthBinaryBind)); - for (auto &func : array_length.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return (array_length); -} - -ScalarFunction StrlenFun::GetFunction() { - return ScalarFunction("strlen", {LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction); -} - -ScalarFunctionSet BitLengthFun::GetFunctions() { - ScalarFunctionSet bit_length("bit_length"); - bit_length.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - bit_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - return (bit_length); -} - -ScalarFunctionSet OctetLengthFun::GetFunctions() { - // length for BLOB type - ScalarFunctionSet octet_length("octet_length"); - octet_length.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - octet_length.AddFunction(ScalarFunction({LogicalType::BIT}, LogicalType::BIGINT, - ScalarFunction::UnaryFunction)); - return (octet_length); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/like.cpp b/src/duckdb/src/function/scalar/string/like.cpp deleted file mode 100644 index 5e279cf28..000000000 --- a/src/duckdb/src/function/scalar/string/like.cpp +++ /dev/null @@ -1,587 +0,0 @@ -#include "duckdb/common/exception.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -struct StandardCharacterReader { - static void NextCharacter(const char *sdata, idx_t slen, idx_t &sidx) { - sidx++; - while (sidx < slen && !IsCharacter(sdata[sidx])) { - sidx++; - } - } - - static char Operation(const char *data, idx_t pos) { - return data[pos]; - } -}; - -struct ASCIILCaseReader { - static void NextCharacter(const char *sdata, idx_t slen, idx_t &sidx) { - sidx++; - } - - static char Operation(const char *data, idx_t pos) { - return (char)StringUtil::ASCII_TO_LOWER_MAP[(uint8_t)data[pos]]; - } -}; - -template -bool TemplatedLikeOperator(const char *sdata, idx_t slen, const char *pdata, idx_t plen, char escape) { - idx_t pidx = 0; - idx_t sidx = 0; - for (; pidx < plen && sidx < slen; pidx++) { - char pchar = READER::Operation(pdata, pidx); - char schar = READER::Operation(sdata, sidx); - if (HAS_ESCAPE && pchar == escape) { - pidx++; - if (pidx == plen) { - throw SyntaxException("Like pattern must not end with escape character!"); - } - if (pdata[pidx] != schar) { - return false; - } - sidx++; - } else if (pchar == UNDERSCORE) { - READER::NextCharacter(sdata, slen, sidx); - } else if (pchar == PERCENTAGE) { - pidx++; - while (pidx < plen && pdata[pidx] == PERCENTAGE) { - pidx++; - } - if (pidx == plen) { - return true; /* tail is acceptable */ - } - for (; sidx < slen; sidx++) { - if (TemplatedLikeOperator( - sdata + sidx, slen - sidx, pdata + pidx, plen - pidx, escape)) { - return true; - } - } - return false; - } else if (pchar == schar) { - sidx++; - } else { - return false; - } - } - while (pidx < plen && pdata[pidx] == PERCENTAGE) { - pidx++; - } - return pidx == plen && sidx == slen; -} - -struct LikeSegment { - explicit LikeSegment(string pattern) : pattern(std::move(pattern)) { - } - - string pattern; -}; - -struct LikeMatcher : public FunctionData { - LikeMatcher(string like_pattern_p, vector segments, bool has_start_percentage, bool has_end_percentage) - : like_pattern(std::move(like_pattern_p)), segments(std::move(segments)), - has_start_percentage(has_start_percentage), has_end_percentage(has_end_percentage) { - } - - bool Match(string_t &str) { - auto str_data = const_uchar_ptr_cast(str.GetData()); - auto str_len = str.GetSize(); - idx_t segment_idx = 0; - idx_t end_idx = segments.size() - 1; - if (!has_start_percentage) { - // no start sample_size: match the first part of the string directly - auto &segment = segments[0]; - if (str_len < segment.pattern.size()) { - return false; - } - if (memcmp(str_data, segment.pattern.c_str(), segment.pattern.size()) != 0) { - return false; - } - str_data += segment.pattern.size(); - str_len -= segment.pattern.size(); - segment_idx++; - if (segments.size() == 1) { - // only one segment, and it matches - // we have a match if there is an end sample_size, OR if the memcmp was an exact match (remaining str is - // empty) - return has_end_percentage || str_len == 0; - } - } - // main match loop: for every segment in the middle, use Contains to find the needle in the haystack - for (; segment_idx < end_idx; segment_idx++) { - auto &segment = segments[segment_idx]; - // find the pattern of the current segment - idx_t next_offset = - FindStrInStr(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), segment.pattern.size()); - if (next_offset == DConstants::INVALID_INDEX) { - // could not find this pattern in the string: no match - return false; - } - idx_t offset = next_offset + segment.pattern.size(); - str_data += offset; - str_len -= offset; - } - if (!has_end_percentage) { - end_idx--; - // no end sample_size: match the final segment now - auto &segment = segments.back(); - if (str_len < segment.pattern.size()) { - return false; - } - if (memcmp(str_data + str_len - segment.pattern.size(), segment.pattern.c_str(), segment.pattern.size()) != - 0) { - return false; - } - return true; - } else { - auto &segment = segments.back(); - // find the pattern of the current segment - idx_t next_offset = - FindStrInStr(str_data, str_len, const_uchar_ptr_cast(segment.pattern.c_str()), segment.pattern.size()); - return next_offset != DConstants::INVALID_INDEX; - } - } - - static unique_ptr CreateLikeMatcher(string like_pattern, char escape = '\0') { - vector segments; - idx_t last_non_pattern = 0; - bool has_start_percentage = false; - bool has_end_percentage = false; - for (idx_t i = 0; i < like_pattern.size(); i++) { - auto ch = like_pattern[i]; - if (ch == escape || ch == '%' || ch == '_') { - // special character, push a constant pattern - if (i > last_non_pattern) { - segments.emplace_back(like_pattern.substr(last_non_pattern, i - last_non_pattern)); - } - last_non_pattern = i + 1; - if (ch == escape || ch == '_') { - // escape or underscore: could not create efficient like matcher - // FIXME: we could handle escaped percentages here - return nullptr; - } else { - // sample_size - if (i == 0) { - has_start_percentage = true; - } - if (i + 1 == like_pattern.size()) { - has_end_percentage = true; - } - } - } - } - if (last_non_pattern < like_pattern.size()) { - segments.emplace_back(like_pattern.substr(last_non_pattern, like_pattern.size() - last_non_pattern)); - } - if (segments.empty()) { - return nullptr; - } - return make_uniq(std::move(like_pattern), std::move(segments), has_start_percentage, - has_end_percentage); - } - - unique_ptr Copy() const override { - return make_uniq(like_pattern, segments, has_start_percentage, has_end_percentage); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return like_pattern == other.like_pattern; - } - -private: - string like_pattern; - vector segments; - bool has_start_percentage; - bool has_end_percentage; -}; - -static unique_ptr LikeBindFunction(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - for (auto &arg : arguments) { - if (arg->return_type.id() == LogicalTypeId::VARCHAR && !StringType::GetCollation(arg->return_type).empty()) { - return nullptr; - } - } - if (arguments[1]->IsFoldable()) { - Value pattern_str = ExpressionExecutor::EvaluateScalar(context, *arguments[1]); - return LikeMatcher::CreateLikeMatcher(pattern_str.ToString()); - } - return nullptr; -} - -bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen, char escape) { - return TemplatedLikeOperator<'%', '_', true>(s, slen, pattern, plen, escape); -} - -bool LikeOperatorFunction(const char *s, idx_t slen, const char *pattern, idx_t plen) { - return TemplatedLikeOperator<'%', '_', false>(s, slen, pattern, plen, '\0'); -} - -bool LikeOperatorFunction(string_t &s, string_t &pat) { - return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize()); -} - -bool LikeOperatorFunction(string_t &s, string_t &pat, char escape) { - return LikeOperatorFunction(s.GetData(), s.GetSize(), pat.GetData(), pat.GetSize(), escape); -} - -bool Glob(const char *string, idx_t slen, const char *pattern, idx_t plen, bool allow_question_mark) { - idx_t sidx = 0; - idx_t pidx = 0; -main_loop : { - // main matching loop - while (sidx < slen && pidx < plen) { - char s = string[sidx]; - char p = pattern[pidx]; - switch (p) { - case '*': { - // asterisk: match any set of characters - // skip any subsequent asterisks - pidx++; - while (pidx < plen && pattern[pidx] == '*') { - pidx++; - } - // if the asterisk is the last character, the pattern always matches - if (pidx == plen) { - return true; - } - // recursively match the remainder of the pattern - for (; sidx < slen; sidx++) { - if (Glob(string + sidx, slen - sidx, pattern + pidx, plen - pidx)) { - return true; - } - } - return false; - } - case '?': - // when enabled: matches anything but null - if (allow_question_mark) { - break; - } - DUCKDB_EXPLICIT_FALLTHROUGH; - case '[': - pidx++; - goto parse_bracket; - case '\\': - // escape character, next character needs to match literally - pidx++; - // check that we still have a character remaining - if (pidx == plen) { - return false; - } - p = pattern[pidx]; - if (s != p) { - return false; - } - break; - default: - // not a control character: characters need to match literally - if (s != p) { - return false; - } - break; - } - sidx++; - pidx++; - } - while (pidx < plen && pattern[pidx] == '*') { - pidx++; - } - // we are finished only if we have consumed the full pattern - return pidx == plen && sidx == slen; -} -parse_bracket : { - // inside a bracket - if (pidx == plen) { - return false; - } - // check the first character - // if it is an exclamation mark we need to invert our logic - char p = pattern[pidx]; - char s = string[sidx]; - bool invert = false; - if (p == '!') { - invert = true; - pidx++; - } - bool found_match = invert; - idx_t start_pos = pidx; - bool found_closing_bracket = false; - // now check the remainder of the pattern - while (pidx < plen) { - p = pattern[pidx]; - // if the first character is a closing bracket, we match it literally - // otherwise it indicates an end of bracket - if (p == ']' && pidx > start_pos) { - // end of bracket found: we are done - found_closing_bracket = true; - pidx++; - break; - } - // we either match a range (a-b) or a single character (a) - // check if the next character is a dash - if (pidx + 1 == plen) { - // no next character! - break; - } - bool matches; - if (pattern[pidx + 1] == '-') { - // range! find the next character in the range - if (pidx + 2 == plen) { - break; - } - char next_char = pattern[pidx + 2]; - // check if the current character is within the range - matches = s >= p && s <= next_char; - // shift the pattern forward past the range - pidx += 3; - } else { - // no range! perform a direct match - matches = p == s; - // shift the pattern forward past the character - pidx++; - } - if (found_match == invert && matches) { - // found a match! set the found_matches flag - // we keep on pattern matching after this until we reach the end bracket - // however, we don't need to update the found_match flag anymore - found_match = !invert; - } - } - if (!found_closing_bracket) { - // no end of bracket: invalid pattern - return false; - } - if (!found_match) { - // did not match the bracket: return false; - return false; - } - // finished the bracket matching: move forward - sidx++; - goto main_loop; -} -} - -static char GetEscapeChar(string_t escape) { - // Only one escape character should be allowed - if (escape.GetSize() > 1) { - throw SyntaxException("Invalid escape string. Escape string must be empty or one character."); - } - return escape.GetSize() == 0 ? '\0' : *escape.GetData(); -} - -struct LikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - char escape_char = GetEscapeChar(escape); - return LikeOperatorFunction(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize(), escape_char); - } -}; - -struct NotLikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - return !LikeEscapeOperator::Operation(str, pattern, escape); - } -}; - -struct LikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return LikeOperatorFunction(str, pattern); - } -}; - -bool ILikeOperatorFunction(string_t &str, string_t &pattern, char escape = '\0') { - auto str_data = str.GetData(); - auto str_size = str.GetSize(); - auto pat_data = pattern.GetData(); - auto pat_size = pattern.GetSize(); - - // lowercase both the str and the pattern - idx_t str_llength = LowerLength(str_data, str_size); - auto str_ldata = make_unsafe_uniq_array_uninitialized(str_llength); - LowerCase(str_data, str_size, str_ldata.get()); - - idx_t pat_llength = LowerLength(pat_data, pat_size); - auto pat_ldata = make_unsafe_uniq_array_uninitialized(pat_llength); - LowerCase(pat_data, pat_size, pat_ldata.get()); - string_t str_lcase(str_ldata.get(), UnsafeNumericCast(str_llength)); - string_t pat_lcase(pat_ldata.get(), UnsafeNumericCast(pat_llength)); - return LikeOperatorFunction(str_lcase, pat_lcase, escape); -} - -struct ILikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - char escape_char = GetEscapeChar(escape); - return ILikeOperatorFunction(str, pattern, escape_char); - } -}; - -struct NotILikeEscapeOperator { - template - static inline bool Operation(TA str, TB pattern, TC escape) { - return !ILikeEscapeOperator::Operation(str, pattern, escape); - } -}; - -struct ILikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return ILikeOperatorFunction(str, pattern); - } -}; - -struct NotLikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return !LikeOperatorFunction(str, pattern); - } -}; - -struct NotILikeOperator { - template - static inline TR Operation(TA str, TB pattern) { - return !ILikeOperator::Operation(str, pattern); - } -}; - -struct ILikeOperatorASCII { - template - static inline TR Operation(TA str, TB pattern) { - return TemplatedLikeOperator<'%', '_', false, ASCIILCaseReader>(str.GetData(), str.GetSize(), pattern.GetData(), - pattern.GetSize(), '\0'); - } -}; - -struct NotILikeOperatorASCII { - template - static inline TR Operation(TA str, TB pattern) { - return !ILikeOperatorASCII::Operation(str, pattern); - } -}; - -struct GlobOperator { - template - static inline TR Operation(TA str, TB pattern) { - return Glob(str.GetData(), str.GetSize(), pattern.GetData(), pattern.GetSize()); - } -}; - -// This can be moved to the scalar_function class -template -static void LikeEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &str = args.data[0]; - auto &pattern = args.data[1]; - auto &escape = args.data[2]; - - TernaryExecutor::Execute( - str, pattern, escape, result, args.size(), FUNC::template Operation); -} - -template -static unique_ptr ILikePropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - D_ASSERT(child_stats.size() >= 1); - // can only propagate stats if the children have stats - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = ScalarFunction::BinaryFunction; - } - return nullptr; -} - -template -static void RegularLikeFunction(DataChunk &input, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - if (func_expr.bind_info) { - auto &matcher = func_expr.bind_info->Cast(); - // use fast like matcher - UnaryExecutor::Execute(input.data[0], result, input.size(), [&](string_t input) { - return INVERT ? !matcher.Match(input) : matcher.Match(input); - }); - } else { - // use generic like matcher - BinaryExecutor::ExecuteStandard(input.data[0], input.data[1], result, - input.size()); - } -} - -ScalarFunction NotLikeFun::GetFunction() { - ScalarFunction not_like("!~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegularLikeFunction, LikeBindFunction); - not_like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return not_like; -} - -ScalarFunction GlobPatternFun::GetFunction() { - ScalarFunction glob("~~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction); - glob.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return glob; -} - -ScalarFunction ILikeFun::GetFunction() { - ScalarFunction ilike("~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, nullptr, - ILikePropagateStats); - ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return ilike; -} - -ScalarFunction NotILikeFun::GetFunction() { - ScalarFunction not_ilike("!~~*", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - ScalarFunction::BinaryFunction, nullptr, - nullptr, ILikePropagateStats); - not_ilike.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return not_ilike; -} - -ScalarFunction LikeFun::GetFunction() { - ScalarFunction like("~~", {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegularLikeFunction, LikeBindFunction); - like.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return like; -} - -ScalarFunction NotLikeEscapeFun::GetFunction() { - ScalarFunction not_like_escape("not_like_escape", - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); - not_like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return not_like_escape; -} - -ScalarFunction IlikeEscapeFun::GetFunction() { - ScalarFunction ilike_escape("ilike_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); - ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return ilike_escape; -} - -ScalarFunction NotIlikeEscapeFun::GetFunction() { - ScalarFunction not_ilike_escape("not_ilike_escape", - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); - not_ilike_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return not_ilike_escape; -} -ScalarFunction LikeEscapeFun::GetFunction() { - ScalarFunction like_escape("like_escape", {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::BOOLEAN, LikeEscapeFunction); - like_escape.collation_handling = FunctionCollationHandling::PUSH_COMBINABLE_COLLATIONS; - return like_escape; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/md5.cpp b/src/duckdb/src/function/scalar/string/md5.cpp deleted file mode 100644 index 837f97c12..000000000 --- a/src/duckdb/src/function/scalar/string/md5.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "duckdb/common/crypto/md5.hpp" - -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" - -namespace duckdb { - -struct MD5Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto hash = StringVector::EmptyString(result, MD5Context::MD5_HASH_LENGTH_TEXT); - MD5Context context; - context.Add(input); - context.FinishHex(hash.GetDataWriteable()); - hash.Finalize(); - return hash; - } -}; - -struct MD5Number128Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input) { - data_t digest[MD5Context::MD5_HASH_LENGTH_BINARY]; - - MD5Context context; - context.Add(input); - context.Finish(digest); - return *reinterpret_cast(digest); - } -}; - -static void MD5Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::ExecuteString(input, result, args.size()); -} - -static void MD5NumberFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::Execute(input, result, args.size()); -} - -ScalarFunctionSet MD5Fun::GetFunctions() { - ScalarFunctionSet set("md5"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, MD5Function)); - set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, MD5Function)); - return set; -} - -ScalarFunctionSet MD5NumberFun::GetFunctions() { - ScalarFunctionSet set("md5_number"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::HUGEINT, MD5NumberFunction)); - set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::HUGEINT, MD5NumberFunction)); - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/nfc_normalize.cpp b/src/duckdb/src/function/scalar/string/nfc_normalize.cpp deleted file mode 100644 index 92a061494..000000000 --- a/src/duckdb/src/function/scalar/string/nfc_normalize.cpp +++ /dev/null @@ -1,34 +0,0 @@ -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -struct NFCNormalizeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto input_data = input.GetData(); - auto input_length = input.GetSize(); - if (IsAscii(input_data, input_length)) { - return input; - } - auto normalized_str = Utf8Proc::Normalize(input_data, input_length); - D_ASSERT(normalized_str); - auto result_str = StringVector::AddString(result, normalized_str); - free(normalized_str); - return result_str; - } -}; - -static void NFCNormalizeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction NFCNormalizeFun::GetFunction() { - return ScalarFunction("nfc_normalize", {LogicalType::VARCHAR}, LogicalType::VARCHAR, NFCNormalizeFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/prefix.cpp b/src/duckdb/src/function/scalar/string/prefix.cpp deleted file mode 100644 index 2b46610fa..000000000 --- a/src/duckdb/src/function/scalar/string/prefix.cpp +++ /dev/null @@ -1,68 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/common/types/string_type.hpp" - -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -static bool PrefixFunction(const string_t &str, const string_t &pattern); - -struct PrefixOperator { - template - static inline TR Operation(TA left, TB right) { - return PrefixFunction(left, right); - } -}; -static bool PrefixFunction(const string_t &str, const string_t &pattern) { - auto str_length = str.GetSize(); - auto patt_length = pattern.GetSize(); - if (patt_length > str_length) { - return false; - } - if (patt_length <= string_t::PREFIX_LENGTH) { - // short prefix - if (patt_length == 0) { - // length = 0, return true - return true; - } - - // prefix early out - const char *str_pref = str.GetPrefix(); - const char *patt_pref = pattern.GetPrefix(); - for (idx_t i = 0; i < patt_length; ++i) { - if (str_pref[i] != patt_pref[i]) { - return false; - } - } - return true; - } else { - // prefix early out - const char *str_pref = str.GetPrefix(); - const char *patt_pref = pattern.GetPrefix(); - for (idx_t i = 0; i < string_t::PREFIX_LENGTH; ++i) { - if (str_pref[i] != patt_pref[i]) { - // early out - return false; - } - } - // compare the rest of the prefix - const char *str_data = str.GetData(); - const char *patt_data = pattern.GetData(); - D_ASSERT(patt_length <= str_length); - for (idx_t i = string_t::PREFIX_LENGTH; i < patt_length; ++i) { - if (str_data[i] != patt_data[i]) { - return false; - } - } - return true; - } -} - -ScalarFunction PrefixFun::GetFunction() { - return ScalarFunction("prefix", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp.cpp b/src/duckdb/src/function/scalar/string/regexp.cpp deleted file mode 100644 index 2383e3b57..000000000 --- a/src/duckdb/src/function/scalar/string/regexp.cpp +++ /dev/null @@ -1,470 +0,0 @@ -#include "duckdb/function/scalar/regexp.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/binary_executor.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -using regexp_util::CreateStringPiece; -using regexp_util::Extract; -using regexp_util::ParseRegexOptions; -using regexp_util::TryParseConstantPattern; - -static bool RegexOptionsEquals(const duckdb_re2::RE2::Options &opt_a, const duckdb_re2::RE2::Options &opt_b) { - return opt_a.case_sensitive() == opt_b.case_sensitive(); -} - -RegexpBaseBindData::RegexpBaseBindData() : constant_pattern(false) { -} -RegexpBaseBindData::RegexpBaseBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern) - : options(options), constant_string(std::move(constant_string_p)), constant_pattern(constant_pattern) { -} - -RegexpBaseBindData::~RegexpBaseBindData() { -} - -bool RegexpBaseBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return constant_pattern == other.constant_pattern && constant_string == other.constant_string && - RegexOptionsEquals(options, other.options); -} - -unique_ptr RegexInitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data) { - auto &info = bind_data->Cast(); - if (info.constant_pattern) { - return make_uniq(info); - } - return nullptr; -} - -//===--------------------------------------------------------------------===// -// Regexp Matches -//===--------------------------------------------------------------------===// -RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern) { - if (constant_pattern) { - auto pattern = make_uniq(constant_string, options); - if (!pattern->ok()) { - throw InvalidInputException(pattern->error()); - } - - range_success = pattern->PossibleMatchRange(&range_min, &range_max, 1000); - } else { - range_success = false; - } -} - -RegexpMatchesBindData::RegexpMatchesBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, string range_min_p, string range_max_p, - bool range_success) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), range_min(std::move(range_min_p)), - range_max(std::move(range_max_p)), range_success(range_success) { -} - -unique_ptr RegexpMatchesBindData::Copy() const { - return make_uniq(options, constant_string, constant_pattern, range_min, range_max, - range_success); -} - -unique_ptr RegexpMatchesBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - // pattern is the second argument. If its constant, we can already prepare the pattern and store it for later. - D_ASSERT(arguments.size() == 2 || arguments.size() == 3); - RE2::Options options; - options.set_log_errors(false); - if (arguments.size() == 3) { - ParseRegexOptions(context, *arguments[2], options); - } - - string constant_string; - bool constant_pattern; - constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - return make_uniq(options, std::move(constant_string), constant_pattern); -} - -struct RegexPartialMatch { - static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { - return duckdb_re2::RE2::PartialMatch(input, re); - } -}; - -struct RegexFullMatch { - static inline bool Operation(const duckdb_re2::StringPiece &input, duckdb_re2::RE2 &re) { - return duckdb_re2::RE2::FullMatch(input, re); - } -}; - -template -static void RegexpMatchesFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { - return OP::Operation(CreateStringPiece(input), lstate.constant_pattern); - }); - } else { - BinaryExecutor::Execute(strings, patterns, result, args.size(), - [&](string_t input, string_t pattern) { - RE2 re(CreateStringPiece(pattern), info.options); - if (!re.ok()) { - throw InvalidInputException(re.error()); - } - return OP::Operation(CreateStringPiece(input), re); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Replace -//===--------------------------------------------------------------------===// -RegexpReplaceBindData::RegexpReplaceBindData() : global_replace(false) { -} - -RegexpReplaceBindData::RegexpReplaceBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, bool global_replace) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), global_replace(global_replace) { -} - -unique_ptr RegexpReplaceBindData::Copy() const { - auto copy = make_uniq(options, constant_string, constant_pattern, global_replace); - return std::move(copy); -} - -bool RegexpReplaceBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return RegexpBaseBindData::Equals(other) && global_replace == other.global_replace; -} - -static unique_ptr RegexReplaceBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - auto data = make_uniq(); - - data->constant_pattern = TryParseConstantPattern(context, *arguments[1], data->constant_string); - if (arguments.size() == 4) { - ParseRegexOptions(context, *arguments[3], data->options, &data->global_replace); - } - data->options.set_log_errors(false); - return std::move(data); -} - -static void RegexReplaceFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - auto &replaces = args.data[2]; - - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - BinaryExecutor::Execute( - strings, replaces, result, args.size(), [&](string_t input, string_t replace) { - std::string sstring = input.GetString(); - if (info.global_replace) { - RE2::GlobalReplace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); - } else { - RE2::Replace(&sstring, lstate.constant_pattern, CreateStringPiece(replace)); - } - return StringVector::AddString(result, sstring); - }); - } else { - TernaryExecutor::Execute( - strings, patterns, replaces, result, args.size(), [&](string_t input, string_t pattern, string_t replace) { - RE2 re(CreateStringPiece(pattern), info.options); - std::string sstring = input.GetString(); - if (info.global_replace) { - RE2::GlobalReplace(&sstring, re, CreateStringPiece(replace)); - } else { - RE2::Replace(&sstring, re, CreateStringPiece(replace)); - } - return StringVector::AddString(result, sstring); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Extract -//===--------------------------------------------------------------------===// -RegexpExtractBindData::RegexpExtractBindData() { -} - -RegexpExtractBindData::RegexpExtractBindData(duckdb_re2::RE2::Options options, string constant_string_p, - bool constant_pattern, string group_string_p) - : RegexpBaseBindData(options, std::move(constant_string_p), constant_pattern), - group_string(std::move(group_string_p)), rewrite(group_string) { -} - -unique_ptr RegexpExtractBindData::Copy() const { - return make_uniq(options, constant_string, constant_pattern, group_string); -} - -bool RegexpExtractBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return RegexpBaseBindData::Equals(other) && group_string == other.group_string; -} - -static void RegexExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - UnaryExecutor::Execute(strings, result, args.size(), [&](string_t input) { - return Extract(input, result, lstate.constant_pattern, info.rewrite); - }); - } else { - BinaryExecutor::Execute(strings, patterns, result, args.size(), - [&](string_t input, string_t pattern) { - RE2 re(CreateStringPiece(pattern), info.options); - return Extract(input, result, re, info.rewrite); - }); - } -} - -//===--------------------------------------------------------------------===// -// Regexp Extract Struct -//===--------------------------------------------------------------------===// -static void RegexExtractStructFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - - const auto count = args.size(); - auto &input = args.data[0]; - - auto &child_entries = StructVector::GetEntries(result); - const auto groupSize = child_entries.size(); - // Reference the 'input' StringBuffer, because we won't need to allocate new data - // for the result, all returned strings are substrings of the originals - for (auto &child_entry : child_entries) { - child_entry->SetAuxiliary(input.GetAuxiliary()); - } - - vector argv(groupSize); - vector groups(groupSize); - vector ws(groupSize); - for (size_t i = 0; i < groupSize; ++i) { - groups[i] = &argv[i]; - argv[i] = &ws[i]; - } - - if (input.GetVectorType() == VectorType::CONSTANT_VECTOR) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - - if (ConstantVector::IsNull(input)) { - ConstantVector::SetNull(result, true); - } else { - ConstantVector::SetNull(result, false); - auto idata = ConstantVector::GetData(input); - auto str = CreateStringPiece(idata[0]); - auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), - UnsafeNumericCast(groups.size())); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - ConstantVector::SetNull(*child_entry, false); - auto &extracted = ws[col]; - auto cdata = ConstantVector::GetData(*child_entry); - cdata[0] = string_t(extracted.data(), UnsafeNumericCast(match ? extracted.size() : 0)); - } - } - } else { - UnifiedVectorFormat iunified; - input.ToUnifiedFormat(count, iunified); - - const auto &ivalidity = iunified.validity; - auto idata = UnifiedVectorFormat::GetData(iunified); - - // Start with a valid flat vector - result.SetVectorType(VectorType::FLAT_VECTOR); - - // Start with valid children - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - child_entry->SetVectorType(VectorType::FLAT_VECTOR); - } - - for (idx_t i = 0; i < count; ++i) { - const auto idx = iunified.sel->get_index(i); - if (ivalidity.RowIsValid(idx)) { - auto str = CreateStringPiece(idata[idx]); - auto match = duckdb_re2::RE2::PartialMatchN(str, lstate.constant_pattern, groups.data(), - UnsafeNumericCast(groups.size())); - for (size_t col = 0; col < child_entries.size(); ++col) { - auto &child_entry = child_entries[col]; - auto cdata = FlatVector::GetData(*child_entry); - auto &extracted = ws[col]; - cdata[i] = string_t(extracted.data(), UnsafeNumericCast(match ? extracted.size() : 0)); - } - } else { - FlatVector::SetNull(result, i, true); - } - } - } -} - -static unique_ptr RegexExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - - duckdb_re2::RE2::Options options; - - string constant_string; - bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - - if (arguments.size() >= 4) { - ParseRegexOptions(context, *arguments[3], options); - } - - string group_string = "\\0"; - if (arguments.size() >= 3) { - if (arguments[2]->HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!arguments[2]->IsFoldable()) { - throw InvalidInputException("Group specification field must be a constant!"); - } - Value group = ExpressionExecutor::EvaluateScalar(context, *arguments[2]); - if (group.IsNull()) { - group_string = ""; - } else if (group.type().id() == LogicalTypeId::LIST) { - if (!constant_pattern) { - throw BinderException("%s with LIST requires a constant pattern", bound_function.name); - } - auto &list_children = ListValue::GetChildren(group); - if (list_children.empty()) { - throw BinderException("%s requires non-empty lists of capture names", bound_function.name); - } - case_insensitive_set_t name_collision_set; - child_list_t struct_children; - for (const auto &child : list_children) { - if (child.IsNull()) { - throw BinderException("NULL group name in %s", bound_function.name); - } - const auto group_name = child.ToString(); - if (name_collision_set.find(group_name) != name_collision_set.end()) { - throw BinderException("Duplicate group name \"%s\" in %s", group_name, bound_function.name); - } - name_collision_set.insert(group_name); - struct_children.emplace_back(make_pair(group_name, LogicalType::VARCHAR)); - } - bound_function.return_type = LogicalType::STRUCT(struct_children); - - duckdb_re2::StringPiece constant_piece(constant_string.c_str(), constant_string.size()); - RE2 constant_pattern(constant_piece, options); - if (size_t(constant_pattern.NumberOfCapturingGroups()) < list_children.size()) { - throw BinderException("Not enough group names in %s", bound_function.name); - } - } else { - auto group_idx = group.GetValue(); - if (group_idx < 0 || group_idx > 9) { - throw InvalidInputException("Group index must be between 0 and 9!"); - } - group_string = "\\" + to_string(group_idx); - } - } - - return make_uniq(options, std::move(constant_string), constant_pattern, - std::move(group_string)); -} - -ScalarFunctionSet RegexpFun::GetFunctions() { - ScalarFunctionSet regexp_full_match("regexp_full_match"); - regexp_full_match.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_full_match.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - return (regexp_full_match); -} - -ScalarFunctionSet RegexpMatchesFun::GetFunctions() { - ScalarFunctionSet regexp_partial_match("regexp_matches"); - regexp_partial_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, RegexpMatchesFunction, - RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING)); - regexp_partial_match.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::BOOLEAN, - RegexpMatchesFunction, RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - for (auto &func : regexp_partial_match.functions) { - BaseScalarFunction::SetReturnsError(func); - } - return (regexp_partial_match); -} - -ScalarFunctionSet RegexpReplaceFun::GetFunctions() { - ScalarFunctionSet regexp_replace("regexp_replace"); - regexp_replace.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, - LogicalType::VARCHAR, RegexReplaceFunction, RegexReplaceBind, nullptr, - nullptr, RegexInitLocalState)); - regexp_replace.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - RegexReplaceFunction, RegexReplaceBind, nullptr, nullptr, RegexInitLocalState)); - return (regexp_replace); -} - -ScalarFunctionSet RegexpExtractFun::GetFunctions() { - ScalarFunctionSet regexp_extract("regexp_extract"); - regexp_extract.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::VARCHAR, - RegexExtractFunction, RegexExtractBind, nullptr, nullptr, - RegexInitLocalState, LogicalType::INVALID, FunctionStability::CONSISTENT, - FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, - LogicalType::VARCHAR, RegexExtractFunction, RegexExtractBind, nullptr, - nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, LogicalType::VARCHAR, - RegexExtractFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - // REGEXP_EXTRACT(, , [[, ]...]) - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR)}, LogicalType::VARCHAR, - RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - // REGEXP_EXTRACT(, , [[, ]...], ) - regexp_extract.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::LIST(LogicalType::VARCHAR), LogicalType::VARCHAR}, - LogicalType::VARCHAR, RegexExtractStructFunction, RegexExtractBind, nullptr, nullptr, RegexInitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - return (regexp_extract); -} - -ScalarFunctionSet RegexpExtractAllFun::GetFunctions() { - ScalarFunctionSet regexp_extract_all("regexp_extract_all"); - regexp_extract_all.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR}, LogicalType::LIST(LogicalType::VARCHAR), - RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract_all.AddFunction(ScalarFunction( - {LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::LIST(LogicalType::VARCHAR), - RegexpExtractAll::Execute, RegexpExtractAll::Bind, nullptr, nullptr, RegexpExtractAll::InitLocalState, - LogicalType::INVALID, FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - regexp_extract_all.AddFunction( - ScalarFunction({LogicalType::VARCHAR, LogicalType::VARCHAR, LogicalType::INTEGER, LogicalType::VARCHAR}, - LogicalType::LIST(LogicalType::VARCHAR), RegexpExtractAll::Execute, RegexpExtractAll::Bind, - nullptr, nullptr, RegexpExtractAll::InitLocalState, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING)); - return (regexp_extract_all); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp deleted file mode 100644 index 144dcff03..000000000 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_extract_all.cpp +++ /dev/null @@ -1,246 +0,0 @@ -#include "duckdb/function/scalar/regexp.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "re2/re2.h" - -namespace duckdb { - -using regexp_util::CreateStringPiece; -using regexp_util::Extract; -using regexp_util::ParseRegexOptions; -using regexp_util::TryParseConstantPattern; - -unique_ptr -RegexpExtractAll::InitLocalState(ExpressionState &state, const BoundFunctionExpression &expr, FunctionData *bind_data) { - auto &info = bind_data->Cast(); - if (info.constant_pattern) { - return make_uniq(info, true); - } - return nullptr; -} - -// Forwards startpos automatically -bool ExtractAll(duckdb_re2::StringPiece &input, duckdb_re2::RE2 &pattern, idx_t *startpos, - duckdb_re2::StringPiece *groups, int ngroups) { - - D_ASSERT(pattern.ok()); - D_ASSERT(pattern.NumberOfCapturingGroups() == ngroups); - - if (!pattern.Match(input, *startpos, input.size(), pattern.UNANCHORED, groups, ngroups + 1)) { - return false; - } - idx_t consumed = static_cast(groups[0].end() - (input.begin() + *startpos)); - if (!consumed) { - // Empty match found, have to manually forward the input - // to avoid an infinite loop - // FIXME: support unicode characters - consumed++; - while (*startpos + consumed < input.length() && !IsCharacter(input[*startpos + consumed])) { - consumed++; - } - } - *startpos += consumed; - return true; -} - -void ExtractSingleTuple(const string_t &string, duckdb_re2::RE2 &pattern, int32_t group, RegexStringPieceArgs &args, - Vector &result, idx_t row) { - auto input = CreateStringPiece(string); - - auto &child_vector = ListVector::GetEntry(result); - auto list_content = FlatVector::GetData(child_vector); - auto &child_validity = FlatVector::Validity(child_vector); - - auto current_list_size = ListVector::GetListSize(result); - auto current_list_capacity = ListVector::GetListCapacity(result); - - auto result_data = FlatVector::GetData(result); - auto &list_entry = result_data[row]; - list_entry.offset = current_list_size; - - if (group < 0) { - list_entry.length = 0; - return; - } - // If the requested group index is out of bounds - // we want to throw only if there is a match - bool throw_on_group_found = (idx_t)group > args.size; - - idx_t startpos = 0; - for (idx_t iteration = 0; - ExtractAll(input, pattern, &startpos, args.group_buffer, UnsafeNumericCast(args.size)); iteration++) { - if (!iteration && throw_on_group_found) { - throw InvalidInputException("Pattern has %d groups. Cannot access group %d", args.size, group); - } - - // Make sure we have enough room for the new entries - if (current_list_size + 1 >= current_list_capacity) { - ListVector::Reserve(result, current_list_capacity * 2); - current_list_capacity = ListVector::GetListCapacity(result); - list_content = FlatVector::GetData(child_vector); - } - - // Write the captured groups into the list-child vector - auto &match_group = args.group_buffer[group]; - - idx_t child_idx = current_list_size; - if (match_group.empty()) { - // This group was not matched - list_content[child_idx] = string_t(string.GetData(), 0); - if (match_group.begin() == nullptr) { - // This group is optional - child_validity.SetInvalid(child_idx); - } - } else { - // Every group is a substring of the original, we can find out the offset using the pointer - // the 'match_group' address is guaranteed to be bigger than that of the source - D_ASSERT(const_char_ptr_cast(match_group.begin()) >= string.GetData()); - auto offset = UnsafeNumericCast(match_group.begin() - string.GetData()); - list_content[child_idx] = - string_t(string.GetData() + offset, UnsafeNumericCast(match_group.size())); - } - current_list_size++; - if (startpos > input.size()) { - // Empty match found at the end of the string - break; - } - } - list_entry.length = current_list_size - list_entry.offset; - ListVector::SetListSize(result, current_list_size); -} - -int32_t GetGroupIndex(DataChunk &args, idx_t row, int32_t &result) { - if (args.ColumnCount() < 3) { - result = 0; - return true; - } - UnifiedVectorFormat format; - args.data[2].ToUnifiedFormat(args.size(), format); - idx_t index = format.sel->get_index(row); - if (!format.validity.RowIsValid(index)) { - return false; - } - result = UnifiedVectorFormat::GetData(format)[index]; - return true; -} - -duckdb_re2::RE2 &GetPattern(const RegexpBaseBindData &info, ExpressionState &state, - unique_ptr &pattern_p) { - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - return lstate.constant_pattern; - } - D_ASSERT(pattern_p); - return *pattern_p; -} - -RegexStringPieceArgs &GetGroupsBuffer(const RegexpBaseBindData &info, ExpressionState &state, - unique_ptr &groups_p) { - if (info.constant_pattern) { - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - return lstate.group_buffer; - } - D_ASSERT(groups_p); - return *groups_p; -} - -void RegexpExtractAll::Execute(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - const auto &info = func_expr.bind_info->Cast(); - - auto &strings = args.data[0]; - auto &patterns = args.data[1]; - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - auto &output_child = ListVector::GetEntry(result); - - UnifiedVectorFormat strings_data; - strings.ToUnifiedFormat(args.size(), strings_data); - - UnifiedVectorFormat pattern_data; - patterns.ToUnifiedFormat(args.size(), pattern_data); - - ListVector::Reserve(result, STANDARD_VECTOR_SIZE); - // Reference the 'strings' StringBuffer, because we won't need to allocate new data - // for the result, all returned strings are substrings of the originals - output_child.SetAuxiliary(strings.GetAuxiliary()); - - // Avoid doing extra work if all the inputs are constant - idx_t tuple_count = args.AllConstant() ? 1 : args.size(); - - unique_ptr non_const_args; - unique_ptr stored_re; - if (!info.constant_pattern) { - non_const_args = make_uniq(); - } else { - // Verify that the constant pattern is valid - auto &re = GetPattern(info, state, stored_re); - auto group_count_p = re.NumberOfCapturingGroups(); - if (group_count_p == -1) { - throw InvalidInputException("Pattern failed to parse, error: '%s'", re.error()); - } - } - - for (idx_t row = 0; row < tuple_count; row++) { - bool pattern_valid = true; - if (!info.constant_pattern) { - // Check if the pattern is NULL or not, - // and compile the pattern if it's not constant - auto pattern_idx = pattern_data.sel->get_index(row); - if (!pattern_data.validity.RowIsValid(pattern_idx)) { - pattern_valid = false; - } else { - auto &pattern_p = UnifiedVectorFormat::GetData(pattern_data)[pattern_idx]; - auto pattern_strpiece = CreateStringPiece(pattern_p); - stored_re = make_uniq(pattern_strpiece, info.options); - - // Increase the size of the args buffer if needed - auto group_count_p = stored_re->NumberOfCapturingGroups(); - if (group_count_p == -1) { - throw InvalidInputException("Pattern failed to parse, error: '%s'", stored_re->error()); - } - non_const_args->SetSize(UnsafeNumericCast(group_count_p)); - } - } - - auto string_idx = strings_data.sel->get_index(row); - int32_t group_index; - if (!pattern_valid || !strings_data.validity.RowIsValid(string_idx) || !GetGroupIndex(args, row, group_index)) { - // If something is NULL, the result is NULL - // FIXME: do we even need 'SPECIAL_HANDLING'? - auto result_data = FlatVector::GetData(result); - auto &result_validity = FlatVector::Validity(result); - result_data[row].length = 0; - result_data[row].offset = ListVector::GetListSize(result); - result_validity.SetInvalid(row); - continue; - } - - auto &re = GetPattern(info, state, stored_re); - auto &groups = GetGroupsBuffer(info, state, non_const_args); - auto &string = UnifiedVectorFormat::GetData(strings_data)[string_idx]; - ExtractSingleTuple(string, re, group_index, groups, result, row); - } - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } -} - -unique_ptr RegexpExtractAll::Bind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(arguments.size() >= 2); - - duckdb_re2::RE2::Options options; - - string constant_string; - bool constant_pattern = TryParseConstantPattern(context, *arguments[1], constant_string); - - if (arguments.size() >= 4) { - ParseRegexOptions(context, *arguments[3], options); - } - return make_uniq(options, std::move(constant_string), constant_pattern, ""); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp b/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp deleted file mode 100644 index 4e485195c..000000000 --- a/src/duckdb/src/function/scalar/string/regexp/regexp_util.cpp +++ /dev/null @@ -1,83 +0,0 @@ -#include "duckdb/function/scalar/regexp.hpp" -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -namespace regexp_util { - -bool TryParseConstantPattern(ClientContext &context, Expression &expr, string &constant_string) { - if (!expr.IsFoldable()) { - return false; - } - Value pattern_str = ExpressionExecutor::EvaluateScalar(context, expr); - if (!pattern_str.IsNull() && pattern_str.type().id() == LogicalTypeId::VARCHAR) { - constant_string = StringValue::Get(pattern_str); - return true; - } - return false; -} - -void ParseRegexOptions(const string &options, duckdb_re2::RE2::Options &result, bool *global_replace) { - for (idx_t i = 0; i < options.size(); i++) { - switch (options[i]) { - case 'c': - // case-sensitive matching - result.set_case_sensitive(true); - break; - case 'i': - // case-insensitive matching - result.set_case_sensitive(false); - break; - case 'l': - // literal matching - result.set_literal(true); - break; - case 'm': - case 'n': - case 'p': - // newline-sensitive matching - result.set_dot_nl(false); - break; - case 's': - // non-newline-sensitive matching - result.set_dot_nl(true); - break; - case 'g': - // global replace, only available for regexp_replace - if (global_replace) { - *global_replace = true; - } else { - throw InvalidInputException("Option 'g' (global replace) is only valid for regexp_replace"); - } - break; - case ' ': - case '\t': - case '\n': - // ignore whitespace - break; - default: - throw InvalidInputException("Unrecognized Regex option %c", options[i]); - } - } -} - -void ParseRegexOptions(ClientContext &context, Expression &expr, RE2::Options &target, bool *global_replace) { - if (expr.HasParameter()) { - throw ParameterNotResolvedException(); - } - if (!expr.IsFoldable()) { - throw InvalidInputException("Regex options field must be a constant"); - } - Value options_str = ExpressionExecutor::EvaluateScalar(context, expr); - if (options_str.IsNull()) { - throw InvalidInputException("Regex options field must not be NULL"); - } - if (options_str.type().id() != LogicalTypeId::VARCHAR) { - throw InvalidInputException("Regex options field must be a string"); - } - ParseRegexOptions(StringValue::Get(options_str), target, global_replace); -} - -} // namespace regexp_util - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/regexp_escape.cpp b/src/duckdb/src/function/scalar/string/regexp_escape.cpp deleted file mode 100644 index 3d72fe681..000000000 --- a/src/duckdb/src/function/scalar/string/regexp_escape.cpp +++ /dev/null @@ -1,22 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "re2/re2.h" - -namespace duckdb { - -struct EscapeOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE &input, Vector &result) { - auto escaped_pattern = RE2::QuoteMeta(input.GetString()); - return StringVector::AddString(result, escaped_pattern); - } -}; - -static void RegexpEscapeFunction(DataChunk &args, ExpressionState &state, Vector &result) { - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); -} - -ScalarFunction RegexpEscapeFun::GetFunction() { - return ScalarFunction("regexp_escape", {LogicalType::VARCHAR}, LogicalType::VARCHAR, RegexpEscapeFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/sha1.cpp b/src/duckdb/src/function/scalar/string/sha1.cpp deleted file mode 100644 index c59dcf252..000000000 --- a/src/duckdb/src/function/scalar/string/sha1.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "mbedtls_wrapper.hpp" - -namespace duckdb { - -struct SHA1Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto hash = StringVector::EmptyString(result, duckdb_mbedtls::MbedTlsWrapper::SHA1_HASH_LENGTH_TEXT); - - duckdb_mbedtls::MbedTlsWrapper::SHA1State state; - state.AddString(input.GetString()); - state.FinishHex(hash.GetDataWriteable()); - - hash.Finalize(); - return hash; - } -}; - -static void SHA1Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::ExecuteString(input, result, args.size()); -} - -ScalarFunctionSet SHA1Fun::GetFunctions() { - ScalarFunctionSet set("sha1"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA1Function)); - set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, SHA1Function)); - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/sha256.cpp b/src/duckdb/src/function/scalar/string/sha256.cpp deleted file mode 100644 index a48ccf93f..000000000 --- a/src/duckdb/src/function/scalar/string/sha256.cpp +++ /dev/null @@ -1,35 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/unary_executor.hpp" -#include "mbedtls_wrapper.hpp" - -namespace duckdb { - -struct SHA256Operator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - auto hash = StringVector::EmptyString(result, duckdb_mbedtls::MbedTlsWrapper::SHA256_HASH_LENGTH_TEXT); - - duckdb_mbedtls::MbedTlsWrapper::SHA256State state; - state.AddString(input.GetString()); - state.FinishHex(hash.GetDataWriteable()); - - hash.Finalize(); - return hash; - } -}; - -static void SHA256Function(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input = args.data[0]; - - UnaryExecutor::ExecuteString(input, result, args.size()); -} - -ScalarFunctionSet SHA256Fun::GetFunctions() { - ScalarFunctionSet set("sha256"); - set.AddFunction(ScalarFunction({LogicalType::VARCHAR}, LogicalType::VARCHAR, SHA256Function)); - set.AddFunction(ScalarFunction({LogicalType::BLOB}, LogicalType::VARCHAR, SHA256Function)); - return set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/string_split.cpp b/src/duckdb/src/function/scalar/string/string_split.cpp deleted file mode 100644 index 9673eca96..000000000 --- a/src/duckdb/src/function/scalar/string/string_split.cpp +++ /dev/null @@ -1,197 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/common/types/vector.hpp" -#include "duckdb/common/vector_size.hpp" -#include "duckdb/function/scalar/regexp.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -struct StringSplitInput { - StringSplitInput(Vector &result_list, Vector &result_child, idx_t offset) - : result_list(result_list), result_child(result_child), offset(offset) { - } - - Vector &result_list; - Vector &result_child; - idx_t offset; - - void AddSplit(const char *split_data, idx_t split_size, idx_t list_idx) { - auto list_entry = offset + list_idx; - if (list_entry >= ListVector::GetListCapacity(result_list)) { - ListVector::SetListSize(result_list, offset + list_idx); - ListVector::Reserve(result_list, ListVector::GetListCapacity(result_list) * 2); - } - FlatVector::GetData(result_child)[list_entry] = - string_t(split_data, UnsafeNumericCast(split_size)); - } -}; - -struct RegularStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - match_size = delim_size; - if (delim_size == 0) { - return 0; - } - return FindStrInStr(const_uchar_ptr_cast(input_data), input_size, const_uchar_ptr_cast(delim_data), delim_size); - } -}; - -struct ConstantRegexpStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - D_ASSERT(data); - auto regex = reinterpret_cast(data); - duckdb_re2::StringPiece match; - if (!regex->Match(duckdb_re2::StringPiece(input_data, input_size), 0, input_size, RE2::UNANCHORED, &match, 1)) { - return DConstants::INVALID_INDEX; - } - match_size = match.size(); - return UnsafeNumericCast(match.data() - input_data); - } -}; - -struct RegexpStringSplit { - static idx_t Find(const char *input_data, idx_t input_size, const char *delim_data, idx_t delim_size, - idx_t &match_size, void *data) { - duckdb_re2::RE2 regex(duckdb_re2::StringPiece(delim_data, delim_size)); - if (!regex.ok()) { - throw InvalidInputException(regex.error()); - } - return ConstantRegexpStringSplit::Find(input_data, input_size, delim_data, delim_size, match_size, ®ex); - } -}; - -struct StringSplitter { - template - static idx_t Split(string_t input, string_t delim, StringSplitInput &state, void *data) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - auto delim_data = delim.GetData(); - auto delim_size = delim.GetSize(); - idx_t list_idx = 0; - while (input_size > 0) { - idx_t match_size = 0; - auto pos = OP::Find(input_data, input_size, delim_data, delim_size, match_size, data); - if (pos > input_size) { - break; - } - if (match_size == 0 && pos == 0) { - // special case: 0 length match and pos is 0 - // move to the next character - for (pos++; pos < input_size; pos++) { - if (IsCharacter(input_data[pos])) { - break; - } - } - if (pos == input_size) { - break; - } - } - D_ASSERT(input_size >= pos + match_size); - state.AddSplit(input_data, pos, list_idx); - - list_idx++; - input_data += (pos + match_size); - input_size -= (pos + match_size); - } - state.AddSplit(input_data, input_size, list_idx); - list_idx++; - return list_idx; - } -}; - -template -static void StringSplitExecutor(DataChunk &args, ExpressionState &state, Vector &result, void *data = nullptr) { - UnifiedVectorFormat input_data; - args.data[0].ToUnifiedFormat(args.size(), input_data); - auto inputs = UnifiedVectorFormat::GetData(input_data); - - UnifiedVectorFormat delim_data; - args.data[1].ToUnifiedFormat(args.size(), delim_data); - auto delims = UnifiedVectorFormat::GetData(delim_data); - - D_ASSERT(result.GetType().id() == LogicalTypeId::LIST); - - result.SetVectorType(VectorType::FLAT_VECTOR); - ListVector::SetListSize(result, 0); - - auto list_struct_data = FlatVector::GetData(result); - - // count all the splits and set up the list entries - auto &child_entry = ListVector::GetEntry(result); - auto &result_mask = FlatVector::Validity(result); - idx_t total_splits = 0; - for (idx_t i = 0; i < args.size(); i++) { - auto input_idx = input_data.sel->get_index(i); - auto delim_idx = delim_data.sel->get_index(i); - if (!input_data.validity.RowIsValid(input_idx)) { - result_mask.SetInvalid(i); - continue; - } - StringSplitInput split_input(result, child_entry, total_splits); - if (!delim_data.validity.RowIsValid(delim_idx)) { - // delim is NULL: copy the complete entry - split_input.AddSplit(inputs[input_idx].GetData(), inputs[input_idx].GetSize(), 0); - list_struct_data[i].length = 1; - list_struct_data[i].offset = total_splits; - total_splits++; - continue; - } - auto list_length = StringSplitter::Split(inputs[input_idx], delims[delim_idx], split_input, data); - list_struct_data[i].length = list_length; - list_struct_data[i].offset = total_splits; - total_splits += list_length; - } - ListVector::SetListSize(result, total_splits); - D_ASSERT(ListVector::GetListSize(result) == total_splits); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - StringVector::AddHeapReference(child_entry, args.data[0]); -} - -static void StringSplitFunction(DataChunk &args, ExpressionState &state, Vector &result) { - StringSplitExecutor(args, state, result, nullptr); -} - -static void StringSplitRegexFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - if (info.constant_pattern) { - // fast path: pre-compiled regex - auto &lstate = ExecuteFunctionState::GetFunctionState(state)->Cast(); - StringSplitExecutor(args, state, result, &lstate.constant_pattern); - } else { - // slow path: have to re-compile regex for every row - StringSplitExecutor(args, state, result); - } -} - -ScalarFunction StringSplitFun::GetFunction() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - - ScalarFunction string_split({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitFunction); - string_split.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return string_split; -} - -ScalarFunctionSet StringSplitRegexFun::GetFunctions() { - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - ScalarFunctionSet regexp_split; - ScalarFunction regex_fun({LogicalType::VARCHAR, LogicalType::VARCHAR}, varchar_list_type, StringSplitRegexFunction, - RegexpMatchesBind, nullptr, nullptr, RegexInitLocalState, LogicalType::INVALID, - FunctionStability::CONSISTENT, FunctionNullHandling::SPECIAL_HANDLING); - regexp_split.AddFunction(regex_fun); - // regexp options - regex_fun.arguments.emplace_back(LogicalType::VARCHAR); - regexp_split.AddFunction(regex_fun); - return regexp_split; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/strip_accents.cpp b/src/duckdb/src/function/scalar/string/strip_accents.cpp deleted file mode 100644 index 2ab7ca497..000000000 --- a/src/duckdb/src/function/scalar/string/strip_accents.cpp +++ /dev/null @@ -1,45 +0,0 @@ -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" - -#include "utf8proc.hpp" - -namespace duckdb { - -bool IsAscii(const char *input, idx_t n) { - for (idx_t i = 0; i < n; i++) { - if (input[i] & 0x80) { - // non-ascii character - return false; - } - } - return true; -} - -struct StripAccentsOperator { - template - static RESULT_TYPE Operation(INPUT_TYPE input, Vector &result) { - if (IsAscii(input.GetData(), input.GetSize())) { - return input; - } - - // non-ascii, perform collation - auto stripped = utf8proc_remove_accents((const utf8proc_uint8_t *)input.GetData(), - UnsafeNumericCast(input.GetSize())); - auto result_str = StringVector::AddString(result, const_char_ptr_cast(stripped)); - free(stripped); - return result_str; - } -}; - -static void StripAccentsFunction(DataChunk &args, ExpressionState &state, Vector &result) { - D_ASSERT(args.ColumnCount() == 1); - - UnaryExecutor::ExecuteString(args.data[0], result, args.size()); - StringVector::AddHeapReference(result, args.data[0]); -} - -ScalarFunction StripAccentsFun::GetFunction() { - return ScalarFunction("strip_accents", {LogicalType::VARCHAR}, LogicalType::VARCHAR, StripAccentsFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/substring.cpp b/src/duckdb/src/function/scalar/string/substring.cpp deleted file mode 100644 index 58c93624a..000000000 --- a/src/duckdb/src/function/scalar/string/substring.cpp +++ /dev/null @@ -1,337 +0,0 @@ -#include "duckdb/function/scalar/string_common.hpp" -#include "duckdb/function/scalar/string_functions.hpp" - -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/common/vector_operations/ternary_executor.hpp" - -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "utf8proc.hpp" -#include "duckdb/common/types/blob.hpp" - -namespace duckdb { - -static const int64_t SUPPORTED_UPPER_BOUND = NumericLimits::Maximum(); -static const int64_t SUPPORTED_LOWER_BOUND = -SUPPORTED_UPPER_BOUND - 1; - -static inline void AssertInSupportedRange(idx_t input_size, int64_t offset, int64_t length) { - - if (input_size > (uint64_t)SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring input size is too large (> %d)", SUPPORTED_UPPER_BOUND); - } - if (offset < SUPPORTED_LOWER_BOUND) { - throw OutOfRangeException("Substring offset outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); - } - if (offset > SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring offset outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); - } - if (length < SUPPORTED_LOWER_BOUND) { - throw OutOfRangeException("Substring length outside of supported range (< %d)", SUPPORTED_LOWER_BOUND); - } - if (length > SUPPORTED_UPPER_BOUND) { - throw OutOfRangeException("Substring length outside of supported range (> %d)", SUPPORTED_UPPER_BOUND); - } -} - -string_t SubstringEmptyString(Vector &result) { - auto result_string = StringVector::EmptyString(result, 0); - result_string.Finalize(); - return result_string; -} - -string_t SubstringSlice(Vector &result, const char *input_data, int64_t offset, int64_t length) { - auto result_string = StringVector::EmptyString(result, UnsafeNumericCast(length)); - auto result_data = result_string.GetDataWriteable(); - memcpy(result_data, input_data + offset, UnsafeNumericCast(length)); - result_string.Finalize(); - return result_string; -} - -// compute start and end characters from the given input size and offset/length -bool SubstringStartEnd(int64_t input_size, int64_t offset, int64_t length, int64_t &start, int64_t &end) { - if (length == 0) { - return false; - } - if (offset > 0) { - // positive offset: scan from start - start = MinValue(input_size, offset - 1); - } else if (offset < 0) { - // negative offset: scan from end (i.e. start = end + offset) - start = MaxValue(input_size + offset, 0); - } else { - // offset = 0: special case, we start 1 character BEHIND the first character - start = 0; - length--; - if (length <= 0) { - return false; - } - } - if (length > 0) { - // positive length: go forward (i.e. end = start + offset) - end = MinValue(input_size, start + length); - } else { - // negative length: go backwards (i.e. end = start, start = start + length) - end = start; - start = MaxValue(0, start + length); - } - if (start == end) { - return false; - } - D_ASSERT(start < end); - return true; -} - -string_t SubstringASCII(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - int64_t start, end; - if (!SubstringStartEnd(UnsafeNumericCast(input_size), offset, length, start, end)) { - return SubstringEmptyString(result); - } - return SubstringSlice(result, input_data, start, UnsafeNumericCast(end - start)); -} - -string_t SubstringUnicode(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - if (length == 0) { - return SubstringEmptyString(result); - } - // first figure out which direction we need to scan - idx_t start_pos; - idx_t end_pos; - if (offset < 0) { - start_pos = 0; - end_pos = DConstants::INVALID_INDEX; - - // negative offset: scan backwards - int64_t start, end; - - // we express start and end as unicode codepoints from the back - offset--; - if (length < 0) { - // negative length - start = -offset - length; - end = -offset; - } else { - // positive length - start = -offset; - end = -offset - length; - } - if (end <= 0) { - end_pos = input_size; - } - int64_t current_character = 0; - for (idx_t i = input_size; i > 0; i--) { - if (IsCharacter(input_data[i - 1])) { - current_character++; - if (current_character == start) { - start_pos = i; - break; - } else if (current_character == end) { - end_pos = i; - } - } - } - while (!IsCharacter(input_data[start_pos])) { - start_pos++; - } - while (end_pos < input_size && !IsCharacter(input_data[end_pos])) { - end_pos++; - } - - if (end_pos == DConstants::INVALID_INDEX) { - return SubstringEmptyString(result); - } - } else { - start_pos = DConstants::INVALID_INDEX; - end_pos = input_size; - - // positive offset: scan forwards - int64_t start, end; - - // we express start and end as unicode codepoints from the front - offset--; - if (length < 0) { - // negative length - start = MaxValue(0, offset + length); - end = offset; - } else { - // positive length - start = MaxValue(0, offset); - end = offset + length; - } - - int64_t current_character = 0; - for (idx_t i = 0; i < input_size; i++) { - if (IsCharacter(input_data[i])) { - if (current_character == start) { - start_pos = i; - } else if (current_character == end) { - end_pos = i; - break; - } - current_character++; - } - } - if (start_pos == DConstants::INVALID_INDEX || end == 0 || end <= start) { - return SubstringEmptyString(result); - } - } - D_ASSERT(end_pos >= start_pos); - // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, UnsafeNumericCast(start_pos), - UnsafeNumericCast(end_pos - start_pos)); -} - -string_t SubstringGrapheme(Vector &result, string_t input, int64_t offset, int64_t length) { - auto input_data = input.GetData(); - auto input_size = input.GetSize(); - - AssertInSupportedRange(input_size, offset, length); - - // we don't know yet if the substring is ascii, but we assume it is (for now) - // first get the start and end as if this was an ascii string - int64_t start, end; - if (!SubstringStartEnd(UnsafeNumericCast(input_size), offset, length, start, end)) { - return SubstringEmptyString(result); - } - - // now check if all the characters between 0 and end are ascii characters - // note that we scan one further to check for a potential combining diacritics (e.g. i + diacritic is ï) - bool is_ascii = true; - idx_t ascii_end = MinValue(UnsafeNumericCast(end + 1), input_size); - for (idx_t i = 0; i < ascii_end; i++) { - if (input_data[i] & 0x80) { - // found a non-ascii character: eek - is_ascii = false; - break; - } - } - if (is_ascii) { - // all characters are ascii, we can just slice the substring - return SubstringSlice(result, input_data, start, end - start); - } - // if the characters are not ascii, we need to scan grapheme clusters - // first figure out which direction we need to scan - // offset = 0 case is taken care of in SubstringStartEnd - if (offset < 0) { - // negative offset, this case is more difficult - // we first need to count the number of characters in the string - idx_t num_characters = Utf8Proc::GraphemeCount(input_data, input_size); - // now call substring start and end again, but with the number of unicode characters this time - SubstringStartEnd(UnsafeNumericCast(num_characters), offset, length, start, end); - } - - // now scan the graphemes of the string to find the positions of the start and end characters - int64_t current_character = 0; - idx_t start_pos = DConstants::INVALID_INDEX, end_pos = input_size; - for (auto cluster : Utf8Proc::GraphemeClusters(input_data, input_size)) { - if (current_character == start) { - start_pos = cluster.start; - } else if (current_character == end) { - end_pos = cluster.start; - break; - } - current_character++; - } - if (start_pos == DConstants::INVALID_INDEX) { - return SubstringEmptyString(result); - } - // after we have found these, we can slice the substring - return SubstringSlice(result, input_data, UnsafeNumericCast(start_pos), - UnsafeNumericCast(end_pos - start_pos)); -} - -struct SubstringUnicodeOp { - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringUnicode(result, input, offset, length); - } -}; - -struct SubstringGraphemeOp { - static string_t Substring(Vector &result, string_t input, int64_t offset, int64_t length) { - return SubstringGrapheme(result, input, offset, length); - } -}; - -template -static void SubstringFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; - if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; - - TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), - [&](string_t input_string, int64_t offset, int64_t length) { - return OP::Substring(result, input_string, offset, length); - }); - } else { - BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { - return OP::Substring(result, input_string, offset, NumericLimits::Maximum()); - }); - } -} - -static void SubstringFunctionASCII(DataChunk &args, ExpressionState &state, Vector &result) { - auto &input_vector = args.data[0]; - auto &offset_vector = args.data[1]; - if (args.ColumnCount() == 3) { - auto &length_vector = args.data[2]; - - TernaryExecutor::Execute( - input_vector, offset_vector, length_vector, result, args.size(), - [&](string_t input_string, int64_t offset, int64_t length) { - return SubstringASCII(result, input_string, offset, length); - }); - } else { - BinaryExecutor::Execute( - input_vector, offset_vector, result, args.size(), [&](string_t input_string, int64_t offset) { - return SubstringASCII(result, input_string, offset, NumericLimits::Maximum()); - }); - } -} - -static unique_ptr SubstringPropagateStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - // can only propagate stats if the children have stats - // we only care about the stats of the first child (i.e. the string) - if (!StringStats::CanContainUnicode(child_stats[0])) { - expr.function.function = SubstringFunctionASCII; - } - return nullptr; -} - -ScalarFunctionSet SubstringFun::GetFunctions() { - ScalarFunctionSet substr("substring"); - substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::VARCHAR, SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - substr.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - return (substr); -} - -ScalarFunctionSet SubstringGraphemeFun::GetFunctions() { - ScalarFunctionSet substr_grapheme("substring_grapheme"); - substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT, LogicalType::BIGINT}, - LogicalType::VARCHAR, SubstringFunction, nullptr, - nullptr, SubstringPropagateStats)); - substr_grapheme.AddFunction(ScalarFunction({LogicalType::VARCHAR, LogicalType::BIGINT}, LogicalType::VARCHAR, - SubstringFunction, nullptr, nullptr, - SubstringPropagateStats)); - return (substr_grapheme); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/string/suffix.cpp b/src/duckdb/src/function/scalar/string/suffix.cpp deleted file mode 100644 index 21175f61d..000000000 --- a/src/duckdb/src/function/scalar/string/suffix.cpp +++ /dev/null @@ -1,43 +0,0 @@ -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/common/types/string_type.hpp" - -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -static bool SuffixFunction(const string_t &str, const string_t &suffix); - -struct SuffixOperator { - template - static inline TR Operation(TA left, TB right) { - return SuffixFunction(left, right); - } -}; - -static bool SuffixFunction(const string_t &str, const string_t &suffix) { - auto suffix_size = suffix.GetSize(); - auto str_size = str.GetSize(); - if (suffix_size > str_size) { - return false; - } - - auto suffix_data = suffix.GetData(); - auto str_data = str.GetData(); - auto suf_idx = UnsafeNumericCast(suffix_size) - 1; - idx_t str_idx = str_size - 1; - for (; suf_idx >= 0; --suf_idx, --str_idx) { - if (suffix_data[suf_idx] != str_data[str_idx]) { - return false; - } - } - return true; -} - -ScalarFunction SuffixFun::GetFunction() { - return ScalarFunction("suffix", // name of the function - {LogicalType::VARCHAR, LogicalType::VARCHAR}, // argument list - LogicalType::BOOLEAN, // return type - ScalarFunction::BinaryFunction); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/struct/struct_concat.cpp b/src/duckdb/src/function/scalar/struct/struct_concat.cpp deleted file mode 100644 index f5ed780e8..000000000 --- a/src/duckdb/src/function/scalar/struct/struct_concat.cpp +++ /dev/null @@ -1,115 +0,0 @@ -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/struct_functions.hpp" -#include "duckdb/common/case_insensitive_map.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/storage/statistics/struct_stats.hpp" - -namespace duckdb { - -static void StructConcatFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &result_cols = StructVector::GetEntries(result); - idx_t offset = 0; - - if (!args.AllConstant()) { - // Unless all arguments are constant, we flatten the input to make sure it's homogeneous - args.Flatten(); - } - - for (auto &arg : args.data) { - const auto &child_cols = StructVector::GetEntries(arg); - for (auto &child_col : child_cols) { - result_cols[offset++]->Reference(*child_col); - } - } - D_ASSERT(offset == result_cols.size()); - - if (args.AllConstant()) { - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } - - result.Verify(args.size()); -} - -static unique_ptr StructConcatBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // collect names and deconflict, construct return type - if (arguments.empty()) { - throw InvalidInputException("struct_concat: At least one argument is required"); - } - - child_list_t combined_children; - case_insensitive_set_t name_set; - - bool has_unnamed = false; - - for (idx_t arg_idx = 0; arg_idx < arguments.size(); arg_idx++) { - const auto &arg = arguments[arg_idx]; - - if (arg->return_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - - if (arg->return_type.id() != LogicalTypeId::STRUCT) { - throw InvalidInputException("struct_concat: Argument at position \"%d\" is not a STRUCT", arg_idx + 1); - } - - const auto &child_types = StructType::GetChildTypes(arg->return_type); - for (const auto &child : child_types) { - if (!child.first.empty()) { - auto it = name_set.find(child.first); - if (it != name_set.end()) { - if (*it == child.first) { - throw InvalidInputException("struct_concat: Arguments contain duplicate STRUCT entry \"%s\"", - child.first); - } - throw InvalidInputException( - "struct_concat: Arguments contain case-insensitive duplicate STRUCT entry \"%s\" and \"%s\"", - child.first, *it); - } - name_set.insert(child.first); - } else { - has_unnamed = true; - } - combined_children.push_back(child); - } - } - - if (has_unnamed && !name_set.empty()) { - throw InvalidInputException("struct_concat: Cannot mix named and unnamed STRUCTs"); - } - - bound_function.return_type = LogicalType::STRUCT(combined_children); - return nullptr; -} - -unique_ptr StructConcatStats(ClientContext &context, FunctionStatisticsInput &input) { - const auto &expr = input.expr; - - auto &arg_stats = input.child_stats; - auto &arg_exprs = input.expr.children; - - auto struct_stats = StructStats::CreateUnknown(expr.return_type); - idx_t struct_index = 0; - - for (idx_t arg_idx = 0; arg_idx < arg_exprs.size(); arg_idx++) { - auto &arg_stat = arg_stats[arg_idx]; - auto &arg_type = arg_exprs[arg_idx]->return_type; - for (idx_t child_idx = 0; child_idx < StructType::GetChildCount(arg_type); child_idx++) { - auto &child_stat = StructStats::GetChildStats(arg_stat, child_idx); - StructStats::SetChildStats(struct_stats, struct_index++, child_stat); - } - } - return struct_stats.ToUnique(); -} - -ScalarFunction StructConcatFun::GetFunction() { - ScalarFunction fun("struct_concat", {}, LogicalTypeId::STRUCT, StructConcatFunction, StructConcatBind, nullptr, - StructConcatStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - return fun; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/struct/struct_extract.cpp b/src/duckdb/src/function/scalar/struct/struct_extract.cpp deleted file mode 100644 index 20cc74157..000000000 --- a/src/duckdb/src/function/scalar/struct/struct_extract.cpp +++ /dev/null @@ -1,179 +0,0 @@ -#include "duckdb/common/string_util.hpp" -#include "duckdb/execution/expression_executor.hpp" -#include "duckdb/function/scalar/struct_functions.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/storage/statistics/struct_stats.hpp" -#include "duckdb/function/scalar/struct_utils.hpp" - -namespace duckdb { - -static void StructExtractFunction(DataChunk &args, ExpressionState &state, Vector &result) { - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - - // this should be guaranteed by the binder - auto &vec = args.data[0]; - - vec.Verify(args.size()); - auto &children = StructVector::GetEntries(vec); - D_ASSERT(info.index < children.size()); - auto &struct_child = children[info.index]; - result.Reference(*struct_child); - result.Verify(args.size()); -} - -static unique_ptr StructExtractBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - D_ASSERT(bound_function.arguments.size() == 2); - auto &child_type = arguments[0]->return_type; - if (child_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - D_ASSERT(LogicalTypeId::STRUCT == child_type.id()); - auto &struct_children = StructType::GetChildTypes(child_type); - if (struct_children.empty()) { - throw InternalException("Can't extract something from an empty struct"); - } - if (StructType::IsUnnamed(child_type)) { - throw BinderException( - "struct_extract with a string key cannot be used on an unnamed struct, use a numeric index instead"); - } - bound_function.arguments[0] = child_type; - - auto &key_child = arguments[1]; - if (key_child->HasParameter()) { - throw ParameterNotResolvedException(); - } - - if (key_child->return_type.id() != LogicalTypeId::VARCHAR || !key_child->IsFoldable()) { - throw BinderException("Key name for struct_extract needs to be a constant string"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - D_ASSERT(key_val.type().id() == LogicalTypeId::VARCHAR); - auto &key_str = StringValue::Get(key_val); - if (key_val.IsNull() || key_str.empty()) { - throw BinderException("Key name for struct_extract needs to be neither NULL nor empty"); - } - string key = StringUtil::Lower(key_str); - - LogicalType return_type; - idx_t key_index = 0; - bool found_key = false; - - for (size_t i = 0; i < struct_children.size(); i++) { - auto &child = struct_children[i]; - if (StringUtil::Lower(child.first) == key) { - found_key = true; - key_index = i; - return_type = child.second; - break; - } - } - - if (!found_key) { - vector candidates; - candidates.reserve(struct_children.size()); - for (auto &struct_child : struct_children) { - candidates.push_back(struct_child.first); - } - auto closest_settings = StringUtil::TopNJaroWinkler(candidates, key); - auto message = StringUtil::CandidatesMessage(closest_settings, "Candidate Entries"); - throw BinderException("Could not find key \"%s\" in struct\n%s", key, message); - } - - bound_function.return_type = std::move(return_type); - return GetBindData(key_index); -} - -static unique_ptr StructExtractBindInternal(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments, - bool struct_extract) { - D_ASSERT(bound_function.arguments.size() == 2); - auto &child_type = arguments[0]->return_type; - if (child_type.id() == LogicalTypeId::UNKNOWN) { - throw ParameterNotResolvedException(); - } - D_ASSERT(LogicalTypeId::STRUCT == child_type.id()); - auto &struct_children = StructType::GetChildTypes(child_type); - if (struct_children.empty()) { - throw InternalException("Can't extract something from an empty struct"); - } - if (struct_extract && !StructType::IsUnnamed(child_type)) { - throw BinderException( - "struct_extract with an integer key can only be used on unnamed structs, use a string key instead"); - } - bound_function.arguments[0] = child_type; - - auto &key_child = arguments[1]; - if (key_child->HasParameter()) { - throw ParameterNotResolvedException(); - } - - if (!key_child->IsFoldable()) { - throw BinderException("Key index for struct_extract needs to be a constant value"); - } - Value key_val = ExpressionExecutor::EvaluateScalar(context, *key_child); - auto index = key_val.GetValue(); - if (index <= 0 || idx_t(index) > struct_children.size()) { - throw BinderException("Key index %lld for struct_extract out of range - expected an index between 1 and %llu", - index, struct_children.size()); - } - bound_function.return_type = struct_children[NumericCast(index - 1)].second; - return GetBindData(NumericCast(index - 1)); -} - -static unique_ptr StructExtractBindIndex(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return StructExtractBindInternal(context, bound_function, arguments, true); -} - -static unique_ptr StructExtractAtBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - return StructExtractBindInternal(context, bound_function, arguments, false); -} - -static unique_ptr PropagateStructExtractStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &bind_data = input.bind_data; - - auto &info = bind_data->Cast(); - auto struct_child_stats = StructStats::GetChildStats(child_stats[0]); - return struct_child_stats[info.index].ToUnique(); -} - -unique_ptr GetBindData(idx_t index) { - return make_uniq(index); -} - -ScalarFunction GetKeyExtractFunction() { - return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::VARCHAR}, LogicalType::ANY, - StructExtractFunction, StructExtractBind, nullptr, PropagateStructExtractStats); -} - -ScalarFunction GetIndexExtractFunction() { - return ScalarFunction("struct_extract", {LogicalTypeId::STRUCT, LogicalType::BIGINT}, LogicalType::ANY, - StructExtractFunction, StructExtractBindIndex); -} - -ScalarFunction GetExtractAtFunction() { - return ScalarFunction("struct_extract_at", {LogicalTypeId::STRUCT, LogicalType::BIGINT}, LogicalType::ANY, - StructExtractFunction, StructExtractAtBind); -} - -ScalarFunctionSet StructExtractFun::GetFunctions() { - // the arguments and return types are actually set in the binder function - ScalarFunctionSet struct_extract_set("struct_extract"); - struct_extract_set.AddFunction(GetKeyExtractFunction()); - struct_extract_set.AddFunction(GetIndexExtractFunction()); - return struct_extract_set; -} - -ScalarFunctionSet StructExtractAtFun::GetFunctions() { - ScalarFunctionSet struct_extractat_set("struct_extract_at"); - struct_extractat_set.AddFunction(GetExtractAtFunction()); - return struct_extractat_set; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/struct/struct_pack.cpp b/src/duckdb/src/function/scalar/struct/struct_pack.cpp deleted file mode 100644 index 51a3e34ca..000000000 --- a/src/duckdb/src/function/scalar/struct/struct_pack.cpp +++ /dev/null @@ -1,92 +0,0 @@ -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/scalar/struct_functions.hpp" -#include "duckdb/common/case_insensitive_map.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" -#include "duckdb/planner/expression_binder.hpp" -#include "duckdb/storage/statistics/struct_stats.hpp" - -namespace duckdb { - -static void StructPackFunction(DataChunk &args, ExpressionState &state, Vector &result) { -#ifdef DEBUG - auto &func_expr = state.expr.Cast(); - auto &info = func_expr.bind_info->Cast(); - // this should never happen if the binder below is sane - D_ASSERT(args.ColumnCount() == StructType::GetChildTypes(info.stype).size()); -#endif - bool all_const = true; - auto &child_entries = StructVector::GetEntries(result); - for (idx_t i = 0; i < args.ColumnCount(); i++) { - if (args.data[i].GetVectorType() != VectorType::CONSTANT_VECTOR) { - all_const = false; - } - // same holds for this - child_entries[i]->Reference(args.data[i]); - } - result.SetVectorType(all_const ? VectorType::CONSTANT_VECTOR : VectorType::FLAT_VECTOR); - result.Verify(args.size()); -} - -template -static unique_ptr StructPackBind(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - case_insensitive_set_t name_collision_set; - - // collect names and deconflict, construct return type - if (arguments.empty()) { - throw InvalidInputException("Can't pack nothing into a struct"); - } - child_list_t struct_children; - for (idx_t i = 0; i < arguments.size(); i++) { - auto &child = arguments[i]; - string alias; - if (IS_STRUCT_PACK) { - if (child->GetAlias().empty()) { - throw BinderException("Need named argument for struct pack, e.g. STRUCT_PACK(a := b)"); - } - alias = child->GetAlias(); - if (name_collision_set.find(alias) != name_collision_set.end()) { - throw BinderException("Duplicate struct entry name \"%s\"", alias); - } - name_collision_set.insert(alias); - } - struct_children.push_back(make_pair(alias, arguments[i]->return_type)); - } - - // this is more for completeness reasons - bound_function.return_type = LogicalType::STRUCT(struct_children); - return make_uniq(bound_function.return_type); -} - -unique_ptr StructPackStats(ClientContext &context, FunctionStatisticsInput &input) { - auto &child_stats = input.child_stats; - auto &expr = input.expr; - auto struct_stats = StructStats::CreateUnknown(expr.return_type); - for (idx_t i = 0; i < child_stats.size(); i++) { - StructStats::SetChildStats(struct_stats, i, child_stats[i]); - } - return struct_stats.ToUnique(); -} - -template -ScalarFunction GetStructPackFunction() { - ScalarFunction fun(IS_STRUCT_PACK ? "struct_pack" : "row", {}, LogicalTypeId::STRUCT, StructPackFunction, - StructPackBind, nullptr, StructPackStats); - fun.varargs = LogicalType::ANY; - fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - fun.serialize = VariableReturnBindData::Serialize; - fun.deserialize = VariableReturnBindData::Deserialize; - return fun; -} - -ScalarFunction StructPackFun::GetFunction() { - return GetStructPackFunction(); -} - -ScalarFunction RowFun::GetFunction() { - return GetStructPackFunction(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar/system/aggregate_export.cpp b/src/duckdb/src/function/scalar/system/aggregate_export.cpp deleted file mode 100644 index 64fee3e8c..000000000 --- a/src/duckdb/src/function/scalar/system/aggregate_export.cpp +++ /dev/null @@ -1,363 +0,0 @@ -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/function/function_binder.hpp" -#include "duckdb/function/scalar/generic_common.hpp" -#include "duckdb/function/scalar/system_functions.hpp" -#include "duckdb/function/scalar/generic_functions.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/planner/expression/bound_aggregate_expression.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_function_expression.hpp" - -namespace duckdb { - -// aggregate state export -struct ExportAggregateBindData : public FunctionData { - AggregateFunction aggr; - idx_t state_size; - - explicit ExportAggregateBindData(AggregateFunction aggr_p, idx_t state_size_p) - : aggr(std::move(aggr_p)), state_size(state_size_p) { - } - - unique_ptr Copy() const override { - return make_uniq(aggr, state_size); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return aggr == other.aggr && state_size == other.state_size; - } - - static ExportAggregateBindData &GetFrom(ExpressionState &state) { - auto &func_expr = state.expr.Cast(); - return func_expr.bind_info->Cast(); - } -}; - -struct CombineState : public FunctionLocalState { - idx_t state_size; - - unsafe_unique_array state_buffer0, state_buffer1; - Vector state_vector0, state_vector1; - - ArenaAllocator allocator; - - explicit CombineState(idx_t state_size_p) - : state_size(state_size_p), state_buffer0(make_unsafe_uniq_array(state_size_p)), - state_buffer1(make_unsafe_uniq_array(state_size_p)), - state_vector0(Value::POINTER(CastPointerToValue(state_buffer0.get()))), - state_vector1(Value::POINTER(CastPointerToValue(state_buffer1.get()))), - allocator(Allocator::DefaultAllocator()) { - } -}; - -static unique_ptr InitCombineState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.state_size); -} - -struct FinalizeState : public FunctionLocalState { - idx_t state_size; - unsafe_unique_array state_buffer; - Vector addresses; - - ArenaAllocator allocator; - - explicit FinalizeState(idx_t state_size_p) - : state_size(state_size_p), - state_buffer(make_unsafe_uniq_array(STANDARD_VECTOR_SIZE * AlignValue(state_size_p))), - addresses(LogicalType::POINTER), allocator(Allocator::DefaultAllocator()) { - } -}; - -static unique_ptr InitFinalizeState(ExpressionState &state, const BoundFunctionExpression &expr, - FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.state_size); -} - -static void AggregateStateFinalize(DataChunk &input, ExpressionState &state_p, Vector &result) { - auto &bind_data = ExportAggregateBindData::GetFrom(state_p); - auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); - local_state.allocator.Reset(); - - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); - D_ASSERT(input.data.size() == 1); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); - auto aligned_state_size = AlignValue(bind_data.state_size); - - auto state_vec_ptr = FlatVector::GetData(local_state.addresses); - - UnifiedVectorFormat state_data; - input.data[0].ToUnifiedFormat(input.size(), state_data); - for (idx_t i = 0; i < input.size(); i++) { - auto state_idx = state_data.sel->get_index(i); - auto state_entry = UnifiedVectorFormat::GetData(state_data) + state_idx; - auto target_ptr = char_ptr_cast(local_state.state_buffer.get()) + aligned_state_size * i; - - if (state_data.validity.RowIsValid(state_idx)) { - D_ASSERT(state_entry->GetSize() == bind_data.state_size); - memcpy((void *)target_ptr, state_entry->GetData(), bind_data.state_size); - } else { - // create a dummy state because finalize does not understand NULLs in its input - // we put the NULL back in explicitly below - bind_data.aggr.initialize(bind_data.aggr, data_ptr_cast(target_ptr)); - } - state_vec_ptr[i] = data_ptr_cast(target_ptr); - } - - AggregateInputData aggr_input_data(nullptr, local_state.allocator); - bind_data.aggr.finalize(local_state.addresses, aggr_input_data, result, input.size(), 0); - - for (idx_t i = 0; i < input.size(); i++) { - auto state_idx = state_data.sel->get_index(i); - if (!state_data.validity.RowIsValid(state_idx)) { - FlatVector::SetNull(result, i, true); - } - } -} - -static void AggregateStateCombine(DataChunk &input, ExpressionState &state_p, Vector &result) { - auto &bind_data = ExportAggregateBindData::GetFrom(state_p); - auto &local_state = ExecuteFunctionState::GetFunctionState(state_p)->Cast(); - local_state.allocator.Reset(); - - D_ASSERT(bind_data.state_size == bind_data.aggr.state_size(bind_data.aggr)); - - D_ASSERT(input.data.size() == 2); - D_ASSERT(input.data[0].GetType().id() == LogicalTypeId::AGGREGATE_STATE); - D_ASSERT(input.data[0].GetType() == result.GetType()); - - if (input.data[0].GetType().InternalType() != input.data[1].GetType().InternalType()) { - throw IOException("Aggregate state combine type mismatch, expect %s, got %s", - input.data[0].GetType().ToString(), input.data[1].GetType().ToString()); - } - - UnifiedVectorFormat state0_data, state1_data; - input.data[0].ToUnifiedFormat(input.size(), state0_data); - input.data[1].ToUnifiedFormat(input.size(), state1_data); - - auto result_ptr = FlatVector::GetData(result); - - for (idx_t i = 0; i < input.size(); i++) { - auto state0_idx = state0_data.sel->get_index(i); - auto state1_idx = state1_data.sel->get_index(i); - - auto &state0 = UnifiedVectorFormat::GetData(state0_data)[state0_idx]; - auto &state1 = UnifiedVectorFormat::GetData(state1_data)[state1_idx]; - - // if both are NULL, we return NULL. If either of them is not, the result is that one - if (!state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { - FlatVector::SetNull(result, i, true); - continue; - } - if (state0_data.validity.RowIsValid(state0_idx) && !state1_data.validity.RowIsValid(state1_idx)) { - result_ptr[i] = - StringVector::AddStringOrBlob(result, const_char_ptr_cast(state0.GetData()), bind_data.state_size); - continue; - } - if (!state0_data.validity.RowIsValid(state0_idx) && state1_data.validity.RowIsValid(state1_idx)) { - result_ptr[i] = - StringVector::AddStringOrBlob(result, const_char_ptr_cast(state1.GetData()), bind_data.state_size); - continue; - } - - // we actually have to combine - if (state0.GetSize() != bind_data.state_size || state1.GetSize() != bind_data.state_size) { - throw IOException("Aggregate state size mismatch, expect %llu, got %llu and %llu", bind_data.state_size, - state0.GetSize(), state1.GetSize()); - } - - memcpy(local_state.state_buffer0.get(), state0.GetData(), bind_data.state_size); - memcpy(local_state.state_buffer1.get(), state1.GetData(), bind_data.state_size); - - AggregateInputData aggr_input_data(nullptr, local_state.allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - bind_data.aggr.combine(local_state.state_vector0, local_state.state_vector1, aggr_input_data, 1); - - result_ptr[i] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(local_state.state_buffer1.get()), - bind_data.state_size); - } -} - -static unique_ptr BindAggregateState(ClientContext &context, ScalarFunction &bound_function, - vector> &arguments) { - - // grab the aggregate type and bind the aggregate again - - // the aggregate name and types are in the logical type of the aggregate state, make sure its sane - auto &arg_return_type = arguments[0]->return_type; - for (auto &arg_type : bound_function.arguments) { - arg_type = arg_return_type; - } - - if (arg_return_type.id() != LogicalTypeId::AGGREGATE_STATE) { - throw BinderException("Can only FINALIZE aggregate state, not %s", arg_return_type.ToString()); - } - // combine - if (arguments.size() == 2 && arguments[0]->return_type != arguments[1]->return_type && - arguments[1]->return_type.id() != LogicalTypeId::BLOB) { - throw BinderException("Cannot COMBINE aggregate states from different functions, %s <> %s", - arguments[0]->return_type.ToString(), arguments[1]->return_type.ToString()); - } - - // following error states are only reachable when someone messes up creating the state_type - // which is impossible from SQL - - auto state_type = AggregateStateType::GetStateType(arg_return_type); - - // now we can look up the function in the catalog again and bind it - auto &func = Catalog::GetSystemCatalog(context).GetEntry(context, CatalogType::SCALAR_FUNCTION_ENTRY, - DEFAULT_SCHEMA, state_type.function_name); - if (func.type != CatalogType::AGGREGATE_FUNCTION_ENTRY) { - throw InternalException("Could not find aggregate %s", state_type.function_name); - } - auto &aggr = func.Cast(); - - ErrorData error; - - FunctionBinder function_binder(context); - auto best_function = - function_binder.BindFunction(aggr.name, aggr.functions, state_type.bound_argument_types, error); - if (!best_function.IsValid()) { - throw InternalException("Could not re-bind exported aggregate %s: %s", state_type.function_name, - error.Message()); - } - auto bound_aggr = aggr.functions.GetFunctionByOffset(best_function.GetIndex()); - if (bound_aggr.bind) { - // FIXME: this is really hacky - // but the aggregate state export needs a rework around how it handles more complex aggregates anyway - vector> args; - args.reserve(state_type.bound_argument_types.size()); - for (auto &arg_type : state_type.bound_argument_types) { - args.push_back(make_uniq(Value(arg_type))); - } - auto bind_info = bound_aggr.bind(context, bound_aggr, args); - if (bind_info) { - throw BinderException("Aggregate function with bind info not supported yet in aggregate state export"); - } - } - - if (bound_aggr.return_type != state_type.return_type || bound_aggr.arguments != state_type.bound_argument_types) { - throw InternalException("Type mismatch for exported aggregate %s", state_type.function_name); - } - - if (bound_function.name == "finalize") { - bound_function.return_type = bound_aggr.return_type; - } else { - D_ASSERT(bound_function.name == "combine"); - bound_function.return_type = arg_return_type; - } - - return make_uniq(bound_aggr, bound_aggr.state_size(bound_aggr)); -} - -static void ExportAggregateFinalize(Vector &state, AggregateInputData &aggr_input_data, Vector &result, idx_t count, - idx_t offset) { - D_ASSERT(offset == 0); - auto &bind_data = aggr_input_data.bind_data->Cast(); - auto state_size = bind_data.aggregate->function.state_size(bind_data.aggregate->function); - auto blob_ptr = FlatVector::GetData(result); - auto addresses_ptr = FlatVector::GetData(state); - for (idx_t row_idx = 0; row_idx < count; row_idx++) { - auto data_ptr = addresses_ptr[row_idx]; - blob_ptr[row_idx] = StringVector::AddStringOrBlob(result, const_char_ptr_cast(data_ptr), state_size); - } -} - -ExportAggregateFunctionBindData::ExportAggregateFunctionBindData(unique_ptr aggregate_p) { - D_ASSERT(aggregate_p->GetExpressionType() == ExpressionType::BOUND_AGGREGATE); - aggregate = unique_ptr_cast(std::move(aggregate_p)); -} - -unique_ptr ExportAggregateFunctionBindData::Copy() const { - return make_uniq(aggregate->Copy()); -} - -bool ExportAggregateFunctionBindData::Equals(const FunctionData &other_p) const { - auto &other = other_p.Cast(); - return aggregate->Equals(*other.aggregate); -} - -static void ExportStateAggregateSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const AggregateFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -static unique_ptr ExportStateAggregateDeserialize(Deserializer &deserializer, - AggregateFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); -} - -static void ExportStateScalarSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const ScalarFunction &function) { - throw NotImplementedException("FIXME: export state serialize"); -} - -static unique_ptr ExportStateScalarDeserialize(Deserializer &deserializer, ScalarFunction &function) { - throw NotImplementedException("FIXME: export state deserialize"); -} - -unique_ptr -ExportAggregateFunction::Bind(unique_ptr child_aggregate) { - auto &bound_function = child_aggregate->function; - if (!bound_function.combine) { - throw BinderException("Cannot use EXPORT_STATE for non-combinable function %s", bound_function.name); - } - if (bound_function.bind) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom binders"); - } - if (bound_function.destructor) { - throw BinderException("Cannot use EXPORT_STATE on aggregate functions with custom destructors"); - } - // this should be required - D_ASSERT(bound_function.state_size); - D_ASSERT(bound_function.finalize); - - D_ASSERT(child_aggregate->function.return_type.id() != LogicalTypeId::INVALID); -#ifdef DEBUG - for (auto &arg_type : child_aggregate->function.arguments) { - D_ASSERT(arg_type.id() != LogicalTypeId::INVALID); - } -#endif - auto export_bind_data = make_uniq(child_aggregate->Copy()); - aggregate_state_t state_type(child_aggregate->function.name, child_aggregate->function.return_type, - child_aggregate->function.arguments); - auto return_type = LogicalType::AGGREGATE_STATE(std::move(state_type)); - - auto export_function = - AggregateFunction("aggregate_state_export_" + bound_function.name, bound_function.arguments, return_type, - bound_function.state_size, bound_function.initialize, bound_function.update, - bound_function.combine, ExportAggregateFinalize, bound_function.simple_update, - /* can't bind this again */ nullptr, /* no dynamic state yet */ nullptr, - /* can't propagate statistics */ nullptr, nullptr); - export_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - export_function.serialize = ExportStateAggregateSerialize; - export_function.deserialize = ExportStateAggregateDeserialize; - - return make_uniq(export_function, std::move(child_aggregate->children), - std::move(child_aggregate->filter), std::move(export_bind_data), - child_aggregate->aggr_type); -} - -ScalarFunction FinalizeFun::GetFunction() { - auto result = ScalarFunction("finalize", {LogicalTypeId::AGGREGATE_STATE}, LogicalTypeId::INVALID, - AggregateStateFinalize, BindAggregateState, nullptr, nullptr, InitFinalizeState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; - return result; -} - -ScalarFunction CombineFun::GetFunction() { - auto result = - ScalarFunction("combine", {LogicalTypeId::AGGREGATE_STATE, LogicalTypeId::ANY}, LogicalTypeId::AGGREGATE_STATE, - AggregateStateCombine, BindAggregateState, nullptr, nullptr, InitCombineState); - result.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - result.serialize = ExportStateScalarSerialize; - result.deserialize = ExportStateScalarDeserialize; - return result; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar_function.cpp b/src/duckdb/src/function/scalar_function.cpp deleted file mode 100644 index a627643fd..000000000 --- a/src/duckdb/src/function/scalar_function.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "duckdb/function/scalar_function.hpp" - -namespace duckdb { - -FunctionLocalState::~FunctionLocalState() { -} - -ScalarFunctionInfo::~ScalarFunctionInfo() { -} - -ScalarFunction::ScalarFunction(string name, vector arguments, LogicalType return_type, - scalar_function_t function, bind_scalar_function_t bind, - bind_scalar_function_extended_t bind_extended, function_statistics_t statistics, - init_local_state_t init_local_state, LogicalType varargs, FunctionStability side_effects, - FunctionNullHandling null_handling, bind_lambda_function_t bind_lambda) - : BaseScalarFunction(std::move(name), std::move(arguments), std::move(return_type), side_effects, - std::move(varargs), null_handling), - function(std::move(function)), bind(bind), bind_extended(bind_extended), init_local_state(init_local_state), - statistics(statistics), bind_lambda(bind_lambda), bind_expression(nullptr), get_modified_databases(nullptr), - serialize(nullptr), deserialize(nullptr) { -} - -ScalarFunction::ScalarFunction(vector arguments, LogicalType return_type, scalar_function_t function, - bind_scalar_function_t bind, bind_scalar_function_extended_t bind_extended, - function_statistics_t statistics, init_local_state_t init_local_state, - LogicalType varargs, FunctionStability side_effects, FunctionNullHandling null_handling, - bind_lambda_function_t bind_lambda) - : ScalarFunction(string(), std::move(arguments), std::move(return_type), std::move(function), bind, bind_extended, - statistics, init_local_state, std::move(varargs), side_effects, null_handling, bind_lambda) { -} - -bool ScalarFunction::operator==(const ScalarFunction &rhs) const { - return name == rhs.name && arguments == rhs.arguments && return_type == rhs.return_type && varargs == rhs.varargs && - bind == rhs.bind && bind_extended == rhs.bind_extended && statistics == rhs.statistics && - bind_lambda == rhs.bind_lambda; -} - -bool ScalarFunction::operator!=(const ScalarFunction &rhs) const { - return !(*this == rhs); -} - -bool ScalarFunction::Equal(const ScalarFunction &rhs) const { - // number of types - if (this->arguments.size() != rhs.arguments.size()) { - return false; - } - // argument types - for (idx_t i = 0; i < this->arguments.size(); ++i) { - if (this->arguments[i] != rhs.arguments[i]) { - return false; - } - } - // return type - if (this->return_type != rhs.return_type) { - return false; - } - // varargs - if (this->varargs != rhs.varargs) { - return false; - } - - return true; // they are equal -} - -void ScalarFunction::NopFunction(DataChunk &input, ExpressionState &state, Vector &result) { - D_ASSERT(input.ColumnCount() >= 1); - result.Reference(input.data[0]); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/scalar_macro_function.cpp b/src/duckdb/src/function/scalar_macro_function.cpp deleted file mode 100644 index 9f3d79c3f..000000000 --- a/src/duckdb/src/function/scalar_macro_function.cpp +++ /dev/null @@ -1,52 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/scalar_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// - -#include "duckdb/function/scalar_macro_function.hpp" - -#include "duckdb/function/macro_function.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/parsed_expression_iterator.hpp" - -namespace duckdb { - -ScalarMacroFunction::ScalarMacroFunction(unique_ptr expression) - : MacroFunction(MacroType::SCALAR_MACRO), expression(std::move(expression)) { -} - -ScalarMacroFunction::ScalarMacroFunction(void) : MacroFunction(MacroType::SCALAR_MACRO) { -} - -unique_ptr ScalarMacroFunction::Copy() const { - auto result = make_uniq(); - result->expression = expression->Copy(); - CopyProperties(*result); - - return std::move(result); -} - -void RemoveQualificationRecursive(unique_ptr &expr) { - if (expr->GetExpressionType() == ExpressionType::COLUMN_REF) { - auto &col_ref = expr->Cast(); - auto &col_names = col_ref.column_names; - if (col_names.size() == 2 && col_names[0].find(DummyBinding::DUMMY_NAME) != string::npos) { - col_names.erase(col_names.begin()); - } - } else { - ParsedExpressionIterator::EnumerateChildren( - *expr, [](unique_ptr &child) { RemoveQualificationRecursive(child); }); - } -} - -string ScalarMacroFunction::ToSQL() const { - // In case of nested macro's we need to fix it a bit - auto expression_copy = expression->Copy(); - RemoveQualificationRecursive(expression_copy); - return MacroFunction::ToSQL() + StringUtil::Format("(%s)", expression_copy->ToString()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow.cpp b/src/duckdb/src/function/table/arrow.cpp deleted file mode 100644 index 5d719e8f7..000000000 --- a/src/duckdb/src/function/table/arrow.cpp +++ /dev/null @@ -1,621 +0,0 @@ -#include "duckdb/common/arrow/arrow.hpp" - -#include "duckdb.hpp" -#include "duckdb/common/arrow/arrow_wrapper.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/vector_buffer.hpp" -#include "duckdb/function/table/arrow.hpp" -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" -#include "duckdb/function/table/arrow/arrow_type_info.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/parser/parsed_data/create_table_function_info.hpp" -#include "duckdb/parser/tableref/table_function_ref.hpp" -#include "utf8proc_wrapper.hpp" -#include "duckdb/common/arrow/schema_metadata.hpp" - -namespace duckdb { - -static unique_ptr CreateListType(ArrowSchema &child, ArrowVariableSizeType size_type, bool view) { - auto child_type = ArrowTableFunction::GetArrowLogicalType(child); - - unique_ptr type_info; - auto type = LogicalType::LIST(child_type->GetDuckType()); - if (view) { - type_info = ArrowListInfo::ListView(std::move(child_type), size_type); - } else { - type_info = ArrowListInfo::List(std::move(child_type), size_type); - } - return make_uniq(type, std::move(type_info)); -} - -static unique_ptr GetArrowExtensionType(const ArrowSchemaMetadata &extension_type, const string &format) { - auto arrow_extension = extension_type.GetExtensionName(); - // Check for arrow canonical extensions - if (arrow_extension == "arrow.uuid") { - if (format != "w:16") { - std::ostringstream error; - error - << "arrow.uuid must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly defined as:" - << format; - return make_uniq(error.str()); - } - return make_uniq(LogicalType::UUID); - } else if (arrow_extension == "arrow.json") { - if (format == "u") { - return make_uniq(LogicalType::JSON(), make_uniq(ArrowVariableSizeType::NORMAL)); - } else if (format == "U") { - return make_uniq(LogicalType::JSON(), - make_uniq(ArrowVariableSizeType::SUPER_SIZE)); - } else if (format == "vu") { - return make_uniq(LogicalType::JSON(), make_uniq(ArrowVariableSizeType::VIEW)); - } else { - std::ostringstream error; - error - << "arrow.json must be of a varchar format (i.e., \'u\',\'U\' or \'vu\'). It is incorrectly defined as:" - << format; - return make_uniq(error.str()); - } - } - // Check for DuckDB canonical extensions - else if (extension_type.IsNonCanonicalType("hugeint")) { - if (format != "w:16") { - std::ostringstream error; - error << "DuckDB hugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly " - "defined as:" - << format; - return make_uniq(error.str()); - } - return make_uniq(LogicalType::HUGEINT); - } else if (extension_type.IsNonCanonicalType("uhugeint")) { - if (format != "w:16") { - std::ostringstream error; - error << "DuckDB uhugeint must be a fixed-size binary of 16 bytes (i.e., \'w:16\'). It is incorrectly " - "defined as:" - << format; - return make_uniq(error.str()); - } - return make_uniq(LogicalType::UHUGEINT); - } else if (extension_type.IsNonCanonicalType("time_tz")) { - if (format != "w:8") { - std::ostringstream error; - error << "DuckDB time_tz must be a fixed-size binary of 8 bytes (i.e., \'w:8\'). It is incorrectly defined " - "as:" - << format; - return make_uniq(error.str()); - } - return make_uniq(LogicalType::TIME_TZ, - make_uniq(ArrowDateTimeType::MICROSECONDS)); - } else if (extension_type.IsNonCanonicalType("bit")) { - if (format != "z" && format != "Z") { - std::ostringstream error; - error << "DuckDB bit must be a blob (i.e., \'z\' or \'Z\'). It is incorrectly defined as:" << format; - return make_uniq(error.str()); - } else if (format == "z") { - auto type_info = make_uniq(ArrowVariableSizeType::NORMAL); - return make_uniq(LogicalType::BIT, std::move(type_info)); - } - auto type_info = make_uniq(ArrowVariableSizeType::SUPER_SIZE); - return make_uniq(LogicalType::BIT, std::move(type_info)); - - } else if (extension_type.IsNonCanonicalType("varint")) { - if (format != "z" && format != "Z") { - std::ostringstream error; - error << "DuckDB bit must be a blob (i.e., \'z\'). It is incorrectly defined as:" << format; - return make_uniq(error.str()); - } - unique_ptr type_info; - if (format == "z") { - type_info = make_uniq(ArrowVariableSizeType::NORMAL); - } else { - type_info = make_uniq(ArrowVariableSizeType::SUPER_SIZE); - } - return make_uniq(LogicalType::VARINT, std::move(type_info)); - } else { - std::ostringstream error; - error << "Arrow Type with extension name: " << arrow_extension << " and format: " << format - << ", is not currently supported in DuckDB."; - return make_uniq(error.str(), true); - } -} -static unique_ptr GetArrowLogicalTypeNoDictionary(ArrowSchema &schema) { - auto format = string(schema.format); - // Let's first figure out if this type is an extension type - ArrowSchemaMetadata schema_metadata(schema.metadata); - if (schema_metadata.HasExtension()) { - return GetArrowExtensionType(schema_metadata, format); - } - // If not, we just check the format itself - if (format == "n") { - return make_uniq(LogicalType::SQLNULL); - } else if (format == "b") { - return make_uniq(LogicalType::BOOLEAN); - } else if (format == "c") { - return make_uniq(LogicalType::TINYINT); - } else if (format == "s") { - return make_uniq(LogicalType::SMALLINT); - } else if (format == "i") { - return make_uniq(LogicalType::INTEGER); - } else if (format == "l") { - return make_uniq(LogicalType::BIGINT); - } else if (format == "C") { - return make_uniq(LogicalType::UTINYINT); - } else if (format == "S") { - return make_uniq(LogicalType::USMALLINT); - } else if (format == "I") { - return make_uniq(LogicalType::UINTEGER); - } else if (format == "L") { - return make_uniq(LogicalType::UBIGINT); - } else if (format == "f") { - return make_uniq(LogicalType::FLOAT); - } else if (format == "g") { - return make_uniq(LogicalType::DOUBLE); - } else if (format[0] == 'd') { //! this can be either decimal128 or decimal 256 (e.g., d:38,0) - auto extra_info = StringUtil::Split(format, ':'); - if (extra_info.size() != 2) { - throw InvalidInputException( - "Decimal format of Arrow object is incomplete, it is missing the scale and width. Current format: %s", - format); - } - auto parameters = StringUtil::Split(extra_info[1], ","); - // Parameters must always be 2 or 3 values (i.e., width, scale and an optional bit-width) - if (parameters.size() != 2 && parameters.size() != 3) { - throw InvalidInputException( - "Decimal format of Arrow object is incomplete, it is missing the scale or width. Current format: %s", - format); - } - uint64_t width = std::stoull(parameters[0]); - uint64_t scale = std::stoull(parameters[1]); - uint64_t bitwidth = 128; - if (parameters.size() == 3) { - // We have a bit-width defined - bitwidth = std::stoull(parameters[2]); - } - if (width > 38 || bitwidth > 128) { - throw NotImplementedException("Unsupported Internal Arrow Type for Decimal %s", format); - } - return make_uniq(LogicalType::DECIMAL(NumericCast(width), NumericCast(scale))); - } else if (format == "u") { - return make_uniq(LogicalType::VARCHAR, make_uniq(ArrowVariableSizeType::NORMAL)); - } else if (format == "U") { - return make_uniq(LogicalType::VARCHAR, - make_uniq(ArrowVariableSizeType::SUPER_SIZE)); - } else if (format == "vu") { - return make_uniq(LogicalType::VARCHAR, make_uniq(ArrowVariableSizeType::VIEW)); - } else if (format == "tsn:") { - return make_uniq(LogicalTypeId::TIMESTAMP_NS); - } else if (format == "tsu:") { - return make_uniq(LogicalTypeId::TIMESTAMP); - } else if (format == "tsm:") { - return make_uniq(LogicalTypeId::TIMESTAMP_MS); - } else if (format == "tss:") { - return make_uniq(LogicalTypeId::TIMESTAMP_SEC); - } else if (format == "tdD") { - return make_uniq(LogicalType::DATE, make_uniq(ArrowDateTimeType::DAYS)); - } else if (format == "tdm") { - return make_uniq(LogicalType::DATE, make_uniq(ArrowDateTimeType::MILLISECONDS)); - } else if (format == "tts") { - return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::SECONDS)); - } else if (format == "ttm") { - return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::MILLISECONDS)); - } else if (format == "ttu") { - return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::MICROSECONDS)); - } else if (format == "ttn") { - return make_uniq(LogicalType::TIME, make_uniq(ArrowDateTimeType::NANOSECONDS)); - } else if (format == "tDs") { - return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::SECONDS)); - } else if (format == "tDm") { - return make_uniq(LogicalType::INTERVAL, - make_uniq(ArrowDateTimeType::MILLISECONDS)); - } else if (format == "tDu") { - return make_uniq(LogicalType::INTERVAL, - make_uniq(ArrowDateTimeType::MICROSECONDS)); - } else if (format == "tDn") { - return make_uniq(LogicalType::INTERVAL, - make_uniq(ArrowDateTimeType::NANOSECONDS)); - } else if (format == "tiD") { - return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::DAYS)); - } else if (format == "tiM") { - return make_uniq(LogicalType::INTERVAL, make_uniq(ArrowDateTimeType::MONTHS)); - } else if (format == "tin") { - return make_uniq(LogicalType::INTERVAL, - make_uniq(ArrowDateTimeType::MONTH_DAY_NANO)); - } else if (format == "+l") { - return CreateListType(*schema.children[0], ArrowVariableSizeType::NORMAL, false); - } else if (format == "+L") { - return CreateListType(*schema.children[0], ArrowVariableSizeType::SUPER_SIZE, false); - } else if (format == "+vl") { - return CreateListType(*schema.children[0], ArrowVariableSizeType::NORMAL, true); - } else if (format == "+vL") { - return CreateListType(*schema.children[0], ArrowVariableSizeType::SUPER_SIZE, true); - } else if (format[0] == '+' && format[1] == 'w') { - std::string parameters = format.substr(format.find(':') + 1); - auto fixed_size = NumericCast(std::stoi(parameters)); - auto child_type = ArrowTableFunction::GetArrowLogicalType(*schema.children[0]); - - auto array_type = LogicalType::ARRAY(child_type->GetDuckType(), fixed_size); - auto type_info = make_uniq(std::move(child_type), fixed_size); - return make_uniq(array_type, std::move(type_info)); - } else if (format == "+s") { - child_list_t child_types; - vector> children; - if (schema.n_children == 0) { - throw InvalidInputException( - "Attempted to convert a STRUCT with no fields to DuckDB which is not supported"); - } - for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { - children.emplace_back(ArrowTableFunction::GetArrowLogicalType(*schema.children[type_idx])); - child_types.emplace_back(schema.children[type_idx]->name, children.back()->GetDuckType()); - } - auto type_info = make_uniq(std::move(children)); - auto struct_type = make_uniq(LogicalType::STRUCT(std::move(child_types)), std::move(type_info)); - return struct_type; - } else if (format[0] == '+' && format[1] == 'u') { - if (format[2] != 's') { - throw NotImplementedException("Unsupported Internal Arrow Type: \"%c\" Union", format[2]); - } - D_ASSERT(format[3] == ':'); - - std::string prefix = "+us:"; - // TODO: what are these type ids actually for? - auto type_ids = StringUtil::Split(format.substr(prefix.size()), ','); - - child_list_t members; - vector> children; - if (schema.n_children == 0) { - throw InvalidInputException("Attempted to convert a UNION with no fields to DuckDB which is not supported"); - } - for (idx_t type_idx = 0; type_idx < (idx_t)schema.n_children; type_idx++) { - auto type = schema.children[type_idx]; - - children.emplace_back(ArrowTableFunction::GetArrowLogicalType(*type)); - members.emplace_back(type->name, children.back()->GetDuckType()); - } - - auto type_info = make_uniq(std::move(children)); - auto union_type = make_uniq(LogicalType::UNION(members), std::move(type_info)); - return union_type; - } else if (format == "+r") { - child_list_t members; - vector> children; - idx_t n_children = idx_t(schema.n_children); - D_ASSERT(n_children == 2); - D_ASSERT(string(schema.children[0]->name) == "run_ends"); - D_ASSERT(string(schema.children[1]->name) == "values"); - for (idx_t i = 0; i < n_children; i++) { - auto type = schema.children[i]; - children.emplace_back(ArrowTableFunction::GetArrowLogicalType(*type)); - members.emplace_back(type->name, children.back()->GetDuckType()); - } - - auto type_info = make_uniq(std::move(children)); - auto struct_type = make_uniq(LogicalType::STRUCT(members), std::move(type_info)); - struct_type->SetRunEndEncoded(); - return struct_type; - } else if (format == "+m") { - auto &arrow_struct_type = *schema.children[0]; - D_ASSERT(arrow_struct_type.n_children == 2); - auto key_type = ArrowTableFunction::GetArrowLogicalType(*arrow_struct_type.children[0]); - auto value_type = ArrowTableFunction::GetArrowLogicalType(*arrow_struct_type.children[1]); - child_list_t key_value; - key_value.emplace_back(std::make_pair("key", key_type->GetDuckType())); - key_value.emplace_back(std::make_pair("value", value_type->GetDuckType())); - - auto map_type = LogicalType::MAP(key_type->GetDuckType(), value_type->GetDuckType()); - vector> children; - children.reserve(2); - children.push_back(std::move(key_type)); - children.push_back(std::move(value_type)); - auto inner_struct = make_uniq(LogicalType::STRUCT(std::move(key_value)), - make_uniq(std::move(children))); - auto map_type_info = ArrowListInfo::List(std::move(inner_struct), ArrowVariableSizeType::NORMAL); - return make_uniq(map_type, std::move(map_type_info)); - } else if (format == "z") { - auto type_info = make_uniq(ArrowVariableSizeType::NORMAL); - return make_uniq(LogicalType::BLOB, std::move(type_info)); - } else if (format == "Z") { - auto type_info = make_uniq(ArrowVariableSizeType::SUPER_SIZE); - return make_uniq(LogicalType::BLOB, std::move(type_info)); - } else if (format[0] == 'w') { - string parameters = format.substr(format.find(':') + 1); - auto fixed_size = NumericCast(std::stoi(parameters)); - auto type_info = make_uniq(fixed_size); - return make_uniq(LogicalType::BLOB, std::move(type_info)); - } else if (format[0] == 't' && format[1] == 's') { - // Timestamp with Timezone - // TODO right now we just get the UTC value. We probably want to support this properly in the future - unique_ptr type_info; - if (format[2] == 'n') { - type_info = make_uniq(ArrowDateTimeType::NANOSECONDS); - } else if (format[2] == 'u') { - type_info = make_uniq(ArrowDateTimeType::MICROSECONDS); - } else if (format[2] == 'm') { - type_info = make_uniq(ArrowDateTimeType::MILLISECONDS); - } else if (format[2] == 's') { - type_info = make_uniq(ArrowDateTimeType::SECONDS); - } else { - throw NotImplementedException(" Timestamptz precision of not accepted"); - } - return make_uniq(LogicalType::TIMESTAMP_TZ, std::move(type_info)); - } else { - throw NotImplementedException("Unsupported Internal Arrow Type %s", format); - } -} - -unique_ptr ArrowTableFunction::GetArrowLogicalType(ArrowSchema &schema) { - auto arrow_type = GetArrowLogicalTypeNoDictionary(schema); - if (schema.dictionary) { - auto dictionary = GetArrowLogicalType(*schema.dictionary); - arrow_type->SetDictionary(std::move(dictionary)); - } - return arrow_type; -} - -void ArrowTableFunction::PopulateArrowTableType(ArrowTableType &arrow_table, ArrowSchemaWrapper &schema_p, - vector &names, vector &return_types) { - for (idx_t col_idx = 0; col_idx < (idx_t)schema_p.arrow_schema.n_children; col_idx++) { - auto &schema = *schema_p.arrow_schema.children[col_idx]; - if (!schema.release) { - throw InvalidInputException("arrow_scan: released schema passed"); - } - auto arrow_type = GetArrowLogicalType(schema); - return_types.emplace_back(arrow_type->GetDuckType(true)); - arrow_table.AddColumn(col_idx, std::move(arrow_type)); - auto name = string(schema.name); - if (name.empty()) { - name = string("v") + to_string(col_idx); - } - names.push_back(name); - } -} - -unique_ptr ArrowTableFunction::ArrowScanBindDumb(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, - vector &names) { - auto bind_data = ArrowScanBind(context, input, return_types, names); - auto &arrow_bind_data = bind_data->Cast(); - arrow_bind_data.projection_pushdown_enabled = false; - return bind_data; -} - -unique_ptr ArrowTableFunction::ArrowScanBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - if (input.inputs[0].IsNull() || input.inputs[1].IsNull() || input.inputs[2].IsNull()) { - throw BinderException("arrow_scan: pointers cannot be null"); - } - auto &ref = input.ref; - - shared_ptr dependency; - if (ref.external_dependency) { - // This was created during the replacement scan for Python (see python_replacement_scan.cpp) - // this object is the owning reference to 'stream_factory_ptr' and has to be kept alive. - dependency = ref.external_dependency->GetDependency("replacement_cache"); - D_ASSERT(dependency); - } - - auto stream_factory_ptr = input.inputs[0].GetPointer(); - auto stream_factory_produce = (stream_factory_produce_t)input.inputs[1].GetPointer(); // NOLINT - auto stream_factory_get_schema = (stream_factory_get_schema_t)input.inputs[2].GetPointer(); // NOLINT - - auto res = make_uniq(stream_factory_produce, stream_factory_ptr, std::move(dependency)); - - auto &data = *res; - stream_factory_get_schema(reinterpret_cast(stream_factory_ptr), data.schema_root.arrow_schema); - PopulateArrowTableType(res->arrow_table, data.schema_root, names, return_types); - QueryResult::DeduplicateColumns(names); - res->all_types = return_types; - if (return_types.empty()) { - throw InvalidInputException("Provided table/dataframe must have at least one column"); - } - return std::move(res); -} - -unique_ptr ProduceArrowScan(const ArrowScanFunctionData &function, - const vector &column_ids, TableFilterSet *filters) { - //! Generate Projection Pushdown Vector - ArrowStreamParameters parameters; - D_ASSERT(!column_ids.empty()); - auto &arrow_types = function.arrow_table.GetColumns(); - for (idx_t idx = 0; idx < column_ids.size(); idx++) { - auto col_idx = column_ids[idx]; - if (col_idx != COLUMN_IDENTIFIER_ROW_ID) { - auto &schema = *function.schema_root.arrow_schema.children[col_idx]; - arrow_types.at(col_idx)->ThrowIfInvalid(); - parameters.projected_columns.projection_map[idx] = schema.name; - parameters.projected_columns.columns.emplace_back(schema.name); - parameters.projected_columns.filter_to_col[idx] = col_idx; - } - } - parameters.filters = filters; - return function.scanner_producer(function.stream_factory_ptr, parameters); -} - -idx_t ArrowTableFunction::ArrowScanMaxThreads(ClientContext &context, const FunctionData *bind_data_p) { - return context.db->NumberOfThreads(); -} - -bool ArrowTableFunction::ArrowScanParallelStateNext(ClientContext &context, const FunctionData *bind_data_p, - ArrowScanLocalState &state, ArrowScanGlobalState ¶llel_state) { - lock_guard parallel_lock(parallel_state.main_mutex); - if (parallel_state.done) { - return false; - } - state.Reset(); - state.batch_index = ++parallel_state.batch_index; - - auto current_chunk = parallel_state.stream->GetNextChunk(); - while (current_chunk->arrow_array.length == 0 && current_chunk->arrow_array.release) { - current_chunk = parallel_state.stream->GetNextChunk(); - } - state.chunk = std::move(current_chunk); - //! have we run out of chunks? we are done - if (!state.chunk->arrow_array.release) { - parallel_state.done = true; - return false; - } - return true; -} - -unique_ptr ArrowTableFunction::ArrowScanInitGlobal(ClientContext &context, - TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - result->stream = ProduceArrowScan(bind_data, input.column_ids, input.filters.get()); - result->max_threads = ArrowScanMaxThreads(context, input.bind_data.get()); - if (!input.projection_ids.empty()) { - result->projection_ids = input.projection_ids; - for (const auto &col_idx : input.column_ids) { - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - result->scanned_types.emplace_back(LogicalType::ROW_TYPE); - } else { - result->scanned_types.push_back(bind_data.all_types[col_idx]); - } - } - } - return std::move(result); -} - -unique_ptr -ArrowTableFunction::ArrowScanInitLocalInternal(ClientContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - auto &global_state = global_state_p->Cast(); - auto current_chunk = make_uniq(); - auto result = make_uniq(std::move(current_chunk)); - result->column_ids = input.column_ids; - result->filters = input.filters.get(); - auto &bind_data = input.bind_data->Cast(); - if (!bind_data.projection_pushdown_enabled) { - result->column_ids.clear(); - } else if (!input.projection_ids.empty()) { - auto &asgs = global_state_p->Cast(); - result->all_columns.Initialize(context, asgs.scanned_types); - } - if (!ArrowScanParallelStateNext(context, input.bind_data.get(), *result, global_state)) { - return nullptr; - } - return std::move(result); -} - -unique_ptr ArrowTableFunction::ArrowScanInitLocal(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - return ArrowScanInitLocalInternal(context.client, input, global_state_p); -} - -void ArrowTableFunction::ArrowScanFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - if (!data_p.local_state) { - return; - } - auto &data = data_p.bind_data->CastNoConst(); // FIXME - auto &state = data_p.local_state->Cast(); - auto &global_state = data_p.global_state->Cast(); - - //! Out of tuples in this chunk - if (state.chunk_offset >= (idx_t)state.chunk->arrow_array.length) { - if (!ArrowScanParallelStateNext(context, data_p.bind_data.get(), state, global_state)) { - return; - } - } - auto output_size = - MinValue(STANDARD_VECTOR_SIZE, NumericCast(state.chunk->arrow_array.length) - state.chunk_offset); - data.lines_read += output_size; - if (global_state.CanRemoveFilterColumns()) { - state.all_columns.Reset(); - state.all_columns.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), state.all_columns, data.lines_read - output_size); - output.ReferenceColumns(state.all_columns, global_state.projection_ids); - } else { - output.SetCardinality(output_size); - ArrowToDuckDB(state, data.arrow_table.GetColumns(), output, data.lines_read - output_size); - } - - output.Verify(); - state.chunk_offset += output.size(); -} - -unique_ptr ArrowTableFunction::ArrowScanCardinality(ClientContext &context, const FunctionData *data) { - return make_uniq(); -} - -OperatorPartitionData ArrowTableFunction::ArrowGetPartitionData(ClientContext &context, - TableFunctionGetPartitionInput &input) { - if (input.partition_info.RequiresPartitionColumns()) { - throw InternalException("ArrowTableFunction::GetPartitionData: partition columns not supported"); - } - auto &state = input.local_state->Cast(); - return OperatorPartitionData(state.batch_index); -} - -bool ArrowTableFunction::ArrowPushdownType(const LogicalType &type) { - switch (type.id()) { - case LogicalTypeId::BOOLEAN: - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_TZ: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::VARCHAR: - case LogicalTypeId::BLOB: - return true; - case LogicalTypeId::DECIMAL: { - switch (type.InternalType()) { - case PhysicalType::INT16: - case PhysicalType::INT32: - case PhysicalType::INT64: - return true; - default: - return false; - } - } break; - case LogicalTypeId::STRUCT: { - auto struct_types = StructType::GetChildTypes(type); - for (auto &struct_type : struct_types) { - if (!ArrowPushdownType(struct_type.second)) { - return false; - } - } - return true; - } - default: - return false; - } -} - -void ArrowTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction arrow("arrow_scan", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, - ArrowScanFunction, ArrowScanBind, ArrowScanInitGlobal, ArrowScanInitLocal); - arrow.cardinality = ArrowScanCardinality; - arrow.get_partition_data = ArrowGetPartitionData; - arrow.projection_pushdown = true; - arrow.filter_pushdown = true; - arrow.filter_prune = true; - arrow.supports_pushdown_type = ArrowPushdownType; - set.AddFunction(arrow); - - TableFunction arrow_dumb("arrow_scan_dumb", {LogicalType::POINTER, LogicalType::POINTER, LogicalType::POINTER}, - ArrowScanFunction, ArrowScanBindDumb, ArrowScanInitGlobal, ArrowScanInitLocal); - arrow_dumb.cardinality = ArrowScanCardinality; - arrow_dumb.get_partition_data = ArrowGetPartitionData; - arrow_dumb.projection_pushdown = false; - arrow_dumb.filter_pushdown = false; - arrow_dumb.filter_prune = false; - set.AddFunction(arrow_dumb); -} - -void BuiltinFunctions::RegisterArrowFunctions() { - ArrowTableFunction::RegisterFunction(*this); -} -} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp b/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp deleted file mode 100644 index 749ebc29c..000000000 --- a/src/duckdb/src/function/table/arrow/arrow_array_scan_state.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "duckdb/function/table/arrow.hpp" -#include "duckdb/common/printer.hpp" -#include "duckdb/common/types/arrow_aux_data.hpp" - -namespace duckdb { - -ArrowArrayScanState::ArrowArrayScanState(ArrowScanLocalState &state) : state(state) { - arrow_dictionary = nullptr; -} - -ArrowArrayScanState &ArrowArrayScanState::GetChild(idx_t child_idx) { - auto it = children.find(child_idx); - if (it == children.end()) { - auto child_p = make_uniq(state); - auto &child = *child_p; - child.owned_data = owned_data; - children.emplace(child_idx, std::move(child_p)); - return child; - } - if (!it->second->owned_data) { - // Propagate down the ownership, for dictionaries in children - D_ASSERT(owned_data); - it->second->owned_data = owned_data; - } - return *it->second; -} - -void ArrowArrayScanState::AddDictionary(unique_ptr dictionary_p, ArrowArray *arrow_dict) { - dictionary = std::move(dictionary_p); - D_ASSERT(owned_data); - D_ASSERT(arrow_dict); - arrow_dictionary = arrow_dict; - // Make sure the data referenced by the dictionary stays alive - dictionary->GetBuffer()->SetAuxiliaryData(make_uniq(owned_data)); -} - -bool ArrowArrayScanState::HasDictionary() const { - return dictionary != nullptr; -} - -bool ArrowArrayScanState::CacheOutdated(ArrowArray *dictionary) const { - if (!dictionary) { - // Not cached - return true; - } - if (dictionary == arrow_dictionary.get()) { - // Already cached, not outdated - return false; - } - return true; -} - -Vector &ArrowArrayScanState::GetDictionary() { - D_ASSERT(HasDictionary()); - return *dictionary; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp b/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp deleted file mode 100644 index 3f23e94ee..000000000 --- a/src/duckdb/src/function/table/arrow/arrow_duck_schema.cpp +++ /dev/null @@ -1,103 +0,0 @@ -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" -#include "duckdb/common/arrow/arrow.hpp" -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -void ArrowTableType::AddColumn(idx_t index, unique_ptr type) { - D_ASSERT(arrow_convert_data.find(index) == arrow_convert_data.end()); - arrow_convert_data.emplace(std::make_pair(index, std::move(type))); -} - -const arrow_column_map_t &ArrowTableType::GetColumns() const { - return arrow_convert_data; -} - -void ArrowType::SetDictionary(unique_ptr dictionary) { - D_ASSERT(!this->dictionary_type); - dictionary_type = std::move(dictionary); -} - -bool ArrowType::HasDictionary() const { - return dictionary_type != nullptr; -} - -const ArrowType &ArrowType::GetDictionary() const { - D_ASSERT(dictionary_type); - return *dictionary_type; -} - -void ArrowType::SetRunEndEncoded() { - D_ASSERT(type_info); - D_ASSERT(type_info->type == ArrowTypeInfoType::STRUCT); - auto &struct_info = type_info->Cast(); - D_ASSERT(struct_info.ChildCount() == 2); - - auto actual_type = struct_info.GetChild(1).GetDuckType(); - // Override the duckdb type to the actual type - type = actual_type; - run_end_encoded = true; -} - -bool ArrowType::RunEndEncoded() const { - return run_end_encoded; -} - -void ArrowType::ThrowIfInvalid() const { - if (type.id() == LogicalTypeId::INVALID) { - if (not_implemented) { - throw NotImplementedException(error_message); - } - throw InvalidInputException(error_message); - } -} - -LogicalType ArrowType::GetDuckType(bool use_dictionary) const { - if (use_dictionary && dictionary_type) { - return dictionary_type->GetDuckType(); - } - if (!use_dictionary) { - return type; - } - // Dictionaries can exist in arbitrarily nested schemas - // have to reconstruct the type - auto id = type.id(); - switch (id) { - case LogicalTypeId::STRUCT: { - auto &struct_info = type_info->Cast(); - child_list_t new_children; - for (idx_t i = 0; i < struct_info.ChildCount(); i++) { - auto &child = struct_info.GetChild(i); - auto &child_name = StructType::GetChildName(type, i); - new_children.emplace_back(std::make_pair(child_name, child.GetDuckType(true))); - } - return LogicalType::STRUCT(std::move(new_children)); - } - case LogicalTypeId::LIST: { - auto &list_info = type_info->Cast(); - auto &child = list_info.GetChild(); - return LogicalType::LIST(child.GetDuckType(true)); - } - case LogicalTypeId::MAP: { - auto &list_info = type_info->Cast(); - auto &struct_child = list_info.GetChild(); - auto struct_type = struct_child.GetDuckType(true); - return LogicalType::MAP(StructType::GetChildType(struct_type, 0), StructType::GetChildType(struct_type, 1)); - } - case LogicalTypeId::UNION: { - auto &union_info = type_info->Cast(); - child_list_t new_children; - for (idx_t i = 0; i < union_info.ChildCount(); i++) { - auto &child = union_info.GetChild(i); - auto &child_name = UnionType::GetMemberName(type, i); - new_children.emplace_back(std::make_pair(child_name, child.GetDuckType(true))); - } - return LogicalType::UNION(std::move(new_children)); - } - default: { - return type; - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow/arrow_type_info.cpp b/src/duckdb/src/function/table/arrow/arrow_type_info.cpp deleted file mode 100644 index e012f1b5c..000000000 --- a/src/duckdb/src/function/table/arrow/arrow_type_info.cpp +++ /dev/null @@ -1,135 +0,0 @@ -#include "duckdb/function/table/arrow/arrow_type_info.hpp" -#include "duckdb/function/table/arrow/arrow_duck_schema.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// ArrowTypeInfo -//===--------------------------------------------------------------------===// - -ArrowTypeInfo::ArrowTypeInfo(ArrowTypeInfoType type) : type(type) { -} - -ArrowTypeInfo::~ArrowTypeInfo() { -} - -//===--------------------------------------------------------------------===// -// ArrowStructInfo -//===--------------------------------------------------------------------===// - -ArrowStructInfo::ArrowStructInfo(vector> children) - : ArrowTypeInfo(ArrowTypeInfoType::STRUCT), children(std::move(children)) { -} - -idx_t ArrowStructInfo::ChildCount() const { - return children.size(); -} - -ArrowStructInfo::~ArrowStructInfo() { -} - -const ArrowType &ArrowStructInfo::GetChild(idx_t index) const { - D_ASSERT(index < children.size()); - return *children[index]; -} - -const vector> &ArrowStructInfo::GetChildren() const { - return children; -} - -//===--------------------------------------------------------------------===// -// ArrowDateTimeInfo -//===--------------------------------------------------------------------===// - -ArrowDateTimeInfo::ArrowDateTimeInfo(ArrowDateTimeType size) - : ArrowTypeInfo(ArrowTypeInfoType::DATE_TIME), size_type(size) { -} - -ArrowDateTimeInfo::~ArrowDateTimeInfo() { -} - -ArrowDateTimeType ArrowDateTimeInfo::GetDateTimeType() const { - return size_type; -} - -//===--------------------------------------------------------------------===// -// ArrowStringInfo -//===--------------------------------------------------------------------===// - -ArrowStringInfo::ArrowStringInfo(ArrowVariableSizeType size) - : ArrowTypeInfo(ArrowTypeInfoType::STRING), size_type(size), fixed_size(0) { - D_ASSERT(size != ArrowVariableSizeType::FIXED_SIZE); -} - -ArrowStringInfo::~ArrowStringInfo() { -} - -ArrowStringInfo::ArrowStringInfo(idx_t fixed_size) - : ArrowTypeInfo(ArrowTypeInfoType::STRING), size_type(ArrowVariableSizeType::FIXED_SIZE), fixed_size(fixed_size) { -} - -ArrowVariableSizeType ArrowStringInfo::GetSizeType() const { - return size_type; -} - -idx_t ArrowStringInfo::FixedSize() const { - D_ASSERT(size_type == ArrowVariableSizeType::FIXED_SIZE); - return fixed_size; -} - -//===--------------------------------------------------------------------===// -// ArrowListInfo -//===--------------------------------------------------------------------===// - -ArrowListInfo::ArrowListInfo(unique_ptr child, ArrowVariableSizeType size) - : ArrowTypeInfo(ArrowTypeInfoType::LIST), size_type(size), child(std::move(child)) { -} - -ArrowListInfo::~ArrowListInfo() { -} - -unique_ptr ArrowListInfo::ListView(unique_ptr child, ArrowVariableSizeType size) { - D_ASSERT(size == ArrowVariableSizeType::SUPER_SIZE || size == ArrowVariableSizeType::NORMAL); - auto list_info = unique_ptr(new ArrowListInfo(std::move(child), size)); - list_info->is_view = true; - return list_info; -} - -unique_ptr ArrowListInfo::List(unique_ptr child, ArrowVariableSizeType size) { - D_ASSERT(size == ArrowVariableSizeType::SUPER_SIZE || size == ArrowVariableSizeType::NORMAL); - return unique_ptr(new ArrowListInfo(std::move(child), size)); -} - -ArrowVariableSizeType ArrowListInfo::GetSizeType() const { - return size_type; -} - -bool ArrowListInfo::IsView() const { - return is_view; -} - -ArrowType &ArrowListInfo::GetChild() const { - return *child; -} - -//===--------------------------------------------------------------------===// -// ArrowArrayInfo -//===--------------------------------------------------------------------===// - -ArrowArrayInfo::ArrowArrayInfo(unique_ptr child, idx_t fixed_size) - : ArrowTypeInfo(ArrowTypeInfoType::ARRAY), child(std::move(child)), fixed_size(fixed_size) { - D_ASSERT(fixed_size > 0); -} - -ArrowArrayInfo::~ArrowArrayInfo() { -} - -idx_t ArrowArrayInfo::FixedSize() const { - return fixed_size; -} - -ArrowType &ArrowArrayInfo::GetChild() const { - return *child; -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/arrow_conversion.cpp b/src/duckdb/src/function/table/arrow_conversion.cpp deleted file mode 100644 index e09f81aac..000000000 --- a/src/duckdb/src/function/table/arrow_conversion.cpp +++ /dev/null @@ -1,1416 +0,0 @@ -#include "duckdb/common/exception/conversion_exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/common/operator/multiply.hpp" -#include "duckdb/common/types/arrow_aux_data.hpp" -#include "duckdb/common/types/arrow_string_view_type.hpp" -#include "duckdb/common/types/hugeint.hpp" -#include "duckdb/function/scalar/nested_functions.hpp" -#include "duckdb/function/table/arrow.hpp" - -#include "duckdb/common/bswap.hpp" - -namespace duckdb { - -namespace { - -enum class ArrowArrayPhysicalType : uint8_t { DICTIONARY_ENCODED, RUN_END_ENCODED, DEFAULT }; - -ArrowArrayPhysicalType GetArrowArrayPhysicalType(const ArrowType &type) { - if (type.HasDictionary()) { - return ArrowArrayPhysicalType::DICTIONARY_ENCODED; - } - if (type.RunEndEncoded()) { - return ArrowArrayPhysicalType::RUN_END_ENCODED; - } - return ArrowArrayPhysicalType::DEFAULT; -} - -} // namespace - -#if STANDARD_VECTOR_SIZE > 64 -static void ShiftRight(unsigned char *ar, int size, int shift) { - int carry = 0; - while (shift--) { - for (int i = size - 1; i >= 0; --i) { - int next = (ar[i] & 1) ? 0x80 : 0; - ar[i] = UnsafeNumericCast(carry | (ar[i] >> 1)); - carry = next; - } - } -} -#endif - -idx_t GetEffectiveOffset(const ArrowArray &array, int64_t parent_offset, const ArrowScanLocalState &state, - int64_t nested_offset = -1) { - if (nested_offset != -1) { - // The parent of this array is a list - // We just ignore the parent offset, it's already applied to the list - return UnsafeNumericCast(array.offset + nested_offset); - } - // Parent offset is set in the case of a struct, it applies to all child arrays - // 'chunk_offset' is how much of the chunk we've already scanned, in case the chunk size exceeds - // STANDARD_VECTOR_SIZE - return UnsafeNumericCast(array.offset + parent_offset) + state.chunk_offset; -} - -template -T *ArrowBufferData(ArrowArray &array, idx_t buffer_idx) { - return (T *)array.buffers[buffer_idx]; // NOLINT -} - -static void GetValidityMask(ValidityMask &mask, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, - int64_t parent_offset, int64_t nested_offset = -1, bool add_null = false) { - // In certains we don't need to or cannot copy arrow's validity mask to duckdb. - // - // The conditions where we do want to copy arrow's mask to duckdb are: - // 1. nulls exist - // 2. n_buffers > 0, meaning the array's arrow type is not `null` - // 3. the validity buffer (the first buffer) is not a nullptr - if (array.null_count != 0 && array.n_buffers > 0 && array.buffers[0]) { - auto bit_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - mask.EnsureWritable(); -#if STANDARD_VECTOR_SIZE > 64 - auto n_bitmask_bytes = (size + 8 - 1) / 8; - if (bit_offset % 8 == 0) { - //! just memcpy nullmask - memcpy((void *)mask.GetData(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes); - } else { - //! need to re-align nullmask - vector temp_nullmask(n_bitmask_bytes + 1); - memcpy(temp_nullmask.data(), ArrowBufferData(array, 0) + bit_offset / 8, n_bitmask_bytes + 1); - ShiftRight(temp_nullmask.data(), NumericCast(n_bitmask_bytes + 1), - NumericCast(bit_offset % 8ull)); //! why this has to be a right shift is a mystery to me - memcpy((void *)mask.GetData(), data_ptr_cast(temp_nullmask.data()), n_bitmask_bytes); - } -#else - auto byte_offset = bit_offset / 8; - auto source_data = ArrowBufferData(array, 0); - bit_offset %= 8; - for (idx_t i = 0; i < size; i++) { - mask.Set(i, source_data[byte_offset] & (1 << bit_offset)); - bit_offset++; - if (bit_offset == 8) { - bit_offset = 0; - byte_offset++; - } - } -#endif - } - if (add_null) { - //! We are setting a validity mask of the data part of dictionary vector - //! For some reason, Nulls are allowed to be indexes, hence we need to set the last element here to be null - //! We might have to resize the mask - mask.Resize(size + 1); - mask.SetInvalid(size); - } -} - -static void SetValidityMask(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, - int64_t parent_offset, int64_t nested_offset, bool add_null = false) { - D_ASSERT(vector.GetVectorType() == VectorType::FLAT_VECTOR); - auto &mask = FlatVector::Validity(vector); - GetValidityMask(mask, array, scan_state, size, parent_offset, nested_offset, add_null); -} - -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset = -1, - ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - -static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset = -1, - const ValidityMask *parent_mask = nullptr, uint64_t parent_offset = 0); - -namespace { - -struct ArrowListOffsetData { - idx_t list_size = 0; - idx_t start_offset = 0; -}; - -} // namespace - -template -static ArrowListOffsetData ConvertArrowListOffsetsTemplated(Vector &vector, ArrowArray &array, idx_t size, - idx_t effective_offset) { - ArrowListOffsetData result; - auto &start_offset = result.start_offset; - auto &list_size = result.list_size; - - if (size == 0) { - start_offset = 0; - list_size = 0; - return result; - } - - idx_t cur_offset = 0; - auto offsets = ArrowBufferData(array, 1) + effective_offset; - start_offset = offsets[0]; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = cur_offset; - le.length = offsets[i + 1] - offsets[i]; - cur_offset += le.length; - } - list_size = offsets[size]; - list_size -= start_offset; - return result; -} - -template -static ArrowListOffsetData ConvertArrowListViewOffsetsTemplated(Vector &vector, ArrowArray &array, idx_t size, - idx_t effective_offset) { - ArrowListOffsetData result; - auto &start_offset = result.start_offset; - auto &list_size = result.list_size; - - list_size = 0; - auto offsets = ArrowBufferData(array, 1) + effective_offset; - auto sizes = ArrowBufferData(array, 2) + effective_offset; - - // In ListArrays the offsets have to be sequential - // ListViewArrays do not have this same constraint - // for that reason we need to keep track of the lowest offset, so we can skip all the data that comes before it - // when we scan the child data - - auto lowest_offset = size ? offsets[0] : 0; - auto list_data = FlatVector::GetData(vector); - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = offsets[i]; - le.length = sizes[i]; - list_size += le.length; - if (sizes[i] != 0) { - lowest_offset = MinValue(lowest_offset, offsets[i]); - } - } - start_offset = lowest_offset; - if (start_offset) { - // We start scanning the child data at the 'start_offset' so we need to fix up the created list entries - for (idx_t i = 0; i < size; i++) { - auto &le = list_data[i]; - le.offset = le.offset <= start_offset ? 0 : le.offset - start_offset; - } - } - return result; -} - -static ArrowListOffsetData ConvertArrowListOffsets(Vector &vector, ArrowArray &array, idx_t size, - const ArrowType &arrow_type, idx_t effective_offset) { - auto &list_info = arrow_type.GetTypeInfo(); - auto size_type = list_info.GetSizeType(); - if (list_info.IsView()) { - if (size_type == ArrowVariableSizeType::NORMAL) { - return ConvertArrowListViewOffsetsTemplated(vector, array, size, effective_offset); - } else { - D_ASSERT(size_type == ArrowVariableSizeType::SUPER_SIZE); - return ConvertArrowListViewOffsetsTemplated(vector, array, size, effective_offset); - } - } else { - if (size_type == ArrowVariableSizeType::NORMAL) { - return ConvertArrowListOffsetsTemplated(vector, array, size, effective_offset); - } else { - D_ASSERT(size_type == ArrowVariableSizeType::SUPER_SIZE); - return ConvertArrowListOffsetsTemplated(vector, array, size, effective_offset); - } - } -} - -static void ArrowToDuckDBList(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, - int64_t parent_offset) { - auto &scan_state = array_state.state; - - auto &list_info = arrow_type.GetTypeInfo(); - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); - - auto effective_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - auto list_data = ConvertArrowListOffsets(vector, array, size, arrow_type, effective_offset); - auto &start_offset = list_data.start_offset; - auto &list_size = list_data.list_size; - - ListVector::Reserve(vector, list_size); - ListVector::SetListSize(vector, list_size); - auto &child_vector = ListVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, list_size, array.offset, - NumericCast(start_offset)); - auto &list_mask = FlatVector::Validity(vector); - if (parent_mask) { - //! Since this List is owned by a struct we must guarantee their validity map matches on Null - if (!parent_mask->AllValid()) { - for (idx_t i = 0; i < size; i++) { - if (!parent_mask->RowIsValid(i)) { - list_mask.SetInvalid(i); - } - } - } - } - auto &child_state = array_state.GetChild(0); - auto &child_array = *array.children[0]; - auto &child_type = list_info.GetChild(); - - if (list_size == 0 && start_offset == 0) { - D_ASSERT(!child_array.dictionary); - ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, -1); - return; - } - - auto array_physical_type = GetArrowArrayPhysicalType(child_type); - switch (array_physical_type) { - case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - // TODO: add support for offsets - ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); - break; - case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); - break; - case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child_vector, child_array, child_state, list_size, child_type, - NumericCast(start_offset)); - break; - default: - throw NotImplementedException("ArrowArrayPhysicalType not recognized"); - } -} - -static void ArrowToDuckDBArray(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, const ValidityMask *parent_mask, - int64_t parent_offset) { - - auto &array_info = arrow_type.GetTypeInfo(); - auto &scan_state = array_state.state; - auto array_size = array_info.FixedSize(); - auto child_count = array_size * size; - auto child_offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) * array_size; - - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); - - auto &child_vector = ArrayVector::GetEntry(vector); - SetValidityMask(child_vector, *array.children[0], scan_state, child_count, array.offset, - NumericCast(child_offset)); - - auto &array_mask = FlatVector::Validity(vector); - if (parent_mask) { - //! Since this List is owned by a struct we must guarantee their validity map matches on Null - if (!parent_mask->AllValid()) { - for (idx_t i = 0; i < size; i++) { - if (!parent_mask->RowIsValid(i)) { - array_mask.SetInvalid(i); - } - } - } - } - - // Broadcast the validity mask to the child vector - if (!array_mask.AllValid()) { - auto &child_validity_mask = FlatVector::Validity(child_vector); - for (idx_t i = 0; i < size; i++) { - if (!array_mask.RowIsValid(i)) { - for (idx_t j = 0; j < array_size; j++) { - child_validity_mask.SetInvalid(i * array_size + j); - } - } - } - } - - auto &child_state = array_state.GetChild(0); - auto &child_array = *array.children[0]; - auto &child_type = array_info.GetChild(); - if (child_count == 0 && child_offset == 0) { - D_ASSERT(!child_array.dictionary); - ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, -1); - } else { - if (child_array.dictionary) { - ColumnArrowToDuckDBDictionary(child_vector, child_array, child_state, child_count, child_type, - NumericCast(child_offset)); - } else { - ColumnArrowToDuckDB(child_vector, child_array, child_state, child_count, child_type, - NumericCast(child_offset)); - } - } -} - -static void ArrowToDuckDBBlob(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, int64_t parent_offset) { - SetValidityMask(vector, array, scan_state, size, parent_offset, nested_offset); - auto &string_info = arrow_type.GetTypeInfo(); - auto size_type = string_info.GetSizeType(); - if (size_type == ArrowVariableSizeType::FIXED_SIZE) { - auto fixed_size = string_info.FixedSize(); - //! Have to check validity mask before setting this up - idx_t offset = GetEffectiveOffset(array, parent_offset, scan_state, nested_offset) * fixed_size; - auto cdata = ArrowBufferData(array, 1); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offset; - auto blob_len = fixed_size; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - offset += blob_len; - } - } else if (size_type == ArrowVariableSizeType::NORMAL) { - auto offsets = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - auto cdata = ArrowBufferData(array, 2); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offsets[row_idx]; - auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - } - } else { - //! Check if last offset is higher than max uint32 - if (ArrowBufferData(array, 1)[array.length] > NumericLimits::Maximum()) { // LCOV_EXCL_START - throw ConversionException("DuckDB does not support Blobs over 4GB"); - } // LCOV_EXCL_STOP - auto offsets = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - auto cdata = ArrowBufferData(array, 2); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto bptr = cdata + offsets[row_idx]; - auto blob_len = offsets[row_idx + 1] - offsets[row_idx]; - FlatVector::GetData(vector)[row_idx] = StringVector::AddStringOrBlob(vector, bptr, blob_len); - } - } -} - -static void ArrowToDuckDBMapVerify(Vector &vector, idx_t count) { - auto valid_check = MapVector::CheckMapValidity(vector, count); - switch (valid_check) { - case MapInvalidReason::VALID: - break; - case MapInvalidReason::DUPLICATE_KEY: { - throw InvalidInputException("Arrow map contains duplicate key, which isn't supported by DuckDB map type"); - } - case MapInvalidReason::NULL_KEY: { - throw InvalidInputException("Arrow map contains NULL as map key, which isn't supported by DuckDB map type"); - } - default: { - throw InternalException("MapInvalidReason not implemented"); - } - } -} - -template -static void SetVectorString(Vector &vector, idx_t size, char *cdata, T *offsets) { - auto strings = FlatVector::GetData(vector); - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto cptr = cdata + offsets[row_idx]; - auto str_len = offsets[row_idx + 1] - offsets[row_idx]; - if (str_len > NumericLimits::Maximum()) { // LCOV_EXCL_START - throw ConversionException("DuckDB does not support Strings over 4GB"); - } // LCOV_EXCL_STOP - strings[row_idx] = string_t(cptr, UnsafeNumericCast(str_len)); - } -} - -static void SetVectorStringView(Vector &vector, idx_t size, ArrowArray &array, idx_t current_pos) { - auto strings = FlatVector::GetData(vector); - auto arrow_string = ArrowBufferData(array, 1) + current_pos; - - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - if (FlatVector::IsNull(vector, row_idx)) { - continue; - } - auto length = UnsafeNumericCast(arrow_string[row_idx].Length()); - if (arrow_string[row_idx].IsInline()) { - // This string is inlined - // | Bytes 0-3 | Bytes 4-15 | - // |------------|---------------------------------------| - // | length | data (padded with 0) | - strings[row_idx] = string_t(arrow_string[row_idx].GetInlineData(), length); - } else { - // This string is not inlined, we have to check a different buffer and offsets - // | Bytes 0-3 | Bytes 4-7 | Bytes 8-11 | Bytes 12-15 | - // |------------|------------|------------|-------------| - // | length | prefix | buf. index | offset | - auto buffer_index = UnsafeNumericCast(arrow_string[row_idx].GetBufferIndex()); - int32_t offset = arrow_string[row_idx].GetOffset(); - D_ASSERT(array.n_buffers > 2 + buffer_index); - auto c_data = ArrowBufferData(array, 2 + buffer_index); - strings[row_idx] = string_t(&c_data[offset], length); - } - } -} - -static void DirectConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, uint64_t parent_offset) { - auto internal_type = GetTypeIdSize(vector.GetType().InternalType()); - auto data_ptr = - ArrowBufferData(array, 1) + - internal_type * GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - FlatVector::SetData(vector, data_ptr); -} - -template -static void TimeConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = - static_cast(array.buffers[1]) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - if (!validity_mask.RowIsValid(row)) { - continue; - } - if (!TryMultiplyOperator::Operation(static_cast(src_ptr[row]), conversion, tgt_ptr[row].micros)) { - throw ConversionException("Could not convert Time to Microsecond"); - } - } -} - -static void UUIDConversion(Vector &vector, const ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size) { - auto tgt_ptr = FlatVector::GetData(vector); - auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = static_cast(array.buffers[1]) + - GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - if (!validity_mask.RowIsValid(row)) { - continue; - } - tgt_ptr[row].lower = static_cast(BSwap(src_ptr[row].upper)); - // flip Upper MSD - tgt_ptr[row].upper = - static_cast(static_cast(BSwap(src_ptr[row].lower)) ^ (static_cast(1) << 63)); - } -} - -static void TimestampTZConversion(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto &validity_mask = FlatVector::Validity(vector); - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - if (!validity_mask.RowIsValid(row)) { - continue; - } - if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].value)) { - throw ConversionException("Could not convert TimestampTZ to Microsecond"); - } - } -} - -static void IntervalConversionUs(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size, int64_t conversion) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = 0; - tgt_ptr[row].months = 0; - if (!TryMultiplyOperator::Operation(src_ptr[row], conversion, tgt_ptr[row].micros)) { - throw ConversionException("Could not convert Interval to Microsecond"); - } - } -} - -static void IntervalConversionMonths(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = 0; - tgt_ptr[row].micros = 0; - tgt_ptr[row].months = src_ptr[row]; - } -} - -static void IntervalConversionMonthDayNanos(Vector &vector, ArrowArray &array, const ArrowScanLocalState &scan_state, - int64_t nested_offset, int64_t parent_offset, idx_t size) { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = - ArrowBufferData(array, 1) + GetEffectiveOffset(array, parent_offset, scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].days = src_ptr[row].days; - tgt_ptr[row].micros = src_ptr[row].nanoseconds / Interval::NANOS_PER_MICRO; - tgt_ptr[row].months = src_ptr[row].months; - } -} - -// Find the index of the first run-end that is strictly greater than the offset. -// count is returned if no such run-end is found. -template -static idx_t FindRunIndex(const RUN_END_TYPE *run_ends, idx_t count, idx_t offset) { - // Binary-search within the [0, count) range. For example: - // [0, 0, 0, 1, 1, 2] encoded as - // run_ends: [3, 5, 6]: - // 0, 1, 2 -> 0 - // 3, 4 -> 1 - // 5 -> 2 - // 6, 7 .. -> 3 (3 == count [not found]) - idx_t begin = 0; - idx_t end = count; - while (begin < end) { - idx_t middle = (begin + end) / 2; - // begin < end implies middle < end - if (offset >= static_cast(run_ends[middle])) { - // keep searching in [middle + 1, end) - begin = middle + 1; - } else { - // offset < run_ends[middle], so keep searching in [begin, middle) - end = middle; - } - } - return begin; -} - -template -static void FlattenRunEnds(Vector &result, ArrowRunEndEncodingState &run_end_encoding, idx_t compressed_size, - idx_t scan_offset, idx_t count) { - auto &runs = *run_end_encoding.run_ends; - auto &values = *run_end_encoding.values; - - UnifiedVectorFormat run_end_format; - UnifiedVectorFormat value_format; - runs.ToUnifiedFormat(compressed_size, run_end_format); - values.ToUnifiedFormat(compressed_size, value_format); - auto run_ends_data = run_end_format.GetData(run_end_format); - auto values_data = value_format.GetData(value_format); - auto result_data = FlatVector::GetData(result); - auto &validity = FlatVector::Validity(result); - - // According to the arrow spec, the 'run_ends' array is always valid - // so we will assume this is true and not check the validity map - - // Now construct the result vector from the run_ends and the values - - auto run = FindRunIndex(run_ends_data, compressed_size, scan_offset); - idx_t logical_index = scan_offset; - idx_t index = 0; - if (value_format.validity.AllValid()) { - // None of the compressed values are NULL - for (; run < compressed_size; ++run) { - auto run_end_index = run_end_format.sel->get_index(run); - auto value_index = value_format.sel->get_index(run); - auto &value = values_data[value_index]; - auto run_end = static_cast(run_ends_data[run_end_index]); - - D_ASSERT(run_end > (logical_index + index)); - auto to_scan = run_end - (logical_index + index); - // Cap the amount to scan so we don't go over size - to_scan = MinValue(to_scan, (count - index)); - - for (idx_t i = 0; i < to_scan; i++) { - result_data[index + i] = value; - } - index += to_scan; - if (index >= count) { - if (logical_index + index >= run_end) { - // The last run was completed, forward the run index - ++run; - } - break; - } - } - } else { - for (; run < compressed_size; ++run) { - auto run_end_index = run_end_format.sel->get_index(run); - auto value_index = value_format.sel->get_index(run); - auto run_end = static_cast(run_ends_data[run_end_index]); - - D_ASSERT(run_end > (logical_index + index)); - auto to_scan = run_end - (logical_index + index); - // Cap the amount to scan so we don't go over size - to_scan = MinValue(to_scan, (count - index)); - - if (value_format.validity.RowIsValidUnsafe(value_index)) { - auto &value = values_data[value_index]; - for (idx_t i = 0; i < to_scan; i++) { - result_data[index + i] = value; - validity.SetValid(index + i); - } - } else { - for (idx_t i = 0; i < to_scan; i++) { - validity.SetInvalid(index + i); - } - } - index += to_scan; - if (index >= count) { - if (logical_index + index >= run_end) { - // The last run was completed, forward the run index - ++run; - } - break; - } - } - } -} - -template -static void FlattenRunEndsSwitch(Vector &result, ArrowRunEndEncodingState &run_end_encoding, idx_t compressed_size, - idx_t scan_offset, idx_t size) { - auto &values = *run_end_encoding.values; - auto physical_type = values.GetType().InternalType(); - - switch (physical_type) { - case PhysicalType::INT8: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT16: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT32: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT64: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT128: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::UINT8: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::UINT16: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::UINT32: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::UINT64: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::BOOL: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::FLOAT: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::DOUBLE: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INTERVAL: - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::VARCHAR: { - // Share the string heap, we don't need to allocate new strings, we just reference the existing ones - result.SetAuxiliary(values.GetAuxiliary()); - FlattenRunEnds(result, run_end_encoding, compressed_size, scan_offset, size); - break; - } - default: - throw NotImplementedException("RunEndEncoded value type '%s' not supported yet", TypeIdToString(physical_type)); - } -} - -static void ColumnArrowToDuckDBRunEndEncoded(Vector &vector, const ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset, - ValidityMask *parent_mask, uint64_t parent_offset) { - // Scan the 'run_ends' array - D_ASSERT(array.n_children == 2); - auto &run_ends_array = *array.children[0]; - auto &values_array = *array.children[1]; - - auto &struct_info = arrow_type.GetTypeInfo(); - auto &run_ends_type = struct_info.GetChild(0); - auto &values_type = struct_info.GetChild(1); - D_ASSERT(vector.GetType() == values_type.GetDuckType()); - - auto &scan_state = array_state.state; - if (vector.GetBuffer()) { - vector.GetBuffer()->SetAuxiliaryData(make_uniq(array_state.owned_data)); - } - - D_ASSERT(run_ends_array.length == values_array.length); - auto compressed_size = NumericCast(run_ends_array.length); - // Create a vector for the run ends and the values - auto &run_end_encoding = array_state.RunEndEncoding(); - if (!run_end_encoding.run_ends) { - // The run ends and values have not been scanned yet for this array - D_ASSERT(!run_end_encoding.values); - run_end_encoding.run_ends = make_uniq(run_ends_type.GetDuckType(), compressed_size); - run_end_encoding.values = make_uniq(values_type.GetDuckType(), compressed_size); - - ColumnArrowToDuckDB(*run_end_encoding.run_ends, run_ends_array, array_state, compressed_size, run_ends_type); - auto &values = *run_end_encoding.values; - SetValidityMask(values, values_array, scan_state, compressed_size, NumericCast(parent_offset), - nested_offset); - ColumnArrowToDuckDB(values, values_array, array_state, compressed_size, values_type); - } - - idx_t scan_offset = GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - auto physical_type = run_ends_type.GetDuckType().InternalType(); - switch (physical_type) { - case PhysicalType::INT16: - FlattenRunEndsSwitch(vector, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT32: - FlattenRunEndsSwitch(vector, run_end_encoding, compressed_size, scan_offset, size); - break; - case PhysicalType::INT64: - FlattenRunEndsSwitch(vector, run_end_encoding, compressed_size, scan_offset, size); - break; - default: - throw NotImplementedException("Type '%s' not implemented for RunEndEncoding", TypeIdToString(physical_type)); - } -} - -static void ColumnArrowToDuckDB(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, idx_t size, - const ArrowType &arrow_type, int64_t nested_offset, ValidityMask *parent_mask, - uint64_t parent_offset) { - auto &scan_state = array_state.state; - D_ASSERT(!array.dictionary); - - if (vector.GetBuffer()) { - vector.GetBuffer()->SetAuxiliaryData(make_uniq(array_state.owned_data)); - } - switch (vector.GetType().id()) { - case LogicalTypeId::SQLNULL: - vector.Reference(Value()); - break; - case LogicalTypeId::BOOLEAN: { - //! Arrow bit-packs boolean values - //! Lets first figure out where we are in the source array - auto effective_offset = - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - auto src_ptr = ArrowBufferData(array, 1) + effective_offset / 8; - auto tgt_ptr = (uint8_t *)FlatVector::GetData(vector); - int src_pos = 0; - idx_t cur_bit = effective_offset % 8; - for (idx_t row = 0; row < size; row++) { - if ((src_ptr[src_pos] & (1 << cur_bit)) == 0) { - tgt_ptr[row] = 0; - } else { - tgt_ptr[row] = 1; - } - cur_bit++; - if (cur_bit == 8) { - src_pos++; - cur_bit = 0; - } - } - break; - } - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::BIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::TIME_TZ: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case LogicalTypeId::UUID: - UUIDConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size); - break; - case LogicalTypeId::VARCHAR: { - auto &string_info = arrow_type.GetTypeInfo(); - auto size_type = string_info.GetSizeType(); - switch (size_type) { - case ArrowVariableSizeType::SUPER_SIZE: { - auto cdata = ArrowBufferData(array, 2); - auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - SetVectorString(vector, size, cdata, offsets); - break; - } - case ArrowVariableSizeType::NORMAL: - case ArrowVariableSizeType::FIXED_SIZE: { - auto cdata = ArrowBufferData(array, 2); - auto offsets = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - SetVectorString(vector, size, cdata, offsets); - break; - } - case ArrowVariableSizeType::VIEW: { - SetVectorStringView( - vector, size, array, - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset)); - break; - } - } - break; - } - case LogicalTypeId::DATE: { - auto &datetime_info = arrow_type.GetTypeInfo(); - auto precision = datetime_info.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::DAYS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - //! convert date from nanoseconds to days - auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row] = date_t(UnsafeNumericCast(static_cast(src_ptr[row]) / - static_cast(1000 * 60 * 60 * 24))); - } - break; - } - default: - throw NotImplementedException("Unsupported precision for Date Type "); - } - break; - } - case LogicalTypeId::TIME: { - auto &datetime_info = arrow_type.GetTypeInfo(); - auto precision = datetime_info.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000000); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - TimeConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].micros = src_ptr[row] / 1000; - } - break; - } - default: - throw NotImplementedException("Unsupported precision for Time Type "); - } - break; - } - case LogicalTypeId::TIMESTAMP_TZ: { - auto &datetime_info = arrow_type.GetTypeInfo(); - auto precision = datetime_info.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000000); - break; - } - case ArrowDateTimeType::MILLISECONDS: { - TimestampTZConversion(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - DirectConversion(vector, array, scan_state, nested_offset, parent_offset); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].value = src_ptr[row] / 1000; - } - break; - } - default: - throw NotImplementedException("Unsupported precision for TimestampTZ Type "); - } - break; - } - case LogicalTypeId::INTERVAL: { - auto &datetime_info = arrow_type.GetTypeInfo(); - auto precision = datetime_info.GetDateTimeType(); - switch (precision) { - case ArrowDateTimeType::SECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000000); - break; - } - case ArrowDateTimeType::DAYS: - case ArrowDateTimeType::MILLISECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1000); - break; - } - case ArrowDateTimeType::MICROSECONDS: { - IntervalConversionUs(vector, array, scan_state, nested_offset, NumericCast(parent_offset), size, - 1); - break; - } - case ArrowDateTimeType::NANOSECONDS: { - auto tgt_ptr = FlatVector::GetData(vector); - auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - for (idx_t row = 0; row < size; row++) { - tgt_ptr[row].micros = src_ptr[row] / 1000; - tgt_ptr[row].days = 0; - tgt_ptr[row].months = 0; - } - break; - } - case ArrowDateTimeType::MONTHS: { - IntervalConversionMonths(vector, array, scan_state, nested_offset, NumericCast(parent_offset), - size); - break; - } - case ArrowDateTimeType::MONTH_DAY_NANO: { - IntervalConversionMonthDayNanos(vector, array, scan_state, nested_offset, - NumericCast(parent_offset), size); - break; - } - default: - throw NotImplementedException("Unsupported precision for Interval/Duration Type "); - } - break; - } - case LogicalTypeId::DECIMAL: { - auto val_mask = FlatVector::Validity(vector); - //! We have to convert from INT128 - auto src_ptr = ArrowBufferData(array, 1) + - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - switch (vector.GetType().InternalType()) { - case PhysicalType::INT16: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT32: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT64: { - auto tgt_ptr = FlatVector::GetData(vector); - for (idx_t row = 0; row < size; row++) { - if (val_mask.RowIsValid(row)) { - auto result = Hugeint::TryCast(src_ptr[row], tgt_ptr[row]); - D_ASSERT(result); - (void)result; - } - } - break; - } - case PhysicalType::INT128: { - FlatVector::SetData(vector, ArrowBufferData(array, 1) + - GetTypeIdSize(vector.GetType().InternalType()) * - GetEffectiveOffset(array, NumericCast(parent_offset), - scan_state, nested_offset)); - break; - } - default: - throw NotImplementedException("Unsupported physical type for Decimal: %s", - TypeIdToString(vector.GetType().InternalType())); - } - break; - } - case LogicalTypeId::BLOB: - case LogicalTypeId::BIT: - case LogicalTypeId::VARINT: { - ArrowToDuckDBBlob(vector, array, scan_state, size, arrow_type, nested_offset, - NumericCast(parent_offset)); - break; - } - case LogicalTypeId::LIST: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, - NumericCast(parent_offset)); - break; - } - case LogicalTypeId::ARRAY: { - ArrowToDuckDBArray(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, - NumericCast(parent_offset)); - break; - } - case LogicalTypeId::MAP: { - ArrowToDuckDBList(vector, array, array_state, size, arrow_type, nested_offset, parent_mask, - NumericCast(parent_offset)); - ArrowToDuckDBMapVerify(vector, size); - break; - } - case LogicalTypeId::STRUCT: { - //! Fill the children - auto &struct_info = arrow_type.GetTypeInfo(); - auto &child_entries = StructVector::GetEntries(vector); - auto &struct_validity_mask = FlatVector::Validity(vector); - for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { - auto &child_entry = *child_entries[child_idx]; - auto &child_array = *array.children[child_idx]; - auto &child_type = struct_info.GetChild(child_idx); - auto &child_state = array_state.GetChild(child_idx); - - SetValidityMask(child_entry, child_array, scan_state, size, array.offset, nested_offset); - if (!struct_validity_mask.AllValid()) { - auto &child_validity_mark = FlatVector::Validity(child_entry); - for (idx_t i = 0; i < size; i++) { - if (!struct_validity_mask.RowIsValid(i)) { - child_validity_mark.SetInvalid(i); - } - } - } - - auto array_physical_type = GetArrowArrayPhysicalType(child_type); - switch (array_physical_type) { - case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset)); - break; - case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset)); - break; - case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child_entry, child_array, child_state, size, child_type, nested_offset, - &struct_validity_mask, NumericCast(array.offset)); - break; - default: - throw NotImplementedException("ArrowArrayPhysicalType not recognized"); - } - } - break; - } - case LogicalTypeId::UNION: { - auto type_ids = ArrowBufferData(array, array.n_buffers == 1 ? 0 : 1); - D_ASSERT(type_ids); - auto members = UnionType::CopyMemberTypes(vector.GetType()); - - auto &validity_mask = FlatVector::Validity(vector); - auto &union_info = arrow_type.GetTypeInfo(); - duckdb::vector children; - for (idx_t child_idx = 0; child_idx < NumericCast(array.n_children); child_idx++) { - Vector child(members[child_idx].second, size); - auto &child_array = *array.children[child_idx]; - auto &child_state = array_state.GetChild(child_idx); - auto &child_type = union_info.GetChild(child_idx); - - SetValidityMask(child, child_array, scan_state, size, NumericCast(parent_offset), nested_offset); - auto array_physical_type = GetArrowArrayPhysicalType(child_type); - - switch (array_physical_type) { - case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(child, child_array, child_state, size, child_type); - break; - case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(child, child_array, child_state, size, child_type); - break; - case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(child, child_array, child_state, size, child_type, nested_offset, &validity_mask); - break; - default: - throw NotImplementedException("ArrowArrayPhysicalType not recognized"); - } - - children.push_back(std::move(child)); - } - - for (idx_t row_idx = 0; row_idx < size; row_idx++) { - auto tag = NumericCast(type_ids[row_idx]); - - auto out_of_range = tag >= array.n_children; - if (out_of_range) { - throw InvalidInputException("Arrow union tag out of range: %d", tag); - } - - const Value &value = children[tag].GetValue(row_idx); - vector.SetValue(row_idx, value.IsNull() ? Value() : Value::UNION(members, tag, value)); - } - - break; - } - default: - throw NotImplementedException("Unsupported type for arrow conversion: %s", vector.GetType().ToString()); - } -} - -template -static void SetSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - sel.set_index(row, UnsafeNumericCast(indices[row])); - } -} - -template -static void SetSelectionVectorLoopWithChecks(SelectionVector &sel, data_ptr_t indices_p, idx_t size) { - - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - if (indices[row] > NumericLimits::Maximum()) { - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - sel.set_index(row, NumericCast(indices[row])); - } -} - -template -static void SetMaskedSelectionVectorLoop(SelectionVector &sel, data_ptr_t indices_p, idx_t size, ValidityMask &mask, - idx_t last_element_pos) { - auto indices = reinterpret_cast(indices_p); - for (idx_t row = 0; row < size; row++) { - if (mask.RowIsValid(row)) { - sel.set_index(row, UnsafeNumericCast(indices[row])); - } else { - //! Need to point out to last element - sel.set_index(row, last_element_pos); - } - } -} - -static void SetSelectionVector(SelectionVector &sel, data_ptr_t indices_p, const LogicalType &logical_type, idx_t size, - ValidityMask *mask = nullptr, idx_t last_element_pos = 0) { - sel.Initialize(size); - - if (mask) { - switch (logical_type.id()) { - case LogicalTypeId::UTINYINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::TINYINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::USMALLINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::SMALLINT: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::UINTEGER: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::INTEGER: - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::UBIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - case LogicalTypeId::BIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! Its guaranteed that our indices will point to the last element, so just throw an error - throw ConversionException("DuckDB only supports indices that fit on an uint32"); - } - SetMaskedSelectionVectorLoop(sel, indices_p, size, *mask, last_element_pos); - break; - - default: - throw NotImplementedException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); - } - - } else { - switch (logical_type.id()) { - case LogicalTypeId::UTINYINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::TINYINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::USMALLINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::SMALLINT: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::UINTEGER: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::INTEGER: - SetSelectionVectorLoop(sel, indices_p, size); - break; - case LogicalTypeId::UBIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! We need to check if our indexes fit in a uint32_t - SetSelectionVectorLoopWithChecks(sel, indices_p, size); - } else { - SetSelectionVectorLoop(sel, indices_p, size); - } - break; - case LogicalTypeId::BIGINT: - if (last_element_pos > NumericLimits::Maximum()) { - //! We need to check if our indexes fit in a uint32_t - SetSelectionVectorLoopWithChecks(sel, indices_p, size); - } else { - SetSelectionVectorLoop(sel, indices_p, size); - } - break; - default: - throw ConversionException("(Arrow) Unsupported type for selection vectors %s", logical_type.ToString()); - } - } -} - -static bool CanContainNull(const ArrowArray &array, const ValidityMask *parent_mask) { - if (array.null_count > 0) { - return true; - } - if (!parent_mask) { - return false; - } - return !parent_mask->AllValid(); -} - -static void ColumnArrowToDuckDBDictionary(Vector &vector, ArrowArray &array, ArrowArrayScanState &array_state, - idx_t size, const ArrowType &arrow_type, int64_t nested_offset, - const ValidityMask *parent_mask, uint64_t parent_offset) { - if (vector.GetBuffer()) { - vector.GetBuffer()->SetAuxiliaryData(make_uniq(array_state.owned_data)); - } - D_ASSERT(arrow_type.HasDictionary()); - auto &scan_state = array_state.state; - const bool has_nulls = CanContainNull(array, parent_mask); - if (array_state.CacheOutdated(array.dictionary)) { - //! We need to set the dictionary data for this column - auto base_vector = make_uniq(vector.GetType(), NumericCast(array.dictionary->length)); - SetValidityMask(*base_vector, *array.dictionary, scan_state, NumericCast(array.dictionary->length), 0, 0, - has_nulls); - auto &dictionary_type = arrow_type.GetDictionary(); - auto arrow_physical_type = GetArrowArrayPhysicalType(dictionary_type); - switch (arrow_physical_type) { - case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(*base_vector, *array.dictionary, array_state, - NumericCast(array.dictionary->length), dictionary_type); - break; - case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(*base_vector, *array.dictionary, array_state, - NumericCast(array.dictionary->length), dictionary_type); - break; - case ArrowArrayPhysicalType::DEFAULT: - ColumnArrowToDuckDB(*base_vector, *array.dictionary, array_state, - NumericCast(array.dictionary->length), dictionary_type); - break; - default: - throw NotImplementedException("ArrowArrayPhysicalType not recognized"); - }; - array_state.AddDictionary(std::move(base_vector), array.dictionary); - } - auto offset_type = arrow_type.GetDuckType(); - //! Get Pointer to Indices of Dictionary - auto indices = ArrowBufferData(array, 1) + - GetTypeIdSize(offset_type.InternalType()) * - GetEffectiveOffset(array, NumericCast(parent_offset), scan_state, nested_offset); - - SelectionVector sel; - if (has_nulls) { - ValidityMask indices_validity; - GetValidityMask(indices_validity, array, scan_state, size, NumericCast(parent_offset)); - if (parent_mask && !parent_mask->AllValid()) { - auto &struct_validity_mask = *parent_mask; - for (idx_t i = 0; i < size; i++) { - if (!struct_validity_mask.RowIsValid(i)) { - indices_validity.SetInvalid(i); - } - } - } - SetSelectionVector(sel, indices, offset_type, size, &indices_validity, - NumericCast(array.dictionary->length)); - } else { - SetSelectionVector(sel, indices, offset_type, size); - } - vector.Slice(array_state.GetDictionary(), sel, size); - vector.Verify(size); -} - -void ArrowTableFunction::ArrowToDuckDB(ArrowScanLocalState &scan_state, const arrow_column_map_t &arrow_convert_data, - DataChunk &output, idx_t start, bool arrow_scan_is_projected, - idx_t rowid_column_index) { - for (idx_t idx = 0; idx < output.ColumnCount(); idx++) { - auto col_idx = scan_state.column_ids.empty() ? idx : scan_state.column_ids[idx]; - - // If projection was not pushed down into the arrow scanner, but projection pushdown is enabled on the - // table function, we need to use original column ids here. - auto arrow_array_idx = arrow_scan_is_projected ? idx : col_idx; - - if (rowid_column_index != COLUMN_IDENTIFIER_ROW_ID) { - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - arrow_array_idx = rowid_column_index; - } else if (col_idx >= rowid_column_index) { - // Since the rowid column is skipped when the table is bound (its not a named column), - // we need to shift references forward in the Arrow array by one to match the alignment - // that DuckDB believes the Arrow array is in. - col_idx += 1; - arrow_array_idx += 1; - } - } else { - // If there isn't any defined row_id_index, and we're asked for it, skip the column. - // This is the incumbent behavior. - if (col_idx == COLUMN_IDENTIFIER_ROW_ID) { - continue; - } - } - - auto &parent_array = scan_state.chunk->arrow_array; - auto &array = *scan_state.chunk->arrow_array.children[arrow_array_idx]; - if (!array.release) { - throw InvalidInputException("arrow_scan: released array passed"); - } - if (array.length != scan_state.chunk->arrow_array.length) { - throw InvalidInputException("arrow_scan: array length mismatch"); - } - - D_ASSERT(arrow_convert_data.find(col_idx) != arrow_convert_data.end()); - auto &arrow_type = *arrow_convert_data.at(col_idx); - auto &array_state = scan_state.GetState(col_idx); - - // Make sure this Vector keeps the Arrow chunk alive in case we can zero-copy the data - if (!array_state.owned_data) { - array_state.owned_data = scan_state.chunk; - } - - auto array_physical_type = GetArrowArrayPhysicalType(arrow_type); - - switch (array_physical_type) { - case ArrowArrayPhysicalType::DICTIONARY_ENCODED: - ColumnArrowToDuckDBDictionary(output.data[idx], array, array_state, output.size(), arrow_type); - break; - case ArrowArrayPhysicalType::RUN_END_ENCODED: - ColumnArrowToDuckDBRunEndEncoded(output.data[idx], array, array_state, output.size(), arrow_type); - break; - case ArrowArrayPhysicalType::DEFAULT: - SetValidityMask(output.data[idx], array, scan_state, output.size(), parent_array.offset, -1); - ColumnArrowToDuckDB(output.data[idx], array, array_state, output.size(), arrow_type); - break; - default: - throw NotImplementedException("ArrowArrayPhysicalType not recognized"); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/checkpoint.cpp b/src/duckdb/src/function/table/checkpoint.cpp deleted file mode 100644 index 9fdd19a79..000000000 --- a/src/duckdb/src/function/table/checkpoint.cpp +++ /dev/null @@ -1,69 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/transaction/transaction_manager.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/function/function_set.hpp" - -namespace duckdb { - -struct CheckpointBindData : public FunctionData { - explicit CheckpointBindData(optional_ptr db) : db(db) { - } - - optional_ptr db; - -public: - unique_ptr Copy() const override { - return make_uniq(db); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return db == other.db; - } -}; - -static unique_ptr CheckpointBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::BOOLEAN); - names.emplace_back("Success"); - - optional_ptr db; - auto &db_manager = DatabaseManager::Get(context); - if (!input.inputs.empty()) { - if (input.inputs[0].IsNull()) { - throw BinderException("Database cannot be NULL"); - } - auto &db_name = StringValue::Get(input.inputs[0]); - db = db_manager.GetDatabase(context, db_name); - if (!db) { - throw BinderException("Database \"%s\" not found", db_name); - } - } else { - db = db_manager.GetDatabase(context, DatabaseManager::GetDefaultDatabase(context)); - } - return make_uniq(db); -} - -template -static void TemplatedCheckpointFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &transaction_manager = TransactionManager::Get(*bind_data.db.get_mutable()); - transaction_manager.Checkpoint(context, FORCE); -} - -void CheckpointFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet checkpoint("checkpoint"); - checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); - checkpoint.AddFunction(TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); - set.AddFunction(checkpoint); - - TableFunctionSet force_checkpoint("force_checkpoint"); - force_checkpoint.AddFunction(TableFunction({}, TemplatedCheckpointFunction, CheckpointBind)); - force_checkpoint.AddFunction( - TableFunction({LogicalType::VARCHAR}, TemplatedCheckpointFunction, CheckpointBind)); - set.AddFunction(force_checkpoint); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/copy_csv.cpp b/src/duckdb/src/function/table/copy_csv.cpp deleted file mode 100644 index 42e3f0610..000000000 --- a/src/duckdb/src/function/table/copy_csv.cpp +++ /dev/null @@ -1,660 +0,0 @@ -#include "duckdb/common/bind_helpers.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/serializer/memory_stream.hpp" -#include "duckdb/common/serializer/write_stream.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/types/column/column_data_collection.hpp" -#include "duckdb/common/types/string_type.hpp" -#include "duckdb/common/vector_operations/vector_operations.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/function/copy_function.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/table/read_csv.hpp" -#include "duckdb/parser/expression/bound_expression.hpp" -#include "duckdb/parser/expression/cast_expression.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/parser/parsed_data/copy_info.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" - -#include - -namespace duckdb { - -void AreOptionsEqual(char str_1, char str_2, const string &name_str_1, const string &name_str_2) { - if (str_1 == '\0' || str_2 == '\0') { - return; - } - if (str_1 == str_2) { - throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); - } -} - -void SubstringDetection(char str_1, string &str_2, const string &name_str_1, const string &name_str_2) { - if (str_1 == '\0' || str_2.empty()) { - return; - } - if (str_2.find(str_1) != string::npos) { - throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); - } -} - -void StringDetection(const string &str_1, const string &str_2, const string &name_str_1, const string &name_str_2) { - if (str_1.empty() || str_2.empty()) { - return; - } - if (str_2.find(str_1) != string::npos) { - throw BinderException("%s must not appear in the %s specification and vice versa", name_str_1, name_str_2); - } -} - -//===--------------------------------------------------------------------===// -// Bind -//===--------------------------------------------------------------------===// -void WriteQuoteOrEscape(WriteStream &writer, char quote_or_escape) { - if (quote_or_escape != '\0') { - writer.Write(quote_or_escape); - } -} - -void BaseCSVData::Finalize() { - auto delimiter_string = options.dialect_options.state_machine_options.delimiter.GetValue(); - - // quote and delimiter must not be substrings of each other - SubstringDetection(options.dialect_options.state_machine_options.quote.GetValue(), delimiter_string, "QUOTE", - "DELIMITER"); - - // escape and delimiter must not be substrings of each other - SubstringDetection(options.dialect_options.state_machine_options.escape.GetValue(), delimiter_string, "ESCAPE", - "DELIMITER"); - - // escape and quote must not be substrings of each other (but can be the same) - if (options.dialect_options.state_machine_options.quote != options.dialect_options.state_machine_options.escape) { - AreOptionsEqual(options.dialect_options.state_machine_options.quote.GetValue(), - options.dialect_options.state_machine_options.escape.GetValue(), "QUOTE", "ESCAPE"); - } - - // comment and quote must not be substrings of each other - AreOptionsEqual(options.dialect_options.state_machine_options.comment.GetValue(), - options.dialect_options.state_machine_options.quote.GetValue(), "COMMENT", "QUOTE"); - - // delimiter and comment must not be substrings of each other - SubstringDetection(options.dialect_options.state_machine_options.comment.GetValue(), delimiter_string, "COMMENT", - "DELIMITER"); - - // null string and delimiter must not be substrings of each other - for (auto &null_str : options.null_str) { - if (!null_str.empty()) { - StringDetection(options.dialect_options.state_machine_options.delimiter.GetValue(), null_str, "DELIMITER", - "NULL"); - - // quote and nullstr must not be substrings of each other - SubstringDetection(options.dialect_options.state_machine_options.quote.GetValue(), null_str, "QUOTE", - "NULL"); - - // Validate the nullstr against the escape character - const char escape = options.dialect_options.state_machine_options.escape.GetValue(); - // Allow nullstr to be escape character + some non-special character, e.g., "\N" (MySQL default). - // In this case, only unquoted occurrences of the nullstr will be recognized as null values. - if (options.dialect_options.state_machine_options.rfc_4180 == false && null_str.size() == 2 && - null_str[0] == escape && null_str[1] != '\0') { - continue; - } - SubstringDetection(escape, null_str, "ESCAPE", "NULL"); - } - } - - if (!options.prefix.empty() || !options.suffix.empty()) { - if (options.prefix.empty() || options.suffix.empty()) { - throw BinderException("COPY ... (FORMAT CSV) must have both PREFIX and SUFFIX, or none at all"); - } - if (options.dialect_options.header.GetValue()) { - throw BinderException("COPY ... (FORMAT CSV)'s HEADER cannot be combined with PREFIX/SUFFIX"); - } - } -} - -string TransformNewLine(string new_line) { - new_line = StringUtil::Replace(new_line, "\\r", "\r"); - return StringUtil::Replace(new_line, "\\n", "\n"); - ; -} - -static vector> CreateCastExpressions(WriteCSVData &bind_data, ClientContext &context, - const vector &names, - const vector &sql_types) { - auto &options = bind_data.options; - auto &formats = options.write_date_format; - - bool has_dateformat = !formats[LogicalTypeId::DATE].IsNull(); - bool has_timestampformat = !formats[LogicalTypeId::TIMESTAMP].IsNull(); - - // Create a binder - auto binder = Binder::CreateBinder(context); - - auto &bind_context = binder->bind_context; - auto table_index = binder->GenerateTableIndex(); - bind_context.AddGenericBinding(table_index, "copy_csv", names, sql_types); - - // Create the ParsedExpressions (cast, strftime, etc..) - vector> unbound_expressions; - for (idx_t i = 0; i < sql_types.size(); i++) { - auto &type = sql_types[i]; - auto &name = names[i]; - - bool is_timestamp = type.id() == LogicalTypeId::TIMESTAMP || type.id() == LogicalTypeId::TIMESTAMP_TZ; - if (has_dateformat && type.id() == LogicalTypeId::DATE) { - // strftime(, 'format') - vector> children; - children.push_back(make_uniq(make_uniq(name, type, i))); - children.push_back(make_uniq(formats[LogicalTypeId::DATE])); - auto func = make_uniq_base("strftime", std::move(children)); - unbound_expressions.push_back(std::move(func)); - } else if (has_timestampformat && is_timestamp) { - // strftime(, 'format') - vector> children; - children.push_back(make_uniq(make_uniq(name, type, i))); - children.push_back(make_uniq(formats[LogicalTypeId::TIMESTAMP])); - auto func = make_uniq_base("strftime", std::move(children)); - unbound_expressions.push_back(std::move(func)); - } else { - // CAST AS VARCHAR - auto column = make_uniq(make_uniq(name, type, i)); - auto expr = make_uniq_base(LogicalType::VARCHAR, std::move(column)); - unbound_expressions.push_back(std::move(expr)); - } - } - - // Create an ExpressionBinder, bind the Expressions - vector> expressions; - ExpressionBinder expression_binder(*binder, context); - expression_binder.target_type = LogicalType::VARCHAR; - for (auto &expr : unbound_expressions) { - expressions.push_back(expression_binder.Bind(expr)); - } - - return expressions; -} - -static unique_ptr WriteCSVBind(ClientContext &context, CopyFunctionBindInput &input, - const vector &names, const vector &sql_types) { - auto bind_data = make_uniq(input.info.file_path, sql_types, names); - - // check all the options in the copy info - for (auto &option : input.info.options) { - auto loption = StringUtil::Lower(option.first); - auto &set = option.second; - bind_data->options.SetWriteOption(loption, ConvertVectorToValue(set)); - } - // verify the parsed options - if (bind_data->options.force_quote.empty()) { - // no FORCE_QUOTE specified: initialize to false - bind_data->options.force_quote.resize(names.size(), false); - } - bind_data->Finalize(); - - switch (bind_data->options.compression) { - case FileCompressionType::GZIP: - if (!IsFileCompressed(input.file_extension, FileCompressionType::GZIP)) { - input.file_extension += CompressionExtensionFromType(FileCompressionType::GZIP); - } - break; - case FileCompressionType::ZSTD: - if (!IsFileCompressed(input.file_extension, FileCompressionType::ZSTD)) { - input.file_extension += CompressionExtensionFromType(FileCompressionType::ZSTD); - } - break; - default: - break; - } - - auto expressions = CreateCastExpressions(*bind_data, context, names, sql_types); - bind_data->cast_expressions = std::move(expressions); - - bind_data->requires_quotes = make_unsafe_uniq_array(256); - memset(bind_data->requires_quotes.get(), 0, sizeof(bool) * 256); - bind_data->requires_quotes['\n'] = true; - bind_data->requires_quotes['\r'] = true; - bind_data->requires_quotes[NumericCast( - bind_data->options.dialect_options.state_machine_options.delimiter.GetValue()[0])] = true; - bind_data->requires_quotes[NumericCast( - bind_data->options.dialect_options.state_machine_options.quote.GetValue())] = true; - - if (!bind_data->options.write_newline.empty()) { - bind_data->newline = TransformNewLine(bind_data->options.write_newline); - } - return std::move(bind_data); -} - -static unique_ptr ReadCSVBind(ClientContext &context, CopyInfo &info, vector &expected_names, - vector &expected_types) { - auto bind_data = make_uniq(); - bind_data->csv_types = expected_types; - bind_data->csv_names = expected_names; - bind_data->return_types = expected_types; - bind_data->return_names = expected_names; - - auto multi_file_reader = MultiFileReader::CreateDefault("CSVCopy"); - bind_data->files = multi_file_reader->CreateFileList(context, Value(info.file_path))->GetAllFiles(); - - auto &options = bind_data->options; - - // check all the options in the copy info - for (auto &option : info.options) { - auto loption = StringUtil::Lower(option.first); - auto &set = option.second; - options.SetReadOption(loption, ConvertVectorToValue(set), expected_names); - } - // verify the parsed options - if (options.force_not_null.empty()) { - // no FORCE_QUOTE specified: initialize to false - options.force_not_null.resize(expected_types.size(), false); - } - - // Look for rejects table options last - named_parameter_map_t options_map; - for (auto &option : info.options) { - options_map[option.first] = ConvertVectorToValue(std::move(option.second)); - } - options.file_path = bind_data->files[0]; - options.name_list = expected_names; - options.sql_type_list = expected_types; - options.columns_set = true; - for (idx_t i = 0; i < expected_types.size(); i++) { - options.sql_types_per_column[expected_names[i]] = i; - } - - if (options.auto_detect) { - auto buffer_manager = make_shared_ptr(context, options, bind_data->files[0], 0); - CSVSniffer sniffer(options, buffer_manager, CSVStateMachineCache::Get(context)); - sniffer.SniffCSV(); - } - bind_data->FinalizeRead(context); - - return std::move(bind_data); -} - -//===--------------------------------------------------------------------===// -// Helper writing functions -//===--------------------------------------------------------------------===// -static string AddEscapes(char to_be_escaped, const char escape, const string &val) { - idx_t i = 0; - string new_val = ""; - idx_t found = val.find(to_be_escaped); - - while (found != string::npos) { - while (i < found) { - new_val += val[i]; - i++; - } - if (escape != '\0') { - new_val += escape; - found = val.find(to_be_escaped, found + 1); - } - } - while (i < val.length()) { - new_val += val[i]; - i++; - } - return new_val; -} - -static bool RequiresQuotes(WriteCSVData &csv_data, const char *str, idx_t len) { - auto &options = csv_data.options; - // check if the string is equal to the null string - if (len == options.null_str[0].size() && memcmp(str, options.null_str[0].c_str(), len) == 0) { - return true; - } - auto str_data = reinterpret_cast(str); - for (idx_t i = 0; i < len; i++) { - if (csv_data.requires_quotes[str_data[i]]) { - // this byte requires quotes - write a quoted string - return true; - } - } - // no newline, quote or delimiter in the string - // no quoting or escaping necessary - return false; -} - -static void WriteQuotedString(WriteStream &writer, WriteCSVData &csv_data, const char *str, idx_t len, - bool force_quote) { - auto &options = csv_data.options; - if (!force_quote) { - // force quote is disabled: check if we need to add quotes anyway - force_quote = RequiresQuotes(csv_data, str, len); - } - // If a quote is set to none (i.e., null-terminator) we skip the quotation - if (force_quote && options.dialect_options.state_machine_options.quote.GetValue() != '\0') { - // quoting is enabled: we might need to escape things in the string - bool requires_escape = false; - // simple CSV - // do a single loop to check for a quote or escape value - for (idx_t i = 0; i < len; i++) { - if (str[i] == options.dialect_options.state_machine_options.quote.GetValue() || - str[i] == options.dialect_options.state_machine_options.escape.GetValue()) { - requires_escape = true; - break; - } - } - - if (!requires_escape) { - // fast path: no need to escape anything - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote.GetValue()); - writer.WriteData(const_data_ptr_cast(str), len); - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote.GetValue()); - return; - } - - // slow path: need to add escapes - string new_val(str, len); - new_val = AddEscapes(options.dialect_options.state_machine_options.escape.GetValue(), - options.dialect_options.state_machine_options.escape.GetValue(), new_val); - if (options.dialect_options.state_machine_options.escape != - options.dialect_options.state_machine_options.quote) { - // need to escape quotes separately - new_val = AddEscapes(options.dialect_options.state_machine_options.quote.GetValue(), - options.dialect_options.state_machine_options.escape.GetValue(), new_val); - } - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote.GetValue()); - writer.WriteData(const_data_ptr_cast(new_val.c_str()), new_val.size()); - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.quote.GetValue()); - } else { - writer.WriteData(const_data_ptr_cast(str), len); - } -} - -//===--------------------------------------------------------------------===// -// Sink -//===--------------------------------------------------------------------===// -struct LocalWriteCSVData : public LocalFunctionData { -public: - LocalWriteCSVData(ClientContext &context, vector> &expressions) - : executor(context, expressions) { - } - -public: - //! Used to execute the expressions that transform input -> string - ExpressionExecutor executor; - //! The thread-local buffer to write data into - MemoryStream stream; - //! A chunk with VARCHAR columns to cast intermediates into - DataChunk cast_chunk; - //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY - bool written_anything = false; -}; - -struct GlobalWriteCSVData : public GlobalFunctionData { - GlobalWriteCSVData(FileSystem &fs, const string &file_path, FileCompressionType compression) - : fs(fs), written_anything(false) { - handle = fs.OpenFile(file_path, FileFlags::FILE_FLAGS_WRITE | FileFlags::FILE_FLAGS_FILE_CREATE_NEW | - FileLockType::WRITE_LOCK | compression); - } - - //! Write generic data, e.g., CSV header - void WriteData(const_data_ptr_t data, idx_t size) { - lock_guard flock(lock); - handle->Write((void *)data, size); - } - - void WriteData(const char *data, idx_t size) { - WriteData(const_data_ptr_cast(data), size); - } - - //! Write rows - void WriteRows(const_data_ptr_t data, idx_t size, const string &newline) { - lock_guard flock(lock); - if (written_anything) { - handle->Write((void *)newline.c_str(), newline.length()); - } else { - written_anything = true; - } - handle->Write((void *)data, size); - } - - idx_t FileSize() { - lock_guard flock(lock); - return handle->GetFileSize(); - } - - FileSystem &fs; - //! The mutex for writing to the physical file - mutex lock; - //! The file handle to write to - unique_ptr handle; - //! If we've written any rows yet, allows us to prevent a trailing comma when writing JSON ARRAY - bool written_anything; -}; - -static unique_ptr WriteCSVInitializeLocal(ExecutionContext &context, FunctionData &bind_data) { - auto &csv_data = bind_data.Cast(); - auto local_data = make_uniq(context.client, csv_data.cast_expressions); - - // create the chunk with VARCHAR types - vector types; - types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); - - local_data->cast_chunk.Initialize(Allocator::Get(context.client), types); - return std::move(local_data); -} - -static unique_ptr WriteCSVInitializeGlobal(ClientContext &context, FunctionData &bind_data, - const string &file_path) { - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - auto global_data = - make_uniq(FileSystem::GetFileSystem(context), file_path, options.compression); - - if (!options.prefix.empty()) { - global_data->WriteData(options.prefix.c_str(), options.prefix.size()); - } - - if (!(options.dialect_options.header.IsSetByUser() && !options.dialect_options.header.GetValue())) { - MemoryStream stream; - // write the header line to the file - for (idx_t i = 0; i < csv_data.options.name_list.size(); i++) { - if (i != 0) { - WriteQuoteOrEscape(stream, options.dialect_options.state_machine_options.delimiter.GetValue()[0]); - } - WriteQuotedString(stream, csv_data, csv_data.options.name_list[i].c_str(), - csv_data.options.name_list[i].size(), false); - } - stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - - global_data->WriteData(stream.GetData(), stream.GetPosition()); - } - - return std::move(global_data); -} - -static void WriteCSVChunkInternal(ClientContext &context, FunctionData &bind_data, DataChunk &cast_chunk, - MemoryStream &writer, DataChunk &input, bool &written_anything, - ExpressionExecutor &executor) { - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - - // first cast the columns of the chunk to varchar - cast_chunk.Reset(); - cast_chunk.SetCardinality(input); - - executor.Execute(input, cast_chunk); - - cast_chunk.Flatten(); - // now loop over the vectors and output the values - for (idx_t row_idx = 0; row_idx < cast_chunk.size(); row_idx++) { - if (row_idx == 0 && !written_anything) { - written_anything = true; - } else { - writer.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - } - // write values - D_ASSERT(options.null_str.size() == 1); - for (idx_t col_idx = 0; col_idx < cast_chunk.ColumnCount(); col_idx++) { - if (col_idx != 0) { - WriteQuoteOrEscape(writer, options.dialect_options.state_machine_options.delimiter.GetValue()[0]); - } - if (FlatVector::IsNull(cast_chunk.data[col_idx], row_idx)) { - // write null value - writer.WriteData(const_data_ptr_cast(options.null_str[0].c_str()), options.null_str[0].size()); - continue; - } - - // non-null value, fetch the string value from the cast chunk - auto str_data = FlatVector::GetData(cast_chunk.data[col_idx]); - // FIXME: we could gain some performance here by checking for certain types if they ever require quotes - // (e.g. integers only require quotes if the delimiter is a number, decimals only require quotes if the - // delimiter is a number or "." character) - WriteQuotedString(writer, csv_data, str_data[row_idx].GetData(), str_data[row_idx].GetSize(), - csv_data.options.force_quote[col_idx]); - } - } -} - -static void WriteCSVSink(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - LocalFunctionData &lstate, DataChunk &input) { - auto &csv_data = bind_data.Cast(); - auto &local_data = lstate.Cast(); - auto &global_state = gstate.Cast(); - - // write data into the local buffer - WriteCSVChunkInternal(context.client, bind_data, local_data.cast_chunk, local_data.stream, input, - local_data.written_anything, local_data.executor); - - // check if we should flush what we have currently written - auto &writer = local_data.stream; - if (writer.GetPosition() >= csv_data.flush_size) { - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); - local_data.written_anything = false; - } -} - -//===--------------------------------------------------------------------===// -// Combine -//===--------------------------------------------------------------------===// -static void WriteCSVCombine(ExecutionContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - LocalFunctionData &lstate) { - auto &local_data = lstate.Cast(); - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &writer = local_data.stream; - // flush the local writer - if (local_data.written_anything) { - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); - } -} - -//===--------------------------------------------------------------------===// -// Finalize -//===--------------------------------------------------------------------===// -void WriteCSVFinalize(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate) { - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &options = csv_data.options; - - MemoryStream stream; - if (!options.suffix.empty()) { - stream.WriteData(const_data_ptr_cast(options.suffix.c_str()), options.suffix.size()); - } else if (global_state.written_anything) { - stream.WriteData(const_data_ptr_cast(csv_data.newline.c_str()), csv_data.newline.size()); - } - global_state.WriteData(stream.GetData(), stream.GetPosition()); - - global_state.handle->Close(); - global_state.handle.reset(); -} - -//===--------------------------------------------------------------------===// -// Execution Mode -//===--------------------------------------------------------------------===// -CopyFunctionExecutionMode WriteCSVExecutionMode(bool preserve_insertion_order, bool supports_batch_index) { - if (!preserve_insertion_order) { - return CopyFunctionExecutionMode::PARALLEL_COPY_TO_FILE; - } - if (supports_batch_index) { - return CopyFunctionExecutionMode::BATCH_COPY_TO_FILE; - } - return CopyFunctionExecutionMode::REGULAR_COPY_TO_FILE; -} -//===--------------------------------------------------------------------===// -// Prepare Batch -//===--------------------------------------------------------------------===// -struct WriteCSVBatchData : public PreparedBatchData { - //! The thread-local buffer to write data into - MemoryStream stream; -}; - -unique_ptr WriteCSVPrepareBatch(ClientContext &context, FunctionData &bind_data, - GlobalFunctionData &gstate, - unique_ptr collection) { - auto &csv_data = bind_data.Cast(); - - // create the cast chunk with VARCHAR types - vector types; - types.resize(csv_data.options.name_list.size(), LogicalType::VARCHAR); - DataChunk cast_chunk; - cast_chunk.Initialize(Allocator::Get(context), types); - - auto &original_types = collection->Types(); - auto expressions = CreateCastExpressions(csv_data, context, csv_data.options.name_list, original_types); - ExpressionExecutor executor(context, expressions); - - // write CSV chunks to the batch data - bool written_anything = false; - auto batch = make_uniq(); - for (auto &chunk : collection->Chunks()) { - WriteCSVChunkInternal(context, bind_data, cast_chunk, batch->stream, chunk, written_anything, executor); - } - return std::move(batch); -} - -//===--------------------------------------------------------------------===// -// Flush Batch -//===--------------------------------------------------------------------===// -void WriteCSVFlushBatch(ClientContext &context, FunctionData &bind_data, GlobalFunctionData &gstate, - PreparedBatchData &batch) { - auto &csv_batch = batch.Cast(); - auto &global_state = gstate.Cast(); - auto &csv_data = bind_data.Cast(); - auto &writer = csv_batch.stream; - global_state.WriteRows(writer.GetData(), writer.GetPosition(), csv_data.newline); - writer.Rewind(); -} - -//===--------------------------------------------------------------------===// -// File rotation -//===--------------------------------------------------------------------===// -bool WriteCSVRotateFiles(FunctionData &, const optional_idx &file_size_bytes) { - return file_size_bytes.IsValid(); -} - -bool WriteCSVRotateNextFile(GlobalFunctionData &gstate, FunctionData &, const optional_idx &file_size_bytes) { - auto &global_state = gstate.Cast(); - return global_state.FileSize() > file_size_bytes.GetIndex(); -} - -void CSVCopyFunction::RegisterFunction(BuiltinFunctions &set) { - CopyFunction info("csv"); - info.copy_to_bind = WriteCSVBind; - info.copy_to_initialize_local = WriteCSVInitializeLocal; - info.copy_to_initialize_global = WriteCSVInitializeGlobal; - info.copy_to_sink = WriteCSVSink; - info.copy_to_combine = WriteCSVCombine; - info.copy_to_finalize = WriteCSVFinalize; - info.execution_mode = WriteCSVExecutionMode; - info.prepare_batch = WriteCSVPrepareBatch; - info.flush_batch = WriteCSVFlushBatch; - info.rotate_files = WriteCSVRotateFiles; - info.rotate_next_file = WriteCSVRotateNextFile; - - info.copy_from_bind = ReadCSVBind; - info.copy_from_function = ReadCSVTableFunction::GetFunction(); - - info.extension = "csv"; - - set.AddFunction(info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/glob.cpp b/src/duckdb/src/function/table/glob.cpp deleted file mode 100644 index 736c2778e..000000000 --- a/src/duckdb/src/function/table/glob.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/common/file_system.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/common/multi_file_reader.hpp" - -namespace duckdb { - -struct GlobFunctionBindData : public TableFunctionData { - shared_ptr file_list; -}; - -static unique_ptr GlobFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - auto multi_file_reader = MultiFileReader::Create(input.table_function); - result->file_list = multi_file_reader->CreateFileList(context, input.inputs[0], FileGlobOptions::ALLOW_EMPTY); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("file"); - return std::move(result); -} - -struct GlobFunctionState : public GlobalTableFunctionState { - GlobFunctionState() { - } - - MultiFileListScanData file_list_scan; -}; - -static unique_ptr GlobFunctionInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto res = make_uniq(); - - bind_data.file_list->InitializeScan(res->file_list_scan); - - return std::move(res); -} - -static void GlobFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t count = 0; - while (count < STANDARD_VECTOR_SIZE) { - string file; - if (!bind_data.file_list->Scan(state.file_list_scan, file)) { - break; - } - output.data[0].SetValue(count++, file); - } - output.SetCardinality(count); -} - -void GlobTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction glob_function("glob", {LogicalType::VARCHAR}, GlobFunction, GlobFunctionBind, GlobFunctionInit); - set.AddFunction(MultiFileReader::CreateFunctionSet(glob_function)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/query_function.cpp b/src/duckdb/src/function/table/query_function.cpp deleted file mode 100644 index c44b1919c..000000000 --- a/src/duckdb/src/function/table/query_function.cpp +++ /dev/null @@ -1,85 +0,0 @@ -#include "duckdb/parser/parser.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/function/table/range.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/parser/tableref/subqueryref.hpp" - -namespace duckdb { - -static unique_ptr ParseSubquery(const string &query, const ParserOptions &options, const string &err_msg) { - Parser parser(options); - parser.ParseQuery(query); - if (parser.statements.size() != 1 || parser.statements[0]->type != StatementType::SELECT_STATEMENT) { - throw ParserException(err_msg); - } - auto select_stmt = unique_ptr_cast(std::move(parser.statements[0])); - return duckdb::make_uniq(std::move(select_stmt)); -} - -static string UnionTablesQuery(TableFunctionBindInput &input) { - for (auto &input_val : input.inputs) { - if (input_val.IsNull()) { - throw BinderException("Cannot use NULL as function argument"); - } - } - string result; - string by_name = (input.inputs.size() == 2 && - (input.inputs[1].type().id() == LogicalTypeId::BOOLEAN && input.inputs[1].GetValue())) - ? "BY NAME " - : ""; // 'by_name' variable defaults to false - if (input.inputs[0].type().id() == LogicalTypeId::VARCHAR) { - auto from_path = input.inputs[0].ToString(); - auto qualified_name = QualifiedName::Parse(from_path); - result += "FROM " + qualified_name.ToString(); - } else if (input.inputs[0].type() == LogicalType::LIST(LogicalType::VARCHAR)) { - string union_all_clause = " UNION ALL " + by_name + "FROM "; - const auto &children = ListValue::GetChildren(input.inputs[0]); - - if (children.empty()) { - throw InvalidInputException("Input list is empty"); - } - auto qualified_name = QualifiedName::Parse(children[0].ToString()); - result += "FROM " + qualified_name.ToString(); - for (size_t i = 1; i < children.size(); ++i) { - auto child = children[i].ToString(); - auto qualified_name = QualifiedName::Parse(child); - result += union_all_clause + qualified_name.ToString(); - } - } else { - throw InvalidInputException("Expected a table or a list with tables as input"); - } - return result; -} - -static unique_ptr QueryBindReplace(ClientContext &context, TableFunctionBindInput &input) { - auto query = input.inputs[0].ToString(); - auto subquery_ref = ParseSubquery(query, context.GetParserOptions(), "Expected a single SELECT statement"); - return std::move(subquery_ref); -} - -static unique_ptr TableBindReplace(ClientContext &context, TableFunctionBindInput &input) { - auto query = UnionTablesQuery(input); - auto subquery_ref = - ParseSubquery(query, context.GetParserOptions(), "Expected a table or a list with tables as input"); - return std::move(subquery_ref); -} - -void QueryTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction query("query", {LogicalType::VARCHAR}, nullptr, nullptr); - query.bind_replace = QueryBindReplace; - set.AddFunction(query); - - TableFunctionSet query_table("query_table"); - TableFunction query_table_function({LogicalType::VARCHAR}, nullptr, nullptr); - query_table_function.bind_replace = TableBindReplace; - query_table.AddFunction(query_table_function); - - query_table_function.arguments = {LogicalType::LIST(LogicalType::VARCHAR)}; - query_table.AddFunction(query_table_function); - // add by_name option - query_table_function.arguments.emplace_back(LogicalType::BOOLEAN); - query_table.AddFunction(query_table_function); - set.AddFunction(query_table); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/range.cpp b/src/duckdb/src/function/table/range.cpp deleted file mode 100644 index 17bcda81d..000000000 --- a/src/duckdb/src/function/table/range.cpp +++ /dev/null @@ -1,367 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/function/table/summary.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/types/timestamp.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// Range (integers) -//===--------------------------------------------------------------------===// -static void GetParameters(int64_t values[], idx_t value_count, hugeint_t &start, hugeint_t &end, hugeint_t &increment) { - if (value_count < 2) { - // single argument: only the end is specified - start = 0; - end = values[0]; - } else { - // two arguments: first two arguments are start and end - start = values[0]; - end = values[1]; - } - if (value_count < 3) { - increment = 1; - } else { - increment = values[2]; - } -} - -struct RangeFunctionBindData : public TableFunctionData { - explicit RangeFunctionBindData(const vector &inputs) : cardinality(0) { - int64_t values[3]; - for (idx_t i = 0; i < inputs.size(); i++) { - if (inputs[i].IsNull()) { - return; - } - values[i] = inputs[i].GetValue(); - } - hugeint_t start; - hugeint_t end; - hugeint_t increment; - GetParameters(values, inputs.size(), start, end, increment); - cardinality = Hugeint::Cast((end - start) / increment); - } - - idx_t cardinality; -}; - -template -static unique_ptr RangeFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::BIGINT); - if (GENERATE_SERIES) { - names.emplace_back("generate_series"); - } else { - names.emplace_back("range"); - } - if (input.inputs.empty() || input.inputs.size() > 3) { - return nullptr; - } - return make_uniq(input.inputs); -} - -struct RangeFunctionLocalState : public LocalTableFunctionState { - RangeFunctionLocalState() { - } - - bool initialized_row = false; - idx_t current_input_row = 0; - idx_t current_idx = 0; - - hugeint_t start; - hugeint_t end; - hugeint_t increment; -}; - -static unique_ptr RangeFunctionLocalInit(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state) { - return make_uniq(); -} - -template -static void GenerateRangeParameters(DataChunk &input, idx_t row_id, RangeFunctionLocalState &result) { - input.Flatten(); - for (idx_t c = 0; c < input.ColumnCount(); c++) { - if (FlatVector::IsNull(input.data[c], row_id)) { - result.start = GENERATE_SERIES ? 1 : 0; - result.end = 0; - result.increment = 1; - return; - } - } - int64_t values[3]; - for (idx_t c = 0; c < input.ColumnCount(); c++) { - if (c >= 3) { - throw InternalException("Unsupported parameter count for range function"); - } - values[c] = FlatVector::GetValue(input.data[c], row_id); - } - GetParameters(values, input.ColumnCount(), result.start, result.end, result.increment); - if (result.increment == 0) { - throw BinderException("interval cannot be 0!"); - } - if (result.start > result.end && result.increment > 0) { - throw BinderException("start is bigger than end, but increment is positive: cannot generate infinite series"); - } - if (result.start < result.end && result.increment < 0) { - throw BinderException("start is smaller than end, but increment is negative: cannot generate infinite series"); - } - if (GENERATE_SERIES) { - // generate_series has inclusive bounds on the RHS - if (result.increment < 0) { - result.end = result.end - 1; - } else { - result.end = result.end + 1; - } - } -} - -template -static OperatorResultType RangeFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - auto &state = data_p.local_state->Cast(); - while (true) { - if (!state.initialized_row) { - // initialize for the current input row - if (state.current_input_row >= input.size()) { - // ran out of rows - state.current_input_row = 0; - state.initialized_row = false; - return OperatorResultType::NEED_MORE_INPUT; - } - GenerateRangeParameters(input, state.current_input_row, state); - state.initialized_row = true; - state.current_idx = 0; - } - auto increment = state.increment; - auto end = state.end; - hugeint_t current_value = state.start + increment * UnsafeNumericCast(state.current_idx); - int64_t current_value_i64; - if (!Hugeint::TryCast(current_value, current_value_i64)) { - // move to next row - state.current_input_row++; - state.initialized_row = false; - continue; - } - int64_t offset = increment < 0 ? 1 : -1; - idx_t remaining = MinValue( - Hugeint::Cast((end - current_value + (increment + offset)) / increment), STANDARD_VECTOR_SIZE); - // set the result vector as a sequence vector - output.data[0].Sequence(current_value_i64, Hugeint::Cast(increment), remaining); - // increment the index pointer by the remaining count - state.current_idx += remaining; - output.SetCardinality(remaining); - if (remaining == 0) { - // move to next row - state.current_input_row++; - state.initialized_row = false; - continue; - } - return OperatorResultType::HAVE_MORE_OUTPUT; - } -} - -unique_ptr RangeCardinality(ClientContext &context, const FunctionData *bind_data_p) { - if (!bind_data_p) { - return nullptr; - } - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.cardinality, bind_data.cardinality); -} - -//===--------------------------------------------------------------------===// -// Range (timestamp) -//===--------------------------------------------------------------------===// -template -static unique_ptr RangeDateTimeBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.push_back(LogicalType::TIMESTAMP); - if (GENERATE_SERIES) { - names.emplace_back("generate_series"); - } else { - names.emplace_back("range"); - } - return nullptr; -} - -struct RangeDateTimeLocalState : public LocalTableFunctionState { - RangeDateTimeLocalState() { - } - - bool initialized_row = false; - idx_t current_input_row = 0; - timestamp_t current_state; - - timestamp_t start; - timestamp_t end; - interval_t increment; - bool inclusive_bound; - bool greater_than_check; - - bool Finished(timestamp_t current_value) const { - if (greater_than_check) { - if (inclusive_bound) { - return current_value > end; - } else { - return current_value >= end; - } - } else { - if (inclusive_bound) { - return current_value < end; - } else { - return current_value <= end; - } - } - } -}; - -template -static void GenerateRangeDateTimeParameters(DataChunk &input, idx_t row_id, RangeDateTimeLocalState &result) { - input.Flatten(); - - for (idx_t c = 0; c < input.ColumnCount(); c++) { - if (FlatVector::IsNull(input.data[c], row_id)) { - result.start = timestamp_t(0); - result.end = timestamp_t(0); - result.increment = interval_t(); - result.greater_than_check = true; - result.inclusive_bound = false; - return; - } - } - - result.start = FlatVector::GetValue(input.data[0], row_id); - result.end = FlatVector::GetValue(input.data[1], row_id); - result.increment = FlatVector::GetValue(input.data[2], row_id); - - // Infinities either cause errors or infinite loops, so just ban them - if (!Timestamp::IsFinite(result.start) || !Timestamp::IsFinite(result.end)) { - throw BinderException("RANGE with infinite bounds is not supported"); - } - - if (result.increment.months == 0 && result.increment.days == 0 && result.increment.micros == 0) { - throw BinderException("interval cannot be 0!"); - } - // all elements should point in the same direction - if (result.increment.months > 0 || result.increment.days > 0 || result.increment.micros > 0) { - if (result.increment.months < 0 || result.increment.days < 0 || result.increment.micros < 0) { - throw BinderException("RANGE with composite interval that has mixed signs is not supported"); - } - result.greater_than_check = true; - if (result.start > result.end) { - throw BinderException( - "start is bigger than end, but increment is positive: cannot generate infinite series"); - } - } else { - result.greater_than_check = false; - if (result.start < result.end) { - throw BinderException( - "start is smaller than end, but increment is negative: cannot generate infinite series"); - } - } - result.inclusive_bound = GENERATE_SERIES; -} - -static unique_ptr RangeDateTimeLocalInit(ExecutionContext &context, - TableFunctionInitInput &input, - GlobalTableFunctionState *global_state) { - return make_uniq(); -} - -template -static OperatorResultType RangeDateTimeFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - auto &state = data_p.local_state->Cast(); - while (true) { - if (!state.initialized_row) { - // initialize for the current input row - if (state.current_input_row >= input.size()) { - // ran out of rows - state.current_input_row = 0; - state.initialized_row = false; - return OperatorResultType::NEED_MORE_INPUT; - } - GenerateRangeDateTimeParameters(input, state.current_input_row, state); - state.initialized_row = true; - state.current_state = state.start; - } - idx_t size = 0; - auto data = FlatVector::GetData(output.data[0]); - while (true) { - if (state.Finished(state.current_state)) { - break; - } - if (size >= STANDARD_VECTOR_SIZE) { - break; - } - data[size++] = state.current_state; - state.current_state = - AddOperator::Operation(state.current_state, state.increment); - } - if (size == 0) { - // move to next row - state.current_input_row++; - state.initialized_row = false; - continue; - } - output.SetCardinality(size); - return OperatorResultType::HAVE_MORE_OUTPUT; - } -} - -void RangeTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet range("range"); - - TableFunction range_function({LogicalType::BIGINT}, nullptr, RangeFunctionBind, nullptr, - RangeFunctionLocalInit); - range_function.in_out_function = RangeFunction; - range_function.cardinality = RangeCardinality; - - // single argument range: (end) - implicit start = 0 and increment = 1 - range.AddFunction(range_function); - // two arguments range: (start, end) - implicit increment = 1 - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; - range.AddFunction(range_function); - // three arguments range: (start, end, increment) - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; - range.AddFunction(range_function); - TableFunction range_in_out({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, nullptr, - RangeDateTimeBind, nullptr, RangeDateTimeLocalInit); - range_in_out.in_out_function = RangeDateTimeFunction; - range.AddFunction(range_in_out); - set.AddFunction(range); - // generate_series: similar to range, but inclusive instead of exclusive bounds on the RHS - TableFunctionSet generate_series("generate_series"); - range_function.bind = RangeFunctionBind; - range_function.in_out_function = RangeFunction; - range_function.arguments = {LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - range_function.arguments = {LogicalType::BIGINT, LogicalType::BIGINT, LogicalType::BIGINT}; - generate_series.AddFunction(range_function); - TableFunction generate_series_in_out({LogicalType::TIMESTAMP, LogicalType::TIMESTAMP, LogicalType::INTERVAL}, - nullptr, RangeDateTimeBind, nullptr, RangeDateTimeLocalInit); - generate_series_in_out.in_out_function = RangeDateTimeFunction; - generate_series.AddFunction(generate_series_in_out); - set.AddFunction(generate_series); -} - -void BuiltinFunctions::RegisterTableFunctions() { - CheckpointFunction::RegisterFunction(*this); - GlobTableFunction::RegisterFunction(*this); - RangeTableFunction::RegisterFunction(*this); - RepeatTableFunction::RegisterFunction(*this); - SummaryTableFunction::RegisterFunction(*this); - UnnestTableFunction::RegisterFunction(*this); - RepeatRowTableFunction::RegisterFunction(*this); - CSVSnifferFunction::RegisterFunction(*this); - ReadBlobFunction::RegisterFunction(*this); - ReadTextFunction::RegisterFunction(*this); - QueryTableFunction::RegisterFunction(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/read_csv.cpp b/src/duckdb/src/function/table/read_csv.cpp deleted file mode 100644 index f01934e7b..000000000 --- a/src/duckdb/src/function/table/read_csv.cpp +++ /dev/null @@ -1,465 +0,0 @@ -#include "duckdb/function/table/read_csv.hpp" - -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/union_by_name.hpp" -#include "duckdb/execution/operator/csv_scanner/global_csv_state.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_error.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/persistent/csv_rejects_table.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_file_scanner.hpp" -#include "duckdb/execution/operator/csv_scanner/base_scanner.hpp" - -#include "duckdb/execution/operator/csv_scanner/string_value_scanner.hpp" - -#include -#include "duckdb/execution/operator/csv_scanner/csv_schema.hpp" - -namespace duckdb { - -unique_ptr ReadCSV::OpenCSV(const string &file_path, const CSVReaderOptions &options, - ClientContext &context) { - auto &fs = FileSystem::GetFileSystem(context); - auto &allocator = BufferAllocator::Get(context); - auto &db_config = DBConfig::GetConfig(context); - return CSVFileHandle::OpenFile(db_config, fs, allocator, file_path, options); -} - -ReadCSVData::ReadCSVData() { -} - -void ReadCSVData::FinalizeRead(ClientContext &context) { - BaseCSVData::Finalize(); -} - -//! Function to do schema discovery over one CSV file or a list/glob of CSV files -void SchemaDiscovery(ClientContext &context, ReadCSVData &result, CSVReaderOptions &options, - vector &return_types, vector &names, MultiFileList &multi_file_list) { - vector schemas; - const auto option_og = options; - - const auto file_paths = multi_file_list.GetAllFiles(); - - // Here what we want to do is to sniff a given number of lines, if we have many files, we might go through them - // to reach the number of lines. - const idx_t required_number_of_lines = options.sniff_size * options.sample_size_chunks; - - idx_t total_number_of_rows = 0; - idx_t current_file = 0; - options.file_path = file_paths[current_file]; - - result.buffer_manager = make_shared_ptr(context, options, options.file_path, 0, false); - { - CSVSniffer sniffer(options, result.buffer_manager, CSVStateMachineCache::Get(context)); - auto sniffer_result = sniffer.SniffCSV(); - idx_t rows_read = sniffer.LinesSniffed() - - (options.dialect_options.skip_rows.GetValue() + options.dialect_options.header.GetValue()); - - schemas.emplace_back(sniffer_result.names, sniffer_result.return_types, file_paths[0], rows_read, - result.buffer_manager->GetBuffer(0)->actual_size == 0); - total_number_of_rows += sniffer.LinesSniffed(); - } - - // We do a copy of the options to not pollute the options of the first file. - constexpr idx_t max_files_to_sniff = 10; - idx_t files_to_sniff = file_paths.size() > max_files_to_sniff ? max_files_to_sniff : file_paths.size(); - while (total_number_of_rows < required_number_of_lines && current_file + 1 < files_to_sniff) { - auto option_copy = option_og; - current_file++; - option_copy.file_path = file_paths[current_file]; - auto buffer_manager = - make_shared_ptr(context, option_copy, option_copy.file_path, current_file, false); - // TODO: We could cache the sniffer to be reused during scanning. Currently that's an exercise left to the - // reader - CSVSniffer sniffer(option_copy, buffer_manager, CSVStateMachineCache::Get(context)); - auto sniffer_result = sniffer.SniffCSV(); - idx_t rows_read = sniffer.LinesSniffed() - (option_copy.dialect_options.skip_rows.GetValue() + - option_copy.dialect_options.header.GetValue()); - if (buffer_manager->GetBuffer(0)->actual_size == 0) { - schemas.emplace_back(true); - } else { - schemas.emplace_back(sniffer_result.names, sniffer_result.return_types, option_copy.file_path, rows_read); - } - total_number_of_rows += sniffer.LinesSniffed(); - } - - // We might now have multiple schemas, we need to go through them to define the one true schema - CSVSchema best_schema; - for (auto &schema : schemas) { - if (best_schema.Empty()) { - // A schema is bettah than no schema - best_schema = schema; - } else if (best_schema.GetRowsRead() == 0) { - // If the best-schema has no data-rows, that's easy, we just take the new schema - best_schema = schema; - } else if (schema.GetRowsRead() != 0) { - // We might have conflicting-schemas, we must merge them - best_schema.MergeSchemas(schema, options.null_padding); - } - } - - if (names.empty()) { - names = best_schema.GetNames(); - return_types = best_schema.GetTypes(); - } - result.csv_types = return_types; - result.csv_names = names; -} - -static unique_ptr ReadCSVBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - auto result = make_uniq(); - auto &options = result->options; - const auto multi_file_reader = MultiFileReader::Create(input.table_function); - const auto multi_file_list = multi_file_reader->CreateFileList(context, input.inputs[0]); - if (multi_file_list->GetTotalFileCount() > 1) { - options.multi_file_reader = true; - } - options.FromNamedParameters(input.named_parameters, context); - - options.file_options.AutoDetectHivePartitioning(*multi_file_list, context); - options.Verify(); - if (!options.file_options.union_by_name) { - if (options.auto_detect) { - SchemaDiscovery(context, *result, options, return_types, names, *multi_file_list); - } else { - // If we are not running the sniffer, the columns must be set! - if (!options.columns_set) { - throw BinderException("read_csv requires columns to be specified through the 'columns' option. Use " - "read_csv_auto or set read_csv(..., " - "AUTO_DETECT=TRUE) to automatically guess columns."); - } - names = options.name_list; - return_types = options.sql_type_list; - } - D_ASSERT(return_types.size() == names.size()); - result->options.dialect_options.num_cols = names.size(); - - multi_file_reader->BindOptions(options.file_options, *multi_file_list, return_types, names, - result->reader_bind); - } else { - result->reader_bind = multi_file_reader->BindUnionReader(context, return_types, names, - *multi_file_list, *result, options); - if (result->union_readers.size() > 1) { - for (idx_t i = 0; i < result->union_readers.size(); i++) { - result->column_info.emplace_back(result->union_readers[i]->names, result->union_readers[i]->types); - } - } - if (!options.sql_types_per_column.empty()) { - const auto exception = CSVError::ColumnTypesError(options.sql_types_per_column, names); - if (!exception.error_message.empty()) { - throw BinderException(exception.error_message); - } - for (idx_t i = 0; i < names.size(); i++) { - auto it = options.sql_types_per_column.find(names[i]); - if (it != options.sql_types_per_column.end()) { - return_types[i] = options.sql_type_list[it->second]; - } - } - } - } - - result->csv_types = return_types; - result->csv_names = names; - result->return_types = return_types; - result->return_names = names; - if (!options.force_not_null_names.empty()) { - // Let's first check all column names match - duckdb::unordered_set column_names; - for (auto &name : names) { - column_names.insert(name); - } - for (auto &force_name : options.force_not_null_names) { - if (column_names.find(force_name) == column_names.end()) { - throw BinderException("\"force_not_null\" expected to find %s, but it was not found in the table", - force_name); - } - } - D_ASSERT(options.force_not_null.empty()); - for (idx_t i = 0; i < names.size(); i++) { - if (options.force_not_null_names.find(names[i]) != options.force_not_null_names.end()) { - options.force_not_null.push_back(true); - } else { - options.force_not_null.push_back(false); - } - } - } - - // TODO: make the CSV reader use MultiFileList throughout, instead of converting to vector - result->files = multi_file_list->GetAllFiles(); - result->Finalize(); - return std::move(result); -} - -//===--------------------------------------------------------------------===// -// Read CSV Local State -//===--------------------------------------------------------------------===// -struct CSVLocalState : public LocalTableFunctionState { -public: - explicit CSVLocalState(unique_ptr csv_reader_p) : csv_reader(std::move(csv_reader_p)) { - } - - //! The CSV reader - unique_ptr csv_reader; - bool done = false; -}; - -//===--------------------------------------------------------------------===// -// Read CSV Functions -//===--------------------------------------------------------------------===// -static unique_ptr ReadCSVInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - - // Create the temporary rejects table - if (bind_data.options.store_rejects.GetValue()) { - CSVRejectsTable::GetOrCreate(context, bind_data.options.rejects_scan_name.GetValue(), - bind_data.options.rejects_table_name.GetValue()) - ->InitializeTable(context, bind_data); - } - if (bind_data.files.empty()) { - // This can happen when a filename based filter pushdown has eliminated all possible files for this scan. - return nullptr; - } - return make_uniq(context, bind_data.buffer_manager, bind_data.options, - context.db->NumberOfThreads(), bind_data.files, input.column_indexes, bind_data); -} - -unique_ptr ReadCSVInitLocal(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state_p) { - if (!global_state_p) { - return nullptr; - } - auto &global_state = global_state_p->Cast(); - if (global_state.IsDone()) { - // nothing to do - return nullptr; - } - auto csv_scanner = global_state.Next(nullptr); - if (!csv_scanner) { - global_state.DecrementThread(); - } - return make_uniq(std::move(csv_scanner)); -} - -static void ReadCSVFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - if (!data_p.global_state) { - return; - } - auto &csv_global_state = data_p.global_state->Cast(); - if (!data_p.local_state) { - return; - } - auto &csv_local_state = data_p.local_state->Cast(); - - if (!csv_local_state.csv_reader) { - // no csv_reader was set, this can happen when a filename-based filter has filtered out all possible files - return; - } - do { - if (output.size() != 0) { - MultiFileReader().FinalizeChunk(context, bind_data.reader_bind, - csv_local_state.csv_reader->csv_file_scan->reader_data, output, nullptr); - break; - } - if (csv_local_state.csv_reader->FinishedIterator()) { - csv_local_state.csv_reader = csv_global_state.Next(csv_local_state.csv_reader.get()); - if (!csv_local_state.csv_reader) { - csv_global_state.DecrementThread(); - break; - } - } - csv_local_state.csv_reader->Flush(output); - - } while (true); -} - -static OperatorPartitionData CSVReaderGetPartitionData(ClientContext &context, TableFunctionGetPartitionInput &input) { - if (input.partition_info.RequiresPartitionColumns()) { - throw InternalException("CSVReader::GetPartitionData: partition columns not supported"); - } - auto &data = input.local_state->Cast(); - return OperatorPartitionData(data.csv_reader->scanner_idx); -} - -void ReadCSVTableFunction::ReadCSVAddNamedParameters(TableFunction &table_function) { - table_function.named_parameters["sep"] = LogicalType::VARCHAR; - table_function.named_parameters["delim"] = LogicalType::VARCHAR; - table_function.named_parameters["quote"] = LogicalType::VARCHAR; - table_function.named_parameters["new_line"] = LogicalType::VARCHAR; - table_function.named_parameters["escape"] = LogicalType::VARCHAR; - table_function.named_parameters["nullstr"] = LogicalType::ANY; - table_function.named_parameters["columns"] = LogicalType::ANY; - table_function.named_parameters["auto_type_candidates"] = LogicalType::ANY; - table_function.named_parameters["header"] = LogicalType::BOOLEAN; - table_function.named_parameters["auto_detect"] = LogicalType::BOOLEAN; - table_function.named_parameters["sample_size"] = LogicalType::BIGINT; - table_function.named_parameters["all_varchar"] = LogicalType::BOOLEAN; - table_function.named_parameters["dateformat"] = LogicalType::VARCHAR; - table_function.named_parameters["timestampformat"] = LogicalType::VARCHAR; - table_function.named_parameters["normalize_names"] = LogicalType::BOOLEAN; - table_function.named_parameters["compression"] = LogicalType::VARCHAR; - table_function.named_parameters["skip"] = LogicalType::BIGINT; - table_function.named_parameters["max_line_size"] = LogicalType::VARCHAR; - table_function.named_parameters["maximum_line_size"] = LogicalType::VARCHAR; - table_function.named_parameters["ignore_errors"] = LogicalType::BOOLEAN; - table_function.named_parameters["store_rejects"] = LogicalType::BOOLEAN; - table_function.named_parameters["rejects_table"] = LogicalType::VARCHAR; - table_function.named_parameters["rejects_scan"] = LogicalType::VARCHAR; - table_function.named_parameters["rejects_limit"] = LogicalType::BIGINT; - table_function.named_parameters["force_not_null"] = LogicalType::LIST(LogicalType::VARCHAR); - table_function.named_parameters["buffer_size"] = LogicalType::UBIGINT; - table_function.named_parameters["decimal_separator"] = LogicalType::VARCHAR; - table_function.named_parameters["parallel"] = LogicalType::BOOLEAN; - table_function.named_parameters["null_padding"] = LogicalType::BOOLEAN; - table_function.named_parameters["allow_quoted_nulls"] = LogicalType::BOOLEAN; - table_function.named_parameters["column_types"] = LogicalType::ANY; - table_function.named_parameters["dtypes"] = LogicalType::ANY; - table_function.named_parameters["types"] = LogicalType::ANY; - table_function.named_parameters["names"] = LogicalType::LIST(LogicalType::VARCHAR); - table_function.named_parameters["column_names"] = LogicalType::LIST(LogicalType::VARCHAR); - table_function.named_parameters["comment"] = LogicalType::VARCHAR; - table_function.named_parameters["encoding"] = LogicalType::VARCHAR; - table_function.named_parameters["rfc_4180"] = LogicalType::BOOLEAN; - - MultiFileReader::AddParameters(table_function); -} - -double CSVReaderProgress(ClientContext &context, const FunctionData *bind_data_p, - const GlobalTableFunctionState *global_state) { - if (!global_state) { - return 0; - } - auto &bind_data = bind_data_p->Cast(); - auto &data = global_state->Cast(); - return data.GetProgress(bind_data); -} - -void CSVComplexFilterPushdown(ClientContext &context, LogicalGet &get, FunctionData *bind_data_p, - vector> &filters) { - auto &data = bind_data_p->Cast(); - SimpleMultiFileList file_list(data.files); - MultiFilePushdownInfo info(get); - auto filtered_list = - MultiFileReader().ComplexFilterPushdown(context, file_list, data.options.file_options, info, filters); - if (filtered_list) { - data.files = filtered_list->GetAllFiles(); - MultiFileReader::PruneReaders(data, file_list); - } else { - data.files = file_list.GetAllFiles(); - } -} - -unique_ptr CSVReaderCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - // determined through the scientific method as the average amount of rows in a CSV file - idx_t per_file_cardinality = 42; - if (bind_data.buffer_manager && bind_data.buffer_manager->file_handle) { - auto estimated_row_width = (bind_data.csv_types.size() * 5); - per_file_cardinality = bind_data.buffer_manager->file_handle->FileSize() / estimated_row_width; - } - return make_uniq(bind_data.files.size() * per_file_cardinality); -} - -static void CSVReaderSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const TableFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "extra_info", function.extra_info); - serializer.WriteProperty(101, "csv_data", &bind_data); -} - -static unique_ptr CSVReaderDeserialize(Deserializer &deserializer, TableFunction &function) { - unique_ptr result; - deserializer.ReadProperty(100, "extra_info", function.extra_info); - deserializer.ReadProperty(101, "csv_data", result); - return std::move(result); -} - -void PushdownTypeToCSVScanner(ClientContext &context, optional_ptr bind_data, - const unordered_map &new_column_types) { - auto &csv_bind = bind_data->Cast(); - for (auto &type : new_column_types) { - csv_bind.csv_types[type.first] = type.second; - csv_bind.return_types[type.first] = type.second; - } -} - -TableFunction ReadCSVTableFunction::GetFunction() { - TableFunction read_csv("read_csv", {LogicalType::VARCHAR}, ReadCSVFunction, ReadCSVBind, ReadCSVInitGlobal, - ReadCSVInitLocal); - read_csv.table_scan_progress = CSVReaderProgress; - read_csv.pushdown_complex_filter = CSVComplexFilterPushdown; - read_csv.serialize = CSVReaderSerialize; - read_csv.deserialize = CSVReaderDeserialize; - read_csv.get_partition_data = CSVReaderGetPartitionData; - read_csv.cardinality = CSVReaderCardinality; - read_csv.projection_pushdown = true; - read_csv.type_pushdown = PushdownTypeToCSVScanner; - ReadCSVAddNamedParameters(read_csv); - return read_csv; -} - -TableFunction ReadCSVTableFunction::GetAutoFunction() { - auto read_csv_auto = ReadCSVTableFunction::GetFunction(); - read_csv_auto.name = "read_csv_auto"; - read_csv_auto.bind = ReadCSVBind; - return read_csv_auto; -} - -void ReadCSVTableFunction::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetFunction())); - set.AddFunction(MultiFileReader::CreateFunctionSet(ReadCSVTableFunction::GetAutoFunction())); -} - -unique_ptr ReadCSVReplacement(ClientContext &context, ReplacementScanInput &input, - optional_ptr data) { - auto table_name = ReplacementScan::GetFullPath(input); - auto lower_name = StringUtil::Lower(table_name); - // remove any compression - if (StringUtil::EndsWith(lower_name, CompressionExtensionFromType(FileCompressionType::GZIP))) { - lower_name = lower_name.substr(0, lower_name.size() - 3); - } else if (StringUtil::EndsWith(lower_name, CompressionExtensionFromType(FileCompressionType::ZSTD))) { - if (!Catalog::TryAutoLoad(context, "parquet")) { - throw MissingExtensionException("parquet extension is required for reading zst compressed file"); - } - lower_name = lower_name.substr(0, lower_name.size() - 4); - } - if (!StringUtil::EndsWith(lower_name, ".csv") && !StringUtil::Contains(lower_name, ".csv?") && - !StringUtil::EndsWith(lower_name, ".tsv") && !StringUtil::Contains(lower_name, ".tsv?")) { - return nullptr; - } - auto table_function = make_uniq(); - vector> children; - children.push_back(make_uniq(Value(table_name))); - table_function->function = make_uniq("read_csv_auto", std::move(children)); - - if (!FileSystem::HasGlob(table_name)) { - auto &fs = FileSystem::GetFileSystem(context); - table_function->alias = fs.ExtractBaseName(table_name); - } - - return std::move(table_function); -} - -void BuiltinFunctions::RegisterReadFunctions() { - CSVCopyFunction::RegisterFunction(*this); - ReadCSVTableFunction::RegisterFunction(*this); - auto &config = DBConfig::GetConfig(*transaction.db); - config.replacement_scans.emplace_back(ReadCSVReplacement); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/read_file.cpp b/src/duckdb/src/function/table/read_file.cpp deleted file mode 100644 index 158e89fd0..000000000 --- a/src/duckdb/src/function/table/read_file.cpp +++ /dev/null @@ -1,266 +0,0 @@ -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/function/table/range.hpp" -#include "utf8proc_wrapper.hpp" - -namespace duckdb { - -struct ReadBlobOperation { - static constexpr const char *NAME = "read_blob"; - static constexpr const char *FILE_TYPE = "blob"; - - static inline LogicalType TYPE() { - return LogicalType::BLOB; - } - - static inline void VERIFY(const string &, const string_t &) { - } -}; - -struct ReadTextOperation { - static constexpr const char *NAME = "read_text"; - static constexpr const char *FILE_TYPE = "text"; - - static inline LogicalType TYPE() { - return LogicalType::VARCHAR; - } - - static inline void VERIFY(const string &filename, const string_t &content) { - if (Utf8Proc::Analyze(content.GetData(), content.GetSize()) == UnicodeType::INVALID) { - throw InvalidInputException( - "read_text: could not read content of file '%s' as valid UTF-8 encoded text. You " - "may want to use read_blob instead.", - filename); - } - } -}; - -//------------------------------------------------------------------------------ -// Bind -//------------------------------------------------------------------------------ -struct ReadFileBindData : public TableFunctionData { - vector files; - - static constexpr const idx_t FILE_NAME_COLUMN = 0; - static constexpr const idx_t FILE_CONTENT_COLUMN = 1; - static constexpr const idx_t FILE_SIZE_COLUMN = 2; - static constexpr const idx_t FILE_LAST_MODIFIED_COLUMN = 3; -}; - -template -static unique_ptr ReadFileBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - - auto multi_file_reader = MultiFileReader::Create(input.table_function); - result->files = - multi_file_reader->CreateFileList(context, input.inputs[0], FileGlobOptions::ALLOW_EMPTY)->GetAllFiles(); - - return_types.push_back(LogicalType::VARCHAR); - names.push_back("filename"); - return_types.push_back(OP::TYPE()); - names.push_back("content"); - return_types.push_back(LogicalType::BIGINT); - names.push_back("size"); - return_types.push_back(LogicalType::TIMESTAMP_TZ); - names.push_back("last_modified"); - - return std::move(result); -} - -//------------------------------------------------------------------------------ -// Global state -//------------------------------------------------------------------------------ -struct ReadFileGlobalState : public GlobalTableFunctionState { - ReadFileGlobalState() : current_file_idx(0) { - } - - atomic current_file_idx; - vector files; - vector column_ids; - bool requires_file_open = false; -}; - -static unique_ptr ReadFileInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - - result->files = bind_data.files; - result->current_file_idx = 0; - result->column_ids = input.column_ids; - - for (const auto &column_id : input.column_ids) { - // For everything except the 'file' name column, we need to open the file - if (column_id != ReadFileBindData::FILE_NAME_COLUMN && column_id != COLUMN_IDENTIFIER_ROW_ID) { - result->requires_file_open = true; - break; - } - } - - return std::move(result); -} - -//------------------------------------------------------------------------------ -// Execute -//------------------------------------------------------------------------------ -static void AssertMaxFileSize(const string &file_name, idx_t file_size) { - const auto max_file_size = NumericLimits::Maximum(); - if (file_size > max_file_size) { - auto max_byte_size_format = StringUtil::BytesToHumanReadableString(max_file_size); - auto file_byte_size_format = StringUtil::BytesToHumanReadableString(file_size); - auto error_msg = StringUtil::Format("File '%s' size (%s) exceeds maximum allowed file (%s)", file_name.c_str(), - file_byte_size_format, max_byte_size_format); - throw InvalidInputException(error_msg); - } -} - -template -static void ReadFileExecute(ClientContext &context, TableFunctionInput &input, DataChunk &output) { - auto &bind_data = input.bind_data->Cast(); - auto &state = input.global_state->Cast(); - auto &fs = FileSystem::GetFileSystem(context); - - auto output_count = MinValue(STANDARD_VECTOR_SIZE, bind_data.files.size() - state.current_file_idx); - - // We utilize projection pushdown here to only read the file content if the 'data' column is requested - for (idx_t out_idx = 0; out_idx < output_count; out_idx++) { - // Add the file name to the output - auto &file_name = bind_data.files[state.current_file_idx + out_idx]; - - unique_ptr file_handle = nullptr; - - // Given the columns requested, do we even need to open the file? - if (state.requires_file_open) { - file_handle = fs.OpenFile(file_name, FileFlags::FILE_FLAGS_READ); - } - - for (idx_t col_idx = 0; col_idx < state.column_ids.size(); col_idx++) { - // We utilize projection pushdown to avoid potentially expensive fs operations. - auto proj_idx = state.column_ids[col_idx]; - if (proj_idx == COLUMN_IDENTIFIER_ROW_ID) { - continue; - } - try { - switch (proj_idx) { - case ReadFileBindData::FILE_NAME_COLUMN: { - auto &file_name_vector = output.data[col_idx]; - auto file_name_string = StringVector::AddString(file_name_vector, file_name); - FlatVector::GetData(file_name_vector)[out_idx] = file_name_string; - } break; - case ReadFileBindData::FILE_CONTENT_COLUMN: { - auto file_size_raw = file_handle->GetFileSize(); - AssertMaxFileSize(file_name, file_size_raw); - auto file_size = UnsafeNumericCast(file_size_raw); - auto &file_content_vector = output.data[col_idx]; - auto content_string = StringVector::EmptyString(file_content_vector, file_size_raw); - - auto remaining_bytes = UnsafeNumericCast(file_size); - - // Read in batches of 100mb - constexpr auto MAX_READ_SIZE = 100LL * 1024 * 1024; - while (remaining_bytes > 0) { - const auto bytes_to_read = MinValue(remaining_bytes, MAX_READ_SIZE); - const auto content_string_ptr = - content_string.GetDataWriteable() + (file_size - remaining_bytes); - const auto actually_read = - file_handle->Read(content_string_ptr, UnsafeNumericCast(bytes_to_read)); - if (actually_read == 0) { - // Uh oh, random EOF? - throw IOException("Failed to read file '%s' at offset %lu, unexpected EOF", file_name, - file_size - remaining_bytes); - } - remaining_bytes -= actually_read; - } - - content_string.Finalize(); - - OP::VERIFY(file_name, content_string); - - FlatVector::GetData(file_content_vector)[out_idx] = content_string; - } break; - case ReadFileBindData::FILE_SIZE_COLUMN: { - auto &file_size_vector = output.data[col_idx]; - FlatVector::GetData(file_size_vector)[out_idx] = - NumericCast(file_handle->GetFileSize()); - } break; - case ReadFileBindData::FILE_LAST_MODIFIED_COLUMN: { - auto &last_modified_vector = output.data[col_idx]; - // This can sometimes fail (e.g. httpfs file system cant always parse the last modified time - // correctly) - try { - auto timestamp_seconds = Timestamp::FromEpochSeconds(fs.GetLastModifiedTime(*file_handle)); - FlatVector::GetData(last_modified_vector)[out_idx] = - timestamp_tz_t(timestamp_seconds); - } catch (std::exception &ex) { - ErrorData error(ex); - if (error.Type() == ExceptionType::CONVERSION) { - FlatVector::SetNull(last_modified_vector, out_idx, true); - } else { - throw; - } - } - } break; - default: - throw InternalException("Unsupported column index for read_file"); - } - } - // Filesystems are not required to support all operations, so we just set the column to NULL if not - // implemented - catch (std::exception &ex) { - ErrorData error(ex); - if (error.Type() == ExceptionType::NOT_IMPLEMENTED) { - FlatVector::SetNull(output.data[col_idx], out_idx, true); - } else { - throw; - } - } - } - } - - state.current_file_idx += output_count; - output.SetCardinality(output_count); -} - -//------------------------------------------------------------------------------ -// Misc -//------------------------------------------------------------------------------ - -static double ReadFileProgress(ClientContext &context, const FunctionData *bind_data, - const GlobalTableFunctionState *gstate) { - auto &state = gstate->Cast(); - return static_cast(state.current_file_idx) / static_cast(state.files.size()); -} - -static unique_ptr ReadFileCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - auto result = make_uniq(); - result->has_max_cardinality = true; - result->max_cardinality = bind_data.files.size(); - result->has_estimated_cardinality = true; - result->estimated_cardinality = bind_data.files.size(); - return result; -} - -//------------------------------------------------------------------------------ -// Register -//------------------------------------------------------------------------------ -template -static TableFunction GetFunction() { - TableFunction func(OP::NAME, {LogicalType::VARCHAR}, ReadFileExecute, ReadFileBind, ReadFileInitGlobal); - func.table_scan_progress = ReadFileProgress; - func.cardinality = ReadFileCardinality; - func.projection_pushdown = true; - return func; -} - -void ReadBlobFunction::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(MultiFileReader::CreateFunctionSet(GetFunction())); -} - -void ReadTextFunction::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(MultiFileReader::CreateFunctionSet(GetFunction())); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/repeat.cpp b/src/duckdb/src/function/table/repeat.cpp deleted file mode 100644 index 08fdc0086..000000000 --- a/src/duckdb/src/function/table/repeat.cpp +++ /dev/null @@ -1,61 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -struct RepeatFunctionData : public TableFunctionData { - RepeatFunctionData(Value value, idx_t target_count) : value(std::move(value)), target_count(target_count) { - } - - Value value; - idx_t target_count; -}; - -struct RepeatOperatorData : public GlobalTableFunctionState { - RepeatOperatorData() : current_count(0) { - } - idx_t current_count; -}; - -static unique_ptr RepeatBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - // the repeat function returns the type of the first argument - auto &inputs = input.inputs; - return_types.push_back(inputs[0].type()); - names.push_back(inputs[0].ToString()); - if (inputs[1].IsNull()) { - throw BinderException("Repeat second parameter cannot be NULL"); - } - auto repeat_count = inputs[1].GetValue(); - if (repeat_count < 0) { - throw BinderException("Repeat second parameter cannot be be less than 0"); - } - return make_uniq(inputs[0], NumericCast(repeat_count)); -} - -static unique_ptr RepeatInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RepeatFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); - output.data[0].Reference(bind_data.value); - output.SetCardinality(remaining); - state.current_count += remaining; -} - -static unique_ptr RepeatCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.target_count, bind_data.target_count); -} - -void RepeatTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction repeat("repeat", {LogicalType::ANY, LogicalType::BIGINT}, RepeatFunction, RepeatBind, RepeatInit); - repeat.cardinality = RepeatCardinality; - set.AddFunction(repeat); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/repeat_row.cpp b/src/duckdb/src/function/table/repeat_row.cpp deleted file mode 100644 index f68bf7cae..000000000 --- a/src/duckdb/src/function/table/repeat_row.cpp +++ /dev/null @@ -1,67 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/common/algorithm.hpp" - -namespace duckdb { - -struct RepeatRowFunctionData : public TableFunctionData { - RepeatRowFunctionData(vector values, idx_t target_count) - : values(std::move(values)), target_count(target_count) { - } - - const vector values; - idx_t target_count; -}; - -struct RepeatRowOperatorData : public GlobalTableFunctionState { - RepeatRowOperatorData() : current_count(0) { - } - idx_t current_count; -}; - -static unique_ptr RepeatRowBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto &inputs = input.inputs; - for (idx_t input_idx = 0; input_idx < inputs.size(); input_idx++) { - return_types.push_back(inputs[input_idx].type()); - names.push_back("column" + std::to_string(input_idx)); - } - auto entry = input.named_parameters.find("num_rows"); - if (entry == input.named_parameters.end()) { - throw BinderException("repeat_rows requires num_rows to be specified"); - } - if (inputs.empty()) { - throw BinderException("repeat_rows requires at least one column to be specified"); - } - return make_uniq(inputs, NumericCast(entry->second.GetValue())); -} - -static unique_ptr RepeatRowInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void RepeatRowFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - - idx_t remaining = MinValue(bind_data.target_count - state.current_count, STANDARD_VECTOR_SIZE); - for (idx_t val_idx = 0; val_idx < bind_data.values.size(); val_idx++) { - output.data[val_idx].Reference(bind_data.values[val_idx]); - } - output.SetCardinality(remaining); - state.current_count += remaining; -} - -static unique_ptr RepeatRowCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return make_uniq(bind_data.target_count, bind_data.target_count); -} - -void RepeatRowTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction repeat_row("repeat_row", {}, RepeatRowFunction, RepeatRowBind, RepeatRowInit); - repeat_row.varargs = LogicalType::ANY; - repeat_row.named_parameters["num_rows"] = LogicalType::BIGINT; - repeat_row.cardinality = RepeatRowCardinality; - set.AddFunction(repeat_row); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/sniff_csv.cpp b/src/duckdb/src/function/table/sniff_csv.cpp deleted file mode 100644 index 5c72aabb9..000000000 --- a/src/duckdb/src/function/table/sniff_csv.cpp +++ /dev/null @@ -1,319 +0,0 @@ -#include "duckdb/function/built_in_functions.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_reader_options.hpp" -#include "duckdb/common/types/data_chunk.hpp" -#include "duckdb/execution/operator/csv_scanner/sniffer/csv_sniffer.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_buffer_manager.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/function/table/range.hpp" -#include "duckdb/execution/operator/csv_scanner/csv_file_handle.hpp" -#include "duckdb/function/table/read_csv.hpp" - -namespace duckdb { - -struct CSVSniffFunctionData : public TableFunctionData { - CSVSniffFunctionData() { - } - string path; - // The CSV reader options - CSVReaderOptions options; - // Return Types of CSV (If given by the user) - vector return_types_csv; - // Column Names of CSV (If given by the user) - vector names_csv; - // If we want to force the match of the sniffer types - bool force_match = true; -}; - -struct CSVSniffGlobalState : public GlobalTableFunctionState { - CSVSniffGlobalState() { - } - bool done = false; -}; - -static unique_ptr CSVSniffInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static unique_ptr CSVSniffBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - if (input.inputs[0].IsNull()) { - throw BinderException("sniff_csv cannot take NULL as a file path parameter"); - } - result->path = input.inputs[0].ToString(); - auto it = input.named_parameters.find("auto_detect"); - if (it != input.named_parameters.end()) { - if (it->second.IsNull()) { - throw BinderException("\"%s\" expects a non-null boolean value (e.g. TRUE or 1)", it->first); - } - if (!it->second.GetValue()) { - throw InvalidInputException("sniff_csv function does not accept auto_detect variable set to false"); - } - // otherwise remove it - input.named_parameters.erase("auto_detect"); - } - - // If we want to force the match of the sniffer - it = input.named_parameters.find("force_match"); - if (it != input.named_parameters.end()) { - result->force_match = it->second.GetValue(); - input.named_parameters.erase("force_match"); - } - result->options.FromNamedParameters(input.named_parameters, context); - result->options.Verify(); - - // We want to return the whole CSV Configuration - // 1. Delimiter - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Delimiter"); - // 2. Quote - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Quote"); - // 3. Escape - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Escape"); - // 4. NewLine Delimiter - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("NewLineDelimiter"); - // 5. Comment - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Comment"); - // 6. Skip Rows - return_types.emplace_back(LogicalType::UINTEGER); - names.emplace_back("SkipRows"); - // 7. Has Header - return_types.emplace_back(LogicalType::BOOLEAN); - names.emplace_back("HasHeader"); - // 8. List> - child_list_t struct_children {{"name", LogicalType::VARCHAR}, {"type", LogicalType::VARCHAR}}; - auto list_child = LogicalType::STRUCT(struct_children); - return_types.emplace_back(LogicalType::LIST(list_child)); - names.emplace_back("Columns"); - // 9. Date Format - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("DateFormat"); - // 10. Timestamp Format - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("TimestampFormat"); - // 11. CSV read function with all the options used - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("UserArguments"); - // 12. CSV read function with all the options used - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Prompt"); - return std::move(result); -} - -string FormatOptions(char opt) { - if (opt == '\'') { - return "''"; - } - if (opt == '\0') { - return ""; - } - string result; - result += opt; - return result; -} - -string FormatOptions(string opt) { - if (opt.size() == 1) { - return FormatOptions(opt[0]); - } - return opt; -} - -static void CSVSniffFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &global_state = data_p.global_state->Cast(); - // Are we done? - if (global_state.done) { - return; - } - const CSVSniffFunctionData &data = data_p.bind_data->Cast(); - auto &fs = duckdb::FileSystem::GetFileSystem(context); - - auto paths = fs.GlobFiles(data.path, context, FileGlobOptions::DISALLOW_EMPTY); - if (paths.size() > 1) { - throw NotImplementedException("sniff_csv does not operate on more than one file yet"); - } - - // We must run the sniffer. - auto sniffer_options = data.options; - sniffer_options.file_path = paths[0]; - - auto buffer_manager = make_shared_ptr(context, sniffer_options, sniffer_options.file_path, 0); - if (sniffer_options.name_list.empty()) { - sniffer_options.name_list = data.names_csv; - } - - if (sniffer_options.sql_type_list.empty()) { - sniffer_options.sql_type_list = data.return_types_csv; - } - CSVSniffer sniffer(sniffer_options, buffer_manager, CSVStateMachineCache::Get(context)); - auto sniffer_result = sniffer.SniffCSV(data.force_match); - string str_opt; - string separator = ", "; - // Set output - output.SetCardinality(1); - - // 1. Delimiter - str_opt = sniffer_options.dialect_options.state_machine_options.delimiter.GetValue(); - output.SetValue(0, 0, str_opt); - // 2. Quote - str_opt = sniffer_options.dialect_options.state_machine_options.quote.GetValue(); - output.SetValue(1, 0, str_opt); - // 3. Escape - str_opt = sniffer_options.dialect_options.state_machine_options.escape.GetValue(); - output.SetValue(2, 0, str_opt); - // 4. NewLine Delimiter - auto new_line_identifier = sniffer_options.NewLineIdentifierToString(); - output.SetValue(3, 0, new_line_identifier); - // 5. Comment - str_opt = sniffer_options.dialect_options.state_machine_options.comment.GetValue(); - output.SetValue(4, 0, str_opt); - // 6. Skip Rows - output.SetValue(5, 0, Value::UINTEGER(NumericCast(sniffer_options.dialect_options.skip_rows.GetValue()))); - // 7. Has Header - auto has_header = Value::BOOLEAN(sniffer_options.dialect_options.header.GetValue()); - output.SetValue(6, 0, has_header); - // 8. List> {'col1': 'INTEGER', 'col2': 'VARCHAR'} - vector values; - std::ostringstream columns; - columns << "{"; - for (idx_t i = 0; i < sniffer_result.return_types.size(); i++) { - child_list_t struct_children {{"name", sniffer_result.names[i]}, - {"type", {sniffer_result.return_types[i].ToString()}}}; - values.emplace_back(Value::STRUCT(struct_children)); - columns << "'" << sniffer_result.names[i] << "': '" << sniffer_result.return_types[i].ToString() << "'"; - if (i != sniffer_result.return_types.size() - 1) { - columns << separator; - } - } - columns << "}"; - output.SetValue(7, 0, Value::LIST(values)); - // 9. Date Format - auto date_format = sniffer_options.dialect_options.date_format[LogicalType::DATE].GetValue(); - if (!date_format.Empty()) { - output.SetValue(8, 0, date_format.format_specifier); - } else { - bool has_date = false; - for (auto &c_type : sniffer_result.return_types) { - // Must be ISO 8601 - if (c_type.id() == LogicalTypeId::DATE) { - output.SetValue(8, 0, Value("%Y-%m-%d")); - has_date = true; - } - } - if (!has_date) { - output.SetValue(8, 0, Value(nullptr)); - } - } - - // 10. Timestamp Format - auto timestamp_format = sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue(); - if (!timestamp_format.Empty()) { - output.SetValue(9, 0, timestamp_format.format_specifier); - } else { - output.SetValue(9, 0, Value(nullptr)); - } - - // 11. The Extra User Arguments - if (data.options.user_defined_parameters.empty()) { - output.SetValue(10, 0, Value()); - } else { - output.SetValue(10, 0, Value(data.options.user_defined_parameters)); - } - - // 12. csv_read string - std::ostringstream csv_read; - - // Base, Path and auto_detect=false - csv_read << "FROM read_csv('" << paths[0] << "'" << separator << "auto_detect=false" << separator; - // 10.1. Delimiter - if (!sniffer_options.dialect_options.state_machine_options.delimiter.IsSetByUser()) { - csv_read << "delim=" - << "'" << FormatOptions(sniffer_options.dialect_options.state_machine_options.delimiter.GetValue()) - << "'" << separator; - } - // 11.2. Quote - if (!sniffer_options.dialect_options.state_machine_options.quote.IsSetByUser()) { - csv_read << "quote=" - << "'" << FormatOptions(sniffer_options.dialect_options.state_machine_options.quote.GetValue()) << "'" - << separator; - } - // 11.3. Escape - if (!sniffer_options.dialect_options.state_machine_options.escape.IsSetByUser()) { - csv_read << "escape=" - << "'" << FormatOptions(sniffer_options.dialect_options.state_machine_options.escape.GetValue()) << "'" - << separator; - } - // 11.4. NewLine Delimiter - if (!sniffer_options.dialect_options.state_machine_options.new_line.IsSetByUser()) { - if (new_line_identifier != "mix") { - csv_read << "new_line=" - << "'" << new_line_identifier << "'" << separator; - } - } - // 11.5. Skip Rows - if (!sniffer_options.dialect_options.skip_rows.IsSetByUser()) { - csv_read << "skip=" << sniffer_options.dialect_options.skip_rows.GetValue() << separator; - } - - // 11.6. Comment - if (!sniffer_options.dialect_options.state_machine_options.comment.IsSetByUser()) { - csv_read << "comment=" - << "'" << FormatOptions(sniffer_options.dialect_options.state_machine_options.comment.GetValue()) - << "'" << separator; - } - - // 11.7. Has Header - if (!sniffer_options.dialect_options.header.IsSetByUser()) { - csv_read << "header=" << has_header << separator; - } - // 11.8. column={'col1': 'INTEGER', 'col2': 'VARCHAR'} - csv_read << "columns=" << columns.str(); - // 11.9. Date Format - if (!sniffer_options.dialect_options.date_format[LogicalType::DATE].IsSetByUser()) { - if (!sniffer_options.dialect_options.date_format[LogicalType::DATE].GetValue().format_specifier.empty()) { - csv_read << separator << "dateformat=" - << "'" - << sniffer_options.dialect_options.date_format[LogicalType::DATE].GetValue().format_specifier - << "'"; - } else { - for (auto &c_type : sniffer_result.return_types) { - // Must be ISO 8601 - if (c_type.id() == LogicalTypeId::DATE) { - csv_read << separator << "dateformat=" - << "'%Y-%m-%d'"; - break; - } - } - } - } - // 11.10. Timestamp Format - if (!sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].IsSetByUser()) { - if (!sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue().format_specifier.empty()) { - csv_read << separator << "timestampformat=" - << "'" - << sniffer_options.dialect_options.date_format[LogicalType::TIMESTAMP].GetValue().format_specifier - << "'"; - } - } - // 11.11 User Arguments - if (!data.options.user_defined_parameters.empty()) { - csv_read << separator << data.options.user_defined_parameters; - } - csv_read << ");"; - output.SetValue(11, 0, csv_read.str()); - global_state.done = true; -} - -void CSVSnifferFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction csv_sniffer("sniff_csv", {LogicalType::VARCHAR}, CSVSniffFunction, CSVSniffBind, CSVSniffInitGlobal); - // Accept same options as the actual csv reader - ReadCSVTableFunction::ReadCSVAddNamedParameters(csv_sniffer); - csv_sniffer.named_parameters["force_match"] = LogicalType::BOOLEAN; - set.AddFunction(csv_sniffer); -} -} // namespace duckdb diff --git a/src/duckdb/src/function/table/summary.cpp b/src/duckdb/src/function/table/summary.cpp deleted file mode 100644 index d6c4615e4..000000000 --- a/src/duckdb/src/function/table/summary.cpp +++ /dev/null @@ -1,52 +0,0 @@ -#include "duckdb/function/table/summary.hpp" -#include "duckdb/function/table_function.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/common/file_system.hpp" - -// this function makes not that much sense on its own but is a demo for table-parameter table-producing functions - -namespace duckdb { - -static unique_ptr SummaryFunctionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("summary"); - - for (idx_t i = 0; i < input.input_table_types.size(); i++) { - return_types.push_back(input.input_table_types[i]); - names.emplace_back(input.input_table_names[i]); - } - - return make_uniq(); -} - -static OperatorResultType SummaryFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - output.SetCardinality(input.size()); - - for (idx_t row_idx = 0; row_idx < input.size(); row_idx++) { - string summary_val = "["; - - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - summary_val += input.GetValue(col_idx, row_idx).ToString(); - if (col_idx < input.ColumnCount() - 1) { - summary_val += ", "; - } - } - summary_val += "]"; - output.SetValue(0, row_idx, Value(summary_val)); - } - for (idx_t col_idx = 0; col_idx < input.ColumnCount(); col_idx++) { - output.data[col_idx + 1].Reference(input.data[col_idx]); - } - return OperatorResultType::NEED_MORE_INPUT; -} - -void SummaryTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction summary_function("summary", {LogicalType::TABLE}, nullptr, SummaryFunctionBind); - summary_function.in_out_function = SummaryFunction; - set.AddFunction(summary_function); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_columns.cpp b/src/duckdb/src/function/table/system/duckdb_columns.cpp deleted file mode 100644 index d30ea5300..000000000 --- a/src/duckdb/src/function/table/system/duckdb_columns.cpp +++ /dev/null @@ -1,350 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/parser/constraints/not_null_constraint.hpp" - -#include - -namespace duckdb { - -struct DuckDBColumnsData : public GlobalTableFunctionState { - DuckDBColumnsData() : offset(0), column_offset(0) { - } - - vector> entries; - idx_t offset; - idx_t column_offset; -}; - -static unique_ptr DuckDBColumnsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("column_index"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("column_default"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("is_nullable"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("data_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("data_type_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("character_maximum_length"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_precision"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_precision_radix"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("numeric_scale"); - return_types.emplace_back(LogicalType::INTEGER); - - return nullptr; -} - -unique_ptr DuckDBColumnsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and views and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TABLE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - } - return std::move(result); -} - -class ColumnHelper { -public: - static unique_ptr Create(CatalogEntry &entry); - - virtual ~ColumnHelper() { - } - - virtual StandardEntry &Entry() = 0; - virtual idx_t NumColumns() = 0; - virtual const string &ColumnName(idx_t col) = 0; - virtual const LogicalType &ColumnType(idx_t col) = 0; - virtual const Value ColumnDefault(idx_t col) = 0; - virtual bool IsNullable(idx_t col) = 0; - virtual const Value ColumnComment(idx_t col) = 0; - - void WriteColumns(idx_t index, idx_t start_col, idx_t end_col, DataChunk &output); -}; - -class TableColumnHelper : public ColumnHelper { -public: - explicit TableColumnHelper(TableCatalogEntry &entry) : entry(entry) { - for (auto &constraint : entry.GetConstraints()) { - if (constraint->type == ConstraintType::NOT_NULL) { - auto ¬_null = constraint->Cast(); - not_null_cols.insert(not_null.index.index); - } - } - } - - StandardEntry &Entry() override { - return entry; - } - idx_t NumColumns() override { - return entry.GetColumns().LogicalColumnCount(); - } - const string &ColumnName(idx_t col) override { - return entry.GetColumn(LogicalIndex(col)).Name(); - } - const LogicalType &ColumnType(idx_t col) override { - return entry.GetColumn(LogicalIndex(col)).Type(); - } - const Value ColumnDefault(idx_t col) override { - auto &column = entry.GetColumn(LogicalIndex(col)); - if (column.Generated()) { - return Value(column.GeneratedExpression().ToString()); - } else if (column.HasDefaultValue()) { - return Value(column.DefaultValue().ToString()); - } - return Value(); - } - bool IsNullable(idx_t col) override { - return not_null_cols.find(col) == not_null_cols.end(); - } - const Value ColumnComment(idx_t col) override { - return entry.GetColumn(LogicalIndex(col)).Comment(); - } - -private: - TableCatalogEntry &entry; - std::set not_null_cols; -}; - -class ViewColumnHelper : public ColumnHelper { -public: - explicit ViewColumnHelper(ViewCatalogEntry &entry) : entry(entry) { - } - - StandardEntry &Entry() override { - return entry; - } - idx_t NumColumns() override { - return entry.types.size(); - } - const string &ColumnName(idx_t col) override { - return col < entry.aliases.size() ? entry.aliases[col] : entry.names[col]; - } - const LogicalType &ColumnType(idx_t col) override { - return entry.types[col]; - } - const Value ColumnDefault(idx_t col) override { - return Value(); - } - bool IsNullable(idx_t col) override { - return true; - } - const Value ColumnComment(idx_t col) override { - if (entry.column_comments.empty()) { - return Value(); - } - D_ASSERT(entry.column_comments.size() == entry.types.size()); - return entry.column_comments[col]; - } - -private: - ViewCatalogEntry &entry; -}; - -unique_ptr ColumnHelper::Create(CatalogEntry &entry) { - switch (entry.type) { - case CatalogType::TABLE_ENTRY: - return make_uniq(entry.Cast()); - case CatalogType::VIEW_ENTRY: - return make_uniq(entry.Cast()); - default: - throw NotImplementedException("Unsupported catalog type for duckdb_columns"); - } -} - -void ColumnHelper::WriteColumns(idx_t start_index, idx_t start_col, idx_t end_col, DataChunk &output) { - for (idx_t i = start_col; i < end_col; i++) { - auto index = start_index + (i - start_col); - auto &entry = Entry(); - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, index, entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.catalog.GetOid()))); - // schema_name, VARCHAR - output.SetValue(col++, index, entry.schema.name); - // schema_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.schema.oid))); - // table_name, VARCHAR - output.SetValue(col++, index, entry.name); - // table_oid, BIGINT - output.SetValue(col++, index, Value::BIGINT(NumericCast(entry.oid))); - // column_name, VARCHAR - output.SetValue(col++, index, Value(ColumnName(i))); - // column_index, INTEGER - output.SetValue(col++, index, Value::INTEGER(UnsafeNumericCast(i + 1))); - // comment, VARCHAR - output.SetValue(col++, index, ColumnComment(i)); - // internal, BOOLEAN - output.SetValue(col++, index, Value::BOOLEAN(entry.internal)); - // column_default, VARCHAR - output.SetValue(col++, index, Value(ColumnDefault(i))); - // is_nullable, BOOLEAN - output.SetValue(col++, index, Value::BOOLEAN(IsNullable(i))); - // data_type, VARCHAR - const LogicalType &type = ColumnType(i); - output.SetValue(col++, index, Value(type.ToString())); - // data_type_id, BIGINT - output.SetValue(col++, index, Value::BIGINT(int(type.id()))); - if (type == LogicalType::VARCHAR) { - // FIXME: need check constraints in place to set this correctly - // character_maximum_length, INTEGER - output.SetValue(col++, index, Value()); - } else { - // "character_maximum_length", PhysicalType::INTEGER - output.SetValue(col++, index, Value()); - } - - Value numeric_precision, numeric_scale, numeric_precision_radix; - switch (type.id()) { - case LogicalTypeId::DECIMAL: - numeric_precision = Value::INTEGER(DecimalType::GetWidth(type)); - numeric_scale = Value::INTEGER(DecimalType::GetScale(type)); - numeric_precision_radix = Value::INTEGER(10); - break; - case LogicalTypeId::HUGEINT: - numeric_precision = Value::INTEGER(128); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::BIGINT: - numeric_precision = Value::INTEGER(64); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::INTEGER: - numeric_precision = Value::INTEGER(32); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::SMALLINT: - numeric_precision = Value::INTEGER(16); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::TINYINT: - numeric_precision = Value::INTEGER(8); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::FLOAT: - numeric_precision = Value::INTEGER(24); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - case LogicalTypeId::DOUBLE: - numeric_precision = Value::INTEGER(53); - numeric_scale = Value::INTEGER(0); - numeric_precision_radix = Value::INTEGER(2); - break; - default: - numeric_precision = Value(); - numeric_scale = Value(); - numeric_precision_radix = Value(); - break; - } - - // numeric_precision, INTEGER - output.SetValue(col++, index, numeric_precision); - // numeric_precision_radix, INTEGER - output.SetValue(col++, index, numeric_precision_radix); - // numeric_scale, INTEGER - output.SetValue(col++, index, numeric_scale); - } -} - -void DuckDBColumnsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - - // We need to track the offset of the relation we're writing as well as the last column - // we wrote from that relation (if any); it's possible that we can fill up the output - // with a partial list of columns from a relation and will need to pick up processing the - // next chunk at the same spot. - idx_t next = data.offset; - idx_t column_offset = data.column_offset; - idx_t index = 0; - while (next < data.entries.size() && index < STANDARD_VECTOR_SIZE) { - auto column_helper = ColumnHelper::Create(data.entries[next].get()); - idx_t columns = column_helper->NumColumns(); - - // Check to see if we are going to exceed the maximum index for a DataChunk - if (index + (columns - column_offset) > STANDARD_VECTOR_SIZE) { - idx_t column_limit = column_offset + (STANDARD_VECTOR_SIZE - index); - output.SetCardinality(STANDARD_VECTOR_SIZE); - column_helper->WriteColumns(index, column_offset, column_limit, output); - - // Make the current column limit the column offset when we process the next chunk - column_offset = column_limit; - break; - } else { - // Otherwise, write all of the columns from the current relation and - // then move on to the next one. - output.SetCardinality(index + (columns - column_offset)); - column_helper->WriteColumns(index, column_offset, columns, output); - index += columns - column_offset; - next++; - column_offset = 0; - } - } - data.offset = next; - data.column_offset = column_offset; -} - -void DuckDBColumnsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_columns", {}, DuckDBColumnsFunction, DuckDBColumnsBind, DuckDBColumnsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_constraints.cpp b/src/duckdb/src/function/table/system/duckdb_constraints.cpp deleted file mode 100644 index 5f436f2c8..000000000 --- a/src/duckdb/src/function/table/system/duckdb_constraints.cpp +++ /dev/null @@ -1,341 +0,0 @@ -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/parser/constraint.hpp" -#include "duckdb/parser/constraints/check_constraint.hpp" -#include "duckdb/parser/constraints/foreign_key_constraint.hpp" -#include "duckdb/parser/constraints/not_null_constraint.hpp" -#include "duckdb/parser/constraints/unique_constraint.hpp" -#include "duckdb/planner/binder.hpp" -#include "duckdb/planner/constraints/bound_check_constraint.hpp" -#include "duckdb/parser/parsed_expression_iterator.hpp" - -namespace duckdb { - -struct ConstraintEntry { - ConstraintEntry(ClientContext &context, TableCatalogEntry &table) : table(table) { - if (!table.IsDuckTable()) { - return; - } - auto binder = Binder::CreateBinder(context); - bound_constraints = binder->BindConstraints(table.GetConstraints(), table.name, table.GetColumns()); - } - - TableCatalogEntry &table; - vector> bound_constraints; -}; - -struct DuckDBConstraintsData : public GlobalTableFunctionState { - DuckDBConstraintsData() : offset(0), constraint_offset(0), unique_constraint_offset(0) { - } - - vector entries; - idx_t offset; - idx_t constraint_offset; - idx_t unique_constraint_offset; - case_insensitive_set_t constraint_names; -}; - -static unique_ptr DuckDBConstraintsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("constraint_index"); - return_types.emplace_back(LogicalType::BIGINT); - - // CHECK, PRIMARY KEY or UNIQUE - names.emplace_back("constraint_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("constraint_text"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("expression"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("constraint_column_indexes"); - return_types.push_back(LogicalType::LIST(LogicalType::BIGINT)); - - names.emplace_back("constraint_column_names"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("constraint_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - // FOREIGN KEY - names.emplace_back("referenced_table"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("referenced_column_names"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - return nullptr; -} - -unique_ptr DuckDBConstraintsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them - auto schemas = Catalog::GetAllSchemas(context); - - for (auto &schema : schemas) { - vector> entries; - - schema.get().Scan(context, CatalogType::TABLE_ENTRY, [&](CatalogEntry &entry) { - if (entry.type == CatalogType::TABLE_ENTRY) { - entries.push_back(entry); - } - }); - - sort(entries.begin(), entries.end(), [&](CatalogEntry &x, CatalogEntry &y) { return (x.name < y.name); }); - for (auto &entry : entries) { - result->entries.emplace_back(context, entry.get().Cast()); - } - }; - - return std::move(result); -} - -struct ExtraConstraintInfo { - vector column_indexes; - vector column_names; - string referenced_table; - vector referenced_columns; -}; - -void ExtractReferencedColumns(const ParsedExpression &expr, vector &result) { - if (expr.GetExpressionClass() == ExpressionClass::COLUMN_REF) { - auto &colref = expr.Cast(); - result.push_back(colref.GetColumnName()); - } - ParsedExpressionIterator::EnumerateChildren( - expr, [&](const ParsedExpression &child) { ExtractReferencedColumns(child, result); }); -} - -ExtraConstraintInfo GetExtraConstraintInfo(const TableCatalogEntry &table, const Constraint &constraint) { - ExtraConstraintInfo result; - switch (constraint.type) { - case ConstraintType::CHECK: { - auto &check_constraint = constraint.Cast(); - ExtractReferencedColumns(*check_constraint.expression, result.column_names); - break; - } - case ConstraintType::NOT_NULL: { - auto ¬_null_constraint = constraint.Cast(); - result.column_indexes.push_back(not_null_constraint.index); - break; - } - case ConstraintType::UNIQUE: { - auto &unique = constraint.Cast(); - if (unique.HasIndex()) { - result.column_indexes.push_back(unique.GetIndex()); - } else { - result.column_names = unique.GetColumnNames(); - } - break; - } - case ConstraintType::FOREIGN_KEY: { - auto &fk = constraint.Cast(); - result.referenced_columns = fk.pk_columns; - result.referenced_table = fk.info.table; - result.column_names = fk.fk_columns; - break; - } - default: - throw InternalException("Unsupported type for constraint name"); - } - if (result.column_indexes.empty()) { - // generate column indexes from names - for (auto &name : result.column_names) { - result.column_indexes.push_back(table.GetColumnIndex(name)); - } - } else { - // generate names from column indexes - for (auto &index : result.column_indexes) { - result.column_names.push_back(table.GetColumn(index).GetName()); - } - } - return result; -} - -string GetConstraintName(const TableCatalogEntry &table, Constraint &constraint, const ExtraConstraintInfo &info) { - string result = table.name + "_"; - for (auto &col : info.column_names) { - result += StringUtil::Lower(col) + "_"; - } - for (auto &col : info.referenced_columns) { - result += StringUtil::Lower(col) + "_"; - } - switch (constraint.type) { - case ConstraintType::CHECK: - result += "check"; - break; - case ConstraintType::NOT_NULL: - result += "not_null"; - break; - case ConstraintType::UNIQUE: { - auto &unique = constraint.Cast(); - result += unique.IsPrimaryKey() ? "pkey" : "key"; - break; - } - case ConstraintType::FOREIGN_KEY: - result += "fkey"; - break; - default: - throw InternalException("Unsupported type for constraint name"); - } - return result; -} - -void DuckDBConstraintsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset]; - - auto &table = entry.table; - auto &constraints = table.GetConstraints(); - for (; data.constraint_offset < constraints.size() && count < STANDARD_VECTOR_SIZE; data.constraint_offset++) { - auto &constraint = constraints[data.constraint_offset]; - // return values: - // constraint_type, VARCHAR - // Processing this first due to shortcut (early continue) - string constraint_type; - switch (constraint->type) { - case ConstraintType::CHECK: - constraint_type = "CHECK"; - break; - case ConstraintType::UNIQUE: { - auto &unique = constraint->Cast(); - constraint_type = unique.IsPrimaryKey() ? "PRIMARY KEY" : "UNIQUE"; - break; - } - case ConstraintType::NOT_NULL: - constraint_type = "NOT NULL"; - break; - case ConstraintType::FOREIGN_KEY: { - auto &fk = constraint->Cast(); - if (fk.info.type == ForeignKeyType::FK_TYPE_PRIMARY_KEY_TABLE) { - // Those are already covered by PRIMARY KEY and UNIQUE entries - continue; - } - constraint_type = "FOREIGN KEY"; - break; - } - default: - throw NotImplementedException("Unimplemented constraint for duckdb_constraints"); - } - - idx_t col = 0; - // database_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.catalog.GetName())); - // database_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.catalog.GetOid()))); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.oid))); - // table_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.name)); - // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.oid))); - - auto info = GetExtraConstraintInfo(table, *constraint); - auto constraint_name = GetConstraintName(table, *constraint, info); - if (data.constraint_names.find(constraint_name) != data.constraint_names.end()) { - // duplicate constraint name - idx_t index = 2; - while (data.constraint_names.find(constraint_name + "_" + to_string(index)) != - data.constraint_names.end()) { - index++; - } - constraint_name += "_" + to_string(index); - } - // constraint_index, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(data.unique_constraint_offset++))); - - // constraint_type, VARCHAR - output.SetValue(col++, count, Value(constraint_type)); - - // constraint_text, VARCHAR - output.SetValue(col++, count, Value(constraint->ToString())); - - // expression, VARCHAR - Value expression_text; - if (constraint->type == ConstraintType::CHECK) { - auto &check = constraint->Cast(); - expression_text = Value(check.expression->ToString()); - } - output.SetValue(col++, count, expression_text); - - vector column_index_list; - vector column_name_list; - vector referenced_column_name_list; - for (auto &col_index : info.column_indexes) { - column_index_list.push_back(Value::UBIGINT(col_index.index)); - } - for (auto &name : info.column_names) { - column_name_list.push_back(Value(std::move(name))); - } - for (auto &name : info.referenced_columns) { - referenced_column_name_list.push_back(Value(std::move(name))); - } - // constraint_column_indexes, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::BIGINT, std::move(column_index_list))); - - // constraint_column_names, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(column_name_list))); - - // constraint_name, VARCHAR - output.SetValue(col++, count, Value(std::move(constraint_name))); - - // referenced_table, VARCHAR - output.SetValue(col++, count, - info.referenced_table.empty() ? Value() : Value(std::move(info.referenced_table))); - - // referenced_column_names, LIST - output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, std::move(referenced_column_name_list))); - count++; - } - - if (data.constraint_offset >= constraints.size()) { - data.constraint_offset = 0; - data.offset++; - } - } - output.SetCardinality(count); -} - -void DuckDBConstraintsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_constraints", {}, DuckDBConstraintsFunction, DuckDBConstraintsBind, - DuckDBConstraintsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_databases.cpp b/src/duckdb/src/function/table/system/duckdb_databases.cpp deleted file mode 100644 index 2c7c9c913..000000000 --- a/src/duckdb/src/function/table/system/duckdb_databases.cpp +++ /dev/null @@ -1,105 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/main/attached_database.hpp" - -namespace duckdb { - -struct DuckDBDatabasesData : public GlobalTableFunctionState { - DuckDBDatabasesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBDatabasesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("readonly"); - return_types.emplace_back(LogicalType::BOOLEAN); - - return nullptr; -} - -unique_ptr DuckDBDatabasesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto &db_manager = DatabaseManager::Get(context); - result->entries = db_manager.GetDatabases(context); - return std::move(result); -} - -void DuckDBDatabasesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - - auto &attached = entry.get().Cast(); - // return values: - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, attached.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(attached.oid))); - bool is_internal = attached.IsSystem() || attached.IsTemporary(); - bool is_readonly = attached.IsReadOnly(); - // path, VARCHAR - Value db_path; - if (!is_internal) { - bool in_memory = attached.GetCatalog().InMemory(); - if (!in_memory) { - db_path = Value(attached.GetCatalog().GetDBPath()); - } - } - output.SetValue(col++, count, db_path); - // comment, VARCHAR - output.SetValue(col++, count, Value(attached.comment)); - // tags, MAP - output.SetValue(col++, count, Value::MAP(attached.tags)); - // internal, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(is_internal)); - // type, VARCHAR - output.SetValue(col++, count, Value(attached.GetCatalog().GetCatalogType())); - // readonly, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(is_readonly)); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBDatabasesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_databases", {}, DuckDBDatabasesFunction, DuckDBDatabasesBind, DuckDBDatabasesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_dependencies.cpp b/src/duckdb/src/function/table/system/duckdb_dependencies.cpp deleted file mode 100644 index 720bee4d8..000000000 --- a/src/duckdb/src/function/table/system/duckdb_dependencies.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/duck_catalog.hpp" -#include "duckdb/catalog/dependency_manager.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -struct DependencyInformation { - DependencyInformation(CatalogEntry &object, CatalogEntry &dependent, const DependencyDependentFlags &flags) - : object(object), dependent(dependent), flags(flags) { - } - - CatalogEntry &object; - CatalogEntry &dependent; - DependencyDependentFlags flags; -}; - -struct DuckDBDependenciesData : public GlobalTableFunctionState { - DuckDBDependenciesData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBDependenciesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("classid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("objid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("objsubid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("refclassid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("refobjid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("refobjsubid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("deptype"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBDependenciesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas and collect them - auto &catalog = Catalog::GetCatalog(context, INVALID_CATALOG); - auto dependency_manager = catalog.GetDependencyManager(); - if (dependency_manager) { - dependency_manager->Scan( - context, [&](CatalogEntry &obj, CatalogEntry &dependent, const DependencyDependentFlags &flags) { - result->entries.emplace_back(obj, dependent, flags); - }); - } - - return std::move(result); -} - -void DuckDBDependenciesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset]; - - // return values: - // classid, LogicalType::BIGINT - output.SetValue(0, count, Value::BIGINT(0)); - // objid, LogicalType::BIGINT - output.SetValue(1, count, Value::BIGINT(NumericCast(entry.object.oid))); - // objsubid, LogicalType::INTEGER - output.SetValue(2, count, Value::INTEGER(0)); - // refclassid, LogicalType::BIGINT - output.SetValue(3, count, Value::BIGINT(0)); - // refobjid, LogicalType::BIGINT - output.SetValue(4, count, Value::BIGINT(NumericCast(entry.dependent.oid))); - // refobjsubid, LogicalType::INTEGER - output.SetValue(5, count, Value::INTEGER(0)); - // deptype, LogicalType::VARCHAR - string dependency_type_str; - if (entry.flags.IsBlocking()) { - dependency_type_str = "n"; - } else { - dependency_type_str = "a"; - } - output.SetValue(6, count, Value(dependency_type_str)); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBDependenciesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_dependencies", {}, DuckDBDependenciesFunction, DuckDBDependenciesBind, - DuckDBDependenciesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_extensions.cpp b/src/duckdb/src/function/table/system/duckdb_extensions.cpp deleted file mode 100644 index 64d26dea9..000000000 --- a/src/duckdb/src/function/table/system/duckdb_extensions.cpp +++ /dev/null @@ -1,221 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/map.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/extension_helper.hpp" - -#include "duckdb/common/serializer/buffered_file_reader.hpp" -#include "duckdb/common/serializer/binary_deserializer.hpp" -#include "duckdb/main/extension_install_info.hpp" - -namespace duckdb { - -struct ExtensionInformation { - string name; - bool loaded = false; - bool installed = false; - string file_path; - ExtensionInstallMode install_mode; - string installed_from; - string description; - vector aliases; - string extension_version; -}; - -struct DuckDBExtensionsData : public GlobalTableFunctionState { - DuckDBExtensionsData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBExtensionsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("extension_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("loaded"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("installed"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("install_path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("aliases"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("extension_version"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("install_mode"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("installed_from"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBExtensionsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto &fs = FileSystem::GetFileSystem(context); - auto &db = DatabaseInstance::GetDatabase(context); - - // Firstly, we go over all Default Extensions: duckdb_extensions always prints those, installed/loaded or not - map installed_extensions; - auto extension_count = ExtensionHelper::DefaultExtensionCount(); - auto alias_count = ExtensionHelper::ExtensionAliasCount(); - for (idx_t i = 0; i < extension_count; i++) { - auto extension = ExtensionHelper::GetDefaultExtension(i); - ExtensionInformation info; - info.name = extension.name; - info.installed = extension.statically_loaded; - info.loaded = false; - info.file_path = extension.statically_loaded ? "(BUILT-IN)" : string(); - info.install_mode = - extension.statically_loaded ? ExtensionInstallMode::STATICALLY_LINKED : ExtensionInstallMode::UNKNOWN; - info.description = extension.description; - for (idx_t k = 0; k < alias_count; k++) { - auto alias = ExtensionHelper::GetExtensionAlias(k); - if (info.name == alias.extension) { - info.aliases.emplace_back(alias.alias); - } - } - installed_extensions[info.name] = std::move(info); - } - - // Secondly we scan all installed extensions and their install info -#ifndef WASM_LOADABLE_EXTENSIONS - auto ext_directory = ExtensionHelper::GetExtensionDirectoryPath(context); - fs.ListFiles(ext_directory, [&](const string &path, bool is_directory) { - if (!StringUtil::EndsWith(path, ".duckdb_extension")) { - return; - } - ExtensionInformation info; - info.name = fs.ExtractBaseName(path); - info.installed = true; - info.loaded = false; - info.file_path = fs.JoinPath(ext_directory, path); - - // Check the info file for its installation source - auto info_file_path = fs.JoinPath(ext_directory, path + ".info"); - - // Read the info file - auto extension_install_info = ExtensionInstallInfo::TryReadInfoFile(fs, info_file_path, info.name); - info.install_mode = extension_install_info->mode; - info.extension_version = extension_install_info->version; - if (extension_install_info->mode == ExtensionInstallMode::REPOSITORY) { - info.installed_from = ExtensionRepository::GetRepository(extension_install_info->repository_url); - } else { - info.installed_from = extension_install_info->full_path; - } - - auto entry = installed_extensions.find(info.name); - if (entry == installed_extensions.end()) { - installed_extensions[info.name] = std::move(info); - } else { - if (entry->second.install_mode != ExtensionInstallMode::STATICALLY_LINKED) { - entry->second.file_path = info.file_path; - entry->second.install_mode = info.install_mode; - entry->second.installed_from = info.installed_from; - entry->second.install_mode = info.install_mode; - entry->second.extension_version = info.extension_version; - } - entry->second.installed = true; - } - }); -#endif - - // Finally, we check the list of currently loaded extensions - auto &extensions = db.GetExtensions(); - for (auto &e : extensions) { - if (!e.second.is_loaded) { - continue; - } - auto &ext_name = e.first; - auto &ext_data = e.second; - if (auto &ext_install_info = ext_data.install_info) { - auto entry = installed_extensions.find(ext_name); - if (entry == installed_extensions.end() || !entry->second.installed) { - ExtensionInformation &info = installed_extensions[ext_name]; - info.name = ext_name; - info.loaded = true; - info.extension_version = ext_install_info->version; - info.installed = ext_install_info->mode == ExtensionInstallMode::STATICALLY_LINKED; - info.install_mode = ext_install_info->mode; - } else { - entry->second.loaded = true; - entry->second.extension_version = ext_install_info->version; - } - } - if (auto &ext_load_info = ext_data.load_info) { - auto entry = installed_extensions.find(ext_name); - if (entry != installed_extensions.end()) { - entry->second.description = ext_load_info->description; - } - } - } - - result->entries.reserve(installed_extensions.size()); - for (auto &kv : installed_extensions) { - result->entries.push_back(std::move(kv.second)); - } - return std::move(result); -} - -void DuckDBExtensionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset]; - - // return values: - // extension_name LogicalType::VARCHAR - output.SetValue(0, count, Value(entry.name)); - // loaded LogicalType::BOOLEAN - output.SetValue(1, count, Value::BOOLEAN(entry.loaded)); - // installed LogicalType::BOOLEAN - output.SetValue(2, count, Value::BOOLEAN(entry.installed)); - // install_path LogicalType::VARCHAR - output.SetValue(3, count, Value(entry.file_path)); - // description LogicalType::VARCHAR - output.SetValue(4, count, Value(entry.description)); - // aliases LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(5, count, Value::LIST(LogicalType::VARCHAR, entry.aliases)); - // extension version LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(6, count, Value(entry.extension_version)); - // installed_mode LogicalType::VARCHAR - output.SetValue(7, count, entry.installed ? Value(EnumUtil::ToString(entry.install_mode)) : Value()); - // installed_source LogicalType::VARCHAR - output.SetValue(8, count, Value(entry.installed_from)); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBExtensionsFun::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet functions("duckdb_extensions"); - functions.AddFunction(TableFunction({}, DuckDBExtensionsFunction, DuckDBExtensionsBind, DuckDBExtensionsInit)); - set.AddFunction(functions); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_functions.cpp b/src/duckdb/src/function/table/system/duckdb_functions.cpp deleted file mode 100644 index 4ed87e524..000000000 --- a/src/duckdb/src/function/table/system/duckdb_functions.cpp +++ /dev/null @@ -1,708 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/aggregate_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/scalar_macro_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_macro_catalog_entry.hpp" -#include "duckdb/function/table_macro_function.hpp" -#include "duckdb/function/scalar_macro_function.hpp" -#include "duckdb/catalog/catalog_entry/table_function_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/pragma_function_catalog_entry.hpp" -#include "duckdb/parser/expression/columnref_expression.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/common/optional_idx.hpp" -#include "duckdb/common/types.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBFunctionsData : public GlobalTableFunctionState { - DuckDBFunctionsData() : offset(0), offset_in_entry(0) { - } - - vector> entries; - idx_t offset; - idx_t offset_in_entry; -}; - -static unique_ptr DuckDBFunctionsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("function_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("function_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("return_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("parameters"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("parameter_types"); - return_types.push_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("varargs"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("macro_definition"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("has_side_effects"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("function_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("examples"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("stability"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -static void ExtractFunctionsFromSchema(ClientContext &context, SchemaCatalogEntry &schema, - DuckDBFunctionsData &result) { - schema.Scan(context, CatalogType::SCALAR_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); - schema.Scan(context, CatalogType::TABLE_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); - schema.Scan(context, CatalogType::PRAGMA_FUNCTION_ENTRY, - [&](CatalogEntry &entry) { result.entries.push_back(entry); }); -} - -unique_ptr DuckDBFunctionsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - ExtractFunctionsFromSchema(context, schema.get(), *result); - }; - - std::sort(result->entries.begin(), result->entries.end(), - [&](reference a, reference b) { - return (int32_t)a.get().type < (int32_t)b.get().type; - }); - return std::move(result); -} - -Value FunctionStabilityToValue(FunctionStability stability) { - switch (stability) { - case FunctionStability::VOLATILE: - return Value("VOLATILE"); - case FunctionStability::CONSISTENT: - return Value("CONSISTENT"); - case FunctionStability::CONSISTENT_WITHIN_QUERY: - return Value("CONSISTENT_WITHIN_QUERY"); - default: - throw InternalException("Unsupported FunctionStability"); - } -} - -struct ScalarFunctionExtractor { - static idx_t FunctionCount(ScalarFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("scalar"); - } - - static Value GetReturnType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); - } - - static vector GetParameters(ScalarFunctionCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - return results; - } - - static Value GetParameterTypes(ScalarFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(ScalarFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return fun.arguments; - } - - static Value GetVarArgs(ScalarFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value IsVolatile(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); - } - - static Value ResultType(ScalarFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); - } -}; - -struct AggregateFunctionExtractor { - static idx_t FunctionCount(AggregateFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("aggregate"); - } - - static Value GetReturnType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(entry.functions.GetFunctionByOffset(offset).return_type.ToString()); - } - - static vector GetParameters(AggregateFunctionCatalogEntry &entry, idx_t offset) { - vector results; - for (idx_t i = 0; i < entry.functions.GetFunctionByOffset(offset).arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - return results; - } - - static Value GetParameterTypes(AggregateFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(AggregateFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return fun.arguments; - } - - static Value GetVarArgs(AggregateFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value IsVolatile(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return Value::BOOLEAN(entry.functions.GetFunctionByOffset(offset).stability == FunctionStability::VOLATILE); - } - - static Value ResultType(AggregateFunctionCatalogEntry &entry, idx_t offset) { - return FunctionStabilityToValue(entry.functions.GetFunctionByOffset(offset).stability); - } -}; - -struct MacroExtractor { - static idx_t FunctionCount(ScalarMacroCatalogEntry &entry) { - return entry.macros.size(); - } - - static Value GetFunctionType() { - return Value("macro"); - } - - static Value GetReturnType(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(ScalarMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (auto ¶m : macro_entry.parameters) { - D_ASSERT(param->GetExpressionType() == ExpressionType::COLUMN_REF); - auto &colref = param->Cast(); - results.emplace_back(colref.GetColumnName()); - } - for (auto ¶m_entry : macro_entry.default_parameters) { - results.emplace_back(param_entry.first); - } - return results; - } - - static Value GetParameterTypes(ScalarMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(ScalarMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { - results.emplace_back(LogicalType::UNKNOWN); - } - for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { - results.emplace_back(LogicalType::UNKNOWN); - } - return results; - } - - static Value GetVarArgs(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value GetMacroDefinition(ScalarMacroCatalogEntry &entry, idx_t offset) { - auto ¯o_entry = *entry.macros[offset]; - D_ASSERT(macro_entry.type == MacroType::SCALAR_MACRO); - auto &func = macro_entry.Cast(); - return func.expression->ToString(); - } - - static Value IsVolatile(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value ResultType(ScalarMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct TableMacroExtractor { - static idx_t FunctionCount(TableMacroCatalogEntry &entry) { - return entry.macros.size(); - } - - static Value GetFunctionType() { - return Value("table_macro"); - } - - static Value GetReturnType(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(TableMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (auto ¶m : macro_entry.parameters) { - D_ASSERT(param->GetExpressionType() == ExpressionType::COLUMN_REF); - auto &colref = param->Cast(); - results.emplace_back(colref.GetColumnName()); - } - for (auto ¶m_entry : macro_entry.default_parameters) { - results.emplace_back(param_entry.first); - } - return results; - } - - static Value GetParameterTypes(TableMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { - results.emplace_back(LogicalType::VARCHAR); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(TableMacroCatalogEntry &entry, idx_t offset) { - vector results; - auto ¯o_entry = *entry.macros[offset]; - for (idx_t i = 0; i < macro_entry.parameters.size(); i++) { - results.emplace_back(LogicalType::UNKNOWN); - } - for (idx_t i = 0; i < macro_entry.default_parameters.size(); i++) { - results.emplace_back(LogicalType::UNKNOWN); - } - return results; - } - - static Value GetVarArgs(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value GetMacroDefinition(TableMacroCatalogEntry &entry, idx_t offset) { - auto ¯o_entry = *entry.macros[offset]; - if (macro_entry.type == MacroType::TABLE_MACRO) { - auto &func = macro_entry.Cast(); - return func.query_node->ToString(); - } - return Value(); - } - - static Value IsVolatile(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value ResultType(TableMacroCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct TableFunctionExtractor { - static idx_t FunctionCount(TableFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("table"); - } - - static Value GetReturnType(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(TableFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.first); - } - return results; - } - - static Value GetParameterTypes(TableFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.second.ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(TableFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return fun.arguments; - } - - static Value GetVarArgs(TableFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value IsVolatile(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value ResultType(TableFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -struct PragmaFunctionExtractor { - static idx_t FunctionCount(PragmaFunctionCatalogEntry &entry) { - return entry.functions.Size(); - } - - static Value GetFunctionType() { - return Value("pragma"); - } - - static Value GetReturnType(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static vector GetParameters(PragmaFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back("col" + to_string(i)); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.first); - } - return results; - } - - static Value GetParameterTypes(PragmaFunctionCatalogEntry &entry, idx_t offset) { - vector results; - auto fun = entry.functions.GetFunctionByOffset(offset); - - for (idx_t i = 0; i < fun.arguments.size(); i++) { - results.emplace_back(fun.arguments[i].ToString()); - } - for (auto ¶m : fun.named_parameters) { - results.emplace_back(param.second.ToString()); - } - return Value::LIST(LogicalType::VARCHAR, std::move(results)); - } - - static vector GetParameterLogicalTypes(PragmaFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return fun.arguments; - } - - static Value GetVarArgs(PragmaFunctionCatalogEntry &entry, idx_t offset) { - auto fun = entry.functions.GetFunctionByOffset(offset); - return !fun.HasVarArgs() ? Value() : Value(fun.varargs.ToString()); - } - - static Value GetMacroDefinition(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value IsVolatile(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } - - static Value ResultType(PragmaFunctionCatalogEntry &entry, idx_t offset) { - return Value(); - } -}; - -static vector ToValueVector(vector &string_vector) { - vector result; - for (string &str : string_vector) { - result.emplace_back(Value(str)); - } - return result; -} - -template -static Value GetParameterNames(FunctionEntry &entry, idx_t function_idx, FunctionDescription &function_description, - Value ¶meter_types) { - vector parameter_names; - if (!function_description.parameter_names.empty()) { - for (idx_t param_idx = 0; param_idx < ListValue::GetChildren(parameter_types).size(); param_idx++) { - if (param_idx < function_description.parameter_names.size()) { - parameter_names.emplace_back(function_description.parameter_names[param_idx]); - } else { - parameter_names.emplace_back("col" + to_string(param_idx)); - } - } - } else { - // fallback - auto &function = entry.Cast(); - parameter_names = OP::GetParameters(function, function_idx); - } - return Value::LIST(LogicalType::VARCHAR, parameter_names); -} - -// returns values: -// 0: exact type match; N: match using N values; Invalid(): no match -static optional_idx CalcDescriptionSpecificity(FunctionDescription &description, - const vector ¶meter_types) { - if (description.parameter_types.size() != parameter_types.size()) { - return optional_idx::Invalid(); - } - idx_t any_count = 0; - for (idx_t i = 0; i < description.parameter_types.size(); i++) { - if (description.parameter_types[i].id() == LogicalTypeId::ANY) { - any_count++; - } else if (description.parameter_types[i] != parameter_types[i]) { - return optional_idx::Invalid(); - } - } - return any_count; -} - -// Find FunctionDescription object with matching number of arguments and types -static optional_idx GetFunctionDescriptionIndex(vector &function_descriptions, - vector &function_parameter_types) { - if (function_descriptions.size() == 1) { - // one description, use it even if nr of parameters don't match - idx_t nr_function_parameters = function_parameter_types.size(); - for (idx_t i = 0; i < function_descriptions[0].parameter_types.size(); i++) { - if (i < nr_function_parameters && function_descriptions[0].parameter_types[i] != LogicalTypeId::ANY && - function_descriptions[0].parameter_types[i] != function_parameter_types[i]) { - return optional_idx::Invalid(); - } - } - return optional_idx(0); - } - - // multiple descriptions, search most specific description - optional_idx best_description_idx; - // specificity_score: 0: exact type match; N: match using N values; Invalid(): no match - optional_idx best_specificity_score; - optional_idx specificity_score; - for (idx_t descr_idx = 0; descr_idx < function_descriptions.size(); descr_idx++) { - specificity_score = CalcDescriptionSpecificity(function_descriptions[descr_idx], function_parameter_types); - if (specificity_score.IsValid() && - (!best_specificity_score.IsValid() || specificity_score.GetIndex() < best_specificity_score.GetIndex())) { - best_specificity_score = specificity_score; - best_description_idx = descr_idx; - } - } - return best_description_idx; -} - -template -bool ExtractFunctionData(FunctionEntry &entry, idx_t function_idx, DataChunk &output, idx_t output_offset) { - auto &function = entry.Cast(); - vector parameter_types_vector = OP::GetParameterLogicalTypes(function, function_idx); - Value parameter_types_value = OP::GetParameterTypes(function, function_idx); - optional_idx description_idx = GetFunctionDescriptionIndex(entry.descriptions, parameter_types_vector); - FunctionDescription function_description = - description_idx.IsValid() ? entry.descriptions[description_idx.GetIndex()] : FunctionDescription(); - - idx_t col = 0; - - // database_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.schema.catalog.GetName())); - - // database_oid, BIGINT - output.SetValue(col++, output_offset, Value::BIGINT(NumericCast(function.schema.catalog.GetOid()))); - - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.schema.name)); - - // function_name, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(function.name)); - - // function_type, LogicalType::VARCHAR - output.SetValue(col++, output_offset, Value(OP::GetFunctionType())); - - // function_description, LogicalType::VARCHAR - output.SetValue(col++, output_offset, - (function_description.description.empty()) ? Value() : Value(function_description.description)); - - // comment, LogicalType::VARCHAR - output.SetValue(col++, output_offset, entry.comment); - - // tags, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR) - output.SetValue(col++, output_offset, Value::MAP(entry.tags)); - - // return_type, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetReturnType(function, function_idx)); - - // parameters, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(col++, output_offset, - GetParameterNames(function, function_idx, function_description, parameter_types_value)); - - // parameter_types, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(col++, output_offset, parameter_types_value); - - // varargs, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetVarArgs(function, function_idx)); - - // macro_definition, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::GetMacroDefinition(function, function_idx)); - - // has_side_effects, LogicalType::BOOLEAN - output.SetValue(col++, output_offset, OP::IsVolatile(function, function_idx)); - - // internal, LogicalType::BOOLEAN - output.SetValue(col++, output_offset, Value::BOOLEAN(function.internal)); - - // function_oid, LogicalType::BIGINT - output.SetValue(col++, output_offset, Value::BIGINT(NumericCast(function.oid))); - - // examples, LogicalType::LIST(LogicalType::VARCHAR) - output.SetValue(col++, output_offset, - Value::LIST(LogicalType::VARCHAR, ToValueVector(function_description.examples))); - - // stability, LogicalType::VARCHAR - output.SetValue(col++, output_offset, OP::ResultType(function, function_idx)); - - return function_idx + 1 == OP::FunctionCount(function); -} - -void DuckDBFunctionsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get().Cast(); - bool finished; - - switch (entry.type) { - case CatalogType::SCALAR_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::AGGREGATE_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::TABLE_MACRO_ENTRY: - finished = ExtractFunctionData(entry, data.offset_in_entry, - output, count); - break; - case CatalogType::MACRO_ENTRY: - finished = ExtractFunctionData(entry, data.offset_in_entry, output, - count); - break; - case CatalogType::TABLE_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - case CatalogType::PRAGMA_FUNCTION_ENTRY: - finished = ExtractFunctionData( - entry, data.offset_in_entry, output, count); - break; - default: - throw InternalException("FIXME: unrecognized function type in duckdb_functions"); - } - if (finished) { - // finished with this function, move to the next function - data.offset++; - data.offset_in_entry = 0; - } else { - // more functions remain - data.offset_in_entry++; - } - count++; - } - output.SetCardinality(count); -} - -void DuckDBFunctionsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_functions", {}, DuckDBFunctionsFunction, DuckDBFunctionsBind, DuckDBFunctionsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_indexes.cpp b/src/duckdb/src/function/table/system/duckdb_indexes.cpp deleted file mode 100644 index 78cc88d81..000000000 --- a/src/duckdb/src/function/table/system/duckdb_indexes.cpp +++ /dev/null @@ -1,150 +0,0 @@ -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/index_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBIndexesData : public GlobalTableFunctionState { - DuckDBIndexesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBIndexesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("index_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("index_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("is_unique"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("is_primary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("expressions"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBIndexesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::INDEX_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -Value GetIndexExpressions(IndexCatalogEntry &index) { - auto create_info = index.GetInfo(); - auto &create_index_info = create_info->Cast(); - - auto vec = create_index_info.ExpressionsToList(); - - vector content; - content.reserve(vec.size()); - for (auto &item : vec) { - content.push_back(Value(item)); - } - return Value::LIST(LogicalType::VARCHAR, std::move(content)); -} - -void DuckDBIndexesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - auto &index = entry.Cast(); - // return values: - - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, index.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(index.catalog.GetOid()))); - // schema_name, VARCHAR - output.SetValue(col++, count, Value(index.schema.name)); - // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(index.schema.oid))); - // index_name, VARCHAR - output.SetValue(col++, count, Value(index.name)); - // index_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(index.oid))); - // find the table in the catalog - auto &table_entry = - index.schema.catalog.GetEntry(context, index.GetSchemaName(), index.GetTableName()); - // table_name, VARCHAR - output.SetValue(col++, count, Value(table_entry.name)); - // table_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table_entry.oid))); - // comment, VARCHAR - output.SetValue(col++, count, Value(index.comment)); - // tags, MAP - output.SetValue(col++, count, Value::MAP(index.tags)); - // is_unique, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(index.IsUnique())); - // is_primary, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(index.IsPrimary())); - // expressions, VARCHAR - output.SetValue(col++, count, GetIndexExpressions(index).ToString()); - // sql, VARCHAR - auto sql = index.ToSQL(); - output.SetValue(col++, count, sql.empty() ? Value() : Value(std::move(sql))); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBIndexesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_indexes", {}, DuckDBIndexesFunction, DuckDBIndexesBind, DuckDBIndexesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_keywords.cpp b/src/duckdb/src/function/table/system/duckdb_keywords.cpp deleted file mode 100644 index 35167a8b4..000000000 --- a/src/duckdb/src/function/table/system/duckdb_keywords.cpp +++ /dev/null @@ -1,78 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/parser/parser.hpp" - -namespace duckdb { - -struct DuckDBKeywordsData : public GlobalTableFunctionState { - DuckDBKeywordsData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBKeywordsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("keyword_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("keyword_category"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBKeywordsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->entries = Parser::KeywordList(); - return std::move(result); -} - -void DuckDBKeywordsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - - // keyword_name, VARCHAR - output.SetValue(0, count, Value(entry.name)); - // keyword_category, VARCHAR - string category_name; - switch (entry.category) { - case KeywordCategory::KEYWORD_RESERVED: - category_name = "reserved"; - break; - case KeywordCategory::KEYWORD_UNRESERVED: - category_name = "unreserved"; - break; - case KeywordCategory::KEYWORD_TYPE_FUNC: - category_name = "type_function"; - break; - case KeywordCategory::KEYWORD_COL_NAME: - category_name = "column_name"; - break; - default: - throw InternalException("Unrecognized keyword category"); - } - output.SetValue(1, count, Value(std::move(category_name))); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBKeywordsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_keywords", {}, DuckDBKeywordsFunction, DuckDBKeywordsBind, DuckDBKeywordsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_memory.cpp b/src/duckdb/src/function/table/system/duckdb_memory.cpp deleted file mode 100644 index a1eb044b7..000000000 --- a/src/duckdb/src/function/table/system/duckdb_memory.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -struct DuckDBMemoryData : public GlobalTableFunctionState { - DuckDBMemoryData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBMemoryBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("tag"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("memory_usage_bytes"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("temporary_storage_bytes"); - return_types.emplace_back(LogicalType::BIGINT); - - return nullptr; -} - -unique_ptr DuckDBMemoryInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->entries = BufferManager::GetBufferManager(context).GetMemoryUsageInfo(); - return std::move(result); -} - -void DuckDBMemoryFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - // return values: - idx_t col = 0; - // tag, VARCHAR - output.SetValue(col++, count, EnumUtil::ToString(entry.tag)); - // memory_usage_bytes, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.size))); - // temporary_storage_bytes, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.evicted_data))); - count++; - } - output.SetCardinality(count); -} - -void DuckDBMemoryFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_memory", {}, DuckDBMemoryFunction, DuckDBMemoryBind, DuckDBMemoryInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_optimizers.cpp b/src/duckdb/src/function/table/system/duckdb_optimizers.cpp deleted file mode 100644 index ac531678a..000000000 --- a/src/duckdb/src/function/table/system/duckdb_optimizers.cpp +++ /dev/null @@ -1,57 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/main/config.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/enums/optimizer_type.hpp" - -namespace duckdb { - -struct DuckDBOptimizersData : public GlobalTableFunctionState { - DuckDBOptimizersData() : offset(0) { - } - - vector optimizers; - idx_t offset; -}; - -static unique_ptr DuckDBOptimizersBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBOptimizersInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->optimizers = ListAllOptimizers(); - return std::move(result); -} - -void DuckDBOptimizersFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.optimizers.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.optimizers.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.optimizers[data.offset++]; - - // return values: - // name, LogicalType::VARCHAR - output.SetValue(0, count, Value(entry)); - count++; - } - output.SetCardinality(count); -} - -void DuckDBOptimizersFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_optimizers", {}, DuckDBOptimizersFunction, DuckDBOptimizersBind, DuckDBOptimizersInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_schemas.cpp b/src/duckdb/src/function/table/system/duckdb_schemas.cpp deleted file mode 100644 index 295d440b6..000000000 --- a/src/duckdb/src/function/table/system/duckdb_schemas.cpp +++ /dev/null @@ -1,98 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBSchemasData : public GlobalTableFunctionState { - DuckDBSchemasData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBSchemasBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSchemasInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas and collect them - result->entries = Catalog::GetAllSchemas(context); - - return std::move(result); -} - -void DuckDBSchemasFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset].get(); - - // return values: - idx_t col = 0; - // "oid", PhysicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.oid))); - // database_name, VARCHAR - output.SetValue(col++, count, entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.catalog.GetOid()))); - // "schema_name", PhysicalType::VARCHAR - output.SetValue(col++, count, Value(entry.name)); - // "comment", PhysicalType::VARCHAR - output.SetValue(col++, count, Value(entry.comment)); - // "tags", MAP(VARCHAR, VARCHAR) - output.SetValue(col++, count, Value::MAP(entry.tags)); - // "internal", PhysicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(entry.internal)); - // "sql", PhysicalType::VARCHAR - output.SetValue(col++, count, Value()); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBSchemasFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_schemas", {}, DuckDBSchemasFunction, DuckDBSchemasBind, DuckDBSchemasInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_secrets.cpp b/src/duckdb/src/function/table/system/duckdb_secrets.cpp deleted file mode 100644 index 6069344bf..000000000 --- a/src/duckdb/src/function/table/system/duckdb_secrets.cpp +++ /dev/null @@ -1,140 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/map.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/main/secret/secret_manager.hpp" - -namespace duckdb { - -struct DuckDBSecretsData : public GlobalTableFunctionState { - DuckDBSecretsData() : offset(0) { - } - idx_t offset; - duckdb::vector secrets; -}; - -struct DuckDBSecretsBindData : public FunctionData { -public: - unique_ptr Copy() const override { - return make_uniq(); - }; - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return redact == other.redact; - } - SecretDisplayType redact = SecretDisplayType::REDACTED; -}; - -static unique_ptr DuckDBSecretsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - - auto entry = input.named_parameters.find("redact"); - if (entry != input.named_parameters.end()) { - if (BooleanValue::Get(entry->second)) { - result->redact = SecretDisplayType::REDACTED; - } else { - result->redact = SecretDisplayType::UNREDACTED; - } - } - - if (!DBConfig::GetConfig(context).options.allow_unredacted_secrets && - result->redact == SecretDisplayType::UNREDACTED) { - throw InvalidInputException("Displaying unredacted secrets is disabled"); - } - - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("provider"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("persistent"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("storage"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("scope"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - names.emplace_back("secret_string"); - return_types.emplace_back(LogicalType::VARCHAR); - - return std::move(result); -} - -unique_ptr DuckDBSecretsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - return std::move(result); -} - -void DuckDBSecretsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - auto &bind_data = data_p.bind_data->Cast(); - - auto &secret_manager = SecretManager::Get(context); - - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); - - if (data.secrets.empty()) { - data.secrets = secret_manager.AllSecrets(transaction); - } - auto &secrets = data.secrets; - if (data.offset >= secrets.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < secrets.size() && count < STANDARD_VECTOR_SIZE) { - auto &secret_entry = secrets[data.offset]; - - vector scope_value; - for (const auto &scope_entry : secret_entry.secret->GetScope()) { - scope_value.push_back(scope_entry); - } - - const auto &secret = *secret_entry.secret; - - idx_t i = 0; - // name - output.SetValue(i++, count, secret.GetName()); - // type - output.SetValue(i++, count, Value(secret.GetType())); - // provider - output.SetValue(i++, count, Value(secret.GetProvider())); - // persistent - output.SetValue(i++, count, Value(secret_entry.persist_type == SecretPersistType::PERSISTENT)); - // storage - output.SetValue(i++, count, Value(secret_entry.storage_mode)); - // scope - output.SetValue(i++, count, Value::LIST(LogicalType::VARCHAR, scope_value)); - // secret_string - output.SetValue(i++, count, secret.ToString(bind_data.redact)); - - data.offset++; - count++; - } - output.SetCardinality(count); -} - -void DuckDBSecretsFun::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet functions("duckdb_secrets"); - auto fun = TableFunction({}, DuckDBSecretsFunction, DuckDBSecretsBind, DuckDBSecretsInit); - fun.named_parameters["redact"] = LogicalType::BOOLEAN; - functions.AddFunction(fun); - set.AddFunction(functions); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_sequences.cpp b/src/duckdb/src/function/table/system/duckdb_sequences.cpp deleted file mode 100644 index b95d23c2e..000000000 --- a/src/duckdb/src/function/table/system/duckdb_sequences.cpp +++ /dev/null @@ -1,144 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/sequence_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/numeric_utils.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBSequencesData : public GlobalTableFunctionState { - DuckDBSequencesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBSequencesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sequence_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("sequence_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("start_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("min_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("max_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("increment_by"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("cycle"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("last_value"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSequencesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect themand collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::SEQUENCE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); - }; - return std::move(result); -} - -void DuckDBSequencesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &seq = data.entries[data.offset++].get(); - auto seq_data = seq.GetData(); - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, seq.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.catalog.GetOid()))); - // schema_name, VARCHAR - output.SetValue(col++, count, Value(seq.schema.name)); - // schema_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.schema.oid))); - // sequence_name, VARCHAR - output.SetValue(col++, count, Value(seq.name)); - // sequence_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(seq.oid))); - // comment, VARCHAR - output.SetValue(col++, count, Value(seq.comment)); - // tags, MAP(VARCHAR, VARCHAR) - output.SetValue(col++, count, Value::MAP(seq.tags)); - // temporary, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(seq.temporary)); - // start_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq_data.start_value)); - // min_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq_data.min_value)); - // max_value, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq_data.max_value)); - // increment_by, BIGINT - output.SetValue(col++, count, Value::BIGINT(seq_data.increment)); - // cycle, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(seq_data.cycle)); - // last_value, BIGINT - output.SetValue(col++, count, seq_data.usage_count == 0 ? Value() : Value::BIGINT(seq_data.last_value)); - // sql, LogicalType::VARCHAR - output.SetValue(col++, count, Value(seq.ToSQL())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBSequencesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_sequences", {}, DuckDBSequencesFunction, DuckDBSequencesBind, DuckDBSequencesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_settings.cpp b/src/duckdb/src/function/table/system/duckdb_settings.cpp deleted file mode 100644 index b7c3a56b8..000000000 --- a/src/duckdb/src/function/table/system/duckdb_settings.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/config.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/common/enum_util.hpp" - -namespace duckdb { - -struct DuckDBSettingValue { - string name; - string value; - string description; - string input_type; - string scope; -}; - -struct DuckDBSettingsData : public GlobalTableFunctionState { - DuckDBSettingsData() : offset(0) { - } - - vector settings; - idx_t offset; -}; - -static unique_ptr DuckDBSettingsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("description"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("input_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("scope"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBSettingsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto &config = DBConfig::GetConfig(context); - auto options_count = DBConfig::GetOptionCount(); - for (idx_t i = 0; i < options_count; i++) { - auto option = DBConfig::GetOptionByIndex(i); - D_ASSERT(option); - DuckDBSettingValue value; - auto scope = option->set_global ? SettingScope::GLOBAL : SettingScope::LOCAL; - value.name = option->name; - value.value = option->get_setting(context).ToString(); - value.description = option->description; - value.input_type = option->parameter_type; - value.scope = EnumUtil::ToString(scope); - - result->settings.push_back(std::move(value)); - } - for (auto &ext_param : config.extension_parameters) { - Value setting_val; - string setting_str_val; - auto scope = SettingScope::GLOBAL; - auto lookup_result = context.TryGetCurrentSetting(ext_param.first, setting_val); - if (lookup_result) { - setting_str_val = setting_val.ToString(); - scope = lookup_result.GetScope(); - } - DuckDBSettingValue value; - value.name = ext_param.first; - value.value = std::move(setting_str_val); - value.description = ext_param.second.description; - value.input_type = ext_param.second.type.ToString(); - value.scope = EnumUtil::ToString(scope); - - result->settings.push_back(std::move(value)); - } - return std::move(result); -} - -void DuckDBSettingsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.settings.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.settings.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.settings[data.offset++]; - - // return values: - // name, LogicalType::VARCHAR - output.SetValue(0, count, Value(entry.name)); - // value, LogicalType::VARCHAR - output.SetValue(1, count, Value(entry.value)); - // description, LogicalType::VARCHAR - output.SetValue(2, count, Value(entry.description)); - // input_type, LogicalType::VARCHAR - output.SetValue(3, count, Value(entry.input_type)); - // scope, LogicalType::VARCHAR - output.SetValue(4, count, Value(entry.scope)); - count++; - } - output.SetCardinality(count); -} - -void DuckDBSettingsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_settings", {}, DuckDBSettingsFunction, DuckDBSettingsBind, DuckDBSettingsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_tables.cpp b/src/duckdb/src/function/table/system/duckdb_tables.cpp deleted file mode 100644 index 5e007bbff..000000000 --- a/src/duckdb/src/function/table/system/duckdb_tables.cpp +++ /dev/null @@ -1,165 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/parser/constraint.hpp" -#include "duckdb/parser/constraints/unique_constraint.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table_storage_info.hpp" - -namespace duckdb { - -struct DuckDBTablesData : public GlobalTableFunctionState { - DuckDBTablesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBTablesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("table_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("table_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("has_primary_key"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("estimated_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("index_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("check_constraint_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBTablesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect themand collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TABLE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -static idx_t CheckConstraintCount(TableCatalogEntry &table) { - idx_t check_count = 0; - for (auto &constraint : table.GetConstraints()) { - if (constraint->type == ConstraintType::CHECK) { - check_count++; - } - } - return check_count; -} - -void DuckDBTablesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - if (entry.type != CatalogType::TABLE_ENTRY) { - continue; - } - auto &table = entry.Cast(); - auto storage_info = table.GetStorageInfo(context); - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, table.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.catalog.GetOid()))); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.schema.oid))); - // table_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.name)); - // table_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.oid))); - // comment, LogicalType::VARCHAR - output.SetValue(col++, count, Value(table.comment)); - // tags, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR) - output.SetValue(col++, count, Value::MAP(table.tags)); - // internal, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(table.internal)); - // temporary, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(table.temporary)); - // has_primary_key, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(table.HasPrimaryKey())); - // estimated_size, LogicalType::BIGINT - - Value card_val = !storage_info.cardinality.IsValid() - ? Value() - : Value::BIGINT(NumericCast(storage_info.cardinality.GetIndex())); - output.SetValue(col++, count, card_val); - // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(table.GetColumns().LogicalColumnCount()))); - // index_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(storage_info.index_info.size()))); - // check_constraint_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(CheckConstraintCount(table)))); - // sql, LogicalType::VARCHAR - auto table_info = table.GetInfo(); - table_info->catalog.clear(); - output.SetValue(col++, count, Value(table_info->ToString())); - count++; - } - output.SetCardinality(count); -} - -void DuckDBTablesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_tables", {}, DuckDBTablesFunction, DuckDBTablesBind, DuckDBTablesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp b/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp deleted file mode 100644 index f18866394..000000000 --- a/src/duckdb/src/function/table/system/duckdb_temporary_files.cpp +++ /dev/null @@ -1,59 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/storage/buffer_manager.hpp" - -namespace duckdb { - -struct DuckDBTemporaryFilesData : public GlobalTableFunctionState { - DuckDBTemporaryFilesData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr DuckDBTemporaryFilesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("size"); - return_types.emplace_back(LogicalType::BIGINT); - - return nullptr; -} - -unique_ptr DuckDBTemporaryFilesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - result->entries = BufferManager::GetBufferManager(context).GetTemporaryFiles(); - return std::move(result); -} - -void DuckDBTemporaryFilesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++]; - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, entry.path); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(entry.size))); - count++; - } - output.SetCardinality(count); -} - -void DuckDBTemporaryFilesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_temporary_files", {}, DuckDBTemporaryFilesFunction, DuckDBTemporaryFilesBind, - DuckDBTemporaryFilesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_types.cpp b/src/duckdb/src/function/table/system/duckdb_types.cpp deleted file mode 100644 index dd98408b9..000000000 --- a/src/duckdb/src/function/table/system/duckdb_types.cpp +++ /dev/null @@ -1,201 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBTypesData : public GlobalTableFunctionState { - DuckDBTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; - unordered_set oids; -}; - -static unique_ptr DuckDBTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("type_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("type_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("logical_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - // NUMERIC, STRING, DATETIME, BOOLEAN, COMPOSITE, USER - names.emplace_back("type_category"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("labels"); - return_types.emplace_back(LogicalType::LIST(LogicalType::VARCHAR)); - - return nullptr; -} - -unique_ptr DuckDBTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::TYPE_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.Cast()); }); - }; - return std::move(result); -} - -void DuckDBTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &type_entry = data.entries[data.offset++].get(); - auto &type = type_entry.user_type; - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, type_entry.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(type_entry.catalog.GetOid()))); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(type_entry.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(type_entry.schema.oid))); - // type_oid, BIGINT - int64_t oid; - if (type_entry.internal) { - oid = NumericCast(type.id()); - } else { - oid = NumericCast(type_entry.oid); - } - Value oid_val; - if (data.oids.find(oid) == data.oids.end()) { - data.oids.insert(oid); - oid_val = Value::BIGINT(oid); - } else { - oid_val = Value(); - } - output.SetValue(col++, count, oid_val); - // type_name, VARCHAR - output.SetValue(col++, count, Value(type_entry.name)); - // type_size, BIGINT - auto internal_type = type.InternalType(); - output.SetValue(col++, count, - internal_type == PhysicalType::INVALID - ? Value() - : Value::BIGINT(NumericCast(GetTypeIdSize(internal_type)))); - // logical_type, VARCHAR - output.SetValue(col++, count, Value(EnumUtil::ToString(type.id()))); - // type_category, VARCHAR - string category; - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::DECIMAL: - case LogicalTypeId::FLOAT: - case LogicalTypeId::DOUBLE: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - case LogicalTypeId::HUGEINT: - case LogicalTypeId::UHUGEINT: - category = "NUMERIC"; - break; - case LogicalTypeId::DATE: - case LogicalTypeId::TIME: - case LogicalTypeId::TIMESTAMP_SEC: - case LogicalTypeId::TIMESTAMP_MS: - case LogicalTypeId::TIMESTAMP: - case LogicalTypeId::TIMESTAMP_NS: - case LogicalTypeId::INTERVAL: - case LogicalTypeId::TIME_TZ: - case LogicalTypeId::TIMESTAMP_TZ: - category = "DATETIME"; - break; - case LogicalTypeId::CHAR: - case LogicalTypeId::VARCHAR: - category = "STRING"; - break; - case LogicalTypeId::BOOLEAN: - category = "BOOLEAN"; - break; - case LogicalTypeId::STRUCT: - case LogicalTypeId::LIST: - case LogicalTypeId::MAP: - case LogicalTypeId::UNION: - category = "COMPOSITE"; - break; - default: - break; - } - output.SetValue(col++, count, category.empty() ? Value() : Value(category)); - // comment, VARCHAR - output.SetValue(col++, count, Value(type_entry.comment)); - // tags, MAP - output.SetValue(col++, count, Value::MAP(type_entry.tags)); - // internal, BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(type_entry.internal)); - // labels, VARCHAR[] - if (type.id() == LogicalTypeId::ENUM && type.AuxInfo()) { - auto data = FlatVector::GetData(EnumType::GetValuesInsertOrder(type)); - idx_t size = EnumType::GetSize(type); - - vector labels; - for (idx_t i = 0; i < size; i++) { - labels.emplace_back(data[i]); - } - - output.SetValue(col++, count, Value::LIST(LogicalType::VARCHAR, labels)); - } else { - output.SetValue(col++, count, Value()); - } - - count++; - } - output.SetCardinality(count); -} - -void DuckDBTypesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_types", {}, DuckDBTypesFunction, DuckDBTypesBind, DuckDBTypesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_variables.cpp b/src/duckdb/src/function/table/system/duckdb_variables.cpp deleted file mode 100644 index 62cfbcb82..000000000 --- a/src/duckdb/src/function/table/system/duckdb_variables.cpp +++ /dev/null @@ -1,84 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/type_catalog_entry.hpp" -#include "duckdb/common/enum_util.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_config.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct VariableData { - string name; - Value value; -}; - -struct DuckDBVariablesData : public GlobalTableFunctionState { - DuckDBVariablesData() : offset(0) { - } - - vector variables; - idx_t offset; -}; - -static unique_ptr DuckDBVariablesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBVariablesInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - auto &config = ClientConfig::GetConfig(context); - - for (auto &entry : config.user_variables) { - VariableData data; - data.name = entry.first; - data.value = entry.second; - result->variables.push_back(std::move(data)); - } - return std::move(result); -} - -void DuckDBVariablesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.variables.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.variables.size() && count < STANDARD_VECTOR_SIZE) { - auto &variable_entry = data.variables[data.offset++]; - - // return values: - idx_t col = 0; - // name, VARCHAR - output.SetValue(col++, count, Value(variable_entry.name)); - // value, BIGINT - output.SetValue(col++, count, Value(variable_entry.value.ToString())); - // type, VARCHAR - output.SetValue(col, count, Value(variable_entry.value.type().ToString())); - count++; - } - output.SetCardinality(count); -} - -void DuckDBVariablesFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("duckdb_variables", {}, DuckDBVariablesFunction, DuckDBVariablesBind, DuckDBVariablesInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_views.cpp b/src/duckdb/src/function/table/system/duckdb_views.cpp deleted file mode 100644 index 1ae6fcce7..000000000 --- a/src/duckdb/src/function/table/system/duckdb_views.cpp +++ /dev/null @@ -1,126 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/client_data.hpp" - -namespace duckdb { - -struct DuckDBViewsData : public GlobalTableFunctionState { - DuckDBViewsData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -static unique_ptr DuckDBViewsBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("schema_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("schema_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("view_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("view_oid"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("comment"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("tags"); - return_types.emplace_back(LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR)); - - names.emplace_back("internal"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("temporary"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("column_count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("sql"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr DuckDBViewsInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - // scan all the schemas for tables and collect them and collect them - auto schemas = Catalog::GetAllSchemas(context); - for (auto &schema : schemas) { - schema.get().Scan(context, CatalogType::VIEW_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry); }); - }; - return std::move(result); -} - -void DuckDBViewsFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = data.entries[data.offset++].get(); - - if (entry.type != CatalogType::VIEW_ENTRY) { - continue; - } - auto &view = entry.Cast(); - - // return values: - idx_t col = 0; - // database_name, VARCHAR - output.SetValue(col++, count, view.catalog.GetName()); - // database_oid, BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(view.catalog.GetOid()))); - // schema_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.schema.name)); - // schema_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(view.schema.oid))); - // view_name, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.name)); - // view_oid, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(view.oid))); - // comment, LogicalType::VARCHARs - output.SetValue(col++, count, Value(view.comment)); - // tags, LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR) - output.SetValue(col++, count, Value::MAP(view.tags)); - // internal, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(view.internal)); - // temporary, LogicalType::BOOLEAN - output.SetValue(col++, count, Value::BOOLEAN(view.temporary)); - // column_count, LogicalType::BIGINT - output.SetValue(col++, count, Value::BIGINT(NumericCast(view.types.size()))); - // sql, LogicalType::VARCHAR - output.SetValue(col++, count, Value(view.ToSQL())); - - count++; - } - output.SetCardinality(count); -} - -void DuckDBViewsFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_views", {}, DuckDBViewsFunction, DuckDBViewsBind, DuckDBViewsInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/duckdb_which_secret.cpp b/src/duckdb/src/function/table/system/duckdb_which_secret.cpp deleted file mode 100644 index 3314fee95..000000000 --- a/src/duckdb/src/function/table/system/duckdb_which_secret.cpp +++ /dev/null @@ -1,75 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/common/file_system.hpp" -#include "duckdb/common/map.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/multi_file_reader.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/client_context.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/main/extension_helper.hpp" -#include "duckdb/main/secret/secret_manager.hpp" - -namespace duckdb { - -struct DuckDBWhichSecretData : public GlobalTableFunctionState { - DuckDBWhichSecretData() : finished(false) { - } - bool finished; -}; - -struct DuckDBWhichSecretBindData : public TableFunctionData { - explicit DuckDBWhichSecretBindData(TableFunctionBindInput &tf_input) : inputs(tf_input.inputs) {}; - - duckdb::vector inputs; -}; - -static unique_ptr DuckDBWhichSecretBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("persistent"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("storage"); - return_types.emplace_back(LogicalType::VARCHAR); - - return make_uniq(input); -} - -unique_ptr DuckDBWhichSecretInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -void DuckDBWhichSecretFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.finished) { - // finished returning values - return; - } - auto &bind_data = data_p.bind_data->Cast(); - - auto &secret_manager = SecretManager::Get(context); - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(context); - - auto &inputs = bind_data.inputs; - auto path = inputs[0].ToString(); - auto type = inputs[1].ToString(); - auto secret_match = secret_manager.LookupSecret(transaction, path, type); - if (secret_match.HasMatch()) { - auto &secret_entry = *secret_match.secret_entry; - output.SetCardinality(1); - output.SetValue(0, 0, secret_entry.secret->GetName()); - output.SetValue(1, 0, EnumUtil::ToString(secret_entry.persist_type)); - output.SetValue(2, 0, secret_entry.storage_mode); - } - data.finished = true; -} - -void DuckDBWhichSecretFun::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("which_secret", {duckdb::LogicalType::VARCHAR, duckdb::LogicalType::VARCHAR}, - DuckDBWhichSecretFunction, DuckDBWhichSecretBind, DuckDBWhichSecretInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_collations.cpp b/src/duckdb/src/function/table/system/pragma_collations.cpp deleted file mode 100644 index 187ae1745..000000000 --- a/src/duckdb/src/function/table/system/pragma_collations.cpp +++ /dev/null @@ -1,58 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/collate_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp" -#include "duckdb/common/exception.hpp" - -namespace duckdb { - -struct PragmaCollateData : public GlobalTableFunctionState { - PragmaCollateData() : offset(0) { - } - - vector entries; - idx_t offset; -}; - -static unique_ptr PragmaCollateBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("collname"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr PragmaCollateInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - - auto schemas = Catalog::GetAllSchemas(context); - for (auto schema : schemas) { - schema.get().Scan(context, CatalogType::COLLATION_ENTRY, - [&](CatalogEntry &entry) { result->entries.push_back(entry.name); }); - } - return std::move(result); -} - -static void PragmaCollateFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, data.entries.size()); - output.SetCardinality(next - data.offset); - for (idx_t i = data.offset; i < next; i++) { - auto index = i - data.offset; - output.SetValue(0, index, Value(data.entries[i])); - } - - data.offset = next; -} - -void PragmaCollations::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("pragma_collations", {}, PragmaCollateFunction, PragmaCollateBind, PragmaCollateInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_database_size.cpp b/src/duckdb/src/function/table/system/pragma_database_size.cpp deleted file mode 100644 index ba8f2020e..000000000 --- a/src/duckdb/src/function/table/system/pragma_database_size.cpp +++ /dev/null @@ -1,97 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/storage/storage_manager.hpp" -#include "duckdb/storage/block_manager.hpp" -#include "duckdb/storage/storage_info.hpp" -#include "duckdb/common/to_string.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/database_manager.hpp" - -namespace duckdb { - -struct PragmaDatabaseSizeData : public GlobalTableFunctionState { - PragmaDatabaseSizeData() : index(0) { - } - - idx_t index; - vector> databases; - Value memory_usage; - Value memory_limit; -}; - -static unique_ptr PragmaDatabaseSizeBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("database_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("database_size"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("block_size"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("used_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("wal_size"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("memory_usage"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("memory_limit"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr PragmaDatabaseSizeInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - result->databases = DatabaseManager::Get(context).GetDatabases(context); - auto &buffer_manager = BufferManager::GetBufferManager(context); - result->memory_usage = Value(StringUtil::BytesToHumanReadableString(buffer_manager.GetUsedMemory())); - auto max_memory = buffer_manager.GetMaxMemory(); - result->memory_limit = - max_memory == (idx_t)-1 ? Value("Unlimited") : Value(StringUtil::BytesToHumanReadableString(max_memory)); - - return std::move(result); -} - -void PragmaDatabaseSizeFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - idx_t row = 0; - for (; data.index < data.databases.size() && row < STANDARD_VECTOR_SIZE; data.index++) { - auto &db = data.databases[data.index].get(); - if (db.IsSystem() || db.IsTemporary()) { - continue; - } - auto ds = db.GetCatalog().GetDatabaseSize(context); - idx_t col = 0; - output.data[col++].SetValue(row, Value(db.GetName())); - output.data[col++].SetValue(row, Value(StringUtil::BytesToHumanReadableString(ds.bytes))); - output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.block_size))); - output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.total_blocks))); - output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.used_blocks))); - output.data[col++].SetValue(row, Value::BIGINT(NumericCast(ds.free_blocks))); - output.data[col++].SetValue( - row, ds.wal_size == idx_t(-1) ? Value() : Value(StringUtil::BytesToHumanReadableString(ds.wal_size))); - output.data[col++].SetValue(row, data.memory_usage); - output.data[col++].SetValue(row, data.memory_limit); - row++; - } - output.SetCardinality(row); -} - -void PragmaDatabaseSize::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_database_size", {}, PragmaDatabaseSizeFunction, PragmaDatabaseSizeBind, - PragmaDatabaseSizeInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_metadata_info.cpp b/src/duckdb/src/function/table/system/pragma_metadata_info.cpp deleted file mode 100644 index 1a10c671e..000000000 --- a/src/duckdb/src/function/table/system/pragma_metadata_info.cpp +++ /dev/null @@ -1,90 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/storage/database_size.hpp" -#include "duckdb/main/database_manager.hpp" -#include "duckdb/function/function_set.hpp" -namespace duckdb { - -struct PragmaMetadataFunctionData : public TableFunctionData { - explicit PragmaMetadataFunctionData() { - } - - vector metadata_info; -}; - -struct PragmaMetadataOperatorData : public GlobalTableFunctionState { - PragmaMetadataOperatorData() : offset(0) { - } - - idx_t offset; -}; - -static unique_ptr PragmaMetadataInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("block_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("total_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_blocks"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("free_list"); - return_types.emplace_back(LogicalType::LIST(LogicalType::BIGINT)); - - string db_name; - if (input.inputs.empty()) { - db_name = DatabaseManager::GetDefaultDatabase(context); - } else { - if (input.inputs[0].IsNull()) { - throw BinderException("Database argument for pragma_metadata_info cannot be NULL"); - } - db_name = StringValue::Get(input.inputs[0]); - } - auto &catalog = Catalog::GetCatalog(context, db_name); - auto result = make_uniq(); - result->metadata_info = catalog.GetMetadataInfo(context); - return std::move(result); -} - -unique_ptr PragmaMetadataInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaMetadataInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &data = data_p.global_state->Cast(); - idx_t count = 0; - while (data.offset < bind_data.metadata_info.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = bind_data.metadata_info[data.offset++]; - - idx_t col_idx = 0; - // block_id - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); - // total_blocks - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.total_blocks))); - // free_blocks - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.free_list.size()))); - // free_list - vector list_values; - for (auto &free_id : entry.free_list) { - list_values.push_back(Value::BIGINT(NumericCast(free_id))); - } - output.SetValue(col_idx++, count, Value::LIST(LogicalType::BIGINT, std::move(list_values))); - count++; - } - output.SetCardinality(count); -} - -void PragmaMetadataInfo::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet metadata_info("pragma_metadata_info"); - metadata_info.AddFunction( - TableFunction({}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, PragmaMetadataInfoInit)); - metadata_info.AddFunction(TableFunction({LogicalType::VARCHAR}, PragmaMetadataInfoFunction, PragmaMetadataInfoBind, - PragmaMetadataInfoInit)); - set.AddFunction(metadata_info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_storage_info.cpp b/src/duckdb/src/function/table/system/pragma_storage_info.cpp deleted file mode 100644 index 5500c1c5d..000000000 --- a/src/duckdb/src/function/table/system/pragma_storage_info.cpp +++ /dev/null @@ -1,168 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/parser/qualified_name.hpp" -#include "duckdb/planner/constraints/bound_not_null_constraint.hpp" -#include "duckdb/planner/constraints/bound_unique_constraint.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table_storage_info.hpp" -#include "duckdb/planner/binder.hpp" - -#include - -namespace duckdb { - -struct PragmaStorageFunctionData : public TableFunctionData { - explicit PragmaStorageFunctionData(TableCatalogEntry &table_entry) : table_entry(table_entry) { - } - - TableCatalogEntry &table_entry; - vector column_segments_info; -}; - -struct PragmaStorageOperatorData : public GlobalTableFunctionState { - PragmaStorageOperatorData() : offset(0) { - } - - idx_t offset; -}; - -static unique_ptr PragmaStorageInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("row_group_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("column_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("column_path"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("segment_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("segment_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("start"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("count"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("compression"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("stats"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("has_updates"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("persistent"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("block_id"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("block_offset"); - return_types.emplace_back(LogicalType::BIGINT); - - names.emplace_back("segment_info"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("additional_block_ids"); - return_types.emplace_back(LogicalType::LIST(LogicalTypeId::BIGINT)); - - auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); - - // look up the table name in the catalog - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - auto &table_entry = Catalog::GetEntry(context, qname.catalog, qname.schema, qname.name); - auto result = make_uniq(table_entry); - result->column_segments_info = table_entry.GetColumnSegmentInfo(); - return std::move(result); -} - -unique_ptr PragmaStorageInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static Value ValueFromBlockIdList(const vector &block_ids) { - vector blocks; - for (auto &block_id : block_ids) { - blocks.push_back(Value::BIGINT(block_id)); - } - return Value::LIST(LogicalTypeId::BIGINT, blocks); -} - -static void PragmaStorageInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &data = data_p.global_state->Cast(); - idx_t count = 0; - auto &columns = bind_data.table_entry.GetColumns(); - while (data.offset < bind_data.column_segments_info.size() && count < STANDARD_VECTOR_SIZE) { - auto &entry = bind_data.column_segments_info[data.offset++]; - - idx_t col_idx = 0; - // row_group_id - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.row_group_index))); - // column_name - auto &col = columns.GetColumn(PhysicalIndex(entry.column_id)); - output.SetValue(col_idx++, count, Value(col.Name())); - // column_id - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.column_id))); - // column_path - output.SetValue(col_idx++, count, Value(entry.column_path)); - // segment_id - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_idx))); - // segment_type - output.SetValue(col_idx++, count, Value(entry.segment_type)); - // start - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_start))); - // count - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.segment_count))); - // compression - output.SetValue(col_idx++, count, Value(entry.compression_type)); - // stats - output.SetValue(col_idx++, count, Value(entry.segment_stats)); - // has_updates - output.SetValue(col_idx++, count, Value::BOOLEAN(entry.has_updates)); - // persistent - output.SetValue(col_idx++, count, Value::BOOLEAN(entry.persistent)); - // block_id - // block_offset - if (entry.persistent) { - output.SetValue(col_idx++, count, Value::BIGINT(entry.block_id)); - output.SetValue(col_idx++, count, Value::BIGINT(NumericCast(entry.block_offset))); - } else { - output.SetValue(col_idx++, count, Value()); - output.SetValue(col_idx++, count, Value()); - } - // segment_info - output.SetValue(col_idx++, count, Value(entry.segment_info)); - // additional_block_ids - if (entry.persistent) { - output.SetValue(col_idx++, count, ValueFromBlockIdList(entry.additional_blocks)); - } else { - output.SetValue(col_idx++, count, Value()); - } - count++; - } - output.SetCardinality(count); -} - -void PragmaStorageInfo::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_storage_info", {LogicalType::VARCHAR}, PragmaStorageInfoFunction, - PragmaStorageInfoBind, PragmaStorageInfoInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_table_info.cpp b/src/duckdb/src/function/table/system/pragma_table_info.cpp deleted file mode 100644 index 56205d900..000000000 --- a/src/duckdb/src/function/table/system/pragma_table_info.cpp +++ /dev/null @@ -1,297 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/parser/qualified_name.hpp" -#include "duckdb/parser/constraints/not_null_constraint.hpp" -#include "duckdb/parser/constraints/unique_constraint.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/planner/binder.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" - -#include - -namespace duckdb { - -struct PragmaTableFunctionData : public TableFunctionData { - explicit PragmaTableFunctionData(CatalogEntry &entry_p, bool is_table_info) - : entry(entry_p), is_table_info(is_table_info) { - } - - CatalogEntry &entry; - bool is_table_info; -}; - -struct PragmaTableOperatorData : public GlobalTableFunctionState { - PragmaTableOperatorData() : offset(0) { - } - idx_t offset; -}; - -struct ColumnConstraintInfo { - bool not_null = false; - bool pk = false; - bool unique = false; -}; - -static Value DefaultValue(const ColumnDefinition &def) { - if (def.Generated()) { - return Value(def.GeneratedExpression().ToString()); - } - if (!def.HasDefaultValue()) { - return Value(); - } - auto &value = def.DefaultValue(); - return Value(value.ToString()); -} - -struct PragmaTableInfoHelper { - static void GetSchema(vector &return_types, vector &names) { - names.emplace_back("cid"); - return_types.emplace_back(LogicalType::INTEGER); - - names.emplace_back("name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("notnull"); - return_types.emplace_back(LogicalType::BOOLEAN); - - names.emplace_back("dflt_value"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("pk"); - return_types.emplace_back(LogicalType::BOOLEAN); - } - - static void GetTableColumns(const ColumnDefinition &column, ColumnConstraintInfo constraint_info, DataChunk &output, - idx_t index) { - // return values: - // "cid", PhysicalType::INT32 - output.SetValue(0, index, Value::INTEGER((int32_t)column.Oid())); - // "name", PhysicalType::VARCHAR - output.SetValue(1, index, Value(column.Name())); - // "type", PhysicalType::VARCHAR - output.SetValue(2, index, Value(column.Type().ToString())); - // "notnull", PhysicalType::BOOL - output.SetValue(3, index, Value::BOOLEAN(constraint_info.not_null)); - // "dflt_value", PhysicalType::VARCHAR - output.SetValue(4, index, DefaultValue(column)); - // "pk", PhysicalType::BOOL - output.SetValue(5, index, Value::BOOLEAN(constraint_info.pk)); - } - - static void GetViewColumns(idx_t i, const string &name, const LogicalType &type, DataChunk &output, idx_t index) { - // return values: - // "cid", PhysicalType::INT32 - output.SetValue(0, index, Value::INTEGER((int32_t)i)); - // "name", PhysicalType::VARCHAR - output.SetValue(1, index, Value(name)); - // "type", PhysicalType::VARCHAR - output.SetValue(2, index, Value(type.ToString())); - // "notnull", PhysicalType::BOOL - output.SetValue(3, index, Value::BOOLEAN(false)); - // "dflt_value", PhysicalType::VARCHAR - output.SetValue(4, index, Value()); - // "pk", PhysicalType::BOOL - output.SetValue(5, index, Value::BOOLEAN(false)); - } -}; - -struct PragmaShowHelper { - static void GetSchema(vector &return_types, vector &names) { - names.emplace_back("column_name"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("column_type"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("null"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("key"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("default"); - return_types.emplace_back(LogicalType::VARCHAR); - - names.emplace_back("extra"); - return_types.emplace_back(LogicalType::VARCHAR); - } - - static void GetTableColumns(const ColumnDefinition &column, ColumnConstraintInfo constraint_info, DataChunk &output, - idx_t index) { - // "column_name", PhysicalType::VARCHAR - output.SetValue(0, index, Value(column.Name())); - // "column_type", PhysicalType::VARCHAR - output.SetValue(1, index, Value(column.Type().ToString())); - // "null", PhysicalType::VARCHAR - output.SetValue(2, index, Value(constraint_info.not_null ? "NO" : "YES")); - // "key", PhysicalType::VARCHAR - Value key; - if (constraint_info.pk || constraint_info.unique) { - key = Value(constraint_info.pk ? "PRI" : "UNI"); - } - output.SetValue(3, index, key); - // "default", VARCHAR - output.SetValue(4, index, DefaultValue(column)); - // "extra", VARCHAR - output.SetValue(5, index, Value()); - } - - static void GetViewColumns(idx_t i, const string &name, const LogicalType &type, DataChunk &output, idx_t index) { - // "column_name", PhysicalType::VARCHAR - output.SetValue(0, index, Value(name)); - // "column_type", PhysicalType::VARCHAR - output.SetValue(1, index, Value(type.ToString())); - // "null", PhysicalType::VARCHAR - output.SetValue(2, index, Value("YES")); - // "key", PhysicalType::VARCHAR - output.SetValue(3, index, Value()); - // "default", VARCHAR - output.SetValue(4, index, Value()); - // "extra", VARCHAR - output.SetValue(5, index, Value()); - } -}; - -template -static unique_ptr PragmaTableInfoBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - if (IS_PRAGMA_TABLE_INFO) { - PragmaTableInfoHelper::GetSchema(return_types, names); - } else { - PragmaShowHelper::GetSchema(return_types, names); - } - - auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); - - // look up the table name in the catalog - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - auto &entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, qname.schema, qname.name); - return make_uniq(entry, IS_PRAGMA_TABLE_INFO); -} - -unique_ptr PragmaTableInfoInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static ColumnConstraintInfo CheckConstraints(TableCatalogEntry &table, const ColumnDefinition &column) { - ColumnConstraintInfo result; - // check all constraints - for (auto &constraint : table.GetConstraints()) { - switch (constraint->type) { - case ConstraintType::NOT_NULL: { - auto ¬_null = constraint->Cast(); - if (not_null.index == column.Logical()) { - result.not_null = true; - } - break; - } - case ConstraintType::UNIQUE: { - auto &unique = constraint->Cast(); - bool &constraint_info = unique.IsPrimaryKey() ? result.pk : result.unique; - if (unique.HasIndex()) { - if (unique.GetIndex() == column.Logical()) { - constraint_info = true; - } - } else { - auto &columns = unique.GetColumnNames(); - if (std::find(columns.begin(), columns.end(), column.GetName()) != columns.end()) { - constraint_info = true; - } - } - break; - } - default: - break; - } - } - return result; -} - -void PragmaTableInfo::GetColumnInfo(TableCatalogEntry &table, const ColumnDefinition &column, DataChunk &output, - idx_t index) { - auto constraint_info = CheckConstraints(table, column); - PragmaShowHelper::GetTableColumns(column, constraint_info, output, index); -} - -static void PragmaTableInfoTable(PragmaTableOperatorData &data, TableCatalogEntry &table, DataChunk &output, - bool is_table_info) { - if (data.offset >= table.GetColumns().LogicalColumnCount()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, table.GetColumns().LogicalColumnCount()); - output.SetCardinality(next - data.offset); - - for (idx_t i = data.offset; i < next; i++) { - auto index = i - data.offset; - auto &column = table.GetColumn(LogicalIndex(i)); - D_ASSERT(column.Oid() < (idx_t)NumericLimits::Maximum()); - auto constraint_info = CheckConstraints(table, column); - - if (is_table_info) { - PragmaTableInfoHelper::GetTableColumns(column, constraint_info, output, index); - } else { - PragmaShowHelper::GetTableColumns(column, constraint_info, output, index); - } - } - data.offset = next; -} - -static void PragmaTableInfoView(PragmaTableOperatorData &data, ViewCatalogEntry &view, DataChunk &output, - bool is_table_info) { - if (data.offset >= view.types.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t next = MinValue(data.offset + STANDARD_VECTOR_SIZE, view.types.size()); - output.SetCardinality(next - data.offset); - - for (idx_t i = data.offset; i < next; i++) { - auto index = i - data.offset; - auto type = view.types[i]; - auto &name = i < view.aliases.size() ? view.aliases[i] : view.names[i]; - - if (is_table_info) { - PragmaTableInfoHelper::GetViewColumns(i, name, type, output, index); - } else { - PragmaShowHelper::GetViewColumns(i, name, type, output, index); - } - } - data.offset = next; -} - -static void PragmaTableInfoFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - switch (bind_data.entry.type) { - case CatalogType::TABLE_ENTRY: - PragmaTableInfoTable(state, bind_data.entry.Cast(), output, bind_data.is_table_info); - break; - case CatalogType::VIEW_ENTRY: - PragmaTableInfoView(state, bind_data.entry.Cast(), output, bind_data.is_table_info); - break; - default: - throw NotImplementedException("Unimplemented catalog type for pragma_table_info"); - } -} - -void PragmaTableInfo::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("pragma_table_info", {LogicalType::VARCHAR}, PragmaTableInfoFunction, - PragmaTableInfoBind, PragmaTableInfoInit)); - set.AddFunction(TableFunction("pragma_show", {LogicalType::VARCHAR}, PragmaTableInfoFunction, - PragmaTableInfoBind, PragmaTableInfoInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_table_sample.cpp b/src/duckdb/src/function/table/system/pragma_table_sample.cpp deleted file mode 100644 index 7f4122b92..000000000 --- a/src/duckdb/src/function/table/system/pragma_table_sample.cpp +++ /dev/null @@ -1,95 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" - -#include "duckdb/catalog/catalog.hpp" -#include "duckdb/catalog/catalog_entry/table_catalog_entry.hpp" -#include "duckdb/catalog/catalog_entry/view_catalog_entry.hpp" -#include "duckdb/parser/qualified_name.hpp" -#include "duckdb/parser/constraints/not_null_constraint.hpp" -#include "duckdb/parser/constraints/unique_constraint.hpp" -#include "duckdb/planner/expression/bound_parameter_expression.hpp" -#include "duckdb/planner/binder.hpp" - -#include "duckdb/common/exception.hpp" -#include "duckdb/common/limits.hpp" - -#include - -namespace duckdb { - -struct DuckDBTableSampleFunctionData : public TableFunctionData { - explicit DuckDBTableSampleFunctionData(CatalogEntry &entry_p) : entry(entry_p) { - } - CatalogEntry &entry; -}; - -struct DuckDBTableSampleOperatorData : public GlobalTableFunctionState { - DuckDBTableSampleOperatorData() : sample_offset(0) { - sample = nullptr; - } - idx_t sample_offset; - unique_ptr sample; -}; - -static unique_ptr DuckDBTableSampleBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - // look up the table name in the catalog - auto qname = QualifiedName::Parse(input.inputs[0].GetValue()); - Binder::BindSchemaOrCatalog(context, qname.catalog, qname.schema); - - auto &entry = Catalog::GetEntry(context, CatalogType::TABLE_ENTRY, qname.catalog, qname.schema, qname.name); - if (entry.type != CatalogType::TABLE_ENTRY) { - throw NotImplementedException("Invalid Catalog type passed to table_sample()"); - } - auto &table_entry = entry.Cast(); - auto types = table_entry.GetTypes(); - for (auto &type : types) { - return_types.push_back(type); - } - for (idx_t i = 0; i < types.size(); i++) { - auto logical_index = LogicalIndex(i); - auto &col = table_entry.GetColumn(logical_index); - names.push_back(col.GetName()); - } - - return make_uniq(entry); -} - -unique_ptr DuckDBTableSampleInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void DuckDBTableSampleTable(ClientContext &context, DuckDBTableSampleOperatorData &data, - TableCatalogEntry &table, DataChunk &output) { - // if table has statistics. - // copy the sample of statistics into the output chunk - if (!data.sample) { - data.sample = table.GetSample(); - } - if (data.sample) { - auto sample_chunk = data.sample->GetChunk(); - if (sample_chunk) { - sample_chunk->Copy(output, 0); - data.sample_offset += sample_chunk->size(); - } - } -} - -static void DuckDBTableSampleFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &bind_data = data_p.bind_data->Cast(); - auto &state = data_p.global_state->Cast(); - switch (bind_data.entry.type) { - case CatalogType::TABLE_ENTRY: - DuckDBTableSampleTable(context, state, bind_data.entry.Cast(), output); - break; - default: - throw NotImplementedException("Unimplemented catalog type for pragma_table_sample"); - } -} - -void DuckDBTableSample::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction(TableFunction("duckdb_table_sample", {LogicalType::VARCHAR}, DuckDBTableSampleFunction, - DuckDBTableSampleBind, DuckDBTableSampleInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/pragma_user_agent.cpp b/src/duckdb/src/function/table/system/pragma_user_agent.cpp deleted file mode 100644 index 3803f7195..000000000 --- a/src/duckdb/src/function/table/system/pragma_user_agent.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/config.hpp" - -namespace duckdb { - -struct PragmaUserAgentData : public GlobalTableFunctionState { - PragmaUserAgentData() : finished(false) { - } - - std::string user_agent; - bool finished; -}; - -static unique_ptr PragmaUserAgentBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - - names.emplace_back("user_agent"); - return_types.emplace_back(LogicalType::VARCHAR); - - return nullptr; -} - -unique_ptr PragmaUserAgentInit(ClientContext &context, TableFunctionInitInput &input) { - auto result = make_uniq(); - auto &config = DBConfig::GetConfig(context); - result->user_agent = config.UserAgent(); - - return std::move(result); -} - -void PragmaUserAgentFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - - if (data.finished) { - // signal end of output - return; - } - - output.SetCardinality(1); - output.SetValue(0, 0, data.user_agent); - - data.finished = true; -} - -void PragmaUserAgent::RegisterFunction(BuiltinFunctions &set) { - set.AddFunction( - TableFunction("pragma_user_agent", {}, PragmaUserAgentFunction, PragmaUserAgentBind, PragmaUserAgentInit)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/test_all_types.cpp b/src/duckdb/src/function/table/system/test_all_types.cpp deleted file mode 100644 index 0fbcd884b..000000000 --- a/src/duckdb/src/function/table/system/test_all_types.cpp +++ /dev/null @@ -1,352 +0,0 @@ -#include "duckdb/common/operator/cast_operators.hpp" -#include "duckdb/common/pair.hpp" -#include "duckdb/common/types/bit.hpp" -#include "duckdb/common/types/date.hpp" -#include "duckdb/common/types/timestamp.hpp" -#include "duckdb/function/table/system_functions.hpp" - -#include -#include - -namespace duckdb { - -struct TestAllTypesData : public GlobalTableFunctionState { - TestAllTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -vector TestAllTypesFun::GetTestTypes(bool use_large_enum) { - vector result; - // scalar types/numerics - result.emplace_back(LogicalType::BOOLEAN, "bool"); - result.emplace_back(LogicalType::TINYINT, "tinyint"); - result.emplace_back(LogicalType::SMALLINT, "smallint"); - result.emplace_back(LogicalType::INTEGER, "int"); - result.emplace_back(LogicalType::BIGINT, "bigint"); - result.emplace_back(LogicalType::HUGEINT, "hugeint"); - result.emplace_back(LogicalType::UHUGEINT, "uhugeint"); - result.emplace_back(LogicalType::UTINYINT, "utinyint"); - result.emplace_back(LogicalType::USMALLINT, "usmallint"); - result.emplace_back(LogicalType::UINTEGER, "uint"); - result.emplace_back(LogicalType::UBIGINT, "ubigint"); - result.emplace_back(LogicalType::VARINT, "varint"); - result.emplace_back(LogicalType::DATE, "date"); - result.emplace_back(LogicalType::TIME, "time"); - result.emplace_back(LogicalType::TIMESTAMP, "timestamp"); - result.emplace_back(LogicalType::TIMESTAMP_S, "timestamp_s"); - result.emplace_back(LogicalType::TIMESTAMP_MS, "timestamp_ms"); - result.emplace_back(LogicalType::TIMESTAMP_NS, "timestamp_ns"); - result.emplace_back(LogicalType::TIME_TZ, "time_tz"); - result.emplace_back(LogicalType::TIMESTAMP_TZ, "timestamp_tz"); - result.emplace_back(LogicalType::FLOAT, "float"); - result.emplace_back(LogicalType::DOUBLE, "double"); - result.emplace_back(LogicalType::DECIMAL(4, 1), "dec_4_1"); - result.emplace_back(LogicalType::DECIMAL(9, 4), "dec_9_4"); - result.emplace_back(LogicalType::DECIMAL(18, 6), "dec_18_6"); - result.emplace_back(LogicalType::DECIMAL(38, 10), "dec38_10"); - result.emplace_back(LogicalType::UUID, "uuid"); - - // interval - interval_t min_interval; - min_interval.months = 0; - min_interval.days = 0; - min_interval.micros = 0; - - interval_t max_interval; - max_interval.months = 999; - max_interval.days = 999; - max_interval.micros = 999999999; - result.emplace_back(LogicalType::INTERVAL, "interval", Value::INTERVAL(min_interval), - Value::INTERVAL(max_interval)); - // strings/blobs/bitstrings - result.emplace_back(LogicalType::VARCHAR, "varchar", Value("🦆🦆🦆🦆🦆🦆"), - Value(string("goo\x00se", 6))); - result.emplace_back(LogicalType::BLOB, "blob", Value::BLOB("thisisalongblob\\x00withnullbytes"), - Value::BLOB("\\x00\\x00\\x00a")); - result.emplace_back(LogicalType::BIT, "bit", Value::BIT("0010001001011100010101011010111"), Value::BIT("10101")); - - // enums - Vector small_enum(LogicalType::VARCHAR, 2); - auto small_enum_ptr = FlatVector::GetData(small_enum); - small_enum_ptr[0] = StringVector::AddStringOrBlob(small_enum, "DUCK_DUCK_ENUM"); - small_enum_ptr[1] = StringVector::AddStringOrBlob(small_enum, "GOOSE"); - result.emplace_back(LogicalType::ENUM(small_enum, 2), "small_enum"); - - Vector medium_enum(LogicalType::VARCHAR, 300); - auto medium_enum_ptr = FlatVector::GetData(medium_enum); - for (idx_t i = 0; i < 300; i++) { - medium_enum_ptr[i] = StringVector::AddStringOrBlob(medium_enum, string("enum_") + to_string(i)); - } - result.emplace_back(LogicalType::ENUM(medium_enum, 300), "medium_enum"); - - if (use_large_enum) { - // this is a big one... not sure if we should push this one here, but it's required for completeness - Vector large_enum(LogicalType::VARCHAR, 70000); - auto large_enum_ptr = FlatVector::GetData(large_enum); - for (idx_t i = 0; i < 70000; i++) { - large_enum_ptr[i] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(i)); - } - result.emplace_back(LogicalType::ENUM(large_enum, 70000), "large_enum"); - } else { - Vector large_enum(LogicalType::VARCHAR, 2); - auto large_enum_ptr = FlatVector::GetData(large_enum); - large_enum_ptr[0] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(0)); - large_enum_ptr[1] = StringVector::AddStringOrBlob(large_enum, string("enum_") + to_string(69999)); - result.emplace_back(LogicalType::ENUM(large_enum, 2), "large_enum"); - } - - // arrays - auto int_list_type = LogicalType::LIST(LogicalType::INTEGER); - auto empty_int_list = Value::LIST(LogicalType::INTEGER, vector()); - auto int_list = - Value::LIST(LogicalType::INTEGER, {Value::INTEGER(42), Value::INTEGER(999), Value(LogicalType::INTEGER), - Value(LogicalType::INTEGER), Value::INTEGER(-42)}); - result.emplace_back(int_list_type, "int_array", empty_int_list, int_list); - - auto double_list_type = LogicalType::LIST(LogicalType::DOUBLE); - auto empty_double_list = Value::LIST(LogicalType::DOUBLE, vector()); - auto double_list = Value::LIST(LogicalType::DOUBLE, {Value::DOUBLE(42), Value::DOUBLE(NAN), - Value::DOUBLE(std::numeric_limits::infinity()), - Value::DOUBLE(-std::numeric_limits::infinity()), - Value(LogicalType::DOUBLE), Value::DOUBLE(-42)}); - result.emplace_back(double_list_type, "double_array", empty_double_list, double_list); - - auto date_list_type = LogicalType::LIST(LogicalType::DATE); - auto empty_date_list = Value::LIST(LogicalType::DATE, vector()); - auto date_list = Value::LIST(LogicalType::DATE, {Value::DATE(date_t()), Value::DATE(date_t::infinity()), - Value::DATE(date_t::ninfinity()), Value(LogicalType::DATE), - Value::DATE(Date::FromString("2022-05-12"))}); - result.emplace_back(date_list_type, "date_array", empty_date_list, date_list); - - auto timestamp_list_type = LogicalType::LIST(LogicalType::TIMESTAMP); - auto empty_timestamp_list = Value::LIST(LogicalType::TIMESTAMP, vector()); - auto timestamp_list = - Value::LIST(LogicalType::TIMESTAMP, {Value::TIMESTAMP(timestamp_t()), Value::TIMESTAMP(timestamp_t::infinity()), - Value::TIMESTAMP(timestamp_t::ninfinity()), Value(LogicalType::TIMESTAMP), - Value::TIMESTAMP(Timestamp::FromString("2022-05-12 16:23:45"))}); - result.emplace_back(timestamp_list_type, "timestamp_array", empty_timestamp_list, timestamp_list); - - auto timestamptz_list_type = LogicalType::LIST(LogicalType::TIMESTAMP_TZ); - auto empty_timestamptz_list = Value::LIST(LogicalType::TIMESTAMP_TZ, vector()); - auto timestamptz_list = - Value::LIST(LogicalType::TIMESTAMP_TZ, - {Value::TIMESTAMPTZ(timestamp_tz_t()), Value::TIMESTAMPTZ(timestamp_tz_t(timestamp_t::infinity())), - Value::TIMESTAMPTZ(timestamp_tz_t(timestamp_t::ninfinity())), Value(LogicalType::TIMESTAMP_TZ), - Value::TIMESTAMPTZ(timestamp_tz_t(Timestamp::FromString("2022-05-12 16:23:45-07")))}); - result.emplace_back(timestamptz_list_type, "timestamptz_array", empty_timestamptz_list, timestamptz_list); - - auto varchar_list_type = LogicalType::LIST(LogicalType::VARCHAR); - auto empty_varchar_list = Value::LIST(LogicalType::VARCHAR, vector()); - auto varchar_list = Value::LIST(LogicalType::VARCHAR, {Value("🦆🦆🦆🦆🦆🦆"), Value("goose"), - Value(LogicalType::VARCHAR), Value("")}); - result.emplace_back(varchar_list_type, "varchar_array", empty_varchar_list, varchar_list); - - // nested arrays - auto nested_list_type = LogicalType::LIST(int_list_type); - auto empty_nested_list = Value::LIST(int_list_type, vector()); - auto nested_int_list = - Value::LIST(int_list_type, {empty_int_list, int_list, Value(int_list_type), empty_int_list, int_list}); - result.emplace_back(nested_list_type, "nested_int_array", empty_nested_list, nested_int_list); - - // structs - child_list_t struct_type_list; - struct_type_list.push_back(make_pair("a", LogicalType::INTEGER)); - struct_type_list.push_back(make_pair("b", LogicalType::VARCHAR)); - auto struct_type = LogicalType::STRUCT(struct_type_list); - - child_list_t min_struct_list; - min_struct_list.push_back(make_pair("a", Value(LogicalType::INTEGER))); - min_struct_list.push_back(make_pair("b", Value(LogicalType::VARCHAR))); - auto min_struct_val = Value::STRUCT(std::move(min_struct_list)); - - child_list_t max_struct_list; - max_struct_list.push_back(make_pair("a", Value::INTEGER(42))); - max_struct_list.push_back(make_pair("b", Value("🦆🦆🦆🦆🦆🦆"))); - auto max_struct_val = Value::STRUCT(std::move(max_struct_list)); - - result.emplace_back(struct_type, "struct", min_struct_val, max_struct_val); - - // structs with lists - child_list_t struct_list_type_list; - struct_list_type_list.push_back(make_pair("a", int_list_type)); - struct_list_type_list.push_back(make_pair("b", varchar_list_type)); - auto struct_list_type = LogicalType::STRUCT(struct_list_type_list); - - child_list_t min_struct_vl_list; - min_struct_vl_list.push_back(make_pair("a", Value(int_list_type))); - min_struct_vl_list.push_back(make_pair("b", Value(varchar_list_type))); - auto min_struct_val_list = Value::STRUCT(std::move(min_struct_vl_list)); - - child_list_t max_struct_vl_list; - max_struct_vl_list.push_back(make_pair("a", int_list)); - max_struct_vl_list.push_back(make_pair("b", varchar_list)); - auto max_struct_val_list = Value::STRUCT(std::move(max_struct_vl_list)); - - result.emplace_back(struct_list_type, "struct_of_arrays", std::move(min_struct_val_list), - std::move(max_struct_val_list)); - - // array of structs - auto array_of_structs_type = LogicalType::LIST(struct_type); - auto min_array_of_struct_val = Value::LIST(struct_type, vector()); - auto max_array_of_struct_val = Value::LIST(struct_type, {min_struct_val, max_struct_val, Value(struct_type)}); - result.emplace_back(array_of_structs_type, "array_of_structs", std::move(min_array_of_struct_val), - std::move(max_array_of_struct_val)); - - // map - auto map_type = LogicalType::MAP(LogicalType::VARCHAR, LogicalType::VARCHAR); - auto min_map_value = Value::MAP(ListType::GetChildType(map_type), vector()); - - child_list_t map_struct1; - map_struct1.push_back(make_pair("key", Value("key1"))); - map_struct1.push_back(make_pair("value", Value("🦆🦆🦆🦆🦆🦆"))); - child_list_t map_struct2; - map_struct2.push_back(make_pair("key", Value("key2"))); - map_struct2.push_back(make_pair("value", Value("goose"))); - - vector map_values; - map_values.push_back(Value::STRUCT(map_struct1)); - map_values.push_back(Value::STRUCT(map_struct2)); - - auto max_map_value = Value::MAP(ListType::GetChildType(map_type), map_values); - result.emplace_back(map_type, "map", std::move(min_map_value), std::move(max_map_value)); - - // union - child_list_t members = {{"name", LogicalType::VARCHAR}, {"age", LogicalType::SMALLINT}}; - auto union_type = LogicalType::UNION(members); - const Value &min = Value::UNION(members, 0, Value("Frank")); - const Value &max = Value::UNION(members, 1, Value::SMALLINT(5)); - result.emplace_back(union_type, "union", min, max); - - // fixed int array - auto fixed_int_array_type = LogicalType::ARRAY(LogicalType::INTEGER, 3); - auto fixed_int_min_array_value = Value::ARRAY(LogicalType::INTEGER, {Value(LogicalType::INTEGER), 2, 3}); - auto fixed_int_max_array_value = Value::ARRAY(LogicalType::INTEGER, {4, 5, 6}); - result.emplace_back(fixed_int_array_type, "fixed_int_array", fixed_int_min_array_value, fixed_int_max_array_value); - - // fixed varchar array - auto fixed_varchar_array_type = LogicalType::ARRAY(LogicalType::VARCHAR, 3); - auto fixed_varchar_min_array_value = - Value::ARRAY(LogicalType::VARCHAR, {Value("a"), Value(LogicalType::VARCHAR), Value("c")}); - auto fixed_varchar_max_array_value = Value::ARRAY(LogicalType::VARCHAR, {Value("d"), Value("e"), Value("f")}); - result.emplace_back(fixed_varchar_array_type, "fixed_varchar_array", fixed_varchar_min_array_value, - fixed_varchar_max_array_value); - - // fixed nested int array - auto fixed_nested_int_array_type = LogicalType::ARRAY(fixed_int_array_type, 3); - auto fixed_nested_int_min_array_value = Value::ARRAY( - fixed_int_array_type, {fixed_int_min_array_value, Value(fixed_int_array_type), fixed_int_min_array_value}); - auto fixed_nested_int_max_array_value = Value::ARRAY( - fixed_int_array_type, {fixed_int_max_array_value, fixed_int_min_array_value, fixed_int_max_array_value}); - result.emplace_back(fixed_nested_int_array_type, "fixed_nested_int_array", fixed_nested_int_min_array_value, - fixed_nested_int_max_array_value); - - // fixed nested varchar array - auto fixed_nested_varchar_array_type = LogicalType::ARRAY(fixed_varchar_array_type, 3); - auto fixed_nested_varchar_min_array_value = - Value::ARRAY(fixed_varchar_array_type, - {fixed_varchar_min_array_value, Value(fixed_varchar_array_type), fixed_varchar_min_array_value}); - auto fixed_nested_varchar_max_array_value = - Value::ARRAY(fixed_varchar_array_type, - {fixed_varchar_max_array_value, fixed_varchar_min_array_value, fixed_varchar_max_array_value}); - result.emplace_back(fixed_nested_varchar_array_type, "fixed_nested_varchar_array", - fixed_nested_varchar_min_array_value, fixed_nested_varchar_max_array_value); - - // fixed array of structs - auto fixed_struct_array_type = LogicalType::ARRAY(struct_type, 3); - auto fixed_struct_min_array_value = Value::ARRAY(struct_type, {min_struct_val, max_struct_val, min_struct_val}); - auto fixed_struct_max_array_value = Value::ARRAY(struct_type, {max_struct_val, min_struct_val, max_struct_val}); - result.emplace_back(fixed_struct_array_type, "fixed_struct_array", fixed_struct_min_array_value, - fixed_struct_max_array_value); - - // struct of fixed array - auto struct_of_fixed_array_type = - LogicalType::STRUCT({{"a", fixed_int_array_type}, {"b", fixed_varchar_array_type}}); - auto struct_of_fixed_array_min_value = - Value::STRUCT({{"a", fixed_int_min_array_value}, {"b", fixed_varchar_min_array_value}}); - auto struct_of_fixed_array_max_value = - Value::STRUCT({{"a", fixed_int_max_array_value}, {"b", fixed_varchar_max_array_value}}); - result.emplace_back(struct_of_fixed_array_type, "struct_of_fixed_array", struct_of_fixed_array_min_value, - struct_of_fixed_array_max_value); - - // fixed array of list of int - auto fixed_array_of_list_of_int_type = LogicalType::ARRAY(int_list_type, 3); - auto fixed_array_of_list_of_int_min_value = Value::ARRAY(int_list_type, {empty_int_list, int_list, empty_int_list}); - auto fixed_array_of_list_of_int_max_value = Value::ARRAY(int_list_type, {int_list, empty_int_list, int_list}); - result.emplace_back(fixed_array_of_list_of_int_type, "fixed_array_of_int_list", - fixed_array_of_list_of_int_min_value, fixed_array_of_list_of_int_max_value); - - // list of fixed array of int - auto list_of_fixed_array_of_int_type = LogicalType::LIST(fixed_int_array_type); - auto list_of_fixed_array_of_int_min_value = Value::LIST( - fixed_int_array_type, {fixed_int_min_array_value, fixed_int_max_array_value, fixed_int_min_array_value}); - auto list_of_fixed_array_of_int_max_value = Value::LIST( - fixed_int_array_type, {fixed_int_max_array_value, fixed_int_min_array_value, fixed_int_max_array_value}); - result.emplace_back(list_of_fixed_array_of_int_type, "list_of_fixed_int_array", - list_of_fixed_array_of_int_min_value, list_of_fixed_array_of_int_max_value); - - return result; -} - -struct TestAllTypesBindData : public TableFunctionData { - vector test_types; -}; - -static unique_ptr TestAllTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - bool use_large_enum = false; - auto entry = input.named_parameters.find("use_large_enum"); - if (entry != input.named_parameters.end()) { - use_large_enum = BooleanValue::Get(entry->second); - } - result->test_types = TestAllTypesFun::GetTestTypes(use_large_enum); - for (auto &test_type : result->test_types) { - return_types.push_back(test_type.type); - names.push_back(test_type.name); - } - return std::move(result); -} - -unique_ptr TestAllTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - // 3 rows: min, max and NULL - result->entries.resize(3); - // initialize the values - for (auto &test_type : bind_data.test_types) { - result->entries[0].push_back(test_type.min_value); - result->entries[1].push_back(test_type.max_value); - result->entries[2].emplace_back(test_type.type); - } - return std::move(result); -} - -void TestAllTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - // start returning values - // either fill up the chunk or return all the remaining columns - idx_t count = 0; - while (data.offset < data.entries.size() && count < STANDARD_VECTOR_SIZE) { - auto &vals = data.entries[data.offset++]; - for (idx_t col_idx = 0; col_idx < vals.size(); col_idx++) { - output.SetValue(col_idx, count, vals[col_idx]); - } - count++; - } - output.SetCardinality(count); -} - -void TestAllTypesFun::RegisterFunction(BuiltinFunctions &set) { - TableFunction test_all_types("test_all_types", {}, TestAllTypesFunction, TestAllTypesBind, TestAllTypesInit); - test_all_types.named_parameters["use_large_enum"] = LogicalType::BOOLEAN; - set.AddFunction(test_all_types); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system/test_vector_types.cpp b/src/duckdb/src/function/table/system/test_vector_types.cpp deleted file mode 100644 index 23dab8758..000000000 --- a/src/duckdb/src/function/table/system/test_vector_types.cpp +++ /dev/null @@ -1,336 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/common/map.hpp" -#include "duckdb/common/pair.hpp" - -namespace duckdb { - -// FLAT, CONSTANT, DICTIONARY, SEQUENCE -struct TestVectorBindData : public TableFunctionData { - vector types; - bool all_flat = false; -}; - -struct TestVectorTypesData : public GlobalTableFunctionState { - TestVectorTypesData() : offset(0) { - } - - vector> entries; - idx_t offset; -}; - -struct TestVectorInfo { - TestVectorInfo(const vector &types, const map &test_type_map, - vector> &entries) - : types(types), test_type_map(test_type_map), entries(entries) { - } - - const vector &types; - const map &test_type_map; - vector> &entries; -}; - -struct TestGeneratedValues { -public: - void AddColumn(vector values) { - if (!column_values.empty() && column_values[0].size() != values.size()) { - throw InternalException("Size mismatch when adding a column to TestGeneratedValues"); - } - column_values.push_back(std::move(values)); - } - - const Value &GetValue(idx_t row, idx_t column) const { - return column_values[column][row]; - } - - idx_t Rows() const { - return column_values.empty() ? 0 : column_values[0].size(); - } - - idx_t Columns() const { - return column_values.size(); - } - -private: - vector> column_values; -}; - -struct TestVectorFlat { - static constexpr const idx_t TEST_VECTOR_CARDINALITY = 3; - - static vector GenerateValues(TestVectorInfo &info, const LogicalType &type) { - vector result; - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - vector> struct_children; - auto &child_types = StructType::GetChildTypes(type); - - struct_children.resize(TEST_VECTOR_CARDINALITY); - for (auto &child_type : child_types) { - auto child_values = GenerateValues(info, child_type.second); - - for (idx_t i = 0; i < child_values.size(); i++) { - struct_children[i].push_back(make_pair(child_type.first, std::move(child_values[i]))); - } - } - for (auto &struct_child : struct_children) { - result.push_back(Value::STRUCT(std::move(struct_child))); - } - break; - } - case PhysicalType::LIST: { - if (type.id() == LogicalTypeId::MAP) { - auto &child_type = ListType::GetChildType(type); - auto child_values = GenerateValues(info, child_type); - result.push_back(Value::MAP(child_type, {child_values[0]})); - result.push_back(Value(type)); - result.push_back(Value::MAP(child_type, {child_values[1]})); - break; - } - auto &child_type = ListType::GetChildType(type); - auto child_values = GenerateValues(info, child_type); - - result.push_back(Value::LIST(child_type, {child_values[0], child_values[1]})); - result.push_back(Value::LIST(child_type, {})); - result.push_back(Value::LIST(child_type, {child_values[2]})); - break; - } - default: { - auto entry = info.test_type_map.find(type.id()); - if (entry == info.test_type_map.end()) { - throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); - } - result.push_back(entry->second.min_value); - result.push_back(entry->second.max_value); - result.emplace_back(type); - break; - } - } - return result; - } - - static TestGeneratedValues GenerateValues(TestVectorInfo &info) { - // generate the values for each column - TestGeneratedValues generated_values; - for (auto &type : info.types) { - generated_values.AddColumn(GenerateValues(info, type)); - } - return generated_values; - } - - static void Generate(TestVectorInfo &info) { - auto result_values = GenerateValues(info); - for (idx_t cur_row = 0; cur_row < result_values.Rows(); cur_row += STANDARD_VECTOR_SIZE) { - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); - auto cardinality = MinValue(STANDARD_VECTOR_SIZE, result_values.Rows() - cur_row); - for (idx_t c = 0; c < info.types.size(); c++) { - for (idx_t i = 0; i < cardinality; i++) { - result->data[c].SetValue(i, result_values.GetValue(cur_row + i, c)); - } - } - result->SetCardinality(cardinality); - info.entries.push_back(std::move(result)); - } - } -}; - -struct TestVectorConstant { - static void Generate(TestVectorInfo &info) { - auto values = TestVectorFlat::GenerateValues(info); - for (idx_t cur_row = 0; cur_row < TestVectorFlat::TEST_VECTOR_CARDINALITY; cur_row += STANDARD_VECTOR_SIZE) { - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types); - auto cardinality = MinValue(STANDARD_VECTOR_SIZE, TestVectorFlat::TEST_VECTOR_CARDINALITY - cur_row); - for (idx_t c = 0; c < info.types.size(); c++) { - result->data[c].SetValue(0, values.GetValue(0, c)); - result->data[c].SetVectorType(VectorType::CONSTANT_VECTOR); - } - result->SetCardinality(cardinality); - - info.entries.push_back(std::move(result)); - } - } -}; - -struct TestVectorSequence { - static void GenerateVector(TestVectorInfo &info, const LogicalType &type, Vector &result) { - D_ASSERT(type == result.GetType()); - switch (type.id()) { - case LogicalTypeId::TINYINT: - case LogicalTypeId::SMALLINT: - case LogicalTypeId::INTEGER: - case LogicalTypeId::BIGINT: - case LogicalTypeId::UTINYINT: - case LogicalTypeId::USMALLINT: - case LogicalTypeId::UINTEGER: - case LogicalTypeId::UBIGINT: - result.Sequence(3, 2, 3); -#if STANDARD_VECTOR_SIZE <= 2 - result.Flatten(3); -#endif - return; - default: - break; - } - switch (type.InternalType()) { - case PhysicalType::STRUCT: { - auto &child_entries = StructVector::GetEntries(result); - for (auto &child_entry : child_entries) { - GenerateVector(info, child_entry->GetType(), *child_entry); - } - break; - } - case PhysicalType::LIST: { - D_ASSERT(type.id() != LogicalTypeId::MAP); - auto data = FlatVector::GetData(result); - data[0].offset = 0; - data[0].length = 2; - data[1].offset = 2; - data[1].length = 0; - data[2].offset = 2; - data[2].length = 1; - - GenerateVector(info, ListType::GetChildType(type), ListVector::GetEntry(result)); - ListVector::SetListSize(result, 3); - break; - } - default: { - auto entry = info.test_type_map.find(type.id()); - if (entry == info.test_type_map.end()) { - throw NotImplementedException("Unimplemented type for test_vector_types %s", type.ToString()); - } - result.SetValue(0, entry->second.min_value); - result.SetValue(1, entry->second.max_value); - result.SetValue(2, Value(type)); - break; - } - } - } - - static void Generate(TestVectorInfo &info) { - static constexpr const idx_t SEQ_CARDINALITY = 3; - - auto result = make_uniq(); - result->Initialize(Allocator::DefaultAllocator(), info.types, - MaxValue(SEQ_CARDINALITY, STANDARD_VECTOR_SIZE)); - - for (idx_t c = 0; c < info.types.size(); c++) { - if (info.types[c].id() == LogicalTypeId::MAP) { - // FIXME: we don't support MAP in the TestVectorSequence - return; - } - GenerateVector(info, info.types[c], result->data[c]); - } - result->SetCardinality(SEQ_CARDINALITY); -#if STANDARD_VECTOR_SIZE > 2 - info.entries.push_back(std::move(result)); -#else - // vsize = 2, split into two smaller data chunks - for (idx_t offset = 0; offset < SEQ_CARDINALITY; offset += STANDARD_VECTOR_SIZE) { - auto new_result = make_uniq(); - new_result->Initialize(Allocator::DefaultAllocator(), info.types); - - idx_t copy_count = MinValue(STANDARD_VECTOR_SIZE, SEQ_CARDINALITY - offset); - result->Copy(*new_result, *FlatVector::IncrementalSelectionVector(), offset + copy_count, offset); - - info.entries.push_back(std::move(new_result)); - } -#endif - } -}; - -struct TestVectorDictionary { - static void Generate(TestVectorInfo &info) { - idx_t current_chunk = info.entries.size(); - - unordered_set slice_entries {1, 2}; - - TestVectorFlat::Generate(info); - idx_t current_idx = 0; - for (idx_t i = current_chunk; i < info.entries.size(); i++) { - auto &chunk = *info.entries[i]; - SelectionVector sel(STANDARD_VECTOR_SIZE); - idx_t sel_idx = 0; - for (idx_t k = 0; k < chunk.size(); k++) { - if (slice_entries.count(current_idx + k) > 0) { - sel.set_index(sel_idx++, k); - } - } - chunk.Slice(sel, sel_idx); - current_idx += chunk.size(); - } - } -}; - -static unique_ptr TestVectorTypesBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - auto result = make_uniq(); - for (idx_t i = 0; i < input.inputs.size(); i++) { - string name = "test_vector"; - if (i > 0) { - name += to_string(i + 1); - } - auto &input_val = input.inputs[i]; - names.emplace_back(name); - return_types.push_back(input_val.type()); - result->types.push_back(input_val.type()); - } - for (auto &entry : input.named_parameters) { - if (entry.first == "all_flat") { - result->all_flat = BooleanValue::Get(entry.second); - } else { - throw InternalException("Unrecognized named parameter for test_vector_types"); - } - } - return std::move(result); -} - -unique_ptr TestVectorTypesInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - - auto result = make_uniq(); - - auto test_types = TestAllTypesFun::GetTestTypes(); - - map test_type_map; - for (auto &test_type : test_types) { - test_type_map.insert(make_pair(test_type.type.id(), std::move(test_type))); - } - - TestVectorInfo info(bind_data.types, test_type_map, result->entries); - TestVectorFlat::Generate(info); - TestVectorConstant::Generate(info); - TestVectorDictionary::Generate(info); - TestVectorSequence::Generate(info); - for (auto &entry : result->entries) { - entry->Verify(); - } - if (bind_data.all_flat) { - for (auto &entry : result->entries) { - entry->Flatten(); - entry->Verify(); - } - } - return std::move(result); -} - -void TestVectorTypesFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.offset >= data.entries.size()) { - // finished returning values - return; - } - output.Reference(*data.entries[data.offset]); - data.offset++; -} - -void TestVectorTypesFun::RegisterFunction(BuiltinFunctions &set) { - TableFunction test_vector_types("test_vector_types", {LogicalType::ANY}, TestVectorTypesFunction, - TestVectorTypesBind, TestVectorTypesInit); - test_vector_types.varargs = LogicalType::ANY; - test_vector_types.named_parameters["all_flat"] = LogicalType::BOOLEAN; - - set.AddFunction(std::move(test_vector_types)); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/system_functions.cpp b/src/duckdb/src/function/table/system_functions.cpp deleted file mode 100644 index 7560221c5..000000000 --- a/src/duckdb/src/function/table/system_functions.cpp +++ /dev/null @@ -1,46 +0,0 @@ -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/parser/parsed_data/create_view_info.hpp" -#include "duckdb/parser/query_node/select_node.hpp" -#include "duckdb/parser/expression/star_expression.hpp" -#include "duckdb/parser/tableref/table_function_ref.hpp" -#include "duckdb/parser/expression/function_expression.hpp" -#include "duckdb/catalog/catalog.hpp" - -namespace duckdb { - -void BuiltinFunctions::RegisterSQLiteFunctions() { - PragmaVersion::RegisterFunction(*this); - PragmaPlatform::RegisterFunction(*this); - PragmaCollations::RegisterFunction(*this); - PragmaTableInfo::RegisterFunction(*this); - PragmaStorageInfo::RegisterFunction(*this); - PragmaMetadataInfo::RegisterFunction(*this); - PragmaDatabaseSize::RegisterFunction(*this); - PragmaUserAgent::RegisterFunction(*this); - - DuckDBColumnsFun::RegisterFunction(*this); - DuckDBConstraintsFun::RegisterFunction(*this); - DuckDBDatabasesFun::RegisterFunction(*this); - DuckDBFunctionsFun::RegisterFunction(*this); - DuckDBKeywordsFun::RegisterFunction(*this); - DuckDBIndexesFun::RegisterFunction(*this); - DuckDBSchemasFun::RegisterFunction(*this); - DuckDBDependenciesFun::RegisterFunction(*this); - DuckDBExtensionsFun::RegisterFunction(*this); - DuckDBMemoryFun::RegisterFunction(*this); - DuckDBOptimizersFun::RegisterFunction(*this); - DuckDBSecretsFun::RegisterFunction(*this); - DuckDBWhichSecretFun::RegisterFunction(*this); - DuckDBSequencesFun::RegisterFunction(*this); - DuckDBSettingsFun::RegisterFunction(*this); - DuckDBTablesFun::RegisterFunction(*this); - DuckDBTableSample::RegisterFunction(*this); - DuckDBTemporaryFilesFun::RegisterFunction(*this); - DuckDBTypesFun::RegisterFunction(*this); - DuckDBVariablesFun::RegisterFunction(*this); - DuckDBViewsFun::RegisterFunction(*this); - TestAllTypesFun::RegisterFunction(*this); - TestVectorTypesFun::RegisterFunction(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/table_scan.cpp b/src/duckdb/src/function/table/table_scan.cpp deleted file mode 100644 index b36abbb9b..000000000 --- a/src/duckdb/src/function/table/table_scan.cpp +++ /dev/null @@ -1,645 +0,0 @@ -#include "duckdb/function/table/table_scan.hpp" - -#include "duckdb/catalog/catalog_entry/duck_table_entry.hpp" -#include "duckdb/catalog/dependency_list.hpp" -#include "duckdb/common/mutex.hpp" -#include "duckdb/common/serializer/deserializer.hpp" -#include "duckdb/common/serializer/serializer.hpp" -#include "duckdb/execution/index/art/art.hpp" -#include "duckdb/function/function_set.hpp" -#include "duckdb/main/attached_database.hpp" -#include "duckdb/main/client_config.hpp" -#include "duckdb/optimizer/matcher/expression_matcher.hpp" -#include "duckdb/planner/expression/bound_between_expression.hpp" -#include "duckdb/planner/expression_iterator.hpp" -#include "duckdb/planner/operator/logical_get.hpp" -#include "duckdb/storage/data_table.hpp" -#include "duckdb/storage/table/scan_state.hpp" -#include "duckdb/transaction/duck_transaction.hpp" -#include "duckdb/transaction/local_storage.hpp" -#include "duckdb/storage/storage_index.hpp" -#include "duckdb/main/client_data.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/planner/filter/optional_filter.hpp" -#include "duckdb/planner/filter/in_filter.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_comparison_expression.hpp" -#include "duckdb/planner/filter/conjunction_filter.hpp" - -namespace duckdb { - -struct TableScanLocalState : public LocalTableFunctionState { - //! The current position in the scan. - TableScanState scan_state; - //! The DataChunk containing all read columns. - //! This includes filter columns, which are immediately removed. - DataChunk all_columns; -}; - -struct IndexScanLocalState : public LocalTableFunctionState { - //! The batch index, which determines the offset in the row ID vector. - idx_t batch_index; - //! The DataChunk containing all read columns. - //! This includes filter columns, which are immediately removed. - DataChunk all_columns; -}; - -static StorageIndex TransformStorageIndex(const ColumnIndex &column_id) { - vector result; - for (auto &child_id : column_id.GetChildIndexes()) { - result.push_back(TransformStorageIndex(child_id)); - } - return StorageIndex(column_id.GetPrimaryIndex(), std::move(result)); -} - -static StorageIndex GetStorageIndex(TableCatalogEntry &table, const ColumnIndex &column_id) { - if (column_id.IsRowIdColumn()) { - return StorageIndex(); - } - - // The index of the base ColumnIndex is equal to the physical column index in the table - // for any child indices because the indices are already the physical indices. - // Only the top-level can have generated columns. - auto &col = table.GetColumn(column_id.ToLogical()); - auto result = TransformStorageIndex(column_id); - result.SetIndex(col.StorageOid()); - return result; -} - -class TableScanGlobalState : public GlobalTableFunctionState { -public: - TableScanGlobalState(ClientContext &context, const FunctionData *bind_data_p) { - D_ASSERT(bind_data_p); - auto &bind_data = bind_data_p->Cast(); - max_threads = bind_data.table.GetStorage().MaxThreads(context); - } - - //! The maximum number of threads for this table scan. - idx_t max_threads; - //! The projected columns of this table scan. - vector projection_ids; - //! The types of all scanned columns. - vector scanned_types; - -public: - virtual unique_ptr InitLocalState(ExecutionContext &context, - TableFunctionInitInput &input) = 0; - virtual void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) = 0; - virtual double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p) const = 0; - virtual OperatorPartitionData TableScanGetPartitionData(ClientContext &context, - TableFunctionGetPartitionInput &input) = 0; - - idx_t MaxThreads() const override { - return max_threads; - } - bool CanRemoveFilterColumns() const { - return !projection_ids.empty(); - } -}; - -class DuckIndexScanState : public TableScanGlobalState { -public: - DuckIndexScanState(ClientContext &context, const FunctionData *bind_data_p) - : TableScanGlobalState(context, bind_data_p), next_batch_index(0), finished(false) { - } - - //! The batch index of the next Sink. - //! Also determines the offset of the next chunk. I.e., offset = next_batch_index * STANDARD_VECTOR_SIZE. - idx_t next_batch_index; - //! The total scanned row IDs. - unsafe_vector row_ids; - //! The column IDs of the to-be-scanned columns. - vector column_ids; - //! True, if no more row IDs must be scanned. - bool finished; - //! Synchronize changes to the global index scan state. - mutex index_scan_lock; - - ColumnFetchState fetch_state; - TableScanState table_scan_state; - -public: - unique_ptr InitLocalState(ExecutionContext &context, - TableFunctionInitInput &input) override { - auto l_state = make_uniq(); - if (input.CanRemoveFilterColumns()) { - l_state->all_columns.Initialize(context.client, scanned_types); - } - return std::move(l_state); - } - - void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) override { - auto &bind_data = data_p.bind_data->Cast(); - auto &tx = DuckTransaction::Get(context, bind_data.table.catalog); - auto &storage = bind_data.table.GetStorage(); - auto &l_state = data_p.local_state->Cast(); - - auto row_id_count = row_ids.size(); - idx_t scan_count = 0; - idx_t offset = 0; - - { - // Synchronize changes to the shared global state. - lock_guard l(index_scan_lock); - if (!finished) { - l_state.batch_index = next_batch_index; - next_batch_index++; - - offset = l_state.batch_index * STANDARD_VECTOR_SIZE; - auto remaining = row_id_count - offset; - scan_count = remaining < STANDARD_VECTOR_SIZE ? remaining : STANDARD_VECTOR_SIZE; - finished = remaining < STANDARD_VECTOR_SIZE ? true : false; - } - } - - if (scan_count != 0) { - auto row_id_data = (data_ptr_t)&row_ids[0 + offset]; // NOLINT - this is not pretty - Vector local_vector(LogicalType::ROW_TYPE, row_id_data); - - if (CanRemoveFilterColumns()) { - l_state.all_columns.Reset(); - storage.Fetch(tx, l_state.all_columns, column_ids, local_vector, scan_count, fetch_state); - output.ReferenceColumns(l_state.all_columns, projection_ids); - } else { - storage.Fetch(tx, output, column_ids, local_vector, scan_count, fetch_state); - } - } - - if (output.size() == 0) { - auto &local_storage = LocalStorage::Get(tx); - if (CanRemoveFilterColumns()) { - l_state.all_columns.Reset(); - local_storage.Scan(table_scan_state.local_state, column_ids, l_state.all_columns); - output.ReferenceColumns(l_state.all_columns, projection_ids); - } else { - local_storage.Scan(table_scan_state.local_state, column_ids, output); - } - } - } - - double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p) const override { - auto total_rows = row_ids.size(); - if (total_rows == 0) { - return 100; - } - - auto scanned_rows = next_batch_index * STANDARD_VECTOR_SIZE; - auto percentage = 100 * (static_cast(scanned_rows) / static_cast(total_rows)); - return percentage > 100 ? 100 : percentage; - } - - OperatorPartitionData TableScanGetPartitionData(ClientContext &context, - TableFunctionGetPartitionInput &input) override { - auto &l_state = input.local_state->Cast(); - return OperatorPartitionData(l_state.batch_index); - } -}; - -class DuckTableScanState : public TableScanGlobalState { -public: - DuckTableScanState(ClientContext &context, const FunctionData *bind_data_p) - : TableScanGlobalState(context, bind_data_p) { - } - - ParallelTableScanState state; - -public: - unique_ptr InitLocalState(ExecutionContext &context, - TableFunctionInitInput &input) override { - auto &bind_data = input.bind_data->Cast(); - auto l_state = make_uniq(); - - vector storage_ids; - for (auto &col : input.column_indexes) { - storage_ids.push_back(GetStorageIndex(bind_data.table, col)); - } - - l_state->scan_state.Initialize(std::move(storage_ids), input.filters.get(), input.sample_options.get()); - - auto &storage = bind_data.table.GetStorage(); - storage.NextParallelScan(context.client, state, l_state->scan_state); - if (input.CanRemoveFilterColumns()) { - l_state->all_columns.Initialize(context.client, scanned_types); - } - - l_state->scan_state.options.force_fetch_row = ClientConfig::GetConfig(context.client).force_fetch_row; - return std::move(l_state); - } - - void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) override { - auto &bind_data = data_p.bind_data->Cast(); - auto &tx = DuckTransaction::Get(context, bind_data.table.catalog); - auto &storage = bind_data.table.GetStorage(); - - auto &l_state = data_p.local_state->Cast(); - l_state.scan_state.options.force_fetch_row = ClientConfig::GetConfig(context).force_fetch_row; - - do { - if (bind_data.is_create_index) { - storage.CreateIndexScan(l_state.scan_state, output, - TableScanType::TABLE_SCAN_COMMITTED_ROWS_OMIT_PERMANENTLY_DELETED); - } else if (CanRemoveFilterColumns()) { - l_state.all_columns.Reset(); - storage.Scan(tx, l_state.all_columns, l_state.scan_state); - output.ReferenceColumns(l_state.all_columns, projection_ids); - } else { - storage.Scan(tx, output, l_state.scan_state); - } - if (output.size() > 0) { - return; - } - - auto next = storage.NextParallelScan(context, state, l_state.scan_state); - if (!next) { - return; - } - } while (true); - } - - double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p) const override { - auto &bind_data = bind_data_p->Cast(); - auto &storage = bind_data.table.GetStorage(); - auto total_rows = storage.GetTotalRows(); - - // The table is empty or smaller than the standard vector size. - if (total_rows == 0) { - return 100; - } - - idx_t scanned_rows = state.scan_state.processed_rows; - scanned_rows += state.local_state.processed_rows; - auto percentage = 100 * (static_cast(scanned_rows) / static_cast(total_rows)); - if (percentage > 100) { - // If the last chunk has fewer elements than STANDARD_VECTOR_SIZE, and if our percentage is over 100, - // then we finished this table. - return 100; - } - return percentage; - } - - OperatorPartitionData TableScanGetPartitionData(ClientContext &context, - TableFunctionGetPartitionInput &input) override { - auto &l_state = input.local_state->Cast(); - if (l_state.scan_state.table_state.row_group) { - return OperatorPartitionData(l_state.scan_state.table_state.batch_index); - } - if (l_state.scan_state.local_state.row_group) { - return OperatorPartitionData(l_state.scan_state.table_state.batch_index + - l_state.scan_state.local_state.batch_index); - } - return OperatorPartitionData(0); - } -}; - -static unique_ptr TableScanInitLocal(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *g_state) { - auto &cast_g_state = g_state->Cast(); - return cast_g_state.InitLocalState(context, input); -} - -unique_ptr DuckTableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, - DataTable &storage, const TableScanBindData &bind_data) { - auto g_state = make_uniq(context, input.bind_data.get()); - storage.InitializeParallelScan(context, g_state->state); - if (!input.CanRemoveFilterColumns()) { - return std::move(g_state); - } - - g_state->projection_ids = input.projection_ids; - const auto &columns = bind_data.table.GetColumns(); - for (const auto &col_idx : input.column_indexes) { - if (col_idx.IsRowIdColumn()) { - g_state->scanned_types.emplace_back(LogicalType::ROW_TYPE); - } else { - g_state->scanned_types.push_back(columns.GetColumn(col_idx.ToLogical()).Type()); - } - } - return std::move(g_state); -} - -unique_ptr DuckIndexScanInitGlobal(ClientContext &context, TableFunctionInitInput &input, - DataTable &storage, const TableScanBindData &bind_data, - unsafe_vector &row_ids) { - auto g_state = make_uniq(context, input.bind_data.get()); - if (!row_ids.empty()) { - std::sort(row_ids.begin(), row_ids.end()); - g_state->row_ids = std::move(row_ids); - } - g_state->finished = g_state->row_ids.empty() ? true : false; - - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - g_state->table_scan_state.options.force_fetch_row = ClientConfig::GetConfig(context).force_fetch_row; - - if (input.CanRemoveFilterColumns()) { - g_state->projection_ids = input.projection_ids; - } - - const auto &columns = bind_data.table.GetColumns(); - for (const auto &col_idx : input.column_indexes) { - g_state->column_ids.push_back(GetStorageIndex(bind_data.table, col_idx)); - if (col_idx.IsRowIdColumn()) { - g_state->scanned_types.emplace_back(LogicalType::ROW_TYPE); - continue; - } - g_state->scanned_types.push_back(columns.GetColumn(col_idx.ToLogical()).Type()); - } - - g_state->table_scan_state.Initialize(g_state->column_ids, input.filters.get()); - local_storage.InitializeScan(storage, g_state->table_scan_state.local_state, input.filters); - - // Const-cast to indicate an index scan. - // We need this information in the bind data so that we can access it during ANALYZE. - auto &no_const_bind_data = bind_data.CastNoConst(); - no_const_bind_data.is_index_scan = true; - - return std::move(g_state); -} - -void ExtractInFilter(unique_ptr &filter, BoundColumnRefExpression &bound_ref, - unique_ptr>> &filter_expressions) { - // Special-handling of IN filters. - // They are part of a CONJUNCTION_AND. - if (filter->filter_type != TableFilterType::CONJUNCTION_AND) { - return; - } - - auto &and_filter = filter->Cast(); - auto &children = and_filter.child_filters; - if (children.empty()) { - return; - } - if (children[0]->filter_type != TableFilterType::OPTIONAL_FILTER) { - return; - } - - auto &optional_filter = children[0]->Cast(); - auto &child = optional_filter.child_filter; - if (child->filter_type != TableFilterType::IN_FILTER) { - return; - } - - auto &in_filter = child->Cast(); - if (!in_filter.origin_is_hash_join) { - return; - } - - // They are all on the same column, so we can split them. - for (const auto &value : in_filter.values) { - auto bound_constant = make_uniq(value); - auto filter_expr = make_uniq(ExpressionType::COMPARE_EQUAL, bound_ref.Copy(), - std::move(bound_constant)); - filter_expressions->push_back(std::move(filter_expr)); - } -} - -unique_ptr>> ExtractFilters(const ColumnDefinition &col, unique_ptr &filter, - idx_t storage_idx) { - ColumnBinding binding(0, storage_idx); - auto bound_ref = make_uniq(col.Name(), col.Type(), binding); - - auto filter_expressions = make_uniq>>(); - ExtractInFilter(filter, *bound_ref, filter_expressions); - - if (filter_expressions->empty()) { - auto filter_expr = filter->ToExpression(*bound_ref); - filter_expressions->push_back(std::move(filter_expr)); - } - return filter_expressions; -} - -bool TryScanIndex(ART &art, const ColumnList &column_list, TableFunctionInitInput &input, TableFilterSet &filter_set, - idx_t max_count, unsafe_vector &row_ids) { - // FIXME: No support for index scans on compound ARTs. - // See note above on multi-filter support. - if (art.unbound_expressions.size() > 1) { - return false; - } - - auto index_expr = art.unbound_expressions[0]->Copy(); - auto &indexed_columns = art.GetColumnIds(); - - // NOTE: We do not push down multi-column filters, e.g., 42 = a + b. - if (indexed_columns.size() != 1) { - return false; - } - - // Get ART column. - auto &col = column_list.GetColumn(LogicalIndex(indexed_columns[0])); - - // The indexes of the filters match input.column_indexes, which are: i -> column_index. - // Try to find a filter on the ART column. - optional_idx storage_index; - for (idx_t i = 0; i < input.column_indexes.size(); i++) { - if (input.column_indexes[i].ToLogical() == col.Logical()) { - storage_index = i; - break; - } - } - - // No filter matches the ART column. - if (!storage_index.IsValid()) { - return false; - } - - // Try to find a matching filter for the column. - auto filter = filter_set.filters.find(storage_index.GetIndex()); - if (filter == filter_set.filters.end()) { - return false; - } - - auto filter_expressions = ExtractFilters(col, filter->second, storage_index.GetIndex()); - for (const auto &filter_expr : *filter_expressions) { - auto scan_state = art.TryInitializeScan(*index_expr, *filter_expr); - if (!scan_state) { - return false; - } - - // Check if we can use an index scan, and already retrieve the matching row ids. - if (!art.Scan(*scan_state, max_count, row_ids)) { - row_ids.clear(); - return false; - } - } - return true; -} - -unique_ptr TableScanInitGlobal(ClientContext &context, TableFunctionInitInput &input) { - D_ASSERT(input.bind_data); - - auto &bind_data = input.bind_data->Cast(); - auto &table = bind_data.table; - auto &storage = table.GetStorage(); - - // Can't index scan without filters. - if (!input.filters) { - return DuckTableScanInitGlobal(context, input, storage, bind_data); - } - auto &filter_set = *input.filters; - - // FIXME: We currently only support scanning one ART with one filter. - // If multiple filters exist, i.e., a = 11 AND b = 24, we need to - // 1. 1.1. Find + scan one ART for a = 11. - // 1.2. Find + scan one ART for b = 24. - // 1.3. Return the intersecting row IDs. - // 2. (Reorder and) scan a single ART with a compound key of (a, b). - if (filter_set.filters.size() != 1) { - return DuckTableScanInitGlobal(context, input, storage, bind_data); - } - - // The checkpoint lock ensures that we do not checkpoint while scanning this table. - auto checkpoint_lock = storage.GetSharedCheckpointLock(); - auto &info = storage.GetDataTableInfo(); - auto &indexes = info->GetIndexes(); - if (indexes.Empty()) { - return DuckTableScanInitGlobal(context, input, storage, bind_data); - } - - auto &db_config = DBConfig::GetConfig(context); - auto scan_percentage = db_config.GetSetting(context); - auto scan_max_count = db_config.GetSetting(context); - - auto total_rows = storage.GetTotalRows(); - auto total_rows_from_percentage = LossyNumericCast(double(total_rows) * scan_percentage); - auto max_count = MaxValue(scan_max_count, total_rows_from_percentage); - - auto &column_list = table.GetColumns(); - bool index_scan = false; - unsafe_vector row_ids; - - info->GetIndexes().BindAndScan(context, *info, [&](ART &art) { - index_scan = TryScanIndex(art, column_list, input, filter_set, max_count, row_ids); - return index_scan; - }); - - if (!index_scan) { - return DuckTableScanInitGlobal(context, input, storage, bind_data); - } - return DuckIndexScanInitGlobal(context, input, storage, bind_data, row_ids); -} - -static unique_ptr TableScanStatistics(ClientContext &context, const FunctionData *bind_data_p, - column_t column_id) { - auto &bind_data = bind_data_p->Cast(); - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - - // Don't emit statistics for tables with outstanding transaction-local data. - if (local_storage.Find(bind_data.table.GetStorage())) { - return nullptr; - } - return bind_data.table.GetStatistics(context, column_id); -} - -static void TableScanFunc(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &g_state = data_p.global_state->Cast(); - g_state.TableScanFunc(context, data_p, output); -} - -double TableScanProgress(ClientContext &context, const FunctionData *bind_data_p, - const GlobalTableFunctionState *g_state_p) { - auto &g_state = g_state_p->Cast(); - return g_state.TableScanProgress(context, bind_data_p); -} - -OperatorPartitionData TableScanGetPartitionData(ClientContext &context, TableFunctionGetPartitionInput &input) { - if (input.partition_info.RequiresPartitionColumns()) { - throw InternalException("TableScan::GetPartitionData: partition columns not supported"); - } - - auto &g_state = input.global_state->Cast(); - return g_state.TableScanGetPartitionData(context, input); -} - -vector TableScanGetPartitionStats(ClientContext &context, GetPartitionStatsInput &input) { - auto &bind_data = input.bind_data->Cast(); - vector result; - auto &storage = bind_data.table.GetStorage(); - return storage.GetPartitionStats(context); -} - -BindInfo TableScanGetBindInfo(const optional_ptr bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - return BindInfo(bind_data.table); -} - -void TableScanDependency(LogicalDependencyList &entries, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - entries.AddDependency(bind_data.table); -} - -unique_ptr TableScanCardinality(ClientContext &context, const FunctionData *bind_data_p) { - auto &bind_data = bind_data_p->Cast(); - auto &local_storage = LocalStorage::Get(context, bind_data.table.catalog); - auto &storage = bind_data.table.GetStorage(); - idx_t table_rows = storage.GetTotalRows(); - idx_t estimated_cardinality = table_rows + local_storage.AddedRows(bind_data.table.GetStorage()); - return make_uniq(table_rows, estimated_cardinality); -} - -InsertionOrderPreservingMap TableScanToString(TableFunctionToStringInput &input) { - InsertionOrderPreservingMap result; - auto &bind_data = input.bind_data->Cast(); - result["Table"] = bind_data.table.name; - result["Type"] = bind_data.is_index_scan ? "Index Scan" : "Sequence Scan"; - return result; -} - -static void TableScanSerialize(Serializer &serializer, const optional_ptr bind_data_p, - const TableFunction &function) { - auto &bind_data = bind_data_p->Cast(); - serializer.WriteProperty(100, "catalog", bind_data.table.schema.catalog.GetName()); - serializer.WriteProperty(101, "schema", bind_data.table.schema.name); - serializer.WriteProperty(102, "table", bind_data.table.name); - serializer.WriteProperty(103, "is_index_scan", bind_data.is_index_scan); - serializer.WriteProperty(104, "is_create_index", bind_data.is_create_index); - serializer.WritePropertyWithDefault(105, "result_ids", unsafe_vector()); -} - -static unique_ptr TableScanDeserialize(Deserializer &deserializer, TableFunction &function) { - auto catalog = deserializer.ReadProperty(100, "catalog"); - auto schema = deserializer.ReadProperty(101, "schema"); - auto table = deserializer.ReadProperty(102, "table"); - auto &catalog_entry = - Catalog::GetEntry(deserializer.Get(), catalog, schema, table); - if (catalog_entry.type != CatalogType::TABLE_ENTRY) { - throw SerializationException("Cant find table for %s.%s", schema, table); - } - auto result = make_uniq(catalog_entry.Cast()); - deserializer.ReadProperty(103, "is_index_scan", result->is_index_scan); - deserializer.ReadProperty(104, "is_create_index", result->is_create_index); - deserializer.ReadDeletedProperty>(105, "result_ids"); - return std::move(result); -} - -TableFunction TableScanFunction::GetFunction() { - TableFunction scan_function("seq_scan", {}, TableScanFunc); - scan_function.init_local = TableScanInitLocal; - scan_function.init_global = TableScanInitGlobal; - scan_function.statistics = TableScanStatistics; - scan_function.dependency = TableScanDependency; - scan_function.cardinality = TableScanCardinality; - scan_function.pushdown_complex_filter = nullptr; - scan_function.to_string = TableScanToString; - scan_function.table_scan_progress = TableScanProgress; - scan_function.get_partition_data = TableScanGetPartitionData; - scan_function.get_partition_stats = TableScanGetPartitionStats; - scan_function.get_bind_info = TableScanGetBindInfo; - scan_function.projection_pushdown = true; - scan_function.filter_pushdown = true; - scan_function.filter_prune = true; - scan_function.sampling_pushdown = true; - scan_function.serialize = TableScanSerialize; - scan_function.deserialize = TableScanDeserialize; - return scan_function; -} - -void TableScanFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunctionSet table_scan_set("seq_scan"); - table_scan_set.AddFunction(GetFunction()); - set.AddFunction(std::move(table_scan_set)); -} - -void BuiltinFunctions::RegisterTableScanFunctions() { - TableScanFunction::RegisterFunction(*this); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/unnest.cpp b/src/duckdb/src/function/table/unnest.cpp deleted file mode 100644 index 7abdf9df0..000000000 --- a/src/duckdb/src/function/table/unnest.cpp +++ /dev/null @@ -1,86 +0,0 @@ -#include "duckdb/function/table/range.hpp" -#include "duckdb/common/algorithm.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_unnest_expression.hpp" -#include "duckdb/execution/operator/projection/physical_unnest.hpp" - -namespace duckdb { - -struct UnnestBindData : public FunctionData { - explicit UnnestBindData(LogicalType input_type_p) : input_type(std::move(input_type_p)) { - } - - LogicalType input_type; - -public: - unique_ptr Copy() const override { - return make_uniq(input_type); - } - - bool Equals(const FunctionData &other_p) const override { - auto &other = other_p.Cast(); - return input_type == other.input_type; - } -}; - -struct UnnestGlobalState : public GlobalTableFunctionState { - UnnestGlobalState() { - } - - vector> select_list; - - idx_t MaxThreads() const override { - return GlobalTableFunctionState::MAX_THREADS; - } -}; - -struct UnnestLocalState : public LocalTableFunctionState { - UnnestLocalState() { - } - - unique_ptr operator_state; -}; - -static unique_ptr UnnestBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - if (input.input_table_types.size() != 1 || input.input_table_types[0].id() != LogicalTypeId::LIST) { - throw BinderException("UNNEST requires a single list as input"); - } - return_types.push_back(ListType::GetChildType(input.input_table_types[0])); - names.push_back("unnest"); - return make_uniq(input.input_table_types[0]); -} - -static unique_ptr UnnestLocalInit(ExecutionContext &context, TableFunctionInitInput &input, - GlobalTableFunctionState *global_state) { - auto &gstate = global_state->Cast(); - - auto result = make_uniq(); - result->operator_state = PhysicalUnnest::GetState(context, gstate.select_list); - return std::move(result); -} - -static unique_ptr UnnestInit(ClientContext &context, TableFunctionInitInput &input) { - auto &bind_data = input.bind_data->Cast(); - auto result = make_uniq(); - auto ref = make_uniq(bind_data.input_type, 0U); - auto bound_unnest = make_uniq(ListType::GetChildType(bind_data.input_type)); - bound_unnest->child = std::move(ref); - result->select_list.push_back(std::move(bound_unnest)); - return std::move(result); -} - -static OperatorResultType UnnestFunction(ExecutionContext &context, TableFunctionInput &data_p, DataChunk &input, - DataChunk &output) { - auto &state = data_p.global_state->Cast(); - auto &lstate = data_p.local_state->Cast(); - return PhysicalUnnest::ExecuteInternal(context, input, output, *lstate.operator_state, state.select_list, false); -} - -void UnnestTableFunction::RegisterFunction(BuiltinFunctions &set) { - TableFunction unnest_function("unnest", {LogicalType::ANY}, nullptr, UnnestBind, UnnestInit, UnnestLocalInit); - unnest_function.in_out_function = UnnestFunction; - set.AddFunction(unnest_function); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table/version/pragma_version.cpp b/src/duckdb/src/function/table/version/pragma_version.cpp deleted file mode 100644 index bc8341084..000000000 --- a/src/duckdb/src/function/table/version/pragma_version.cpp +++ /dev/null @@ -1,116 +0,0 @@ -#ifndef DUCKDB_PATCH_VERSION -#define DUCKDB_PATCH_VERSION "4-dev4329" -#endif -#ifndef DUCKDB_MINOR_VERSION -#define DUCKDB_MINOR_VERSION 1 -#endif -#ifndef DUCKDB_MAJOR_VERSION -#define DUCKDB_MAJOR_VERSION 1 -#endif -#ifndef DUCKDB_VERSION -#define DUCKDB_VERSION "v1.1.4-dev4329" -#endif -#ifndef DUCKDB_SOURCE_ID -#define DUCKDB_SOURCE_ID "a1c7e9b115" -#endif -#include "duckdb/function/table/system_functions.hpp" -#include "duckdb/main/database.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/common/platform.hpp" - -#include - -namespace duckdb { - -struct PragmaVersionData : public GlobalTableFunctionState { - PragmaVersionData() : finished(false) { - } - - bool finished; -}; - -static unique_ptr PragmaVersionBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("library_version"); - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("source_id"); - return_types.emplace_back(LogicalType::VARCHAR); - return nullptr; -} - -static unique_ptr PragmaVersionInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaVersionFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.finished) { - // finished returning values - return; - } - output.SetCardinality(1); - output.SetValue(0, 0, DuckDB::LibraryVersion()); - output.SetValue(1, 0, DuckDB::SourceID()); - data.finished = true; -} - -void PragmaVersion::RegisterFunction(BuiltinFunctions &set) { - TableFunction pragma_version("pragma_version", {}, PragmaVersionFunction); - pragma_version.bind = PragmaVersionBind; - pragma_version.init_global = PragmaVersionInit; - set.AddFunction(pragma_version); -} - -idx_t DuckDB::StandardVectorSize() { - return STANDARD_VECTOR_SIZE; -} - -const char *DuckDB::SourceID() { - return DUCKDB_SOURCE_ID; -} - -const char *DuckDB::LibraryVersion() { - return DUCKDB_VERSION; -} - -string DuckDB::Platform() { - return DuckDBPlatform(); -} - -struct PragmaPlatformData : public GlobalTableFunctionState { - PragmaPlatformData() : finished(false) { - } - - bool finished; -}; - -static unique_ptr PragmaPlatformBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - names.emplace_back("platform"); - return_types.emplace_back(LogicalType::VARCHAR); - return nullptr; -} - -static unique_ptr PragmaPlatformInit(ClientContext &context, TableFunctionInitInput &input) { - return make_uniq(); -} - -static void PragmaPlatformFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = data_p.global_state->Cast(); - if (data.finished) { - // finished returning values - return; - } - output.SetCardinality(1); - output.SetValue(0, 0, DuckDB::Platform()); - data.finished = true; -} - -void PragmaPlatform::RegisterFunction(BuiltinFunctions &set) { - TableFunction pragma_platform("pragma_platform", {}, PragmaPlatformFunction); - pragma_platform.bind = PragmaPlatformBind; - pragma_platform.init_global = PragmaPlatformInit; - set.AddFunction(pragma_platform); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table_function.cpp b/src/duckdb/src/function/table_function.cpp deleted file mode 100644 index 6e9df8194..000000000 --- a/src/duckdb/src/function/table_function.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include "duckdb/function/table_function.hpp" - -namespace duckdb { - -GlobalTableFunctionState::~GlobalTableFunctionState() { -} - -LocalTableFunctionState::~LocalTableFunctionState() { -} - -PartitionStatistics::PartitionStatistics() : row_start(0), count(0), count_type(CountType::COUNT_APPROXIMATE) { -} - -TableFunctionInfo::~TableFunctionInfo() { -} - -TableFunction::TableFunction(string name, vector arguments, table_function_t function, - table_function_bind_t bind, table_function_init_global_t init_global, - table_function_init_local_t init_local) - : SimpleNamedParameterFunction(std::move(name), std::move(arguments)), bind(bind), bind_replace(nullptr), - init_global(init_global), init_local(init_local), function(function), in_out_function(nullptr), - in_out_function_final(nullptr), statistics(nullptr), dependency(nullptr), cardinality(nullptr), - pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), get_partition_data(nullptr), - get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), supports_pushdown_type(nullptr), - get_partition_info(nullptr), get_partition_stats(nullptr), serialize(nullptr), deserialize(nullptr), - projection_pushdown(false), filter_pushdown(false), filter_prune(false), sampling_pushdown(false) { -} - -TableFunction::TableFunction(const vector &arguments, table_function_t function, - table_function_bind_t bind, table_function_init_global_t init_global, - table_function_init_local_t init_local) - : TableFunction(string(), arguments, function, bind, init_global, init_local) { -} -TableFunction::TableFunction() - : SimpleNamedParameterFunction("", {}), bind(nullptr), bind_replace(nullptr), init_global(nullptr), - init_local(nullptr), function(nullptr), in_out_function(nullptr), statistics(nullptr), dependency(nullptr), - cardinality(nullptr), pushdown_complex_filter(nullptr), to_string(nullptr), table_scan_progress(nullptr), - get_partition_data(nullptr), get_bind_info(nullptr), type_pushdown(nullptr), get_multi_file_reader(nullptr), - supports_pushdown_type(nullptr), get_partition_info(nullptr), get_partition_stats(nullptr), serialize(nullptr), - deserialize(nullptr), projection_pushdown(false), filter_pushdown(false), filter_prune(false), - sampling_pushdown(false) { -} - -bool TableFunction::Equal(const TableFunction &rhs) const { - // number of types - if (this->arguments.size() != rhs.arguments.size()) { - return false; - } - // argument types - for (idx_t i = 0; i < this->arguments.size(); ++i) { - if (this->arguments[i] != rhs.arguments[i]) { - return false; - } - } - // varargs - if (this->varargs != rhs.varargs) { - return false; - } - - return true; // they are equal -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/table_macro_function.cpp b/src/duckdb/src/function/table_macro_function.cpp deleted file mode 100644 index becb1fe6b..000000000 --- a/src/duckdb/src/function/table_macro_function.cpp +++ /dev/null @@ -1,34 +0,0 @@ -//===----------------------------------------------------------------------===// -// DuckDB -// -// duckdb/function/table_macro_function.hpp -// -// -//===----------------------------------------------------------------------===// -//! The SelectStatement of the view -#include "duckdb/function/table_macro_function.hpp" - -#include "duckdb/parser/expression/constant_expression.hpp" -#include "duckdb/parser/query_node.hpp" - -namespace duckdb { - -TableMacroFunction::TableMacroFunction(unique_ptr query_node) - : MacroFunction(MacroType::TABLE_MACRO), query_node(std::move(query_node)) { -} - -TableMacroFunction::TableMacroFunction(void) : MacroFunction(MacroType::TABLE_MACRO) { -} - -unique_ptr TableMacroFunction::Copy() const { - auto result = make_uniq(); - result->query_node = query_node->Copy(); - this->CopyProperties(*result); - return std::move(result); -} - -string TableMacroFunction::ToSQL() const { - return MacroFunction::ToSQL() + StringUtil::Format("TABLE (%s)", query_node->ToString()); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/udf_function.cpp b/src/duckdb/src/function/udf_function.cpp deleted file mode 100644 index 3c03dbbe3..000000000 --- a/src/duckdb/src/function/udf_function.cpp +++ /dev/null @@ -1,27 +0,0 @@ -#include "duckdb/function/udf_function.hpp" - -#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" -#include "duckdb/parser/parsed_data/create_aggregate_function_info.hpp" - -#include "duckdb/main/client_context.hpp" - -namespace duckdb { - -void UDFWrapper::RegisterFunction(string name, vector args, LogicalType ret_type, - scalar_function_t udf_function, ClientContext &context, LogicalType varargs) { - - ScalarFunction scalar_function(std::move(name), std::move(args), std::move(ret_type), std::move(udf_function)); - scalar_function.varargs = std::move(varargs); - scalar_function.null_handling = FunctionNullHandling::SPECIAL_HANDLING; - CreateScalarFunctionInfo info(scalar_function); - info.schema = DEFAULT_SCHEMA; - context.RegisterFunction(info); -} - -void UDFWrapper::RegisterAggrFunction(AggregateFunction aggr_function, ClientContext &context, LogicalType varargs) { - aggr_function.varargs = std::move(varargs); - CreateAggregateFunctionInfo info(std::move(aggr_function)); - context.RegisterFunction(info); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregate_function.cpp b/src/duckdb/src/function/window/window_aggregate_function.cpp deleted file mode 100644 index f08fc5dbc..000000000 --- a/src/duckdb/src/function/window/window_aggregate_function.cpp +++ /dev/null @@ -1,248 +0,0 @@ -#include "duckdb/function/window/window_aggregate_function.hpp" - -#include "duckdb/function/window/window_constant_aggregator.hpp" -#include "duckdb/function/window/window_custom_aggregator.hpp" -#include "duckdb/function/window/window_distinct_aggregator.hpp" -#include "duckdb/function/window/window_naive_aggregator.hpp" -#include "duckdb/function/window/window_segment_tree.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/planner/expression/bound_reference_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowAggregateExecutor -//===--------------------------------------------------------------------===// -class WindowAggregateExecutorGlobalState : public WindowExecutorGlobalState { -public: - WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask); - - // aggregate global state - unique_ptr gsink; - - // the filter reference expression. - const Expression *filter_ref; -}; - -static BoundWindowExpression &SimplifyWindowedAggregate(BoundWindowExpression &wexpr, ClientContext &context) { - // Remove redundant/irrelevant modifiers (they can be serious performance cliffs) - if (wexpr.aggregate && ClientConfig::GetConfig(context).enable_optimizer) { - const auto &aggr = wexpr.aggregate; - auto &arg_orders = wexpr.arg_orders; - if (aggr->distinct_dependent != AggregateDistinctDependent::DISTINCT_DEPENDENT) { - wexpr.distinct = false; - } - if (aggr->order_dependent != AggregateOrderDependent::ORDER_DEPENDENT) { - arg_orders.clear(); - } else { - // If the argument order is prefix of the partition ordering, - // then we can just use the partition ordering. - if (BoundWindowExpression::GetSharedOrders(wexpr.orders, arg_orders) == arg_orders.size()) { - arg_orders.clear(); - } - } - } - - return wexpr; -} - -WindowAggregateExecutor::WindowAggregateExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared, WindowAggregationMode mode) - : WindowExecutor(SimplifyWindowedAggregate(wexpr, context), context, shared), mode(mode) { - - // Force naive for SEPARATE mode or for (currently!) unsupported functionality - if (!ClientConfig::GetConfig(context).enable_optimizer || mode == WindowAggregationMode::SEPARATE) { - aggregator = make_uniq(*this, shared); - } else if (WindowDistinctAggregator::CanAggregate(wexpr)) { - // build a merge sort tree - // see https://dl.acm.org/doi/pdf/10.1145/3514221.3526184 - aggregator = make_uniq(wexpr, shared, context); - } else if (WindowConstantAggregator::CanAggregate(wexpr)) { - aggregator = make_uniq(wexpr, shared, context); - } else if (WindowCustomAggregator::CanAggregate(wexpr, mode)) { - aggregator = make_uniq(wexpr, shared); - } else if (WindowSegmentTree::CanAggregate(wexpr)) { - // build a segment tree for frame-adhering aggregates - // see http://www.vldb.org/pvldb/vol8/p1058-leis.pdf - aggregator = make_uniq(wexpr, shared); - } else { - // No accelerator can handle this combination, so fall back to naïve. - aggregator = make_uniq(*this, shared); - } - - // Compute the FILTER with the other eval columns. - // Anyone who needs it can then convert it to the form they need. - if (wexpr.filter_expr) { - const auto filter_idx = shared.RegisterSink(wexpr.filter_expr); - filter_ref = make_uniq(wexpr.filter_expr->return_type, filter_idx); - } -} - -WindowAggregateExecutorGlobalState::WindowAggregateExecutorGlobalState(const WindowAggregateExecutor &executor, - const idx_t group_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, group_count, partition_mask, order_mask), - filter_ref(executor.filter_ref.get()) { - gsink = executor.aggregator->GetGlobalState(executor.context, group_count, partition_mask); -} - -unique_ptr WindowAggregateExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -class WindowAggregateExecutorLocalState : public WindowExecutorBoundsState { -public: - WindowAggregateExecutorLocalState(const WindowExecutorGlobalState &gstate, const WindowAggregator &aggregator) - : WindowExecutorBoundsState(gstate), filter_executor(gstate.executor.context) { - - auto &gastate = gstate.Cast(); - aggregator_state = aggregator.GetLocalState(*gastate.gsink); - - // evaluate the FILTER clause and stuff it into a large mask for compactness and reuse - auto filter_ref = gastate.filter_ref; - if (filter_ref) { - filter_executor.AddExpression(*filter_ref); - filter_sel.Initialize(STANDARD_VECTOR_SIZE); - } - } - -public: - // state of aggregator - unique_ptr aggregator_state; - //! Executor for any filter clause - ExpressionExecutor filter_executor; - //! Result of filtering - SelectionVector filter_sel; -}; - -unique_ptr -WindowAggregateExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate, *aggregator); -} - -void WindowAggregateExecutor::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); - auto &filter_sel = lastate.filter_sel; - auto &filter_executor = lastate.filter_executor; - - idx_t filtered = 0; - SelectionVector *filtering = nullptr; - if (gastate.filter_ref) { - filtering = &filter_sel; - filtered = filter_executor.SelectExpression(sink_chunk, filter_sel); - } - - D_ASSERT(aggregator); - auto &gestate = *gastate.gsink; - auto &lestate = *lastate.aggregator_state; - aggregator->Sink(gestate, lestate, sink_chunk, coll_chunk, input_idx, filtering, filtered); - - WindowExecutor::Sink(sink_chunk, coll_chunk, input_idx, gstate, lstate); -} - -static void ApplyWindowStats(const WindowBoundary &boundary, FrameDelta &delta, BaseStatistics *base, bool is_start) { - // Avoid overflow by clamping to the frame bounds - auto base_stats = delta; - - switch (boundary) { - case WindowBoundary::UNBOUNDED_PRECEDING: - if (is_start) { - delta.end = 0; - return; - } - break; - case WindowBoundary::UNBOUNDED_FOLLOWING: - if (!is_start) { - delta.begin = 0; - return; - } - break; - case WindowBoundary::CURRENT_ROW_ROWS: - delta.begin = delta.end = 0; - return; - case WindowBoundary::EXPR_PRECEDING_ROWS: - if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { - // Preceding so negative offset from current row - base_stats.begin = NumericStats::GetMin(*base); - base_stats.end = NumericStats::GetMax(*base); - if (delta.begin < base_stats.end && base_stats.end < delta.end) { - delta.begin = -base_stats.end; - } - if (delta.begin < base_stats.begin && base_stats.begin < delta.end) { - delta.end = -base_stats.begin + 1; - } - } - return; - case WindowBoundary::EXPR_FOLLOWING_ROWS: - if (base && base->GetStatsType() == StatisticsType::NUMERIC_STATS && NumericStats::HasMinMax(*base)) { - base_stats.begin = NumericStats::GetMin(*base); - base_stats.end = NumericStats::GetMax(*base); - if (base_stats.end < delta.end) { - delta.end = base_stats.end + 1; - } - } - return; - - case WindowBoundary::CURRENT_ROW_RANGE: - case WindowBoundary::EXPR_PRECEDING_RANGE: - case WindowBoundary::EXPR_FOLLOWING_RANGE: - return; - default: - break; - } - - if (is_start) { - throw InternalException("Unsupported window start boundary"); - } else { - throw InternalException("Unsupported window end boundary"); - } -} - -void WindowAggregateExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { - WindowExecutor::Finalize(gstate, lstate, collection); - - auto &gastate = gstate.Cast(); - auto &gsink = gastate.gsink; - D_ASSERT(aggregator); - - // Estimate the frame statistics - // Default to the entire partition if we don't know anything - FrameStats stats; - const auto count = NumericCast(gastate.payload_count); - - // First entry is the frame start - stats[0] = FrameDelta(-count, count); - auto base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[0].get(); - ApplyWindowStats(wexpr.start, stats[0], base, true); - - // Second entry is the frame end - stats[1] = FrameDelta(-count, count); - base = wexpr.expr_stats.empty() ? nullptr : wexpr.expr_stats[1].get(); - ApplyWindowStats(wexpr.end, stats[1], base, false); - - auto &lastate = lstate.Cast(); - aggregator->Finalize(*gsink, *lastate.aggregator_state, collection, stats); -} - -void WindowAggregateExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); - auto &gsink = gastate.gsink; - D_ASSERT(aggregator); - - auto &agg_state = *lastate.aggregator_state; - - aggregator->Evaluate(*gsink, agg_state, lastate.bounds, result, count, row_idx); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregate_states.cpp b/src/duckdb/src/function/window/window_aggregate_states.cpp deleted file mode 100644 index 2a673ec54..000000000 --- a/src/duckdb/src/function/window/window_aggregate_states.cpp +++ /dev/null @@ -1,48 +0,0 @@ -#include "duckdb/function/window/window_aggregate_states.hpp" - -namespace duckdb { - -WindowAggregateStates::WindowAggregateStates(const AggregateObject &aggr) - : aggr(aggr), state_size(aggr.function.state_size(aggr.function)), allocator(Allocator::DefaultAllocator()) { -} - -void WindowAggregateStates::Initialize(idx_t count) { - states.resize(count * state_size); - auto state_ptr = states.data(); - - statef = make_uniq(LogicalType::POINTER, count); - auto state_f_data = FlatVector::GetData(*statef); - - for (idx_t i = 0; i < count; ++i, state_ptr += state_size) { - state_f_data[i] = state_ptr; - aggr.function.initialize(aggr.function, state_ptr); - } - - // Prevent conversion of results to constants - statef->SetVectorType(VectorType::FLAT_VECTOR); -} - -void WindowAggregateStates::Combine(WindowAggregateStates &target, AggregateCombineType combine_type) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator, AggregateCombineType::ALLOW_DESTRUCTIVE); - aggr.function.combine(*statef, *target.statef, aggr_input_data, GetCount()); -} - -void WindowAggregateStates::Finalize(Vector &result) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(*statef, aggr_input_data, result, GetCount(), 0); -} - -void WindowAggregateStates::Destroy() { - if (states.empty()) { - return; - } - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - if (aggr.function.destructor) { - aggr.function.destructor(*statef, aggr_input_data, GetCount()); - } - - states.clear(); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_aggregator.cpp b/src/duckdb/src/function/window/window_aggregator.cpp deleted file mode 100644 index 197b89e38..000000000 --- a/src/duckdb/src/function/window/window_aggregator.cpp +++ /dev/null @@ -1,88 +0,0 @@ -#include "duckdb/function/window/window_aggregator.hpp" - -#include "duckdb/function/window/window_collection.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowAggregator -//===--------------------------------------------------------------------===// -WindowAggregatorState::WindowAggregatorState() : allocator(Allocator::DefaultAllocator()) { -} - -WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr) - : wexpr(wexpr), aggr(wexpr), result_type(wexpr.return_type), state_size(aggr.function.state_size(aggr.function)), - exclude_mode(wexpr.exclude_clause) { - - for (auto &child : wexpr.children) { - arg_types.emplace_back(child->return_type); - } -} - -WindowAggregator::WindowAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared) - : WindowAggregator(wexpr) { - for (auto &child : wexpr.children) { - child_idx.emplace_back(shared.RegisterCollection(child, false)); - } -} - -WindowAggregator::~WindowAggregator() { -} - -unique_ptr WindowAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &) const { - return make_uniq(context, *this, group_count); -} - -void WindowAggregatorLocalState::Sink(WindowAggregatorGlobalState &gastate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx) { -} - -void WindowAggregator::Sink(WindowAggregatorState &gstate, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); - lastate.Sink(gastate, sink_chunk, coll_chunk, input_idx); - if (filter_sel) { - auto &filter_mask = gastate.filter_mask; - for (idx_t f = 0; f < filtered; ++f) { - filter_mask.SetValid(input_idx + filter_sel->get_index(f)); - } - } -} - -void WindowAggregatorLocalState::InitSubFrames(SubFrames &frames, const WindowExcludeMode exclude_mode) { - idx_t nframes = 0; - switch (exclude_mode) { - case WindowExcludeMode::NO_OTHER: - nframes = 1; - break; - case WindowExcludeMode::TIES: - nframes = 3; - break; - case WindowExcludeMode::CURRENT_ROW: - case WindowExcludeMode::GROUP: - nframes = 2; - break; - } - frames.resize(nframes, {0, 0}); -} - -void WindowAggregatorLocalState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - // Prepare to scan - if (!cursor) { - cursor = make_uniq(*collection, gastate.aggregator.child_idx); - } -} - -void WindowAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) { - auto &gasink = gstate.Cast(); - auto &lastate = lstate.Cast(); - lastate.Finalize(gasink, collection); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_boundaries_state.cpp b/src/duckdb/src/function/window/window_boundaries_state.cpp deleted file mode 100644 index 7727e2a63..000000000 --- a/src/duckdb/src/function/window/window_boundaries_state.cpp +++ /dev/null @@ -1,854 +0,0 @@ -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/function/window/window_boundaries_state.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowBoundariesState -//===--------------------------------------------------------------------===// -idx_t WindowBoundariesState::FindNextStart(const ValidityMask &mask, idx_t l, const idx_t r, idx_t &n) { - if (mask.AllValid()) { - auto start = MinValue(l + n - 1, r); - n -= MinValue(n, r - l); - return start; - } - - while (l < r) { - // If l is aligned with the start of a block, and the block is blank, then skip forward one block. - idx_t entry_idx; - idx_t shift; - mask.GetEntryIndex(l, entry_idx, shift); - - const auto block = mask.GetValidityEntry(entry_idx); - if (mask.NoneValid(block) && !shift) { - l += ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop over the block - for (; shift < ValidityMask::BITS_PER_VALUE && l < r; ++shift, ++l) { - if (mask.RowIsValid(block, shift) && --n == 0) { - return MinValue(l, r); - } - } - } - - // Didn't find a start so return the end of the range - return r; -} - -idx_t WindowBoundariesState::FindPrevStart(const ValidityMask &mask, const idx_t l, idx_t r, idx_t &n) { - if (mask.AllValid()) { - auto start = (r <= l + n) ? l : r - n; - n -= r - start; - return start; - } - - while (l < r) { - // If r is aligned with the start of a block, and the previous block is blank, - // then skip backwards one block. - idx_t entry_idx; - idx_t shift; - mask.GetEntryIndex(r - 1, entry_idx, shift); - - const auto block = mask.GetValidityEntry(entry_idx); - if (mask.NoneValid(block) && (shift + 1 == ValidityMask::BITS_PER_VALUE)) { - // r is nonzero (> l) and word aligned, so this will not underflow. - r -= ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop backwards over the block - // shift is probing r-1 >= l >= 0 - for (++shift; shift-- > 0 && l < r; --r) { - // l < r ensures n == 1 if result is supposed to be NULL because of EXCLUDE - if (mask.RowIsValid(block, shift) && --n == 0) { - return MaxValue(l, r - 1); - } - } - } - - // Didn't find a start so return the start of the range - return l; -} - -//===--------------------------------------------------------------------===// -// WindowColumnIterator -//===--------------------------------------------------------------------===// -template -struct WindowColumnIterator { - using iterator = WindowColumnIterator; - using iterator_category = std::random_access_iterator_tag; - using difference_type = std::ptrdiff_t; - using value_type = T; - using reference = T; - using pointer = idx_t; - - explicit WindowColumnIterator(WindowCursor &coll, pointer pos = 0) : coll(&coll), pos(pos) { - } - - // Forward iterator - inline reference operator*() const { - return coll->GetCell(0, pos); - } - inline explicit operator pointer() const { - return pos; - } - - inline iterator &operator++() { - ++pos; - return *this; - } - inline iterator operator++(int) { - auto result = *this; - ++(*this); - return result; - } - - // Bidirectional iterator - inline iterator &operator--() { - --pos; - return *this; - } - inline iterator operator--(int) { - auto result = *this; - --(*this); - return result; - } - - // Random Access - inline iterator &operator+=(difference_type n) { - pos += UnsafeNumericCast(n); - return *this; - } - inline iterator &operator-=(difference_type n) { - pos -= UnsafeNumericCast(n); - return *this; - } - - inline reference operator[](difference_type m) const { - return coll->GetCell(0, pos + m); - } - - friend inline iterator &operator+(const iterator &a, difference_type n) { - return iterator(a.coll, a.pos + n); - } - - friend inline iterator operator-(const iterator &a, difference_type n) { - return iterator(a.coll, a.pos - n); - } - - friend inline iterator operator+(difference_type n, const iterator &a) { - return a + n; - } - friend inline difference_type operator-(const iterator &a, const iterator &b) { - return difference_type(a.pos - b.pos); - } - - friend inline bool operator==(const iterator &a, const iterator &b) { - return a.pos == b.pos; - } - friend inline bool operator!=(const iterator &a, const iterator &b) { - return a.pos != b.pos; - } - friend inline bool operator<(const iterator &a, const iterator &b) { - return a.pos < b.pos; - } - friend inline bool operator<=(const iterator &a, const iterator &b) { - return a.pos <= b.pos; - } - friend inline bool operator>(const iterator &a, const iterator &b) { - return a.pos > b.pos; - } - friend inline bool operator>=(const iterator &a, const iterator &b) { - return a.pos >= b.pos; - } - -private: - // optional_ptr does not allow us to modify this, but the constructor enforces it. - WindowCursor *coll; - pointer pos; -}; - -template -struct OperationCompare : public std::function { - inline bool operator()(const T &lhs, const T &val) const { - return OP::template Operation(lhs, val); - } -}; - -template -static idx_t FindTypedRangeBound(WindowCursor &over, const idx_t order_begin, const idx_t order_end, - const WindowBoundary range, WindowInputExpression &boundary, const idx_t chunk_idx, - const FrameBounds &prev) { - D_ASSERT(!boundary.CellIsNull(chunk_idx)); - const auto val = boundary.GetCell(chunk_idx); - - OperationCompare comp; - - // Check that the value we are searching for is in range. - if (range == WindowBoundary::EXPR_PRECEDING_RANGE) { - // Preceding but value past the current value - const auto cur_val = over.GetCell(0, order_end - 1); - if (comp(cur_val, val)) { - throw OutOfRangeException("Invalid RANGE PRECEDING value"); - } - } else { - // Following but value before the current value - D_ASSERT(range == WindowBoundary::EXPR_FOLLOWING_RANGE); - const auto cur_val = over.GetCell(0, order_begin); - if (comp(val, cur_val)) { - throw OutOfRangeException("Invalid RANGE FOLLOWING value"); - } - } - - // Try to reuse the previous bounds to restrict the search. - // This is only valid if the previous bounds were non-empty - // Only inject the comparisons if the previous bounds are a strict subset. - WindowColumnIterator begin(over, order_begin); - WindowColumnIterator end(over, order_end); - if (prev.start < prev.end) { - if (order_begin < prev.start && prev.start < order_end) { - const auto first = over.GetCell(0, prev.start); - if (!comp(val, first)) { - // prev.first <= val, so we can start further forward - begin += UnsafeNumericCast(prev.start - order_begin); - } - } - if (order_begin < prev.end && prev.end < order_end) { - const auto second = over.GetCell(0, prev.end - 1); - if (!comp(second, val)) { - // val <= prev.second, so we can end further back - // (prev.second is the largest peer) - end -= UnsafeNumericCast(order_end - prev.end - 1); - } - } - } - - if (FROM) { - return idx_t(std::lower_bound(begin, end, val, comp)); - } else { - return idx_t(std::upper_bound(begin, end, val, comp)); - } -} - -template -static idx_t FindRangeBound(WindowCursor &over, const idx_t order_begin, const idx_t order_end, - const WindowBoundary range, WindowInputExpression &boundary, const idx_t chunk_idx, - const FrameBounds &prev) { - switch (boundary.InternalType()) { - case PhysicalType::INT8: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::INT16: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::INT32: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::INT64: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::UINT8: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::UINT16: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::UINT32: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::UINT64: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::INT128: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::UINT128: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, - prev); - case PhysicalType::FLOAT: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::DOUBLE: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case PhysicalType::INTERVAL: - return FindTypedRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, - prev); - default: - throw InternalException("Unsupported column type for RANGE"); - } -} - -template -static idx_t FindOrderedRangeBound(WindowCursor &over, const OrderType range_sense, const idx_t order_begin, - const idx_t order_end, const WindowBoundary range, WindowInputExpression &boundary, - const idx_t chunk_idx, const FrameBounds &prev) { - switch (range_sense) { - case OrderType::ASCENDING: - return FindRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - case OrderType::DESCENDING: - return FindRangeBound(over, order_begin, order_end, range, boundary, chunk_idx, prev); - default: - throw InternalException("Unsupported ORDER BY sense for RANGE"); - } -} - -bool WindowBoundariesState::HasPrecedingRange(const BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_PRECEDING_RANGE || wexpr.end == WindowBoundary::EXPR_PRECEDING_RANGE); -} - -bool WindowBoundariesState::HasFollowingRange(const BoundWindowExpression &wexpr) { - return (wexpr.start == WindowBoundary::EXPR_FOLLOWING_RANGE || wexpr.end == WindowBoundary::EXPR_FOLLOWING_RANGE); -} - -WindowBoundsSet WindowBoundariesState::GetWindowBounds(const BoundWindowExpression &wexpr) { - const auto partition_count = wexpr.partitions.size(); - const auto order_count = wexpr.orders.size(); - - WindowBoundsSet result; - switch (wexpr.GetExpressionType()) { - case ExpressionType::WINDOW_ROW_NUMBER: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_NTILE: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_RANK: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - result.insert(PEER_BEGIN); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_RANK_DENSE: - result.insert(PARTITION_BEGIN); - result.insert(PEER_BEGIN); - break; - case ExpressionType::WINDOW_PERCENT_RANK: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - result.insert(PEER_BEGIN); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_CUME_DIST: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - result.insert(PEER_END); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_LEAD: - case ExpressionType::WINDOW_LAG: - if (wexpr.arg_orders.empty()) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - } else { - // Secondary orders need to know where the frame is - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - } - break; - case ExpressionType::WINDOW_FIRST_VALUE: - case ExpressionType::WINDOW_LAST_VALUE: - case ExpressionType::WINDOW_NTH_VALUE: - case ExpressionType::WINDOW_AGGREGATE: - result.insert(FRAME_BEGIN); - result.insert(FRAME_END); - break; - default: - throw InternalException("Window aggregate type %s", ExpressionTypeToString(wexpr.GetExpressionType())); - } - - // Internal dependencies - if (result.count(FRAME_BEGIN) || result.count(FRAME_END)) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - - // if we have EXCLUDE GROUP / TIES, we also need peer boundaries - if (wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { - result.insert(PEER_BEGIN); - result.insert(PEER_END); - } - - // If the frames are RANGE, then we need peer boundaries - // If they are preceding or following, we also need to know - // where the valid values begin or end. - switch (wexpr.start) { - case WindowBoundary::CURRENT_ROW_RANGE: - result.insert(PEER_BEGIN); - break; - case WindowBoundary::EXPR_PRECEDING_RANGE: - result.insert(PEER_BEGIN); - result.insert(VALID_BEGIN); - result.insert(VALID_END); - break; - case WindowBoundary::EXPR_FOLLOWING_RANGE: - result.insert(PEER_BEGIN); - result.insert(VALID_END); - break; - default: - break; - } - - switch (wexpr.end) { - case WindowBoundary::CURRENT_ROW_RANGE: - result.insert(PEER_END); - break; - case WindowBoundary::EXPR_PRECEDING_RANGE: - result.insert(PEER_END); - result.insert(VALID_BEGIN); - break; - case WindowBoundary::EXPR_FOLLOWING_RANGE: - result.insert(PEER_END); - result.insert(VALID_BEGIN); - result.insert(VALID_END); - break; - default: - break; - } - } - - if (result.count(VALID_END)) { - result.insert(PARTITION_END); - if (HasFollowingRange(wexpr)) { - result.insert(VALID_BEGIN); - } - } - if (result.count(VALID_BEGIN)) { - result.insert(PARTITION_BEGIN); - result.insert(PARTITION_END); - } - if (result.count(PEER_END)) { - result.insert(PARTITION_END); - if (order_count) { - result.insert(PEER_BEGIN); - } - } - if (result.count(PARTITION_END) && (partition_count + order_count)) { - result.insert(PARTITION_BEGIN); - } - - return result; -} - -WindowBoundariesState::WindowBoundariesState(const BoundWindowExpression &wexpr, const idx_t input_size) - : required(GetWindowBounds(wexpr)), type(wexpr.GetExpressionType()), input_size(input_size), - start_boundary(wexpr.start), end_boundary(wexpr.end), partition_count(wexpr.partitions.size()), - order_count(wexpr.orders.size()), range_sense(wexpr.orders.empty() ? OrderType::INVALID : wexpr.orders[0].type), - has_preceding_range(HasPrecedingRange(wexpr)), has_following_range(HasFollowingRange(wexpr)) { -} - -void WindowBoundariesState::Bounds(DataChunk &bounds, idx_t row_idx, optional_ptr range, - const idx_t count, WindowInputExpression &boundary_start, - WindowInputExpression &boundary_end, const ValidityMask &partition_mask, - const ValidityMask &order_mask) { - bounds.Reset(); - D_ASSERT(bounds.ColumnCount() == 8); - - const auto is_jump = (next_pos != row_idx); - if (required.count(PARTITION_BEGIN)) { - PartitionBegin(bounds, row_idx, count, is_jump, partition_mask); - } - if (required.count(PARTITION_END)) { - PartitionEnd(bounds, row_idx, count, is_jump, partition_mask); - } - if (required.count(PEER_BEGIN)) { - PeerBegin(bounds, row_idx, count, is_jump, partition_mask, order_mask); - } - if (required.count(PEER_END)) { - PeerEnd(bounds, row_idx, count, partition_mask, order_mask); - } - if (required.count(VALID_BEGIN)) { - ValidBegin(bounds, row_idx, count, is_jump, partition_mask, order_mask, range); - } - if (required.count(VALID_END)) { - ValidEnd(bounds, row_idx, count, is_jump, partition_mask, order_mask, range); - } - if (required.count(FRAME_BEGIN)) { - FrameBegin(bounds, row_idx, count, boundary_start, range); - } - if (required.count(FRAME_END)) { - FrameEnd(bounds, row_idx, count, boundary_end, range); - } - next_pos += count; - - bounds.SetCardinality(count); -} - -void WindowBoundariesState::PartitionBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, - const ValidityMask &partition_mask) { - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - - // OVER() - if (partition_count + order_count == 0) { - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - partition_begin_data[chunk_idx] = 0; - } - return; - } - - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - // determine partition and peer group boundaries to ultimately figure out window size - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - - // when the partition changes, recompute the boundaries - if (!is_same_partition || is_jump) { - if (is_jump) { - idx_t n = 1; - partition_start = FindPrevStart(partition_mask, 0, row_idx + 1, n); - is_jump = false; - } else { - partition_start = row_idx; - } - } - - partition_begin_data[chunk_idx] = partition_start; - } -} - -void WindowBoundariesState::PartitionEnd(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, - const ValidityMask &partition_mask) { - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - - // OVER() - if (partition_count + order_count == 0) { - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - partition_end_data[chunk_idx] = input_size; - } - return; - } - - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - // determine partition and peer group boundaries to ultimately figure out window size - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - - // when the partition changes, recompute the boundaries - if (!is_same_partition || is_jump) { - // find end of partition - partition_end = input_size; - if (partition_count) { - const auto partition_begin = partition_begin_data[chunk_idx]; - idx_t n = 1; - partition_end = FindNextStart(partition_mask, partition_begin + 1, input_size, n); - } - is_jump = false; - } - - partition_end_data[chunk_idx] = partition_end; - } -} - -void WindowBoundariesState::PeerBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, - const ValidityMask &partition_mask, const ValidityMask &order_mask) { - - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); - - // OVER() - if (partition_count + order_count == 0) { - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - peer_begin_data[chunk_idx] = 0; - } - return; - } - - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - // determine partition and peer group boundaries to ultimately figure out window size - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - const auto is_peer = !order_mask.RowIsValidUnsafe(row_idx); - - // when the partition changes, recompute the boundaries - if (!is_same_partition || is_jump) { - // find end of partition - if (is_jump) { - idx_t n = 1; - peer_start = FindPrevStart(order_mask, 0, row_idx + 1, n); - } else { - peer_start = row_idx; - } - is_jump = false; - } else if (!is_peer) { - peer_start = row_idx; - } - - peer_begin_data[chunk_idx] = peer_start; - } -} - -void WindowBoundariesState::PeerEnd(DataChunk &bounds, idx_t row_idx, const idx_t count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) { - // OVER() - if (!order_count) { - bounds.data[PEER_END].Reference(bounds.data[PARTITION_END]); - return; - } - - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); - auto peer_end_data = FlatVector::GetData(bounds.data[PEER_END]); - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - idx_t n = 1; - const auto peer_start = peer_begin_data[chunk_idx]; - const auto partition_end = partition_end_data[chunk_idx]; - peer_end_data[chunk_idx] = FindNextStart(order_mask, peer_start + 1, partition_end, n); - } -} - -void WindowBoundariesState::ValidBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, - const ValidityMask &partition_mask, const ValidityMask &order_mask, - optional_ptr range) { - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto valid_begin_data = FlatVector::GetData(bounds.data[VALID_BEGIN]); - - // OVER() - D_ASSERT(partition_count + order_count != 0); - D_ASSERT(range); - - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - - if (!is_same_partition || is_jump) { - // Find valid ordering values for the new partition - // so we can exclude NULLs from RANGE expression computations - valid_start = partition_begin_data[chunk_idx]; - const auto valid_end = partition_end_data[chunk_idx]; - - if ((valid_start < valid_end) && has_preceding_range) { - // Exclude any leading NULLs - if (range->CellIsNull(0, valid_start)) { - idx_t n = 1; - valid_start = FindNextStart(order_mask, valid_start + 1, valid_end, n); - } - } - } - - valid_begin_data[chunk_idx] = valid_start; - } -} - -void WindowBoundariesState::ValidEnd(DataChunk &bounds, idx_t row_idx, const idx_t count, bool is_jump, - const ValidityMask &partition_mask, const ValidityMask &order_mask, - optional_ptr range) { - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto valid_begin_data = FlatVector::GetData(bounds.data[VALID_BEGIN]); - auto valid_end_data = FlatVector::GetData(bounds.data[VALID_END]); - - // OVER() - D_ASSERT(partition_count + order_count != 0); - D_ASSERT(range); - - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - const auto is_same_partition = !partition_mask.RowIsValidUnsafe(row_idx); - - if (!is_same_partition || is_jump) { - // Find valid ordering values for the new partition - // so we can exclude NULLs from RANGE expression computations - valid_end = partition_end_data[chunk_idx]; - - if ((valid_start < valid_end) && has_following_range) { - // Exclude any trailing NULLs - const auto valid_start = valid_begin_data[chunk_idx]; - if (range->CellIsNull(0, valid_end - 1)) { - idx_t n = 1; - valid_end = FindPrevStart(order_mask, valid_start, valid_end, n); - } - - // Reset range hints - prev.start = valid_start; - prev.end = valid_end; - } - } - - valid_end_data[chunk_idx] = valid_end; - } -} - -void WindowBoundariesState::FrameBegin(DataChunk &bounds, idx_t row_idx, const idx_t count, - WindowInputExpression &boundary_begin, optional_ptr range) { - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto peer_begin_data = FlatVector::GetData(bounds.data[PEER_BEGIN]); - auto valid_begin_data = FlatVector::GetData(bounds.data[VALID_BEGIN]); - auto valid_end_data = FlatVector::GetData(bounds.data[VALID_END]); - auto frame_begin_data = FlatVector::GetData(bounds.data[FRAME_BEGIN]); - - idx_t window_start = NumericLimits::Maximum(); - - switch (start_boundary) { - case WindowBoundary::UNBOUNDED_PRECEDING: - bounds.data[FRAME_BEGIN].Reference(bounds.data[PARTITION_BEGIN]); - // No need to clamp - return; - case WindowBoundary::CURRENT_ROW_ROWS: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - frame_begin_data[chunk_idx] = row_idx; - } - break; - case WindowBoundary::CURRENT_ROW_RANGE: - bounds.data[FRAME_BEGIN].Reference(bounds.data[PEER_BEGIN]); - frame_begin_data = peer_begin_data; - break; - case WindowBoundary::EXPR_PRECEDING_ROWS: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - int64_t computed_start; - if (!TrySubtractOperator::Operation(static_cast(row_idx), - boundary_begin.GetCell(chunk_idx), computed_start)) { - window_start = partition_begin_data[chunk_idx]; - } else { - window_start = UnsafeNumericCast(MaxValue(computed_start, 0)); - } - frame_begin_data[chunk_idx] = window_start; - } - break; - case WindowBoundary::EXPR_FOLLOWING_ROWS: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - int64_t computed_start; - if (!TryAddOperator::Operation(static_cast(row_idx), boundary_begin.GetCell(chunk_idx), - computed_start)) { - window_start = partition_begin_data[chunk_idx]; - } else { - window_start = UnsafeNumericCast(MaxValue(computed_start, 0)); - } - frame_begin_data[chunk_idx] = window_start; - } - break; - case WindowBoundary::EXPR_PRECEDING_RANGE: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - if (boundary_begin.CellIsNull(chunk_idx)) { - window_start = peer_begin_data[chunk_idx]; - } else { - const auto valid_start = valid_begin_data[chunk_idx]; - prev.end = valid_end_data[chunk_idx]; - window_start = FindOrderedRangeBound(*range, range_sense, valid_start, row_idx + 1, - start_boundary, boundary_begin, chunk_idx, prev); - prev.start = window_start; - } - frame_begin_data[chunk_idx] = window_start; - } - break; - case WindowBoundary::EXPR_FOLLOWING_RANGE: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - if (boundary_begin.CellIsNull(chunk_idx)) { - window_start = peer_begin_data[chunk_idx]; - } else { - const auto valid_end = valid_end_data[chunk_idx]; - prev.end = valid_end; - window_start = FindOrderedRangeBound(*range, range_sense, row_idx, valid_end, start_boundary, - boundary_begin, chunk_idx, prev); - prev.start = window_start; - } - frame_begin_data[chunk_idx] = window_start; - } - break; - default: - throw InternalException("Unsupported window start boundary"); - } - - ClampFrame(count, frame_begin_data, partition_begin_data, partition_end_data); -} - -void WindowBoundariesState::FrameEnd(DataChunk &bounds, idx_t row_idx, const idx_t count, - WindowInputExpression &boundary_end, optional_ptr range) { - auto partition_begin_data = FlatVector::GetData(bounds.data[PARTITION_BEGIN]); - auto partition_end_data = FlatVector::GetData(bounds.data[PARTITION_END]); - auto peer_end_data = FlatVector::GetData(bounds.data[PEER_END]); - auto valid_begin_data = FlatVector::GetData(bounds.data[VALID_BEGIN]); - auto valid_end_data = FlatVector::GetData(bounds.data[VALID_END]); - auto frame_end_data = FlatVector::GetData(bounds.data[FRAME_END]); - - idx_t window_end = NumericLimits::Maximum(); - - switch (end_boundary) { - case WindowBoundary::CURRENT_ROW_ROWS: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - frame_end_data[chunk_idx] = row_idx + 1; - } - break; - case WindowBoundary::CURRENT_ROW_RANGE: - bounds.data[FRAME_END].Reference(bounds.data[PEER_END]); - frame_end_data = peer_end_data; - break; - case WindowBoundary::UNBOUNDED_FOLLOWING: - bounds.data[FRAME_END].Reference(bounds.data[PARTITION_END]); - // No need to clamp - return; - case WindowBoundary::EXPR_PRECEDING_ROWS: { - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - int64_t computed_start; - if (!TrySubtractOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), - computed_start)) { - window_end = partition_end_data[chunk_idx]; - } else { - window_end = UnsafeNumericCast(MaxValue(computed_start, 0)); - } - frame_end_data[chunk_idx] = window_end; - } - break; - } - case WindowBoundary::EXPR_FOLLOWING_ROWS: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - int64_t computed_start; - if (!TryAddOperator::Operation(int64_t(row_idx + 1), boundary_end.GetCell(chunk_idx), - computed_start)) { - window_end = partition_end_data[chunk_idx]; - } else { - window_end = UnsafeNumericCast(MaxValue(computed_start, 0)); - } - frame_end_data[chunk_idx] = window_end; - } - break; - case WindowBoundary::EXPR_PRECEDING_RANGE: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end_data[chunk_idx]; - } else { - const auto valid_start = valid_begin_data[chunk_idx]; - prev.start = valid_start; - window_end = FindOrderedRangeBound(*range, range_sense, valid_start, row_idx + 1, end_boundary, - boundary_end, chunk_idx, prev); - prev.end = window_end; - } - frame_end_data[chunk_idx] = window_end; - } - break; - case WindowBoundary::EXPR_FOLLOWING_RANGE: - for (idx_t chunk_idx = 0; chunk_idx < count; ++chunk_idx, ++row_idx) { - if (boundary_end.CellIsNull(chunk_idx)) { - window_end = peer_end_data[chunk_idx]; - } else { - const auto valid_end = valid_end_data[chunk_idx]; - prev.start = valid_begin_data[chunk_idx]; - window_end = FindOrderedRangeBound(*range, range_sense, row_idx, valid_end, end_boundary, - boundary_end, chunk_idx, prev); - prev.end = window_end; - } - frame_end_data[chunk_idx] = window_end; - } - break; - default: - throw InternalException("Unsupported window end boundary"); - } - - ClampFrame(count, frame_end_data, partition_begin_data, partition_end_data); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_collection.cpp b/src/duckdb/src/function/window/window_collection.cpp deleted file mode 100644 index 0dee0cc84..000000000 --- a/src/duckdb/src/function/window/window_collection.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "duckdb/function/window/window_collection.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowCollection -//===--------------------------------------------------------------------===// -WindowCollection::WindowCollection(BufferManager &buffer_manager, idx_t count, const vector &types) - : all_valids(types.size()), types(types), count(count), buffer_manager(buffer_manager) { - if (!types.empty()) { - inputs = make_uniq(buffer_manager, types); - } - - validities.resize(types.size()); - - // Atomic vectors can't be constructed with a given value - for (auto &all_valid : all_valids) { - all_valid = true; - } -} - -void WindowCollection::GetCollection(idx_t row_idx, ColumnDataCollectionSpec &spec) { - if (spec.second && row_idx == spec.first + spec.second->Count()) { - return; - } - - lock_guard collection_guard(lock); - - auto collection = make_uniq(buffer_manager, types); - spec = {row_idx, collection.get()}; - Range probe {row_idx, collections.size()}; - auto i = std::upper_bound(ranges.begin(), ranges.end(), probe); - ranges.insert(i, probe); - collections.emplace_back(std::move(collection)); -} - -void WindowCollection::Combine(const ColumnSet &validity_cols) { - lock_guard collection_guard(lock); - - // If there are no columns (COUNT(*)) then this is a NOP - if (types.empty()) { - return; - } - - // Have we already combined? - if (inputs->Count()) { - D_ASSERT(collections.empty()); - D_ASSERT(ranges.empty()); - return; - } - - // If there are columns, we should have data - D_ASSERT(!collections.empty()); - D_ASSERT(!ranges.empty()); - - for (auto &range : ranges) { - inputs->Combine(*collections[range.second]); - } - collections.clear(); - ranges.clear(); - - if (validity_cols.empty()) { - return; - } - - D_ASSERT(inputs.get()); - - // Find all columns with NULLs - vector invalid_cols; - for (auto &col_idx : validity_cols) { - if (!all_valids[col_idx]) { - invalid_cols.emplace_back(col_idx); - validities[col_idx].Initialize(inputs->Count()); - } - } - - if (invalid_cols.empty()) { - return; - } - - WindowCursor cursor(*this, invalid_cols); - idx_t target_offset = 0; - while (cursor.Scan()) { - const auto count = cursor.chunk.size(); - for (idx_t i = 0; i < invalid_cols.size(); ++i) { - auto &other = FlatVector::Validity(cursor.chunk.data[i]); - const auto col_idx = invalid_cols[i]; - validities[col_idx].SliceInPlace(other, target_offset, 0, count); - } - target_offset += count; - } -} - -WindowBuilder::WindowBuilder(WindowCollection &collection) : collection(collection) { -} - -void WindowBuilder::Sink(DataChunk &chunk, idx_t input_idx) { - // Check whether we need a a new collection - if (!sink.second || input_idx < sink.first || sink.first + sink.second->Count() < input_idx) { - collection.GetCollection(input_idx, sink); - D_ASSERT(sink.second); - sink.second->InitializeAppend(appender); - } - sink.second->Append(appender, chunk); - - // Record NULLs - for (column_t col_idx = 0; col_idx < chunk.ColumnCount(); ++col_idx) { - if (!collection.all_valids[col_idx]) { - continue; - } - - // Column was valid, make sure it still is. - UnifiedVectorFormat data; - chunk.data[col_idx].ToUnifiedFormat(chunk.size(), data); - if (!data.validity.AllValid()) { - collection.all_valids[col_idx] = false; - } - } -} - -WindowCursor::WindowCursor(const WindowCollection &paged, vector column_ids) : paged(paged) { - D_ASSERT(paged.collections.empty()); - D_ASSERT(paged.ranges.empty()); - if (column_ids.empty()) { - // For things like COUNT(*) set the state up to contain the whole range - state.segment_index = 0; - state.chunk_index = 0; - state.current_row_index = 0; - state.next_row_index = paged.size(); - state.properties = ColumnDataScanProperties::ALLOW_ZERO_COPY; - chunk.SetCapacity(state.next_row_index); - chunk.SetCardinality(state.next_row_index); - return; - } else if (chunk.data.empty()) { - auto &inputs = paged.inputs; - D_ASSERT(inputs.get()); - inputs->InitializeScan(state, std::move(column_ids)); - inputs->InitializeScanChunk(state, chunk); - } -} - -WindowCursor::WindowCursor(const WindowCollection &paged, column_t col_idx) - : WindowCursor(paged, vector(1, col_idx)) { -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_constant_aggregator.cpp b/src/duckdb/src/function/window/window_constant_aggregator.cpp deleted file mode 100644 index 0e09c6f99..000000000 --- a/src/duckdb/src/function/window/window_constant_aggregator.cpp +++ /dev/null @@ -1,357 +0,0 @@ -#include "duckdb/function/window/window_constant_aggregator.hpp" - -#include "duckdb/function/function_binder.hpp" -#include "duckdb/function/window/window_aggregate_states.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowConstantAggregatorGlobalState -//===--------------------------------------------------------------------===// - -class WindowConstantAggregatorGlobalState : public WindowAggregatorGlobalState { -public: - WindowConstantAggregatorGlobalState(ClientContext &context, const WindowConstantAggregator &aggregator, idx_t count, - const ValidityMask &partition_mask); - - void Finalize(const FrameStats &stats); - - //! Partition starts - vector partition_offsets; - //! Reused result state container for the window functions - WindowAggregateStates statef; - //! Aggregate results - unique_ptr results; -}; - -WindowConstantAggregatorGlobalState::WindowConstantAggregatorGlobalState(ClientContext &context, - const WindowConstantAggregator &aggregator, - idx_t group_count, - const ValidityMask &partition_mask) - : WindowAggregatorGlobalState(context, aggregator, STANDARD_VECTOR_SIZE), statef(aggr) { - - // Locate the partition boundaries - if (partition_mask.AllValid()) { - partition_offsets.emplace_back(0); - } else { - idx_t entry_idx; - idx_t shift; - for (idx_t start = 0; start < group_count;) { - partition_mask.GetEntryIndex(start, entry_idx, shift); - - // If start is aligned with the start of a block, - // and the block is blank, then skip forward one block. - const auto block = partition_mask.GetValidityEntry(entry_idx); - if (partition_mask.NoneValid(block) && !shift) { - start += ValidityMask::BITS_PER_VALUE; - continue; - } - - // Loop over the block - for (; shift < ValidityMask::BITS_PER_VALUE && start < group_count; ++shift, ++start) { - if (partition_mask.RowIsValid(block, shift)) { - partition_offsets.emplace_back(start); - } - } - } - } - - // Initialise the vector for caching the results - results = make_uniq(aggregator.result_type, partition_offsets.size()); - - // Initialise the final states - statef.Initialize(partition_offsets.size()); - - // Add final guard - partition_offsets.emplace_back(group_count); -} - -//===--------------------------------------------------------------------===// -// WindowConstantAggregatorLocalState -//===--------------------------------------------------------------------===// -class WindowConstantAggregatorLocalState : public WindowAggregatorLocalState { -public: - explicit WindowConstantAggregatorLocalState(const WindowConstantAggregatorGlobalState &gstate); - ~WindowConstantAggregatorLocalState() override { - } - - void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered); - void Combine(WindowConstantAggregatorGlobalState &gstate); - -public: - //! The global state we are sharing - const WindowConstantAggregatorGlobalState &gstate; - //! Reusable chunk for sinking - DataChunk inputs; - //! Chunk for referencing the input columns - DataChunk payload_chunk; - //! A vector of pointers to "state", used for intermediate window segment aggregation - Vector statep; - //! Reused result state container for the window functions - WindowAggregateStates statef; - //! The current result partition being read - idx_t partition; - //! Shared SV for evaluation - SelectionVector matches; -}; - -WindowConstantAggregatorLocalState::WindowConstantAggregatorLocalState( - const WindowConstantAggregatorGlobalState &gstate) - : gstate(gstate), statep(Value::POINTER(0)), statef(gstate.statef.aggr), partition(0) { - matches.Initialize(); - - // Start the aggregates - auto &partition_offsets = gstate.partition_offsets; - auto &aggregator = gstate.aggregator; - statef.Initialize(partition_offsets.size() - 1); - - // Set up shared buffer - inputs.Initialize(Allocator::DefaultAllocator(), aggregator.arg_types); - payload_chunk.InitializeEmpty(inputs.GetTypes()); - - gstate.locals++; -} - -//===--------------------------------------------------------------------===// -// WindowConstantAggregator -//===--------------------------------------------------------------------===// -bool WindowConstantAggregator::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { - return false; - } - // window exclusion cannot be handled by constant aggregates - if (wexpr.exclude_clause != WindowExcludeMode::NO_OTHER) { - return false; - } - - // DISTINCT aggregation cannot be handled by constant aggregation - if (wexpr.distinct) { - return false; - } - - // COUNT(*) is already handled efficiently by segment trees. - if (wexpr.children.empty()) { - return false; - } - - /* - The default framing option is RANGE UNBOUNDED PRECEDING, which - is the same as RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT - ROW; it sets the frame to be all rows from the partition start - up through the current row's last peer (a row that the window's - ORDER BY clause considers equivalent to the current row; all - rows are peers if there is no ORDER BY). In general, UNBOUNDED - PRECEDING means that the frame starts with the first row of the - partition, and similarly UNBOUNDED FOLLOWING means that the - frame ends with the last row of the partition, regardless of - RANGE, ROWS or GROUPS mode. In ROWS mode, CURRENT ROW means that - the frame starts or ends with the current row; but in RANGE or - GROUPS mode it means that the frame starts or ends with the - current row's first or last peer in the ORDER BY ordering. The - offset PRECEDING and offset FOLLOWING options vary in meaning - depending on the frame mode. - */ - switch (wexpr.start) { - case WindowBoundary::UNBOUNDED_PRECEDING: - break; - case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { - return false; - } - break; - default: - return false; - } - - switch (wexpr.end) { - case WindowBoundary::UNBOUNDED_FOLLOWING: - break; - case WindowBoundary::CURRENT_ROW_RANGE: - if (!wexpr.orders.empty()) { - return false; - } - break; - default: - return false; - } - - return true; -} - -BoundWindowExpression &WindowConstantAggregator::RebindAggregate(ClientContext &context, BoundWindowExpression &wexpr) { - FunctionBinder::BindSortedAggregate(context, wexpr); - - return wexpr; -} - -WindowConstantAggregator::WindowConstantAggregator(BoundWindowExpression &wexpr, WindowSharedExpressions &shared, - ClientContext &context) - : WindowAggregator(RebindAggregate(context, wexpr)) { - - // We only need these values for Sink - for (auto &child : wexpr.children) { - child_idx.emplace_back(shared.RegisterSink(child)); - } -} - -unique_ptr WindowConstantAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { - return make_uniq(context, *this, group_count, partition_mask); -} - -void WindowConstantAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { - auto &lastate = lstate.Cast(); - - lastate.Sink(sink_chunk, coll_chunk, input_idx, filter_sel, filtered); -} - -void WindowConstantAggregatorLocalState::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t row, - optional_ptr filter_sel, idx_t filtered) { - auto &partition_offsets = gstate.partition_offsets; - const auto &aggr = gstate.aggr; - const auto chunk_begin = row; - const auto chunk_end = chunk_begin + sink_chunk.size(); - idx_t partition = - idx_t(std::upper_bound(partition_offsets.begin(), partition_offsets.end(), row) - partition_offsets.begin()) - - 1; - - auto state_f_data = statef.GetData(); - auto state_p_data = FlatVector::GetData(statep); - - auto &child_idx = gstate.aggregator.child_idx; - for (column_t c = 0; c < child_idx.size(); ++c) { - payload_chunk.data[c].Reference(sink_chunk.data[child_idx[c]]); - } - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - idx_t begin = 0; - idx_t filter_idx = 0; - auto partition_end = partition_offsets[partition + 1]; - while (row < chunk_end) { - if (row == partition_end) { - ++partition; - partition_end = partition_offsets[partition + 1]; - } - partition_end = MinValue(partition_end, chunk_end); - auto end = partition_end - chunk_begin; - - inputs.Reset(); - if (filter_sel) { - // Slice to any filtered rows in [begin, end) - SelectionVector sel; - - // Find the first value in [begin, end) - for (; filter_idx < filtered; ++filter_idx) { - auto idx = filter_sel->get_index(filter_idx); - if (idx >= begin) { - break; - } - } - - // Find the first value in [end, filtered) - sel.Initialize(filter_sel->data() + filter_idx); - idx_t nsel = 0; - for (; filter_idx < filtered; ++filter_idx, ++nsel) { - auto idx = filter_sel->get_index(filter_idx); - if (idx >= end) { - break; - } - } - - if (nsel != inputs.size()) { - inputs.Slice(payload_chunk, sel, nsel); - } - } else { - // Slice to [begin, end) - if (begin) { - for (idx_t c = 0; c < payload_chunk.ColumnCount(); ++c) { - inputs.data[c].Slice(payload_chunk.data[c], begin, end); - } - } else { - inputs.Reference(payload_chunk); - } - inputs.SetCardinality(end - begin); - } - - // Aggregate the filtered rows into a single state - const auto count = inputs.size(); - auto state = state_f_data[partition]; - if (aggr.function.simple_update) { - aggr.function.simple_update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), state, count); - } else { - state_p_data[0] = state_f_data[partition]; - aggr.function.update(inputs.data.data(), aggr_input_data, inputs.ColumnCount(), statep, count); - } - - // Skip filtered rows too! - row += end - begin; - begin = end; - } -} - -void WindowConstantAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { - auto &gastate = gstate.Cast(); - auto &lastate = lstate.Cast(); - - // Single-threaded combine - lock_guard finalize_guard(gastate.lock); - lastate.statef.Combine(gastate.statef); - lastate.statef.Destroy(); - - // Last one out turns off the lights! - if (++gastate.finalized == gastate.locals) { - gastate.statef.Finalize(*gastate.results); - gastate.statef.Destroy(); - } -} - -unique_ptr WindowConstantAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(gstate.Cast()); -} - -void WindowConstantAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - auto &gasink = gsink.Cast(); - const auto &partition_offsets = gasink.partition_offsets; - const auto &results = *gasink.results; - - auto begins = FlatVector::GetData(bounds.data[FRAME_BEGIN]); - // Chunk up the constants and copy them one at a time - auto &lcstate = lstate.Cast(); - idx_t matched = 0; - idx_t target_offset = 0; - for (idx_t i = 0; i < count; ++i) { - const auto begin = begins[i]; - // Find the partition containing [begin, end) - while (partition_offsets[lcstate.partition + 1] <= begin) { - // Flush the previous partition's data - if (matched) { - VectorOperations::Copy(results, result, lcstate.matches, matched, 0, target_offset); - target_offset += matched; - matched = 0; - } - ++lcstate.partition; - } - - lcstate.matches.set_index(matched++, lcstate.partition); - } - - // Flush the last partition - if (matched) { - // Optimize constant result - if (target_offset == 0 && matched == count) { - VectorOperations::Copy(results, result, lcstate.matches, 1, 0, target_offset); - result.SetVectorType(VectorType::CONSTANT_VECTOR); - } else { - VectorOperations::Copy(results, result, lcstate.matches, matched, 0, target_offset); - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_custom_aggregator.cpp b/src/duckdb/src/function/window/window_custom_aggregator.cpp deleted file mode 100644 index 8416e3031..000000000 --- a/src/duckdb/src/function/window/window_custom_aggregator.cpp +++ /dev/null @@ -1,146 +0,0 @@ -#include "duckdb/function/window/window_custom_aggregator.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowCustomAggregator -//===--------------------------------------------------------------------===// -bool WindowCustomAggregator::CanAggregate(const BoundWindowExpression &wexpr, WindowAggregationMode mode) { - if (!wexpr.aggregate) { - return false; - } - - if (!wexpr.aggregate->window) { - return false; - } - - // ORDER BY arguments are not currently supported - if (!wexpr.arg_orders.empty()) { - return false; - } - - return (mode < WindowAggregationMode::COMBINE); -} - -WindowCustomAggregator::WindowCustomAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared) - : WindowAggregator(wexpr, shared) { -} - -WindowCustomAggregator::~WindowCustomAggregator() { -} - -class WindowCustomAggregatorState : public WindowAggregatorLocalState { -public: - WindowCustomAggregatorState(const AggregateObject &aggr, const WindowExcludeMode exclude_mode); - ~WindowCustomAggregatorState() override; - -public: - //! The aggregate function - const AggregateObject aggr; - //! Data pointer that contains a single state, shared by all the custom evaluators - vector state; - //! Reused result state container for the window functions - Vector statef; - //! The frame boundaries, used for the window functions - SubFrames frames; -}; - -class WindowCustomAggregatorGlobalState : public WindowAggregatorGlobalState { -public: - explicit WindowCustomAggregatorGlobalState(ClientContext &context, const WindowCustomAggregator &aggregator, - idx_t group_count) - : WindowAggregatorGlobalState(context, aggregator, group_count), context(context) { - - gcstate = make_uniq(aggr, aggregator.exclude_mode); - } - - //! Buffer manager for paging custom accelerator data - ClientContext &context; - //! Traditional packed filter mask for API - ValidityMask filter_packed; - //! Data pointer that contains a single local state, used for global custom window execution state - unique_ptr gcstate; - //! Partition description for custom window APIs - unique_ptr partition_input; -}; - -WindowCustomAggregatorState::WindowCustomAggregatorState(const AggregateObject &aggr, - const WindowExcludeMode exclude_mode) - : aggr(aggr), state(aggr.function.state_size(aggr.function)), - statef(Value::POINTER(CastPointerToValue(state.data()))), frames(3, {0, 0}) { - // if we have a frame-by-frame method, share the single state - aggr.function.initialize(aggr.function, state.data()); - - InitSubFrames(frames, exclude_mode); -} - -WindowCustomAggregatorState::~WindowCustomAggregatorState() { - if (aggr.function.destructor) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.destructor(statef, aggr_input_data, 1); - } -} - -unique_ptr WindowCustomAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &) const { - return make_uniq(context, *this, group_count); -} - -void WindowCustomAggregator::Finalize(WindowAggregatorState &gstate, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { - // Single threaded Finalize for now - auto &gcsink = gstate.Cast(); - lock_guard gestate_guard(gcsink.lock); - if (gcsink.finalized) { - return; - } - - WindowAggregator::Finalize(gstate, lstate, collection, stats); - - auto inputs = collection->inputs.get(); - const auto count = collection->size(); - vector all_valids; - for (auto col_idx : child_idx) { - all_valids.push_back(collection->all_valids[col_idx]); - } - auto &filter_mask = gcsink.filter_mask; - auto &filter_packed = gcsink.filter_packed; - filter_mask.Pack(filter_packed, filter_mask.Capacity()); - - gcsink.partition_input = - make_uniq(gcsink.context, inputs, count, child_idx, all_valids, filter_packed, stats); - - if (aggr.function.window_init) { - auto &gcstate = *gcsink.gcstate; - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), gcstate.allocator); - aggr.function.window_init(aggr_input_data, *gcsink.partition_input, gcstate.state.data()); - } - - ++gcsink.finalized; -} - -unique_ptr WindowCustomAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(aggr, exclude_mode); -} - -void WindowCustomAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - auto &lcstate = lstate.Cast(); - auto &frames = lcstate.frames; - const_data_ptr_t gstate_p = nullptr; - auto &gcsink = gsink.Cast(); - if (gcsink.gcstate) { - gstate_p = gcsink.gcstate->state.data(); - } - - EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { - // Extract the range - AggregateInputData aggr_input_data(aggr.GetFunctionData(), lstate.allocator); - aggr.function.window(aggr_input_data, *gcsink.partition_input, gstate_p, lcstate.state.data(), frames, result, - i); - }); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_distinct_aggregator.cpp b/src/duckdb/src/function/window/window_distinct_aggregator.cpp deleted file mode 100644 index 1dc940f95..000000000 --- a/src/duckdb/src/function/window/window_distinct_aggregator.cpp +++ /dev/null @@ -1,758 +0,0 @@ -#include "duckdb/function/window/window_distinct_aggregator.hpp" - -#include "duckdb/common/sort/partition_state.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/execution/merge_sort_tree.hpp" -#include "duckdb/function/window/window_aggregate_states.hpp" -#include "duckdb/planner/bound_result_modifier.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -#include -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowDistinctAggregator -//===--------------------------------------------------------------------===// -bool WindowDistinctAggregator::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { - return false; - } - - return wexpr.distinct && wexpr.exclude_clause == WindowExcludeMode::NO_OTHER && wexpr.arg_orders.empty(); -} - -WindowDistinctAggregator::WindowDistinctAggregator(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared, - ClientContext &context) - : WindowAggregator(wexpr, shared), context(context) { -} - -class WindowDistinctAggregatorLocalState; - -class WindowDistinctAggregatorGlobalState; - -class WindowDistinctSortTree : public MergeSortTree { -public: - // prev_idx, input_idx - using ZippedTuple = std::tuple; - using ZippedElements = vector; - - explicit WindowDistinctSortTree(WindowDistinctAggregatorGlobalState &gdastate, idx_t count) : gdastate(gdastate) { - // Set up for parallel build - build_level = 0; - build_complete = 0; - build_run = 0; - build_run_length = 1; - build_num_runs = count; - } - - void Build(WindowDistinctAggregatorLocalState &ldastate); - -protected: - bool TryNextRun(idx_t &level_idx, idx_t &run_idx); - void BuildRun(idx_t level_nr, idx_t i, WindowDistinctAggregatorLocalState &ldastate); - - WindowDistinctAggregatorGlobalState &gdastate; -}; - -class WindowDistinctAggregatorGlobalState : public WindowAggregatorGlobalState { -public: - using GlobalSortStatePtr = unique_ptr; - using LocalSortStatePtr = unique_ptr; - using ZippedTuple = WindowDistinctSortTree::ZippedTuple; - using ZippedElements = WindowDistinctSortTree::ZippedElements; - - WindowDistinctAggregatorGlobalState(ClientContext &context, const WindowDistinctAggregator &aggregator, - idx_t group_count); - - //! Compute the block starts - void MeasurePayloadBlocks(); - //! Create a new local sort - optional_ptr InitializeLocalSort() const; - - //! Patch up the previous index block boundaries - void PatchPrevIdcs(); - bool TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate); - - // Single threaded sorting for now - ClientContext &context; - idx_t memory_per_thread; - - //! Finalize guard - mutable mutex lock; - //! Finalize stage - atomic stage; - //! Tasks launched - idx_t total_tasks = 0; - //! Tasks launched - mutable idx_t tasks_assigned; - //! Tasks landed - mutable atomic tasks_completed; - - //! The sorted payload data types (partition index) - vector payload_types; - //! The aggregate arguments + partition index - vector sort_types; - - //! Sorting operations - GlobalSortStatePtr global_sort; - //! Local sort set - mutable vector local_sorts; - //! The block starts (the scanner doesn't know this) plus the total count - vector block_starts; - - //! The block boundary seconds - mutable ZippedElements seconds; - //! The MST with the distinct back pointers - mutable MergeSortTree zipped_tree; - //! The merge sort tree for the aggregate. - WindowDistinctSortTree merge_sort_tree; - - //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes - WindowAggregateStates levels_flat_native; - //! For each level, the starting location in the levels_flat_native array - vector levels_flat_start; -}; - -WindowDistinctAggregatorGlobalState::WindowDistinctAggregatorGlobalState(ClientContext &context, - const WindowDistinctAggregator &aggregator, - idx_t group_count) - : WindowAggregatorGlobalState(context, aggregator, group_count), context(aggregator.context), - stage(PartitionSortStage::INIT), tasks_assigned(0), tasks_completed(0), merge_sort_tree(*this, group_count), - levels_flat_native(aggr) { - payload_types.emplace_back(LogicalType::UBIGINT); - - // 1: functionComputePrevIdcs(𝑖𝑛) - // 2: sorted ← [] - // We sort the aggregate arguments and use the partition index as a tie-breaker. - // TODO: Use a hash table? - sort_types = aggregator.arg_types; - for (const auto &type : payload_types) { - sort_types.emplace_back(type); - } - - vector orders; - for (const auto &type : sort_types) { - auto expr = make_uniq(Value(type)); - orders.emplace_back(BoundOrderByNode(OrderType::ASCENDING, OrderByNullType::NULLS_FIRST, std::move(expr))); - } - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - - global_sort = make_uniq(BufferManager::GetBufferManager(context), orders, payload_layout); - - memory_per_thread = PhysicalOperator::GetMaxThreadMemory(context); - - // 6: prevIdcs ← [] - // 7: prevIdcs[0] ← “-” - auto &prev_idcs = zipped_tree.Allocate(group_count); - - // To handle FILTER clauses we make the missing elements - // point to themselves so they won't be counted. - for (idx_t i = 0; i < group_count; ++i) { - prev_idcs[i] = ZippedTuple(i + 1, i); - } - - // compute space required to store aggregation states of merge sort tree - // this is one aggregate state per entry per level - idx_t internal_nodes = 0; - levels_flat_start.push_back(internal_nodes); - for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { - internal_nodes += zipped_tree.tree[level_nr].first.size(); - levels_flat_start.push_back(internal_nodes); - } - levels_flat_native.Initialize(internal_nodes); - - merge_sort_tree.tree.reserve(zipped_tree.tree.size()); - for (idx_t level_nr = 0; level_nr < zipped_tree.tree.size(); ++level_nr) { - auto &zipped_level = zipped_tree.tree[level_nr].first; - WindowDistinctSortTree::Elements level; - WindowDistinctSortTree::Offsets cascades; - level.resize(zipped_level.size()); - merge_sort_tree.tree.emplace_back(std::move(level), std::move(cascades)); - } -} - -optional_ptr WindowDistinctAggregatorGlobalState::InitializeLocalSort() const { - lock_guard local_sort_guard(lock); - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - ++tasks_assigned; - local_sorts.emplace_back(std::move(local_sort)); - - return local_sorts.back().get(); -} - -class WindowDistinctAggregatorLocalState : public WindowAggregatorLocalState { -public: - explicit WindowDistinctAggregatorLocalState(const WindowDistinctAggregatorGlobalState &aggregator); - - void Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered); - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; - void Sorted(); - void ExecuteTask(); - void Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx); - - //! Thread-local sorting data - optional_ptr local_sort; - //! Finalize stage - PartitionSortStage stage = PartitionSortStage::INIT; - //! Finalize scan block index - idx_t block_idx; - //! Thread-local tree aggregation - Vector update_v; - Vector source_v; - Vector target_v; - DataChunk leaves; - SelectionVector sel; - -protected: - //! Flush the accumulated intermediate states into the result states - void FlushStates(); - - //! The aggregator we are working with - const WindowDistinctAggregatorGlobalState &gastate; - DataChunk sort_chunk; - DataChunk payload_chunk; - //! Reused result state container for the window functions - WindowAggregateStates statef; - //! A vector of pointers to "state", used for buffering intermediate aggregates - Vector statep; - //! Reused state pointers for combining tree elements - Vector statel; - //! Count of buffered values - idx_t flush_count; - //! The frame boundaries, used for the window functions - SubFrames frames; -}; - -WindowDistinctAggregatorLocalState::WindowDistinctAggregatorLocalState( - const WindowDistinctAggregatorGlobalState &gastate) - : update_v(LogicalType::POINTER), source_v(LogicalType::POINTER), target_v(LogicalType::POINTER), gastate(gastate), - statef(gastate.aggr), statep(LogicalType::POINTER), statel(LogicalType::POINTER), flush_count(0) { - InitSubFrames(frames, gastate.aggregator.exclude_mode); - payload_chunk.Initialize(Allocator::DefaultAllocator(), gastate.payload_types); - - sort_chunk.Initialize(Allocator::DefaultAllocator(), gastate.sort_types); - sort_chunk.data.back().Reference(payload_chunk.data[0]); - - gastate.locals++; -} - -unique_ptr WindowDistinctAggregator::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { - return make_uniq(context, *this, group_count); -} - -void WindowDistinctAggregator::Sink(WindowAggregatorState &gsink, WindowAggregatorState &lstate, DataChunk &sink_chunk, - DataChunk &coll_chunk, idx_t input_idx, optional_ptr filter_sel, - idx_t filtered) { - WindowAggregator::Sink(gsink, lstate, sink_chunk, coll_chunk, input_idx, filter_sel, filtered); - - auto &ldstate = lstate.Cast(); - ldstate.Sink(sink_chunk, coll_chunk, input_idx, filter_sel, filtered); -} - -void WindowDistinctAggregatorLocalState::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, idx_t input_idx, - optional_ptr filter_sel, idx_t filtered) { - // 3: for i ← 0 to in.size do - // 4: sorted[i] ← (in[i], i) - const auto count = sink_chunk.size(); - payload_chunk.Reset(); - auto &sorted_vec = payload_chunk.data[0]; - auto sorted = FlatVector::GetData(sorted_vec); - std::iota(sorted, sorted + count, input_idx); - - // Our arguments are being fully materialised, - // but we also need them as sort keys. - auto &child_idx = gastate.aggregator.child_idx; - for (column_t c = 0; c < child_idx.size(); ++c) { - sort_chunk.data[c].Reference(coll_chunk.data[child_idx[c]]); - } - sort_chunk.data.back().Reference(sorted_vec); - sort_chunk.SetCardinality(sink_chunk); - payload_chunk.SetCardinality(sort_chunk); - - // Apply FILTER clause, if any - if (filter_sel) { - sort_chunk.Slice(*filter_sel, filtered); - payload_chunk.Slice(*filter_sel, filtered); - } - - if (!local_sort) { - local_sort = gastate.InitializeLocalSort(); - } - - local_sort->SinkChunk(sort_chunk, payload_chunk); - - if (local_sort->SizeInBytes() > gastate.memory_per_thread) { - local_sort->Sort(*gastate.global_sort, true); - } -} - -void WindowDistinctAggregatorLocalState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); - - //! Input data chunk, used for leaf segment aggregation - leaves.Initialize(Allocator::DefaultAllocator(), cursor->chunk.GetTypes()); - sel.Initialize(); -} - -void WindowDistinctAggregatorLocalState::ExecuteTask() { - auto &global_sort = *gastate.global_sort; - switch (stage) { - case PartitionSortStage::SCAN: - global_sort.AddLocalState(*gastate.local_sorts[block_idx]); - break; - case PartitionSortStage::MERGE: { - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); - break; - } - case PartitionSortStage::SORTED: - Sorted(); - break; - default: - break; - } - - ++gastate.tasks_completed; -} - -void WindowDistinctAggregatorGlobalState::MeasurePayloadBlocks() { - const auto &blocks = global_sort->sorted_blocks[0]->payload_data->data_blocks; - idx_t count = 0; - for (const auto &block : blocks) { - block_starts.emplace_back(count); - count += block->count; - } - block_starts.emplace_back(count); -} - -bool WindowDistinctAggregatorGlobalState::TryPrepareNextStage(WindowDistinctAggregatorLocalState &lstate) { - lock_guard stage_guard(lock); - - switch (stage.load()) { - case PartitionSortStage::INIT: - // 5: Sort sorted lexicographically increasing - total_tasks = local_sorts.size(); - tasks_assigned = 0; - tasks_completed = 0; - lstate.stage = stage = PartitionSortStage::SCAN; - lstate.block_idx = tasks_assigned++; - return true; - case PartitionSortStage::SCAN: - // Process all the local sorts - if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::SCAN; - lstate.block_idx = tasks_assigned++; - return true; - } else if (tasks_completed < tasks_assigned) { - return false; - } - global_sort->PrepareMergePhase(); - if (!(global_sort->sorted_blocks.size() / 2)) { - if (global_sort->sorted_blocks.empty()) { - lstate.stage = stage = PartitionSortStage::FINISHED; - return true; - } - MeasurePayloadBlocks(); - seconds.resize(block_starts.size() - 1); - total_tasks = seconds.size(); - tasks_completed = 0; - tasks_assigned = 0; - lstate.stage = stage = PartitionSortStage::SORTED; - lstate.block_idx = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.stage = stage = PartitionSortStage::MERGE; - total_tasks = locals; - tasks_assigned = 1; - tasks_completed = 0; - return true; - case PartitionSortStage::MERGE: - if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::MERGE; - ++tasks_assigned; - return true; - } else if (tasks_completed < tasks_assigned) { - return false; - } - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - MeasurePayloadBlocks(); - seconds.resize(block_starts.size() - 1); - total_tasks = seconds.size(); - tasks_completed = 0; - tasks_assigned = 0; - lstate.stage = stage = PartitionSortStage::SORTED; - lstate.block_idx = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.stage = PartitionSortStage::MERGE; - total_tasks = locals; - tasks_assigned = 1; - tasks_completed = 0; - return true; - case PartitionSortStage::SORTED: - if (tasks_assigned < total_tasks) { - lstate.stage = PartitionSortStage::SORTED; - lstate.block_idx = tasks_assigned++; - return true; - } else if (tasks_completed < tasks_assigned) { - lstate.stage = PartitionSortStage::FINISHED; - // Sleep while other tasks finish - return false; - } - // Last task patches the boundaries - PatchPrevIdcs(); - break; - default: - break; - } - - lstate.stage = stage = PartitionSortStage::FINISHED; - - return true; -} - -void WindowDistinctAggregator::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, - CollectionPtr collection, const FrameStats &stats) { - auto &gdsink = gsink.Cast(); - auto &ldstate = lstate.Cast(); - ldstate.Finalize(gdsink, collection); - - // Sort, merge and build the tree in parallel - while (gdsink.stage.load() != PartitionSortStage::FINISHED) { - if (gdsink.TryPrepareNextStage(ldstate)) { - ldstate.ExecuteTask(); - } else { - std::this_thread::yield(); - } - } - - // These are a parallel implementations, - // so every thread can call them. - gdsink.zipped_tree.Build(); - gdsink.merge_sort_tree.Build(ldstate); - - ++gdsink.finalized; -} - -void WindowDistinctAggregatorLocalState::Sorted() { - using ZippedTuple = WindowDistinctAggregatorGlobalState::ZippedTuple; - auto &global_sort = gastate.global_sort; - auto &prev_idcs = gastate.zipped_tree.LowestLevel(); - auto &aggregator = gastate.aggregator; - auto &scan_chunk = payload_chunk; - - auto scanner = make_uniq(*global_sort, block_idx); - const auto in_size = gastate.block_starts.at(block_idx + 1); - scanner->Scan(scan_chunk); - idx_t scan_idx = 0; - - auto *input_idx = FlatVector::GetData(scan_chunk.data[0]); - idx_t i = 0; - - SBIterator curr(*global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator prev(*global_sort, ExpressionType::COMPARE_LESSTHAN); - auto prefix_layout = global_sort->sort_layout.GetPrefixComparisonLayout(aggregator.arg_types.size()); - - const auto block_begin = gastate.block_starts.at(block_idx); - if (!block_begin) { - // First block, so set up initial sentinel - i = input_idx[scan_idx++]; - prev_idcs[i] = ZippedTuple(0, i); - std::get<0>(gastate.seconds[block_idx]) = i; - } else { - // Move to the to end of the previous block - // so we can record the comparison result for the first row - curr.SetIndex(block_begin - 1); - prev.SetIndex(block_begin - 1); - scan_idx = 0; - std::get<0>(gastate.seconds[block_idx]) = input_idx[scan_idx]; - } - - // 8: for i ← 1 to in.size do - for (++curr; curr.GetIndex() < in_size; ++curr, ++prev) { - // Scan second one chunk at a time - // Note the scan is one behind the iterators - if (scan_idx >= scan_chunk.size()) { - scan_chunk.Reset(); - scanner->Scan(scan_chunk); - scan_idx = 0; - input_idx = FlatVector::GetData(scan_chunk.data[0]); - } - auto second = i; - i = input_idx[scan_idx++]; - - int lt = 0; - if (prefix_layout.all_constant) { - lt = FastMemcmp(prev.entry_ptr, curr.entry_ptr, prefix_layout.comparison_size); - } else { - lt = Comparators::CompareTuple(prev.scan, curr.scan, prev.entry_ptr, curr.entry_ptr, prefix_layout, - prev.external); - } - - // 9: if sorted[i].first == sorted[i-1].first then - // 10: prevIdcs[i] ← sorted[i-1].second - // 11: else - // 12: prevIdcs[i] ← “-” - if (!lt) { - prev_idcs[i] = ZippedTuple(second + 1, i); - } else { - prev_idcs[i] = ZippedTuple(0, i); - } - } - - // Save the last value of i for patching up the block boundaries - std::get<1>(gastate.seconds[block_idx]) = i; -} - -void WindowDistinctAggregatorGlobalState::PatchPrevIdcs() { - // 13: return prevIdcs - - // Patch up the indices at block boundaries - // (We don't need to patch block 0.) - auto &prev_idcs = zipped_tree.LowestLevel(); - for (idx_t block_idx = 1; block_idx < seconds.size(); ++block_idx) { - // We only need to patch if the first index in the block - // was a back link to the previous block (10:) - auto i = std::get<0>(seconds.at(block_idx)); - if (std::get<0>(prev_idcs[i])) { - auto second = std::get<1>(seconds.at(block_idx - 1)); - prev_idcs[i] = ZippedTuple(second + 1, i); - } - } -} - -bool WindowDistinctSortTree::TryNextRun(idx_t &level_idx, idx_t &run_idx) { - const auto fanout = FANOUT; - - lock_guard stage_guard(build_lock); - - // Verify we are not done - if (build_level >= tree.size()) { - return false; - } - - // Finished with this level? - if (build_complete >= build_num_runs) { - auto &zipped_tree = gdastate.zipped_tree; - std::swap(tree[build_level].second, zipped_tree.tree[build_level].second); - - ++build_level; - if (build_level >= tree.size()) { - zipped_tree.tree.clear(); - return false; - } - - const auto count = LowestLevel().size(); - build_run_length *= fanout; - build_num_runs = (count + build_run_length - 1) / build_run_length; - build_run = 0; - build_complete = 0; - } - - // If all runs are in flight, - // yield until the next level is ready - if (build_run >= build_num_runs) { - return false; - } - - level_idx = build_level; - run_idx = build_run++; - - return true; -} - -void WindowDistinctSortTree::Build(WindowDistinctAggregatorLocalState &ldastate) { - // Fan in parent levels until we are at the top - // Note that we don't build the top layer as that would just be all the data. - while (build_level.load() < tree.size()) { - idx_t level_idx; - idx_t run_idx; - if (TryNextRun(level_idx, run_idx)) { - BuildRun(level_idx, run_idx, ldastate); - } else { - std::this_thread::yield(); - } - } -} - -void WindowDistinctSortTree::BuildRun(idx_t level_nr, idx_t run_idx, WindowDistinctAggregatorLocalState &ldastate) { - auto &aggr = gdastate.aggr; - auto &allocator = gdastate.allocator; - auto &inputs = ldastate.cursor->chunk; - auto &levels_flat_native = gdastate.levels_flat_native; - - //! Input data chunk, used for leaf segment aggregation - auto &leaves = ldastate.leaves; - auto &sel = ldastate.sel; - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - - //! The states to update - auto &update_v = ldastate.update_v; - auto updates = FlatVector::GetData(update_v); - - auto &source_v = ldastate.source_v; - auto sources = FlatVector::GetData(source_v); - auto &target_v = ldastate.target_v; - auto targets = FlatVector::GetData(target_v); - - auto &zipped_tree = gdastate.zipped_tree; - auto &zipped_level = zipped_tree.tree[level_nr].first; - auto &level = tree[level_nr].first; - - // Reset the combine state - idx_t nupdate = 0; - idx_t ncombine = 0; - data_ptr_t prev_state = nullptr; - idx_t i = run_idx * build_run_length; - auto next_limit = MinValue(zipped_level.size(), i + build_run_length); - idx_t levels_flat_offset = level_nr * zipped_level.size() + i; - for (auto j = i; j < next_limit; ++j) { - // Initialise the next aggregate - auto curr_state = levels_flat_native.GetStatePtr(levels_flat_offset++); - - // Update this state (if it matches) - const auto prev_idx = std::get<0>(zipped_level[j]); - level[j] = prev_idx; - if (prev_idx < i + 1) { - const auto update_idx = std::get<1>(zipped_level[j]); - if (!ldastate.cursor->RowIsVisible(update_idx)) { - // Flush if we have to move the cursor - // Push the updates first so they propagate - leaves.Reference(inputs); - leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); - nupdate = 0; - - // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); - ncombine = 0; - - // Move the update into range. - ldastate.cursor->Seek(update_idx); - } - - updates[nupdate] = curr_state; - // input_idx - sel[nupdate] = ldastate.cursor->RowOffset(update_idx); - ++nupdate; - } - - // Merge the previous state (if any) - if (prev_state) { - sources[ncombine] = prev_state; - targets[ncombine] = curr_state; - ++ncombine; - } - prev_state = curr_state; - - // Flush the states if one is maxed out. - if (MaxValue(ncombine, nupdate) >= STANDARD_VECTOR_SIZE) { - // Push the updates first so they propagate - leaves.Reference(inputs); - leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); - nupdate = 0; - - // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); - ncombine = 0; - } - } - - // Flush any remaining states - if (ncombine || nupdate) { - // Push the updates - leaves.Reference(inputs); - leaves.Slice(sel, nupdate); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), update_v, nupdate); - nupdate = 0; - - // Combine the states sequentially - aggr.function.combine(source_v, target_v, aggr_input_data, ncombine); - ncombine = 0; - } - - ++build_complete; -} - -void WindowDistinctAggregatorLocalState::FlushStates() { - if (!flush_count) { - return; - } - - const auto &aggr = gastate.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - statel.Verify(flush_count); - aggr.function.combine(statel, statep, aggr_input_data, flush_count); - - flush_count = 0; -} - -void WindowDistinctAggregatorLocalState::Evaluate(const WindowDistinctAggregatorGlobalState &gdstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) { - auto ldata = FlatVector::GetData(statel); - auto pdata = FlatVector::GetData(statep); - - const auto &merge_sort_tree = gdstate.merge_sort_tree; - const auto &levels_flat_native = gdstate.levels_flat_native; - const auto exclude_mode = gdstate.aggregator.exclude_mode; - - // Build the finalise vector that just points to the result states - statef.Initialize(count); - - WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t rid) { - auto agg_state = statef.GetStatePtr(rid); - - // TODO: Extend AggregateLowerBound to handle subframes, just like SelectNth. - const auto lower = frames[0].start; - const auto upper = frames[0].end; - merge_sort_tree.AggregateLowerBound(lower, upper, lower + 1, - [&](idx_t level, const idx_t run_begin, const idx_t run_pos) { - if (run_pos != run_begin) { - // Find the source aggregate - // Buffer a merge of the indicated state into the current state - const auto agg_idx = gdstate.levels_flat_start[level] + run_pos - 1; - const auto running_agg = levels_flat_native.GetStatePtr(agg_idx); - pdata[flush_count] = agg_state; - ldata[flush_count++] = running_agg; - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(); - } - } - }); - }); - - // Flush the final states - FlushStates(); - - // Finalise the result aggregates and write to the result - statef.Finalize(result); - statef.Destroy(); -} - -unique_ptr WindowDistinctAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(gstate.Cast()); -} - -void WindowDistinctAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - - const auto &gdstate = gsink.Cast(); - auto &ldstate = lstate.Cast(); - ldstate.Evaluate(gdstate, bounds, result, count, row_idx); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_executor.cpp b/src/duckdb/src/function/window/window_executor.cpp deleted file mode 100644 index faa7215e4..000000000 --- a/src/duckdb/src/function/window/window_executor.cpp +++ /dev/null @@ -1,99 +0,0 @@ -#include "duckdb/function/window/window_executor.hpp" - -#include "duckdb/function/window/window_shared_expressions.hpp" - -#include "duckdb/common/array.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowExecutorBoundsState -//===--------------------------------------------------------------------===// -WindowExecutorBoundsState::WindowExecutorBoundsState(const WindowExecutorGlobalState &gstate) - : WindowExecutorLocalState(gstate), partition_mask(gstate.partition_mask), order_mask(gstate.order_mask), - state(gstate.executor.wexpr, gstate.payload_count) { - vector bounds_types(8, LogicalType(LogicalTypeId::UBIGINT)); - bounds.Initialize(Allocator::Get(gstate.executor.context), bounds_types); -} - -void WindowExecutorBoundsState::UpdateBounds(WindowExecutorGlobalState &gstate, idx_t row_idx, DataChunk &eval_chunk, - optional_ptr range) { - // Evaluate the row-level arguments - WindowInputExpression boundary_start(eval_chunk, gstate.executor.boundary_start_idx); - WindowInputExpression boundary_end(eval_chunk, gstate.executor.boundary_end_idx); - - const auto count = eval_chunk.size(); - state.Bounds(bounds, row_idx, range, count, boundary_start, boundary_end, partition_mask, order_mask); -} - -//===--------------------------------------------------------------------===// -// WindowExecutor -//===--------------------------------------------------------------------===// -WindowExecutor::WindowExecutor(BoundWindowExpression &wexpr, ClientContext &context, WindowSharedExpressions &shared) - : wexpr(wexpr), context(context), - range_expr((WindowBoundariesState::HasPrecedingRange(wexpr) || WindowBoundariesState::HasFollowingRange(wexpr)) - ? wexpr.orders[0].expression.get() - : nullptr) { - if (range_expr) { - range_idx = shared.RegisterCollection(wexpr.orders[0].expression, false); - } - - boundary_start_idx = shared.RegisterEvaluate(wexpr.start_expr); - boundary_end_idx = shared.RegisterEvaluate(wexpr.end_expr); -} - -void WindowExecutor::Evaluate(idx_t row_idx, DataChunk &eval_chunk, Vector &result, WindowExecutorLocalState &lstate, - WindowExecutorGlobalState &gstate) const { - auto &lbstate = lstate.Cast(); - lbstate.UpdateBounds(gstate, row_idx, eval_chunk, lstate.range_cursor); - - const auto count = eval_chunk.size(); - EvaluateInternal(gstate, lstate, eval_chunk, result, count, row_idx); - - result.Verify(count); -} - -WindowExecutorGlobalState::WindowExecutorGlobalState(const WindowExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : executor(executor), payload_count(payload_count), partition_mask(partition_mask), order_mask(order_mask) { - for (const auto &child : executor.wexpr.children) { - arg_types.emplace_back(child->return_type); - } -} - -WindowExecutorLocalState::WindowExecutorLocalState(const WindowExecutorGlobalState &gstate) { -} - -void WindowExecutorLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { -} - -void WindowExecutorLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - const auto range_idx = gstate.executor.range_idx; - if (range_idx != DConstants::INVALID_INDEX) { - range_cursor = make_uniq(*collection, range_idx); - } -} - -unique_ptr WindowExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -unique_ptr WindowExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate); -} - -void WindowExecutor::Sink(DataChunk &sink_chunk, DataChunk &coll_chunk, const idx_t input_idx, - WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate) const { - lstate.Sink(gstate, sink_chunk, coll_chunk, input_idx); -} - -void WindowExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { - lstate.Finalize(gstate, collection); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_index_tree.cpp b/src/duckdb/src/function/window/window_index_tree.cpp deleted file mode 100644 index 5791b2af7..000000000 --- a/src/duckdb/src/function/window/window_index_tree.cpp +++ /dev/null @@ -1,63 +0,0 @@ -#include "duckdb/function/window/window_index_tree.hpp" - -#include -#include - -namespace duckdb { - -WindowIndexTree::WindowIndexTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count) - : WindowMergeSortTree(context, orders, sort_idx, count) { -} - -WindowIndexTree::WindowIndexTree(ClientContext &context, const BoundOrderModifier &order_bys, - const vector &sort_idx, const idx_t count) - : WindowIndexTree(context, order_bys.orders, sort_idx, count) { -} - -unique_ptr WindowIndexTree::GetLocalState() { - return make_uniq(*this); -} - -WindowIndexTreeLocalState::WindowIndexTreeLocalState(WindowIndexTree &index_tree) - : WindowMergeSortTreeLocalState(index_tree), index_tree(index_tree) { -} - -void WindowIndexTreeLocalState::BuildLeaves() { - auto &global_sort = *index_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { - return; - } - - PayloadScanner scanner(global_sort, build_task); - idx_t row_idx = index_tree.block_starts[build_task]; - for (;;) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); - const auto count = payload_chunk.size(); - if (count == 0) { - break; - } - auto &indices = payload_chunk.data[0]; - if (index_tree.mst32) { - auto &sorted = index_tree.mst32->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } else { - auto &sorted = index_tree.mst64->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } - row_idx += count; - } -} - -idx_t WindowIndexTree::SelectNth(const SubFrames &frames, idx_t n) const { - if (mst32) { - return mst32->NthElement(mst32->SelectNth(frames, n)); - } else { - return mst64->NthElement(mst64->SelectNth(frames, n)); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_merge_sort_tree.cpp b/src/duckdb/src/function/window/window_merge_sort_tree.cpp deleted file mode 100644 index ef22694c4..000000000 --- a/src/duckdb/src/function/window/window_merge_sort_tree.cpp +++ /dev/null @@ -1,275 +0,0 @@ -#include "duckdb/function/window/window_merge_sort_tree.hpp" -#include "duckdb/planner/expression/bound_constant_expression.hpp" - -#include -#include - -namespace duckdb { - -WindowMergeSortTree::WindowMergeSortTree(ClientContext &context, const vector &orders, - const vector &sort_idx, const idx_t count, bool unique) - : context(context), memory_per_thread(PhysicalOperator::GetMaxThreadMemory(context)), sort_idx(sort_idx), - build_stage(PartitionSortStage::INIT), tasks_completed(0) { - // Sort the unfiltered indices by the orders - const auto force_external = ClientConfig::GetConfig(context).force_external; - LogicalType index_type; - if (count < std::numeric_limits::max() && !force_external) { - index_type = LogicalType::INTEGER; - mst32 = make_uniq(); - } else { - index_type = LogicalType::BIGINT; - mst64 = make_uniq(); - } - - vector payload_types; - payload_types.emplace_back(index_type); - - RowLayout payload_layout; - payload_layout.Initialize(payload_types); - - auto &buffer_manager = BufferManager::GetBufferManager(context); - if (unique) { - vector unique_orders; - for (const auto &order : orders) { - unique_orders.emplace_back(order.Copy()); - } - auto unique_expr = make_uniq(Value(index_type)); - const auto order_type = OrderType::ASCENDING; - const auto order_by_type = OrderByNullType::NULLS_LAST; - unique_orders.emplace_back(BoundOrderByNode(order_type, order_by_type, std::move(unique_expr))); - global_sort = make_uniq(buffer_manager, unique_orders, payload_layout); - } else { - global_sort = make_uniq(buffer_manager, orders, payload_layout); - } - global_sort->external = force_external; -} - -optional_ptr WindowMergeSortTree::AddLocalSort() { - lock_guard local_sort_guard(lock); - auto local_sort = make_uniq(); - local_sort->Initialize(*global_sort, global_sort->buffer_manager); - local_sorts.emplace_back(std::move(local_sort)); - - return local_sorts.back().get(); -} - -WindowMergeSortTreeLocalState::WindowMergeSortTreeLocalState(WindowMergeSortTree &window_tree) - : window_tree(window_tree) { - sort_chunk.Initialize(window_tree.context, window_tree.global_sort->sort_layout.logical_types); - payload_chunk.Initialize(window_tree.context, window_tree.global_sort->payload_layout.GetTypes()); - local_sort = window_tree.AddLocalSort(); -} - -void WindowMergeSortTreeLocalState::SinkChunk(DataChunk &chunk, const idx_t row_idx, - optional_ptr filter_sel, idx_t filtered) { - // Sequence the payload column - auto &indices = payload_chunk.data[0]; - payload_chunk.SetCardinality(chunk); - indices.Sequence(int64_t(row_idx), 1, payload_chunk.size()); - - // Reference the sort columns - auto &sort_idx = window_tree.sort_idx; - for (column_t c = 0; c < sort_idx.size(); ++c) { - sort_chunk.data[c].Reference(chunk.data[sort_idx[c]]); - } - // Add the row numbers if we are uniquifying - if (sort_idx.size() < sort_chunk.ColumnCount()) { - sort_chunk.data[sort_idx.size()].Reference(indices); - } - sort_chunk.SetCardinality(chunk); - - // Apply FILTER clause, if any - if (filter_sel) { - sort_chunk.Slice(*filter_sel, filtered); - payload_chunk.Slice(*filter_sel, filtered); - } - - local_sort->SinkChunk(sort_chunk, payload_chunk); - - // Flush if we have too much data - if (local_sort->SizeInBytes() > window_tree.memory_per_thread) { - local_sort->Sort(*window_tree.global_sort, true); - } -} - -void WindowMergeSortTreeLocalState::ExecuteSortTask() { - switch (build_stage) { - case PartitionSortStage::SCAN: - window_tree.global_sort->AddLocalState(*window_tree.local_sorts[build_task]); - break; - case PartitionSortStage::MERGE: { - auto &global_sort = *window_tree.global_sort; - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); - break; - } - case PartitionSortStage::SORTED: - BuildLeaves(); - break; - default: - break; - } - - ++window_tree.tasks_completed; -} - -idx_t WindowMergeSortTree::MeasurePayloadBlocks() { - const auto &blocks = global_sort->sorted_blocks[0]->payload_data->data_blocks; - idx_t count = 0; - for (const auto &block : blocks) { - block_starts.emplace_back(count); - count += block->count; - } - block_starts.emplace_back(count); - - // Allocate the leaves. - if (mst32) { - mst32->Allocate(count); - mst32->LowestLevel().resize(count); - } else if (mst64) { - mst64->Allocate(count); - mst64->LowestLevel().resize(count); - } - - return count; -} - -void WindowMergeSortTreeLocalState::BuildLeaves() { - auto &global_sort = *window_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { - return; - } - - PayloadScanner scanner(global_sort, build_task); - idx_t row_idx = window_tree.block_starts[build_task]; - for (;;) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); - const auto count = payload_chunk.size(); - if (count == 0) { - break; - } - auto &indices = payload_chunk.data[0]; - if (window_tree.mst32) { - auto &sorted = window_tree.mst32->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } else { - auto &sorted = window_tree.mst64->LowestLevel(); - auto data = FlatVector::GetData(indices); - std::copy(data, data + count, sorted.data() + row_idx); - } - row_idx += count; - } -} - -void WindowMergeSortTree::CleanupSort() { - global_sort.reset(); - local_sorts.clear(); -} - -bool WindowMergeSortTree::TryPrepareSortStage(WindowMergeSortTreeLocalState &lstate) { - lock_guard stage_guard(lock); - - switch (build_stage.load()) { - case PartitionSortStage::INIT: - total_tasks = local_sorts.size(); - tasks_assigned = 0; - tasks_completed = 0; - lstate.build_stage = build_stage = PartitionSortStage::SCAN; - lstate.build_task = tasks_assigned++; - return true; - case PartitionSortStage::SCAN: - // Process all the local sorts - if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::SCAN; - lstate.build_task = tasks_assigned++; - return true; - } else if (tasks_completed < tasks_assigned) { - return false; - } - global_sort->PrepareMergePhase(); - if (!(global_sort->sorted_blocks.size() / 2)) { - if (global_sort->sorted_blocks.empty()) { - lstate.build_stage = build_stage = PartitionSortStage::FINISHED; - return true; - } - MeasurePayloadBlocks(); - total_tasks = block_starts.size() - 1; - tasks_completed = 0; - tasks_assigned = 0; - lstate.build_stage = build_stage = PartitionSortStage::SORTED; - lstate.build_task = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.build_stage = build_stage = PartitionSortStage::MERGE; - total_tasks = local_sorts.size(); - tasks_assigned = 1; - tasks_completed = 0; - return true; - case PartitionSortStage::MERGE: - if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::MERGE; - ++tasks_assigned; - return true; - } else if (tasks_completed < tasks_assigned) { - return false; - } - global_sort->CompleteMergeRound(true); - if (!(global_sort->sorted_blocks.size() / 2)) { - MeasurePayloadBlocks(); - total_tasks = block_starts.size() - 1; - tasks_completed = 0; - tasks_assigned = 0; - lstate.build_stage = build_stage = PartitionSortStage::SORTED; - lstate.build_task = tasks_assigned++; - return true; - } - global_sort->InitializeMergeRound(); - lstate.build_stage = PartitionSortStage::MERGE; - total_tasks = local_sorts.size(); - tasks_assigned = 1; - tasks_completed = 0; - return true; - case PartitionSortStage::SORTED: - if (tasks_assigned < total_tasks) { - lstate.build_stage = PartitionSortStage::SORTED; - lstate.build_task = tasks_assigned++; - return true; - } else if (tasks_completed < tasks_assigned) { - lstate.build_stage = PartitionSortStage::FINISHED; - // Sleep while other tasks finish - return false; - } - CleanupSort(); - break; - default: - break; - } - - lstate.build_stage = build_stage = PartitionSortStage::FINISHED; - - return true; -} - -void WindowMergeSortTreeLocalState::Sort() { - // Sort, merge and build the tree in parallel - while (window_tree.build_stage.load() != PartitionSortStage::FINISHED) { - if (window_tree.TryPrepareSortStage(*this)) { - ExecuteSortTask(); - } else { - std::this_thread::yield(); - } - } -} - -void WindowMergeSortTree::Build() { - if (mst32) { - mst32->Build(); - } else { - mst64->Build(); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_naive_aggregator.cpp b/src/duckdb/src/function/window/window_naive_aggregator.cpp deleted file mode 100644 index 448639e77..000000000 --- a/src/duckdb/src/function/window/window_naive_aggregator.cpp +++ /dev/null @@ -1,361 +0,0 @@ -#include "duckdb/function/window/window_naive_aggregator.hpp" -#include "duckdb/common/sort/sort.hpp" -#include "duckdb/function/window/window_collection.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" -#include "duckdb/function/window/window_aggregate_function.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowNaiveAggregator -//===--------------------------------------------------------------------===// -WindowNaiveAggregator::WindowNaiveAggregator(const WindowAggregateExecutor &executor, WindowSharedExpressions &shared) - : WindowAggregator(executor.wexpr, shared), executor(executor) { - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterCollection(order.expression, false)); - } -} - -WindowNaiveAggregator::~WindowNaiveAggregator() { -} - -class WindowNaiveState : public WindowAggregatorLocalState { -public: - struct HashRow { - explicit HashRow(WindowNaiveState &state) : state(state) { - } - - inline size_t operator()(const idx_t &i) const { - return state.Hash(i); - } - - WindowNaiveState &state; - }; - - struct EqualRow { - explicit EqualRow(WindowNaiveState &state) : state(state) { - } - - inline bool operator()(const idx_t &lhs, const idx_t &rhs) const { - return state.KeyEqual(lhs, rhs); - } - - WindowNaiveState &state; - }; - - using RowSet = std::unordered_set; - - explicit WindowNaiveState(const WindowNaiveAggregator &gsink); - - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; - - void Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx); - -protected: - //! Flush the accumulated intermediate states into the result states - void FlushStates(const WindowAggregatorGlobalState &gsink); - - //! Hashes a value for the hash table - size_t Hash(idx_t rid); - //! Compares two values for the hash table - bool KeyEqual(const idx_t &lhs, const idx_t &rhs); - - //! The global state - const WindowNaiveAggregator &aggregator; - //! Data pointer that contains a vector of states, used for row aggregation - vector state; - //! Reused result state container for the aggregate - Vector statef; - //! A vector of pointers to "state", used for buffering intermediate aggregates - Vector statep; - //! Input data chunk, used for leaf segment aggregation - DataChunk leaves; - //! The rows beging updated. - SelectionVector update_sel; - //! Count of buffered values - idx_t flush_count; - //! The frame boundaries, used for EXCLUDE - SubFrames frames; - //! The optional hash table used for DISTINCT - Vector hashes; - //! The state used for comparing the collection across chunk boundaries - unique_ptr comparer; - - //! The state used for scanning ORDER BY values from the collection - unique_ptr arg_orderer; - //! Reusable sort key chunk - DataChunk orderby_sort; - //! Reusable sort payload chunk - DataChunk orderby_payload; - //! Reusable sort key slicer - SelectionVector orderby_sel; - //! Reusable payload layout. - RowLayout payload_layout; -}; - -WindowNaiveState::WindowNaiveState(const WindowNaiveAggregator &aggregator_p) - : aggregator(aggregator_p), state(aggregator.state_size * STANDARD_VECTOR_SIZE), statef(LogicalType::POINTER), - statep((LogicalType::POINTER)), flush_count(0), hashes(LogicalType::HASH) { - InitSubFrames(frames, aggregator.exclude_mode); - - update_sel.Initialize(); - - // Build the finalise vector that just points to the result states - data_ptr_t state_ptr = state.data(); - D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR); - statef.SetVectorType(VectorType::CONSTANT_VECTOR); - statef.Flatten(STANDARD_VECTOR_SIZE); - auto fdata = FlatVector::GetData(statef); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { - fdata[i] = state_ptr; - state_ptr += aggregator.state_size; - } - - // Initialise any ORDER BY data - if (!aggregator.arg_order_idx.empty() && !arg_orderer) { - orderby_payload.Initialize(Allocator::DefaultAllocator(), {LogicalType::UBIGINT}); - payload_layout.Initialize(orderby_payload.GetTypes()); - orderby_sel.Initialize(); - } -} - -void WindowNaiveState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); - - // Set up the comparison scanner just in case - if (!comparer) { - comparer = make_uniq(*collection, aggregator.child_idx); - } - - // Set up the argument ORDER BY scanner if needed - if (!aggregator.arg_order_idx.empty() && !arg_orderer) { - arg_orderer = make_uniq(*collection, aggregator.arg_order_idx); - orderby_sort.Initialize(BufferAllocator::Get(gastate.context), arg_orderer->chunk.GetTypes()); - } - - // Initialise the chunks - const auto types = cursor->chunk.GetTypes(); - if (leaves.ColumnCount() == 0 && !types.empty()) { - leaves.Initialize(BufferAllocator::Get(gastate.context), types); - } -} - -void WindowNaiveState::FlushStates(const WindowAggregatorGlobalState &gsink) { - if (!flush_count) { - return; - } - - auto &scanned = cursor->chunk; - leaves.Slice(scanned, update_sel, flush_count); - - const auto &aggr = gsink.aggr; - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.update(leaves.data.data(), aggr_input_data, leaves.ColumnCount(), statep, flush_count); - - flush_count = 0; -} - -size_t WindowNaiveState::Hash(idx_t rid) { - D_ASSERT(cursor->RowIsVisible(rid)); - auto s = cursor->RowOffset(rid); - auto &scanned = cursor->chunk; - SelectionVector sel(&s); - leaves.Slice(scanned, sel, 1); - leaves.Hash(hashes); - - return *FlatVector::GetData(hashes); -} - -bool WindowNaiveState::KeyEqual(const idx_t &lidx, const idx_t &ridx) { - // One of the indices will be scanned, so make it the left one - auto lhs = lidx; - auto rhs = ridx; - if (!cursor->RowIsVisible(lhs)) { - std::swap(lhs, rhs); - D_ASSERT(cursor->RowIsVisible(lhs)); - } - - auto &scanned = cursor->chunk; - auto l = cursor->RowOffset(lhs); - SelectionVector lsel(&l); - - auto rreader = cursor.get(); - if (!cursor->RowIsVisible(rhs)) { - // Values on different pages! - rreader = comparer.get(); - rreader->Seek(rhs); - } - auto rscanned = &rreader->chunk; - auto r = rreader->RowOffset(rhs); - SelectionVector rsel(&r); - - sel_t f = 0; - SelectionVector fsel(&f); - - for (column_t c = 0; c < scanned.ColumnCount(); ++c) { - Vector left(scanned.data[c], lsel, 1); - Vector right(rscanned->data[c], rsel, 1); - if (!VectorOperations::NotDistinctFrom(left, right, nullptr, 1, nullptr, &fsel)) { - return false; - } - } - - return true; -} - -void WindowNaiveState::Evaluate(const WindowAggregatorGlobalState &gsink, const DataChunk &bounds, Vector &result, - idx_t count, idx_t row_idx) { - const auto &aggr = gsink.aggr; - auto &filter_mask = gsink.filter_mask; - const auto types = cursor->chunk.GetTypes(); - - auto fdata = FlatVector::GetData(statef); - auto pdata = FlatVector::GetData(statep); - - HashRow hash_row(*this); - EqualRow equal_row(*this); - RowSet row_set(STANDARD_VECTOR_SIZE, hash_row, equal_row); - - WindowAggregator::EvaluateSubFrames(bounds, aggregator.exclude_mode, count, row_idx, frames, [&](idx_t rid) { - auto agg_state = fdata[rid]; - aggr.function.initialize(aggr.function, agg_state); - - // Reset the DISTINCT hash table - row_set.clear(); - - // Sort the input rows by the argument - if (arg_orderer) { - auto &context = aggregator.executor.context; - auto &orders = aggregator.wexpr.arg_orders; - auto &buffer_manager = BufferManager::GetBufferManager(context); - GlobalSortState global_sort(buffer_manager, orders, payload_layout); - LocalSortState local_sort; - local_sort.Initialize(global_sort, global_sort.buffer_manager); - - idx_t orderby_count = 0; - auto orderby_row = FlatVector::GetData(orderby_payload.data[0]); - for (const auto &frame : frames) { - for (auto f = frame.start; f < frame.end; ++f) { - // FILTER before the ORDER BY - if (!filter_mask.RowIsValid(f)) { - continue; - } - - if (!arg_orderer->RowIsVisible(f) || orderby_count >= STANDARD_VECTOR_SIZE) { - if (orderby_count) { - orderby_sort.Reference(arg_orderer->chunk); - orderby_sort.Slice(orderby_sel, orderby_count); - orderby_payload.SetCardinality(orderby_count); - local_sort.SinkChunk(orderby_sort, orderby_payload); - } - orderby_count = 0; - arg_orderer->Seek(f); - } - orderby_row[orderby_count] = f; - orderby_sel.set_index(orderby_count++, arg_orderer->RowOffset(f)); - } - } - if (orderby_count) { - orderby_sort.Reference(arg_orderer->chunk); - orderby_sort.Slice(orderby_sel, orderby_count); - orderby_payload.SetCardinality(orderby_count); - local_sort.SinkChunk(orderby_sort, orderby_payload); - } - - global_sort.AddLocalState(local_sort); - if (global_sort.sorted_blocks.empty()) { - return; - } - global_sort.PrepareMergePhase(); - while (global_sort.sorted_blocks.size() > 1) { - global_sort.InitializeMergeRound(); - MergeSorter merge_sorter(global_sort, global_sort.buffer_manager); - merge_sorter.PerformInMergeRound(); - global_sort.CompleteMergeRound(false); - } - - PayloadScanner scanner(global_sort); - while (scanner.Remaining()) { - orderby_payload.Reset(); - scanner.Scan(orderby_payload); - orderby_row = FlatVector::GetData(orderby_payload.data[0]); - for (idx_t i = 0; i < orderby_payload.size(); ++i) { - const auto f = orderby_row[i]; - // Seek to the current position - if (!cursor->RowIsVisible(f)) { - // We need to flush when we cross a chunk boundary - FlushStates(gsink); - cursor->Seek(f); - } - - // Filter out duplicates - if (aggr.IsDistinct() && !row_set.insert(f).second) { - continue; - } - - pdata[flush_count] = agg_state; - update_sel[flush_count++] = cursor->RowOffset(f); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(gsink); - } - } - } - return; - } - - // Just update the aggregate with the unfiltered input rows - for (const auto &frame : frames) { - for (auto f = frame.start; f < frame.end; ++f) { - if (!filter_mask.RowIsValid(f)) { - continue; - } - - // Seek to the current position - if (!cursor->RowIsVisible(f)) { - // We need to flush when we cross a chunk boundary - FlushStates(gsink); - cursor->Seek(f); - } - - // Filter out duplicates - if (aggr.IsDistinct() && !row_set.insert(f).second) { - continue; - } - - pdata[flush_count] = agg_state; - update_sel[flush_count++] = cursor->RowOffset(f); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(gsink); - } - } - } - }); - - // Flush the final states - FlushStates(gsink); - - // Finalise the result aggregates and write to the result - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); - - // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); - } -} - -unique_ptr WindowNaiveAggregator::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(*this); -} - -void WindowNaiveAggregator::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - const auto &gnstate = gsink.Cast(); - auto &lnstate = lstate.Cast(); - lnstate.Evaluate(gnstate, bounds, result, count, row_idx); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_rank_function.cpp b/src/duckdb/src/function/window/window_rank_function.cpp deleted file mode 100644 index b128ea89a..000000000 --- a/src/duckdb/src/function/window/window_rank_function.cpp +++ /dev/null @@ -1,288 +0,0 @@ -#include "duckdb/function/window/window_rank_function.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/function/window/window_token_tree.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowPeerGlobalState -//===--------------------------------------------------------------------===// -class WindowPeerGlobalState : public WindowExecutorGlobalState { -public: - WindowPeerGlobalState(const WindowPeerExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask) { - if (!executor.arg_order_idx.empty()) { - token_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, - payload_count); - } - } - - //! The token tree for ORDER BY arguments - unique_ptr token_tree; -}; - -//===--------------------------------------------------------------------===// -// WindowPeerLocalState -//===--------------------------------------------------------------------===// -// Base class for non-aggregate functions that use peer boundaries -class WindowPeerLocalState : public WindowExecutorBoundsState { -public: - explicit WindowPeerLocalState(const WindowPeerGlobalState &gpstate) - : WindowExecutorBoundsState(gpstate), gpstate(gpstate) { - if (gpstate.token_tree) { - local_tree = gpstate.token_tree->GetLocalState(); - } - } - - //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; - //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; - - void NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx); - - uint64_t dense_rank = 1; - uint64_t rank_equal = 0; - uint64_t rank = 1; - - //! The corresponding global peer state - const WindowPeerGlobalState &gpstate; - //! The optional sorting state for secondary sorts - unique_ptr local_tree; -}; - -void WindowPeerLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); - - if (local_tree) { - auto &local_tokens = local_tree->Cast(); - local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); - } -} - -void WindowPeerLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); - - if (local_tree) { - auto &local_tokens = local_tree->Cast(); - local_tokens.Sort(); - local_tokens.window_tree.Build(); - } -} - -void WindowPeerLocalState::NextRank(idx_t partition_begin, idx_t peer_begin, idx_t row_idx) { - if (partition_begin == row_idx) { - dense_rank = 1; - rank = 1; - rank_equal = 0; - } else if (peer_begin == row_idx) { - dense_rank++; - rank += rank_equal; - rank_equal = 0; - } - rank_equal++; -} - -//===--------------------------------------------------------------------===// -// WindowPeerExecutor -//===--------------------------------------------------------------------===// -WindowPeerExecutor::WindowPeerExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); - } -} - -unique_ptr WindowPeerExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -unique_ptr WindowPeerExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate.Cast()); -} - -//===--------------------------------------------------------------------===// -// WindowRankExecutor -//===--------------------------------------------------------------------===// -WindowRankExecutor::WindowRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { -} - -void WindowRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); - auto rdata = FlatVector::GetData(result); - - if (gpeer.token_tree) { - auto frame_begin = FlatVector::GetData(lpeer.bounds.data[FRAME_BEGIN]); - auto frame_end = FlatVector::GetData(lpeer.bounds.data[FRAME_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - rdata[i] = gpeer.token_tree->Rank(frame_begin[i], frame_end[i], row_idx); - } - return; - } - - // Reset to "previous" row - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = lpeer.rank; - } -} - -//===--------------------------------------------------------------------===// -// WindowDenseRankExecutor -//===--------------------------------------------------------------------===// -WindowDenseRankExecutor::WindowDenseRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { -} - -void WindowDenseRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &lpeer = lstate.Cast(); - - auto &order_mask = gstate.order_mask; - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - auto rdata = FlatVector::GetData(result); - - // Reset to "previous" row - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - // The previous dense rank is the number of order mask bits in [partition_begin, row_idx) - lpeer.dense_rank = 0; - - auto order_begin = partition_begin[0]; - idx_t begin_idx; - idx_t begin_offset; - order_mask.GetEntryIndex(order_begin, begin_idx, begin_offset); - - auto order_end = row_idx; - idx_t end_idx; - idx_t end_offset; - order_mask.GetEntryIndex(order_end, end_idx, end_offset); - - // If they are in the same entry, just loop - if (begin_idx == end_idx) { - const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < end_offset; ++begin_offset) { - lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - } - } else { - // Count the ragged bits at the start of the partition - if (begin_offset) { - const auto entry = order_mask.GetValidityEntry(begin_idx); - for (; begin_offset < order_mask.BITS_PER_VALUE; ++begin_offset) { - lpeer.dense_rank += order_mask.RowIsValid(entry, begin_offset); - ++order_begin; - } - ++begin_idx; - } - - // Count the the aligned bits. - ValidityMask tail_mask(order_mask.GetData() + begin_idx, end_idx - begin_idx); - lpeer.dense_rank += tail_mask.CountValid(order_end - order_begin); - } - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - rdata[i] = NumericCast(lpeer.dense_rank); - } -} - -//===--------------------------------------------------------------------===// -// WindowPercentRankExecutor -//===--------------------------------------------------------------------===// -WindowPercentRankExecutor::WindowPercentRankExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { -} - -void WindowPercentRankExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); - auto rdata = FlatVector::GetData(result); - - if (gpeer.token_tree) { - auto frame_begin = FlatVector::GetData(lpeer.bounds.data[FRAME_BEGIN]); - auto frame_end = FlatVector::GetData(lpeer.bounds.data[FRAME_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - auto denom = static_cast(NumericCast(frame_end[i] - frame_begin[i] - 1)); - const auto rank = gpeer.token_tree->Rank(frame_begin[i], frame_end[i], row_idx); - double percent_rank = denom > 0 ? ((double)rank - 1) / denom : 0; - rdata[i] = percent_rank; - } - return; - } - - // Reset to "previous" row - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lpeer.bounds.data[PARTITION_END]); - auto peer_begin = FlatVector::GetData(lpeer.bounds.data[PEER_BEGIN]); - lpeer.rank = (peer_begin[0] - partition_begin[0]) + 1; - lpeer.rank_equal = (row_idx - peer_begin[0]); - - for (idx_t i = 0; i < count; ++i, ++row_idx) { - lpeer.NextRank(partition_begin[i], peer_begin[i], row_idx); - auto denom = static_cast(NumericCast(partition_end[i] - partition_begin[i] - 1)); - double percent_rank = denom > 0 ? ((double)lpeer.rank - 1) / denom : 0; - rdata[i] = percent_rank; - } -} - -//===--------------------------------------------------------------------===// -// WindowCumeDistExecutor -//===--------------------------------------------------------------------===// -WindowCumeDistExecutor::WindowCumeDistExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowPeerExecutor(wexpr, context, shared) { -} - -void WindowCumeDistExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &gpeer = gstate.Cast(); - auto &lpeer = lstate.Cast(); - auto rdata = FlatVector::GetData(result); - - if (gpeer.token_tree) { - auto frame_begin = FlatVector::GetData(lpeer.bounds.data[FRAME_BEGIN]); - auto frame_end = FlatVector::GetData(lpeer.bounds.data[FRAME_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - const auto denom = static_cast(NumericCast(frame_end[i] - frame_begin[i])); - const auto peer_end = gpeer.token_tree->PeerEnd(frame_begin[i], frame_end[i], row_idx); - const auto num = static_cast(peer_end - frame_begin[i]); - rdata[i] = denom > 0 ? (num / denom) : 0; - } - return; - } - - auto partition_begin = FlatVector::GetData(lpeer.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lpeer.bounds.data[PARTITION_END]); - auto peer_end = FlatVector::GetData(lpeer.bounds.data[PEER_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - const auto denom = static_cast(NumericCast(partition_end[i] - partition_begin[i])); - const auto num = static_cast(peer_end[i] - partition_begin[i]); - rdata[i] = denom > 0 ? (num / denom) : 0; - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_rownumber_function.cpp b/src/duckdb/src/function/window/window_rownumber_function.cpp deleted file mode 100644 index f946583f1..000000000 --- a/src/duckdb/src/function/window/window_rownumber_function.cpp +++ /dev/null @@ -1,191 +0,0 @@ -#include "duckdb/function/window/window_rownumber_function.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/function/window/window_token_tree.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowRowNumberGlobalState -//===--------------------------------------------------------------------===// -class WindowRowNumberGlobalState : public WindowExecutorGlobalState { -public: - WindowRowNumberGlobalState(const WindowRowNumberExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), - ntile_idx(executor.ntile_idx) { - if (!executor.arg_order_idx.empty()) { - // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their position in - // the input data, such that two elements never compare as equal." - token_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, - payload_count, true); - } - } - - //! The token tree for ORDER BY arguments - unique_ptr token_tree; - - //! The evaluation index for NTILE - const column_t ntile_idx; -}; - -//===--------------------------------------------------------------------===// -// WindowRowNumberLocalState -//===--------------------------------------------------------------------===// -class WindowRowNumberLocalState : public WindowExecutorBoundsState { -public: - explicit WindowRowNumberLocalState(const WindowRowNumberGlobalState &grstate) - : WindowExecutorBoundsState(grstate), grstate(grstate) { - if (grstate.token_tree) { - local_tree = grstate.token_tree->GetLocalState(); - } - } - - //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; - //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; - - //! The corresponding global peer state - const WindowRowNumberGlobalState &grstate; - //! The optional sorting state for secondary sorts - unique_ptr local_tree; -}; - -void WindowRowNumberLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); - - if (local_tree) { - auto &local_tokens = local_tree->Cast(); - local_tokens.SinkChunk(sink_chunk, input_idx, nullptr, 0); - } -} - -void WindowRowNumberLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); - - if (local_tree) { - auto &local_tokens = local_tree->Cast(); - local_tokens.Sort(); - local_tokens.window_tree.Build(); - } -} - -//===--------------------------------------------------------------------===// -// WindowRowNumberExecutor -//===--------------------------------------------------------------------===// -WindowRowNumberExecutor::WindowRowNumberExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); - } -} - -unique_ptr WindowRowNumberExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -unique_ptr -WindowRowNumberExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - return make_uniq(gstate.Cast()); -} - -void WindowRowNumberExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &grstate = gstate.Cast(); - auto &lrstate = lstate.Cast(); - auto rdata = FlatVector::GetData(result); - - if (grstate.token_tree) { - auto frame_begin = FlatVector::GetData(lrstate.bounds.data[FRAME_BEGIN]); - auto frame_end = FlatVector::GetData(lrstate.bounds.data[FRAME_END]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - // Row numbers are unique ranks - rdata[i] = grstate.token_tree->Rank(frame_begin[i], frame_end[i], row_idx); - } - return; - } - - auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - rdata[i] = row_idx - partition_begin[i] + 1; - } -} - -//===--------------------------------------------------------------------===// -// WindowNtileExecutor -//===--------------------------------------------------------------------===// -WindowNtileExecutor::WindowNtileExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowRowNumberExecutor(wexpr, context, shared) { - - // NTILE has one argument - ntile_idx = shared.RegisterEvaluate(wexpr.children[0]); -} - -void WindowNtileExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &grstate = gstate.Cast(); - auto &lrstate = lstate.Cast(); - auto partition_begin = FlatVector::GetData(lrstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(lrstate.bounds.data[PARTITION_END]); - if (grstate.token_tree) { - // With secondary sorts, we restrict to the frame boundaries, but everything else should compute the same. - partition_begin = FlatVector::GetData(lrstate.bounds.data[FRAME_BEGIN]); - partition_end = FlatVector::GetData(lrstate.bounds.data[FRAME_END]); - } - auto rdata = FlatVector::GetData(result); - WindowInputExpression ntile_col(eval_chunk, ntile_idx); - for (idx_t i = 0; i < count; ++i, ++row_idx) { - if (ntile_col.CellIsNull(i)) { - FlatVector::SetNull(result, i, true); - } else { - auto n_param = ntile_col.GetCell(i); - if (n_param < 1) { - throw InvalidInputException("Argument for ntile must be greater than zero"); - } - // With thanks from SQLite's ntileValueFunc() - auto n_total = NumericCast(partition_end[i] - partition_begin[i]); - if (n_param > n_total) { - // more groups allowed than we have values - // map every entry to a unique group - n_param = n_total; - } - int64_t n_size = (n_total / n_param); - // find the row idx within the group - D_ASSERT(row_idx >= partition_begin[i]); - idx_t partition_idx = 0; - if (grstate.token_tree) { - partition_idx = grstate.token_tree->Rank(partition_begin[i], partition_end[i], row_idx) - 1; - } else { - partition_idx = row_idx - partition_begin[i]; - } - auto adjusted_row_idx = NumericCast(partition_idx); - - // now compute the ntile - int64_t n_large = n_total - n_param * n_size; - int64_t i_small = n_large * (n_size + 1); - int64_t result_ntile; - - D_ASSERT((n_large * (n_size + 1) + (n_param - n_large) * n_size) == n_total); - - if (adjusted_row_idx < i_small) { - result_ntile = 1 + adjusted_row_idx / (n_size + 1); - } else { - result_ntile = 1 + n_large + (adjusted_row_idx - i_small) / n_size; - } - // result has to be between [1, NTILE] - D_ASSERT(result_ntile >= 1 && result_ntile <= n_param); - rdata[i] = result_ntile; - } - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_segment_tree.cpp b/src/duckdb/src/function/window/window_segment_tree.cpp deleted file mode 100644 index 6708ae90b..000000000 --- a/src/duckdb/src/function/window/window_segment_tree.cpp +++ /dev/null @@ -1,594 +0,0 @@ -#include "duckdb/function/window/window_segment_tree.hpp" - -#include "duckdb/function/window/window_aggregate_states.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -#include - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowSegmentTree -//===--------------------------------------------------------------------===// -bool WindowSegmentTree::CanAggregate(const BoundWindowExpression &wexpr) { - if (!wexpr.aggregate) { - return false; - } - - return !wexpr.distinct && wexpr.arg_orders.empty(); -} - -WindowSegmentTree::WindowSegmentTree(const BoundWindowExpression &wexpr, WindowSharedExpressions &shared) - : WindowAggregator(wexpr, shared) { -} - -class WindowSegmentTreeGlobalState : public WindowAggregatorGlobalState { -public: - using AtomicCounters = vector>; - - WindowSegmentTreeGlobalState(ClientContext &context, const WindowSegmentTree &aggregator, idx_t group_count); - - ArenaAllocator &CreateTreeAllocator() { - lock_guard tree_lock(lock); - tree_allocators.emplace_back(make_uniq(Allocator::DefaultAllocator())); - return *tree_allocators.back(); - } - - //! The owning aggregator - const WindowSegmentTree &tree; - //! The actual window segment tree: an array of aggregate states that represent all the intermediate nodes - WindowAggregateStates levels_flat_native; - //! For each level, the starting location in the levels_flat_native array - vector levels_flat_start; - //! The level being built (read) - std::atomic build_level; - //! The number of entries started so far at each level - unique_ptr build_started; - //! The number of entries completed so far at each level - unique_ptr build_completed; - //! The tree allocators. - //! We need to hold onto them for the tree lifetime, - //! not the lifetime of the local state that constructed part of the tree - vector> tree_allocators; - - // TREE_FANOUT needs to cleanly divide STANDARD_VECTOR_SIZE - static constexpr idx_t TREE_FANOUT = 16; -}; - -class WindowSegmentTreePart { -public: - //! Right side nodes need to be cached and processed in reverse order - using RightEntry = std::pair; - - enum FramePart : uint8_t { FULL = 0, LEFT = 1, RIGHT = 2 }; - - WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, unique_ptr cursor, - const ValidityArray &filter_mask); - ~WindowSegmentTreePart(); - - unique_ptr Copy() const { - return make_uniq(allocator, aggr, cursor->Copy(), filter_mask); - } - - void FlushStates(bool combining); - void ExtractFrame(idx_t begin, idx_t end, data_ptr_t current_state); - void WindowSegmentValue(const WindowSegmentTreeGlobalState &tree, idx_t l_idx, idx_t begin, idx_t end, - data_ptr_t current_state); - //! Writes result and calls destructors - void Finalize(Vector &result, idx_t count); - - void Combine(WindowSegmentTreePart &other, idx_t count); - - void Evaluate(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, Vector &result, - idx_t count, idx_t row_idx, FramePart frame_part); - -protected: - //! Initialises the accumulation state vector (statef) - void Initialize(idx_t count); - //! Accumulate upper tree levels - void EvaluateUpperLevels(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, - idx_t count, idx_t row_idx, FramePart frame_part); - void EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, idx_t count, - idx_t row_idx, FramePart frame_part, FramePart leaf_part); - -public: - //! Allocator for aggregates - ArenaAllocator &allocator; - //! The aggregate function - const AggregateObject &aggr; - //! Order insensitive aggregate (we can optimise internal combines) - const bool order_insensitive; - //! The filtered rows in inputs - const ValidityArray &filter_mask; - //! The size of a single aggregate state - const idx_t state_size; - //! Data pointer that contains a vector of states, used for intermediate window segment aggregation - vector state; - //! Scanned data state - unique_ptr cursor; - //! Input data chunk, used for leaf segment aggregation - DataChunk leaves; - //! The filtered rows in inputs. - SelectionVector filter_sel; - //! A vector of pointers to "state", used for intermediate window segment aggregation - Vector statep; - //! Reused state pointers for combining segment tree levels - Vector statel; - //! Reused result state container for the window functions - Vector statef; - //! Count of buffered values - idx_t flush_count; - //! Cache of right side tree ranges for ordered aggregates - vector right_stack; -}; - -class WindowSegmentTreeState : public WindowAggregatorLocalState { -public: - WindowSegmentTreeState() { - } - - void Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) override; - void Evaluate(const WindowSegmentTreeGlobalState &gsink, const DataChunk &bounds, Vector &result, idx_t count, - idx_t row_idx); - //! The left (default) segment tree part - unique_ptr part; - //! The right segment tree part (for EXCLUDE) - unique_ptr right_part; -}; - -void WindowSegmentTree::Finalize(WindowAggregatorState &gsink, WindowAggregatorState &lstate, CollectionPtr collection, - const FrameStats &stats) { - WindowAggregator::Finalize(gsink, lstate, collection, stats); - - auto &gasink = gsink.Cast(); - ++gasink.finalized; -} - -WindowSegmentTreePart::WindowSegmentTreePart(ArenaAllocator &allocator, const AggregateObject &aggr, - unique_ptr cursor_p, const ValidityArray &filter_mask) - : allocator(allocator), aggr(aggr), - order_insensitive(aggr.function.order_dependent == AggregateOrderDependent::NOT_ORDER_DEPENDENT), - filter_mask(filter_mask), state_size(aggr.function.state_size(aggr.function)), - state(state_size * STANDARD_VECTOR_SIZE), cursor(std::move(cursor_p)), statep(LogicalType::POINTER), - statel(LogicalType::POINTER), statef(LogicalType::POINTER), flush_count(0) { - - auto &inputs = cursor->chunk; - if (inputs.ColumnCount() > 0) { - leaves.Initialize(Allocator::DefaultAllocator(), inputs.GetTypes()); - filter_sel.Initialize(); - } - - // Build the finalise vector that just points to the result states - data_ptr_t state_ptr = state.data(); - D_ASSERT(statef.GetVectorType() == VectorType::FLAT_VECTOR); - statef.SetVectorType(VectorType::CONSTANT_VECTOR); - statef.Flatten(STANDARD_VECTOR_SIZE); - auto fdata = FlatVector::GetData(statef); - for (idx_t i = 0; i < STANDARD_VECTOR_SIZE; ++i) { - fdata[i] = state_ptr; - state_ptr += state_size; - } -} - -WindowSegmentTreePart::~WindowSegmentTreePart() { -} - -unique_ptr WindowSegmentTree::GetGlobalState(ClientContext &context, idx_t group_count, - const ValidityMask &partition_mask) const { - return make_uniq(context, *this, group_count); -} - -unique_ptr WindowSegmentTree::GetLocalState(const WindowAggregatorState &gstate) const { - return make_uniq(); -} - -void WindowSegmentTreePart::FlushStates(bool combining) { - if (!flush_count) { - return; - } - - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - if (combining) { - statel.Verify(flush_count); - aggr.function.combine(statel, statep, aggr_input_data, flush_count); - } else { - auto &scanned = cursor->chunk; - leaves.Slice(scanned, filter_sel, flush_count); - aggr.function.update(&leaves.data[0], aggr_input_data, leaves.ColumnCount(), statep, flush_count); - } - - flush_count = 0; -} - -void WindowSegmentTreePart::Combine(WindowSegmentTreePart &other, idx_t count) { - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.combine(other.statef, statef, aggr_input_data, count); -} - -void WindowSegmentTreePart::ExtractFrame(idx_t begin, idx_t end, data_ptr_t state_ptr) { - const auto count = end - begin; - - // If we are not filtering, - // just update the shared dictionary selection to the range - // Otherwise set it to the input rows that pass the filter - auto states = FlatVector::GetData(statep); - if (filter_mask.AllValid()) { - const auto offset = cursor->RowOffset(begin); - for (idx_t i = 0; i < count; ++i) { - states[flush_count] = state_ptr; - filter_sel.set_index(flush_count++, offset + i); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(false); - } - } - } else { - for (idx_t i = begin; i < end; ++i) { - if (filter_mask.RowIsValid(i)) { - states[flush_count] = state_ptr; - filter_sel.set_index(flush_count++, cursor->RowOffset(i)); - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(false); - } - } - } - } -} - -void WindowSegmentTreePart::WindowSegmentValue(const WindowSegmentTreeGlobalState &tree, idx_t l_idx, idx_t begin, - idx_t end, data_ptr_t state_ptr) { - D_ASSERT(begin <= end); - auto &inputs = cursor->chunk; - if (begin == end || inputs.ColumnCount() == 0) { - return; - } - - const auto count = end - begin; - if (l_idx == 0) { - // Check the leaves when they cross chunk boundaries - while (begin < end) { - if (!cursor->RowIsVisible(begin)) { - FlushStates(false); - cursor->Seek(begin); - } - auto next = MinValue(end, cursor->state.next_row_index); - ExtractFrame(begin, next, state_ptr); - begin = next; - } - } else { - // find out where the states begin - auto begin_ptr = tree.levels_flat_native.GetStatePtr(begin + tree.levels_flat_start[l_idx - 1]); - // set up a vector of pointers that point towards the set of states - auto ldata = FlatVector::GetData(statel); - auto pdata = FlatVector::GetData(statep); - for (idx_t i = 0; i < count; i++) { - pdata[flush_count] = state_ptr; - ldata[flush_count++] = begin_ptr; - begin_ptr += state_size; - if (flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(true); - } - } - } -} -void WindowSegmentTreePart::Finalize(Vector &result, idx_t count) { - // Finalise the result aggregates and write to result if write_result is set - AggregateInputData aggr_input_data(aggr.GetFunctionData(), allocator); - aggr.function.finalize(statef, aggr_input_data, result, count, 0); - - // Destruct the result aggregates - if (aggr.function.destructor) { - aggr.function.destructor(statef, aggr_input_data, count); - } -} - -WindowSegmentTreeGlobalState::WindowSegmentTreeGlobalState(ClientContext &context, const WindowSegmentTree &aggregator, - idx_t group_count) - : WindowAggregatorGlobalState(context, aggregator, group_count), tree(aggregator), levels_flat_native(aggr) { - - D_ASSERT(!aggregator.wexpr.children.empty()); - - // compute space required to store internal nodes of segment tree - levels_flat_start.push_back(0); - - idx_t levels_flat_offset = 0; - idx_t level_current = 0; - // level 0 is data itself - idx_t level_size; - // iterate over the levels of the segment tree - while ((level_size = - (level_current == 0 ? group_count : levels_flat_offset - levels_flat_start[level_current - 1])) > 1) { - for (idx_t pos = 0; pos < level_size; pos += TREE_FANOUT) { - levels_flat_offset++; - } - - levels_flat_start.push_back(levels_flat_offset); - level_current++; - } - - // Corner case: single element in the window - if (levels_flat_offset == 0) { - ++levels_flat_offset; - } - - levels_flat_native.Initialize(levels_flat_offset); - - // Start by building from the bottom level - build_level = 0; - - build_started = make_uniq(levels_flat_start.size()); - for (auto &counter : *build_started) { - counter = 0; - } - - build_completed = make_uniq(levels_flat_start.size()); - for (auto &counter : *build_completed) { - counter = 0; - } -} - -void WindowSegmentTreeState::Finalize(WindowAggregatorGlobalState &gastate, CollectionPtr collection) { - WindowAggregatorLocalState::Finalize(gastate, collection); - - // Single part for constructing the tree - auto &gstate = gastate.Cast(); - auto cursor = make_uniq(*collection, gastate.aggregator.child_idx); - const auto leaf_count = collection->size(); - auto &filter_mask = gstate.filter_mask; - WindowSegmentTreePart gtstate(gstate.CreateTreeAllocator(), gastate.aggr, std::move(cursor), filter_mask); - - auto &levels_flat_native = gstate.levels_flat_native; - const auto &levels_flat_start = gstate.levels_flat_start; - // iterate over the levels of the segment tree - for (;;) { - const idx_t level_current = gstate.build_level.load(); - if (level_current >= levels_flat_start.size()) { - break; - } - - // level 0 is data itself - const auto level_size = - (level_current == 0 ? leaf_count : levels_flat_start[level_current] - levels_flat_start[level_current - 1]); - if (level_size <= 1) { - break; - } - const idx_t build_count = (level_size + gstate.TREE_FANOUT - 1) / gstate.TREE_FANOUT; - - // Build the next fan-in - const idx_t build_idx = (*gstate.build_started).at(level_current)++; - if (build_idx >= build_count) { - // Nothing left at this level, so wait until other threads are done. - // Since we are only building TREE_FANOUT values at a time, this will be quick. - while (level_current == gstate.build_level.load()) { - std::this_thread::sleep_for(std::chrono::milliseconds(1)); - } - continue; - } - - // compute the aggregate for this entry in the segment tree - const idx_t pos = build_idx * gstate.TREE_FANOUT; - const idx_t levels_flat_offset = levels_flat_start[level_current] + build_idx; - auto state_ptr = levels_flat_native.GetStatePtr(levels_flat_offset); - gtstate.WindowSegmentValue(gstate, level_current, pos, MinValue(level_size, pos + gstate.TREE_FANOUT), - state_ptr); - gtstate.FlushStates(level_current > 0); - - // If that was the last one, mark the level as complete. - const idx_t build_complete = ++(*gstate.build_completed).at(level_current); - if (build_complete == build_count) { - gstate.build_level++; - continue; - } - } -} - -void WindowSegmentTree::Evaluate(const WindowAggregatorState &gsink, WindowAggregatorState &lstate, - const DataChunk &bounds, Vector &result, idx_t count, idx_t row_idx) const { - const auto >state = gsink.Cast(); - auto <state = lstate.Cast(); - ltstate.Evaluate(gtstate, bounds, result, count, row_idx); -} - -void WindowSegmentTreeState::Evaluate(const WindowSegmentTreeGlobalState >state, const DataChunk &bounds, - Vector &result, idx_t count, idx_t row_idx) { - auto window_begin = FlatVector::GetData(bounds.data[FRAME_BEGIN]); - auto window_end = FlatVector::GetData(bounds.data[FRAME_END]); - auto peer_begin = FlatVector::GetData(bounds.data[PEER_BEGIN]); - auto peer_end = FlatVector::GetData(bounds.data[PEER_END]); - - if (!part) { - part = make_uniq(allocator, gtstate.aggr, cursor->Copy(), gtstate.filter_mask); - } - - if (gtstate.aggregator.exclude_mode != WindowExcludeMode::NO_OTHER) { - // 1. evaluate the tree left of the excluded part - part->Evaluate(gtstate, window_begin, peer_begin, result, count, row_idx, WindowSegmentTreePart::LEFT); - - // 2. set up a second state for the right of the excluded part - if (!right_part) { - right_part = part->Copy(); - } - - // 3. evaluate the tree right of the excluded part - right_part->Evaluate(gtstate, peer_end, window_end, result, count, row_idx, WindowSegmentTreePart::RIGHT); - - // 4. combine the buffer state into the Segment Tree State - part->Combine(*right_part, count); - } else { - part->Evaluate(gtstate, window_begin, window_end, result, count, row_idx, WindowSegmentTreePart::FULL); - } - - part->Finalize(result, count); -} - -void WindowSegmentTreePart::Evaluate(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, const idx_t *ends, - Vector &result, idx_t count, idx_t row_idx, FramePart frame_part) { - Initialize(count); - - if (order_insensitive) { - // First pass: aggregate the segment tree nodes with sharing - EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); - - // Second pass: aggregate the ragged leaves - EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::FULL); - } else { - // Evaluate leaves in order - EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::LEFT); - EvaluateUpperLevels(tree, begins, ends, count, row_idx, frame_part); - EvaluateLeaves(tree, begins, ends, count, row_idx, frame_part, FramePart::RIGHT); - } -} - -void WindowSegmentTreePart::Initialize(idx_t count) { - auto fdata = FlatVector::GetData(statef); - for (idx_t rid = 0; rid < count; ++rid) { - auto state_ptr = fdata[rid]; - aggr.function.initialize(aggr.function, state_ptr); - } -} - -void WindowSegmentTreePart::EvaluateUpperLevels(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, - const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part) { - auto fdata = FlatVector::GetData(statef); - - const auto exclude_mode = tree.tree.exclude_mode; - const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - - const auto max_level = tree.levels_flat_start.size() + 1; - right_stack.resize(max_level, {0, 0}); - - // Share adjacent identical states - // We do this first because we want to share only tree aggregations - idx_t prev_begin = 1; - idx_t prev_end = 0; - auto ldata = FlatVector::GetData(statel); - auto pdata = FlatVector::GetData(statep); - data_ptr_t prev_state = nullptr; - for (idx_t rid = 0, cur_row = row_idx; rid < count; ++rid, ++cur_row) { - auto state_ptr = fdata[rid]; - - auto begin = begin_on_curr_row ? cur_row + 1 : begins[rid]; - auto end = end_on_curr_row ? cur_row : ends[rid]; - if (begin >= end) { - continue; - } - - // Skip level 0 - idx_t l_idx = 0; - idx_t right_max = 0; - for (; l_idx < max_level; l_idx++) { - idx_t parent_begin = begin / tree.TREE_FANOUT; - idx_t parent_end = end / tree.TREE_FANOUT; - if (prev_state && l_idx == 1 && begin == prev_begin && end == prev_end) { - // Just combine the previous top level result - ldata[flush_count] = prev_state; - pdata[flush_count] = state_ptr; - if (++flush_count >= STANDARD_VECTOR_SIZE) { - FlushStates(true); - } - break; - } - - if (order_insensitive && l_idx == 1) { - prev_state = state_ptr; - prev_begin = begin; - prev_end = end; - } - - if (parent_begin == parent_end) { - if (l_idx) { - WindowSegmentValue(tree, l_idx, begin, end, state_ptr); - } - break; - } - idx_t group_begin = parent_begin * tree.TREE_FANOUT; - if (begin != group_begin) { - if (l_idx) { - WindowSegmentValue(tree, l_idx, begin, group_begin + tree.TREE_FANOUT, state_ptr); - } - parent_begin++; - } - idx_t group_end = parent_end * tree.TREE_FANOUT; - if (end != group_end) { - if (l_idx) { - if (order_insensitive) { - WindowSegmentValue(tree, l_idx, group_end, end, state_ptr); - } else { - right_stack[l_idx] = {group_end, end}; - right_max = l_idx; - } - } - } - begin = parent_begin; - end = parent_end; - } - - // Flush the right side values from left to right for order_sensitive aggregates - // As we go up the tree, the right side ranges move left, - // so we just cache them in a fixed size, preallocated array. - // Then we can just reverse scan the array and append the cached ranges. - for (l_idx = right_max; l_idx > 0; --l_idx) { - auto &right_entry = right_stack[l_idx]; - const auto group_end = right_entry.first; - const auto end = right_entry.second; - if (end) { - WindowSegmentValue(tree, l_idx, group_end, end, state_ptr); - right_entry = {0, 0}; - } - } - } - FlushStates(true); -} - -void WindowSegmentTreePart::EvaluateLeaves(const WindowSegmentTreeGlobalState &tree, const idx_t *begins, - const idx_t *ends, idx_t count, idx_t row_idx, FramePart frame_part, - FramePart leaf_part) { - - auto fdata = FlatVector::GetData(statef); - - // For order-sensitive aggregates, we have to process the ragged leaves in two pieces. - // The left side have to be added before the main tree followed by the ragged right sides. - // The current row is the leftmost value of the right hand side. - const bool compute_left = leaf_part != FramePart::RIGHT; - const bool compute_right = leaf_part != FramePart::LEFT; - const auto exclude_mode = tree.tree.exclude_mode; - const bool begin_on_curr_row = frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - const bool end_on_curr_row = frame_part == FramePart::LEFT && exclude_mode == WindowExcludeMode::CURRENT_ROW; - // with EXCLUDE TIES, in addition to the frame part right of the peer group's end, we also need to consider the - // current row - const bool add_curr_row = compute_left && frame_part == FramePart::RIGHT && exclude_mode == WindowExcludeMode::TIES; - - for (idx_t rid = 0, cur_row = row_idx; rid < count; ++rid, ++cur_row) { - auto state_ptr = fdata[rid]; - - const auto begin = begin_on_curr_row ? cur_row + 1 : begins[rid]; - const auto end = end_on_curr_row ? cur_row : ends[rid]; - if (add_curr_row) { - WindowSegmentValue(tree, 0, cur_row, cur_row + 1, state_ptr); - } - if (begin >= end) { - continue; - } - - idx_t parent_begin = begin / tree.TREE_FANOUT; - idx_t parent_end = end / tree.TREE_FANOUT; - if (parent_begin == parent_end) { - if (compute_left) { - WindowSegmentValue(tree, 0, begin, end, state_ptr); - } - continue; - } - - idx_t group_begin = parent_begin * tree.TREE_FANOUT; - if (begin != group_begin && compute_left) { - WindowSegmentValue(tree, 0, begin, group_begin + tree.TREE_FANOUT, state_ptr); - } - idx_t group_end = parent_end * tree.TREE_FANOUT; - if (end != group_end && compute_right) { - WindowSegmentValue(tree, 0, group_end, end, state_ptr); - } - } - FlushStates(false); -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_shared_expressions.cpp b/src/duckdb/src/function/window/window_shared_expressions.cpp deleted file mode 100644 index 811d3ec57..000000000 --- a/src/duckdb/src/function/window/window_shared_expressions.cpp +++ /dev/null @@ -1,50 +0,0 @@ -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/execution/expression_executor.hpp" - -namespace duckdb { - -column_t WindowSharedExpressions::RegisterExpr(const unique_ptr &expr, Shared &shared) { - auto pexpr = expr.get(); - if (!pexpr) { - return DConstants::INVALID_INDEX; - } - - // We need to make separate columns for volatile arguments - const auto is_volatile = expr->IsVolatile(); - auto i = shared.columns.find(*pexpr); - if (i != shared.columns.end() && !is_volatile) { - return i->second.front(); - } - - // New column, find maximum column number - column_t result = shared.size++; - shared.columns[*pexpr].emplace_back(result); - - return result; -} - -vector> WindowSharedExpressions::GetSortedExpressions(Shared &shared) { - vector> sorted(shared.size); - for (auto &col : shared.columns) { - auto &expr = col.first.get(); - for (auto col_idx : col.second) { - sorted[col_idx] = &expr; - } - } - - return sorted; -} -void WindowSharedExpressions::PrepareExecutors(Shared &shared, ExpressionExecutor &exec, DataChunk &chunk) { - const auto sorted = GetSortedExpressions(shared); - vector types; - for (auto expr : sorted) { - exec.AddExpression(*expr); - types.emplace_back(expr->return_type); - } - - if (!types.empty()) { - chunk.Initialize(exec.GetAllocator(), types); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_token_tree.cpp b/src/duckdb/src/function/window/window_token_tree.cpp deleted file mode 100644 index 82b5124e4..000000000 --- a/src/duckdb/src/function/window/window_token_tree.cpp +++ /dev/null @@ -1,142 +0,0 @@ -#include "duckdb/function/window/window_token_tree.hpp" - -namespace duckdb { - -class WindowTokenTreeLocalState : public WindowMergeSortTreeLocalState { -public: - explicit WindowTokenTreeLocalState(WindowTokenTree &token_tree) - : WindowMergeSortTreeLocalState(token_tree), token_tree(token_tree) { - } - //! Process sorted leaf data - void BuildLeaves() override; - - WindowTokenTree &token_tree; -}; - -void WindowTokenTreeLocalState::BuildLeaves() { - auto &global_sort = *token_tree.global_sort; - if (global_sort.sorted_blocks.empty()) { - return; - } - - // Scan the sort keys and note deltas - SBIterator curr(global_sort, ExpressionType::COMPARE_LESSTHAN); - SBIterator prev(global_sort, ExpressionType::COMPARE_LESSTHAN); - const auto &sort_layout = global_sort.sort_layout; - - const auto block_begin = token_tree.block_starts.at(build_task); - const auto block_end = token_tree.block_starts.at(build_task + 1); - auto &deltas = token_tree.deltas; - if (!block_begin) { - // First block, so set up initial delta - deltas[0] = 0; - } else { - // Move to the to end of the previous block - // so we can record the comparison result for the first row - curr.SetIndex(block_begin - 1); - prev.SetIndex(block_begin - 1); - } - - for (++curr; curr.GetIndex() < block_end; ++curr, ++prev) { - int lt = 0; - if (sort_layout.all_constant) { - lt = FastMemcmp(prev.entry_ptr, curr.entry_ptr, sort_layout.comparison_size); - } else { - lt = Comparators::CompareTuple(prev.scan, curr.scan, prev.entry_ptr, curr.entry_ptr, sort_layout, - prev.external); - } - - deltas[curr.GetIndex()] = (lt != 0); - } -} - -idx_t WindowTokenTree::MeasurePayloadBlocks() { - const auto count = WindowMergeSortTree::MeasurePayloadBlocks(); - - deltas.resize(count); - - return count; -} - -template -static void BuildTokens(WindowTokenTree &token_tree, vector &tokens) { - PayloadScanner scanner(*token_tree.global_sort); - DataChunk payload_chunk; - payload_chunk.Initialize(token_tree.context, token_tree.global_sort->payload_layout.GetTypes()); - const T *row_idx = nullptr; - idx_t i = 0; - - T token = 0; - for (auto &d : token_tree.deltas) { - if (i >= payload_chunk.size()) { - payload_chunk.Reset(); - scanner.Scan(payload_chunk); - if (!payload_chunk.size()) { - break; - } - row_idx = FlatVector::GetData(payload_chunk.data[0]); - i = 0; - } - - token += d; - tokens[row_idx[i++]] = token; - } -} - -unique_ptr WindowTokenTree::GetLocalState() { - return make_uniq(*this); -} - -void WindowTokenTree::CleanupSort() { - // Convert the deltas to tokens - if (mst64) { - BuildTokens(*this, mst64->LowestLevel()); - } else { - BuildTokens(*this, mst32->LowestLevel()); - } - - // Deallocate memory - vector empty; - deltas.swap(empty); - - WindowMergeSortTree::CleanupSort(); -} - -template -static idx_t TokenRank(const TREE &tree, const idx_t lower, const idx_t upper, const idx_t row_idx) { - idx_t rank = 1; - const auto needle = tree.LowestLevel()[row_idx]; - tree.AggregateLowerBound(lower, upper, needle, [&](idx_t level, const idx_t run_begin, const idx_t run_pos) { - rank += run_pos - run_begin; - }); - return rank; -} - -idx_t WindowTokenTree::Rank(const idx_t lower, const idx_t upper, const idx_t row_idx) const { - if (mst64) { - return TokenRank(*mst64, lower, upper, row_idx); - } else { - return TokenRank(*mst32, lower, upper, row_idx); - } -} - -template -static idx_t NextPeer(const TREE &tree, const idx_t lower, const idx_t upper, const idx_t row_idx) { - idx_t rank = 0; - // Because tokens are dense, we can find the next peer by adding 1 to the probed token value - const auto needle = tree.LowestLevel()[row_idx] + 1; - tree.AggregateLowerBound(lower, upper, needle, [&](idx_t level, const idx_t run_begin, const idx_t run_pos) { - rank += run_pos - run_begin; - }); - return rank; -} - -idx_t WindowTokenTree::PeerEnd(const idx_t lower, const idx_t upper, const idx_t row_idx) const { - if (mst64) { - return NextPeer(*mst64, lower, upper, row_idx); - } else { - return NextPeer(*mst32, lower, upper, row_idx); - } -} - -} // namespace duckdb diff --git a/src/duckdb/src/function/window/window_value_function.cpp b/src/duckdb/src/function/window/window_value_function.cpp deleted file mode 100644 index 6b8a7038e..000000000 --- a/src/duckdb/src/function/window/window_value_function.cpp +++ /dev/null @@ -1,566 +0,0 @@ -#include "duckdb/common/operator/add.hpp" -#include "duckdb/common/operator/subtract.hpp" -#include "duckdb/function/window/window_aggregator.hpp" -#include "duckdb/function/window/window_collection.hpp" -#include "duckdb/function/window/window_index_tree.hpp" -#include "duckdb/function/window/window_shared_expressions.hpp" -#include "duckdb/function/window/window_token_tree.hpp" -#include "duckdb/function/window/window_value_function.hpp" -#include "duckdb/planner/expression/bound_window_expression.hpp" - -namespace duckdb { - -//===--------------------------------------------------------------------===// -// WindowValueGlobalState -//===--------------------------------------------------------------------===// - -class WindowValueGlobalState : public WindowExecutorGlobalState { -public: - using WindowCollectionPtr = unique_ptr; - WindowValueGlobalState(const WindowValueExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowExecutorGlobalState(executor, payload_count, partition_mask, order_mask), ignore_nulls(&all_valid), - child_idx(executor.child_idx) { - - if (!executor.arg_order_idx.empty()) { - value_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, - payload_count); - } - } - - void Finalize(CollectionPtr collection) { - lock_guard ignore_nulls_guard(lock); - if (child_idx != DConstants::INVALID_INDEX && executor.wexpr.ignore_nulls) { - ignore_nulls = &collection->validities[child_idx]; - } - } - - // IGNORE NULLS - mutex lock; - ValidityMask all_valid; - optional_ptr ignore_nulls; - - //! Copy of the executor child_idx - const column_t child_idx; - - //! Merge sort tree to map unfiltered row number to value - unique_ptr value_tree; -}; - -//===--------------------------------------------------------------------===// -// WindowValueLocalState -//===--------------------------------------------------------------------===// - -//! A class representing the state of the first_value, last_value and nth_value functions -class WindowValueLocalState : public WindowExecutorBoundsState { -public: - explicit WindowValueLocalState(const WindowValueGlobalState &gvstate) - : WindowExecutorBoundsState(gvstate), gvstate(gvstate) { - WindowAggregatorLocalState::InitSubFrames(frames, gvstate.executor.wexpr.exclude_clause); - - if (gvstate.value_tree) { - local_value = gvstate.value_tree->GetLocalState(); - if (gvstate.executor.wexpr.ignore_nulls) { - sort_nulls.Initialize(); - } - } - } - - //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; - //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; - - //! The corresponding global value state - const WindowValueGlobalState &gvstate; - //! The optional sorting state for secondary sorts - unique_ptr local_value; - //! Reusable selection vector for NULLs - SelectionVector sort_nulls; - //! The frame boundaries, used for EXCLUDE - SubFrames frames; - - //! The state used for reading the collection - unique_ptr cursor; -}; - -void WindowValueLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowExecutorBoundsState::Sink(gstate, sink_chunk, coll_chunk, input_idx); - - if (local_value) { - idx_t filtered = 0; - optional_ptr filter_sel; - - // If we need to IGNORE NULLS for the child, and there are NULLs, - // then build an SV to hold them - const auto coll_count = coll_chunk.size(); - auto &child = coll_chunk.data[gvstate.child_idx]; - UnifiedVectorFormat child_data; - child.ToUnifiedFormat(coll_count, child_data); - const auto &validity = child_data.validity; - if (gstate.executor.wexpr.ignore_nulls && !validity.AllValid()) { - for (sel_t i = 0; i < coll_count; ++i) { - if (validity.RowIsValidUnsafe(i)) { - sort_nulls[filtered++] = i; - } - } - filter_sel = &sort_nulls; - } - - auto &value_state = local_value->Cast(); - value_state.SinkChunk(sink_chunk, input_idx, filter_sel, filtered); - } -} - -void WindowValueLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowExecutorBoundsState::Finalize(gstate, collection); - - if (local_value) { - auto &value_state = local_value->Cast(); - value_state.Sort(); - value_state.index_tree.Build(); - } - - // Prepare to scan - if (!cursor && gvstate.child_idx != DConstants::INVALID_INDEX) { - cursor = make_uniq(*collection, gvstate.child_idx); - } -} - -//===--------------------------------------------------------------------===// -// WindowValueExecutor -//===--------------------------------------------------------------------===// -WindowValueExecutor::WindowValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowExecutor(wexpr, context, shared) { - - for (const auto &order : wexpr.arg_orders) { - arg_order_idx.emplace_back(shared.RegisterSink(order.expression)); - } - - // The children have to be handled separately because only the first one is global - if (!wexpr.children.empty()) { - child_idx = shared.RegisterCollection(wexpr.children[0], wexpr.ignore_nulls); - - if (wexpr.children.size() > 1) { - nth_idx = shared.RegisterEvaluate(wexpr.children[1]); - } - } - - offset_idx = shared.RegisterEvaluate(wexpr.offset_expr); - default_idx = shared.RegisterEvaluate(wexpr.default_expr); -} - -unique_ptr WindowValueExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -void WindowValueExecutor::Finalize(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - CollectionPtr collection) const { - auto &gvstate = gstate.Cast(); - gvstate.Finalize(collection); - - WindowExecutor::Finalize(gstate, lstate, collection); -} - -unique_ptr WindowValueExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - const auto &gvstate = gstate.Cast(); - return make_uniq(gvstate); -} - -//===--------------------------------------------------------------------===// -// WindowLeadLagGlobalState -//===--------------------------------------------------------------------===// -// The functions LEAD and LAG can be extended to a windowed version with -// two independent ORDER BY clauses just like first_value and other value -// functions. -// To evaluate a windowed LEAD/LAG, one has to (1) compute the ROW_NUMBER -// of the own row, (2) adjust the row number by adding or subtracting an -// offset, (3) find the row at that offset, and (4) evaluate the expression -// provided to LEAD/LAG on this row. One can use the algorithm from Section -// 4.4 to determine the row number of the own row (step 1) and the -// algorithm from Section 4.5 to find the row with the adjusted position -// (step 3). Both algorithms are in O(𝑛 log𝑛), so the overall algorithm -// for LEAD/LAG is also O(𝑛 log𝑛). -// -// 4.4: unique WindowTokenTree -// 4.5: WindowIndexTree - -class WindowLeadLagGlobalState : public WindowValueGlobalState { -public: - explicit WindowLeadLagGlobalState(const WindowValueExecutor &executor, const idx_t payload_count, - const ValidityMask &partition_mask, const ValidityMask &order_mask) - : WindowValueGlobalState(executor, payload_count, partition_mask, order_mask) { - - if (value_tree) { - // "The ROW_NUMBER function can be computed by disambiguating duplicate elements based on their position in - // the input data, such that two elements never compare as equal." - // Note: If the user specifies an partial secondary sort, the disambiguation will use the - // partition's row numbers, not the secondary sort's row numbers. - row_tree = make_uniq(executor.context, executor.wexpr.arg_orders, executor.arg_order_idx, - payload_count, true); - } - } - - //! Merge sort tree to map partition offset to row number (algorithm from Section 4.5) - unique_ptr row_tree; -}; - -//===--------------------------------------------------------------------===// -// WindowLeadLagLocalState -//===--------------------------------------------------------------------===// -class WindowLeadLagLocalState : public WindowValueLocalState { -public: - explicit WindowLeadLagLocalState(const WindowLeadLagGlobalState &gstate) : WindowValueLocalState(gstate) { - if (gstate.row_tree) { - local_row = gstate.row_tree->GetLocalState(); - } - } - - //! Accumulate the secondary sort values - void Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) override; - //! Finish the sinking and prepare to scan - void Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) override; - - //! The optional sorting state for the secondary sort row mapping - unique_ptr local_row; -}; - -void WindowLeadLagLocalState::Sink(WindowExecutorGlobalState &gstate, DataChunk &sink_chunk, DataChunk &coll_chunk, - idx_t input_idx) { - WindowValueLocalState::Sink(gstate, sink_chunk, coll_chunk, input_idx); - - if (local_row) { - idx_t filtered = 0; - optional_ptr filter_sel; - - auto &row_state = local_row->Cast(); - row_state.SinkChunk(sink_chunk, input_idx, filter_sel, filtered); - } -} - -void WindowLeadLagLocalState::Finalize(WindowExecutorGlobalState &gstate, CollectionPtr collection) { - WindowValueLocalState::Finalize(gstate, collection); - - if (local_row) { - auto &row_state = local_row->Cast(); - row_state.Sort(); - row_state.window_tree.Build(); - } -} - -//===--------------------------------------------------------------------===// -// WindowLeadLagExecutor -//===--------------------------------------------------------------------===// -WindowLeadLagExecutor::WindowLeadLagExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { -} - -unique_ptr WindowLeadLagExecutor::GetGlobalState(const idx_t payload_count, - const ValidityMask &partition_mask, - const ValidityMask &order_mask) const { - return make_uniq(*this, payload_count, partition_mask, order_mask); -} - -unique_ptr -WindowLeadLagExecutor::GetLocalState(const WindowExecutorGlobalState &gstate) const { - const auto &glstate = gstate.Cast(); - return make_uniq(glstate); -} - -void WindowLeadLagExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &glstate = gstate.Cast(); - auto &llstate = lstate.Cast(); - auto &cursor = *llstate.cursor; - - WindowInputExpression leadlag_offset(eval_chunk, offset_idx); - WindowInputExpression leadlag_default(eval_chunk, default_idx); - - if (glstate.row_tree) { - auto frame_begin = FlatVector::GetData(llstate.bounds.data[FRAME_BEGIN]); - auto frame_end = FlatVector::GetData(llstate.bounds.data[FRAME_END]); - // TODO: Handle subframes. - auto &frames = llstate.frames; - frames.resize(1); - auto &frame = frames[0]; - for (idx_t i = 0; i < count; ++i, ++row_idx) { - // (1) compute the ROW_NUMBER of the own row - frame = FrameBounds(frame_begin[i], frame_end[i]); - const auto own_row = glstate.row_tree->Rank(frame.start, frame.end, row_idx) - 1; - // (2) adjust the row number by adding or subtracting an offset - auto val_idx = NumericCast(own_row); - int64_t offset = 1; - if (wexpr.offset_expr) { - offset = leadlag_offset.GetCell(i); - } - if (wexpr.GetExpressionType() == ExpressionType::WINDOW_LEAD) { - val_idx = AddOperatorOverflowCheck::Operation(val_idx, offset); - } else { - val_idx = SubtractOperatorOverflowCheck::Operation(val_idx, offset); - } - const auto frame_width = NumericCast(frame.end - frame.start); - if (val_idx >= 0 && val_idx < frame_width) { - // (3) find the row at that offset - const auto n = NumericCast(val_idx); - const auto nth_index = glstate.value_tree->SelectNth(frames, n); - // (4) evaluate the expression provided to LEAD/LAG on this row. - cursor.CopyCell(0, nth_index, result, i); - } else if (wexpr.default_expr) { - leadlag_default.CopyCell(result, i); - } else { - FlatVector::SetNull(result, i, true); - } - } - return; - } - - auto partition_begin = FlatVector::GetData(llstate.bounds.data[PARTITION_BEGIN]); - auto partition_end = FlatVector::GetData(llstate.bounds.data[PARTITION_END]); - - auto &ignore_nulls = glstate.ignore_nulls; - bool can_shift = ignore_nulls->AllValid(); - if (wexpr.offset_expr) { - can_shift = can_shift && wexpr.offset_expr->IsFoldable(); - } - if (wexpr.default_expr) { - can_shift = can_shift && wexpr.default_expr->IsFoldable(); - } - - const auto row_end = row_idx + count; - for (idx_t i = 0; i < count;) { - int64_t offset = 1; - if (wexpr.offset_expr) { - offset = leadlag_offset.GetCell(i); - } - int64_t val_idx = (int64_t)row_idx; - if (wexpr.GetExpressionType() == ExpressionType::WINDOW_LEAD) { - val_idx = AddOperatorOverflowCheck::Operation(val_idx, offset); - } else { - val_idx = SubtractOperatorOverflowCheck::Operation(val_idx, offset); - } - - idx_t delta = 0; - if (val_idx < (int64_t)row_idx) { - // Count backwards - delta = idx_t(row_idx - idx_t(val_idx)); - val_idx = int64_t(WindowBoundariesState::FindPrevStart(*ignore_nulls, partition_begin[i], row_idx, delta)); - } else if (val_idx > (int64_t)row_idx) { - delta = idx_t(idx_t(val_idx) - row_idx); - val_idx = - int64_t(WindowBoundariesState::FindNextStart(*ignore_nulls, row_idx + 1, partition_end[i], delta)); - } - // else offset is zero, so don't move. - - if (can_shift) { - const auto target_limit = MinValue(partition_end[i], row_end) - row_idx; - if (!delta) { - // Copy source[index:index+width] => result[i:] - auto index = NumericCast(val_idx); - const auto source_limit = partition_end[i] - index; - auto width = MinValue(source_limit, target_limit); - // We may have to scan multiple blocks here, so loop until we have copied everything - const idx_t col_idx = 0; - while (width) { - const auto source_offset = cursor.Seek(index); - auto &source = cursor.chunk.data[col_idx]; - const auto copied = MinValue(cursor.chunk.size() - source_offset, width); - VectorOperations::Copy(source, result, source_offset + copied, source_offset, i); - i += copied; - row_idx += copied; - index += copied; - width -= copied; - } - } else if (wexpr.default_expr) { - const auto width = MinValue(delta, target_limit); - leadlag_default.CopyCell(result, i, width); - i += width; - row_idx += width; - } else { - for (idx_t nulls = MinValue(delta, target_limit); nulls--; ++i, ++row_idx) { - FlatVector::SetNull(result, i, true); - } - } - } else { - if (!delta) { - cursor.CopyCell(0, NumericCast(val_idx), result, i); - } else if (wexpr.default_expr) { - leadlag_default.CopyCell(result, i); - } else { - FlatVector::SetNull(result, i, true); - } - ++i; - ++row_idx; - } - } -} - -WindowFirstValueExecutor::WindowFirstValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { -} - -void WindowFirstValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); - auto &cursor = *lvstate.cursor; - auto &bounds = lvstate.bounds; - auto &frames = lvstate.frames; - auto &ignore_nulls = *gvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; - WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { - if (gvstate.value_tree) { - idx_t frame_width = 0; - for (const auto &frame : frames) { - frame_width += frame.end - frame.start; - } - - if (frame_width) { - const auto first_idx = gvstate.value_tree->SelectNth(frames, 0); - cursor.CopyCell(0, first_idx, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - return; - } - - for (const auto &frame : frames) { - if (frame.start >= frame.end) { - continue; - } - - // Same as NTH_VALUE(..., 1) - idx_t n = 1; - const auto first_idx = WindowBoundariesState::FindNextStart(ignore_nulls, frame.start, frame.end, n); - if (!n) { - cursor.CopyCell(0, first_idx, result, i); - return; - } - } - - // Didn't find one - FlatVector::SetNull(result, i, true); - }); -} - -WindowLastValueExecutor::WindowLastValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { -} - -void WindowLastValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, - idx_t row_idx) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); - auto &cursor = *lvstate.cursor; - auto &bounds = lvstate.bounds; - auto &frames = lvstate.frames; - auto &ignore_nulls = *gvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; - WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { - if (gvstate.value_tree) { - idx_t frame_width = 0; - for (const auto &frame : frames) { - frame_width += frame.end - frame.start; - } - - if (frame_width) { - const auto last_idx = gvstate.value_tree->SelectNth(frames, frame_width - 1); - cursor.CopyCell(0, last_idx, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - return; - } - - for (idx_t f = frames.size(); f-- > 0;) { - const auto &frame = frames[f]; - if (frame.start >= frame.end) { - continue; - } - - idx_t n = 1; - const auto last_idx = WindowBoundariesState::FindPrevStart(ignore_nulls, frame.start, frame.end, n); - if (!n) { - cursor.CopyCell(0, last_idx, result, i); - return; - } - } - - // Didn't find one - FlatVector::SetNull(result, i, true); - }); -} - -WindowNthValueExecutor::WindowNthValueExecutor(BoundWindowExpression &wexpr, ClientContext &context, - WindowSharedExpressions &shared) - : WindowValueExecutor(wexpr, context, shared) { -} - -void WindowNthValueExecutor::EvaluateInternal(WindowExecutorGlobalState &gstate, WindowExecutorLocalState &lstate, - DataChunk &eval_chunk, Vector &result, idx_t count, idx_t row_idx) const { - auto &gvstate = gstate.Cast(); - auto &lvstate = lstate.Cast(); - auto &cursor = *lvstate.cursor; - auto &bounds = lvstate.bounds; - auto &frames = lvstate.frames; - auto &ignore_nulls = *gvstate.ignore_nulls; - auto exclude_mode = gvstate.executor.wexpr.exclude_clause; - D_ASSERT(cursor.chunk.ColumnCount() == 1); - WindowInputExpression nth_col(eval_chunk, nth_idx); - WindowAggregator::EvaluateSubFrames(bounds, exclude_mode, count, row_idx, frames, [&](idx_t i) { - // Returns value evaluated at the row that is the n'th row of the window frame (counting from 1); - // returns NULL if there is no such row. - if (nth_col.CellIsNull(i)) { - FlatVector::SetNull(result, i, true); - return; - } - auto n_param = nth_col.GetCell(i); - if (n_param < 1) { - FlatVector::SetNull(result, i, true); - return; - } - - // Decrement as we go along. - auto n = idx_t(n_param); - - if (gvstate.value_tree) { - idx_t frame_width = 0; - for (const auto &frame : frames) { - frame_width += frame.end - frame.start; - } - - if (n < frame_width) { - const auto nth_index = gvstate.value_tree->SelectNth(frames, n - 1); - cursor.CopyCell(0, nth_index, result, i); - } else { - FlatVector::SetNull(result, i, true); - } - return; - } - - for (const auto &frame : frames) { - if (frame.start >= frame.end) { - continue; - } - - const auto nth_index = WindowBoundariesState::FindNextStart(ignore_nulls, frame.start, frame.end, n); - if (!n) { - cursor.CopyCell(0, nth_index, result, i); - return; - } - } - FlatVector::SetNull(result, i, true); - }); -} - -} // namespace duckdb diff --git a/src/duckdb/src/include/duckdb.h b/src/duckdb/src/include/duckdb.h deleted file mode 100644 index ae52abd33..000000000 --- a/src/duckdb/src/include/duckdb.h +++ /dev/null @@ -1,4545 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// DuckDB -// -// duckdb.h -// -// -//===----------------------------------------------------------------------===// -// -// !!!!!!! -// WARNING: this file is autogenerated by scripts/generate_c_api.py, manual changes will be overwritten -// !!!!!!! - -#pragma once - -//! duplicate of duckdb/main/winapi.hpp -#ifndef DUCKDB_API -#ifdef _WIN32 -#ifdef DUCKDB_STATIC_BUILD -#define DUCKDB_API -#else -#if defined(DUCKDB_BUILD_LIBRARY) && !defined(DUCKDB_BUILD_LOADABLE_EXTENSION) -#define DUCKDB_API __declspec(dllexport) -#else -#define DUCKDB_API __declspec(dllimport) -#endif -#endif -#else -#define DUCKDB_API -#endif -#endif - -//! duplicate of duckdb/main/winapi.hpp -#ifndef DUCKDB_EXTENSION_API -#ifdef _WIN32 -#ifdef DUCKDB_STATIC_BUILD -#define DUCKDB_EXTENSION_API -#else -#ifdef DUCKDB_BUILD_LOADABLE_EXTENSION -#define DUCKDB_EXTENSION_API __declspec(dllexport) -#else -#define DUCKDB_EXTENSION_API -#endif -#endif -#else -#define DUCKDB_EXTENSION_API __attribute__((visibility("default"))) -#endif -#endif - -#include -#include -#include - -#ifdef __cplusplus -extern "C" { -#endif - -//===--------------------------------------------------------------------===// -// Enums -//===--------------------------------------------------------------------===// -// WARNING: the numbers of these enums should not be changed, as changing the numbers breaks ABI compatibility -// Always add enums at the END of the enum -//! An enum over DuckDB's internal types. -typedef enum DUCKDB_TYPE { - DUCKDB_TYPE_INVALID = 0, - // bool - DUCKDB_TYPE_BOOLEAN = 1, - // int8_t - DUCKDB_TYPE_TINYINT = 2, - // int16_t - DUCKDB_TYPE_SMALLINT = 3, - // int32_t - DUCKDB_TYPE_INTEGER = 4, - // int64_t - DUCKDB_TYPE_BIGINT = 5, - // uint8_t - DUCKDB_TYPE_UTINYINT = 6, - // uint16_t - DUCKDB_TYPE_USMALLINT = 7, - // uint32_t - DUCKDB_TYPE_UINTEGER = 8, - // uint64_t - DUCKDB_TYPE_UBIGINT = 9, - // float - DUCKDB_TYPE_FLOAT = 10, - // double - DUCKDB_TYPE_DOUBLE = 11, - // duckdb_timestamp (microseconds) - DUCKDB_TYPE_TIMESTAMP = 12, - // duckdb_date - DUCKDB_TYPE_DATE = 13, - // duckdb_time - DUCKDB_TYPE_TIME = 14, - // duckdb_interval - DUCKDB_TYPE_INTERVAL = 15, - // duckdb_hugeint - DUCKDB_TYPE_HUGEINT = 16, - // duckdb_uhugeint - DUCKDB_TYPE_UHUGEINT = 32, - // const char* - DUCKDB_TYPE_VARCHAR = 17, - // duckdb_blob - DUCKDB_TYPE_BLOB = 18, - // duckdb_decimal - DUCKDB_TYPE_DECIMAL = 19, - // duckdb_timestamp_s (seconds) - DUCKDB_TYPE_TIMESTAMP_S = 20, - // duckdb_timestamp_ms (milliseconds) - DUCKDB_TYPE_TIMESTAMP_MS = 21, - // duckdb_timestamp_ns (nanoseconds) - DUCKDB_TYPE_TIMESTAMP_NS = 22, - // enum type, only useful as logical type - DUCKDB_TYPE_ENUM = 23, - // list type, only useful as logical type - DUCKDB_TYPE_LIST = 24, - // struct type, only useful as logical type - DUCKDB_TYPE_STRUCT = 25, - // map type, only useful as logical type - DUCKDB_TYPE_MAP = 26, - // duckdb_array, only useful as logical type - DUCKDB_TYPE_ARRAY = 33, - // duckdb_hugeint - DUCKDB_TYPE_UUID = 27, - // union type, only useful as logical type - DUCKDB_TYPE_UNION = 28, - // duckdb_bit - DUCKDB_TYPE_BIT = 29, - // duckdb_time_tz - DUCKDB_TYPE_TIME_TZ = 30, - // duckdb_timestamp (microseconds) - DUCKDB_TYPE_TIMESTAMP_TZ = 31, - // ANY type - DUCKDB_TYPE_ANY = 34, - // duckdb_varint - DUCKDB_TYPE_VARINT = 35, - // SQLNULL type - DUCKDB_TYPE_SQLNULL = 36, -} duckdb_type; -//! An enum over the returned state of different functions. -typedef enum duckdb_state { DuckDBSuccess = 0, DuckDBError = 1 } duckdb_state; -//! An enum over the pending state of a pending query result. -typedef enum duckdb_pending_state { - DUCKDB_PENDING_RESULT_READY = 0, - DUCKDB_PENDING_RESULT_NOT_READY = 1, - DUCKDB_PENDING_ERROR = 2, - DUCKDB_PENDING_NO_TASKS_AVAILABLE = 3 -} duckdb_pending_state; -//! An enum over DuckDB's different result types. -typedef enum duckdb_result_type { - DUCKDB_RESULT_TYPE_INVALID = 0, - DUCKDB_RESULT_TYPE_CHANGED_ROWS = 1, - DUCKDB_RESULT_TYPE_NOTHING = 2, - DUCKDB_RESULT_TYPE_QUERY_RESULT = 3, -} duckdb_result_type; -//! An enum over DuckDB's different statement types. -typedef enum duckdb_statement_type { - DUCKDB_STATEMENT_TYPE_INVALID = 0, - DUCKDB_STATEMENT_TYPE_SELECT = 1, - DUCKDB_STATEMENT_TYPE_INSERT = 2, - DUCKDB_STATEMENT_TYPE_UPDATE = 3, - DUCKDB_STATEMENT_TYPE_EXPLAIN = 4, - DUCKDB_STATEMENT_TYPE_DELETE = 5, - DUCKDB_STATEMENT_TYPE_PREPARE = 6, - DUCKDB_STATEMENT_TYPE_CREATE = 7, - DUCKDB_STATEMENT_TYPE_EXECUTE = 8, - DUCKDB_STATEMENT_TYPE_ALTER = 9, - DUCKDB_STATEMENT_TYPE_TRANSACTION = 10, - DUCKDB_STATEMENT_TYPE_COPY = 11, - DUCKDB_STATEMENT_TYPE_ANALYZE = 12, - DUCKDB_STATEMENT_TYPE_VARIABLE_SET = 13, - DUCKDB_STATEMENT_TYPE_CREATE_FUNC = 14, - DUCKDB_STATEMENT_TYPE_DROP = 15, - DUCKDB_STATEMENT_TYPE_EXPORT = 16, - DUCKDB_STATEMENT_TYPE_PRAGMA = 17, - DUCKDB_STATEMENT_TYPE_VACUUM = 18, - DUCKDB_STATEMENT_TYPE_CALL = 19, - DUCKDB_STATEMENT_TYPE_SET = 20, - DUCKDB_STATEMENT_TYPE_LOAD = 21, - DUCKDB_STATEMENT_TYPE_RELATION = 22, - DUCKDB_STATEMENT_TYPE_EXTENSION = 23, - DUCKDB_STATEMENT_TYPE_LOGICAL_PLAN = 24, - DUCKDB_STATEMENT_TYPE_ATTACH = 25, - DUCKDB_STATEMENT_TYPE_DETACH = 26, - DUCKDB_STATEMENT_TYPE_MULTI = 27, -} duckdb_statement_type; -//! An enum over DuckDB's different result types. -typedef enum duckdb_error_type { - DUCKDB_ERROR_INVALID = 0, - DUCKDB_ERROR_OUT_OF_RANGE = 1, - DUCKDB_ERROR_CONVERSION = 2, - DUCKDB_ERROR_UNKNOWN_TYPE = 3, - DUCKDB_ERROR_DECIMAL = 4, - DUCKDB_ERROR_MISMATCH_TYPE = 5, - DUCKDB_ERROR_DIVIDE_BY_ZERO = 6, - DUCKDB_ERROR_OBJECT_SIZE = 7, - DUCKDB_ERROR_INVALID_TYPE = 8, - DUCKDB_ERROR_SERIALIZATION = 9, - DUCKDB_ERROR_TRANSACTION = 10, - DUCKDB_ERROR_NOT_IMPLEMENTED = 11, - DUCKDB_ERROR_EXPRESSION = 12, - DUCKDB_ERROR_CATALOG = 13, - DUCKDB_ERROR_PARSER = 14, - DUCKDB_ERROR_PLANNER = 15, - DUCKDB_ERROR_SCHEDULER = 16, - DUCKDB_ERROR_EXECUTOR = 17, - DUCKDB_ERROR_CONSTRAINT = 18, - DUCKDB_ERROR_INDEX = 19, - DUCKDB_ERROR_STAT = 20, - DUCKDB_ERROR_CONNECTION = 21, - DUCKDB_ERROR_SYNTAX = 22, - DUCKDB_ERROR_SETTINGS = 23, - DUCKDB_ERROR_BINDER = 24, - DUCKDB_ERROR_NETWORK = 25, - DUCKDB_ERROR_OPTIMIZER = 26, - DUCKDB_ERROR_NULL_POINTER = 27, - DUCKDB_ERROR_IO = 28, - DUCKDB_ERROR_INTERRUPT = 29, - DUCKDB_ERROR_FATAL = 30, - DUCKDB_ERROR_INTERNAL = 31, - DUCKDB_ERROR_INVALID_INPUT = 32, - DUCKDB_ERROR_OUT_OF_MEMORY = 33, - DUCKDB_ERROR_PERMISSION = 34, - DUCKDB_ERROR_PARAMETER_NOT_RESOLVED = 35, - DUCKDB_ERROR_PARAMETER_NOT_ALLOWED = 36, - DUCKDB_ERROR_DEPENDENCY = 37, - DUCKDB_ERROR_HTTP = 38, - DUCKDB_ERROR_MISSING_EXTENSION = 39, - DUCKDB_ERROR_AUTOLOAD = 40, - DUCKDB_ERROR_SEQUENCE = 41, - DUCKDB_INVALID_CONFIGURATION = 42 -} duckdb_error_type; -//! An enum over DuckDB's different cast modes. -typedef enum duckdb_cast_mode { DUCKDB_CAST_NORMAL = 0, DUCKDB_CAST_TRY = 1 } duckdb_cast_mode; - -//===--------------------------------------------------------------------===// -// General type definitions -//===--------------------------------------------------------------------===// - -//! DuckDB's index type. -typedef uint64_t idx_t; - -//! The callback that will be called to destroy data, e.g., -//! bind data (if any), init data (if any), extra data for replacement scans (if any) -typedef void (*duckdb_delete_callback_t)(void *data); - -//! Used for threading, contains a task state. Must be destroyed with `duckdb_destroy_state`. -typedef void *duckdb_task_state; - -//===--------------------------------------------------------------------===// -// Types (no explicit freeing) -//===--------------------------------------------------------------------===// - -//! Days are stored as days since 1970-01-01 -//! Use the duckdb_from_date/duckdb_to_date function to extract individual information -typedef struct { - int32_t days; -} duckdb_date; -typedef struct { - int32_t year; - int8_t month; - int8_t day; -} duckdb_date_struct; - -//! Time is stored as microseconds since 00:00:00 -//! Use the duckdb_from_time/duckdb_to_time function to extract individual information -typedef struct { - int64_t micros; -} duckdb_time; -typedef struct { - int8_t hour; - int8_t min; - int8_t sec; - int32_t micros; -} duckdb_time_struct; - -//! TIME_TZ is stored as 40 bits for int64_t micros, and 24 bits for int32_t offset -typedef struct { - uint64_t bits; -} duckdb_time_tz; -typedef struct { - duckdb_time_struct time; - int32_t offset; -} duckdb_time_tz_struct; - -//! TIMESTAMP values are stored as microseconds since 1970-01-01. -//! Use the duckdb_from_timestamp and duckdb_to_timestamp functions to extract individual information. -typedef struct { - int64_t micros; -} duckdb_timestamp; - -//! TIMESTAMP_S values are stored as seconds since 1970-01-01. -typedef struct { - int64_t seconds; -} duckdb_timestamp_s; - -//! TIMESTAMP_MS values are stored as milliseconds since 1970-01-01. -typedef struct { - int64_t millis; -} duckdb_timestamp_ms; - -//! TIMESTAMP_NS values are stored as nanoseconds since 1970-01-01. -typedef struct { - int64_t nanos; -} duckdb_timestamp_ns; - -typedef struct { - duckdb_date_struct date; - duckdb_time_struct time; -} duckdb_timestamp_struct; - -typedef struct { - int32_t months; - int32_t days; - int64_t micros; -} duckdb_interval; - -//! Hugeints are composed of a (lower, upper) component -//! The value of the hugeint is upper * 2^64 + lower -//! For easy usage, the functions duckdb_hugeint_to_double/duckdb_double_to_hugeint are recommended -typedef struct { - uint64_t lower; - int64_t upper; -} duckdb_hugeint; -typedef struct { - uint64_t lower; - uint64_t upper; -} duckdb_uhugeint; - -//! Decimals are composed of a width and a scale, and are stored in a hugeint -typedef struct { - uint8_t width; - uint8_t scale; - duckdb_hugeint value; -} duckdb_decimal; - -//! A type holding information about the query execution progress -typedef struct { - double percentage; - uint64_t rows_processed; - uint64_t total_rows_to_process; -} duckdb_query_progress_type; - -//! The internal representation of a VARCHAR (string_t). If the VARCHAR does not -//! exceed 12 characters, then we inline it. Otherwise, we inline a prefix for faster -//! string comparisons and store a pointer to the remaining characters. This is a non- -//! owning structure, i.e., it does not have to be freed. -typedef struct { - union { - struct { - uint32_t length; - char prefix[4]; - char *ptr; - } pointer; - struct { - uint32_t length; - char inlined[12]; - } inlined; - } value; -} duckdb_string_t; - -//! The internal representation of a list metadata entry contains the list's offset in -//! the child vector, and its length. The parent vector holds these metadata entries, -//! whereas the child vector holds the data -typedef struct { - uint64_t offset; - uint64_t length; -} duckdb_list_entry; - -//! A column consists of a pointer to its internal data. Don't operate on this type directly. -//! Instead, use functions such as duckdb_column_data, duckdb_nullmask_data, -//! duckdb_column_type, and duckdb_column_name, which take the result and the column index -//! as their parameters -typedef struct { - // deprecated, use duckdb_column_data - void *deprecated_data; - // deprecated, use duckdb_nullmask_data - bool *deprecated_nullmask; - // deprecated, use duckdb_column_type - duckdb_type deprecated_type; - // deprecated, use duckdb_column_name - char *deprecated_name; - void *internal_data; -} duckdb_column; - -//! A vector to a specified column in a data chunk. Lives as long as the -//! data chunk lives, i.e., must not be destroyed. -typedef struct _duckdb_vector { - void *internal_ptr; -} * duckdb_vector; - -//===--------------------------------------------------------------------===// -// Types (explicit freeing/destroying) -//===--------------------------------------------------------------------===// - -//! Strings are composed of a char pointer and a size. You must free string.data -//! with `duckdb_free`. -typedef struct { - char *data; - idx_t size; -} duckdb_string; - -//! BLOBs are composed of a byte pointer and a size. You must free blob.data -//! with `duckdb_free`. -typedef struct { - void *data; - idx_t size; -} duckdb_blob; - -//! BITs are composed of a byte pointer and a size. -//! BIT byte data has 0 to 7 bits of padding. -//! The first byte contains the number of padding bits. -//! This number of bits of the second byte are set to 1, starting from the MSB. -//! You must free `data` with `duckdb_free`. -typedef struct { - uint8_t *data; - idx_t size; -} duckdb_bit; - -//! VARINTs are composed of a byte pointer, a size, and an is_negative bool. -//! The absolute value of the number is stored in `data` in little endian format. -//! You must free `data` with `duckdb_free`. -typedef struct { - uint8_t *data; - idx_t size; - bool is_negative; -} duckdb_varint; - -//! A query result consists of a pointer to its internal data. -//! Must be freed with 'duckdb_destroy_result'. -typedef struct { - // deprecated, use duckdb_column_count - idx_t deprecated_column_count; - // deprecated, use duckdb_row_count - idx_t deprecated_row_count; - // deprecated, use duckdb_rows_changed - idx_t deprecated_rows_changed; - // deprecated, use duckdb_column_*-family of functions - duckdb_column *deprecated_columns; - // deprecated, use duckdb_result_error - char *deprecated_error_message; - void *internal_data; -} duckdb_result; - -//! A database instance cache object. Must be destroyed with `duckdb_destroy_instance_cache`. -typedef struct _duckdb_instance_cache { - void *internal_ptr; -} * duckdb_instance_cache; - -//! A database object. Must be closed with `duckdb_close`. -typedef struct _duckdb_database { - void *internal_ptr; -} * duckdb_database; - -//! A connection to a duckdb database. Must be closed with `duckdb_disconnect`. -typedef struct _duckdb_connection { - void *internal_ptr; -} * duckdb_connection; - -//! A prepared statement is a parameterized query that allows you to bind parameters to it. -//! Must be destroyed with `duckdb_destroy_prepare`. -typedef struct _duckdb_prepared_statement { - void *internal_ptr; -} * duckdb_prepared_statement; - -//! Extracted statements. Must be destroyed with `duckdb_destroy_extracted`. -typedef struct _duckdb_extracted_statements { - void *internal_ptr; -} * duckdb_extracted_statements; - -//! The pending result represents an intermediate structure for a query that is not yet fully executed. -//! Must be destroyed with `duckdb_destroy_pending`. -typedef struct _duckdb_pending_result { - void *internal_ptr; -} * duckdb_pending_result; - -//! The appender enables fast data loading into DuckDB. -//! Must be destroyed with `duckdb_appender_destroy`. -typedef struct _duckdb_appender { - void *internal_ptr; -} * duckdb_appender; - -//! The table description allows querying info about the table. -//! Must be destroyed with `duckdb_table_description_destroy`. -typedef struct _duckdb_table_description { - void *internal_ptr; -} * duckdb_table_description; - -//! Can be used to provide start-up options for the DuckDB instance. -//! Must be destroyed with `duckdb_destroy_config`. -typedef struct _duckdb_config { - void *internal_ptr; -} * duckdb_config; - -//! Holds an internal logical type. -//! Must be destroyed with `duckdb_destroy_logical_type`. -typedef struct _duckdb_logical_type { - void *internal_ptr; -} * duckdb_logical_type; - -//! Holds extra information used when registering a custom logical type. -//! Reserved for future use. -typedef struct _duckdb_create_type_info { - void *internal_ptr; -} * duckdb_create_type_info; - -//! Contains a data chunk from a duckdb_result. -//! Must be destroyed with `duckdb_destroy_data_chunk`. -typedef struct _duckdb_data_chunk { - void *internal_ptr; -} * duckdb_data_chunk; - -//! Holds a DuckDB value, which wraps a type. -//! Must be destroyed with `duckdb_destroy_value`. -typedef struct _duckdb_value { - void *internal_ptr; -} * duckdb_value; - -//! Holds a recursive tree that matches the query plan. -typedef struct _duckdb_profiling_info { - void *internal_ptr; -} * duckdb_profiling_info; - -//===--------------------------------------------------------------------===// -// C API Extension info -//===--------------------------------------------------------------------===// -//! Holds state during the C API extension intialization process -typedef struct _duckdb_extension_info { - void *internal_ptr; -} * duckdb_extension_info; - -//===--------------------------------------------------------------------===// -// Function types -//===--------------------------------------------------------------------===// -//! Additional function info. When setting this info, it is necessary to pass a destroy-callback function. -typedef struct _duckdb_function_info { - void *internal_ptr; -} * duckdb_function_info; - -//===--------------------------------------------------------------------===// -// Scalar function types -//===--------------------------------------------------------------------===// -//! A scalar function. Must be destroyed with `duckdb_destroy_scalar_function`. -typedef struct _duckdb_scalar_function { - void *internal_ptr; -} * duckdb_scalar_function; - -//! A scalar function set. Must be destroyed with `duckdb_destroy_scalar_function_set`. -typedef struct _duckdb_scalar_function_set { - void *internal_ptr; -} * duckdb_scalar_function_set; - -//! The main function of the scalar function. -typedef void (*duckdb_scalar_function_t)(duckdb_function_info info, duckdb_data_chunk input, duckdb_vector output); - -//===--------------------------------------------------------------------===// -// Aggregate function types -//===--------------------------------------------------------------------===// -//! An aggregate function. Must be destroyed with `duckdb_destroy_aggregate_function`. -typedef struct _duckdb_aggregate_function { - void *internal_ptr; -} * duckdb_aggregate_function; - -//! A aggregate function set. Must be destroyed with `duckdb_destroy_aggregate_function_set`. -typedef struct _duckdb_aggregate_function_set { - void *internal_ptr; -} * duckdb_aggregate_function_set; - -//! Aggregate state -typedef struct _duckdb_aggregate_state { - void *internal_ptr; -} * duckdb_aggregate_state; - -//! Returns the aggregate state size -typedef idx_t (*duckdb_aggregate_state_size)(duckdb_function_info info); -//! Initialize the aggregate state -typedef void (*duckdb_aggregate_init_t)(duckdb_function_info info, duckdb_aggregate_state state); -//! Destroy aggregate state (optional) -typedef void (*duckdb_aggregate_destroy_t)(duckdb_aggregate_state *states, idx_t count); -//! Update a set of aggregate states with new values -typedef void (*duckdb_aggregate_update_t)(duckdb_function_info info, duckdb_data_chunk input, - duckdb_aggregate_state *states); -//! Combine aggregate states -typedef void (*duckdb_aggregate_combine_t)(duckdb_function_info info, duckdb_aggregate_state *source, - duckdb_aggregate_state *target, idx_t count); -//! Finalize aggregate states into a result vector -typedef void (*duckdb_aggregate_finalize_t)(duckdb_function_info info, duckdb_aggregate_state *source, - duckdb_vector result, idx_t count, idx_t offset); - -//===--------------------------------------------------------------------===// -// Table function types -//===--------------------------------------------------------------------===// - -//! A table function. Must be destroyed with `duckdb_destroy_table_function`. -typedef struct _duckdb_table_function { - void *internal_ptr; -} * duckdb_table_function; - -//! The bind info of the function. When setting this info, it is necessary to pass a destroy-callback function. -typedef struct _duckdb_bind_info { - void *internal_ptr; -} * duckdb_bind_info; - -//! Additional function init info. When setting this info, it is necessary to pass a destroy-callback function. -typedef struct _duckdb_init_info { - void *internal_ptr; -} * duckdb_init_info; - -//! The bind function of the table function. -typedef void (*duckdb_table_function_bind_t)(duckdb_bind_info info); - -//! The (possibly thread-local) init function of the table function. -typedef void (*duckdb_table_function_init_t)(duckdb_init_info info); - -//! The main function of the table function. -typedef void (*duckdb_table_function_t)(duckdb_function_info info, duckdb_data_chunk output); - -//===--------------------------------------------------------------------===// -// Cast types -//===--------------------------------------------------------------------===// - -//! A cast function. Must be destroyed with `duckdb_destroy_cast_function`. -typedef struct _duckdb_cast_function { - void *internal_ptr; -} * duckdb_cast_function; - -typedef bool (*duckdb_cast_function_t)(duckdb_function_info info, idx_t count, duckdb_vector input, - duckdb_vector output); - -//===--------------------------------------------------------------------===// -// Replacement scan types -//===--------------------------------------------------------------------===// - -//! Additional replacement scan info. When setting this info, it is necessary to pass a destroy-callback function. -typedef struct _duckdb_replacement_scan_info { - void *internal_ptr; -} * duckdb_replacement_scan_info; - -//! A replacement scan function that can be added to a database. -typedef void (*duckdb_replacement_callback_t)(duckdb_replacement_scan_info info, const char *table_name, void *data); - -//===--------------------------------------------------------------------===// -// Arrow-related types -//===--------------------------------------------------------------------===// - -//! Holds an arrow query result. Must be destroyed with `duckdb_destroy_arrow`. -typedef struct _duckdb_arrow { - void *internal_ptr; -} * duckdb_arrow; - -//! Holds an arrow array stream. Must be destroyed with `duckdb_destroy_arrow_stream`. -typedef struct _duckdb_arrow_stream { - void *internal_ptr; -} * duckdb_arrow_stream; - -//! Holds an arrow schema. Remember to release the respective ArrowSchema object. -typedef struct _duckdb_arrow_schema { - void *internal_ptr; -} * duckdb_arrow_schema; - -//! Holds an arrow array. Remember to release the respective ArrowArray object. -typedef struct _duckdb_arrow_array { - void *internal_ptr; -} * duckdb_arrow_array; - -//===--------------------------------------------------------------------===// -// DuckDB extension access -//===--------------------------------------------------------------------===// -//! Passed to C API extension as parameter to the entrypoint -struct duckdb_extension_access { - //! Indicate that an error has occurred - void (*set_error)(duckdb_extension_info info, const char *error); - //! Fetch the database from duckdb to register extensions to - duckdb_database *(*get_database)(duckdb_extension_info info); - //! Fetch the API - const void *(*get_api)(duckdb_extension_info info, const char *version); -}; - -#ifndef DUCKDB_API_EXCLUDE_FUNCTIONS - -//===--------------------------------------------------------------------===// -// Functions -//===--------------------------------------------------------------------===// - -//===--------------------------------------------------------------------===// -// Open Connect -//===--------------------------------------------------------------------===// - -/*! -Creates a new database instance cache. -The instance cache is necessary if a client/program (re)opens multiple databases to the same file within the same -process. Must be destroyed with 'duckdb_destroy_instance_cache'. - -* @return The database instance cache. -*/ -DUCKDB_API duckdb_instance_cache duckdb_create_instance_cache(); - -/*! -Creates a new database instance in the instance cache, or retrieves an existing database instance. -Must be closed with 'duckdb_close'. - -* @param instance_cache The instance cache in which to create the database, or from which to take the database. -* @param path Path to the database file on disk. Both `nullptr` and `:memory:` open or retrieve an in-memory database. -* @param out_database The resulting cached database. -* @param config (Optional) configuration used to create the database. -* @param out_error If set and the function returns `DuckDBError`, this contains the error message. -Note that the error message must be freed using `duckdb_free`. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_get_or_create_from_cache(duckdb_instance_cache instance_cache, const char *path, - duckdb_database *out_database, duckdb_config config, - char **out_error); - -/*! -Destroys an existing database instance cache and de-allocates its memory. - -* @param instance_cache The instance cache to destroy. -*/ -DUCKDB_API void duckdb_destroy_instance_cache(duckdb_instance_cache *instance_cache); - -/*! -Creates a new database or opens an existing database file stored at the given path. -If no path is given a new in-memory database is created instead. -The database must be closed with 'duckdb_close'. - -* @param path Path to the database file on disk. Both `nullptr` and `:memory:` open an in-memory database. -* @param out_database The result database object. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_open(const char *path, duckdb_database *out_database); - -/*! -Extended version of duckdb_open. Creates a new database or opens an existing database file stored at the given path. -The database must be closed with 'duckdb_close'. - -* @param path Path to the database file on disk. Both `nullptr` and `:memory:` open an in-memory database. -* @param out_database The result database object. -* @param config (Optional) configuration used to start up the database. -* @param out_error If set and the function returns `DuckDBError`, this contains the error message. -Note that the error message must be freed using `duckdb_free`. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_open_ext(const char *path, duckdb_database *out_database, duckdb_config config, - char **out_error); - -/*! -Closes the specified database and de-allocates all memory allocated for that database. -This should be called after you are done with any database allocated through `duckdb_open` or `duckdb_open_ext`. -Note that failing to call `duckdb_close` (in case of e.g. a program crash) will not cause data corruption. -Still, it is recommended to always correctly close a database object after you are done with it. - -* @param database The database object to shut down. -*/ -DUCKDB_API void duckdb_close(duckdb_database *database); - -/*! -Opens a connection to a database. Connections are required to query the database, and store transactional state -associated with the connection. -The instantiated connection should be closed using 'duckdb_disconnect'. - -* @param database The database file to connect to. -* @param out_connection The result connection object. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_connect(duckdb_database database, duckdb_connection *out_connection); - -/*! -Interrupt running query - -* @param connection The connection to interrupt -*/ -DUCKDB_API void duckdb_interrupt(duckdb_connection connection); - -/*! -Get progress of the running query - -* @param connection The working connection -* @return -1 if no progress or a percentage of the progress -*/ -DUCKDB_API duckdb_query_progress_type duckdb_query_progress(duckdb_connection connection); - -/*! -Closes the specified connection and de-allocates all memory allocated for that connection. - -* @param connection The connection to close. -*/ -DUCKDB_API void duckdb_disconnect(duckdb_connection *connection); - -/*! -Returns the version of the linked DuckDB, with a version postfix for dev versions - -Usually used for developing C extensions that must return this for a compatibility check. -*/ -DUCKDB_API const char *duckdb_library_version(); - -//===--------------------------------------------------------------------===// -// Configuration -//===--------------------------------------------------------------------===// - -/*! -Initializes an empty configuration object that can be used to provide start-up options for the DuckDB instance -through `duckdb_open_ext`. -The duckdb_config must be destroyed using 'duckdb_destroy_config' - -This will always succeed unless there is a malloc failure. - -Note that `duckdb_destroy_config` should always be called on the resulting config, even if the function returns -`DuckDBError`. - -* @param out_config The result configuration object. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_create_config(duckdb_config *out_config); - -/*! -This returns the total amount of configuration options available for usage with `duckdb_get_config_flag`. - -This should not be called in a loop as it internally loops over all the options. - -* @return The amount of config options available. -*/ -DUCKDB_API size_t duckdb_config_count(); - -/*! -Obtains a human-readable name and description of a specific configuration option. This can be used to e.g. -display configuration options. This will succeed unless `index` is out of range (i.e. `>= duckdb_config_count`). - -The result name or description MUST NOT be freed. - -* @param index The index of the configuration option (between 0 and `duckdb_config_count`) -* @param out_name A name of the configuration flag. -* @param out_description A description of the configuration flag. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_get_config_flag(size_t index, const char **out_name, const char **out_description); - -/*! -Sets the specified option for the specified configuration. The configuration option is indicated by name. -To obtain a list of config options, see `duckdb_get_config_flag`. - -In the source code, configuration options are defined in `config.cpp`. - -This can fail if either the name is invalid, or if the value provided for the option is invalid. - -* @param config The configuration object to set the option on. -* @param name The name of the configuration flag to set. -* @param option The value to set the configuration flag to. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_set_config(duckdb_config config, const char *name, const char *option); - -/*! -Destroys the specified configuration object and de-allocates all memory allocated for the object. - -* @param config The configuration object to destroy. -*/ -DUCKDB_API void duckdb_destroy_config(duckdb_config *config); - -//===--------------------------------------------------------------------===// -// Query Execution -//===--------------------------------------------------------------------===// - -/*! -Executes a SQL query within a connection and stores the full (materialized) result in the out_result pointer. -If the query fails to execute, DuckDBError is returned and the error message can be retrieved by calling -`duckdb_result_error`. - -Note that after running `duckdb_query`, `duckdb_destroy_result` must be called on the result object even if the -query fails, otherwise the error stored within the result will not be freed correctly. - -* @param connection The connection to perform the query in. -* @param query The SQL query to run. -* @param out_result The query result. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_query(duckdb_connection connection, const char *query, duckdb_result *out_result); - -/*! -Closes the result and de-allocates all memory allocated for that connection. - -* @param result The result to destroy. -*/ -DUCKDB_API void duckdb_destroy_result(duckdb_result *result); - -/*! -Returns the column name of the specified column. The result should not need to be freed; the column names will -automatically be destroyed when the result is destroyed. - -Returns `NULL` if the column is out of range. - -* @param result The result object to fetch the column name from. -* @param col The column index. -* @return The column name of the specified column. -*/ -DUCKDB_API const char *duckdb_column_name(duckdb_result *result, idx_t col); - -/*! -Returns the column type of the specified column. - -Returns `DUCKDB_TYPE_INVALID` if the column is out of range. - -* @param result The result object to fetch the column type from. -* @param col The column index. -* @return The column type of the specified column. -*/ -DUCKDB_API duckdb_type duckdb_column_type(duckdb_result *result, idx_t col); - -/*! -Returns the statement type of the statement that was executed - -* @param result The result object to fetch the statement type from. -* @return duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID -*/ -DUCKDB_API duckdb_statement_type duckdb_result_statement_type(duckdb_result result); - -/*! -Returns the logical column type of the specified column. - -The return type of this call should be destroyed with `duckdb_destroy_logical_type`. - -Returns `NULL` if the column is out of range. - -* @param result The result object to fetch the column type from. -* @param col The column index. -* @return The logical column type of the specified column. -*/ -DUCKDB_API duckdb_logical_type duckdb_column_logical_type(duckdb_result *result, idx_t col); - -/*! -Returns the number of columns present in a the result object. - -* @param result The result object. -* @return The number of columns present in the result object. -*/ -DUCKDB_API idx_t duckdb_column_count(duckdb_result *result); - -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Returns the number of rows present in the result object. - -* @param result The result object. -* @return The number of rows present in the result object. -*/ -DUCKDB_API idx_t duckdb_row_count(duckdb_result *result); - -#endif -/*! -Returns the number of rows changed by the query stored in the result. This is relevant only for INSERT/UPDATE/DELETE -queries. For other queries the rows_changed will be 0. - -* @param result The result object. -* @return The number of rows changed. -*/ -DUCKDB_API idx_t duckdb_rows_changed(duckdb_result *result); - -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATED**: Prefer using `duckdb_result_get_chunk` instead. - -Returns the data of a specific column of a result in columnar format. - -The function returns a dense array which contains the result data. The exact type stored in the array depends on the -corresponding duckdb_type (as provided by `duckdb_column_type`). For the exact type by which the data should be -accessed, see the comments in [the types section](types) or the `DUCKDB_TYPE` enum. - -For example, for a column of type `DUCKDB_TYPE_INTEGER`, rows can be accessed in the following manner: -```c -int32_t *data = (int32_t *) duckdb_column_data(&result, 0); -printf("Data for row %d: %d\n", row, data[row]); -``` - -* @param result The result object to fetch the column data from. -* @param col The column index. -* @return The column data of the specified column. -*/ -DUCKDB_API void *duckdb_column_data(duckdb_result *result, idx_t col); - -/*! -**DEPRECATED**: Prefer using `duckdb_result_get_chunk` instead. - -Returns the nullmask of a specific column of a result in columnar format. The nullmask indicates for every row -whether or not the corresponding row is `NULL`. If a row is `NULL`, the values present in the array provided -by `duckdb_column_data` are undefined. - -```c -int32_t *data = (int32_t *) duckdb_column_data(&result, 0); -bool *nullmask = duckdb_nullmask_data(&result, 0); -if (nullmask[row]) { - printf("Data for row %d: NULL\n", row); -} else { - printf("Data for row %d: %d\n", row, data[row]); -} -``` - -* @param result The result object to fetch the nullmask from. -* @param col The column index. -* @return The nullmask of the specified column. -*/ -DUCKDB_API bool *duckdb_nullmask_data(duckdb_result *result, idx_t col); - -#endif -/*! -Returns the error message contained within the result. The error is only set if `duckdb_query` returns `DuckDBError`. - -The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_result` is called. - -* @param result The result object to fetch the error from. -* @return The error of the result. -*/ -DUCKDB_API const char *duckdb_result_error(duckdb_result *result); - -/*! -Returns the result error type contained within the result. The error is only set if `duckdb_query` returns -`DuckDBError`. - -* @param result The result object to fetch the error from. -* @return The error type of the result. -*/ -DUCKDB_API duckdb_error_type duckdb_result_error_type(duckdb_result *result); - -//===--------------------------------------------------------------------===// -// Result Functions -//===--------------------------------------------------------------------===// - -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Fetches a data chunk from the duckdb_result. This function should be called repeatedly until the result is exhausted. - -The result must be destroyed with `duckdb_destroy_data_chunk`. - -This function supersedes all `duckdb_value` functions, as well as the `duckdb_column_data` and `duckdb_nullmask_data` -functions. It results in significantly better performance, and should be preferred in newer code-bases. - -If this function is used, none of the other result functions can be used and vice versa (i.e. this function cannot be -mixed with the legacy result functions). - -Use `duckdb_result_chunk_count` to figure out how many chunks there are in the result. - -* @param result The result object to fetch the data chunk from. -* @param chunk_index The chunk index to fetch from. -* @return The resulting data chunk. Returns `NULL` if the chunk index is out of bounds. -*/ -DUCKDB_API duckdb_data_chunk duckdb_result_get_chunk(duckdb_result result, idx_t chunk_index); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Checks if the type of the internal result is StreamQueryResult. - -* @param result The result object to check. -* @return Whether or not the result object is of the type StreamQueryResult -*/ -DUCKDB_API bool duckdb_result_is_streaming(duckdb_result result); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Returns the number of data chunks present in the result. - -* @param result The result object -* @return Number of data chunks present in the result. -*/ -DUCKDB_API idx_t duckdb_result_chunk_count(duckdb_result result); - -#endif -/*! -Returns the return_type of the given result, or DUCKDB_RETURN_TYPE_INVALID on error - -* @param result The result object -* @return The return_type -*/ -DUCKDB_API duckdb_result_type duckdb_result_return_type(duckdb_result result); - -//===--------------------------------------------------------------------===// -// Safe Fetch Functions -//===--------------------------------------------------------------------===// - -// These functions will perform conversions if necessary. -// On failure (e.g. if conversion cannot be performed or if the value is NULL) a default value is returned. -// Note that these functions are slow since they perform bounds checking and conversion -// For fast access of values prefer using `duckdb_result_get_chunk` -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The boolean value at the specified location, or false if the value cannot be converted. -*/ -DUCKDB_API bool duckdb_value_boolean(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The int8_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API int8_t duckdb_value_int8(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The int16_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API int16_t duckdb_value_int16(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The int32_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API int32_t duckdb_value_int32(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The int64_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API int64_t duckdb_value_int64(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_hugeint value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_hugeint duckdb_value_hugeint(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_uhugeint value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_uhugeint duckdb_value_uhugeint(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_decimal value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_decimal duckdb_value_decimal(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The uint8_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API uint8_t duckdb_value_uint8(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The uint16_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API uint16_t duckdb_value_uint16(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The uint32_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API uint32_t duckdb_value_uint32(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The uint64_t value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API uint64_t duckdb_value_uint64(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The float value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API float duckdb_value_float(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The double value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API double duckdb_value_double(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_date value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_date duckdb_value_date(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_time value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_time duckdb_value_time(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_timestamp value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_timestamp duckdb_value_timestamp(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_interval value at the specified location, or 0 if the value cannot be converted. -*/ -DUCKDB_API duckdb_interval duckdb_value_interval(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATED**: Use duckdb_value_string instead. This function does not work correctly if the string contains null -bytes. - -* @return The text value at the specified location as a null-terminated string, or nullptr if the value cannot be -converted. The result must be freed with `duckdb_free`. -*/ -DUCKDB_API char *duckdb_value_varchar(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -No support for nested types, and for other complex types. -The resulting field "string.data" must be freed with `duckdb_free.` - -* @return The string value at the specified location. Attempts to cast the result value to string. -*/ -DUCKDB_API duckdb_string duckdb_value_string(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. - -* @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. -If the column is NOT a VARCHAR column this function will return NULL. - -The result must NOT be freed. -*/ -DUCKDB_API char *duckdb_value_varchar_internal(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATED**: Use duckdb_value_string_internal instead. This function does not work correctly if the string contains -null bytes. -* @return The char* value at the specified location. ONLY works on VARCHAR columns and does not auto-cast. -If the column is NOT a VARCHAR column this function will return NULL. - -The result must NOT be freed. -*/ -DUCKDB_API duckdb_string duckdb_value_string_internal(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return The duckdb_blob value at the specified location. Returns a blob with blob.data set to nullptr if the -value cannot be converted. The resulting field "blob.data" must be freed with `duckdb_free.` -*/ -DUCKDB_API duckdb_blob duckdb_value_blob(duckdb_result *result, idx_t col, idx_t row); - -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -* @return Returns true if the value at the specified index is NULL, and false otherwise. -*/ -DUCKDB_API bool duckdb_value_is_null(duckdb_result *result, idx_t col, idx_t row); - -#endif -//===--------------------------------------------------------------------===// -// Helpers -//===--------------------------------------------------------------------===// - -/*! -Allocate `size` bytes of memory using the duckdb internal malloc function. Any memory allocated in this manner -should be freed using `duckdb_free`. - -* @param size The number of bytes to allocate. -* @return A pointer to the allocated memory region. -*/ -DUCKDB_API void *duckdb_malloc(size_t size); - -/*! -Free a value returned from `duckdb_malloc`, `duckdb_value_varchar`, `duckdb_value_blob`, or -`duckdb_value_string`. - -* @param ptr The memory region to de-allocate. -*/ -DUCKDB_API void duckdb_free(void *ptr); - -/*! -The internal vector size used by DuckDB. -This is the amount of tuples that will fit into a data chunk created by `duckdb_create_data_chunk`. - -* @return The vector size. -*/ -DUCKDB_API idx_t duckdb_vector_size(); - -/*! -Whether or not the duckdb_string_t value is inlined. -This means that the data of the string does not have a separate allocation. - -*/ -DUCKDB_API bool duckdb_string_is_inlined(duckdb_string_t string); - -/*! -Get the string length of a string_t - -* @param string The string to get the length of. -* @return The length. -*/ -DUCKDB_API uint32_t duckdb_string_t_length(duckdb_string_t string); - -/*! -Get a pointer to the string data of a string_t - -* @param string The string to get the pointer to. -* @return The pointer. -*/ -DUCKDB_API const char *duckdb_string_t_data(duckdb_string_t *string); - -//===--------------------------------------------------------------------===// -// Date Time Timestamp Helpers -//===--------------------------------------------------------------------===// - -/*! -Decompose a `duckdb_date` object into year, month and date (stored as `duckdb_date_struct`). - -* @param date The date object, as obtained from a `DUCKDB_TYPE_DATE` column. -* @return The `duckdb_date_struct` with the decomposed elements. -*/ -DUCKDB_API duckdb_date_struct duckdb_from_date(duckdb_date date); - -/*! -Re-compose a `duckdb_date` from year, month and date (`duckdb_date_struct`). - -* @param date The year, month and date stored in a `duckdb_date_struct`. -* @return The `duckdb_date` element. -*/ -DUCKDB_API duckdb_date duckdb_to_date(duckdb_date_struct date); - -/*! -Test a `duckdb_date` to see if it is a finite value. - -* @param date The date object, as obtained from a `DUCKDB_TYPE_DATE` column. -* @return True if the date is finite, false if it is ±infinity. -*/ -DUCKDB_API bool duckdb_is_finite_date(duckdb_date date); - -/*! -Decompose a `duckdb_time` object into hour, minute, second and microsecond (stored as `duckdb_time_struct`). - -* @param time The time object, as obtained from a `DUCKDB_TYPE_TIME` column. -* @return The `duckdb_time_struct` with the decomposed elements. -*/ -DUCKDB_API duckdb_time_struct duckdb_from_time(duckdb_time time); - -/*! -Create a `duckdb_time_tz` object from micros and a timezone offset. - -* @param micros The microsecond component of the time. -* @param offset The timezone offset component of the time. -* @return The `duckdb_time_tz` element. -*/ -DUCKDB_API duckdb_time_tz duckdb_create_time_tz(int64_t micros, int32_t offset); - -/*! -Decompose a TIME_TZ objects into micros and a timezone offset. - -Use `duckdb_from_time` to further decompose the micros into hour, minute, second and microsecond. - -* @param micros The time object, as obtained from a `DUCKDB_TYPE_TIME_TZ` column. -*/ -DUCKDB_API duckdb_time_tz_struct duckdb_from_time_tz(duckdb_time_tz micros); - -/*! -Re-compose a `duckdb_time` from hour, minute, second and microsecond (`duckdb_time_struct`). - -* @param time The hour, minute, second and microsecond in a `duckdb_time_struct`. -* @return The `duckdb_time` element. -*/ -DUCKDB_API duckdb_time duckdb_to_time(duckdb_time_struct time); - -/*! -Decompose a `duckdb_timestamp` object into a `duckdb_timestamp_struct`. - -* @param ts The ts object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. -* @return The `duckdb_timestamp_struct` with the decomposed elements. -*/ -DUCKDB_API duckdb_timestamp_struct duckdb_from_timestamp(duckdb_timestamp ts); - -/*! -Re-compose a `duckdb_timestamp` from a duckdb_timestamp_struct. - -* @param ts The de-composed elements in a `duckdb_timestamp_struct`. -* @return The `duckdb_timestamp` element. -*/ -DUCKDB_API duckdb_timestamp duckdb_to_timestamp(duckdb_timestamp_struct ts); - -/*! -Test a `duckdb_timestamp` to see if it is a finite value. - -* @param ts The duckdb_timestamp object, as obtained from a `DUCKDB_TYPE_TIMESTAMP` column. -* @return True if the timestamp is finite, false if it is ±infinity. -*/ -DUCKDB_API bool duckdb_is_finite_timestamp(duckdb_timestamp ts); - -/*! -Test a `duckdb_timestamp_s` to see if it is a finite value. - -* @param ts The duckdb_timestamp_s object, as obtained from a `DUCKDB_TYPE_TIMESTAMP_S` column. -* @return True if the timestamp is finite, false if it is ±infinity. -*/ -DUCKDB_API bool duckdb_is_finite_timestamp_s(duckdb_timestamp_s ts); - -/*! -Test a `duckdb_timestamp_ms` to see if it is a finite value. - -* @param ts The duckdb_timestamp_ms object, as obtained from a `DUCKDB_TYPE_TIMESTAMP_MS` column. -* @return True if the timestamp is finite, false if it is ±infinity. -*/ -DUCKDB_API bool duckdb_is_finite_timestamp_ms(duckdb_timestamp_ms ts); - -/*! -Test a `duckdb_timestamp_ns` to see if it is a finite value. - -* @param ts The duckdb_timestamp_ns object, as obtained from a `DUCKDB_TYPE_TIMESTAMP_NS` column. -* @return True if the timestamp is finite, false if it is ±infinity. -*/ -DUCKDB_API bool duckdb_is_finite_timestamp_ns(duckdb_timestamp_ns ts); - -//===--------------------------------------------------------------------===// -// Hugeint Helpers -//===--------------------------------------------------------------------===// - -/*! -Converts a duckdb_hugeint object (as obtained from a `DUCKDB_TYPE_HUGEINT` column) into a double. - -* @param val The hugeint value. -* @return The converted `double` element. -*/ -DUCKDB_API double duckdb_hugeint_to_double(duckdb_hugeint val); - -/*! -Converts a double value to a duckdb_hugeint object. - -If the conversion fails because the double value is too big the result will be 0. - -* @param val The double value. -* @return The converted `duckdb_hugeint` element. -*/ -DUCKDB_API duckdb_hugeint duckdb_double_to_hugeint(double val); - -//===--------------------------------------------------------------------===// -// Unsigned Hugeint Helpers -//===--------------------------------------------------------------------===// - -/*! -Converts a duckdb_uhugeint object (as obtained from a `DUCKDB_TYPE_UHUGEINT` column) into a double. - -* @param val The uhugeint value. -* @return The converted `double` element. -*/ -DUCKDB_API double duckdb_uhugeint_to_double(duckdb_uhugeint val); - -/*! -Converts a double value to a duckdb_uhugeint object. - -If the conversion fails because the double value is too big the result will be 0. - -* @param val The double value. -* @return The converted `duckdb_uhugeint` element. -*/ -DUCKDB_API duckdb_uhugeint duckdb_double_to_uhugeint(double val); - -//===--------------------------------------------------------------------===// -// Decimal Helpers -//===--------------------------------------------------------------------===// - -/*! -Converts a double value to a duckdb_decimal object. - -If the conversion fails because the double value is too big, or the width/scale are invalid the result will be 0. - -* @param val The double value. -* @return The converted `duckdb_decimal` element. -*/ -DUCKDB_API duckdb_decimal duckdb_double_to_decimal(double val, uint8_t width, uint8_t scale); - -/*! -Converts a duckdb_decimal object (as obtained from a `DUCKDB_TYPE_DECIMAL` column) into a double. - -* @param val The decimal value. -* @return The converted `double` element. -*/ -DUCKDB_API double duckdb_decimal_to_double(duckdb_decimal val); - -//===--------------------------------------------------------------------===// -// Prepared Statements -//===--------------------------------------------------------------------===// - -// A prepared statement is a parameterized query that allows you to bind parameters to it. -// * This is useful to easily supply parameters to functions and avoid SQL injection attacks. -// * This is useful to speed up queries that you will execute several times with different parameters. -// Because the query will only be parsed, bound, optimized and planned once during the prepare stage, -// rather than once per execution. -// For example: -// SELECT * FROM tbl WHERE id=? -// Or a query with multiple parameters: -// SELECT * FROM tbl WHERE id=$1 OR name=$2 -/*! -Create a prepared statement object from a query. - -Note that after calling `duckdb_prepare`, the prepared statement should always be destroyed using -`duckdb_destroy_prepare`, even if the prepare fails. - -If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. - -* @param connection The connection object -* @param query The SQL query to prepare -* @param out_prepared_statement The resulting prepared statement object -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_prepare(duckdb_connection connection, const char *query, - duckdb_prepared_statement *out_prepared_statement); - -/*! -Closes the prepared statement and de-allocates all memory allocated for the statement. - -* @param prepared_statement The prepared statement to destroy. -*/ -DUCKDB_API void duckdb_destroy_prepare(duckdb_prepared_statement *prepared_statement); - -/*! -Returns the error message associated with the given prepared statement. -If the prepared statement has no error message, this returns `nullptr` instead. - -The error message should not be freed. It will be de-allocated when `duckdb_destroy_prepare` is called. - -* @param prepared_statement The prepared statement to obtain the error from. -* @return The error message, or `nullptr` if there is none. -*/ -DUCKDB_API const char *duckdb_prepare_error(duckdb_prepared_statement prepared_statement); - -/*! -Returns the number of parameters that can be provided to the given prepared statement. - -Returns 0 if the query was not successfully prepared. - -* @param prepared_statement The prepared statement to obtain the number of parameters for. -*/ -DUCKDB_API idx_t duckdb_nparams(duckdb_prepared_statement prepared_statement); - -/*! -Returns the name used to identify the parameter -The returned string should be freed using `duckdb_free`. - -Returns NULL if the index is out of range for the provided prepared statement. - -* @param prepared_statement The prepared statement for which to get the parameter name from. -*/ -DUCKDB_API const char *duckdb_parameter_name(duckdb_prepared_statement prepared_statement, idx_t index); - -/*! -Returns the parameter type for the parameter at the given index. - -Returns `DUCKDB_TYPE_INVALID` if the parameter index is out of range or the statement was not successfully prepared. - -* @param prepared_statement The prepared statement. -* @param param_idx The parameter index. -* @return The parameter type -*/ -DUCKDB_API duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_t param_idx); - -/*! -Returns the logical type for the parameter at the given index. - -Returns `nullptr` if the parameter index is out of range or the statement was not successfully prepared. - -The return type of this call should be destroyed with `duckdb_destroy_logical_type`. - -* @param prepared_statement The prepared statement. -* @param param_idx The parameter index. -* @return The logical type of the parameter -*/ -DUCKDB_API duckdb_logical_type duckdb_param_logical_type(duckdb_prepared_statement prepared_statement, idx_t param_idx); - -/*! -Clear the params bind to the prepared statement. -*/ -DUCKDB_API duckdb_state duckdb_clear_bindings(duckdb_prepared_statement prepared_statement); - -/*! -Returns the statement type of the statement to be executed - -* @param statement The prepared statement. -* @return duckdb_statement_type value or DUCKDB_STATEMENT_TYPE_INVALID -*/ -DUCKDB_API duckdb_statement_type duckdb_prepared_statement_type(duckdb_prepared_statement statement); - -//===--------------------------------------------------------------------===// -// Bind Values To Prepared Statements -//===--------------------------------------------------------------------===// - -/*! -Binds a value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_value val); - -/*! -Retrieve the index of the parameter for the prepared statement, identified by name -*/ -DUCKDB_API duckdb_state duckdb_bind_parameter_index(duckdb_prepared_statement prepared_statement, idx_t *param_idx_out, - const char *name); - -/*! -Binds a bool value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_boolean(duckdb_prepared_statement prepared_statement, idx_t param_idx, bool val); - -/*! -Binds an int8_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_int8(duckdb_prepared_statement prepared_statement, idx_t param_idx, int8_t val); - -/*! -Binds an int16_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_int16(duckdb_prepared_statement prepared_statement, idx_t param_idx, int16_t val); - -/*! -Binds an int32_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_int32(duckdb_prepared_statement prepared_statement, idx_t param_idx, int32_t val); - -/*! -Binds an int64_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_int64(duckdb_prepared_statement prepared_statement, idx_t param_idx, int64_t val); - -/*! -Binds a duckdb_hugeint value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_hugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_hugeint val); - -/*! -Binds an duckdb_uhugeint value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_uhugeint(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_uhugeint val); - -/*! -Binds a duckdb_decimal value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_decimal(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_decimal val); - -/*! -Binds an uint8_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_uint8(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint8_t val); - -/*! -Binds an uint16_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_uint16(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint16_t val); - -/*! -Binds an uint32_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_uint32(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint32_t val); - -/*! -Binds an uint64_t value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_uint64(duckdb_prepared_statement prepared_statement, idx_t param_idx, uint64_t val); - -/*! -Binds a float value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_float(duckdb_prepared_statement prepared_statement, idx_t param_idx, float val); - -/*! -Binds a double value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_double(duckdb_prepared_statement prepared_statement, idx_t param_idx, double val); - -/*! -Binds a duckdb_date value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_date(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_date val); - -/*! -Binds a duckdb_time value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_time(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_time val); - -/*! -Binds a duckdb_timestamp value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_timestamp(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_timestamp val); - -/*! -Binds a duckdb_timestamp value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_timestamp_tz(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_timestamp val); - -/*! -Binds a duckdb_interval value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_interval(duckdb_prepared_statement prepared_statement, idx_t param_idx, - duckdb_interval val); - -/*! -Binds a null-terminated varchar value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_varchar(duckdb_prepared_statement prepared_statement, idx_t param_idx, - const char *val); - -/*! -Binds a varchar value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_varchar_length(duckdb_prepared_statement prepared_statement, idx_t param_idx, - const char *val, idx_t length); - -/*! -Binds a blob value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_blob(duckdb_prepared_statement prepared_statement, idx_t param_idx, - const void *data, idx_t length); - -/*! -Binds a NULL value to the prepared statement at the specified index. -*/ -DUCKDB_API duckdb_state duckdb_bind_null(duckdb_prepared_statement prepared_statement, idx_t param_idx); - -//===--------------------------------------------------------------------===// -// Execute Prepared Statements -//===--------------------------------------------------------------------===// - -/*! -Executes the prepared statement with the given bound parameters, and returns a materialized query result. - -This method can be called multiple times for each prepared statement, and the parameters can be modified -between calls to this function. - -Note that the result must be freed with `duckdb_destroy_result`. - -* @param prepared_statement The prepared statement to execute. -* @param out_result The query result. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_execute_prepared(duckdb_prepared_statement prepared_statement, - duckdb_result *out_result); - -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Executes the prepared statement with the given bound parameters, and returns an optionally-streaming query result. -To determine if the resulting query was in fact streamed, use `duckdb_result_is_streaming` - -This method can be called multiple times for each prepared statement, and the parameters can be modified -between calls to this function. - -Note that the result must be freed with `duckdb_destroy_result`. - -* @param prepared_statement The prepared statement to execute. -* @param out_result The query result. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_execute_prepared_streaming(duckdb_prepared_statement prepared_statement, - duckdb_result *out_result); - -#endif -//===--------------------------------------------------------------------===// -// Extract Statements -//===--------------------------------------------------------------------===// - -// A query string can be extracted into multiple SQL statements. Each statement can be prepared and executed separately. -/*! -Extract all statements from a query. -Note that after calling `duckdb_extract_statements`, the extracted statements should always be destroyed using -`duckdb_destroy_extracted`, even if no statements were extracted. - -If the extract fails, `duckdb_extract_statements_error` can be called to obtain the reason why the extract failed. - -* @param connection The connection object -* @param query The SQL query to extract -* @param out_extracted_statements The resulting extracted statements object -* @return The number of extracted statements or 0 on failure. -*/ -DUCKDB_API idx_t duckdb_extract_statements(duckdb_connection connection, const char *query, - duckdb_extracted_statements *out_extracted_statements); - -/*! -Prepare an extracted statement. -Note that after calling `duckdb_prepare_extracted_statement`, the prepared statement should always be destroyed using -`duckdb_destroy_prepare`, even if the prepare fails. - -If the prepare fails, `duckdb_prepare_error` can be called to obtain the reason why the prepare failed. - -* @param connection The connection object -* @param extracted_statements The extracted statements object -* @param index The index of the extracted statement to prepare -* @param out_prepared_statement The resulting prepared statement object -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_prepare_extracted_statement(duckdb_connection connection, - duckdb_extracted_statements extracted_statements, - idx_t index, - duckdb_prepared_statement *out_prepared_statement); - -/*! -Returns the error message contained within the extracted statements. -The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_extracted` is called. - -* @param extracted_statements The extracted statements to fetch the error from. -* @return The error of the extracted statements. -*/ -DUCKDB_API const char *duckdb_extract_statements_error(duckdb_extracted_statements extracted_statements); - -/*! -De-allocates all memory allocated for the extracted statements. -* @param extracted_statements The extracted statements to destroy. -*/ -DUCKDB_API void duckdb_destroy_extracted(duckdb_extracted_statements *extracted_statements); - -//===--------------------------------------------------------------------===// -// Pending Result Interface -//===--------------------------------------------------------------------===// - -/*! -Executes the prepared statement with the given bound parameters, and returns a pending result. -The pending result represents an intermediate structure for a query that is not yet fully executed. -The pending result can be used to incrementally execute a query, returning control to the client between tasks. - -Note that after calling `duckdb_pending_prepared`, the pending result should always be destroyed using -`duckdb_destroy_pending`, even if this function returns DuckDBError. - -* @param prepared_statement The prepared statement to execute. -* @param out_result The pending query result. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_pending_prepared(duckdb_prepared_statement prepared_statement, - duckdb_pending_result *out_result); - -#ifndef DUCKDB_API_NO_DEPRECATED -/*! -**DEPRECATION NOTICE**: This method is scheduled for removal in a future release. - -Executes the prepared statement with the given bound parameters, and returns a pending result. -This pending result will create a streaming duckdb_result when executed. -The pending result represents an intermediate structure for a query that is not yet fully executed. - -Note that after calling `duckdb_pending_prepared_streaming`, the pending result should always be destroyed using -`duckdb_destroy_pending`, even if this function returns DuckDBError. - -* @param prepared_statement The prepared statement to execute. -* @param out_result The pending query result. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_pending_prepared_streaming(duckdb_prepared_statement prepared_statement, - duckdb_pending_result *out_result); - -#endif -/*! -Closes the pending result and de-allocates all memory allocated for the result. - -* @param pending_result The pending result to destroy. -*/ -DUCKDB_API void duckdb_destroy_pending(duckdb_pending_result *pending_result); - -/*! -Returns the error message contained within the pending result. - -The result of this function must not be freed. It will be cleaned up when `duckdb_destroy_pending` is called. - -* @param pending_result The pending result to fetch the error from. -* @return The error of the pending result. -*/ -DUCKDB_API const char *duckdb_pending_error(duckdb_pending_result pending_result); - -/*! -Executes a single task within the query, returning whether or not the query is ready. - -If this returns DUCKDB_PENDING_RESULT_READY, the duckdb_execute_pending function can be called to obtain the result. -If this returns DUCKDB_PENDING_RESULT_NOT_READY, the duckdb_pending_execute_task function should be called again. -If this returns DUCKDB_PENDING_ERROR, an error occurred during execution. - -The error message can be obtained by calling duckdb_pending_error on the pending_result. - -* @param pending_result The pending result to execute a task within. -* @return The state of the pending result after the execution. -*/ -DUCKDB_API duckdb_pending_state duckdb_pending_execute_task(duckdb_pending_result pending_result); - -/*! -If this returns DUCKDB_PENDING_RESULT_READY, the duckdb_execute_pending function can be called to obtain the result. -If this returns DUCKDB_PENDING_RESULT_NOT_READY, the duckdb_pending_execute_check_state function should be called again. -If this returns DUCKDB_PENDING_ERROR, an error occurred during execution. - -The error message can be obtained by calling duckdb_pending_error on the pending_result. - -* @param pending_result The pending result. -* @return The state of the pending result. -*/ -DUCKDB_API duckdb_pending_state duckdb_pending_execute_check_state(duckdb_pending_result pending_result); - -/*! -Fully execute a pending query result, returning the final query result. - -If duckdb_pending_execute_task has been called until DUCKDB_PENDING_RESULT_READY was returned, this will return fast. -Otherwise, all remaining tasks must be executed first. - -Note that the result must be freed with `duckdb_destroy_result`. - -* @param pending_result The pending result to execute. -* @param out_result The result object. -* @return `DuckDBSuccess` on success or `DuckDBError` on failure. -*/ -DUCKDB_API duckdb_state duckdb_execute_pending(duckdb_pending_result pending_result, duckdb_result *out_result); - -/*! -Returns whether a duckdb_pending_state is finished executing. For example if `pending_state` is -DUCKDB_PENDING_RESULT_READY, this function will return true. - -* @param pending_state The pending state on which to decide whether to finish execution. -* @return Boolean indicating pending execution should be considered finished. -*/ -DUCKDB_API bool duckdb_pending_execution_is_finished(duckdb_pending_state pending_state); - -//===--------------------------------------------------------------------===// -// Value Interface -//===--------------------------------------------------------------------===// - -/*! -Destroys the value and de-allocates all memory allocated for that type. - -* @param value The value to destroy. -*/ -DUCKDB_API void duckdb_destroy_value(duckdb_value *value); - -/*! -Creates a value from a null-terminated string - -* @param text The null-terminated string -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_varchar(const char *text); - -/*! -Creates a value from a string - -* @param text The text -* @param length The length of the text -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_varchar_length(const char *text, idx_t length); - -/*! -Creates a value from a boolean - -* @param input The boolean value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_bool(bool input); - -/*! -Creates a value from a int8_t (a tinyint) - -* @param input The tinyint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_int8(int8_t input); - -/*! -Creates a value from a uint8_t (a utinyint) - -* @param input The utinyint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uint8(uint8_t input); - -/*! -Creates a value from a int16_t (a smallint) - -* @param input The smallint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_int16(int16_t input); - -/*! -Creates a value from a uint16_t (a usmallint) - -* @param input The usmallint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uint16(uint16_t input); - -/*! -Creates a value from a int32_t (an integer) - -* @param input The integer value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_int32(int32_t input); - -/*! -Creates a value from a uint32_t (a uinteger) - -* @param input The uinteger value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uint32(uint32_t input); - -/*! -Creates a value from a uint64_t (a ubigint) - -* @param input The ubigint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uint64(uint64_t input); - -/*! -Creates a value from an int64 - -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_int64(int64_t val); - -/*! -Creates a value from a hugeint - -* @param input The hugeint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_hugeint(duckdb_hugeint input); - -/*! -Creates a value from a uhugeint - -* @param input The uhugeint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uhugeint(duckdb_uhugeint input); - -/*! -Creates a VARINT value from a duckdb_varint - -* @param input The duckdb_varint value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_varint(duckdb_varint input); - -/*! -Creates a DECIMAL value from a duckdb_decimal - -* @param input The duckdb_decimal value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_decimal(duckdb_decimal input); - -/*! -Creates a value from a float - -* @param input The float value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_float(float input); - -/*! -Creates a value from a double - -* @param input The double value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_double(double input); - -/*! -Creates a value from a date - -* @param input The date value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_date(duckdb_date input); - -/*! -Creates a value from a time - -* @param input The time value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_time(duckdb_time input); - -/*! -Creates a value from a time_tz. -Not to be confused with `duckdb_create_time_tz`, which creates a duckdb_time_tz_t. - -* @param value The time_tz value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_time_tz_value(duckdb_time_tz value); - -/*! -Creates a TIMESTAMP value from a duckdb_timestamp - -* @param input The duckdb_timestamp value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_timestamp(duckdb_timestamp input); - -/*! -Creates a TIMESTAMP_TZ value from a duckdb_timestamp - -* @param input The duckdb_timestamp value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_timestamp_tz(duckdb_timestamp input); - -/*! -Creates a TIMESTAMP_S value from a duckdb_timestamp_s - -* @param input The duckdb_timestamp_s value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_timestamp_s(duckdb_timestamp_s input); - -/*! -Creates a TIMESTAMP_MS value from a duckdb_timestamp_ms - -* @param input The duckdb_timestamp_ms value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_timestamp_ms(duckdb_timestamp_ms input); - -/*! -Creates a TIMESTAMP_NS value from a duckdb_timestamp_ns - -* @param input The duckdb_timestamp_ns value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_timestamp_ns(duckdb_timestamp_ns input); - -/*! -Creates a value from an interval - -* @param input The interval value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_interval(duckdb_interval input); - -/*! -Creates a value from a blob - -* @param data The blob data -* @param length The length of the blob data -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_blob(const uint8_t *data, idx_t length); - -/*! -Creates a BIT value from a duckdb_bit - -* @param input The duckdb_bit value -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_bit(duckdb_bit input); - -/*! -Creates a UUID value from a uhugeint - -* @param input The duckdb_uhugeint containing the UUID -* @return The value. This must be destroyed with `duckdb_destroy_value`. -*/ -DUCKDB_API duckdb_value duckdb_create_uuid(duckdb_uhugeint input); - -/*! -Returns the boolean value of the given value. - -* @param val A duckdb_value containing a boolean -* @return A boolean, or false if the value cannot be converted -*/ -DUCKDB_API bool duckdb_get_bool(duckdb_value val); - -/*! -Returns the int8_t value of the given value. - -* @param val A duckdb_value containing a tinyint -* @return A int8_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API int8_t duckdb_get_int8(duckdb_value val); - -/*! -Returns the uint8_t value of the given value. - -* @param val A duckdb_value containing a utinyint -* @return A uint8_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API uint8_t duckdb_get_uint8(duckdb_value val); - -/*! -Returns the int16_t value of the given value. - -* @param val A duckdb_value containing a smallint -* @return A int16_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API int16_t duckdb_get_int16(duckdb_value val); - -/*! -Returns the uint16_t value of the given value. - -* @param val A duckdb_value containing a usmallint -* @return A uint16_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API uint16_t duckdb_get_uint16(duckdb_value val); - -/*! -Returns the int32_t value of the given value. - -* @param val A duckdb_value containing a integer -* @return A int32_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API int32_t duckdb_get_int32(duckdb_value val); - -/*! -Returns the uint32_t value of the given value. - -* @param val A duckdb_value containing a uinteger -* @return A uint32_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API uint32_t duckdb_get_uint32(duckdb_value val); - -/*! -Returns the int64_t value of the given value. - -* @param val A duckdb_value containing a bigint -* @return A int64_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API int64_t duckdb_get_int64(duckdb_value val); - -/*! -Returns the uint64_t value of the given value. - -* @param val A duckdb_value containing a ubigint -* @return A uint64_t, or MinValue if the value cannot be converted -*/ -DUCKDB_API uint64_t duckdb_get_uint64(duckdb_value val); - -/*! -Returns the hugeint value of the given value. - -* @param val A duckdb_value containing a hugeint -* @return A duckdb_hugeint, or MinValue if the value cannot be converted -*/ -DUCKDB_API duckdb_hugeint duckdb_get_hugeint(duckdb_value val); - -/*! -Returns the uhugeint value of the given value. - -* @param val A duckdb_value containing a uhugeint -* @return A duckdb_uhugeint, or MinValue if the value cannot be converted -*/ -DUCKDB_API duckdb_uhugeint duckdb_get_uhugeint(duckdb_value val); - -/*! -Returns the duckdb_varint value of the given value. -The `data` field must be destroyed with `duckdb_free`. - -* @param val A duckdb_value containing a VARINT -* @return A duckdb_varint. The `data` field must be destroyed with `duckdb_free`. -*/ -DUCKDB_API duckdb_varint duckdb_get_varint(duckdb_value val); - -/*! -Returns the duckdb_decimal value of the given value. - -* @param val A duckdb_value containing a DECIMAL -* @return A duckdb_decimal, or MinValue if the value cannot be converted -*/ -DUCKDB_API duckdb_decimal duckdb_get_decimal(duckdb_value val); - -/*! -Returns the float value of the given value. - -* @param val A duckdb_value containing a float -* @return A float, or NAN if the value cannot be converted -*/ -DUCKDB_API float duckdb_get_float(duckdb_value val); - -/*! -Returns the double value of the given value. - -* @param val A duckdb_value containing a double -* @return A double, or NAN if the value cannot be converted -*/ -DUCKDB_API double duckdb_get_double(duckdb_value val); - -/*! -Returns the date value of the given value. - -* @param val A duckdb_value containing a date -* @return A duckdb_date, or MinValue if the value cannot be converted -*/ -DUCKDB_API duckdb_date duckdb_get_date(duckdb_value val); - -/*! -Returns the time value of the given value. - -* @param val A duckdb_value containing a time -* @return A duckdb_time, or MinValue