Skip to content

Commit

Permalink
[core] Fix that sequence fields are mistakenly aggregated by default …
Browse files Browse the repository at this point in the history
…aggregator in AggregateMergeFunction (#4977)
  • Loading branch information
yuzelin authored Jan 22, 2025
1 parent 79154d4 commit ed6de3e
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.paimon.data.InternalRow;
import org.apache.paimon.mergetree.compact.aggregate.FieldAggregator;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldAggregatorFactory;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldPrimaryKeyAggFactory;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataField;
import org.apache.paimon.types.DataType;
Expand Down Expand Up @@ -54,6 +55,7 @@
import static org.apache.paimon.CoreOptions.PARTIAL_UPDATE_REMOVE_RECORD_ON_DELETE;
import static org.apache.paimon.CoreOptions.PARTIAL_UPDATE_REMOVE_RECORD_ON_SEQUENCE_GROUP;
import static org.apache.paimon.utils.InternalRowUtils.createFieldGetters;
import static org.apache.paimon.utils.Preconditions.checkArgument;

/**
* A {@link MergeFunction} where key is primary key (unique) and value is the partial record, update
Expand Down Expand Up @@ -352,10 +354,6 @@ private Factory(Options options, RowType rowType, List<String> primaryKeys) {
this.fieldAggregators =
createFieldAggregators(
rowType, primaryKeys, allSequenceFields, new CoreOptions(options));
if (!fieldAggregators.isEmpty() && fieldSeqComparators.isEmpty()) {
throw new IllegalArgumentException(
"Must use sequence group for aggregation functions.");
}

removeRecordOnDelete = options.get(PARTIAL_UPDATE_REMOVE_RECORD_ON_DELETE);

Expand Down Expand Up @@ -526,41 +524,47 @@ private Map<Integer, Supplier<FieldAggregator>> createFieldAggregators(
List<String> fieldNames = rowType.getFieldNames();
List<DataType> fieldTypes = rowType.getFieldTypes();
Map<Integer, Supplier<FieldAggregator>> fieldAggregators = new HashMap<>();
String defaultAggFunc = options.fieldsDefaultFunc();
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
DataType fieldType = fieldTypes.get(i);
// aggregate by primary keys, so they do not aggregate
boolean isPrimaryKey = primaryKeys.contains(fieldName);
String strAggFunc = options.fieldAggFunc(fieldName);
boolean ignoreRetract = options.fieldAggIgnoreRetract(fieldName);

if (strAggFunc != null) {
if (allSequenceFields.contains(fieldName)) {
// no agg for sequence fields
continue;
}

if (primaryKeys.contains(fieldName)) {
// aggregate by primary keys, so they do not aggregate
fieldAggregators.put(
i,
() ->
FieldAggregatorFactory.create(
fieldType,
strAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
} else if (defaultAggFunc != null && !allSequenceFields.contains(fieldName)) {
// no agg for sequence fields
fieldName,
FieldPrimaryKeyAggFactory.NAME,
options));
continue;
}

String aggFuncName = getAggFuncName(options, fieldName);
if (aggFuncName != null) {
checkArgument(
!fieldSeqComparators.isEmpty(),
"Must use sequence group for aggregation functions.");
fieldAggregators.put(
i,
() ->
FieldAggregatorFactory.create(
fieldType,
defaultAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName));
fieldType, fieldName, aggFuncName, options));
}
}
return fieldAggregators;
}

@Nullable
private String getAggFuncName(CoreOptions options, String fieldName) {
String aggFunc = options.fieldAggFunc(fieldName);
return aggFunc == null ? options.fieldsDefaultFunc() : aggFunc;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
import org.apache.paimon.mergetree.compact.MergeFunction;
import org.apache.paimon.mergetree.compact.MergeFunctionFactory;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldAggregatorFactory;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldLastNonNullValueAggFactory;
import org.apache.paimon.mergetree.compact.aggregate.factory.FieldPrimaryKeyAggFactory;
import org.apache.paimon.options.Options;
import org.apache.paimon.types.DataType;
import org.apache.paimon.types.RowKind;
Expand Down Expand Up @@ -132,27 +134,39 @@ public MergeFunction<KeyValue> create(@Nullable int[][] projection) {
}

FieldAggregator[] fieldAggregators = new FieldAggregator[fieldNames.size()];
String defaultAggFunc = options.fieldsDefaultFunc();
List<String> sequenceFields = options.sequenceField();
for (int i = 0; i < fieldNames.size(); i++) {
String fieldName = fieldNames.get(i);
DataType fieldType = fieldTypes.get(i);
// aggregate by primary keys, so they do not aggregate
boolean isPrimaryKey = primaryKeys.contains(fieldName);
String strAggFunc = options.fieldAggFunc(fieldName);
strAggFunc = strAggFunc == null ? defaultAggFunc : strAggFunc;

boolean ignoreRetract = options.fieldAggIgnoreRetract(fieldName);
String aggFuncName = getAggFuncName(fieldName, sequenceFields);
fieldAggregators[i] =
FieldAggregatorFactory.create(
fieldType,
strAggFunc,
ignoreRetract,
isPrimaryKey,
options,
fieldName);
FieldAggregatorFactory.create(fieldType, fieldName, aggFuncName, options);
}

return new AggregateMergeFunction(createFieldGetters(fieldTypes), fieldAggregators);
}

private String getAggFuncName(String fieldName, List<String> sequenceFields) {
if (sequenceFields.contains(fieldName)) {
// no agg for sequence fields, use last_non_null_value to do cover
return FieldLastNonNullValueAggFactory.NAME;
}

if (primaryKeys.contains(fieldName)) {
// aggregate by primary keys, so they do not aggregate
return FieldPrimaryKeyAggFactory.NAME;
}

String aggFuncName = options.fieldAggFunc(fieldName);
if (aggFuncName == null) {
aggFuncName = options.fieldsDefaultFunc();
}
if (aggFuncName == null) {
// final default agg func
aggFuncName = FieldLastNonNullValueAggFactory.NAME;
}
return aggFuncName;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
import org.apache.paimon.mergetree.compact.aggregate.FieldIgnoreRetractAgg;
import org.apache.paimon.types.DataType;

import javax.annotation.Nullable;

/** Factory for {@link FieldAggregator}. */
public interface FieldAggregatorFactory extends Factory {

Expand All @@ -35,37 +33,23 @@ public interface FieldAggregatorFactory extends Factory {
String identifier();

static FieldAggregator create(
DataType fieldType,
@Nullable String strAgg,
boolean ignoreRetract,
boolean isPrimaryKey,
CoreOptions options,
String field) {
FieldAggregator fieldAggregator;
if (isPrimaryKey) {
strAgg = FieldPrimaryKeyAggFactory.NAME;
} else if (strAgg == null) {
strAgg = FieldLastNonNullValueAggFactory.NAME;
}

DataType fieldType, String fieldName, String aggFuncName, CoreOptions options) {
FieldAggregatorFactory fieldAggregatorFactory =
FactoryUtil.discoverFactory(
FieldAggregator.class.getClassLoader(),
FieldAggregatorFactory.class,
strAgg);
aggFuncName);
if (fieldAggregatorFactory == null) {
throw new RuntimeException(
String.format(
"Use unsupported aggregation: %s or spell aggregate function incorrectly!",
strAgg));
}

fieldAggregator = fieldAggregatorFactory.create(fieldType, options, field);

if (ignoreRetract) {
fieldAggregator = new FieldIgnoreRetractAgg(fieldAggregator);
aggFuncName));
}

return fieldAggregator;
FieldAggregator fieldAggregator =
fieldAggregatorFactory.create(fieldType, options, fieldName);
return options.fieldAggIgnoreRetract(fieldName)
? new FieldIgnoreRetractAgg(fieldAggregator)
: fieldAggregator;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -936,10 +936,8 @@ public void testCustomAgg() throws IOException {
FieldAggregatorFactory.create(
DataTypes.STRING(),
"custom",
false,
false,
CoreOptions.fromMap(new HashMap<>()),
"custom");
"custom",
CoreOptions.fromMap(new HashMap<>()));

Object agg = fieldAggregator.agg("test", "test");
assertThat(agg).isEqualTo("test");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1195,6 +1195,24 @@ public void testMergeRead() {
assertThat(batchSql("SELECT * FROM T where v = 1"))
.containsExactlyInAnyOrder(Row.of(2, 1, 1));
}

@Test
public void testSequenceFieldWithDefaultAgg() {
sql(
"CREATE TABLE seq_default_agg ("
+ " pk INT PRIMARY KEY NOT ENFORCED,"
+ " seq INT,"
+ " v INT) WITH ("
+ " 'merge-engine'='aggregation',"
+ " 'sequence.field'='seq',"
+ " 'fields.default-aggregate-function'='sum'"
+ ")");

sql("INSERT INTO seq_default_agg VALUES (0, 1, 1)");
sql("INSERT INTO seq_default_agg VALUES (0, 2, 2)");

assertThat(sql("SELECT * FROM seq_default_agg")).containsExactly(Row.of(0, 2, 3));
}
}

/** ITCase for {@link FieldNestedUpdateAgg}. */
Expand Down

0 comments on commit ed6de3e

Please sign in to comment.