Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add decrease interface for aggregation #9737

Open
wants to merge 39 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
3f80eed
init
xzhangxian1008 Nov 26, 2024
171fd46
add gtest
xzhangxian1008 Nov 27, 2024
943c578
add ft
xzhangxian1008 Nov 27, 2024
e0b05a2
refine test framework
xzhangxian1008 Nov 27, 2024
75d2f12
fix compilation phase
xzhangxian1008 Nov 27, 2024
7022607
init
xzhangxian1008 Dec 4, 2024
b011bfa
codes done, need tests
xzhangxian1008 Dec 5, 2024
8628a2e
save
xzhangxian1008 Dec 12, 2024
db4e399
add tests
xzhangxian1008 Dec 13, 2024
a9360da
add sum tests
xzhangxian1008 Dec 17, 2024
e3c31bd
refine tests
xzhangxian1008 Dec 17, 2024
e3b9756
format
xzhangxian1008 Dec 17, 2024
a159857
fix bugs
xzhangxian1008 Dec 18, 2024
0d4401d
refine test
xzhangxian1008 Dec 18, 2024
9ff2bd3
fix tests
xzhangxian1008 Dec 18, 2024
a9aa879
add test for string type
xzhangxian1008 Dec 19, 2024
bf652bf
add test for SingleValueDataGeneric type
xzhangxian1008 Dec 19, 2024
61eaef0
tweaking
xzhangxian1008 Dec 20, 2024
6017780
Merge branch 'master' into wagg
xzhangxian1008 Dec 20, 2024
2358efb
remove something
xzhangxian1008 Dec 20, 2024
3227b50
revoke
xzhangxian1008 Dec 20, 2024
b2772ea
add AlignedBuffer
xzhangxian1008 Dec 20, 2024
37f4cc5
remove something
xzhangxian1008 Dec 20, 2024
dd52c9f
tweaking
xzhangxian1008 Dec 20, 2024
6379b0f
remove useless change
xzhangxian1008 Dec 20, 2024
9e41b3a
fix ci
xzhangxian1008 Dec 23, 2024
e6adcbf
fix ci
xzhangxian1008 Dec 23, 2024
b84c911
fix ut
xzhangxian1008 Dec 24, 2024
aa12221
address some comments
xzhangxian1008 Dec 25, 2024
2b9e991
create new class for window
xzhangxian1008 Dec 26, 2024
c54f45f
add ut tests
xzhangxian1008 Dec 27, 2024
ad8f002
add static
xzhangxian1008 Dec 27, 2024
2944543
fix ut
xzhangxian1008 Jan 3, 2025
4a550c7
address comments
xzhangxian1008 Jan 7, 2025
2c7e362
address comments
xzhangxian1008 Jan 7, 2025
917892b
replace queue with tree
xzhangxian1008 Jan 9, 2025
8592277
remove useless codes
xzhangxian1008 Jan 9, 2025
0d6b0d1
address comment
xzhangxian1008 Jan 10, 2025
c7817b5
tweaking
xzhangxian1008 Jan 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions dbms/src/AggregateFunctions/AggregateFunctionArray.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
for (size_t i = 0; i < num_arguments; ++i)
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData();

const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const auto & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets = first_array_column.getOffsets();

size_t begin = row_num == 0 ? 0 : offsets[row_num - 1];
Expand All @@ -82,7 +82,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
/// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
for (size_t i = 1; i < num_arguments; ++i)
{
const ColumnArray & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const auto & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const IColumn::Offsets & ith_offsets = ith_column.getOffsets();

if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
Expand All @@ -95,6 +95,7 @@ class AggregateFunctionArray final : public IAggregateFunctionHelper<AggregateFu
nested_func->add(place, nested, i, arena);
}


void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
nested_func->merge(place, rhs, arena);
Expand Down
19 changes: 19 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionAvg.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@ struct AggregateFunctionAvgData
T sum;
UInt64 count;

void reset()
{
sum = T(0);
count = 0;
}

AggregateFunctionAvgData()
: sum(0)
, count(0)
Expand Down Expand Up @@ -78,6 +84,19 @@ class AggregateFunctionAvg final
++this->data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
if constexpr (IsDecimal<T>)
this->data(place).sum -= static_cast<const ColumnDecimal<T> &>(*columns[0]).getData()[row_num];
else
this->data(place).sum -= static_cast<const ColumnVector<T> &>(*columns[0]).getData()[row_num];

--this->data(place).count;
guo-shaoge marked this conversation as resolved.
Show resolved Hide resolved
assert(this->data(place).count >= 0);
}

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
this->data(place).sum += this->data(rhs).sum;
Expand Down
28 changes: 28 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionCount.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace DB
struct AggregateFunctionCountData
{
UInt64 count = 0;

inline void reset() noexcept { count = 0; }
};

namespace ErrorCodes
Expand All @@ -52,6 +54,13 @@ class AggregateFunctionCount final
++data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn **, size_t, Arena *) const override
{
--data(place).count;
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void addBatchSinglePlace(
size_t start_offset,
size_t batch_size,
Expand Down Expand Up @@ -173,6 +182,13 @@ class AggregateFunctionCountNotNullUnary final
data(place).count += !static_cast<const ColumnNullable &>(*columns[0]).isNullAt(row_num);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
data(place).count -= !static_cast<const ColumnNullable &>(*columns[0]).isNullAt(row_num);
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;
Expand Down Expand Up @@ -234,6 +250,18 @@ class AggregateFunctionCountNotNullVariadic final
++data(place).count;
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena *) const override
{
for (size_t i = 0; i < number_of_arguments; ++i)
if (is_nullable[i] && static_cast<const ColumnNullable &>(*columns[i]).isNullAt(row_num))
return;

--data(place).count;
guo-shaoge marked this conversation as resolved.
Show resolved Hide resolved
assert(data(place).count >= 0);
}

void reset(AggregateDataPtr __restrict place) const override { data(place).reset(); }

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena *) const override
{
data(place).count += data(rhs).count;
Expand Down
20 changes: 20 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionFactory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,26 @@ AggregateFunctionPtr AggregateFunctionFactory::get(
return res;
}

AggregateFunctionPtr AggregateFunctionFactory::getForWindow(
const Context & context,
const String & name,
const DataTypes & argument_types,
const Array & parameters,
int recursion_level) const
{
AggregateFunctionCombinatorPtr combinator
= AggregateFunctionCombinatorFactory::instance().tryFindSuffix("NullForWindow");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if the argument is not null, it still need this combinator?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

even if the argument is not null, it still need this combinator?

yes

if (!combinator)
throw Exception(
"Logical error: cannot find aggregate function combinator to apply a function to Nullable for window "
"arguments.",
ErrorCodes::LOGICAL_ERROR);

DataTypes nested_types = combinator->transformArguments(argument_types);
AggregateFunctionPtr nested_function = getImpl(context, name, nested_types, parameters, recursion_level);
return combinator->transformAggregateFunction(nested_function, argument_types, parameters);
}

AggregateFunctionPtr AggregateFunctionFactory::getImpl(
const Context & context,
const String & name,
Expand Down
8 changes: 8 additions & 0 deletions dbms/src/AggregateFunctions/AggregateFunctionFactory.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ class AggregateFunctionFactory final : public ext::Singleton<AggregateFunctionFa
int recursion_level = 0,
bool empty_input_as_null = false) const;

/// Throws an exception if not found.
AggregateFunctionPtr getForWindow(
const Context & context,
const String & name,
const DataTypes & argument_types,
const Array & parameters = {},
int recursion_level = 0) const;

/// Returns nullptr if not found.
AggregateFunctionPtr tryGet(
const Context & context,
Expand Down
18 changes: 10 additions & 8 deletions dbms/src/AggregateFunctions/AggregateFunctionForEach.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class AggregateFunctionForEach final
AggregateFunctionForEachData & ensureAggregateData(
AggregateDataPtr __restrict place,
size_t new_size,
Arena & arena) const
Arena * arena) const
{
AggregateFunctionForEachData & state = data(place);

Expand All @@ -75,7 +75,9 @@ class AggregateFunctionForEach final
size_t old_size = state.dynamic_array_size;
if (old_size < new_size)
{
state.array_of_aggregate_datas = arena.realloc(
RUNTIME_CHECK_MSG(arena, "got null arena ptr in ensureAggregateData");

state.array_of_aggregate_datas = arena->realloc(
state.array_of_aggregate_datas,
old_size * nested_size_of_data,
new_size * nested_size_of_data);
Expand Down Expand Up @@ -155,7 +157,7 @@ class AggregateFunctionForEach final
for (size_t i = 0; i < num_arguments; ++i)
nested[i] = &static_cast<const ColumnArray &>(*columns[i]).getData();

const ColumnArray & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const auto & first_array_column = static_cast<const ColumnArray &>(*columns[0]);
const IColumn::Offsets & offsets = first_array_column.getOffsets();

size_t begin = row_num == 0 ? 0 : offsets[row_num - 1];
Expand All @@ -164,7 +166,7 @@ class AggregateFunctionForEach final
/// Sanity check. NOTE We can implement specialization for a case with single argument, if the check will hurt performance.
for (size_t i = 1; i < num_arguments; ++i)
{
const ColumnArray & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const auto & ith_column = static_cast<const ColumnArray &>(*columns[i]);
const IColumn::Offsets & ith_offsets = ith_column.getOffsets();

if (ith_offsets[row_num] != end || (row_num != 0 && ith_offsets[row_num - 1] != begin))
Expand All @@ -173,7 +175,7 @@ class AggregateFunctionForEach final
ErrorCodes::SIZES_OF_ARRAYS_DOESNT_MATCH);
}

AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, *arena);
AggregateFunctionForEachData & state = ensureAggregateData(place, end - begin, arena);

char * nested_state = state.array_of_aggregate_datas;
for (size_t i = begin; i < end; ++i)
Expand All @@ -186,7 +188,7 @@ class AggregateFunctionForEach final
void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
{
const AggregateFunctionForEachData & rhs_state = data(rhs);
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, *arena);
AggregateFunctionForEachData & state = ensureAggregateData(place, rhs_state.dynamic_array_size, arena);

const char * rhs_nested_state = rhs_state.array_of_aggregate_datas;
char * nested_state = state.array_of_aggregate_datas;
Expand Down Expand Up @@ -220,7 +222,7 @@ class AggregateFunctionForEach final
size_t new_size = 0;
readBinary(new_size, buf);

ensureAggregateData(place, new_size, *arena);
ensureAggregateData(place, new_size, arena);

char * nested_state = state.array_of_aggregate_datas;
for (size_t i = 0; i < new_size; ++i)
Expand All @@ -234,7 +236,7 @@ class AggregateFunctionForEach final
{
const AggregateFunctionForEachData & state = data(place);

ColumnArray & arr_to = static_cast<ColumnArray &>(to);
auto & arr_to = static_cast<ColumnArray &>(to);
ColumnArray::Offsets & offsets_to = arr_to.getOffsets();
IColumn & elems_to = arr_to.getData();

Expand Down
32 changes: 25 additions & 7 deletions dbms/src/AggregateFunctions/AggregateFunctionGroupConcat.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,32 +116,36 @@ class AggregateFunctionGroupConcat final

DataTypePtr getReturnType() const override { return result_is_nullable ? makeNullable(ret_type) : ret_type; }

/// reject nulls before add() of nested agg
void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
template <bool is_add>
void addOrDecrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const
{
if constexpr (only_one_column)
{
if (is_nullable[0])
{
const ColumnNullable * column = static_cast<const ColumnNullable *>(columns[0]);
const auto * column = static_cast<const ColumnNullable *>(columns[0]);
if (!column->isNullAt(row_num))
{
this->setFlag(place);
const IColumn * nested_column = &column->getNestedColumn();
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);

if constexpr (is_add)
this->nested_function->add(this->nestedPlace(place), &nested_column, row_num, arena);
else
this->nested_function->decrease(this->nestedPlace(place), &nested_column, row_num, arena);
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved
}
return;
}
}
else
{
/// remove the row with null, except for sort columns
const ColumnTuple & tuple = static_cast<const ColumnTuple &>(*columns[0]);
const auto & tuple = static_cast<const ColumnTuple &>(*columns[0]);
for (size_t i = 0; i < number_of_concat_items; ++i)
{
if (is_nullable[i])
{
const ColumnNullable & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i));
const auto & nullable_col = static_cast<const ColumnNullable &>(tuple.getColumn(i));
if (nullable_col.isNullAt(row_num))
{
/// If at least one column has a null value in the current row,
Expand All @@ -152,7 +156,21 @@ class AggregateFunctionGroupConcat final
}
}
this->setFlag(place);
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
if constexpr (is_add)
this->nested_function->add(this->nestedPlace(place), columns, row_num, arena);
else
this->nested_function->decrease(this->nestedPlace(place), columns, row_num, arena);
}

void add(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena) const override
{
addOrDecrease<true>(place, columns, row_num, arena);
}

void decrease(AggregateDataPtr __restrict place, const IColumn ** columns, size_t row_num, Arena * arena)
const override
{
addOrDecrease<false>(place, columns, row_num, arena);
}

void insertResultInto(ConstAggregateDataPtr __restrict place, IColumn & to, Arena * arena) const override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ class AggregateFunctionIntersectionsMax final
PointType right = static_cast<const ColumnVector<PointType> &>(*columns[1]).getData()[row_num];

if (!isNaN(left))
this->data(place).value.push_back(std::make_pair(left, Int64(1)), arena);
this->data(place).value.push_back(std::make_pair(left, static_cast<Int64>(1)), arena);

if (!isNaN(right))
this->data(place).value.push_back(std::make_pair(right, Int64(-1)), arena);
this->data(place).value.push_back(std::make_pair(right, static_cast<Int64>(-1)), arena);
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, Arena * arena) const override
Expand Down
2 changes: 1 addition & 1 deletion dbms/src/AggregateFunctions/AggregateFunctionMerge.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class AggregateFunctionMerge final : public IAggregateFunctionHelper<AggregateFu
AggregateFunctionMerge(const AggregateFunctionPtr & nested_, const IDataType & argument)
: nested_func(nested_)
{
const DataTypeAggregateFunction * data_type = typeid_cast<const DataTypeAggregateFunction *>(&argument);
const auto * data_type = typeid_cast<const DataTypeAggregateFunction *>(&argument);

if (!data_type || data_type->getFunctionName() != nested_func->getName())
throw Exception(
Expand Down
Loading